feat(memory-gateway): 引入 Memory Gateway 配置、凭据存储和服务编排
* 新增 MemoryGatewayConfig 和 MemoryConfig dataclass,用于配置管理。 * 实现 MemoryGatewayUserCredential 和 MemoryGatewayCredentialStore,用于处理用户凭据。 * 创建 MemoryGatewayService,用于管理与 Memory Gateway 的交互。 * 开发用于记忆设置的 JSON 配置文件。 * 增强单元测试,覆盖新功能,包括凭据存储和服务行为。 * 更新 entrypoint 和实例创建脚本,以初始化 Memory Gateway 用户存储。
This commit is contained in:
@ -12,6 +12,39 @@ from beaver.interfaces.web.app import create_app, _reload_agent_config
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
def test_load_config_reads_shared_memory_config(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
memory_config_path = tmp_path / "memory-config.json"
|
||||
memory_config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"memory": {
|
||||
"mode": "hybrid",
|
||||
"gateway": {
|
||||
"baseUrl": "http://172.19.207.37:8010",
|
||||
"appId": "default",
|
||||
"projectId": "default",
|
||||
"scope": ["current_chat", "resources", "all_user_memory"],
|
||||
"topK": 8,
|
||||
"timeoutSeconds": 10,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(memory_config_path))
|
||||
|
||||
config = load_config(config_path=config_path)
|
||||
|
||||
assert config.memory.mode == "hybrid"
|
||||
assert config.memory.gateway.base_url == "http://172.19.207.37:8010"
|
||||
assert config.memory.gateway.scope == ["current_chat", "resources", "all_user_memory"]
|
||||
assert config.memory.gateway.top_k == 8
|
||||
assert config.memory.gateway.timeout_seconds == 10
|
||||
|
||||
|
||||
def test_load_config_reads_current_instance_shape(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
@ -477,17 +510,25 @@ def test_load_config_adds_managed_local_mcp_servers(tmp_path) -> None:
|
||||
assert "beaver.interfaces.mcp.tools_server" in local.args
|
||||
|
||||
|
||||
def test_missing_memory_config_defaults_to_implicit_hybrid(tmp_path) -> None:
|
||||
def test_missing_memory_config_defaults_to_implicit_hybrid(
|
||||
tmp_path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(tmp_path / "missing-memory.json"))
|
||||
config = load_config(config_path=tmp_path / "missing.json")
|
||||
|
||||
assert config.memory.mode == "hybrid"
|
||||
assert config.memory.explicit is False
|
||||
assert config.memory.gateway.scope == ["current_chat", "resources"]
|
||||
assert config.memory.gateway.scope == ["current_chat", "resources", "all_user_memory"]
|
||||
|
||||
|
||||
def test_load_config_reads_explicit_curated_memory_mode(tmp_path) -> None:
|
||||
def test_load_config_reads_explicit_curated_memory_mode(
|
||||
tmp_path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps({"memory": {"mode": "curated"}}), encoding="utf-8")
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
memory_config_path = tmp_path / "memory-config.json"
|
||||
memory_config_path.write_text(json.dumps({"memory": {"mode": "curated"}}), encoding="utf-8")
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(memory_config_path))
|
||||
|
||||
config = load_config(config_path=config_path)
|
||||
|
||||
@ -495,17 +536,19 @@ def test_load_config_reads_explicit_curated_memory_mode(tmp_path) -> None:
|
||||
assert config.memory.explicit is True
|
||||
|
||||
|
||||
def test_load_config_reads_explicit_hybrid_gateway_settings(tmp_path) -> None:
|
||||
def test_load_config_reads_explicit_hybrid_gateway_settings(
|
||||
tmp_path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
memory_config_path = tmp_path / "memory-config.json"
|
||||
memory_config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"memory": {
|
||||
"mode": "hybrid",
|
||||
"gateway": {
|
||||
"baseUrl": "http://127.0.0.1:8010",
|
||||
"userId": "gateway-user",
|
||||
"userKey": "uk_secret",
|
||||
"appId": "beaver",
|
||||
"projectId": "sandbox",
|
||||
"scope": ["current_chat", "resources"],
|
||||
@ -517,14 +560,13 @@ def test_load_config_reads_explicit_hybrid_gateway_settings(tmp_path) -> None:
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(memory_config_path))
|
||||
|
||||
config = load_config(config_path=config_path)
|
||||
|
||||
assert config.memory.mode == "hybrid"
|
||||
assert config.memory.explicit is True
|
||||
assert config.memory.gateway.base_url == "http://127.0.0.1:8010"
|
||||
assert config.memory.gateway.user_id == "gateway-user"
|
||||
assert config.memory.gateway.user_key == "uk_secret"
|
||||
assert config.memory.gateway.app_id == "beaver"
|
||||
assert config.memory.gateway.project_id == "sandbox"
|
||||
assert config.memory.gateway.scope == ["current_chat", "resources"]
|
||||
@ -532,41 +574,33 @@ def test_load_config_reads_explicit_hybrid_gateway_settings(tmp_path) -> None:
|
||||
assert config.memory.gateway.timeout_seconds == 12.5
|
||||
|
||||
|
||||
def test_explicit_hybrid_requires_gateway_credentials_without_leaking_secret(tmp_path) -> None:
|
||||
def test_explicit_hybrid_requires_gateway_base_url(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"memory": {
|
||||
"mode": "hybrid",
|
||||
"gateway": {
|
||||
"baseUrl": "http://127.0.0.1:8010",
|
||||
"userKey": "uk_super_secret",
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
memory_config_path = tmp_path / "memory-config.json"
|
||||
memory_config_path.write_text(
|
||||
json.dumps({"memory": {"mode": "hybrid", "gateway": {"appId": "beaver"}}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(memory_config_path))
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
load_config(config_path=config_path)
|
||||
|
||||
assert "userId" in str(exc_info.value)
|
||||
assert "uk_super_secret" not in str(exc_info.value)
|
||||
assert "baseUrl" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_hybrid_memory_rejects_unknown_scope(tmp_path) -> None:
|
||||
def test_hybrid_memory_rejects_unknown_scope(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
memory_config_path = tmp_path / "memory-config.json"
|
||||
memory_config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"memory": {
|
||||
"mode": "hybrid",
|
||||
"gateway": {
|
||||
"baseUrl": "http://127.0.0.1:8010",
|
||||
"userId": "gateway-user",
|
||||
"userKey": "uk_secret",
|
||||
"scope": ["current_chat", "unknown"],
|
||||
},
|
||||
}
|
||||
@ -574,22 +608,23 @@ def test_hybrid_memory_rejects_unknown_scope(tmp_path) -> None:
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(memory_config_path))
|
||||
|
||||
with pytest.raises(ValueError, match="scope"):
|
||||
load_config(config_path=config_path)
|
||||
|
||||
|
||||
def test_hybrid_memory_rejects_empty_scope(tmp_path) -> None:
|
||||
def test_hybrid_memory_rejects_empty_scope(tmp_path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
memory_config_path = tmp_path / "memory-config.json"
|
||||
memory_config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"memory": {
|
||||
"mode": "hybrid",
|
||||
"gateway": {
|
||||
"baseUrl": "http://127.0.0.1:8010",
|
||||
"userId": "gateway-user",
|
||||
"userKey": "uk_secret",
|
||||
"scope": [],
|
||||
},
|
||||
}
|
||||
@ -597,6 +632,7 @@ def test_hybrid_memory_rejects_empty_scope(tmp_path) -> None:
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(memory_config_path))
|
||||
|
||||
with pytest.raises(ValueError, match="scope"):
|
||||
load_config(config_path=config_path)
|
||||
@ -610,18 +646,21 @@ def test_hybrid_memory_rejects_empty_scope(tmp_path) -> None:
|
||||
({"timeoutSeconds": 0}, "timeoutSeconds"),
|
||||
],
|
||||
)
|
||||
def test_hybrid_memory_rejects_invalid_limits(tmp_path, gateway_override, expected_error) -> None:
|
||||
def test_hybrid_memory_rejects_invalid_limits(
|
||||
tmp_path, gateway_override, expected_error, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
gateway = {
|
||||
"baseUrl": "http://127.0.0.1:8010",
|
||||
"userId": "gateway-user",
|
||||
"userKey": "uk_secret",
|
||||
**gateway_override,
|
||||
}
|
||||
config_path.write_text(
|
||||
memory_config_path = tmp_path / "memory-config.json"
|
||||
memory_config_path.write_text(
|
||||
json.dumps({"memory": {"mode": "hybrid", "gateway": gateway}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(memory_config_path))
|
||||
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
load_config(config_path=config_path)
|
||||
|
||||
@ -8,8 +8,13 @@ from beaver.engine import AgentLoop, EngineLoader
|
||||
from beaver.engine.providers.base import LLMProvider, LLMResponse
|
||||
from beaver.engine.providers.factory import ProviderBundle
|
||||
from beaver.foundation.config import BeaverConfig, MemoryConfig, MemoryGatewayConfig
|
||||
from beaver.integrations.memory_gateway import MemoryGatewayClientError
|
||||
from beaver.services.memory_gateway_service import GatewayPersistOutcome, GatewayRecallOutcome
|
||||
from beaver.memory.gateway import (
|
||||
GatewayPersistOutcome,
|
||||
GatewayRecallOutcome,
|
||||
MemoryGatewayClientError,
|
||||
MemoryGatewayCredentialStore,
|
||||
MemoryGatewayUserCredential,
|
||||
)
|
||||
|
||||
|
||||
class RecordingProvider(LLMProvider):
|
||||
@ -74,8 +79,6 @@ def _hybrid_config() -> BeaverConfig:
|
||||
explicit=True,
|
||||
gateway=MemoryGatewayConfig(
|
||||
base_url="http://gateway.test",
|
||||
user_id="gateway-user",
|
||||
user_key="uk_secret",
|
||||
scope=["current_chat", "resources"],
|
||||
),
|
||||
)
|
||||
@ -93,11 +96,24 @@ def _write_curated_user_memory(workspace: Path) -> None:
|
||||
(root / "USER.md").write_text("The user prefers concise answers.", encoding="utf-8")
|
||||
|
||||
|
||||
def _run(loop: AgentLoop, provider: LLMProvider, *, session_id: str = "web:gateway-test"):
|
||||
def _gateway_store(tmp_path: Path) -> MemoryGatewayCredentialStore:
|
||||
store = MemoryGatewayCredentialStore(tmp_path / "memory_gateway_users.json")
|
||||
store.save("tom", MemoryGatewayUserCredential(user_id="gateway-user", user_key="uk_secret"))
|
||||
return store
|
||||
|
||||
|
||||
def _run(
|
||||
loop: AgentLoop,
|
||||
provider: LLMProvider,
|
||||
*,
|
||||
session_id: str = "web:gateway-test",
|
||||
gateway_user_id: str | None = "tom",
|
||||
):
|
||||
return asyncio.run(
|
||||
loop.process_direct(
|
||||
"What should I remember?",
|
||||
session_id=session_id,
|
||||
gateway_user_id=gateway_user_id,
|
||||
provider_bundle=_bundle(provider),
|
||||
include_skill_assembly=False,
|
||||
include_tools=False,
|
||||
@ -134,7 +150,8 @@ def test_hybrid_run_keeps_curated_context_and_persists_gateway_turn(tmp_path: Pa
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
config=_hybrid_config(),
|
||||
memory_gateway_service=gateway,
|
||||
memory_gateway_credentials=_gateway_store(tmp_path),
|
||||
memory_gateway_service_factory=lambda _config, _credential: gateway,
|
||||
)
|
||||
)
|
||||
|
||||
@ -182,7 +199,8 @@ def test_gateway_recall_failure_is_audited_without_changing_result(tmp_path: Pat
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
config=_hybrid_config(),
|
||||
memory_gateway_service=gateway,
|
||||
memory_gateway_credentials=_gateway_store(tmp_path),
|
||||
memory_gateway_service_factory=lambda _config, _credential: gateway,
|
||||
)
|
||||
)
|
||||
|
||||
@ -210,7 +228,8 @@ def test_gateway_add_failure_skips_flush_audit_and_preserves_result(tmp_path: Pa
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
config=_hybrid_config(),
|
||||
memory_gateway_service=gateway,
|
||||
memory_gateway_credentials=_gateway_store(tmp_path),
|
||||
memory_gateway_service_factory=lambda _config, _credential: gateway,
|
||||
)
|
||||
)
|
||||
|
||||
@ -235,7 +254,8 @@ def test_gateway_flush_failure_records_add_success_and_flush_failure(tmp_path: P
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
config=_hybrid_config(),
|
||||
memory_gateway_service=gateway,
|
||||
memory_gateway_credentials=_gateway_store(tmp_path),
|
||||
memory_gateway_service_factory=lambda _config, _credential: gateway,
|
||||
)
|
||||
)
|
||||
|
||||
@ -276,7 +296,8 @@ def test_failed_run_is_not_persisted_to_gateway(tmp_path: Path) -> None:
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
config=_hybrid_config(),
|
||||
memory_gateway_service=gateway,
|
||||
memory_gateway_credentials=_gateway_store(tmp_path),
|
||||
memory_gateway_service_factory=lambda _config, _credential: gateway,
|
||||
)
|
||||
)
|
||||
|
||||
@ -286,3 +307,23 @@ def test_failed_run_is_not_persisted_to_gateway(tmp_path: Path) -> None:
|
||||
assert gateway.recall_calls
|
||||
assert gateway.persist_calls == []
|
||||
loop.close()
|
||||
|
||||
|
||||
def test_missing_gateway_identity_skips_gateway_calls(tmp_path: Path) -> None:
|
||||
gateway = FakeGatewayService()
|
||||
provider = RecordingProvider(LLMResponse(content="Curated only.", finish_reason="stop"))
|
||||
loop = AgentLoop(
|
||||
loader=EngineLoader(
|
||||
workspace=tmp_path,
|
||||
config=_hybrid_config(),
|
||||
memory_gateway_credentials=_gateway_store(tmp_path),
|
||||
memory_gateway_service_factory=lambda _config, _credential: gateway,
|
||||
)
|
||||
)
|
||||
|
||||
result = _run(loop, provider, session_id="web:no-gateway-user", gateway_user_id=None)
|
||||
|
||||
assert result.output_text == "Curated only."
|
||||
assert gateway.recall_calls == []
|
||||
assert gateway.persist_calls == []
|
||||
loop.close()
|
||||
|
||||
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import stat
|
||||
|
||||
from beaver.memory.gateway import (
|
||||
MemoryGatewayCredentialStore,
|
||||
MemoryGatewayUserCredential,
|
||||
)
|
||||
|
||||
|
||||
def test_credential_store_returns_none_for_missing_user(tmp_path) -> None:
|
||||
store = MemoryGatewayCredentialStore(tmp_path / "memory_gateway_users.json")
|
||||
|
||||
assert store.get("tom") is None
|
||||
|
||||
|
||||
def test_credential_store_round_trips_multiple_users(tmp_path) -> None:
|
||||
path = tmp_path / "memory_gateway_users.json"
|
||||
store = MemoryGatewayCredentialStore(path)
|
||||
|
||||
store.save("tom", MemoryGatewayUserCredential(user_id="tom", user_key="uk_tom"))
|
||||
store.save("alice", MemoryGatewayUserCredential(user_id="alice", user_key="uk_alice"))
|
||||
|
||||
assert store.get("tom") == MemoryGatewayUserCredential(user_id="tom", user_key="uk_tom")
|
||||
assert store.get("alice") == MemoryGatewayUserCredential(user_id="alice", user_key="uk_alice")
|
||||
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
assert payload == {
|
||||
"users": {
|
||||
"alice": {"userId": "alice", "userKey": "uk_alice"},
|
||||
"tom": {"userId": "tom", "userKey": "uk_tom"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_credential_store_update_preserves_other_users(tmp_path) -> None:
|
||||
path = tmp_path / "memory_gateway_users.json"
|
||||
store = MemoryGatewayCredentialStore(path)
|
||||
store.save("tom", MemoryGatewayUserCredential(user_id="tom", user_key="uk_old"))
|
||||
store.save("alice", MemoryGatewayUserCredential(user_id="alice", user_key="uk_alice"))
|
||||
|
||||
store.save("tom", MemoryGatewayUserCredential(user_id="tom", user_key="uk_new"))
|
||||
|
||||
assert store.get("tom") == MemoryGatewayUserCredential(user_id="tom", user_key="uk_new")
|
||||
assert store.get("alice") == MemoryGatewayUserCredential(user_id="alice", user_key="uk_alice")
|
||||
|
||||
|
||||
def test_credential_store_masks_secret_in_repr_and_uses_private_mode(tmp_path) -> None:
|
||||
path = tmp_path / "memory_gateway_users.json"
|
||||
credential = MemoryGatewayUserCredential(user_id="tom", user_key="uk_super_secret")
|
||||
store = MemoryGatewayCredentialStore(path)
|
||||
|
||||
store.save("tom", credential)
|
||||
|
||||
assert "uk_super_secret" not in repr(credential)
|
||||
assert stat.S_IMODE(path.stat().st_mode) == 0o600
|
||||
assert not any(child.suffix == ".tmp" for child in tmp_path.iterdir())
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from beaver.engine import EngineLoader
|
||||
from beaver.foundation.config import BeaverConfig, MemoryConfig, MemoryGatewayConfig
|
||||
from beaver.memory.gateway import MemoryGatewayCredentialStore, MemoryGatewayUserCredential
|
||||
|
||||
|
||||
def test_loader_keeps_curated_memory_in_explicit_curated_mode(tmp_path) -> None:
|
||||
@ -14,7 +15,9 @@ def test_loader_keeps_curated_memory_in_explicit_curated_mode(tmp_path) -> None:
|
||||
loaded = EngineLoader(workspace=tmp_path, config=config).load()
|
||||
|
||||
try:
|
||||
assert loaded.memory_gateway_service is None
|
||||
assert loaded.memory_gateway_config is None
|
||||
assert loaded.memory_gateway_credentials is None
|
||||
assert loaded.memory_gateway_service_factory is None
|
||||
assert loaded.curated_memory_store is not None
|
||||
assert loaded.memory_service is not None
|
||||
assert "memory" in loaded.tools
|
||||
@ -26,22 +29,30 @@ def test_loader_keeps_curated_memory_in_explicit_curated_mode(tmp_path) -> None:
|
||||
def test_loader_adds_gateway_service_without_disabling_curated_memory(tmp_path) -> None:
|
||||
gateway_config = MemoryGatewayConfig(
|
||||
base_url="http://gateway.test",
|
||||
user_id="gateway-user",
|
||||
user_key="uk_secret",
|
||||
)
|
||||
config = BeaverConfig(
|
||||
memory=MemoryConfig(mode="hybrid", explicit=True, gateway=gateway_config)
|
||||
)
|
||||
credential_store = MemoryGatewayCredentialStore(tmp_path / "memory_gateway_users.json")
|
||||
fake_gateway_service = object()
|
||||
|
||||
loaded = EngineLoader(
|
||||
workspace=tmp_path,
|
||||
config=config,
|
||||
memory_gateway_service=fake_gateway_service,
|
||||
memory_gateway_credentials=credential_store,
|
||||
memory_gateway_service_factory=lambda cfg, credential: fake_gateway_service,
|
||||
).load()
|
||||
|
||||
try:
|
||||
assert loaded.memory_gateway_service is fake_gateway_service
|
||||
assert loaded.memory_gateway_config == gateway_config
|
||||
assert loaded.memory_gateway_credentials is credential_store
|
||||
assert loaded.memory_gateway_service_factory is not None
|
||||
assert (
|
||||
loaded.memory_gateway_service_factory(
|
||||
MemoryGatewayUserCredential(user_id="gateway-user", user_key="uk_secret")
|
||||
)
|
||||
is fake_gateway_service
|
||||
)
|
||||
assert loaded.curated_memory_store is not None
|
||||
assert loaded.memory_service is not None
|
||||
assert "memory" in loaded.tools
|
||||
@ -60,7 +71,7 @@ def test_loader_implicit_hybrid_without_credentials_warns_and_degrades(
|
||||
loaded = EngineLoader(workspace=tmp_path, config=config).load()
|
||||
|
||||
try:
|
||||
assert loaded.memory_gateway_service is None
|
||||
assert loaded.memory_gateway_config is None
|
||||
assert loaded.curated_memory_store is not None
|
||||
assert "memory" in loaded.tools
|
||||
assert "continuing with curated memory only" in caplog.text
|
||||
@ -76,7 +87,7 @@ def test_loader_explicit_hybrid_without_credentials_fails_before_opening_session
|
||||
memory=MemoryConfig(
|
||||
mode="hybrid",
|
||||
explicit=True,
|
||||
gateway=MemoryGatewayConfig(user_key="uk_super_secret"),
|
||||
gateway=MemoryGatewayConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
@ -89,4 +100,3 @@ def test_loader_explicit_hybrid_without_credentials_fails_before_opening_session
|
||||
EngineLoader(workspace=tmp_path, config=config).load()
|
||||
|
||||
assert "Memory Gateway" in str(exc_info.value)
|
||||
assert "uk_super_secret" not in str(exc_info.value)
|
||||
|
||||
@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from beaver.interfaces.web.app import create_app
|
||||
from beaver.memory.gateway import (
|
||||
MemoryGatewayClientError,
|
||||
MemoryGatewayCredentialStore,
|
||||
)
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
class FakeGatewayClient:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
response: dict[str, str] | None = None,
|
||||
error: MemoryGatewayClientError | None = None,
|
||||
) -> None:
|
||||
self.response = response or {"user_id": "tom", "user_key": "uk_tom"}
|
||||
self.error = error
|
||||
self.calls: list[str] = []
|
||||
|
||||
async def create_user(self, user_id: str) -> dict[str, str]:
|
||||
self.calls.append(user_id)
|
||||
if self.error is not None:
|
||||
raise self.error
|
||||
return dict(self.response)
|
||||
|
||||
|
||||
def _service(tmp_path) -> AgentService:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps({}), encoding="utf-8")
|
||||
return AgentService(config_path=config_path)
|
||||
|
||||
|
||||
def _write_memory_config(tmp_path) -> None:
|
||||
memory_config_path = tmp_path / "memory-config.json"
|
||||
memory_config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"memory": {
|
||||
"mode": "hybrid",
|
||||
"gateway": {
|
||||
"baseUrl": "http://172.19.207.37:8010",
|
||||
"appId": "default",
|
||||
"projectId": "default",
|
||||
"scope": ["current_chat", "resources", "all_user_memory"],
|
||||
"topK": 8,
|
||||
"timeoutSeconds": 10,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def test_register_provisions_gateway_user_and_hides_key(
|
||||
tmp_path, monkeypatch
|
||||
) -> None:
|
||||
auth_path = tmp_path / "web_auth_users.json"
|
||||
users_path = tmp_path / "memory_gateway_users.json"
|
||||
monkeypatch.setenv("BEAVER_AUTH_FILE", str(auth_path))
|
||||
monkeypatch.setenv("BEAVER_MEMORY_GATEWAY_USERS_PATH", str(users_path))
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(tmp_path / "memory-config.json"))
|
||||
_write_memory_config(tmp_path)
|
||||
|
||||
service = _service(tmp_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
fake_client = FakeGatewayClient(response={"user_id": "tom", "user_key": "uk_tom"})
|
||||
app.state.memory_gateway_client_factory = lambda _config: fake_client
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/api/auth/register",
|
||||
json={"username": "tom", "password": "pw"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert fake_client.calls == ["tom"]
|
||||
body = response.json()
|
||||
assert "user_key" not in json.dumps(body)
|
||||
assert MemoryGatewayCredentialStore(users_path).get("tom") is not None
|
||||
assert MemoryGatewayCredentialStore(users_path).get("tom").user_key == "uk_tom"
|
||||
service.close()
|
||||
|
||||
|
||||
def test_register_keeps_local_user_and_logs_when_gateway_provisioning_fails(
|
||||
tmp_path, monkeypatch, caplog
|
||||
) -> None:
|
||||
auth_path = tmp_path / "web_auth_users.json"
|
||||
users_path = tmp_path / "memory_gateway_users.json"
|
||||
monkeypatch.setenv("BEAVER_AUTH_FILE", str(auth_path))
|
||||
monkeypatch.setenv("BEAVER_MEMORY_GATEWAY_USERS_PATH", str(users_path))
|
||||
monkeypatch.setenv("BEAVER_MEMORY_CONFIG_PATH", str(tmp_path / "memory-config.json"))
|
||||
_write_memory_config(tmp_path)
|
||||
|
||||
service = _service(tmp_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
app.state.memory_gateway_client_factory = lambda _config: FakeGatewayClient(
|
||||
error=MemoryGatewayClientError("create_user", "network")
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="beaver.interfaces.web.app"):
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/api/auth/register",
|
||||
json={"username": "tom", "password": "pw"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
auth_payload = json.loads(auth_path.read_text(encoding="utf-8"))
|
||||
assert auth_payload == {"users": [{"username": "tom", "password": "pw"}]}
|
||||
assert MemoryGatewayCredentialStore(users_path).get("tom") is None
|
||||
assert "Memory Gateway user provisioning failed" in caplog.text
|
||||
assert "operation=create_user" in caplog.text
|
||||
assert "category=network" in caplog.text
|
||||
assert "user_key" not in caplog.text
|
||||
service.close()
|
||||
@ -5,16 +5,18 @@ import json
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from beaver.foundation.config import MemoryGatewayConfig
|
||||
from beaver.integrations.memory_gateway import MemoryGatewayClient, MemoryGatewayClientError
|
||||
from beaver.services.memory_gateway_service import MemoryGatewayService
|
||||
from beaver.memory.gateway import (
|
||||
MemoryGatewayClient,
|
||||
MemoryGatewayClientError,
|
||||
MemoryGatewayConfig,
|
||||
MemoryGatewayService,
|
||||
MemoryGatewayUserCredential,
|
||||
)
|
||||
|
||||
|
||||
def _config() -> MemoryGatewayConfig:
|
||||
return MemoryGatewayConfig(
|
||||
base_url="http://gateway.test",
|
||||
user_id="gateway-user",
|
||||
user_key="uk_super_secret",
|
||||
app_id="beaver",
|
||||
project_id="sandbox",
|
||||
scope=["current_chat", "resources"],
|
||||
@ -23,6 +25,10 @@ def _config() -> MemoryGatewayConfig:
|
||||
)
|
||||
|
||||
|
||||
def _credential() -> MemoryGatewayUserCredential:
|
||||
return MemoryGatewayUserCredential(user_id="gateway-user", user_key="uk_super_secret")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_uses_exact_gateway_paths_and_payloads() -> None:
|
||||
requests: list[httpx.Request] = []
|
||||
@ -113,7 +119,7 @@ async def test_recall_sanitizes_results_and_builds_reference_message() -> None:
|
||||
]
|
||||
}
|
||||
)
|
||||
service = MemoryGatewayService(_config(), client=client)
|
||||
service = MemoryGatewayService(_config(), _credential(), client=client)
|
||||
|
||||
outcome = await service.recall_before_run(session_id="web:alpha", query="contract")
|
||||
|
||||
@ -146,6 +152,7 @@ async def test_recall_sanitizes_results_and_builds_reference_message() -> None:
|
||||
async def test_recall_rejects_malformed_results_shape() -> None:
|
||||
service = MemoryGatewayService(
|
||||
_config(),
|
||||
_credential(),
|
||||
client=FakeGatewayClient(search_response={"results": {"not": "a list"}}),
|
||||
)
|
||||
|
||||
@ -160,7 +167,7 @@ async def test_recall_rejects_malformed_results_shape() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_after_run_adds_two_messages_then_flushes() -> None:
|
||||
client = FakeGatewayClient()
|
||||
service = MemoryGatewayService(_config(), client=client)
|
||||
service = MemoryGatewayService(_config(), _credential(), client=client)
|
||||
|
||||
outcome = await service.persist_after_run(
|
||||
session_id="web:alpha",
|
||||
@ -206,7 +213,7 @@ async def test_persist_after_run_adds_two_messages_then_flushes() -> None:
|
||||
async def test_add_failure_skips_flush() -> None:
|
||||
add_error = MemoryGatewayClientError("add", "http_status", status_code=503)
|
||||
client = FakeGatewayClient(add_error=add_error)
|
||||
service = MemoryGatewayService(_config(), client=client)
|
||||
service = MemoryGatewayService(_config(), _credential(), client=client)
|
||||
|
||||
outcome = await service.persist_after_run(
|
||||
session_id="web:alpha",
|
||||
@ -226,7 +233,7 @@ async def test_add_failure_skips_flush() -> None:
|
||||
async def test_flush_failure_preserves_successful_add() -> None:
|
||||
flush_error = MemoryGatewayClientError("flush", "network")
|
||||
client = FakeGatewayClient(flush_error=flush_error)
|
||||
service = MemoryGatewayService(_config(), client=client)
|
||||
service = MemoryGatewayService(_config(), _credential(), client=client)
|
||||
|
||||
outcome = await service.persist_after_run(
|
||||
session_id="web:alpha",
|
||||
|
||||
@ -88,6 +88,7 @@ def test_websocket_message_returns_chat_metadata_and_session_updated() -> None:
|
||||
"session_id": "web:alpha",
|
||||
"source": "websocket",
|
||||
"user_id": None,
|
||||
"gateway_user_id": None,
|
||||
"title": None,
|
||||
"execution_context": None,
|
||||
"prompt_locale": "zh-Hant",
|
||||
@ -134,6 +135,7 @@ def test_websocket_message_uses_direct_processing_when_loop_is_not_running() ->
|
||||
"session_id": "web:alpha",
|
||||
"source": "websocket",
|
||||
"user_id": None,
|
||||
"gateway_user_id": None,
|
||||
"title": None,
|
||||
"execution_context": None,
|
||||
"prompt_locale": None,
|
||||
@ -164,6 +166,7 @@ def test_rest_chat_uses_direct_processing_when_loop_is_not_running() -> None:
|
||||
"session_id": "web:alpha",
|
||||
"source": "web",
|
||||
"user_id": None,
|
||||
"gateway_user_id": None,
|
||||
"title": None,
|
||||
"execution_context": None,
|
||||
"prompt_locale": "en",
|
||||
@ -181,6 +184,72 @@ def test_rest_chat_uses_direct_processing_when_loop_is_not_running() -> None:
|
||||
assert response.json()["output_text"] == "echo:hello"
|
||||
|
||||
|
||||
def test_rest_chat_uses_authenticated_user_for_gateway_identity() -> None:
|
||||
service = DirectModeOnlyAgentService()
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
app.state.auth_tokens["token-1"] = "tom"
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/api/chat",
|
||||
headers={"Authorization": "Bearer token-1"},
|
||||
json={"session_id": "web:alpha", "message": "hello", "user_id": "other"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert service.calls == [
|
||||
{
|
||||
"message": "hello",
|
||||
"session_id": "web:alpha",
|
||||
"source": "web",
|
||||
"user_id": "other",
|
||||
"gateway_user_id": "tom",
|
||||
"title": None,
|
||||
"execution_context": None,
|
||||
"prompt_locale": None,
|
||||
"model": None,
|
||||
"provider_name": None,
|
||||
"embedding_model": None,
|
||||
"temperature": None,
|
||||
"max_tokens": None,
|
||||
"max_tool_iterations": None,
|
||||
"fallback_target": None,
|
||||
"auxiliary_target": None,
|
||||
"embedding_target": None,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_websocket_uses_authenticated_user_for_gateway_identity() -> None:
|
||||
service = StubAgentService()
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
app.state.auth_tokens["token-1"] = "tom"
|
||||
|
||||
with TestClient(app) as client:
|
||||
with client.websocket_connect("/ws/web:alpha?token=token-1") as websocket:
|
||||
websocket.send_json({"type": "message", "content": "hello", "user_id": "other"})
|
||||
assert websocket.receive_json() == {"type": "status", "status": "thinking"}
|
||||
websocket.receive_json()
|
||||
websocket.receive_json()
|
||||
|
||||
assert service.calls == [
|
||||
{
|
||||
"message": "hello",
|
||||
"session_id": "web:alpha",
|
||||
"source": "websocket",
|
||||
"user_id": "other",
|
||||
"gateway_user_id": "tom",
|
||||
"title": None,
|
||||
"execution_context": None,
|
||||
"prompt_locale": None,
|
||||
"model": None,
|
||||
"provider_name": None,
|
||||
"embedding_model": None,
|
||||
"max_tool_iterations": None,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_websocket_empty_content_returns_error_without_runtime_call() -> None:
|
||||
service = StubAgentService()
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
|
||||
Reference in New Issue
Block a user