"""Provider chain helpers. 这里先实现最小可用的 fallback chain: - 每次调用都先尝试主 provider - 本次调用主 provider 返回 `finish_reason=error` 时,再切到 fallback - fallback 只影响当前这一次调用,不会污染下一次 run 的首选链路 这样后面 `AgentLoop` 不需要自己处理“主模型挂了再换一个 provider”。 """ from __future__ import annotations from .base import LLMProvider, LLMResponse from .runtime import ProviderRuntime class FallbackProviderChain(LLMProvider): """把 primary/fallback provider 封装成一个统一的 LLMProvider。""" def __init__( self, primary_runtime: ProviderRuntime, primary_provider: LLMProvider, fallback_runtime: ProviderRuntime | None = None, fallback_provider: LLMProvider | None = None, ) -> None: super().__init__( api_key=primary_runtime.api_key, api_base=primary_runtime.api_base, request_timeout_seconds=primary_runtime.request_timeout_seconds, ) self.primary_runtime = primary_runtime self.primary_provider = primary_provider self.fallback_runtime = fallback_runtime self.fallback_provider = fallback_provider # 这里只记录“最近一次 chat 实际用了哪条链”,用于调试和测试。 # 真正的选路决策必须按调用粒度重新从 primary 开始,不能跨调用粘住 fallback。 self._last_runtime = primary_runtime self._last_provider = primary_provider self._last_call_used_fallback = False @property def fallback_activated(self) -> bool: """最近一次 chat 是否实际用到了 fallback。""" return self._last_call_used_fallback @property def active_runtime(self) -> ProviderRuntime: """最近一次 chat 实际使用的 runtime。""" return self._last_runtime async def chat( self, messages: list[dict], tools: list[dict] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, thinking_enabled: bool | None = None, ) -> LLMResponse: self._last_provider = self.primary_provider self._last_runtime = self.primary_runtime self._last_call_used_fallback = False response = await self._safe_chat( self.primary_provider, self.primary_runtime, messages=messages, tools=tools, model=model or self.primary_runtime.model, max_tokens=max_tokens, temperature=temperature, thinking_enabled=thinking_enabled, ) response = self._decorate_response(response, self.primary_runtime) if not self._should_activate_fallback(response): return response assert self.fallback_provider is not None assert self.fallback_runtime is not None self._last_provider = self.fallback_provider self._last_runtime = self.fallback_runtime self._last_call_used_fallback = True response = await self._safe_chat( self.fallback_provider, self.fallback_runtime, messages=messages, tools=tools, model=self.fallback_runtime.model, max_tokens=max_tokens, temperature=temperature, thinking_enabled=thinking_enabled, ) return self._decorate_response(response, self.fallback_runtime) def get_default_model(self) -> str: return self.primary_runtime.model def _should_activate_fallback(self, response: LLMResponse) -> bool: return ( self.fallback_provider is not None and self.fallback_runtime is not None and response.finish_reason == "error" ) @staticmethod async def _safe_chat( provider: LLMProvider, runtime: ProviderRuntime, *, messages: list[dict], tools: list[dict] | None, model: str, max_tokens: int, temperature: float, thinking_enabled: bool | None, ) -> LLMResponse: """把 provider 抛出的异常也收敛成统一 error response。 这样 fallback 的触发条件就不依赖“每个 provider 都记得自己 catch 异常”。 """ try: kwargs = { "messages": messages, "tools": tools, "model": model, "max_tokens": max_tokens, "temperature": temperature, } if thinking_enabled is not None: kwargs["thinking_enabled"] = thinking_enabled return await provider.chat(**kwargs) except Exception as exc: return LLMResponse( content=f"Error: {exc}", finish_reason="error", provider_name=runtime.provider_name, model=runtime.model, ) @staticmethod def _decorate_response(response: LLMResponse, runtime: ProviderRuntime) -> LLMResponse: if response.provider_name is None: response.provider_name = runtime.provider_name if response.model is None: response.model = runtime.model return response