437 lines
14 KiB
Python
437 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Literal
|
|
from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
|
|
|
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
|
from pydantic import BaseModel, Field
|
|
from starlette.responses import Response
|
|
|
|
from .config import GatewayConfig
|
|
from .db import init_db
|
|
from .everos_client import EverOSClient
|
|
from .repository import MemoryRepository
|
|
from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge
|
|
|
|
|
|
API_LOGGER = logging.getLogger("memory_gateway.api")
|
|
MAX_LOG_BODY_BYTES = 4096
|
|
REDACTED = "[REDACTED]"
|
|
SENSITIVE_FIELD_NAMES = {
|
|
"api_key",
|
|
"authorization",
|
|
"password",
|
|
"secret",
|
|
"token",
|
|
"user_key",
|
|
}
|
|
|
|
|
|
class SearchMemoriesRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
user_key: str = Field(min_length=1)
|
|
conversation_id: str | None = None
|
|
query: str = Field(min_length=1)
|
|
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
|
|
default_factory=lambda: ["current_chat", "resources"]
|
|
)
|
|
top_k: int = Field(default=8, ge=1, le=100)
|
|
app_id: str = "default"
|
|
project_id: str = "default"
|
|
|
|
|
|
class AddMemoryMessage(BaseModel):
|
|
sender_id: str = Field(min_length=1)
|
|
role: Literal["user", "assistant", "tool"]
|
|
timestamp: int = Field(gt=0)
|
|
content: str | list[dict[str, Any]]
|
|
|
|
|
|
class AddMemoryRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
user_key: str = Field(min_length=1)
|
|
session_id: str = Field(min_length=1)
|
|
messages: list[AddMemoryMessage] = Field(min_length=1)
|
|
app_id: str = "default"
|
|
project_id: str = "default"
|
|
|
|
|
|
class FlushMemoryRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
user_key: str = Field(min_length=1)
|
|
session_id: str = Field(min_length=1)
|
|
app_id: str = "default"
|
|
project_id: str = "default"
|
|
|
|
|
|
class MemoryOverrideRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
user_key: str = Field(min_length=1)
|
|
session_id: str = Field(min_length=1)
|
|
override_text: str = Field(min_length=1)
|
|
|
|
|
|
class MemoryDeleteRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
user_key: str = Field(min_length=1)
|
|
session_id: str = Field(min_length=1)
|
|
reason: str | None = None
|
|
|
|
|
|
class UserCreateRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
|
|
|
|
def _is_sensitive_field(key: str) -> bool:
|
|
lowered = key.lower()
|
|
return lowered in SENSITIVE_FIELD_NAMES or lowered.endswith("_token")
|
|
|
|
|
|
def _redact(value: Any, key: str | None = None) -> Any:
|
|
if key is not None and _is_sensitive_field(key):
|
|
return REDACTED
|
|
if isinstance(value, dict):
|
|
return {
|
|
item_key: _redact(item_value, item_key)
|
|
for item_key, item_value in value.items()
|
|
}
|
|
if isinstance(value, list):
|
|
return [_redact(item) for item in value]
|
|
return value
|
|
|
|
|
|
def _redacted_query_params(items: list[tuple[str, str]]) -> dict[str, Any]:
|
|
result: dict[str, Any] = {}
|
|
for key, value in items:
|
|
safe_value = REDACTED if _is_sensitive_field(key) else value
|
|
if key in result:
|
|
existing = result[key]
|
|
if isinstance(existing, list):
|
|
existing.append(safe_value)
|
|
else:
|
|
result[key] = [existing, safe_value]
|
|
else:
|
|
result[key] = safe_value
|
|
return result
|
|
|
|
|
|
def _redacted_url(url: str) -> str:
|
|
parts = urlsplit(url)
|
|
query = "&".join(
|
|
f"{quote(key, safe='')}="
|
|
f"{quote(REDACTED if _is_sensitive_field(key) else value, safe='[]')}"
|
|
for key, value in parse_qsl(parts.query, keep_blank_values=True)
|
|
)
|
|
return urlunsplit((parts.scheme, parts.netloc, parts.path, query, parts.fragment))
|
|
|
|
|
|
def _should_capture_request_body(
|
|
content_type: str | None,
|
|
content_length: int | None,
|
|
) -> bool:
|
|
normalized = (content_type or "").lower()
|
|
if normalized.startswith("multipart/"):
|
|
return False
|
|
return content_length is None or content_length <= MAX_LOG_BODY_BYTES
|
|
|
|
|
|
def _uncaptured_body_for_log(
|
|
content_type: str | None,
|
|
content_length: int | None,
|
|
) -> dict[str, Any]:
|
|
normalized = (content_type or "").lower()
|
|
result: dict[str, Any] = {
|
|
"content_type": normalized,
|
|
"size_bytes": content_length,
|
|
}
|
|
if not normalized.startswith("multipart/"):
|
|
result["truncated"] = True
|
|
return result
|
|
|
|
|
|
def _body_for_log(body: bytes, content_type: str | None) -> Any:
|
|
if not body:
|
|
return None
|
|
content_type = (content_type or "").lower()
|
|
if content_type.startswith("multipart/"):
|
|
return {"content_type": content_type, "size_bytes": len(body)}
|
|
if len(body) > MAX_LOG_BODY_BYTES:
|
|
return {
|
|
"truncated": True,
|
|
"size_bytes": len(body),
|
|
"content_type": content_type,
|
|
}
|
|
text = body.decode("utf-8", errors="replace")
|
|
if "application/json" in content_type:
|
|
try:
|
|
return _redact(json.loads(text))
|
|
except json.JSONDecodeError:
|
|
return text
|
|
if "application/x-www-form-urlencoded" in content_type:
|
|
return _redacted_query_params(parse_qsl(text, keep_blank_values=True))
|
|
if content_type.startswith("text/"):
|
|
return text
|
|
return {"content_type": content_type, "size_bytes": len(body)}
|
|
|
|
|
|
def create_app(
|
|
*,
|
|
config: GatewayConfig | None = None,
|
|
everos_client: Any | None = None,
|
|
) -> FastAPI:
|
|
cfg = config or GatewayConfig.from_env()
|
|
init_db(cfg.database_path)
|
|
repository = MemoryRepository(cfg.database_path)
|
|
client = everos_client or EverOSClient(
|
|
cfg.everos_base_url,
|
|
timeout=cfg.everos_timeout_seconds,
|
|
)
|
|
service = MemoryGatewayService(cfg, repository, client)
|
|
|
|
app = FastAPI(title="memory-gateway", version="0.1.0")
|
|
app.state.config = cfg
|
|
app.state.repository = repository
|
|
app.state.everos_client = client
|
|
app.state.gateway_service = service
|
|
|
|
router = APIRouter()
|
|
|
|
@app.middleware("http")
|
|
async def log_api_request(request: Request, call_next: Any) -> Response:
|
|
request_time = datetime.now(timezone.utc).isoformat()
|
|
started = time.perf_counter()
|
|
request_content_type = request.headers.get("content-type")
|
|
raw_content_length = request.headers.get("content-length")
|
|
try:
|
|
request_content_length = (
|
|
int(raw_content_length) if raw_content_length is not None else None
|
|
)
|
|
except ValueError:
|
|
request_content_length = None
|
|
capture_request_body = _should_capture_request_body(
|
|
request_content_type,
|
|
request_content_length,
|
|
)
|
|
request_body = await request.body() if capture_request_body else None
|
|
response_body = b""
|
|
status_code = 500
|
|
error: str | None = None
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
status_code = response.status_code
|
|
async for chunk in response.body_iterator:
|
|
response_body += chunk
|
|
logged_response = Response(
|
|
content=response_body,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
media_type=response.media_type,
|
|
background=response.background,
|
|
)
|
|
except Exception as exc:
|
|
error = str(exc)
|
|
raise
|
|
finally:
|
|
duration_ms = round((time.perf_counter() - started) * 1000, 3)
|
|
query_items = list(request.query_params.multi_items())
|
|
output: dict[str, Any] = {"status_code": status_code}
|
|
if error is not None:
|
|
output["error"] = error
|
|
else:
|
|
output["body"] = _body_for_log(
|
|
response_body,
|
|
response.headers.get("content-type"),
|
|
)
|
|
event = {
|
|
"request_time": request_time,
|
|
"duration_ms": duration_ms,
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"url": _redacted_url(str(request.url)),
|
|
"client": request.client.host if request.client else None,
|
|
"input": {
|
|
"query_params": _redacted_query_params(query_items),
|
|
"body": (
|
|
_body_for_log(request_body, request_content_type)
|
|
if request_body is not None
|
|
else _uncaptured_body_for_log(
|
|
request_content_type,
|
|
request_content_length,
|
|
)
|
|
),
|
|
},
|
|
"output": output,
|
|
}
|
|
API_LOGGER.info(json.dumps(event, ensure_ascii=False, default=str))
|
|
|
|
return logged_response
|
|
|
|
def require_user(user_id: str, user_key: str) -> None:
|
|
if not service.authenticate_user(user_id, user_key):
|
|
raise HTTPException(status_code=401, detail="invalid user credentials")
|
|
|
|
@router.get("/health")
|
|
async def health() -> dict[str, Any]:
|
|
try:
|
|
everos_health = await client.health_check()
|
|
except Exception as exc:
|
|
return {
|
|
"status": "degraded",
|
|
"api": {"status": "ok"},
|
|
"everos": {
|
|
"status": "unavailable",
|
|
"base_url": cfg.everos_base_url,
|
|
"error": str(exc),
|
|
},
|
|
}
|
|
return {
|
|
"status": "ok",
|
|
"api": {"status": "ok"},
|
|
"everos": {
|
|
"status": "ok",
|
|
"base_url": cfg.everos_base_url,
|
|
"data": everos_health,
|
|
},
|
|
}
|
|
|
|
@router.post("/users")
|
|
async def create_user(request: UserCreateRequest) -> dict[str, Any]:
|
|
return service.create_user(request.user_id)
|
|
|
|
@router.post("/resources")
|
|
async def upload_resource(
|
|
user_id: str = Form(...),
|
|
user_key: str = Form(...),
|
|
app_id: str = Form("default"),
|
|
project_id: str = Form("default"),
|
|
title: str | None = Form(None),
|
|
description: str | None = Form(None),
|
|
file: UploadFile = File(...),
|
|
) -> dict[str, Any]:
|
|
require_user(user_id, user_key)
|
|
try:
|
|
return await service.upload_resource(
|
|
user_id=user_id,
|
|
app_id=app_id,
|
|
project_id=project_id,
|
|
file=file,
|
|
title=title,
|
|
description=description,
|
|
)
|
|
except UploadTooLarge as exc:
|
|
raise HTTPException(status_code=413, detail=str(exc)) from exc
|
|
except UnsupportedContentType as exc:
|
|
raise HTTPException(status_code=415, detail=str(exc)) from exc
|
|
|
|
@router.get("/resources")
|
|
async def list_resources(
|
|
user_id: str,
|
|
user_key: str,
|
|
) -> dict[str, Any]:
|
|
require_user(user_id, user_key)
|
|
return {"resources": service.list_resources(user_id)}
|
|
|
|
@router.get("/resources/{resource_id}")
|
|
async def get_resource(
|
|
resource_id: str,
|
|
user_id: str,
|
|
user_key: str,
|
|
) -> dict[str, Any]:
|
|
require_user(user_id, user_key)
|
|
resource = service.get_resource_detail(resource_id, user_id)
|
|
if resource is None:
|
|
return {"resources": []}
|
|
return {"resources": [resource]}
|
|
|
|
@router.delete("/resources/{resource_id}")
|
|
async def delete_resource(
|
|
resource_id: str,
|
|
user_id: str,
|
|
user_key: str,
|
|
) -> dict[str, Any]:
|
|
require_user(user_id, user_key)
|
|
resource = service.delete_resource(resource_id, user_id)
|
|
if resource is None:
|
|
raise HTTPException(status_code=404, detail="resource not found")
|
|
return resource
|
|
|
|
@router.post("/memories/search")
|
|
async def search_memories(
|
|
request: SearchMemoriesRequest,
|
|
) -> dict[str, Any]:
|
|
require_user(request.user_id, request.user_key)
|
|
return await service.search_memories(
|
|
user_id=request.user_id,
|
|
query=request.query,
|
|
conversation_id=request.conversation_id,
|
|
scope=request.scope,
|
|
top_k=request.top_k,
|
|
app_id=request.app_id,
|
|
project_id=request.project_id,
|
|
)
|
|
|
|
@router.post("/memories/add")
|
|
async def add_memory(
|
|
request: AddMemoryRequest,
|
|
) -> dict[str, Any]:
|
|
require_user(request.user_id, request.user_key)
|
|
return await service.add_memory(
|
|
session_id=request.session_id,
|
|
app_id=request.app_id,
|
|
project_id=request.project_id,
|
|
messages=[message.model_dump() for message in request.messages],
|
|
)
|
|
|
|
@router.post("/memories/flush")
|
|
async def flush_memory(
|
|
request: FlushMemoryRequest,
|
|
) -> dict[str, Any]:
|
|
require_user(request.user_id, request.user_key)
|
|
return await service.flush_memory(
|
|
session_id=request.session_id,
|
|
app_id=request.app_id,
|
|
project_id=request.project_id,
|
|
)
|
|
|
|
@router.patch("/memories/{memory_id}")
|
|
async def patch_memory(
|
|
memory_id: str,
|
|
request: MemoryOverrideRequest,
|
|
) -> dict[str, Any]:
|
|
require_user(request.user_id, request.user_key)
|
|
try:
|
|
service.assert_memory_session_owned(request.user_id, request.session_id)
|
|
except PermissionError as exc:
|
|
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
|
return service.upsert_override(
|
|
user_id=request.user_id,
|
|
memory_id=memory_id,
|
|
session_id=request.session_id,
|
|
override_text=request.override_text,
|
|
)
|
|
|
|
@router.delete("/memories/{memory_id}")
|
|
async def delete_memory(
|
|
memory_id: str,
|
|
request: MemoryDeleteRequest,
|
|
) -> dict[str, Any]:
|
|
require_user(request.user_id, request.user_key)
|
|
try:
|
|
service.assert_memory_session_owned(request.user_id, request.session_id)
|
|
except PermissionError as exc:
|
|
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
|
return service.delete_memory(
|
|
user_id=request.user_id,
|
|
memory_id=memory_id,
|
|
session_id=request.session_id,
|
|
reason=request.reason,
|
|
)
|
|
|
|
app.include_router(router)
|
|
return app
|