feat(memory-gateway): 引入 Memory Gateway 配置、凭据存储和服务编排
* 新增 MemoryGatewayConfig 和 MemoryConfig dataclass,用于配置管理。 * 实现 MemoryGatewayUserCredential 和 MemoryGatewayCredentialStore,用于处理用户凭据。 * 创建 MemoryGatewayService,用于管理与 Memory Gateway 的交互。 * 开发用于记忆设置的 JSON 配置文件。 * 增强单元测试,覆盖新功能,包括凭据存储和服务行为。 * 更新 entrypoint 和实例创建脚本,以初始化 Memory Gateway 用户存储。
This commit is contained in:
@ -15,10 +15,16 @@ from beaver.engine.session import SessionManager
|
||||
from beaver.foundation.config import BeaverConfig, load_config
|
||||
from beaver.integrations.mcp import MCPConnectionManager
|
||||
from beaver.memory.curated.store import MemoryStore
|
||||
from beaver.memory.gateway import (
|
||||
MemoryGatewayConfig,
|
||||
MemoryGatewayCredentialStore,
|
||||
MemoryGatewayService,
|
||||
MemoryGatewayUserCredential,
|
||||
default_memory_gateway_users_path,
|
||||
)
|
||||
from beaver.memory.runs import RunMemoryStore
|
||||
from beaver.memory.skills import SkillLearningStore
|
||||
from beaver.services.memory_service import MemoryService
|
||||
from beaver.services.memory_gateway_service import MemoryGatewayService
|
||||
from beaver.skills.drafts import DraftService
|
||||
from beaver.skills.learning import EvidenceSelector, SkillDraftSynthesizer, SkillLearningPipelineService, SkillLearningService
|
||||
from beaver.skills.learning.safety import SkillDraftSafetyChecker
|
||||
@ -84,7 +90,9 @@ class EngineLoadResult:
|
||||
session_manager: SessionManager | None = None
|
||||
curated_memory_store: MemoryStore | None = None
|
||||
memory_service: MemoryService | None = None
|
||||
memory_gateway_service: MemoryGatewayService | None = None
|
||||
memory_gateway_config: MemoryGatewayConfig | None = None
|
||||
memory_gateway_credentials: MemoryGatewayCredentialStore | None = None
|
||||
memory_gateway_service_factory: Callable[[MemoryGatewayUserCredential], MemoryGatewayService] | None = None
|
||||
run_memory_store: RunMemoryStore | None = None
|
||||
skill_learning_store: SkillLearningStore | None = None
|
||||
tool_registry: ToolRegistry | None = None
|
||||
@ -160,7 +168,8 @@ class EngineLoader:
|
||||
session_manager: SessionManager | None = None,
|
||||
curated_memory_store: MemoryStore | None = None,
|
||||
memory_service: MemoryService | None = None,
|
||||
memory_gateway_service: MemoryGatewayService | None = None,
|
||||
memory_gateway_credentials: MemoryGatewayCredentialStore | None = None,
|
||||
memory_gateway_service_factory: Callable[[MemoryGatewayConfig, MemoryGatewayUserCredential], MemoryGatewayService] | None = None,
|
||||
run_memory_store: RunMemoryStore | None = None,
|
||||
skill_learning_store: SkillLearningStore | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
@ -186,7 +195,8 @@ class EngineLoader:
|
||||
self._session_manager = session_manager
|
||||
self._curated_memory_store = curated_memory_store
|
||||
self._memory_service = memory_service
|
||||
self._memory_gateway_service = memory_gateway_service
|
||||
self._memory_gateway_credentials = memory_gateway_credentials
|
||||
self._memory_gateway_service_factory = memory_gateway_service_factory
|
||||
self._run_memory_store = run_memory_store
|
||||
self._skill_learning_store = skill_learning_store
|
||||
self._tool_registry = tool_registry
|
||||
@ -209,7 +219,11 @@ class EngineLoader:
|
||||
"""装配当前主链需要的最小 runtime 对象。"""
|
||||
|
||||
workspace = self.workspace
|
||||
memory_gateway_service = self._resolve_memory_gateway_service()
|
||||
(
|
||||
memory_gateway_config,
|
||||
memory_gateway_credentials,
|
||||
memory_gateway_service_factory,
|
||||
) = self._resolve_memory_gateway_components()
|
||||
session_manager = self._session_manager or SessionManager(workspace)
|
||||
|
||||
curated_root = workspace / "memory" / "curated"
|
||||
@ -306,12 +320,14 @@ class EngineLoader:
|
||||
config=self.config,
|
||||
tools=[spec.name for spec in tool_registry.list_specs()],
|
||||
skills=[record.name for record in skills_loader.list_skills(filter_unavailable=False)],
|
||||
memory_stores=["curated", *(["memory_gateway"] if memory_gateway_service is not None else [])],
|
||||
memory_stores=["curated", *(["memory_gateway"] if memory_gateway_service_factory is not None else [])],
|
||||
permissions=[],
|
||||
session_manager=session_manager,
|
||||
curated_memory_store=memory_service.get_store(),
|
||||
memory_service=memory_service,
|
||||
memory_gateway_service=memory_gateway_service,
|
||||
memory_gateway_config=memory_gateway_config,
|
||||
memory_gateway_credentials=memory_gateway_credentials,
|
||||
memory_gateway_service_factory=memory_gateway_service_factory,
|
||||
run_memory_store=run_memory_store,
|
||||
skill_learning_store=skill_learning_store,
|
||||
tool_registry=tool_registry,
|
||||
@ -337,10 +353,16 @@ class EngineLoader:
|
||||
result.register_closeable("mcp_manager", lambda: _close_mcp_manager(mcp_manager))
|
||||
return result
|
||||
|
||||
def _resolve_memory_gateway_service(self) -> MemoryGatewayService | None:
|
||||
def _resolve_memory_gateway_components(
|
||||
self,
|
||||
) -> tuple[
|
||||
MemoryGatewayConfig | None,
|
||||
MemoryGatewayCredentialStore | None,
|
||||
Callable[[MemoryGatewayUserCredential], MemoryGatewayService] | None,
|
||||
]:
|
||||
memory_config = self.config.memory
|
||||
if memory_config.mode == "curated":
|
||||
return None
|
||||
return None, None, None
|
||||
|
||||
gateway_config = memory_config.gateway
|
||||
if memory_config.explicit and not gateway_config.is_configured:
|
||||
@ -351,8 +373,18 @@ class EngineLoader:
|
||||
logger.warning(
|
||||
"Memory Gateway is not configured; continuing with curated memory only"
|
||||
)
|
||||
return None
|
||||
return self._memory_gateway_service or MemoryGatewayService(gateway_config)
|
||||
return None, None, None
|
||||
|
||||
credential_store = self._memory_gateway_credentials or MemoryGatewayCredentialStore(
|
||||
default_memory_gateway_users_path()
|
||||
)
|
||||
|
||||
def factory(credential: MemoryGatewayUserCredential) -> MemoryGatewayService:
|
||||
if self._memory_gateway_service_factory is not None:
|
||||
return self._memory_gateway_service_factory(gateway_config, credential)
|
||||
return MemoryGatewayService(gateway_config, credential)
|
||||
|
||||
return gateway_config, credential_store, factory
|
||||
|
||||
|
||||
def _close_mcp_manager(manager: MCPConnectionManager) -> None:
|
||||
|
||||
@ -227,6 +227,7 @@ class AgentLoop:
|
||||
session_id: str | None = None,
|
||||
source: str = "direct",
|
||||
user_id: str | None = None,
|
||||
gateway_user_id: str | None = None,
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
skill_selection_context: str | None = None,
|
||||
@ -279,6 +280,7 @@ class AgentLoop:
|
||||
session_id=session_id,
|
||||
source=source,
|
||||
user_id=user_id,
|
||||
gateway_user_id=gateway_user_id,
|
||||
title=title,
|
||||
execution_context=execution_context,
|
||||
skill_selection_context=skill_selection_context,
|
||||
@ -319,6 +321,7 @@ class AgentLoop:
|
||||
session_id: str | None = None,
|
||||
source: str = "direct",
|
||||
user_id: str | None = None,
|
||||
gateway_user_id: str | None = None,
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
skill_selection_context: str | None = None,
|
||||
@ -360,6 +363,13 @@ class AgentLoop:
|
||||
"""
|
||||
|
||||
loaded = self.boot()
|
||||
memory_gateway_service = None
|
||||
gateway_credential_store = getattr(loaded, "memory_gateway_credentials", None)
|
||||
gateway_service_factory = getattr(loaded, "memory_gateway_service_factory", None)
|
||||
if gateway_user_id and gateway_credential_store is not None and gateway_service_factory is not None:
|
||||
gateway_credential = gateway_credential_store.get(gateway_user_id)
|
||||
if gateway_credential is not None:
|
||||
memory_gateway_service = gateway_service_factory(gateway_credential)
|
||||
session_manager = self._require_loaded("session_manager")
|
||||
memory_service = self._require_loaded("memory_service")
|
||||
context_builder = self._require_loaded("context_builder")
|
||||
@ -482,7 +492,6 @@ class AgentLoop:
|
||||
final_model: str | None = resolved_model
|
||||
run_started_at = self._utc_now()
|
||||
activated_receipts: list[SkillActivationReceipt] = []
|
||||
memory_gateway_service = getattr(loaded, "memory_gateway_service", None)
|
||||
try:
|
||||
bundle = provider_bundle or make_provider_bundle(
|
||||
model=resolved_model,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Configuration models and loaders."""
|
||||
|
||||
from .loader import default_config_path, load_config
|
||||
from .loader import default_config_path, default_memory_config_path, load_config
|
||||
from .schema import (
|
||||
AgentDefaultsConfig,
|
||||
AuthzConfig,
|
||||
@ -26,5 +26,6 @@ __all__ = [
|
||||
"ProviderConfig",
|
||||
"ToolsConfig",
|
||||
"default_config_path",
|
||||
"default_memory_config_path",
|
||||
"load_config",
|
||||
]
|
||||
|
||||
@ -55,6 +55,16 @@ def default_config_path(*, workspace: str | Path | None = None) -> Path:
|
||||
return root / ".beaver" / "config.json"
|
||||
|
||||
|
||||
def default_memory_config_path() -> Path:
|
||||
"""Resolve the shared Memory Gateway config path."""
|
||||
|
||||
explicit = os.getenv("BEAVER_MEMORY_CONFIG_PATH")
|
||||
if explicit:
|
||||
return Path(explicit).expanduser()
|
||||
|
||||
return Path(__file__).resolve().parents[3] / "memory" / "config.json"
|
||||
|
||||
|
||||
def load_config(
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
@ -63,24 +73,38 @@ def load_config(
|
||||
"""Load backend config; missing config is treated as an empty config."""
|
||||
|
||||
path = Path(config_path).expanduser() if config_path is not None else default_config_path(workspace=workspace)
|
||||
data: dict[str, Any] | None = None
|
||||
if path.exists():
|
||||
loaded = json.loads(path.read_text(encoding="utf-8"))
|
||||
if not isinstance(loaded, dict):
|
||||
raise ValueError(f"Beaver config must be a JSON object: {path}")
|
||||
data = loaded
|
||||
memory_data = _load_memory_config_data()
|
||||
|
||||
return BeaverConfig(
|
||||
agents_defaults=_parse_agent_defaults(data or {}),
|
||||
providers=_parse_providers((data or {}).get("providers")),
|
||||
embedding=_parse_embedding(data or {}),
|
||||
tools=_parse_tools((data or {}).get("tools")) if data is not None else ToolsConfig(),
|
||||
authz=_parse_authz((data or {}).get("authz")),
|
||||
channels=_parse_channels((data or {}).get("channels")),
|
||||
backend_identity=_parse_backend_identity(
|
||||
(data or {}).get("backend_identity") or (data or {}).get("backendIdentity")
|
||||
),
|
||||
memory=_parse_memory(memory_data),
|
||||
config_path=path,
|
||||
)
|
||||
|
||||
|
||||
def _load_memory_config_data() -> dict[str, Any]:
|
||||
path = default_memory_config_path()
|
||||
if not path.exists():
|
||||
return BeaverConfig(config_path=path)
|
||||
return {}
|
||||
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Beaver config must be a JSON object: {path}")
|
||||
|
||||
return BeaverConfig(
|
||||
agents_defaults=_parse_agent_defaults(data),
|
||||
providers=_parse_providers(data.get("providers")),
|
||||
embedding=_parse_embedding(data),
|
||||
tools=_parse_tools(data.get("tools")),
|
||||
authz=_parse_authz(data.get("authz")),
|
||||
channels=_parse_channels(data.get("channels")),
|
||||
backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")),
|
||||
memory=_parse_memory(data),
|
||||
config_path=path,
|
||||
)
|
||||
raise ValueError(f"Beaver memory config must be a JSON object: {path}")
|
||||
return data
|
||||
|
||||
|
||||
def _parse_agent_defaults(data: dict[str, Any]) -> AgentDefaultsConfig:
|
||||
@ -269,12 +293,10 @@ def _parse_memory(data: dict[str, Any]) -> MemoryConfig:
|
||||
scope = (
|
||||
_string_list(gateway_raw.get("scope"))
|
||||
if "scope" in gateway_raw
|
||||
else ["current_chat", "resources"]
|
||||
else MemoryGatewayConfig().scope
|
||||
)
|
||||
gateway = MemoryGatewayConfig(
|
||||
base_url=_string(gateway_raw.get("baseUrl") or gateway_raw.get("base_url")) or "",
|
||||
user_id=_string(gateway_raw.get("userId") or gateway_raw.get("user_id")) or "",
|
||||
user_key=_string(gateway_raw.get("userKey") or gateway_raw.get("user_key")) or "",
|
||||
app_id=_string(gateway_raw.get("appId") or gateway_raw.get("app_id")) or "default",
|
||||
project_id=_string(gateway_raw.get("projectId") or gateway_raw.get("project_id")) or "default",
|
||||
scope=scope,
|
||||
@ -283,15 +305,8 @@ def _parse_memory(data: dict[str, Any]) -> MemoryConfig:
|
||||
)
|
||||
|
||||
if mode == "hybrid" and explicit:
|
||||
missing: list[str] = []
|
||||
if not gateway.base_url:
|
||||
missing.append("baseUrl")
|
||||
if not gateway.user_id:
|
||||
missing.append("userId")
|
||||
if not gateway.user_key:
|
||||
missing.append("userKey")
|
||||
if missing:
|
||||
raise ValueError(f"Explicit hybrid memory requires gateway fields: {', '.join(missing)}")
|
||||
raise ValueError("Explicit hybrid memory requires gateway.baseUrl")
|
||||
allowed_scopes = {"current_chat", "resources", "all_user_memory"}
|
||||
if not gateway.scope or any(scope not in allowed_scopes for scope in gateway.scope):
|
||||
raise ValueError("memory.gateway.scope contains an unsupported value")
|
||||
|
||||
@ -6,6 +6,8 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beaver.memory.gateway import MemoryConfig, MemoryGatewayConfig
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderConfig:
|
||||
@ -115,33 +117,6 @@ class BackendIdentityConfig:
|
||||
public_base_url: str = ""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MemoryGatewayConfig:
|
||||
"""Fixed Memory Gateway settings for one Beaver instance."""
|
||||
|
||||
base_url: str = ""
|
||||
user_id: str = ""
|
||||
user_key: str = field(default="", repr=False)
|
||||
app_id: str = "default"
|
||||
project_id: str = "default"
|
||||
scope: list[str] = field(default_factory=lambda: ["current_chat", "resources"])
|
||||
top_k: int = 8
|
||||
timeout_seconds: float = 10.0
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(_clean(self.base_url) and _clean(self.user_id) and _clean(self.user_key))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MemoryConfig:
|
||||
"""Curated baseline plus optional Memory Gateway layer."""
|
||||
|
||||
mode: str = "hybrid"
|
||||
explicit: bool = False
|
||||
gateway: MemoryGatewayConfig = field(default_factory=MemoryGatewayConfig)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BeaverConfig:
|
||||
"""Config loaded once per backend sandbox instance."""
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
"""Memory Gateway HTTP integration."""
|
||||
|
||||
from .client import MemoryGatewayClient, MemoryGatewayClientError
|
||||
|
||||
__all__ = ["MemoryGatewayClient", "MemoryGatewayClientError"]
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
@ -21,6 +22,13 @@ from typing import Any
|
||||
from beaver.engine.providers.registry import PROVIDERS, find_by_name
|
||||
from beaver.foundation.config import default_config_path, load_config
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage
|
||||
from beaver.memory.gateway import (
|
||||
MemoryGatewayClient,
|
||||
MemoryGatewayClientError,
|
||||
MemoryGatewayCredentialStore,
|
||||
MemoryGatewayUserCredential,
|
||||
default_memory_gateway_users_path,
|
||||
)
|
||||
from beaver.interfaces.channels.runtime import ChannelRuntime
|
||||
from beaver.interfaces.channels.connections import (
|
||||
ChannelConnectionStore,
|
||||
@ -97,6 +105,8 @@ from .schemas import (
|
||||
WebStatusResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, File, Form, Header, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -588,6 +598,10 @@ def create_app(
|
||||
app.state.auth_tokens = {}
|
||||
app.state.handoff_codes = {}
|
||||
app.state.auth_file = Path(os.getenv("BEAVER_AUTH_FILE") or "")
|
||||
app.state.memory_gateway_credential_store = MemoryGatewayCredentialStore(
|
||||
default_memory_gateway_users_path()
|
||||
)
|
||||
app.state.memory_gateway_client_factory = lambda config: MemoryGatewayClient(config)
|
||||
max_file_size = 50 * 1024 * 1024
|
||||
max_user_file_upload_size = _int_env("BEAVER_USER_FILES_MAX_UPLOAD_BYTES", 5 * 1024 * 1024 * 1024)
|
||||
user_file_upload_part_size = _int_env("BEAVER_USER_FILES_UPLOAD_PART_SIZE", 10 * 1024 * 1024)
|
||||
@ -1103,6 +1117,30 @@ def create_app(
|
||||
users[username] = password
|
||||
_save_auth_users(auth_file, users)
|
||||
|
||||
if config.memory.mode == "hybrid" and config.memory.gateway.is_configured:
|
||||
try:
|
||||
gateway_client = app.state.memory_gateway_client_factory(config.memory.gateway)
|
||||
gateway_payload = await gateway_client.create_user(username)
|
||||
gateway_user_id = _clean_text(gateway_payload.get("user_id"))
|
||||
gateway_user_key = _clean_text(gateway_payload.get("user_key"))
|
||||
if not gateway_user_id or not gateway_user_key:
|
||||
raise MemoryGatewayClientError("create_user", "invalid_response")
|
||||
app.state.memory_gateway_credential_store.save(
|
||||
username,
|
||||
MemoryGatewayUserCredential(
|
||||
user_id=gateway_user_id,
|
||||
user_key=gateway_user_key,
|
||||
),
|
||||
)
|
||||
except MemoryGatewayClientError as exc:
|
||||
logger.warning(
|
||||
"Memory Gateway user provisioning failed for Beaver user %s: operation=%s category=%s status_code=%s",
|
||||
username,
|
||||
exc.operation,
|
||||
exc.category,
|
||||
exc.status_code,
|
||||
)
|
||||
|
||||
token = _issue_web_token(app, username)
|
||||
handoff_code, handoff_expires_at = _issue_handoff_code(app, username, token)
|
||||
backend_connection = {
|
||||
@ -2445,7 +2483,11 @@ def create_app(
|
||||
503: {"model": WebErrorResponse},
|
||||
},
|
||||
)
|
||||
async def chat(request: Request, payload: WebChatRequest) -> WebChatResponse:
|
||||
async def chat(
|
||||
request: Request,
|
||||
payload: WebChatRequest,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> WebChatResponse:
|
||||
agent_service = get_agent_service(request)
|
||||
message = payload.message.strip()
|
||||
if not message:
|
||||
@ -2496,10 +2538,12 @@ def create_app(
|
||||
embedding_target = _model_dump(payload.embedding_target)
|
||||
|
||||
try:
|
||||
gateway_user_id = _optional_web_user(app, authorization)
|
||||
direct_kwargs = {
|
||||
"session_id": payload.session_id,
|
||||
"source": "web",
|
||||
"user_id": payload.user_id,
|
||||
"gateway_user_id": gateway_user_id,
|
||||
"title": payload.title,
|
||||
"execution_context": payload.execution_context,
|
||||
"prompt_locale": payload.prompt_locale,
|
||||
@ -2558,6 +2602,7 @@ def create_app(
|
||||
await websocket.send_json({"type": "error", "error": "AgentService is not ready"})
|
||||
await websocket.close(code=1011)
|
||||
return
|
||||
gateway_user_id = _web_user_from_token(app, websocket.query_params.get("token"))
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -2616,6 +2661,7 @@ def create_app(
|
||||
"session_id": session_id,
|
||||
"source": "websocket",
|
||||
"user_id": _clean_text(payload.get("user_id")) or None,
|
||||
"gateway_user_id": gateway_user_id,
|
||||
"title": _clean_text(payload.get("title")) or None,
|
||||
"execution_context": _clean_text(payload.get("execution_context")) or None,
|
||||
"prompt_locale": _clean_text(payload.get("prompt_locale")) or None,
|
||||
@ -3680,6 +3726,22 @@ def _require_web_user(app: FastAPI, authorization: str | None) -> str:
|
||||
return username
|
||||
|
||||
|
||||
def _optional_web_user(app: FastAPI, authorization: str | None) -> str | None:
|
||||
if not authorization:
|
||||
return None
|
||||
prefix = "bearer "
|
||||
if not authorization.lower().startswith(prefix):
|
||||
return None
|
||||
return _web_user_from_token(app, authorization[len(prefix):].strip())
|
||||
|
||||
|
||||
def _web_user_from_token(app: FastAPI, token: str | None) -> str | None:
|
||||
cleaned = _clean_text(token)
|
||||
if not cleaned:
|
||||
return None
|
||||
return app.state.auth_tokens.get(cleaned)
|
||||
|
||||
|
||||
def _backend_connection_view(request: Request) -> dict[str, Any]:
|
||||
public_base_url = (
|
||||
os.getenv("BEAVER_BACKEND_IDENTITY__PUBLIC_BASE_URL")
|
||||
|
||||
23
app-instance/backend/beaver/memory/gateway/__init__.py
Normal file
23
app-instance/backend/beaver/memory/gateway/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Memory Gateway support."""
|
||||
|
||||
from .client import MemoryGatewayClient, MemoryGatewayClientError
|
||||
from .config import MemoryConfig, MemoryGatewayConfig
|
||||
from .credentials import (
|
||||
MemoryGatewayCredentialStore,
|
||||
MemoryGatewayUserCredential,
|
||||
default_memory_gateway_users_path,
|
||||
)
|
||||
from .service import GatewayPersistOutcome, GatewayRecallOutcome, MemoryGatewayService
|
||||
|
||||
__all__ = [
|
||||
"GatewayPersistOutcome",
|
||||
"GatewayRecallOutcome",
|
||||
"MemoryConfig",
|
||||
"MemoryGatewayCredentialStore",
|
||||
"MemoryGatewayClient",
|
||||
"MemoryGatewayClientError",
|
||||
"MemoryGatewayConfig",
|
||||
"MemoryGatewayService",
|
||||
"MemoryGatewayUserCredential",
|
||||
"default_memory_gateway_users_path",
|
||||
]
|
||||
@ -6,7 +6,7 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from beaver.foundation.config import MemoryGatewayConfig
|
||||
from .config import MemoryGatewayConfig
|
||||
|
||||
|
||||
class MemoryGatewayClientError(RuntimeError):
|
||||
@ -21,7 +21,7 @@ class MemoryGatewayClientError(RuntimeError):
|
||||
|
||||
|
||||
class MemoryGatewayClient:
|
||||
"""HTTP transport for search, add, and flush operations."""
|
||||
"""HTTP transport for search, add, flush, and provisioning operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -32,6 +32,9 @@ class MemoryGatewayClient:
|
||||
self.config = config
|
||||
self.transport = transport
|
||||
|
||||
async def create_user(self, user_id: str) -> dict[str, Any]:
|
||||
return await self._post("create_user", "/users", {"user_id": user_id})
|
||||
|
||||
async def search(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return await self._post("search", "/memories/search", payload)
|
||||
|
||||
32
app-instance/backend/beaver/memory/gateway/config.py
Normal file
32
app-instance/backend/beaver/memory/gateway/config.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Configuration models for the Memory Gateway layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MemoryGatewayConfig:
|
||||
"""Shared non-secret Memory Gateway settings."""
|
||||
|
||||
base_url: str = ""
|
||||
app_id: str = "default"
|
||||
project_id: str = "default"
|
||||
scope: list[str] = field(
|
||||
default_factory=lambda: ["current_chat", "resources", "all_user_memory"]
|
||||
)
|
||||
top_k: int = 8
|
||||
timeout_seconds: float = 10.0
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.base_url.strip())
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MemoryConfig:
|
||||
"""Curated baseline plus optional Memory Gateway layer."""
|
||||
|
||||
mode: str = "hybrid"
|
||||
explicit: bool = False
|
||||
gateway: MemoryGatewayConfig = field(default_factory=MemoryGatewayConfig)
|
||||
75
app-instance/backend/beaver/memory/gateway/credentials.py
Normal file
75
app-instance/backend/beaver/memory/gateway/credentials.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Per-instance credential storage for Memory Gateway users."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MemoryGatewayUserCredential:
|
||||
user_id: str
|
||||
user_key: str = field(repr=False)
|
||||
|
||||
|
||||
class MemoryGatewayCredentialStore:
|
||||
"""Persist Beaver username -> Gateway credential mappings."""
|
||||
|
||||
def __init__(self, path: str | Path) -> None:
|
||||
self.path = Path(path)
|
||||
|
||||
def get(self, username: str) -> MemoryGatewayUserCredential | None:
|
||||
users = self._load_users()
|
||||
payload = users.get(username)
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
user_id = str(payload.get("userId") or "").strip()
|
||||
user_key = str(payload.get("userKey") or "").strip()
|
||||
if not user_id or not user_key:
|
||||
return None
|
||||
return MemoryGatewayUserCredential(user_id=user_id, user_key=user_key)
|
||||
|
||||
def save(self, username: str, credential: MemoryGatewayUserCredential) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
users = self._load_users()
|
||||
users[username] = {
|
||||
"userId": credential.user_id,
|
||||
"userKey": credential.user_key,
|
||||
}
|
||||
payload = {"users": dict(sorted(users.items()))}
|
||||
fd, tmp_name = tempfile.mkstemp(
|
||||
prefix=f".{self.path.name}.",
|
||||
suffix=".tmp",
|
||||
dir=str(self.path.parent),
|
||||
)
|
||||
tmp_path = Path(tmp_name)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as handle:
|
||||
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
||||
handle.write("\n")
|
||||
os.chmod(tmp_path, 0o600)
|
||||
os.replace(tmp_path, self.path)
|
||||
os.chmod(self.path, 0o600)
|
||||
finally:
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
|
||||
def _load_users(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {}
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
users = data.get("users")
|
||||
return users if isinstance(users, dict) else {}
|
||||
|
||||
|
||||
def default_memory_gateway_users_path() -> Path:
|
||||
raw = os.getenv("BEAVER_MEMORY_GATEWAY_USERS_PATH")
|
||||
if raw:
|
||||
return Path(raw)
|
||||
return Path.home() / ".beaver" / "memory_gateway_users.json"
|
||||
@ -6,8 +6,9 @@ import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.config import MemoryGatewayConfig
|
||||
from beaver.integrations.memory_gateway import MemoryGatewayClient, MemoryGatewayClientError
|
||||
from .client import MemoryGatewayClient, MemoryGatewayClientError
|
||||
from .config import MemoryGatewayConfig
|
||||
from .credentials import MemoryGatewayUserCredential
|
||||
|
||||
_RECALL_FIELDS = ("id", "session_id", "text", "score", "source_scope", "resource_uri")
|
||||
|
||||
@ -33,16 +34,18 @@ class MemoryGatewayService:
|
||||
def __init__(
|
||||
self,
|
||||
config: MemoryGatewayConfig,
|
||||
credential: MemoryGatewayUserCredential,
|
||||
*,
|
||||
client: MemoryGatewayClient | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.credential = credential
|
||||
self.client = client or MemoryGatewayClient(config)
|
||||
|
||||
async def recall_before_run(self, *, session_id: str, query: str) -> GatewayRecallOutcome:
|
||||
payload = {
|
||||
"user_id": self.config.user_id,
|
||||
"user_key": self.config.user_key,
|
||||
"user_id": self.credential.user_id,
|
||||
"user_key": self.credential.user_key,
|
||||
"conversation_id": session_id,
|
||||
"query": query,
|
||||
"scope": list(self.config.scope),
|
||||
@ -90,8 +93,8 @@ class MemoryGatewayService:
|
||||
) -> GatewayPersistOutcome:
|
||||
gateway_session_id = f"chat:{session_id}"
|
||||
common = {
|
||||
"user_id": self.config.user_id,
|
||||
"user_key": self.config.user_key,
|
||||
"user_id": self.credential.user_id,
|
||||
"user_key": self.credential.user_key,
|
||||
"session_id": gateway_session_id,
|
||||
"app_id": self.config.app_id,
|
||||
"project_id": self.config.project_id,
|
||||
@ -100,7 +103,7 @@ class MemoryGatewayService:
|
||||
**common,
|
||||
"messages": [
|
||||
{
|
||||
"sender_id": self.config.user_id,
|
||||
"sender_id": self.credential.user_id,
|
||||
"role": "user",
|
||||
"timestamp": user_timestamp_ms,
|
||||
"content": user_text,
|
||||
@ -1,6 +1,6 @@
|
||||
"""Application services for Beaver."""
|
||||
|
||||
__all__ = ["AgentService", "CronService", "MemoryGatewayService", "MemoryService"]
|
||||
__all__ = ["AgentService", "CronService", "MemoryService"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
@ -12,10 +12,6 @@ def __getattr__(name: str):
|
||||
from .memory_service import MemoryService
|
||||
|
||||
return MemoryService
|
||||
if name == "MemoryGatewayService":
|
||||
from .memory_gateway_service import MemoryGatewayService
|
||||
|
||||
return MemoryGatewayService
|
||||
if name == "CronService":
|
||||
from .cron_service import CronService
|
||||
|
||||
|
||||
Reference in New Issue
Block a user