feat: add Memory Gateway integration with async support for memory snapshots and user management

This commit is contained in:
2026-06-04 17:00:02 +08:00
parent 236ac19789
commit d93ca62990
13 changed files with 949 additions and 2 deletions

View File

@ -14,6 +14,8 @@ from beaver.engine.session import SessionManager
from beaver.foundation.config import BeaverConfig, load_config from beaver.foundation.config import BeaverConfig, load_config
from beaver.integrations.mcp import MCPConnectionManager from beaver.integrations.mcp import MCPConnectionManager
from beaver.memory.curated.store import MemoryStore from beaver.memory.curated.store import MemoryStore
from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore
from beaver.memory.gateway.service import GatewayAugmentedMemoryService
from beaver.memory.runs import RunMemoryStore from beaver.memory.runs import RunMemoryStore
from beaver.memory.skills import SkillLearningStore from beaver.memory.skills import SkillLearningStore
from beaver.services.memory_service import MemoryService from beaver.services.memory_service import MemoryService
@ -206,7 +208,26 @@ class EngineLoader:
curated_root = workspace / "memory" / "curated" curated_root = workspace / "memory" / "curated"
curated_memory_store = self._curated_memory_store or MemoryStore(curated_root) curated_memory_store = self._curated_memory_store or MemoryStore(curated_root)
memory_service = self._memory_service or MemoryService(curated_root, store=curated_memory_store) if self._memory_service is not None:
memory_service = self._memory_service
else:
local_memory_service = MemoryService(curated_root, store=curated_memory_store)
memory_cfg = self.config.memory
if memory_cfg.mode == "gateway" and memory_cfg.gateway.base_url:
gateway_store = MemoryGatewayUserStore(workspace / "memory" / "gateway" / "state.db")
gateway_client = MemoryGatewayClient(
base_url=memory_cfg.gateway.base_url,
api_key=memory_cfg.gateway.api_key,
timeout_seconds=memory_cfg.gateway.timeout_seconds,
store=gateway_store,
)
memory_service = GatewayAugmentedMemoryService(
local_service=local_memory_service,
client=gateway_client,
config=memory_cfg.gateway,
)
else:
memory_service = local_memory_service
memory_service.initialize() memory_service.initialize()
run_memory_store = self._run_memory_store or RunMemoryStore(workspace / "memory" / "runs") run_memory_store = self._run_memory_store or RunMemoryStore(workspace / "memory" / "runs")
skill_learning_store = self._skill_learning_store or SkillLearningStore(workspace / "memory" / "skills") skill_learning_store = self._skill_learning_store or SkillLearningStore(workspace / "memory" / "skills")

View File

