Files
beaver_project/app-instance/backend/beaver/engine/providers/runtime.py
steven_li 3b0af173cc refactor(beaver): 移除Hermes相关引用和迁移代码,完善Beaver后端主线实现
移除了所有Hermes相关的命名引用,包括:
- 从.gitignore中清理相关构建缓存文件
- 将README中的beaver-home路径配置更新
- 完善backend/README.md文档说明Beaver后端主线实现
- 移除Hermes风格的相关注释和兼容性代码
- 清理nanobot环境变量兼容性处理
- 删除技能迁移和服务迁移相关功能代码
- 更新测试用例中相关命名和函数名

BREAKING CHANGE: 移除了Hermes迁移相关API和CLI命令,不再支持nanobot环境变量兼容性
2026-05-14 17:20:32 +08:00

409 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Provider runtime resolution for Beaver."""
from __future__ import annotations
from dataclasses import dataclass, field, replace
from typing import Any
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
@dataclass(slots=True)
class ProviderRoutingConfig:
"""OpenRouter provider routing 配置。"""
sort: str | None = None
only: list[str] = field(default_factory=list)
ignore: list[str] = field(default_factory=list)
order: list[str] = field(default_factory=list)
require_parameters: bool = False
data_collection: str | None = None
@dataclass(slots=True)
class ProviderTarget:
"""一次 provider 选路请求的标准化配置。
这层不是具体 runtime而是“调用方想要什么”
- 用哪个 provider
- 跑哪个 model
- 是否指定自定义 base_url
- 是否带额外 headers / routing
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
"""
provider_name: str | None = None
model: str | None = None
api_key: str | None = None
api_base: str | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
request_timeout_seconds: float | None = None
routing: ProviderRoutingConfig | None = None
@dataclass(slots=True)
class ProviderRuntime:
"""运行时真正使用的 provider 解析结果。"""
spec: ProviderSpec
model: str
provider_name: str
api_mode: str
api_key: str | None = None
api_base: str | None = None
default_model: str | None = None
request_timeout_seconds: float | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
routing: ProviderRoutingConfig | None = None
fallback_target: ProviderTarget | None = None
auxiliary: bool = False
role: str = "main"
source: str = "runtime"
def resolve_provider_runtime(
*,
model: str,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
request_timeout_seconds: float | None = None,
extra_headers: dict[str, str] | None = None,
routing: ProviderRoutingConfig | None = None,
fallback_target: ProviderTarget | dict[str, Any] | None = None,
auxiliary: bool = False,
role: str = "main",
source: str = "runtime",
) -> ProviderRuntime:
"""把调用侧传入的配置解析成统一 runtime。"""
gateway = find_gateway(provider_name, api_key, api_base)
if gateway is not None:
spec = gateway
elif provider_name:
spec = find_by_name(provider_name)
else:
spec = find_by_model(model)
if spec is None:
if api_base:
spec = find_by_name("custom")
else:
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
resolved_api_base = api_base or spec.default_api_base or None
return ProviderRuntime(
spec=spec,
model=resolved_model,
provider_name=spec.name,
api_mode=spec.api_mode,
api_key=api_key,
api_base=resolved_api_base,
default_model=resolved_model,
request_timeout_seconds=request_timeout_seconds,
extra_headers=extra_headers or {},
routing=routing,
fallback_target=normalize_provider_target(fallback_target),
auxiliary=auxiliary,
role=role,
source=source,
)
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
"""把 dict/对象形式的 provider 配置收敛成统一结构。
这里兼容几种常见写法,便于后续接 CLI / config / gateway
- `provider` 或 `provider_name`
- `base_url` 或 `api_base`
- `headers` 或 `extra_headers`
- `timeout` 或 `request_timeout_seconds`
"""
if target is None:
return None
if isinstance(target, ProviderTarget):
return target
provider_name = target.get("provider_name")
if provider_name is None:
provider_name = target.get("provider")
api_base = target.get("api_base")
if api_base is None:
api_base = target.get("base_url")
extra_headers = target.get("extra_headers")
if extra_headers is None:
extra_headers = target.get("headers")
request_timeout_seconds = target.get("request_timeout_seconds")
if request_timeout_seconds is None:
request_timeout_seconds = target.get("timeout")
routing = target.get("routing")
if isinstance(routing, dict):
routing = ProviderRoutingConfig(**routing)
return ProviderTarget(
provider_name=provider_name,
model=target.get("model"),
api_key=target.get("api_key"),
api_base=api_base,
extra_headers=dict(extra_headers or {}),
request_timeout_seconds=request_timeout_seconds,
routing=routing,
)
def resolve_fallback_runtime(
primary_runtime: ProviderRuntime,
fallback_target: ProviderTarget | dict[str, Any] | None,
) -> ProviderRuntime | None:
"""把 fallback 配置解析成独立 runtime。
fallback 的语义是“主 provider 失败后切换到另一个 provider:model”。
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
"""
target = normalize_provider_target(fallback_target)
if target is None or not target.model:
return None
inferred_provider = target.provider_name
if inferred_provider in {None, "", "main"}:
inferred_provider = primary_runtime.provider_name
api_key = target.api_key
api_base = target.api_base
extra_headers = dict(target.extra_headers)
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
if inferred_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=target.model,
provider_name=inferred_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=target.routing,
auxiliary=False,
role="fallback",
source="fallback_config",
)
def resolve_auxiliary_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
task_name: str = "auxiliary",
) -> ProviderRuntime:
"""解析辅助任务专用 runtime。
支持三类输入:
- `None` / `provider=main`:直接复用主链 provider
- 显式 `provider + model`:走独立 provider
- 仅给 `model`:按模型名自动匹配 provider
"""
normalized = normalize_provider_target(target)
if normalized is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
)
provider_name = normalized.provider_name
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
routing=normalized.routing or primary_runtime.routing,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
)
if provider_name == "main":
return resolve_provider_runtime(
model=normalized.model or primary_runtime.model,
provider_name=primary_runtime.provider_name,
api_key=normalized.api_key or primary_runtime.api_key,
api_base=normalized.api_base or primary_runtime.api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="main_runtime",
)
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="auto->main",
)
resolved_model = normalized.model or primary_runtime.model
resolved_provider = normalized.provider_name
if resolved_provider in {"auto", "", None} and not normalized.api_base:
# `auto` 的第一阶段实现保持保守:
# - 有显式 model 时按 model 匹配 provider
# - 匹配不到则回退主链 provider
spec = find_by_model(resolved_model)
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if resolved_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="auxiliary_config",
)
def resolve_embedding_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
default_model: str = "text-embedding-v4",
) -> ProviderRuntime | None:
"""解析 embedding 专用 runtime。
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
"""
normalized = normalize_provider_target(target)
if normalized is None:
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
if not _supports_openai_embeddings(primary_runtime):
return None
return resolve_provider_runtime(
model=default_model,
provider_name="openai",
api_key=primary_runtime.api_key,
api_base=primary_runtime.api_base,
request_timeout_seconds=primary_runtime.request_timeout_seconds,
extra_headers=dict(primary_runtime.extra_headers),
routing=primary_runtime.routing,
auxiliary=False,
role="embedding",
source="embedding_inherited",
)
resolved_model = normalized.model or default_model
resolved_provider = normalized.provider_name
if resolved_provider in {None, "", "main", "auto"}:
resolved_provider = "custom" if normalized.api_base else "openai"
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if not api_base and _supports_openai_embeddings(primary_runtime):
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
runtime = resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing,
auxiliary=False,
role="embedding",
source="embedding_config",
)
if not _supports_openai_embeddings(runtime):
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
return runtime
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
def _clone_runtime(
runtime: ProviderRuntime,
**changes: Any,
) -> ProviderRuntime:
"""基于现有 runtime 复制一个轻量变体。
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
"""
payload = {
"extra_headers": dict(runtime.extra_headers),
"routing": runtime.routing,
"fallback_target": runtime.fallback_target,
}
payload.update(changes)
return replace(runtime, **payload)
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
"""根据 registry 规则应用必要前缀。"""
resolved = model
if gateway_mode:
prefix = spec.litellm_prefix
if spec.strip_model_prefix:
resolved = resolved.split("/")[-1]
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
if spec.litellm_prefix:
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
resolved = f"{spec.litellm_prefix}/{resolved}"
return resolved
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
if "/" not in model:
return model
prefix, remainder = model.split("/", 1)
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"