"""LLM judge for Hermes memory QA outputs.""" from __future__ import annotations import argparse import asyncio import json import os from pathlib import Path from typing import Any import httpx import yaml def load_answers(path: str | Path) -> list[dict[str, Any]]: input_path = Path(path) if input_path.suffix == ".jsonl": with input_path.open("r", encoding="utf-8") as file: return [json.loads(line) for line in file if line.strip()] with input_path.open("r", encoding="utf-8") as file: data = json.load(file) if isinstance(data, dict): return data.get("results", data.get("grades", [])) if isinstance(data, list): return data raise ValueError("answers file must be JSON list, JSONL, or object with results") def load_config(path: str | Path | None) -> dict[str, Any]: if not path: return {} config_path = Path(path) if not config_path.exists(): return {} with config_path.open("r", encoding="utf-8") as file: return yaml.safe_load(file) or {} def resolve_judge_config(args: argparse.Namespace) -> dict[str, Any]: config = load_config(args.config) judge = config.get("judge", {}) base_url = args.base_url or judge.get("base_url") or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1" model = args.model or judge.get("model") or "gpt-4o-mini" api_key_env = args.api_key_env or judge.get("api_key_env") or "OPENAI_API_KEY" api_key = args.api_key or judge.get("api_key") or os.environ.get(api_key_env, "") parallel = args.parallel if args.parallel is not None else int(judge.get("parallel", 4)) timeout_seconds = args.timeout_seconds if args.timeout_seconds is not None else int(judge.get("timeout_seconds", 120)) return { "base_url": str(base_url), "model": str(model), "api_key": str(api_key), "api_key_env": str(api_key_env), "parallel": int(parallel), "timeout_seconds": int(timeout_seconds), } def judge_prompt(question: str, expected: str, response: str) -> list[dict[str, str]]: return [ { "role": "system", "content": "You are an expert grader for long-term memory QA. Return JSON only.", }, { "role": "user", "content": ( "Decide whether the generated answer matches the gold answer.\n" "Be generous: count it correct if it refers to the same fact, topic, person, place, or date.\n" "Return exactly JSON: {\"is_correct\":\"CORRECT\" or \"WRONG\", \"reasoning\":\"short reason\"}.\n\n" f"Question: {question}\n" f"Gold answer: {expected}\n" f"Generated answer: {response}" ), }, ] async def grade_one( client: httpx.AsyncClient, *, base_url: str, api_key: str, model: str, item: dict[str, Any], ) -> dict[str, Any]: payload = { "model": model, "temperature": 0, "messages": judge_prompt(item["question"], item["expected"], item["response"]), } response = await client.post( f"{base_url.rstrip('/')}/chat/completions", headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, json=payload, ) response.raise_for_status() content = response.json()["choices"][0]["message"]["content"] parsed = json.loads(content) label = str(parsed.get("is_correct", parsed.get("label", "WRONG"))).strip().lower() return { **item, "grade": label == "correct", "judge_reasoning": parsed.get("reasoning", ""), } async def grade_answers( answers: list[dict[str, Any]], *, base_url: str, api_key: str, model: str, timeout_seconds: int = 120, parallel: int = 4, ) -> list[dict[str, Any]]: limits = httpx.Limits(max_connections=max(1, parallel)) async with httpx.AsyncClient(timeout=timeout_seconds, limits=limits) as client: semaphore = asyncio.Semaphore(max(1, parallel)) async def _grade(item: dict[str, Any]) -> dict[str, Any]: async with semaphore: return await grade_one(client, base_url=base_url, api_key=api_key, model=model, item=item) return await asyncio.gather(*[_grade(item) for item in answers]) def summarize(grades: list[dict[str, Any]]) -> dict[str, Any]: correct = sum(1 for item in grades if item.get("grade")) total = len(grades) categories: dict[str, dict[str, int]] = {} for item in grades: category = str(item.get("category", "unknown")) categories.setdefault(category, {"correct": 0, "total": 0}) categories[category]["total"] += 1 if item.get("grade"): categories[category]["correct"] += 1 return { "score": correct / total if total else 0.0, "correct": correct, "total": total, "categories": categories, } def main() -> None: parser = argparse.ArgumentParser(description="Judge Hermes memory QA answers") parser.add_argument("input", help="QA JSONL or JSON file") parser.add_argument("--config", default="eval/hermes_memory_eval/config.yaml") parser.add_argument("--output", default=None) parser.add_argument("--base-url", default=None) parser.add_argument("--api-key", default=None) parser.add_argument("--api-key-env", default=None) parser.add_argument("--model", default=None) parser.add_argument("--parallel", type=int, default=None) parser.add_argument("--timeout-seconds", type=int, default=None) args = parser.parse_args() judge_config = resolve_judge_config(args) if not judge_config["api_key"]: raise SystemExit(f"missing --api-key or {judge_config['api_key_env']}") answers = load_answers(args.input) grades = asyncio.run( grade_answers( answers, base_url=judge_config["base_url"], api_key=judge_config["api_key"], model=judge_config["model"], parallel=judge_config["parallel"], timeout_seconds=judge_config["timeout_seconds"], ) ) summary = summarize(grades) print(f"score: {summary['correct']}/{summary['total']} ({summary['score']:.2%})") for category, stats in sorted(summary["categories"].items()): total = stats["total"] score = stats["correct"] / total if total else 0.0 print(f"category {category}: {stats['correct']}/{total} ({score:.2%})") if args.output: output = {"summary": summary, "grades": grades} Path(args.output).parent.mkdir(parents=True, exist_ok=True) with Path(args.output).open("w", encoding="utf-8") as file: json.dump(output, file, indent=2, ensure_ascii=False) if __name__ == "__main__": main()