@ -380,10 +380,15 @@ class AgentLoop:
resolved_max_tool_iterations = ( resolved_max_tool_iterations = (
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
) )
resolved_memory_user_id = user_id or config.memory.gateway.default_user_id or None
# 每个 run 都捕获自己的 frozen snapshot不能依赖 MemoryService # 每个 run 都捕获自己的 frozen snapshot不能依赖 MemoryService
# 上的共享 `_snapshot`,否则 parallel team runs 会互相覆盖。 # 上的共享 `_snapshot`,否则 parallel team runs 会互相覆盖。
memory_snapshot = memory_service.capture_snapshot_for_run() memory_snapshot = await memory_service.capture_snapshot_for_run_async(
user_id=resolved_memory_user_id,
session_id=resolved_session_id,
query=task,
)
if parent_session_id: if parent_session_id:
session_manager.ensure_session( session_manager.ensure_session(
@ -834,6 +839,22 @@ class AgentLoop:
model=final_model, model=final_model,
user_id=user_id, user_id=user_id,
) )
archive_fn = getattr(memory_service, "archive_run_async", None)
if archive_fn is not None and resolved_memory_user_id:
asyncio.create_task(
self._archive_memory_gateway_run(
archive_fn=archive_fn,
session_manager=session_manager,
session_id=resolved_session_id,
run_id=resolved_run_id,
user_id=resolved_memory_user_id,
user_message=task,
assistant_message=final_text,
source=source,
title=title,
model=final_model,
)
)
self._record_run_receipts( self._record_run_receipts(
skill_learning_service=skill_learning_service, skill_learning_service=skill_learning_service,
session_manager=session_manager, session_manager=session_manager,
@ -1191,6 +1212,57 @@ class AgentLoop:
context_visible=False, context_visible=False,
) )
@staticmethod
async def _archive_memory_gateway_run(
*,
archive_fn: Any,
session_manager: Any,
session_id: str,
run_id: str,
user_id: str | None,
user_message: str,
assistant_message: str,
source: str,
title: str | None,
model: str | None,
) -> None:
try:
result = await archive_fn(
user_id=user_id,
session_id=session_id,
user_message=user_message,
assistant_message=assistant_message,
)
except Exception as exc: # noqa: BLE001 - archive must not change completed run result
session_manager.append_message(
session_id,
run_id=run_id,
role="system",
event_type="memory_gateway_archive_failed",
event_payload={"error": str(exc)},
content=f"Memory Gateway archive failed: {exc}",
context_visible=False,
source=source,
title=title,
model=model,
user_id=user_id,
)
return
session_manager.append_message(
session_id,
run_id=run_id,
role="system",
event_type="memory_gateway_archive_completed",
event_payload={"result": result},
content="Memory Gateway archive completed.",
context_visible=False,
source=source,
title=title,
model=model,
user_id=user_id,
)
@staticmethod @staticmethod
def _utc_now() -> str: def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat() return datetime.now(timezone.utc).isoformat()

View File

@ -15,6 +15,8 @@ from .schema import (
BeaverConfig, BeaverConfig,
ChannelConfig, ChannelConfig,
EmbeddingConfig, EmbeddingConfig,
MemoryConfig,
MemoryGatewayConfig,
MCPServerConfig, MCPServerConfig,
ProviderConfig, ProviderConfig,
ToolsConfig, ToolsConfig,
@ -76,6 +78,7 @@ def load_config(
authz=_parse_authz(data.get("authz")), authz=_parse_authz(data.get("authz")),
channels=_parse_channels(data.get("channels")), channels=_parse_channels(data.get("channels")),
backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")), backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")),
memory=_parse_memory(data.get("memory")),
config_path=path, config_path=path,
) )
@ -251,6 +254,35 @@ def _parse_backend_identity(raw: Any) -> BackendIdentityConfig:
) )
def _parse_memory(raw: Any) -> MemoryConfig:
data = _as_dict(raw)
gateway = _as_dict(data.get("gateway"))
mode = (_string(data.get("mode")) or "local").lower()
if mode not in {"local", "gateway"}:
mode = "local"
return MemoryConfig(
mode=mode,
gateway=MemoryGatewayConfig(
base_url=_string(gateway.get("baseUrl") or gateway.get("base_url")) or "",
api_key=_string(gateway.get("apiKey") or gateway.get("api_key")) or "",
default_user_id=_string(gateway.get("defaultUserId") or gateway.get("default_user_id")) or "",
timeout_seconds=_float(gateway.get("timeoutSeconds") or gateway.get("timeout_seconds")) or 15.0,
snapshot_search_limit=int(
_float(gateway.get("snapshotSearchLimit") or gateway.get("snapshot_search_limit"))
or 5
),
commit_on_run_complete=_bool(
(
gateway.get("commitOnRunComplete")
if "commitOnRunComplete" in gateway
else gateway.get("commit_on_run_complete")
),
default=True,
),
),
)
def _as_dict(value: Any) -> dict[str, Any]: def _as_dict(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {} return value if isinstance(value, dict) else {}

View File

@ -115,6 +115,26 @@ class BackendIdentityConfig:
public_base_url: str = "" public_base_url: str = ""
@dataclass(slots=True)
class MemoryGatewayConfig:
"""Memory Gateway integration settings."""
base_url: str = ""
api_key: str = ""
default_user_id: str = ""
timeout_seconds: float = 15.0
snapshot_search_limit: int = 5
commit_on_run_complete: bool = True
@dataclass(slots=True)
class MemoryConfig:
"""Runtime memory strategy configuration."""
mode: str = "local"
gateway: MemoryGatewayConfig = field(default_factory=MemoryGatewayConfig)
@dataclass(slots=True) @dataclass(slots=True)
class BeaverConfig: class BeaverConfig:
"""Config loaded once per backend sandbox instance.""" """Config loaded once per backend sandbox instance."""
@ -126,6 +146,7 @@ class BeaverConfig:
authz: AuthzConfig = field(default_factory=AuthzConfig) authz: AuthzConfig = field(default_factory=AuthzConfig)
channels: dict[str, ChannelConfig] = field(default_factory=dict) channels: dict[str, ChannelConfig] = field(default_factory=dict)
backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig) backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig)
memory: MemoryConfig = field(default_factory=MemoryConfig)
config_path: Path | None = None config_path: Path | None = None
@property @property

