Initial SOC memory POC implementation
This commit is contained in:
170
tests/test_server.py
Normal file
170
tests/test_server.py
Normal file
@ -0,0 +1,170 @@
|
||||
import sys
|
||||
import types
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def install_test_stubs() -> None:
|
||||
if "mcp.server" not in sys.modules:
|
||||
mcp_module = types.ModuleType("mcp")
|
||||
mcp_server_module = types.ModuleType("mcp.server")
|
||||
mcp_types_module = types.ModuleType("mcp.types")
|
||||
|
||||
class Server:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def list_tools(self):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
def call_tool(self):
|
||||
def decorator(func):
|
||||
return func
|
||||
return decorator
|
||||
|
||||
class Tool:
|
||||
def __init__(self, name, description, inputSchema):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.inputSchema = inputSchema
|
||||
|
||||
class TextContent:
|
||||
def __init__(self, type, text):
|
||||
self.type = type
|
||||
self.text = text
|
||||
|
||||
def model_dump(self):
|
||||
return {"type": self.type, "text": self.text}
|
||||
|
||||
mcp_server_module.Server = Server
|
||||
mcp_types_module.Tool = Tool
|
||||
mcp_types_module.TextContent = TextContent
|
||||
sys.modules["mcp"] = mcp_module
|
||||
sys.modules["mcp.server"] = mcp_server_module
|
||||
sys.modules["mcp.types"] = mcp_types_module
|
||||
|
||||
if "sse_starlette" not in sys.modules:
|
||||
sse_module = types.ModuleType("sse_starlette")
|
||||
|
||||
class EventSourceResponse(StreamingResponse):
|
||||
def __init__(self, content, *args, **kwargs):
|
||||
super().__init__(content, media_type="text/event-stream", *args, **kwargs)
|
||||
|
||||
sse_module.EventSourceResponse = EventSourceResponse
|
||||
sys.modules["sse_starlette"] = sse_module
|
||||
|
||||
|
||||
install_test_stubs()
|
||||
|
||||
from memory_gateway.server import app
|
||||
from memory_gateway.types import Config, SearchResult, ServerConfig
|
||||
|
||||
|
||||
class FakeOVClient:
|
||||
async def health_check(self):
|
||||
return {"status": "ok", "backend": "fake"}
|
||||
|
||||
async def search(self, query, namespace=None, limit=None, uri=None):
|
||||
return SearchResult(
|
||||
results=[
|
||||
{
|
||||
"uri": "viking://soc/test",
|
||||
"abstract": query,
|
||||
"score": 1.0,
|
||||
"context_type": "memory",
|
||||
}
|
||||
],
|
||||
total=1,
|
||||
)
|
||||
|
||||
async def add_memory(self, content, namespace=None, memory_type="general"):
|
||||
return {
|
||||
"status": "ok",
|
||||
"content": content,
|
||||
"namespace": namespace,
|
||||
"memory_type": memory_type,
|
||||
}
|
||||
|
||||
async def add_resource(self, uri, content, resource_type="text"):
|
||||
return {
|
||||
"status": "ok",
|
||||
"uri": uri,
|
||||
"content": content,
|
||||
"resource_type": resource_type,
|
||||
}
|
||||
|
||||
async def list_memories(self, namespace=None, memory_type=None, limit=None):
|
||||
return []
|
||||
|
||||
async def list_resources(self, namespace=None, limit=None):
|
||||
return []
|
||||
|
||||
|
||||
async def fake_get_openviking_client():
|
||||
return FakeOVClient()
|
||||
|
||||
|
||||
def build_headers(api_key: str | None):
|
||||
return {"x-api-key": api_key} if api_key is not None else {}
|
||||
|
||||
|
||||
def test_health_requires_api_key(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"memory_gateway.server.get_config",
|
||||
lambda: Config(server=ServerConfig(api_key="secret")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"memory_gateway.server.get_openviking_client",
|
||||
fake_get_openviking_client,
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 401
|
||||
|
||||
response = client.get("/health", headers=build_headers("secret"))
|
||||
assert response.status_code == 200
|
||||
assert response.json()["openviking"]["status"] == "ok"
|
||||
|
||||
|
||||
def test_mcp_rpc_lists_tools_with_api_key(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"memory_gateway.server.get_config",
|
||||
lambda: Config(server=ServerConfig(api_key="secret")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"memory_gateway.server.get_openviking_client",
|
||||
fake_get_openviking_client,
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/mcp/rpc",
|
||||
json={"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}},
|
||||
headers=build_headers("secret"),
|
||||
)
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["jsonrpc"] == "2.0"
|
||||
assert len(payload["result"]["tools"]) == 6
|
||||
|
||||
|
||||
def test_search_passes_through_gateway(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"memory_gateway.server.get_config",
|
||||
lambda: Config(server=ServerConfig(api_key="")),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"memory_gateway.server.get_openviking_client",
|
||||
fake_get_openviking_client,
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/search", json={"query": "phishing"})
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["total"] == 1
|
||||
assert payload["results"][0]["abstract"] == "phishing"
|
||||
Reference in New Issue
Block a user