feat: add Memory Gateway integration with async support for memory snapshots and user management

This commit is contained in:
2026-06-04 17:00:02 +08:00
parent 236ac19789
commit d93ca62990
13 changed files with 949 additions and 2 deletions

View File

@ -14,6 +14,8 @@ from beaver.engine.session import SessionManager
from beaver.foundation.config import BeaverConfig, load_config
from beaver.integrations.mcp import MCPConnectionManager
from beaver.memory.curated.store import MemoryStore
from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore
from beaver.memory.gateway.service import GatewayAugmentedMemoryService
from beaver.memory.runs import RunMemoryStore
from beaver.memory.skills import SkillLearningStore
from beaver.services.memory_service import MemoryService
@ -206,7 +208,26 @@ class EngineLoader:
curated_root = workspace / "memory" / "curated"
curated_memory_store = self._curated_memory_store or MemoryStore(curated_root)
memory_service = self._memory_service or MemoryService(curated_root, store=curated_memory_store)
if self._memory_service is not None:
memory_service = self._memory_service
else:
local_memory_service = MemoryService(curated_root, store=curated_memory_store)
memory_cfg = self.config.memory
if memory_cfg.mode == "gateway" and memory_cfg.gateway.base_url:
gateway_store = MemoryGatewayUserStore(workspace / "memory" / "gateway" / "state.db")
gateway_client = MemoryGatewayClient(
base_url=memory_cfg.gateway.base_url,
api_key=memory_cfg.gateway.api_key,
timeout_seconds=memory_cfg.gateway.timeout_seconds,
store=gateway_store,
)
memory_service = GatewayAugmentedMemoryService(
local_service=local_memory_service,
client=gateway_client,
config=memory_cfg.gateway,
)
else:
memory_service = local_memory_service
memory_service.initialize()
run_memory_store = self._run_memory_store or RunMemoryStore(workspace / "memory" / "runs")
skill_learning_store = self._skill_learning_store or SkillLearningStore(workspace / "memory" / "skills")

View File

@ -380,10 +380,15 @@ class AgentLoop:
resolved_max_tool_iterations = (
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
)
resolved_memory_user_id = user_id or config.memory.gateway.default_user_id or None
# 每个 run 都捕获自己的 frozen snapshot不能依赖 MemoryService
# 上的共享 `_snapshot`,否则 parallel team runs 会互相覆盖。
memory_snapshot = memory_service.capture_snapshot_for_run()
memory_snapshot = await memory_service.capture_snapshot_for_run_async(
user_id=resolved_memory_user_id,
session_id=resolved_session_id,
query=task,
)
if parent_session_id:
session_manager.ensure_session(
@ -834,6 +839,22 @@ class AgentLoop:
model=final_model,
user_id=user_id,
)
archive_fn = getattr(memory_service, "archive_run_async", None)
if archive_fn is not None and resolved_memory_user_id:
asyncio.create_task(
self._archive_memory_gateway_run(
archive_fn=archive_fn,
session_manager=session_manager,
session_id=resolved_session_id,
run_id=resolved_run_id,
user_id=resolved_memory_user_id,
user_message=task,
assistant_message=final_text,
source=source,
title=title,
model=final_model,
)
)
self._record_run_receipts(
skill_learning_service=skill_learning_service,
session_manager=session_manager,
@ -1191,6 +1212,57 @@ class AgentLoop:
context_visible=False,
)
@staticmethod
async def _archive_memory_gateway_run(
*,
archive_fn: Any,
session_manager: Any,
session_id: str,
run_id: str,
user_id: str | None,
user_message: str,
assistant_message: str,
source: str,
title: str | None,
model: str | None,
) -> None:
try:
result = await archive_fn(
user_id=user_id,
session_id=session_id,
user_message=user_message,
assistant_message=assistant_message,
)
except Exception as exc: # noqa: BLE001 - archive must not change completed run result
session_manager.append_message(
session_id,
run_id=run_id,
role="system",
event_type="memory_gateway_archive_failed",
event_payload={"error": str(exc)},
content=f"Memory Gateway archive failed: {exc}",
context_visible=False,
source=source,
title=title,
model=model,
user_id=user_id,
)
return
session_manager.append_message(
session_id,
run_id=run_id,
role="system",
event_type="memory_gateway_archive_completed",
event_payload={"result": result},
content="Memory Gateway archive completed.",
context_visible=False,
source=source,
title=title,
model=model,
user_id=user_id,
)
@staticmethod
def _utc_now() -> str:
return datetime.now(timezone.utc).isoformat()

