115 lines
3.9 KiB
Python
115 lines
3.9 KiB
Python
"""Beaver 工具执行器。
|
||
|
||
这层专门负责把 provider 返回的 tool call 转成真正的工具执行。
|
||
它不关心 provider 是 OpenAI、Anthropic 还是 Codex,只关心:
|
||
|
||
1. 工具叫什么
|
||
2. 参数是什么
|
||
3. registry 能不能找到它
|
||
4. 执行结果怎么标准化
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from typing import Any
|
||
|
||
from beaver.engine.providers.base import ToolCallRequest
|
||
from beaver.tools.base import ToolContext, ToolResult
|
||
from beaver.tools.registry.tool_registry import ToolRegistry
|
||
|
||
|
||
class ToolExecutor:
|
||
"""统一执行单个 tool call。"""
|
||
|
||
def __init__(self, registry: ToolRegistry) -> None:
|
||
self.registry = registry
|
||
|
||
async def execute(
|
||
self,
|
||
tool_name: str,
|
||
arguments: dict[str, Any] | None,
|
||
*,
|
||
context: ToolContext | None = None,
|
||
) -> ToolResult:
|
||
"""按工具名执行一次调用。"""
|
||
|
||
tool = self.registry.get(tool_name)
|
||
if tool is None:
|
||
return ToolResult(
|
||
success=False,
|
||
content=f"Tool {tool_name} is not registered.",
|
||
tool_name=tool_name,
|
||
error="tool_not_found",
|
||
)
|
||
return await tool.invoke(arguments or {}, context or ToolContext())
|
||
|
||
async def execute_tool_call(
|
||
self,
|
||
tool_call: ToolCallRequest | dict[str, Any],
|
||
*,
|
||
context: ToolContext | None = None,
|
||
) -> ToolResult:
|
||
"""执行 provider 返回的一次结构化 tool call。
|
||
|
||
兼容两种输入:
|
||
- `ToolCallRequest`
|
||
- OpenAI 风格 dict
|
||
"""
|
||
|
||
try:
|
||
tool_name, arguments = self._normalize_tool_call(tool_call)
|
||
except Exception as exc:
|
||
return ToolResult(
|
||
success=False,
|
||
content=f"Tool call could not be parsed: {exc}",
|
||
tool_name=self._extract_tool_name(tool_call),
|
||
error="tool_call_parse_error",
|
||
)
|
||
|
||
parse_error = arguments.pop("__beaver_tool_argument_parse_error__", None)
|
||
if parse_error is not None:
|
||
return ToolResult(
|
||
success=False,
|
||
content=f"Tool call arguments for {tool_name} could not be parsed: {parse_error}",
|
||
tool_name=tool_name,
|
||
error="tool_call_argument_parse_error",
|
||
raw_output=arguments.get("__raw_arguments__"),
|
||
)
|
||
return await self.execute(tool_name, arguments, context=context)
|
||
|
||
@staticmethod
|
||
def _normalize_tool_call(tool_call: ToolCallRequest | dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||
if isinstance(tool_call, ToolCallRequest):
|
||
return tool_call.name, dict(tool_call.arguments)
|
||
|
||
function = tool_call.get("function")
|
||
if isinstance(function, dict):
|
||
name = function.get("name")
|
||
arguments = function.get("arguments", {})
|
||
else:
|
||
name = tool_call.get("name")
|
||
arguments = tool_call.get("arguments", {})
|
||
|
||
if not name:
|
||
raise ValueError("Tool call is missing a tool name")
|
||
if isinstance(arguments, str):
|
||
try:
|
||
arguments = json.loads(arguments)
|
||
except json.JSONDecodeError as exc:
|
||
raise ValueError(f"Tool call arguments for {name!r} are not valid JSON") from exc
|
||
if not isinstance(arguments, dict):
|
||
raise ValueError(f"Tool call arguments for {name!r} must be a dict")
|
||
return str(name), arguments
|
||
|
||
@staticmethod
|
||
def _extract_tool_name(tool_call: ToolCallRequest | dict[str, Any]) -> str:
|
||
if isinstance(tool_call, ToolCallRequest):
|
||
return str(tool_call.name or "unknown")
|
||
function = tool_call.get("function")
|
||
if isinstance(function, dict) and function.get("name"):
|
||
return str(function["name"])
|
||
if tool_call.get("name"):
|
||
return str(tool_call["name"])
|
||
return "unknown"
|