feat(learning): 添加技能学习候选者合成锁定机制 添加了 DraftSynthesisInProgress 和 DraftHasNoChanges 异常来处理并发场景, 确保同一技能学习候选者的合成过程不会重复执行。实现了 claim_learning_candidate_for_synthesis 方法来原子性地锁定候选者进行合成。 fix(web): 为技能草案创建端点添加适当的HTTP状态码 当草案没有变化或正在合成时,现在正确返回409状态码而不是内部错误。 feat(skills): 实现技能修订内容比较以检测无变化情况 添加了 _is_noop_revision 方法来比较基础技能和提议的修订, 如果内容没有实际变化则抛出 NoDraftChanges 异常。 refactor(process): 修复任务证据记录后根运行状态更新逻辑 将任务证据记录事件后的状态从 waiting 更改为 done,并设置 finished_at 时间戳。 feat(tools): 防止在同一运行中重复执行外部写入操作 为邮件发送、日历创建等外部写入工具添加去重机制,避免重复的外部操作。 test: 添加技能学习和工具执行的单元测试 增加测试用例验证并发草案合成、重复外部写入抑制和无变化修订检测等功能。 ```
196 lines
6.9 KiB
Python
196 lines
6.9 KiB
Python
"""Beaver 工具执行器。
|
||
|
||
这层专门负责把 provider 返回的 tool call 转成真正的工具执行。
|
||
它不关心 provider 是 OpenAI、Anthropic 还是 Codex,只关心:
|
||
|
||
1. 工具叫什么
|
||
2. 参数是什么
|
||
3. registry 能不能找到它
|
||
4. 执行结果怎么标准化
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import json
|
||
from typing import TYPE_CHECKING, Any
|
||
|
||
from beaver.tools.base import ToolContext, ToolResult
|
||
from beaver.tools.registry.tool_registry import ToolRegistry
|
||
|
||
if TYPE_CHECKING:
|
||
from beaver.engine.providers.base import ToolCallRequest
|
||
|
||
|
||
class ToolExecutor:
|
||
"""统一执行单个 tool call。"""
|
||
|
||
def __init__(self, registry: ToolRegistry) -> None:
|
||
self.registry = registry
|
||
|
||
async def execute(
|
||
self,
|
||
tool_name: str,
|
||
arguments: dict[str, Any] | None,
|
||
*,
|
||
context: ToolContext | None = None,
|
||
) -> ToolResult:
|
||
"""按工具名执行一次调用。"""
|
||
|
||
tool = self.registry.get(tool_name)
|
||
if tool is None:
|
||
return ToolResult(
|
||
success=False,
|
||
content=f"Tool {tool_name} is not registered.",
|
||
tool_name=tool_name,
|
||
error="tool_not_found",
|
||
)
|
||
normalized_arguments = dict(arguments or {})
|
||
tool_context = context or ToolContext()
|
||
write_key = _external_write_key(tool_name, normalized_arguments)
|
||
if write_key is None:
|
||
return await tool.invoke(normalized_arguments, tool_context)
|
||
|
||
external_writes = _external_write_state(tool_context)
|
||
previous = external_writes.get(write_key)
|
||
if previous is not None:
|
||
previous_content = str(previous.get("content") or "").strip()
|
||
detail = f" Previous result: {previous_content}" if previous_content else ""
|
||
return ToolResult(
|
||
success=True,
|
||
content=(
|
||
f"Duplicate external write suppressed for {tool_name}. "
|
||
"A matching write was already attempted in this run."
|
||
f"{detail}"
|
||
),
|
||
tool_name=tool_name,
|
||
error="duplicate_external_write_suppressed",
|
||
raw_output={"duplicate": True, "previous": previous},
|
||
)
|
||
|
||
external_writes[write_key] = {
|
||
"tool_name": tool_name,
|
||
"arguments": normalized_arguments,
|
||
"status": "attempted",
|
||
"content": "",
|
||
"error": None,
|
||
}
|
||
result = await tool.invoke(normalized_arguments, tool_context)
|
||
external_writes[write_key] = {
|
||
"tool_name": tool_name,
|
||
"arguments": normalized_arguments,
|
||
"status": "done" if result.success else "error",
|
||
"content": result.content,
|
||
"error": result.error,
|
||
}
|
||
return result
|
||
|
||
async def execute_tool_call(
|
||
self,
|
||
tool_call: ToolCallRequest | dict[str, Any],
|
||
*,
|
||
context: ToolContext | None = None,
|
||
) -> ToolResult:
|
||
"""执行 provider 返回的一次结构化 tool call。
|
||
|
||
兼容两种输入:
|
||
- `ToolCallRequest`
|
||
- OpenAI 风格 dict
|
||
"""
|
||
|
||
try:
|
||
tool_name, arguments = self._normalize_tool_call(tool_call)
|
||
except Exception as exc:
|
||
return ToolResult(
|
||
success=False,
|
||
content=f"Tool call could not be parsed: {exc}",
|
||
tool_name=self._extract_tool_name(tool_call),
|
||
error="tool_call_parse_error",
|
||
)
|
||
|
||
parse_error = arguments.pop("__beaver_tool_argument_parse_error__", None)
|
||
if parse_error is not None:
|
||
return ToolResult(
|
||
success=False,
|
||
content=f"Tool call arguments for {tool_name} could not be parsed: {parse_error}",
|
||
tool_name=tool_name,
|
||
error="tool_call_argument_parse_error",
|
||
raw_output=arguments.get("__raw_arguments__"),
|
||
)
|
||
return await self.execute(tool_name, arguments, context=context)
|
||
|
||
@staticmethod
|
||
def _normalize_tool_call(tool_call: ToolCallRequest | dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||
if not isinstance(tool_call, dict):
|
||
name = getattr(tool_call, "name", None)
|
||
arguments = getattr(tool_call, "arguments", {})
|
||
else:
|
||
function = tool_call.get("function")
|
||
if isinstance(function, dict):
|
||
name = function.get("name")
|
||
arguments = function.get("arguments", {})
|
||
else:
|
||
name = tool_call.get("name")
|
||
arguments = tool_call.get("arguments", {})
|
||
|
||
if not name:
|
||
raise ValueError("Tool call is missing a tool name")
|
||
if isinstance(arguments, str):
|
||
try:
|
||
arguments = json.loads(arguments)
|
||
except json.JSONDecodeError as exc:
|
||
raise ValueError(f"Tool call arguments for {name!r} are not valid JSON") from exc
|
||
if not isinstance(arguments, dict):
|
||
raise ValueError(f"Tool call arguments for {name!r} must be a dict")
|
||
return str(name), arguments
|
||
|
||
@staticmethod
|
||
def _extract_tool_name(tool_call: ToolCallRequest | dict[str, Any]) -> str:
|
||
if not isinstance(tool_call, dict):
|
||
return str(getattr(tool_call, "name", None) or "unknown")
|
||
function = tool_call.get("function")
|
||
if isinstance(function, dict) and function.get("name"):
|
||
return str(function["name"])
|
||
if tool_call.get("name"):
|
||
return str(tool_call["name"])
|
||
return "unknown"
|
||
|
||
|
||
_EXTERNAL_WRITE_TOOL_TERMS = (
|
||
"mail_send_email",
|
||
"mail_reply_to_message",
|
||
"mail_forward_message",
|
||
"mail_move_message",
|
||
"calendar_create_event",
|
||
"calendar_update_event",
|
||
)
|
||
|
||
|
||
def _external_write_state(context: ToolContext) -> dict[str, dict[str, Any]]:
|
||
state = context.metadata.setdefault("external_write_attempts", {})
|
||
if not isinstance(state, dict):
|
||
state = {}
|
||
context.metadata["external_write_attempts"] = state
|
||
return state
|
||
|
||
|
||
def _external_write_key(tool_name: str, arguments: dict[str, Any]) -> str | None:
|
||
lowered = tool_name.lower()
|
||
if not any(term in lowered for term in _EXTERNAL_WRITE_TOOL_TERMS):
|
||
return None
|
||
payload = json.dumps(_normalize_for_key(arguments), ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||
digest = hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||
return f"{lowered}:{digest}"
|
||
|
||
|
||
def _normalize_for_key(value: Any) -> Any:
|
||
if isinstance(value, dict):
|
||
return {str(key): _normalize_for_key(value[key]) for key in sorted(value, key=str)}
|
||
if isinstance(value, list):
|
||
return [_normalize_for_key(item) for item in value]
|
||
if isinstance(value, tuple):
|
||
return [_normalize_for_key(item) for item in value]
|
||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||
return value
|
||
return str(value)
|