"""Provider runtime 的统一工厂入口。""" from __future__ import annotations from dataclasses import dataclass from typing import Any from .anthropic import AnthropicProvider from .base import LLMProvider from .chain import FallbackProviderChain from .codex import OpenAICodexProvider from .custom import CustomProvider from .litellm import LiteLLMProvider from .runtime import ( ProviderRoutingConfig, ProviderRuntime, ProviderTarget, normalize_provider_target, resolve_auxiliary_runtime, resolve_embedding_runtime, resolve_fallback_runtime, resolve_provider_runtime, ) @dataclass(slots=True) class ProviderBundle: """一次运行所需的 provider 组合。 这里把三条常见链路收口到一起: - `main`:主对话 - `fallback`:主链失败后的备用 provider - `auxiliary`:搜索摘要、压缩、memory flush 等辅助任务 """ main_runtime: ProviderRuntime main_provider: LLMProvider fallback_runtime: ProviderRuntime | None = None fallback_provider: LLMProvider | None = None auxiliary_runtime: ProviderRuntime | None = None auxiliary_provider: LLMProvider | None = None embedding_runtime: ProviderRuntime | None = None def build_provider_runtime(**kwargs: Any) -> ProviderRuntime: """构建统一 provider runtime。""" return resolve_provider_runtime(**kwargs) def make_provider_from_runtime(runtime: ProviderRuntime) -> LLMProvider: """根据 runtime 创建具体 provider 实例。""" if runtime.spec.provider_impl == "custom": return CustomProvider( api_key=runtime.api_key or "no-key", api_base=runtime.api_base or "http://localhost:8000/v1", default_model=runtime.default_model or runtime.model, request_timeout_seconds=runtime.request_timeout_seconds, ) if runtime.spec.provider_impl == "codex": return OpenAICodexProvider( default_model=runtime.default_model or runtime.model, request_timeout_seconds=runtime.request_timeout_seconds, ) if runtime.spec.provider_impl == "anthropic": return AnthropicProvider( api_key=runtime.api_key, default_model=runtime.default_model or runtime.model, api_base=runtime.api_base, request_timeout_seconds=runtime.request_timeout_seconds, ) return LiteLLMProvider( api_key=runtime.api_key, api_base=runtime.api_base, default_model=runtime.default_model or runtime.model, provider_name=runtime.provider_name, extra_headers=runtime.extra_headers, request_timeout_seconds=runtime.request_timeout_seconds, routing=runtime.routing, ) def make_main_provider(**kwargs: Any) -> tuple[ProviderRuntime, LLMProvider]: """构建主对话 provider。""" fallback_target = kwargs.pop("fallback_target", None) if fallback_target is None and "fallback_model" in kwargs: fallback_target = kwargs.pop("fallback_model") runtime = build_provider_runtime( auxiliary=False, fallback_target=fallback_target, role="main", source="main_config", **kwargs, ) provider = make_provider_from_runtime(runtime) fallback_pair = make_fallback_provider(runtime, fallback_target) if fallback_pair is None: return runtime, provider fallback_runtime, fallback_provider = fallback_pair return runtime, FallbackProviderChain(runtime, provider, fallback_runtime, fallback_provider) def make_fallback_provider( primary_runtime: ProviderRuntime, fallback_target: ProviderTarget | dict[str, Any] | None = None, ) -> tuple[ProviderRuntime, LLMProvider] | None: """构建 fallback provider。""" runtime = resolve_fallback_runtime(primary_runtime, fallback_target or primary_runtime.fallback_target) if runtime is None: return None return runtime, make_provider_from_runtime(runtime) def make_aux_provider( main_runtime: ProviderRuntime | None = None, *, target: ProviderTarget | dict[str, Any] | None = None, task_name: str = "auxiliary", **kwargs: Any, ) -> tuple[ProviderRuntime, LLMProvider]: """构建辅助任务 provider。""" if target is None and kwargs: target = kwargs if main_runtime is not None: runtime = resolve_auxiliary_runtime(main_runtime, target, task_name=task_name) else: normalized = normalize_provider_target(target) if normalized is None or not normalized.model: raise ValueError("Auxiliary provider without main_runtime requires at least a model") runtime = build_provider_runtime( model=normalized.model, provider_name=normalized.provider_name, api_key=normalized.api_key, api_base=normalized.api_base, request_timeout_seconds=normalized.request_timeout_seconds, extra_headers=normalized.extra_headers, routing=normalized.routing, auxiliary=True, role=task_name, source="auxiliary_config", ) return runtime, make_provider_from_runtime(runtime) def make_embedding_runtime( main_runtime: ProviderRuntime, *, target: ProviderTarget | dict[str, Any] | None = None, default_model: str = "text-embedding-v4", ) -> ProviderRuntime | None: """构建 embedding 专用 runtime。""" return resolve_embedding_runtime(main_runtime, target=target, default_model=default_model) def make_provider_bundle( *, auxiliary_target: ProviderTarget | dict[str, Any] | None = None, auxiliary_task_name: str = "auxiliary", embedding_target: ProviderTarget | dict[str, Any] | None = None, embedding_model: str = "text-embedding-v4", **kwargs: Any, ) -> ProviderBundle: """一次性构建 main/fallback/aux 三条 provider 链。""" runtime_kwargs = dict(kwargs) fallback_target = runtime_kwargs.pop("fallback_target", None) if fallback_target is None and "fallback_model" in kwargs: fallback_target = runtime_kwargs.pop("fallback_model") main_runtime = build_provider_runtime( auxiliary=False, fallback_target=fallback_target, role="main", source="main_config", **runtime_kwargs, ) primary_provider = make_provider_from_runtime(main_runtime) fallback_pair = make_fallback_provider(main_runtime, fallback_target) if fallback_pair is None: main_provider: LLMProvider = primary_provider fallback_runtime = None fallback_provider = None else: fallback_runtime, fallback_provider = fallback_pair main_provider = FallbackProviderChain(main_runtime, primary_provider, fallback_runtime, fallback_provider) auxiliary_runtime = None auxiliary_provider = None if auxiliary_target is not None: auxiliary_runtime, auxiliary_provider = make_aux_provider( main_runtime, target=auxiliary_target, task_name=auxiliary_task_name, ) embedding_runtime = make_embedding_runtime( main_runtime, target=embedding_target, default_model=embedding_model, ) return ProviderBundle( main_runtime=main_runtime, main_provider=main_provider, fallback_runtime=fallback_runtime, fallback_provider=fallback_provider, auxiliary_runtime=auxiliary_runtime, auxiliary_provider=auxiliary_provider, embedding_runtime=embedding_runtime, ) __all__ = [ "ProviderBundle", "ProviderRoutingConfig", "ProviderRuntime", "ProviderTarget", "build_provider_runtime", "make_aux_provider", "make_embedding_runtime", "make_fallback_provider", "make_main_provider", "make_provider_bundle", "make_provider_from_runtime", ]