Files
memory-gateway/core/service.py

883 lines
29 KiB
Python

from __future__ import annotations
import asyncio
import base64
import binascii
import hashlib
import mimetypes
import secrets
import shutil
import time
import uuid
from pathlib import Path
from typing import Any
from urllib.parse import unquote, urlparse
from fastapi import UploadFile
from .config import GatewayConfig
from .repository import MemoryRepository
def new_resource_id() -> str:
return f"r_{uuid.uuid4().hex}"
def resource_session_id(user_id: str, resource_id: str) -> str:
return f"resource:{user_id}:{resource_id}"
def public_resource_uri(user_id: str, resource_id: str) -> str:
return f"resource://{user_id}/{resource_id}"
def current_timestamp_ms() -> int:
return int(time.time() * 1000)
def infer_content_type(filename: str | None, mime_type: str | None) -> str:
mime = (mime_type or mimetypes.guess_type(filename or "")[0] or "").lower()
suffix = Path(filename or "").suffix.lower()
if mime.startswith("image/"):
return "image"
if mime.startswith("audio/"):
return "audio"
if mime == "application/pdf" or suffix == ".pdf":
return "pdf"
if mime in {"text/html", "application/xhtml+xml"} or suffix in {".html", ".htm"}:
return "html"
if mime.startswith("text/plain") or suffix in {".txt", ".md", ".csv", ".log"}:
return "text"
return "doc"
def _safe_filename(filename: str | None) -> str:
name = Path(filename or "upload.bin").name
return name or "upload.bin"
class UploadTooLarge(ValueError):
pass
class UnsupportedContentType(ValueError):
pass
class InvalidAttachment(ValueError):
pass
def _copy_upload(
file: UploadFile,
destination: Path,
max_upload_bytes: int,
cleanup_root: Path,
) -> tuple[str, int]:
sha256 = hashlib.sha256()
size = 0
destination.parent.mkdir(parents=True, exist_ok=True)
try:
with destination.open("wb") as out:
while True:
chunk = file.file.read(1024 * 1024)
if not chunk:
break
size += len(chunk)
if size > max_upload_bytes:
raise UploadTooLarge(
f"upload exceeds max size of {max_upload_bytes} bytes"
)
sha256.update(chunk)
out.write(chunk)
return sha256.hexdigest(), size
except Exception:
destination.unlink(missing_ok=True)
_remove_empty_parents(destination.parent, stop_at=cleanup_root)
raise
def _mime_allowed(mime_type: str | None, allowed_mime_types: tuple[str, ...]) -> bool:
mime = (mime_type or "").lower()
if not mime:
return False
for allowed in allowed_mime_types:
item = allowed.lower()
if item.endswith("/*") and mime.startswith(item[:-1]):
return True
if item == mime:
return True
return False
def _remove_empty_parents(path: Path, stop_at: Path | None = None) -> None:
current = path
stop = stop_at.resolve() if stop_at is not None else None
while True:
try:
resolved = current.resolve()
if stop is not None and resolved == stop:
return
current.rmdir()
except OSError:
return
parent = current.parent
if parent == current:
return
current = parent
class MemoryGatewayService:
def __init__(
self,
config: GatewayConfig,
repository: MemoryRepository,
backend_client: Any,
) -> None:
self.config = config
self.repository = repository
self.backend_client = backend_client
def create_user(self, user_id: str) -> dict[str, Any]:
user_key = f"uk_{secrets.token_urlsafe(32)}"
user = self.repository.create_user(user_id, user_key)
return {
"user_id": user["id"],
"user_key": user["user_key"],
"created_at": user["created_at"],
}
def authenticate_user(self, user_id: str, user_key: str) -> bool:
user = self.repository.get_user(user_id)
if user is None:
return False
return secrets.compare_digest(str(user["user_key"]), user_key)
async def upload_resource(
self,
*,
user_id: str,
app_id: str,
project_id: str,
file: UploadFile,
title: str | None,
description: str | None,
) -> dict[str, Any]:
resource_id = new_resource_id()
session_id = resource_session_id(user_id, resource_id)
original_filename = _safe_filename(file.filename)
mime_type = file.content_type or mimetypes.guess_type(original_filename)[0]
if not _mime_allowed(mime_type, self.config.allowed_mime_types):
raise UnsupportedContentType(f"unsupported content type: {mime_type}")
content_type = infer_content_type(original_filename, mime_type)
stored_path = self.config.storage_dir / user_id / resource_id / original_filename
sha256, size_bytes = _copy_upload(
file,
stored_path,
self.config.max_upload_bytes,
self.config.storage_dir,
)
existing = self.repository.find_active_resource_by_sha256(
user_id=user_id,
app_id=app_id,
project_id=project_id,
sha256=sha256,
)
if existing is not None:
shutil.rmtree(stored_path.parent, ignore_errors=True)
self._register_resource_attachment(existing)
return self._resource_summary(existing)
internal_uri = stored_path.resolve().as_uri()
resource = self.repository.create_resource(
id=resource_id,
user_id=user_id,
app_id=app_id,
project_id=project_id,
session_id=session_id,
original_filename=original_filename,
mime_type=mime_type,
content_type=content_type,
uri=internal_uri,
uri_public=False,
sha256=sha256,
size_bytes=size_bytes,
title=title,
description=description,
status="ingesting",
error_message=None,
)
self._register_resource_attachment(resource)
try:
await self._retry_backend_call(
lambda: self.backend_client.add_memory(
self._build_add_payload(
resource=resource,
user_id=user_id,
app_id=app_id,
project_id=project_id,
filename=original_filename,
)
)
)
await self._retry_backend_call(
lambda: self.backend_client.flush_memory(session_id, app_id, project_id)
)
except Exception as exc:
failed = self.repository.update_resource_status(
resource_id,
"failed",
str(exc),
)
return self._resource_summary(failed or resource)
extracted = self.repository.update_resource_status(resource_id, "extracted")
return self._resource_summary(extracted or resource)
async def _retry_backend_call(self, operation: Any) -> Any:
attempts = max(1, self.config.backend_ingest_attempts)
last_error: Exception | None = None
for attempt in range(attempts):
try:
return await operation()
except Exception as exc:
last_error = exc
if attempt == attempts - 1:
break
delay = self.config.backend_retry_delay_seconds
if delay > 0:
await asyncio.sleep(delay)
if last_error is None:
raise RuntimeError("upstream memory service operation failed")
raise last_error
def _build_add_payload(
self,
*,
resource: dict[str, Any],
user_id: str,
app_id: str,
project_id: str,
filename: str,
) -> dict[str, Any]:
content_item = self._build_content_item(resource=resource, filename=filename)
return {
"session_id": resource["session_id"],
"app_id": app_id,
"project_id": project_id,
"messages": [
{
"sender_id": user_id,
"role": "user",
"timestamp": current_timestamp_ms(),
"content": [content_item],
}
],
}
def _build_content_item(
self,
*,
resource: dict[str, Any],
filename: str,
) -> dict[str, Any]:
content_type = str(resource["content_type"])
path = self._resource_file_path(resource)
content = path.read_bytes()
item = {
"type": content_type,
"name": filename,
"ext": Path(filename).suffix.lstrip(".") or None,
"extras": {
"resource_id": resource["id"],
"source": "user_upload",
},
}
if content_type == "text":
item["text"] = content.decode("utf-8", errors="replace")
else:
item["base64"] = base64.b64encode(content).decode("ascii")
return item
def _resource_file_path(self, resource: dict[str, Any]) -> Path:
uri = str(resource["uri"])
parsed = urlparse(uri)
if parsed.scheme != "file":
raise ValueError(f"unsupported resource uri scheme: {parsed.scheme}")
return Path(unquote(parsed.path)).resolve(strict=True)
def list_resources(self, user_id: str) -> list[dict[str, Any]]:
return [self._resource_detail(item) for item in self.repository.list_resources(user_id)]
def get_resource_detail(
self,
resource_id: str,
user_id: str,
) -> dict[str, Any] | None:
resource = self.repository.get_resource_for_user(resource_id, user_id)
if resource is None:
return None
return self._resource_detail(resource)
def delete_resource(self, resource_id: str, user_id: str) -> dict[str, Any] | None:
before = self.repository.get_resource_for_user(resource_id, user_id)
if before is None:
return None
resource = self.repository.soft_delete_resource(resource_id, user_id)
self._cleanup_resource_file(before)
return self._resource_summary(resource)
def _cleanup_resource_file(self, resource: dict[str, Any]) -> None:
uri = str(resource.get("uri") or "")
if not uri.startswith("file://"):
return
parsed = urlparse(uri)
if parsed.scheme != "file":
return
try:
path = Path(unquote(parsed.path)).resolve()
storage_root = self.config.storage_dir.resolve()
except OSError:
return
if not path.is_relative_to(storage_root):
return
try:
path.unlink(missing_ok=True)
except OSError:
return
_remove_empty_parents(path.parent, stop_at=storage_root)
async def search_memories(
self,
*,
user_id: str,
agent_id: str | None,
query: str,
conversation_id: str | None,
scope: list[str],
method: str,
top_k: int,
radius: float | None,
include_profile: bool,
enable_llm_rerank: bool,
filters: dict[str, Any] | None,
app_id: str,
project_id: str,
) -> dict[str, Any]:
results: list[dict[str, Any]] = []
session_resource_map: dict[str, dict[str, Any]] = {}
if "current_chat" in scope and conversation_id:
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters=_combine_filters(
filters,
{"session_id": f"chat:{conversation_id}"},
),
)
results.extend(
self._extract_results(
await self.backend_client.search_memory(payload),
source_scope="current_chat",
session_resource_map=session_resource_map,
user_id=user_id,
)
)
if "resources" in scope:
resources = self.repository.list_extracted_resources(
user_id,
app_id,
project_id,
)
session_resource_map.update({item["session_id"]: item for item in resources})
session_ids = [item["session_id"] for item in resources]
for batch in _chunks(session_ids, self.config.resource_search_batch_size):
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters=_combine_filters(
filters,
{"session_id": {"in": batch}},
),
)
results.extend(
self._extract_results(
await self.backend_client.search_memory(payload),
source_scope="resources",
session_resource_map=session_resource_map,
user_id=user_id,
)
)
if "all_user_memory" in scope:
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters=filters,
)
results.extend(
self._extract_results(
await self.backend_client.search_memory(payload),
source_scope="all_user_memory",
session_resource_map=session_resource_map,
user_id=user_id,
)
)
filtered = self._apply_tombstones(user_id, results)
overridden = self._apply_overrides(user_id, filtered)
return {"results": overridden}
async def add_memory(
self,
*,
user_id: str,
session_id: str,
app_id: str,
project_id: str,
messages: list[dict[str, Any]],
) -> dict[str, Any]:
attachments, generated_paths = self._prepare_memory_attachments(
user_id=user_id,
session_id=session_id,
app_id=app_id,
project_id=project_id,
messages=messages,
)
payload = {
"session_id": session_id,
"app_id": app_id,
"project_id": project_id,
"messages": messages,
}
try:
backend = await self.backend_client.add_memory(payload)
for attachment in attachments:
self.repository.create_attachment(**attachment)
except Exception:
for path in generated_paths:
path.unlink(missing_ok=True)
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
raise
return {"session_id": session_id, "backend": backend}
def _register_resource_attachment(self, resource: dict[str, Any]) -> None:
self.repository.create_attachment(
user_id=resource["user_id"],
app_id=resource["app_id"],
project_id=resource["project_id"],
session_id=resource["session_id"],
resource_id=resource["id"],
content_type=resource["content_type"],
name=resource["original_filename"] or resource["id"],
internal_uri=resource["uri"],
source="resource_upload",
sha256=resource["sha256"],
)
def _prepare_memory_attachments(
self,
*,
user_id: str,
session_id: str,
app_id: str,
project_id: str,
messages: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[Path]]:
attachments: list[dict[str, Any]] = []
generated_paths: list[Path] = []
try:
for message in messages:
content = message.get("content")
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict):
continue
uri = item.get("uri")
encoded = item.get("base64")
if not uri and not encoded:
continue
attachment_id = f"a_{uuid.uuid4().hex}"
name = _attachment_name(item, str(uri) if uri else None)
sha256: str | None = None
if uri:
internal_uri = str(uri)
source = "memory_add_uri"
else:
try:
data = base64.b64decode(str(encoded), validate=True)
except (binascii.Error, ValueError) as exc:
raise InvalidAttachment(
f"invalid base64 attachment: {name}"
) from exc
if len(data) > self.config.max_upload_bytes:
raise UploadTooLarge(
f"attachment exceeds max size of "
f"{self.config.max_upload_bytes} bytes"
)
sha256 = hashlib.sha256(data).hexdigest()
path = (
self.config.storage_dir
/ user_id
/ "memory_attachments"
/ sha256
/ name
)
if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(data)
generated_paths.append(path)
internal_uri = path.resolve().as_uri()
source = "memory_add_base64"
attachments.append(
{
"id": attachment_id,
"user_id": user_id,
"app_id": app_id,
"project_id": project_id,
"session_id": session_id,
"resource_id": None,
"content_type": str(item.get("type") or "doc"),
"name": name,
"internal_uri": internal_uri,
"source": source,
"sha256": sha256,
}
)
except Exception:
for path in generated_paths:
path.unlink(missing_ok=True)
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
raise
return attachments, generated_paths
async def flush_memory(
self,
*,
session_id: str,
app_id: str,
project_id: str,
) -> dict[str, Any]:
return {
"session_id": session_id,
"backend": await self.backend_client.flush_memory(
session_id,
app_id,
project_id,
),
}
def _search_payload(
self,
*,
user_id: str,
agent_id: str | None,
query: str,
method: str,
top_k: int,
radius: float | None,
include_profile: bool,
enable_llm_rerank: bool,
app_id: str,
project_id: str,
filters: dict[str, Any] | None,
) -> dict[str, Any]:
payload: dict[str, Any] = {
"query": query,
"method": method,
"top_k": top_k,
"include_profile": include_profile,
"enable_llm_rerank": enable_llm_rerank,
"app_id": app_id,
"project_id": project_id,
}
payload["agent_id" if agent_id else "user_id"] = agent_id or user_id
if radius is not None:
payload["radius"] = radius
if filters is not None:
payload["filters"] = filters
return payload
def _extract_results(
self,
response: dict[str, Any],
*,
source_scope: str,
session_resource_map: dict[str, dict[str, Any]],
user_id: str,
) -> list[dict[str, Any]]:
data = response.get("data", {})
raw_items: list[tuple[str, dict[str, Any]]] = []
memory_types = {
"episodes": "episode",
"profiles": "profile",
"agent_cases": "agent_case",
"agent_skills": "agent_skill",
"unprocessed_messages": "unprocessed_message",
}
for key, memory_type in memory_types.items():
raw_items.extend(
(memory_type, item) for item in (data.get(key, []) or [])
)
normalized = []
attachment_cache: dict[str, list[dict[str, Any]]] = {}
for memory_type, raw in raw_items:
session_id = raw.get("session_id")
resource = session_resource_map.get(session_id)
if resource is None and isinstance(session_id, str):
resource = self.repository.get_resource_by_session_for_user(
session_id,
user_id,
)
attachments: list[dict[str, Any]] = []
if isinstance(session_id, str):
if session_id not in attachment_cache:
attachment_cache[session_id] = (
self.repository.list_attachments_for_session(
user_id,
session_id,
)
)
session_attachments = attachment_cache[session_id]
attachments = _matching_attachments(raw, session_attachments)
normalized.append(
{
"id": raw.get("id"),
"memory_type": memory_type,
"session_id": session_id,
"text": _display_text(raw),
"score": raw.get("score"),
"source_scope": source_scope,
"resource_id": resource["id"] if resource else None,
"resource_uri": (
public_resource_uri(user_id, resource["id"]) if resource else None
),
"attachments": attachments,
"raw": raw,
}
)
return normalized
def _apply_tombstones(
self,
user_id: str,
results: list[dict[str, Any]],
) -> list[dict[str, Any]]:
tombstones = self.repository.get_tombstones(user_id)
memory_ids = {item["memory_id"] for item in tombstones if item["memory_id"]}
session_ids = {item["session_id"] for item in tombstones if item["session_id"]}
return [
item
for item in results
if item.get("id") not in memory_ids
and item.get("session_id") not in session_ids
]
def _apply_overrides(
self,
user_id: str,
results: list[dict[str, Any]],
) -> list[dict[str, Any]]:
overrides = {
item["memory_id"]: item
for item in self.repository.get_active_overrides(user_id)
if item["memory_id"]
}
for result in results:
override = overrides.get(result.get("id"))
if override:
result["text"] = override["override_text"]
result["override_id"] = override["id"]
return results
def upsert_override(
self,
*,
user_id: str,
memory_id: str,
session_id: str | None,
override_text: str,
) -> dict[str, Any]:
override = self.repository.upsert_override(
user_id,
memory_id,
session_id,
override_text,
)
return {"memory_id": memory_id, "override_id": override["id"], "status": "active"}
def assert_memory_session_owned(self, user_id: str, session_id: str) -> None:
if session_id == f"memory_edit:{user_id}":
return
if session_id.startswith("resource:"):
resource = self.repository.get_resource_by_session_for_user(
session_id,
user_id,
)
if resource is not None:
return
raise PermissionError("memory session does not belong to user")
def delete_memory(
self,
*,
user_id: str,
memory_id: str,
session_id: str | None,
reason: str | None,
) -> dict[str, Any]:
tombstone = self.repository.add_tombstone(
user_id,
memory_id,
session_id,
reason,
)
return {"memory_id": memory_id, "tombstone_id": tombstone["id"], "status": "deleted"}
def _resource_summary(self, resource: dict[str, Any]) -> dict[str, Any]:
return {
"resource_id": resource["id"],
"session_id": resource["session_id"],
"uri": public_resource_uri(resource["user_id"], resource["id"]),
"status": resource["status"],
}
def _resource_detail(self, resource: dict[str, Any]) -> dict[str, Any]:
return {
"resource_id": resource["id"],
"user_id": resource["user_id"],
"filename": resource["original_filename"],
"content_type": resource["content_type"],
"mime_type": resource["mime_type"],
"uri": public_resource_uri(resource["user_id"], resource["id"]),
"session_id": resource["session_id"],
"status": resource["status"],
"title": resource["title"],
"description": resource["description"],
"created_at": resource["created_at"],
"updated_at": resource["updated_at"],
}
def _combine_filters(
custom_filters: dict[str, Any] | None,
scope_filters: dict[str, Any] | None,
) -> dict[str, Any] | None:
if custom_filters is None:
return scope_filters
if scope_filters is None:
return custom_filters
return {"AND": [custom_filters, scope_filters]}
def _attachment_name(item: dict[str, Any], uri: str | None) -> str:
if item.get("name"):
return _safe_filename(str(item["name"]))
if uri:
parsed = urlparse(uri)
uri_name = Path(unquote(parsed.path)).name
if uri_name:
return _safe_filename(uri_name)
extension = str(item.get("ext") or "bin").lstrip(".") or "bin"
return f"attachment.{extension}"
def _matching_attachments(
raw: dict[str, Any],
attachments: list[dict[str, Any]],
) -> list[dict[str, Any]]:
strings = [value.casefold() for value in _raw_string_values(raw)]
matched: list[dict[str, Any]] = []
seen_uris: set[str] = set()
for attachment in attachments:
name = str(attachment["name"])
internal_uri = str(attachment["internal_uri"])
if internal_uri in seen_uris:
continue
if not any(name.casefold() in value for value in strings):
continue
seen_uris.add(internal_uri)
matched.append(
{
"type": attachment["content_type"],
"name": name,
"internal_uri": internal_uri,
}
)
return matched
def _raw_string_values(value: Any, key: str | None = None) -> list[str]:
if key is not None and key.casefold() == "base64":
return []
if isinstance(value, str):
return [value]
if isinstance(value, dict):
strings: list[str] = []
for item_key, item_value in value.items():
strings.extend(_raw_string_values(item_value, str(item_key)))
return strings
if isinstance(value, list):
strings = []
for item in value:
strings.extend(_raw_string_values(item))
return strings
return []
def _chunks(items: list[str], size: int) -> list[list[str]]:
if not items:
return []
return [items[index : index + size] for index in range(0, len(items), size)]
def _display_text(raw: dict[str, Any]) -> str:
for key in (
"episode",
"summary",
"content",
"profile_data",
"task_intent",
"approach",
"key_insight",
"name",
"description",
):
value = raw.get(key)
if value is None:
continue
if isinstance(value, str):
return value
return str(value)
return ""