View File

@ -0,0 +1,6 @@
"""Memory Gateway integration for Beaver memory runtime."""
from .client import MemoryGatewayClient
from .store import MemoryGatewayUserStore
__all__ = ["MemoryGatewayClient", "MemoryGatewayUserStore"]

View File

@ -0,0 +1,150 @@
"""HTTP client for Memory Gateway's `/memory-system` API."""
from __future__ import annotations
from typing import Any
import httpx
from .store import MemoryGatewayUserStore
class MemoryGatewayClient:
"""Small async client for the Memory Gateway business API."""
def __init__(
self,
*,
base_url: str,
store: MemoryGatewayUserStore,
api_key: str = "",
timeout_seconds: float = 15.0,
transport: httpx.AsyncBaseTransport | None = None,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.timeout_seconds = timeout_seconds
self.store = store
self.transport = transport
async def ensure_user(self, user_id: str) -> str:
cached = self.store.get_user_key(user_id)
if cached:
return cached
data = await self._post("/memory-system/users", {"user_id": user_id})
user_key = self._extract_user_key(data)
if not user_key:
raise RuntimeError("Memory Gateway user creation response missing user_key")
self.store.save_user_key(user_id, user_key)
return user_key
async def get_profile(
self,
*,
user_id: str,
user_key: str,
query: str = "用户画像",
limit: int = 5,
) -> dict[str, Any]:
return await self._get(
f"/memory-system/users/{user_id}/profile",
{"user_key": user_key, "query": query, "limit": limit},
)
async def get_session_context(
self,
*,
user_id: str,
user_key: str,
session_id: str,
query: str,
limit: int = 5,
) -> dict[str, Any]:
return await self._post(
f"/memory-system/sessions/{session_id}/context",
{"user_id": user_id, "user_key": user_key, "query": query, "limit": limit},
)
async def search(
self,
*,
user_id: str,
user_key: str,
session_id: str,
query: str,
limit: int = 5,
) -> dict[str, Any]:
return await self._post(
"/memory-system/search",
{
"user_id": user_id,
"user_key": user_key,
"session_id": session_id,
"query": query,
"use_llm": False,
"limit": limit,
},
)
async def ingest_messages(
self,
*,
user_id: str,
user_key: str,
session_id: str,
user_message: str | None,
assistant_message: str | None,
) -> dict[str, Any]:
return await self._post(
"/memory-system/messages",
{
"user_id": user_id,
"user_key": user_key,
"session_id": session_id,
"user_message": user_message,
"assistant_message": assistant_message,
},
)
async def commit_session(
self,
*,
user_id: str,
user_key: str,
session_id: str,
) -> dict[str, Any]:
return await self._post(
f"/memory-system/sessions/{session_id}/commit",
{"user_id": user_id, "user_key": user_key},
)
async def _get(self, path: str, params: dict[str, Any]) -> dict[str, Any]:
async with self._client() as client:
response = await client.get(path, params=params)
response.raise_for_status()
return response.json()
async def _post(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
async with self._client() as client:
response = await client.post(path, json=payload)
response.raise_for_status()
return response.json()
def _client(self) -> httpx.AsyncClient:
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-API-Key"] = self.api_key
return httpx.AsyncClient(
base_url=self.base_url,
headers=headers,
timeout=self.timeout_seconds,
transport=self.transport,
)
@staticmethod
def _extract_user_key(data: dict[str, Any]) -> str | None:
account = data.get("account")
result = account.get("result") if isinstance(account, dict) else None
user_key = result.get("user_key") if isinstance(result, dict) else None
return str(user_key) if user_key else None

View File

@ -0,0 +1,190 @@
"""Memory service wrapper that augments local snapshots with Memory Gateway recall."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from beaver.foundation.config.schema import MemoryGatewayConfig
from beaver.memory.curated.snapshot import MemorySnapshot
from beaver.services.memory_service import MemoryService
logger = logging.getLogger(__name__)
class GatewayAugmentedMemoryService:
"""Keep local curated memory and add Memory Gateway recall when available."""
def __init__(
self,
*,
local_service: MemoryService,
client: Any,
config: MemoryGatewayConfig,
) -> None:
self.local_service = local_service
self.client = client
self.config = config
def initialize(self) -> None:
self.local_service.initialize()
def get_store(self) -> Any:
return self.local_service.get_store()
def get_snapshot(self) -> MemorySnapshot:
return self.local_service.get_snapshot()
def capture_snapshot_for_run(self) -> MemorySnapshot:
return self.local_service.capture_snapshot_for_run()
async def capture_snapshot_for_run_async(
self,
*,
user_id: str | None = None,
session_id: str | None = None,
query: str | None = None,
) -> MemorySnapshot:
local_snapshot = self.local_service.capture_snapshot_for_run()
if not user_id or not session_id or not query:
return local_snapshot
try:
user_key = await self.client.ensure_user(user_id)
limit = max(1, self.config.snapshot_search_limit)
profile, session_context, search = await asyncio.gather(
self.client.get_profile(
user_id=user_id,
user_key=user_key,
query="用户画像",
limit=limit,
),
self.client.get_session_context(
user_id=user_id,
user_key=user_key,
session_id=session_id,
query=query,
limit=limit,
),
self.client.search(
user_id=user_id,
user_key=user_key,
session_id=session_id,
query=query,
limit=limit,
),
)
except Exception as exc: # noqa: BLE001 - gateway recall must not break local memory
logger.warning("Memory Gateway snapshot augmentation failed: %s", exc, exc_info=True)
return local_snapshot
user_block = self._join_blocks(
local_snapshot.user_block,
self._render_profile(profile),
)
memory_block = self._join_blocks(
local_snapshot.memory_block,
self._render_session_context(session_context),
self._render_search(search),
)
return MemorySnapshot(memory_block=memory_block, user_block=user_block)
async def archive_run_async(
self,
*,
user_id: str | None,
session_id: str,
user_message: str,
assistant_message: str,
) -> dict[str, Any]:
if not user_id:
return {"success": False, "error": "user_id is required for Memory Gateway archive."}
user_key = await self.client.ensure_user(user_id)
result: dict[str, Any] = {
"messages": await self.client.ingest_messages(
user_id=user_id,
user_key=user_key,
session_id=session_id,
user_message=user_message,
assistant_message=assistant_message,
)
}
if self.config.commit_on_run_complete:
result["commit"] = await self.client.commit_session(
user_id=user_id,
user_key=user_key,
session_id=session_id,
)
return result
@staticmethod
def _join_blocks(*blocks: str | None) -> str | None:
joined = "\n\n".join(block.strip() for block in blocks if block and block.strip())
return joined or None
@staticmethod
def _render_profile(payload: dict[str, Any]) -> str | None:
lines: list[str] = []
profile = payload.get("profile")
if isinstance(profile, dict):
lines.extend(_compact_mapping(profile))
elif profile:
lines.append(str(profile))
lines.extend(_compact_items(payload.get("items")))
return _section("GATEWAY USER PROFILE", lines)
@staticmethod
def _render_session_context(payload: dict[str, Any]) -> str | None:
lines: list[str] = []
context = payload.get("context")
if isinstance(context, dict):
overview = context.get("latest_archive_overview")
if overview:
lines.append(str(overview))
messages = context.get("messages")
if isinstance(messages, list):
lines.extend(_compact_items(messages))
lines.extend(_compact_items(payload.get("items")))
return _section("GATEWAY SESSION CONTEXT", lines)
@staticmethod
def _render_search(payload: dict[str, Any]) -> str | None:
return _section("GATEWAY SEARCH RESULTS", _compact_items(payload.get("items")))
def _section(title: str, lines: list[str]) -> str | None:
cleaned = [line.strip() for line in lines if line and line.strip()]
if not cleaned:
return None
return f"{title}\n" + "\n".join(cleaned)
def _compact_mapping(payload: dict[str, Any]) -> list[str]:
result: list[str] = []
for key in ("summary", "text", "content", "description"):
value = payload.get(key)
if value:
result.append(str(value))
if not result and payload:
result.append(str(payload))
return result
def _compact_items(items: Any) -> list[str]:
if not isinstance(items, list):
return []
result: list[str] = []
for item in items:
if not isinstance(item, dict):
if item:
result.append(f"- {item}")
continue
text = item.get("text") or item.get("summary") or item.get("content")
if not text:
text = str(item)
source = item.get("source_backend")
prefix = f"[{source}] " if source else ""
result.append(f"- {prefix}{text}")
return result

View File

@ -0,0 +1,54 @@
"""SQLite cache for Memory Gateway user credentials."""
from __future__ import annotations
import sqlite3
import time
from pathlib import Path
class MemoryGatewayUserStore:
"""Persist `user_id -> user_key` mappings returned by Memory Gateway."""
def __init__(self, db_path: str | Path) -> None:
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_schema()
def get_user_key(self, user_id: str) -> str | None:
with self._connect() as conn:
row = conn.execute(
"SELECT user_key FROM memory_gateway_users WHERE user_id = ?",
(user_id,),
).fetchone()
return str(row[0]) if row else None
def save_user_key(self, user_id: str, user_key: str) -> None:
now = time.time()
with self._connect() as conn:
conn.execute(
"""
INSERT INTO memory_gateway_users (user_id, user_key, created_at, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
user_key = excluded.user_key,
updated_at = excluded.updated_at
""",
(user_id, user_key, now, now),
)
def _init_schema(self) -> None:
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS memory_gateway_users (
user_id TEXT PRIMARY KEY,
user_key TEXT NOT NULL,
created_at REAL NOT NULL,
updated_at REAL NOT NULL
)
"""
)
def _connect(self) -> sqlite3.Connection:
return sqlite3.connect(str(self.db_path))

View File

@ -58,6 +58,17 @@ class MemoryService:
store.load_from_disk() store.load_from_disk()
return capture_memory_snapshot(store) return capture_memory_snapshot(store)
async def capture_snapshot_for_run_async(
self,
*,
user_id: str | None = None,
session_id: str | None = None,
query: str | None = None,
) -> MemorySnapshot:
"""Async-compatible snapshot hook used by optional memory integrations."""
return self.capture_snapshot_for_run()
def get_snapshot(self) -> MemorySnapshot: def get_snapshot(self) -> MemorySnapshot:
"""获取当前 run 应注入 system prompt 的 frozen snapshot。""" """获取当前 run 应注入 system prompt 的 frozen snapshot。"""

View File

@ -85,6 +85,49 @@ def test_config_loader_reads_channels(tmp_path) -> None:
assert channel.secrets == {"ignored_for_status": "secret-value"} assert channel.secrets == {"ignored_for_status": "secret-value"}
def test_config_loader_reads_memory_gateway_config(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"memory": {
"mode": "gateway",
"gateway": {
"baseUrl": "http://127.0.0.1:1934",
"apiKey": "gateway-key",
"defaultUserId": "default-user",
"timeoutSeconds": 12,
"snapshotSearchLimit": 7,
"commitOnRunComplete": False,
},
},
}
),
encoding="utf-8",
)
config = load_config(config_path=config_path)
assert config.memory.mode == "gateway"
assert config.memory.gateway.base_url == "http://127.0.0.1:1934"
assert config.memory.gateway.api_key == "gateway-key"
assert config.memory.gateway.default_user_id == "default-user"
assert config.memory.gateway.timeout_seconds == 12
assert config.memory.gateway.snapshot_search_limit == 7
assert config.memory.gateway.commit_on_run_complete is False
def test_config_loader_defaults_to_local_memory(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(json.dumps({}), encoding="utf-8")
config = load_config(config_path=config_path)
assert config.memory.mode == "local"
assert config.memory.gateway.base_url == ""
assert config.memory.gateway.commit_on_run_complete is True
def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None: def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None:
config_path = tmp_path / "config.json" config_path = tmp_path / "config.json"
config_path.write_text( config_path.write_text(

View File

@ -0,0 +1,130 @@
import asyncio
from pathlib import Path
from types import SimpleNamespace
import pytest
from beaver.engine import AgentLoop, EngineLoader
from beaver.engine.providers.base import LLMProvider, LLMResponse
from beaver.engine.providers.factory import ProviderBundle
from beaver.memory.curated.snapshot import MemorySnapshot
from beaver.memory.curated.store import MemoryStore
class _RecordingProvider(LLMProvider):
def __init__(self, response_text: str = "done") -> None:
super().__init__()
self.response_text = response_text
self.messages: list[list[dict]] = []
async def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
thinking_enabled: bool | None = None,
) -> LLMResponse:
self.messages.append(messages)
return LLMResponse(
content=self.response_text,
finish_reason="stop",
provider_name="stub",
model="stub-model",
)
def get_default_model(self) -> str:
return "stub-model"
class _FakeMemoryService:
def __init__(self, root: Path) -> None:
self.store = MemoryStore(root)
self.snapshot_calls: list[dict] = []
self.archive_calls: list[dict] = []
self.archive_started: asyncio.Event | None = None
self.archive_release: asyncio.Event | None = None
def initialize(self) -> None:
self.store.load_from_disk()
def get_store(self) -> MemoryStore:
return self.store
async def capture_snapshot_for_run_async(self, **kwargs) -> MemorySnapshot:
self.snapshot_calls.append(kwargs)
return MemorySnapshot(memory_block="ASYNC SNAPSHOT FROM GATEWAY", user_block=None)
async def archive_run_async(self, **kwargs) -> dict:
self.archive_calls.append(kwargs)
if self.archive_started is not None:
self.archive_started.set()
if self.archive_release is not None:
await self.archive_release.wait()
return {"status": "success"}
def _bundle(provider: _RecordingProvider) -> ProviderBundle:
return ProviderBundle(
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
main_provider=provider,
)
@pytest.mark.asyncio
async def test_agent_loop_uses_async_memory_snapshot(tmp_path) -> None:
memory_service = _FakeMemoryService(tmp_path / "memory" / "curated")
provider = _RecordingProvider("final answer")
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service))
result = await loop.process_direct(
"current task",
user_id="user-1",
session_id="session-1",
provider_bundle=_bundle(provider),
include_skill_assembly=False,
include_tools=False,
)
assert result.output_text == "final answer"
assert memory_service.snapshot_calls == [
{"user_id": "user-1", "session_id": "session-1", "query": "current task"}
]
assert "ASYNC SNAPSHOT FROM GATEWAY" in provider.messages[0][0]["content"]
loop.close()
@pytest.mark.asyncio
async def test_agent_loop_archives_memory_gateway_run_in_background(tmp_path) -> None:
memory_service = _FakeMemoryService(tmp_path / "memory" / "curated")
memory_service.archive_started = asyncio.Event()
memory_service.archive_release = asyncio.Event()
provider = _RecordingProvider("assistant final")
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service))
result = await loop.process_direct(
"user asks",
user_id="user-1",
session_id="session-1",
provider_bundle=_bundle(provider),
include_skill_assembly=False,
include_tools=False,
)
assert result.finish_reason == "stop"
await asyncio.wait_for(memory_service.archive_started.wait(), timeout=1)
assert memory_service.archive_calls == [
{
"user_id": "user-1",
"session_id": "session-1",
"user_message": "user asks",
"assistant_message": "assistant final",
}
]
memory_service.archive_release.set()
await asyncio.sleep(0)
events = loop.boot().session_manager.get_run_event_records(result.session_id, result.run_id)
assert any(event.event_type == "memory_gateway_archive_completed" for event in events)
loop.close()

