修改了nanobot,往Hermes agent的风格走,进度1/3

This commit is contained in:
2026-04-20 18:11:14 +08:00
parent cdfc222c9f
commit 36882a7d7b
261 changed files with 12659 additions and 604 deletions

View File

@ -0,0 +1,408 @@
"""Hermes 风格的 provider runtime resolution。"""
from __future__ import annotations
from dataclasses import dataclass, field, replace
from typing import Any
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
@dataclass(slots=True)
class ProviderRoutingConfig:
"""OpenRouter provider routing 配置。"""
sort: str | None = None
only: list[str] = field(default_factory=list)
ignore: list[str] = field(default_factory=list)
order: list[str] = field(default_factory=list)
require_parameters: bool = False
data_collection: str | None = None
@dataclass(slots=True)
class ProviderTarget:
"""一次 provider 选路请求的标准化配置。
这层不是具体 runtime而是“调用方想要什么”
- 用哪个 provider
- 跑哪个 model
- 是否指定自定义 base_url
- 是否带额外 headers / routing
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
"""
provider_name: str | None = None
model: str | None = None
api_key: str | None = None
api_base: str | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
request_timeout_seconds: float | None = None
routing: ProviderRoutingConfig | None = None
@dataclass(slots=True)
class ProviderRuntime:
"""运行时真正使用的 provider 解析结果。"""
spec: ProviderSpec
model: str
provider_name: str
api_mode: str
api_key: str | None = None
api_base: str | None = None
default_model: str | None = None
request_timeout_seconds: float | None = None
extra_headers: dict[str, str] = field(default_factory=dict)
routing: ProviderRoutingConfig | None = None
fallback_target: ProviderTarget | None = None
auxiliary: bool = False
role: str = "main"
source: str = "runtime"
def resolve_provider_runtime(
*,
model: str,
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
request_timeout_seconds: float | None = None,
extra_headers: dict[str, str] | None = None,
routing: ProviderRoutingConfig | None = None,
fallback_target: ProviderTarget | dict[str, Any] | None = None,
auxiliary: bool = False,
role: str = "main",
source: str = "runtime",
) -> ProviderRuntime:
"""把调用侧传入的配置解析成统一 runtime。"""
gateway = find_gateway(provider_name, api_key, api_base)
if gateway is not None:
spec = gateway
elif provider_name:
spec = find_by_name(provider_name)
else:
spec = find_by_model(model)
if spec is None:
if api_base:
spec = find_by_name("custom")
else:
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
resolved_api_base = api_base or spec.default_api_base or None
return ProviderRuntime(
spec=spec,
model=resolved_model,
provider_name=spec.name,
api_mode=spec.api_mode,
api_key=api_key,
api_base=resolved_api_base,
default_model=resolved_model,
request_timeout_seconds=request_timeout_seconds,
extra_headers=extra_headers or {},
routing=routing,
fallback_target=normalize_provider_target(fallback_target),
auxiliary=auxiliary,
role=role,
source=source,
)
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
"""把 dict/对象形式的 provider 配置收敛成统一结构。
这里兼容几种常见写法,便于后续接 CLI / config / gateway
- `provider` 或 `provider_name`
- `base_url` 或 `api_base`
- `headers` 或 `extra_headers`
- `timeout` 或 `request_timeout_seconds`
"""
if target is None:
return None
if isinstance(target, ProviderTarget):
return target
provider_name = target.get("provider_name")
if provider_name is None:
provider_name = target.get("provider")
api_base = target.get("api_base")
if api_base is None:
api_base = target.get("base_url")
extra_headers = target.get("extra_headers")
if extra_headers is None:
extra_headers = target.get("headers")
request_timeout_seconds = target.get("request_timeout_seconds")
if request_timeout_seconds is None:
request_timeout_seconds = target.get("timeout")
routing = target.get("routing")
if isinstance(routing, dict):
routing = ProviderRoutingConfig(**routing)
return ProviderTarget(
provider_name=provider_name,
model=target.get("model"),
api_key=target.get("api_key"),
api_base=api_base,
extra_headers=dict(extra_headers or {}),
request_timeout_seconds=request_timeout_seconds,
routing=routing,
)
def resolve_fallback_runtime(
primary_runtime: ProviderRuntime,
fallback_target: ProviderTarget | dict[str, Any] | None,
) -> ProviderRuntime | None:
"""把 fallback 配置解析成独立 runtime。
Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
"""
target = normalize_provider_target(fallback_target)
if target is None or not target.model:
return None
inferred_provider = target.provider_name
if inferred_provider in {None, "", "main"}:
inferred_provider = primary_runtime.provider_name
api_key = target.api_key
api_base = target.api_base
extra_headers = dict(target.extra_headers)
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
if inferred_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=target.model,
provider_name=inferred_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=target.routing,
auxiliary=False,
role="fallback",
source="fallback_config",
)
def resolve_auxiliary_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
task_name: str = "auxiliary",
) -> ProviderRuntime:
"""解析辅助任务专用 runtime。
支持三类输入:
- `None` / `provider=main`:直接复用主链 provider
- 显式 `provider + model`:走独立 provider
- 仅给 `model`:按模型名自动匹配 provider
"""
normalized = normalize_provider_target(target)
if normalized is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
)
provider_name = normalized.provider_name
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="main_runtime",
routing=normalized.routing or primary_runtime.routing,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
)
if provider_name == "main":
return resolve_provider_runtime(
model=normalized.model or primary_runtime.model,
provider_name=primary_runtime.provider_name,
api_key=normalized.api_key or primary_runtime.api_key,
api_base=normalized.api_base or primary_runtime.api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="main_runtime",
)
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
return _clone_runtime(
primary_runtime,
auxiliary=True,
role=task_name,
source="auto->main",
)
resolved_model = normalized.model or primary_runtime.model
resolved_provider = normalized.provider_name
if resolved_provider in {"auto", "", None} and not normalized.api_base:
# `auto` 的第一阶段实现保持保守:
# - 有显式 model 时按 model 匹配 provider
# - 匹配不到则回退主链 provider
spec = find_by_model(resolved_model)
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if resolved_provider == primary_runtime.provider_name and not api_base:
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
return resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing or primary_runtime.routing,
auxiliary=True,
role=task_name,
source="auxiliary_config",
)
def resolve_embedding_runtime(
primary_runtime: ProviderRuntime,
target: ProviderTarget | dict[str, Any] | None = None,
*,
default_model: str = "text-embedding-v4",
) -> ProviderRuntime | None:
"""解析 embedding 专用 runtime。
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
"""
normalized = normalize_provider_target(target)
if normalized is None:
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
if not _supports_openai_embeddings(primary_runtime):
return None
return resolve_provider_runtime(
model=default_model,
provider_name="openai",
api_key=primary_runtime.api_key,
api_base=primary_runtime.api_base,
request_timeout_seconds=primary_runtime.request_timeout_seconds,
extra_headers=dict(primary_runtime.extra_headers),
routing=primary_runtime.routing,
auxiliary=False,
role="embedding",
source="embedding_inherited",
)
resolved_model = normalized.model or default_model
resolved_provider = normalized.provider_name
if resolved_provider in {None, "", "main", "auto"}:
resolved_provider = "custom" if normalized.api_base else "openai"
api_key = normalized.api_key
api_base = normalized.api_base
extra_headers = dict(normalized.extra_headers)
if not api_base and _supports_openai_embeddings(primary_runtime):
api_key = api_key or primary_runtime.api_key
api_base = api_base or primary_runtime.api_base
if not extra_headers:
extra_headers = dict(primary_runtime.extra_headers)
runtime = resolve_provider_runtime(
model=resolved_model,
provider_name=resolved_provider,
api_key=api_key,
api_base=api_base,
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
extra_headers=extra_headers,
routing=normalized.routing,
auxiliary=False,
role="embedding",
source="embedding_config",
)
if not _supports_openai_embeddings(runtime):
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
return runtime
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
def _clone_runtime(
runtime: ProviderRuntime,
**changes: Any,
) -> ProviderRuntime:
"""基于现有 runtime 复制一个轻量变体。
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
"""
payload = {
"extra_headers": dict(runtime.extra_headers),
"routing": runtime.routing,
"fallback_target": runtime.fallback_target,
}
payload.update(changes)
return replace(runtime, **payload)
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
"""根据 registry 规则应用必要前缀。"""
resolved = model
if gateway_mode:
prefix = spec.litellm_prefix
if spec.strip_model_prefix:
resolved = resolved.split("/")[-1]
if prefix and not resolved.startswith(f"{prefix}/"):
resolved = f"{prefix}/{resolved}"
return resolved
if spec.litellm_prefix:
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
resolved = f"{spec.litellm_prefix}/{resolved}"
return resolved
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
if "/" not in model:
return model
prefix, remainder = model.split("/", 1)
if prefix.lower().replace("-", "_") != spec_name:
return model
return f"{canonical_prefix}/{remainder}"