移除了agents/registry.json中的所有内置agents配置,将agents数组清空。 为web应用添加了CORS中间件支持,允许指定的前端地址跨域访问。 重构了技能上传功能,增加了LLM重写机制,自动规范化上传的技能格式。 新增了工具名称提取逻辑,从技能正文中自动识别Required Tools段落。 更新了技能学习候选者和草稿的载荷结构,添加评估报告统计信息。 修改了意图路由技能的说明,改进任务状态管理逻辑。
217 lines
7.7 KiB
Python
217 lines
7.7 KiB
Python
"""Replay execution helpers for skill draft evaluation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Literal
|
|
from uuid import uuid4
|
|
|
|
from beaver.tools.base import ToolContext, ToolResult, ToolSpec
|
|
from beaver.tools.registry.tool_registry import ToolRegistry
|
|
from beaver.tools.runtime.executor import ToolExecutor
|
|
|
|
ToolExecutionMode = Literal["executed", "surrogate", "blocked"]
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ReplayToolPolicy:
|
|
safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"})
|
|
surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"})
|
|
destructive_terms: tuple[str, ...] = (
|
|
"delete",
|
|
"remove",
|
|
"destroy",
|
|
"revoke",
|
|
"permission",
|
|
"credential",
|
|
"payment",
|
|
"pay",
|
|
)
|
|
external_write_terms: tuple[str, ...] = (
|
|
"send",
|
|
"post",
|
|
"publish",
|
|
"create",
|
|
"update",
|
|
"invite",
|
|
"reply",
|
|
"forward",
|
|
)
|
|
|
|
|
|
class ReplayToolExecutor:
|
|
def __init__(
|
|
self,
|
|
inner: ToolExecutor,
|
|
*,
|
|
registry: ToolRegistry,
|
|
policy: ReplayToolPolicy | None = None,
|
|
) -> None:
|
|
self.inner = inner
|
|
self.registry = registry
|
|
self.policy = policy or ReplayToolPolicy()
|
|
self.traces: list[dict[str, Any]] = []
|
|
|
|
async def execute(
|
|
self,
|
|
tool_name: str,
|
|
arguments: dict[str, Any] | None,
|
|
*,
|
|
context: ToolContext | None = None,
|
|
) -> ToolResult:
|
|
tool = self.registry.get(tool_name)
|
|
spec = tool.spec if tool is not None else ToolSpec(
|
|
name=tool_name,
|
|
description="unregistered tool",
|
|
input_schema={"type": "object", "properties": {}},
|
|
toolset="unknown",
|
|
)
|
|
mode = classify_tool_mode(spec, self.policy)
|
|
trace = {
|
|
"trace_id": uuid4().hex,
|
|
"tool_name": tool_name,
|
|
"mode": mode,
|
|
"arguments": dict(arguments or {}),
|
|
"schema": dict(spec.input_schema),
|
|
"toolset": spec.toolset,
|
|
"metadata": dict(spec.metadata),
|
|
"classification_reason": _classification_reason(spec, mode),
|
|
}
|
|
if mode == "executed":
|
|
result = await self.inner.execute(tool_name, arguments or {}, context=context)
|
|
trace["result"] = {
|
|
"success": result.success,
|
|
"error": result.error,
|
|
"content": result.content[:2000],
|
|
}
|
|
self.traces.append(trace)
|
|
return result
|
|
if mode == "surrogate":
|
|
trace["result"] = {
|
|
"success": True,
|
|
"error": "replay_surrogate",
|
|
"content": "Tool call recorded for surrogate evaluation.",
|
|
}
|
|
self.traces.append(trace)
|
|
return ToolResult(
|
|
success=True,
|
|
content="Tool call recorded for surrogate evaluation.",
|
|
tool_name=tool_name,
|
|
error="replay_surrogate",
|
|
raw_output=trace,
|
|
)
|
|
trace["result"] = {
|
|
"success": False,
|
|
"error": "replay_blocked",
|
|
"content": "Tool call blocked by replay policy.",
|
|
}
|
|
self.traces.append(trace)
|
|
return ToolResult(
|
|
success=False,
|
|
content="Tool call blocked by replay policy.",
|
|
tool_name=tool_name,
|
|
error="replay_blocked",
|
|
raw_output=trace,
|
|
)
|
|
|
|
async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult:
|
|
tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call)
|
|
return await self.execute(tool_name, arguments, context=context)
|
|
|
|
|
|
def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode:
|
|
policy = policy or ReplayToolPolicy()
|
|
name = spec.name.lower()
|
|
toolset = spec.toolset.lower()
|
|
metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()}
|
|
if any(term in name for term in policy.destructive_terms):
|
|
return "blocked"
|
|
if toolset in policy.safe_toolsets:
|
|
return "executed"
|
|
if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}:
|
|
if any(term in name for term in policy.external_write_terms):
|
|
return "surrogate"
|
|
return "executed"
|
|
return "surrogate"
|
|
|
|
|
|
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
|
|
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class ReplayArmRequest:
|
|
case_id: str
|
|
arm: str
|
|
task_text: str
|
|
pinned_skill_names: list[str] = field(default_factory=list)
|
|
pinned_skill_contexts: list[Any] = field(default_factory=list)
|
|
provider_bundle: Any | None = None
|
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class ReplayRunner:
|
|
def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None:
|
|
self.agent_loop = agent_loop
|
|
self.policy = policy or ReplayToolPolicy()
|
|
|
|
async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]:
|
|
loaded = self.agent_loop.boot()
|
|
replay_executor = ReplayToolExecutor(
|
|
loaded.tool_executor,
|
|
registry=loaded.tool_registry,
|
|
policy=self.policy,
|
|
)
|
|
direct_kwargs = {
|
|
"provider_bundle": request.provider_bundle,
|
|
"include_skill_assembly": False,
|
|
"include_tools": True,
|
|
"pinned_skill_names": request.pinned_skill_names,
|
|
"pinned_skill_contexts": request.pinned_skill_contexts,
|
|
"max_tool_iterations": int(request.model_settings.get("max_tool_iterations") or 4),
|
|
"temperature": float(request.model_settings.get("temperature") or 0.0),
|
|
"source": "skill_replay_eval",
|
|
"tool_executor_override": replay_executor,
|
|
}
|
|
try:
|
|
result = await self.agent_loop.process_direct(request.task_text, **direct_kwargs)
|
|
except RuntimeError as exc:
|
|
if not _is_process_direct_disabled_while_running(exc) or not hasattr(self.agent_loop, "submit_direct"):
|
|
raise
|
|
result = await self.agent_loop.submit_direct(request.task_text, **direct_kwargs)
|
|
return {
|
|
"case_id": request.case_id,
|
|
"arm": request.arm,
|
|
"session_id": result.session_id,
|
|
"run_id": result.run_id,
|
|
"task_text": request.task_text,
|
|
"finish_reason": result.finish_reason,
|
|
"final_answer": result.output_text,
|
|
"tool_calls": list(replay_executor.traces),
|
|
"artifacts": [],
|
|
"side_effects": _side_effects_from_traces(replay_executor.traces),
|
|
}
|
|
|
|
|
|
def _is_process_direct_disabled_while_running(exc: RuntimeError) -> bool:
|
|
message = str(exc)
|
|
return (
|
|
"AgentLoop.process_direct() is disabled while run() is active" in message
|
|
and "submit tasks via submit_direct() instead" in message
|
|
)
|
|
|
|
|
|
def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
effects: list[dict[str, Any]] = []
|
|
for trace in traces:
|
|
if trace.get("mode") in {"surrogate", "blocked"}:
|
|
effects.append(
|
|
{
|
|
"tool_name": trace.get("tool_name"),
|
|
"mode": trace.get("mode"),
|
|
"arguments": trace.get("arguments"),
|
|
"classification_reason": trace.get("classification_reason"),
|
|
}
|
|
)
|
|
return effects
|