- Implement HermesClient for interacting with the Hermes CLI. - Create judge module for grading QA outputs from Hermes memory. - Develop LoCoMo dataset parsing and formatting utilities. - Introduce run_eval script to facilitate memory evaluation using LoCoMo-style datasets.
187 lines
6.7 KiB
Python
187 lines
6.7 KiB
Python
"""Run Hermes memory evaluation using LoCoMo-style datasets."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
if __package__ in {None, ""}:
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
|
|
|
from eval.hermes_memory_eval.hermes_client import HermesClient, HermesClientConfig
|
|
from eval.hermes_memory_eval.locomo import (
|
|
build_qas,
|
|
build_sessions,
|
|
load_samples,
|
|
parse_session_range,
|
|
sample_user_id,
|
|
)
|
|
|
|
|
|
def load_config(path: str | Path) -> dict[str, Any]:
|
|
with Path(path).open("r", encoding="utf-8") as file:
|
|
return yaml.safe_load(file) or {}
|
|
|
|
|
|
def memory_env(config: dict[str, Any]) -> dict[str, str]:
|
|
memory = config.get("memory", {})
|
|
env: dict[str, str] = {}
|
|
mappings = {
|
|
"env_file": "MEMORY_SYSTEM_ENV_FILE",
|
|
"endpoint": "MEMORY_SYSTEM_ENDPOINT",
|
|
"api_key": "MEMORY_SYSTEM_API_KEY",
|
|
"search_use_llm": "MEMORY_SYSTEM_SEARCH_USE_LLM",
|
|
"commit_every_turns": "MEMORY_SYSTEM_COMMIT_EVERY_TURNS",
|
|
"commit_interval_seconds": "MEMORY_SYSTEM_COMMIT_INTERVAL_SECONDS",
|
|
}
|
|
for key, env_key in mappings.items():
|
|
value = memory.get(key)
|
|
if value is not None:
|
|
env[env_key] = str(value)
|
|
return env
|
|
|
|
|
|
def build_client(config: dict[str, Any]) -> HermesClient:
|
|
hermes = config.get("hermes", {})
|
|
return HermesClient(
|
|
HermesClientConfig(
|
|
command=str(hermes.get("command", "hermes")),
|
|
timeout_seconds=int(hermes.get("timeout_seconds", 600)),
|
|
quiet=bool(hermes.get("quiet", True)),
|
|
source=str(hermes.get("source", "memory-eval")),
|
|
extra_args=[str(arg) for arg in hermes.get("extra_args", [])],
|
|
)
|
|
)
|
|
|
|
|
|
def qa_prompt(config: dict[str, Any], question: str) -> str:
|
|
qa_config = config.get("qa", {})
|
|
template = str(
|
|
qa_config.get(
|
|
"prompt_template",
|
|
(
|
|
"请先使用 memory_system_search 查询长期记忆,再根据检索到的记忆回答问题。"
|
|
"如果记忆中没有答案,请直接说不知道,不要编造。\n\n问题:{question}"
|
|
),
|
|
)
|
|
)
|
|
return template.format(question=question)
|
|
|
|
|
|
def write_jsonl(path: str | Path, records: list[dict[str, Any]]) -> None:
|
|
output_path = Path(path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with output_path.open("w", encoding="utf-8") as file:
|
|
for record in records:
|
|
file.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
|
|
|
|
def run_ingest(args: argparse.Namespace) -> None:
|
|
config = load_config(args.config)
|
|
client = build_client(config)
|
|
env = memory_env(config)
|
|
samples = load_samples(args.input, args.sample)
|
|
session_range = parse_session_range(args.sessions)
|
|
user_prefix = str(config.get("memory", {}).get("user_prefix", "locomo-"))
|
|
records: list[dict[str, Any]] = []
|
|
|
|
for sample in samples:
|
|
user_id = args.user or sample_user_id(user_prefix, sample)
|
|
sessions = build_sessions(sample, session_range=session_range, tail=args.tail)
|
|
print(f"=== Sample {sample['sample_id']} user={user_id} sessions={len(sessions)} ===", file=sys.stderr)
|
|
for session in sessions:
|
|
try:
|
|
response = client.chat(session.message, user_id=user_id, env=env)
|
|
status = "success"
|
|
except Exception as exc:
|
|
response = str(exc)
|
|
status = "failed"
|
|
print(f"[{session.sample_id}/{session.session_key}] {status}", file=sys.stderr)
|
|
records.append(
|
|
{
|
|
"mode": "ingest",
|
|
"status": status,
|
|
"sample_id": session.sample_id,
|
|
"session": session.session_key,
|
|
"date_time": session.date_time,
|
|
"user_id": user_id,
|
|
"response": response,
|
|
}
|
|
)
|
|
|
|
if args.output:
|
|
write_jsonl(args.output, records)
|
|
print(f"written: {args.output}", file=sys.stderr)
|
|
|
|
|
|
def run_qa(args: argparse.Namespace) -> None:
|
|
config = load_config(args.config)
|
|
client = build_client(config)
|
|
env = memory_env(config)
|
|
samples = load_samples(args.input, args.sample)
|
|
user_prefix = str(config.get("memory", {}).get("user_prefix", "locomo-"))
|
|
records: list[dict[str, Any]] = []
|
|
|
|
for sample in samples:
|
|
user_id = args.user or sample_user_id(user_prefix, sample)
|
|
qas = build_qas(sample, include_category_5=args.include_category_5)
|
|
if args.count is not None:
|
|
qas = qas[: args.count]
|
|
print(f"=== Sample {sample['sample_id']} user={user_id} qa={len(qas)} ===", file=sys.stderr)
|
|
for index, qa in enumerate(qas, start=1):
|
|
try:
|
|
response = client.chat(qa_prompt(config, qa.question), user_id=user_id, env=env)
|
|
status = "success"
|
|
except Exception as exc:
|
|
response = str(exc)
|
|
status = "failed"
|
|
print(f"[{qa.sample_id}] Q{index}/{len(qas)} {status}", file=sys.stderr)
|
|
records.append(
|
|
{
|
|
"mode": "qa",
|
|
"status": status,
|
|
"sample_id": qa.sample_id,
|
|
"user_id": user_id,
|
|
"qi": index,
|
|
"question": qa.question,
|
|
"expected": qa.expected,
|
|
"response": response,
|
|
"category": qa.category,
|
|
"evidence": qa.evidence,
|
|
}
|
|
)
|
|
|
|
if args.output:
|
|
write_jsonl(args.output, records)
|
|
print(f"written: {args.output}", file=sys.stderr)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Evaluate Hermes memory with LoCoMo-style datasets")
|
|
parser.add_argument("mode", choices=["ingest", "qa"])
|
|
parser.add_argument("input", help="Path to LoCoMo JSON dataset")
|
|
parser.add_argument("--config", default="eval/hermes_memory_eval/config.example.yaml")
|
|
parser.add_argument("--output", default=None)
|
|
parser.add_argument("--sample", type=int, default=None)
|
|
parser.add_argument("--user", default=None)
|
|
parser.add_argument("--sessions", default=None, help="Ingest session range, for example 1-4")
|
|
parser.add_argument("--tail", default="请记住以上历史对话,只回复 OK。")
|
|
parser.add_argument("--count", type=int, default=None)
|
|
parser.add_argument("--include-category-5", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
if args.mode == "ingest":
|
|
run_ingest(args)
|
|
else:
|
|
run_qa(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|