修改了nanobot,往Hermes agent的风格走,进度1/3
This commit is contained in:
33
app-instance/backend/beaver/engine/providers/__init__.py
Normal file
33
app-instance/backend/beaver/engine/providers/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""LLM provider adapters."""
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .chain import FallbackProviderChain
|
||||
from .factory import (
|
||||
ProviderBundle,
|
||||
ProviderRoutingConfig,
|
||||
ProviderRuntime,
|
||||
ProviderTarget,
|
||||
build_provider_runtime,
|
||||
make_aux_provider,
|
||||
make_fallback_provider,
|
||||
make_main_provider,
|
||||
make_provider_bundle,
|
||||
make_provider_from_runtime,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FallbackProviderChain",
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"ProviderBundle",
|
||||
"ProviderRoutingConfig",
|
||||
"ProviderRuntime",
|
||||
"ProviderTarget",
|
||||
"ToolCallRequest",
|
||||
"build_provider_runtime",
|
||||
"make_aux_provider",
|
||||
"make_fallback_provider",
|
||||
"make_main_provider",
|
||||
"make_provider_bundle",
|
||||
"make_provider_from_runtime",
|
||||
]
|
||||
173
app-instance/backend/beaver/engine/providers/anthropic.py
Normal file
173
app-instance/backend/beaver/engine/providers/anthropic.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Native Anthropic Messages API provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import anthropic
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
anthropic = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""使用 Anthropic 原生 Messages API,而不是强行走 OpenAI-compatible path。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
default_model: str = "claude-sonnet-4-5",
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self._client = None
|
||||
|
||||
def _client_or_raise(self):
|
||||
if anthropic is None:
|
||||
raise RuntimeError("anthropic package is not installed")
|
||||
if self._client is None:
|
||||
self._client = anthropic.AsyncAnthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
timeout=self.request_timeout_seconds,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
try:
|
||||
client = self._client_or_raise()
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
|
||||
|
||||
system_prompt, anthropic_messages = _convert_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"system": system_prompt or "",
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = _convert_tools(tools)
|
||||
|
||||
try:
|
||||
response = await client.messages.create(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
|
||||
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content_parts.append(block.text)
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
arguments=block.input,
|
||||
)
|
||||
)
|
||||
usage_payload = {}
|
||||
if getattr(response, "usage", None):
|
||||
usage_payload = {
|
||||
"input_tokens": getattr(response.usage, "input_tokens", 0),
|
||||
"output_tokens": getattr(response.usage, "output_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=getattr(response, "stop_reason", "stop") or "stop",
|
||||
usage=usage_payload,
|
||||
provider_name="anthropic",
|
||||
model=model or self.default_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
if role == "system":
|
||||
content = message.get("content")
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
if role == "tool":
|
||||
converted.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.get("tool_call_id"),
|
||||
"content": message.get("content") or "",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
continue
|
||||
if role == "assistant" and message.get("tool_calls"):
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if message.get("content"):
|
||||
content_blocks.append({"type": "text", "text": message["content"]})
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
function = tool_call.get("function", tool_call)
|
||||
arguments = function.get("arguments")
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_call.get("id"),
|
||||
"name": function.get("name"),
|
||||
"input": arguments or {},
|
||||
}
|
||||
)
|
||||
converted.append({"role": "assistant", "content": content_blocks})
|
||||
continue
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
blocks = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
blocks.append({"type": "text", "text": item.get("text", "")})
|
||||
converted.append({"role": role, "content": blocks or [{"type": "text", "text": ""}]})
|
||||
else:
|
||||
converted.append({"role": role, "content": content or ""})
|
||||
return system_prompt, converted
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
if not fn.get("name"):
|
||||
continue
|
||||
converted.append(
|
||||
{
|
||||
"name": fn["name"],
|
||||
"description": fn.get("description") or "",
|
||||
"input_schema": fn.get("parameters") or {"type": "object", "properties": {}},
|
||||
}
|
||||
)
|
||||
return converted
|
||||
98
app-instance/backend/beaver/engine/providers/base.py
Normal file
98
app-instance/backend/beaver/engine/providers/base.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Beaver provider 子系统的统一契约。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCallRequest:
|
||||
"""模型返回的一次工具调用请求。"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMResponse:
|
||||
"""统一的模型响应结构。"""
|
||||
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, Any] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""所有 provider 实现必须遵守的统一接口。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.request_timeout_seconds = (
|
||||
max(1.0, float(request_timeout_seconds))
|
||||
if request_timeout_seconds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""清理 provider 普遍不接受的空 content。"""
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str) and content == "":
|
||||
clean = dict(message)
|
||||
clean["content"] = None if (message.get("role") == "assistant" and message.get("tool_calls")) else "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = [
|
||||
item
|
||||
for item in content
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") in ("text", "input_text", "output_text")
|
||||
and not item.get("text")
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(content):
|
||||
clean = dict(message)
|
||||
clean["content"] = filtered or "(empty)"
|
||||
if message.get("role") == "assistant" and message.get("tool_calls") and not filtered:
|
||||
clean["content"] = None
|
||||
result.append(clean)
|
||||
continue
|
||||
result.append(message)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""统一聊天接口。"""
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""返回 provider 的默认模型名。"""
|
||||
145
app-instance/backend/beaver/engine/providers/chain.py
Normal file
145
app-instance/backend/beaver/engine/providers/chain.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Provider chain helpers.
|
||||
|
||||
这里先实现最小可用的 fallback chain:
|
||||
- 每次调用都先尝试主 provider
|
||||
- 本次调用主 provider 返回 `finish_reason=error` 时,再切到 fallback
|
||||
- fallback 只影响当前这一次调用,不会污染下一次 run 的首选链路
|
||||
|
||||
这样后面 `AgentLoop` 不需要自己处理“主模型挂了再换一个 provider”。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import LLMProvider, LLMResponse
|
||||
from .runtime import ProviderRuntime
|
||||
|
||||
|
||||
class FallbackProviderChain(LLMProvider):
|
||||
"""把 primary/fallback provider 封装成一个统一的 LLMProvider。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
primary_runtime: ProviderRuntime,
|
||||
primary_provider: LLMProvider,
|
||||
fallback_runtime: ProviderRuntime | None = None,
|
||||
fallback_provider: LLMProvider | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
api_key=primary_runtime.api_key,
|
||||
api_base=primary_runtime.api_base,
|
||||
request_timeout_seconds=primary_runtime.request_timeout_seconds,
|
||||
)
|
||||
self.primary_runtime = primary_runtime
|
||||
self.primary_provider = primary_provider
|
||||
self.fallback_runtime = fallback_runtime
|
||||
self.fallback_provider = fallback_provider
|
||||
# 这里只记录“最近一次 chat 实际用了哪条链”,用于调试和测试。
|
||||
# 真正的选路决策必须按调用粒度重新从 primary 开始,不能跨调用粘住 fallback。
|
||||
self._last_runtime = primary_runtime
|
||||
self._last_provider = primary_provider
|
||||
self._last_call_used_fallback = False
|
||||
|
||||
@property
|
||||
def fallback_activated(self) -> bool:
|
||||
"""最近一次 chat 是否实际用到了 fallback。"""
|
||||
|
||||
return self._last_call_used_fallback
|
||||
|
||||
@property
|
||||
def active_runtime(self) -> ProviderRuntime:
|
||||
"""最近一次 chat 实际使用的 runtime。"""
|
||||
|
||||
return self._last_runtime
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
self._last_provider = self.primary_provider
|
||||
self._last_runtime = self.primary_runtime
|
||||
self._last_call_used_fallback = False
|
||||
|
||||
response = await self._safe_chat(
|
||||
self.primary_provider,
|
||||
self.primary_runtime,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model or self.primary_runtime.model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
response = self._decorate_response(response, self.primary_runtime)
|
||||
if not self._should_activate_fallback(response):
|
||||
return response
|
||||
|
||||
assert self.fallback_provider is not None
|
||||
assert self.fallback_runtime is not None
|
||||
|
||||
self._last_provider = self.fallback_provider
|
||||
self._last_runtime = self.fallback_runtime
|
||||
self._last_call_used_fallback = True
|
||||
|
||||
response = await self._safe_chat(
|
||||
self.fallback_provider,
|
||||
self.fallback_runtime,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=self.fallback_runtime.model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
return self._decorate_response(response, self.fallback_runtime)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.primary_runtime.model
|
||||
|
||||
def _should_activate_fallback(self, response: LLMResponse) -> bool:
|
||||
return (
|
||||
self.fallback_provider is not None
|
||||
and self.fallback_runtime is not None
|
||||
and response.finish_reason == "error"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _safe_chat(
|
||||
provider: LLMProvider,
|
||||
runtime: ProviderRuntime,
|
||||
*,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> LLMResponse:
|
||||
"""把 provider 抛出的异常也收敛成统一 error response。
|
||||
|
||||
这样 fallback 的触发条件就不依赖“每个 provider 都记得自己 catch 异常”。
|
||||
"""
|
||||
|
||||
try:
|
||||
return await provider.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
provider_name=runtime.provider_name,
|
||||
model=runtime.model,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _decorate_response(response: LLMResponse, runtime: ProviderRuntime) -> LLMResponse:
|
||||
if response.provider_name is None:
|
||||
response.provider_name = runtime.provider_name
|
||||
if response.model is None:
|
||||
response.model = runtime.model
|
||||
return response
|
||||
274
app-instance/backend/beaver/engine/providers/codex.py
Normal file
274
app-instance/backend/beaver/engine/providers/codex.py
Normal file
@ -0,0 +1,274 @@
|
||||
"""OpenAI Codex Responses provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import httpx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
httpx = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
get_codex_token = None # type: ignore[assignment]
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
DEFAULT_ORIGINATOR = "beaver"
|
||||
|
||||
|
||||
class OpenAICodexProvider(LLMProvider):
|
||||
"""使用 Codex OAuth 调用 Responses API。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_model: str = "openai-codex/gpt-5.1-codex",
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key=None, api_base=None, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
if httpx is None or get_codex_token is None:
|
||||
return LLMResponse(content="Error: codex dependencies are not installed", finish_reason="error", provider_name="openai_codex")
|
||||
|
||||
resolved_model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
headers = _build_headers(token.account_id, token.access)
|
||||
body: dict[str, Any] = {
|
||||
"model": _strip_model_prefix(resolved_model),
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"instructions": system_prompt,
|
||||
"input": input_items,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"prompt_cache_key": _prompt_cache_key(messages),
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
|
||||
try:
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
DEFAULT_CODEX_URL,
|
||||
headers,
|
||||
body,
|
||||
verify=True,
|
||||
timeout_seconds=self.request_timeout_seconds or 600.0,
|
||||
)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error calling Codex: {exc}", finish_reason="error", provider_name="openai_codex")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
provider_name="openai_codex",
|
||||
model=resolved_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
def _strip_model_prefix(model: str) -> str:
|
||||
if model.startswith("openai-codex/") or model.startswith("openai_codex/"):
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
def _build_headers(account_id: str, token: str) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": DEFAULT_ORIGINATOR,
|
||||
"User-Agent": "beaver (python)",
|
||||
"accept": "text/event-stream",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
async def _request_codex(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
body: dict[str, Any],
|
||||
verify: bool,
|
||||
timeout_seconds: float,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
async with httpx.AsyncClient(timeout=timeout_seconds, verify=verify) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=body) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||
return await _consume_sse(response)
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
}
|
||||
)
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
for index, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
if role == "user":
|
||||
input_items.append(_convert_user_message(content))
|
||||
continue
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed",
|
||||
"id": f"msg_{index}",
|
||||
}
|
||||
)
|
||||
for tool_call in message.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{index}",
|
||||
"call_id": call_id or f"call_{index}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(message.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": output_text,
|
||||
}
|
||||
)
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
|
||||
|
||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _iter_sse(response: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [item[5:].strip() for item in buffer if item.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(response: Any) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
finish_reason = "stop"
|
||||
async for event in _iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta") or ""
|
||||
content_parts.append(delta)
|
||||
elif event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
raw_arguments = item.get("arguments") or "{}"
|
||||
try:
|
||||
arguments = json.loads(raw_arguments) if isinstance(raw_arguments, str) else raw_arguments
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{item.get('call_id', 'call')}|{item.get('id', '')}",
|
||||
name=item.get("name", ""),
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
finish_reason = event.get("response", {}).get("status", "completed")
|
||||
return "".join(content_parts) or None, tool_calls, finish_reason
|
||||
|
||||
|
||||
def _friendly_error(status_code: int, body: str) -> str:
|
||||
return f"Codex API error ({status_code}): {body[:400]}"
|
||||
106
app-instance/backend/beaver/engine/providers/custom.py
Normal file
106
app-instance/backend/beaver/engine/providers/custom.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import json_repair
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
json_repair = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
from openai import AsyncOpenAI
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
AsyncOpenAI = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class CustomProvider(LLMProvider):
|
||||
"""直接连接任意 OpenAI-compatible endpoint。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "no-key",
|
||||
api_base: str = "http://localhost:8000/v1",
|
||||
default_model: str = "default",
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self._client = None
|
||||
|
||||
def _client_or_raise(self):
|
||||
if AsyncOpenAI is None:
|
||||
raise RuntimeError("openai package is not installed")
|
||||
if self._client is None:
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
timeout=self.request_timeout_seconds,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
client = self._client_or_raise()
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"messages": self.sanitize_empty_content(messages),
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
kwargs.update(tools=tools, tool_choice="auto")
|
||||
try:
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="custom")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
parsed_tool_calls: list[ToolCallRequest] = []
|
||||
for tool_call in message.tool_calls or []:
|
||||
raw_arguments = tool_call.function.arguments
|
||||
if isinstance(raw_arguments, str):
|
||||
if json_repair is not None:
|
||||
arguments = json_repair.loads(raw_arguments)
|
||||
else:
|
||||
import json
|
||||
arguments = json.loads(raw_arguments)
|
||||
else:
|
||||
arguments = raw_arguments
|
||||
parsed_tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_payload = {}
|
||||
if usage is not None:
|
||||
usage_payload = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content=message.content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
usage=usage_payload,
|
||||
reasoning_content=getattr(message, "reasoning_content", None),
|
||||
provider_name="custom",
|
||||
model=model or self.default_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
235
app-instance/backend/beaver/engine/providers/factory.py
Normal file
235
app-instance/backend/beaver/engine/providers/factory.py
Normal file
@ -0,0 +1,235 @@
|
||||
"""Provider runtime 的统一工厂入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .anthropic import AnthropicProvider
|
||||
from .base import LLMProvider
|
||||
from .chain import FallbackProviderChain
|
||||
from .codex import OpenAICodexProvider
|
||||
from .custom import CustomProvider
|
||||
from .litellm import LiteLLMProvider
|
||||
from .runtime import (
|
||||
ProviderRoutingConfig,
|
||||
ProviderRuntime,
|
||||
ProviderTarget,
|
||||
normalize_provider_target,
|
||||
resolve_auxiliary_runtime,
|
||||
resolve_embedding_runtime,
|
||||
resolve_fallback_runtime,
|
||||
resolve_provider_runtime,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderBundle:
|
||||
"""一次运行所需的 provider 组合。
|
||||
|
||||
这里把三条常见链路收口到一起:
|
||||
- `main`:主对话
|
||||
- `fallback`:主链失败后的备用 provider
|
||||
- `auxiliary`:搜索摘要、压缩、memory flush 等辅助任务
|
||||
"""
|
||||
|
||||
main_runtime: ProviderRuntime
|
||||
main_provider: LLMProvider
|
||||
fallback_runtime: ProviderRuntime | None = None
|
||||
fallback_provider: LLMProvider | None = None
|
||||
auxiliary_runtime: ProviderRuntime | None = None
|
||||
auxiliary_provider: LLMProvider | None = None
|
||||
embedding_runtime: ProviderRuntime | None = None
|
||||
|
||||
|
||||
def build_provider_runtime(**kwargs: Any) -> ProviderRuntime:
|
||||
"""构建统一 provider runtime。"""
|
||||
|
||||
return resolve_provider_runtime(**kwargs)
|
||||
|
||||
|
||||
def make_provider_from_runtime(runtime: ProviderRuntime) -> LLMProvider:
|
||||
"""根据 runtime 创建具体 provider 实例。"""
|
||||
|
||||
if runtime.spec.provider_impl == "custom":
|
||||
return CustomProvider(
|
||||
api_key=runtime.api_key or "no-key",
|
||||
api_base=runtime.api_base or "http://localhost:8000/v1",
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if runtime.spec.provider_impl == "codex":
|
||||
return OpenAICodexProvider(
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if runtime.spec.provider_impl == "anthropic":
|
||||
return AnthropicProvider(
|
||||
api_key=runtime.api_key,
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
api_base=runtime.api_base,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
return LiteLLMProvider(
|
||||
api_key=runtime.api_key,
|
||||
api_base=runtime.api_base,
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
provider_name=runtime.provider_name,
|
||||
extra_headers=runtime.extra_headers,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
routing=runtime.routing,
|
||||
)
|
||||
|
||||
|
||||
def make_main_provider(**kwargs: Any) -> tuple[ProviderRuntime, LLMProvider]:
|
||||
"""构建主对话 provider。"""
|
||||
|
||||
fallback_target = kwargs.pop("fallback_target", None)
|
||||
if fallback_target is None and "fallback_model" in kwargs:
|
||||
fallback_target = kwargs.pop("fallback_model")
|
||||
|
||||
runtime = build_provider_runtime(
|
||||
auxiliary=False,
|
||||
fallback_target=fallback_target,
|
||||
role="main",
|
||||
source="main_config",
|
||||
**kwargs,
|
||||
)
|
||||
provider = make_provider_from_runtime(runtime)
|
||||
fallback_pair = make_fallback_provider(runtime, fallback_target)
|
||||
if fallback_pair is None:
|
||||
return runtime, provider
|
||||
fallback_runtime, fallback_provider = fallback_pair
|
||||
return runtime, FallbackProviderChain(runtime, provider, fallback_runtime, fallback_provider)
|
||||
|
||||
|
||||
def make_fallback_provider(
|
||||
primary_runtime: ProviderRuntime,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
) -> tuple[ProviderRuntime, LLMProvider] | None:
|
||||
"""构建 fallback provider。"""
|
||||
|
||||
runtime = resolve_fallback_runtime(primary_runtime, fallback_target or primary_runtime.fallback_target)
|
||||
if runtime is None:
|
||||
return None
|
||||
return runtime, make_provider_from_runtime(runtime)
|
||||
|
||||
|
||||
def make_aux_provider(
|
||||
main_runtime: ProviderRuntime | None = None,
|
||||
*,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
task_name: str = "auxiliary",
|
||||
**kwargs: Any,
|
||||
) -> tuple[ProviderRuntime, LLMProvider]:
|
||||
"""构建辅助任务 provider。"""
|
||||
|
||||
if target is None and kwargs:
|
||||
target = kwargs
|
||||
|
||||
if main_runtime is not None:
|
||||
runtime = resolve_auxiliary_runtime(main_runtime, target, task_name=task_name)
|
||||
else:
|
||||
normalized = normalize_provider_target(target)
|
||||
if normalized is None or not normalized.model:
|
||||
raise ValueError("Auxiliary provider without main_runtime requires at least a model")
|
||||
runtime = build_provider_runtime(
|
||||
model=normalized.model,
|
||||
provider_name=normalized.provider_name,
|
||||
api_key=normalized.api_key,
|
||||
api_base=normalized.api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds,
|
||||
extra_headers=normalized.extra_headers,
|
||||
routing=normalized.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auxiliary_config",
|
||||
)
|
||||
return runtime, make_provider_from_runtime(runtime)
|
||||
|
||||
|
||||
def make_embedding_runtime(
|
||||
main_runtime: ProviderRuntime,
|
||||
*,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
default_model: str = "text-embedding-v4",
|
||||
) -> ProviderRuntime | None:
|
||||
"""构建 embedding 专用 runtime。"""
|
||||
|
||||
return resolve_embedding_runtime(main_runtime, target=target, default_model=default_model)
|
||||
|
||||
|
||||
def make_provider_bundle(
|
||||
*,
|
||||
auxiliary_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
auxiliary_task_name: str = "auxiliary",
|
||||
embedding_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
embedding_model: str = "text-embedding-v4",
|
||||
**kwargs: Any,
|
||||
) -> ProviderBundle:
|
||||
"""一次性构建 main/fallback/aux 三条 provider 链。"""
|
||||
|
||||
runtime_kwargs = dict(kwargs)
|
||||
fallback_target = runtime_kwargs.pop("fallback_target", None)
|
||||
if fallback_target is None and "fallback_model" in kwargs:
|
||||
fallback_target = runtime_kwargs.pop("fallback_model")
|
||||
|
||||
main_runtime = build_provider_runtime(
|
||||
auxiliary=False,
|
||||
fallback_target=fallback_target,
|
||||
role="main",
|
||||
source="main_config",
|
||||
**runtime_kwargs,
|
||||
)
|
||||
primary_provider = make_provider_from_runtime(main_runtime)
|
||||
fallback_pair = make_fallback_provider(main_runtime, fallback_target)
|
||||
if fallback_pair is None:
|
||||
main_provider: LLMProvider = primary_provider
|
||||
fallback_runtime = None
|
||||
fallback_provider = None
|
||||
else:
|
||||
fallback_runtime, fallback_provider = fallback_pair
|
||||
main_provider = FallbackProviderChain(main_runtime, primary_provider, fallback_runtime, fallback_provider)
|
||||
|
||||
auxiliary_runtime = None
|
||||
auxiliary_provider = None
|
||||
if auxiliary_target is not None:
|
||||
auxiliary_runtime, auxiliary_provider = make_aux_provider(
|
||||
main_runtime,
|
||||
target=auxiliary_target,
|
||||
task_name=auxiliary_task_name,
|
||||
)
|
||||
|
||||
embedding_runtime = make_embedding_runtime(
|
||||
main_runtime,
|
||||
target=embedding_target,
|
||||
default_model=embedding_model,
|
||||
)
|
||||
|
||||
return ProviderBundle(
|
||||
main_runtime=main_runtime,
|
||||
main_provider=main_provider,
|
||||
fallback_runtime=fallback_runtime,
|
||||
fallback_provider=fallback_provider,
|
||||
auxiliary_runtime=auxiliary_runtime,
|
||||
auxiliary_provider=auxiliary_provider,
|
||||
embedding_runtime=embedding_runtime,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ProviderBundle",
|
||||
"ProviderRoutingConfig",
|
||||
"ProviderRuntime",
|
||||
"ProviderTarget",
|
||||
"build_provider_runtime",
|
||||
"make_aux_provider",
|
||||
"make_embedding_runtime",
|
||||
"make_fallback_provider",
|
||||
"make_main_provider",
|
||||
"make_provider_bundle",
|
||||
"make_provider_from_runtime",
|
||||
]
|
||||
230
app-instance/backend/beaver/engine/providers/litellm.py
Normal file
230
app-instance/backend/beaver/engine/providers/litellm.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .registry import find_by_model, find_gateway
|
||||
from .runtime import ProviderRoutingConfig
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import json_repair
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
json_repair = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
litellm = None # type: ignore[assignment]
|
||||
acompletion = None # type: ignore[assignment]
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""通过 LiteLLM 统一访问大多数 provider。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
routing: ProviderRoutingConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.routing = routing
|
||||
self.provider_name = provider_name
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
if litellm is not None:
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.drop_params = True
|
||||
|
||||
def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]:
|
||||
"""为当前请求生成 LiteLLM 依赖的临时环境变量。
|
||||
|
||||
LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime
|
||||
之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。
|
||||
"""
|
||||
|
||||
if not api_key:
|
||||
return {}
|
||||
spec = self._gateway or find_by_model(model)
|
||||
if spec is None or not spec.env_key:
|
||||
return {}
|
||||
overrides: dict[str, str] = {spec.env_key: api_key}
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_value in spec.env_extras:
|
||||
resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||||
overrides[env_name] = resolved
|
||||
return overrides
|
||||
|
||||
@contextmanager
|
||||
def _temporary_env(self, overrides: dict[str, str]):
|
||||
"""只在当前请求期间注入 provider 需要的环境变量。"""
|
||||
|
||||
if not overrides:
|
||||
yield
|
||||
return
|
||||
|
||||
sentinel = object()
|
||||
previous: dict[str, object] = {}
|
||||
for key, value in overrides.items():
|
||||
previous[key] = os.environ.get(key, sentinel)
|
||||
os.environ[key] = value
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for key, old_value in previous.items():
|
||||
if old_value is sentinel:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = str(old_value)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
if self._gateway:
|
||||
prefix = self._gateway.litellm_prefix
|
||||
resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
|
||||
model = f"{spec.litellm_prefix}/{model}"
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
sanitized = []
|
||||
for message in messages:
|
||||
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
|
||||
if clean.get("role") == "assistant" and "content" not in clean:
|
||||
clean["content"] = None
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None:
|
||||
spec = find_by_model(original_model)
|
||||
if spec is None:
|
||||
return
|
||||
model_lower = original_model.lower()
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None:
|
||||
if self.provider_name != "openrouter" or self.routing is None:
|
||||
return
|
||||
provider_payload: dict[str, Any] = {}
|
||||
if self.routing.sort:
|
||||
provider_payload["sort"] = self.routing.sort
|
||||
if self.routing.only:
|
||||
provider_payload["only"] = self.routing.only
|
||||
if self.routing.ignore:
|
||||
provider_payload["ignore"] = self.routing.ignore
|
||||
if self.routing.order:
|
||||
provider_payload["order"] = self.routing.order
|
||||
if self.routing.require_parameters:
|
||||
provider_payload["require_parameters"] = True
|
||||
if self.routing.data_collection:
|
||||
provider_payload["data_collection"] = self.routing.data_collection
|
||||
if provider_payload:
|
||||
kwargs["provider"] = provider_payload
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
if acompletion is None:
|
||||
return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name)
|
||||
|
||||
original_model = model or self.default_model
|
||||
resolved_model = self._resolve_model(original_model)
|
||||
sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages))
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": resolved_model,
|
||||
"messages": sanitized_messages,
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
self._apply_model_overrides(original_model, kwargs)
|
||||
self._apply_openrouter_routing(kwargs)
|
||||
env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model)
|
||||
|
||||
try:
|
||||
with self._temporary_env(env_overrides):
|
||||
response = await acompletion(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model)
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
for tool_call in message.tool_calls or []:
|
||||
raw_arguments = tool_call.function.arguments
|
||||
if isinstance(raw_arguments, str):
|
||||
try:
|
||||
if json_repair is not None:
|
||||
arguments = json_repair.loads(raw_arguments)
|
||||
else:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except Exception as exc:
|
||||
# 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。
|
||||
# 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。
|
||||
arguments = {
|
||||
"__beaver_tool_argument_parse_error__": str(exc),
|
||||
"__raw_arguments__": raw_arguments,
|
||||
}
|
||||
else:
|
||||
arguments = raw_arguments
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_payload = {}
|
||||
if usage is not None:
|
||||
usage_payload = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content=getattr(message, "content", None),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=getattr(choice, "finish_reason", "stop") or "stop",
|
||||
usage=usage_payload,
|
||||
reasoning_content=getattr(message, "reasoning_content", None),
|
||||
provider_name=self.provider_name or "litellm",
|
||||
model=resolved_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
249
app-instance/backend/beaver/engine/providers/registry.py
Normal file
249
app-instance/backend/beaver/engine/providers/registry.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""Provider registry: 统一维护 provider 元数据与匹配规则。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProviderSpec:
|
||||
"""单个 provider 的元数据定义。"""
|
||||
|
||||
name: str
|
||||
keywords: tuple[str, ...]
|
||||
env_key: str
|
||||
display_name: str = ""
|
||||
litellm_prefix: str = ""
|
||||
skip_prefixes: tuple[str, ...] = ()
|
||||
env_extras: tuple[tuple[str, str], ...] = ()
|
||||
is_gateway: bool = False
|
||||
is_local: bool = False
|
||||
detect_by_key_prefix: str = ""
|
||||
detect_by_base_keyword: str = ""
|
||||
default_api_base: str = ""
|
||||
strip_model_prefix: bool = False
|
||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||
is_oauth: bool = False
|
||||
is_direct: bool = False
|
||||
supports_prompt_caching: bool = False
|
||||
api_mode: str = "chat_completions"
|
||||
provider_impl: str = "litellm"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.display_name or self.name.title()
|
||||
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
ProviderSpec(
|
||||
name="custom",
|
||||
keywords=(),
|
||||
env_key="",
|
||||
display_name="Custom",
|
||||
is_direct=True,
|
||||
provider_impl="custom",
|
||||
api_mode="chat_completions",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter",
|
||||
is_gateway=True,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="siliconflow",
|
||||
keywords=("siliconflow",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="SiliconFlow",
|
||||
litellm_prefix="openai",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="siliconflow",
|
||||
default_api_base="https://api.siliconflow.cn/v1",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="volcengine",
|
||||
keywords=("volcengine", "volces", "ark"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="VolcEngine",
|
||||
litellm_prefix="volcengine",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="volces",
|
||||
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
supports_prompt_caching=True,
|
||||
api_mode="anthropic_messages",
|
||||
provider_impl="anthropic",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openai_codex",
|
||||
keywords=("openai-codex", "codex"),
|
||||
env_key="",
|
||||
display_name="OpenAI Codex",
|
||||
is_oauth=True,
|
||||
detect_by_base_keyword="codex",
|
||||
default_api_base="https://chatgpt.com/backend-api",
|
||||
api_mode="codex_responses",
|
||||
provider_impl="codex",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="github_copilot",
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="",
|
||||
display_name="Github Copilot",
|
||||
litellm_prefix="github_copilot",
|
||||
skip_prefixes=("github_copilot/",),
|
||||
is_oauth=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
env_key="DEEPSEEK_API_KEY",
|
||||
display_name="DeepSeek",
|
||||
litellm_prefix="deepseek",
|
||||
skip_prefixes=("deepseek/",),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini",
|
||||
skip_prefixes=("gemini/",),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai",
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope",
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot",
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
||||
default_api_base="https://api.moonshot.ai/v1",
|
||||
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax",
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="vllm",
|
||||
keywords=("vllm",),
|
||||
env_key="HOSTED_VLLM_API_KEY",
|
||||
display_name="vLLM/Local",
|
||||
litellm_prefix="hosted_vllm",
|
||||
is_local=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq",
|
||||
skip_prefixes=("groq/",),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def find_by_name(name: str) -> ProviderSpec | None:
|
||||
for spec in PROVIDERS:
|
||||
if spec.name == name:
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_by_model(model: str) -> ProviderSpec | None:
|
||||
"""按模型名关键词匹配标准 provider。"""
|
||||
|
||||
model_lower = model.lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
standard_specs = [spec for spec in PROVIDERS if not spec.is_gateway and not spec.is_local]
|
||||
|
||||
# 显式前缀优先级最高。
|
||||
# 这里不能只看 standard provider:
|
||||
# - `openrouter/...` 应该直接命中 openrouter
|
||||
# - `hosted_vllm/...` 应该能回到 vllm 这个本地 provider
|
||||
# - `github_copilot/...codex` 也不应被误判成 openai_codex
|
||||
for spec in PROVIDERS:
|
||||
aliases = {spec.name}
|
||||
if spec.litellm_prefix:
|
||||
aliases.add(spec.litellm_prefix.replace("-", "_"))
|
||||
if model_prefix and normalized_prefix in aliases:
|
||||
return spec
|
||||
|
||||
for spec in standard_specs:
|
||||
if any(keyword in model_lower or keyword.replace("-", "_") in model_normalized for keyword in spec.keywords):
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_gateway(
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""按 config key / api_key / api_base 识别 gateway 或 local provider。"""
|
||||
|
||||
if provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
return None
|
||||
408
app-instance/backend/beaver/engine/providers/runtime.py
Normal file
408
app-instance/backend/beaver/engine/providers/runtime.py
Normal file
@ -0,0 +1,408 @@
|
||||
"""Hermes 风格的 provider runtime resolution。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any
|
||||
|
||||
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderRoutingConfig:
|
||||
"""OpenRouter provider routing 配置。"""
|
||||
|
||||
sort: str | None = None
|
||||
only: list[str] = field(default_factory=list)
|
||||
ignore: list[str] = field(default_factory=list)
|
||||
order: list[str] = field(default_factory=list)
|
||||
require_parameters: bool = False
|
||||
data_collection: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderTarget:
|
||||
"""一次 provider 选路请求的标准化配置。
|
||||
|
||||
这层不是具体 runtime,而是“调用方想要什么”:
|
||||
- 用哪个 provider
|
||||
- 跑哪个 model
|
||||
- 是否指定自定义 base_url
|
||||
- 是否带额外 headers / routing
|
||||
|
||||
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
|
||||
"""
|
||||
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout_seconds: float | None = None
|
||||
routing: ProviderRoutingConfig | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderRuntime:
|
||||
"""运行时真正使用的 provider 解析结果。"""
|
||||
|
||||
spec: ProviderSpec
|
||||
model: str
|
||||
provider_name: str
|
||||
api_mode: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
default_model: str | None = None
|
||||
request_timeout_seconds: float | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
routing: ProviderRoutingConfig | None = None
|
||||
fallback_target: ProviderTarget | None = None
|
||||
auxiliary: bool = False
|
||||
role: str = "main"
|
||||
source: str = "runtime"
|
||||
|
||||
|
||||
def resolve_provider_runtime(
|
||||
*,
|
||||
model: str,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
routing: ProviderRoutingConfig | None = None,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
auxiliary: bool = False,
|
||||
role: str = "main",
|
||||
source: str = "runtime",
|
||||
) -> ProviderRuntime:
|
||||
"""把调用侧传入的配置解析成统一 runtime。"""
|
||||
|
||||
gateway = find_gateway(provider_name, api_key, api_base)
|
||||
if gateway is not None:
|
||||
spec = gateway
|
||||
elif provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
else:
|
||||
spec = find_by_model(model)
|
||||
|
||||
if spec is None:
|
||||
if api_base:
|
||||
spec = find_by_name("custom")
|
||||
else:
|
||||
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
|
||||
|
||||
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
|
||||
resolved_api_base = api_base or spec.default_api_base or None
|
||||
|
||||
return ProviderRuntime(
|
||||
spec=spec,
|
||||
model=resolved_model,
|
||||
provider_name=spec.name,
|
||||
api_mode=spec.api_mode,
|
||||
api_key=api_key,
|
||||
api_base=resolved_api_base,
|
||||
default_model=resolved_model,
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
extra_headers=extra_headers or {},
|
||||
routing=routing,
|
||||
fallback_target=normalize_provider_target(fallback_target),
|
||||
auxiliary=auxiliary,
|
||||
role=role,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
|
||||
"""把 dict/对象形式的 provider 配置收敛成统一结构。
|
||||
|
||||
这里兼容几种常见写法,便于后续接 CLI / config / gateway:
|
||||
- `provider` 或 `provider_name`
|
||||
- `base_url` 或 `api_base`
|
||||
- `headers` 或 `extra_headers`
|
||||
- `timeout` 或 `request_timeout_seconds`
|
||||
"""
|
||||
|
||||
if target is None:
|
||||
return None
|
||||
if isinstance(target, ProviderTarget):
|
||||
return target
|
||||
|
||||
provider_name = target.get("provider_name")
|
||||
if provider_name is None:
|
||||
provider_name = target.get("provider")
|
||||
|
||||
api_base = target.get("api_base")
|
||||
if api_base is None:
|
||||
api_base = target.get("base_url")
|
||||
|
||||
extra_headers = target.get("extra_headers")
|
||||
if extra_headers is None:
|
||||
extra_headers = target.get("headers")
|
||||
|
||||
request_timeout_seconds = target.get("request_timeout_seconds")
|
||||
if request_timeout_seconds is None:
|
||||
request_timeout_seconds = target.get("timeout")
|
||||
|
||||
routing = target.get("routing")
|
||||
if isinstance(routing, dict):
|
||||
routing = ProviderRoutingConfig(**routing)
|
||||
|
||||
return ProviderTarget(
|
||||
provider_name=provider_name,
|
||||
model=target.get("model"),
|
||||
api_key=target.get("api_key"),
|
||||
api_base=api_base,
|
||||
extra_headers=dict(extra_headers or {}),
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
routing=routing,
|
||||
)
|
||||
|
||||
|
||||
def resolve_fallback_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None,
|
||||
) -> ProviderRuntime | None:
|
||||
"""把 fallback 配置解析成独立 runtime。
|
||||
|
||||
Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。
|
||||
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
|
||||
"""
|
||||
|
||||
target = normalize_provider_target(fallback_target)
|
||||
if target is None or not target.model:
|
||||
return None
|
||||
|
||||
inferred_provider = target.provider_name
|
||||
if inferred_provider in {None, "", "main"}:
|
||||
inferred_provider = primary_runtime.provider_name
|
||||
|
||||
api_key = target.api_key
|
||||
api_base = target.api_base
|
||||
extra_headers = dict(target.extra_headers)
|
||||
|
||||
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
|
||||
if inferred_provider == primary_runtime.provider_name and not api_base:
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
return resolve_provider_runtime(
|
||||
model=target.model,
|
||||
provider_name=inferred_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=target.routing,
|
||||
auxiliary=False,
|
||||
role="fallback",
|
||||
source="fallback_config",
|
||||
)
|
||||
|
||||
|
||||
def resolve_auxiliary_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
*,
|
||||
task_name: str = "auxiliary",
|
||||
) -> ProviderRuntime:
|
||||
"""解析辅助任务专用 runtime。
|
||||
|
||||
支持三类输入:
|
||||
- `None` / `provider=main`:直接复用主链 provider
|
||||
- 显式 `provider + model`:走独立 provider
|
||||
- 仅给 `model`:按模型名自动匹配 provider
|
||||
"""
|
||||
|
||||
normalized = normalize_provider_target(target)
|
||||
if normalized is None:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
)
|
||||
|
||||
provider_name = normalized.provider_name
|
||||
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if provider_name == "main":
|
||||
return resolve_provider_runtime(
|
||||
model=normalized.model or primary_runtime.model,
|
||||
provider_name=primary_runtime.provider_name,
|
||||
api_key=normalized.api_key or primary_runtime.api_key,
|
||||
api_base=normalized.api_base or primary_runtime.api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
)
|
||||
|
||||
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auto->main",
|
||||
)
|
||||
|
||||
resolved_model = normalized.model or primary_runtime.model
|
||||
resolved_provider = normalized.provider_name
|
||||
if resolved_provider in {"auto", "", None} and not normalized.api_base:
|
||||
# `auto` 的第一阶段实现保持保守:
|
||||
# - 有显式 model 时按 model 匹配 provider
|
||||
# - 匹配不到则回退主链 provider
|
||||
spec = find_by_model(resolved_model)
|
||||
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
|
||||
|
||||
api_key = normalized.api_key
|
||||
api_base = normalized.api_base
|
||||
extra_headers = dict(normalized.extra_headers)
|
||||
|
||||
if resolved_provider == primary_runtime.provider_name and not api_base:
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
return resolve_provider_runtime(
|
||||
model=resolved_model,
|
||||
provider_name=resolved_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auxiliary_config",
|
||||
)
|
||||
|
||||
|
||||
def resolve_embedding_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
*,
|
||||
default_model: str = "text-embedding-v4",
|
||||
) -> ProviderRuntime | None:
|
||||
"""解析 embedding 专用 runtime。
|
||||
|
||||
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
|
||||
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
|
||||
"""
|
||||
|
||||
normalized = normalize_provider_target(target)
|
||||
|
||||
if normalized is None:
|
||||
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
|
||||
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
|
||||
if not _supports_openai_embeddings(primary_runtime):
|
||||
return None
|
||||
return resolve_provider_runtime(
|
||||
model=default_model,
|
||||
provider_name="openai",
|
||||
api_key=primary_runtime.api_key,
|
||||
api_base=primary_runtime.api_base,
|
||||
request_timeout_seconds=primary_runtime.request_timeout_seconds,
|
||||
extra_headers=dict(primary_runtime.extra_headers),
|
||||
routing=primary_runtime.routing,
|
||||
auxiliary=False,
|
||||
role="embedding",
|
||||
source="embedding_inherited",
|
||||
)
|
||||
|
||||
resolved_model = normalized.model or default_model
|
||||
resolved_provider = normalized.provider_name
|
||||
if resolved_provider in {None, "", "main", "auto"}:
|
||||
resolved_provider = "custom" if normalized.api_base else "openai"
|
||||
|
||||
api_key = normalized.api_key
|
||||
api_base = normalized.api_base
|
||||
extra_headers = dict(normalized.extra_headers)
|
||||
|
||||
if not api_base and _supports_openai_embeddings(primary_runtime):
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
runtime = resolve_provider_runtime(
|
||||
model=resolved_model,
|
||||
provider_name=resolved_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=normalized.routing,
|
||||
auxiliary=False,
|
||||
role="embedding",
|
||||
source="embedding_config",
|
||||
)
|
||||
if not _supports_openai_embeddings(runtime):
|
||||
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
|
||||
return runtime
|
||||
|
||||
|
||||
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
|
||||
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
|
||||
|
||||
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
|
||||
|
||||
|
||||
def _clone_runtime(
|
||||
runtime: ProviderRuntime,
|
||||
**changes: Any,
|
||||
) -> ProviderRuntime:
|
||||
"""基于现有 runtime 复制一个轻量变体。
|
||||
|
||||
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"extra_headers": dict(runtime.extra_headers),
|
||||
"routing": runtime.routing,
|
||||
"fallback_target": runtime.fallback_target,
|
||||
}
|
||||
payload.update(changes)
|
||||
return replace(runtime, **payload)
|
||||
|
||||
|
||||
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
|
||||
"""根据 registry 规则应用必要前缀。"""
|
||||
|
||||
resolved = model
|
||||
if gateway_mode:
|
||||
prefix = spec.litellm_prefix
|
||||
if spec.strip_model_prefix:
|
||||
resolved = resolved.split("/")[-1]
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
|
||||
if spec.litellm_prefix:
|
||||
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
|
||||
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
|
||||
resolved = f"{spec.litellm_prefix}/{resolved}"
|
||||
return resolved
|
||||
|
||||
|
||||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
||||
if "/" not in model:
|
||||
return model
|
||||
prefix, remainder = model.split("/", 1)
|
||||
if prefix.lower().replace("-", "_") != spec_name:
|
||||
return model
|
||||
return f"{canonical_prefix}/{remainder}"
|
||||
Reference in New Issue
Block a user