- Implement HermesClient for interacting with the Hermes CLI. - Create judge module for grading QA outputs from Hermes memory. - Develop LoCoMo dataset parsing and formatting utilities. - Introduce run_eval script to facilitate memory evaluation using LoCoMo-style datasets.
119 lines
3.6 KiB
Python
119 lines
3.6 KiB
Python
"""LoCoMo dataset parsing and formatting for Hermes memory evaluation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LocomoSession:
|
|
sample_id: str
|
|
session_key: str
|
|
date_time: str
|
|
message: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LocomoQA:
|
|
sample_id: str
|
|
question: str
|
|
expected: str
|
|
category: str
|
|
evidence: list[Any]
|
|
|
|
|
|
def load_samples(path: str | Path, sample_index: int | None = None) -> list[dict[str, Any]]:
|
|
with Path(path).open("r", encoding="utf-8") as file:
|
|
data = json.load(file)
|
|
if not isinstance(data, list):
|
|
raise ValueError("LoCoMo input must be a JSON list")
|
|
if sample_index is None:
|
|
return data
|
|
if sample_index < 0 or sample_index >= len(data):
|
|
raise ValueError(f"sample index {sample_index} out of range 0-{len(data) - 1}")
|
|
return [data[sample_index]]
|
|
|
|
|
|
def parse_session_range(value: str | None) -> tuple[int, int] | None:
|
|
if not value:
|
|
return None
|
|
if "-" in value:
|
|
start, end = value.split("-", 1)
|
|
return int(start), int(end)
|
|
number = int(value)
|
|
return number, number
|
|
|
|
|
|
def format_message(message: dict[str, Any]) -> str:
|
|
speaker = message.get("speaker", "unknown")
|
|
text = message.get("text", "")
|
|
line = f"{speaker}: {text}"
|
|
image_urls = message.get("img_url", [])
|
|
if isinstance(image_urls, str):
|
|
image_urls = [image_urls]
|
|
caption = message.get("blip_caption", "")
|
|
for url in image_urls:
|
|
suffix = f": {caption}" if caption else ""
|
|
line += f"\n{url}{suffix}"
|
|
if caption and not image_urls:
|
|
line += f"\n({caption})"
|
|
return line
|
|
|
|
|
|
def build_sessions(
|
|
sample: dict[str, Any],
|
|
session_range: tuple[int, int] | None = None,
|
|
tail: str = "请记住以上历史对话,只回复 OK。",
|
|
) -> list[LocomoSession]:
|
|
conversation = sample["conversation"]
|
|
session_keys = sorted(
|
|
[key for key in conversation if key.startswith("session_") and not key.endswith("_date_time")],
|
|
key=lambda key: int(key.split("_")[1]),
|
|
)
|
|
sessions: list[LocomoSession] = []
|
|
for session_key in session_keys:
|
|
session_number = int(session_key.split("_")[1])
|
|
if session_range:
|
|
start, end = session_range
|
|
if session_number < start or session_number > end:
|
|
continue
|
|
date_time = conversation.get(f"{session_key}_date_time", "")
|
|
parts = [f"[group chat conversation: {date_time}]"]
|
|
parts.extend(format_message(message) for message in conversation[session_key])
|
|
if tail:
|
|
parts.append(tail)
|
|
sessions.append(
|
|
LocomoSession(
|
|
sample_id=str(sample["sample_id"]),
|
|
session_key=session_key,
|
|
date_time=date_time,
|
|
message="\n\n".join(parts),
|
|
)
|
|
)
|
|
return sessions
|
|
|
|
|
|
def build_qas(sample: dict[str, Any], *, include_category_5: bool = False) -> list[LocomoQA]:
|
|
qas: list[LocomoQA] = []
|
|
for qa in sample.get("qa", []):
|
|
category = str(qa.get("category", ""))
|
|
if category == "5" and not include_category_5:
|
|
continue
|
|
qas.append(
|
|
LocomoQA(
|
|
sample_id=str(sample["sample_id"]),
|
|
question=str(qa["question"]),
|
|
expected=str(qa["answer"]),
|
|
category=category,
|
|
evidence=qa.get("evidence", []),
|
|
)
|
|
)
|
|
return qas
|
|
|
|
|
|
def sample_user_id(prefix: str, sample: dict[str, Any]) -> str:
|
|
return f"{prefix}{sample['sample_id']}"
|