from __future__ import annotations import asyncio import base64 import binascii import hashlib import mimetypes import secrets import shutil import time import uuid from pathlib import Path from typing import Any from urllib.parse import unquote, urlparse from fastapi import UploadFile from .config import GatewayConfig from .repository import MemoryRepository def new_resource_id() -> str: return f"r_{uuid.uuid4().hex}" def resource_session_id(user_id: str, resource_id: str) -> str: return f"resource:{user_id}:{resource_id}" def public_resource_uri(user_id: str, resource_id: str) -> str: return f"resource://{user_id}/{resource_id}" def current_timestamp_ms() -> int: return int(time.time() * 1000) def infer_content_type(filename: str | None, mime_type: str | None) -> str: mime = (mime_type or mimetypes.guess_type(filename or "")[0] or "").lower() suffix = Path(filename or "").suffix.lower() if mime.startswith("image/"): return "image" if mime.startswith("audio/"): return "audio" if mime == "application/pdf" or suffix == ".pdf": return "pdf" if mime in {"text/html", "application/xhtml+xml"} or suffix in {".html", ".htm"}: return "html" if mime.startswith("text/plain") or suffix in {".txt", ".md", ".csv", ".log"}: return "text" return "doc" def _safe_filename(filename: str | None) -> str: name = Path(filename or "upload.bin").name return name or "upload.bin" class UploadTooLarge(ValueError): pass class UnsupportedContentType(ValueError): pass class InvalidAttachment(ValueError): pass def _copy_upload( file: UploadFile, destination: Path, max_upload_bytes: int, cleanup_root: Path, ) -> tuple[str, int]: sha256 = hashlib.sha256() size = 0 destination.parent.mkdir(parents=True, exist_ok=True) try: with destination.open("wb") as out: while True: chunk = file.file.read(1024 * 1024) if not chunk: break size += len(chunk) if size > max_upload_bytes: raise UploadTooLarge( f"upload exceeds max size of {max_upload_bytes} bytes" ) sha256.update(chunk) out.write(chunk) return sha256.hexdigest(), size except Exception: destination.unlink(missing_ok=True) _remove_empty_parents(destination.parent, stop_at=cleanup_root) raise def _mime_allowed(mime_type: str | None, allowed_mime_types: tuple[str, ...]) -> bool: mime = (mime_type or "").lower() if not mime: return False for allowed in allowed_mime_types: item = allowed.lower() if item.endswith("/*") and mime.startswith(item[:-1]): return True if item == mime: return True return False def _remove_empty_parents(path: Path, stop_at: Path | None = None) -> None: current = path stop = stop_at.resolve() if stop_at is not None else None while True: try: resolved = current.resolve() if stop is not None and resolved == stop: return current.rmdir() except OSError: return parent = current.parent if parent == current: return current = parent class MemoryGatewayService: def __init__( self, config: GatewayConfig, repository: MemoryRepository, backend_client: Any, ) -> None: self.config = config self.repository = repository self.backend_client = backend_client def create_user(self, user_id: str) -> dict[str, Any]: user_key = f"uk_{secrets.token_urlsafe(32)}" user = self.repository.create_user(user_id, user_key) return { "user_id": user["id"], "user_key": user["user_key"], "created_at": user["created_at"], } def authenticate_user(self, user_id: str, user_key: str) -> bool: user = self.repository.get_user(user_id) if user is None: return False return secrets.compare_digest(str(user["user_key"]), user_key) async def upload_resource( self, *, user_id: str, app_id: str, project_id: str, file: UploadFile, title: str | None, description: str | None, ) -> dict[str, Any]: resource_id = new_resource_id() session_id = resource_session_id(user_id, resource_id) original_filename = _safe_filename(file.filename) mime_type = file.content_type or mimetypes.guess_type(original_filename)[0] if not _mime_allowed(mime_type, self.config.allowed_mime_types): raise UnsupportedContentType(f"unsupported content type: {mime_type}") content_type = infer_content_type(original_filename, mime_type) stored_path = self.config.storage_dir / user_id / resource_id / original_filename sha256, size_bytes = _copy_upload( file, stored_path, self.config.max_upload_bytes, self.config.storage_dir, ) existing = self.repository.find_active_resource_by_sha256( user_id=user_id, app_id=app_id, project_id=project_id, sha256=sha256, ) if existing is not None: shutil.rmtree(stored_path.parent, ignore_errors=True) self._register_resource_attachment(existing) return self._resource_summary(existing) internal_uri = stored_path.resolve().as_uri() resource = self.repository.create_resource( id=resource_id, user_id=user_id, app_id=app_id, project_id=project_id, session_id=session_id, original_filename=original_filename, mime_type=mime_type, content_type=content_type, uri=internal_uri, uri_public=False, sha256=sha256, size_bytes=size_bytes, title=title, description=description, status="ingesting", error_message=None, ) self._register_resource_attachment(resource) try: await self._retry_backend_call( lambda: self.backend_client.add_memory( self._build_add_payload( resource=resource, user_id=user_id, app_id=app_id, project_id=project_id, filename=original_filename, ) ) ) await self._retry_backend_call( lambda: self.backend_client.flush_memory(session_id, app_id, project_id) ) except Exception as exc: failed = self.repository.update_resource_status( resource_id, "failed", str(exc), ) return self._resource_summary(failed or resource) extracted = self.repository.update_resource_status(resource_id, "extracted") return self._resource_summary(extracted or resource) async def _retry_backend_call(self, operation: Any) -> Any: attempts = max(1, self.config.backend_ingest_attempts) last_error: Exception | None = None for attempt in range(attempts): try: return await operation() except Exception as exc: last_error = exc if attempt == attempts - 1: break delay = self.config.backend_retry_delay_seconds if delay > 0: await asyncio.sleep(delay) if last_error is None: raise RuntimeError("upstream memory service operation failed") raise last_error def _build_add_payload( self, *, resource: dict[str, Any], user_id: str, app_id: str, project_id: str, filename: str, ) -> dict[str, Any]: content_item = self._build_content_item(resource=resource, filename=filename) return { "session_id": resource["session_id"], "app_id": app_id, "project_id": project_id, "messages": [ { "sender_id": user_id, "role": "user", "timestamp": current_timestamp_ms(), "content": [content_item], } ], } def _build_content_item( self, *, resource: dict[str, Any], filename: str, ) -> dict[str, Any]: content_type = str(resource["content_type"]) path = self._resource_file_path(resource) content = path.read_bytes() item = { "type": content_type, "name": filename, "ext": Path(filename).suffix.lstrip(".") or None, "extras": { "resource_id": resource["id"], "source": "user_upload", }, } if content_type == "text": item["text"] = content.decode("utf-8", errors="replace") else: item["base64"] = base64.b64encode(content).decode("ascii") return item def _resource_file_path(self, resource: dict[str, Any]) -> Path: uri = str(resource["uri"]) parsed = urlparse(uri) if parsed.scheme != "file": raise ValueError(f"unsupported resource uri scheme: {parsed.scheme}") return Path(unquote(parsed.path)).resolve(strict=True) def list_resources(self, user_id: str) -> list[dict[str, Any]]: return [self._resource_detail(item) for item in self.repository.list_resources(user_id)] def get_resource_detail( self, resource_id: str, user_id: str, ) -> dict[str, Any] | None: resource = self.repository.get_resource_for_user(resource_id, user_id) if resource is None: return None return self._resource_detail(resource) def delete_resource(self, resource_id: str, user_id: str) -> dict[str, Any] | None: before = self.repository.get_resource_for_user(resource_id, user_id) if before is None: return None resource = self.repository.soft_delete_resource(resource_id, user_id) self._cleanup_resource_file(before) return self._resource_summary(resource) def _cleanup_resource_file(self, resource: dict[str, Any]) -> None: uri = str(resource.get("uri") or "") if not uri.startswith("file://"): return parsed = urlparse(uri) if parsed.scheme != "file": return try: path = Path(unquote(parsed.path)).resolve() storage_root = self.config.storage_dir.resolve() except OSError: return if not path.is_relative_to(storage_root): return try: path.unlink(missing_ok=True) except OSError: return _remove_empty_parents(path.parent, stop_at=storage_root) async def search_memories( self, *, user_id: str, agent_id: str | None, query: str, conversation_id: str | None, scope: list[str], method: str, top_k: int, radius: float | None, include_profile: bool, enable_llm_rerank: bool, filters: dict[str, Any] | None, app_id: str, project_id: str, ) -> dict[str, Any]: results: list[dict[str, Any]] = [] session_resource_map: dict[str, dict[str, Any]] = {} if "current_chat" in scope and conversation_id: payload = self._search_payload( user_id=user_id, agent_id=agent_id, query=query, method=method, top_k=top_k, radius=radius, include_profile=include_profile, enable_llm_rerank=enable_llm_rerank, app_id=app_id, project_id=project_id, filters=_combine_filters( filters, {"session_id": f"chat:{conversation_id}"}, ), ) results.extend( self._extract_results( await self.backend_client.search_memory(payload), source_scope="current_chat", session_resource_map=session_resource_map, user_id=user_id, ) ) if "resources" in scope: resources = self.repository.list_extracted_resources( user_id, app_id, project_id, ) session_resource_map.update({item["session_id"]: item for item in resources}) session_ids = [item["session_id"] for item in resources] for batch in _chunks(session_ids, self.config.resource_search_batch_size): payload = self._search_payload( user_id=user_id, agent_id=agent_id, query=query, method=method, top_k=top_k, radius=radius, include_profile=include_profile, enable_llm_rerank=enable_llm_rerank, app_id=app_id, project_id=project_id, filters=_combine_filters( filters, {"session_id": {"in": batch}}, ), ) results.extend( self._extract_results( await self.backend_client.search_memory(payload), source_scope="resources", session_resource_map=session_resource_map, user_id=user_id, ) ) if "all_user_memory" in scope: payload = self._search_payload( user_id=user_id, agent_id=agent_id, query=query, method=method, top_k=top_k, radius=radius, include_profile=include_profile, enable_llm_rerank=enable_llm_rerank, app_id=app_id, project_id=project_id, filters=filters, ) results.extend( self._extract_results( await self.backend_client.search_memory(payload), source_scope="all_user_memory", session_resource_map=session_resource_map, user_id=user_id, ) ) filtered = self._apply_tombstones(user_id, results) overridden = self._apply_overrides(user_id, filtered) return {"results": overridden} async def add_memory( self, *, user_id: str, session_id: str, app_id: str, project_id: str, messages: list[dict[str, Any]], ) -> dict[str, Any]: attachments, generated_paths = self._prepare_memory_attachments( user_id=user_id, session_id=session_id, app_id=app_id, project_id=project_id, messages=messages, ) payload = { "session_id": session_id, "app_id": app_id, "project_id": project_id, "messages": messages, } try: backend = await self.backend_client.add_memory(payload) for attachment in attachments: self.repository.create_attachment(**attachment) except Exception: for path in generated_paths: path.unlink(missing_ok=True) _remove_empty_parents(path.parent, stop_at=self.config.storage_dir) raise return {"session_id": session_id, "backend": backend} def _register_resource_attachment(self, resource: dict[str, Any]) -> None: self.repository.create_attachment( user_id=resource["user_id"], app_id=resource["app_id"], project_id=resource["project_id"], session_id=resource["session_id"], resource_id=resource["id"], content_type=resource["content_type"], name=resource["original_filename"] or resource["id"], internal_uri=resource["uri"], source="resource_upload", sha256=resource["sha256"], ) def _prepare_memory_attachments( self, *, user_id: str, session_id: str, app_id: str, project_id: str, messages: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[Path]]: attachments: list[dict[str, Any]] = [] generated_paths: list[Path] = [] try: for message in messages: content = message.get("content") if not isinstance(content, list): continue for item in content: if not isinstance(item, dict): continue uri = item.get("uri") encoded = item.get("base64") if not uri and not encoded: continue attachment_id = f"a_{uuid.uuid4().hex}" name = _attachment_name(item, str(uri) if uri else None) sha256: str | None = None if uri: internal_uri = str(uri) source = "memory_add_uri" else: try: data = base64.b64decode(str(encoded), validate=True) except (binascii.Error, ValueError) as exc: raise InvalidAttachment( f"invalid base64 attachment: {name}" ) from exc if len(data) > self.config.max_upload_bytes: raise UploadTooLarge( f"attachment exceeds max size of " f"{self.config.max_upload_bytes} bytes" ) sha256 = hashlib.sha256(data).hexdigest() path = ( self.config.storage_dir / user_id / "memory_attachments" / sha256 / name ) if not path.exists(): path.parent.mkdir(parents=True, exist_ok=True) path.write_bytes(data) generated_paths.append(path) internal_uri = path.resolve().as_uri() source = "memory_add_base64" attachments.append( { "id": attachment_id, "user_id": user_id, "app_id": app_id, "project_id": project_id, "session_id": session_id, "resource_id": None, "content_type": str(item.get("type") or "doc"), "name": name, "internal_uri": internal_uri, "source": source, "sha256": sha256, } ) except Exception: for path in generated_paths: path.unlink(missing_ok=True) _remove_empty_parents(path.parent, stop_at=self.config.storage_dir) raise return attachments, generated_paths async def flush_memory( self, *, session_id: str, app_id: str, project_id: str, ) -> dict[str, Any]: return { "session_id": session_id, "backend": await self.backend_client.flush_memory( session_id, app_id, project_id, ), } def _search_payload( self, *, user_id: str, agent_id: str | None, query: str, method: str, top_k: int, radius: float | None, include_profile: bool, enable_llm_rerank: bool, app_id: str, project_id: str, filters: dict[str, Any] | None, ) -> dict[str, Any]: payload: dict[str, Any] = { "query": query, "method": method, "top_k": top_k, "include_profile": include_profile, "enable_llm_rerank": enable_llm_rerank, "app_id": app_id, "project_id": project_id, } payload["agent_id" if agent_id else "user_id"] = agent_id or user_id if radius is not None: payload["radius"] = radius if filters is not None: payload["filters"] = filters return payload def _extract_results( self, response: dict[str, Any], *, source_scope: str, session_resource_map: dict[str, dict[str, Any]], user_id: str, ) -> list[dict[str, Any]]: data = response.get("data", {}) raw_items: list[tuple[str, dict[str, Any]]] = [] memory_types = { "episodes": "episode", "profiles": "profile", "agent_cases": "agent_case", "agent_skills": "agent_skill", "unprocessed_messages": "unprocessed_message", } for key, memory_type in memory_types.items(): raw_items.extend( (memory_type, item) for item in (data.get(key, []) or []) ) normalized = [] attachment_cache: dict[str, list[dict[str, Any]]] = {} for memory_type, raw in raw_items: session_id = raw.get("session_id") resource = session_resource_map.get(session_id) if resource is None and isinstance(session_id, str): resource = self.repository.get_resource_by_session_for_user( session_id, user_id, ) attachments: list[dict[str, Any]] = [] if isinstance(session_id, str): if session_id not in attachment_cache: attachment_cache[session_id] = ( self.repository.list_attachments_for_session( user_id, session_id, ) ) session_attachments = attachment_cache[session_id] attachments = _matching_attachments(raw, session_attachments) normalized.append( { "id": raw.get("id"), "memory_type": memory_type, "session_id": session_id, "text": _display_text(raw), "score": raw.get("score"), "source_scope": source_scope, "resource_id": resource["id"] if resource else None, "resource_uri": ( public_resource_uri(user_id, resource["id"]) if resource else None ), "attachments": attachments, "raw": raw, } ) return normalized def _apply_tombstones( self, user_id: str, results: list[dict[str, Any]], ) -> list[dict[str, Any]]: tombstones = self.repository.get_tombstones(user_id) memory_ids = {item["memory_id"] for item in tombstones if item["memory_id"]} session_ids = {item["session_id"] for item in tombstones if item["session_id"]} return [ item for item in results if item.get("id") not in memory_ids and item.get("session_id") not in session_ids ] def _apply_overrides( self, user_id: str, results: list[dict[str, Any]], ) -> list[dict[str, Any]]: overrides = { item["memory_id"]: item for item in self.repository.get_active_overrides(user_id) if item["memory_id"] } for result in results: override = overrides.get(result.get("id")) if override: result["text"] = override["override_text"] result["override_id"] = override["id"] return results def upsert_override( self, *, user_id: str, memory_id: str, session_id: str | None, override_text: str, ) -> dict[str, Any]: override = self.repository.upsert_override( user_id, memory_id, session_id, override_text, ) return {"memory_id": memory_id, "override_id": override["id"], "status": "active"} def assert_memory_session_owned(self, user_id: str, session_id: str) -> None: if session_id == f"memory_edit:{user_id}": return if session_id.startswith("resource:"): resource = self.repository.get_resource_by_session_for_user( session_id, user_id, ) if resource is not None: return raise PermissionError("memory session does not belong to user") def delete_memory( self, *, user_id: str, memory_id: str, session_id: str | None, reason: str | None, ) -> dict[str, Any]: tombstone = self.repository.add_tombstone( user_id, memory_id, session_id, reason, ) return {"memory_id": memory_id, "tombstone_id": tombstone["id"], "status": "deleted"} def _resource_summary(self, resource: dict[str, Any]) -> dict[str, Any]: return { "resource_id": resource["id"], "session_id": resource["session_id"], "uri": public_resource_uri(resource["user_id"], resource["id"]), "status": resource["status"], } def _resource_detail(self, resource: dict[str, Any]) -> dict[str, Any]: return { "resource_id": resource["id"], "user_id": resource["user_id"], "filename": resource["original_filename"], "content_type": resource["content_type"], "mime_type": resource["mime_type"], "uri": public_resource_uri(resource["user_id"], resource["id"]), "session_id": resource["session_id"], "status": resource["status"], "title": resource["title"], "description": resource["description"], "created_at": resource["created_at"], "updated_at": resource["updated_at"], } def _combine_filters( custom_filters: dict[str, Any] | None, scope_filters: dict[str, Any] | None, ) -> dict[str, Any] | None: if custom_filters is None: return scope_filters if scope_filters is None: return custom_filters return {"AND": [custom_filters, scope_filters]} def _attachment_name(item: dict[str, Any], uri: str | None) -> str: if item.get("name"): return _safe_filename(str(item["name"])) if uri: parsed = urlparse(uri) uri_name = Path(unquote(parsed.path)).name if uri_name: return _safe_filename(uri_name) extension = str(item.get("ext") or "bin").lstrip(".") or "bin" return f"attachment.{extension}" def _matching_attachments( raw: dict[str, Any], attachments: list[dict[str, Any]], ) -> list[dict[str, Any]]: strings = [value.casefold() for value in _raw_string_values(raw)] matched: list[dict[str, Any]] = [] seen_uris: set[str] = set() for attachment in attachments: name = str(attachment["name"]) internal_uri = str(attachment["internal_uri"]) if internal_uri in seen_uris: continue if not any(name.casefold() in value for value in strings): continue seen_uris.add(internal_uri) matched.append( { "type": attachment["content_type"], "name": name, "internal_uri": internal_uri, } ) return matched def _raw_string_values(value: Any, key: str | None = None) -> list[str]: if key is not None and key.casefold() == "base64": return [] if isinstance(value, str): return [value] if isinstance(value, dict): strings: list[str] = [] for item_key, item_value in value.items(): strings.extend(_raw_string_values(item_value, str(item_key))) return strings if isinstance(value, list): strings = [] for item in value: strings.extend(_raw_string_values(item)) return strings return [] def _chunks(items: list[str], size: int) -> list[list[str]]: if not items: return [] return [items[index : index + size] for index in range(0, len(items), size)] def _display_text(raw: dict[str, Any]) -> str: for key in ( "episode", "summary", "content", "profile_data", "task_intent", "approach", "key_insight", "name", "description", ): value = raw.get(key) if value is None: continue if isinstance(value, str): return value return str(value) return ""