feat(skill-learning): add replay tool policy
This commit is contained in:
@ -11,6 +11,7 @@ from .missing_skill import (
|
|||||||
)
|
)
|
||||||
from .pipeline import SkillLearningPipelineService
|
from .pipeline import SkillLearningPipelineService
|
||||||
from .preservation import check_preservation
|
from .preservation import check_preservation
|
||||||
|
from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
||||||
from .service import RunReceiptContext, SkillLearningService
|
from .service import RunReceiptContext, SkillLearningService
|
||||||
from .synthesizer import SkillDraftSynthesizer
|
from .synthesizer import SkillDraftSynthesizer
|
||||||
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
||||||
@ -27,6 +28,9 @@ __all__ = [
|
|||||||
"RunReceiptContext",
|
"RunReceiptContext",
|
||||||
"SkillLearningPipelineService",
|
"SkillLearningPipelineService",
|
||||||
"check_preservation",
|
"check_preservation",
|
||||||
|
"ReplayToolExecutor",
|
||||||
|
"ReplayToolPolicy",
|
||||||
|
"classify_tool_mode",
|
||||||
"SkillDraftSynthesizer",
|
"SkillDraftSynthesizer",
|
||||||
"SkillLearningService",
|
"SkillLearningService",
|
||||||
"SkillLearningWorker",
|
"SkillLearningWorker",
|
||||||
|
|||||||
139
app-instance/backend/beaver/skills/learning/replay.py
Normal file
139
app-instance/backend/beaver/skills/learning/replay.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
"""Replay execution helpers for skill draft evaluation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Literal
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from beaver.tools.base import ToolContext, ToolResult, ToolSpec
|
||||||
|
from beaver.tools.registry.tool_registry import ToolRegistry
|
||||||
|
from beaver.tools.runtime.executor import ToolExecutor
|
||||||
|
|
||||||
|
ToolExecutionMode = Literal["executed", "surrogate", "blocked"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ReplayToolPolicy:
|
||||||
|
safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"})
|
||||||
|
surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"})
|
||||||
|
destructive_terms: tuple[str, ...] = (
|
||||||
|
"delete",
|
||||||
|
"remove",
|
||||||
|
"destroy",
|
||||||
|
"revoke",
|
||||||
|
"permission",
|
||||||
|
"credential",
|
||||||
|
"payment",
|
||||||
|
"pay",
|
||||||
|
)
|
||||||
|
external_write_terms: tuple[str, ...] = (
|
||||||
|
"send",
|
||||||
|
"post",
|
||||||
|
"publish",
|
||||||
|
"create",
|
||||||
|
"update",
|
||||||
|
"invite",
|
||||||
|
"reply",
|
||||||
|
"forward",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayToolExecutor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner: ToolExecutor,
|
||||||
|
*,
|
||||||
|
registry: ToolRegistry,
|
||||||
|
policy: ReplayToolPolicy | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.inner = inner
|
||||||
|
self.registry = registry
|
||||||
|
self.policy = policy or ReplayToolPolicy()
|
||||||
|
self.traces: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any] | None,
|
||||||
|
*,
|
||||||
|
context: ToolContext | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
tool = self.registry.get(tool_name)
|
||||||
|
spec = tool.spec if tool is not None else ToolSpec(
|
||||||
|
name=tool_name,
|
||||||
|
description="unregistered tool",
|
||||||
|
input_schema={"type": "object", "properties": {}},
|
||||||
|
toolset="unknown",
|
||||||
|
)
|
||||||
|
mode = classify_tool_mode(spec, self.policy)
|
||||||
|
trace = {
|
||||||
|
"trace_id": uuid4().hex,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"mode": mode,
|
||||||
|
"arguments": dict(arguments or {}),
|
||||||
|
"schema": dict(spec.input_schema),
|
||||||
|
"toolset": spec.toolset,
|
||||||
|
"metadata": dict(spec.metadata),
|
||||||
|
"classification_reason": _classification_reason(spec, mode),
|
||||||
|
}
|
||||||
|
if mode == "executed":
|
||||||
|
result = await self.inner.execute(tool_name, arguments or {}, context=context)
|
||||||
|
trace["result"] = {
|
||||||
|
"success": result.success,
|
||||||
|
"error": result.error,
|
||||||
|
"content": result.content[:2000],
|
||||||
|
}
|
||||||
|
self.traces.append(trace)
|
||||||
|
return result
|
||||||
|
if mode == "surrogate":
|
||||||
|
trace["result"] = {
|
||||||
|
"success": True,
|
||||||
|
"error": "replay_surrogate",
|
||||||
|
"content": "Tool call recorded for surrogate evaluation.",
|
||||||
|
}
|
||||||
|
self.traces.append(trace)
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
content="Tool call recorded for surrogate evaluation.",
|
||||||
|
tool_name=tool_name,
|
||||||
|
error="replay_surrogate",
|
||||||
|
raw_output=trace,
|
||||||
|
)
|
||||||
|
trace["result"] = {
|
||||||
|
"success": False,
|
||||||
|
"error": "replay_blocked",
|
||||||
|
"content": "Tool call blocked by replay policy.",
|
||||||
|
}
|
||||||
|
self.traces.append(trace)
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
content="Tool call blocked by replay policy.",
|
||||||
|
tool_name=tool_name,
|
||||||
|
error="replay_blocked",
|
||||||
|
raw_output=trace,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult:
|
||||||
|
tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call)
|
||||||
|
return await self.execute(tool_name, arguments, context=context)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode:
|
||||||
|
policy = policy or ReplayToolPolicy()
|
||||||
|
name = spec.name.lower()
|
||||||
|
toolset = spec.toolset.lower()
|
||||||
|
metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()}
|
||||||
|
if any(term in name for term in policy.destructive_terms):
|
||||||
|
return "blocked"
|
||||||
|
if toolset in policy.safe_toolsets:
|
||||||
|
return "executed"
|
||||||
|
if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}:
|
||||||
|
if any(term in name for term in policy.external_write_terms):
|
||||||
|
return "surrogate"
|
||||||
|
return "executed"
|
||||||
|
return "surrogate"
|
||||||
|
|
||||||
|
|
||||||
|
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
|
||||||
|
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"
|
||||||
@ -0,0 +1,67 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from beaver.skills.learning.replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
||||||
|
from beaver.tools.base import BaseTool, ToolContext, ToolResult, ToolSpec
|
||||||
|
from beaver.tools.registry.tool_registry import ToolRegistry
|
||||||
|
from beaver.tools.runtime.executor import ToolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTool(BaseTool):
|
||||||
|
def __init__(self, name: str, *, toolset: str = "filesystem", metadata: dict | None = None) -> None:
|
||||||
|
self._spec = ToolSpec(
|
||||||
|
name=name,
|
||||||
|
description=f"{name} tool",
|
||||||
|
input_schema={"type": "object", "properties": {"path": {"type": "string"}}},
|
||||||
|
toolset=toolset,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def spec(self) -> ToolSpec:
|
||||||
|
return self._spec
|
||||||
|
|
||||||
|
async def invoke(self, arguments: dict, context: ToolContext) -> ToolResult:
|
||||||
|
return ToolResult(success=True, content=f"executed:{arguments}", tool_name=self.spec.name)
|
||||||
|
|
||||||
|
|
||||||
|
def _executor(*tools: FakeTool) -> ReplayToolExecutor:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
for tool in tools:
|
||||||
|
registry.register(tool)
|
||||||
|
return ReplayToolExecutor(ToolExecutor(registry), registry=registry, policy=ReplayToolPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_tool_modes_from_spec() -> None:
|
||||||
|
assert classify_tool_mode(FakeTool("read_file").spec) == "executed"
|
||||||
|
assert classify_tool_mode(FakeTool("write_file").spec) == "executed"
|
||||||
|
assert classify_tool_mode(FakeTool("mcp_outlook_send_email", toolset="mcp", metadata={"transport": "mcp"}).spec) == "surrogate"
|
||||||
|
assert classify_tool_mode(FakeTool("delete_account", toolset="mcp", metadata={"transport": "mcp"}).spec) == "blocked"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replay_executor_executes_safe_tool_and_records_trace() -> None:
|
||||||
|
executor = _executor(FakeTool("write_file"))
|
||||||
|
|
||||||
|
result = asyncio.run(executor.execute("write_file", {"path": "a.txt"}, context=ToolContext(workspace="/tmp/replay")))
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.content.startswith("executed:")
|
||||||
|
assert executor.traces[0]["mode"] == "executed"
|
||||||
|
assert executor.traces[0]["tool_name"] == "write_file"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replay_executor_surrogates_external_write_and_blocks_destructive() -> None:
|
||||||
|
executor = _executor(
|
||||||
|
FakeTool("mcp_outlook_send_email", toolset="mcp", metadata={"transport": "mcp"}),
|
||||||
|
FakeTool("delete_account", toolset="mcp", metadata={"transport": "mcp"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
send = asyncio.run(executor.execute("mcp_outlook_send_email", {"to": "ada@example.com"}, context=ToolContext()))
|
||||||
|
delete = asyncio.run(executor.execute("delete_account", {"id": "1"}, context=ToolContext()))
|
||||||
|
|
||||||
|
assert send.success is True
|
||||||
|
assert send.error == "replay_surrogate"
|
||||||
|
assert delete.success is False
|
||||||
|
assert delete.error == "replay_blocked"
|
||||||
|
assert [trace["mode"] for trace in executor.traces] == ["surrogate", "blocked"]
|
||||||
Reference in New Issue
Block a user