diff --git a/app-instance/backend/beaver/engine/loader.py b/app-instance/backend/beaver/engine/loader.py index 270cd50..bb06cd1 100644 --- a/app-instance/backend/beaver/engine/loader.py +++ b/app-instance/backend/beaver/engine/loader.py @@ -14,6 +14,8 @@ from beaver.engine.session import SessionManager from beaver.foundation.config import BeaverConfig, load_config from beaver.integrations.mcp import MCPConnectionManager from beaver.memory.curated.store import MemoryStore +from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore +from beaver.memory.gateway.service import GatewayAugmentedMemoryService from beaver.memory.runs import RunMemoryStore from beaver.memory.skills import SkillLearningStore from beaver.services.memory_service import MemoryService @@ -206,7 +208,26 @@ class EngineLoader: curated_root = workspace / "memory" / "curated" curated_memory_store = self._curated_memory_store or MemoryStore(curated_root) - memory_service = self._memory_service or MemoryService(curated_root, store=curated_memory_store) + if self._memory_service is not None: + memory_service = self._memory_service + else: + local_memory_service = MemoryService(curated_root, store=curated_memory_store) + memory_cfg = self.config.memory + if memory_cfg.mode == "gateway" and memory_cfg.gateway.base_url: + gateway_store = MemoryGatewayUserStore(workspace / "memory" / "gateway" / "state.db") + gateway_client = MemoryGatewayClient( + base_url=memory_cfg.gateway.base_url, + api_key=memory_cfg.gateway.api_key, + timeout_seconds=memory_cfg.gateway.timeout_seconds, + store=gateway_store, + ) + memory_service = GatewayAugmentedMemoryService( + local_service=local_memory_service, + client=gateway_client, + config=memory_cfg.gateway, + ) + else: + memory_service = local_memory_service memory_service.initialize() run_memory_store = self._run_memory_store or RunMemoryStore(workspace / "memory" / "runs") skill_learning_store = self._skill_learning_store or SkillLearningStore(workspace / "memory" / "skills") diff --git a/app-instance/backend/beaver/engine/loop.py b/app-instance/backend/beaver/engine/loop.py index 50a32e9..119dd01 100644 --- a/app-instance/backend/beaver/engine/loop.py +++ b/app-instance/backend/beaver/engine/loop.py @@ -380,10 +380,15 @@ class AgentLoop: resolved_max_tool_iterations = ( self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations ) + resolved_memory_user_id = user_id or config.memory.gateway.default_user_id or None # 每个 run 都捕获自己的 frozen snapshot,不能依赖 MemoryService # 上的共享 `_snapshot`,否则 parallel team runs 会互相覆盖。 - memory_snapshot = memory_service.capture_snapshot_for_run() + memory_snapshot = await memory_service.capture_snapshot_for_run_async( + user_id=resolved_memory_user_id, + session_id=resolved_session_id, + query=task, + ) if parent_session_id: session_manager.ensure_session( @@ -834,6 +839,22 @@ class AgentLoop: model=final_model, user_id=user_id, ) + archive_fn = getattr(memory_service, "archive_run_async", None) + if archive_fn is not None and resolved_memory_user_id: + asyncio.create_task( + self._archive_memory_gateway_run( + archive_fn=archive_fn, + session_manager=session_manager, + session_id=resolved_session_id, + run_id=resolved_run_id, + user_id=resolved_memory_user_id, + user_message=task, + assistant_message=final_text, + source=source, + title=title, + model=final_model, + ) + ) self._record_run_receipts( skill_learning_service=skill_learning_service, session_manager=session_manager, @@ -1191,6 +1212,57 @@ class AgentLoop: context_visible=False, ) + @staticmethod + async def _archive_memory_gateway_run( + *, + archive_fn: Any, + session_manager: Any, + session_id: str, + run_id: str, + user_id: str | None, + user_message: str, + assistant_message: str, + source: str, + title: str | None, + model: str | None, + ) -> None: + try: + result = await archive_fn( + user_id=user_id, + session_id=session_id, + user_message=user_message, + assistant_message=assistant_message, + ) + except Exception as exc: # noqa: BLE001 - archive must not change completed run result + session_manager.append_message( + session_id, + run_id=run_id, + role="system", + event_type="memory_gateway_archive_failed", + event_payload={"error": str(exc)}, + content=f"Memory Gateway archive failed: {exc}", + context_visible=False, + source=source, + title=title, + model=model, + user_id=user_id, + ) + return + + session_manager.append_message( + session_id, + run_id=run_id, + role="system", + event_type="memory_gateway_archive_completed", + event_payload={"result": result}, + content="Memory Gateway archive completed.", + context_visible=False, + source=source, + title=title, + model=model, + user_id=user_id, + ) + @staticmethod def _utc_now() -> str: return datetime.now(timezone.utc).isoformat() diff --git a/app-instance/backend/beaver/foundation/config/loader.py b/app-instance/backend/beaver/foundation/config/loader.py index 3e71302..dd12ed5 100644 --- a/app-instance/backend/beaver/foundation/config/loader.py +++ b/app-instance/backend/beaver/foundation/config/loader.py @@ -15,6 +15,8 @@ from .schema import ( BeaverConfig, ChannelConfig, EmbeddingConfig, + MemoryConfig, + MemoryGatewayConfig, MCPServerConfig, ProviderConfig, ToolsConfig, @@ -76,6 +78,7 @@ def load_config( authz=_parse_authz(data.get("authz")), channels=_parse_channels(data.get("channels")), backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")), + memory=_parse_memory(data.get("memory")), config_path=path, ) @@ -251,6 +254,35 @@ def _parse_backend_identity(raw: Any) -> BackendIdentityConfig: ) +def _parse_memory(raw: Any) -> MemoryConfig: + data = _as_dict(raw) + gateway = _as_dict(data.get("gateway")) + mode = (_string(data.get("mode")) or "local").lower() + if mode not in {"local", "gateway"}: + mode = "local" + return MemoryConfig( + mode=mode, + gateway=MemoryGatewayConfig( + base_url=_string(gateway.get("baseUrl") or gateway.get("base_url")) or "", + api_key=_string(gateway.get("apiKey") or gateway.get("api_key")) or "", + default_user_id=_string(gateway.get("defaultUserId") or gateway.get("default_user_id")) or "", + timeout_seconds=_float(gateway.get("timeoutSeconds") or gateway.get("timeout_seconds")) or 15.0, + snapshot_search_limit=int( + _float(gateway.get("snapshotSearchLimit") or gateway.get("snapshot_search_limit")) + or 5 + ), + commit_on_run_complete=_bool( + ( + gateway.get("commitOnRunComplete") + if "commitOnRunComplete" in gateway + else gateway.get("commit_on_run_complete") + ), + default=True, + ), + ), + ) + + def _as_dict(value: Any) -> dict[str, Any]: return value if isinstance(value, dict) else {} diff --git a/app-instance/backend/beaver/foundation/config/schema.py b/app-instance/backend/beaver/foundation/config/schema.py index 2c89f57..cac1758 100644 --- a/app-instance/backend/beaver/foundation/config/schema.py +++ b/app-instance/backend/beaver/foundation/config/schema.py @@ -115,6 +115,26 @@ class BackendIdentityConfig: public_base_url: str = "" +@dataclass(slots=True) +class MemoryGatewayConfig: + """Memory Gateway integration settings.""" + + base_url: str = "" + api_key: str = "" + default_user_id: str = "" + timeout_seconds: float = 15.0 + snapshot_search_limit: int = 5 + commit_on_run_complete: bool = True + + +@dataclass(slots=True) +class MemoryConfig: + """Runtime memory strategy configuration.""" + + mode: str = "local" + gateway: MemoryGatewayConfig = field(default_factory=MemoryGatewayConfig) + + @dataclass(slots=True) class BeaverConfig: """Config loaded once per backend sandbox instance.""" @@ -126,6 +146,7 @@ class BeaverConfig: authz: AuthzConfig = field(default_factory=AuthzConfig) channels: dict[str, ChannelConfig] = field(default_factory=dict) backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig) + memory: MemoryConfig = field(default_factory=MemoryConfig) config_path: Path | None = None @property diff --git a/app-instance/backend/beaver/memory/gateway/__init__.py b/app-instance/backend/beaver/memory/gateway/__init__.py new file mode 100644 index 0000000..395d57f --- /dev/null +++ b/app-instance/backend/beaver/memory/gateway/__init__.py @@ -0,0 +1,6 @@ +"""Memory Gateway integration for Beaver memory runtime.""" + +from .client import MemoryGatewayClient +from .store import MemoryGatewayUserStore + +__all__ = ["MemoryGatewayClient", "MemoryGatewayUserStore"] diff --git a/app-instance/backend/beaver/memory/gateway/client.py b/app-instance/backend/beaver/memory/gateway/client.py new file mode 100644 index 0000000..1c9b6ac --- /dev/null +++ b/app-instance/backend/beaver/memory/gateway/client.py @@ -0,0 +1,150 @@ +"""HTTP client for Memory Gateway's `/memory-system` API.""" + +from __future__ import annotations + +from typing import Any + +import httpx + +from .store import MemoryGatewayUserStore + + +class MemoryGatewayClient: + """Small async client for the Memory Gateway business API.""" + + def __init__( + self, + *, + base_url: str, + store: MemoryGatewayUserStore, + api_key: str = "", + timeout_seconds: float = 15.0, + transport: httpx.AsyncBaseTransport | None = None, + ) -> None: + self.base_url = base_url.rstrip("/") + self.api_key = api_key + self.timeout_seconds = timeout_seconds + self.store = store + self.transport = transport + + async def ensure_user(self, user_id: str) -> str: + cached = self.store.get_user_key(user_id) + if cached: + return cached + + data = await self._post("/memory-system/users", {"user_id": user_id}) + user_key = self._extract_user_key(data) + if not user_key: + raise RuntimeError("Memory Gateway user creation response missing user_key") + self.store.save_user_key(user_id, user_key) + return user_key + + async def get_profile( + self, + *, + user_id: str, + user_key: str, + query: str = "用户画像", + limit: int = 5, + ) -> dict[str, Any]: + return await self._get( + f"/memory-system/users/{user_id}/profile", + {"user_key": user_key, "query": query, "limit": limit}, + ) + + async def get_session_context( + self, + *, + user_id: str, + user_key: str, + session_id: str, + query: str, + limit: int = 5, + ) -> dict[str, Any]: + return await self._post( + f"/memory-system/sessions/{session_id}/context", + {"user_id": user_id, "user_key": user_key, "query": query, "limit": limit}, + ) + + async def search( + self, + *, + user_id: str, + user_key: str, + session_id: str, + query: str, + limit: int = 5, + ) -> dict[str, Any]: + return await self._post( + "/memory-system/search", + { + "user_id": user_id, + "user_key": user_key, + "session_id": session_id, + "query": query, + "use_llm": False, + "limit": limit, + }, + ) + + async def ingest_messages( + self, + *, + user_id: str, + user_key: str, + session_id: str, + user_message: str | None, + assistant_message: str | None, + ) -> dict[str, Any]: + return await self._post( + "/memory-system/messages", + { + "user_id": user_id, + "user_key": user_key, + "session_id": session_id, + "user_message": user_message, + "assistant_message": assistant_message, + }, + ) + + async def commit_session( + self, + *, + user_id: str, + user_key: str, + session_id: str, + ) -> dict[str, Any]: + return await self._post( + f"/memory-system/sessions/{session_id}/commit", + {"user_id": user_id, "user_key": user_key}, + ) + + async def _get(self, path: str, params: dict[str, Any]) -> dict[str, Any]: + async with self._client() as client: + response = await client.get(path, params=params) + response.raise_for_status() + return response.json() + + async def _post(self, path: str, payload: dict[str, Any]) -> dict[str, Any]: + async with self._client() as client: + response = await client.post(path, json=payload) + response.raise_for_status() + return response.json() + + def _client(self) -> httpx.AsyncClient: + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["X-API-Key"] = self.api_key + return httpx.AsyncClient( + base_url=self.base_url, + headers=headers, + timeout=self.timeout_seconds, + transport=self.transport, + ) + + @staticmethod + def _extract_user_key(data: dict[str, Any]) -> str | None: + account = data.get("account") + result = account.get("result") if isinstance(account, dict) else None + user_key = result.get("user_key") if isinstance(result, dict) else None + return str(user_key) if user_key else None diff --git a/app-instance/backend/beaver/memory/gateway/service.py b/app-instance/backend/beaver/memory/gateway/service.py new file mode 100644 index 0000000..b8e1038 --- /dev/null +++ b/app-instance/backend/beaver/memory/gateway/service.py @@ -0,0 +1,190 @@ +"""Memory service wrapper that augments local snapshots with Memory Gateway recall.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from beaver.foundation.config.schema import MemoryGatewayConfig +from beaver.memory.curated.snapshot import MemorySnapshot +from beaver.services.memory_service import MemoryService + +logger = logging.getLogger(__name__) + + +class GatewayAugmentedMemoryService: + """Keep local curated memory and add Memory Gateway recall when available.""" + + def __init__( + self, + *, + local_service: MemoryService, + client: Any, + config: MemoryGatewayConfig, + ) -> None: + self.local_service = local_service + self.client = client + self.config = config + + def initialize(self) -> None: + self.local_service.initialize() + + def get_store(self) -> Any: + return self.local_service.get_store() + + def get_snapshot(self) -> MemorySnapshot: + return self.local_service.get_snapshot() + + def capture_snapshot_for_run(self) -> MemorySnapshot: + return self.local_service.capture_snapshot_for_run() + + async def capture_snapshot_for_run_async( + self, + *, + user_id: str | None = None, + session_id: str | None = None, + query: str | None = None, + ) -> MemorySnapshot: + local_snapshot = self.local_service.capture_snapshot_for_run() + if not user_id or not session_id or not query: + return local_snapshot + + try: + user_key = await self.client.ensure_user(user_id) + limit = max(1, self.config.snapshot_search_limit) + profile, session_context, search = await asyncio.gather( + self.client.get_profile( + user_id=user_id, + user_key=user_key, + query="用户画像", + limit=limit, + ), + self.client.get_session_context( + user_id=user_id, + user_key=user_key, + session_id=session_id, + query=query, + limit=limit, + ), + self.client.search( + user_id=user_id, + user_key=user_key, + session_id=session_id, + query=query, + limit=limit, + ), + ) + except Exception as exc: # noqa: BLE001 - gateway recall must not break local memory + logger.warning("Memory Gateway snapshot augmentation failed: %s", exc, exc_info=True) + return local_snapshot + + user_block = self._join_blocks( + local_snapshot.user_block, + self._render_profile(profile), + ) + memory_block = self._join_blocks( + local_snapshot.memory_block, + self._render_session_context(session_context), + self._render_search(search), + ) + return MemorySnapshot(memory_block=memory_block, user_block=user_block) + + async def archive_run_async( + self, + *, + user_id: str | None, + session_id: str, + user_message: str, + assistant_message: str, + ) -> dict[str, Any]: + if not user_id: + return {"success": False, "error": "user_id is required for Memory Gateway archive."} + + user_key = await self.client.ensure_user(user_id) + result: dict[str, Any] = { + "messages": await self.client.ingest_messages( + user_id=user_id, + user_key=user_key, + session_id=session_id, + user_message=user_message, + assistant_message=assistant_message, + ) + } + if self.config.commit_on_run_complete: + result["commit"] = await self.client.commit_session( + user_id=user_id, + user_key=user_key, + session_id=session_id, + ) + return result + + @staticmethod + def _join_blocks(*blocks: str | None) -> str | None: + joined = "\n\n".join(block.strip() for block in blocks if block and block.strip()) + return joined or None + + @staticmethod + def _render_profile(payload: dict[str, Any]) -> str | None: + lines: list[str] = [] + profile = payload.get("profile") + if isinstance(profile, dict): + lines.extend(_compact_mapping(profile)) + elif profile: + lines.append(str(profile)) + lines.extend(_compact_items(payload.get("items"))) + return _section("GATEWAY USER PROFILE", lines) + + @staticmethod + def _render_session_context(payload: dict[str, Any]) -> str | None: + lines: list[str] = [] + context = payload.get("context") + if isinstance(context, dict): + overview = context.get("latest_archive_overview") + if overview: + lines.append(str(overview)) + messages = context.get("messages") + if isinstance(messages, list): + lines.extend(_compact_items(messages)) + lines.extend(_compact_items(payload.get("items"))) + return _section("GATEWAY SESSION CONTEXT", lines) + + @staticmethod + def _render_search(payload: dict[str, Any]) -> str | None: + return _section("GATEWAY SEARCH RESULTS", _compact_items(payload.get("items"))) + + +def _section(title: str, lines: list[str]) -> str | None: + cleaned = [line.strip() for line in lines if line and line.strip()] + if not cleaned: + return None + return f"{title}\n" + "\n".join(cleaned) + + +def _compact_mapping(payload: dict[str, Any]) -> list[str]: + result: list[str] = [] + for key in ("summary", "text", "content", "description"): + value = payload.get(key) + if value: + result.append(str(value)) + if not result and payload: + result.append(str(payload)) + return result + + +def _compact_items(items: Any) -> list[str]: + if not isinstance(items, list): + return [] + result: list[str] = [] + for item in items: + if not isinstance(item, dict): + if item: + result.append(f"- {item}") + continue + text = item.get("text") or item.get("summary") or item.get("content") + if not text: + text = str(item) + source = item.get("source_backend") + prefix = f"[{source}] " if source else "" + result.append(f"- {prefix}{text}") + return result diff --git a/app-instance/backend/beaver/memory/gateway/store.py b/app-instance/backend/beaver/memory/gateway/store.py new file mode 100644 index 0000000..acf54b1 --- /dev/null +++ b/app-instance/backend/beaver/memory/gateway/store.py @@ -0,0 +1,54 @@ +"""SQLite cache for Memory Gateway user credentials.""" + +from __future__ import annotations + +import sqlite3 +import time +from pathlib import Path + + +class MemoryGatewayUserStore: + """Persist `user_id -> user_key` mappings returned by Memory Gateway.""" + + def __init__(self, db_path: str | Path) -> None: + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_schema() + + def get_user_key(self, user_id: str) -> str | None: + with self._connect() as conn: + row = conn.execute( + "SELECT user_key FROM memory_gateway_users WHERE user_id = ?", + (user_id,), + ).fetchone() + return str(row[0]) if row else None + + def save_user_key(self, user_id: str, user_key: str) -> None: + now = time.time() + with self._connect() as conn: + conn.execute( + """ + INSERT INTO memory_gateway_users (user_id, user_key, created_at, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(user_id) DO UPDATE SET + user_key = excluded.user_key, + updated_at = excluded.updated_at + """, + (user_id, user_key, now, now), + ) + + def _init_schema(self) -> None: + with self._connect() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS memory_gateway_users ( + user_id TEXT PRIMARY KEY, + user_key TEXT NOT NULL, + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """ + ) + + def _connect(self) -> sqlite3.Connection: + return sqlite3.connect(str(self.db_path)) diff --git a/app-instance/backend/beaver/services/memory_service.py b/app-instance/backend/beaver/services/memory_service.py index 91dd5b8..12c93b5 100644 --- a/app-instance/backend/beaver/services/memory_service.py +++ b/app-instance/backend/beaver/services/memory_service.py @@ -58,6 +58,17 @@ class MemoryService: store.load_from_disk() return capture_memory_snapshot(store) + async def capture_snapshot_for_run_async( + self, + *, + user_id: str | None = None, + session_id: str | None = None, + query: str | None = None, + ) -> MemorySnapshot: + """Async-compatible snapshot hook used by optional memory integrations.""" + + return self.capture_snapshot_for_run() + def get_snapshot(self) -> MemorySnapshot: """获取当前 run 应注入 system prompt 的 frozen snapshot。""" diff --git a/app-instance/backend/tests/unit/test_config_loader.py b/app-instance/backend/tests/unit/test_config_loader.py index 1f61cef..06cfd55 100644 --- a/app-instance/backend/tests/unit/test_config_loader.py +++ b/app-instance/backend/tests/unit/test_config_loader.py @@ -85,6 +85,49 @@ def test_config_loader_reads_channels(tmp_path) -> None: assert channel.secrets == {"ignored_for_status": "secret-value"} +def test_config_loader_reads_memory_gateway_config(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "memory": { + "mode": "gateway", + "gateway": { + "baseUrl": "http://127.0.0.1:1934", + "apiKey": "gateway-key", + "defaultUserId": "default-user", + "timeoutSeconds": 12, + "snapshotSearchLimit": 7, + "commitOnRunComplete": False, + }, + }, + } + ), + encoding="utf-8", + ) + + config = load_config(config_path=config_path) + + assert config.memory.mode == "gateway" + assert config.memory.gateway.base_url == "http://127.0.0.1:1934" + assert config.memory.gateway.api_key == "gateway-key" + assert config.memory.gateway.default_user_id == "default-user" + assert config.memory.gateway.timeout_seconds == 12 + assert config.memory.gateway.snapshot_search_limit == 7 + assert config.memory.gateway.commit_on_run_complete is False + + +def test_config_loader_defaults_to_local_memory(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps({}), encoding="utf-8") + + config = load_config(config_path=config_path) + + assert config.memory.mode == "local" + assert config.memory.gateway.base_url == "" + assert config.memory.gateway.commit_on_run_complete is True + + def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None: config_path = tmp_path / "config.json" config_path.write_text( diff --git a/app-instance/backend/tests/unit/test_memory_gateway_archive.py b/app-instance/backend/tests/unit/test_memory_gateway_archive.py new file mode 100644 index 0000000..c2c3817 --- /dev/null +++ b/app-instance/backend/tests/unit/test_memory_gateway_archive.py @@ -0,0 +1,130 @@ +import asyncio +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from beaver.engine import AgentLoop, EngineLoader +from beaver.engine.providers.base import LLMProvider, LLMResponse +from beaver.engine.providers.factory import ProviderBundle +from beaver.memory.curated.snapshot import MemorySnapshot +from beaver.memory.curated.store import MemoryStore + + +class _RecordingProvider(LLMProvider): + def __init__(self, response_text: str = "done") -> None: + super().__init__() + self.response_text = response_text + self.messages: list[list[dict]] = [] + + async def chat( + self, + messages: list[dict], + tools: list[dict] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + thinking_enabled: bool | None = None, + ) -> LLMResponse: + self.messages.append(messages) + return LLMResponse( + content=self.response_text, + finish_reason="stop", + provider_name="stub", + model="stub-model", + ) + + def get_default_model(self) -> str: + return "stub-model" + + +class _FakeMemoryService: + def __init__(self, root: Path) -> None: + self.store = MemoryStore(root) + self.snapshot_calls: list[dict] = [] + self.archive_calls: list[dict] = [] + self.archive_started: asyncio.Event | None = None + self.archive_release: asyncio.Event | None = None + + def initialize(self) -> None: + self.store.load_from_disk() + + def get_store(self) -> MemoryStore: + return self.store + + async def capture_snapshot_for_run_async(self, **kwargs) -> MemorySnapshot: + self.snapshot_calls.append(kwargs) + return MemorySnapshot(memory_block="ASYNC SNAPSHOT FROM GATEWAY", user_block=None) + + async def archive_run_async(self, **kwargs) -> dict: + self.archive_calls.append(kwargs) + if self.archive_started is not None: + self.archive_started.set() + if self.archive_release is not None: + await self.archive_release.wait() + return {"status": "success"} + + +def _bundle(provider: _RecordingProvider) -> ProviderBundle: + return ProviderBundle( + main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"), + main_provider=provider, + ) + + +@pytest.mark.asyncio +async def test_agent_loop_uses_async_memory_snapshot(tmp_path) -> None: + memory_service = _FakeMemoryService(tmp_path / "memory" / "curated") + provider = _RecordingProvider("final answer") + loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service)) + + result = await loop.process_direct( + "current task", + user_id="user-1", + session_id="session-1", + provider_bundle=_bundle(provider), + include_skill_assembly=False, + include_tools=False, + ) + + assert result.output_text == "final answer" + assert memory_service.snapshot_calls == [ + {"user_id": "user-1", "session_id": "session-1", "query": "current task"} + ] + assert "ASYNC SNAPSHOT FROM GATEWAY" in provider.messages[0][0]["content"] + loop.close() + + +@pytest.mark.asyncio +async def test_agent_loop_archives_memory_gateway_run_in_background(tmp_path) -> None: + memory_service = _FakeMemoryService(tmp_path / "memory" / "curated") + memory_service.archive_started = asyncio.Event() + memory_service.archive_release = asyncio.Event() + provider = _RecordingProvider("assistant final") + loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service)) + + result = await loop.process_direct( + "user asks", + user_id="user-1", + session_id="session-1", + provider_bundle=_bundle(provider), + include_skill_assembly=False, + include_tools=False, + ) + + assert result.finish_reason == "stop" + await asyncio.wait_for(memory_service.archive_started.wait(), timeout=1) + assert memory_service.archive_calls == [ + { + "user_id": "user-1", + "session_id": "session-1", + "user_message": "user asks", + "assistant_message": "assistant final", + } + ] + + memory_service.archive_release.set() + await asyncio.sleep(0) + events = loop.boot().session_manager.get_run_event_records(result.session_id, result.run_id) + assert any(event.event_type == "memory_gateway_archive_completed" for event in events) + loop.close() diff --git a/app-instance/backend/tests/unit/test_memory_gateway_client.py b/app-instance/backend/tests/unit/test_memory_gateway_client.py new file mode 100644 index 0000000..2a87fd3 --- /dev/null +++ b/app-instance/backend/tests/unit/test_memory_gateway_client.py @@ -0,0 +1,100 @@ +import json + +import httpx +import pytest + +from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore + + +@pytest.mark.asyncio +async def test_memory_gateway_client_uses_cached_user_key(tmp_path) -> None: + store = MemoryGatewayUserStore(tmp_path / "gateway.db") + store.save_user_key("user-1", "cached-key") + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + return httpx.Response(500, json={"error": "should not be called"}) + + client = MemoryGatewayClient( + base_url="http://gateway.test", + store=store, + transport=httpx.MockTransport(handler), + ) + + user_key = await client.ensure_user("user-1") + + assert user_key == "cached-key" + assert requests == [] + + +@pytest.mark.asyncio +async def test_memory_gateway_client_creates_and_caches_user_key(tmp_path) -> None: + store = MemoryGatewayUserStore(tmp_path / "gateway.db") + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + assert request.method == "POST" + assert request.url.path == "/memory-system/users" + assert json.loads(request.content) == {"user_id": "user-1"} + return httpx.Response( + 200, + json={ + "status": "success", + "account": { + "status": "ok", + "result": {"user_key": "created-key"}, + }, + }, + ) + + client = MemoryGatewayClient( + base_url="http://gateway.test", + store=store, + transport=httpx.MockTransport(handler), + ) + + user_key = await client.ensure_user("user-1") + + assert user_key == "created-key" + assert store.get_user_key("user-1") == "created-key" + assert len(requests) == 1 + + +@pytest.mark.asyncio +async def test_memory_gateway_client_ingests_messages_with_user_key(tmp_path) -> None: + store = MemoryGatewayUserStore(tmp_path / "gateway.db") + requests: list[httpx.Request] = [] + + async def handler(request: httpx.Request) -> httpx.Response: + requests.append(request) + assert request.method == "POST" + assert request.url.path == "/memory-system/messages" + assert request.headers["X-API-Key"] == "gateway-api-key" + assert json.loads(request.content) == { + "user_id": "user-1", + "user_key": "user-key", + "session_id": "session-1", + "user_message": "hello", + "assistant_message": "hi", + } + return httpx.Response(200, json={"status": "success", "message_count": 2}) + + client = MemoryGatewayClient( + base_url="http://gateway.test", + api_key="gateway-api-key", + store=store, + transport=httpx.MockTransport(handler), + ) + + result = await client.ingest_messages( + user_id="user-1", + user_key="user-key", + session_id="session-1", + user_message="hello", + assistant_message="hi", + ) + + assert result["status"] == "success" + assert len(requests) == 1 diff --git a/app-instance/backend/tests/unit/test_memory_gateway_snapshot.py b/app-instance/backend/tests/unit/test_memory_gateway_snapshot.py new file mode 100644 index 0000000..7088d7d --- /dev/null +++ b/app-instance/backend/tests/unit/test_memory_gateway_snapshot.py @@ -0,0 +1,117 @@ +import pytest + +from beaver.engine import EngineLoader +from beaver.foundation.config.schema import MemoryGatewayConfig +from beaver.memory.gateway.service import GatewayAugmentedMemoryService +from beaver.services.memory_service import MemoryService + + +class _FakeGatewayClient: + async def ensure_user(self, user_id: str) -> str: + return f"{user_id}-key" + + async def get_profile(self, **kwargs): + return { + "status": "success", + "profile": {"summary": "用户喜欢拿铁。"}, + "items": [{"summary": "用户偏好中文回复。"}], + } + + async def get_session_context(self, **kwargs): + return { + "status": "success", + "context": {"latest_archive_overview": "上次讨论了 memory gateway 接入。"}, + "items": [{"summary": "用户要求保留本地 MEMORY.md。"}], + } + + async def search(self, **kwargs): + return { + "status": "success", + "items": [ + { + "source_backend": "openviking", + "text": "需要异步写入 /memory-system/messages。", + } + ], + } + + +class _FailingGatewayClient(_FakeGatewayClient): + async def ensure_user(self, user_id: str) -> str: + raise RuntimeError("gateway unavailable") + + +@pytest.mark.asyncio +async def test_gateway_snapshot_keeps_local_memory_and_adds_gateway_sections(tmp_path) -> None: + local = MemoryService(tmp_path / "memory" / "curated") + local.initialize() + local.get_store().add("memory", "本地项目约定:默认用中文解释。") + local.get_store().add("user", "本地用户画像:用户关注记忆系统。") + service = GatewayAugmentedMemoryService( + local_service=local, + client=_FakeGatewayClient(), + config=MemoryGatewayConfig(snapshot_search_limit=3), + ) + + snapshot = await service.capture_snapshot_for_run_async( + user_id="user-1", + session_id="session-1", + query="如何接入 memory gateway", + ) + prompt = "\n".join(snapshot.as_prompt_sections()) + + assert "本地项目约定:默认用中文解释。" in prompt + assert "本地用户画像:用户关注记忆系统。" in prompt + assert "GATEWAY USER PROFILE" in prompt + assert "用户喜欢拿铁。" in prompt + assert "GATEWAY SESSION CONTEXT" in prompt + assert "上次讨论了 memory gateway 接入。" in prompt + assert "GATEWAY SEARCH RESULTS" in prompt + assert "需要异步写入 /memory-system/messages。" in prompt + + +@pytest.mark.asyncio +async def test_gateway_snapshot_falls_back_to_local_memory_on_gateway_failure(tmp_path) -> None: + local = MemoryService(tmp_path / "memory" / "curated") + local.initialize() + local.get_store().add("memory", "本地记忆仍然可用。") + service = GatewayAugmentedMemoryService( + local_service=local, + client=_FailingGatewayClient(), + config=MemoryGatewayConfig(snapshot_search_limit=3), + ) + + snapshot = await service.capture_snapshot_for_run_async( + user_id="user-1", + session_id="session-1", + query="任何问题", + ) + prompt = "\n".join(snapshot.as_prompt_sections()) + + assert "本地记忆仍然可用。" in prompt + assert "GATEWAY USER PROFILE" not in prompt + + +def test_engine_loader_uses_gateway_memory_service_without_replacing_tools(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + """ + { + "agents": {"defaults": {"workspace": "%s"}}, + "memory": { + "mode": "gateway", + "gateway": {"baseUrl": "http://gateway.test", "defaultUserId": "default-user"} + } + } + """ + % str(tmp_path / "workspace"), + encoding="utf-8", + ) + + loader = EngineLoader(config_path=config_path) + loaded = loader.load() + + assert isinstance(loaded.memory_service, GatewayAugmentedMemoryService) + assert "memory" in loaded.tools + assert "session_search" in loaded.tools + loaded.close()