harden memory edits and uploads

This commit is contained in:
2026-06-11 11:06:35 +08:00
parent 7155704b73
commit 8afb460883
7 changed files with 469 additions and 72 deletions

View File

@ -1,11 +1,14 @@
from __future__ import annotations
import asyncio
import hashlib
import mimetypes
import secrets
import shutil
import uuid
from pathlib import Path
from typing import Any
from urllib.parse import unquote, urlparse
from fastapi import UploadFile
@ -46,19 +49,71 @@ def _safe_filename(filename: str | None) -> str:
return name or "upload.bin"
def _copy_upload(file: UploadFile, destination: Path) -> tuple[str, int]:
class UploadTooLarge(ValueError):
pass
class UnsupportedContentType(ValueError):
pass
def _copy_upload(
file: UploadFile,
destination: Path,
max_upload_bytes: int,
cleanup_root: Path,
) -> tuple[str, int]:
sha256 = hashlib.sha256()
size = 0
destination.parent.mkdir(parents=True, exist_ok=True)
with destination.open("wb") as out:
while True:
chunk = file.file.read(1024 * 1024)
if not chunk:
break
size += len(chunk)
sha256.update(chunk)
out.write(chunk)
return sha256.hexdigest(), size
try:
with destination.open("wb") as out:
while True:
chunk = file.file.read(1024 * 1024)
if not chunk:
break
size += len(chunk)
if size > max_upload_bytes:
raise UploadTooLarge(
f"upload exceeds max size of {max_upload_bytes} bytes"
)
sha256.update(chunk)
out.write(chunk)
return sha256.hexdigest(), size
except Exception:
destination.unlink(missing_ok=True)
_remove_empty_parents(destination.parent, stop_at=cleanup_root)
raise
def _mime_allowed(mime_type: str | None, allowed_mime_types: tuple[str, ...]) -> bool:
mime = (mime_type or "").lower()
if not mime:
return False
for allowed in allowed_mime_types:
item = allowed.lower()
if item.endswith("/*") and mime.startswith(item[:-1]):
return True
if item == mime:
return True
return False
def _remove_empty_parents(path: Path, stop_at: Path | None = None) -> None:
current = path
stop = stop_at.resolve() if stop_at is not None else None
while True:
try:
resolved = current.resolve()
if stop is not None and resolved == stop:
return
current.rmdir()
except OSError:
return
parent = current.parent
if parent == current:
return
current = parent
class MemoryGatewayService:
@ -101,9 +156,26 @@ class MemoryGatewayService:
session_id = resource_session_id(user_id, resource_id)
original_filename = _safe_filename(file.filename)
mime_type = file.content_type or mimetypes.guess_type(original_filename)[0]
if not _mime_allowed(mime_type, self.config.allowed_mime_types):
raise UnsupportedContentType(f"unsupported content type: {mime_type}")
content_type = infer_content_type(original_filename, mime_type)
stored_path = self.config.storage_dir / user_id / resource_id / original_filename
sha256, size_bytes = _copy_upload(file, stored_path)
sha256, size_bytes = _copy_upload(
file,
stored_path,
self.config.max_upload_bytes,
self.config.storage_dir,
)
existing = self.repository.find_active_resource_by_sha256(
user_id=user_id,
app_id=app_id,
project_id=project_id,
sha256=sha256,
)
if existing is not None:
shutil.rmtree(stored_path.parent, ignore_errors=True)
return self._resource_summary(existing)
internal_uri = stored_path.resolve().as_uri()
resource = self.repository.create_resource(
@ -126,16 +198,20 @@ class MemoryGatewayService:
)
try:
await self.everos_client.add_memory(
self._build_add_payload(
resource=resource,
user_id=user_id,
app_id=app_id,
project_id=project_id,
filename=original_filename,
await self._retry_everos_call(
lambda: self.everos_client.add_memory(
self._build_add_payload(
resource=resource,
user_id=user_id,
app_id=app_id,
project_id=project_id,
filename=original_filename,
)
)
)
await self.everos_client.flush_memory(session_id, app_id, project_id)
await self._retry_everos_call(
lambda: self.everos_client.flush_memory(session_id, app_id, project_id)
)
except Exception as exc:
failed = self.repository.update_resource_status(
resource_id,
@ -147,6 +223,23 @@ class MemoryGatewayService:
extracted = self.repository.update_resource_status(resource_id, "extracted")
return self._resource_summary(extracted or resource)
async def _retry_everos_call(self, operation: Any) -> Any:
attempts = max(1, self.config.everos_ingest_attempts)
last_error: Exception | None = None
for attempt in range(attempts):
try:
return await operation()
except Exception as exc:
last_error = exc
if attempt == attempts - 1:
break
delay = self.config.everos_retry_delay_seconds
if delay > 0:
await asyncio.sleep(delay)
if last_error is None:
raise RuntimeError("EverOS operation failed")
raise last_error
def _build_add_payload(
self,
*,
@ -199,8 +292,29 @@ class MemoryGatewayService:
if before is None:
return None
resource = self.repository.soft_delete_resource(resource_id, user_id)
self._cleanup_resource_file(before)
return self._resource_summary(resource)
def _cleanup_resource_file(self, resource: dict[str, Any]) -> None:
uri = str(resource.get("uri") or "")
if not uri.startswith("file://"):
return
parsed = urlparse(uri)
if parsed.scheme != "file":
return
try:
path = Path(unquote(parsed.path)).resolve()
storage_root = self.config.storage_dir.resolve()
except OSError:
return
if not path.is_relative_to(storage_root):
return
try:
path.unlink(missing_ok=True)
except OSError:
return
_remove_empty_parents(path.parent, stop_at=storage_root)
async def search_memories(
self,
*,
@ -394,6 +508,18 @@ class MemoryGatewayService:
)
return {"memory_id": memory_id, "override_id": override["id"], "status": "active"}
def assert_memory_session_owned(self, user_id: str, session_id: str) -> None:
if session_id == f"memory_edit:{user_id}":
return
if session_id.startswith("resource:"):
resource = self.repository.get_resource_by_session_for_user(
session_id,
user_id,
)
if resource is not None:
return
raise PermissionError("memory session does not belong to user")
def delete_memory(
self,
*,