diff --git a/app-instance/backend/beaver/skills/learning/__init__.py b/app-instance/backend/beaver/skills/learning/__init__.py index 1a4f5f3..4bd4cd8 100644 --- a/app-instance/backend/beaver/skills/learning/__init__.py +++ b/app-instance/backend/beaver/skills/learning/__init__.py @@ -13,6 +13,7 @@ from .pipeline import SkillLearningPipelineService from .preservation import check_preservation from .replay import ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode from .service import RunReceiptContext, SkillLearningService +from .surrogate import SurrogateToolEvaluator from .synthesizer import SkillDraftSynthesizer from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult @@ -31,6 +32,7 @@ __all__ = [ "ReplayToolExecutor", "ReplayToolPolicy", "classify_tool_mode", + "SurrogateToolEvaluator", "SkillDraftSynthesizer", "SkillLearningService", "SkillLearningWorker", diff --git a/app-instance/backend/beaver/skills/learning/surrogate.py b/app-instance/backend/beaver/skills/learning/surrogate.py new file mode 100644 index 0000000..ab2d3ff --- /dev/null +++ b/app-instance/backend/beaver/skills/learning/surrogate.py @@ -0,0 +1,53 @@ +"""Surrogate evaluation for replay tool calls that cannot execute safely.""" + +from __future__ import annotations + +from typing import Any + + +class SurrogateToolEvaluator: + async def evaluate(self, *, task_text: str, baseline: dict[str, Any], candidate: dict[str, Any]) -> dict[str, Any]: + baseline_score = _score_arm(task_text, baseline) + candidate_score = _score_arm(task_text, candidate) + surrogate_count = _mode_count(baseline, "surrogate") + _mode_count(candidate, "surrogate") + blocked_count = _mode_count(baseline, "blocked") + _mode_count(candidate, "blocked") + confidence = "low" if blocked_count else ("medium" if surrogate_count <= 2 else "low") + return { + "baseline_score": baseline_score, + "candidate_score": candidate_score, + "delta": round(candidate_score - baseline_score, 4), + "surrogate_tool_count": surrogate_count, + "blocked_tool_count": blocked_count, + "confidence": confidence, + "notes": [ + "Surrogate score is based on intended tool calls, schemas, arguments, and task relevance.", + ], + } + + +def _score_arm(task_text: str, arm: dict[str, Any]) -> float: + calls = [item for item in arm.get("tool_calls") or [] if isinstance(item, dict)] + if not calls: + return 0.5 + scores = [_score_call(task_text, call) for call in calls] + return round(sum(scores) / len(scores), 4) + + +def _score_call(task_text: str, call: dict[str, Any]) -> float: + if call.get("mode") == "blocked": + return 0.2 + if call.get("mode") == "executed": + result = call.get("result") if isinstance(call.get("result"), dict) else {} + return 0.85 if result.get("success") is not False else 0.35 + arguments = dict(call.get("arguments") or {}) + if not arguments: + return 0.45 + non_empty = sum(1 for value in arguments.values() if str(value).strip()) + completeness = non_empty / max(1, len(arguments)) + argument_text = " ".join(str(value).lower() for value in arguments.values()) + relevance = 0.15 if any(token and token in argument_text for token in task_text.lower().split()[:16]) else 0.0 + return round(min(0.9, 0.5 + 0.3 * completeness + relevance), 4) + + +def _mode_count(arm: dict[str, Any], mode: str) -> int: + return sum(1 for item in arm.get("tool_calls") or [] if isinstance(item, dict) and item.get("mode") == mode) diff --git a/app-instance/backend/tests/unit/test_skill_learning_surrogate.py b/app-instance/backend/tests/unit/test_skill_learning_surrogate.py new file mode 100644 index 0000000..47d6847 --- /dev/null +++ b/app-instance/backend/tests/unit/test_skill_learning_surrogate.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import asyncio + +from beaver.skills.learning.surrogate import SurrogateToolEvaluator + + +def test_surrogate_scores_complete_candidate_higher_than_missing_baseline() -> None: + evaluator = SurrogateToolEvaluator() + baseline = { + "arm": "baseline", + "tool_calls": [ + {"tool_name": "mcp_outlook_send_email", "mode": "surrogate", "arguments": {"to": "", "subject": ""}}, + ], + } + candidate = { + "arm": "candidate", + "tool_calls": [ + { + "tool_name": "mcp_outlook_send_email", + "mode": "surrogate", + "arguments": {"to": "ada@example.com", "subject": "Status", "body": "Done"}, + }, + ], + } + + result = asyncio.run(evaluator.evaluate(task_text="Send a status email to Ada.", baseline=baseline, candidate=candidate)) + + assert result["candidate_score"] > result["baseline_score"] + assert result["surrogate_tool_count"] == 2 + assert result["confidence"] in {"low", "medium"}