View File

@ -0,0 +1,100 @@
import json
import httpx
import pytest
from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore
@pytest.mark.asyncio
async def test_memory_gateway_client_uses_cached_user_key(tmp_path) -> None:
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
store.save_user_key("user-1", "cached-key")
requests: list[httpx.Request] = []
async def handler(request: httpx.Request) -> httpx.Response:
requests.append(request)
return httpx.Response(500, json={"error": "should not be called"})
client = MemoryGatewayClient(
base_url="http://gateway.test",
store=store,
transport=httpx.MockTransport(handler),
)
user_key = await client.ensure_user("user-1")
assert user_key == "cached-key"
assert requests == []
@pytest.mark.asyncio
async def test_memory_gateway_client_creates_and_caches_user_key(tmp_path) -> None:
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
requests: list[httpx.Request] = []
async def handler(request: httpx.Request) -> httpx.Response:
requests.append(request)
assert request.method == "POST"
assert request.url.path == "/memory-system/users"
assert json.loads(request.content) == {"user_id": "user-1"}
return httpx.Response(
200,
json={
"status": "success",
"account": {
"status": "ok",
"result": {"user_key": "created-key"},
},
},
)
client = MemoryGatewayClient(
base_url="http://gateway.test",
store=store,
transport=httpx.MockTransport(handler),
)
user_key = await client.ensure_user("user-1")
assert user_key == "created-key"
assert store.get_user_key("user-1") == "created-key"
assert len(requests) == 1
@pytest.mark.asyncio
async def test_memory_gateway_client_ingests_messages_with_user_key(tmp_path) -> None:
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
requests: list[httpx.Request] = []
async def handler(request: httpx.Request) -> httpx.Response:
requests.append(request)
assert request.method == "POST"
assert request.url.path == "/memory-system/messages"
assert request.headers["X-API-Key"] == "gateway-api-key"
assert json.loads(request.content) == {
"user_id": "user-1",
"user_key": "user-key",
"session_id": "session-1",
"user_message": "hello",
"assistant_message": "hi",
}
return httpx.Response(200, json={"status": "success", "message_count": 2})
client = MemoryGatewayClient(
base_url="http://gateway.test",
api_key="gateway-api-key",
store=store,
transport=httpx.MockTransport(handler),
)
result = await client.ingest_messages(
user_id="user-1",
user_key="user-key",
session_id="session-1",
user_message="hello",
assistant_message="hi",
)
assert result["status"] == "success"
assert len(requests) == 1

