Add generic memory gateway v1
This commit is contained in:
@ -24,6 +24,16 @@ from .config import get_config, set_config, Config
|
||||
from .openviking_client import get_openviking_client, close_openviking_client
|
||||
from .document_ingest import convert_file_to_markdown, save_markdown_to_obsidian, slugify
|
||||
from .llm import LLMConfigurationError, LLMSummaryError, summarize_with_llm
|
||||
from .mcp_tools_v1 import MEMORY_GATEWAY_MCP_TOOLS
|
||||
from .schemas import (
|
||||
AccessContext,
|
||||
CommitSessionRequest,
|
||||
EpisodeAppendRequest,
|
||||
MemoryFeedbackRequest,
|
||||
MemorySearchRequest,
|
||||
MemoryUpsertRequest,
|
||||
)
|
||||
from .services import service as v1_service
|
||||
from .types import SearchRequest, AddMemoryRequest, AddResourceRequest, CommitSummaryRequest
|
||||
|
||||
# 配置日志
|
||||
@ -41,7 +51,7 @@ mcp_server = Server("memory-gateway")
|
||||
@mcp_server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
"""列出可用的 MCP 工具"""
|
||||
return [
|
||||
legacy_tools = [
|
||||
Tool(
|
||||
name="search",
|
||||
description="语义搜索记忆和资源",
|
||||
@ -135,12 +145,25 @@ async def list_tools() -> list[Tool]:
|
||||
},
|
||||
),
|
||||
]
|
||||
v1_tools = [
|
||||
Tool(
|
||||
name=definition["name"],
|
||||
description=definition["description"],
|
||||
inputSchema=definition["inputSchema"],
|
||||
)
|
||||
for definition in MEMORY_GATEWAY_MCP_TOOLS
|
||||
]
|
||||
return legacy_tools + v1_tools
|
||||
|
||||
|
||||
@mcp_server.call_tool()
|
||||
async def call_tool(name: str, arguments: Any) -> list[TextContent]:
|
||||
"""调用 MCP 工具"""
|
||||
try:
|
||||
if name.startswith("memory_"):
|
||||
result = await call_v1_memory_tool(name, arguments or {})
|
||||
return [TextContent(type="text", text=json.dumps(result, ensure_ascii=False, default=str))]
|
||||
|
||||
ov_client = await get_openviking_client()
|
||||
|
||||
if name == "search":
|
||||
@ -200,6 +223,60 @@ async def call_tool(name: str, arguments: Any) -> list[TextContent]:
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
|
||||
|
||||
async def call_v1_memory_tool(name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Dispatch v1 Memory Gateway MCP tools to the same service used by /v1."""
|
||||
if name == "memory_search":
|
||||
return _jsonable(await v1_service.search_memory_with_openviking(MemorySearchRequest(**arguments)))
|
||||
if name == "memory_upsert":
|
||||
return v1_service.upsert_memory(MemoryUpsertRequest(**arguments)).model_dump(mode="json")
|
||||
if name == "memory_append_episode":
|
||||
return v1_service.append_episode(EpisodeAppendRequest(**arguments)).model_dump(mode="json")
|
||||
if name == "memory_commit_session":
|
||||
session_id = arguments.get("session_id")
|
||||
if not session_id:
|
||||
raise ValueError("session_id is required")
|
||||
return _jsonable(v1_service.commit_session(session_id, CommitSessionRequest(**arguments)))
|
||||
if name == "memory_get_profile":
|
||||
return v1_service.get_profile(arguments["user_id"]).model_dump(mode="json")
|
||||
if name == "memory_list_namespaces":
|
||||
return {
|
||||
"namespaces": [
|
||||
item.model_dump(mode="json")
|
||||
for item in v1_service.list_namespaces(
|
||||
AccessContext(
|
||||
user_id=arguments["user_id"],
|
||||
agent_id=arguments.get("agent_id"),
|
||||
workspace_id=arguments.get("workspace_id"),
|
||||
session_id=arguments.get("session_id"),
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
if name == "memory_delete":
|
||||
return v1_service.delete_memory(
|
||||
arguments["memory_id"],
|
||||
AccessContext(
|
||||
user_id=arguments["user_id"],
|
||||
agent_id=arguments.get("agent_id"),
|
||||
workspace_id=arguments.get("workspace_id"),
|
||||
session_id=arguments.get("session_id"),
|
||||
),
|
||||
)
|
||||
if name == "memory_feedback":
|
||||
return v1_service.add_feedback(arguments["memory_id"], MemoryFeedbackRequest(**arguments))
|
||||
raise ValueError(f"Unknown v1 memory tool: {name}")
|
||||
|
||||
|
||||
def _jsonable(value: Any) -> Any:
|
||||
if hasattr(value, "model_dump"):
|
||||
return value.model_dump(mode="json")
|
||||
if isinstance(value, list):
|
||||
return [_jsonable(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {key: _jsonable(item) for key, item in value.items()}
|
||||
return value
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
@ -401,10 +478,12 @@ async def health_check():
|
||||
try:
|
||||
ov_client = await get_openviking_client()
|
||||
ov_status = await ov_client.health_check()
|
||||
evermemos_status = v1_service.evermemos_health()
|
||||
return {
|
||||
"status": "ok",
|
||||
"gateway": "memory-gateway",
|
||||
"openviking": ov_status,
|
||||
"evermemos": evermemos_status,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
@ -490,6 +569,12 @@ mcp_router.add_api_route("/rpc", mcp_rpc, methods=["POST"])
|
||||
# 注册 MCP 路由
|
||||
app.include_router(mcp_router, prefix="/mcp", tags=["mcp"])
|
||||
|
||||
# Generic Memory Gateway v1 routes are imported lazily here to avoid changing
|
||||
# the existing legacy /api and /mcp startup path.
|
||||
from .api_v1 import router as api_v1_router # noqa: E402
|
||||
|
||||
app.include_router(api_v1_router)
|
||||
|
||||
|
||||
@app.post("/api/search", dependencies=[Depends(verify_api_key)])
|
||||
async def api_search(request: SearchRequest):
|
||||
|
||||
Reference in New Issue
Block a user