feat: add multipart memory uploads
This commit is contained in:
83
core/api.py
83
core/api.py
@ -9,7 +9,9 @@ from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
from starlette.responses import Response
|
||||
|
||||
from .config import GatewayConfig
|
||||
@ -220,6 +222,60 @@ def _backend_http_error_detail(exc: httpx.HTTPStatusError) -> Any:
|
||||
return exc.response.text
|
||||
|
||||
|
||||
def _form_text(form: Any, field: str, default: str | None = None) -> str:
|
||||
value = form.get(field)
|
||||
if value is None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise HTTPException(status_code=422, detail=f"missing form field: {field}")
|
||||
if isinstance(value, StarletteUploadFile):
|
||||
raise HTTPException(status_code=422, detail=f"form field must be text: {field}")
|
||||
return str(value)
|
||||
|
||||
|
||||
async def _form_json_text(form: Any, field: str) -> str:
|
||||
value = form.get(field)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=422, detail=f"missing form field: {field}")
|
||||
if isinstance(value, StarletteUploadFile):
|
||||
raw = await value.read()
|
||||
return raw.decode("utf-8")
|
||||
return str(value)
|
||||
|
||||
|
||||
def _upload_files_from_form(form: Any) -> dict[str, UploadFile]:
|
||||
files: dict[str, UploadFile] = {}
|
||||
for key, value in form.multi_items():
|
||||
if not isinstance(value, StarletteUploadFile):
|
||||
continue
|
||||
if key == "messages":
|
||||
continue
|
||||
if key in files:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"duplicate upload file field: {key}",
|
||||
)
|
||||
files[key] = value
|
||||
return files
|
||||
|
||||
|
||||
async def _multipart_messages(form: Any) -> list[dict[str, Any]]:
|
||||
raw = await _form_json_text(form, "messages")
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"invalid messages JSON: {exc.msg}",
|
||||
) from exc
|
||||
if not isinstance(parsed, list):
|
||||
raise HTTPException(status_code=400, detail="messages must be a JSON array")
|
||||
try:
|
||||
return [AddMemoryMessage.model_validate(item).model_dump() for item in parsed]
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.errors()) from exc
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
config: GatewayConfig | None = None,
|
||||
@ -466,6 +522,33 @@ def create_app(
|
||||
except InvalidAttachment as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
@router.post("/memories/add/multipart")
|
||||
async def add_memory_multipart(request: Request) -> dict[str, Any]:
|
||||
form = await request.form()
|
||||
user_id = _form_text(form, "user_id")
|
||||
user_key = _form_text(form, "user_key")
|
||||
require_user(user_id, user_key)
|
||||
try:
|
||||
return await service.add_memory_with_uploads(
|
||||
user_id=user_id,
|
||||
session_id=_form_text(form, "session_id"),
|
||||
app_id=_form_text(form, "app_id", "default"),
|
||||
project_id=_form_text(form, "project_id", "default"),
|
||||
messages=await _multipart_messages(form),
|
||||
upload_files=_upload_files_from_form(form),
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(
|
||||
status_code=exc.response.status_code,
|
||||
detail=_backend_http_error_detail(exc),
|
||||
) from exc
|
||||
except UploadTooLarge as exc:
|
||||
raise HTTPException(status_code=413, detail=str(exc)) from exc
|
||||
except UnsupportedContentType as exc:
|
||||
raise HTTPException(status_code=415, detail=str(exc)) from exc
|
||||
except InvalidAttachment as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
@router.post("/memories/flush")
|
||||
async def flush_memory(
|
||||
request: FlushMemoryRequest,
|
||||
|
||||
Reference in New Issue
Block a user