harden memory edits and uploads

This commit is contained in:
2026-06-11 11:06:35 +08:00
parent 7155704b73
commit 8afb460883
7 changed files with 469 additions and 72 deletions

View File

@ -9,7 +9,7 @@ from .config import GatewayConfig
from .db import init_db
from .everos_client import EverOSClient
from .repository import MemoryRepository
from .service import MemoryGatewayService
from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge
class SearchMemoriesRequest(BaseModel):
@ -28,14 +28,14 @@ class SearchMemoriesRequest(BaseModel):
class MemoryOverrideRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
session_id: str | None = None
session_id: str = Field(min_length=1)
override_text: str = Field(min_length=1)
class MemoryDeleteRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
session_id: str | None = None
session_id: str = Field(min_length=1)
reason: str | None = None
@ -51,7 +51,10 @@ def create_app(
cfg = config or GatewayConfig.from_env()
init_db(cfg.database_path)
repository = MemoryRepository(cfg.database_path)
client = everos_client or EverOSClient(cfg.everos_base_url)
client = everos_client or EverOSClient(
cfg.everos_base_url,
timeout=cfg.everos_timeout_seconds,
)
service = MemoryGatewayService(cfg, repository, client)
app = FastAPI(title="memory-gateway2", version="0.1.0")
@ -105,14 +108,19 @@ def create_app(
file: UploadFile = File(...),
) -> dict[str, Any]:
require_user(user_id, user_key)
return await service.upload_resource(
user_id=user_id,
app_id=app_id,
project_id=project_id,
file=file,
title=title,
description=description,
)
try:
return await service.upload_resource(
user_id=user_id,
app_id=app_id,
project_id=project_id,
file=file,
title=title,
description=description,
)
except UploadTooLarge as exc:
raise HTTPException(status_code=413, detail=str(exc)) from exc
except UnsupportedContentType as exc:
raise HTTPException(status_code=415, detail=str(exc)) from exc
@router.get("/resources")
async def list_resources(
@ -167,6 +175,10 @@ def create_app(
request: MemoryOverrideRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
try:
service.assert_memory_session_owned(request.user_id, request.session_id)
except PermissionError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
return service.upsert_override(
user_id=request.user_id,
memory_id=memory_id,
@ -180,6 +192,10 @@ def create_app(
request: MemoryDeleteRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
try:
service.assert_memory_session_owned(request.user_id, request.session_id)
except PermissionError as exc:
raise HTTPException(status_code=403, detail=str(exc)) from exc
return service.delete_memory(
user_id=request.user_id,
memory_id=memory_id,

View File

@ -6,6 +6,23 @@ from pathlib import Path
_PROJECT_ROOT = Path(__file__).resolve().parents[1]
_DEFAULT_ALLOWED_MIME_TYPES = (
"image/*",
"audio/*",
"application/pdf",
"text/html",
"application/xhtml+xml",
"text/plain",
"text/markdown",
"text/csv",
"application/json",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
)
@dataclass(frozen=True)
@ -14,9 +31,22 @@ class GatewayConfig:
database_path: Path = _PROJECT_ROOT / "data" / "memory_gateway.sqlite3"
storage_dir: Path = _PROJECT_ROOT / "data" / "storage"
resource_search_batch_size: int = 50
max_upload_bytes: int = 25 * 1024 * 1024
allowed_mime_types: tuple[str, ...] = _DEFAULT_ALLOWED_MIME_TYPES
everos_ingest_attempts: int = 3
everos_retry_delay_seconds: float = 0.25
everos_timeout_seconds: float = 120.0
@classmethod
def from_env(cls) -> GatewayConfig:
allowed_mime_types = tuple(
item.strip()
for item in os.environ.get(
"MEMORY_GATEWAY_ALLOWED_MIME_TYPES",
",".join(_DEFAULT_ALLOWED_MIME_TYPES),
).split(",")
if item.strip()
)
return cls(
everos_base_url=os.environ.get(
"EVEROS_BASE_URL",
@ -37,4 +67,17 @@ class GatewayConfig:
resource_search_batch_size=int(
os.environ.get("MEMORY_GATEWAY_RESOURCE_SEARCH_BATCH_SIZE", "50")
),
max_upload_bytes=int(
os.environ.get("MEMORY_GATEWAY_MAX_UPLOAD_BYTES", str(25 * 1024 * 1024))
),
allowed_mime_types=allowed_mime_types,
everos_ingest_attempts=int(
os.environ.get("MEMORY_GATEWAY_EVEROS_INGEST_ATTEMPTS", "3")
),
everos_retry_delay_seconds=float(
os.environ.get("MEMORY_GATEWAY_EVEROS_RETRY_DELAY_SECONDS", "0.25")
),
everos_timeout_seconds=float(
os.environ.get("MEMORY_GATEWAY_EVEROS_TIMEOUT_SECONDS", "120")
),
)

View File

@ -157,6 +157,31 @@ class MemoryRepository:
).fetchone()
return _row_to_dict(row)
def find_active_resource_by_sha256(
self,
*,
user_id: str,
app_id: str,
project_id: str,
sha256: str,
) -> dict[str, Any] | None:
with connect(self.db_path) as conn:
row = conn.execute(
"""
SELECT * FROM user_resources
WHERE user_id = ?
AND app_id = ?
AND project_id = ?
AND sha256 = ?
AND deleted_at IS NULL
AND status IN ('ingesting', 'extracted')
ORDER BY created_at ASC
LIMIT 1
""",
(user_id, app_id, project_id, sha256),
).fetchone()
return _row_to_dict(row)
def list_resources(self, user_id: str) -> list[dict[str, Any]]:
with connect(self.db_path) as conn:
rows = conn.execute(

View File

@ -1,11 +1,14 @@
from __future__ import annotations
import asyncio
import hashlib
import mimetypes
import secrets
import shutil
import uuid
from pathlib import Path
from typing import Any
from urllib.parse import unquote, urlparse
from fastapi import UploadFile
@ -46,19 +49,71 @@ def _safe_filename(filename: str | None) -> str:
return name or "upload.bin"
def _copy_upload(file: UploadFile, destination: Path) -> tuple[str, int]:
class UploadTooLarge(ValueError):
pass
class UnsupportedContentType(ValueError):
pass
def _copy_upload(
file: UploadFile,
destination: Path,
max_upload_bytes: int,
cleanup_root: Path,
) -> tuple[str, int]:
sha256 = hashlib.sha256()
size = 0
destination.parent.mkdir(parents=True, exist_ok=True)
with destination.open("wb") as out:
while True:
chunk = file.file.read(1024 * 1024)
if not chunk:
break
size += len(chunk)
sha256.update(chunk)
out.write(chunk)
return sha256.hexdigest(), size
try:
with destination.open("wb") as out:
while True:
chunk = file.file.read(1024 * 1024)
if not chunk:
break
size += len(chunk)
if size > max_upload_bytes:
raise UploadTooLarge(
f"upload exceeds max size of {max_upload_bytes} bytes"
)
sha256.update(chunk)
out.write(chunk)
return sha256.hexdigest(), size
except Exception:
destination.unlink(missing_ok=True)
_remove_empty_parents(destination.parent, stop_at=cleanup_root)
raise
def _mime_allowed(mime_type: str | None, allowed_mime_types: tuple[str, ...]) -> bool:
mime = (mime_type or "").lower()
if not mime:
return False
for allowed in allowed_mime_types:
item = allowed.lower()
if item.endswith("/*") and mime.startswith(item[:-1]):
return True
if item == mime:
return True
return False
def _remove_empty_parents(path: Path, stop_at: Path | None = None) -> None:
current = path
stop = stop_at.resolve() if stop_at is not None else None
while True:
try:
resolved = current.resolve()
if stop is not None and resolved == stop:
return
current.rmdir()
except OSError:
return
parent = current.parent
if parent == current:
return
current = parent
class MemoryGatewayService:
@ -101,9 +156,26 @@ class MemoryGatewayService:
session_id = resource_session_id(user_id, resource_id)
original_filename = _safe_filename(file.filename)
mime_type = file.content_type or mimetypes.guess_type(original_filename)[0]
if not _mime_allowed(mime_type, self.config.allowed_mime_types):
raise UnsupportedContentType(f"unsupported content type: {mime_type}")
content_type = infer_content_type(original_filename, mime_type)
stored_path = self.config.storage_dir / user_id / resource_id / original_filename
sha256, size_bytes = _copy_upload(file, stored_path)
sha256, size_bytes = _copy_upload(
file,
stored_path,
self.config.max_upload_bytes,
self.config.storage_dir,
)
existing = self.repository.find_active_resource_by_sha256(
user_id=user_id,
app_id=app_id,
project_id=project_id,
sha256=sha256,
)
if existing is not None:
shutil.rmtree(stored_path.parent, ignore_errors=True)
return self._resource_summary(existing)
internal_uri = stored_path.resolve().as_uri()
resource = self.repository.create_resource(
@ -126,16 +198,20 @@ class MemoryGatewayService:
)
try:
await self.everos_client.add_memory(
self._build_add_payload(
resource=resource,
user_id=user_id,
app_id=app_id,
project_id=project_id,
filename=original_filename,
await self._retry_everos_call(
lambda: self.everos_client.add_memory(
self._build_add_payload(
resource=resource,
user_id=user_id,
app_id=app_id,
project_id=project_id,
filename=original_filename,
)
)
)
await self.everos_client.flush_memory(session_id, app_id, project_id)
await self._retry_everos_call(
lambda: self.everos_client.flush_memory(session_id, app_id, project_id)
)
except Exception as exc:
failed = self.repository.update_resource_status(
resource_id,
@ -147,6 +223,23 @@ class MemoryGatewayService:
extracted = self.repository.update_resource_status(resource_id, "extracted")
return self._resource_summary(extracted or resource)
async def _retry_everos_call(self, operation: Any) -> Any:
attempts = max(1, self.config.everos_ingest_attempts)
last_error: Exception | None = None
for attempt in range(attempts):
try:
return await operation()
except Exception as exc:
last_error = exc
if attempt == attempts - 1:
break
delay = self.config.everos_retry_delay_seconds
if delay > 0:
await asyncio.sleep(delay)
if last_error is None:
raise RuntimeError("EverOS operation failed")
raise last_error
def _build_add_payload(
self,
*,
@ -199,8 +292,29 @@ class MemoryGatewayService:
if before is None:
return None
resource = self.repository.soft_delete_resource(resource_id, user_id)
self._cleanup_resource_file(before)
return self._resource_summary(resource)
def _cleanup_resource_file(self, resource: dict[str, Any]) -> None:
uri = str(resource.get("uri") or "")
if not uri.startswith("file://"):
return
parsed = urlparse(uri)
if parsed.scheme != "file":
return
try:
path = Path(unquote(parsed.path)).resolve()
storage_root = self.config.storage_dir.resolve()
except OSError:
return
if not path.is_relative_to(storage_root):
return
try:
path.unlink(missing_ok=True)
except OSError:
return
_remove_empty_parents(path.parent, stop_at=storage_root)
async def search_memories(
self,
*,
@ -394,6 +508,18 @@ class MemoryGatewayService:
)
return {"memory_id": memory_id, "override_id": override["id"], "status": "active"}
def assert_memory_session_owned(self, user_id: str, session_id: str) -> None:
if session_id == f"memory_edit:{user_id}":
return
if session_id.startswith("resource:"):
resource = self.repository.get_resource_by_session_for_user(
session_id,
user_id,
)
if resource is not None:
return
raise PermissionError("memory session does not belong to user")
def delete_memory(
self,
*,