151 lines
4.4 KiB
Python
151 lines
4.4 KiB
Python
"""HTTP client for Memory Gateway's `/memory-system` API."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from .store import MemoryGatewayUserStore
|
|
|
|
|
|
class MemoryGatewayClient:
|
|
"""Small async client for the Memory Gateway business API."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
base_url: str,
|
|
store: MemoryGatewayUserStore,
|
|
api_key: str = "",
|
|
timeout_seconds: float = 15.0,
|
|
transport: httpx.AsyncBaseTransport | None = None,
|
|
) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.api_key = api_key
|
|
self.timeout_seconds = timeout_seconds
|
|
self.store = store
|
|
self.transport = transport
|
|
|
|
async def ensure_user(self, user_id: str) -> str:
|
|
cached = self.store.get_user_key(user_id)
|
|
if cached:
|
|
return cached
|
|
|
|
data = await self._post("/memory-system/users", {"user_id": user_id})
|
|
user_key = self._extract_user_key(data)
|
|
if not user_key:
|
|
raise RuntimeError("Memory Gateway user creation response missing user_key")
|
|
self.store.save_user_key(user_id, user_key)
|
|
return user_key
|
|
|
|
async def get_profile(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
user_key: str,
|
|
query: str = "用户画像",
|
|
limit: int = 5,
|
|
) -> dict[str, Any]:
|
|
return await self._get(
|
|
f"/memory-system/users/{user_id}/profile",
|
|
{"user_key": user_key, "query": query, "limit": limit},
|
|
)
|
|
|
|
async def get_session_context(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
user_key: str,
|
|
session_id: str,
|
|
query: str,
|
|
limit: int = 5,
|
|
) -> dict[str, Any]:
|
|
return await self._post(
|
|
f"/memory-system/sessions/{session_id}/context",
|
|
{"user_id": user_id, "user_key": user_key, "query": query, "limit": limit},
|
|
)
|
|
|
|
async def search(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
user_key: str,
|
|
session_id: str,
|
|
query: str,
|
|
limit: int = 5,
|
|
) -> dict[str, Any]:
|
|
return await self._post(
|
|
"/memory-system/search",
|
|
{
|
|
"user_id": user_id,
|
|
"user_key": user_key,
|
|
"session_id": session_id,
|
|
"query": query,
|
|
"use_llm": False,
|
|
"limit": limit,
|
|
},
|
|
)
|
|
|
|
async def ingest_messages(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
user_key: str,
|
|
session_id: str,
|
|
user_message: str | None,
|
|
assistant_message: str | None,
|
|
) -> dict[str, Any]:
|
|
return await self._post(
|
|
"/memory-system/messages",
|
|
{
|
|
"user_id": user_id,
|
|
"user_key": user_key,
|
|
"session_id": session_id,
|
|
"user_message": user_message,
|
|
"assistant_message": assistant_message,
|
|
},
|
|
)
|
|
|
|
async def commit_session(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
user_key: str,
|
|
session_id: str,
|
|
) -> dict[str, Any]:
|
|
return await self._post(
|
|
f"/memory-system/sessions/{session_id}/commit",
|
|
{"user_id": user_id, "user_key": user_key},
|
|
)
|
|
|
|
async def _get(self, path: str, params: dict[str, Any]) -> dict[str, Any]:
|
|
async with self._client() as client:
|
|
response = await client.get(path, params=params)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def _post(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
async with self._client() as client:
|
|
response = await client.post(path, json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def _client(self) -> httpx.AsyncClient:
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
headers["X-API-Key"] = self.api_key
|
|
return httpx.AsyncClient(
|
|
base_url=self.base_url,
|
|
headers=headers,
|
|
timeout=self.timeout_seconds,
|
|
transport=self.transport,
|
|
)
|
|
|
|
@staticmethod
|
|
def _extract_user_key(data: dict[str, Any]) -> str | None:
|
|
account = data.get("account")
|
|
result = account.get("result") if isinstance(account, dict) else None
|
|
user_key = result.get("user_key") if isinstance(result, dict) else None
|
|
return str(user_key) if user_key else None
|