"""Replay execution helpers for skill draft evaluation.""" from __future__ import annotations from dataclasses import dataclass, field from time import perf_counter from typing import Any, Callable, Literal from uuid import uuid4 from beaver.tools.base import ToolContext, ToolResult, ToolSpec from beaver.tools.registry.tool_registry import ToolRegistry from beaver.tools.runtime.executor import ToolExecutor ToolExecutionMode = Literal["executed", "surrogate", "blocked"] @dataclass(slots=True) class ReplayToolPolicy: safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"}) surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"}) destructive_terms: tuple[str, ...] = ( "delete", "remove", "destroy", "revoke", "permission", "credential", "payment", "pay", ) external_write_terms: tuple[str, ...] = ( "send", "post", "publish", "create", "update", "invite", "reply", "forward", ) class ReplayToolExecutor: def __init__( self, inner: ToolExecutor, *, registry: ToolRegistry, policy: ReplayToolPolicy | None = None, ) -> None: self.inner = inner self.registry = registry self.policy = policy or ReplayToolPolicy() self.traces: list[dict[str, Any]] = [] async def execute( self, tool_name: str, arguments: dict[str, Any] | None, *, context: ToolContext | None = None, ) -> ToolResult: started_at = perf_counter() tool = self.registry.get(tool_name) spec = tool.spec if tool is not None else ToolSpec( name=tool_name, description="unregistered tool", input_schema={"type": "object", "properties": {}}, toolset="unknown", ) mode = classify_tool_mode(spec, self.policy) trace = { "trace_id": uuid4().hex, "tool_name": tool_name, "mode": mode, "arguments": dict(arguments or {}), "schema": dict(spec.input_schema), "toolset": spec.toolset, "metadata": dict(spec.metadata), "classification_reason": _classification_reason(spec, mode), } if mode == "executed": result = await self.inner.execute(tool_name, arguments or {}, context=context) trace["result"] = { "success": result.success, "error": result.error, "content": result.content[:2000], } trace["duration_ms"] = round((perf_counter() - started_at) * 1000, 2) self.traces.append(trace) return result if mode == "surrogate": trace["result"] = { "success": True, "error": "replay_surrogate", "content": "Tool call recorded for surrogate evaluation.", } trace["duration_ms"] = round((perf_counter() - started_at) * 1000, 2) self.traces.append(trace) return ToolResult( success=True, content="Tool call recorded for surrogate evaluation.", tool_name=tool_name, error="replay_surrogate", raw_output=trace, ) trace["result"] = { "success": False, "error": "replay_blocked", "content": "Tool call blocked by replay policy.", } trace["duration_ms"] = round((perf_counter() - started_at) * 1000, 2) self.traces.append(trace) return ToolResult( success=False, content="Tool call blocked by replay policy.", tool_name=tool_name, error="replay_blocked", raw_output=trace, ) async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult: tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call) return await self.execute(tool_name, arguments, context=context) def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode: policy = policy or ReplayToolPolicy() name = spec.name.lower() toolset = spec.toolset.lower() metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()} if any(term in name for term in policy.destructive_terms): return "blocked" if toolset in policy.safe_toolsets: return "executed" if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}: if any(term in name for term in policy.external_write_terms): return "surrogate" return "executed" return "surrogate" def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str: return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}" @dataclass(slots=True) class ReplayArmRequest: case_id: str arm: str task_text: str pinned_skill_names: list[str] = field(default_factory=list) pinned_skill_contexts: list[Any] = field(default_factory=list) provider_bundle: Any | None = None model_settings: dict[str, Any] = field(default_factory=dict) class ReplayRunner: def __init__( self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None, isolated_loop_factory: Callable[[], Any] | None = None, ) -> None: self.agent_loop = agent_loop self.policy = policy or ReplayToolPolicy() self.isolated_loop_factory = isolated_loop_factory async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]: target_loop = self.isolated_loop_factory() if self.isolated_loop_factory is not None else self.agent_loop loaded = target_loop.boot() replay_executor = ReplayToolExecutor( loaded.tool_executor, registry=loaded.tool_registry, policy=self.policy, ) direct_kwargs = { "provider_bundle": request.provider_bundle, "include_skill_assembly": False, "include_tools": True, "pinned_skill_names": request.pinned_skill_names, "pinned_skill_contexts": request.pinned_skill_contexts, "max_tool_iterations": int(request.model_settings.get("max_tool_iterations") or 4), "temperature": float(request.model_settings.get("temperature") or 0.0), "source": "skill_replay_eval", "tool_executor_override": replay_executor, } try: try: result = await target_loop.process_direct(request.task_text, **direct_kwargs) except RuntimeError as exc: if not _is_process_direct_disabled_while_running(exc) or not hasattr(target_loop, "submit_direct"): raise result = await target_loop.submit_direct(request.task_text, **direct_kwargs) session_manager = getattr(loaded, "session_manager", None) if session_manager is not None and hasattr(session_manager, "end_session"): session_manager.end_session(result.session_id, "evaluation_complete") return { "case_id": request.case_id, "arm": request.arm, "session_id": result.session_id, "run_id": result.run_id, "task_text": request.task_text, "finish_reason": result.finish_reason, "final_answer": result.output_text, "tool_calls": list(replay_executor.traces), "artifacts": [], "side_effects": _side_effects_from_traces(replay_executor.traces), } finally: if target_loop is not self.agent_loop and hasattr(target_loop, "close"): mcp_manager = getattr(loaded, "mcp_manager", None) if mcp_manager is not None and hasattr(mcp_manager, "close"): try: await mcp_manager.close() finally: closeables = getattr(loaded, "closeables", None) if isinstance(closeables, list): loaded.closeables = [ (name, close_fn) for name, close_fn in closeables if name != "mcp_manager" ] target_loop.close() def _is_process_direct_disabled_while_running(exc: RuntimeError) -> bool: message = str(exc) return ( "AgentLoop.process_direct() is disabled while run() is active" in message and "submit tasks via submit_direct() instead" in message ) def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]: effects: list[dict[str, Any]] = [] for trace in traces: if trace.get("mode") in {"surrogate", "blocked"}: effects.append( { "tool_name": trace.get("tool_name"), "mode": trace.get("mode"), "arguments": trace.get("arguments"), "classification_reason": trace.get("classification_reason"), } ) return effects