236 lines
7.7 KiB
Python
236 lines
7.7 KiB
Python
"""Provider runtime 的统一工厂入口。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Any
|
||
|
||
from .anthropic import AnthropicProvider
|
||
from .base import LLMProvider
|
||
from .chain import FallbackProviderChain
|
||
from .codex import OpenAICodexProvider
|
||
from .custom import CustomProvider
|
||
from .litellm import LiteLLMProvider
|
||
from .runtime import (
|
||
ProviderRoutingConfig,
|
||
ProviderRuntime,
|
||
ProviderTarget,
|
||
normalize_provider_target,
|
||
resolve_auxiliary_runtime,
|
||
resolve_embedding_runtime,
|
||
resolve_fallback_runtime,
|
||
resolve_provider_runtime,
|
||
)
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class ProviderBundle:
|
||
"""一次运行所需的 provider 组合。
|
||
|
||
这里把三条常见链路收口到一起:
|
||
- `main`:主对话
|
||
- `fallback`:主链失败后的备用 provider
|
||
- `auxiliary`:搜索摘要、压缩、memory flush 等辅助任务
|
||
"""
|
||
|
||
main_runtime: ProviderRuntime
|
||
main_provider: LLMProvider
|
||
fallback_runtime: ProviderRuntime | None = None
|
||
fallback_provider: LLMProvider | None = None
|
||
auxiliary_runtime: ProviderRuntime | None = None
|
||
auxiliary_provider: LLMProvider | None = None
|
||
embedding_runtime: ProviderRuntime | None = None
|
||
|
||
|
||
def build_provider_runtime(**kwargs: Any) -> ProviderRuntime:
|
||
"""构建统一 provider runtime。"""
|
||
|
||
return resolve_provider_runtime(**kwargs)
|
||
|
||
|
||
def make_provider_from_runtime(runtime: ProviderRuntime) -> LLMProvider:
|
||
"""根据 runtime 创建具体 provider 实例。"""
|
||
|
||
if runtime.spec.provider_impl == "custom":
|
||
return CustomProvider(
|
||
api_key=runtime.api_key or "no-key",
|
||
api_base=runtime.api_base or "http://localhost:8000/v1",
|
||
default_model=runtime.default_model or runtime.model,
|
||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||
)
|
||
|
||
if runtime.spec.provider_impl == "codex":
|
||
return OpenAICodexProvider(
|
||
default_model=runtime.default_model or runtime.model,
|
||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||
)
|
||
|
||
if runtime.spec.provider_impl == "anthropic":
|
||
return AnthropicProvider(
|
||
api_key=runtime.api_key,
|
||
default_model=runtime.default_model or runtime.model,
|
||
api_base=runtime.api_base,
|
||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||
)
|
||
|
||
return LiteLLMProvider(
|
||
api_key=runtime.api_key,
|
||
api_base=runtime.api_base,
|
||
default_model=runtime.default_model or runtime.model,
|
||
provider_name=runtime.provider_name,
|
||
extra_headers=runtime.extra_headers,
|
||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||
routing=runtime.routing,
|
||
)
|
||
|
||
|
||
def make_main_provider(**kwargs: Any) -> tuple[ProviderRuntime, LLMProvider]:
|
||
"""构建主对话 provider。"""
|
||
|
||
fallback_target = kwargs.pop("fallback_target", None)
|
||
if fallback_target is None and "fallback_model" in kwargs:
|
||
fallback_target = kwargs.pop("fallback_model")
|
||
|
||
runtime = build_provider_runtime(
|
||
auxiliary=False,
|
||
fallback_target=fallback_target,
|
||
role="main",
|
||
source="main_config",
|
||
**kwargs,
|
||
)
|
||
provider = make_provider_from_runtime(runtime)
|
||
fallback_pair = make_fallback_provider(runtime, fallback_target)
|
||
if fallback_pair is None:
|
||
return runtime, provider
|
||
fallback_runtime, fallback_provider = fallback_pair
|
||
return runtime, FallbackProviderChain(runtime, provider, fallback_runtime, fallback_provider)
|
||
|
||
|
||
def make_fallback_provider(
|
||
primary_runtime: ProviderRuntime,
|
||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||
) -> tuple[ProviderRuntime, LLMProvider] | None:
|
||
"""构建 fallback provider。"""
|
||
|
||
runtime = resolve_fallback_runtime(primary_runtime, fallback_target or primary_runtime.fallback_target)
|
||
if runtime is None:
|
||
return None
|
||
return runtime, make_provider_from_runtime(runtime)
|
||
|
||
|
||
def make_aux_provider(
|
||
main_runtime: ProviderRuntime | None = None,
|
||
*,
|
||
target: ProviderTarget | dict[str, Any] | None = None,
|
||
task_name: str = "auxiliary",
|
||
**kwargs: Any,
|
||
) -> tuple[ProviderRuntime, LLMProvider]:
|
||
"""构建辅助任务 provider。"""
|
||
|
||
if target is None and kwargs:
|
||
target = kwargs
|
||
|
||
if main_runtime is not None:
|
||
runtime = resolve_auxiliary_runtime(main_runtime, target, task_name=task_name)
|
||
else:
|
||
normalized = normalize_provider_target(target)
|
||
if normalized is None or not normalized.model:
|
||
raise ValueError("Auxiliary provider without main_runtime requires at least a model")
|
||
runtime = build_provider_runtime(
|
||
model=normalized.model,
|
||
provider_name=normalized.provider_name,
|
||
api_key=normalized.api_key,
|
||
api_base=normalized.api_base,
|
||
request_timeout_seconds=normalized.request_timeout_seconds,
|
||
extra_headers=normalized.extra_headers,
|
||
routing=normalized.routing,
|
||
auxiliary=True,
|
||
role=task_name,
|
||
source="auxiliary_config",
|
||
)
|
||
return runtime, make_provider_from_runtime(runtime)
|
||
|
||
|
||
def make_embedding_runtime(
|
||
main_runtime: ProviderRuntime,
|
||
*,
|
||
target: ProviderTarget | dict[str, Any] | None = None,
|
||
default_model: str = "text-embedding-v4",
|
||
) -> ProviderRuntime | None:
|
||
"""构建 embedding 专用 runtime。"""
|
||
|
||
return resolve_embedding_runtime(main_runtime, target=target, default_model=default_model)
|
||
|
||
|
||
def make_provider_bundle(
|
||
*,
|
||
auxiliary_target: ProviderTarget | dict[str, Any] | None = None,
|
||
auxiliary_task_name: str = "auxiliary",
|
||
embedding_target: ProviderTarget | dict[str, Any] | None = None,
|
||
embedding_model: str = "text-embedding-v4",
|
||
**kwargs: Any,
|
||
) -> ProviderBundle:
|
||
"""一次性构建 main/fallback/aux 三条 provider 链。"""
|
||
|
||
runtime_kwargs = dict(kwargs)
|
||
fallback_target = runtime_kwargs.pop("fallback_target", None)
|
||
if fallback_target is None and "fallback_model" in kwargs:
|
||
fallback_target = runtime_kwargs.pop("fallback_model")
|
||
|
||
main_runtime = build_provider_runtime(
|
||
auxiliary=False,
|
||
fallback_target=fallback_target,
|
||
role="main",
|
||
source="main_config",
|
||
**runtime_kwargs,
|
||
)
|
||
primary_provider = make_provider_from_runtime(main_runtime)
|
||
fallback_pair = make_fallback_provider(main_runtime, fallback_target)
|
||
if fallback_pair is None:
|
||
main_provider: LLMProvider = primary_provider
|
||
fallback_runtime = None
|
||
fallback_provider = None
|
||
else:
|
||
fallback_runtime, fallback_provider = fallback_pair
|
||
main_provider = FallbackProviderChain(main_runtime, primary_provider, fallback_runtime, fallback_provider)
|
||
|
||
auxiliary_runtime = None
|
||
auxiliary_provider = None
|
||
if auxiliary_target is not None:
|
||
auxiliary_runtime, auxiliary_provider = make_aux_provider(
|
||
main_runtime,
|
||
target=auxiliary_target,
|
||
task_name=auxiliary_task_name,
|
||
)
|
||
|
||
embedding_runtime = make_embedding_runtime(
|
||
main_runtime,
|
||
target=embedding_target,
|
||
default_model=embedding_model,
|
||
)
|
||
|
||
return ProviderBundle(
|
||
main_runtime=main_runtime,
|
||
main_provider=main_provider,
|
||
fallback_runtime=fallback_runtime,
|
||
fallback_provider=fallback_provider,
|
||
auxiliary_runtime=auxiliary_runtime,
|
||
auxiliary_provider=auxiliary_provider,
|
||
embedding_runtime=embedding_runtime,
|
||
)
|
||
|
||
|
||
__all__ = [
|
||
"ProviderBundle",
|
||
"ProviderRoutingConfig",
|
||
"ProviderRuntime",
|
||
"ProviderTarget",
|
||
"build_provider_runtime",
|
||
"make_aux_provider",
|
||
"make_embedding_runtime",
|
||
"make_fallback_provider",
|
||
"make_main_provider",
|
||
"make_provider_bundle",
|
||
"make_provider_from_runtime",
|
||
]
|