feat: add multipart memory uploads
This commit is contained in:
83
core/api.py
83
core/api.py
@ -9,7 +9,9 @@ from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
from starlette.responses import Response
|
||||
|
||||
from .config import GatewayConfig
|
||||
@ -220,6 +222,60 @@ def _backend_http_error_detail(exc: httpx.HTTPStatusError) -> Any:
|
||||
return exc.response.text
|
||||
|
||||
|
||||
def _form_text(form: Any, field: str, default: str | None = None) -> str:
|
||||
value = form.get(field)
|
||||
if value is None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise HTTPException(status_code=422, detail=f"missing form field: {field}")
|
||||
if isinstance(value, StarletteUploadFile):
|
||||
raise HTTPException(status_code=422, detail=f"form field must be text: {field}")
|
||||
return str(value)
|
||||
|
||||
|
||||
async def _form_json_text(form: Any, field: str) -> str:
|
||||
value = form.get(field)
|
||||
if value is None:
|
||||
raise HTTPException(status_code=422, detail=f"missing form field: {field}")
|
||||
if isinstance(value, StarletteUploadFile):
|
||||
raw = await value.read()
|
||||
return raw.decode("utf-8")
|
||||
return str(value)
|
||||
|
||||
|
||||
def _upload_files_from_form(form: Any) -> dict[str, UploadFile]:
|
||||
files: dict[str, UploadFile] = {}
|
||||
for key, value in form.multi_items():
|
||||
if not isinstance(value, StarletteUploadFile):
|
||||
continue
|
||||
if key == "messages":
|
||||
continue
|
||||
if key in files:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"duplicate upload file field: {key}",
|
||||
)
|
||||
files[key] = value
|
||||
return files
|
||||
|
||||
|
||||
async def _multipart_messages(form: Any) -> list[dict[str, Any]]:
|
||||
raw = await _form_json_text(form, "messages")
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"invalid messages JSON: {exc.msg}",
|
||||
) from exc
|
||||
if not isinstance(parsed, list):
|
||||
raise HTTPException(status_code=400, detail="messages must be a JSON array")
|
||||
try:
|
||||
return [AddMemoryMessage.model_validate(item).model_dump() for item in parsed]
|
||||
except ValidationError as exc:
|
||||
raise HTTPException(status_code=422, detail=exc.errors()) from exc
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
config: GatewayConfig | None = None,
|
||||
@ -466,6 +522,33 @@ def create_app(
|
||||
except InvalidAttachment as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
@router.post("/memories/add/multipart")
|
||||
async def add_memory_multipart(request: Request) -> dict[str, Any]:
|
||||
form = await request.form()
|
||||
user_id = _form_text(form, "user_id")
|
||||
user_key = _form_text(form, "user_key")
|
||||
require_user(user_id, user_key)
|
||||
try:
|
||||
return await service.add_memory_with_uploads(
|
||||
user_id=user_id,
|
||||
session_id=_form_text(form, "session_id"),
|
||||
app_id=_form_text(form, "app_id", "default"),
|
||||
project_id=_form_text(form, "project_id", "default"),
|
||||
messages=await _multipart_messages(form),
|
||||
upload_files=_upload_files_from_form(form),
|
||||
)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise HTTPException(
|
||||
status_code=exc.response.status_code,
|
||||
detail=_backend_http_error_detail(exc),
|
||||
) from exc
|
||||
except UploadTooLarge as exc:
|
||||
raise HTTPException(status_code=413, detail=str(exc)) from exc
|
||||
except UnsupportedContentType as exc:
|
||||
raise HTTPException(status_code=415, detail=str(exc)) from exc
|
||||
except InvalidAttachment as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
@router.post("/memories/flush")
|
||||
async def flush_memory(
|
||||
request: FlushMemoryRequest,
|
||||
|
||||
167
core/service.py
167
core/service.py
@ -138,6 +138,25 @@ def _remove_empty_parents(path: Path, stop_at: Path | None = None) -> None:
|
||||
current = parent
|
||||
|
||||
|
||||
def _read_upload_bytes(
|
||||
file: UploadFile,
|
||||
max_upload_bytes: int,
|
||||
) -> tuple[bytes, str, int]:
|
||||
sha256 = hashlib.sha256()
|
||||
size = 0
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
chunk = file.file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
size += len(chunk)
|
||||
if size > max_upload_bytes:
|
||||
raise UploadTooLarge(f"upload exceeds max size of {max_upload_bytes} bytes")
|
||||
sha256.update(chunk)
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks), sha256.hexdigest(), size
|
||||
|
||||
|
||||
class MemoryGatewayService:
|
||||
def __init__(
|
||||
self,
|
||||
@ -617,6 +636,41 @@ class MemoryGatewayService:
|
||||
raise
|
||||
return {"session_id": session_id, "backend": backend}
|
||||
|
||||
async def add_memory_with_uploads(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
upload_files: dict[str, UploadFile],
|
||||
) -> dict[str, Any]:
|
||||
messages, attachments, generated_paths = self._prepare_uploaded_memory_files(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
messages=messages,
|
||||
upload_files=upload_files,
|
||||
)
|
||||
payload = {
|
||||
"session_id": session_id,
|
||||
"app_id": app_id,
|
||||
"project_id": project_id,
|
||||
"messages": messages,
|
||||
}
|
||||
try:
|
||||
backend = await self.backend_client.add_memory(payload)
|
||||
for attachment in attachments:
|
||||
self.repository.create_attachment(**attachment)
|
||||
except Exception:
|
||||
for path in generated_paths:
|
||||
path.unlink(missing_ok=True)
|
||||
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
|
||||
raise
|
||||
return {"session_id": session_id, "backend": backend}
|
||||
|
||||
def _register_resource_attachment(
|
||||
self,
|
||||
resource: dict[str, Any],
|
||||
@ -713,6 +767,119 @@ class MemoryGatewayService:
|
||||
raise
|
||||
return attachments, generated_paths
|
||||
|
||||
def _prepare_uploaded_memory_files(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
upload_files: dict[str, UploadFile],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[Path]]:
|
||||
attachments: list[dict[str, Any]] = []
|
||||
generated_paths: list[Path] = []
|
||||
used_upload_ids: set[str] = set()
|
||||
try:
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for index, item in enumerate(content):
|
||||
if not isinstance(item, dict) or "upload_id" not in item:
|
||||
continue
|
||||
upload_id = str(item.get("upload_id") or "").strip()
|
||||
if not upload_id:
|
||||
raise InvalidAttachment("upload_id must not be empty")
|
||||
if upload_id in used_upload_ids:
|
||||
raise InvalidAttachment(f"duplicate upload_id: {upload_id}")
|
||||
file = upload_files.get(upload_id)
|
||||
if file is None:
|
||||
raise InvalidAttachment(
|
||||
f"missing upload file for upload_id: {upload_id}"
|
||||
)
|
||||
used_upload_ids.add(upload_id)
|
||||
content[index] = self._materialize_uploaded_content_item(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
item=item,
|
||||
file=file,
|
||||
attachments=attachments,
|
||||
generated_paths=generated_paths,
|
||||
)
|
||||
unused_upload_ids = sorted(set(upload_files) - used_upload_ids)
|
||||
if unused_upload_ids:
|
||||
raise InvalidAttachment(
|
||||
f"unused upload file field: {unused_upload_ids[0]}"
|
||||
)
|
||||
except Exception:
|
||||
for path in generated_paths:
|
||||
path.unlink(missing_ok=True)
|
||||
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
|
||||
raise
|
||||
return messages, attachments, generated_paths
|
||||
|
||||
def _materialize_uploaded_content_item(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
item: dict[str, Any],
|
||||
file: UploadFile,
|
||||
attachments: list[dict[str, Any]],
|
||||
generated_paths: list[Path],
|
||||
) -> dict[str, Any]:
|
||||
name = _safe_filename(str(item.get("name") or file.filename or "upload.bin"))
|
||||
mime_type = file.content_type or mimetypes.guess_type(name)[0]
|
||||
if not _mime_allowed(mime_type, self.config.allowed_mime_types):
|
||||
raise UnsupportedContentType(f"unsupported content type: {mime_type}")
|
||||
content_type = normalize_content_type(
|
||||
name,
|
||||
mime_type,
|
||||
str(item.get("type") or ""),
|
||||
)
|
||||
data, sha256, _size_bytes = _read_upload_bytes(
|
||||
file,
|
||||
self.config.max_upload_bytes,
|
||||
)
|
||||
path = self.config.storage_dir / user_id / "memory_attachments" / sha256 / name
|
||||
if not path.exists():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
generated_paths.append(path)
|
||||
content_item = {
|
||||
key: value for key, value in item.items() if key not in {"upload_id", "uri"}
|
||||
}
|
||||
content_item["type"] = content_type
|
||||
content_item["name"] = name
|
||||
content_item["ext"] = Path(name).suffix.lstrip(".") or content_item.get("ext")
|
||||
if content_type == "text":
|
||||
content_item.pop("base64", None)
|
||||
content_item["text"] = data.decode("utf-8", errors="replace")
|
||||
else:
|
||||
content_item.pop("text", None)
|
||||
content_item["base64"] = base64.b64encode(data).decode("ascii")
|
||||
attachments.append(
|
||||
{
|
||||
"id": f"a_{uuid.uuid4().hex}",
|
||||
"user_id": user_id,
|
||||
"app_id": app_id,
|
||||
"project_id": project_id,
|
||||
"session_id": session_id,
|
||||
"resource_id": None,
|
||||
"content_type": content_type,
|
||||
"name": name,
|
||||
"internal_uri": path.resolve().as_uri(),
|
||||
"source": "memory_add_upload",
|
||||
"sha256": sha256,
|
||||
}
|
||||
)
|
||||
return content_item
|
||||
|
||||
async def flush_memory(
|
||||
self,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user