feat(skill-learning): select replay eval cases
This commit is contained in:
109
app-instance/backend/beaver/skills/learning/case_selection.py
Normal file
109
app-instance/backend/beaver/skills/learning/case_selection.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Historical replay case selection for skill draft evaluation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from beaver.memory.runs import RunRecord
|
||||
from beaver.memory.skills import SkillLearningCandidate
|
||||
|
||||
MAX_REPLAY_CASES = 10
|
||||
|
||||
|
||||
def select_replay_cases(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[dict[str, Any]]:
|
||||
accepted = [record for record in runs if _is_accepted(record)]
|
||||
if candidate.kind == "revise_skill":
|
||||
selected = _select_revise(candidate, accepted)
|
||||
elif candidate.kind == "merge_skills":
|
||||
selected = _select_merge(candidate, accepted)
|
||||
else:
|
||||
selected = _select_new(candidate, accepted)
|
||||
return [_case_payload(candidate, record) for record in selected[:MAX_REPLAY_CASES]]
|
||||
|
||||
|
||||
def _select_revise(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||
target = candidate.related_skill_names[0] if candidate.related_skill_names else ""
|
||||
version = str(candidate.evidence.get("skill_version") or "")
|
||||
matches = [
|
||||
record
|
||||
for record in runs
|
||||
if any(
|
||||
receipt.skill_name == target and (not version or receipt.skill_version == version)
|
||||
for receipt in record.activated_skills
|
||||
)
|
||||
]
|
||||
return _recent_diverse(matches)
|
||||
|
||||
|
||||
def _select_merge(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||
targets = set(candidate.related_skill_names)
|
||||
matches = [
|
||||
record
|
||||
for record in runs
|
||||
if targets and targets.issubset({receipt.skill_name for receipt in record.activated_skills})
|
||||
]
|
||||
return _recent_diverse(matches)
|
||||
|
||||
|
||||
def _select_new(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||
source_ids = set(candidate.source_run_ids)
|
||||
if source_ids:
|
||||
matches = [record for record in runs if record.run_id in source_ids]
|
||||
else:
|
||||
theme = str(candidate.evidence.get("theme") or "").lower().strip()
|
||||
matches = [record for record in runs if theme and theme in record.task_text.lower()]
|
||||
return _recent_diverse(matches)
|
||||
|
||||
|
||||
def _case_payload(candidate: SkillLearningCandidate, record: RunRecord) -> dict[str, Any]:
|
||||
baseline_skill_names = []
|
||||
if candidate.kind == "revise_skill":
|
||||
baseline_skill_names = list(candidate.related_skill_names[:1])
|
||||
elif candidate.kind == "merge_skills":
|
||||
baseline_skill_names = list(candidate.related_skill_names)
|
||||
return {
|
||||
"run_id": record.run_id,
|
||||
"task_id": record.task_id,
|
||||
"session_id": record.session_id,
|
||||
"task_text": record.task_text,
|
||||
"baseline_skill_names": baseline_skill_names,
|
||||
"candidate_skill_name": candidate.draft_skill_name,
|
||||
"accepted_score": _score(record),
|
||||
}
|
||||
|
||||
|
||||
def _recent_diverse(runs: list[RunRecord]) -> list[RunRecord]:
|
||||
sorted_runs = sorted(runs, key=lambda item: (item.started_at, item.run_id), reverse=True)
|
||||
result: list[RunRecord] = []
|
||||
seen_tasks: set[str] = set()
|
||||
for record in sorted_runs:
|
||||
task_key = record.task_id or record.task_text
|
||||
if task_key in seen_tasks and len(sorted_runs) > MAX_REPLAY_CASES:
|
||||
continue
|
||||
seen_tasks.add(task_key)
|
||||
result.append(record)
|
||||
if len(result) >= MAX_REPLAY_CASES:
|
||||
break
|
||||
if len(result) < min(len(sorted_runs), MAX_REPLAY_CASES):
|
||||
seen_run_ids = {record.run_id for record in result}
|
||||
result.extend(record for record in sorted_runs if record.run_id not in seen_run_ids)
|
||||
return result[:MAX_REPLAY_CASES]
|
||||
|
||||
|
||||
def _is_accepted(record: RunRecord) -> bool:
|
||||
feedback = record.feedback or {}
|
||||
acceptance = feedback.get("acceptance_type")
|
||||
if acceptance is None and feedback.get("feedback_type") == "satisfied":
|
||||
acceptance = "accept"
|
||||
return bool(record.success) and acceptance == "accept"
|
||||
|
||||
|
||||
def _score(record: RunRecord) -> float:
|
||||
validation = record.validation_result or {}
|
||||
value = validation.get("score") if isinstance(validation, dict) else None
|
||||
if value is not None:
|
||||
try:
|
||||
return max(0.0, min(1.0, float(value)))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return 0.8 if record.success else 0.4
|
||||
Reference in New Issue
Block a user