Files
beaver_project/app-instance/backend/beaver/engine/providers/runtime.py

409 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Hermes 风格的 provider runtime resolution。"""
from __future__ import annotations
from dataclasses import dataclass, field, replace
from typing import Any
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
@dataclass(slots=True)
class ProviderRoutingConfig:
"""OpenRouter provider routing 配置。"""
sort: str | None = None
only: list[str] = field(default_factory=list)
ignore: list[str] = field(default_factory=list)
order: list[str] = field(default_factory=list)
require_parameters: bool = False
data_collection: str | None = None
@dataclass(slots=True)
class ProviderTarget:
"""一次 provider 选路请求的标准化配置。
这层不是具体 runtime而是“调用方想要什么”
- 用哪个 provider
- 跑哪个 model
- 是否指定自定义 base_url
- 是否带额外 headers / routing
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
"""
provider_name: str | None = None
model: str | None = None
api_key: str | None = None
api_base: str | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
request_timeout_seconds: float | None = None
routing: ProviderRoutingConfig | None = None
@dataclass(slots=True)
class ProviderRuntime:
"""运行时真正使用的 provider 解析结果。"""
spec: ProviderSpec
model: str
provider_name: str
api_mode: str
api_key: str | None = None
api_base: str | None = None
default_model: str | None = None
request_timeout_seconds: float | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
routing: ProviderRoutingConfig | None = None
fallback_target: ProviderTarget | None = None
auxiliary: bool = False
role: str = "main"
source: str = "runtime"
def resolve_provider_runtime(
*,
model: str,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
request_timeout_seconds: float | None = None,
extra_headers: dict[str, str] | None = None,
routing: ProviderRoutingConfig | None = None,
fallback_target: ProviderTarget | dict[str, Any] | None = None,
auxiliary: bool = False,
role: str = "main",
source: str = "runtime",
) -> ProviderRuntime:
"""把调用侧传入的配置解析成统一 runtime。"""
gateway = find_gateway(provider_name, api_key, api_base)
if gateway is not None:
spec = gateway
elif provider_name:
spec = find_by_name(provider_name)
else:
spec = find_by_model(model)
if spec is None:
if api_base:
spec = find_by_name("custom")
else:
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
resolved_api_base = api_base or spec.default_api_base or None
return ProviderRuntime(
spec=spec,
model=resolved_model,
provider_name=spec.name,
api_mode=spec.api_mode,
api_key=api_key,
api_base=resolved_api_base,
default_model=resolved_model,
request_timeout_seconds=request_timeout_seconds,
extra_headers=extra_headers or {},
routing=routing,
fallback_target=normalize_provider_target(fallback_target),
auxiliary=auxiliary,
role=role,
source=source,
)
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
"""把 dict/对象形式的 provider 配置收敛成统一结构。
这里兼容几种常见写法,便于后续接 CLI / config / gateway
- `provider` 或 `provider_name`
- `base_url` 或 `api_base`
- `headers` 或 `extra_headers`
- `timeout` 或 `request_timeout_seconds`
"""
if target is None:
return None
if isinstance(target, ProviderTarget):
return target
provider_name = target.get("provider_name")
if provider_name is None:
provider_name = target.get("provider")
api_base = target.get("api_base")
if api_base is None:
api_base = target.get("base_url")
extra_headers = target.get("extra_headers")
if extra_headers is None:
extra_headers = target.get("headers")
request_timeout_seconds = target.get("request_timeout_seconds")
if request_timeout_seconds is None:
request_timeout_seconds = target.get("timeout")
routing = target.get("routing")
if isinstance(routing, dict):
routing = ProviderRoutingConfig(**routing)
return ProviderTarget(
provider_name=provider_name,
model=target.get("model"),
api_key=target.get("api_key"),
api_base=api_base,
extra_headers=dict(extra_headers or {}),
request_timeout_seconds=request_timeout_seconds,
routing=routing,
)
def resolve_fallback_runtime(
primary_runtime: ProviderRuntime,
fallback_target: ProviderTarget | dict[str, Any] | None,
) -> ProviderRuntime | None:
"""把 fallback 配置解析成独立 runtime。
Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
"""
target = normalize_provider_target(fallback_target)
if target is None or not target.model:
return None
inferred_provider = target.provider_name
if inferred_provider in {None, "", "main"}:
inferred_provider = primary_runtime.provider_name
api_key = target.api_key
api_base = target.api_base
extra_headers = dict(target.extra_headers)
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
if inferred_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=target.model,
provider_name=inferred_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=target.routing,
auxiliary=False,
role="fallback",
source="fallback_config",
)
def resolve_auxiliary_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
task_name: str = "auxiliary",
) -> ProviderRuntime:
"""解析辅助任务专用 runtime。
支持三类输入:
- `None` / `provider=main`:直接复用主链 provider
- 显式 `provider + model`:走独立 provider
- 仅给 `model`:按模型名自动匹配 provider
"""
normalized = normalize_provider_target(target)
if normalized is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
)
provider_name = normalized.provider_name
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
routing=normalized.routing or primary_runtime.routing,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
)
if provider_name == "main":
return resolve_provider_runtime(
model=normalized.model or primary_runtime.model,
provider_name=primary_runtime.provider_name,
api_key=normalized.api_key or primary_runtime.api_key,
api_base=normalized.api_base or primary_runtime.api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="main_runtime",
)
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="auto->main",
)
resolved_model = normalized.model or primary_runtime.model
resolved_provider = normalized.provider_name
if resolved_provider in {"auto", "", None} and not normalized.api_base:
# `auto` 的第一阶段实现保持保守:
# - 有显式 model 时按 model 匹配 provider
# - 匹配不到则回退主链 provider
spec = find_by_model(resolved_model)
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if resolved_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="auxiliary_config",
)
def resolve_embedding_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
default_model: str = "text-embedding-v4",
) -> ProviderRuntime | None:
"""解析 embedding 专用 runtime。
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
"""
normalized = normalize_provider_target(target)
if normalized is None:
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
if not _supports_openai_embeddings(primary_runtime):
return None
return resolve_provider_runtime(
model=default_model,
provider_name="openai",
api_key=primary_runtime.api_key,
api_base=primary_runtime.api_base,
request_timeout_seconds=primary_runtime.request_timeout_seconds,
extra_headers=dict(primary_runtime.extra_headers),
routing=primary_runtime.routing,
auxiliary=False,
role="embedding",
source="embedding_inherited",
)
resolved_model = normalized.model or default_model
resolved_provider = normalized.provider_name
if resolved_provider in {None, "", "main", "auto"}:
resolved_provider = "custom" if normalized.api_base else "openai"
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if not api_base and _supports_openai_embeddings(primary_runtime):
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
runtime = resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing,
auxiliary=False,
role="embedding",
source="embedding_config",
)
if not _supports_openai_embeddings(runtime):
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
return runtime
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
def _clone_runtime(
runtime: ProviderRuntime,
**changes: Any,
) -> ProviderRuntime:
"""基于现有 runtime 复制一个轻量变体。
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
"""
payload = {
"extra_headers": dict(runtime.extra_headers),
"routing": runtime.routing,
"fallback_target": runtime.fallback_target,
}
payload.update(changes)
return replace(runtime, **payload)
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
"""根据 registry 规则应用必要前缀。"""
resolved = model
if gateway_mode:
prefix = spec.litellm_prefix
if spec.strip_model_prefix:
resolved = resolved.split("/")[-1]
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
if spec.litellm_prefix:
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
resolved = f"{spec.litellm_prefix}/{resolved}"
return resolved
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
if "/" not in model:
return model
prefix, remainder = model.split("/", 1)
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"