feat(skill-learning): run replay arms through agent loop

This commit is contained in:
2026-06-08 13:33:53 +08:00
parent 4c8bc53d33
commit cc1bf85517
3 changed files with 103 additions and 1 deletions

View File

@ -11,7 +11,7 @@ from .missing_skill import (
)
from .pipeline import SkillLearningPipelineService
from .preservation import check_preservation
from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
from .replay import ReplayArmRequest, ReplayRunner, ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
from .service import RunReceiptContext, SkillLearningService
from .surrogate import SurrogateToolEvaluator
from .synthesizer import SkillDraftSynthesizer
@ -31,6 +31,8 @@ __all__ = [
"check_preservation",
"ReplayToolExecutor",
"ReplayToolPolicy",
"ReplayArmRequest",
"ReplayRunner",
"classify_tool_mode",
"SurrogateToolEvaluator",
"SkillDraftSynthesizer",

View File

@ -137,3 +137,67 @@ def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"
@dataclass(slots=True)
class ReplayArmRequest:
case_id: str
arm: str
task_text: str
pinned_skill_names: list[str] = field(default_factory=list)
pinned_skill_contexts: list[Any] = field(default_factory=list)
provider_bundle: Any | None = None
model_settings: dict[str, Any] = field(default_factory=dict)
class ReplayRunner:
def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None:
self.agent_loop = agent_loop
self.policy = policy or ReplayToolPolicy()
async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]:
loaded = self.agent_loop.boot()
replay_executor = ReplayToolExecutor(
loaded.tool_executor,
registry=loaded.tool_registry,
policy=self.policy,
)
result = await self.agent_loop.process_direct(
request.task_text,
provider_bundle=request.provider_bundle,
include_skill_assembly=False,
include_tools=True,
pinned_skill_names=request.pinned_skill_names,
pinned_skill_contexts=request.pinned_skill_contexts,
max_tool_iterations=int(request.model_settings.get("max_tool_iterations") or 4),
temperature=float(request.model_settings.get("temperature") or 0.0),
source="skill_replay_eval",
tool_executor_override=replay_executor,
)
return {
"case_id": request.case_id,
"arm": request.arm,
"session_id": result.session_id,
"run_id": result.run_id,
"task_text": request.task_text,
"finish_reason": result.finish_reason,
"final_answer": result.output_text,
"tool_calls": list(replay_executor.traces),
"artifacts": [],
"side_effects": _side_effects_from_traces(replay_executor.traces),
}
def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]:
effects: list[dict[str, Any]] = []
for trace in traces:
if trace.get("mode") in {"surrogate", "blocked"}:
effects.append(
{
"tool_name": trace.get("tool_name"),
"mode": trace.get("mode"),
"arguments": trace.get("arguments"),
"classification_reason": trace.get("classification_reason"),
}
)
return effects

View File

@ -0,0 +1,36 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from beaver.skills.learning.replay import ReplayArmRequest, ReplayRunner
class FakeAgentLoop:
def boot(self):
return SimpleNamespace(tool_executor=SimpleNamespace(), tool_registry=SimpleNamespace(get=lambda name: None))
async def process_direct(self, task: str, **kwargs):
executor = kwargs["tool_executor_override"]
await executor.execute("mcp_outlook_send_email", {"to": "ada@example.com"})
return SimpleNamespace(session_id="session-replay", run_id="run-replay", output_text="done", finish_reason="stop")
def test_replay_runner_returns_arm_report_with_tool_trace() -> None:
runner = ReplayRunner(agent_loop=FakeAgentLoop())
request = ReplayArmRequest(
case_id="case-1",
arm="candidate",
task_text="Send a status email to Ada.",
pinned_skill_names=[],
pinned_skill_contexts=[],
provider_bundle=object(),
model_settings={"max_tool_iterations": 2},
)
report = asyncio.run(runner.run_arm(request))
assert report["case_id"] == "case-1"
assert report["arm"] == "candidate"
assert report["finish_reason"] == "stop"
assert report["tool_calls"][0]["tool_name"] == "mcp_outlook_send_email"