Add Memory Gateway agent plugin

This commit is contained in:
2026-05-06 16:10:04 +08:00
parent e65731a273
commit c44af407d4
48 changed files with 3111 additions and 0 deletions

View File

@ -0,0 +1,26 @@
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
def _load_cleanup():
path = Path(__file__).resolve().parents[1] / "scripts" / "cleanup_test_memories.py"
spec = importlib.util.spec_from_file_location("cleanup_test_memories_guard_test", path)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
def test_cleanup_requires_test_user():
module = _load_cleanup()
try:
module.run("real_user")
except ValueError as exc:
assert "cleanup_refuses_non_test_user" in str(exc)
else:
raise AssertionError("cleanup accepted a non-test user")

View File

@ -0,0 +1,17 @@
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
def test_cleanup_test_memories_imports():
path = Path(__file__).resolve().parents[1] / "scripts" / "cleanup_test_memories.py"
spec = importlib.util.spec_from_file_location("cleanup_test_memories", path)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
assert callable(module.run)
assert module.USER_ID.startswith("test_user_")

View File

@ -0,0 +1,105 @@
from __future__ import annotations
import io
import json
import urllib.error
from memory_gateway_plugin.client import MemoryGatewayClient
from memory_gateway_plugin.config import PluginConfig
class FakeResponse:
status = 200
def __init__(self, payload):
self.payload = payload
def __enter__(self):
return self
def __exit__(self, *args):
return False
def read(self):
return json.dumps(self.payload).encode("utf-8")
def test_client_search_success(monkeypatch):
seen = {}
def fake_urlopen(request, timeout):
seen["url"] = request.full_url
seen["timeout"] = timeout
return FakeResponse({"results": [], "total": 0})
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway", timeout=7))
result = client.search_memory({"user_id": "u", "query": "demo"})
assert result["ok"] is True
assert result["endpoint"] == "/v1/memory/search"
assert seen["url"] == "http://gateway/v1/memory/search"
assert seen["timeout"] == 7
def test_client_network_error(monkeypatch):
def fake_urlopen(request, timeout):
raise urllib.error.URLError("connection refused")
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway"))
result = client.search_memory({"user_id": "u", "query": "demo"})
assert result["ok"] is False
assert result["status_code"] is None
assert "connection refused" in result["error"]
def test_commit_session_calls_correct_endpoint(monkeypatch):
seen = {}
def fake_urlopen(request, timeout):
seen["url"] = request.full_url
return FakeResponse({"session_id": "sess_1"})
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway"))
result = client.commit_session("sess_1", {"user_id": "u", "session_id": "sess_1"})
assert result["ok"] is True
assert seen["url"] == "http://gateway/v1/sessions/sess_1/commit"
def test_feedback_calls_correct_endpoint(monkeypatch):
seen = {}
def fake_urlopen(request, timeout):
seen["url"] = request.full_url
return FakeResponse({"status": "ok"})
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway"))
result = client.send_feedback("mem_1", {"user_id": "u", "feedback": "incorrect"})
assert result["ok"] is True
assert seen["url"] == "http://gateway/v1/memory/mem_1/feedback"
def test_client_http_error(monkeypatch):
def fake_urlopen(request, timeout):
raise urllib.error.HTTPError(
url=request.full_url,
code=401,
msg="unauthorized",
hdrs=None,
fp=io.BytesIO(b'{"detail":"Invalid or missing API key"}'),
)
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway"))
result = client.search_memory({"user_id": "u", "query": "demo"})
assert result["ok"] is False
assert result["status_code"] == 401
assert result["endpoint"] == "/v1/memory/search"

View File

@ -0,0 +1,16 @@
from __future__ import annotations
from memory_gateway_plugin.output import debug_raw_enabled, summarize_data
def test_debug_raw_disabled_by_default(monkeypatch):
monkeypatch.delenv("MEMORY_GATEWAY_PLUGIN_DEBUG_RAW", raising=False)
assert debug_raw_enabled() is False
assert summarize_data({"results": [{"memory": {"content": "raw"}}], "total": 1}) == {
"count": 1,
"total": 1,
"local_total": None,
"openviking_total": None,
"searched_namespaces": [],
}

