add multimodal memory proxy and API logging
This commit is contained in:
233
core/api.py
233
core/api.py
@ -1,9 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
||||
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import Response
|
||||
|
||||
from .config import GatewayConfig
|
||||
from .db import init_db
|
||||
@ -12,6 +18,19 @@ from .repository import MemoryRepository
|
||||
from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge
|
||||
|
||||
|
||||
API_LOGGER = logging.getLogger("memory_gateway.api")
|
||||
MAX_LOG_BODY_BYTES = 4096
|
||||
REDACTED = "[REDACTED]"
|
||||
SENSITIVE_FIELD_NAMES = {
|
||||
"api_key",
|
||||
"authorization",
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"user_key",
|
||||
}
|
||||
|
||||
|
||||
class SearchMemoriesRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
@ -25,6 +44,30 @@ class SearchMemoriesRequest(BaseModel):
|
||||
project_id: str = "default"
|
||||
|
||||
|
||||
class AddMemoryMessage(BaseModel):
|
||||
sender_id: str = Field(min_length=1)
|
||||
role: Literal["user", "assistant", "tool"]
|
||||
timestamp: int = Field(gt=0)
|
||||
content: str | list[dict[str, Any]]
|
||||
|
||||
|
||||
class AddMemoryRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
session_id: str = Field(min_length=1)
|
||||
messages: list[AddMemoryMessage] = Field(min_length=1)
|
||||
app_id: str = "default"
|
||||
project_id: str = "default"
|
||||
|
||||
|
||||
class FlushMemoryRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
session_id: str = Field(min_length=1)
|
||||
app_id: str = "default"
|
||||
project_id: str = "default"
|
||||
|
||||
|
||||
class MemoryOverrideRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
@ -43,6 +86,98 @@ class UserCreateRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
|
||||
|
||||
def _is_sensitive_field(key: str) -> bool:
|
||||
lowered = key.lower()
|
||||
return lowered in SENSITIVE_FIELD_NAMES or lowered.endswith("_token")
|
||||
|
||||
|
||||
def _redact(value: Any, key: str | None = None) -> Any:
|
||||
if key is not None and _is_sensitive_field(key):
|
||||
return REDACTED
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
item_key: _redact(item_value, item_key)
|
||||
for item_key, item_value in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [_redact(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _redacted_query_params(items: list[tuple[str, str]]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in items:
|
||||
safe_value = REDACTED if _is_sensitive_field(key) else value
|
||||
if key in result:
|
||||
existing = result[key]
|
||||
if isinstance(existing, list):
|
||||
existing.append(safe_value)
|
||||
else:
|
||||
result[key] = [existing, safe_value]
|
||||
else:
|
||||
result[key] = safe_value
|
||||
return result
|
||||
|
||||
|
||||
def _redacted_url(url: str) -> str:
|
||||
parts = urlsplit(url)
|
||||
query = "&".join(
|
||||
f"{quote(key, safe='')}="
|
||||
f"{quote(REDACTED if _is_sensitive_field(key) else value, safe='[]')}"
|
||||
for key, value in parse_qsl(parts.query, keep_blank_values=True)
|
||||
)
|
||||
return urlunsplit((parts.scheme, parts.netloc, parts.path, query, parts.fragment))
|
||||
|
||||
|
||||
def _should_capture_request_body(
|
||||
content_type: str | None,
|
||||
content_length: int | None,
|
||||
) -> bool:
|
||||
normalized = (content_type or "").lower()
|
||||
if normalized.startswith("multipart/"):
|
||||
return False
|
||||
return content_length is None or content_length <= MAX_LOG_BODY_BYTES
|
||||
|
||||
|
||||
def _uncaptured_body_for_log(
|
||||
content_type: str | None,
|
||||
content_length: int | None,
|
||||
) -> dict[str, Any]:
|
||||
normalized = (content_type or "").lower()
|
||||
result: dict[str, Any] = {
|
||||
"content_type": normalized,
|
||||
"size_bytes": content_length,
|
||||
}
|
||||
if not normalized.startswith("multipart/"):
|
||||
result["truncated"] = True
|
||||
return result
|
||||
|
||||
|
||||
def _body_for_log(body: bytes, content_type: str | None) -> Any:
|
||||
if not body:
|
||||
return None
|
||||
content_type = (content_type or "").lower()
|
||||
if content_type.startswith("multipart/"):
|
||||
return {"content_type": content_type, "size_bytes": len(body)}
|
||||
if len(body) > MAX_LOG_BODY_BYTES:
|
||||
return {
|
||||
"truncated": True,
|
||||
"size_bytes": len(body),
|
||||
"content_type": content_type,
|
||||
}
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return _redact(json.loads(text))
|
||||
except json.JSONDecodeError:
|
||||
return text
|
||||
if "application/x-www-form-urlencoded" in content_type:
|
||||
return _redacted_query_params(parse_qsl(text, keep_blank_values=True))
|
||||
if content_type.startswith("text/"):
|
||||
return text
|
||||
return {"content_type": content_type, "size_bytes": len(body)}
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
config: GatewayConfig | None = None,
|
||||
@ -57,7 +192,7 @@ def create_app(
|
||||
)
|
||||
service = MemoryGatewayService(cfg, repository, client)
|
||||
|
||||
app = FastAPI(title="memory-gateway2", version="0.1.0")
|
||||
app = FastAPI(title="memory-gateway", version="0.1.0")
|
||||
app.state.config = cfg
|
||||
app.state.repository = repository
|
||||
app.state.everos_client = client
|
||||
@ -65,6 +200,77 @@ def create_app(
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_api_request(request: Request, call_next: Any) -> Response:
|
||||
request_time = datetime.now(timezone.utc).isoformat()
|
||||
started = time.perf_counter()
|
||||
request_content_type = request.headers.get("content-type")
|
||||
raw_content_length = request.headers.get("content-length")
|
||||
try:
|
||||
request_content_length = (
|
||||
int(raw_content_length) if raw_content_length is not None else None
|
||||
)
|
||||
except ValueError:
|
||||
request_content_length = None
|
||||
capture_request_body = _should_capture_request_body(
|
||||
request_content_type,
|
||||
request_content_length,
|
||||
)
|
||||
request_body = await request.body() if capture_request_body else None
|
||||
response_body = b""
|
||||
status_code = 500
|
||||
error: str | None = None
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
async for chunk in response.body_iterator:
|
||||
response_body += chunk
|
||||
logged_response = Response(
|
||||
content=response_body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
background=response.background,
|
||||
)
|
||||
except Exception as exc:
|
||||
error = str(exc)
|
||||
raise
|
||||
finally:
|
||||
duration_ms = round((time.perf_counter() - started) * 1000, 3)
|
||||
query_items = list(request.query_params.multi_items())
|
||||
output: dict[str, Any] = {"status_code": status_code}
|
||||
if error is not None:
|
||||
output["error"] = error
|
||||
else:
|
||||
output["body"] = _body_for_log(
|
||||
response_body,
|
||||
response.headers.get("content-type"),
|
||||
)
|
||||
event = {
|
||||
"request_time": request_time,
|
||||
"duration_ms": duration_ms,
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"url": _redacted_url(str(request.url)),
|
||||
"client": request.client.host if request.client else None,
|
||||
"input": {
|
||||
"query_params": _redacted_query_params(query_items),
|
||||
"body": (
|
||||
_body_for_log(request_body, request_content_type)
|
||||
if request_body is not None
|
||||
else _uncaptured_body_for_log(
|
||||
request_content_type,
|
||||
request_content_length,
|
||||
)
|
||||
),
|
||||
},
|
||||
"output": output,
|
||||
}
|
||||
API_LOGGER.info(json.dumps(event, ensure_ascii=False, default=str))
|
||||
|
||||
return logged_response
|
||||
|
||||
def require_user(user_id: str, user_key: str) -> None:
|
||||
if not service.authenticate_user(user_id, user_key):
|
||||
raise HTTPException(status_code=401, detail="invalid user credentials")
|
||||
@ -169,6 +375,29 @@ def create_app(
|
||||
project_id=request.project_id,
|
||||
)
|
||||
|
||||
@router.post("/memories/add")
|
||||
async def add_memory(
|
||||
request: AddMemoryRequest,
|
||||
) -> dict[str, Any]:
|
||||
require_user(request.user_id, request.user_key)
|
||||
return await service.add_memory(
|
||||
session_id=request.session_id,
|
||||
app_id=request.app_id,
|
||||
project_id=request.project_id,
|
||||
messages=[message.model_dump() for message in request.messages],
|
||||
)
|
||||
|
||||
@router.post("/memories/flush")
|
||||
async def flush_memory(
|
||||
request: FlushMemoryRequest,
|
||||
) -> dict[str, Any]:
|
||||
require_user(request.user_id, request.user_key)
|
||||
return await service.flush_memory(
|
||||
session_id=request.session_id,
|
||||
app_id=request.app_id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
|
||||
@router.patch("/memories/{memory_id}")
|
||||
async def patch_memory(
|
||||
memory_id: str,
|
||||
|
||||
@ -27,7 +27,7 @@ _DEFAULT_ALLOWED_MIME_TYPES = (
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GatewayConfig:
|
||||
everos_base_url: str = "http://127.0.0.1:8000"
|
||||
everos_base_url: str = "http://127.0.0.1:1995"
|
||||
database_path: Path = _PROJECT_ROOT / "data" / "memory_gateway.sqlite3"
|
||||
storage_dir: Path = _PROJECT_ROOT / "data" / "storage"
|
||||
resource_search_batch_size: int = 50
|
||||
@ -50,7 +50,7 @@ class GatewayConfig:
|
||||
return cls(
|
||||
everos_base_url=os.environ.get(
|
||||
"EVEROS_BASE_URL",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://127.0.0.1:1995",
|
||||
).rstrip("/"),
|
||||
database_path=Path(
|
||||
os.environ.get(
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import secrets
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -28,6 +30,10 @@ def public_resource_uri(user_id: str, resource_id: str) -> str:
|
||||
return f"resource://{user_id}/{resource_id}"
|
||||
|
||||
|
||||
def current_timestamp_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def infer_content_type(filename: str | None, mime_type: str | None) -> str:
|
||||
mime = (mime_type or mimetypes.guess_type(filename or "")[0] or "").lower()
|
||||
suffix = Path(filename or "").suffix.lower()
|
||||
@ -249,6 +255,7 @@ class MemoryGatewayService:
|
||||
project_id: str,
|
||||
filename: str,
|
||||
) -> dict[str, Any]:
|
||||
content_item = self._build_content_item(resource=resource, filename=filename)
|
||||
return {
|
||||
"session_id": resource["session_id"],
|
||||
"app_id": app_id,
|
||||
@ -257,23 +264,43 @@ class MemoryGatewayService:
|
||||
{
|
||||
"sender_id": user_id,
|
||||
"role": "user",
|
||||
"timestamp": 1781068800000,
|
||||
"content": [
|
||||
{
|
||||
"type": resource["content_type"],
|
||||
"uri": resource["uri"],
|
||||
"name": filename,
|
||||
"ext": Path(filename).suffix.lstrip(".") or None,
|
||||
"extras": {
|
||||
"resource_id": resource["id"],
|
||||
"source": "user_upload",
|
||||
},
|
||||
}
|
||||
],
|
||||
"timestamp": current_timestamp_ms(),
|
||||
"content": [content_item],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _build_content_item(
|
||||
self,
|
||||
*,
|
||||
resource: dict[str, Any],
|
||||
filename: str,
|
||||
) -> dict[str, Any]:
|
||||
content_type = str(resource["content_type"])
|
||||
path = self._resource_file_path(resource)
|
||||
content = path.read_bytes()
|
||||
item = {
|
||||
"type": content_type,
|
||||
"name": filename,
|
||||
"ext": Path(filename).suffix.lstrip(".") or None,
|
||||
"extras": {
|
||||
"resource_id": resource["id"],
|
||||
"source": "user_upload",
|
||||
},
|
||||
}
|
||||
if content_type == "text":
|
||||
item["text"] = content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
item["base64"] = base64.b64encode(content).decode("ascii")
|
||||
return item
|
||||
|
||||
def _resource_file_path(self, resource: dict[str, Any]) -> Path:
|
||||
uri = str(resource["uri"])
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme != "file":
|
||||
raise ValueError(f"unsupported resource uri scheme: {parsed.scheme}")
|
||||
return Path(unquote(parsed.path)).resolve(strict=True)
|
||||
|
||||
def list_resources(self, user_id: str) -> list[dict[str, Any]]:
|
||||
return [self._resource_detail(item) for item in self.repository.list_resources(user_id)]
|
||||
|
||||
@ -395,6 +422,41 @@ class MemoryGatewayService:
|
||||
overridden = self._apply_overrides(user_id, filtered)
|
||||
return {"results": overridden}
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
payload = {
|
||||
"session_id": session_id,
|
||||
"app_id": app_id,
|
||||
"project_id": project_id,
|
||||
"messages": messages,
|
||||
}
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"everos": await self.everos_client.add_memory(payload),
|
||||
}
|
||||
|
||||
async def flush_memory(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"everos": await self.everos_client.flush_memory(
|
||||
session_id,
|
||||
app_id,
|
||||
project_id,
|
||||
),
|
||||
}
|
||||
|
||||
def _search_payload(
|
||||
self,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user