fix(providers): avoid chat template body for vllm mistral

This commit is contained in:
2026-06-09 13:19:09 +08:00
parent 9e2c02a333
commit dc4c6f313d
2 changed files with 109 additions and 3 deletions

View File

@ -3,9 +3,11 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from ipaddress import ip_address
import json import json
import os import os
from typing import Any from typing import Any
from urllib.parse import urlsplit
from .base import LLMProvider, LLMResponse, ToolCallRequest from .base import LLMProvider, LLMResponse, ToolCallRequest
from .registry import find_by_model, find_by_name, find_gateway from .registry import find_by_model, find_by_name, find_gateway
@ -26,6 +28,23 @@ except ModuleNotFoundError: # pragma: no cover
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) _ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
def _looks_like_local_vllm_api_base(api_base: str | None) -> bool:
if not api_base:
return False
lowered = api_base.lower()
if "vllm" in lowered or "localhost" in lowered:
return True
host = urlsplit(lowered).hostname or ""
if host in {"127.0.0.1", "::1", "0.0.0.0"}:
return True
try:
parsed_host = ip_address(host)
except ValueError:
return False
return parsed_host.is_private or parsed_host.is_loopback
class LiteLLMProvider(LLMProvider): class LiteLLMProvider(LLMProvider):
"""通过 LiteLLM 统一访问大多数 provider。""" """通过 LiteLLM 统一访问大多数 provider。"""
@ -200,10 +219,12 @@ class LiteLLMProvider(LLMProvider):
kwargs["extra_body"] = extra_body kwargs["extra_body"] = extra_body
def _uses_mistral_reasoning_parser(self, original_model: str, resolved_model: str) -> bool: def _uses_mistral_reasoning_parser(self, original_model: str, resolved_model: str) -> bool:
if self.provider_name != "vllm":
return False
model_names = f"{original_model} {resolved_model}".lower() model_names = f"{original_model} {resolved_model}".lower()
return "mistral" in model_names if "mistral" not in model_names:
return False
if self.provider_name == "vllm":
return True
return self.provider_name in {"openai", "custom"} and _looks_like_local_vllm_api_base(self.api_base)
async def chat( async def chat(
self, self,

View File

@ -253,6 +253,91 @@ def test_mistral_vllm_omits_reasoning_body_when_thinking_mode_is_unspecified(
assert "extra_body" not in captured assert "extra_body" not in captured
def test_mistral_openai_compatible_private_vllm_uses_reasoning_effort(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict = {}
class Message:
content = "ok"
reasoning_content = None
tool_calls = []
class Choice:
message = Message()
finish_reason = "stop"
class Response:
choices = [Choice()]
usage = None
async def fake_acompletion(**kwargs):
captured.update(kwargs)
return Response()
monkeypatch.setattr("beaver.engine.providers.litellm.acompletion", fake_acompletion)
monkeypatch.setattr("beaver.engine.providers.litellm.litellm", SimpleNamespace())
provider = LiteLLMProvider(
api_key="EMPTY",
api_base="http://172.19.207.103/v1",
default_model="Mistral-Medium-3.5-128B",
provider_name="openai",
)
asyncio.run(
provider.chat(
[{"role": "user", "content": "reply ok"}],
model="Mistral-Medium-3.5-128B",
thinking_enabled=False,
)
)
assert captured["extra_body"] == {"reasoning_effort": "none"}
assert "chat_template_kwargs" not in captured["extra_body"]
assert "thinking" not in captured["extra_body"]
def test_mistral_openai_compatible_private_vllm_omits_body_when_unspecified(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured: dict = {}
class Message:
content = "ok"
reasoning_content = None
tool_calls = []
class Choice:
message = Message()
finish_reason = "stop"
class Response:
choices = [Choice()]
usage = None
async def fake_acompletion(**kwargs):
captured.update(kwargs)
return Response()
monkeypatch.setattr("beaver.engine.providers.litellm.acompletion", fake_acompletion)
monkeypatch.setattr("beaver.engine.providers.litellm.litellm", SimpleNamespace())
provider = LiteLLMProvider(
api_key="EMPTY",
api_base="http://172.19.207.103/v1",
default_model="Mistral-Medium-3.5-128B",
provider_name="openai",
)
asyncio.run(
provider.chat(
[{"role": "user", "content": "reply ok"}],
model="Mistral-Medium-3.5-128B",
)
)
assert "extra_body" not in captured
def test_litellm_provider_sanitizes_tool_call_arguments(monkeypatch: pytest.MonkeyPatch) -> None: def test_litellm_provider_sanitizes_tool_call_arguments(monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict = {} captured: dict = {}