feat(skill-learning): run replay arms through agent loop
This commit is contained in:
@ -137,3 +137,67 @@ def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -
|
||||
|
||||
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
|
||||
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ReplayArmRequest:
|
||||
case_id: str
|
||||
arm: str
|
||||
task_text: str
|
||||
pinned_skill_names: list[str] = field(default_factory=list)
|
||||
pinned_skill_contexts: list[Any] = field(default_factory=list)
|
||||
provider_bundle: Any | None = None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ReplayRunner:
|
||||
def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None:
|
||||
self.agent_loop = agent_loop
|
||||
self.policy = policy or ReplayToolPolicy()
|
||||
|
||||
async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]:
|
||||
loaded = self.agent_loop.boot()
|
||||
replay_executor = ReplayToolExecutor(
|
||||
loaded.tool_executor,
|
||||
registry=loaded.tool_registry,
|
||||
policy=self.policy,
|
||||
)
|
||||
result = await self.agent_loop.process_direct(
|
||||
request.task_text,
|
||||
provider_bundle=request.provider_bundle,
|
||||
include_skill_assembly=False,
|
||||
include_tools=True,
|
||||
pinned_skill_names=request.pinned_skill_names,
|
||||
pinned_skill_contexts=request.pinned_skill_contexts,
|
||||
max_tool_iterations=int(request.model_settings.get("max_tool_iterations") or 4),
|
||||
temperature=float(request.model_settings.get("temperature") or 0.0),
|
||||
source="skill_replay_eval",
|
||||
tool_executor_override=replay_executor,
|
||||
)
|
||||
return {
|
||||
"case_id": request.case_id,
|
||||
"arm": request.arm,
|
||||
"session_id": result.session_id,
|
||||
"run_id": result.run_id,
|
||||
"task_text": request.task_text,
|
||||
"finish_reason": result.finish_reason,
|
||||
"final_answer": result.output_text,
|
||||
"tool_calls": list(replay_executor.traces),
|
||||
"artifacts": [],
|
||||
"side_effects": _side_effects_from_traces(replay_executor.traces),
|
||||
}
|
||||
|
||||
|
||||
def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
effects: list[dict[str, Any]] = []
|
||||
for trace in traces:
|
||||
if trace.get("mode") in {"surrogate", "blocked"}:
|
||||
effects.append(
|
||||
{
|
||||
"tool_name": trace.get("tool_name"),
|
||||
"mode": trace.get("mode"),
|
||||
"arguments": trace.get("arguments"),
|
||||
"classification_reason": trace.get("classification_reason"),
|
||||
}
|
||||
)
|
||||
return effects
|
||||
|
||||
Reference in New Issue
Block a user