Files
memory-gateway/core/api.py

467 lines
15 KiB
Python

from __future__ import annotations
import json
import logging
import time
from datetime import datetime, timezone
from typing import Any, Literal
from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
from pydantic import BaseModel, Field, field_validator
from starlette.responses import Response
from .config import GatewayConfig
from .db import init_db
from .backend_client import BackendClient
from .repository import MemoryRepository
from .service import (
InvalidAttachment,
MemoryGatewayService,
UnsupportedContentType,
UploadTooLarge,
)
API_LOGGER = logging.getLogger("memory_gateway.api")
MAX_LOG_BODY_BYTES = 4096
REDACTED = "[REDACTED]"
SENSITIVE_FIELD_NAMES = {
"api_key",
"authorization",
"password",
"secret",
"token",
"user_key",
}
class SearchMemoriesRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
agent_id: str | None = Field(default=None, min_length=1)
conversation_id: str | None = None
query: str = Field(min_length=1)
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
default_factory=lambda: ["current_chat", "resources"]
)
method: Literal["keyword", "vector", "hybrid", "agentic"] = "hybrid"
top_k: int = 8
radius: float | None = Field(default=None, ge=0, le=1)
include_profile: bool = True
enable_llm_rerank: bool = True
filters: dict[str, Any] | None = None
app_id: str = "default"
project_id: str = "default"
@field_validator("top_k")
@classmethod
def validate_top_k(cls, value: int) -> int:
if value != -1 and not 1 <= value <= 100:
raise ValueError("top_k must be -1 or in 1..100")
return value
class AddMemoryMessage(BaseModel):
sender_id: str = Field(min_length=1)
role: Literal["user", "assistant", "tool"]
timestamp: int = Field(gt=0)
content: str | list[dict[str, Any]]
class AddMemoryRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
session_id: str = Field(min_length=1)
messages: list[AddMemoryMessage] = Field(min_length=1)
app_id: str = "default"
project_id: str = "default"
class FlushMemoryRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
session_id: str = Field(min_length=1)
app_id: str = "default"
project_id: str = "default"
class MemoryOverrideRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
session_id: str = Field(min_length=1)
override_text: str = Field(min_length=1)
class MemoryDeleteRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
session_id: str = Field(min_length=1)
reason: str | None = None
class UserCreateRequest(BaseModel):
user_id: str = Field(min_length=1)
def _is_sensitive_field(key: str) -> bool:
lowered = key.lower()
return lowered in SENSITIVE_FIELD_NAMES or lowered.endswith("_token")
def _redact(value: Any, key: str | None = None) -> Any:
if key is not None and _is_sensitive_field(key):
return REDACTED
if isinstance(value, dict):
return {
item_key: _redact(item_value, item_key)
for item_key, item_value in value.items()
}
if isinstance(value, list):
return [_redact(item) for item in value]
return value
def _redacted_query_params(items: list[tuple[str, str]]) -> dict[str, Any]:
result: dict[str, Any] = {}
for key, value in items:
safe_value = REDACTED if _is_sensitive_field(key) else value
if key in result:
existing = result[key]
if isinstance(existing, list):
existing.append(safe_value)
else:
result[key] = [existing, safe_value]
else:
result[key] = safe_value
return result
def _redacted_url(url: str) -> str:
parts = urlsplit(url)
query = "&".join(
f"{quote(key, safe='')}="
f"{quote(REDACTED if _is_sensitive_field(key) else value, safe='[]')}"
for key, value in parse_qsl(parts.query, keep_blank_values=True)
)
return urlunsplit((parts.scheme, parts.netloc, parts.path, query, parts.fragment))
def _should_capture_request_body(
content_type: str | None,
content_length: int | None,
) -> bool:
normalized = (content_type or "").lower()
if normalized.startswith("multipart/"):
return False
return content_length is None or content_length <= MAX_LOG_BODY_BYTES
def _uncaptured_body_for_log(
content_type: str | None,
content_length: int | None,
) -> dict[str, Any]:
normalized = (content_type or "").lower()
result: dict[str, Any] = {
"content_type": normalized,
"size_bytes": content_length,
}
if not normalized.startswith("multipart/"):
result["truncated"] = True
return result
def _body_for_log(body: bytes, content_type: str | None) -> Any:
if not body:
return None
content_type = (content_type or "").lower()
if content_type.startswith("multipart/"):
return {"content_type": content_type, "size_bytes": len(body)}
if len(body) > MAX_LOG_BODY_BYTES:
return {
"truncated": True,
"size_bytes": len(body),
"content_type": content_type,
}
text = body.decode("utf-8", errors="replace")
if "application/json" in content_type:
try:
return _redact(json.loads(text))
except json.JSONDecodeError:
return text
if "application/x-www-form-urlencoded" in content_type:
return _redacted_query_params(parse_qsl(text, keep_blank_values=True))
if content_type.startswith("text/"):
return text
return {"content_type": content_type, "size_bytes": len(body)}
def create_app(
*,
config: GatewayConfig | None = None,
backend_client: Any | None = None,
) -> FastAPI:
cfg = config or GatewayConfig.from_env()
init_db(cfg.database_path)
repository = MemoryRepository(cfg.database_path)
client = backend_client or BackendClient(
cfg.backend_base_url,
timeout=cfg.backend_timeout_seconds,
)
service = MemoryGatewayService(cfg, repository, client)
app = FastAPI(title="memory-gateway", version="0.1.0")
app.state.config = cfg
app.state.repository = repository
app.state.backend_client = client
app.state.gateway_service = service
router = APIRouter()
@app.middleware("http")
async def log_api_request(request: Request, call_next: Any) -> Response:
request_time = datetime.now(timezone.utc).isoformat()
started = time.perf_counter()
request_content_type = request.headers.get("content-type")
raw_content_length = request.headers.get("content-length")
try:
request_content_length = (
int(raw_content_length) if raw_content_length is not None else None
)
except ValueError:
request_content_length = None
capture_request_body = _should_capture_request_body(
request_content_type,
request_content_length,
)
request_body = await request.body() if capture_request_body else None
response_body = b""
status_code = 500
error: str | None = None
try:
response = await call_next(request)
status_code = response.status_code
async for chunk in response.body_iterator:
response_body += chunk
logged_response = Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
background=response.background,
)
except Exception as exc:
error = str(exc)
raise
finally:
duration_ms = round((time.perf_counter() - started) * 1000, 3)
query_items = list(request.query_params.multi_items())
output: dict[str, Any] = {"status_code": status_code}
if error is not None:
output["error"] = error
else:
output["body"] = _body_for_log(
response_body,
response.headers.get("content-type"),
)
event = {
"request_time": request_time,
"duration_ms": duration_ms,
"method": request.method,
"path": request.url.path,
"url": _redacted_url(str(request.url)),
"client": request.client.host if request.client else None,
"input": {
"query_params": _redacted_query_params(query_items),
"body": (
_body_for_log(request_body, request_content_type)
if request_body is not None
else _uncaptured_body_for_log(
request_content_type,
request_content_length,
)
),
},
"output": output,
}
API_LOGGER.info(json.dumps(event, ensure_ascii=False, default=str))
return logged_response
def require_user(user_id: str, user_key: str) -> None:
if not service.authenticate_user(user_id, user_key):
raise HTTPException(status_code=401, detail="invalid user credentials")
@router.get("/health")
async def health() -> dict[str, Any]:
try:
backend_health = await client.health_check()
except Exception as exc:
return {
"status": "degraded",
"api": {"status": "ok"},
"backend": {
"status": "unavailable",
"base_url": cfg.backend_base_url,
"error": str(exc),
},
}
return {
"status": "ok",
"api": {"status": "ok"},
"backend": {
"status": "ok",
"base_url": cfg.backend_base_url,
"data": backend_health,
},
}
@router.post("/users")
async def create_user(request: UserCreateRequest) -> dict[str, Any]:
return service.create_user(request.user_id)
@router.post("/resources")
async def upload_resource(
user_id: str = Form(...),
user_key: str = Form(...),
app_id: str = Form("default"),
project_id: str = Form("default"),
title: str | None = Form(None),
description: str | None = Form(None),
file: UploadFile = File(...),
) -> dict[str, Any]:
require_user(user_id, user_key)
try:
return await service.upload_resource(
user_id=user_id,
app_id=app_id,
project_id=project_id,
file=file,
title=title,
description=description,
)
except UploadTooLarge as exc:
raise HTTPException(status_code=413, detail=str(exc)) from exc
except UnsupportedContentType as exc:
raise HTTPException(status_code=415, detail=str(exc)) from exc
@router.get("/resources")
async def list_resources(
user_id: str,
user_key: str,
) -> dict[str, Any]:
require_user(user_id, user_key)
return {"resources": service.list_resources(user_id)}
@router.get("/resources/{resource_id}")
async def get_resource(
resource_id: str,
user_id: str,
user_key: str,
) -> dict[str, Any]:
require_user(user_id, user_key)
resource = service.get_resource_detail(resource_id, user_id)
if resource is None:
return {"resources": []}
return {"resources": [resource]}
@router.delete("/resources/{resource_id}")
async def delete_resource(
resource_id: str,
user_id: str,
user_key: str,
) -> dict[str, Any]:
require_user(user_id, user_key)
resource = service.delete_resource(resource_id, user_id)
if resource is None:
raise HTTPException(status_code=404, detail="resource not found")
return resource
@router.post("/memories/search")
async def search_memories(
request: SearchMemoriesRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
return await service.search_memories(
user_id=request.user_id,
agent_id=request.agent_id,
query=request.query,
conversation_id=request.conversation_id,
scope=request.scope,
method=request.method,
top_k=request.top_k,
radius=request.radius,
include_profile=request.include_profile,
enable_llm_rerank=request.enable_llm_rerank,
filters=request.filters,
app_id=request.app_id,
project_id=request.project_id,
)
@router.post("/memories/add")
async def add_memory(
request: AddMemoryRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
try:
return await service.add_memory(
user_id=request.user_id,
session_id=request.session_id,
app_id=request.app_id,
project_id=request.project_id,
messages=[message.model_dump() for message in request.messages],
)
except UploadTooLarge as exc:
raise HTTPException(status_code=413, detail=str(exc)) from exc
except InvalidAttachment as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
@router.post("/memories/flush")
async def flush_memory(
request: FlushMemoryRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
return await service.flush_memory(
session_id=request.session_id,
app_id=request.app_id,
project_id=request.project_id,
)
@router.patch("/memories/{memory_id}")
async def patch_memory(
memory_id: str,
request: MemoryOverrideRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
try:
service.assert_memory_session_owned(request.user_id, request.session_id)
except PermissionError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
return service.upsert_override(
user_id=request.user_id,
memory_id=memory_id,
session_id=request.session_id,
override_text=request.override_text,
)
@router.delete("/memories/{memory_id}")
async def delete_memory(
memory_id: str,
request: MemoryDeleteRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
try:
service.assert_memory_session_owned(request.user_id, request.session_id)
except PermissionError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
return service.delete_memory(
user_id=request.user_id,
memory_id=memory_id,
session_id=request.session_id,
reason=request.reason,
)
app.include_router(router)
return app