import sys import types from fastapi.responses import StreamingResponse from fastapi.testclient import TestClient def install_test_stubs() -> None: if "mcp.server" not in sys.modules: mcp_module = types.ModuleType("mcp") mcp_server_module = types.ModuleType("mcp.server") mcp_types_module = types.ModuleType("mcp.types") class Server: def __init__(self, name): self.name = name def list_tools(self): def decorator(func): return func return decorator def call_tool(self): def decorator(func): return func return decorator class Tool: def __init__(self, name, description, inputSchema): self.name = name self.description = description self.inputSchema = inputSchema class TextContent: def __init__(self, type, text): self.type = type self.text = text def model_dump(self): return {"type": self.type, "text": self.text} mcp_server_module.Server = Server mcp_types_module.Tool = Tool mcp_types_module.TextContent = TextContent sys.modules["mcp"] = mcp_module sys.modules["mcp.server"] = mcp_server_module sys.modules["mcp.types"] = mcp_types_module if "sse_starlette" not in sys.modules: sse_module = types.ModuleType("sse_starlette") class EventSourceResponse(StreamingResponse): def __init__(self, content, *args, **kwargs): super().__init__(content, media_type="text/event-stream", *args, **kwargs) sse_module.EventSourceResponse = EventSourceResponse sys.modules["sse_starlette"] = sse_module install_test_stubs() from memory_gateway.server import app from memory_gateway.types import Config, SearchResult, ServerConfig class FakeOVClient: async def health_check(self): return {"status": "ok", "backend": "fake"} async def search(self, query, namespace=None, limit=None, uri=None): return SearchResult( results=[ { "uri": "viking://soc/test", "abstract": query, "score": 1.0, "context_type": "memory", } ], total=1, ) async def add_memory(self, content, namespace=None, memory_type="general"): return { "status": "ok", "content": content, "namespace": namespace, "memory_type": memory_type, } async def add_resource(self, uri, content, resource_type="text"): return { "status": "ok", "uri": uri, "content": content, "resource_type": resource_type, } async def list_memories(self, namespace=None, memory_type=None, limit=None): return [] async def list_resources(self, namespace=None, limit=None): return [] async def fake_get_openviking_client(): return FakeOVClient() def build_headers(api_key: str | None): return {"x-api-key": api_key} if api_key is not None else {} def test_health_requires_api_key(monkeypatch): monkeypatch.setattr( "memory_gateway.server.get_config", lambda: Config(server=ServerConfig(api_key="secret")), ) monkeypatch.setattr( "memory_gateway.server.get_openviking_client", fake_get_openviking_client, ) with TestClient(app) as client: response = client.get("/health") assert response.status_code == 401 response = client.get("/health", headers=build_headers("secret")) assert response.status_code == 200 assert response.json()["openviking"]["status"] == "ok" def test_mcp_rpc_lists_tools_with_api_key(monkeypatch): monkeypatch.setattr( "memory_gateway.server.get_config", lambda: Config(server=ServerConfig(api_key="secret")), ) monkeypatch.setattr( "memory_gateway.server.get_openviking_client", fake_get_openviking_client, ) with TestClient(app) as client: response = client.post( "/mcp/rpc", json={"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}, headers=build_headers("secret"), ) assert response.status_code == 200 payload = response.json() assert payload["jsonrpc"] == "2.0" assert len(payload["result"]["tools"]) == 6 def test_search_passes_through_gateway(monkeypatch): monkeypatch.setattr( "memory_gateway.server.get_config", lambda: Config(server=ServerConfig(api_key="")), ) monkeypatch.setattr( "memory_gateway.server.get_openviking_client", fake_get_openviking_client, ) with TestClient(app) as client: response = client.post("/api/search", json={"query": "phishing"}) assert response.status_code == 200 payload = response.json() assert payload["total"] == 1 assert payload["results"][0]["abstract"] == "phishing"