harden memory edits and uploads
This commit is contained in:
40
core/api.py
40
core/api.py
@ -9,7 +9,7 @@ from .config import GatewayConfig
|
||||
from .db import init_db
|
||||
from .everos_client import EverOSClient
|
||||
from .repository import MemoryRepository
|
||||
from .service import MemoryGatewayService
|
||||
from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge
|
||||
|
||||
|
||||
class SearchMemoriesRequest(BaseModel):
|
||||
@ -28,14 +28,14 @@ class SearchMemoriesRequest(BaseModel):
|
||||
class MemoryOverrideRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
session_id: str | None = None
|
||||
session_id: str = Field(min_length=1)
|
||||
override_text: str = Field(min_length=1)
|
||||
|
||||
|
||||
class MemoryDeleteRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
session_id: str | None = None
|
||||
session_id: str = Field(min_length=1)
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
@ -51,7 +51,10 @@ def create_app(
|
||||
cfg = config or GatewayConfig.from_env()
|
||||
init_db(cfg.database_path)
|
||||
repository = MemoryRepository(cfg.database_path)
|
||||
client = everos_client or EverOSClient(cfg.everos_base_url)
|
||||
client = everos_client or EverOSClient(
|
||||
cfg.everos_base_url,
|
||||
timeout=cfg.everos_timeout_seconds,
|
||||
)
|
||||
service = MemoryGatewayService(cfg, repository, client)
|
||||
|
||||
app = FastAPI(title="memory-gateway2", version="0.1.0")
|
||||
@ -105,14 +108,19 @@ def create_app(
|
||||
file: UploadFile = File(...),
|
||||
) -> dict[str, Any]:
|
||||
require_user(user_id, user_key)
|
||||
return await service.upload_resource(
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
file=file,
|
||||
title=title,
|
||||
description=description,
|
||||
)
|
||||
try:
|
||||
return await service.upload_resource(
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
file=file,
|
||||
title=title,
|
||||
description=description,
|
||||
)
|
||||
except UploadTooLarge as exc:
|
||||
raise HTTPException(status_code=413, detail=str(exc)) from exc
|
||||
except UnsupportedContentType as exc:
|
||||
raise HTTPException(status_code=415, detail=str(exc)) from exc
|
||||
|
||||
@router.get("/resources")
|
||||
async def list_resources(
|
||||
@ -167,6 +175,10 @@ def create_app(
|
||||
request: MemoryOverrideRequest,
|
||||
) -> dict[str, Any]:
|
||||
require_user(request.user_id, request.user_key)
|
||||
try:
|
||||
service.assert_memory_session_owned(request.user_id, request.session_id)
|
||||
except PermissionError as exc:
|
||||
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
||||
return service.upsert_override(
|
||||
user_id=request.user_id,
|
||||
memory_id=memory_id,
|
||||
@ -180,6 +192,10 @@ def create_app(
|
||||
request: MemoryDeleteRequest,
|
||||
) -> dict[str, Any]:
|
||||
require_user(request.user_id, request.user_key)
|
||||
try:
|
||||
service.assert_memory_session_owned(request.user_id, request.session_id)
|
||||
except PermissionError as exc:
|
||||
raise HTTPException(status_code=403, detail=str(exc)) from exc
|
||||
return service.delete_memory(
|
||||
user_id=request.user_id,
|
||||
memory_id=memory_id,
|
||||
|
||||
@ -6,6 +6,23 @@ from pathlib import Path
|
||||
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
_DEFAULT_ALLOWED_MIME_TYPES = (
|
||||
"image/*",
|
||||
"audio/*",
|
||||
"application/pdf",
|
||||
"text/html",
|
||||
"application/xhtml+xml",
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"application/json",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -14,9 +31,22 @@ class GatewayConfig:
|
||||
database_path: Path = _PROJECT_ROOT / "data" / "memory_gateway.sqlite3"
|
||||
storage_dir: Path = _PROJECT_ROOT / "data" / "storage"
|
||||
resource_search_batch_size: int = 50
|
||||
max_upload_bytes: int = 25 * 1024 * 1024
|
||||
allowed_mime_types: tuple[str, ...] = _DEFAULT_ALLOWED_MIME_TYPES
|
||||
everos_ingest_attempts: int = 3
|
||||
everos_retry_delay_seconds: float = 0.25
|
||||
everos_timeout_seconds: float = 120.0
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> GatewayConfig:
|
||||
allowed_mime_types = tuple(
|
||||
item.strip()
|
||||
for item in os.environ.get(
|
||||
"MEMORY_GATEWAY_ALLOWED_MIME_TYPES",
|
||||
",".join(_DEFAULT_ALLOWED_MIME_TYPES),
|
||||
).split(",")
|
||||
if item.strip()
|
||||
)
|
||||
return cls(
|
||||
everos_base_url=os.environ.get(
|
||||
"EVEROS_BASE_URL",
|
||||
@ -37,4 +67,17 @@ class GatewayConfig:
|
||||
resource_search_batch_size=int(
|
||||
os.environ.get("MEMORY_GATEWAY_RESOURCE_SEARCH_BATCH_SIZE", "50")
|
||||
),
|
||||
max_upload_bytes=int(
|
||||
os.environ.get("MEMORY_GATEWAY_MAX_UPLOAD_BYTES", str(25 * 1024 * 1024))
|
||||
),
|
||||
allowed_mime_types=allowed_mime_types,
|
||||
everos_ingest_attempts=int(
|
||||
os.environ.get("MEMORY_GATEWAY_EVEROS_INGEST_ATTEMPTS", "3")
|
||||
),
|
||||
everos_retry_delay_seconds=float(
|
||||
os.environ.get("MEMORY_GATEWAY_EVEROS_RETRY_DELAY_SECONDS", "0.25")
|
||||
),
|
||||
everos_timeout_seconds=float(
|
||||
os.environ.get("MEMORY_GATEWAY_EVEROS_TIMEOUT_SECONDS", "120")
|
||||
),
|
||||
)
|
||||
|
||||
@ -157,6 +157,31 @@ class MemoryRepository:
|
||||
).fetchone()
|
||||
return _row_to_dict(row)
|
||||
|
||||
def find_active_resource_by_sha256(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
sha256: str,
|
||||
) -> dict[str, Any] | None:
|
||||
with connect(self.db_path) as conn:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT * FROM user_resources
|
||||
WHERE user_id = ?
|
||||
AND app_id = ?
|
||||
AND project_id = ?
|
||||
AND sha256 = ?
|
||||
AND deleted_at IS NULL
|
||||
AND status IN ('ingesting', 'extracted')
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(user_id, app_id, project_id, sha256),
|
||||
).fetchone()
|
||||
return _row_to_dict(row)
|
||||
|
||||
def list_resources(self, user_id: str) -> list[dict[str, Any]]:
|
||||
with connect(self.db_path) as conn:
|
||||
rows = conn.execute(
|
||||
|
||||
164
core/service.py
164
core/service.py
@ -1,11 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import secrets
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
@ -46,19 +49,71 @@ def _safe_filename(filename: str | None) -> str:
|
||||
return name or "upload.bin"
|
||||
|
||||
|
||||
def _copy_upload(file: UploadFile, destination: Path) -> tuple[str, int]:
|
||||
class UploadTooLarge(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UnsupportedContentType(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _copy_upload(
|
||||
file: UploadFile,
|
||||
destination: Path,
|
||||
max_upload_bytes: int,
|
||||
cleanup_root: Path,
|
||||
) -> tuple[str, int]:
|
||||
sha256 = hashlib.sha256()
|
||||
size = 0
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
with destination.open("wb") as out:
|
||||
while True:
|
||||
chunk = file.file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
size += len(chunk)
|
||||
sha256.update(chunk)
|
||||
out.write(chunk)
|
||||
return sha256.hexdigest(), size
|
||||
try:
|
||||
with destination.open("wb") as out:
|
||||
while True:
|
||||
chunk = file.file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
size += len(chunk)
|
||||
if size > max_upload_bytes:
|
||||
raise UploadTooLarge(
|
||||
f"upload exceeds max size of {max_upload_bytes} bytes"
|
||||
)
|
||||
sha256.update(chunk)
|
||||
out.write(chunk)
|
||||
return sha256.hexdigest(), size
|
||||
except Exception:
|
||||
destination.unlink(missing_ok=True)
|
||||
_remove_empty_parents(destination.parent, stop_at=cleanup_root)
|
||||
raise
|
||||
|
||||
|
||||
def _mime_allowed(mime_type: str | None, allowed_mime_types: tuple[str, ...]) -> bool:
|
||||
mime = (mime_type or "").lower()
|
||||
if not mime:
|
||||
return False
|
||||
for allowed in allowed_mime_types:
|
||||
item = allowed.lower()
|
||||
if item.endswith("/*") and mime.startswith(item[:-1]):
|
||||
return True
|
||||
if item == mime:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _remove_empty_parents(path: Path, stop_at: Path | None = None) -> None:
|
||||
current = path
|
||||
stop = stop_at.resolve() if stop_at is not None else None
|
||||
while True:
|
||||
try:
|
||||
resolved = current.resolve()
|
||||
if stop is not None and resolved == stop:
|
||||
return
|
||||
current.rmdir()
|
||||
except OSError:
|
||||
return
|
||||
parent = current.parent
|
||||
if parent == current:
|
||||
return
|
||||
current = parent
|
||||
|
||||
|
||||
class MemoryGatewayService:
|
||||
@ -101,9 +156,26 @@ class MemoryGatewayService:
|
||||
session_id = resource_session_id(user_id, resource_id)
|
||||
original_filename = _safe_filename(file.filename)
|
||||
mime_type = file.content_type or mimetypes.guess_type(original_filename)[0]
|
||||
if not _mime_allowed(mime_type, self.config.allowed_mime_types):
|
||||
raise UnsupportedContentType(f"unsupported content type: {mime_type}")
|
||||
content_type = infer_content_type(original_filename, mime_type)
|
||||
stored_path = self.config.storage_dir / user_id / resource_id / original_filename
|
||||
sha256, size_bytes = _copy_upload(file, stored_path)
|
||||
sha256, size_bytes = _copy_upload(
|
||||
file,
|
||||
stored_path,
|
||||
self.config.max_upload_bytes,
|
||||
self.config.storage_dir,
|
||||
)
|
||||
existing = self.repository.find_active_resource_by_sha256(
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
sha256=sha256,
|
||||
)
|
||||
if existing is not None:
|
||||
shutil.rmtree(stored_path.parent, ignore_errors=True)
|
||||
return self._resource_summary(existing)
|
||||
|
||||
internal_uri = stored_path.resolve().as_uri()
|
||||
|
||||
resource = self.repository.create_resource(
|
||||
@ -126,16 +198,20 @@ class MemoryGatewayService:
|
||||
)
|
||||
|
||||
try:
|
||||
await self.everos_client.add_memory(
|
||||
self._build_add_payload(
|
||||
resource=resource,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
filename=original_filename,
|
||||
await self._retry_everos_call(
|
||||
lambda: self.everos_client.add_memory(
|
||||
self._build_add_payload(
|
||||
resource=resource,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
filename=original_filename,
|
||||
)
|
||||
)
|
||||
)
|
||||
await self.everos_client.flush_memory(session_id, app_id, project_id)
|
||||
await self._retry_everos_call(
|
||||
lambda: self.everos_client.flush_memory(session_id, app_id, project_id)
|
||||
)
|
||||
except Exception as exc:
|
||||
failed = self.repository.update_resource_status(
|
||||
resource_id,
|
||||
@ -147,6 +223,23 @@ class MemoryGatewayService:
|
||||
extracted = self.repository.update_resource_status(resource_id, "extracted")
|
||||
return self._resource_summary(extracted or resource)
|
||||
|
||||
async def _retry_everos_call(self, operation: Any) -> Any:
|
||||
attempts = max(1, self.config.everos_ingest_attempts)
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
return await operation()
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if attempt == attempts - 1:
|
||||
break
|
||||
delay = self.config.everos_retry_delay_seconds
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
if last_error is None:
|
||||
raise RuntimeError("EverOS operation failed")
|
||||
raise last_error
|
||||
|
||||
def _build_add_payload(
|
||||
self,
|
||||
*,
|
||||
@ -199,8 +292,29 @@ class MemoryGatewayService:
|
||||
if before is None:
|
||||
return None
|
||||
resource = self.repository.soft_delete_resource(resource_id, user_id)
|
||||
self._cleanup_resource_file(before)
|
||||
return self._resource_summary(resource)
|
||||
|
||||
def _cleanup_resource_file(self, resource: dict[str, Any]) -> None:
|
||||
uri = str(resource.get("uri") or "")
|
||||
if not uri.startswith("file://"):
|
||||
return
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme != "file":
|
||||
return
|
||||
try:
|
||||
path = Path(unquote(parsed.path)).resolve()
|
||||
storage_root = self.config.storage_dir.resolve()
|
||||
except OSError:
|
||||
return
|
||||
if not path.is_relative_to(storage_root):
|
||||
return
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
return
|
||||
_remove_empty_parents(path.parent, stop_at=storage_root)
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
*,
|
||||
@ -394,6 +508,18 @@ class MemoryGatewayService:
|
||||
)
|
||||
return {"memory_id": memory_id, "override_id": override["id"], "status": "active"}
|
||||
|
||||
def assert_memory_session_owned(self, user_id: str, session_id: str) -> None:
|
||||
if session_id == f"memory_edit:{user_id}":
|
||||
return
|
||||
if session_id.startswith("resource:"):
|
||||
resource = self.repository.get_resource_by_session_for_user(
|
||||
session_id,
|
||||
user_id,
|
||||
)
|
||||
if resource is not None:
|
||||
return
|
||||
raise PermissionError("memory session does not belong to user")
|
||||
|
||||
def delete_memory(
|
||||
self,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user