Files
beaver_project/app-instance/backend/beaver/skills/learning/case_selection.py

110 lines
4.1 KiB
Python

"""Historical replay case selection for skill draft evaluation."""
from __future__ import annotations
from typing import Any
from beaver.memory.runs import RunRecord
from beaver.memory.skills import SkillLearningCandidate
MAX_REPLAY_CASES = 10
def select_replay_cases(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[dict[str, Any]]:
accepted = [record for record in runs if _is_accepted(record)]
if candidate.kind == "revise_skill":
selected = _select_revise(candidate, accepted)
elif candidate.kind == "merge_skills":
selected = _select_merge(candidate, accepted)
else:
selected = _select_new(candidate, accepted)
return [_case_payload(candidate, record) for record in selected[:MAX_REPLAY_CASES]]
def _select_revise(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
target = candidate.related_skill_names[0] if candidate.related_skill_names else ""
version = str(candidate.evidence.get("skill_version") or "")
matches = [
record
for record in runs
if any(
receipt.skill_name == target and (not version or receipt.skill_version == version)
for receipt in record.activated_skills
)
]
return _recent_diverse(matches)
def _select_merge(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
targets = set(candidate.related_skill_names)
matches = [
record
for record in runs
if targets and targets.issubset({receipt.skill_name for receipt in record.activated_skills})
]
return _recent_diverse(matches)
def _select_new(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
source_ids = set(candidate.source_run_ids)
if source_ids:
matches = [record for record in runs if record.run_id in source_ids]
else:
theme = str(candidate.evidence.get("theme") or "").lower().strip()
matches = [record for record in runs if theme and theme in record.task_text.lower()]
return _recent_diverse(matches)
def _case_payload(candidate: SkillLearningCandidate, record: RunRecord) -> dict[str, Any]:
baseline_skill_names = []
if candidate.kind == "revise_skill":
baseline_skill_names = list(candidate.related_skill_names[:1])
elif candidate.kind == "merge_skills":
baseline_skill_names = list(candidate.related_skill_names)
return {
"run_id": record.run_id,
"task_id": record.task_id,
"session_id": record.session_id,
"task_text": record.task_text,
"baseline_skill_names": baseline_skill_names,
"candidate_skill_name": candidate.draft_skill_name,
"accepted_score": _score(record),
}
def _recent_diverse(runs: list[RunRecord]) -> list[RunRecord]:
sorted_runs = sorted(runs, key=lambda item: (item.started_at, item.run_id), reverse=True)
result: list[RunRecord] = []
seen_tasks: set[str] = set()
for record in sorted_runs:
task_key = record.task_id or record.task_text
if task_key in seen_tasks and len(sorted_runs) > MAX_REPLAY_CASES:
continue
seen_tasks.add(task_key)
result.append(record)
if len(result) >= MAX_REPLAY_CASES:
break
if len(result) < min(len(sorted_runs), MAX_REPLAY_CASES):
seen_run_ids = {record.run_id for record in result}
result.extend(record for record in sorted_runs if record.run_id not in seen_run_ids)
return result[:MAX_REPLAY_CASES]
def _is_accepted(record: RunRecord) -> bool:
feedback = record.feedback or {}
acceptance = feedback.get("acceptance_type")
if acceptance is None and feedback.get("feedback_type") == "satisfied":
acceptance = "accept"
return bool(record.success) and acceptance == "accept"
def _score(record: RunRecord) -> float:
validation = record.validation_result or {}
value = validation.get("score") if isinstance(validation, dict) else None
if value is not None:
try:
return max(0.0, min(1.0, float(value)))
except (TypeError, ValueError):
pass
return 0.8 if record.success else 0.4