388 lines
12 KiB
Python
388 lines
12 KiB
Python
"""Memory Gateway MCP Server.
|
||
|
||
基于 Model Context Protocol 的记忆网关服务,为局域网内的 AI Agent 提供统一的 OpenViking 访问入口。
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from contextlib import asynccontextmanager
|
||
from typing import Any, Optional
|
||
|
||
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status
|
||
from fastapi.responses import JSONResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from mcp.server import Server
|
||
from mcp.types import TextContent, Tool
|
||
from sse_starlette import EventSourceResponse
|
||
|
||
from .config import get_config, set_config, Config
|
||
from .openviking_client import get_openviking_client, close_openviking_client
|
||
from .types import SearchRequest, AddMemoryRequest, AddResourceRequest
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# 创建 MCP Server
|
||
mcp_server = Server("memory-gateway")
|
||
|
||
|
||
@mcp_server.list_tools()
|
||
async def list_tools() -> list[Tool]:
|
||
"""列出可用的 MCP 工具"""
|
||
return [
|
||
Tool(
|
||
name="search",
|
||
description="语义搜索记忆和资源",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {"type": "string", "description": "搜索查询"},
|
||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||
"limit": {"type": "integer", "description": "返回结果数量(默认10)"},
|
||
"uri": {"type": "string", "description": "资源 URI(可选)"},
|
||
},
|
||
"required": ["query"],
|
||
},
|
||
),
|
||
Tool(
|
||
name="add_memory",
|
||
description="添加新记忆",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"content": {"type": "string", "description": "记忆内容"},
|
||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||
"memory_type": {"type": "string", "description": "记忆类型(默认general)"},
|
||
},
|
||
"required": ["content"],
|
||
},
|
||
),
|
||
Tool(
|
||
name="add_resource",
|
||
description="添加资源",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"uri": {"type": "string", "description": "资源 URI"},
|
||
"content": {"type": "string", "description": "资源内容"},
|
||
"resource_type": {"type": "string", "description": "资源类型(默认text)"},
|
||
},
|
||
"required": ["uri", "content"],
|
||
},
|
||
),
|
||
Tool(
|
||
name="get_status",
|
||
description="检查系统状态",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {},
|
||
},
|
||
),
|
||
Tool(
|
||
name="list_memories",
|
||
description="列出已存储的记忆",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||
"memory_type": {"type": "string", "description": "记忆类型(可选)"},
|
||
"limit": {"type": "integer", "description": "返回数量(默认10)"},
|
||
},
|
||
},
|
||
),
|
||
Tool(
|
||
name="list_resources",
|
||
description="列出已存储的资源",
|
||
inputSchema={
|
||
"type": "object",
|
||
"properties": {
|
||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||
"limit": {"type": "integer", "description": "返回数量(默认10)"},
|
||
},
|
||
},
|
||
),
|
||
]
|
||
|
||
|
||
@mcp_server.call_tool()
|
||
async def call_tool(name: str, arguments: Any) -> list[TextContent]:
|
||
"""调用 MCP 工具"""
|
||
try:
|
||
ov_client = await get_openviking_client()
|
||
|
||
if name == "search":
|
||
result = await ov_client.search(
|
||
query=arguments.get("query"),
|
||
namespace=arguments.get("namespace"),
|
||
limit=arguments.get("limit"),
|
||
uri=arguments.get("uri"),
|
||
)
|
||
return [TextContent(type="text", text=str(result.results))]
|
||
|
||
elif name == "add_memory":
|
||
result = await ov_client.add_memory(
|
||
content=arguments.get("content"),
|
||
namespace=arguments.get("namespace"),
|
||
memory_type=arguments.get("memory_type", "general"),
|
||
)
|
||
return [TextContent(type="text", text=str(result))]
|
||
|
||
elif name == "add_resource":
|
||
result = await ov_client.add_resource(
|
||
uri=arguments.get("uri"),
|
||
content=arguments.get("content"),
|
||
resource_type=arguments.get("resource_type", "text"),
|
||
)
|
||
return [TextContent(type="text", text=str(result))]
|
||
|
||
elif name == "get_status":
|
||
ov_status = await ov_client.health_check()
|
||
return [TextContent(type="text", text=f"Memory Gateway: OK\nOpenViking: {ov_status}")]
|
||
|
||
elif name == "list_memories":
|
||
memories = await ov_client.list_memories(
|
||
namespace=arguments.get("namespace"),
|
||
memory_type=arguments.get("memory_type"),
|
||
limit=arguments.get("limit"),
|
||
)
|
||
return [TextContent(type="text", text=str([m.model_dump() for m in memories]))]
|
||
|
||
elif name == "list_resources":
|
||
resources = await ov_client.list_resources(
|
||
namespace=arguments.get("namespace"),
|
||
limit=arguments.get("limit"),
|
||
)
|
||
return [TextContent(type="text", text=str([r.model_dump() for r in resources]))]
|
||
|
||
else:
|
||
raise ValueError(f"Unknown tool: {name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"工具执行失败: {e}")
|
||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理"""
|
||
logger.info("Memory Gateway 启动中...")
|
||
config = get_config()
|
||
logger.info(f"配置加载完成: {config.server.host}:{config.server.port}")
|
||
logger.info(f"OpenViking 后端: {config.openviking.url}")
|
||
|
||
# 测试 OpenViking 连接
|
||
try:
|
||
ov_client = await get_openviking_client()
|
||
status = await ov_client.health_check()
|
||
logger.info(f"OpenViking 连接状态: {status}")
|
||
except Exception as e:
|
||
logger.warning(f"OpenViking 连接失败: {e}")
|
||
|
||
yield
|
||
|
||
logger.info("Memory Gateway 关闭中...")
|
||
await close_openviking_client()
|
||
|
||
|
||
def verify_api_key(x_api_key: Optional[str] = Header(default=None)) -> None:
|
||
"""在配置了 API Key 时校验请求头。"""
|
||
expected_key = get_config().server.api_key
|
||
if not expected_key:
|
||
return
|
||
if x_api_key != expected_key:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Invalid or missing API key",
|
||
)
|
||
|
||
|
||
# FastAPI 应用
|
||
app = FastAPI(title="Memory Gateway", version="0.1.0", lifespan=lifespan)
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
@app.get("/health", dependencies=[Depends(verify_api_key)])
|
||
async def health_check():
|
||
"""健康检查"""
|
||
try:
|
||
ov_client = await get_openviking_client()
|
||
ov_status = await ov_client.health_check()
|
||
return {
|
||
"status": "ok",
|
||
"gateway": "memory-gateway",
|
||
"openviking": ov_status,
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"status": "degraded",
|
||
"gateway": "memory-gateway",
|
||
"error": str(e),
|
||
}
|
||
|
||
mcp_router = APIRouter()
|
||
|
||
|
||
async def mcp_server_events(request: Request, _: None = Depends(verify_api_key)):
|
||
"""MCP Server-Sent Events 端点 - 使用 stdio 模式模拟"""
|
||
async def event_generator():
|
||
# 发送初始化消息
|
||
yield {"event": "initialize", "data": json.dumps({"protocolVersion": "2024-11-05"})}
|
||
|
||
# 保持连接
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(30)
|
||
yield {"event": "ping", "data": ""}
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
return EventSourceResponse(event_generator())
|
||
|
||
|
||
mcp_router.add_api_route("/sse", mcp_server_events, methods=["GET"])
|
||
|
||
|
||
# MCP JSON-RPC 端点(简化实现)
|
||
async def mcp_rpc(request: Request, _: None = Depends(verify_api_key)):
|
||
"""处理 MCP JSON-RPC 请求"""
|
||
body = await request.json()
|
||
|
||
method = body.get("method")
|
||
params = body.get("params", {})
|
||
msg_id = body.get("id")
|
||
|
||
try:
|
||
if method == "tools/list":
|
||
tools = await list_tools()
|
||
result = {
|
||
"tools": [
|
||
{
|
||
"name": t.name,
|
||
"description": t.description,
|
||
"inputSchema": t.inputSchema,
|
||
}
|
||
for t in tools
|
||
]
|
||
}
|
||
elif method == "tools/call":
|
||
tool_name = params.get("name")
|
||
tool_args = params.get("arguments", {})
|
||
result_content = await call_tool_tool(tool_name, tool_args)
|
||
result = {"content": [c.model_dump() for c in result_content]}
|
||
else:
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={"jsonrpc": "2.0", "error": {"code": -32601, "message": f"Method not found: {method}"}, "id": msg_id}
|
||
)
|
||
|
||
return {"jsonrpc": "2.0", "result": result, "id": msg_id}
|
||
|
||
except Exception as e:
|
||
logger.error(f"MCP RPC 错误: {e}")
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"jsonrpc": "2.0", "error": {"code": -32603, "message": str(e)}, "id": msg_id}
|
||
)
|
||
|
||
|
||
async def call_tool_tool(name: str, arguments: dict) -> list[TextContent]:
|
||
"""调用工具的内部函数"""
|
||
return await call_tool(name, arguments)
|
||
|
||
|
||
mcp_router.add_api_route("/rpc", mcp_rpc, methods=["POST"])
|
||
|
||
|
||
# 注册 MCP 路由
|
||
app.include_router(mcp_router, prefix="/mcp", tags=["mcp"])
|
||
|
||
|
||
@app.post("/api/search", dependencies=[Depends(verify_api_key)])
|
||
async def api_search(request: SearchRequest):
|
||
"""REST API: 搜索"""
|
||
ov_client = await get_openviking_client()
|
||
result = await ov_client.search(
|
||
query=request.query,
|
||
namespace=request.namespace or get_config().memory.default_namespace,
|
||
limit=request.limit or get_config().memory.search_limit,
|
||
uri=request.uri,
|
||
)
|
||
return {"results": result.results, "total": result.total}
|
||
|
||
|
||
@app.post("/api/memory", dependencies=[Depends(verify_api_key)])
|
||
async def api_add_memory(request: AddMemoryRequest):
|
||
"""REST API: 添加记忆"""
|
||
ov_client = await get_openviking_client()
|
||
result = await ov_client.add_memory(
|
||
content=request.content,
|
||
namespace=request.namespace or get_config().memory.default_namespace,
|
||
memory_type=request.memory_type,
|
||
)
|
||
return result
|
||
|
||
|
||
@app.post("/api/resource", dependencies=[Depends(verify_api_key)])
|
||
async def api_add_resource(request: AddResourceRequest):
|
||
"""REST API: 添加资源"""
|
||
ov_client = await get_openviking_client()
|
||
result = await ov_client.add_resource(
|
||
uri=request.uri,
|
||
content=request.content,
|
||
resource_type=request.resource_type,
|
||
)
|
||
return result
|
||
|
||
|
||
def create_app(config: Optional[Config] = None) -> FastAPI:
|
||
"""创建 FastAPI 应用"""
|
||
if config:
|
||
set_config(config)
|
||
return app
|
||
|
||
|
||
# 入口点
|
||
def main():
|
||
"""主入口"""
|
||
import argparse
|
||
import uvicorn
|
||
|
||
parser = argparse.ArgumentParser(description="Memory Gateway MCP Server")
|
||
parser.add_argument("--config", default="config.yaml", help="配置文件路径")
|
||
parser.add_argument("--host", default=None, help="监听地址")
|
||
parser.add_argument("--port", type=int, default=None, help="监听端口")
|
||
args = parser.parse_args()
|
||
|
||
# 加载配置
|
||
from .config import load_config as load
|
||
config = load(args.config)
|
||
if args.host:
|
||
config.server.host = args.host
|
||
if args.port:
|
||
config.server.port = args.port
|
||
set_config(config)
|
||
|
||
# 启动服务
|
||
uvicorn.run(
|
||
app,
|
||
host=config.server.host,
|
||
port=config.server.port,
|
||
log_level=config.logging.level.lower(),
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|