from __future__ import annotations import json import logging import time from datetime import datetime, timezone from typing import Any, Literal from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile from pydantic import BaseModel, Field from starlette.responses import Response from .config import GatewayConfig from .db import init_db from .everos_client import EverOSClient from .repository import MemoryRepository from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge API_LOGGER = logging.getLogger("memory_gateway.api") MAX_LOG_BODY_BYTES = 4096 REDACTED = "[REDACTED]" SENSITIVE_FIELD_NAMES = { "api_key", "authorization", "password", "secret", "token", "user_key", } class SearchMemoriesRequest(BaseModel): user_id: str = Field(min_length=1) user_key: str = Field(min_length=1) conversation_id: str | None = None query: str = Field(min_length=1) scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field( default_factory=lambda: ["current_chat", "resources"] ) top_k: int = Field(default=8, ge=1, le=100) app_id: str = "default" project_id: str = "default" class AddMemoryMessage(BaseModel): sender_id: str = Field(min_length=1) role: Literal["user", "assistant", "tool"] timestamp: int = Field(gt=0) content: str | list[dict[str, Any]] class AddMemoryRequest(BaseModel): user_id: str = Field(min_length=1) user_key: str = Field(min_length=1) session_id: str = Field(min_length=1) messages: list[AddMemoryMessage] = Field(min_length=1) app_id: str = "default" project_id: str = "default" class FlushMemoryRequest(BaseModel): user_id: str = Field(min_length=1) user_key: str = Field(min_length=1) session_id: str = Field(min_length=1) app_id: str = "default" project_id: str = "default" class MemoryOverrideRequest(BaseModel): user_id: str = Field(min_length=1) user_key: str = Field(min_length=1) session_id: str = Field(min_length=1) override_text: str = Field(min_length=1) class MemoryDeleteRequest(BaseModel): user_id: str = Field(min_length=1) user_key: str = Field(min_length=1) session_id: str = Field(min_length=1) reason: str | None = None class UserCreateRequest(BaseModel): user_id: str = Field(min_length=1) def _is_sensitive_field(key: str) -> bool: lowered = key.lower() return lowered in SENSITIVE_FIELD_NAMES or lowered.endswith("_token") def _redact(value: Any, key: str | None = None) -> Any: if key is not None and _is_sensitive_field(key): return REDACTED if isinstance(value, dict): return { item_key: _redact(item_value, item_key) for item_key, item_value in value.items() } if isinstance(value, list): return [_redact(item) for item in value] return value def _redacted_query_params(items: list[tuple[str, str]]) -> dict[str, Any]: result: dict[str, Any] = {} for key, value in items: safe_value = REDACTED if _is_sensitive_field(key) else value if key in result: existing = result[key] if isinstance(existing, list): existing.append(safe_value) else: result[key] = [existing, safe_value] else: result[key] = safe_value return result def _redacted_url(url: str) -> str: parts = urlsplit(url) query = "&".join( f"{quote(key, safe='')}=" f"{quote(REDACTED if _is_sensitive_field(key) else value, safe='[]')}" for key, value in parse_qsl(parts.query, keep_blank_values=True) ) return urlunsplit((parts.scheme, parts.netloc, parts.path, query, parts.fragment)) def _should_capture_request_body( content_type: str | None, content_length: int | None, ) -> bool: normalized = (content_type or "").lower() if normalized.startswith("multipart/"): return False return content_length is None or content_length <= MAX_LOG_BODY_BYTES def _uncaptured_body_for_log( content_type: str | None, content_length: int | None, ) -> dict[str, Any]: normalized = (content_type or "").lower() result: dict[str, Any] = { "content_type": normalized, "size_bytes": content_length, } if not normalized.startswith("multipart/"): result["truncated"] = True return result def _body_for_log(body: bytes, content_type: str | None) -> Any: if not body: return None content_type = (content_type or "").lower() if content_type.startswith("multipart/"): return {"content_type": content_type, "size_bytes": len(body)} if len(body) > MAX_LOG_BODY_BYTES: return { "truncated": True, "size_bytes": len(body), "content_type": content_type, } text = body.decode("utf-8", errors="replace") if "application/json" in content_type: try: return _redact(json.loads(text)) except json.JSONDecodeError: return text if "application/x-www-form-urlencoded" in content_type: return _redacted_query_params(parse_qsl(text, keep_blank_values=True)) if content_type.startswith("text/"): return text return {"content_type": content_type, "size_bytes": len(body)} def create_app( *, config: GatewayConfig | None = None, everos_client: Any | None = None, ) -> FastAPI: cfg = config or GatewayConfig.from_env() init_db(cfg.database_path) repository = MemoryRepository(cfg.database_path) client = everos_client or EverOSClient( cfg.everos_base_url, timeout=cfg.everos_timeout_seconds, ) service = MemoryGatewayService(cfg, repository, client) app = FastAPI(title="memory-gateway", version="0.1.0") app.state.config = cfg app.state.repository = repository app.state.everos_client = client app.state.gateway_service = service router = APIRouter() @app.middleware("http") async def log_api_request(request: Request, call_next: Any) -> Response: request_time = datetime.now(timezone.utc).isoformat() started = time.perf_counter() request_content_type = request.headers.get("content-type") raw_content_length = request.headers.get("content-length") try: request_content_length = ( int(raw_content_length) if raw_content_length is not None else None ) except ValueError: request_content_length = None capture_request_body = _should_capture_request_body( request_content_type, request_content_length, ) request_body = await request.body() if capture_request_body else None response_body = b"" status_code = 500 error: str | None = None try: response = await call_next(request) status_code = response.status_code async for chunk in response.body_iterator: response_body += chunk logged_response = Response( content=response_body, status_code=response.status_code, headers=dict(response.headers), media_type=response.media_type, background=response.background, ) except Exception as exc: error = str(exc) raise finally: duration_ms = round((time.perf_counter() - started) * 1000, 3) query_items = list(request.query_params.multi_items()) output: dict[str, Any] = {"status_code": status_code} if error is not None: output["error"] = error else: output["body"] = _body_for_log( response_body, response.headers.get("content-type"), ) event = { "request_time": request_time, "duration_ms": duration_ms, "method": request.method, "path": request.url.path, "url": _redacted_url(str(request.url)), "client": request.client.host if request.client else None, "input": { "query_params": _redacted_query_params(query_items), "body": ( _body_for_log(request_body, request_content_type) if request_body is not None else _uncaptured_body_for_log( request_content_type, request_content_length, ) ), }, "output": output, } API_LOGGER.info(json.dumps(event, ensure_ascii=False, default=str)) return logged_response def require_user(user_id: str, user_key: str) -> None: if not service.authenticate_user(user_id, user_key): raise HTTPException(status_code=401, detail="invalid user credentials") @router.get("/health") async def health() -> dict[str, Any]: try: everos_health = await client.health_check() except Exception as exc: return { "status": "degraded", "api": {"status": "ok"}, "everos": { "status": "unavailable", "base_url": cfg.everos_base_url, "error": str(exc), }, } return { "status": "ok", "api": {"status": "ok"}, "everos": { "status": "ok", "base_url": cfg.everos_base_url, "data": everos_health, }, } @router.post("/users") async def create_user(request: UserCreateRequest) -> dict[str, Any]: return service.create_user(request.user_id) @router.post("/resources") async def upload_resource( user_id: str = Form(...), user_key: str = Form(...), app_id: str = Form("default"), project_id: str = Form("default"), title: str | None = Form(None), description: str | None = Form(None), file: UploadFile = File(...), ) -> dict[str, Any]: require_user(user_id, user_key) try: return await service.upload_resource( user_id=user_id, app_id=app_id, project_id=project_id, file=file, title=title, description=description, ) except UploadTooLarge as exc: raise HTTPException(status_code=413, detail=str(exc)) from exc except UnsupportedContentType as exc: raise HTTPException(status_code=415, detail=str(exc)) from exc @router.get("/resources") async def list_resources( user_id: str, user_key: str, ) -> dict[str, Any]: require_user(user_id, user_key) return {"resources": service.list_resources(user_id)} @router.get("/resources/{resource_id}") async def get_resource( resource_id: str, user_id: str, user_key: str, ) -> dict[str, Any]: require_user(user_id, user_key) resource = service.get_resource_detail(resource_id, user_id) if resource is None: return {"resources": []} return {"resources": [resource]} @router.delete("/resources/{resource_id}") async def delete_resource( resource_id: str, user_id: str, user_key: str, ) -> dict[str, Any]: require_user(user_id, user_key) resource = service.delete_resource(resource_id, user_id) if resource is None: raise HTTPException(status_code=404, detail="resource not found") return resource @router.post("/memories/search") async def search_memories( request: SearchMemoriesRequest, ) -> dict[str, Any]: require_user(request.user_id, request.user_key) return await service.search_memories( user_id=request.user_id, query=request.query, conversation_id=request.conversation_id, scope=request.scope, top_k=request.top_k, app_id=request.app_id, project_id=request.project_id, ) @router.post("/memories/add") async def add_memory( request: AddMemoryRequest, ) -> dict[str, Any]: require_user(request.user_id, request.user_key) return await service.add_memory( session_id=request.session_id, app_id=request.app_id, project_id=request.project_id, messages=[message.model_dump() for message in request.messages], ) @router.post("/memories/flush") async def flush_memory( request: FlushMemoryRequest, ) -> dict[str, Any]: require_user(request.user_id, request.user_key) return await service.flush_memory( session_id=request.session_id, app_id=request.app_id, project_id=request.project_id, ) @router.patch("/memories/{memory_id}") async def patch_memory( memory_id: str, request: MemoryOverrideRequest, ) -> dict[str, Any]: require_user(request.user_id, request.user_key) try: service.assert_memory_session_owned(request.user_id, request.session_id) except PermissionError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc return service.upsert_override( user_id=request.user_id, memory_id=memory_id, session_id=request.session_id, override_text=request.override_text, ) @router.delete("/memories/{memory_id}") async def delete_memory( memory_id: str, request: MemoryDeleteRequest, ) -> dict[str, Any]: require_user(request.user_id, request.user_key) try: service.assert_memory_session_owned(request.user_id, request.session_id) except PermissionError as exc: raise HTTPException(status_code=403, detail=str(exc)) from exc return service.delete_memory( user_id=request.user_id, memory_id=memory_id, session_id=request.session_id, reason=request.reason, ) app.include_router(router) return app