View File

@ -15,6 +15,8 @@ from .schema import (
BeaverConfig,
ChannelConfig,
EmbeddingConfig,
MemoryConfig,
MemoryGatewayConfig,
MCPServerConfig,
ProviderConfig,
ToolsConfig,
@ -76,6 +78,7 @@ def load_config(
authz=_parse_authz(data.get("authz")),
channels=_parse_channels(data.get("channels")),
backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")),
memory=_parse_memory(data.get("memory")),
config_path=path,
)
@ -251,6 +254,35 @@ def _parse_backend_identity(raw: Any) -> BackendIdentityConfig:
)
def _parse_memory(raw: Any) -> MemoryConfig:
data = _as_dict(raw)
gateway = _as_dict(data.get("gateway"))
mode = (_string(data.get("mode")) or "local").lower()
if mode not in {"local", "gateway"}:
mode = "local"
return MemoryConfig(
mode=mode,
gateway=MemoryGatewayConfig(
base_url=_string(gateway.get("baseUrl") or gateway.get("base_url")) or "",
api_key=_string(gateway.get("apiKey") or gateway.get("api_key")) or "",
default_user_id=_string(gateway.get("defaultUserId") or gateway.get("default_user_id")) or "",
timeout_seconds=_float(gateway.get("timeoutSeconds") or gateway.get("timeout_seconds")) or 15.0,
snapshot_search_limit=int(
_float(gateway.get("snapshotSearchLimit") or gateway.get("snapshot_search_limit"))
or 5
),
commit_on_run_complete=_bool(
(
gateway.get("commitOnRunComplete")
if "commitOnRunComplete" in gateway
else gateway.get("commit_on_run_complete")
),
default=True,
),
),
)
def _as_dict(value: Any) -> dict[str, Any]:
return value if isinstance(value, dict) else {}

View File

@ -115,6 +115,26 @@ class BackendIdentityConfig:
public_base_url: str = ""
@dataclass(slots=True)
class MemoryGatewayConfig:
"""Memory Gateway integration settings."""
base_url: str = ""
api_key: str = ""
default_user_id: str = ""
timeout_seconds: float = 15.0
snapshot_search_limit: int = 5
commit_on_run_complete: bool = True
@dataclass(slots=True)
class MemoryConfig:
"""Runtime memory strategy configuration."""
mode: str = "local"
gateway: MemoryGatewayConfig = field(default_factory=MemoryGatewayConfig)
@dataclass(slots=True)
class BeaverConfig:
"""Config loaded once per backend sandbox instance."""
@ -126,6 +146,7 @@ class BeaverConfig:
authz: AuthzConfig = field(default_factory=AuthzConfig)
channels: dict[str, ChannelConfig] = field(default_factory=dict)
backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig)
memory: MemoryConfig = field(default_factory=MemoryConfig)
config_path: Path | None = None
@property

View File

@ -0,0 +1,6 @@
"""Memory Gateway integration for Beaver memory runtime."""
from .client import MemoryGatewayClient
from .store import MemoryGatewayUserStore
__all__ = ["MemoryGatewayClient", "MemoryGatewayUserStore"]

View File

