Files
beaver_project/app-instance/backend/beaver/engine/providers/litellm.py
steven_li 30ab74ffb2 feat(engine): 添加MCP连接管理和工具集成功能
- 集成MCP连接管理器,支持MCP服务器连接
- 添加多种内置工具:ClarifyTool、CronTool、DelegateTool、ExecuteCodeTool、
  PatchFileTool、ProcessTool、SendMessageTool、SpawnTool、TerminalTool、
  TodoTool、WebFetchTool、WebSearchTool、WriteFileTool等
- 实现工具注册和装配功能
- 添加技能选择上下文参数
- 支持思考模式控制参数thinking_enabled

feat(coordinator): 重构任务执行计划器参数命名

- 将learning_candidate_enabled重命名为allow_candidate_generation
- 更新TeamGraphScheduler中的参数传递
- 修改LocalAgentRunner中的相关参数处理
- 更新README文档中的相应描述

refactor(context): 标准化工具调用参数格式

- 添加_json导入用于参数序列化
- 实现_provider_tool_calls方法标准化OpenAI兼容的工具调用载荷
- 修复工具调用中参数非字符串类型的序列化问题

refactor(session): 优化消息历史记录过滤逻辑

- 修改get_messages_as_conversation为基于运行状态过滤消息
- 排除未完成、失败或错误结束的运行记录
- 改进对话历史的可见性控制机制

fix(store): 修复FTS索引重建逻辑

- 添加异常处理防止FTS索引创建失败
- 实现_rebuild_fts_index方法重新构建全文搜索索引
- 优化索引触发器和表的维护流程
2026-05-14 09:43:48 +08:00

278 lines
12 KiB
Python

