191 lines
6.3 KiB
Python
191 lines
6.3 KiB
Python
"""Memory service wrapper that augments local snapshots with Memory Gateway recall."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any
|
|
|
|
from beaver.foundation.config.schema import MemoryGatewayConfig
|
|
from beaver.memory.curated.snapshot import MemorySnapshot
|
|
from beaver.services.memory_service import MemoryService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GatewayAugmentedMemoryService:
|
|
"""Keep local curated memory and add Memory Gateway recall when available."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
local_service: MemoryService,
|
|
client: Any,
|
|
config: MemoryGatewayConfig,
|
|
) -> None:
|
|
self.local_service = local_service
|
|
self.client = client
|
|
self.config = config
|
|
|
|
def initialize(self) -> None:
|
|
self.local_service.initialize()
|
|
|
|
def get_store(self) -> Any:
|
|
return self.local_service.get_store()
|
|
|
|
def get_snapshot(self) -> MemorySnapshot:
|
|
return self.local_service.get_snapshot()
|
|
|
|
def capture_snapshot_for_run(self) -> MemorySnapshot:
|
|
return self.local_service.capture_snapshot_for_run()
|
|
|
|
async def capture_snapshot_for_run_async(
|
|
self,
|
|
*,
|
|
user_id: str | None = None,
|
|
session_id: str | None = None,
|
|
query: str | None = None,
|
|
) -> MemorySnapshot:
|
|
local_snapshot = self.local_service.capture_snapshot_for_run()
|
|
if not user_id or not session_id or not query:
|
|
return local_snapshot
|
|
|
|
try:
|
|
user_key = await self.client.ensure_user(user_id)
|
|
limit = max(1, self.config.snapshot_search_limit)
|
|
profile, session_context, search = await asyncio.gather(
|
|
self.client.get_profile(
|
|
user_id=user_id,
|
|
user_key=user_key,
|
|
query="用户画像",
|
|
limit=limit,
|
|
),
|
|
self.client.get_session_context(
|
|
user_id=user_id,
|
|
user_key=user_key,
|
|
session_id=session_id,
|
|
query=query,
|
|
limit=limit,
|
|
),
|
|
self.client.search(
|
|
user_id=user_id,
|
|
user_key=user_key,
|
|
session_id=session_id,
|
|
query=query,
|
|
limit=limit,
|
|
),
|
|
)
|
|
except Exception as exc: # noqa: BLE001 - gateway recall must not break local memory
|
|
logger.warning("Memory Gateway snapshot augmentation failed: %s", exc, exc_info=True)
|
|
return local_snapshot
|
|
|
|
user_block = self._join_blocks(
|
|
local_snapshot.user_block,
|
|
self._render_profile(profile),
|
|
)
|
|
memory_block = self._join_blocks(
|
|
local_snapshot.memory_block,
|
|
self._render_session_context(session_context),
|
|
self._render_search(search),
|
|
)
|
|
return MemorySnapshot(memory_block=memory_block, user_block=user_block)
|
|
|
|
async def archive_run_async(
|
|
self,
|
|
*,
|
|
user_id: str | None,
|
|
session_id: str,
|
|
user_message: str,
|
|
assistant_message: str,
|
|
) -> dict[str, Any]:
|
|
if not user_id:
|
|
return {"success": False, "error": "user_id is required for Memory Gateway archive."}
|
|
|
|
user_key = await self.client.ensure_user(user_id)
|
|
result: dict[str, Any] = {
|
|
"messages": await self.client.ingest_messages(
|
|
user_id=user_id,
|
|
user_key=user_key,
|
|
session_id=session_id,
|
|
user_message=user_message,
|
|
assistant_message=assistant_message,
|
|
)
|
|
}
|
|
if self.config.commit_on_run_complete:
|
|
result["commit"] = await self.client.commit_session(
|
|
user_id=user_id,
|
|
user_key=user_key,
|
|
session_id=session_id,
|
|
)
|
|
return result
|
|
|
|
@staticmethod
|
|
def _join_blocks(*blocks: str | None) -> str | None:
|
|
joined = "\n\n".join(block.strip() for block in blocks if block and block.strip())
|
|
return joined or None
|
|
|
|
@staticmethod
|
|
def _render_profile(payload: dict[str, Any]) -> str | None:
|
|
lines: list[str] = []
|
|
profile = payload.get("profile")
|
|
if isinstance(profile, dict):
|
|
lines.extend(_compact_mapping(profile))
|
|
elif profile:
|
|
lines.append(str(profile))
|
|
lines.extend(_compact_items(payload.get("items")))
|
|
return _section("GATEWAY USER PROFILE", lines)
|
|
|
|
@staticmethod
|
|
def _render_session_context(payload: dict[str, Any]) -> str | None:
|
|
lines: list[str] = []
|
|
context = payload.get("context")
|
|
if isinstance(context, dict):
|
|
overview = context.get("latest_archive_overview")
|
|
if overview:
|
|
lines.append(str(overview))
|
|
messages = context.get("messages")
|
|
if isinstance(messages, list):
|
|
lines.extend(_compact_items(messages))
|
|
lines.extend(_compact_items(payload.get("items")))
|
|
return _section("GATEWAY SESSION CONTEXT", lines)
|
|
|
|
@staticmethod
|
|
def _render_search(payload: dict[str, Any]) -> str | None:
|
|
return _section("GATEWAY SEARCH RESULTS", _compact_items(payload.get("items")))
|
|
|
|
|
|
def _section(title: str, lines: list[str]) -> str | None:
|
|
cleaned = [line.strip() for line in lines if line and line.strip()]
|
|
if not cleaned:
|
|
return None
|
|
return f"{title}\n" + "\n".join(cleaned)
|
|
|
|
|
|
def _compact_mapping(payload: dict[str, Any]) -> list[str]:
|
|
result: list[str] = []
|
|
for key in ("summary", "text", "content", "description"):
|
|
value = payload.get(key)
|
|
if value:
|
|
result.append(str(value))
|
|
if not result and payload:
|
|
result.append(str(payload))
|
|
return result
|
|
|
|
|
|
def _compact_items(items: Any) -> list[str]:
|
|
if not isinstance(items, list):
|
|
return []
|
|
result: list[str] = []
|
|
for item in items:
|
|
if not isinstance(item, dict):
|
|
if item:
|
|
result.append(f"- {item}")
|
|
continue
|
|
text = item.get("text") or item.get("summary") or item.get("content")
|
|
if not text:
|
|
text = str(item)
|
|
source = item.get("source_backend")
|
|
prefix = f"[{source}] " if source else ""
|
|
result.append(f"- {prefix}{text}")
|
|
return result
|