Add Memory Gateway agent plugin
This commit is contained in:
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_cleanup():
|
||||
path = Path(__file__).resolve().parents[1] / "scripts" / "cleanup_test_memories.py"
|
||||
spec = importlib.util.spec_from_file_location("cleanup_test_memories_guard_test", path)
|
||||
assert spec and spec.loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_cleanup_requires_test_user():
|
||||
module = _load_cleanup()
|
||||
|
||||
try:
|
||||
module.run("real_user")
|
||||
except ValueError as exc:
|
||||
assert "cleanup_refuses_non_test_user" in str(exc)
|
||||
else:
|
||||
raise AssertionError("cleanup accepted a non-test user")
|
||||
@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_cleanup_test_memories_imports():
|
||||
path = Path(__file__).resolve().parents[1] / "scripts" / "cleanup_test_memories.py"
|
||||
spec = importlib.util.spec_from_file_location("cleanup_test_memories", path)
|
||||
assert spec and spec.loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
assert callable(module.run)
|
||||
assert module.USER_ID.startswith("test_user_")
|
||||
105
plugins/memory-gateway-agent/tests/test_client.py
Normal file
105
plugins/memory-gateway-agent/tests/test_client.py
Normal file
@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import urllib.error
|
||||
|
||||
from memory_gateway_plugin.client import MemoryGatewayClient
|
||||
from memory_gateway_plugin.config import PluginConfig
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
status = 200
|
||||
|
||||
def __init__(self, payload):
|
||||
self.payload = payload
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return json.dumps(self.payload).encode("utf-8")
|
||||
|
||||
|
||||
def test_client_search_success(monkeypatch):
|
||||
seen = {}
|
||||
|
||||
def fake_urlopen(request, timeout):
|
||||
seen["url"] = request.full_url
|
||||
seen["timeout"] = timeout
|
||||
return FakeResponse({"results": [], "total": 0})
|
||||
|
||||
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
|
||||
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway", timeout=7))
|
||||
result = client.search_memory({"user_id": "u", "query": "demo"})
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["endpoint"] == "/v1/memory/search"
|
||||
assert seen["url"] == "http://gateway/v1/memory/search"
|
||||
assert seen["timeout"] == 7
|
||||
|
||||
|
||||
def test_client_network_error(monkeypatch):
|
||||
def fake_urlopen(request, timeout):
|
||||
raise urllib.error.URLError("connection refused")
|
||||
|
||||
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
|
||||
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway"))
|
||||
result = client.search_memory({"user_id": "u", "query": "demo"})
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["status_code"] is None
|
||||
assert "connection refused" in result["error"]
|
||||
|
||||
|
||||
def test_commit_session_calls_correct_endpoint(monkeypatch):
|
||||
seen = {}
|
||||
|
||||
def fake_urlopen(request, timeout):
|
||||
seen["url"] = request.full_url
|
||||
return FakeResponse({"session_id": "sess_1"})
|
||||
|
||||
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
|
||||
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway"))
|
||||
result = client.commit_session("sess_1", {"user_id": "u", "session_id": "sess_1"})
|
||||
|
||||
assert result["ok"] is True
|
||||
assert seen["url"] == "http://gateway/v1/sessions/sess_1/commit"
|
||||
|
||||
|
||||
def test_feedback_calls_correct_endpoint(monkeypatch):
|
||||
seen = {}
|
||||
|
||||
def fake_urlopen(request, timeout):
|
||||
seen["url"] = request.full_url
|
||||
return FakeResponse({"status": "ok"})
|
||||
|
||||
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
|
||||
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway"))
|
||||
result = client.send_feedback("mem_1", {"user_id": "u", "feedback": "incorrect"})
|
||||
|
||||
assert result["ok"] is True
|
||||
assert seen["url"] == "http://gateway/v1/memory/mem_1/feedback"
|
||||
|
||||
|
||||
def test_client_http_error(monkeypatch):
|
||||
def fake_urlopen(request, timeout):
|
||||
raise urllib.error.HTTPError(
|
||||
url=request.full_url,
|
||||
code=401,
|
||||
msg="unauthorized",
|
||||
hdrs=None,
|
||||
fp=io.BytesIO(b'{"detail":"Invalid or missing API key"}'),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("urllib.request.urlopen", fake_urlopen)
|
||||
client = MemoryGatewayClient(PluginConfig(gateway_url="http://gateway"))
|
||||
result = client.search_memory({"user_id": "u", "query": "demo"})
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["status_code"] == 401
|
||||
assert result["endpoint"] == "/v1/memory/search"
|
||||
|
||||
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.output import debug_raw_enabled, summarize_data
|
||||
|
||||
|
||||
def test_debug_raw_disabled_by_default(monkeypatch):
|
||||
monkeypatch.delenv("MEMORY_GATEWAY_PLUGIN_DEBUG_RAW", raising=False)
|
||||
|
||||
assert debug_raw_enabled() is False
|
||||
assert summarize_data({"results": [{"memory": {"content": "raw"}}], "total": 1}) == {
|
||||
"count": 1,
|
||||
"total": 1,
|
||||
"local_total": None,
|
||||
"openviking_total": None,
|
||||
"searched_namespaces": [],
|
||||
}
|
||||
@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_gateway_e2e_script_imports():
|
||||
path = Path(__file__).resolve().parents[1] / "scripts" / "gateway_e2e_check.py"
|
||||
spec = importlib.util.spec_from_file_location("gateway_e2e_check", path)
|
||||
assert spec and spec.loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
assert callable(module.run)
|
||||
assert module.USER_ID == "test_user_memory_gateway_plugin"
|
||||
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_hermes_hook_probe_script_imports():
|
||||
path = Path(__file__).resolve().parents[1] / "scripts" / "hermes_hook_probe.py"
|
||||
spec = importlib.util.spec_from_file_location("hermes_hook_probe", path)
|
||||
assert spec and spec.loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
assert callable(module.run)
|
||||
assert module.SESSION_ID == "test_session_memory_gateway_plugin_001"
|
||||
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from test_hermes_register_tools import FakeHermesContext, load_plugin_module
|
||||
|
||||
|
||||
def test_register_registers_expected_hooks():
|
||||
module = load_plugin_module()
|
||||
ctx = FakeHermesContext()
|
||||
|
||||
module.register(ctx)
|
||||
|
||||
assert [item[0] for item in ctx.registered_hooks] == [
|
||||
"on_session_start",
|
||||
"pre_llm_call",
|
||||
"post_llm_call",
|
||||
"on_session_end",
|
||||
]
|
||||
assert all(callable(item[1]) for item in ctx.registered_hooks)
|
||||
|
||||
|
||||
def test_hook_callbacks_accept_kwargs():
|
||||
module = load_plugin_module()
|
||||
|
||||
assert isinstance(module.on_session_start(session_id="s", extra="x"), dict)
|
||||
assert isinstance(module.pre_llm_call(user_message="", session_id="s", extra="x"), dict)
|
||||
assert module.post_llm_call(user_message="hi", assistant_response="hello", extra="x") is None
|
||||
assert module.on_session_end(session_id="s", extra="x") is None
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class FakeHermesContext:
|
||||
def __init__(self) -> None:
|
||||
self.registered_tools = []
|
||||
self.registered_hooks = []
|
||||
|
||||
def register_tool(self, name, toolset, schema, handler, **kwargs):
|
||||
self.registered_tools.append((name, toolset, schema, handler, kwargs))
|
||||
|
||||
def register_hook(self, hook_name, callback):
|
||||
self.registered_hooks.append((hook_name, callback))
|
||||
|
||||
|
||||
def load_plugin_module():
|
||||
plugin_dir = Path(__file__).resolve().parents[1]
|
||||
if "hermes_plugins" not in sys.modules:
|
||||
parent = types.ModuleType("hermes_plugins")
|
||||
parent.__path__ = []
|
||||
sys.modules["hermes_plugins"] = parent
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"hermes_plugins.memory_gateway_agent_test",
|
||||
plugin_dir / "__init__.py",
|
||||
submodule_search_locations=[str(plugin_dir)],
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.__package__ = "hermes_plugins.memory_gateway_agent_test"
|
||||
module.__path__ = [str(plugin_dir)]
|
||||
sys.modules["hermes_plugins.memory_gateway_agent_test"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_register_registers_five_tools():
|
||||
module = load_plugin_module()
|
||||
ctx = FakeHermesContext()
|
||||
|
||||
module.register(ctx)
|
||||
|
||||
assert [item[0] for item in ctx.registered_tools] == [
|
||||
"memory_search",
|
||||
"memory_append_episode",
|
||||
"memory_commit_session",
|
||||
"memory_upsert",
|
||||
"memory_feedback",
|
||||
]
|
||||
assert all(item[1] == "memory_gateway" for item in ctx.registered_tools)
|
||||
|
||||
|
||||
def test_registered_handlers_are_callable():
|
||||
module = load_plugin_module()
|
||||
ctx = FakeHermesContext()
|
||||
|
||||
module.register(ctx)
|
||||
|
||||
assert all(callable(item[3]) for item in ctx.registered_tools)
|
||||
36
plugins/memory-gateway-agent/tests/test_hermes_schemas.py
Normal file
36
plugins/memory-gateway-agent/tests/test_hermes_schemas.py
Normal file
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from test_hermes_register_tools import load_plugin_module
|
||||
|
||||
|
||||
def test_tool_schemas_exist_for_all_tools():
|
||||
module = load_plugin_module()
|
||||
schemas = module.schemas.TOOL_SCHEMAS
|
||||
|
||||
assert set(schemas) == {
|
||||
"memory_search",
|
||||
"memory_append_episode",
|
||||
"memory_commit_session",
|
||||
"memory_upsert",
|
||||
"memory_feedback",
|
||||
}
|
||||
|
||||
|
||||
def test_tool_schemas_have_required_fields():
|
||||
module = load_plugin_module()
|
||||
schemas = module.schemas.TOOL_SCHEMAS
|
||||
|
||||
assert schemas["memory_search"]["parameters"]["required"] == ["query", "user_id", "agent_id"]
|
||||
assert schemas["memory_append_episode"]["parameters"]["required"] == ["content", "user_id", "agent_id", "session_id"]
|
||||
assert schemas["memory_commit_session"]["parameters"]["required"] == ["user_id", "agent_id", "session_id"]
|
||||
assert schemas["memory_upsert"]["parameters"]["required"] == ["user_id", "agent_id", "content", "memory_type"]
|
||||
assert schemas["memory_feedback"]["parameters"]["required"] == ["memory_id", "user_id", "agent_id", "feedback"]
|
||||
|
||||
|
||||
def test_upsert_schema_warns_high_risk():
|
||||
module = load_plugin_module()
|
||||
|
||||
description = module.schemas.TOOL_SCHEMAS["memory_upsert"]["description"].lower()
|
||||
assert "high-risk" in description
|
||||
assert "do not call automatically" in description
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.config import PluginConfig
|
||||
from memory_gateway_plugin.lifecycle import on_session_end
|
||||
|
||||
|
||||
class CountingClient:
|
||||
def __init__(self) -> None:
|
||||
self.commit_calls = 0
|
||||
|
||||
def commit_session(self, session_id, payload):
|
||||
self.commit_calls += 1
|
||||
return {"ok": True, "data": {"session_id": session_id, "payload": payload}}
|
||||
|
||||
|
||||
def test_hook_auto_commit_disabled_by_default():
|
||||
client = CountingClient()
|
||||
|
||||
result = on_session_end(
|
||||
{"user_id": "u", "agent_id": "a", "session_id": "s"},
|
||||
client=client,
|
||||
config=PluginConfig(auto_commit_session=False),
|
||||
)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["committed"] is False
|
||||
assert result["reason"] == "auto_commit_disabled"
|
||||
assert client.commit_calls == 0
|
||||
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.config import PluginConfig
|
||||
from memory_gateway_plugin.lifecycle import after_user_message
|
||||
|
||||
|
||||
class RecordingClient:
|
||||
def __init__(self) -> None:
|
||||
self.payloads = []
|
||||
|
||||
def append_episode(self, payload):
|
||||
self.payloads.append(payload)
|
||||
return {"ok": True, "data": payload}
|
||||
|
||||
|
||||
def test_hook_post_llm_does_not_save_raw_transcript():
|
||||
client = RecordingClient()
|
||||
raw_transcript = "user: a\nassistant: b\nuser: c\nassistant: d"
|
||||
|
||||
result = after_user_message(
|
||||
{
|
||||
"user_id": "u",
|
||||
"agent_id": "a",
|
||||
"session_id": "s",
|
||||
"user_message": raw_transcript,
|
||||
"assistant_response": "请记住这个完整原始对话。",
|
||||
},
|
||||
client=client,
|
||||
config=PluginConfig(auto_append_episode=True),
|
||||
)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["appended"] is False
|
||||
assert result["reason"] == "policy_skip"
|
||||
assert client.payloads == []
|
||||
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_root_plugin():
|
||||
plugin_dir = Path(__file__).resolve().parents[1]
|
||||
if "hermes_plugins" not in sys.modules:
|
||||
parent = types.ModuleType("hermes_plugins")
|
||||
parent.__path__ = [] # type: ignore[attr-defined]
|
||||
sys.modules["hermes_plugins"] = parent
|
||||
module_name = "hermes_plugins.memory_gateway_agent_pre_llm_test"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name,
|
||||
plugin_dir / "__init__.py",
|
||||
submodule_search_locations=[str(plugin_dir)],
|
||||
)
|
||||
assert spec and spec.loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.__package__ = module_name
|
||||
module.__path__ = [str(plugin_dir)] # type: ignore[attr-defined]
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_hook_pre_llm_search_failure_non_blocking(monkeypatch):
|
||||
plugin = _load_root_plugin()
|
||||
|
||||
class FailingLifecycle:
|
||||
@staticmethod
|
||||
def on_conversation_start(context):
|
||||
return {"ok": False, "error": "network_error"}
|
||||
|
||||
monkeypatch.setattr(plugin, "lifecycle", FailingLifecycle)
|
||||
|
||||
assert plugin.pre_llm_call(user_id="u", agent_id="a", session_id="s", user_message="search") == {}
|
||||
@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.trace import trace_enabled
|
||||
|
||||
|
||||
def test_hook_trace_disabled_by_default(monkeypatch):
|
||||
monkeypatch.delenv("MEMORY_GATEWAY_PLUGIN_TRACE_HOOKS", raising=False)
|
||||
|
||||
assert trace_enabled() is False
|
||||
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.trace import trace_hook
|
||||
|
||||
|
||||
def test_hook_trace_does_not_log_api_key(monkeypatch, tmp_path):
|
||||
import memory_gateway_plugin.trace as trace_mod
|
||||
|
||||
path = tmp_path / "hook_trace.log"
|
||||
monkeypatch.setenv("MEMORY_GATEWAY_PLUGIN_TRACE_HOOKS", "true")
|
||||
monkeypatch.setenv("MEMORY_GATEWAY_API_KEY", "sk-should-not-appear")
|
||||
monkeypatch.setattr(trace_mod, "trace_path", lambda: path)
|
||||
|
||||
trace_hook("post_llm_call", session_id="s", gateway_action="append_episode", gateway_called=True, ok=True)
|
||||
|
||||
text = path.read_text(encoding="utf-8")
|
||||
assert "sk-should-not-appear" not in text
|
||||
assert "api_key" not in text.lower()
|
||||
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.trace import trace_hook
|
||||
|
||||
|
||||
def test_hook_trace_redacts_content(monkeypatch, tmp_path):
|
||||
import memory_gateway_plugin.trace as trace_mod
|
||||
|
||||
path = tmp_path / "hook_trace.log"
|
||||
monkeypatch.setenv("MEMORY_GATEWAY_PLUGIN_TRACE_HOOKS", "true")
|
||||
monkeypatch.setattr(trace_mod, "trace_path", lambda: path)
|
||||
|
||||
trace_hook("pre_llm_call", session_id="test_session_1234567890", gateway_action="memory_search", gateway_called=True, ok=True, reason="password=abc")
|
||||
|
||||
text = path.read_text(encoding="utf-8")
|
||||
assert "password=abc" not in text
|
||||
assert "pre_llm_call" in text
|
||||
assert "test_ses" in text
|
||||
@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_interactive_session_check_imports():
|
||||
path = Path(__file__).resolve().parents[1] / "scripts" / "hermes_interactive_session_check.py"
|
||||
spec = importlib.util.spec_from_file_location("hermes_interactive_session_check", path)
|
||||
assert spec and spec.loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
assert callable(module.run)
|
||||
assert module.SESSION_ID == "test_session_memory_gateway_plugin_interactive_002"
|
||||
70
plugins/memory-gateway-agent/tests/test_lifecycle.py
Normal file
70
plugins/memory-gateway-agent/tests/test_lifecycle.py
Normal file
@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin import register
|
||||
from memory_gateway_plugin.config import PluginConfig
|
||||
from memory_gateway_plugin.lifecycle import after_user_message, on_conversation_start, on_session_end
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def search_memory(self, payload):
|
||||
return {
|
||||
"ok": True,
|
||||
"data": {
|
||||
"results": [
|
||||
{
|
||||
"memory": {
|
||||
"id": "mem_1",
|
||||
"namespace": "user/u/long_term",
|
||||
"summary": "用户偏好中文输出。",
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
def append_episode(self, payload):
|
||||
return {"ok": True, "data": payload}
|
||||
|
||||
def commit_session(self, session_id, payload):
|
||||
return {"ok": True, "data": {"session_id": session_id}}
|
||||
|
||||
|
||||
def test_lifecycle_hooks_do_not_crash_when_ctx_missing_features():
|
||||
result = register(object())
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["mode"] == "manual"
|
||||
|
||||
|
||||
def test_lifecycle_search_returns_compact_context():
|
||||
result = on_conversation_start(
|
||||
{"user_id": "u", "agent_id": "a", "session_id": "s", "user_message": "之前偏好是什么?"},
|
||||
client=FakeClient(),
|
||||
config=PluginConfig(auto_search=True),
|
||||
)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert "用户偏好中文输出" in result["memory_context"]
|
||||
|
||||
|
||||
def test_lifecycle_append_policy_accepts_stable_preference():
|
||||
result = after_user_message(
|
||||
{"user_id": "u", "agent_id": "a", "session_id": "s", "user_message": "请记住:我偏好中文。"},
|
||||
client=FakeClient(),
|
||||
config=PluginConfig(auto_append_episode=True),
|
||||
)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["appended"] is True
|
||||
|
||||
|
||||
def test_lifecycle_session_end_auto_commit_disabled():
|
||||
result = on_session_end(
|
||||
{"user_id": "u", "agent_id": "a", "session_id": "s"},
|
||||
client=FakeClient(),
|
||||
config=PluginConfig(auto_commit_session=False),
|
||||
)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert result["committed"] is False
|
||||
|
||||
24
plugins/memory-gateway-agent/tests/test_output_redaction.py
Normal file
24
plugins/memory-gateway-agent/tests/test_output_redaction.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.output import dumps_safe, redact, short_id
|
||||
|
||||
|
||||
def test_output_redaction_hides_secret_fields():
|
||||
payload = {
|
||||
"api_key": "sk-test",
|
||||
"headers": {"Authorization": "Bearer abc"},
|
||||
"nested": {"cookie": "sid=abc"},
|
||||
"safe": "value",
|
||||
}
|
||||
|
||||
text = dumps_safe(payload)
|
||||
|
||||
assert "sk-test" not in text
|
||||
assert "Bearer abc" not in text
|
||||
assert "sid=abc" not in text
|
||||
assert "value" in text
|
||||
assert redact("password=abc") == "<redacted>"
|
||||
|
||||
|
||||
def test_output_redaction_shortens_memory_ids():
|
||||
assert short_id("mem_1234567890abcdef") == "mem_1234...cdef"
|
||||
22
plugins/memory-gateway-agent/tests/test_policy.py
Normal file
22
plugins/memory-gateway-agent/tests/test_policy.py
Normal file
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.config import PluginConfig
|
||||
from memory_gateway_plugin.policy import should_append_episode, should_commit_session, should_search_memory
|
||||
|
||||
|
||||
def test_policy_should_append_for_explicit_remember():
|
||||
assert should_append_episode("请记住:我偏好中文技术说明。", "", {}, PluginConfig())
|
||||
|
||||
|
||||
def test_policy_should_not_append_for_small_talk():
|
||||
assert not should_append_episode("你好", "", {}, PluginConfig())
|
||||
|
||||
|
||||
def test_policy_should_search_when_enabled():
|
||||
assert should_search_memory("这个项目之前有什么约束?", {}, PluginConfig(auto_search=True))
|
||||
|
||||
|
||||
def test_policy_should_commit_only_when_enabled_or_forced():
|
||||
assert not should_commit_session({}, PluginConfig(auto_commit_session=False))
|
||||
assert should_commit_session({"force_commit": True}, PluginConfig(auto_commit_session=False))
|
||||
|
||||
24
plugins/memory-gateway-agent/tests/test_safety.py
Normal file
24
plugins/memory-gateway-agent/tests/test_safety.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.safety import detect_large_log, sanitize_memory_content, validate_memory_write
|
||||
|
||||
|
||||
def test_safety_rejects_private_key():
|
||||
content = "-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----"
|
||||
result = validate_memory_write(content)
|
||||
|
||||
assert result["allowed"] is False
|
||||
assert result["reason"] == "secret_like_content"
|
||||
|
||||
|
||||
def test_safety_rejects_large_log():
|
||||
content = "\n".join(f"2026-05-06 10:00:{i:02d} ERROR failure" for i in range(10))
|
||||
blocked, reason = detect_large_log(content)
|
||||
|
||||
assert blocked is True
|
||||
assert reason == "large_or_raw_log"
|
||||
|
||||
|
||||
def test_safety_sanitizes_secret_when_called_directly():
|
||||
assert "sk-test" not in sanitize_memory_content("api_key=sk-test")
|
||||
|
||||
106
plugins/memory-gateway-agent/tests/test_tools.py
Normal file
106
plugins/memory-gateway-agent/tests/test_tools.py
Normal file
@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from memory_gateway_plugin.tools import memory_append_episode, memory_feedback, memory_search, memory_upsert
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def search_memory(self, payload):
|
||||
self.calls.append(("search", payload))
|
||||
return {"ok": True, "data": {"results": []}}
|
||||
|
||||
def append_episode(self, payload):
|
||||
self.calls.append(("append", payload))
|
||||
return {"ok": True, "data": payload}
|
||||
|
||||
def upsert_memory(self, payload):
|
||||
self.calls.append(("upsert", payload))
|
||||
return {"ok": True, "data": payload}
|
||||
|
||||
def send_feedback(self, memory_id, payload):
|
||||
self.calls.append(("feedback", memory_id, payload))
|
||||
return {"ok": True, "data": payload}
|
||||
|
||||
|
||||
def test_memory_search_empty_query_rejected():
|
||||
client = FakeClient()
|
||||
result = memory_search(query="", user_id="u", agent_id="a", client=client)
|
||||
|
||||
assert result["ok"] is False
|
||||
assert client.calls == []
|
||||
|
||||
|
||||
def test_append_episode_rejects_api_key():
|
||||
result = memory_append_episode(
|
||||
user_id="u",
|
||||
agent_id="a",
|
||||
session_id="s",
|
||||
episode_summary="api_key=sk-secret",
|
||||
client=FakeClient(),
|
||||
)
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["reason"] == "secret_like_content"
|
||||
|
||||
|
||||
def test_append_episode_rejects_password():
|
||||
result = memory_append_episode(
|
||||
user_id="u",
|
||||
agent_id="a",
|
||||
session_id="s",
|
||||
episode_summary="password=hunter2",
|
||||
client=FakeClient(),
|
||||
)
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["reason"] == "secret_like_content"
|
||||
|
||||
|
||||
def test_append_episode_rejects_raw_transcript():
|
||||
content = "\n".join(["User: hi", "Assistant: hello", "User: remember this", "Assistant: ok"])
|
||||
result = memory_append_episode(user_id="u", agent_id="a", session_id="s", episode_summary=content, client=FakeClient())
|
||||
|
||||
assert result["ok"] is False
|
||||
assert result["reason"] == "raw_chat_transcript"
|
||||
|
||||
|
||||
def test_append_episode_accepts_stable_preference():
|
||||
client = FakeClient()
|
||||
result = memory_append_episode(
|
||||
user_id="u",
|
||||
agent_id="a",
|
||||
session_id="s",
|
||||
episode_summary="用户稳定偏好:以后所有技术方案都使用中文输出。",
|
||||
tags=["preference"],
|
||||
client=client,
|
||||
)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert client.calls[0][0] == "append"
|
||||
|
||||
|
||||
def test_upsert_uses_long_term_namespace_when_provided():
|
||||
client = FakeClient()
|
||||
namespace = "user/u/long_term"
|
||||
result = memory_upsert(
|
||||
user_id="u",
|
||||
agent_id="a",
|
||||
namespace=namespace,
|
||||
memory_type="preference",
|
||||
content="用户稳定偏好:使用中文输出。",
|
||||
client=client,
|
||||
)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert client.calls[0][1]["namespace"] == namespace
|
||||
|
||||
|
||||
def test_feedback_calls_correct_endpoint():
|
||||
client = FakeClient()
|
||||
result = memory_feedback(user_id="u", agent_id="a", memory_id="mem_1", feedback="reject", client=client)
|
||||
|
||||
assert result["ok"] is True
|
||||
assert client.calls[0] == ("feedback", "mem_1", {"user_id": "u", "agent_id": "a", "feedback": "incorrect", "comment": None})
|
||||
|
||||
Reference in New Issue
Block a user