feat: add Memory Gateway integration with async support for memory snapshots and user management

This commit is contained in:
2026-06-04 17:00:02 +08:00
parent 236ac19789
commit d93ca62990
13 changed files with 949 additions and 2 deletions

View File

@ -85,6 +85,49 @@ def test_config_loader_reads_channels(tmp_path) -> None:
assert channel.secrets == {"ignored_for_status": "secret-value"}
def test_config_loader_reads_memory_gateway_config(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"memory": {
"mode": "gateway",
"gateway": {
"baseUrl": "http://127.0.0.1:1934",
"apiKey": "gateway-key",
"defaultUserId": "default-user",
"timeoutSeconds": 12,
"snapshotSearchLimit": 7,
"commitOnRunComplete": False,
},
},
}
),
encoding="utf-8",
)
config = load_config(config_path=config_path)
assert config.memory.mode == "gateway"
assert config.memory.gateway.base_url == "http://127.0.0.1:1934"
assert config.memory.gateway.api_key == "gateway-key"
assert config.memory.gateway.default_user_id == "default-user"
assert config.memory.gateway.timeout_seconds == 12
assert config.memory.gateway.snapshot_search_limit == 7
assert config.memory.gateway.commit_on_run_complete is False
def test_config_loader_defaults_to_local_memory(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(json.dumps({}), encoding="utf-8")
config = load_config(config_path=config_path)
assert config.memory.mode == "local"
assert config.memory.gateway.base_url == ""
assert config.memory.gateway.commit_on_run_complete is True
def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(

View File

@ -0,0 +1,130 @@
import asyncio
from pathlib import Path
from types import SimpleNamespace
import pytest
from beaver.engine import AgentLoop, EngineLoader
from beaver.engine.providers.base import LLMProvider, LLMResponse
from beaver.engine.providers.factory import ProviderBundle
from beaver.memory.curated.snapshot import MemorySnapshot
from beaver.memory.curated.store import MemoryStore
class _RecordingProvider(LLMProvider):
def __init__(self, response_text: str = "done") -> None:
super().__init__()
self.response_text = response_text
self.messages: list[list[dict]] = []
async def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
thinking_enabled: bool | None = None,
) -> LLMResponse:
self.messages.append(messages)
return LLMResponse(
content=self.response_text,
finish_reason="stop",
provider_name="stub",
model="stub-model",
)
def get_default_model(self) -> str:
return "stub-model"
class _FakeMemoryService:
def __init__(self, root: Path) -> None:
self.store = MemoryStore(root)
self.snapshot_calls: list[dict] = []
self.archive_calls: list[dict] = []
self.archive_started: asyncio.Event | None = None
self.archive_release: asyncio.Event | None = None
def initialize(self) -> None:
self.store.load_from_disk()
def get_store(self) -> MemoryStore:
return self.store
async def capture_snapshot_for_run_async(self, **kwargs) -> MemorySnapshot:
self.snapshot_calls.append(kwargs)
return MemorySnapshot(memory_block="ASYNC SNAPSHOT FROM GATEWAY", user_block=None)
async def archive_run_async(self, **kwargs) -> dict:
self.archive_calls.append(kwargs)
if self.archive_started is not None:
self.archive_started.set()
if self.archive_release is not None:
await self.archive_release.wait()
return {"status": "success"}
def _bundle(provider: _RecordingProvider) -> ProviderBundle:
return ProviderBundle(
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
main_provider=provider,
)
@pytest.mark.asyncio
async def test_agent_loop_uses_async_memory_snapshot(tmp_path) -> None:
memory_service = _FakeMemoryService(tmp_path / "memory" / "curated")
provider = _RecordingProvider("final answer")
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service))
result = await loop.process_direct(
"current task",
user_id="user-1",
session_id="session-1",
provider_bundle=_bundle(provider),
include_skill_assembly=False,
include_tools=False,
)
assert result.output_text == "final answer"
assert memory_service.snapshot_calls == [
{"user_id": "user-1", "session_id": "session-1", "query": "current task"}
]
assert "ASYNC SNAPSHOT FROM GATEWAY" in provider.messages[0][0]["content"]
loop.close()
@pytest.mark.asyncio
async def test_agent_loop_archives_memory_gateway_run_in_background(tmp_path) -> None:
memory_service = _FakeMemoryService(tmp_path / "memory" / "curated")
memory_service.archive_started = asyncio.Event()
memory_service.archive_release = asyncio.Event()
provider = _RecordingProvider("assistant final")
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service))
result = await loop.process_direct(
"user asks",
user_id="user-1",
session_id="session-1",
provider_bundle=_bundle(provider),
include_skill_assembly=False,
include_tools=False,
)
assert result.finish_reason == "stop"
await asyncio.wait_for(memory_service.archive_started.wait(), timeout=1)
assert memory_service.archive_calls == [
{
"user_id": "user-1",
"session_id": "session-1",
"user_message": "user asks",
"assistant_message": "assistant final",
}
]
memory_service.archive_release.set()
await asyncio.sleep(0)
events = loop.boot().session_manager.get_run_event_records(result.session_id, result.run_id)
assert any(event.event_type == "memory_gateway_archive_completed" for event in events)
loop.close()

View File