View File

@ -0,0 +1,17 @@
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
def test_gateway_e2e_script_imports():
path = Path(__file__).resolve().parents[1] / "scripts" / "gateway_e2e_check.py"
spec = importlib.util.spec_from_file_location("gateway_e2e_check", path)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
assert callable(module.run)
assert module.USER_ID == "test_user_memory_gateway_plugin"

View File

@ -0,0 +1,15 @@
from __future__ import annotations
import importlib.util
from pathlib import Path
def test_hermes_hook_probe_script_imports():
path = Path(__file__).resolve().parents[1] / "scripts" / "hermes_hook_probe.py"
spec = importlib.util.spec_from_file_location("hermes_hook_probe", path)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
assert callable(module.run)
assert module.SESSION_ID == "test_session_memory_gateway_plugin_001"

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from test_hermes_register_tools import FakeHermesContext, load_plugin_module
def test_register_registers_expected_hooks():
module = load_plugin_module()
ctx = FakeHermesContext()
module.register(ctx)
assert [item[0] for item in ctx.registered_hooks] == [
"on_session_start",
"pre_llm_call",
"post_llm_call",
"on_session_end",
]
assert all(callable(item[1]) for item in ctx.registered_hooks)
def test_hook_callbacks_accept_kwargs():
module = load_plugin_module()
assert isinstance(module.on_session_start(session_id="s", extra="x"), dict)
assert isinstance(module.pre_llm_call(user_message="", session_id="s", extra="x"), dict)
assert module.post_llm_call(user_message="hi", assistant_response="hello", extra="x") is None
assert module.on_session_end(session_id="s", extra="x") is None

View File

