"""Beaver 工具执行器。 这层专门负责把 provider 返回的 tool call 转成真正的工具执行。 它不关心 provider 是 OpenAI、Anthropic 还是 Codex,只关心: 1. 工具叫什么 2. 参数是什么 3. registry 能不能找到它 4. 执行结果怎么标准化 """ from __future__ import annotations import hashlib import json from typing import TYPE_CHECKING, Any from beaver.tools.base import ToolContext, ToolResult from beaver.tools.registry.tool_registry import ToolRegistry if TYPE_CHECKING: from beaver.engine.providers.base import ToolCallRequest class ToolExecutor: """统一执行单个 tool call。""" def __init__(self, registry: ToolRegistry) -> None: self.registry = registry async def execute( self, tool_name: str, arguments: dict[str, Any] | None, *, context: ToolContext | None = None, ) -> ToolResult: """按工具名执行一次调用。""" tool = self.registry.get(tool_name) if tool is None: return ToolResult( success=False, content=f"Tool {tool_name} is not registered.", tool_name=tool_name, error="tool_not_found", ) normalized_arguments = dict(arguments or {}) tool_context = context or ToolContext() write_key = _external_write_key(tool_name, normalized_arguments) if write_key is None: return await tool.invoke(normalized_arguments, tool_context) external_writes = _external_write_state(tool_context) previous = external_writes.get(write_key) if previous is not None: previous_content = str(previous.get("content") or "").strip() detail = f" Previous result: {previous_content}" if previous_content else "" return ToolResult( success=True, content=( f"Duplicate external write suppressed for {tool_name}. " "A matching write was already attempted in this run." f"{detail}" ), tool_name=tool_name, error="duplicate_external_write_suppressed", raw_output={"duplicate": True, "previous": previous}, ) external_writes[write_key] = { "tool_name": tool_name, "arguments": normalized_arguments, "status": "attempted", "content": "", "error": None, } result = await tool.invoke(normalized_arguments, tool_context) external_writes[write_key] = { "tool_name": tool_name, "arguments": normalized_arguments, "status": "done" if result.success else "error", "content": result.content, "error": result.error, } return result async def execute_tool_call( self, tool_call: ToolCallRequest | dict[str, Any], *, context: ToolContext | None = None, ) -> ToolResult: """执行 provider 返回的一次结构化 tool call。 兼容两种输入: - `ToolCallRequest` - OpenAI 风格 dict """ try: tool_name, arguments = self._normalize_tool_call(tool_call) except Exception as exc: return ToolResult( success=False, content=f"Tool call could not be parsed: {exc}", tool_name=self._extract_tool_name(tool_call), error="tool_call_parse_error", ) parse_error = arguments.pop("__beaver_tool_argument_parse_error__", None) if parse_error is not None: return ToolResult( success=False, content=f"Tool call arguments for {tool_name} could not be parsed: {parse_error}", tool_name=tool_name, error="tool_call_argument_parse_error", raw_output=arguments.get("__raw_arguments__"), ) return await self.execute(tool_name, arguments, context=context) @staticmethod def _normalize_tool_call(tool_call: ToolCallRequest | dict[str, Any]) -> tuple[str, dict[str, Any]]: if not isinstance(tool_call, dict): name = getattr(tool_call, "name", None) arguments = getattr(tool_call, "arguments", {}) else: function = tool_call.get("function") if isinstance(function, dict): name = function.get("name") arguments = function.get("arguments", {}) else: name = tool_call.get("name") arguments = tool_call.get("arguments", {}) if not name: raise ValueError("Tool call is missing a tool name") if isinstance(arguments, str): try: arguments = json.loads(arguments) except json.JSONDecodeError as exc: raise ValueError(f"Tool call arguments for {name!r} are not valid JSON") from exc if not isinstance(arguments, dict): raise ValueError(f"Tool call arguments for {name!r} must be a dict") return str(name), arguments @staticmethod def _extract_tool_name(tool_call: ToolCallRequest | dict[str, Any]) -> str: if not isinstance(tool_call, dict): return str(getattr(tool_call, "name", None) or "unknown") function = tool_call.get("function") if isinstance(function, dict) and function.get("name"): return str(function["name"]) if tool_call.get("name"): return str(tool_call["name"]) return "unknown" _EXTERNAL_WRITE_TOOL_TERMS = ( "mail_send_email", "mail_reply_to_message", "mail_forward_message", "mail_move_message", "calendar_create_event", "calendar_update_event", ) def _external_write_state(context: ToolContext) -> dict[str, dict[str, Any]]: state = context.metadata.setdefault("external_write_attempts", {}) if not isinstance(state, dict): state = {} context.metadata["external_write_attempts"] = state return state def _external_write_key(tool_name: str, arguments: dict[str, Any]) -> str | None: lowered = tool_name.lower() if not any(term in lowered for term in _EXTERNAL_WRITE_TOOL_TERMS): return None payload = json.dumps(_normalize_for_key(arguments), ensure_ascii=False, sort_keys=True, separators=(",", ":")) digest = hashlib.sha256(payload.encode("utf-8")).hexdigest() return f"{lowered}:{digest}" def _normalize_for_key(value: Any) -> Any: if isinstance(value, dict): return {str(key): _normalize_for_key(value[key]) for key in sorted(value, key=str)} if isinstance(value, list): return [_normalize_for_key(item) for item in value] if isinstance(value, tuple): return [_normalize_for_key(item) for item in value] if isinstance(value, (str, int, float, bool)) or value is None: return value return str(value)