修改了nanobot,往Hermes agent的风格走,进度1/3
This commit is contained in:
408
app-instance/backend/beaver/engine/providers/runtime.py
Normal file
408
app-instance/backend/beaver/engine/providers/runtime.py
Normal file
@ -0,0 +1,408 @@
|
||||
"""Hermes 风格的 provider runtime resolution。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any
|
||||
|
||||
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderRoutingConfig:
|
||||
"""OpenRouter provider routing 配置。"""
|
||||
|
||||
sort: str | None = None
|
||||
only: list[str] = field(default_factory=list)
|
||||
ignore: list[str] = field(default_factory=list)
|
||||
order: list[str] = field(default_factory=list)
|
||||
require_parameters: bool = False
|
||||
data_collection: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderTarget:
|
||||
"""一次 provider 选路请求的标准化配置。
|
||||
|
||||
这层不是具体 runtime,而是“调用方想要什么”:
|
||||
- 用哪个 provider
|
||||
- 跑哪个 model
|
||||
- 是否指定自定义 base_url
|
||||
- 是否带额外 headers / routing
|
||||
|
||||
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
|
||||
"""
|
||||
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout_seconds: float | None = None
|
||||
routing: ProviderRoutingConfig | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderRuntime:
|
||||
"""运行时真正使用的 provider 解析结果。"""
|
||||
|
||||
spec: ProviderSpec
|
||||
model: str
|
||||
provider_name: str
|
||||
api_mode: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
default_model: str | None = None
|
||||
request_timeout_seconds: float | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
routing: ProviderRoutingConfig | None = None
|
||||
fallback_target: ProviderTarget | None = None
|
||||
auxiliary: bool = False
|
||||
role: str = "main"
|
||||
source: str = "runtime"
|
||||
|
||||
|
||||
def resolve_provider_runtime(
|
||||
*,
|
||||
model: str,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
routing: ProviderRoutingConfig | None = None,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
auxiliary: bool = False,
|
||||
role: str = "main",
|
||||
source: str = "runtime",
|
||||
) -> ProviderRuntime:
|
||||
"""把调用侧传入的配置解析成统一 runtime。"""
|
||||
|
||||
gateway = find_gateway(provider_name, api_key, api_base)
|
||||
if gateway is not None:
|
||||
spec = gateway
|
||||
elif provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
else:
|
||||
spec = find_by_model(model)
|
||||
|
||||
if spec is None:
|
||||
if api_base:
|
||||
spec = find_by_name("custom")
|
||||
else:
|
||||
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
|
||||
|
||||
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
|
||||
resolved_api_base = api_base or spec.default_api_base or None
|
||||
|
||||
return ProviderRuntime(
|
||||
spec=spec,
|
||||
model=resolved_model,
|
||||
provider_name=spec.name,
|
||||
api_mode=spec.api_mode,
|
||||
api_key=api_key,
|
||||
api_base=resolved_api_base,
|
||||
default_model=resolved_model,
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
extra_headers=extra_headers or {},
|
||||
routing=routing,
|
||||
fallback_target=normalize_provider_target(fallback_target),
|
||||
auxiliary=auxiliary,
|
||||
role=role,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
|
||||
"""把 dict/对象形式的 provider 配置收敛成统一结构。
|
||||
|
||||
这里兼容几种常见写法,便于后续接 CLI / config / gateway:
|
||||
- `provider` 或 `provider_name`
|
||||
- `base_url` 或 `api_base`
|
||||
- `headers` 或 `extra_headers`
|
||||
- `timeout` 或 `request_timeout_seconds`
|
||||
"""
|
||||
|
||||
if target is None:
|
||||
return None
|
||||
if isinstance(target, ProviderTarget):
|
||||
return target
|
||||
|
||||
provider_name = target.get("provider_name")
|
||||
if provider_name is None:
|
||||
provider_name = target.get("provider")
|
||||
|
||||
api_base = target.get("api_base")
|
||||
if api_base is None:
|
||||
api_base = target.get("base_url")
|
||||
|
||||
extra_headers = target.get("extra_headers")
|
||||
if extra_headers is None:
|
||||
extra_headers = target.get("headers")
|
||||
|
||||
request_timeout_seconds = target.get("request_timeout_seconds")
|
||||
if request_timeout_seconds is None:
|
||||
request_timeout_seconds = target.get("timeout")
|
||||
|
||||
routing = target.get("routing")
|
||||
if isinstance(routing, dict):
|
||||
routing = ProviderRoutingConfig(**routing)
|
||||
|
||||
return ProviderTarget(
|
||||
provider_name=provider_name,
|
||||
model=target.get("model"),
|
||||
api_key=target.get("api_key"),
|
||||
api_base=api_base,
|
||||
extra_headers=dict(extra_headers or {}),
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
routing=routing,
|
||||
)
|
||||
|
||||
|
||||
def resolve_fallback_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None,
|
||||
) -> ProviderRuntime | None:
|
||||
"""把 fallback 配置解析成独立 runtime。
|
||||
|
||||
Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。
|
||||
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
|
||||
"""
|
||||
|
||||
target = normalize_provider_target(fallback_target)
|
||||
if target is None or not target.model:
|
||||
return None
|
||||
|
||||
inferred_provider = target.provider_name
|
||||
if inferred_provider in {None, "", "main"}:
|
||||
inferred_provider = primary_runtime.provider_name
|
||||
|
||||
api_key = target.api_key
|
||||
api_base = target.api_base
|
||||
extra_headers = dict(target.extra_headers)
|
||||
|
||||
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
|
||||
if inferred_provider == primary_runtime.provider_name and not api_base:
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
return resolve_provider_runtime(
|
||||
model=target.model,
|
||||
provider_name=inferred_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=target.routing,
|
||||
auxiliary=False,
|
||||
role="fallback",
|
||||
source="fallback_config",
|
||||
)
|
||||
|
||||
|
||||
def resolve_auxiliary_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
*,
|
||||
task_name: str = "auxiliary",
|
||||
) -> ProviderRuntime:
|
||||
"""解析辅助任务专用 runtime。
|
||||
|
||||
支持三类输入:
|
||||
- `None` / `provider=main`:直接复用主链 provider
|
||||
- 显式 `provider + model`:走独立 provider
|
||||
- 仅给 `model`:按模型名自动匹配 provider
|
||||
"""
|
||||
|
||||
normalized = normalize_provider_target(target)
|
||||
if normalized is None:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
)
|
||||
|
||||
provider_name = normalized.provider_name
|
||||
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if provider_name == "main":
|
||||
return resolve_provider_runtime(
|
||||
model=normalized.model or primary_runtime.model,
|
||||
provider_name=primary_runtime.provider_name,
|
||||
api_key=normalized.api_key or primary_runtime.api_key,
|
||||
api_base=normalized.api_base or primary_runtime.api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
)
|
||||
|
||||
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auto->main",
|
||||
)
|
||||
|
||||
resolved_model = normalized.model or primary_runtime.model
|
||||
resolved_provider = normalized.provider_name
|
||||
if resolved_provider in {"auto", "", None} and not normalized.api_base:
|
||||
# `auto` 的第一阶段实现保持保守:
|
||||
# - 有显式 model 时按 model 匹配 provider
|
||||
# - 匹配不到则回退主链 provider
|
||||
spec = find_by_model(resolved_model)
|
||||
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
|
||||
|
||||
api_key = normalized.api_key
|
||||
api_base = normalized.api_base
|
||||
extra_headers = dict(normalized.extra_headers)
|
||||
|
||||
if resolved_provider == primary_runtime.provider_name and not api_base:
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
return resolve_provider_runtime(
|
||||
model=resolved_model,
|
||||
provider_name=resolved_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auxiliary_config",
|
||||
)
|
||||
|
||||
|
||||
def resolve_embedding_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
*,
|
||||
default_model: str = "text-embedding-v4",
|
||||
) -> ProviderRuntime | None:
|
||||
"""解析 embedding 专用 runtime。
|
||||
|
||||
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
|
||||
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
|
||||
"""
|
||||
|
||||
normalized = normalize_provider_target(target)
|
||||
|
||||
if normalized is None:
|
||||
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
|
||||
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
|
||||
if not _supports_openai_embeddings(primary_runtime):
|
||||
return None
|
||||
return resolve_provider_runtime(
|
||||
model=default_model,
|
||||
provider_name="openai",
|
||||
api_key=primary_runtime.api_key,
|
||||
api_base=primary_runtime.api_base,
|
||||
request_timeout_seconds=primary_runtime.request_timeout_seconds,
|
||||
extra_headers=dict(primary_runtime.extra_headers),
|
||||
routing=primary_runtime.routing,
|
||||
auxiliary=False,
|
||||
role="embedding",
|
||||
source="embedding_inherited",
|
||||
)
|
||||
|
||||
resolved_model = normalized.model or default_model
|
||||
resolved_provider = normalized.provider_name
|
||||
if resolved_provider in {None, "", "main", "auto"}:
|
||||
resolved_provider = "custom" if normalized.api_base else "openai"
|
||||
|
||||
api_key = normalized.api_key
|
||||
api_base = normalized.api_base
|
||||
extra_headers = dict(normalized.extra_headers)
|
||||
|
||||
if not api_base and _supports_openai_embeddings(primary_runtime):
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
runtime = resolve_provider_runtime(
|
||||
model=resolved_model,
|
||||
provider_name=resolved_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=normalized.routing,
|
||||
auxiliary=False,
|
||||
role="embedding",
|
||||
source="embedding_config",
|
||||
)
|
||||
if not _supports_openai_embeddings(runtime):
|
||||
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
|
||||
return runtime
|
||||
|
||||
|
||||
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
|
||||
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
|
||||
|
||||
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
|
||||
|
||||
|
||||
def _clone_runtime(
|
||||
runtime: ProviderRuntime,
|
||||
**changes: Any,
|
||||
) -> ProviderRuntime:
|
||||
"""基于现有 runtime 复制一个轻量变体。
|
||||
|
||||
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"extra_headers": dict(runtime.extra_headers),
|
||||
"routing": runtime.routing,
|
||||
"fallback_target": runtime.fallback_target,
|
||||
}
|
||||
payload.update(changes)
|
||||
return replace(runtime, **payload)
|
||||
|
||||
|
||||
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
|
||||
"""根据 registry 规则应用必要前缀。"""
|
||||
|
||||
resolved = model
|
||||
if gateway_mode:
|
||||
prefix = spec.litellm_prefix
|
||||
if spec.strip_model_prefix:
|
||||
resolved = resolved.split("/")[-1]
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
|
||||
if spec.litellm_prefix:
|
||||
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
|
||||
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
|
||||
resolved = f"{spec.litellm_prefix}/{resolved}"
|
||||
return resolved
|
||||
|
||||
|
||||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
||||
if "/" not in model:
|
||||
return model
|
||||
prefix, remainder = model.split("/", 1)
|
||||
if prefix.lower().replace("-", "_") != spec_name:
|
||||
return model
|
||||
return f"{canonical_prefix}/{remainder}"
|
||||
Reference in New Issue
Block a user