View File

@ -0,0 +1,117 @@
import pytest
from beaver.engine import EngineLoader
from beaver.foundation.config.schema import MemoryGatewayConfig
from beaver.memory.gateway.service import GatewayAugmentedMemoryService
from beaver.services.memory_service import MemoryService
class _FakeGatewayClient:
async def ensure_user(self, user_id: str) -> str:
return f"{user_id}-key"
async def get_profile(self, **kwargs):
return {
"status": "success",
"profile": {"summary": "用户喜欢拿铁。"},
"items": [{"summary": "用户偏好中文回复。"}],
}
async def get_session_context(self, **kwargs):
return {
"status": "success",
"context": {"latest_archive_overview": "上次讨论了 memory gateway 接入。"},
"items": [{"summary": "用户要求保留本地 MEMORY.md。"}],
}
async def search(self, **kwargs):
return {
"status": "success",
"items": [
{
"source_backend": "openviking",
"text": "需要异步写入 /memory-system/messages。",
}
],
}
class _FailingGatewayClient(_FakeGatewayClient):
async def ensure_user(self, user_id: str) -> str:
raise RuntimeError("gateway unavailable")
@pytest.mark.asyncio
async def test_gateway_snapshot_keeps_local_memory_and_adds_gateway_sections(tmp_path) -> None:
local = MemoryService(tmp_path / "memory" / "curated")
local.initialize()
local.get_store().add("memory", "本地项目约定:默认用中文解释。")
local.get_store().add("user", "本地用户画像:用户关注记忆系统。")
service = GatewayAugmentedMemoryService(
local_service=local,
client=_FakeGatewayClient(),
config=MemoryGatewayConfig(snapshot_search_limit=3),
)
snapshot = await service.capture_snapshot_for_run_async(
user_id="user-1",
session_id="session-1",
query="如何接入 memory gateway",
)
prompt = "\n".join(snapshot.as_prompt_sections())
assert "本地项目约定:默认用中文解释。" in prompt
assert "本地用户画像:用户关注记忆系统。" in prompt
assert "GATEWAY USER PROFILE" in prompt
assert "用户喜欢拿铁。" in prompt
assert "GATEWAY SESSION CONTEXT" in prompt
assert "上次讨论了 memory gateway 接入。" in prompt
assert "GATEWAY SEARCH RESULTS" in prompt
assert "需要异步写入 /memory-system/messages。" in prompt
@pytest.mark.asyncio
async def test_gateway_snapshot_falls_back_to_local_memory_on_gateway_failure(tmp_path) -> None:
local = MemoryService(tmp_path / "memory" / "curated")
local.initialize()
local.get_store().add("memory", "本地记忆仍然可用。")
service = GatewayAugmentedMemoryService(
local_service=local,
client=_FailingGatewayClient(),
config=MemoryGatewayConfig(snapshot_search_limit=3),
)
snapshot = await service.capture_snapshot_for_run_async(
user_id="user-1",
session_id="session-1",
query="任何问题",
)
prompt = "\n".join(snapshot.as_prompt_sections())
assert "本地记忆仍然可用。" in prompt
assert "GATEWAY USER PROFILE" not in prompt
def test_engine_loader_uses_gateway_memory_service_without_replacing_tools(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
"""
{
"agents": {"defaults": {"workspace": "%s"}},
"memory": {
"mode": "gateway",
"gateway": {"baseUrl": "http://gateway.test", "defaultUserId": "default-user"}
}
}
"""
% str(tmp_path / "workspace"),
encoding="utf-8",
)
loader = EngineLoader(config_path=config_path)
loaded = loader.load()
assert isinstance(loaded.memory_service, GatewayAugmentedMemoryService)
assert "memory" in loaded.tools
assert "session_search" in loaded.tools
loaded.close()