add multimodal memory proxy and API logging

This commit is contained in:
2026-06-12 11:04:53 +08:00
parent 8afb460883
commit a29009dc07
12 changed files with 2229 additions and 33 deletions

View File

@ -1,9 +1,15 @@
from __future__ import annotations
import json
import logging
import time
from datetime import datetime, timezone
from typing import Any, Literal
from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
from pydantic import BaseModel, Field
from starlette.responses import Response
from .config import GatewayConfig
from .db import init_db
@ -12,6 +18,19 @@ from .repository import MemoryRepository
from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge
API_LOGGER = logging.getLogger("memory_gateway.api")
MAX_LOG_BODY_BYTES = 4096
REDACTED = "[REDACTED]"
SENSITIVE_FIELD_NAMES = {
"api_key",
"authorization",
"password",
"secret",
"token",
"user_key",
}
class SearchMemoriesRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
@ -25,6 +44,30 @@ class SearchMemoriesRequest(BaseModel):
project_id: str = "default"
class AddMemoryMessage(BaseModel):
sender_id: str = Field(min_length=1)
role: Literal["user", "assistant", "tool"]
timestamp: int = Field(gt=0)
content: str | list[dict[str, Any]]
class AddMemoryRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
session_id: str = Field(min_length=1)
messages: list[AddMemoryMessage] = Field(min_length=1)
app_id: str = "default"
project_id: str = "default"
class FlushMemoryRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
session_id: str = Field(min_length=1)
app_id: str = "default"
project_id: str = "default"
class MemoryOverrideRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
@ -43,6 +86,98 @@ class UserCreateRequest(BaseModel):
user_id: str = Field(min_length=1)
def _is_sensitive_field(key: str) -> bool:
lowered = key.lower()
return lowered in SENSITIVE_FIELD_NAMES or lowered.endswith("_token")
def _redact(value: Any, key: str | None = None) -> Any:
if key is not None and _is_sensitive_field(key):
return REDACTED
if isinstance(value, dict):
return {
item_key: _redact(item_value, item_key)
for item_key, item_value in value.items()
}
if isinstance(value, list):
return [_redact(item) for item in value]
return value
def _redacted_query_params(items: list[tuple[str, str]]) -> dict[str, Any]:
result: dict[str, Any] = {}
for key, value in items:
safe_value = REDACTED if _is_sensitive_field(key) else value
if key in result:
existing = result[key]
if isinstance(existing, list):
existing.append(safe_value)
else:
result[key] = [existing, safe_value]
else:
result[key] = safe_value
return result
def _redacted_url(url: str) -> str:
parts = urlsplit(url)
query = "&".join(
f"{quote(key, safe='')}="
f"{quote(REDACTED if _is_sensitive_field(key) else value, safe='[]')}"
for key, value in parse_qsl(parts.query, keep_blank_values=True)
)
return urlunsplit((parts.scheme, parts.netloc, parts.path, query, parts.fragment))
def _should_capture_request_body(
content_type: str | None,
content_length: int | None,
) -> bool:
normalized = (content_type or "").lower()
if normalized.startswith("multipart/"):
return False
return content_length is None or content_length <= MAX_LOG_BODY_BYTES
def _uncaptured_body_for_log(
content_type: str | None,
content_length: int | None,
) -> dict[str, Any]:
normalized = (content_type or "").lower()
result: dict[str, Any] = {
"content_type": normalized,
"size_bytes": content_length,
}
if not normalized.startswith("multipart/"):
result["truncated"] = True
return result
def _body_for_log(body: bytes, content_type: str | None) -> Any:
if not body:
return None
content_type = (content_type or "").lower()
if content_type.startswith("multipart/"):
return {"content_type": content_type, "size_bytes": len(body)}
if len(body) > MAX_LOG_BODY_BYTES:
return {
"truncated": True,
"size_bytes": len(body),
"content_type": content_type,
}
text = body.decode("utf-8", errors="replace")
if "application/json" in content_type:
try:
return _redact(json.loads(text))
except json.JSONDecodeError:
return text
if "application/x-www-form-urlencoded" in content_type:
return _redacted_query_params(parse_qsl(text, keep_blank_values=True))
if content_type.startswith("text/"):
return text
return {"content_type": content_type, "size_bytes": len(body)}
def create_app(
*,
config: GatewayConfig | None = None,
@ -57,7 +192,7 @@ def create_app(
)
service = MemoryGatewayService(cfg, repository, client)
app = FastAPI(title="memory-gateway2", version="0.1.0")
app = FastAPI(title="memory-gateway", version="0.1.0")
app.state.config = cfg
app.state.repository = repository
app.state.everos_client = client
@ -65,6 +200,77 @@ def create_app(
router = APIRouter()
@app.middleware("http")
async def log_api_request(request: Request, call_next: Any) -> Response:
request_time = datetime.now(timezone.utc).isoformat()
started = time.perf_counter()
request_content_type = request.headers.get("content-type")
raw_content_length = request.headers.get("content-length")
try:
request_content_length = (
int(raw_content_length) if raw_content_length is not None else None
)
except ValueError:
request_content_length = None
capture_request_body = _should_capture_request_body(
request_content_type,
request_content_length,
)
request_body = await request.body() if capture_request_body else None
response_body = b""
status_code = 500
error: str | None = None
try:
response = await call_next(request)
status_code = response.status_code
async for chunk in response.body_iterator:
response_body += chunk
logged_response = Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
background=response.background,
)
except Exception as exc:
error = str(exc)
raise
finally:
duration_ms = round((time.perf_counter() - started) * 1000, 3)
query_items = list(request.query_params.multi_items())
output: dict[str, Any] = {"status_code": status_code}
if error is not None:
output["error"] = error
else:
output["body"] = _body_for_log(
response_body,
response.headers.get("content-type"),
)
event = {
"request_time": request_time,
"duration_ms": duration_ms,
"method": request.method,
"path": request.url.path,
"url": _redacted_url(str(request.url)),
"client": request.client.host if request.client else None,
"input": {
"query_params": _redacted_query_params(query_items),
"body": (
_body_for_log(request_body, request_content_type)
if request_body is not None
else _uncaptured_body_for_log(
request_content_type,
request_content_length,
)
),
},
"output": output,
}
API_LOGGER.info(json.dumps(event, ensure_ascii=False, default=str))
return logged_response
def require_user(user_id: str, user_key: str) -> None:
if not service.authenticate_user(user_id, user_key):
raise HTTPException(status_code=401, detail="invalid user credentials")
@ -169,6 +375,29 @@ def create_app(
project_id=request.project_id,
)
@router.post("/memories/add")
async def add_memory(
request: AddMemoryRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
return await service.add_memory(
session_id=request.session_id,
app_id=request.app_id,
project_id=request.project_id,
messages=[message.model_dump() for message in request.messages],
)
@router.post("/memories/flush")
async def flush_memory(
request: FlushMemoryRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
return await service.flush_memory(
session_id=request.session_id,
app_id=request.app_id,
project_id=request.project_id,
)
@router.patch("/memories/{memory_id}")
async def patch_memory(
memory_id: str,