"""Replay execution helpers for skill draft evaluation.""" from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Literal from uuid import uuid4 from beaver.tools.base import ToolContext, ToolResult, ToolSpec from beaver.tools.registry.tool_registry import ToolRegistry from beaver.tools.runtime.executor import ToolExecutor ToolExecutionMode = Literal["executed", "surrogate", "blocked"] @dataclass(slots=True) class ReplayToolPolicy: safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"}) surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"}) destructive_terms: tuple[str, ...] = ( "delete", "remove", "destroy", "revoke", "permission", "credential", "payment", "pay", ) external_write_terms: tuple[str, ...] = ( "send", "post", "publish", "create", "update", "invite", "reply", "forward", ) class ReplayToolExecutor: def __init__( self, inner: ToolExecutor, *, registry: ToolRegistry, policy: ReplayToolPolicy | None = None, ) -> None: self.inner = inner self.registry = registry self.policy = policy or ReplayToolPolicy() self.traces: list[dict[str, Any]] = [] async def execute( self, tool_name: str, arguments: dict[str, Any] | None, *, context: ToolContext | None = None, ) -> ToolResult: tool = self.registry.get(tool_name) spec = tool.spec if tool is not None else ToolSpec( name=tool_name, description="unregistered tool", input_schema={"type": "object", "properties": {}}, toolset="unknown", ) mode = classify_tool_mode(spec, self.policy) trace = { "trace_id": uuid4().hex, "tool_name": tool_name, "mode": mode, "arguments": dict(arguments or {}), "schema": dict(spec.input_schema), "toolset": spec.toolset, "metadata": dict(spec.metadata), "classification_reason": _classification_reason(spec, mode), } if mode == "executed": result = await self.inner.execute(tool_name, arguments or {}, context=context) trace["result"] = { "success": result.success, "error": result.error, "content": result.content[:2000], } self.traces.append(trace) return result if mode == "surrogate": trace["result"] = { "success": True, "error": "replay_surrogate", "content": "Tool call recorded for surrogate evaluation.", } self.traces.append(trace) return ToolResult( success=True, content="Tool call recorded for surrogate evaluation.", tool_name=tool_name, error="replay_surrogate", raw_output=trace, ) trace["result"] = { "success": False, "error": "replay_blocked", "content": "Tool call blocked by replay policy.", } self.traces.append(trace) return ToolResult( success=False, content="Tool call blocked by replay policy.", tool_name=tool_name, error="replay_blocked", raw_output=trace, ) async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult: tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call) return await self.execute(tool_name, arguments, context=context) def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode: policy = policy or ReplayToolPolicy() name = spec.name.lower() toolset = spec.toolset.lower() metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()} if any(term in name for term in policy.destructive_terms): return "blocked" if toolset in policy.safe_toolsets: return "executed" if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}: if any(term in name for term in policy.external_write_terms): return "surrogate" return "executed" return "surrogate" def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str: return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"