319 lines
13 KiB
Python
319 lines
13 KiB
Python
"""LiteLLM provider implementation for multi-provider support."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import contextmanager
|
|
from ipaddress import ip_address
|
|
import json
|
|
import os
|
|
from typing import Any
|
|
from urllib.parse import urlsplit
|
|
|
|
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
|
from .registry import find_by_model, find_by_name, find_gateway
|
|
from .runtime import ProviderRoutingConfig
|
|
|
|
try: # pragma: no cover - optional dependency
|
|
import json_repair
|
|
except ModuleNotFoundError: # pragma: no cover
|
|
json_repair = None # type: ignore[assignment]
|
|
|
|
try: # pragma: no cover - optional dependency
|
|
import litellm
|
|
from litellm import acompletion
|
|
except ModuleNotFoundError: # pragma: no cover
|
|
litellm = None # type: ignore[assignment]
|
|
acompletion = None # type: ignore[assignment]
|
|
|
|
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
|
|
|
|
|
def _looks_like_local_vllm_api_base(api_base: str | None) -> bool:
|
|
if not api_base:
|
|
return False
|
|
lowered = api_base.lower()
|
|
if "vllm" in lowered or "localhost" in lowered:
|
|
return True
|
|
|
|
host = urlsplit(lowered).hostname or ""
|
|
if host in {"127.0.0.1", "::1", "0.0.0.0"}:
|
|
return True
|
|
try:
|
|
parsed_host = ip_address(host)
|
|
except ValueError:
|
|
return False
|
|
return parsed_host.is_private or parsed_host.is_loopback
|
|
|
|
|
|
class LiteLLMProvider(LLMProvider):
|
|
"""通过 LiteLLM 统一访问大多数 provider。"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
api_base: str | None = None,
|
|
default_model: str = "anthropic/claude-opus-4-5",
|
|
extra_headers: dict[str, str] | None = None,
|
|
provider_name: str | None = None,
|
|
request_timeout_seconds: float | None = None,
|
|
routing: ProviderRoutingConfig | None = None,
|
|
) -> None:
|
|
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
|
self.default_model = default_model
|
|
self.extra_headers = extra_headers or {}
|
|
self.routing = routing
|
|
self.provider_name = provider_name
|
|
self._gateway = find_gateway(provider_name, api_key, api_base)
|
|
if litellm is not None:
|
|
litellm.suppress_debug_info = True
|
|
litellm.drop_params = True
|
|
|
|
def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]:
|
|
"""为当前请求生成 LiteLLM 依赖的临时环境变量。
|
|
|
|
LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime
|
|
之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。
|
|
"""
|
|
|
|
if not api_key:
|
|
return {}
|
|
spec = self._gateway
|
|
if spec is None and self.provider_name:
|
|
spec = find_by_name(self.provider_name)
|
|
if spec is None:
|
|
spec = find_by_model(model)
|
|
if spec is None or not spec.env_key:
|
|
return {}
|
|
overrides: dict[str, str] = {spec.env_key: api_key}
|
|
effective_base = api_base or spec.default_api_base
|
|
for env_name, env_value in spec.env_extras:
|
|
resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
|
overrides[env_name] = resolved
|
|
return overrides
|
|
|
|
@contextmanager
|
|
def _temporary_env(self, overrides: dict[str, str]):
|
|
"""只在当前请求期间注入 provider 需要的环境变量。"""
|
|
|
|
if not overrides:
|
|
yield
|
|
return
|
|
|
|
sentinel = object()
|
|
previous: dict[str, object] = {}
|
|
for key, value in overrides.items():
|
|
previous[key] = os.environ.get(key, sentinel)
|
|
os.environ[key] = value
|
|
try:
|
|
yield
|
|
finally:
|
|
for key, old_value in previous.items():
|
|
if old_value is sentinel:
|
|
os.environ.pop(key, None)
|
|
else:
|
|
os.environ[key] = str(old_value)
|
|
|
|
def _resolve_model(self, model: str) -> str:
|
|
if self._gateway:
|
|
prefix = self._gateway.litellm_prefix
|
|
resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model
|
|
if prefix and not resolved.startswith(f"{prefix}/"):
|
|
resolved = f"{prefix}/{resolved}"
|
|
return resolved
|
|
if self.provider_name:
|
|
spec = find_by_name(self.provider_name)
|
|
if spec is not None and not spec.is_gateway and not spec.is_local:
|
|
resolved = model
|
|
if spec.litellm_prefix and not any(resolved.startswith(prefix) for prefix in spec.skip_prefixes):
|
|
resolved = f"{spec.litellm_prefix}/{resolved}"
|
|
elif spec.name == "openai" and "/" not in resolved:
|
|
resolved = f"openai/{resolved}"
|
|
return resolved
|
|
spec = find_by_model(model)
|
|
if spec and spec.litellm_prefix:
|
|
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
|
|
model = f"{spec.litellm_prefix}/{model}"
|
|
return model
|
|
|
|
@staticmethod
|
|
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
sanitized = []
|
|
system_contents: list[str] = []
|
|
for message in messages:
|
|
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
|
|
if clean.get("role") == "system":
|
|
content = clean.get("content")
|
|
if isinstance(content, str) and content.strip():
|
|
system_contents.append(content.strip())
|
|
elif content is not None:
|
|
system_contents.append(str(content))
|
|
continue
|
|
if clean.get("role") == "assistant" and "content" not in clean:
|
|
clean["content"] = None
|
|
if isinstance(clean.get("tool_calls"), list):
|
|
clean["tool_calls"] = LiteLLMProvider._sanitize_tool_calls(clean["tool_calls"])
|
|
sanitized.append(clean)
|
|
if system_contents:
|
|
sanitized.insert(0, {"role": "system", "content": "\n\n".join(system_contents)})
|
|
return sanitized
|
|
|
|
@staticmethod
|
|
def _sanitize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]:
|
|
sanitized: list[dict[str, Any]] = []
|
|
for tool_call in tool_calls:
|
|
if not isinstance(tool_call, dict):
|
|
continue
|
|
clean = dict(tool_call)
|
|
function = clean.get("function")
|
|
if isinstance(function, dict):
|
|
clean_function = dict(function)
|
|
arguments = clean_function.get("arguments")
|
|
if not isinstance(arguments, str):
|
|
clean_function["arguments"] = json.dumps(arguments or {}, ensure_ascii=False, default=str)
|
|
clean["function"] = clean_function
|
|
sanitized.append(clean)
|
|
return sanitized
|
|
|
|
def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None:
|
|
spec = find_by_model(original_model)
|
|
if spec is None:
|
|
return
|
|
model_lower = original_model.lower()
|
|
for pattern, overrides in spec.model_overrides:
|
|
if pattern in model_lower:
|
|
kwargs.update(overrides)
|
|
return
|
|
|
|
def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None:
|
|
if self.provider_name != "openrouter" or self.routing is None:
|
|
return
|
|
provider_payload: dict[str, Any] = {}
|
|
if self.routing.sort:
|
|
provider_payload["sort"] = self.routing.sort
|
|
if self.routing.only:
|
|
provider_payload["only"] = self.routing.only
|
|
if self.routing.ignore:
|
|
provider_payload["ignore"] = self.routing.ignore
|
|
if self.routing.order:
|
|
provider_payload["order"] = self.routing.order
|
|
if self.routing.require_parameters:
|
|
provider_payload["require_parameters"] = True
|
|
if self.routing.data_collection:
|
|
provider_payload["data_collection"] = self.routing.data_collection
|
|
if provider_payload:
|
|
kwargs["provider"] = provider_payload
|
|
|
|
def _apply_thinking_mode(self, original_model: str, resolved_model: str, kwargs: dict[str, Any], enabled: bool | None) -> None:
|
|
if self._uses_mistral_reasoning_parser(original_model, resolved_model):
|
|
if enabled is not None:
|
|
extra_body = dict(kwargs.get("extra_body") or {})
|
|
extra_body["reasoning_effort"] = "high" if enabled else "none"
|
|
kwargs["extra_body"] = extra_body
|
|
return
|
|
|
|
extra_body = dict(kwargs.get("extra_body") or {})
|
|
chat_template_kwargs = dict(extra_body.get("chat_template_kwargs") or {})
|
|
chat_template_kwargs["enable_thinking"] = False
|
|
extra_body["chat_template_kwargs"] = chat_template_kwargs
|
|
extra_body["thinking"] = {"type": "disabled"}
|
|
kwargs["extra_body"] = extra_body
|
|
|
|
def _uses_mistral_reasoning_parser(self, original_model: str, resolved_model: str) -> bool:
|
|
model_names = f"{original_model} {resolved_model}".lower()
|
|
if "mistral" not in model_names:
|
|
return False
|
|
if self.provider_name == "vllm":
|
|
return True
|
|
return self.provider_name in {"openai", "custom"} and _looks_like_local_vllm_api_base(self.api_base)
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int | None = None,
|
|
temperature: float = 0.7,
|
|
thinking_enabled: bool | None = None,
|
|
) -> LLMResponse:
|
|
if acompletion is None:
|
|
return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name)
|
|
|
|
original_model = model or self.default_model
|
|
resolved_model = self._resolve_model(original_model)
|
|
sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages))
|
|
kwargs: dict[str, Any] = {
|
|
"model": resolved_model,
|
|
"messages": sanitized_messages,
|
|
"temperature": temperature,
|
|
"timeout": self.request_timeout_seconds or 45.0,
|
|
}
|
|
if max_tokens is not None:
|
|
kwargs["max_tokens"] = max(1, max_tokens)
|
|
if self.api_key:
|
|
kwargs["api_key"] = self.api_key
|
|
if self.api_base:
|
|
kwargs["api_base"] = self.api_base
|
|
if self.extra_headers:
|
|
kwargs["extra_headers"] = self.extra_headers
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
kwargs["tool_choice"] = "auto"
|
|
self._apply_model_overrides(original_model, kwargs)
|
|
self._apply_openrouter_routing(kwargs)
|
|
self._apply_thinking_mode(original_model, resolved_model, kwargs, thinking_enabled)
|
|
env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model)
|
|
|
|
try:
|
|
with self._temporary_env(env_overrides):
|
|
response = await acompletion(**kwargs)
|
|
except Exception as exc:
|
|
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model)
|
|
|
|
choice = response.choices[0]
|
|
message = choice.message
|
|
tool_calls: list[ToolCallRequest] = []
|
|
for tool_call in message.tool_calls or []:
|
|
raw_arguments = tool_call.function.arguments
|
|
if isinstance(raw_arguments, str):
|
|
try:
|
|
if json_repair is not None:
|
|
arguments = json_repair.loads(raw_arguments)
|
|
else:
|
|
arguments = json.loads(raw_arguments)
|
|
except Exception as exc:
|
|
# 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。
|
|
# 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。
|
|
arguments = {
|
|
"__beaver_tool_argument_parse_error__": str(exc),
|
|
"__raw_arguments__": raw_arguments,
|
|
}
|
|
else:
|
|
arguments = raw_arguments
|
|
tool_calls.append(
|
|
ToolCallRequest(
|
|
id=tool_call.id,
|
|
name=tool_call.function.name,
|
|
arguments=arguments,
|
|
)
|
|
)
|
|
usage = getattr(response, "usage", None)
|
|
usage_payload = {}
|
|
if usage is not None:
|
|
usage_payload = {
|
|
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
|
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
|
"total_tokens": getattr(usage, "total_tokens", 0),
|
|
}
|
|
return LLMResponse(
|
|
content=getattr(message, "content", None),
|
|
tool_calls=tool_calls,
|
|
finish_reason=getattr(choice, "finish_reason", "stop") or "stop",
|
|
usage=usage_payload,
|
|
reasoning_content=getattr(message, "reasoning_content", None),
|
|
provider_name=self.provider_name or "litellm",
|
|
model=resolved_model,
|
|
)
|
|
|
|
def get_default_model(self) -> str:
|
|
return self.default_model
|