Add Hermes memory evaluation framework with LoCoMo dataset support
- Implement HermesClient for interacting with the Hermes CLI. - Create judge module for grading QA outputs from Hermes memory. - Develop LoCoMo dataset parsing and formatting utilities. - Introduce run_eval script to facilitate memory evaluation using LoCoMo-style datasets.
This commit is contained in:
118
eval/hermes_memory_eval/locomo.py
Normal file
118
eval/hermes_memory_eval/locomo.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""LoCoMo dataset parsing and formatting for Hermes memory evaluation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocomoSession:
|
||||
sample_id: str
|
||||
session_key: str
|
||||
date_time: str
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LocomoQA:
|
||||
sample_id: str
|
||||
question: str
|
||||
expected: str
|
||||
category: str
|
||||
evidence: list[Any]
|
||||
|
||||
|
||||
def load_samples(path: str | Path, sample_index: int | None = None) -> list[dict[str, Any]]:
|
||||
with Path(path).open("r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
if not isinstance(data, list):
|
||||
raise ValueError("LoCoMo input must be a JSON list")
|
||||
if sample_index is None:
|
||||
return data
|
||||
if sample_index < 0 or sample_index >= len(data):
|
||||
raise ValueError(f"sample index {sample_index} out of range 0-{len(data) - 1}")
|
||||
return [data[sample_index]]
|
||||
|
||||
|
||||
def parse_session_range(value: str | None) -> tuple[int, int] | None:
|
||||
if not value:
|
||||
return None
|
||||
if "-" in value:
|
||||
start, end = value.split("-", 1)
|
||||
return int(start), int(end)
|
||||
number = int(value)
|
||||
return number, number
|
||||
|
||||
|
||||
def format_message(message: dict[str, Any]) -> str:
|
||||
speaker = message.get("speaker", "unknown")
|
||||
text = message.get("text", "")
|
||||
line = f"{speaker}: {text}"
|
||||
image_urls = message.get("img_url", [])
|
||||
if isinstance(image_urls, str):
|
||||
image_urls = [image_urls]
|
||||
caption = message.get("blip_caption", "")
|
||||
for url in image_urls:
|
||||
suffix = f": {caption}" if caption else ""
|
||||
line += f"\n{url}{suffix}"
|
||||
if caption and not image_urls:
|
||||
line += f"\n({caption})"
|
||||
return line
|
||||
|
||||
|
||||
def build_sessions(
|
||||
sample: dict[str, Any],
|
||||
session_range: tuple[int, int] | None = None,
|
||||
tail: str = "请记住以上历史对话,只回复 OK。",
|
||||
) -> list[LocomoSession]:
|
||||
conversation = sample["conversation"]
|
||||
session_keys = sorted(
|
||||
[key for key in conversation if key.startswith("session_") and not key.endswith("_date_time")],
|
||||
key=lambda key: int(key.split("_")[1]),
|
||||
)
|
||||
sessions: list[LocomoSession] = []
|
||||
for session_key in session_keys:
|
||||
session_number = int(session_key.split("_")[1])
|
||||
if session_range:
|
||||
start, end = session_range
|
||||
if session_number < start or session_number > end:
|
||||
continue
|
||||
date_time = conversation.get(f"{session_key}_date_time", "")
|
||||
parts = [f"[group chat conversation: {date_time}]"]
|
||||
parts.extend(format_message(message) for message in conversation[session_key])
|
||||
if tail:
|
||||
parts.append(tail)
|
||||
sessions.append(
|
||||
LocomoSession(
|
||||
sample_id=str(sample["sample_id"]),
|
||||
session_key=session_key,
|
||||
date_time=date_time,
|
||||
message="\n\n".join(parts),
|
||||
)
|
||||
)
|
||||
return sessions
|
||||
|
||||
|
||||
def build_qas(sample: dict[str, Any], *, include_category_5: bool = False) -> list[LocomoQA]:
|
||||
qas: list[LocomoQA] = []
|
||||
for qa in sample.get("qa", []):
|
||||
category = str(qa.get("category", ""))
|
||||
if category == "5" and not include_category_5:
|
||||
continue
|
||||
qas.append(
|
||||
LocomoQA(
|
||||
sample_id=str(sample["sample_id"]),
|
||||
question=str(qa["question"]),
|
||||
expected=str(qa["answer"]),
|
||||
category=category,
|
||||
evidence=qa.get("evidence", []),
|
||||
)
|
||||
)
|
||||
return qas
|
||||
|
||||
|
||||
def sample_user_id(prefix: str, sample: dict[str, Any]) -> str:
|
||||
return f"{prefix}{sample['sample_id']}"
|
||||
Reference in New Issue
Block a user