Files
memory-gateway/plugins/memory-gateway-agent/scripts/hermes_hook_probe.py

175 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
from __future__ import annotations
import json
import os
import sys
import urllib.error
import urllib.request
from pathlib import Path
from typing import Any
USER_ID = "test_user_memory_gateway_plugin"
AGENT_ID = "test_hermes_memory_gateway_plugin"
WORKSPACE_ID = "test_workspace_memory_gateway_plugin"
SESSION_ID = "test_session_memory_gateway_plugin_001"
def _ensure_paths() -> None:
plugin_root = Path(__file__).resolve().parents[1]
hermes_repo = Path(os.environ.get("HERMES_REPO", "/home/tom/.hermes/hermes-agent"))
hermes_cli = hermes_repo / "hermes_cli"
for path in [plugin_root, hermes_repo, hermes_cli]:
if str(path) not in sys.path:
sys.path.insert(0, str(path))
_ensure_paths()
from memory_gateway_plugin.output import dumps_safe, summarize_data
def _request(method: str, url: str, payload: dict[str, Any] | None = None, api_key: str = "") -> dict[str, Any]:
headers = {"Content-Type": "application/json"}
if api_key:
headers["X-API-Key"] = api_key
body = None if payload is None else json.dumps(payload, ensure_ascii=False).encode("utf-8")
req = urllib.request.Request(url, data=body, headers=headers, method=method)
try:
with urllib.request.urlopen(req, timeout=10) as response:
raw = response.read().decode("utf-8")
return {"ok": True, "status_code": getattr(response, "status", 200), "data": json.loads(raw) if raw else {}}
except urllib.error.HTTPError as exc:
try:
body_text = exc.read().decode("utf-8")
except Exception:
body_text = str(exc.reason)
return {"ok": False, "status_code": exc.code, "error": body_text[:500]}
except Exception as exc:
return {"ok": False, "status_code": None, "error": str(exc)[:500]}
def _ensure_user() -> dict[str, Any]:
gateway_url = os.environ.get("MEMORY_GATEWAY_URL", "http://127.0.0.1:1934").rstrip("/")
api_key = os.environ.get("MEMORY_GATEWAY_API_KEY", "")
return _request(
"POST",
gateway_url + "/v1/users",
{"user_id": USER_ID, "display_name": "Memory Gateway Hook Probe", "preferences": {"purpose": "hook_probe"}},
api_key=api_key,
)
def _audit_count(action: str) -> int:
gateway_url = os.environ.get("MEMORY_GATEWAY_URL", "http://127.0.0.1:1934").rstrip("/")
api_key = os.environ.get("MEMORY_GATEWAY_API_KEY", "")
result = _request("GET", gateway_url + "/v1/audit?limit=1000", api_key=api_key)
if not result.get("ok"):
return -1
rows = result.get("data") or []
return sum(
1
for row in rows
if row.get("action") == action
and row.get("actor_user_id") == USER_ID
and row.get("actor_agent_id") == AGENT_ID
)
def _hook_report(manager: Any, hook_name: str, payload: dict[str, Any], audit_action: str = "") -> dict[str, Any]:
registered = hook_name in getattr(manager, "_hooks", {}) and bool(manager._hooks[hook_name])
before = _audit_count(audit_action) if audit_action else -1
try:
result = manager.invoke_hook(hook_name, **payload)
after = _audit_count(audit_action) if audit_action else -1
return {
"registered": registered,
"invoked": True,
"result_type": type(result).__name__,
"result": summarize_data(result),
"audit_action": audit_action,
"audit_delta": (after - before) if before >= 0 and after >= 0 else None,
"error": "",
}
except Exception as exc:
return {
"registered": registered,
"invoked": False,
"result_type": "",
"result": None,
"audit_action": audit_action,
"audit_delta": None,
"error": str(exc)[:500],
}
def run(auto_commit: bool = False) -> dict[str, Any]:
os.environ.setdefault("MEMORY_GATEWAY_URL", "http://127.0.0.1:1934")
os.environ["MEMORY_GATEWAY_DEFAULT_USER_ID"] = USER_ID
os.environ["MEMORY_GATEWAY_DEFAULT_AGENT_ID"] = AGENT_ID
os.environ["MEMORY_GATEWAY_DEFAULT_WORKSPACE_ID"] = WORKSPACE_ID
os.environ["MEMORY_GATEWAY_AUTO_COMMIT_SESSION"] = "true" if auto_commit else "false"
from plugins import PluginManager
ensure_user = _ensure_user()
manager = PluginManager()
manager.discover_and_load()
base = {
"user_id": USER_ID,
"agent_id": AGENT_ID,
"workspace_id": WORKSPACE_ID,
"session_id": SESSION_ID,
"task_id": SESSION_ID,
"model": "hook-probe",
"platform": "cli",
}
hooks = {
"on_session_start": dict(base),
"pre_llm_call": {
**base,
"user_message": "Memory Gateway plugin integration test memory preference",
"conversation_history": [],
"is_first_turn": True,
},
"post_llm_call": {
**base,
"user_message": "请记住Memory Gateway plugin hook probe 偏好保存简短摘要型 episode。",
"assistant_response": "已记录为候选摘要,后续由 session commit 判断是否提升为长期记忆。",
},
"on_session_end": dict(base),
}
audit_actions = {
"pre_llm_call": "memory_search",
"post_llm_call": "append_episode",
"on_session_end": "commit_session",
}
reports = {name: _hook_report(manager, name, payload, audit_actions.get(name, "")) for name, payload in hooks.items()}
plugin = manager._plugins.get("memory-gateway-agent")
return {
"ok": all(item["registered"] and item["invoked"] for item in reports.values()),
"auto_commit": auto_commit,
"ensure_user": {"ok": ensure_user.get("ok"), "status_code": ensure_user.get("status_code"), "data": summarize_data(ensure_user.get("data"))},
"plugin": {
"enabled": bool(plugin and plugin.enabled),
"tools_registered": sorted(getattr(plugin, "tools_registered", []) if plugin else []),
"hooks_registered": sorted(getattr(plugin, "hooks_registered", []) if plugin else []),
"error": getattr(plugin, "error", None) if plugin else "plugin_not_found",
},
"hooks": reports,
}
def main() -> int:
auto_commit = os.environ.get("MEMORY_GATEWAY_AUTO_COMMIT_SESSION", "").strip().lower() in {"1", "true", "yes", "on"}
result = run(auto_commit=auto_commit)
print(dumps_safe(result))
return 0 if result.get("ok") else 1
if __name__ == "__main__":
sys.exit(main())