feat(skill-learning): add replay tool policy

This commit is contained in:
2026-06-08 13:31:13 +08:00
parent 7287e93f87
commit eb69bb168a
3 changed files with 210 additions and 0 deletions

View File

@ -11,6 +11,7 @@ from .missing_skill import (
) )
from .pipeline import SkillLearningPipelineService from .pipeline import SkillLearningPipelineService
from .preservation import check_preservation from .preservation import check_preservation
from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
from .service import RunReceiptContext, SkillLearningService from .service import RunReceiptContext, SkillLearningService
from .synthesizer import SkillDraftSynthesizer from .synthesizer import SkillDraftSynthesizer
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
@ -27,6 +28,9 @@ __all__ = [
"RunReceiptContext", "RunReceiptContext",
"SkillLearningPipelineService", "SkillLearningPipelineService",
"check_preservation", "check_preservation",
"ReplayToolExecutor",
"ReplayToolPolicy",
"classify_tool_mode",
"SkillDraftSynthesizer", "SkillDraftSynthesizer",
"SkillLearningService", "SkillLearningService",
"SkillLearningWorker", "SkillLearningWorker",

View File

@ -0,0 +1,139 @@
"""Replay execution helpers for skill draft evaluation."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
from uuid import uuid4
from beaver.tools.base import ToolContext, ToolResult, ToolSpec
from beaver.tools.registry.tool_registry import ToolRegistry
from beaver.tools.runtime.executor import ToolExecutor
ToolExecutionMode = Literal["executed", "surrogate", "blocked"]
@dataclass(slots=True)
class ReplayToolPolicy:
safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"})
surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"})
destructive_terms: tuple[str, ...] = (
"delete",
"remove",
"destroy",
"revoke",
"permission",
"credential",
"payment",
"pay",
)
external_write_terms: tuple[str, ...] = (
"send",
"post",
"publish",
"create",
"update",
"invite",
"reply",
"forward",
)
class ReplayToolExecutor:
def __init__(
self,
inner: ToolExecutor,
*,
registry: ToolRegistry,
policy: ReplayToolPolicy | None = None,
) -> None:
self.inner = inner
self.registry = registry
self.policy = policy or ReplayToolPolicy()
self.traces: list[dict[str, Any]] = []
async def execute(
self,
tool_name: str,
arguments: dict[str, Any] | None,
*,
context: ToolContext | None = None,
) -> ToolResult:
tool = self.registry.get(tool_name)
spec = tool.spec if tool is not None else ToolSpec(
name=tool_name,
description="unregistered tool",
input_schema={"type": "object", "properties": {}},
toolset="unknown",
)
mode = classify_tool_mode(spec, self.policy)
trace = {
"trace_id": uuid4().hex,
"tool_name": tool_name,
"mode": mode,
"arguments": dict(arguments or {}),
"schema": dict(spec.input_schema),
"toolset": spec.toolset,
"metadata": dict(spec.metadata),
"classification_reason": _classification_reason(spec, mode),
}
if mode == "executed":
result = await self.inner.execute(tool_name, arguments or {}, context=context)
trace["result"] = {
"success": result.success,
"error": result.error,
"content": result.content[:2000],
}
self.traces.append(trace)
return result
if mode == "surrogate":
trace["result"] = {
"success": True,
"error": "replay_surrogate",
"content": "Tool call recorded for surrogate evaluation.",
}
self.traces.append(trace)
return ToolResult(
success=True,
content="Tool call recorded for surrogate evaluation.",
tool_name=tool_name,
error="replay_surrogate",
raw_output=trace,
)
trace["result"] = {
"success": False,
"error": "replay_blocked",
"content": "Tool call blocked by replay policy.",
}
self.traces.append(trace)
return ToolResult(
success=False,
content="Tool call blocked by replay policy.",
tool_name=tool_name,
error="replay_blocked",
raw_output=trace,
)
async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult:
tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call)
return await self.execute(tool_name, arguments, context=context)
def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode:
policy = policy or ReplayToolPolicy()
name = spec.name.lower()
toolset = spec.toolset.lower()
metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()}
if any(term in name for term in policy.destructive_terms):
return "blocked"
if toolset in policy.safe_toolsets:
return "executed"
if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}:
if any(term in name for term in policy.external_write_terms):
return "surrogate"
return "executed"
return "surrogate"
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"

View File

@ -0,0 +1,67 @@
from __future__ import annotations
import asyncio
from beaver.skills.learning.replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
from beaver.tools.base import BaseTool, ToolContext, ToolResult, ToolSpec
from beaver.tools.registry.tool_registry import ToolRegistry
from beaver.tools.runtime.executor import ToolExecutor
class FakeTool(BaseTool):
def __init__(self, name: str, *, toolset: str = "filesystem", metadata: dict | None = None) -> None:
self._spec = ToolSpec(
name=name,
description=f"{name} tool",
input_schema={"type": "object", "properties": {"path": {"type": "string"}}},
toolset=toolset,
metadata=metadata or {},
)
@property
def spec(self) -> ToolSpec:
return self._spec
async def invoke(self, arguments: dict, context: ToolContext) -> ToolResult:
return ToolResult(success=True, content=f"executed:{arguments}", tool_name=self.spec.name)
def _executor(*tools: FakeTool) -> ReplayToolExecutor:
registry = ToolRegistry()
for tool in tools:
registry.register(tool)
return ReplayToolExecutor(ToolExecutor(registry), registry=registry, policy=ReplayToolPolicy())
def test_classify_tool_modes_from_spec() -> None:
assert classify_tool_mode(FakeTool("read_file").spec) == "executed"
assert classify_tool_mode(FakeTool("write_file").spec) == "executed"
assert classify_tool_mode(FakeTool("mcp_outlook_send_email", toolset="mcp", metadata={"transport": "mcp"}).spec) == "surrogate"
assert classify_tool_mode(FakeTool("delete_account", toolset="mcp", metadata={"transport": "mcp"}).spec) == "blocked"
def test_replay_executor_executes_safe_tool_and_records_trace() -> None:
executor = _executor(FakeTool("write_file"))
result = asyncio.run(executor.execute("write_file", {"path": "a.txt"}, context=ToolContext(workspace="/tmp/replay")))
assert result.success is True
assert result.content.startswith("executed:")
assert executor.traces[0]["mode"] == "executed"
assert executor.traces[0]["tool_name"] == "write_file"
def test_replay_executor_surrogates_external_write_and_blocks_destructive() -> None:
executor = _executor(
FakeTool("mcp_outlook_send_email", toolset="mcp", metadata={"transport": "mcp"}),
FakeTool("delete_account", toolset="mcp", metadata={"transport": "mcp"}),
)
send = asyncio.run(executor.execute("mcp_outlook_send_email", {"to": "ada@example.com"}, context=ToolContext()))
delete = asyncio.run(executor.execute("delete_account", {"id": "1"}, context=ToolContext()))
assert send.success is True
assert send.error == "replay_surrogate"
assert delete.success is False
assert delete.error == "replay_blocked"
assert [trace["mode"] for trace in executor.traces] == ["surrogate", "blocked"]