diff --git a/app-instance/backend/beaver/skills/learning/__init__.py b/app-instance/backend/beaver/skills/learning/__init__.py index 4bd4cd8..564aedb 100644 --- a/app-instance/backend/beaver/skills/learning/__init__.py +++ b/app-instance/backend/beaver/skills/learning/__init__.py @@ -11,7 +11,7 @@ from .missing_skill import ( ) from .pipeline import SkillLearningPipelineService from .preservation import check_preservation -from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode +from .replay import ReplayArmRequest, ReplayRunner, ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode from .service import RunReceiptContext, SkillLearningService from .surrogate import SurrogateToolEvaluator from .synthesizer import SkillDraftSynthesizer @@ -31,6 +31,8 @@ __all__ = [ "check_preservation", "ReplayToolExecutor", "ReplayToolPolicy", + "ReplayArmRequest", + "ReplayRunner", "classify_tool_mode", "SurrogateToolEvaluator", "SkillDraftSynthesizer", diff --git a/app-instance/backend/beaver/skills/learning/replay.py b/app-instance/backend/beaver/skills/learning/replay.py index bce6cd1..e0a84d0 100644 --- a/app-instance/backend/beaver/skills/learning/replay.py +++ b/app-instance/backend/beaver/skills/learning/replay.py @@ -137,3 +137,67 @@ def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) - def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str: return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}" + + +@dataclass(slots=True) +class ReplayArmRequest: + case_id: str + arm: str + task_text: str + pinned_skill_names: list[str] = field(default_factory=list) + pinned_skill_contexts: list[Any] = field(default_factory=list) + provider_bundle: Any | None = None + model_settings: dict[str, Any] = field(default_factory=dict) + + +class ReplayRunner: + def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None: + self.agent_loop = agent_loop + self.policy = policy or ReplayToolPolicy() + + async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]: + loaded = self.agent_loop.boot() + replay_executor = ReplayToolExecutor( + loaded.tool_executor, + registry=loaded.tool_registry, + policy=self.policy, + ) + result = await self.agent_loop.process_direct( + request.task_text, + provider_bundle=request.provider_bundle, + include_skill_assembly=False, + include_tools=True, + pinned_skill_names=request.pinned_skill_names, + pinned_skill_contexts=request.pinned_skill_contexts, + max_tool_iterations=int(request.model_settings.get("max_tool_iterations") or 4), + temperature=float(request.model_settings.get("temperature") or 0.0), + source="skill_replay_eval", + tool_executor_override=replay_executor, + ) + return { + "case_id": request.case_id, + "arm": request.arm, + "session_id": result.session_id, + "run_id": result.run_id, + "task_text": request.task_text, + "finish_reason": result.finish_reason, + "final_answer": result.output_text, + "tool_calls": list(replay_executor.traces), + "artifacts": [], + "side_effects": _side_effects_from_traces(replay_executor.traces), + } + + +def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]: + effects: list[dict[str, Any]] = [] + for trace in traces: + if trace.get("mode") in {"surrogate", "blocked"}: + effects.append( + { + "tool_name": trace.get("tool_name"), + "mode": trace.get("mode"), + "arguments": trace.get("arguments"), + "classification_reason": trace.get("classification_reason"), + } + ) + return effects diff --git a/app-instance/backend/tests/unit/test_skill_learning_replay_runner.py b/app-instance/backend/tests/unit/test_skill_learning_replay_runner.py new file mode 100644 index 0000000..38648d1 --- /dev/null +++ b/app-instance/backend/tests/unit/test_skill_learning_replay_runner.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +from beaver.skills.learning.replay import ReplayArmRequest, ReplayRunner + + +class FakeAgentLoop: + def boot(self): + return SimpleNamespace(tool_executor=SimpleNamespace(), tool_registry=SimpleNamespace(get=lambda name: None)) + + async def process_direct(self, task: str, **kwargs): + executor = kwargs["tool_executor_override"] + await executor.execute("mcp_outlook_send_email", {"to": "ada@example.com"}) + return SimpleNamespace(session_id="session-replay", run_id="run-replay", output_text="done", finish_reason="stop") + + +def test_replay_runner_returns_arm_report_with_tool_trace() -> None: + runner = ReplayRunner(agent_loop=FakeAgentLoop()) + request = ReplayArmRequest( + case_id="case-1", + arm="candidate", + task_text="Send a status email to Ada.", + pinned_skill_names=[], + pinned_skill_contexts=[], + provider_bundle=object(), + model_settings={"max_tool_iterations": 2}, + ) + + report = asyncio.run(runner.run_arm(request)) + + assert report["case_id"] == "case-1" + assert report["arm"] == "candidate" + assert report["finish_reason"] == "stop" + assert report["tool_calls"][0]["tool_name"] == "mcp_outlook_send_email"