修改了nanobot,往Hermes agent的风格走,进度1/3
This commit is contained in:
230
app-instance/backend/beaver/engine/providers/litellm.py
Normal file
230
app-instance/backend/beaver/engine/providers/litellm.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .registry import find_by_model, find_gateway
|
||||
from .runtime import ProviderRoutingConfig
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import json_repair
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
json_repair = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
litellm = None # type: ignore[assignment]
|
||||
acompletion = None # type: ignore[assignment]
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""通过 LiteLLM 统一访问大多数 provider。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
routing: ProviderRoutingConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.routing = routing
|
||||
self.provider_name = provider_name
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
if litellm is not None:
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.drop_params = True
|
||||
|
||||
def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]:
|
||||
"""为当前请求生成 LiteLLM 依赖的临时环境变量。
|
||||
|
||||
LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime
|
||||
之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。
|
||||
"""
|
||||
|
||||
if not api_key:
|
||||
return {}
|
||||
spec = self._gateway or find_by_model(model)
|
||||
if spec is None or not spec.env_key:
|
||||
return {}
|
||||
overrides: dict[str, str] = {spec.env_key: api_key}
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_value in spec.env_extras:
|
||||
resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||||
overrides[env_name] = resolved
|
||||
return overrides
|
||||
|
||||
@contextmanager
|
||||
def _temporary_env(self, overrides: dict[str, str]):
|
||||
"""只在当前请求期间注入 provider 需要的环境变量。"""
|
||||
|
||||
if not overrides:
|
||||
yield
|
||||
return
|
||||
|
||||
sentinel = object()
|
||||
previous: dict[str, object] = {}
|
||||
for key, value in overrides.items():
|
||||
previous[key] = os.environ.get(key, sentinel)
|
||||
os.environ[key] = value
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for key, old_value in previous.items():
|
||||
if old_value is sentinel:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = str(old_value)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
if self._gateway:
|
||||
prefix = self._gateway.litellm_prefix
|
||||
resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
|
||||
model = f"{spec.litellm_prefix}/{model}"
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
sanitized = []
|
||||
for message in messages:
|
||||
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
|
||||
if clean.get("role") == "assistant" and "content" not in clean:
|
||||
clean["content"] = None
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None:
|
||||
spec = find_by_model(original_model)
|
||||
if spec is None:
|
||||
return
|
||||
model_lower = original_model.lower()
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None:
|
||||
if self.provider_name != "openrouter" or self.routing is None:
|
||||
return
|
||||
provider_payload: dict[str, Any] = {}
|
||||
if self.routing.sort:
|
||||
provider_payload["sort"] = self.routing.sort
|
||||
if self.routing.only:
|
||||
provider_payload["only"] = self.routing.only
|
||||
if self.routing.ignore:
|
||||
provider_payload["ignore"] = self.routing.ignore
|
||||
if self.routing.order:
|
||||
provider_payload["order"] = self.routing.order
|
||||
if self.routing.require_parameters:
|
||||
provider_payload["require_parameters"] = True
|
||||
if self.routing.data_collection:
|
||||
provider_payload["data_collection"] = self.routing.data_collection
|
||||
if provider_payload:
|
||||
kwargs["provider"] = provider_payload
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
if acompletion is None:
|
||||
return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name)
|
||||
|
||||
original_model = model or self.default_model
|
||||
resolved_model = self._resolve_model(original_model)
|
||||
sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages))
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": resolved_model,
|
||||
"messages": sanitized_messages,
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
self._apply_model_overrides(original_model, kwargs)
|
||||
self._apply_openrouter_routing(kwargs)
|
||||
env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model)
|
||||
|
||||
try:
|
||||
with self._temporary_env(env_overrides):
|
||||
response = await acompletion(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model)
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
for tool_call in message.tool_calls or []:
|
||||
raw_arguments = tool_call.function.arguments
|
||||
if isinstance(raw_arguments, str):
|
||||
try:
|
||||
if json_repair is not None:
|
||||
arguments = json_repair.loads(raw_arguments)
|
||||
else:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except Exception as exc:
|
||||
# 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。
|
||||
# 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。
|
||||
arguments = {
|
||||
"__beaver_tool_argument_parse_error__": str(exc),
|
||||
"__raw_arguments__": raw_arguments,
|
||||
}
|
||||
else:
|
||||
arguments = raw_arguments
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_payload = {}
|
||||
if usage is not None:
|
||||
usage_payload = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content=getattr(message, "content", None),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=getattr(choice, "finish_reason", "stop") or "stop",
|
||||
usage=usage_payload,
|
||||
reasoning_content=getattr(message, "reasoning_content", None),
|
||||
provider_name=self.provider_name or "litellm",
|
||||
model=resolved_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
Reference in New Issue
Block a user