68 lines
2.8 KiB
Python
68 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
from beaver.skills.learning.replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
|
from beaver.tools.base import BaseTool, ToolContext, ToolResult, ToolSpec
|
|
from beaver.tools.registry.tool_registry import ToolRegistry
|
|
from beaver.tools.runtime.executor import ToolExecutor
|
|
|
|
|
|
class FakeTool(BaseTool):
|
|
def __init__(self, name: str, *, toolset: str = "filesystem", metadata: dict | None = None) -> None:
|
|
self._spec = ToolSpec(
|
|
name=name,
|
|
description=f"{name} tool",
|
|
input_schema={"type": "object", "properties": {"path": {"type": "string"}}},
|
|
toolset=toolset,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
@property
|
|
def spec(self) -> ToolSpec:
|
|
return self._spec
|
|
|
|
async def invoke(self, arguments: dict, context: ToolContext) -> ToolResult:
|
|
return ToolResult(success=True, content=f"executed:{arguments}", tool_name=self.spec.name)
|
|
|
|
|
|
def _executor(*tools: FakeTool) -> ReplayToolExecutor:
|
|
registry = ToolRegistry()
|
|
for tool in tools:
|
|
registry.register(tool)
|
|
return ReplayToolExecutor(ToolExecutor(registry), registry=registry, policy=ReplayToolPolicy())
|
|
|
|
|
|
def test_classify_tool_modes_from_spec() -> None:
|
|
assert classify_tool_mode(FakeTool("read_file").spec) == "executed"
|
|
assert classify_tool_mode(FakeTool("write_file").spec) == "executed"
|
|
assert classify_tool_mode(FakeTool("mcp_outlook_send_email", toolset="mcp", metadata={"transport": "mcp"}).spec) == "surrogate"
|
|
assert classify_tool_mode(FakeTool("delete_account", toolset="mcp", metadata={"transport": "mcp"}).spec) == "blocked"
|
|
|
|
|
|
def test_replay_executor_executes_safe_tool_and_records_trace() -> None:
|
|
executor = _executor(FakeTool("write_file"))
|
|
|
|
result = asyncio.run(executor.execute("write_file", {"path": "a.txt"}, context=ToolContext(workspace="/tmp/replay")))
|
|
|
|
assert result.success is True
|
|
assert result.content.startswith("executed:")
|
|
assert executor.traces[0]["mode"] == "executed"
|
|
assert executor.traces[0]["tool_name"] == "write_file"
|
|
|
|
|
|
def test_replay_executor_surrogates_external_write_and_blocks_destructive() -> None:
|
|
executor = _executor(
|
|
FakeTool("mcp_outlook_send_email", toolset="mcp", metadata={"transport": "mcp"}),
|
|
FakeTool("delete_account", toolset="mcp", metadata={"transport": "mcp"}),
|
|
)
|
|
|
|
send = asyncio.run(executor.execute("mcp_outlook_send_email", {"to": "ada@example.com"}, context=ToolContext()))
|
|
delete = asyncio.run(executor.execute("delete_account", {"id": "1"}, context=ToolContext()))
|
|
|
|
assert send.success is True
|
|
assert send.error == "replay_surrogate"
|
|
assert delete.success is False
|
|
assert delete.error == "replay_blocked"
|
|
assert [trace["mode"] for trace in executor.traces] == ["surrogate", "blocked"]
|