"""LoCoMo dataset parsing and formatting for Hermes memory evaluation.""" from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Any @dataclass(frozen=True) class LocomoSession: sample_id: str session_key: str date_time: str message: str @dataclass(frozen=True) class LocomoQA: sample_id: str question: str expected: str category: str evidence: list[Any] def load_samples(path: str | Path, sample_index: int | None = None) -> list[dict[str, Any]]: with Path(path).open("r", encoding="utf-8") as file: data = json.load(file) if not isinstance(data, list): raise ValueError("LoCoMo input must be a JSON list") if sample_index is None: return data if sample_index < 0 or sample_index >= len(data): raise ValueError(f"sample index {sample_index} out of range 0-{len(data) - 1}") return [data[sample_index]] def parse_session_range(value: str | None) -> tuple[int, int] | None: if not value: return None if "-" in value: start, end = value.split("-", 1) return int(start), int(end) number = int(value) return number, number def format_message(message: dict[str, Any]) -> str: speaker = message.get("speaker", "unknown") text = message.get("text", "") line = f"{speaker}: {text}" image_urls = message.get("img_url", []) if isinstance(image_urls, str): image_urls = [image_urls] caption = message.get("blip_caption", "") for url in image_urls: suffix = f": {caption}" if caption else "" line += f"\n{url}{suffix}" if caption and not image_urls: line += f"\n({caption})" return line def build_sessions( sample: dict[str, Any], session_range: tuple[int, int] | None = None, tail: str = "请记住以上历史对话,只回复 OK。", ) -> list[LocomoSession]: conversation = sample["conversation"] session_keys = sorted( [key for key in conversation if key.startswith("session_") and not key.endswith("_date_time")], key=lambda key: int(key.split("_")[1]), ) sessions: list[LocomoSession] = [] for session_key in session_keys: session_number = int(session_key.split("_")[1]) if session_range: start, end = session_range if session_number < start or session_number > end: continue date_time = conversation.get(f"{session_key}_date_time", "") parts = [f"[group chat conversation: {date_time}]"] parts.extend(format_message(message) for message in conversation[session_key]) if tail: parts.append(tail) sessions.append( LocomoSession( sample_id=str(sample["sample_id"]), session_key=session_key, date_time=date_time, message="\n\n".join(parts), ) ) return sessions def build_qas(sample: dict[str, Any], *, include_category_5: bool = False) -> list[LocomoQA]: qas: list[LocomoQA] = [] for qa in sample.get("qa", []): category = str(qa.get("category", "")) if category == "5" and not include_category_5: continue qas.append( LocomoQA( sample_id=str(sample["sample_id"]), question=str(qa["question"]), expected=str(qa["answer"]), category=category, evidence=qa.get("evidence", []), ) ) return qas def sample_user_id(prefix: str, sample: dict[str, Any]) -> str: return f"{prefix}{sample['sample_id']}"