110 lines
4.1 KiB
Python
110 lines
4.1 KiB
Python
"""Historical replay case selection for skill draft evaluation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from beaver.memory.runs import RunRecord
|
|
from beaver.memory.skills import SkillLearningCandidate
|
|
|
|
MAX_REPLAY_CASES = 10
|
|
|
|
|
|
def select_replay_cases(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[dict[str, Any]]:
|
|
accepted = [record for record in runs if _is_accepted(record)]
|
|
if candidate.kind == "revise_skill":
|
|
selected = _select_revise(candidate, accepted)
|
|
elif candidate.kind == "merge_skills":
|
|
selected = _select_merge(candidate, accepted)
|
|
else:
|
|
selected = _select_new(candidate, accepted)
|
|
return [_case_payload(candidate, record) for record in selected[:MAX_REPLAY_CASES]]
|
|
|
|
|
|
def _select_revise(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
|
target = candidate.related_skill_names[0] if candidate.related_skill_names else ""
|
|
version = str(candidate.evidence.get("skill_version") or "")
|
|
matches = [
|
|
record
|
|
for record in runs
|
|
if any(
|
|
receipt.skill_name == target and (not version or receipt.skill_version == version)
|
|
for receipt in record.activated_skills
|
|
)
|
|
]
|
|
return _recent_diverse(matches)
|
|
|
|
|
|
def _select_merge(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
|
targets = set(candidate.related_skill_names)
|
|
matches = [
|
|
record
|
|
for record in runs
|
|
if targets and targets.issubset({receipt.skill_name for receipt in record.activated_skills})
|
|
]
|
|
return _recent_diverse(matches)
|
|
|
|
|
|
def _select_new(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
|
source_ids = set(candidate.source_run_ids)
|
|
if source_ids:
|
|
matches = [record for record in runs if record.run_id in source_ids]
|
|
else:
|
|
theme = str(candidate.evidence.get("theme") or "").lower().strip()
|
|
matches = [record for record in runs if theme and theme in record.task_text.lower()]
|
|
return _recent_diverse(matches)
|
|
|
|
|
|
def _case_payload(candidate: SkillLearningCandidate, record: RunRecord) -> dict[str, Any]:
|
|
baseline_skill_names = []
|
|
if candidate.kind == "revise_skill":
|
|
baseline_skill_names = list(candidate.related_skill_names[:1])
|
|
elif candidate.kind == "merge_skills":
|
|
baseline_skill_names = list(candidate.related_skill_names)
|
|
return {
|
|
"run_id": record.run_id,
|
|
"task_id": record.task_id,
|
|
"session_id": record.session_id,
|
|
"task_text": record.task_text,
|
|
"baseline_skill_names": baseline_skill_names,
|
|
"candidate_skill_name": candidate.draft_skill_name,
|
|
"accepted_score": _score(record),
|
|
}
|
|
|
|
|
|
def _recent_diverse(runs: list[RunRecord]) -> list[RunRecord]:
|
|
sorted_runs = sorted(runs, key=lambda item: (item.started_at, item.run_id), reverse=True)
|
|
result: list[RunRecord] = []
|
|
seen_tasks: set[str] = set()
|
|
for record in sorted_runs:
|
|
task_key = record.task_id or record.task_text
|
|
if task_key in seen_tasks and len(sorted_runs) > MAX_REPLAY_CASES:
|
|
continue
|
|
seen_tasks.add(task_key)
|
|
result.append(record)
|
|
if len(result) >= MAX_REPLAY_CASES:
|
|
break
|
|
if len(result) < min(len(sorted_runs), MAX_REPLAY_CASES):
|
|
seen_run_ids = {record.run_id for record in result}
|
|
result.extend(record for record in sorted_runs if record.run_id not in seen_run_ids)
|
|
return result[:MAX_REPLAY_CASES]
|
|
|
|
|
|
def _is_accepted(record: RunRecord) -> bool:
|
|
feedback = record.feedback or {}
|
|
acceptance = feedback.get("acceptance_type")
|
|
if acceptance is None and feedback.get("feedback_type") == "satisfied":
|
|
acceptance = "accept"
|
|
return bool(record.success) and acceptance == "accept"
|
|
|
|
|
|
def _score(record: RunRecord) -> float:
|
|
validation = record.validation_result or {}
|
|
value = validation.get("score") if isinstance(validation, dict) else None
|
|
if value is not None:
|
|
try:
|
|
return max(0.0, min(1.0, float(value)))
|
|
except (TypeError, ValueError):
|
|
pass
|
|
return 0.8 if record.success else 0.4
|