diff --git a/memory_system_api/clients.py b/memory_system_api/clients.py index 4ee0232..b84e61c 100644 --- a/memory_system_api/clients.py +++ b/memory_system_api/clients.py @@ -1,6 +1,7 @@ """Async clients for OpenViking and EverOS used by the lightweight API.""" from __future__ import annotations +from dataclasses import dataclass from datetime import datetime, timezone from typing import Any @@ -10,6 +11,13 @@ from .config import get_config from .store import OpenVikingUserKeyStore +@dataclass(frozen=True) +class OpenVikingCredential: + api_key: str + account_id: str | None = None + user_id: str | None = None + + class OpenVikingMemorySystemClient: def __init__(self, store: OpenVikingUserKeyStore | None = None) -> None: config = get_config() @@ -25,35 +33,42 @@ class OpenVikingMemorySystemClient: response.raise_for_status() return response.json() - async def ensure_user(self, user_id: str) -> str: + async def ensure_user(self, user_id: str) -> OpenVikingCredential: existing = self.store.get_user_key(user_id) if existing: - return existing + return OpenVikingCredential(api_key=existing) async with self._client(self.root_key) as client: response = await client.post( "/api/v1/admin/accounts", json={"account_id": user_id, "admin_user_id": user_id}, ) + if response.status_code == 409: + return self.root_credential(user_id) response.raise_for_status() data = response.json() user_key = self._extract_user_key(data) if not user_key: - raise RuntimeError("OpenViking did not return user_key") + return self.root_credential(user_id) self.store.save_user_key(user_id, user_key) - return user_key + return OpenVikingCredential(api_key=user_key) - async def ensure_session(self, user_key: str, session_id: str) -> dict[str, Any]: - async with self._client(user_key) as client: + def root_credential(self, user_id: str) -> OpenVikingCredential: + return OpenVikingCredential(api_key=self.root_key, account_id=user_id, user_id=user_id) + + async def ensure_session(self, credential: OpenVikingCredential | str, session_id: str) -> dict[str, Any]: + async with self._credential_client(credential) as client: response = await client.post("/api/v1/sessions", json={"session_id": session_id}) if response.status_code in {409, 422}: return {"session_id": session_id, "status": "exists"} response.raise_for_status() return response.json() - async def append_message(self, user_key: str, session_id: str, role: str, content: str) -> dict[str, Any]: - async with self._client(user_key) as client: + async def append_message( + self, credential: OpenVikingCredential | str, session_id: str, role: str, content: str + ) -> dict[str, Any]: + async with self._credential_client(credential) as client: response = await client.post( f"/api/v1/sessions/{session_id}/messages", json={"role": role, "content": content}, @@ -61,26 +76,28 @@ class OpenVikingMemorySystemClient: response.raise_for_status() return response.json() - async def commit_session(self, user_key: str, session_id: str) -> dict[str, Any]: - async with self._client(user_key) as client: + async def commit_session(self, credential: OpenVikingCredential | str, session_id: str) -> dict[str, Any]: + async with self._credential_client(credential) as client: response = await client.post(f"/api/v1/sessions/{session_id}/commit") response.raise_for_status() return response.json() - async def extract_session(self, user_key: str, session_id: str) -> dict[str, Any]: - async with self._client(user_key) as client: + async def extract_session(self, credential: OpenVikingCredential | str, session_id: str) -> dict[str, Any]: + async with self._credential_client(credential) as client: response = await client.post(f"/api/v1/sessions/{session_id}/extract") response.raise_for_status() return response.json() - async def get_task(self, user_key: str, task_id: str) -> dict[str, Any]: - async with self._client(user_key) as client: + async def get_task(self, credential: OpenVikingCredential | str, task_id: str) -> dict[str, Any]: + async with self._credential_client(credential) as client: response = await client.get(f"/api/v1/tasks/{task_id}") response.raise_for_status() return response.json() - async def find(self, user_key: str, user_id: str, query: str, limit: int) -> dict[str, Any]: - async with self._client(user_key) as client: + async def find( + self, credential: OpenVikingCredential | str, user_id: str, query: str, limit: int + ) -> dict[str, Any]: + async with self._credential_client(credential) as client: response = await client.post( "/api/v1/search/find", json={ @@ -92,19 +109,34 @@ class OpenVikingMemorySystemClient: response.raise_for_status() return response.json() - async def search(self, user_key: str, session_id: str | None, query: str, limit: int) -> dict[str, Any]: + async def search( + self, credential: OpenVikingCredential | str, session_id: str | None, query: str, limit: int + ) -> dict[str, Any]: payload: dict[str, Any] = {"query": query, "limit": limit} if session_id: payload["session_id"] = session_id - async with self._client(user_key) as client: + async with self._credential_client(credential) as client: response = await client.post("/api/v1/search/search", json=payload) response.raise_for_status() return response.json() - def _client(self, api_key: str) -> httpx.AsyncClient: + def _credential_client(self, credential: OpenVikingCredential | str) -> httpx.AsyncClient: + if isinstance(credential, str): + return self._client(credential) + headers = {} + if credential.account_id: + headers["X-OpenViking-Account"] = credential.account_id + if credential.user_id: + headers["X-OpenViking-User"] = credential.user_id + return self._client(credential.api_key, headers) + + def _client(self, api_key: str, extra_headers: dict[str, str] | None = None) -> httpx.AsyncClient: + headers = {"X-API-Key": api_key, "Content-Type": "application/json"} + if extra_headers: + headers.update(extra_headers) return httpx.AsyncClient( base_url=self.base_url, - headers={"X-API-Key": api_key, "Content-Type": "application/json"}, + headers=headers, timeout=self.timeout, verify=self.verify_ssl, ) diff --git a/memory_system_api/config.py b/memory_system_api/config.py index 6b74a51..8976863 100644 --- a/memory_system_api/config.py +++ b/memory_system_api/config.py @@ -25,7 +25,7 @@ class OpenVikingConfig(BaseModel): class EverOSConfig(BaseModel): url: str = "http://127.0.0.1:1995" api_key: str = "" - timeout: int = 30 + timeout: int = 180 verify_ssl: bool = True health_path: str = "/health" diff --git a/memory_system_api/service.py b/memory_system_api/service.py index 053fcfb..6dc7093 100644 --- a/memory_system_api/service.py +++ b/memory_system_api/service.py @@ -27,16 +27,14 @@ class MemorySystemService: if not messages: raise ValueError("at least one message is required") - user_key = await self.openviking.ensure_user(request.user_id) - await self.openviking.ensure_session(user_key, request.session_id) - - print("user_key:", user_key) # Debugging line to check the user_key value + credential = await self.openviking.ensure_user(request.user_id) + await self.openviking.ensure_session(credential, request.session_id) async def write_openviking() -> list[dict[str, Any]]: results = [] for message in messages: results.append( - await self.openviking.append_message(user_key, request.session_id, message["role"], message["content"]) + await self.openviking.append_message(credential, request.session_id, message["role"], message["content"]) ) return results @@ -97,6 +95,7 @@ class MemorySystemService: ) backends = await self._run_backends(openviking=search_openviking, everos=search_everos) + backends = self._remove_vectors_from_backends(backends) items = self._merge_search_items(backends) return SearchResponse(status=self._aggregate_status(backends), items=items[: request.limit], backends=backends) @@ -126,7 +125,9 @@ class MemorySystemService: try: return BackendStatus(status="success", result=await call()) except Exception as exc: # noqa: BLE001 - return BackendStatus(status="failed", error=str(exc)) + message = str(exc) + error = f"{type(exc).__name__}: {message}" if message else type(exc).__name__ + return BackendStatus(status="failed", error=error) def _aggregate_status(self, backends: dict[str, BackendStatus]) -> str: statuses = {backend.status for backend in backends.values()} @@ -163,3 +164,16 @@ class MemorySystemService: if "source_backend" in item: return item return {"source_backend": backend_name, **item} + + def _remove_vectors_from_backends(self, backends: dict[str, BackendStatus]) -> dict[str, BackendStatus]: + return { + name: backend.model_copy(update={"result": self._remove_vectors(backend.result)}) + for name, backend in backends.items() + } + + def _remove_vectors(self, value: Any) -> Any: + if isinstance(value, dict): + return {key: self._remove_vectors(item) for key, item in value.items() if key != "vector"} + if isinstance(value, list): + return [self._remove_vectors(item) for item in value] + return value diff --git a/tests/test_memory_system_clients.py b/tests/test_memory_system_clients.py index a4f2f33..e933b2f 100644 --- a/tests/test_memory_system_clients.py +++ b/tests/test_memory_system_clients.py @@ -1,4 +1,96 @@ -from memory_system_api.clients import EverOSMemorySystemClient +import asyncio + +from memory_system_api.clients import EverOSMemorySystemClient, OpenVikingMemorySystemClient + + +class FakeStore: + def __init__(self): + self.saved = {} + + def get_user_key(self, user_id: str) -> str | None: + return self.saved.get(user_id) + + def save_user_key(self, user_id: str, user_key: str) -> None: + self.saved[user_id] = user_key + + +class FakeResponse: + def __init__(self, status_code: int, data: dict): + self.status_code = status_code + self._data = data + + def json(self) -> dict: + return self._data + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise AssertionError(f"unexpected status {self.status_code}") + + +class FakeAsyncClient: + def __init__(self, calls: list, responses: list[FakeResponse], api_key: str, headers: dict): + self.calls = calls + self.responses = responses + self.api_key = api_key + self.headers = headers + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, path: str, json: dict | None = None) -> FakeResponse: + self.calls.append(("post", self.api_key, self.headers, path, json)) + return self.responses.pop(0) + + +def test_openviking_uses_root_identity_when_account_already_exists(): + store = FakeStore() + client = OpenVikingMemorySystemClient(store=store) + client.root_key = "root-key" + calls = [] + responses = [FakeResponse(409, {"status": "error", "error": {"code": "CONFLICT"}})] + client._client = lambda api_key, extra_headers=None: FakeAsyncClient( # type: ignore[method-assign] + calls, + responses, + api_key, + extra_headers or {}, + ) + + credential = asyncio.run(client.ensure_user("tom")) + + assert credential.api_key == "root-key" + assert credential.account_id == "tom" + assert credential.user_id == "tom" + assert store.saved == {} + + +def test_openviking_root_identity_headers_are_sent_for_session_create(): + client = OpenVikingMemorySystemClient(store=FakeStore()) + client.root_key = "root-key" + calls = [] + responses = [FakeResponse(200, {"status": "ok", "result": {"session_id": "sess-2"}})] + client._client = lambda api_key, extra_headers=None: FakeAsyncClient( # type: ignore[method-assign] + calls, + responses, + api_key, + extra_headers or {}, + ) + credential = client.root_credential("tom") + + result = asyncio.run(client.ensure_session(credential, "sess-2")) + + assert result == {"status": "ok", "result": {"session_id": "sess-2"}} + assert calls == [ + ( + "post", + "root-key", + {"X-OpenViking-Account": "tom", "X-OpenViking-User": "tom"}, + "/api/v1/sessions", + {"session_id": "sess-2"}, + ) + ] def test_everos_assistant_payload_does_not_use_user_id_as_sender(): diff --git a/tests/test_memory_system_service.py b/tests/test_memory_system_service.py index ece25b4..b42dd6b 100644 --- a/tests/test_memory_system_service.py +++ b/tests/test_memory_system_service.py @@ -51,6 +51,62 @@ class FakeEverOS: return {"items": [{"source": f"everos-{method}"}]} +class FakeEverOSWithVector(FakeEverOS): + async def search(self, user_id: str, session_id: str | None, query: str, method: str, limit: int) -> dict: + self.calls.append(("search", user_id, session_id, query, method, limit)) + return { + "data": { + "episodes": [{"id": "episode-1", "vector": [0.1, 0.2]}], + "original_data": { + "episodes": { + "episode-1": { + "summary": "喜欢拿铁", + "vector": [0.1, 0.2], + "nested": {"vector": [0.3]}, + } + } + }, + } + } + + +def test_capture_includes_exception_type_when_message_is_empty(): + service = MemorySystemService(openviking=FakeOpenViking(), everos=FakeEverOS()) + + class EmptyError(Exception): + pass + + async def fail(): + raise EmptyError() + + response = asyncio.run(service._capture(fail)) + + assert response.status == "failed" + assert response.error == "EmptyError" + + +def test_search_removes_vectors_from_items_and_backend_results(): + service = MemorySystemService(openviking=FakeOpenViking(), everos=FakeEverOSWithVector()) + + response = asyncio.run(service.search( + SearchRequest(user_id="tom", session_id="sess-1", query="咖啡偏好", use_llm=False, limit=5) + )) + + assert response.items == [ + {"source_backend": "openviking", "source": "openviking-find"}, + {"source_backend": "everos", "id": "episode-1"}, + ] + assert not _has_key(response.backends["everos"].result, "vector") + + +def _has_key(value, key: str) -> bool: + if isinstance(value, dict): + return key in value or any(_has_key(item, key) for item in value.values()) + if isinstance(value, list): + return any(_has_key(item, key) for item in value) + return False + + def test_ingest_splits_user_and_assistant_messages(): openviking = FakeOpenViking() everos = FakeEverOS()