Files
beaver_project/app-instance/backend/beaver/skills/learning/surrogate.py

54 lines
2.4 KiB
Python

"""Surrogate evaluation for replay tool calls that cannot execute safely."""
from __future__ import annotations
from typing import Any
class SurrogateToolEvaluator:
async def evaluate(self, *, task_text: str, baseline: dict[str, Any], candidate: dict[str, Any]) -> dict[str, Any]:
baseline_score = _score_arm(task_text, baseline)
candidate_score = _score_arm(task_text, candidate)
surrogate_count = _mode_count(baseline, "surrogate") + _mode_count(candidate, "surrogate")
blocked_count = _mode_count(baseline, "blocked") + _mode_count(candidate, "blocked")
confidence = "low" if blocked_count else ("medium" if surrogate_count <= 2 else "low")
return {
"baseline_score": baseline_score,
"candidate_score": candidate_score,
"delta": round(candidate_score - baseline_score, 4),
"surrogate_tool_count": surrogate_count,
"blocked_tool_count": blocked_count,
"confidence": confidence,
"notes": [
"Surrogate score is based on intended tool calls, schemas, arguments, and task relevance.",
],
}
def _score_arm(task_text: str, arm: dict[str, Any]) -> float:
calls = [item for item in arm.get("tool_calls") or [] if isinstance(item, dict)]
if not calls:
return 0.5
scores = [_score_call(task_text, call) for call in calls]
return round(sum(scores) / len(scores), 4)
def _score_call(task_text: str, call: dict[str, Any]) -> float:
if call.get("mode") == "blocked":
return 0.2
if call.get("mode") == "executed":
result = call.get("result") if isinstance(call.get("result"), dict) else {}
return 0.85 if result.get("success") is not False else 0.35
arguments = dict(call.get("arguments") or {})
if not arguments:
return 0.45
non_empty = sum(1 for value in arguments.values() if str(value).strip())
completeness = non_empty / max(1, len(arguments))
argument_text = " ".join(str(value).lower() for value in arguments.values())
relevance = 0.15 if any(token and token in argument_text for token in task_text.lower().split()[:16]) else 0.0
return round(min(0.9, 0.5 + 0.3 * completeness + relevance), 4)
def _mode_count(arm: dict[str, Any], mode: str) -> int:
return sum(1 for item in arm.get("tool_calls") or [] if isinstance(item, dict) and item.get("mode") == mode)