69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
"""Small asynchronous client for the Memory Gateway API."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from beaver.foundation.config import MemoryGatewayConfig
|
|
|
|
|
|
class MemoryGatewayClientError(RuntimeError):
|
|
"""Sanitized Gateway transport or response failure."""
|
|
|
|
def __init__(self, operation: str, category: str, *, status_code: int | None = None) -> None:
|
|
self.operation = operation
|
|
self.category = category
|
|
self.status_code = status_code
|
|
status = f" status={status_code}" if status_code is not None else ""
|
|
super().__init__(f"Memory Gateway {operation} failed: {category}{status}")
|
|
|
|
|
|
class MemoryGatewayClient:
|
|
"""HTTP transport for search, add, and flush operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: MemoryGatewayConfig,
|
|
*,
|
|
transport: httpx.AsyncBaseTransport | None = None,
|
|
) -> None:
|
|
self.config = config
|
|
self.transport = transport
|
|
|
|
async def search(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
return await self._post("search", "/memories/search", payload)
|
|
|
|
async def add(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
return await self._post("add", "/memories/add", payload)
|
|
|
|
async def flush(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
return await self._post("flush", "/memories/flush", payload)
|
|
|
|
async def _post(self, operation: str, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
try:
|
|
async with httpx.AsyncClient(
|
|
base_url=self.config.base_url.rstrip("/"),
|
|
timeout=self.config.timeout_seconds,
|
|
transport=self.transport,
|
|
trust_env=False,
|
|
) as client:
|
|
response = await client.post(path, json=payload)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
except httpx.HTTPStatusError as exc:
|
|
raise MemoryGatewayClientError(
|
|
operation,
|
|
"http_status",
|
|
status_code=exc.response.status_code,
|
|
) from None
|
|
except httpx.RequestError:
|
|
raise MemoryGatewayClientError(operation, "network") from None
|
|
except ValueError:
|
|
raise MemoryGatewayClientError(operation, "invalid_json") from None
|
|
|
|
if not isinstance(data, dict):
|
|
raise MemoryGatewayClientError(operation, "invalid_response")
|
|
return data
|