@ -0,0 +1,100 @@
import json
import httpx
import pytest
from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore
@pytest.mark.asyncio
async def test_memory_gateway_client_uses_cached_user_key(tmp_path) -> None:
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
store.save_user_key("user-1", "cached-key")
requests: list[httpx.Request] = []
async def handler(request: httpx.Request) -> httpx.Response:
requests.append(request)
return httpx.Response(500, json={"error": "should not be called"})
client = MemoryGatewayClient(
base_url="http://gateway.test",
store=store,
transport=httpx.MockTransport(handler),
)
user_key = await client.ensure_user("user-1")
assert user_key == "cached-key"
assert requests == []
@pytest.mark.asyncio
async def test_memory_gateway_client_creates_and_caches_user_key(tmp_path) -> None:
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
requests: list[httpx.Request] = []
async def handler(request: httpx.Request) -> httpx.Response:
requests.append(request)
assert request.method == "POST"
assert request.url.path == "/memory-system/users"
assert json.loads(request.content) == {"user_id": "user-1"}
return httpx.Response(
200,
json={
"status": "success",
"account": {
"status": "ok",
"result": {"user_key": "created-key"},
},
},
)
client = MemoryGatewayClient(
base_url="http://gateway.test",
store=store,
transport=httpx.MockTransport(handler),
)
user_key = await client.ensure_user("user-1")
assert user_key == "created-key"
assert store.get_user_key("user-1") == "created-key"
assert len(requests) == 1
@pytest.mark.asyncio
async def test_memory_gateway_client_ingests_messages_with_user_key(tmp_path) -> None:
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
requests: list[httpx.Request] = []
async def handler(request: httpx.Request) -> httpx.Response:
requests.append(request)
assert request.method == "POST"
assert request.url.path == "/memory-system/messages"
assert request.headers["X-API-Key"] == "gateway-api-key"
assert json.loads(request.content) == {
"user_id": "user-1",
"user_key": "user-key",
"session_id": "session-1",
"user_message": "hello",
"assistant_message": "hi",
}
return httpx.Response(200, json={"status": "success", "message_count": 2})
client = MemoryGatewayClient(
base_url="http://gateway.test",
api_key="gateway-api-key",
store=store,
transport=httpx.MockTransport(handler),
)
result = await client.ingest_messages(
user_id="user-1",
user_key="user-key",
session_id="session-1",
user_message="hello",
assistant_message="hi",
)
assert result["status"] == "success"
assert len(requests) == 1

View File

@ -0,0 +1,117 @@
import pytest
from beaver.engine import EngineLoader
from beaver.foundation.config.schema import MemoryGatewayConfig
from beaver.memory.gateway.service import GatewayAugmentedMemoryService
from beaver.services.memory_service import MemoryService
class _FakeGatewayClient:
async def ensure_user(self, user_id: str) -> str:
return f"{user_id}-key"
async def get_profile(self, **kwargs):
return {
"status": "success",
"profile": {"summary": "用户喜欢拿铁。"},
"items": [{"summary": "用户偏好中文回复。"}],
}
async def get_session_context(self, **kwargs):
return {
"status": "success",
"context": {"latest_archive_overview": "上次讨论了 memory gateway 接入。"},
"items": [{"summary": "用户要求保留本地 MEMORY.md。"}],
}
async def search(self, **kwargs):
return {
"status": "success",
"items": [
{
"source_backend": "openviking",
"text": "需要异步写入 /memory-system/messages。",
}
],
}
class _FailingGatewayClient(_FakeGatewayClient):
async def ensure_user(self, user_id: str) -> str:
raise RuntimeError("gateway unavailable")
@pytest.mark.asyncio
async def test_gateway_snapshot_keeps_local_memory_and_adds_gateway_sections(tmp_path) -> None:
local = MemoryService(tmp_path / "memory" / "curated")
local.initialize()
local.get_store().add("memory", "本地项目约定:默认用中文解释。")
local.get_store().add("user", "本地用户画像:用户关注记忆系统。")
service = GatewayAugmentedMemoryService(
local_service=local,
client=_FakeGatewayClient(),
config=MemoryGatewayConfig(snapshot_search_limit=3),
)
snapshot = await service.capture_snapshot_for_run_async(
user_id="user-1",
session_id="session-1",
query="如何接入 memory gateway",
)
prompt = "\n".join(snapshot.as_prompt_sections())
assert "本地项目约定:默认用中文解释。" in prompt
assert "本地用户画像:用户关注记忆系统。" in prompt
assert "GATEWAY USER PROFILE" in prompt
assert "用户喜欢拿铁。" in prompt
assert "GATEWAY SESSION CONTEXT" in prompt
assert "上次讨论了 memory gateway 接入。" in prompt
assert "GATEWAY SEARCH RESULTS" in prompt
assert "需要异步写入 /memory-system/messages。" in prompt
@pytest.mark.asyncio
async def test_gateway_snapshot_falls_back_to_local_memory_on_gateway_failure(tmp_path) -> None:
local = MemoryService(tmp_path / "memory" / "curated")
local.initialize()
local.get_store().add("memory", "本地记忆仍然可用。")
service = GatewayAugmentedMemoryService(
local_service=local,
client=_FailingGatewayClient(),
config=MemoryGatewayConfig(snapshot_search_limit=3),
)
snapshot = await service.capture_snapshot_for_run_async(
user_id="user-1",
session_id="session-1",
query="任何问题",
)
prompt = "\n".join(snapshot.as_prompt_sections())
assert "本地记忆仍然可用。" in prompt
assert "GATEWAY USER PROFILE" not in prompt
def test_engine_loader_uses_gateway_memory_service_without_replacing_tools(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
"""
{
"agents": {"defaults": {"workspace": "%s"}},
"memory": {
"mode": "gateway",
"gateway": {"baseUrl": "http://gateway.test", "defaultUserId": "default-user"}
}
}
"""
% str(tmp_path / "workspace"),
encoding="utf-8",
)
loader = EngineLoader(config_path=config_path)
loaded = loader.load()
assert isinstance(loaded.memory_service, GatewayAugmentedMemoryService)
assert "memory" in loaded.tools
assert "session_search" in loaded.tools
loaded.close()