"""LiteLLM provider implementation for multi-provider support."""
from __future__ import annotations
from contextlib import contextmanager
import json
import os
from typing import Any
from .base import LLMProvider, LLMResponse, ToolCallRequest
from .registry import find_by_model, find_by_name, find_gateway
from .runtime import ProviderRoutingConfig
try: # pragma: no cover - optional dependency
import json_repair
except ModuleNotFoundError: # pragma: no cover
json_repair = None # type: ignore[assignment]
try: # pragma: no cover - optional dependency
import litellm
from litellm import acompletion
except ModuleNotFoundError: # pragma: no cover
litellm = None # type: ignore[assignment]
acompletion = None # type: ignore[assignment]
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
class LiteLLMProvider(LLMProvider):
"""通过 LiteLLM 统一访问大多数 provider。"""
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
default_model: str = "anthropic/claude-opus-4-5",
extra_headers: dict[str, str] | None = None,
provider_name: str | None = None,
request_timeout_seconds: float | None = None,
routing: ProviderRoutingConfig | None = None,
) -> None:
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
self.default_model = default_model
self.extra_headers = extra_headers or {}
self.routing = routing
self.provider_name = provider_name
self._gateway = find_gateway(provider_name, api_key, api_base)
if litellm is not None:
litellm.suppress_debug_info = True
litellm.drop_params = True
def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]:
"""为当前请求生成 LiteLLM 依赖的临时环境变量。
LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime
之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。
"""
if not api_key:
return {}
spec = self._gateway
if spec is None and self.provider_name:
spec = find_by_name(self.provider_name)
if spec is None:
spec = find_by_model(model)
if spec is None or not spec.env_key:
return {}
overrides: dict[str, str] = {spec.env_key: api_key}
effective_base = api_base or spec.default_api_base
for env_name, env_value in spec.env_extras:
resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base)
overrides[env_name] = resolved
return overrides
@contextmanager
def _temporary_env(self, overrides: dict[str, str]):
"""只在当前请求期间注入 provider 需要的环境变量。"""
if not overrides:
yield
return
sentinel = object()
previous: dict[str, object] = {}
for key, value in overrides.items():
previous[key] = os.environ.get(key, sentinel)
os.environ[key] = value
try:
yield
finally:
for key, old_value in previous.items():
if old_value is sentinel:
os.environ.pop(key, None)
else:
os.environ[key] = str(old_value)
def _resolve_model(self, model: str) -> str:
if self._gateway:
prefix = self._gateway.litellm_prefix
resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
if self.provider_name:
spec = find_by_name(self.provider_name)
if spec is not None and not spec.is_gateway and not spec.is_local:
resolved = model
if spec.litellm_prefix and not any(resolved.startswith(prefix) for prefix in spec.skip_prefixes):
resolved = f"{spec.litellm_prefix}/{resolved}"
elif spec.name == "openai" and "/" not in resolved:
resolved = f"openai/{resolved}"
return resolved
spec = find_by_model(model)
if spec and spec.litellm_prefix:
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
model = f"{spec.litellm_prefix}/{model}"
return model
@staticmethod
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
sanitized = []
for message in messages:
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
if isinstance(clean.get("tool_calls"), list):
clean["tool_calls"] = LiteLLMProvider._sanitize_tool_calls(clean["tool_calls"])
sanitized.append(clean)
return sanitized
@staticmethod
def _sanitize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]:
sanitized: list[dict[str, Any]] = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
clean = dict(tool_call)
function = clean.get("function")
if isinstance(function, dict):
clean_function = dict(function)
arguments = clean_function.get("arguments")
if not isinstance(arguments, str):
clean_function["arguments"] = json.dumps(arguments or {}, ensure_ascii=False, default=str)
clean["function"] = clean_function
sanitized.append(clean)
return sanitized
def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None:
spec = find_by_model(original_model)
if spec is None:
return
model_lower = original_model.lower()
for pattern, overrides in spec.model_overrides:
if pattern in model_lower:
kwargs.update(overrides)
return
def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None:
if self.provider_name != "openrouter" or self.routing is None:
return
provider_payload: dict[str, Any] = {}
if self.routing.sort:
provider_payload["sort"] = self.routing.sort
if self.routing.only:
provider_payload["only"] = self.routing.only
if self.routing.ignore:
provider_payload["ignore"] = self.routing.ignore
if self.routing.order:
provider_payload["order"] = self.routing.order
if self.routing.require_parameters:
provider_payload["require_parameters"] = True
if self.routing.data_collection:
provider_payload["data_collection"] = self.routing.data_collection
if provider_payload:
kwargs["provider"] = provider_payload
def _apply_thinking_mode(self, original_model: str, resolved_model: str, kwargs: dict[str, Any], enabled: bool | None) -> None:
if enabled is None:
return
model_key = f"{original_model} {resolved_model}".lower()
if "qwen" not in model_key:
return
extra_body = dict(kwargs.get("extra_body") or {})
chat_template_kwargs = dict(extra_body.get("chat_template_kwargs") or {})
chat_template_kwargs["enable_thinking"] = bool(enabled)
extra_body["chat_template_kwargs"] = chat_template_kwargs
kwargs["extra_body"] = extra_body
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
thinking_enabled: bool | None = None,
) -> LLMResponse:
if acompletion is None:
return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name)
original_model = model or self.default_model
resolved_model = self._resolve_model(original_model)
sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages))
kwargs: dict[str, Any] = {
"model": resolved_model,
"messages": sanitized_messages,
"max_tokens": max(1, max_tokens),
"temperature": temperature,
"timeout": self.request_timeout_seconds or 45.0,
}
if self.api_key:
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
self._apply_model_overrides(original_model, kwargs)
self._apply_openrouter_routing(kwargs)
self._apply_thinking_mode(original_model, resolved_model, kwargs, thinking_enabled)
env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model)
try:
with self._temporary_env(env_overrides):
response = await acompletion(**kwargs)
except Exception as exc:
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model)
choice = response.choices[0]
message = choice.message
tool_calls: list[ToolCallRequest] = []
for tool_call in message.tool_calls or []:
raw_arguments = tool_call.function.arguments
if isinstance(raw_arguments, str):
try:
if json_repair is not None:
arguments = json_repair.loads(raw_arguments)
else:
arguments = json.loads(raw_arguments)
except Exception as exc:
# 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。
# 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。
arguments = {
"__beaver_tool_argument_parse_error__": str(exc),
"__raw_arguments__": raw_arguments,
}
else:
arguments = raw_arguments
tool_calls.append(
ToolCallRequest(
id=tool_call.id,
name=tool_call.function.name,
arguments=arguments,
)
)
usage = getattr(response, "usage", None)
usage_payload = {}
if usage is not None:
usage_payload = {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return LLMResponse(
content=getattr(message, "content", None),
tool_calls=tool_calls,
finish_reason=getattr(choice, "finish_reason", "stop") or "stop",
usage=usage_payload,
reasoning_content=getattr(message, "reasoning_content", None),
provider_name=self.provider_name or "litellm",
model=resolved_model,
)
def get_default_model(self) -> str:
return self.default_model