fix(providers): avoid chat template body for vllm mistral
This commit is contained in:
@ -3,9 +3,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from ipaddress import ip_address
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from .registry import find_by_model, find_by_name, find_gateway
|
from .registry import find_by_model, find_by_name, find_gateway
|
||||||
@ -26,6 +28,23 @@ except ModuleNotFoundError: # pragma: no cover
|
|||||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_like_local_vllm_api_base(api_base: str | None) -> bool:
|
||||||
|
if not api_base:
|
||||||
|
return False
|
||||||
|
lowered = api_base.lower()
|
||||||
|
if "vllm" in lowered or "localhost" in lowered:
|
||||||
|
return True
|
||||||
|
|
||||||
|
host = urlsplit(lowered).hostname or ""
|
||||||
|
if host in {"127.0.0.1", "::1", "0.0.0.0"}:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
parsed_host = ip_address(host)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return parsed_host.is_private or parsed_host.is_loopback
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(LLMProvider):
|
class LiteLLMProvider(LLMProvider):
|
||||||
"""通过 LiteLLM 统一访问大多数 provider。"""
|
"""通过 LiteLLM 统一访问大多数 provider。"""
|
||||||
|
|
||||||
@ -200,10 +219,12 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
kwargs["extra_body"] = extra_body
|
kwargs["extra_body"] = extra_body
|
||||||
|
|
||||||
def _uses_mistral_reasoning_parser(self, original_model: str, resolved_model: str) -> bool:
|
def _uses_mistral_reasoning_parser(self, original_model: str, resolved_model: str) -> bool:
|
||||||
if self.provider_name != "vllm":
|
|
||||||
return False
|
|
||||||
model_names = f"{original_model} {resolved_model}".lower()
|
model_names = f"{original_model} {resolved_model}".lower()
|
||||||
return "mistral" in model_names
|
if "mistral" not in model_names:
|
||||||
|
return False
|
||||||
|
if self.provider_name == "vllm":
|
||||||
|
return True
|
||||||
|
return self.provider_name in {"openai", "custom"} and _looks_like_local_vllm_api_base(self.api_base)
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -253,6 +253,91 @@ def test_mistral_vllm_omits_reasoning_body_when_thinking_mode_is_unspecified(
|
|||||||
assert "extra_body" not in captured
|
assert "extra_body" not in captured
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_openai_compatible_private_vllm_uses_reasoning_effort(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
content = "ok"
|
||||||
|
reasoning_content = None
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
class Choice:
|
||||||
|
message = Message()
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
class Response:
|
||||||
|
choices = [Choice()]
|
||||||
|
usage = None
|
||||||
|
|
||||||
|
async def fake_acompletion(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return Response()
|
||||||
|
|
||||||
|
monkeypatch.setattr("beaver.engine.providers.litellm.acompletion", fake_acompletion)
|
||||||
|
monkeypatch.setattr("beaver.engine.providers.litellm.litellm", SimpleNamespace())
|
||||||
|
|
||||||
|
provider = LiteLLMProvider(
|
||||||
|
api_key="EMPTY",
|
||||||
|
api_base="http://172.19.207.103/v1",
|
||||||
|
default_model="Mistral-Medium-3.5-128B",
|
||||||
|
provider_name="openai",
|
||||||
|
)
|
||||||
|
asyncio.run(
|
||||||
|
provider.chat(
|
||||||
|
[{"role": "user", "content": "reply ok"}],
|
||||||
|
model="Mistral-Medium-3.5-128B",
|
||||||
|
thinking_enabled=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured["extra_body"] == {"reasoning_effort": "none"}
|
||||||
|
assert "chat_template_kwargs" not in captured["extra_body"]
|
||||||
|
assert "thinking" not in captured["extra_body"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_openai_compatible_private_vllm_omits_body_when_unspecified(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
content = "ok"
|
||||||
|
reasoning_content = None
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
class Choice:
|
||||||
|
message = Message()
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
class Response:
|
||||||
|
choices = [Choice()]
|
||||||
|
usage = None
|
||||||
|
|
||||||
|
async def fake_acompletion(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return Response()
|
||||||
|
|
||||||
|
monkeypatch.setattr("beaver.engine.providers.litellm.acompletion", fake_acompletion)
|
||||||
|
monkeypatch.setattr("beaver.engine.providers.litellm.litellm", SimpleNamespace())
|
||||||
|
|
||||||
|
provider = LiteLLMProvider(
|
||||||
|
api_key="EMPTY",
|
||||||
|
api_base="http://172.19.207.103/v1",
|
||||||
|
default_model="Mistral-Medium-3.5-128B",
|
||||||
|
provider_name="openai",
|
||||||
|
)
|
||||||
|
asyncio.run(
|
||||||
|
provider.chat(
|
||||||
|
[{"role": "user", "content": "reply ok"}],
|
||||||
|
model="Mistral-Medium-3.5-128B",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "extra_body" not in captured
|
||||||
|
|
||||||
|
|
||||||
def test_litellm_provider_sanitizes_tool_call_arguments(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_litellm_provider_sanitizes_tool_call_arguments(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
captured: dict = {}
|
captured: dict = {}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user