"""Memory service wrapper that augments local snapshots with Memory Gateway recall.""" from __future__ import annotations import asyncio import logging from typing import Any from beaver.foundation.config.schema import MemoryGatewayConfig from beaver.memory.curated.snapshot import MemorySnapshot from beaver.services.memory_service import MemoryService logger = logging.getLogger(__name__) class GatewayAugmentedMemoryService: """Keep local curated memory and add Memory Gateway recall when available.""" def __init__( self, *, local_service: MemoryService, client: Any, config: MemoryGatewayConfig, ) -> None: self.local_service = local_service self.client = client self.config = config def initialize(self) -> None: self.local_service.initialize() def get_store(self) -> Any: return self.local_service.get_store() def get_snapshot(self) -> MemorySnapshot: return self.local_service.get_snapshot() def capture_snapshot_for_run(self) -> MemorySnapshot: return self.local_service.capture_snapshot_for_run() async def capture_snapshot_for_run_async( self, *, user_id: str | None = None, session_id: str | None = None, query: str | None = None, ) -> MemorySnapshot: local_snapshot = self.local_service.capture_snapshot_for_run() if not user_id or not session_id or not query: return local_snapshot try: user_key = await self.client.ensure_user(user_id) limit = max(1, self.config.snapshot_search_limit) profile, session_context, search = await asyncio.gather( self.client.get_profile( user_id=user_id, user_key=user_key, query="用户画像", limit=limit, ), self.client.get_session_context( user_id=user_id, user_key=user_key, session_id=session_id, query=query, limit=limit, ), self.client.search( user_id=user_id, user_key=user_key, session_id=session_id, query=query, limit=limit, ), ) except Exception as exc: # noqa: BLE001 - gateway recall must not break local memory logger.warning("Memory Gateway snapshot augmentation failed: %s", exc, exc_info=True) return local_snapshot user_block = self._join_blocks( local_snapshot.user_block, self._render_profile(profile), ) memory_block = self._join_blocks( local_snapshot.memory_block, self._render_session_context(session_context), self._render_search(search), ) return MemorySnapshot(memory_block=memory_block, user_block=user_block) async def archive_run_async( self, *, user_id: str | None, session_id: str, user_message: str, assistant_message: str, ) -> dict[str, Any]: if not user_id: return {"success": False, "error": "user_id is required for Memory Gateway archive."} user_key = await self.client.ensure_user(user_id) result: dict[str, Any] = { "messages": await self.client.ingest_messages( user_id=user_id, user_key=user_key, session_id=session_id, user_message=user_message, assistant_message=assistant_message, ) } if self.config.commit_on_run_complete: result["commit"] = await self.client.commit_session( user_id=user_id, user_key=user_key, session_id=session_id, ) return result @staticmethod def _join_blocks(*blocks: str | None) -> str | None: joined = "\n\n".join(block.strip() for block in blocks if block and block.strip()) return joined or None @staticmethod def _render_profile(payload: dict[str, Any]) -> str | None: lines: list[str] = [] profile = payload.get("profile") if isinstance(profile, dict): lines.extend(_compact_mapping(profile)) elif profile: lines.append(str(profile)) lines.extend(_compact_items(payload.get("items"))) return _section("GATEWAY USER PROFILE", lines) @staticmethod def _render_session_context(payload: dict[str, Any]) -> str | None: lines: list[str] = [] context = payload.get("context") if isinstance(context, dict): overview = context.get("latest_archive_overview") if overview: lines.append(str(overview)) messages = context.get("messages") if isinstance(messages, list): lines.extend(_compact_items(messages)) lines.extend(_compact_items(payload.get("items"))) return _section("GATEWAY SESSION CONTEXT", lines) @staticmethod def _render_search(payload: dict[str, Any]) -> str | None: return _section("GATEWAY SEARCH RESULTS", _compact_items(payload.get("items"))) def _section(title: str, lines: list[str]) -> str | None: cleaned = [line.strip() for line in lines if line and line.strip()] if not cleaned: return None return f"{title}\n" + "\n".join(cleaned) def _compact_mapping(payload: dict[str, Any]) -> list[str]: result: list[str] = [] for key in ("summary", "text", "content", "description"): value = payload.get(key) if value: result.append(str(value)) if not result and payload: result.append(str(payload)) return result def _compact_items(items: Any) -> list[str]: if not isinstance(items, list): return [] result: list[str] = [] for item in items: if not isinstance(item, dict): if item: result.append(f"- {item}") continue text = item.get("text") or item.get("summary") or item.get("content") if not text: text = str(item) source = item.get("source_backend") prefix = f"[{source}] " if source else "" result.append(f"- {prefix}{text}") return result