255 lines
8.5 KiB
Python
255 lines
8.5 KiB
Python
import asyncio
|
|
import sys
|
|
import types
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
|
|
def install_test_stubs() -> None:
|
|
if "mcp.server" not in sys.modules:
|
|
mcp_module = types.ModuleType("mcp")
|
|
mcp_server_module = types.ModuleType("mcp.server")
|
|
mcp_types_module = types.ModuleType("mcp.types")
|
|
|
|
class Server:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def list_tools(self):
|
|
def decorator(func):
|
|
return func
|
|
return decorator
|
|
|
|
def call_tool(self):
|
|
def decorator(func):
|
|
return func
|
|
return decorator
|
|
|
|
class Tool:
|
|
def __init__(self, name, description, inputSchema):
|
|
self.name = name
|
|
self.description = description
|
|
self.inputSchema = inputSchema
|
|
|
|
class TextContent:
|
|
def __init__(self, type, text):
|
|
self.type = type
|
|
self.text = text
|
|
|
|
def model_dump(self):
|
|
return {"type": self.type, "text": self.text}
|
|
|
|
mcp_server_module.Server = Server
|
|
mcp_types_module.Tool = Tool
|
|
mcp_types_module.TextContent = TextContent
|
|
sys.modules["mcp"] = mcp_module
|
|
sys.modules["mcp.server"] = mcp_server_module
|
|
sys.modules["mcp.types"] = mcp_types_module
|
|
|
|
if "sse_starlette" not in sys.modules:
|
|
sse_module = types.ModuleType("sse_starlette")
|
|
|
|
class EventSourceResponse(StreamingResponse):
|
|
def __init__(self, content, *args, **kwargs):
|
|
super().__init__(content, media_type="text/event-stream", *args, **kwargs)
|
|
|
|
sse_module.EventSourceResponse = EventSourceResponse
|
|
sys.modules["sse_starlette"] = sse_module
|
|
|
|
|
|
install_test_stubs()
|
|
|
|
import memory_gateway.server as server
|
|
from memory_gateway.types import CommitSummaryRequest, Config, ObsidianConfig, SearchRequest, SearchResult, ServerConfig
|
|
|
|
|
|
class FakeOVClient:
|
|
async def health_check(self):
|
|
return {"status": "ok", "backend": "fake"}
|
|
|
|
async def search(self, query, namespace=None, limit=None, uri=None):
|
|
return SearchResult(
|
|
results=[
|
|
{
|
|
"uri": "viking://memory-gateway/test",
|
|
"abstract": query,
|
|
"score": 1.0,
|
|
"context_type": "memory",
|
|
}
|
|
],
|
|
total=1,
|
|
)
|
|
|
|
async def add_memory(self, content, namespace=None, memory_type="general"):
|
|
return {
|
|
"status": "ok",
|
|
"content": content,
|
|
"namespace": namespace,
|
|
"memory_type": memory_type,
|
|
}
|
|
|
|
async def add_resource(self, uri, content, resource_type="text"):
|
|
return {
|
|
"status": "ok",
|
|
"uri": uri,
|
|
"content": content,
|
|
"resource_type": resource_type,
|
|
}
|
|
|
|
async def list_memories(self, namespace=None, memory_type=None, limit=None):
|
|
return []
|
|
|
|
async def list_resources(self, namespace=None, limit=None):
|
|
return []
|
|
|
|
|
|
async def fake_get_openviking_client():
|
|
return FakeOVClient()
|
|
|
|
|
|
async def fake_summarize_with_llm(content, **kwargs):
|
|
return {
|
|
"title": kwargs.get("title") or "Fake LLM title",
|
|
"summary": f"LLM summary: {content[:80]}",
|
|
"key_points": ["LLM key point", "Preserve IP 198.51.100.20"],
|
|
"tags": kwargs.get("tags") or ["fake"],
|
|
"llm": {"provider": "fake", "model": "fake-model"},
|
|
}
|
|
|
|
|
|
class FakeUploadFile:
|
|
def __init__(self, filename: str, content: bytes) -> None:
|
|
self.filename = filename
|
|
self._content = content
|
|
|
|
async def read(self) -> bytes:
|
|
return self._content
|
|
|
|
def test_health_requires_api_key(monkeypatch):
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_config",
|
|
lambda: Config(server=ServerConfig(api_key="secret")),
|
|
)
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_openviking_client",
|
|
fake_get_openviking_client,
|
|
)
|
|
monkeypatch.setattr("memory_gateway.server.summarize_with_llm", fake_summarize_with_llm)
|
|
monkeypatch.setattr("memory_gateway.server.v1_service.evermemos_health", lambda: {"status": "disabled"})
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
server.verify_api_key()
|
|
assert exc_info.value.status_code == 401
|
|
|
|
server.verify_api_key("secret")
|
|
payload = asyncio.run(server.health_check())
|
|
assert payload["openviking"]["status"] == "ok"
|
|
|
|
|
|
def test_mcp_rpc_lists_tools_with_api_key(monkeypatch):
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_config",
|
|
lambda: Config(server=ServerConfig(api_key="secret")),
|
|
)
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_openviking_client",
|
|
fake_get_openviking_client,
|
|
)
|
|
|
|
server.verify_api_key("secret")
|
|
tools = asyncio.run(server.list_tools())
|
|
assert len(tools) >= 7
|
|
assert any(tool.name == "commit_summary" for tool in tools)
|
|
assert any(tool.name == "memory_search" for tool in tools)
|
|
|
|
|
|
def test_search_passes_through_gateway(monkeypatch):
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_config",
|
|
lambda: Config(server=ServerConfig(api_key="")),
|
|
)
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_openviking_client",
|
|
fake_get_openviking_client,
|
|
)
|
|
|
|
payload = asyncio.run(server.api_search(SearchRequest(query="phishing")))
|
|
assert payload["total"] == 1
|
|
assert payload["results"][0]["abstract"] == "phishing"
|
|
|
|
|
|
def test_summary_endpoint_builds_generic_artifact(monkeypatch):
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_config",
|
|
lambda: Config(server=ServerConfig(api_key="")),
|
|
)
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_openviking_client",
|
|
fake_get_openviking_client,
|
|
)
|
|
monkeypatch.setattr("memory_gateway.server.summarize_with_llm", fake_summarize_with_llm)
|
|
|
|
payload = asyncio.run(
|
|
server.api_commit_summary(
|
|
CommitSummaryRequest(
|
|
title="Demo investigation summary",
|
|
content="结论:这是一次高价值沉淀。\n- 证据:命中历史 case。\n- 建议:后续复用该处置路径。",
|
|
namespace="demo",
|
|
memory_type="knowledge",
|
|
tags=["demo", "summary"],
|
|
persist_as="none",
|
|
)
|
|
)
|
|
)
|
|
assert payload["status"] == "ok"
|
|
assert payload["artifact"]["title"] == "Demo investigation summary"
|
|
assert payload["artifact"]["namespace"] == "demo"
|
|
assert payload["artifact"]["memory_type"] == "knowledge"
|
|
assert payload["artifact"]["summary"].startswith("LLM summary:")
|
|
assert payload["artifact"]["llm"]["provider"] == "fake"
|
|
assert payload["memory_result"] is None
|
|
assert payload["resource_result"] is None
|
|
|
|
|
|
def test_knowledge_upload_converts_saves_and_commits(monkeypatch, tmp_path):
|
|
monkeypatch.setattr(
|
|
"memory_gateway.server.get_config",
|
|
lambda: Config(
|
|
server=ServerConfig(api_key=""),
|
|
obsidian=ObsidianConfig(vault_path=str(tmp_path / "vault"), knowledge_dir="01_Knowledge/Uploaded"),
|
|
),
|
|
)
|
|
monkeypatch.setattr("memory_gateway.server.get_openviking_client", fake_get_openviking_client)
|
|
monkeypatch.setattr("memory_gateway.server.summarize_with_llm", fake_summarize_with_llm)
|
|
monkeypatch.setattr("memory_gateway.server.convert_file_to_markdown", lambda path: "# Uploaded Doc\n\nImportant uploaded knowledge.")
|
|
|
|
async def fake_to_thread(func, *args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
monkeypatch.setattr("memory_gateway.server.asyncio.to_thread", fake_to_thread)
|
|
|
|
upload = FakeUploadFile(filename="sample.txt", content=b"hello")
|
|
payload = asyncio.run(
|
|
server.api_upload_knowledge(
|
|
file=upload,
|
|
title="Uploaded Knowledge",
|
|
namespace="demo",
|
|
knowledge_type="playbook",
|
|
tags="demo,upload",
|
|
source=None,
|
|
obsidian_dir=None,
|
|
resource_uri=None,
|
|
persist_as="resource",
|
|
max_summary_chars=1000,
|
|
)
|
|
)
|
|
|
|
assert payload["status"] == "ok"
|
|
assert payload["artifact"]["schema_version"] == "memory-gateway.knowledge_upload.v1"
|
|
assert payload["artifact"]["knowledge_type"] == "playbook"
|
|
assert payload["artifact"]["markdown_content"].startswith("# Uploaded Doc")
|
|
assert payload["resource_result"]["status"] == "ok"
|
|
assert (tmp_path / "vault" / payload["artifact"]["obsidian_relative_path"]).exists()
|