174 lines
6.3 KiB
Python
174 lines
6.3 KiB
Python
"""Native Anthropic Messages API provider."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from typing import Any
|
||
|
||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||
|
||
try: # pragma: no cover - optional dependency
|
||
import anthropic
|
||
except ModuleNotFoundError: # pragma: no cover
|
||
anthropic = None # type: ignore[assignment]
|
||
|
||
|
||
class AnthropicProvider(LLMProvider):
|
||
"""使用 Anthropic 原生 Messages API,而不是强行走 OpenAI-compatible path。"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str | None = None,
|
||
default_model: str = "claude-sonnet-4-5",
|
||
api_base: str | None = None,
|
||
request_timeout_seconds: float | None = None,
|
||
) -> None:
|
||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||
self.default_model = default_model
|
||
self._client = None
|
||
|
||
def _client_or_raise(self):
|
||
if anthropic is None:
|
||
raise RuntimeError("anthropic package is not installed")
|
||
if self._client is None:
|
||
self._client = anthropic.AsyncAnthropic(
|
||
api_key=self.api_key,
|
||
base_url=self.api_base,
|
||
timeout=self.request_timeout_seconds,
|
||
)
|
||
return self._client
|
||
|
||
async def chat(
|
||
self,
|
||
messages: list[dict[str, Any]],
|
||
tools: list[dict[str, Any]] | None = None,
|
||
model: str | None = None,
|
||
max_tokens: int = 4096,
|
||
temperature: float = 0.7,
|
||
) -> LLMResponse:
|
||
try:
|
||
client = self._client_or_raise()
|
||
except Exception as exc:
|
||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
|
||
|
||
system_prompt, anthropic_messages = _convert_messages(messages)
|
||
kwargs: dict[str, Any] = {
|
||
"model": model or self.default_model,
|
||
"system": system_prompt or "",
|
||
"messages": anthropic_messages,
|
||
"max_tokens": max(1, max_tokens),
|
||
"temperature": temperature,
|
||
}
|
||
if tools:
|
||
kwargs["tools"] = _convert_tools(tools)
|
||
|
||
try:
|
||
response = await client.messages.create(**kwargs)
|
||
except Exception as exc:
|
||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
|
||
|
||
content_parts: list[str] = []
|
||
tool_calls: list[ToolCallRequest] = []
|
||
for block in response.content:
|
||
if block.type == "text":
|
||
content_parts.append(block.text)
|
||
elif block.type == "tool_use":
|
||
tool_calls.append(
|
||
ToolCallRequest(
|
||
id=block.id,
|
||
name=block.name,
|
||
arguments=block.input,
|
||
)
|
||
)
|
||
usage_payload = {}
|
||
if getattr(response, "usage", None):
|
||
usage_payload = {
|
||
"input_tokens": getattr(response.usage, "input_tokens", 0),
|
||
"output_tokens": getattr(response.usage, "output_tokens", 0),
|
||
}
|
||
return LLMResponse(
|
||
content="".join(content_parts) or None,
|
||
tool_calls=tool_calls,
|
||
finish_reason=getattr(response, "stop_reason", "stop") or "stop",
|
||
usage=usage_payload,
|
||
provider_name="anthropic",
|
||
model=model or self.default_model,
|
||
)
|
||
|
||
def get_default_model(self) -> str:
|
||
return self.default_model
|
||
|
||
|
||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||
system_prompt = ""
|
||
converted: list[dict[str, Any]] = []
|
||
for message in messages:
|
||
role = message.get("role")
|
||
if role == "system":
|
||
content = message.get("content")
|
||
system_prompt = content if isinstance(content, str) else ""
|
||
continue
|
||
if role == "tool":
|
||
converted.append(
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "tool_result",
|
||
"tool_use_id": message.get("tool_call_id"),
|
||
"content": message.get("content") or "",
|
||
}
|
||
],
|
||
}
|
||
)
|
||
continue
|
||
if role == "assistant" and message.get("tool_calls"):
|
||
content_blocks: list[dict[str, Any]] = []
|
||
if message.get("content"):
|
||
content_blocks.append({"type": "text", "text": message["content"]})
|
||
for tool_call in message.get("tool_calls", []):
|
||
function = tool_call.get("function", tool_call)
|
||
arguments = function.get("arguments")
|
||
if isinstance(arguments, str):
|
||
try:
|
||
arguments = json.loads(arguments)
|
||
except json.JSONDecodeError:
|
||
arguments = {}
|
||
content_blocks.append(
|
||
{
|
||
"type": "tool_use",
|
||
"id": tool_call.get("id"),
|
||
"name": function.get("name"),
|
||
"input": arguments or {},
|
||
}
|
||
)
|
||
converted.append({"role": "assistant", "content": content_blocks})
|
||
continue
|
||
|
||
content = message.get("content")
|
||
if isinstance(content, list):
|
||
blocks = []
|
||
for item in content:
|
||
if isinstance(item, dict) and item.get("type") == "text":
|
||
blocks.append({"type": "text", "text": item.get("text", "")})
|
||
converted.append({"role": role, "content": blocks or [{"type": "text", "text": ""}]})
|
||
else:
|
||
converted.append({"role": role, "content": content or ""})
|
||
return system_prompt, converted
|
||
|
||
|
||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||
converted: list[dict[str, Any]] = []
|
||
for tool in tools:
|
||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||
if not fn.get("name"):
|
||
continue
|
||
converted.append(
|
||
{
|
||
"name": fn["name"],
|
||
"description": fn.get("description") or "",
|
||
"input_schema": fn.get("parameters") or {"type": "object", "properties": {}},
|
||
}
|
||
)
|
||
return converted
|