"""Hermes 风格的 provider runtime resolution。""" from __future__ import annotations from dataclasses import dataclass, field, replace from typing import Any from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway @dataclass(slots=True) class ProviderRoutingConfig: """OpenRouter provider routing 配置。""" sort: str | None = None only: list[str] = field(default_factory=list) ignore: list[str] = field(default_factory=list) order: list[str] = field(default_factory=list) require_parameters: bool = False data_collection: str | None = None @dataclass(slots=True) class ProviderTarget: """一次 provider 选路请求的标准化配置。 这层不是具体 runtime,而是“调用方想要什么”: - 用哪个 provider - 跑哪个 model - 是否指定自定义 base_url - 是否带额外 headers / routing 后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。 """ provider_name: str | None = None model: str | None = None api_key: str | None = None api_base: str | None = None extra_headers: dict[str, str] = field(default_factory=dict) request_timeout_seconds: float | None = None routing: ProviderRoutingConfig | None = None @dataclass(slots=True) class ProviderRuntime: """运行时真正使用的 provider 解析结果。""" spec: ProviderSpec model: str provider_name: str api_mode: str api_key: str | None = None api_base: str | None = None default_model: str | None = None request_timeout_seconds: float | None = None extra_headers: dict[str, str] = field(default_factory=dict) routing: ProviderRoutingConfig | None = None fallback_target: ProviderTarget | None = None auxiliary: bool = False role: str = "main" source: str = "runtime" def resolve_provider_runtime( *, model: str, provider_name: str | None = None, api_key: str | None = None, api_base: str | None = None, request_timeout_seconds: float | None = None, extra_headers: dict[str, str] | None = None, routing: ProviderRoutingConfig | None = None, fallback_target: ProviderTarget | dict[str, Any] | None = None, auxiliary: bool = False, role: str = "main", source: str = "runtime", ) -> ProviderRuntime: """把调用侧传入的配置解析成统一 runtime。""" gateway = find_gateway(provider_name, api_key, api_base) if gateway is not None: spec = gateway elif provider_name: spec = find_by_name(provider_name) else: spec = find_by_model(model) if spec is None: if api_base: spec = find_by_name("custom") else: raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}") resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None)) resolved_api_base = api_base or spec.default_api_base or None return ProviderRuntime( spec=spec, model=resolved_model, provider_name=spec.name, api_mode=spec.api_mode, api_key=api_key, api_base=resolved_api_base, default_model=resolved_model, request_timeout_seconds=request_timeout_seconds, extra_headers=extra_headers or {}, routing=routing, fallback_target=normalize_provider_target(fallback_target), auxiliary=auxiliary, role=role, source=source, ) def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None: """把 dict/对象形式的 provider 配置收敛成统一结构。 这里兼容几种常见写法,便于后续接 CLI / config / gateway: - `provider` 或 `provider_name` - `base_url` 或 `api_base` - `headers` 或 `extra_headers` - `timeout` 或 `request_timeout_seconds` """ if target is None: return None if isinstance(target, ProviderTarget): return target provider_name = target.get("provider_name") if provider_name is None: provider_name = target.get("provider") api_base = target.get("api_base") if api_base is None: api_base = target.get("base_url") extra_headers = target.get("extra_headers") if extra_headers is None: extra_headers = target.get("headers") request_timeout_seconds = target.get("request_timeout_seconds") if request_timeout_seconds is None: request_timeout_seconds = target.get("timeout") routing = target.get("routing") if isinstance(routing, dict): routing = ProviderRoutingConfig(**routing) return ProviderTarget( provider_name=provider_name, model=target.get("model"), api_key=target.get("api_key"), api_base=api_base, extra_headers=dict(extra_headers or {}), request_timeout_seconds=request_timeout_seconds, routing=routing, ) def resolve_fallback_runtime( primary_runtime: ProviderRuntime, fallback_target: ProviderTarget | dict[str, Any] | None, ) -> ProviderRuntime | None: """把 fallback 配置解析成独立 runtime。 Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。 这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。 """ target = normalize_provider_target(fallback_target) if target is None or not target.model: return None inferred_provider = target.provider_name if inferred_provider in {None, "", "main"}: inferred_provider = primary_runtime.provider_name api_key = target.api_key api_base = target.api_base extra_headers = dict(target.extra_headers) # 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。 if inferred_provider == primary_runtime.provider_name and not api_base: api_key = api_key or primary_runtime.api_key api_base = api_base or primary_runtime.api_base if not extra_headers: extra_headers = dict(primary_runtime.extra_headers) return resolve_provider_runtime( model=target.model, provider_name=inferred_provider, api_key=api_key, api_base=api_base, request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds, extra_headers=extra_headers, routing=target.routing, auxiliary=False, role="fallback", source="fallback_config", ) def resolve_auxiliary_runtime( primary_runtime: ProviderRuntime, target: ProviderTarget | dict[str, Any] | None = None, *, task_name: str = "auxiliary", ) -> ProviderRuntime: """解析辅助任务专用 runtime。 支持三类输入: - `None` / `provider=main`:直接复用主链 provider - 显式 `provider + model`:走独立 provider - 仅给 `model`:按模型名自动匹配 provider """ normalized = normalize_provider_target(target) if normalized is None: return _clone_runtime( primary_runtime, auxiliary=True, role=task_name, source="main_runtime", ) provider_name = normalized.provider_name if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model: return _clone_runtime( primary_runtime, auxiliary=True, role=task_name, source="main_runtime", routing=normalized.routing or primary_runtime.routing, extra_headers=normalized.extra_headers or primary_runtime.extra_headers, request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds, ) if provider_name == "main": return resolve_provider_runtime( model=normalized.model or primary_runtime.model, provider_name=primary_runtime.provider_name, api_key=normalized.api_key or primary_runtime.api_key, api_base=normalized.api_base or primary_runtime.api_base, request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds, extra_headers=normalized.extra_headers or primary_runtime.extra_headers, routing=normalized.routing or primary_runtime.routing, auxiliary=True, role=task_name, source="main_runtime", ) if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None: return _clone_runtime( primary_runtime, auxiliary=True, role=task_name, source="auto->main", ) resolved_model = normalized.model or primary_runtime.model resolved_provider = normalized.provider_name if resolved_provider in {"auto", "", None} and not normalized.api_base: # `auto` 的第一阶段实现保持保守: # - 有显式 model 时按 model 匹配 provider # - 匹配不到则回退主链 provider spec = find_by_model(resolved_model) resolved_provider = spec.name if spec is not None else primary_runtime.provider_name api_key = normalized.api_key api_base = normalized.api_base extra_headers = dict(normalized.extra_headers) if resolved_provider == primary_runtime.provider_name and not api_base: api_key = api_key or primary_runtime.api_key api_base = api_base or primary_runtime.api_base if not extra_headers: extra_headers = dict(primary_runtime.extra_headers) return resolve_provider_runtime( model=resolved_model, provider_name=resolved_provider, api_key=api_key, api_base=api_base, request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds, extra_headers=extra_headers, routing=normalized.routing or primary_runtime.routing, auxiliary=True, role=task_name, source="auxiliary_config", ) def resolve_embedding_runtime( primary_runtime: ProviderRuntime, target: ProviderTarget | dict[str, Any] | None = None, *, default_model: str = "text-embedding-v4", ) -> ProviderRuntime | None: """解析 embedding 专用 runtime。 目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层, 避免上层检索逻辑直接偷拿 main/aux provider 的凭据。 """ normalized = normalize_provider_target(target) if normalized is None: # 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible # 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。 if not _supports_openai_embeddings(primary_runtime): return None return resolve_provider_runtime( model=default_model, provider_name="openai", api_key=primary_runtime.api_key, api_base=primary_runtime.api_base, request_timeout_seconds=primary_runtime.request_timeout_seconds, extra_headers=dict(primary_runtime.extra_headers), routing=primary_runtime.routing, auxiliary=False, role="embedding", source="embedding_inherited", ) resolved_model = normalized.model or default_model resolved_provider = normalized.provider_name if resolved_provider in {None, "", "main", "auto"}: resolved_provider = "custom" if normalized.api_base else "openai" api_key = normalized.api_key api_base = normalized.api_base extra_headers = dict(normalized.extra_headers) if not api_base and _supports_openai_embeddings(primary_runtime): api_key = api_key or primary_runtime.api_key api_base = api_base or primary_runtime.api_base if not extra_headers: extra_headers = dict(primary_runtime.extra_headers) runtime = resolve_provider_runtime( model=resolved_model, provider_name=resolved_provider, api_key=api_key, api_base=api_base, request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds, extra_headers=extra_headers, routing=normalized.routing, auxiliary=False, role="embedding", source="embedding_config", ) if not _supports_openai_embeddings(runtime): raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider") return runtime def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool: """当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。""" return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"} def _clone_runtime( runtime: ProviderRuntime, **changes: Any, ) -> ProviderRuntime: """基于现有 runtime 复制一个轻量变体。 用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。 """ payload = { "extra_headers": dict(runtime.extra_headers), "routing": runtime.routing, "fallback_target": runtime.fallback_target, } payload.update(changes) return replace(runtime, **payload) def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str: """根据 registry 规则应用必要前缀。""" resolved = model if gateway_mode: prefix = spec.litellm_prefix if spec.strip_model_prefix: resolved = resolved.split("/")[-1] if prefix and not resolved.startswith(f"{prefix}/"): resolved = f"{prefix}/{resolved}" return resolved if spec.litellm_prefix: resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix) if not any(resolved.startswith(item) for item in spec.skip_prefixes): resolved = f"{spec.litellm_prefix}/{resolved}" return resolved def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str: if "/" not in model: return model prefix, remainder = model.split("/", 1) if prefix.lower().replace("-", "_") != spec_name: return model return f"{canonical_prefix}/{remainder}"