Files
beaver_project/app-instance/backend/beaver/tools/runtime/executor.py

204 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Beaver 工具执行器。
这层专门负责把 provider 返回的 tool call 转成真正的工具执行。
它不关心 provider 是 OpenAI、Anthropic 还是 Codex只关心
1. 工具叫什么
2. 参数是什么
3. registry 能不能找到它
4. 执行结果怎么标准化
"""
from __future__ import annotations
import hashlib
import json
from typing import TYPE_CHECKING, Any
from beaver.tools.base import ToolContext, ToolResult
from beaver.tools.registry.tool_registry import ToolRegistry
if TYPE_CHECKING:
from beaver.engine.providers.base import ToolCallRequest
class ToolExecutor:
"""统一执行单个 tool call。"""
def __init__(self, registry: ToolRegistry) -> None:
self.registry = registry
async def execute(
self,
tool_name: str,
arguments: dict[str, Any] | None,
*,
context: ToolContext | None = None,
) -> ToolResult:
"""按工具名执行一次调用。"""
allowed = context.metadata.get("allowed_tool_names") if context is not None else None
if isinstance(allowed, list) and tool_name not in allowed:
return ToolResult(
success=False,
content=f"Tool {tool_name} is not allowed for this node.",
tool_name=tool_name,
error="tool_not_allowed",
)
tool = self.registry.get(tool_name)
if tool is None:
return ToolResult(
success=False,
content=f"Tool {tool_name} is not registered.",
tool_name=tool_name,
error="tool_not_found",
)
normalized_arguments = dict(arguments or {})
tool_context = context or ToolContext()
write_key = _external_write_key(tool_name, normalized_arguments)
if write_key is None:
return await tool.invoke(normalized_arguments, tool_context)
external_writes = _external_write_state(tool_context)
previous = external_writes.get(write_key)
if previous is not None:
previous_content = str(previous.get("content") or "").strip()
detail = f" Previous result: {previous_content}" if previous_content else ""
return ToolResult(
success=True,
content=(
f"Duplicate external write suppressed for {tool_name}. "
"A matching write was already attempted in this run."
f"{detail}"
),
tool_name=tool_name,
error="duplicate_external_write_suppressed",
raw_output={"duplicate": True, "previous": previous},
)
external_writes[write_key] = {
"tool_name": tool_name,
"arguments": normalized_arguments,
"status": "attempted",
"content": "",
"error": None,
}
result = await tool.invoke(normalized_arguments, tool_context)
external_writes[write_key] = {
"tool_name": tool_name,
"arguments": normalized_arguments,
"status": "done" if result.success else "error",
"content": result.content,
"error": result.error,
}
return result
async def execute_tool_call(
self,
tool_call: ToolCallRequest | dict[str, Any],
*,
context: ToolContext | None = None,
) -> ToolResult:
"""执行 provider 返回的一次结构化 tool call。
兼容两种输入:
- `ToolCallRequest`
- OpenAI 风格 dict
"""
try:
tool_name, arguments = self._normalize_tool_call(tool_call)
except Exception as exc:
return ToolResult(
success=False,
content=f"Tool call could not be parsed: {exc}",
tool_name=self._extract_tool_name(tool_call),
error="tool_call_parse_error",
)
parse_error = arguments.pop("__beaver_tool_argument_parse_error__", None)
if parse_error is not None:
return ToolResult(
success=False,
content=f"Tool call arguments for {tool_name} could not be parsed: {parse_error}",
tool_name=tool_name,
error="tool_call_argument_parse_error",
raw_output=arguments.get("__raw_arguments__"),
)
return await self.execute(tool_name, arguments, context=context)
@staticmethod
def _normalize_tool_call(tool_call: ToolCallRequest | dict[str, Any]) -> tuple[str, dict[str, Any]]:
if not isinstance(tool_call, dict):
name = getattr(tool_call, "name", None)
arguments = getattr(tool_call, "arguments", {})
else:
function = tool_call.get("function")
if isinstance(function, dict):
name = function.get("name")
arguments = function.get("arguments", {})
else:
name = tool_call.get("name")
arguments = tool_call.get("arguments", {})
if not name:
raise ValueError("Tool call is missing a tool name")
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError as exc:
raise ValueError(f"Tool call arguments for {name!r} are not valid JSON") from exc
if not isinstance(arguments, dict):
raise ValueError(f"Tool call arguments for {name!r} must be a dict")
return str(name), arguments
@staticmethod
def _extract_tool_name(tool_call: ToolCallRequest | dict[str, Any]) -> str:
if not isinstance(tool_call, dict):
return str(getattr(tool_call, "name", None) or "unknown")
function = tool_call.get("function")
if isinstance(function, dict) and function.get("name"):
return str(function["name"])
if tool_call.get("name"):
return str(tool_call["name"])
return "unknown"
_EXTERNAL_WRITE_TOOL_TERMS = (
"mail_send_email",
"mail_reply_to_message",
"mail_forward_message",
"mail_move_message",
"calendar_create_event",
"calendar_update_event",
)
def _external_write_state(context: ToolContext) -> dict[str, dict[str, Any]]:
state = context.metadata.setdefault("external_write_attempts", {})
if not isinstance(state, dict):
state = {}
context.metadata["external_write_attempts"] = state
return state
def _external_write_key(tool_name: str, arguments: dict[str, Any]) -> str | None:
lowered = tool_name.lower()
if not any(term in lowered for term in _EXTERNAL_WRITE_TOOL_TERMS):
return None
payload = json.dumps(_normalize_for_key(arguments), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()
return f"{lowered}:{digest}"
def _normalize_for_key(value: Any) -> Any:
if isinstance(value, dict):
return {str(key): _normalize_for_key(value[key]) for key in sorted(value, key=str)}
if isinstance(value, list):
return [_normalize_for_key(item) for item in value]
if isinstance(value, tuple):
return [_normalize_for_key(item) for item in value]
if isinstance(value, (str, int, float, bool)) or value is None:
return value
return str(value)