Files
beaver_project/app-instance/backend/beaver/engine/providers/registry.py

250 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Provider registry: 统一维护 provider 元数据与匹配规则。"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class ProviderSpec:
"""单个 provider 的元数据定义。"""
name: str
keywords: tuple[str, ...]
env_key: str
display_name: str = ""
litellm_prefix: str = ""
skip_prefixes: tuple[str, ...] = ()
env_extras: tuple[tuple[str, str], ...] = ()
is_gateway: bool = False
is_local: bool = False
detect_by_key_prefix: str = ""
detect_by_base_keyword: str = ""
default_api_base: str = ""
strip_model_prefix: bool = False
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
is_oauth: bool = False
is_direct: bool = False
supports_prompt_caching: bool = False
api_mode: str = "chat_completions"
provider_impl: str = "litellm"
@property
def label(self) -> str:
return self.display_name or self.name.title()
PROVIDERS: tuple[ProviderSpec, ...] = (
ProviderSpec(
name="custom",
keywords=(),
env_key="",
display_name="Custom",
is_direct=True,
provider_impl="custom",
api_mode="chat_completions",
),
ProviderSpec(
name="openrouter",
keywords=("openrouter",),
env_key="OPENROUTER_API_KEY",
display_name="OpenRouter",
litellm_prefix="openrouter",
is_gateway=True,
detect_by_key_prefix="sk-or-",
detect_by_base_keyword="openrouter",
default_api_base="https://openrouter.ai/api/v1",
supports_prompt_caching=True,
),
ProviderSpec(
name="aihubmix",
keywords=("aihubmix",),
env_key="OPENAI_API_KEY",
display_name="AiHubMix",
litellm_prefix="openai",
is_gateway=True,
detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1",
strip_model_prefix=True,
),
ProviderSpec(
name="siliconflow",
keywords=("siliconflow",),
env_key="OPENAI_API_KEY",
display_name="SiliconFlow",
litellm_prefix="openai",
is_gateway=True,
detect_by_base_keyword="siliconflow",
default_api_base="https://api.siliconflow.cn/v1",
),
ProviderSpec(
name="volcengine",
keywords=("volcengine", "volces", "ark"),
env_key="OPENAI_API_KEY",
display_name="VolcEngine",
litellm_prefix="volcengine",
is_gateway=True,
detect_by_base_keyword="volces",
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
),
ProviderSpec(
name="anthropic",
keywords=("anthropic", "claude"),
env_key="ANTHROPIC_API_KEY",
display_name="Anthropic",
supports_prompt_caching=True,
api_mode="anthropic_messages",
provider_impl="anthropic",
),
ProviderSpec(
name="openai",
keywords=("openai", "gpt"),
env_key="OPENAI_API_KEY",
display_name="OpenAI",
),
ProviderSpec(
name="openai_codex",
keywords=("openai-codex", "codex"),
env_key="",
display_name="OpenAI Codex",
is_oauth=True,
detect_by_base_keyword="codex",
default_api_base="https://chatgpt.com/backend-api",
api_mode="codex_responses",
provider_impl="codex",
),
ProviderSpec(
name="github_copilot",
keywords=("github_copilot", "copilot"),
env_key="",
display_name="Github Copilot",
litellm_prefix="github_copilot",
skip_prefixes=("github_copilot/",),
is_oauth=True,
),
ProviderSpec(
name="deepseek",
keywords=("deepseek",),
env_key="DEEPSEEK_API_KEY",
display_name="DeepSeek",
litellm_prefix="deepseek",
skip_prefixes=("deepseek/",),
),
ProviderSpec(
name="gemini",
keywords=("gemini",),
env_key="GEMINI_API_KEY",
display_name="Gemini",
litellm_prefix="gemini",
skip_prefixes=("gemini/",),
),
ProviderSpec(
name="zhipu",
keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY",
display_name="Zhipu AI",
litellm_prefix="zai",
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
),
ProviderSpec(
name="dashscope",
keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY",
display_name="DashScope",
litellm_prefix="dashscope",
skip_prefixes=("dashscope/", "openrouter/"),
),
ProviderSpec(
name="moonshot",
keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY",
display_name="Moonshot",
litellm_prefix="moonshot",
skip_prefixes=("moonshot/", "openrouter/"),
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
default_api_base="https://api.moonshot.ai/v1",
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
),
ProviderSpec(
name="minimax",
keywords=("minimax",),
env_key="MINIMAX_API_KEY",
display_name="MiniMax",
litellm_prefix="minimax",
skip_prefixes=("minimax/", "openrouter/"),
default_api_base="https://api.minimax.io/v1",
),
ProviderSpec(
name="vllm",
keywords=("vllm",),
env_key="HOSTED_VLLM_API_KEY",
display_name="vLLM/Local",
litellm_prefix="hosted_vllm",
is_local=True,
),
ProviderSpec(
name="groq",
keywords=("groq",),
env_key="GROQ_API_KEY",
display_name="Groq",
litellm_prefix="groq",
skip_prefixes=("groq/",),
),
)
def find_by_name(name: str) -> ProviderSpec | None:
for spec in PROVIDERS:
if spec.name == name:
return spec
return None
def find_by_model(model: str) -> ProviderSpec | None:
"""按模型名关键词匹配标准 provider。"""
model_lower = model.lower()
model_normalized = model_lower.replace("-", "_")
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
normalized_prefix = model_prefix.replace("-", "_")
standard_specs = [spec for spec in PROVIDERS if not spec.is_gateway and not spec.is_local]
# 显式前缀优先级最高。
# 这里不能只看 standard provider
# - `openrouter/...` 应该直接命中 openrouter
# - `hosted_vllm/...` 应该能回到 vllm 这个本地 provider
# - `github_copilot/...codex` 也不应被误判成 openai_codex
for spec in PROVIDERS:
aliases = {spec.name}
if spec.litellm_prefix:
aliases.add(spec.litellm_prefix.replace("-", "_"))
if model_prefix and normalized_prefix in aliases:
return spec
for spec in standard_specs:
if any(keyword in model_lower or keyword.replace("-", "_") in model_normalized for keyword in spec.keywords):
return spec
return None
def find_gateway(
provider_name: str | None = None,
api_key: str | None = None,
api_base: str | None = None,
) -> ProviderSpec | None:
"""按 config key / api_key / api_base 识别 gateway 或 local provider。"""
if provider_name:
spec = find_by_name(provider_name)
if spec and (spec.is_gateway or spec.is_local):
return spec
for spec in PROVIDERS:
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
return spec
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
return spec
return None