移除了所有Hermes相关的命名引用,包括: - 从.gitignore中清理相关构建缓存文件 - 将README中的beaver-home路径配置更新 - 完善backend/README.md文档说明Beaver后端主线实现 - 移除Hermes风格的相关注释和兼容性代码 - 清理nanobot环境变量兼容性处理 - 删除技能迁移和服务迁移相关功能代码 - 更新测试用例中相关命名和函数名 BREAKING CHANGE: 移除了Hermes迁移相关API和CLI命令,不再支持nanobot环境变量兼容性
409 lines
14 KiB
Python
409 lines
14 KiB
Python
"""Provider runtime resolution for Beaver."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass, field, replace
|
||
from typing import Any
|
||
|
||
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ProviderRoutingConfig:
|
||
"""OpenRouter provider routing 配置。"""
|
||
|
||
sort: str | None = None
|
||
only: list[str] = field(default_factory=list)
|
||
ignore: list[str] = field(default_factory=list)
|
||
order: list[str] = field(default_factory=list)
|
||
require_parameters: bool = False
|
||
data_collection: str | None = None
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ProviderTarget:
|
||
"""一次 provider 选路请求的标准化配置。
|
||
|
||
这层不是具体 runtime,而是“调用方想要什么”:
|
||
- 用哪个 provider
|
||
- 跑哪个 model
|
||
- 是否指定自定义 base_url
|
||
- 是否带额外 headers / routing
|
||
|
||
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
|
||
"""
|
||
|
||
provider_name: str | None = None
|
||
model: str | None = None
|
||
api_key: str | None = None
|
||
api_base: str | None = None
|
||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||
request_timeout_seconds: float | None = None
|
||
routing: ProviderRoutingConfig | None = None
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ProviderRuntime:
|
||
"""运行时真正使用的 provider 解析结果。"""
|
||
|
||
spec: ProviderSpec
|
||
model: str
|
||
provider_name: str
|
||
api_mode: str
|
||
api_key: str | None = None
|
||
api_base: str | None = None
|
||
default_model: str | None = None
|
||
request_timeout_seconds: float | None = None
|
||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||
routing: ProviderRoutingConfig | None = None
|
||
fallback_target: ProviderTarget | None = None
|
||
auxiliary: bool = False
|
||
role: str = "main"
|
||
source: str = "runtime"
|
||
|
||
|
||
def resolve_provider_runtime(
|
||
*,
|
||
model: str,
|
||
provider_name: str | None = None,
|
||
api_key: str | None = None,
|
||
api_base: str | None = None,
|
||
request_timeout_seconds: float | None = None,
|
||
extra_headers: dict[str, str] | None = None,
|
||
routing: ProviderRoutingConfig | None = None,
|
||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||
auxiliary: bool = False,
|
||
role: str = "main",
|
||
source: str = "runtime",
|
||
) -> ProviderRuntime:
|
||
"""把调用侧传入的配置解析成统一 runtime。"""
|
||
|
||
gateway = find_gateway(provider_name, api_key, api_base)
|
||
if gateway is not None:
|
||
spec = gateway
|
||
elif provider_name:
|
||
spec = find_by_name(provider_name)
|
||
else:
|
||
spec = find_by_model(model)
|
||
|
||
if spec is None:
|
||
if api_base:
|
||
spec = find_by_name("custom")
|
||
else:
|
||
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
|
||
|
||
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
|
||
resolved_api_base = api_base or spec.default_api_base or None
|
||
|
||
return ProviderRuntime(
|
||
spec=spec,
|
||
model=resolved_model,
|
||
provider_name=spec.name,
|
||
api_mode=spec.api_mode,
|
||
api_key=api_key,
|
||
api_base=resolved_api_base,
|
||
default_model=resolved_model,
|
||
request_timeout_seconds=request_timeout_seconds,
|
||
extra_headers=extra_headers or {},
|
||
routing=routing,
|
||
fallback_target=normalize_provider_target(fallback_target),
|
||
auxiliary=auxiliary,
|
||
role=role,
|
||
source=source,
|
||
)
|
||
|
||
|
||
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
|
||
"""把 dict/对象形式的 provider 配置收敛成统一结构。
|
||
|
||
这里兼容几种常见写法,便于后续接 CLI / config / gateway:
|
||
- `provider` 或 `provider_name`
|
||
- `base_url` 或 `api_base`
|
||
- `headers` 或 `extra_headers`
|
||
- `timeout` 或 `request_timeout_seconds`
|
||
"""
|
||
|
||
if target is None:
|
||
return None
|
||
if isinstance(target, ProviderTarget):
|
||
return target
|
||
|
||
provider_name = target.get("provider_name")
|
||
if provider_name is None:
|
||
provider_name = target.get("provider")
|
||
|
||
api_base = target.get("api_base")
|
||
if api_base is None:
|
||
api_base = target.get("base_url")
|
||
|
||
extra_headers = target.get("extra_headers")
|
||
if extra_headers is None:
|
||
extra_headers = target.get("headers")
|
||
|
||
request_timeout_seconds = target.get("request_timeout_seconds")
|
||
if request_timeout_seconds is None:
|
||
request_timeout_seconds = target.get("timeout")
|
||
|
||
routing = target.get("routing")
|
||
if isinstance(routing, dict):
|
||
routing = ProviderRoutingConfig(**routing)
|
||
|
||
return ProviderTarget(
|
||
provider_name=provider_name,
|
||
model=target.get("model"),
|
||
api_key=target.get("api_key"),
|
||
api_base=api_base,
|
||
extra_headers=dict(extra_headers or {}),
|
||
request_timeout_seconds=request_timeout_seconds,
|
||
routing=routing,
|
||
)
|
||
|
||
|
||
def resolve_fallback_runtime(
|
||
primary_runtime: ProviderRuntime,
|
||
fallback_target: ProviderTarget | dict[str, Any] | None,
|
||
) -> ProviderRuntime | None:
|
||
"""把 fallback 配置解析成独立 runtime。
|
||
|
||
fallback 的语义是“主 provider 失败后切换到另一个 provider:model”。
|
||
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
|
||
"""
|
||
|
||
target = normalize_provider_target(fallback_target)
|
||
if target is None or not target.model:
|
||
return None
|
||
|
||
inferred_provider = target.provider_name
|
||
if inferred_provider in {None, "", "main"}:
|
||
inferred_provider = primary_runtime.provider_name
|
||
|
||
api_key = target.api_key
|
||
api_base = target.api_base
|
||
extra_headers = dict(target.extra_headers)
|
||
|
||
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
|
||
if inferred_provider == primary_runtime.provider_name and not api_base:
|
||
api_key = api_key or primary_runtime.api_key
|
||
api_base = api_base or primary_runtime.api_base
|
||
if not extra_headers:
|
||
extra_headers = dict(primary_runtime.extra_headers)
|
||
|
||
return resolve_provider_runtime(
|
||
model=target.model,
|
||
provider_name=inferred_provider,
|
||
api_key=api_key,
|
||
api_base=api_base,
|
||
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||
extra_headers=extra_headers,
|
||
routing=target.routing,
|
||
auxiliary=False,
|
||
role="fallback",
|
||
source="fallback_config",
|
||
)
|
||
|
||
|
||
def resolve_auxiliary_runtime(
|
||
primary_runtime: ProviderRuntime,
|
||
target: ProviderTarget | dict[str, Any] | None = None,
|
||
*,
|
||
task_name: str = "auxiliary",
|
||
) -> ProviderRuntime:
|
||
"""解析辅助任务专用 runtime。
|
||
|
||
支持三类输入:
|
||
- `None` / `provider=main`:直接复用主链 provider
|
||
- 显式 `provider + model`:走独立 provider
|
||
- 仅给 `model`:按模型名自动匹配 provider
|
||
"""
|
||
|
||
normalized = normalize_provider_target(target)
|
||
if normalized is None:
|
||
return _clone_runtime(
|
||
primary_runtime,
|
||
auxiliary=True,
|
||
role=task_name,
|
||
source="main_runtime",
|
||
)
|
||
|
||
provider_name = normalized.provider_name
|
||
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
|
||
return _clone_runtime(
|
||
primary_runtime,
|
||
auxiliary=True,
|
||
role=task_name,
|
||
source="main_runtime",
|
||
routing=normalized.routing or primary_runtime.routing,
|
||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||
)
|
||
|
||
if provider_name == "main":
|
||
return resolve_provider_runtime(
|
||
model=normalized.model or primary_runtime.model,
|
||
provider_name=primary_runtime.provider_name,
|
||
api_key=normalized.api_key or primary_runtime.api_key,
|
||
api_base=normalized.api_base or primary_runtime.api_base,
|
||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||
routing=normalized.routing or primary_runtime.routing,
|
||
auxiliary=True,
|
||
role=task_name,
|
||
source="main_runtime",
|
||
)
|
||
|
||
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
|
||
return _clone_runtime(
|
||
primary_runtime,
|
||
auxiliary=True,
|
||
role=task_name,
|
||
source="auto->main",
|
||
)
|
||
|
||
resolved_model = normalized.model or primary_runtime.model
|
||
resolved_provider = normalized.provider_name
|
||
if resolved_provider in {"auto", "", None} and not normalized.api_base:
|
||
# `auto` 的第一阶段实现保持保守:
|
||
# - 有显式 model 时按 model 匹配 provider
|
||
# - 匹配不到则回退主链 provider
|
||
spec = find_by_model(resolved_model)
|
||
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
|
||
|
||
api_key = normalized.api_key
|
||
api_base = normalized.api_base
|
||
extra_headers = dict(normalized.extra_headers)
|
||
|
||
if resolved_provider == primary_runtime.provider_name and not api_base:
|
||
api_key = api_key or primary_runtime.api_key
|
||
api_base = api_base or primary_runtime.api_base
|
||
if not extra_headers:
|
||
extra_headers = dict(primary_runtime.extra_headers)
|
||
|
||
return resolve_provider_runtime(
|
||
model=resolved_model,
|
||
provider_name=resolved_provider,
|
||
api_key=api_key,
|
||
api_base=api_base,
|
||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||
extra_headers=extra_headers,
|
||
routing=normalized.routing or primary_runtime.routing,
|
||
auxiliary=True,
|
||
role=task_name,
|
||
source="auxiliary_config",
|
||
)
|
||
|
||
|
||
def resolve_embedding_runtime(
|
||
primary_runtime: ProviderRuntime,
|
||
target: ProviderTarget | dict[str, Any] | None = None,
|
||
*,
|
||
default_model: str = "text-embedding-v4",
|
||
) -> ProviderRuntime | None:
|
||
"""解析 embedding 专用 runtime。
|
||
|
||
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
|
||
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
|
||
"""
|
||
|
||
normalized = normalize_provider_target(target)
|
||
|
||
if normalized is None:
|
||
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
|
||
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
|
||
if not _supports_openai_embeddings(primary_runtime):
|
||
return None
|
||
return resolve_provider_runtime(
|
||
model=default_model,
|
||
provider_name="openai",
|
||
api_key=primary_runtime.api_key,
|
||
api_base=primary_runtime.api_base,
|
||
request_timeout_seconds=primary_runtime.request_timeout_seconds,
|
||
extra_headers=dict(primary_runtime.extra_headers),
|
||
routing=primary_runtime.routing,
|
||
auxiliary=False,
|
||
role="embedding",
|
||
source="embedding_inherited",
|
||
)
|
||
|
||
resolved_model = normalized.model or default_model
|
||
resolved_provider = normalized.provider_name
|
||
if resolved_provider in {None, "", "main", "auto"}:
|
||
resolved_provider = "custom" if normalized.api_base else "openai"
|
||
|
||
api_key = normalized.api_key
|
||
api_base = normalized.api_base
|
||
extra_headers = dict(normalized.extra_headers)
|
||
|
||
if not api_base and _supports_openai_embeddings(primary_runtime):
|
||
api_key = api_key or primary_runtime.api_key
|
||
api_base = api_base or primary_runtime.api_base
|
||
if not extra_headers:
|
||
extra_headers = dict(primary_runtime.extra_headers)
|
||
|
||
runtime = resolve_provider_runtime(
|
||
model=resolved_model,
|
||
provider_name=resolved_provider,
|
||
api_key=api_key,
|
||
api_base=api_base,
|
||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||
extra_headers=extra_headers,
|
||
routing=normalized.routing,
|
||
auxiliary=False,
|
||
role="embedding",
|
||
source="embedding_config",
|
||
)
|
||
if not _supports_openai_embeddings(runtime):
|
||
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
|
||
return runtime
|
||
|
||
|
||
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
|
||
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
|
||
|
||
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
|
||
|
||
|
||
def _clone_runtime(
|
||
runtime: ProviderRuntime,
|
||
**changes: Any,
|
||
) -> ProviderRuntime:
|
||
"""基于现有 runtime 复制一个轻量变体。
|
||
|
||
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
|
||
"""
|
||
|
||
payload = {
|
||
"extra_headers": dict(runtime.extra_headers),
|
||
"routing": runtime.routing,
|
||
"fallback_target": runtime.fallback_target,
|
||
}
|
||
payload.update(changes)
|
||
return replace(runtime, **payload)
|
||
|
||
|
||
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
|
||
"""根据 registry 规则应用必要前缀。"""
|
||
|
||
resolved = model
|
||
if gateway_mode:
|
||
prefix = spec.litellm_prefix
|
||
if spec.strip_model_prefix:
|
||
resolved = resolved.split("/")[-1]
|
||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||
resolved = f"{prefix}/{resolved}"
|
||
return resolved
|
||
|
||
if spec.litellm_prefix:
|
||
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
|
||
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
|
||
resolved = f"{spec.litellm_prefix}/{resolved}"
|
||
return resolved
|
||
|
||
|
||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
||
if "/" not in model:
|
||
return model
|
||
prefix, remainder = model.split("/", 1)
|
||
if prefix.lower().replace("-", "_") != spec_name:
|
||
return model
|
||
return f"{canonical_prefix}/{remainder}"
|