Files
memory-gateway/eval/hermes_memory_eval/locomo.py
tomtan c173fa45a7 Add Hermes memory evaluation framework with LoCoMo dataset support
- Implement HermesClient for interacting with the Hermes CLI.
- Create judge module for grading QA outputs from Hermes memory.
- Develop LoCoMo dataset parsing and formatting utilities.
- Introduce run_eval script to facilitate memory evaluation using LoCoMo-style datasets.
2026-05-27 17:06:26 +08:00

119 lines
3.6 KiB
Python

"""LoCoMo dataset parsing and formatting for Hermes memory evaluation."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@dataclass(frozen=True)
class LocomoSession:
sample_id: str
session_key: str
date_time: str
message: str
@dataclass(frozen=True)
class LocomoQA:
sample_id: str
question: str
expected: str
category: str
evidence: list[Any]
def load_samples(path: str | Path, sample_index: int | None = None) -> list[dict[str, Any]]:
with Path(path).open("r", encoding="utf-8") as file:
data = json.load(file)
if not isinstance(data, list):
raise ValueError("LoCoMo input must be a JSON list")
if sample_index is None:
return data
if sample_index < 0 or sample_index >= len(data):
raise ValueError(f"sample index {sample_index} out of range 0-{len(data) - 1}")
return [data[sample_index]]
def parse_session_range(value: str | None) -> tuple[int, int] | None:
if not value:
return None
if "-" in value:
start, end = value.split("-", 1)
return int(start), int(end)
number = int(value)
return number, number
def format_message(message: dict[str, Any]) -> str:
speaker = message.get("speaker", "unknown")
text = message.get("text", "")
line = f"{speaker}: {text}"
image_urls = message.get("img_url", [])
if isinstance(image_urls, str):
image_urls = [image_urls]
caption = message.get("blip_caption", "")
for url in image_urls:
suffix = f": {caption}" if caption else ""
line += f"\n{url}{suffix}"
if caption and not image_urls:
line += f"\n({caption})"
return line
def build_sessions(
sample: dict[str, Any],
session_range: tuple[int, int] | None = None,
tail: str = "请记住以上历史对话,只回复 OK。",
) -> list[LocomoSession]:
conversation = sample["conversation"]
session_keys = sorted(
[key for key in conversation if key.startswith("session_") and not key.endswith("_date_time")],
key=lambda key: int(key.split("_")[1]),
)
sessions: list[LocomoSession] = []
for session_key in session_keys:
session_number = int(session_key.split("_")[1])
if session_range:
start, end = session_range
if session_number < start or session_number > end:
continue
date_time = conversation.get(f"{session_key}_date_time", "")
parts = [f"[group chat conversation: {date_time}]"]
parts.extend(format_message(message) for message in conversation[session_key])
if tail:
parts.append(tail)
sessions.append(
LocomoSession(
sample_id=str(sample["sample_id"]),
session_key=session_key,
date_time=date_time,
message="\n\n".join(parts),
)
)
return sessions
def build_qas(sample: dict[str, Any], *, include_category_5: bool = False) -> list[LocomoQA]:
qas: list[LocomoQA] = []
for qa in sample.get("qa", []):
category = str(qa.get("category", ""))
if category == "5" and not include_category_5:
continue
qas.append(
LocomoQA(
sample_id=str(sample["sample_id"]),
question=str(qa["question"]),
expected=str(qa["answer"]),
category=category,
evidence=qa.get("evidence", []),
)
)
return qas
def sample_user_id(prefix: str, sample: dict[str, Any]) -> str:
return f"{prefix}{sample['sample_id']}"