add multimodal memory proxy and API logging
This commit is contained in:
233
core/api.py
233
core/api.py
@ -1,9 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
||||
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import Response
|
||||
|
||||
from .config import GatewayConfig
|
||||
from .db import init_db
|
||||
@ -12,6 +18,19 @@ from .repository import MemoryRepository
|
||||
from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge
|
||||
|
||||
|
||||
API_LOGGER = logging.getLogger("memory_gateway.api")
|
||||
MAX_LOG_BODY_BYTES = 4096
|
||||
REDACTED = "[REDACTED]"
|
||||
SENSITIVE_FIELD_NAMES = {
|
||||
"api_key",
|
||||
"authorization",
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"user_key",
|
||||
}
|
||||
|
||||
|
||||
class SearchMemoriesRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
@ -25,6 +44,30 @@ class SearchMemoriesRequest(BaseModel):
|
||||
project_id: str = "default"
|
||||
|
||||
|
||||
class AddMemoryMessage(BaseModel):
|
||||
sender_id: str = Field(min_length=1)
|
||||
role: Literal["user", "assistant", "tool"]
|
||||
timestamp: int = Field(gt=0)
|
||||
content: str | list[dict[str, Any]]
|
||||
|
||||
|
||||
class AddMemoryRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
session_id: str = Field(min_length=1)
|
||||
messages: list[AddMemoryMessage] = Field(min_length=1)
|
||||
app_id: str = "default"
|
||||
project_id: str = "default"
|
||||
|
||||
|
||||
class FlushMemoryRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
session_id: str = Field(min_length=1)
|
||||
app_id: str = "default"
|
||||
project_id: str = "default"
|
||||
|
||||
|
||||
class MemoryOverrideRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
@ -43,6 +86,98 @@ class UserCreateRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
|
||||
|
||||
def _is_sensitive_field(key: str) -> bool:
|
||||
lowered = key.lower()
|
||||
return lowered in SENSITIVE_FIELD_NAMES or lowered.endswith("_token")
|
||||
|
||||
|
||||
def _redact(value: Any, key: str | None = None) -> Any:
|
||||
if key is not None and _is_sensitive_field(key):
|
||||
return REDACTED
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
item_key: _redact(item_value, item_key)
|
||||
for item_key, item_value in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_redact(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _redacted_query_params(items: list[tuple[str, str]]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in items:
|
||||
safe_value = REDACTED if _is_sensitive_field(key) else value
|
||||
if key in result:
|
||||
existing = result[key]
|
||||
if isinstance(existing, list):
|
||||
existing.append(safe_value)
|
||||
else:
|
||||
result[key] = [existing, safe_value]
|
||||
else:
|
||||
result[key] = safe_value
|
||||
return result
|
||||
|
||||
|
||||
def _redacted_url(url: str) -> str:
|
||||
parts = urlsplit(url)
|
||||
query = "&".join(
|
||||
f"{quote(key, safe='')}="
|
||||
f"{quote(REDACTED if _is_sensitive_field(key) else value, safe='[]')}"
|
||||
for key, value in parse_qsl(parts.query, keep_blank_values=True)
|
||||
)
|
||||
return urlunsplit((parts.scheme, parts.netloc, parts.path, query, parts.fragment))
|
||||
|
||||
|
||||
def _should_capture_request_body(
|
||||
content_type: str | None,
|
||||
content_length: int | None,
|
||||
) -> bool:
|
||||
normalized = (content_type or "").lower()
|
||||
if normalized.startswith("multipart/"):
|
||||
return False
|
||||
return content_length is None or content_length <= MAX_LOG_BODY_BYTES
|
||||
|
||||
|
||||
def _uncaptured_body_for_log(
|
||||
content_type: str | None,
|
||||
content_length: int | None,
|
||||
) -> dict[str, Any]:
|
||||
normalized = (content_type or "").lower()
|
||||
result: dict[str, Any] = {
|
||||
"content_type": normalized,
|
||||
"size_bytes": content_length,
|
||||
}
|
||||
if not normalized.startswith("multipart/"):
|
||||
result["truncated"] = True
|
||||
return result
|
||||
|
||||
|
||||
def _body_for_log(body: bytes, content_type: str | None) -> Any:
|
||||
if not body:
|
||||
return None
|
||||
content_type = (content_type or "").lower()
|
||||
if content_type.startswith("multipart/"):
|
||||
return {"content_type": content_type, "size_bytes": len(body)}
|
||||
if len(body) > MAX_LOG_BODY_BYTES:
|
||||
return {
|
||||
"truncated": True,
|
||||
"size_bytes": len(body),
|
||||
"content_type": content_type,
|
||||
}
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return _redact(json.loads(text))
|
||||
except json.JSONDecodeError:
|
||||
return text
|
||||
if "application/x-www-form-urlencoded" in content_type:
|
||||
return _redacted_query_params(parse_qsl(text, keep_blank_values=True))
|
||||
if content_type.startswith("text/"):
|
||||
return text
|
||||
return {"content_type": content_type, "size_bytes": len(body)}
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
config: GatewayConfig | None = None,
|
||||
@ -57,7 +192,7 @@ def create_app(
|
||||
)
|
||||
service = MemoryGatewayService(cfg, repository, client)
|
||||
|
||||
app = FastAPI(title="memory-gateway2", version="0.1.0")
|
||||
app = FastAPI(title="memory-gateway", version="0.1.0")
|
||||
app.state.config = cfg
|
||||
app.state.repository = repository
|
||||
app.state.everos_client = client
|
||||
@ -65,6 +200,77 @@ def create_app(
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_api_request(request: Request, call_next: Any) -> Response:
|
||||
request_time = datetime.now(timezone.utc).isoformat()
|
||||
started = time.perf_counter()
|
||||
request_content_type = request.headers.get("content-type")
|
||||
raw_content_length = request.headers.get("content-length")
|
||||
try:
|
||||
request_content_length = (
|
||||
int(raw_content_length) if raw_content_length is not None else None
|
||||
)
|
||||
except ValueError:
|
||||
request_content_length = None
|
||||
capture_request_body = _should_capture_request_body(
|
||||
request_content_type,
|
||||
request_content_length,
|
||||
)
|
||||
request_body = await request.body() if capture_request_body else None
|
||||
response_body = b""
|
||||
status_code = 500
|
||||
error: str | None = None
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
async for chunk in response.body_iterator:
|
||||
response_body += chunk
|
||||
logged_response = Response(
|
||||
content=response_body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
background=response.background,
|
||||
)
|
||||
except Exception as exc:
|
||||
error = str(exc)
|
||||
raise
|
||||
finally:
|
||||
duration_ms = round((time.perf_counter() - started) * 1000, 3)
|
||||
query_items = list(request.query_params.multi_items())
|
||||
output: dict[str, Any] = {"status_code": status_code}
|
||||
if error is not None:
|
||||
output["error"] = error
|
||||
else:
|
||||
output["body"] = _body_for_log(
|
||||
response_body,
|
||||
response.headers.get("content-type"),
|
||||
)
|
||||
event = {
|
||||
"request_time": request_time,
|
||||
"duration_ms": duration_ms,
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"url": _redacted_url(str(request.url)),
|
||||
"client": request.client.host if request.client else None,
|
||||
"input": {
|
||||
"query_params": _redacted_query_params(query_items),
|
||||
"body": (
|
||||
_body_for_log(request_body, request_content_type)
|
||||
if request_body is not None
|
||||
else _uncaptured_body_for_log(
|
||||
request_content_type,
|
||||
request_content_length,
|
||||
)
|
||||
),
|
||||
},
|
||||
"output": output,
|
||||
}
|
||||
API_LOGGER.info(json.dumps(event, ensure_ascii=False, default=str))
|
||||
|
||||
return logged_response
|
||||
|
||||
def require_user(user_id: str, user_key: str) -> None:
|
||||
if not service.authenticate_user(user_id, user_key):
|
||||
raise HTTPException(status_code=401, detail="invalid user credentials")
|
||||
@ -169,6 +375,29 @@ def create_app(
|
||||
project_id=request.project_id,
|
||||
)
|
||||
|
||||
@router.post("/memories/add")
|
||||
async def add_memory(
|
||||
request: AddMemoryRequest,
|
||||
) -> dict[str, Any]:
|
||||
require_user(request.user_id, request.user_key)
|
||||
return await service.add_memory(
|
||||
session_id=request.session_id,
|
||||
app_id=request.app_id,
|
||||
project_id=request.project_id,
|
||||
messages=[message.model_dump() for message in request.messages],
|
||||
)
|
||||
|
||||
@router.post("/memories/flush")
|
||||
async def flush_memory(
|
||||
request: FlushMemoryRequest,
|
||||
) -> dict[str, Any]:
|
||||
require_user(request.user_id, request.user_key)
|
||||
return await service.flush_memory(
|
||||
session_id=request.session_id,
|
||||
app_id=request.app_id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
|
||||
@router.patch("/memories/{memory_id}")
|
||||
async def patch_memory(
|
||||
memory_id: str,
|
||||
|
||||
Reference in New Issue
Block a user