feat: add Memory Gateway integration with async support for memory snapshots and user management
This commit is contained in:
@ -14,6 +14,8 @@ from beaver.engine.session import SessionManager
|
|||||||
from beaver.foundation.config import BeaverConfig, load_config
|
from beaver.foundation.config import BeaverConfig, load_config
|
||||||
from beaver.integrations.mcp import MCPConnectionManager
|
from beaver.integrations.mcp import MCPConnectionManager
|
||||||
from beaver.memory.curated.store import MemoryStore
|
from beaver.memory.curated.store import MemoryStore
|
||||||
|
from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore
|
||||||
|
from beaver.memory.gateway.service import GatewayAugmentedMemoryService
|
||||||
from beaver.memory.runs import RunMemoryStore
|
from beaver.memory.runs import RunMemoryStore
|
||||||
from beaver.memory.skills import SkillLearningStore
|
from beaver.memory.skills import SkillLearningStore
|
||||||
from beaver.services.memory_service import MemoryService
|
from beaver.services.memory_service import MemoryService
|
||||||
@ -206,7 +208,26 @@ class EngineLoader:
|
|||||||
|
|
||||||
curated_root = workspace / "memory" / "curated"
|
curated_root = workspace / "memory" / "curated"
|
||||||
curated_memory_store = self._curated_memory_store or MemoryStore(curated_root)
|
curated_memory_store = self._curated_memory_store or MemoryStore(curated_root)
|
||||||
memory_service = self._memory_service or MemoryService(curated_root, store=curated_memory_store)
|
if self._memory_service is not None:
|
||||||
|
memory_service = self._memory_service
|
||||||
|
else:
|
||||||
|
local_memory_service = MemoryService(curated_root, store=curated_memory_store)
|
||||||
|
memory_cfg = self.config.memory
|
||||||
|
if memory_cfg.mode == "gateway" and memory_cfg.gateway.base_url:
|
||||||
|
gateway_store = MemoryGatewayUserStore(workspace / "memory" / "gateway" / "state.db")
|
||||||
|
gateway_client = MemoryGatewayClient(
|
||||||
|
base_url=memory_cfg.gateway.base_url,
|
||||||
|
api_key=memory_cfg.gateway.api_key,
|
||||||
|
timeout_seconds=memory_cfg.gateway.timeout_seconds,
|
||||||
|
store=gateway_store,
|
||||||
|
)
|
||||||
|
memory_service = GatewayAugmentedMemoryService(
|
||||||
|
local_service=local_memory_service,
|
||||||
|
client=gateway_client,
|
||||||
|
config=memory_cfg.gateway,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
memory_service = local_memory_service
|
||||||
memory_service.initialize()
|
memory_service.initialize()
|
||||||
run_memory_store = self._run_memory_store or RunMemoryStore(workspace / "memory" / "runs")
|
run_memory_store = self._run_memory_store or RunMemoryStore(workspace / "memory" / "runs")
|
||||||
skill_learning_store = self._skill_learning_store or SkillLearningStore(workspace / "memory" / "skills")
|
skill_learning_store = self._skill_learning_store or SkillLearningStore(workspace / "memory" / "skills")
|
||||||
|
|||||||
@ -380,10 +380,15 @@ class AgentLoop:
|
|||||||
resolved_max_tool_iterations = (
|
resolved_max_tool_iterations = (
|
||||||
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
|
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
|
||||||
)
|
)
|
||||||
|
resolved_memory_user_id = user_id or config.memory.gateway.default_user_id or None
|
||||||
|
|
||||||
# 每个 run 都捕获自己的 frozen snapshot,不能依赖 MemoryService
|
# 每个 run 都捕获自己的 frozen snapshot,不能依赖 MemoryService
|
||||||
# 上的共享 `_snapshot`,否则 parallel team runs 会互相覆盖。
|
# 上的共享 `_snapshot`,否则 parallel team runs 会互相覆盖。
|
||||||
memory_snapshot = memory_service.capture_snapshot_for_run()
|
memory_snapshot = await memory_service.capture_snapshot_for_run_async(
|
||||||
|
user_id=resolved_memory_user_id,
|
||||||
|
session_id=resolved_session_id,
|
||||||
|
query=task,
|
||||||
|
)
|
||||||
|
|
||||||
if parent_session_id:
|
if parent_session_id:
|
||||||
session_manager.ensure_session(
|
session_manager.ensure_session(
|
||||||
@ -834,6 +839,22 @@ class AgentLoop:
|
|||||||
model=final_model,
|
model=final_model,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
archive_fn = getattr(memory_service, "archive_run_async", None)
|
||||||
|
if archive_fn is not None and resolved_memory_user_id:
|
||||||
|
asyncio.create_task(
|
||||||
|
self._archive_memory_gateway_run(
|
||||||
|
archive_fn=archive_fn,
|
||||||
|
session_manager=session_manager,
|
||||||
|
session_id=resolved_session_id,
|
||||||
|
run_id=resolved_run_id,
|
||||||
|
user_id=resolved_memory_user_id,
|
||||||
|
user_message=task,
|
||||||
|
assistant_message=final_text,
|
||||||
|
source=source,
|
||||||
|
title=title,
|
||||||
|
model=final_model,
|
||||||
|
)
|
||||||
|
)
|
||||||
self._record_run_receipts(
|
self._record_run_receipts(
|
||||||
skill_learning_service=skill_learning_service,
|
skill_learning_service=skill_learning_service,
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
@ -1191,6 +1212,57 @@ class AgentLoop:
|
|||||||
context_visible=False,
|
context_visible=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _archive_memory_gateway_run(
|
||||||
|
*,
|
||||||
|
archive_fn: Any,
|
||||||
|
session_manager: Any,
|
||||||
|
session_id: str,
|
||||||
|
run_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
user_message: str,
|
||||||
|
assistant_message: str,
|
||||||
|
source: str,
|
||||||
|
title: str | None,
|
||||||
|
model: str | None,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
result = await archive_fn(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_message=user_message,
|
||||||
|
assistant_message=assistant_message,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001 - archive must not change completed run result
|
||||||
|
session_manager.append_message(
|
||||||
|
session_id,
|
||||||
|
run_id=run_id,
|
||||||
|
role="system",
|
||||||
|
event_type="memory_gateway_archive_failed",
|
||||||
|
event_payload={"error": str(exc)},
|
||||||
|
content=f"Memory Gateway archive failed: {exc}",
|
||||||
|
context_visible=False,
|
||||||
|
source=source,
|
||||||
|
title=title,
|
||||||
|
model=model,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
session_manager.append_message(
|
||||||
|
session_id,
|
||||||
|
run_id=run_id,
|
||||||
|
role="system",
|
||||||
|
event_type="memory_gateway_archive_completed",
|
||||||
|
event_payload={"result": result},
|
||||||
|
content="Memory Gateway archive completed.",
|
||||||
|
context_visible=False,
|
||||||
|
source=source,
|
||||||
|
title=title,
|
||||||
|
model=model,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _utc_now() -> str:
|
def _utc_now() -> str:
|
||||||
return datetime.now(timezone.utc).isoformat()
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|||||||
@ -15,6 +15,8 @@ from .schema import (
|
|||||||
BeaverConfig,
|
BeaverConfig,
|
||||||
ChannelConfig,
|
ChannelConfig,
|
||||||
EmbeddingConfig,
|
EmbeddingConfig,
|
||||||
|
MemoryConfig,
|
||||||
|
MemoryGatewayConfig,
|
||||||
MCPServerConfig,
|
MCPServerConfig,
|
||||||
ProviderConfig,
|
ProviderConfig,
|
||||||
ToolsConfig,
|
ToolsConfig,
|
||||||
@ -76,6 +78,7 @@ def load_config(
|
|||||||
authz=_parse_authz(data.get("authz")),
|
authz=_parse_authz(data.get("authz")),
|
||||||
channels=_parse_channels(data.get("channels")),
|
channels=_parse_channels(data.get("channels")),
|
||||||
backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")),
|
backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")),
|
||||||
|
memory=_parse_memory(data.get("memory")),
|
||||||
config_path=path,
|
config_path=path,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -251,6 +254,35 @@ def _parse_backend_identity(raw: Any) -> BackendIdentityConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_memory(raw: Any) -> MemoryConfig:
|
||||||
|
data = _as_dict(raw)
|
||||||
|
gateway = _as_dict(data.get("gateway"))
|
||||||
|
mode = (_string(data.get("mode")) or "local").lower()
|
||||||
|
if mode not in {"local", "gateway"}:
|
||||||
|
mode = "local"
|
||||||
|
return MemoryConfig(
|
||||||
|
mode=mode,
|
||||||
|
gateway=MemoryGatewayConfig(
|
||||||
|
base_url=_string(gateway.get("baseUrl") or gateway.get("base_url")) or "",
|
||||||
|
api_key=_string(gateway.get("apiKey") or gateway.get("api_key")) or "",
|
||||||
|
default_user_id=_string(gateway.get("defaultUserId") or gateway.get("default_user_id")) or "",
|
||||||
|
timeout_seconds=_float(gateway.get("timeoutSeconds") or gateway.get("timeout_seconds")) or 15.0,
|
||||||
|
snapshot_search_limit=int(
|
||||||
|
_float(gateway.get("snapshotSearchLimit") or gateway.get("snapshot_search_limit"))
|
||||||
|
or 5
|
||||||
|
),
|
||||||
|
commit_on_run_complete=_bool(
|
||||||
|
(
|
||||||
|
gateway.get("commitOnRunComplete")
|
||||||
|
if "commitOnRunComplete" in gateway
|
||||||
|
else gateway.get("commit_on_run_complete")
|
||||||
|
),
|
||||||
|
default=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _as_dict(value: Any) -> dict[str, Any]:
|
def _as_dict(value: Any) -> dict[str, Any]:
|
||||||
return value if isinstance(value, dict) else {}
|
return value if isinstance(value, dict) else {}
|
||||||
|
|
||||||
|
|||||||
@ -115,6 +115,26 @@ class BackendIdentityConfig:
|
|||||||
public_base_url: str = ""
|
public_base_url: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class MemoryGatewayConfig:
|
||||||
|
"""Memory Gateway integration settings."""
|
||||||
|
|
||||||
|
base_url: str = ""
|
||||||
|
api_key: str = ""
|
||||||
|
default_user_id: str = ""
|
||||||
|
timeout_seconds: float = 15.0
|
||||||
|
snapshot_search_limit: int = 5
|
||||||
|
commit_on_run_complete: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class MemoryConfig:
|
||||||
|
"""Runtime memory strategy configuration."""
|
||||||
|
|
||||||
|
mode: str = "local"
|
||||||
|
gateway: MemoryGatewayConfig = field(default_factory=MemoryGatewayConfig)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class BeaverConfig:
|
class BeaverConfig:
|
||||||
"""Config loaded once per backend sandbox instance."""
|
"""Config loaded once per backend sandbox instance."""
|
||||||
@ -126,6 +146,7 @@ class BeaverConfig:
|
|||||||
authz: AuthzConfig = field(default_factory=AuthzConfig)
|
authz: AuthzConfig = field(default_factory=AuthzConfig)
|
||||||
channels: dict[str, ChannelConfig] = field(default_factory=dict)
|
channels: dict[str, ChannelConfig] = field(default_factory=dict)
|
||||||
backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig)
|
backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig)
|
||||||
|
memory: MemoryConfig = field(default_factory=MemoryConfig)
|
||||||
config_path: Path | None = None
|
config_path: Path | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
6
app-instance/backend/beaver/memory/gateway/__init__.py
Normal file
6
app-instance/backend/beaver/memory/gateway/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
"""Memory Gateway integration for Beaver memory runtime."""
|
||||||
|
|
||||||
|
from .client import MemoryGatewayClient
|
||||||
|
from .store import MemoryGatewayUserStore
|
||||||
|
|
||||||
|
__all__ = ["MemoryGatewayClient", "MemoryGatewayUserStore"]
|
||||||
150
app-instance/backend/beaver/memory/gateway/client.py
Normal file
150
app-instance/backend/beaver/memory/gateway/client.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
"""HTTP client for Memory Gateway's `/memory-system` API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .store import MemoryGatewayUserStore
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryGatewayClient:
|
||||||
|
"""Small async client for the Memory Gateway business API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
store: MemoryGatewayUserStore,
|
||||||
|
api_key: str = "",
|
||||||
|
timeout_seconds: float = 15.0,
|
||||||
|
transport: httpx.AsyncBaseTransport | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.api_key = api_key
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
self.store = store
|
||||||
|
self.transport = transport
|
||||||
|
|
||||||
|
async def ensure_user(self, user_id: str) -> str:
|
||||||
|
cached = self.store.get_user_key(user_id)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
data = await self._post("/memory-system/users", {"user_id": user_id})
|
||||||
|
user_key = self._extract_user_key(data)
|
||||||
|
if not user_key:
|
||||||
|
raise RuntimeError("Memory Gateway user creation response missing user_key")
|
||||||
|
self.store.save_user_key(user_id, user_key)
|
||||||
|
return user_key
|
||||||
|
|
||||||
|
async def get_profile(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
user_key: str,
|
||||||
|
query: str = "用户画像",
|
||||||
|
limit: int = 5,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return await self._get(
|
||||||
|
f"/memory-system/users/{user_id}/profile",
|
||||||
|
{"user_key": user_key, "query": query, "limit": limit},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_session_context(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
user_key: str,
|
||||||
|
session_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return await self._post(
|
||||||
|
f"/memory-system/sessions/{session_id}/context",
|
||||||
|
{"user_id": user_id, "user_key": user_key, "query": query, "limit": limit},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
user_key: str,
|
||||||
|
session_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return await self._post(
|
||||||
|
"/memory-system/search",
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_key": user_key,
|
||||||
|
"session_id": session_id,
|
||||||
|
"query": query,
|
||||||
|
"use_llm": False,
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ingest_messages(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
user_key: str,
|
||||||
|
session_id: str,
|
||||||
|
user_message: str | None,
|
||||||
|
assistant_message: str | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return await self._post(
|
||||||
|
"/memory-system/messages",
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_key": user_key,
|
||||||
|
"session_id": session_id,
|
||||||
|
"user_message": user_message,
|
||||||
|
"assistant_message": assistant_message,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def commit_session(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
user_key: str,
|
||||||
|
session_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return await self._post(
|
||||||
|
f"/memory-system/sessions/{session_id}/commit",
|
||||||
|
{"user_id": user_id, "user_key": user_key},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get(self, path: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
async with self._client() as client:
|
||||||
|
response = await client.get(path, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def _post(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
async with self._client() as client:
|
||||||
|
response = await client.post(path, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def _client(self) -> httpx.AsyncClient:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self.api_key:
|
||||||
|
headers["X-API-Key"] = self.api_key
|
||||||
|
return httpx.AsyncClient(
|
||||||
|
base_url=self.base_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.timeout_seconds,
|
||||||
|
transport=self.transport,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_user_key(data: dict[str, Any]) -> str | None:
|
||||||
|
account = data.get("account")
|
||||||
|
result = account.get("result") if isinstance(account, dict) else None
|
||||||
|
user_key = result.get("user_key") if isinstance(result, dict) else None
|
||||||
|
return str(user_key) if user_key else None
|
||||||
190
app-instance/backend/beaver/memory/gateway/service.py
Normal file
190
app-instance/backend/beaver/memory/gateway/service.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
"""Memory service wrapper that augments local snapshots with Memory Gateway recall."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.config.schema import MemoryGatewayConfig
|
||||||
|
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||||
|
from beaver.services.memory_service import MemoryService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayAugmentedMemoryService:
|
||||||
|
"""Keep local curated memory and add Memory Gateway recall when available."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
local_service: MemoryService,
|
||||||
|
client: Any,
|
||||||
|
config: MemoryGatewayConfig,
|
||||||
|
) -> None:
|
||||||
|
self.local_service = local_service
|
||||||
|
self.client = client
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def initialize(self) -> None:
|
||||||
|
self.local_service.initialize()
|
||||||
|
|
||||||
|
def get_store(self) -> Any:
|
||||||
|
return self.local_service.get_store()
|
||||||
|
|
||||||
|
def get_snapshot(self) -> MemorySnapshot:
|
||||||
|
return self.local_service.get_snapshot()
|
||||||
|
|
||||||
|
def capture_snapshot_for_run(self) -> MemorySnapshot:
|
||||||
|
return self.local_service.capture_snapshot_for_run()
|
||||||
|
|
||||||
|
async def capture_snapshot_for_run_async(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
) -> MemorySnapshot:
|
||||||
|
local_snapshot = self.local_service.capture_snapshot_for_run()
|
||||||
|
if not user_id or not session_id or not query:
|
||||||
|
return local_snapshot
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_key = await self.client.ensure_user(user_id)
|
||||||
|
limit = max(1, self.config.snapshot_search_limit)
|
||||||
|
profile, session_context, search = await asyncio.gather(
|
||||||
|
self.client.get_profile(
|
||||||
|
user_id=user_id,
|
||||||
|
user_key=user_key,
|
||||||
|
query="用户画像",
|
||||||
|
limit=limit,
|
||||||
|
),
|
||||||
|
self.client.get_session_context(
|
||||||
|
user_id=user_id,
|
||||||
|
user_key=user_key,
|
||||||
|
session_id=session_id,
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
),
|
||||||
|
self.client.search(
|
||||||
|
user_id=user_id,
|
||||||
|
user_key=user_key,
|
||||||
|
session_id=session_id,
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001 - gateway recall must not break local memory
|
||||||
|
logger.warning("Memory Gateway snapshot augmentation failed: %s", exc, exc_info=True)
|
||||||
|
return local_snapshot
|
||||||
|
|
||||||
|
user_block = self._join_blocks(
|
||||||
|
local_snapshot.user_block,
|
||||||
|
self._render_profile(profile),
|
||||||
|
)
|
||||||
|
memory_block = self._join_blocks(
|
||||||
|
local_snapshot.memory_block,
|
||||||
|
self._render_session_context(session_context),
|
||||||
|
self._render_search(search),
|
||||||
|
)
|
||||||
|
return MemorySnapshot(memory_block=memory_block, user_block=user_block)
|
||||||
|
|
||||||
|
async def archive_run_async(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str,
|
||||||
|
user_message: str,
|
||||||
|
assistant_message: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
if not user_id:
|
||||||
|
return {"success": False, "error": "user_id is required for Memory Gateway archive."}
|
||||||
|
|
||||||
|
user_key = await self.client.ensure_user(user_id)
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"messages": await self.client.ingest_messages(
|
||||||
|
user_id=user_id,
|
||||||
|
user_key=user_key,
|
||||||
|
session_id=session_id,
|
||||||
|
user_message=user_message,
|
||||||
|
assistant_message=assistant_message,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if self.config.commit_on_run_complete:
|
||||||
|
result["commit"] = await self.client.commit_session(
|
||||||
|
user_id=user_id,
|
||||||
|
user_key=user_key,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _join_blocks(*blocks: str | None) -> str | None:
|
||||||
|
joined = "\n\n".join(block.strip() for block in blocks if block and block.strip())
|
||||||
|
return joined or None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _render_profile(payload: dict[str, Any]) -> str | None:
|
||||||
|
lines: list[str] = []
|
||||||
|
profile = payload.get("profile")
|
||||||
|
if isinstance(profile, dict):
|
||||||
|
lines.extend(_compact_mapping(profile))
|
||||||
|
elif profile:
|
||||||
|
lines.append(str(profile))
|
||||||
|
lines.extend(_compact_items(payload.get("items")))
|
||||||
|
return _section("GATEWAY USER PROFILE", lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _render_session_context(payload: dict[str, Any]) -> str | None:
|
||||||
|
lines: list[str] = []
|
||||||
|
context = payload.get("context")
|
||||||
|
if isinstance(context, dict):
|
||||||
|
overview = context.get("latest_archive_overview")
|
||||||
|
if overview:
|
||||||
|
lines.append(str(overview))
|
||||||
|
messages = context.get("messages")
|
||||||
|
if isinstance(messages, list):
|
||||||
|
lines.extend(_compact_items(messages))
|
||||||
|
lines.extend(_compact_items(payload.get("items")))
|
||||||
|
return _section("GATEWAY SESSION CONTEXT", lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _render_search(payload: dict[str, Any]) -> str | None:
|
||||||
|
return _section("GATEWAY SEARCH RESULTS", _compact_items(payload.get("items")))
|
||||||
|
|
||||||
|
|
||||||
|
def _section(title: str, lines: list[str]) -> str | None:
|
||||||
|
cleaned = [line.strip() for line in lines if line and line.strip()]
|
||||||
|
if not cleaned:
|
||||||
|
return None
|
||||||
|
return f"{title}\n" + "\n".join(cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_mapping(payload: dict[str, Any]) -> list[str]:
|
||||||
|
result: list[str] = []
|
||||||
|
for key in ("summary", "text", "content", "description"):
|
||||||
|
value = payload.get(key)
|
||||||
|
if value:
|
||||||
|
result.append(str(value))
|
||||||
|
if not result and payload:
|
||||||
|
result.append(str(payload))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_items(items: Any) -> list[str]:
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return []
|
||||||
|
result: list[str] = []
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
if item:
|
||||||
|
result.append(f"- {item}")
|
||||||
|
continue
|
||||||
|
text = item.get("text") or item.get("summary") or item.get("content")
|
||||||
|
if not text:
|
||||||
|
text = str(item)
|
||||||
|
source = item.get("source_backend")
|
||||||
|
prefix = f"[{source}] " if source else ""
|
||||||
|
result.append(f"- {prefix}{text}")
|
||||||
|
return result
|
||||||
54
app-instance/backend/beaver/memory/gateway/store.py
Normal file
54
app-instance/backend/beaver/memory/gateway/store.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""SQLite cache for Memory Gateway user credentials."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryGatewayUserStore:
|
||||||
|
"""Persist `user_id -> user_key` mappings returned by Memory Gateway."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | Path) -> None:
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._init_schema()
|
||||||
|
|
||||||
|
def get_user_key(self, user_id: str) -> str | None:
|
||||||
|
with self._connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT user_key FROM memory_gateway_users WHERE user_id = ?",
|
||||||
|
(user_id,),
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
|
||||||
|
def save_user_key(self, user_id: str, user_key: str) -> None:
|
||||||
|
now = time.time()
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO memory_gateway_users (user_id, user_key, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
ON CONFLICT(user_id) DO UPDATE SET
|
||||||
|
user_key = excluded.user_key,
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
""",
|
||||||
|
(user_id, user_key, now, now),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_schema(self) -> None:
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS memory_gateway_users (
|
||||||
|
user_id TEXT PRIMARY KEY,
|
||||||
|
user_key TEXT NOT NULL,
|
||||||
|
created_at REAL NOT NULL,
|
||||||
|
updated_at REAL NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def _connect(self) -> sqlite3.Connection:
|
||||||
|
return sqlite3.connect(str(self.db_path))
|
||||||
@ -58,6 +58,17 @@ class MemoryService:
|
|||||||
store.load_from_disk()
|
store.load_from_disk()
|
||||||
return capture_memory_snapshot(store)
|
return capture_memory_snapshot(store)
|
||||||
|
|
||||||
|
async def capture_snapshot_for_run_async(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
) -> MemorySnapshot:
|
||||||
|
"""Async-compatible snapshot hook used by optional memory integrations."""
|
||||||
|
|
||||||
|
return self.capture_snapshot_for_run()
|
||||||
|
|
||||||
def get_snapshot(self) -> MemorySnapshot:
|
def get_snapshot(self) -> MemorySnapshot:
|
||||||
"""获取当前 run 应注入 system prompt 的 frozen snapshot。"""
|
"""获取当前 run 应注入 system prompt 的 frozen snapshot。"""
|
||||||
|
|
||||||
|
|||||||
@ -85,6 +85,49 @@ def test_config_loader_reads_channels(tmp_path) -> None:
|
|||||||
assert channel.secrets == {"ignored_for_status": "secret-value"}
|
assert channel.secrets == {"ignored_for_status": "secret-value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_loader_reads_memory_gateway_config(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"memory": {
|
||||||
|
"mode": "gateway",
|
||||||
|
"gateway": {
|
||||||
|
"baseUrl": "http://127.0.0.1:1934",
|
||||||
|
"apiKey": "gateway-key",
|
||||||
|
"defaultUserId": "default-user",
|
||||||
|
"timeoutSeconds": 12,
|
||||||
|
"snapshotSearchLimit": 7,
|
||||||
|
"commitOnRunComplete": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(config_path=config_path)
|
||||||
|
|
||||||
|
assert config.memory.mode == "gateway"
|
||||||
|
assert config.memory.gateway.base_url == "http://127.0.0.1:1934"
|
||||||
|
assert config.memory.gateway.api_key == "gateway-key"
|
||||||
|
assert config.memory.gateway.default_user_id == "default-user"
|
||||||
|
assert config.memory.gateway.timeout_seconds == 12
|
||||||
|
assert config.memory.gateway.snapshot_search_limit == 7
|
||||||
|
assert config.memory.gateway.commit_on_run_complete is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_loader_defaults_to_local_memory(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||||
|
|
||||||
|
config = load_config(config_path=config_path)
|
||||||
|
|
||||||
|
assert config.memory.mode == "local"
|
||||||
|
assert config.memory.gateway.base_url == ""
|
||||||
|
assert config.memory.gateway.commit_on_run_complete is True
|
||||||
|
|
||||||
|
|
||||||
def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None:
|
def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None:
|
||||||
config_path = tmp_path / "config.json"
|
config_path = tmp_path / "config.json"
|
||||||
config_path.write_text(
|
config_path.write_text(
|
||||||
|
|||||||
130
app-instance/backend/tests/unit/test_memory_gateway_archive.py
Normal file
130
app-instance/backend/tests/unit/test_memory_gateway_archive.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from beaver.engine import AgentLoop, EngineLoader
|
||||||
|
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
||||||
|
from beaver.engine.providers.factory import ProviderBundle
|
||||||
|
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||||
|
from beaver.memory.curated.store import MemoryStore
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordingProvider(LLMProvider):
|
||||||
|
def __init__(self, response_text: str = "done") -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.response_text = response_text
|
||||||
|
self.messages: list[list[dict]] = []
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
thinking_enabled: bool | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
self.messages.append(messages)
|
||||||
|
return LLMResponse(
|
||||||
|
content=self.response_text,
|
||||||
|
finish_reason="stop",
|
||||||
|
provider_name="stub",
|
||||||
|
model="stub-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return "stub-model"
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeMemoryService:
|
||||||
|
def __init__(self, root: Path) -> None:
|
||||||
|
self.store = MemoryStore(root)
|
||||||
|
self.snapshot_calls: list[dict] = []
|
||||||
|
self.archive_calls: list[dict] = []
|
||||||
|
self.archive_started: asyncio.Event | None = None
|
||||||
|
self.archive_release: asyncio.Event | None = None
|
||||||
|
|
||||||
|
def initialize(self) -> None:
|
||||||
|
self.store.load_from_disk()
|
||||||
|
|
||||||
|
def get_store(self) -> MemoryStore:
|
||||||
|
return self.store
|
||||||
|
|
||||||
|
async def capture_snapshot_for_run_async(self, **kwargs) -> MemorySnapshot:
|
||||||
|
self.snapshot_calls.append(kwargs)
|
||||||
|
return MemorySnapshot(memory_block="ASYNC SNAPSHOT FROM GATEWAY", user_block=None)
|
||||||
|
|
||||||
|
async def archive_run_async(self, **kwargs) -> dict:
|
||||||
|
self.archive_calls.append(kwargs)
|
||||||
|
if self.archive_started is not None:
|
||||||
|
self.archive_started.set()
|
||||||
|
if self.archive_release is not None:
|
||||||
|
await self.archive_release.wait()
|
||||||
|
return {"status": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
def _bundle(provider: _RecordingProvider) -> ProviderBundle:
|
||||||
|
return ProviderBundle(
|
||||||
|
main_runtime=SimpleNamespace(model="stub-model", provider_name="stub"),
|
||||||
|
main_provider=provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_loop_uses_async_memory_snapshot(tmp_path) -> None:
|
||||||
|
memory_service = _FakeMemoryService(tmp_path / "memory" / "curated")
|
||||||
|
provider = _RecordingProvider("final answer")
|
||||||
|
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service))
|
||||||
|
|
||||||
|
result = await loop.process_direct(
|
||||||
|
"current task",
|
||||||
|
user_id="user-1",
|
||||||
|
session_id="session-1",
|
||||||
|
provider_bundle=_bundle(provider),
|
||||||
|
include_skill_assembly=False,
|
||||||
|
include_tools=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.output_text == "final answer"
|
||||||
|
assert memory_service.snapshot_calls == [
|
||||||
|
{"user_id": "user-1", "session_id": "session-1", "query": "current task"}
|
||||||
|
]
|
||||||
|
assert "ASYNC SNAPSHOT FROM GATEWAY" in provider.messages[0][0]["content"]
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_loop_archives_memory_gateway_run_in_background(tmp_path) -> None:
|
||||||
|
memory_service = _FakeMemoryService(tmp_path / "memory" / "curated")
|
||||||
|
memory_service.archive_started = asyncio.Event()
|
||||||
|
memory_service.archive_release = asyncio.Event()
|
||||||
|
provider = _RecordingProvider("assistant final")
|
||||||
|
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path, memory_service=memory_service))
|
||||||
|
|
||||||
|
result = await loop.process_direct(
|
||||||
|
"user asks",
|
||||||
|
user_id="user-1",
|
||||||
|
session_id="session-1",
|
||||||
|
provider_bundle=_bundle(provider),
|
||||||
|
include_skill_assembly=False,
|
||||||
|
include_tools=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
await asyncio.wait_for(memory_service.archive_started.wait(), timeout=1)
|
||||||
|
assert memory_service.archive_calls == [
|
||||||
|
{
|
||||||
|
"user_id": "user-1",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"user_message": "user asks",
|
||||||
|
"assistant_message": "assistant final",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
memory_service.archive_release.set()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
events = loop.boot().session_manager.get_run_event_records(result.session_id, result.run_id)
|
||||||
|
assert any(event.event_type == "memory_gateway_archive_completed" for event in events)
|
||||||
|
loop.close()
|
||||||
100
app-instance/backend/tests/unit/test_memory_gateway_client.py
Normal file
100
app-instance/backend/tests/unit/test_memory_gateway_client.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from beaver.memory.gateway import MemoryGatewayClient, MemoryGatewayUserStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_gateway_client_uses_cached_user_key(tmp_path) -> None:
|
||||||
|
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
|
||||||
|
store.save_user_key("user-1", "cached-key")
|
||||||
|
requests: list[httpx.Request] = []
|
||||||
|
|
||||||
|
async def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
requests.append(request)
|
||||||
|
return httpx.Response(500, json={"error": "should not be called"})
|
||||||
|
|
||||||
|
client = MemoryGatewayClient(
|
||||||
|
base_url="http://gateway.test",
|
||||||
|
store=store,
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_key = await client.ensure_user("user-1")
|
||||||
|
|
||||||
|
assert user_key == "cached-key"
|
||||||
|
assert requests == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_gateway_client_creates_and_caches_user_key(tmp_path) -> None:
|
||||||
|
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
|
||||||
|
requests: list[httpx.Request] = []
|
||||||
|
|
||||||
|
async def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
requests.append(request)
|
||||||
|
assert request.method == "POST"
|
||||||
|
assert request.url.path == "/memory-system/users"
|
||||||
|
assert json.loads(request.content) == {"user_id": "user-1"}
|
||||||
|
return httpx.Response(
|
||||||
|
200,
|
||||||
|
json={
|
||||||
|
"status": "success",
|
||||||
|
"account": {
|
||||||
|
"status": "ok",
|
||||||
|
"result": {"user_key": "created-key"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
client = MemoryGatewayClient(
|
||||||
|
base_url="http://gateway.test",
|
||||||
|
store=store,
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_key = await client.ensure_user("user-1")
|
||||||
|
|
||||||
|
assert user_key == "created-key"
|
||||||
|
assert store.get_user_key("user-1") == "created-key"
|
||||||
|
assert len(requests) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_gateway_client_ingests_messages_with_user_key(tmp_path) -> None:
|
||||||
|
store = MemoryGatewayUserStore(tmp_path / "gateway.db")
|
||||||
|
requests: list[httpx.Request] = []
|
||||||
|
|
||||||
|
async def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
requests.append(request)
|
||||||
|
assert request.method == "POST"
|
||||||
|
assert request.url.path == "/memory-system/messages"
|
||||||
|
assert request.headers["X-API-Key"] == "gateway-api-key"
|
||||||
|
assert json.loads(request.content) == {
|
||||||
|
"user_id": "user-1",
|
||||||
|
"user_key": "user-key",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"user_message": "hello",
|
||||||
|
"assistant_message": "hi",
|
||||||
|
}
|
||||||
|
return httpx.Response(200, json={"status": "success", "message_count": 2})
|
||||||
|
|
||||||
|
client = MemoryGatewayClient(
|
||||||
|
base_url="http://gateway.test",
|
||||||
|
api_key="gateway-api-key",
|
||||||
|
store=store,
|
||||||
|
transport=httpx.MockTransport(handler),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await client.ingest_messages(
|
||||||
|
user_id="user-1",
|
||||||
|
user_key="user-key",
|
||||||
|
session_id="session-1",
|
||||||
|
user_message="hello",
|
||||||
|
assistant_message="hi",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["status"] == "success"
|
||||||
|
assert len(requests) == 1
|
||||||
117
app-instance/backend/tests/unit/test_memory_gateway_snapshot.py
Normal file
117
app-instance/backend/tests/unit/test_memory_gateway_snapshot.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from beaver.engine import EngineLoader
|
||||||
|
from beaver.foundation.config.schema import MemoryGatewayConfig
|
||||||
|
from beaver.memory.gateway.service import GatewayAugmentedMemoryService
|
||||||
|
from beaver.services.memory_service import MemoryService
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeGatewayClient:
|
||||||
|
async def ensure_user(self, user_id: str) -> str:
|
||||||
|
return f"{user_id}-key"
|
||||||
|
|
||||||
|
async def get_profile(self, **kwargs):
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"profile": {"summary": "用户喜欢拿铁。"},
|
||||||
|
"items": [{"summary": "用户偏好中文回复。"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_session_context(self, **kwargs):
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"context": {"latest_archive_overview": "上次讨论了 memory gateway 接入。"},
|
||||||
|
"items": [{"summary": "用户要求保留本地 MEMORY.md。"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def search(self, **kwargs):
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"source_backend": "openviking",
|
||||||
|
"text": "需要异步写入 /memory-system/messages。",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _FailingGatewayClient(_FakeGatewayClient):
|
||||||
|
async def ensure_user(self, user_id: str) -> str:
|
||||||
|
raise RuntimeError("gateway unavailable")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gateway_snapshot_keeps_local_memory_and_adds_gateway_sections(tmp_path) -> None:
|
||||||
|
local = MemoryService(tmp_path / "memory" / "curated")
|
||||||
|
local.initialize()
|
||||||
|
local.get_store().add("memory", "本地项目约定:默认用中文解释。")
|
||||||
|
local.get_store().add("user", "本地用户画像:用户关注记忆系统。")
|
||||||
|
service = GatewayAugmentedMemoryService(
|
||||||
|
local_service=local,
|
||||||
|
client=_FakeGatewayClient(),
|
||||||
|
config=MemoryGatewayConfig(snapshot_search_limit=3),
|
||||||
|
)
|
||||||
|
|
||||||
|
snapshot = await service.capture_snapshot_for_run_async(
|
||||||
|
user_id="user-1",
|
||||||
|
session_id="session-1",
|
||||||
|
query="如何接入 memory gateway",
|
||||||
|
)
|
||||||
|
prompt = "\n".join(snapshot.as_prompt_sections())
|
||||||
|
|
||||||
|
assert "本地项目约定:默认用中文解释。" in prompt
|
||||||
|
assert "本地用户画像:用户关注记忆系统。" in prompt
|
||||||
|
assert "GATEWAY USER PROFILE" in prompt
|
||||||
|
assert "用户喜欢拿铁。" in prompt
|
||||||
|
assert "GATEWAY SESSION CONTEXT" in prompt
|
||||||
|
assert "上次讨论了 memory gateway 接入。" in prompt
|
||||||
|
assert "GATEWAY SEARCH RESULTS" in prompt
|
||||||
|
assert "需要异步写入 /memory-system/messages。" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gateway_snapshot_falls_back_to_local_memory_on_gateway_failure(tmp_path) -> None:
|
||||||
|
local = MemoryService(tmp_path / "memory" / "curated")
|
||||||
|
local.initialize()
|
||||||
|
local.get_store().add("memory", "本地记忆仍然可用。")
|
||||||
|
service = GatewayAugmentedMemoryService(
|
||||||
|
local_service=local,
|
||||||
|
client=_FailingGatewayClient(),
|
||||||
|
config=MemoryGatewayConfig(snapshot_search_limit=3),
|
||||||
|
)
|
||||||
|
|
||||||
|
snapshot = await service.capture_snapshot_for_run_async(
|
||||||
|
user_id="user-1",
|
||||||
|
session_id="session-1",
|
||||||
|
query="任何问题",
|
||||||
|
)
|
||||||
|
prompt = "\n".join(snapshot.as_prompt_sections())
|
||||||
|
|
||||||
|
assert "本地记忆仍然可用。" in prompt
|
||||||
|
assert "GATEWAY USER PROFILE" not in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_engine_loader_uses_gateway_memory_service_without_replacing_tools(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"workspace": "%s"}},
|
||||||
|
"memory": {
|
||||||
|
"mode": "gateway",
|
||||||
|
"gateway": {"baseUrl": "http://gateway.test", "defaultUserId": "default-user"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
% str(tmp_path / "workspace"),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = EngineLoader(config_path=config_path)
|
||||||
|
loaded = loader.load()
|
||||||
|
|
||||||
|
assert isinstance(loaded.memory_service, GatewayAugmentedMemoryService)
|
||||||
|
assert "memory" in loaded.tools
|
||||||
|
assert "session_search" in loaded.tools
|
||||||
|
loaded.close()
|
||||||
Reference in New Issue
Block a user