feat(skill-learning): select replay eval cases
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
"""Skill learning loop helpers."""
|
"""Skill learning loop helpers."""
|
||||||
|
|
||||||
|
from .case_selection import select_replay_cases
|
||||||
from .evidence import EvidencePacket, EvidenceSelector
|
from .evidence import EvidencePacket, EvidenceSelector
|
||||||
from .eval import SkillDraftEvaluator
|
from .eval import SkillDraftEvaluator
|
||||||
from .missing_skill import (
|
from .missing_skill import (
|
||||||
@ -15,6 +16,7 @@ from .synthesizer import SkillDraftSynthesizer
|
|||||||
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"select_replay_cases",
|
||||||
"EvidencePacket",
|
"EvidencePacket",
|
||||||
"EvidenceSelector",
|
"EvidenceSelector",
|
||||||
"SkillDraftEvaluator",
|
"SkillDraftEvaluator",
|
||||||
|
|||||||
109
app-instance/backend/beaver/skills/learning/case_selection.py
Normal file
109
app-instance/backend/beaver/skills/learning/case_selection.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
"""Historical replay case selection for skill draft evaluation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.memory.runs import RunRecord
|
||||||
|
from beaver.memory.skills import SkillLearningCandidate
|
||||||
|
|
||||||
|
MAX_REPLAY_CASES = 10
|
||||||
|
|
||||||
|
|
||||||
|
def select_replay_cases(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[dict[str, Any]]:
|
||||||
|
accepted = [record for record in runs if _is_accepted(record)]
|
||||||
|
if candidate.kind == "revise_skill":
|
||||||
|
selected = _select_revise(candidate, accepted)
|
||||||
|
elif candidate.kind == "merge_skills":
|
||||||
|
selected = _select_merge(candidate, accepted)
|
||||||
|
else:
|
||||||
|
selected = _select_new(candidate, accepted)
|
||||||
|
return [_case_payload(candidate, record) for record in selected[:MAX_REPLAY_CASES]]
|
||||||
|
|
||||||
|
|
||||||
|
def _select_revise(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||||
|
target = candidate.related_skill_names[0] if candidate.related_skill_names else ""
|
||||||
|
version = str(candidate.evidence.get("skill_version") or "")
|
||||||
|
matches = [
|
||||||
|
record
|
||||||
|
for record in runs
|
||||||
|
if any(
|
||||||
|
receipt.skill_name == target and (not version or receipt.skill_version == version)
|
||||||
|
for receipt in record.activated_skills
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return _recent_diverse(matches)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_merge(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||||
|
targets = set(candidate.related_skill_names)
|
||||||
|
matches = [
|
||||||
|
record
|
||||||
|
for record in runs
|
||||||
|
if targets and targets.issubset({receipt.skill_name for receipt in record.activated_skills})
|
||||||
|
]
|
||||||
|
return _recent_diverse(matches)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_new(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||||
|
source_ids = set(candidate.source_run_ids)
|
||||||
|
if source_ids:
|
||||||
|
matches = [record for record in runs if record.run_id in source_ids]
|
||||||
|
else:
|
||||||
|
theme = str(candidate.evidence.get("theme") or "").lower().strip()
|
||||||
|
matches = [record for record in runs if theme and theme in record.task_text.lower()]
|
||||||
|
return _recent_diverse(matches)
|
||||||
|
|
||||||
|
|
||||||
|
def _case_payload(candidate: SkillLearningCandidate, record: RunRecord) -> dict[str, Any]:
|
||||||
|
baseline_skill_names = []
|
||||||
|
if candidate.kind == "revise_skill":
|
||||||
|
baseline_skill_names = list(candidate.related_skill_names[:1])
|
||||||
|
elif candidate.kind == "merge_skills":
|
||||||
|
baseline_skill_names = list(candidate.related_skill_names)
|
||||||
|
return {
|
||||||
|
"run_id": record.run_id,
|
||||||
|
"task_id": record.task_id,
|
||||||
|
"session_id": record.session_id,
|
||||||
|
"task_text": record.task_text,
|
||||||
|
"baseline_skill_names": baseline_skill_names,
|
||||||
|
"candidate_skill_name": candidate.draft_skill_name,
|
||||||
|
"accepted_score": _score(record),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _recent_diverse(runs: list[RunRecord]) -> list[RunRecord]:
|
||||||
|
sorted_runs = sorted(runs, key=lambda item: (item.started_at, item.run_id), reverse=True)
|
||||||
|
result: list[RunRecord] = []
|
||||||
|
seen_tasks: set[str] = set()
|
||||||
|
for record in sorted_runs:
|
||||||
|
task_key = record.task_id or record.task_text
|
||||||
|
if task_key in seen_tasks and len(sorted_runs) > MAX_REPLAY_CASES:
|
||||||
|
continue
|
||||||
|
seen_tasks.add(task_key)
|
||||||
|
result.append(record)
|
||||||
|
if len(result) >= MAX_REPLAY_CASES:
|
||||||
|
break
|
||||||
|
if len(result) < min(len(sorted_runs), MAX_REPLAY_CASES):
|
||||||
|
seen_run_ids = {record.run_id for record in result}
|
||||||
|
result.extend(record for record in sorted_runs if record.run_id not in seen_run_ids)
|
||||||
|
return result[:MAX_REPLAY_CASES]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_accepted(record: RunRecord) -> bool:
|
||||||
|
feedback = record.feedback or {}
|
||||||
|
acceptance = feedback.get("acceptance_type")
|
||||||
|
if acceptance is None and feedback.get("feedback_type") == "satisfied":
|
||||||
|
acceptance = "accept"
|
||||||
|
return bool(record.success) and acceptance == "accept"
|
||||||
|
|
||||||
|
|
||||||
|
def _score(record: RunRecord) -> float:
|
||||||
|
validation = record.validation_result or {}
|
||||||
|
value = validation.get("score") if isinstance(validation, dict) else None
|
||||||
|
if value is not None:
|
||||||
|
try:
|
||||||
|
return max(0.0, min(1.0, float(value)))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
return 0.8 if record.success else 0.4
|
||||||
@ -0,0 +1,82 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from beaver.memory.runs import RunRecord
|
||||||
|
from beaver.memory.skills import SkillLearningCandidate
|
||||||
|
from beaver.skills.learning.case_selection import select_replay_cases
|
||||||
|
from beaver.skills.specs import SkillActivationReceipt
|
||||||
|
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
task_id: str = "task",
|
||||||
|
session_id: str = "session",
|
||||||
|
task_text: str = "debug task",
|
||||||
|
skill_name: str | None = None,
|
||||||
|
skill_version: str = "v0001",
|
||||||
|
) -> RunRecord:
|
||||||
|
receipts = []
|
||||||
|
if skill_name:
|
||||||
|
receipts.append(
|
||||||
|
SkillActivationReceipt(
|
||||||
|
run_id=run_id,
|
||||||
|
session_id=session_id,
|
||||||
|
skill_name=skill_name,
|
||||||
|
skill_version=skill_version,
|
||||||
|
content_hash="hash",
|
||||||
|
activated_at="now",
|
||||||
|
activation_reason="selected",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return RunRecord(
|
||||||
|
run_id=run_id,
|
||||||
|
session_id=session_id,
|
||||||
|
task_id=task_id,
|
||||||
|
attempt_index=1,
|
||||||
|
task_text=task_text,
|
||||||
|
started_at=f"2026-06-08T00:00:{run_id[-2:]}+00:00",
|
||||||
|
ended_at="end",
|
||||||
|
success=True,
|
||||||
|
finish_reason="stop",
|
||||||
|
feedback={"acceptance_type": "accept"},
|
||||||
|
activated_skills=receipts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_revise_cases_caps_at_ten_and_prefers_related_skill() -> None:
|
||||||
|
runs = [
|
||||||
|
_run(f"run-{index:02d}", task_id=f"task-{index}", skill_name="debug", skill_version="v0001")
|
||||||
|
for index in range(12)
|
||||||
|
]
|
||||||
|
candidate = SkillLearningCandidate(
|
||||||
|
candidate_id="candidate-1",
|
||||||
|
kind="revise_skill",
|
||||||
|
source_run_ids=[],
|
||||||
|
source_session_ids=[],
|
||||||
|
related_skill_names=["debug"],
|
||||||
|
reason="revise",
|
||||||
|
evidence={"skill_version": "v0001"},
|
||||||
|
)
|
||||||
|
|
||||||
|
cases = select_replay_cases(candidate, runs)
|
||||||
|
|
||||||
|
assert len(cases) == 10
|
||||||
|
assert all(case["baseline_skill_names"] == ["debug"] for case in cases)
|
||||||
|
assert cases[0]["run_id"] == "run-11"
|
||||||
|
|
||||||
|
|
||||||
|
def test_select_new_skill_uses_all_available_source_runs_under_ten() -> None:
|
||||||
|
runs = [_run(f"run-{index:02d}", task_id=f"task-{index}") for index in range(3)]
|
||||||
|
candidate = SkillLearningCandidate(
|
||||||
|
candidate_id="candidate-1",
|
||||||
|
kind="new_skill",
|
||||||
|
source_run_ids=["run-00", "run-01", "run-02"],
|
||||||
|
source_session_ids=["session"],
|
||||||
|
related_skill_names=[],
|
||||||
|
reason="new",
|
||||||
|
)
|
||||||
|
|
||||||
|
cases = select_replay_cases(candidate, runs)
|
||||||
|
|
||||||
|
assert [case["run_id"] for case in cases] == ["run-02", "run-01", "run-00"]
|
||||||
|
assert all(case["baseline_skill_names"] == [] for case in cases)
|
||||||
Reference in New Issue
Block a user