diff --git a/app-instance/backend/beaver/skills/learning/__init__.py b/app-instance/backend/beaver/skills/learning/__init__.py index 2e7b9f8..1a4f5f3 100644 --- a/app-instance/backend/beaver/skills/learning/__init__.py +++ b/app-instance/backend/beaver/skills/learning/__init__.py @@ -11,6 +11,7 @@ from .missing_skill import ( ) from .pipeline import SkillLearningPipelineService from .preservation import check_preservation +from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode from .service import RunReceiptContext, SkillLearningService from .synthesizer import SkillDraftSynthesizer from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult @@ -27,6 +28,9 @@ __all__ = [ "RunReceiptContext", "SkillLearningPipelineService", "check_preservation", + "ReplayToolExecutor", + "ReplayToolPolicy", + "classify_tool_mode", "SkillDraftSynthesizer", "SkillLearningService", "SkillLearningWorker", diff --git a/app-instance/backend/beaver/skills/learning/replay.py b/app-instance/backend/beaver/skills/learning/replay.py new file mode 100644 index 0000000..bce6cd1 --- /dev/null +++ b/app-instance/backend/beaver/skills/learning/replay.py @@ -0,0 +1,139 @@ +"""Replay execution helpers for skill draft evaluation.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal +from uuid import uuid4 + +from beaver.tools.base import ToolContext, ToolResult, ToolSpec +from beaver.tools.registry.tool_registry import ToolRegistry +from beaver.tools.runtime.executor import ToolExecutor + +ToolExecutionMode = Literal["executed", "surrogate", "blocked"] + + +@dataclass(slots=True) +class ReplayToolPolicy: + safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"}) + surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"}) + destructive_terms: tuple[str, ...] = ( + "delete", + "remove", + "destroy", + "revoke", + "permission", + "credential", + "payment", + "pay", + ) + external_write_terms: tuple[str, ...] = ( + "send", + "post", + "publish", + "create", + "update", + "invite", + "reply", + "forward", + ) + + +class ReplayToolExecutor: + def __init__( + self, + inner: ToolExecutor, + *, + registry: ToolRegistry, + policy: ReplayToolPolicy | None = None, + ) -> None: + self.inner = inner + self.registry = registry + self.policy = policy or ReplayToolPolicy() + self.traces: list[dict[str, Any]] = [] + + async def execute( + self, + tool_name: str, + arguments: dict[str, Any] | None, + *, + context: ToolContext | None = None, + ) -> ToolResult: + tool = self.registry.get(tool_name) + spec = tool.spec if tool is not None else ToolSpec( + name=tool_name, + description="unregistered tool", + input_schema={"type": "object", "properties": {}}, + toolset="unknown", + ) + mode = classify_tool_mode(spec, self.policy) + trace = { + "trace_id": uuid4().hex, + "tool_name": tool_name, + "mode": mode, + "arguments": dict(arguments or {}), + "schema": dict(spec.input_schema), + "toolset": spec.toolset, + "metadata": dict(spec.metadata), + "classification_reason": _classification_reason(spec, mode), + } + if mode == "executed": + result = await self.inner.execute(tool_name, arguments or {}, context=context) + trace["result"] = { + "success": result.success, + "error": result.error, + "content": result.content[:2000], + } + self.traces.append(trace) + return result + if mode == "surrogate": + trace["result"] = { + "success": True, + "error": "replay_surrogate", + "content": "Tool call recorded for surrogate evaluation.", + } + self.traces.append(trace) + return ToolResult( + success=True, + content="Tool call recorded for surrogate evaluation.", + tool_name=tool_name, + error="replay_surrogate", + raw_output=trace, + ) + trace["result"] = { + "success": False, + "error": "replay_blocked", + "content": "Tool call blocked by replay policy.", + } + self.traces.append(trace) + return ToolResult( + success=False, + content="Tool call blocked by replay policy.", + tool_name=tool_name, + error="replay_blocked", + raw_output=trace, + ) + + async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult: + tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call) + return await self.execute(tool_name, arguments, context=context) + + +def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode: + policy = policy or ReplayToolPolicy() + name = spec.name.lower() + toolset = spec.toolset.lower() + metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()} + if any(term in name for term in policy.destructive_terms): + return "blocked" + if toolset in policy.safe_toolsets: + return "executed" + if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}: + if any(term in name for term in policy.external_write_terms): + return "surrogate" + return "executed" + return "surrogate" + + +def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str: + return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}" diff --git a/app-instance/backend/tests/unit/test_skill_learning_replay.py b/app-instance/backend/tests/unit/test_skill_learning_replay.py new file mode 100644 index 0000000..ac67fa0 --- /dev/null +++ b/app-instance/backend/tests/unit/test_skill_learning_replay.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import asyncio + +from beaver.skills.learning.replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode +from beaver.tools.base import BaseTool, ToolContext, ToolResult, ToolSpec +from beaver.tools.registry.tool_registry import ToolRegistry +from beaver.tools.runtime.executor import ToolExecutor + + +class FakeTool(BaseTool): + def __init__(self, name: str, *, toolset: str = "filesystem", metadata: dict | None = None) -> None: + self._spec = ToolSpec( + name=name, + description=f"{name} tool", + input_schema={"type": "object", "properties": {"path": {"type": "string"}}}, + toolset=toolset, + metadata=metadata or {}, + ) + + @property + def spec(self) -> ToolSpec: + return self._spec + + async def invoke(self, arguments: dict, context: ToolContext) -> ToolResult: + return ToolResult(success=True, content=f"executed:{arguments}", tool_name=self.spec.name) + + +def _executor(*tools: FakeTool) -> ReplayToolExecutor: + registry = ToolRegistry() + for tool in tools: + registry.register(tool) + return ReplayToolExecutor(ToolExecutor(registry), registry=registry, policy=ReplayToolPolicy()) + + +def test_classify_tool_modes_from_spec() -> None: + assert classify_tool_mode(FakeTool("read_file").spec) == "executed" + assert classify_tool_mode(FakeTool("write_file").spec) == "executed" + assert classify_tool_mode(FakeTool("mcp_outlook_send_email", toolset="mcp", metadata={"transport": "mcp"}).spec) == "surrogate" + assert classify_tool_mode(FakeTool("delete_account", toolset="mcp", metadata={"transport": "mcp"}).spec) == "blocked" + + +def test_replay_executor_executes_safe_tool_and_records_trace() -> None: + executor = _executor(FakeTool("write_file")) + + result = asyncio.run(executor.execute("write_file", {"path": "a.txt"}, context=ToolContext(workspace="/tmp/replay"))) + + assert result.success is True + assert result.content.startswith("executed:") + assert executor.traces[0]["mode"] == "executed" + assert executor.traces[0]["tool_name"] == "write_file" + + +def test_replay_executor_surrogates_external_write_and_blocks_destructive() -> None: + executor = _executor( + FakeTool("mcp_outlook_send_email", toolset="mcp", metadata={"transport": "mcp"}), + FakeTool("delete_account", toolset="mcp", metadata={"transport": "mcp"}), + ) + + send = asyncio.run(executor.execute("mcp_outlook_send_email", {"to": "ada@example.com"}, context=ToolContext())) + delete = asyncio.run(executor.execute("delete_account", {"id": "1"}, context=ToolContext())) + + assert send.success is True + assert send.error == "replay_surrogate" + assert delete.success is False + assert delete.error == "replay_blocked" + assert [trace["mode"] for trace in executor.traces] == ["surrogate", "blocked"]