feat: add Memory Gateway integration with async support for memory snapshots and user management
This commit is contained in:
@ -85,6 +85,49 @@ def test_config_loader_reads_channels(tmp_path) -> None:
|
||||
assert channel.secrets == {"ignored_for_status": "secret-value"}
|
||||
|
||||
|
||||
def test_config_loader_reads_memory_gateway_config(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"memory": {
|
||||
"mode": "gateway",
|
||||
"gateway": {
|
||||
"baseUrl": "http://127.0.0.1:1934",
|
||||
"apiKey": "gateway-key",
|
||||
"defaultUserId": "default-user",
|
||||
"timeoutSeconds": 12,
|
||||
"snapshotSearchLimit": 7,
|
||||
"commitOnRunComplete": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path=config_path)
|
||||
|
||||
assert config.memory.mode == "gateway"
|
||||
assert config.memory.gateway.base_url == "http://127.0.0.1:1934"
|
||||
assert config.memory.gateway.api_key == "gateway-key"
|
||||
assert config.memory.gateway.default_user_id == "default-user"
|
||||
assert config.memory.gateway.timeout_seconds == 12
|
||||
assert config.memory.gateway.snapshot_search_limit == 7
|
||||
assert config.memory.gateway.commit_on_run_complete is False
|
||||
|
||||
|
||||
def test_config_loader_defaults_to_local_memory(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
|
||||
config = load_config(config_path=config_path)
|
||||
|
||||
assert config.memory.mode == "local"
|
||||
assert config.memory.gateway.base_url == ""
|
||||
assert config.memory.gateway.commit_on_run_complete is True
|
||||
|
||||
|
||||
def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
|
||||
130
app-instance/backend/tests/unit/test_memory_gateway_archive.py
Normal file
130
app-instance/backend/tests/unit/test_memory_gateway_archive.py
Normal file
@ -0,0 +1,130 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from beaver.engine import AgentLoop, EngineLoader
|
||||
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
||||
from beaver.engine.providers.factory import ProviderBundle
|
||||
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||
from beaver.memory.curated.store import MemoryStore
|
||||
|
||||
|
||||
class _RecordingProvider(LLMProvider):
|
||||
def __init__(self, response_text: str = "done") -> None:
|
||||
super().__init__()
|
||||
self.response_text = response_text
|
||||
self.messages: list[list[dict]] = []
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
thinking_enabled: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
self.messages.append(messages)
|
||||
return LLMResponse(
|
||||
content=self.response_text,
|
||||
finish_reason="stop",
|
||||
provider_name="stub",
|
||||
model="stub-model",
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "stub-model"
|
||||
|
||||
|
||||
class _FakeMemoryService:
|
||||
def __init__(self, root: Path) -> None:
|
||||
self.store = MemoryStore(root)
|
||||
self.snapshot_calls: list[dict] = []
|
||||
self.archive_calls: list[dict] = []
|
||||
self.archive_started: asyncio.Event | None = None
|
||||
self.archive_release: asyncio.Event | None = None
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.store.load_from_disk()
|
||||
|
||||
def get_store(self) -> MemoryStore:
|
||||
return self.store
|
||||
|
||||
async def capture_snapshot_for_run_async(self, **kwargs) -> MemorySnapshot:
|
||||
self.snapshot_calls.append(kwargs)
|
||||
return MemorySnapshot(memory_block="ASYNC SNAPSHOT FROM GATEWAY", user_block=None)
|
||||
|
||||
async def archive_run_async(self, **kwargs) -> dict:
|
||||
self.archive_calls.append(kwargs)
|
||||
if self.archive_started is not None:
|
||||
self.archive_started.set()
|
||||
if self.archive_release is not None:
|
||||
await self.archive_release.wait()
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
def _bundle(provider: _RecordingProvider) -> ProviderBundle:
|
||||
return ProviderBundle(
|
||||
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||
main_provider=provider,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_uses_async_memory_snapshot(tmp_path) -> None:
|
||||
memory_service = _FakeMemoryService(tmp_path / "memory" / "curated")
|
||||
provider = _RecordingProvider("final answer")
|
||||
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service))
|
||||
|
||||
result = await loop.process_direct(
|
||||
"current task",
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
provider_bundle=_bundle(provider),
|
||||
include_skill_assembly=False,
|
||||
include_tools=False,
|
||||
)
|
||||
|
||||
assert result.output_text == "final answer"
|
||||
assert memory_service.snapshot_calls == [
|
||||
{"user_id": "user-1", "session_id": "session-1", "query": "current task"}
|
||||
]
|
||||
assert "ASYNC SNAPSHOT FROM GATEWAY" in provider.messages[0][0]["content"]
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_archives_memory_gateway_run_in_background(tmp_path) -> None:
|
||||
memory_service = _FakeMemoryService(tmp_path / "memory" / "curated")
|
||||
memory_service.archive_started = asyncio.Event()
|
||||
memory_service.archive_release = asyncio.Event()
|
||||
provider = _RecordingProvider("assistant final")
|
||||
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service))
|
||||
|
||||
result = await loop.process_direct(
|
||||
"user asks",
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
provider_bundle=_bundle(provider),
|
||||
include_skill_assembly=False,
|
||||
include_tools=False,
|
||||
)
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
await asyncio.wait_for(memory_service.archive_started.wait(), timeout=1)
|
||||
assert memory_service.archive_calls == [
|
||||
{
|
||||
"user_id": "user-1",
|
||||
"session_id": "session-1",
|
||||
"user_message": "user asks",
|
||||
"assistant_message": "assistant final",
|
||||
}
|
||||
]
|
||||
|
||||
memory_service.archive_release.set()
|
||||
await asyncio.sleep(0)
|
||||
events = loop.boot().session_manager.get_run_event_records(result.session_id, result.run_id)
|
||||
assert any(event.event_type == "memory_gateway_archive_completed" for event in events)
|
||||
loop.close()
|
||||
100
app-instance/backend/tests/unit/test_memory_gateway_client.py
Normal file
100
app-instance/backend/tests/unit/test_memory_gateway_client.py
Normal file
@ -0,0 +1,100 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_gateway_client_uses_cached_user_key(tmp_path) -> None:
|
||||
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
|
||||
store.save_user_key("user-1", "cached-key")
|
||||
requests: list[httpx.Request] = []
|
||||
|
||||
async def handler(request: httpx.Request) -> httpx.Response:
|
||||
requests.append(request)
|
||||
return httpx.Response(500, json={"error": "should not be called"})
|
||||
|
||||
client = MemoryGatewayClient(
|
||||
base_url="http://gateway.test",
|
||||
store=store,
|
||||
transport=httpx.MockTransport(handler),
|
||||
)
|
||||
|
||||
user_key = await client.ensure_user("user-1")
|
||||
|
||||
assert user_key == "cached-key"
|
||||
assert requests == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_gateway_client_creates_and_caches_user_key(tmp_path) -> None:
|
||||
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
|
||||
requests: list[httpx.Request] = []
|
||||
|
||||
async def handler(request: httpx.Request) -> httpx.Response:
|
||||
requests.append(request)
|
||||
assert request.method == "POST"
|
||||
assert request.url.path == "/memory-system/users"
|
||||
assert json.loads(request.content) == {"user_id": "user-1"}
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"status": "success",
|
||||
"account": {
|
||||
"status": "ok",
|
||||
"result": {"user_key": "created-key"},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
client = MemoryGatewayClient(
|
||||
base_url="http://gateway.test",
|
||||
store=store,
|
||||
transport=httpx.MockTransport(handler),
|
||||
)
|
||||
|
||||
user_key = await client.ensure_user("user-1")
|
||||
|
||||
assert user_key == "created-key"
|
||||
assert store.get_user_key("user-1") == "created-key"
|
||||
assert len(requests) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_gateway_client_ingests_messages_with_user_key(tmp_path) -> None:
|
||||
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
|
||||
requests: list[httpx.Request] = []
|
||||
|
||||
async def handler(request: httpx.Request) -> httpx.Response:
|
||||
requests.append(request)
|
||||
assert request.method == "POST"
|
||||
assert request.url.path == "/memory-system/messages"
|
||||
assert request.headers["X-API-Key"] == "gateway-api-key"
|
||||
assert json.loads(request.content) == {
|
||||
"user_id": "user-1",
|
||||
"user_key": "user-key",
|
||||
"session_id": "session-1",
|
||||
"user_message": "hello",
|
||||
"assistant_message": "hi",
|
||||
}
|
||||
return httpx.Response(200, json={"status": "success", "message_count": 2})
|
||||
|
||||
client = MemoryGatewayClient(
|
||||
base_url="http://gateway.test",
|
||||
api_key="gateway-api-key",
|
||||
store=store,
|
||||
transport=httpx.MockTransport(handler),
|
||||
)
|
||||
|
||||
result = await client.ingest_messages(
|
||||
user_id="user-1",
|
||||
user_key="user-key",
|
||||
session_id="session-1",
|
||||
user_message="hello",
|
||||
assistant_message="hi",
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert len(requests) == 1
|
||||
117
app-instance/backend/tests/unit/test_memory_gateway_snapshot.py
Normal file
117
app-instance/backend/tests/unit/test_memory_gateway_snapshot.py
Normal file
@ -0,0 +1,117 @@
|
||||
import pytest
|
||||
|
||||
from beaver.engine import EngineLoader
|
||||
from beaver.foundation.config.schema import MemoryGatewayConfig
|
||||
from beaver.memory.gateway.service import GatewayAugmentedMemoryService
|
||||
from beaver.services.memory_service import MemoryService
|
||||
|
||||
|
||||
class _FakeGatewayClient:
|
||||
async def ensure_user(self, user_id: str) -> str:
|
||||
return f"{user_id}-key"
|
||||
|
||||
async def get_profile(self, **kwargs):
|
||||
return {
|
||||
"status": "success",
|
||||
"profile": {"summary": "用户喜欢拿铁。"},
|
||||
"items": [{"summary": "用户偏好中文回复。"}],
|
||||
}
|
||||
|
||||
async def get_session_context(self, **kwargs):
|
||||
return {
|
||||
"status": "success",
|
||||
"context": {"latest_archive_overview": "上次讨论了 memory gateway 接入。"},
|
||||
"items": [{"summary": "用户要求保留本地 MEMORY.md。"}],
|
||||
}
|
||||
|
||||
async def search(self, **kwargs):
|
||||
return {
|
||||
"status": "success",
|
||||
"items": [
|
||||
{
|
||||
"source_backend": "openviking",
|
||||
"text": "需要异步写入 /memory-system/messages。",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class _FailingGatewayClient(_FakeGatewayClient):
|
||||
async def ensure_user(self, user_id: str) -> str:
|
||||
raise RuntimeError("gateway unavailable")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_snapshot_keeps_local_memory_and_adds_gateway_sections(tmp_path) -> None:
|
||||
local = MemoryService(tmp_path / "memory" / "curated")
|
||||
local.initialize()
|
||||
local.get_store().add("memory", "本地项目约定:默认用中文解释。")
|
||||
local.get_store().add("user", "本地用户画像:用户关注记忆系统。")
|
||||
service = GatewayAugmentedMemoryService(
|
||||
local_service=local,
|
||||
client=_FakeGatewayClient(),
|
||||
config=MemoryGatewayConfig(snapshot_search_limit=3),
|
||||
)
|
||||
|
||||
snapshot = await service.capture_snapshot_for_run_async(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
query="如何接入 memory gateway",
|
||||
)
|
||||
prompt = "\n".join(snapshot.as_prompt_sections())
|
||||
|
||||
assert "本地项目约定:默认用中文解释。" in prompt
|
||||
assert "本地用户画像:用户关注记忆系统。" in prompt
|
||||
assert "GATEWAY USER PROFILE" in prompt
|
||||
assert "用户喜欢拿铁。" in prompt
|
||||
assert "GATEWAY SESSION CONTEXT" in prompt
|
||||
assert "上次讨论了 memory gateway 接入。" in prompt
|
||||
assert "GATEWAY SEARCH RESULTS" in prompt
|
||||
assert "需要异步写入 /memory-system/messages。" in prompt
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_snapshot_falls_back_to_local_memory_on_gateway_failure(tmp_path) -> None:
|
||||
local = MemoryService(tmp_path / "memory" / "curated")
|
||||
local.initialize()
|
||||
local.get_store().add("memory", "本地记忆仍然可用。")
|
||||
service = GatewayAugmentedMemoryService(
|
||||
local_service=local,
|
||||
client=_FailingGatewayClient(),
|
||||
config=MemoryGatewayConfig(snapshot_search_limit=3),
|
||||
)
|
||||
|
||||
snapshot = await service.capture_snapshot_for_run_async(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
query="任何问题",
|
||||
)
|
||||
prompt = "\n".join(snapshot.as_prompt_sections())
|
||||
|
||||
assert "本地记忆仍然可用。" in prompt
|
||||
assert "GATEWAY USER PROFILE" not in prompt
|
||||
|
||||
|
||||
def test_engine_loader_uses_gateway_memory_service_without_replacing_tools(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
"""
|
||||
{
|
||||
"agents": {"defaults": {"workspace": "%s"}},
|
||||
"memory": {
|
||||
"mode": "gateway",
|
||||
"gateway": {"baseUrl": "http://gateway.test", "defaultUserId": "default-user"}
|
||||
}
|
||||
}
|
||||
"""
|
||||
% str(tmp_path / "workspace"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
loader = EngineLoader(config_path=config_path)
|
||||
loaded = loader.load()
|
||||
|
||||
assert isinstance(loaded.memory_service, GatewayAugmentedMemoryService)
|
||||
assert "memory" in loaded.tools
|
||||
assert "session_search" in loaded.tools
|
||||
loaded.close()
|
||||
Reference in New Issue
Block a user