feat(skill-learning): run replay arms through agent loop
This commit is contained in:
@ -11,7 +11,7 @@ from .missing_skill import (
|
||||
)
|
||||
from .pipeline import SkillLearningPipelineService
|
||||
from .preservation import check_preservation
|
||||
from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
||||
from .replay import ReplayArmRequest, ReplayRunner, ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
||||
from .service import RunReceiptContext, SkillLearningService
|
||||
from .surrogate import SurrogateToolEvaluator
|
||||
from .synthesizer import SkillDraftSynthesizer
|
||||
@ -31,6 +31,8 @@ __all__ = [
|
||||
"check_preservation",
|
||||
"ReplayToolExecutor",
|
||||
"ReplayToolPolicy",
|
||||
"ReplayArmRequest",
|
||||
"ReplayRunner",
|
||||
"classify_tool_mode",
|
||||
"SurrogateToolEvaluator",
|
||||
"SkillDraftSynthesizer",
|
||||
|
||||
@ -137,3 +137,67 @@ def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -
|
||||
|
||||
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
|
||||
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ReplayArmRequest:
|
||||
case_id: str
|
||||
arm: str
|
||||
task_text: str
|
||||
pinned_skill_names: list[str] = field(default_factory=list)
|
||||
pinned_skill_contexts: list[Any] = field(default_factory=list)
|
||||
provider_bundle: Any | None = None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ReplayRunner:
|
||||
def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None:
|
||||
self.agent_loop = agent_loop
|
||||
self.policy = policy or ReplayToolPolicy()
|
||||
|
||||
async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]:
|
||||
loaded = self.agent_loop.boot()
|
||||
replay_executor = ReplayToolExecutor(
|
||||
loaded.tool_executor,
|
||||
registry=loaded.tool_registry,
|
||||
policy=self.policy,
|
||||
)
|
||||
result = await self.agent_loop.process_direct(
|
||||
request.task_text,
|
||||
provider_bundle=request.provider_bundle,
|
||||
include_skill_assembly=False,
|
||||
include_tools=True,
|
||||
pinned_skill_names=request.pinned_skill_names,
|
||||
pinned_skill_contexts=request.pinned_skill_contexts,
|
||||
max_tool_iterations=int(request.model_settings.get("max_tool_iterations") or 4),
|
||||
temperature=float(request.model_settings.get("temperature") or 0.0),
|
||||
source="skill_replay_eval",
|
||||
tool_executor_override=replay_executor,
|
||||
)
|
||||
return {
|
||||
"case_id": request.case_id,
|
||||
"arm": request.arm,
|
||||
"session_id": result.session_id,
|
||||
"run_id": result.run_id,
|
||||
"task_text": request.task_text,
|
||||
"finish_reason": result.finish_reason,
|
||||
"final_answer": result.output_text,
|
||||
"tool_calls": list(replay_executor.traces),
|
||||
"artifacts": [],
|
||||
"side_effects": _side_effects_from_traces(replay_executor.traces),
|
||||
}
|
||||
|
||||
|
||||
def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
effects: list[dict[str, Any]] = []
|
||||
for trace in traces:
|
||||
if trace.get("mode") in {"surrogate", "blocked"}:
|
||||
effects.append(
|
||||
{
|
||||
"tool_name": trace.get("tool_name"),
|
||||
"mode": trace.get("mode"),
|
||||
"arguments": trace.get("arguments"),
|
||||
"classification_reason": trace.get("classification_reason"),
|
||||
}
|
||||
)
|
||||
return effects
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from beaver.skills.learning.replay import ReplayArmRequest, ReplayRunner
|
||||
|
||||
|
||||
class FakeAgentLoop:
|
||||
def boot(self):
|
||||
return SimpleNamespace(tool_executor=SimpleNamespace(), tool_registry=SimpleNamespace(get=lambda name: None))
|
||||
|
||||
async def process_direct(self, task: str, **kwargs):
|
||||
executor = kwargs["tool_executor_override"]
|
||||
await executor.execute("mcp_outlook_send_email", {"to": "ada@example.com"})
|
||||
return SimpleNamespace(session_id="session-replay", run_id="run-replay", output_text="done", finish_reason="stop")
|
||||
|
||||
|
||||
def test_replay_runner_returns_arm_report_with_tool_trace() -> None:
|
||||
runner = ReplayRunner(agent_loop=FakeAgentLoop())
|
||||
request = ReplayArmRequest(
|
||||
case_id="case-1",
|
||||
arm="candidate",
|
||||
task_text="Send a status email to Ada.",
|
||||
pinned_skill_names=[],
|
||||
pinned_skill_contexts=[],
|
||||
provider_bundle=object(),
|
||||
model_settings={"max_tool_iterations": 2},
|
||||
)
|
||||
|
||||
report = asyncio.run(runner.run_arm(request))
|
||||
|
||||
assert report["case_id"] == "case-1"
|
||||
assert report["arm"] == "candidate"
|
||||
assert report["finish_reason"] == "stop"
|
||||
assert report["tool_calls"][0]["tool_name"] == "mcp_outlook_send_email"
|
||||
Reference in New Issue
Block a user