feat(skill-learning): run replay arms through agent loop

This commit is contained in:
2026-06-08 13:33:53 +08:00
parent 4c8bc53d33
commit cc1bf85517
3 changed files with 103 additions and 1 deletions

View File

@ -137,3 +137,67 @@ def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"
@dataclass(slots=True)
class ReplayArmRequest:
case_id: str
arm: str
task_text: str
pinned_skill_names: list[str] = field(default_factory=list)
pinned_skill_contexts: list[Any] = field(default_factory=list)
provider_bundle: Any | None = None
model_settings: dict[str, Any] = field(default_factory=dict)
class ReplayRunner:
def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None:
self.agent_loop = agent_loop
self.policy = policy or ReplayToolPolicy()
async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]:
loaded = self.agent_loop.boot()
replay_executor = ReplayToolExecutor(
loaded.tool_executor,
registry=loaded.tool_registry,
policy=self.policy,
)
result = await self.agent_loop.process_direct(
request.task_text,
provider_bundle=request.provider_bundle,
include_skill_assembly=False,
include_tools=True,
pinned_skill_names=request.pinned_skill_names,
pinned_skill_contexts=request.pinned_skill_contexts,
max_tool_iterations=int(request.model_settings.get("max_tool_iterations") or 4),
temperature=float(request.model_settings.get("temperature") or 0.0),
source="skill_replay_eval",
tool_executor_override=replay_executor,
)
return {
"case_id": request.case_id,
"arm": request.arm,
"session_id": result.session_id,
"run_id": result.run_id,
"task_text": request.task_text,
"finish_reason": result.finish_reason,
"final_answer": result.output_text,
"tool_calls": list(replay_executor.traces),
"artifacts": [],
"side_effects": _side_effects_from_traces(replay_executor.traces),
}
def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]:
effects: list[dict[str, Any]] = []
for trace in traces:
if trace.get("mode") in {"surrogate", "blocked"}:
effects.append(
{
"tool_name": trace.get("tool_name"),
"mode": trace.get("mode"),
"arguments": trace.get("arguments"),
"classification_reason": trace.get("classification_reason"),
}
)
return effects