"""Replay execution helpers for skill draft evaluation.""" from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Literal from uuid import uuid4 from beaver.tools.base import ToolContext, ToolResult, ToolSpec from beaver.tools.registry.tool_registry import ToolRegistry from beaver.tools.runtime.executor import ToolExecutor ToolExecutionMode = Literal["executed", "surrogate", "blocked"] @dataclass(slots=True) class ReplayToolPolicy: safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"}) surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"}) destructive_terms: tuple[str, ...] = ( "delete", "remove", "destroy", "revoke", "permission", "credential", "payment", "pay", ) external_write_terms: tuple[str, ...] = ( "send", "post", "publish", "create", "update", "invite", "reply", "forward", ) class ReplayToolExecutor: def __init__( self, inner: ToolExecutor, *, registry: ToolRegistry, policy: ReplayToolPolicy | None = None, ) -> None: self.inner = inner self.registry = registry self.policy = policy or ReplayToolPolicy() self.traces: list[dict[str, Any]] = [] async def execute( self, tool_name: str, arguments: dict[str, Any] | None, *, context: ToolContext | None = None, ) -> ToolResult: tool = self.registry.get(tool_name) spec = tool.spec if tool is not None else ToolSpec( name=tool_name, description="unregistered tool", input_schema={"type": "object", "properties": {}}, toolset="unknown", ) mode = classify_tool_mode(spec, self.policy) trace = { "trace_id": uuid4().hex, "tool_name": tool_name, "mode": mode, "arguments": dict(arguments or {}), "schema": dict(spec.input_schema), "toolset": spec.toolset, "metadata": dict(spec.metadata), "classification_reason": _classification_reason(spec, mode), } if mode == "executed": result = await self.inner.execute(tool_name, arguments or {}, context=context) trace["result"] = { "success": result.success, "error": result.error, "content": result.content[:2000], } self.traces.append(trace) return result if mode == "surrogate": trace["result"] = { "success": True, "error": "replay_surrogate", "content": "Tool call recorded for surrogate evaluation.", } self.traces.append(trace) return ToolResult( success=True, content="Tool call recorded for surrogate evaluation.", tool_name=tool_name, error="replay_surrogate", raw_output=trace, ) trace["result"] = { "success": False, "error": "replay_blocked", "content": "Tool call blocked by replay policy.", } self.traces.append(trace) return ToolResult( success=False, content="Tool call blocked by replay policy.", tool_name=tool_name, error="replay_blocked", raw_output=trace, ) async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult: tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call) return await self.execute(tool_name, arguments, context=context) def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode: policy = policy or ReplayToolPolicy() name = spec.name.lower() toolset = spec.toolset.lower() metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()} if any(term in name for term in policy.destructive_terms): return "blocked" if toolset in policy.safe_toolsets: return "executed" if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}: if any(term in name for term in policy.external_write_terms): return "surrogate" return "executed" return "surrogate" def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str: return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}" @dataclass(slots=True) class ReplayArmRequest: case_id: str arm: str task_text: str pinned_skill_names: list[str] = field(default_factory=list) pinned_skill_contexts: list[Any] = field(default_factory=list) provider_bundle: Any | None = None model_settings: dict[str, Any] = field(default_factory=dict) class ReplayRunner: def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None: self.agent_loop = agent_loop self.policy = policy or ReplayToolPolicy() async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]: loaded = self.agent_loop.boot() replay_executor = ReplayToolExecutor( loaded.tool_executor, registry=loaded.tool_registry, policy=self.policy, ) result = await self.agent_loop.process_direct( request.task_text, provider_bundle=request.provider_bundle, include_skill_assembly=False, include_tools=True, pinned_skill_names=request.pinned_skill_names, pinned_skill_contexts=request.pinned_skill_contexts, max_tool_iterations=int(request.model_settings.get("max_tool_iterations") or 4), temperature=float(request.model_settings.get("temperature") or 0.0), source="skill_replay_eval", tool_executor_override=replay_executor, ) return { "case_id": request.case_id, "arm": request.arm, "session_id": result.session_id, "run_id": result.run_id, "task_text": request.task_text, "finish_reason": result.finish_reason, "final_answer": result.output_text, "tool_calls": list(replay_executor.traces), "artifacts": [], "side_effects": _side_effects_from_traces(replay_executor.traces), } def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]: effects: list[dict[str, Any]] = [] for trace in traces: if trace.get("mode") in {"surrogate", "blocked"}: effects.append( { "tool_name": trace.get("tool_name"), "mode": trace.get("mode"), "arguments": trace.get("arguments"), "classification_reason": trace.get("classification_reason"), } ) return effects