from __future__ import annotations import json import logging import time import urllib.error import urllib.request from typing import Any from .config import PluginConfig, load_config _logger = logging.getLogger(__name__) def _short_error(value: Any, max_chars: int = 500) -> str: text = str(value).replace("\n", " ").strip() return text[:max_chars] class MemoryGatewayClient: def __init__(self, config: PluginConfig | None = None) -> None: self.config = config or load_config() def _headers(self) -> dict[str, str]: headers = {"Content-Type": "application/json"} if self.config.api_key: headers["X-API-Key"] = self.config.api_key return headers def _post(self, endpoint: str, payload: dict[str, Any], retries: int = 3, backoff: float = 1.0) -> dict[str, Any]: url = self.config.gateway_url.rstrip("/") + endpoint body = json.dumps(payload, ensure_ascii=False).encode("utf-8") last_error: Exception | None = None for attempt in range(retries): request = urllib.request.Request(url, data=body, headers=self._headers(), method="POST") try: with urllib.request.urlopen(request, timeout=self.config.timeout) as response: raw = response.read().decode("utf-8") data = json.loads(raw) if raw else {} return { "ok": True, "status_code": getattr(response, "status", 200), "endpoint": endpoint, "data": data, } except urllib.error.HTTPError as exc: # Typically, client errors (4xx) shouldn't be retried unless specifically handled. # Since HTTPError is a subclass of URLError, we catch it first. if exc.code < 500 and exc.code != 429: try: body_text = exc.read().decode("utf-8") except Exception: body_text = exc.reason _logger.error(f"HTTPError in _post to {endpoint}: {exc.code} {body_text}") return { "ok": False, "status_code": exc.code, "endpoint": endpoint, "error": _short_error(body_text), } last_error = exc except (urllib.error.URLError, TimeoutError, OSError) as exc: last_error = exc except Exception as exc: _logger.error("Unexpected error in _post to %s: %s", endpoint, exc, exc_info=True) return { "ok": False, "status_code": None, "endpoint": endpoint, "error": _short_error(exc), } if attempt < retries - 1: time.sleep(backoff * (2 ** attempt)) # Exhausted retries error_msg = str(last_error) if last_error else "Max retries exceeded" _logger.error("Failed _post to %s after %d attempts. Last error: %s", endpoint, retries, last_error) return { "ok": False, "status_code": None, "endpoint": endpoint, "error": error_msg, } def search_memory(self, payload: dict[str, Any]) -> dict[str, Any]: return self._post("/v1/memory/search", payload) def append_episode(self, payload: dict[str, Any]) -> dict[str, Any]: return self._post("/v1/episodes", payload) def commit_session(self, session_id: str, payload: dict[str, Any]) -> dict[str, Any]: return self._post(f"/v1/sessions/{session_id}/commit", payload) def upsert_memory(self, payload: dict[str, Any]) -> dict[str, Any]: return self._post("/v1/memory", payload) def send_feedback(self, memory_id: str, payload: dict[str, Any]) -> dict[str, Any]: return self._post(f"/v1/memory/{memory_id}/feedback", payload)