@ -0,0 +1,63 @@
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
class FakeHermesContext:
def __init__(self) -> None:
self.registered_tools = []
self.registered_hooks = []
def register_tool(self, name, toolset, schema, handler, **kwargs):
self.registered_tools.append((name, toolset, schema, handler, kwargs))
def register_hook(self, hook_name, callback):
self.registered_hooks.append((hook_name, callback))
def load_plugin_module():
plugin_dir = Path(__file__).resolve().parents[1]
if "hermes_plugins" not in sys.modules:
parent = types.ModuleType("hermes_plugins")
parent.__path__ = []
sys.modules["hermes_plugins"] = parent
spec = importlib.util.spec_from_file_location(
"hermes_plugins.memory_gateway_agent_test",
plugin_dir / "__init__.py",
submodule_search_locations=[str(plugin_dir)],
)
module = importlib.util.module_from_spec(spec)
module.__package__ = "hermes_plugins.memory_gateway_agent_test"
module.__path__ = [str(plugin_dir)]
sys.modules["hermes_plugins.memory_gateway_agent_test"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_register_registers_five_tools():
module = load_plugin_module()
ctx = FakeHermesContext()
module.register(ctx)
assert [item[0] for item in ctx.registered_tools] == [
"memory_search",
"memory_append_episode",
"memory_commit_session",
"memory_upsert",
"memory_feedback",
]
assert all(item[1] == "memory_gateway" for item in ctx.registered_tools)
def test_registered_handlers_are_callable():
module = load_plugin_module()
ctx = FakeHermesContext()
module.register(ctx)
assert all(callable(item[3]) for item in ctx.registered_tools)

View File

@ -0,0 +1,36 @@
from __future__ import annotations
from test_hermes_register_tools import load_plugin_module
def test_tool_schemas_exist_for_all_tools():
module = load_plugin_module()
schemas = module.schemas.TOOL_SCHEMAS
assert set(schemas) == {
"memory_search",
"memory_append_episode",
"memory_commit_session",
"memory_upsert",
"memory_feedback",
}
def test_tool_schemas_have_required_fields():
module = load_plugin_module()
schemas = module.schemas.TOOL_SCHEMAS
assert schemas["memory_search"]["parameters"]["required"] == ["query", "user_id", "agent_id"]
assert schemas["memory_append_episode"]["parameters"]["required"] == ["content", "user_id", "agent_id", "session_id"]
assert schemas["memory_commit_session"]["parameters"]["required"] == ["user_id", "agent_id", "session_id"]
assert schemas["memory_upsert"]["parameters"]["required"] == ["user_id", "agent_id", "content", "memory_type"]
assert schemas["memory_feedback"]["parameters"]["required"] == ["memory_id", "user_id", "agent_id", "feedback"]
def test_upsert_schema_warns_high_risk():
module = load_plugin_module()
description = module.schemas.TOOL_SCHEMAS["memory_upsert"]["description"].lower()
assert "high-risk" in description
assert "do not call automatically" in description

View File

@ -0,0 +1,28 @@
from __future__ import annotations
from memory_gateway_plugin.config import PluginConfig
from memory_gateway_plugin.lifecycle import on_session_end
class CountingClient:
def __init__(self) -> None:
self.commit_calls = 0
def commit_session(self, session_id, payload):
self.commit_calls += 1
return {"ok": True, "data": {"session_id": session_id, "payload": payload}}
def test_hook_auto_commit_disabled_by_default():
client = CountingClient()
result = on_session_end(
{"user_id": "u", "agent_id": "a", "session_id": "s"},
client=client,
config=PluginConfig(auto_commit_session=False),
)
assert result["ok"] is True
assert result["committed"] is False
assert result["reason"] == "auto_commit_disabled"
assert client.commit_calls == 0

View File

@ -0,0 +1,35 @@
from __future__ import annotations
from memory_gateway_plugin.config import PluginConfig
from memory_gateway_plugin.lifecycle import after_user_message
class RecordingClient:
def __init__(self) -> None:
self.payloads = []
def append_episode(self, payload):
self.payloads.append(payload)
return {"ok": True, "data": payload}
def test_hook_post_llm_does_not_save_raw_transcript():
client = RecordingClient()
raw_transcript = "user: a\nassistant: b\nuser: c\nassistant: d"
result = after_user_message(
{
"user_id": "u",
"agent_id": "a",
"session_id": "s",
"user_message": raw_transcript,
"assistant_response": "请记住这个完整原始对话。",
},
client=client,
config=PluginConfig(auto_append_episode=True),
)
assert result["ok"] is True
assert result["appended"] is False
assert result["reason"] == "policy_skip"
assert client.payloads == []

View File

@ -0,0 +1,40 @@
from __future__ import annotations
import importlib.util
import sys
import types
from pathlib import Path
def _load_root_plugin():
plugin_dir = Path(__file__).resolve().parents[1]
if "hermes_plugins" not in sys.modules:
parent = types.ModuleType("hermes_plugins")
parent.__path__ = [] # type: ignore[attr-defined]
sys.modules["hermes_plugins"] = parent
module_name = "hermes_plugins.memory_gateway_agent_pre_llm_test"
spec = importlib.util.spec_from_file_location(
module_name,
plugin_dir / "__init__.py",
submodule_search_locations=[str(plugin_dir)],
)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
module.__package__ = module_name
module.__path__ = [str(plugin_dir)] # type: ignore[attr-defined]
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def test_hook_pre_llm_search_failure_non_blocking(monkeypatch):
plugin = _load_root_plugin()
class FailingLifecycle:
@staticmethod
def on_conversation_start(context):
return {"ok": False, "error": "network_error"}
monkeypatch.setattr(plugin, "lifecycle", FailingLifecycle)
assert plugin.pre_llm_call(user_id="u", agent_id="a", session_id="s", user_message="search") == {}

View File

@ -0,0 +1,9 @@
from __future__ import annotations
from memory_gateway_plugin.trace import trace_enabled
def test_hook_trace_disabled_by_default(monkeypatch):
monkeypatch.delenv("MEMORY_GATEWAY_PLUGIN_TRACE_HOOKS", raising=False)
assert trace_enabled() is False

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from memory_gateway_plugin.trace import trace_hook
def test_hook_trace_does_not_log_api_key(monkeypatch, tmp_path):
import memory_gateway_plugin.trace as trace_mod
path = tmp_path / "hook_trace.log"
monkeypatch.setenv("MEMORY_GATEWAY_PLUGIN_TRACE_HOOKS", "true")
monkeypatch.setenv("MEMORY_GATEWAY_API_KEY", "sk-should-not-appear")
monkeypatch.setattr(trace_mod, "trace_path", lambda: path)
trace_hook("post_llm_call", session_id="s", gateway_action="append_episode", gateway_called=True, ok=True)
text = path.read_text(encoding="utf-8")
assert "sk-should-not-appear" not in text
assert "api_key" not in text.lower()

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from memory_gateway_plugin.trace import trace_hook
def test_hook_trace_redacts_content(monkeypatch, tmp_path):
import memory_gateway_plugin.trace as trace_mod
path = tmp_path / "hook_trace.log"
monkeypatch.setenv("MEMORY_GATEWAY_PLUGIN_TRACE_HOOKS", "true")
monkeypatch.setattr(trace_mod, "trace_path", lambda: path)
trace_hook("pre_llm_call", session_id="test_session_1234567890", gateway_action="memory_search", gateway_called=True, ok=True, reason="password=abc")
text = path.read_text(encoding="utf-8")
assert "password=abc" not in text
assert "pre_llm_call" in text
assert "test_ses" in text

View File

@ -0,0 +1,17 @@
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
def test_interactive_session_check_imports():
path = Path(__file__).resolve().parents[1] / "scripts" / "hermes_interactive_session_check.py"
spec = importlib.util.spec_from_file_location("hermes_interactive_session_check", path)
assert spec and spec.loader
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
assert callable(module.run)
assert module.SESSION_ID == "test_session_memory_gateway_plugin_interactive_002"

View File

@ -0,0 +1,70 @@
from __future__ import annotations
from memory_gateway_plugin import register
from memory_gateway_plugin.config import PluginConfig
from memory_gateway_plugin.lifecycle import after_user_message, on_conversation_start, on_session_end
class FakeClient:
def search_memory(self, payload):
return {
"ok": True,
"data": {
"results": [
{
"memory": {
"id": "mem_1",
"namespace": "user/u/long_term",
"summary": "用户偏好中文输出。",
}
}
]
},
}
def append_episode(self, payload):
return {"ok": True, "data": payload}
def commit_session(self, session_id, payload):
return {"ok": True, "data": {"session_id": session_id}}
def test_lifecycle_hooks_do_not_crash_when_ctx_missing_features():
result = register(object())
assert result["ok"] is True
assert result["mode"] == "manual"
def test_lifecycle_search_returns_compact_context():
result = on_conversation_start(
{"user_id": "u", "agent_id": "a", "session_id": "s", "user_message": "之前偏好是什么?"},
client=FakeClient(),
config=PluginConfig(auto_search=True),
)
assert result["ok"] is True
assert "用户偏好中文输出" in result["memory_context"]
def test_lifecycle_append_policy_accepts_stable_preference():
result = after_user_message(
{"user_id": "u", "agent_id": "a", "session_id": "s", "user_message": "请记住:我偏好中文。"},
client=FakeClient(),
config=PluginConfig(auto_append_episode=True),
)
assert result["ok"] is True
assert result["appended"] is True
def test_lifecycle_session_end_auto_commit_disabled():
result = on_session_end(
{"user_id": "u", "agent_id": "a", "session_id": "s"},
client=FakeClient(),
config=PluginConfig(auto_commit_session=False),
)
assert result["ok"] is True
assert result["committed"] is False

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from memory_gateway_plugin.output import dumps_safe, redact, short_id
def test_output_redaction_hides_secret_fields():
payload = {
"api_key": "sk-test",
"headers": {"Authorization": "Bearer abc"},
"nested": {"cookie": "sid=abc"},
"safe": "value",
}
text = dumps_safe(payload)
assert "sk-test" not in text
assert "Bearer abc" not in text
assert "sid=abc" not in text
assert "value" in text
assert redact("password=abc") == "<redacted>"
def test_output_redaction_shortens_memory_ids():
assert short_id("mem_1234567890abcdef") == "mem_1234...cdef"

View File

@ -0,0 +1,22 @@
from __future__ import annotations
from memory_gateway_plugin.config import PluginConfig
from memory_gateway_plugin.policy import should_append_episode, should_commit_session, should_search_memory
def test_policy_should_append_for_explicit_remember():
assert should_append_episode("请记住:我偏好中文技术说明。", "", {}, PluginConfig())
def test_policy_should_not_append_for_small_talk():
assert not should_append_episode("你好", "", {}, PluginConfig())
def test_policy_should_search_when_enabled():
assert should_search_memory("这个项目之前有什么约束?", {}, PluginConfig(auto_search=True))
def test_policy_should_commit_only_when_enabled_or_forced():
assert not should_commit_session({}, PluginConfig(auto_commit_session=False))
assert should_commit_session({"force_commit": True}, PluginConfig(auto_commit_session=False))

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from memory_gateway_plugin.safety import detect_large_log, sanitize_memory_content, validate_memory_write
def test_safety_rejects_private_key():
content = "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----"
result = validate_memory_write(content)
assert result["allowed"] is False
assert result["reason"] == "secret_like_content"
def test_safety_rejects_large_log():
content = "\n".join(f"2026-05-06 10:00:{i:02d} ERROR failure" for i in range(10))
blocked, reason = detect_large_log(content)
assert blocked is True
assert reason == "large_or_raw_log"
def test_safety_sanitizes_secret_when_called_directly():
assert "sk-test" not in sanitize_memory_content("api_key=sk-test")

View File

@ -0,0 +1,106 @@
from __future__ import annotations
from memory_gateway_plugin.tools import memory_append_episode, memory_feedback, memory_search, memory_upsert
class FakeClient:
def __init__(self):
self.calls = []
def search_memory(self, payload):
self.calls.append(("search", payload))
return {"ok": True, "data": {"results": []}}
def append_episode(self, payload):
self.calls.append(("append", payload))
return {"ok": True, "data": payload}
def upsert_memory(self, payload):
self.calls.append(("upsert", payload))
return {"ok": True, "data": payload}
def send_feedback(self, memory_id, payload):
self.calls.append(("feedback", memory_id, payload))
return {"ok": True, "data": payload}
def test_memory_search_empty_query_rejected():
client = FakeClient()
result = memory_search(query="", user_id="u", agent_id="a", client=client)
assert result["ok"] is False
assert client.calls == []
def test_append_episode_rejects_api_key():
result = memory_append_episode(
user_id="u",
agent_id="a",
session_id="s",
episode_summary="api_key=sk-secret",
client=FakeClient(),
)
assert result["ok"] is False
assert result["reason"] == "secret_like_content"
def test_append_episode_rejects_password():
result = memory_append_episode(
user_id="u",
agent_id="a",
session_id="s",
episode_summary="password=hunter2",
client=FakeClient(),
)
assert result["ok"] is False
assert result["reason"] == "secret_like_content"
def test_append_episode_rejects_raw_transcript():
content = "\n".join(["User: hi", "Assistant: hello", "User: remember this", "Assistant: ok"])
result = memory_append_episode(user_id="u", agent_id="a", session_id="s", episode_summary=content, client=FakeClient())
assert result["ok"] is False
assert result["reason"] == "raw_chat_transcript"
def test_append_episode_accepts_stable_preference():
client = FakeClient()
result = memory_append_episode(
user_id="u",
agent_id="a",
session_id="s",
episode_summary="用户稳定偏好:以后所有技术方案都使用中文输出。",
tags=["preference"],
client=client,
)
assert result["ok"] is True
assert client.calls[0][0] == "append"
def test_upsert_uses_long_term_namespace_when_provided():
client = FakeClient()
namespace = "user/u/long_term"
result = memory_upsert(
user_id="u",
agent_id="a",
namespace=namespace,
memory_type="preference",
content="用户稳定偏好:使用中文输出。",
client=client,
)
assert result["ok"] is True
assert client.calls[0][1]["namespace"] == namespace
def test_feedback_calls_correct_endpoint():
client = FakeClient()
result = memory_feedback(user_id="u", agent_id="a", memory_id="mem_1", feedback="reject", client=client)
assert result["ok"] is True
assert client.calls[0] == ("feedback", "mem_1", {"user_id": "u", "agent_id": "a", "feedback": "incorrect", "comment": None})