feat: add Memory Gateway integration with async support for memory snapshots and user management
This commit is contained in:
6
app-instance/backend/beaver/memory/gateway/__init__.py
Normal file
6
app-instance/backend/beaver/memory/gateway/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Memory Gateway integration for Beaver memory runtime."""
|
||||
|
||||
from .client import MemoryGatewayClient
|
||||
from .store import MemoryGatewayUserStore
|
||||
|
||||
__all__ = ["MemoryGatewayClient", "MemoryGatewayUserStore"]
|
||||
150
app-instance/backend/beaver/memory/gateway/client.py
Normal file
150
app-instance/backend/beaver/memory/gateway/client.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""HTTP client for Memory Gateway's `/memory-system` API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .store import MemoryGatewayUserStore
|
||||
|
||||
|
||||
class MemoryGatewayClient:
|
||||
"""Small async client for the Memory Gateway business API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str,
|
||||
store: MemoryGatewayUserStore,
|
||||
api_key: str = "",
|
||||
timeout_seconds: float = 15.0,
|
||||
transport: httpx.AsyncBaseTransport | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.store = store
|
||||
self.transport = transport
|
||||
|
||||
async def ensure_user(self, user_id: str) -> str:
|
||||
cached = self.store.get_user_key(user_id)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
data = await self._post("/memory-system/users", {"user_id": user_id})
|
||||
user_key = self._extract_user_key(data)
|
||||
if not user_key:
|
||||
raise RuntimeError("Memory Gateway user creation response missing user_key")
|
||||
self.store.save_user_key(user_id, user_key)
|
||||
return user_key
|
||||
|
||||
async def get_profile(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_key: str,
|
||||
query: str = "用户画像",
|
||||
limit: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
return await self._get(
|
||||
f"/memory-system/users/{user_id}/profile",
|
||||
{"user_key": user_key, "query": query, "limit": limit},
|
||||
)
|
||||
|
||||
async def get_session_context(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_key: str,
|
||||
session_id: str,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
return await self._post(
|
||||
f"/memory-system/sessions/{session_id}/context",
|
||||
{"user_id": user_id, "user_key": user_key, "query": query, "limit": limit},
|
||||
)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_key: str,
|
||||
session_id: str,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
) -> dict[str, Any]:
|
||||
return await self._post(
|
||||
"/memory-system/search",
|
||||
{
|
||||
"user_id": user_id,
|
||||
"user_key": user_key,
|
||||
"session_id": session_id,
|
||||
"query": query,
|
||||
"use_llm": False,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
async def ingest_messages(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_key: str,
|
||||
session_id: str,
|
||||
user_message: str | None,
|
||||
assistant_message: str | None,
|
||||
) -> dict[str, Any]:
|
||||
return await self._post(
|
||||
"/memory-system/messages",
|
||||
{
|
||||
"user_id": user_id,
|
||||
"user_key": user_key,
|
||||
"session_id": session_id,
|
||||
"user_message": user_message,
|
||||
"assistant_message": assistant_message,
|
||||
},
|
||||
)
|
||||
|
||||
async def commit_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_key: str,
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
return await self._post(
|
||||
f"/memory-system/sessions/{session_id}/commit",
|
||||
{"user_id": user_id, "user_key": user_key},
|
||||
)
|
||||
|
||||
async def _get(self, path: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
async with self._client() as client:
|
||||
response = await client.get(path, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _post(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
async with self._client() as client:
|
||||
response = await client.post(path, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _client(self) -> httpx.AsyncClient:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["X-API-Key"] = self.api_key
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout_seconds,
|
||||
transport=self.transport,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_user_key(data: dict[str, Any]) -> str | None:
|
||||
account = data.get("account")
|
||||
result = account.get("result") if isinstance(account, dict) else None
|
||||
user_key = result.get("user_key") if isinstance(result, dict) else None
|
||||
return str(user_key) if user_key else None
|
||||
190
app-instance/backend/beaver/memory/gateway/service.py
Normal file
190
app-instance/backend/beaver/memory/gateway/service.py
Normal file
@ -0,0 +1,190 @@
|
||||
"""Memory service wrapper that augments local snapshots with Memory Gateway recall."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.config.schema import MemoryGatewayConfig
|
||||
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||
from beaver.services.memory_service import MemoryService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GatewayAugmentedMemoryService:
|
||||
"""Keep local curated memory and add Memory Gateway recall when available."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
local_service: MemoryService,
|
||||
client: Any,
|
||||
config: MemoryGatewayConfig,
|
||||
) -> None:
|
||||
self.local_service = local_service
|
||||
self.client = client
|
||||
self.config = config
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.local_service.initialize()
|
||||
|
||||
def get_store(self) -> Any:
|
||||
return self.local_service.get_store()
|
||||
|
||||
def get_snapshot(self) -> MemorySnapshot:
|
||||
return self.local_service.get_snapshot()
|
||||
|
||||
def capture_snapshot_for_run(self) -> MemorySnapshot:
|
||||
return self.local_service.capture_snapshot_for_run()
|
||||
|
||||
async def capture_snapshot_for_run_async(
|
||||
self,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
query: str | None = None,
|
||||
) -> MemorySnapshot:
|
||||
local_snapshot = self.local_service.capture_snapshot_for_run()
|
||||
if not user_id or not session_id or not query:
|
||||
return local_snapshot
|
||||
|
||||
try:
|
||||
user_key = await self.client.ensure_user(user_id)
|
||||
limit = max(1, self.config.snapshot_search_limit)
|
||||
profile, session_context, search = await asyncio.gather(
|
||||
self.client.get_profile(
|
||||
user_id=user_id,
|
||||
user_key=user_key,
|
||||
query="用户画像",
|
||||
limit=limit,
|
||||
),
|
||||
self.client.get_session_context(
|
||||
user_id=user_id,
|
||||
user_key=user_key,
|
||||
session_id=session_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
),
|
||||
self.client.search(
|
||||
user_id=user_id,
|
||||
user_key=user_key,
|
||||
session_id=session_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - gateway recall must not break local memory
|
||||
logger.warning("Memory Gateway snapshot augmentation failed: %s", exc, exc_info=True)
|
||||
return local_snapshot
|
||||
|
||||
user_block = self._join_blocks(
|
||||
local_snapshot.user_block,
|
||||
self._render_profile(profile),
|
||||
)
|
||||
memory_block = self._join_blocks(
|
||||
local_snapshot.memory_block,
|
||||
self._render_session_context(session_context),
|
||||
self._render_search(search),
|
||||
)
|
||||
return MemorySnapshot(memory_block=memory_block, user_block=user_block)
|
||||
|
||||
async def archive_run_async(
|
||||
self,
|
||||
*,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
) -> dict[str, Any]:
|
||||
if not user_id:
|
||||
return {"success": False, "error": "user_id is required for Memory Gateway archive."}
|
||||
|
||||
user_key = await self.client.ensure_user(user_id)
|
||||
result: dict[str, Any] = {
|
||||
"messages": await self.client.ingest_messages(
|
||||
user_id=user_id,
|
||||
user_key=user_key,
|
||||
session_id=session_id,
|
||||
user_message=user_message,
|
||||
assistant_message=assistant_message,
|
||||
)
|
||||
}
|
||||
if self.config.commit_on_run_complete:
|
||||
result["commit"] = await self.client.commit_session(
|
||||
user_id=user_id,
|
||||
user_key=user_key,
|
||||
session_id=session_id,
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _join_blocks(*blocks: str | None) -> str | None:
|
||||
joined = "\n\n".join(block.strip() for block in blocks if block and block.strip())
|
||||
return joined or None
|
||||
|
||||
@staticmethod
|
||||
def _render_profile(payload: dict[str, Any]) -> str | None:
|
||||
lines: list[str] = []
|
||||
profile = payload.get("profile")
|
||||
if isinstance(profile, dict):
|
||||
lines.extend(_compact_mapping(profile))
|
||||
elif profile:
|
||||
lines.append(str(profile))
|
||||
lines.extend(_compact_items(payload.get("items")))
|
||||
return _section("GATEWAY USER PROFILE", lines)
|
||||
|
||||
@staticmethod
|
||||
def _render_session_context(payload: dict[str, Any]) -> str | None:
|
||||
lines: list[str] = []
|
||||
context = payload.get("context")
|
||||
if isinstance(context, dict):
|
||||
overview = context.get("latest_archive_overview")
|
||||
if overview:
|
||||
lines.append(str(overview))
|
||||
messages = context.get("messages")
|
||||
if isinstance(messages, list):
|
||||
lines.extend(_compact_items(messages))
|
||||
lines.extend(_compact_items(payload.get("items")))
|
||||
return _section("GATEWAY SESSION CONTEXT", lines)
|
||||
|
||||
@staticmethod
|
||||
def _render_search(payload: dict[str, Any]) -> str | None:
|
||||
return _section("GATEWAY SEARCH RESULTS", _compact_items(payload.get("items")))
|
||||
|
||||
|
||||
def _section(title: str, lines: list[str]) -> str | None:
|
||||
cleaned = [line.strip() for line in lines if line and line.strip()]
|
||||
if not cleaned:
|
||||
return None
|
||||
return f"{title}\n" + "\n".join(cleaned)
|
||||
|
||||
|
||||
def _compact_mapping(payload: dict[str, Any]) -> list[str]:
|
||||
result: list[str] = []
|
||||
for key in ("summary", "text", "content", "description"):
|
||||
value = payload.get(key)
|
||||
if value:
|
||||
result.append(str(value))
|
||||
if not result and payload:
|
||||
result.append(str(payload))
|
||||
return result
|
||||
|
||||
|
||||
def _compact_items(items: Any) -> list[str]:
|
||||
if not isinstance(items, list):
|
||||
return []
|
||||
result: list[str] = []
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
if item:
|
||||
result.append(f"- {item}")
|
||||
continue
|
||||
text = item.get("text") or item.get("summary") or item.get("content")
|
||||
if not text:
|
||||
text = str(item)
|
||||
source = item.get("source_backend")
|
||||
prefix = f"[{source}] " if source else ""
|
||||
result.append(f"- {prefix}{text}")
|
||||
return result
|
||||
54
app-instance/backend/beaver/memory/gateway/store.py
Normal file
54
app-instance/backend/beaver/memory/gateway/store.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""SQLite cache for Memory Gateway user credentials."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MemoryGatewayUserStore:
|
||||
"""Persist `user_id -> user_key` mappings returned by Memory Gateway."""
|
||||
|
||||
def __init__(self, db_path: str | Path) -> None:
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_schema()
|
||||
|
||||
def get_user_key(self, user_id: str) -> str | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT user_key FROM memory_gateway_users WHERE user_id = ?",
|
||||
(user_id,),
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
def save_user_key(self, user_id: str, user_key: str) -> None:
|
||||
now = time.time()
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO memory_gateway_users (user_id, user_key, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET
|
||||
user_key = excluded.user_key,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(user_id, user_key, now, now),
|
||||
)
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS memory_gateway_users (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
user_key TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
return sqlite3.connect(str(self.db_path))
|
||||
Reference in New Issue
Block a user