@ -0,0 +1,150 @@
"""HTTP client for Memory Gateway's `/memory-system` API."""
from __future__ import annotations
from typing import Any
import httpx
from .store import MemoryGatewayUserStore
class MemoryGatewayClient:
"""Small async client for the Memory Gateway business API."""
def __init__(
self,
*,
base_url: str,
store: MemoryGatewayUserStore,
api_key: str = "",
timeout_seconds: float = 15.0,
transport: httpx.AsyncBaseTransport | None = None,
) -> None:
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.timeout_seconds = timeout_seconds
self.store = store
self.transport = transport
async def ensure_user(self, user_id: str) -> str:
cached = self.store.get_user_key(user_id)
if cached:
return cached
data = await self._post("/memory-system/users", {"user_id": user_id})
user_key = self._extract_user_key(data)
if not user_key:
raise RuntimeError("Memory Gateway user creation response missing user_key")
self.store.save_user_key(user_id, user_key)
return user_key
async def get_profile(
self,
*,
user_id: str,
user_key: str,
query: str = "用户画像",
limit: int = 5,
) -> dict[str, Any]:
return await self._get(
f"/memory-system/users/{user_id}/profile",
{"user_key": user_key, "query": query, "limit": limit},
)
async def get_session_context(
self,
*,
user_id: str,
user_key: str,
session_id: str,
query: str,
limit: int = 5,
) -> dict[str, Any]:
return await self._post(
f"/memory-system/sessions/{session_id}/context",
{"user_id": user_id, "user_key": user_key, "query": query, "limit": limit},
)
async def search(
self,
*,
user_id: str,
user_key: str,
session_id: str,
query: str,
limit: int = 5,
) -> dict[str, Any]:
return await self._post(
"/memory-system/search",
{
"user_id": user_id,
"user_key": user_key,
"session_id": session_id,
"query": query,
"use_llm": False,
"limit": limit,
},
)
async def ingest_messages(
self,
*,
user_id: str,
user_key: str,
session_id: str,
user_message: str | None,
assistant_message: str | None,
) -> dict[str, Any]:
return await self._post(
"/memory-system/messages",
{
"user_id": user_id,
"user_key": user_key,
"session_id": session_id,
"user_message": user_message,
"assistant_message": assistant_message,
},
)
async def commit_session(
self,
*,
user_id: str,
user_key: str,
session_id: str,
) -> dict[str, Any]:
return await self._post(
f"/memory-system/sessions/{session_id}/commit",
{"user_id": user_id, "user_key": user_key},
)
async def _get(self, path: str, params: dict[str, Any]) -> dict[str, Any]:
async with self._client() as client:
response = await client.get(path, params=params)
response.raise_for_status()
return response.json()
async def _post(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
async with self._client() as client:
response = await client.post(path, json=payload)
response.raise_for_status()
return response.json()
def _client(self) -> httpx.AsyncClient:
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-API-Key"] = self.api_key
return httpx.AsyncClient(
base_url=self.base_url,
headers=headers,
timeout=self.timeout_seconds,
transport=self.transport,
)
@staticmethod
def _extract_user_key(data: dict[str, Any]) -> str | None:
account = data.get("account")
result = account.get("result") if isinstance(account, dict) else None
user_key = result.get("user_key") if isinstance(result, dict) else None
return str(user_key) if user_key else None

View File

@ -0,0 +1,190 @@
"""Memory service wrapper that augments local snapshots with Memory Gateway recall."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from beaver.foundation.config.schema import MemoryGatewayConfig
from beaver.memory.curated.snapshot import MemorySnapshot
from beaver.services.memory_service import MemoryService
logger = logging.getLogger(__name__)
class GatewayAugmentedMemoryService:
"""Keep local curated memory and add Memory Gateway recall when available."""
def __init__(
self,
*,
local_service: MemoryService,
client: Any,
config: MemoryGatewayConfig,
) -> None:
self.local_service = local_service
self.client = client
self.config = config
def initialize(self) -> None:
self.local_service.initialize()
def get_store(self) -> Any:
return self.local_service.get_store()
def get_snapshot(self) -> MemorySnapshot:
return self.local_service.get_snapshot()
def capture_snapshot_for_run(self) -> MemorySnapshot:
return self.local_service.capture_snapshot_for_run()
async def capture_snapshot_for_run_async(
self,
*,
user_id: str | None = None,
session_id: str | None = None,
query: str | None = None,
) -> MemorySnapshot:
local_snapshot = self.local_service.capture_snapshot_for_run()
if not user_id or not session_id or not query:
return local_snapshot
try:
user_key = await self.client.ensure_user(user_id)
limit = max(1, self.config.snapshot_search_limit)
profile, session_context, search = await asyncio.gather(
self.client.get_profile(
user_id=user_id,
user_key=user_key,
query="用户画像",
limit=limit,
),
self.client.get_session_context(
user_id=user_id,
user_key=user_key,
session_id=session_id,
query=query,
limit=limit,
),
self.client.search(
user_id=user_id,
user_key=user_key,
session_id=session_id,
query=query,
limit=limit,
),
)
except Exception as exc: # noqa: BLE001 - gateway recall must not break local memory
logger.warning("Memory Gateway snapshot augmentation failed: %s", exc, exc_info=True)
return local_snapshot
user_block = self._join_blocks(
local_snapshot.user_block,
self._render_profile(profile),
)
memory_block = self._join_blocks(
local_snapshot.memory_block,
self._render_session_context(session_context),
self._render_search(search),
)
return MemorySnapshot(memory_block=memory_block, user_block=user_block)
async def archive_run_async(
self,
*,
user_id: str | None,
session_id: str,
user_message: str,
assistant_message: str,
) -> dict[str, Any]:
if not user_id:
return {"success": False, "error": "user_id is required for Memory Gateway archive."}
user_key = await self.client.ensure_user(user_id)
result: dict[str, Any] = {
"messages": await self.client.ingest_messages(
user_id=user_id,
user_key=user_key,
session_id=session_id,
user_message=user_message,
assistant_message=assistant_message,
)
}
if self.config.commit_on_run_complete:
result["commit"] = await self.client.commit_session(
user_id=user_id,
user_key=user_key,
session_id=session_id,
)
return result
@staticmethod
def _join_blocks(*blocks: str | None) -> str | None:
joined = "\n\n".join(block.strip() for block in blocks if block and block.strip())
return joined or None
@staticmethod
def _render_profile(payload: dict[str, Any]) -> str | None:
lines: list[str] = []
profile = payload.get("profile")
if isinstance(profile, dict):
lines.extend(_compact_mapping(profile))
elif profile:
lines.append(str(profile))
lines.extend(_compact_items(payload.get("items")))
return _section("GATEWAY USER PROFILE", lines)
@staticmethod
def _render_session_context(payload: dict[str, Any]) -> str | None:
lines: list[str] = []
context = payload.get("context")
if isinstance(context, dict):
overview = context.get("latest_archive_overview")
if overview:
lines.append(str(overview))
messages = context.get("messages")
if isinstance(messages, list):
lines.extend(_compact_items(messages))
lines.extend(_compact_items(payload.get("items")))
return _section("GATEWAY SESSION CONTEXT", lines)
@staticmethod
def _render_search(payload: dict[str, Any]) -> str | None:
return _section("GATEWAY SEARCH RESULTS", _compact_items(payload.get("items")))
def _section(title: str, lines: list[str]) -> str | None:
cleaned = [line.strip() for line in lines if line and line.strip()]
if not cleaned:
return None
return f"{title}\n" + "\n".join(cleaned)
def _compact_mapping(payload: dict[str, Any]) -> list[str]:
result: list[str] = []
for key in ("summary", "text", "content", "description"):
value = payload.get(key)
if value:
result.append(str(value))
if not result and payload:
result.append(str(payload))
return result
def _compact_items(items: Any) -> list[str]:
if not isinstance(items, list):
return []
result: list[str] = []
for item in items:
if not isinstance(item, dict):
if item:
result.append(f"- {item}")
continue
text = item.get("text") or item.get("summary") or item.get("content")
if not text:
text = str(item)
source = item.get("source_backend")
prefix = f"[{source}] " if source else ""
result.append(f"- {prefix}{text}")
return result

View File

@ -0,0 +1,54 @@
"""SQLite cache for Memory Gateway user credentials."""
from __future__ import annotations
import sqlite3
import time
from pathlib import Path
class MemoryGatewayUserStore:
"""Persist `user_id -> user_key` mappings returned by Memory Gateway."""
def __init__(self, db_path: str | Path) -> None:
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_schema()
def get_user_key(self, user_id: str) -> str | None:
with self._connect() as conn:
row = conn.execute(
"SELECT user_key FROM memory_gateway_users WHERE user_id = ?",
(user_id,),
).fetchone()
return str(row[0]) if row else None
def save_user_key(self, user_id: str, user_key: str) -> None:
now = time.time()
with self._connect() as conn:
conn.execute(
"""
INSERT INTO memory_gateway_users (user_id, user_key, created_at, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
user_key = excluded.user_key,
updated_at = excluded.updated_at
""",
(user_id, user_key, now, now),
)
def _init_schema(self) -> None:
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS memory_gateway_users (
user_id TEXT PRIMARY KEY,
user_key TEXT NOT NULL,
created_at REAL NOT NULL,
updated_at REAL NOT NULL
)
"""
)
def _connect(self) -> sqlite3.Connection:
return sqlite3.connect(str(self.db_path))

View File

@ -58,6 +58,17 @@ class MemoryService:
store.load_from_disk()
return capture_memory_snapshot(store)
async def capture_snapshot_for_run_async(
self,
*,
user_id: str | None = None,
session_id: str | None = None,
query: str | None = None,
) -> MemorySnapshot:
"""Async-compatible snapshot hook used by optional memory integrations."""
return self.capture_snapshot_for_run()
def get_snapshot(self) -> MemorySnapshot:
"""获取当前 run 应注入 system prompt 的 frozen snapshot。"""