feat(skill-learning): add surrogate tool evaluator
This commit is contained in:
@ -13,6 +13,7 @@ from .pipeline import SkillLearningPipelineService
|
|||||||
from .preservation import check_preservation
|
from .preservation import check_preservation
|
||||||
from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
||||||
from .service import RunReceiptContext, SkillLearningService
|
from .service import RunReceiptContext, SkillLearningService
|
||||||
|
from .surrogate import SurrogateToolEvaluator
|
||||||
from .synthesizer import SkillDraftSynthesizer
|
from .synthesizer import SkillDraftSynthesizer
|
||||||
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
||||||
|
|
||||||
@ -31,6 +32,7 @@ __all__ = [
|
|||||||
"ReplayToolExecutor",
|
"ReplayToolExecutor",
|
||||||
"ReplayToolPolicy",
|
"ReplayToolPolicy",
|
||||||
"classify_tool_mode",
|
"classify_tool_mode",
|
||||||
|
"SurrogateToolEvaluator",
|
||||||
"SkillDraftSynthesizer",
|
"SkillDraftSynthesizer",
|
||||||
"SkillLearningService",
|
"SkillLearningService",
|
||||||
"SkillLearningWorker",
|
"SkillLearningWorker",
|
||||||
|
|||||||
53
app-instance/backend/beaver/skills/learning/surrogate.py
Normal file
53
app-instance/backend/beaver/skills/learning/surrogate.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""Surrogate evaluation for replay tool calls that cannot execute safely."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class SurrogateToolEvaluator:
|
||||||
|
async def evaluate(self, *, task_text: str, baseline: dict[str, Any], candidate: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
baseline_score = _score_arm(task_text, baseline)
|
||||||
|
candidate_score = _score_arm(task_text, candidate)
|
||||||
|
surrogate_count = _mode_count(baseline, "surrogate") + _mode_count(candidate, "surrogate")
|
||||||
|
blocked_count = _mode_count(baseline, "blocked") + _mode_count(candidate, "blocked")
|
||||||
|
confidence = "low" if blocked_count else ("medium" if surrogate_count <= 2 else "low")
|
||||||
|
return {
|
||||||
|
"baseline_score": baseline_score,
|
||||||
|
"candidate_score": candidate_score,
|
||||||
|
"delta": round(candidate_score - baseline_score, 4),
|
||||||
|
"surrogate_tool_count": surrogate_count,
|
||||||
|
"blocked_tool_count": blocked_count,
|
||||||
|
"confidence": confidence,
|
||||||
|
"notes": [
|
||||||
|
"Surrogate score is based on intended tool calls, schemas, arguments, and task relevance.",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _score_arm(task_text: str, arm: dict[str, Any]) -> float:
|
||||||
|
calls = [item for item in arm.get("tool_calls") or [] if isinstance(item, dict)]
|
||||||
|
if not calls:
|
||||||
|
return 0.5
|
||||||
|
scores = [_score_call(task_text, call) for call in calls]
|
||||||
|
return round(sum(scores) / len(scores), 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _score_call(task_text: str, call: dict[str, Any]) -> float:
|
||||||
|
if call.get("mode") == "blocked":
|
||||||
|
return 0.2
|
||||||
|
if call.get("mode") == "executed":
|
||||||
|
result = call.get("result") if isinstance(call.get("result"), dict) else {}
|
||||||
|
return 0.85 if result.get("success") is not False else 0.35
|
||||||
|
arguments = dict(call.get("arguments") or {})
|
||||||
|
if not arguments:
|
||||||
|
return 0.45
|
||||||
|
non_empty = sum(1 for value in arguments.values() if str(value).strip())
|
||||||
|
completeness = non_empty / max(1, len(arguments))
|
||||||
|
argument_text = " ".join(str(value).lower() for value in arguments.values())
|
||||||
|
relevance = 0.15 if any(token and token in argument_text for token in task_text.lower().split()[:16]) else 0.0
|
||||||
|
return round(min(0.9, 0.5 + 0.3 * completeness + relevance), 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _mode_count(arm: dict[str, Any], mode: str) -> int:
|
||||||
|
return sum(1 for item in arm.get("tool_calls") or [] if isinstance(item, dict) and item.get("mode") == mode)
|
||||||
@ -0,0 +1,31 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from beaver.skills.learning.surrogate import SurrogateToolEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
def test_surrogate_scores_complete_candidate_higher_than_missing_baseline() -> None:
|
||||||
|
evaluator = SurrogateToolEvaluator()
|
||||||
|
baseline = {
|
||||||
|
"arm": "baseline",
|
||||||
|
"tool_calls": [
|
||||||
|
{"tool_name": "mcp_outlook_send_email", "mode": "surrogate", "arguments": {"to": "", "subject": ""}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
candidate = {
|
||||||
|
"arm": "candidate",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"tool_name": "mcp_outlook_send_email",
|
||||||
|
"mode": "surrogate",
|
||||||
|
"arguments": {"to": "ada@example.com", "subject": "Status", "body": "Done"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = asyncio.run(evaluator.evaluate(task_text="Send a status email to Ada.", baseline=baseline, candidate=candidate))
|
||||||
|
|
||||||
|
assert result["candidate_score"] > result["baseline_score"]
|
||||||
|
assert result["surrogate_tool_count"] == 2
|
||||||
|
assert result["confidence"] in {"low", "medium"}
|
||||||
Reference in New Issue
Block a user