"""Direct OpenAI-compatible provider — bypasses LiteLLM.""" from __future__ import annotations from typing import Any from .base import LLMProvider, LLMResponse, ToolCallRequest try: # pragma: no cover - optional dependency import json_repair except ModuleNotFoundError: # pragma: no cover json_repair = None # type: ignore[assignment] try: # pragma: no cover - optional dependency from openai import AsyncOpenAI except ModuleNotFoundError: # pragma: no cover AsyncOpenAI = None # type: ignore[assignment] class CustomProvider(LLMProvider): """直接连接任意 OpenAI-compatible endpoint。""" def __init__( self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default", request_timeout_seconds: float | None = None, ) -> None: super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds) self.default_model = default_model self._client = None def _client_or_raise(self): if AsyncOpenAI is None: raise RuntimeError("openai package is not installed") if self._client is None: self._client = AsyncOpenAI( api_key=self.api_key, base_url=self.api_base, timeout=self.request_timeout_seconds, ) return self._client async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, thinking_enabled: bool | None = None, ) -> LLMResponse: client = self._client_or_raise() kwargs: dict[str, Any] = { "model": model or self.default_model, "messages": self.sanitize_empty_content(messages), "max_tokens": max(1, max_tokens), "temperature": temperature, } if tools: kwargs.update(tools=tools, tool_choice="auto") try: response = await client.chat.completions.create(**kwargs) except Exception as exc: return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="custom") choice = response.choices[0] message = choice.message parsed_tool_calls: list[ToolCallRequest] = [] for tool_call in message.tool_calls or []: raw_arguments = tool_call.function.arguments if isinstance(raw_arguments, str): if json_repair is not None: arguments = json_repair.loads(raw_arguments) else: import json arguments = json.loads(raw_arguments) else: arguments = raw_arguments parsed_tool_calls.append( ToolCallRequest( id=tool_call.id, name=tool_call.function.name, arguments=arguments, ) ) usage = getattr(response, "usage", None) usage_payload = {} if usage is not None: usage_payload = { "prompt_tokens": getattr(usage, "prompt_tokens", 0), "completion_tokens": getattr(usage, "completion_tokens", 0), "total_tokens": getattr(usage, "total_tokens", 0), } return LLMResponse( content=message.content, tool_calls=parsed_tool_calls, finish_reason=choice.finish_reason or "stop", usage=usage_payload, reasoning_content=getattr(message, "reasoning_content", None), provider_name="custom", model=model or self.default_model, ) def get_default_model(self) -> str: return self.default_model