"""Surrogate evaluation for replay tool calls that cannot execute safely.""" from __future__ import annotations from typing import Any class SurrogateToolEvaluator: async def evaluate(self, *, task_text: str, baseline: dict[str, Any], candidate: dict[str, Any]) -> dict[str, Any]: baseline_score = _score_arm(task_text, baseline) candidate_score = _score_arm(task_text, candidate) surrogate_count = _mode_count(baseline, "surrogate") + _mode_count(candidate, "surrogate") blocked_count = _mode_count(baseline, "blocked") + _mode_count(candidate, "blocked") confidence = "low" if blocked_count else ("medium" if surrogate_count <= 2 else "low") return { "baseline_score": baseline_score, "candidate_score": candidate_score, "baseline_tool_execution_score": baseline_score, "candidate_tool_execution_score": candidate_score, "delta": round(candidate_score - baseline_score, 4), "surrogate_tool_count": surrogate_count, "blocked_tool_count": blocked_count, "score_role": "diagnostic_only", "confidence": confidence, "notes": [ "Tool execution score is diagnostic only and is not the main task ability score.", ], } def _score_arm(task_text: str, arm: dict[str, Any]) -> float: calls = [item for item in arm.get("tool_calls") or [] if isinstance(item, dict)] if not calls: return 0.5 scores = [_score_call(task_text, call) for call in calls] return round(sum(scores) / len(scores), 4) def _score_call(task_text: str, call: dict[str, Any]) -> float: if call.get("mode") == "blocked": return 0.2 if call.get("mode") == "executed": result = call.get("result") if isinstance(call.get("result"), dict) else {} return 0.85 if result.get("success") is not False else 0.35 arguments = dict(call.get("arguments") or {}) if not arguments: return 0.45 non_empty = sum(1 for value in arguments.values() if str(value).strip()) completeness = non_empty / max(1, len(arguments)) argument_text = " ".join(str(value).lower() for value in arguments.values()) relevance = 0.15 if any(token and token in argument_text for token in task_text.lower().split()[:16]) else 0.0 return round(min(0.9, 0.5 + 0.3 * completeness + relevance), 4) def _mode_count(arm: dict[str, Any], mode: str) -> int: return sum(1 for item in arm.get("tool_calls") or [] if isinstance(item, dict) and item.get("mode") == mode)