"""LiteLLM provider implementation for multi-provider support.""" from __future__ import annotations from contextlib import contextmanager import json import os from typing import Any from .base import LLMProvider, LLMResponse, ToolCallRequest from .registry import find_by_model, find_by_name, find_gateway from .runtime import ProviderRoutingConfig try: # pragma: no cover - optional dependency import json_repair except ModuleNotFoundError: # pragma: no cover json_repair = None # type: ignore[assignment] try: # pragma: no cover - optional dependency import litellm from litellm import acompletion except ModuleNotFoundError: # pragma: no cover litellm = None # type: ignore[assignment] acompletion = None # type: ignore[assignment] _ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) class LiteLLMProvider(LLMProvider): """通过 LiteLLM 统一访问大多数 provider。""" def __init__( self, api_key: str | None = None, api_base: str | None = None, default_model: str = "anthropic/claude-opus-4-5", extra_headers: dict[str, str] | None = None, provider_name: str | None = None, request_timeout_seconds: float | None = None, routing: ProviderRoutingConfig | None = None, ) -> None: super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds) self.default_model = default_model self.extra_headers = extra_headers or {} self.routing = routing self.provider_name = provider_name self._gateway = find_gateway(provider_name, api_key, api_base) if litellm is not None: litellm.suppress_debug_info = True litellm.drop_params = True def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]: """为当前请求生成 LiteLLM 依赖的临时环境变量。 LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime 之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。 """ if not api_key: return {} spec = self._gateway if spec is None and self.provider_name: spec = find_by_name(self.provider_name) if spec is None: spec = find_by_model(model) if spec is None or not spec.env_key: return {} overrides: dict[str, str] = {spec.env_key: api_key} effective_base = api_base or spec.default_api_base for env_name, env_value in spec.env_extras: resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base) overrides[env_name] = resolved return overrides @contextmanager def _temporary_env(self, overrides: dict[str, str]): """只在当前请求期间注入 provider 需要的环境变量。""" if not overrides: yield return sentinel = object() previous: dict[str, object] = {} for key, value in overrides.items(): previous[key] = os.environ.get(key, sentinel) os.environ[key] = value try: yield finally: for key, old_value in previous.items(): if old_value is sentinel: os.environ.pop(key, None) else: os.environ[key] = str(old_value) def _resolve_model(self, model: str) -> str: if self._gateway: prefix = self._gateway.litellm_prefix resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model if prefix and not resolved.startswith(f"{prefix}/"): resolved = f"{prefix}/{resolved}" return resolved if self.provider_name: spec = find_by_name(self.provider_name) if spec is not None and not spec.is_gateway and not spec.is_local: resolved = model if spec.litellm_prefix and not any(resolved.startswith(prefix) for prefix in spec.skip_prefixes): resolved = f"{spec.litellm_prefix}/{resolved}" elif spec.name == "openai" and "/" not in resolved: resolved = f"openai/{resolved}" return resolved spec = find_by_model(model) if spec and spec.litellm_prefix: if not any(model.startswith(prefix) for prefix in spec.skip_prefixes): model = f"{spec.litellm_prefix}/{model}" return model @staticmethod def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: sanitized = [] system_contents: list[str] = [] for message in messages: clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS} if clean.get("role") == "system": content = clean.get("content") if isinstance(content, str) and content.strip(): system_contents.append(content.strip()) elif content is not None: system_contents.append(str(content)) continue if clean.get("role") == "assistant" and "content" not in clean: clean["content"] = None if isinstance(clean.get("tool_calls"), list): clean["tool_calls"] = LiteLLMProvider._sanitize_tool_calls(clean["tool_calls"]) sanitized.append(clean) if system_contents: sanitized.insert(0, {"role": "system", "content": "\n\n".join(system_contents)}) return sanitized @staticmethod def _sanitize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]: sanitized: list[dict[str, Any]] = [] for tool_call in tool_calls: if not isinstance(tool_call, dict): continue clean = dict(tool_call) function = clean.get("function") if isinstance(function, dict): clean_function = dict(function) arguments = clean_function.get("arguments") if not isinstance(arguments, str): clean_function["arguments"] = json.dumps(arguments or {}, ensure_ascii=False, default=str) clean["function"] = clean_function sanitized.append(clean) return sanitized def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None: spec = find_by_model(original_model) if spec is None: return model_lower = original_model.lower() for pattern, overrides in spec.model_overrides: if pattern in model_lower: kwargs.update(overrides) return def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None: if self.provider_name != "openrouter" or self.routing is None: return provider_payload: dict[str, Any] = {} if self.routing.sort: provider_payload["sort"] = self.routing.sort if self.routing.only: provider_payload["only"] = self.routing.only if self.routing.ignore: provider_payload["ignore"] = self.routing.ignore if self.routing.order: provider_payload["order"] = self.routing.order if self.routing.require_parameters: provider_payload["require_parameters"] = True if self.routing.data_collection: provider_payload["data_collection"] = self.routing.data_collection if provider_payload: kwargs["provider"] = provider_payload def _apply_thinking_mode(self, original_model: str, resolved_model: str, kwargs: dict[str, Any], enabled: bool | None) -> None: extra_body = dict(kwargs.get("extra_body") or {}) chat_template_kwargs = dict(extra_body.get("chat_template_kwargs") or {}) chat_template_kwargs["enable_thinking"] = False extra_body["chat_template_kwargs"] = chat_template_kwargs extra_body["thinking"] = {"type": "disabled"} kwargs["extra_body"] = extra_body async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, thinking_enabled: bool | None = None, ) -> LLMResponse: if acompletion is None: return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name) original_model = model or self.default_model resolved_model = self._resolve_model(original_model) sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages)) kwargs: dict[str, Any] = { "model": resolved_model, "messages": sanitized_messages, "max_tokens": max(1, max_tokens), "temperature": temperature, "timeout": self.request_timeout_seconds or 45.0, } if self.api_key: kwargs["api_key"] = self.api_key if self.api_base: kwargs["api_base"] = self.api_base if self.extra_headers: kwargs["extra_headers"] = self.extra_headers if tools: kwargs["tools"] = tools kwargs["tool_choice"] = "auto" self._apply_model_overrides(original_model, kwargs) self._apply_openrouter_routing(kwargs) self._apply_thinking_mode(original_model, resolved_model, kwargs, thinking_enabled) env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model) try: with self._temporary_env(env_overrides): response = await acompletion(**kwargs) except Exception as exc: return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model) choice = response.choices[0] message = choice.message tool_calls: list[ToolCallRequest] = [] for tool_call in message.tool_calls or []: raw_arguments = tool_call.function.arguments if isinstance(raw_arguments, str): try: if json_repair is not None: arguments = json_repair.loads(raw_arguments) else: arguments = json.loads(raw_arguments) except Exception as exc: # 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。 # 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。 arguments = { "__beaver_tool_argument_parse_error__": str(exc), "__raw_arguments__": raw_arguments, } else: arguments = raw_arguments tool_calls.append( ToolCallRequest( id=tool_call.id, name=tool_call.function.name, arguments=arguments, ) ) usage = getattr(response, "usage", None) usage_payload = {} if usage is not None: usage_payload = { "prompt_tokens": getattr(usage, "prompt_tokens", 0), "completion_tokens": getattr(usage, "completion_tokens", 0), "total_tokens": getattr(usage, "total_tokens", 0), } return LLMResponse( content=getattr(message, "content", None), tool_calls=tool_calls, finish_reason=getattr(choice, "finish_reason", "stop") or "stop", usage=usage_payload, reasoning_content=getattr(message, "reasoning_content", None), provider_name=self.provider_name or "litellm", model=resolved_model, ) def get_default_model(self) -> str: return self.default_model