feat(skill-learning): add surrogate tool evaluator

This commit is contained in:
2026-06-08 13:33:02 +08:00
parent 70014c0f70
commit 4c8bc53d33
3 changed files with 86 additions and 0 deletions

View File

@ -13,6 +13,7 @@ from .pipeline import SkillLearningPipelineService
from .preservation import check_preservation from .preservation import check_preservation
from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
from .service import RunReceiptContext, SkillLearningService from .service import RunReceiptContext, SkillLearningService
from .surrogate import SurrogateToolEvaluator
from .synthesizer import SkillDraftSynthesizer from .synthesizer import SkillDraftSynthesizer
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
@ -31,6 +32,7 @@ __all__ = [
"ReplayToolExecutor", "ReplayToolExecutor",
"ReplayToolPolicy", "ReplayToolPolicy",
"classify_tool_mode", "classify_tool_mode",
"SurrogateToolEvaluator",
"SkillDraftSynthesizer", "SkillDraftSynthesizer",
"SkillLearningService", "SkillLearningService",
"SkillLearningWorker", "SkillLearningWorker",

View File

@ -0,0 +1,53 @@
"""Surrogate evaluation for replay tool calls that cannot execute safely."""
from __future__ import annotations
from typing import Any
class SurrogateToolEvaluator:
async def evaluate(self, *, task_text: str, baseline: dict[str, Any], candidate: dict[str, Any]) -> dict[str, Any]:
baseline_score = _score_arm(task_text, baseline)
candidate_score = _score_arm(task_text, candidate)
surrogate_count = _mode_count(baseline, "surrogate") + _mode_count(candidate, "surrogate")
blocked_count = _mode_count(baseline, "blocked") + _mode_count(candidate, "blocked")
confidence = "low" if blocked_count else ("medium" if surrogate_count <= 2 else "low")
return {
"baseline_score": baseline_score,
"candidate_score": candidate_score,
"delta": round(candidate_score - baseline_score, 4),
"surrogate_tool_count": surrogate_count,
"blocked_tool_count": blocked_count,
"confidence": confidence,
"notes": [
"Surrogate score is based on intended tool calls, schemas, arguments, and task relevance.",
],
}
def _score_arm(task_text: str, arm: dict[str, Any]) -> float:
calls = [item for item in arm.get("tool_calls") or [] if isinstance(item, dict)]
if not calls:
return 0.5
scores = [_score_call(task_text, call) for call in calls]
return round(sum(scores) / len(scores), 4)
def _score_call(task_text: str, call: dict[str, Any]) -> float:
if call.get("mode") == "blocked":
return 0.2
if call.get("mode") == "executed":
result = call.get("result") if isinstance(call.get("result"), dict) else {}
return 0.85 if result.get("success") is not False else 0.35
arguments = dict(call.get("arguments") or {})
if not arguments:
return 0.45
non_empty = sum(1 for value in arguments.values() if str(value).strip())
completeness = non_empty / max(1, len(arguments))
argument_text = " ".join(str(value).lower() for value in arguments.values())
relevance = 0.15 if any(token and token in argument_text for token in task_text.lower().split()[:16]) else 0.0
return round(min(0.9, 0.5 + 0.3 * completeness + relevance), 4)
def _mode_count(arm: dict[str, Any], mode: str) -> int:
return sum(1 for item in arm.get("tool_calls") or [] if isinstance(item, dict) and item.get("mode") == mode)

View File

@ -0,0 +1,31 @@
from __future__ import annotations
import asyncio
from beaver.skills.learning.surrogate import SurrogateToolEvaluator
def test_surrogate_scores_complete_candidate_higher_than_missing_baseline() -> None:
evaluator = SurrogateToolEvaluator()
baseline = {
"arm": "baseline",
"tool_calls": [
{"tool_name": "mcp_outlook_send_email", "mode": "surrogate", "arguments": {"to": "", "subject": ""}},
],
}
candidate = {
"arm": "candidate",
"tool_calls": [
{
"tool_name": "mcp_outlook_send_email",
"mode": "surrogate",
"arguments": {"to": "ada@example.com", "subject": "Status", "body": "Done"},
},
],
}
result = asyncio.run(evaluator.evaluate(task_text="Send a status email to Ada.", baseline=baseline, candidate=candidate))
assert result["candidate_score"] > result["baseline_score"]
assert result["surrogate_tool_count"] == 2
assert result["confidence"] in {"low", "medium"}