Initial SOC memory POC implementation
This commit is contained in:
1
memory_gateway/__init__.py
Normal file
1
memory_gateway/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Memory Gateway 核心模块"""
|
||||
55
memory_gateway/config.py
Normal file
55
memory_gateway/config.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""配置加载模块"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .types import Config, ServerConfig, OpenVikingConfig, MemoryConfig, LoggingConfig
|
||||
|
||||
|
||||
def load_config(config_path: Optional[str] = None) -> Config:
|
||||
"""加载配置文件"""
|
||||
if config_path is None:
|
||||
config_path = os.environ.get("MEMORY_GATEWAY_CONFIG", "config.yaml")
|
||||
|
||||
config_file = Path(config_path)
|
||||
|
||||
if not config_file.exists():
|
||||
# 返回默认配置
|
||||
return Config()
|
||||
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
return Config()
|
||||
|
||||
return Config(
|
||||
server=ServerConfig(**data.get("server", {})),
|
||||
openviking=OpenVikingConfig(**data.get("openviking", {})),
|
||||
memory=MemoryConfig(**data.get("memory", {})),
|
||||
logging=LoggingConfig(**data.get("logging", {})),
|
||||
)
|
||||
except (ValidationError, yaml.YAMLError) as e:
|
||||
print(f"配置文件解析错误: {e}")
|
||||
return Config()
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""获取全局配置(单例)"""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = load_config()
|
||||
return _config
|
||||
|
||||
|
||||
def set_config(config: Config) -> None:
|
||||
"""设置全局配置"""
|
||||
global _config
|
||||
_config = config
|
||||
|
||||
|
||||
_config: Optional[Config] = None
|
||||
302
memory_gateway/openviking_client.py
Normal file
302
memory_gateway/openviking_client.py
Normal file
@ -0,0 +1,302 @@
|
||||
"""OpenViking client wrapper used by the SOC Memory POC."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import get_config
|
||||
from .types import MemoryEntry, ResourceEntry, SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenVikingClient:
|
||||
"""Thin async client for the OpenViking HTTP API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 30,
|
||||
account: str = "default",
|
||||
user: str = "default",
|
||||
):
|
||||
self.config = get_config()
|
||||
self.base_url = base_url or self.config.openviking.url
|
||||
self.api_key = api_key or self.config.openviking.api_key or "your-secret-root-key"
|
||||
self.timeout = timeout
|
||||
self.account = account
|
||||
self.user = user
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
def _get_headers(self) -> dict[str, str]:
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["X-API-Key"] = self.api_key
|
||||
headers["X-OpenViking-Account"] = self.account
|
||||
headers["X-OpenViking-User"] = self.user
|
||||
return headers
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers=self._get_headers(),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
client = await self._get_client()
|
||||
try:
|
||||
response = await client.get("/health")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"OpenViking 健康检查失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
namespace: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
uri: Optional[str] = None,
|
||||
) -> SearchResult:
|
||||
"""Semantic search against OpenViking resources/memories."""
|
||||
client = await self._get_client()
|
||||
|
||||
payload: dict[str, Any] = {"query": query}
|
||||
if limit:
|
||||
payload["limit"] = limit
|
||||
|
||||
if uri:
|
||||
payload["uri"] = uri
|
||||
elif namespace:
|
||||
payload["uri"] = f"viking://{namespace}"
|
||||
|
||||
try:
|
||||
response = await client.post("/api/v1/search/search", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("status") != "ok":
|
||||
logger.warning(f"搜索返回错误: {data.get('error')}")
|
||||
return SearchResult(results=[], total=0)
|
||||
|
||||
result = data.get("result", {})
|
||||
memories = result.get("memories", [])
|
||||
resources = result.get("resources", [])
|
||||
|
||||
all_results = []
|
||||
for m in memories + resources:
|
||||
all_results.append(
|
||||
{
|
||||
"uri": m.get("uri"),
|
||||
"abstract": m.get("abstract"),
|
||||
"score": m.get("score"),
|
||||
"context_type": m.get("context_type"),
|
||||
}
|
||||
)
|
||||
|
||||
return SearchResult(results=all_results, total=result.get("total", len(all_results)))
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"搜索失败: {e}")
|
||||
return SearchResult(results=[], total=0)
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
namespace: Optional[str] = None,
|
||||
memory_type: str = "general",
|
||||
) -> dict[str, Any]:
|
||||
"""Add memory via session commit flow."""
|
||||
client = await self._get_client()
|
||||
ns = namespace or self.config.memory.default_namespace or "user/default/memories"
|
||||
|
||||
try:
|
||||
response = await client.post("/api/v1/sessions", json={"mode": "interactive"})
|
||||
response.raise_for_status()
|
||||
session_data = response.json()
|
||||
|
||||
if session_data.get("status") != "ok":
|
||||
return session_data
|
||||
|
||||
session_id = session_data["result"]["session_id"]
|
||||
commit_response = await client.post(
|
||||
f"/api/v1/sessions/{session_id}/commit",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"[{ns}/{memory_type}] {content}",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
commit_response.raise_for_status()
|
||||
return commit_response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"添加记忆失败: {e}")
|
||||
raise
|
||||
|
||||
async def _upload_temp_file(self, file_path: str | Path) -> str:
|
||||
client = await self._get_client()
|
||||
file_path = Path(file_path)
|
||||
mime_type = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
|
||||
|
||||
with file_path.open("rb") as f:
|
||||
response = await client.post(
|
||||
"/api/v1/resources/temp_upload",
|
||||
files={"file": (file_path.name, f, mime_type)},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
result = data.get("result", {})
|
||||
if "temp_path" in result:
|
||||
return result["temp_path"]
|
||||
if "temp_file_id" in result:
|
||||
return result["temp_file_id"]
|
||||
raise KeyError(f"Unexpected temp upload response: {data}")
|
||||
|
||||
async def add_resource(
|
||||
self,
|
||||
uri: str,
|
||||
content: str,
|
||||
resource_type: str = "text",
|
||||
wait: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Add a text/json resource by uploading a temporary file first.
|
||||
|
||||
OpenViking HTTP API does not accept raw `uri + content` directly. The
|
||||
client must upload a temp file and then create the resource with `to`.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
suffix_map = {
|
||||
"json": ".json",
|
||||
"text": ".txt",
|
||||
"markdown": ".md",
|
||||
"md": ".md",
|
||||
}
|
||||
suffix = suffix_map.get(resource_type, ".txt")
|
||||
|
||||
with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
try:
|
||||
temp_ref = await self._upload_temp_file(tmp_path)
|
||||
payload = {
|
||||
"temp_path": temp_ref,
|
||||
"to": uri,
|
||||
"wait": wait,
|
||||
"source_name": Path(uri).name or tmp_path.name,
|
||||
"strict": False,
|
||||
}
|
||||
response = await client.post("/api/v1/resources", json=payload)
|
||||
if response.status_code >= 400:
|
||||
logger.error("添加资源失败响应: %s", response.text)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"添加资源失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
async def list_memories(
|
||||
self,
|
||||
namespace: Optional[str] = None,
|
||||
memory_type: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[MemoryEntry]:
|
||||
client = await self._get_client()
|
||||
|
||||
ns = namespace or "user/default/memories"
|
||||
if memory_type:
|
||||
ns = f"{ns}/{memory_type}"
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/v1/search/search",
|
||||
json={"query": "", "uri": f"viking://{ns}", "limit": limit or 10},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("status") == "ok":
|
||||
result = data.get("result", {})
|
||||
memories = result.get("memories", [])
|
||||
return [
|
||||
MemoryEntry(
|
||||
id=m.get("uri", ""),
|
||||
content=m.get("abstract", ""),
|
||||
namespace=ns,
|
||||
memory_type=memory_type or "general",
|
||||
)
|
||||
for m in memories
|
||||
]
|
||||
return []
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"列出记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def list_resources(
|
||||
self,
|
||||
namespace: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[ResourceEntry]:
|
||||
client = await self._get_client()
|
||||
|
||||
uri = f"viking://{namespace}" if namespace else "viking://resources"
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/v1/search/search",
|
||||
json={"query": "", "uri": uri, "limit": limit or 10},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("status") == "ok":
|
||||
result = data.get("result", {})
|
||||
resources = result.get("resources", [])
|
||||
return [
|
||||
ResourceEntry(
|
||||
uri=r.get("uri", ""),
|
||||
content=r.get("abstract", ""),
|
||||
resource_type="text",
|
||||
)
|
||||
for r in resources
|
||||
]
|
||||
return []
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"列出资源失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
_client: Optional[OpenVikingClient] = None
|
||||
|
||||
|
||||
async def get_openviking_client() -> OpenVikingClient:
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = OpenVikingClient()
|
||||
return _client
|
||||
|
||||
|
||||
async def close_openviking_client():
|
||||
global _client
|
||||
if _client:
|
||||
await _client.close()
|
||||
_client = None
|
||||
387
memory_gateway/server.py
Normal file
387
memory_gateway/server.py
Normal file
@ -0,0 +1,387 @@
|
||||
"""Memory Gateway MCP Server.
|
||||
|
||||
基于 Model Context Protocol 的记忆网关服务,为局域网内的 AI Agent 提供统一的 OpenViking 访问入口。
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from mcp.server import Server
|
||||
from mcp.types import TextContent, Tool
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from .config import get_config, set_config, Config
|
||||
from .openviking_client import get_openviking_client, close_openviking_client
|
||||
from .types import SearchRequest, AddMemoryRequest, AddResourceRequest
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 创建 MCP Server
|
||||
mcp_server = Server("memory-gateway")
|
||||
|
||||
|
||||
@mcp_server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
"""列出可用的 MCP 工具"""
|
||||
return [
|
||||
Tool(
|
||||
name="search",
|
||||
description="语义搜索记忆和资源",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索查询"},
|
||||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||||
"limit": {"type": "integer", "description": "返回结果数量(默认10)"},
|
||||
"uri": {"type": "string", "description": "资源 URI(可选)"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="add_memory",
|
||||
description="添加新记忆",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string", "description": "记忆内容"},
|
||||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||||
"memory_type": {"type": "string", "description": "记忆类型(默认general)"},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="add_resource",
|
||||
description="添加资源",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"uri": {"type": "string", "description": "资源 URI"},
|
||||
"content": {"type": "string", "description": "资源内容"},
|
||||
"resource_type": {"type": "string", "description": "资源类型(默认text)"},
|
||||
},
|
||||
"required": ["uri", "content"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_status",
|
||||
description="检查系统状态",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="list_memories",
|
||||
description="列出已存储的记忆",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||||
"memory_type": {"type": "string", "description": "记忆类型(可选)"},
|
||||
"limit": {"type": "integer", "description": "返回数量(默认10)"},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="list_resources",
|
||||
description="列出已存储的资源",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"namespace": {"type": "string", "description": "命名空间(可选)"},
|
||||
"limit": {"type": "integer", "description": "返回数量(默认10)"},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@mcp_server.call_tool()
|
||||
async def call_tool(name: str, arguments: Any) -> list[TextContent]:
|
||||
"""调用 MCP 工具"""
|
||||
try:
|
||||
ov_client = await get_openviking_client()
|
||||
|
||||
if name == "search":
|
||||
result = await ov_client.search(
|
||||
query=arguments.get("query"),
|
||||
namespace=arguments.get("namespace"),
|
||||
limit=arguments.get("limit"),
|
||||
uri=arguments.get("uri"),
|
||||
)
|
||||
return [TextContent(type="text", text=str(result.results))]
|
||||
|
||||
elif name == "add_memory":
|
||||
result = await ov_client.add_memory(
|
||||
content=arguments.get("content"),
|
||||
namespace=arguments.get("namespace"),
|
||||
memory_type=arguments.get("memory_type", "general"),
|
||||
)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
|
||||
elif name == "add_resource":
|
||||
result = await ov_client.add_resource(
|
||||
uri=arguments.get("uri"),
|
||||
content=arguments.get("content"),
|
||||
resource_type=arguments.get("resource_type", "text"),
|
||||
)
|
||||
return [TextContent(type="text", text=str(result))]
|
||||
|
||||
elif name == "get_status":
|
||||
ov_status = await ov_client.health_check()
|
||||
return [TextContent(type="text", text=f"Memory Gateway: OK\nOpenViking: {ov_status}")]
|
||||
|
||||
elif name == "list_memories":
|
||||
memories = await ov_client.list_memories(
|
||||
namespace=arguments.get("namespace"),
|
||||
memory_type=arguments.get("memory_type"),
|
||||
limit=arguments.get("limit"),
|
||||
)
|
||||
return [TextContent(type="text", text=str([m.model_dump() for m in memories]))]
|
||||
|
||||
elif name == "list_resources":
|
||||
resources = await ov_client.list_resources(
|
||||
namespace=arguments.get("namespace"),
|
||||
limit=arguments.get("limit"),
|
||||
)
|
||||
return [TextContent(type="text", text=str([r.model_dump() for r in resources]))]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行失败: {e}")
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
logger.info("Memory Gateway 启动中...")
|
||||
config = get_config()
|
||||
logger.info(f"配置加载完成: {config.server.host}:{config.server.port}")
|
||||
logger.info(f"OpenViking 后端: {config.openviking.url}")
|
||||
|
||||
# 测试 OpenViking 连接
|
||||
try:
|
||||
ov_client = await get_openviking_client()
|
||||
status = await ov_client.health_check()
|
||||
logger.info(f"OpenViking 连接状态: {status}")
|
||||
except Exception as e:
|
||||
logger.warning(f"OpenViking 连接失败: {e}")
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Memory Gateway 关闭中...")
|
||||
await close_openviking_client()
|
||||
|
||||
|
||||
def verify_api_key(x_api_key: Optional[str] = Header(default=None)) -> None:
|
||||
"""在配置了 API Key 时校验请求头。"""
|
||||
expected_key = get_config().server.api_key
|
||||
if not expected_key:
|
||||
return
|
||||
if x_api_key != expected_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing API key",
|
||||
)
|
||||
|
||||
|
||||
# FastAPI 应用
|
||||
app = FastAPI(title="Memory Gateway", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health", dependencies=[Depends(verify_api_key)])
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
try:
|
||||
ov_client = await get_openviking_client()
|
||||
ov_status = await ov_client.health_check()
|
||||
return {
|
||||
"status": "ok",
|
||||
"gateway": "memory-gateway",
|
||||
"openviking": ov_status,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "degraded",
|
||||
"gateway": "memory-gateway",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
mcp_router = APIRouter()
|
||||
|
||||
|
||||
async def mcp_server_events(request: Request, _: None = Depends(verify_api_key)):
|
||||
"""MCP Server-Sent Events 端点 - 使用 stdio 模式模拟"""
|
||||
async def event_generator():
|
||||
# 发送初始化消息
|
||||
yield {"event": "initialize", "data": json.dumps({"protocolVersion": "2024-11-05"})}
|
||||
|
||||
# 保持连接
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(30)
|
||||
yield {"event": "ping", "data": ""}
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
|
||||
|
||||
mcp_router.add_api_route("/sse", mcp_server_events, methods=["GET"])
|
||||
|
||||
|
||||
# MCP JSON-RPC 端点(简化实现)
|
||||
async def mcp_rpc(request: Request, _: None = Depends(verify_api_key)):
|
||||
"""处理 MCP JSON-RPC 请求"""
|
||||
body = await request.json()
|
||||
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
msg_id = body.get("id")
|
||||
|
||||
try:
|
||||
if method == "tools/list":
|
||||
tools = await list_tools()
|
||||
result = {
|
||||
"tools": [
|
||||
{
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"inputSchema": t.inputSchema,
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
}
|
||||
elif method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
tool_args = params.get("arguments", {})
|
||||
result_content = await call_tool_tool(tool_name, tool_args)
|
||||
result = {"content": [c.model_dump() for c in result_content]}
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"jsonrpc": "2.0", "error": {"code": -32601, "message": f"Method not found: {method}"}, "id": msg_id}
|
||||
)
|
||||
|
||||
return {"jsonrpc": "2.0", "result": result, "id": msg_id}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP RPC 错误: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"jsonrpc": "2.0", "error": {"code": -32603, "message": str(e)}, "id": msg_id}
|
||||
)
|
||||
|
||||
|
||||
async def call_tool_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
"""调用工具的内部函数"""
|
||||
return await call_tool(name, arguments)
|
||||
|
||||
|
||||
mcp_router.add_api_route("/rpc", mcp_rpc, methods=["POST"])
|
||||
|
||||
|
||||
# 注册 MCP 路由
|
||||
app.include_router(mcp_router, prefix="/mcp", tags=["mcp"])
|
||||
|
||||
|
||||
@app.post("/api/search", dependencies=[Depends(verify_api_key)])
|
||||
async def api_search(request: SearchRequest):
|
||||
"""REST API: 搜索"""
|
||||
ov_client = await get_openviking_client()
|
||||
result = await ov_client.search(
|
||||
query=request.query,
|
||||
namespace=request.namespace or get_config().memory.default_namespace,
|
||||
limit=request.limit or get_config().memory.search_limit,
|
||||
uri=request.uri,
|
||||
)
|
||||
return {"results": result.results, "total": result.total}
|
||||
|
||||
|
||||
@app.post("/api/memory", dependencies=[Depends(verify_api_key)])
|
||||
async def api_add_memory(request: AddMemoryRequest):
|
||||
"""REST API: 添加记忆"""
|
||||
ov_client = await get_openviking_client()
|
||||
result = await ov_client.add_memory(
|
||||
content=request.content,
|
||||
namespace=request.namespace or get_config().memory.default_namespace,
|
||||
memory_type=request.memory_type,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/api/resource", dependencies=[Depends(verify_api_key)])
|
||||
async def api_add_resource(request: AddResourceRequest):
|
||||
"""REST API: 添加资源"""
|
||||
ov_client = await get_openviking_client()
|
||||
result = await ov_client.add_resource(
|
||||
uri=request.uri,
|
||||
content=request.content,
|
||||
resource_type=request.resource_type,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def create_app(config: Optional[Config] = None) -> FastAPI:
|
||||
"""创建 FastAPI 应用"""
|
||||
if config:
|
||||
set_config(config)
|
||||
return app
|
||||
|
||||
|
||||
# 入口点
|
||||
def main():
|
||||
"""主入口"""
|
||||
import argparse
|
||||
import uvicorn
|
||||
|
||||
parser = argparse.ArgumentParser(description="Memory Gateway MCP Server")
|
||||
parser.add_argument("--config", default="config.yaml", help="配置文件路径")
|
||||
parser.add_argument("--host", default=None, help="监听地址")
|
||||
parser.add_argument("--port", type=int, default=None, help="监听端口")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载配置
|
||||
from .config import load_config as load
|
||||
config = load(args.config)
|
||||
if args.host:
|
||||
config.server.host = args.host
|
||||
if args.port:
|
||||
config.server.port = args.port
|
||||
set_config(config)
|
||||
|
||||
# 启动服务
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.server.host,
|
||||
port=config.server.port,
|
||||
log_level=config.logging.level.lower(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
82
memory_gateway/types.py
Normal file
82
memory_gateway/types.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""类型定义"""
|
||||
from typing import Optional, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
"""服务器配置"""
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 1934
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class OpenVikingConfig(BaseModel):
|
||||
"""OpenViking 后端配置"""
|
||||
url: str = "http://localhost:1933"
|
||||
api_key: str = ""
|
||||
timeout: int = 30
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
"""记忆配置"""
|
||||
default_namespace: str = "soc"
|
||||
search_limit: int = 10
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
"""日志配置"""
|
||||
level: str = "INFO"
|
||||
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
"""完整配置"""
|
||||
server: ServerConfig = Field(default_factory=ServerConfig)
|
||||
openviking: OpenVikingConfig = Field(default_factory=OpenVikingConfig)
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig)
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig)
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""搜索请求"""
|
||||
query: str
|
||||
namespace: Optional[str] = None
|
||||
limit: Optional[int] = None
|
||||
uri: Optional[str] = None
|
||||
|
||||
|
||||
class AddMemoryRequest(BaseModel):
|
||||
"""添加记忆请求"""
|
||||
content: str
|
||||
namespace: Optional[str] = None
|
||||
memory_type: Optional[str] = "general"
|
||||
|
||||
|
||||
class AddResourceRequest(BaseModel):
|
||||
"""添加资源请求"""
|
||||
uri: str
|
||||
content: str
|
||||
resource_type: Optional[str] = "text"
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""搜索结果"""
|
||||
results: list[dict[str, Any]]
|
||||
total: int
|
||||
|
||||
|
||||
class MemoryEntry(BaseModel):
|
||||
"""记忆条目"""
|
||||
id: str
|
||||
content: str
|
||||
namespace: str
|
||||
memory_type: str
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class ResourceEntry(BaseModel):
|
||||
"""资源条目"""
|
||||
uri: str
|
||||
content: str
|
||||
resource_type: str
|
||||
created_at: Optional[str] = None
|
||||
Reference in New Issue
Block a user