harden memory edits and uploads
This commit is contained in:
164
core/service.py
164
core/service.py
@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import secrets
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
@ -46,19 +49,71 @@ def _safe_filename(filename: str | None) -> str:
|
||||
return name or "upload.bin"
|
||||
|
||||
|
||||
def _copy_upload(file: UploadFile, destination: Path) -> tuple[str, int]:
|
||||
class UploadTooLarge(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedContentType(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _copy_upload(
|
||||
file: UploadFile,
|
||||
destination: Path,
|
||||
max_upload_bytes: int,
|
||||
cleanup_root: Path,
|
||||
) -> tuple[str, int]:
|
||||
sha256 = hashlib.sha256()
|
||||
size = 0
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
with destination.open("wb") as out:
|
||||
while True:
|
||||
chunk = file.file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
size += len(chunk)
|
||||
sha256.update(chunk)
|
||||
out.write(chunk)
|
||||
return sha256.hexdigest(), size
|
||||
try:
|
||||
with destination.open("wb") as out:
|
||||
while True:
|
||||
chunk = file.file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
size += len(chunk)
|
||||
if size > max_upload_bytes:
|
||||
raise UploadTooLarge(
|
||||
f"upload exceeds max size of {max_upload_bytes} bytes"
|
||||
)
|
||||
sha256.update(chunk)
|
||||
out.write(chunk)
|
||||
return sha256.hexdigest(), size
|
||||
except Exception:
|
||||
destination.unlink(missing_ok=True)
|
||||
_remove_empty_parents(destination.parent, stop_at=cleanup_root)
|
||||
raise
|
||||
|
||||
|
||||
def _mime_allowed(mime_type: str | None, allowed_mime_types: tuple[str, ...]) -> bool:
|
||||
mime = (mime_type or "").lower()
|
||||
if not mime:
|
||||
return False
|
||||
for allowed in allowed_mime_types:
|
||||
item = allowed.lower()
|
||||
if item.endswith("/*") and mime.startswith(item[:-1]):
|
||||
return True
|
||||
if item == mime:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _remove_empty_parents(path: Path, stop_at: Path | None = None) -> None:
|
||||
current = path
|
||||
stop = stop_at.resolve() if stop_at is not None else None
|
||||
while True:
|
||||
try:
|
||||
resolved = current.resolve()
|
||||
if stop is not None and resolved == stop:
|
||||
return
|
||||
current.rmdir()
|
||||
except OSError:
|
||||
return
|
||||
parent = current.parent
|
||||
if parent == current:
|
||||
return
|
||||
current = parent
|
||||
|
||||
|
||||
class MemoryGatewayService:
|
||||
@ -101,9 +156,26 @@ class MemoryGatewayService:
|
||||
session_id = resource_session_id(user_id, resource_id)
|
||||
original_filename = _safe_filename(file.filename)
|
||||
mime_type = file.content_type or mimetypes.guess_type(original_filename)[0]
|
||||
if not _mime_allowed(mime_type, self.config.allowed_mime_types):
|
||||
raise UnsupportedContentType(f"unsupported content type: {mime_type}")
|
||||
content_type = infer_content_type(original_filename, mime_type)
|
||||
stored_path = self.config.storage_dir / user_id / resource_id / original_filename
|
||||
sha256, size_bytes = _copy_upload(file, stored_path)
|
||||
sha256, size_bytes = _copy_upload(
|
||||
file,
|
||||
stored_path,
|
||||
self.config.max_upload_bytes,
|
||||
self.config.storage_dir,
|
||||
)
|
||||
existing = self.repository.find_active_resource_by_sha256(
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
sha256=sha256,
|
||||
)
|
||||
if existing is not None:
|
||||
shutil.rmtree(stored_path.parent, ignore_errors=True)
|
||||
return self._resource_summary(existing)
|
||||
|
||||
internal_uri = stored_path.resolve().as_uri()
|
||||
|
||||
resource = self.repository.create_resource(
|
||||
@ -126,16 +198,20 @@ class MemoryGatewayService:
|
||||
)
|
||||
|
||||
try:
|
||||
await self.everos_client.add_memory(
|
||||
self._build_add_payload(
|
||||
resource=resource,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
filename=original_filename,
|
||||
await self._retry_everos_call(
|
||||
lambda: self.everos_client.add_memory(
|
||||
self._build_add_payload(
|
||||
resource=resource,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
filename=original_filename,
|
||||
)
|
||||
)
|
||||
)
|
||||
await self.everos_client.flush_memory(session_id, app_id, project_id)
|
||||
await self._retry_everos_call(
|
||||
lambda: self.everos_client.flush_memory(session_id, app_id, project_id)
|
||||
)
|
||||
except Exception as exc:
|
||||
failed = self.repository.update_resource_status(
|
||||
resource_id,
|
||||
@ -147,6 +223,23 @@ class MemoryGatewayService:
|
||||
extracted = self.repository.update_resource_status(resource_id, "extracted")
|
||||
return self._resource_summary(extracted or resource)
|
||||
|
||||
async def _retry_everos_call(self, operation: Any) -> Any:
|
||||
attempts = max(1, self.config.everos_ingest_attempts)
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
return await operation()
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if attempt == attempts - 1:
|
||||
break
|
||||
delay = self.config.everos_retry_delay_seconds
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
if last_error is None:
|
||||
raise RuntimeError("EverOS operation failed")
|
||||
raise last_error
|
||||
|
||||
def _build_add_payload(
|
||||
self,
|
||||
*,
|
||||
@ -199,8 +292,29 @@ class MemoryGatewayService:
|
||||
if before is None:
|
||||
return None
|
||||
resource = self.repository.soft_delete_resource(resource_id, user_id)
|
||||
self._cleanup_resource_file(before)
|
||||
return self._resource_summary(resource)
|
||||
|
||||
def _cleanup_resource_file(self, resource: dict[str, Any]) -> None:
|
||||
uri = str(resource.get("uri") or "")
|
||||
if not uri.startswith("file://"):
|
||||
return
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme != "file":
|
||||
return
|
||||
try:
|
||||
path = Path(unquote(parsed.path)).resolve()
|
||||
storage_root = self.config.storage_dir.resolve()
|
||||
except OSError:
|
||||
return
|
||||
if not path.is_relative_to(storage_root):
|
||||
return
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
return
|
||||
_remove_empty_parents(path.parent, stop_at=storage_root)
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
*,
|
||||
@ -394,6 +508,18 @@ class MemoryGatewayService:
|
||||
)
|
||||
return {"memory_id": memory_id, "override_id": override["id"], "status": "active"}
|
||||
|
||||
def assert_memory_session_owned(self, user_id: str, session_id: str) -> None:
|
||||
if session_id == f"memory_edit:{user_id}":
|
||||
return
|
||||
if session_id.startswith("resource:"):
|
||||
resource = self.repository.get_resource_by_session_for_user(
|
||||
session_id,
|
||||
user_id,
|
||||
)
|
||||
if resource is not None:
|
||||
return
|
||||
raise PermissionError("memory session does not belong to user")
|
||||
|
||||
def delete_memory(
|
||||
self,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user