Initial SOC memory POC implementation
This commit is contained in:
387
memory_gateway/server.py
Normal file
387
memory_gateway/server.py
Normal file
@ -0,0 +1,387 @@
|
||||
"""Memory Gateway MCP Server.
|
||||
|
||||
基于 Model Context Protocol 的记忆网关服务,为局域网内的 AI Agent 提供统一的 OpenViking 访问入口。
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from mcp.server import Server
|
||||
from mcp.types import TextContent, Tool
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from .config import get_config, set_config, Config
|
||||
from .openviking_client import get_openviking_client, close_openviking_client
|
||||
from .types import SearchRequest, AddMemoryRequest, AddResourceRequest
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 创建 MCP Server
|
||||
mcp_server = Server("memory-gateway")
|
||||
|
||||
|
||||
@mcp_server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
"""列出可用的 MCP 工具"""
|
||||
return [
|
||||
Tool(
|
||||
name="search",
|
||||
description="语义搜索记忆和资源",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索查询"},
|
||||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||||
"limit": {"type": "integer", "description": "返回结果数量(默认10)"},
|
||||
"uri": {"type": "string", "description": "资源 URI(可选)"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="add_memory",
|
||||
description="添加新记忆",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "记忆内容"},
|
||||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||||
"memory_type": {"type": "string", "description": "记忆类型(默认general)"},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="add_resource",
|
||||
description="添加资源",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"uri": {"type": "string", "description": "资源 URI"},
|
||||
"content": {"type": "string", "description": "资源内容"},
|
||||
"resource_type": {"type": "string", "description": "资源类型(默认text)"},
|
||||
},
|
||||
"required": ["uri", "content"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_status",
|
||||
description="检查系统状态",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="list_memories",
|
||||
description="列出已存储的记忆",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||||
"memory_type": {"type": "string", "description": "记忆类型(可选)"},
|
||||
"limit": {"type": "integer", "description": "返回数量(默认10)"},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="list_resources",
|
||||
description="列出已存储的资源",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||||
"limit": {"type": "integer", "description": "返回数量(默认10)"},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@mcp_server.call_tool()
|
||||
async def call_tool(name: str, arguments: Any) -> list[TextContent]:
|
||||
"""调用 MCP 工具"""
|
||||
try:
|
||||
ov_client = await get_openviking_client()
|
||||
|
||||
if name == "search":
|
||||
result = await ov_client.search(
|
||||
query=arguments.get("query"),
|
||||
namespace=arguments.get("namespace"),
|
||||
limit=arguments.get("limit"),
|
||||
uri=arguments.get("uri"),
|
||||
)
|
||||
return [TextContent(type="text", text=str(result.results))]
|
||||
|
||||
elif name == "add_memory":
|
||||
result = await ov_client.add_memory(
|
||||
content=arguments.get("content"),
|
||||
namespace=arguments.get("namespace"),
|
||||
memory_type=arguments.get("memory_type", "general"),
|
||||
)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
|
||||
elif name == "add_resource":
|
||||
result = await ov_client.add_resource(
|
||||
uri=arguments.get("uri"),
|
||||
content=arguments.get("content"),
|
||||
resource_type=arguments.get("resource_type", "text"),
|
||||
)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
|
||||
elif name == "get_status":
|
||||
ov_status = await ov_client.health_check()
|
||||
return [TextContent(type="text", text=f"Memory Gateway: OK\nOpenViking: {ov_status}")]
|
||||
|
||||
elif name == "list_memories":
|
||||
memories = await ov_client.list_memories(
|
||||
namespace=arguments.get("namespace"),
|
||||
memory_type=arguments.get("memory_type"),
|
||||
limit=arguments.get("limit"),
|
||||
)
|
||||
return [TextContent(type="text", text=str([m.model_dump() for m in memories]))]
|
||||
|
||||
elif name == "list_resources":
|
||||
resources = await ov_client.list_resources(
|
||||
namespace=arguments.get("namespace"),
|
||||
limit=arguments.get("limit"),
|
||||
)
|
||||
return [TextContent(type="text", text=str([r.model_dump() for r in resources]))]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行失败: {e}")
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
logger.info("Memory Gateway 启动中...")
|
||||
config = get_config()
|
||||
logger.info(f"配置加载完成: {config.server.host}:{config.server.port}")
|
||||
logger.info(f"OpenViking 后端: {config.openviking.url}")
|
||||
|
||||
# 测试 OpenViking 连接
|
||||
try:
|
||||
ov_client = await get_openviking_client()
|
||||
status = await ov_client.health_check()
|
||||
logger.info(f"OpenViking 连接状态: {status}")
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenViking 连接失败: {e}")
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Memory Gateway 关闭中...")
|
||||
await close_openviking_client()
|
||||
|
||||
|
||||
def verify_api_key(x_api_key: Optional[str] = Header(default=None)) -> None:
|
||||
"""在配置了 API Key 时校验请求头。"""
|
||||
expected_key = get_config().server.api_key
|
||||
if not expected_key:
|
||||
return
|
||||
if x_api_key != expected_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
|
||||
|
||||
# FastAPI 应用
|
||||
app = FastAPI(title="Memory Gateway", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health", dependencies=[Depends(verify_api_key)])
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
try:
|
||||
ov_client = await get_openviking_client()
|
||||
ov_status = await ov_client.health_check()
|
||||
return {
|
||||
"status": "ok",
|
||||
"gateway": "memory-gateway",
|
||||
"openviking": ov_status,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "degraded",
|
||||
"gateway": "memory-gateway",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
mcp_router = APIRouter()
|
||||
|
||||
|
||||
async def mcp_server_events(request: Request, _: None = Depends(verify_api_key)):
|
||||
"""MCP Server-Sent Events 端点 - 使用 stdio 模式模拟"""
|
||||
async def event_generator():
|
||||
# 发送初始化消息
|
||||
yield {"event": "initialize", "data": json.dumps({"protocolVersion": "2024-11-05"})}
|
||||
|
||||
# 保持连接
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(30)
|
||||
yield {"event": "ping", "data": ""}
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
mcp_router.add_api_route("/sse", mcp_server_events, methods=["GET"])
|
||||
|
||||
|
||||
# MCP JSON-RPC 端点(简化实现)
|
||||
async def mcp_rpc(request: Request, _: None = Depends(verify_api_key)):
|
||||
"""处理 MCP JSON-RPC 请求"""
|
||||
body = await request.json()
|
||||
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
msg_id = body.get("id")
|
||||
|
||||
try:
|
||||
if method == "tools/list":
|
||||
tools = await list_tools()
|
||||
result = {
|
||||
"tools": [
|
||||
{
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"inputSchema": t.inputSchema,
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
}
|
||||
elif method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
tool_args = params.get("arguments", {})
|
||||
result_content = await call_tool_tool(tool_name, tool_args)
|
||||
result = {"content": [c.model_dump() for c in result_content]}
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"jsonrpc": "2.0", "error": {"code": -32601, "message": f"Method not found: {method}"}, "id": msg_id}
|
||||
)
|
||||
|
||||
return {"jsonrpc": "2.0", "result": result, "id": msg_id}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP RPC 错误: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"jsonrpc": "2.0", "error": {"code": -32603, "message": str(e)}, "id": msg_id}
|
||||
)
|
||||
|
||||
|
||||
async def call_tool_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
"""调用工具的内部函数"""
|
||||
return await call_tool(name, arguments)
|
||||
|
||||
|
||||
mcp_router.add_api_route("/rpc", mcp_rpc, methods=["POST"])
|
||||
|
||||
|
||||
# 注册 MCP 路由
|
||||
app.include_router(mcp_router, prefix="/mcp", tags=["mcp"])
|
||||
|
||||
|
||||
@app.post("/api/search", dependencies=[Depends(verify_api_key)])
|
||||
async def api_search(request: SearchRequest):
|
||||
"""REST API: 搜索"""
|
||||
ov_client = await get_openviking_client()
|
||||
result = await ov_client.search(
|
||||
query=request.query,
|
||||
namespace=request.namespace or get_config().memory.default_namespace,
|
||||
limit=request.limit or get_config().memory.search_limit,
|
||||
uri=request.uri,
|
||||
)
|
||||
return {"results": result.results, "total": result.total}
|
||||
|
||||
|
||||
@app.post("/api/memory", dependencies=[Depends(verify_api_key)])
|
||||
async def api_add_memory(request: AddMemoryRequest):
|
||||
"""REST API: 添加记忆"""
|
||||
ov_client = await get_openviking_client()
|
||||
result = await ov_client.add_memory(
|
||||
content=request.content,
|
||||
namespace=request.namespace or get_config().memory.default_namespace,
|
||||
memory_type=request.memory_type,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/api/resource", dependencies=[Depends(verify_api_key)])
|
||||
async def api_add_resource(request: AddResourceRequest):
|
||||
"""REST API: 添加资源"""
|
||||
ov_client = await get_openviking_client()
|
||||
result = await ov_client.add_resource(
|
||||
uri=request.uri,
|
||||
content=request.content,
|
||||
resource_type=request.resource_type,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def create_app(config: Optional[Config] = None) -> FastAPI:
|
||||
"""创建 FastAPI 应用"""
|
||||
if config:
|
||||
set_config(config)
|
||||
return app
|
||||
|
||||
|
||||
# 入口点
|
||||
def main():
|
||||
"""主入口"""
|
||||
import argparse
|
||||
import uvicorn
|
||||
|
||||
parser = argparse.ArgumentParser(description="Memory Gateway MCP Server")
|
||||
parser.add_argument("--config", default="config.yaml", help="配置文件路径")
|
||||
parser.add_argument("--host", default=None, help="监听地址")
|
||||
parser.add_argument("--port", type=int, default=None, help="监听端口")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载配置
|
||||
from .config import load_config as load
|
||||
config = load(args.config)
|
||||
if args.host:
|
||||
config.server.host = args.host
|
||||
if args.port:
|
||||
config.server.port = args.port
|
||||
set_config(config)
|
||||
|
||||
# 启动服务
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.server.host,
|
||||
port=config.server.port,
|
||||
log_level=config.logging.level.lower(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user