651 lines
20 KiB
Python
651 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import mimetypes
|
|
import secrets
|
|
import shutil
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib.parse import unquote, urlparse
|
|
|
|
from fastapi import UploadFile
|
|
|
|
from .config import GatewayConfig
|
|
from .repository import MemoryRepository
|
|
|
|
|
|
def new_resource_id() -> str:
|
|
return f"r_{uuid.uuid4().hex}"
|
|
|
|
|
|
def resource_session_id(user_id: str, resource_id: str) -> str:
|
|
return f"resource:{user_id}:{resource_id}"
|
|
|
|
|
|
def public_resource_uri(user_id: str, resource_id: str) -> str:
|
|
return f"resource://{user_id}/{resource_id}"
|
|
|
|
|
|
def current_timestamp_ms() -> int:
|
|
return int(time.time() * 1000)
|
|
|
|
|
|
def infer_content_type(filename: str | None, mime_type: str | None) -> str:
|
|
mime = (mime_type or mimetypes.guess_type(filename or "")[0] or "").lower()
|
|
suffix = Path(filename or "").suffix.lower()
|
|
if mime.startswith("image/"):
|
|
return "image"
|
|
if mime.startswith("audio/"):
|
|
return "audio"
|
|
if mime == "application/pdf" or suffix == ".pdf":
|
|
return "pdf"
|
|
if mime in {"text/html", "application/xhtml+xml"} or suffix in {".html", ".htm"}:
|
|
return "html"
|
|
if mime.startswith("text/plain") or suffix in {".txt", ".md", ".csv", ".log"}:
|
|
return "text"
|
|
return "doc"
|
|
|
|
|
|
def _safe_filename(filename: str | None) -> str:
|
|
name = Path(filename or "upload.bin").name
|
|
return name or "upload.bin"
|
|
|
|
|
|
class UploadTooLarge(ValueError):
|
|
pass
|
|
|
|
|
|
class UnsupportedContentType(ValueError):
|
|
pass
|
|
|
|
|
|
def _copy_upload(
|
|
file: UploadFile,
|
|
destination: Path,
|
|
max_upload_bytes: int,
|
|
cleanup_root: Path,
|
|
) -> tuple[str, int]:
|
|
sha256 = hashlib.sha256()
|
|
size = 0
|
|
destination.parent.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
with destination.open("wb") as out:
|
|
while True:
|
|
chunk = file.file.read(1024 * 1024)
|
|
if not chunk:
|
|
break
|
|
size += len(chunk)
|
|
if size > max_upload_bytes:
|
|
raise UploadTooLarge(
|
|
f"upload exceeds max size of {max_upload_bytes} bytes"
|
|
)
|
|
sha256.update(chunk)
|
|
out.write(chunk)
|
|
return sha256.hexdigest(), size
|
|
except Exception:
|
|
destination.unlink(missing_ok=True)
|
|
_remove_empty_parents(destination.parent, stop_at=cleanup_root)
|
|
raise
|
|
|
|
|
|
def _mime_allowed(mime_type: str | None, allowed_mime_types: tuple[str, ...]) -> bool:
|
|
mime = (mime_type or "").lower()
|
|
if not mime:
|
|
return False
|
|
for allowed in allowed_mime_types:
|
|
item = allowed.lower()
|
|
if item.endswith("/*") and mime.startswith(item[:-1]):
|
|
return True
|
|
if item == mime:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _remove_empty_parents(path: Path, stop_at: Path | None = None) -> None:
|
|
current = path
|
|
stop = stop_at.resolve() if stop_at is not None else None
|
|
while True:
|
|
try:
|
|
resolved = current.resolve()
|
|
if stop is not None and resolved == stop:
|
|
return
|
|
current.rmdir()
|
|
except OSError:
|
|
return
|
|
parent = current.parent
|
|
if parent == current:
|
|
return
|
|
current = parent
|
|
|
|
|
|
class MemoryGatewayService:
|
|
def __init__(
|
|
self,
|
|
config: GatewayConfig,
|
|
repository: MemoryRepository,
|
|
backend_client: Any,
|
|
) -> None:
|
|
self.config = config
|
|
self.repository = repository
|
|
self.backend_client = backend_client
|
|
|
|
def create_user(self, user_id: str) -> dict[str, Any]:
|
|
user_key = f"uk_{secrets.token_urlsafe(32)}"
|
|
user = self.repository.create_user(user_id, user_key)
|
|
return {
|
|
"user_id": user["id"],
|
|
"user_key": user["user_key"],
|
|
"created_at": user["created_at"],
|
|
}
|
|
|
|
def authenticate_user(self, user_id: str, user_key: str) -> bool:
|
|
user = self.repository.get_user(user_id)
|
|
if user is None:
|
|
return False
|
|
return secrets.compare_digest(str(user["user_key"]), user_key)
|
|
|
|
async def upload_resource(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
app_id: str,
|
|
project_id: str,
|
|
file: UploadFile,
|
|
title: str | None,
|
|
description: str | None,
|
|
) -> dict[str, Any]:
|
|
resource_id = new_resource_id()
|
|
session_id = resource_session_id(user_id, resource_id)
|
|
original_filename = _safe_filename(file.filename)
|
|
mime_type = file.content_type or mimetypes.guess_type(original_filename)[0]
|
|
if not _mime_allowed(mime_type, self.config.allowed_mime_types):
|
|
raise UnsupportedContentType(f"unsupported content type: {mime_type}")
|
|
content_type = infer_content_type(original_filename, mime_type)
|
|
stored_path = self.config.storage_dir / user_id / resource_id / original_filename
|
|
sha256, size_bytes = _copy_upload(
|
|
file,
|
|
stored_path,
|
|
self.config.max_upload_bytes,
|
|
self.config.storage_dir,
|
|
)
|
|
existing = self.repository.find_active_resource_by_sha256(
|
|
user_id=user_id,
|
|
app_id=app_id,
|
|
project_id=project_id,
|
|
sha256=sha256,
|
|
)
|
|
if existing is not None:
|
|
shutil.rmtree(stored_path.parent, ignore_errors=True)
|
|
return self._resource_summary(existing)
|
|
|
|
internal_uri = stored_path.resolve().as_uri()
|
|
|
|
resource = self.repository.create_resource(
|
|
id=resource_id,
|
|
user_id=user_id,
|
|
app_id=app_id,
|
|
project_id=project_id,
|
|
session_id=session_id,
|
|
original_filename=original_filename,
|
|
mime_type=mime_type,
|
|
content_type=content_type,
|
|
uri=internal_uri,
|
|
uri_public=False,
|
|
sha256=sha256,
|
|
size_bytes=size_bytes,
|
|
title=title,
|
|
description=description,
|
|
status="ingesting",
|
|
error_message=None,
|
|
)
|
|
|
|
try:
|
|
await self._retry_backend_call(
|
|
lambda: self.backend_client.add_memory(
|
|
self._build_add_payload(
|
|
resource=resource,
|
|
user_id=user_id,
|
|
app_id=app_id,
|
|
project_id=project_id,
|
|
filename=original_filename,
|
|
)
|
|
)
|
|
)
|
|
await self._retry_backend_call(
|
|
lambda: self.backend_client.flush_memory(session_id, app_id, project_id)
|
|
)
|
|
except Exception as exc:
|
|
failed = self.repository.update_resource_status(
|
|
resource_id,
|
|
"failed",
|
|
str(exc),
|
|
)
|
|
return self._resource_summary(failed or resource)
|
|
|
|
extracted = self.repository.update_resource_status(resource_id, "extracted")
|
|
return self._resource_summary(extracted or resource)
|
|
|
|
async def _retry_backend_call(self, operation: Any) -> Any:
|
|
attempts = max(1, self.config.backend_ingest_attempts)
|
|
last_error: Exception | None = None
|
|
for attempt in range(attempts):
|
|
try:
|
|
return await operation()
|
|
except Exception as exc:
|
|
last_error = exc
|
|
if attempt == attempts - 1:
|
|
break
|
|
delay = self.config.backend_retry_delay_seconds
|
|
if delay > 0:
|
|
await asyncio.sleep(delay)
|
|
if last_error is None:
|
|
raise RuntimeError("upstream memory service operation failed")
|
|
raise last_error
|
|
|
|
def _build_add_payload(
|
|
self,
|
|
*,
|
|
resource: dict[str, Any],
|
|
user_id: str,
|
|
app_id: str,
|
|
project_id: str,
|
|
filename: str,
|
|
) -> dict[str, Any]:
|
|
content_item = self._build_content_item(resource=resource, filename=filename)
|
|
return {
|
|
"session_id": resource["session_id"],
|
|
"app_id": app_id,
|
|
"project_id": project_id,
|
|
"messages": [
|
|
{
|
|
"sender_id": user_id,
|
|
"role": "user",
|
|
"timestamp": current_timestamp_ms(),
|
|
"content": [content_item],
|
|
}
|
|
],
|
|
}
|
|
|
|
def _build_content_item(
|
|
self,
|
|
*,
|
|
resource: dict[str, Any],
|
|
filename: str,
|
|
) -> dict[str, Any]:
|
|
content_type = str(resource["content_type"])
|
|
path = self._resource_file_path(resource)
|
|
content = path.read_bytes()
|
|
item = {
|
|
"type": content_type,
|
|
"name": filename,
|
|
"ext": Path(filename).suffix.lstrip(".") or None,
|
|
"extras": {
|
|
"resource_id": resource["id"],
|
|
"source": "user_upload",
|
|
},
|
|
}
|
|
if content_type == "text":
|
|
item["text"] = content.decode("utf-8", errors="replace")
|
|
else:
|
|
item["base64"] = base64.b64encode(content).decode("ascii")
|
|
return item
|
|
|
|
def _resource_file_path(self, resource: dict[str, Any]) -> Path:
|
|
uri = str(resource["uri"])
|
|
parsed = urlparse(uri)
|
|
if parsed.scheme != "file":
|
|
raise ValueError(f"unsupported resource uri scheme: {parsed.scheme}")
|
|
return Path(unquote(parsed.path)).resolve(strict=True)
|
|
|
|
def list_resources(self, user_id: str) -> list[dict[str, Any]]:
|
|
return [self._resource_detail(item) for item in self.repository.list_resources(user_id)]
|
|
|
|
def get_resource_detail(
|
|
self,
|
|
resource_id: str,
|
|
user_id: str,
|
|
) -> dict[str, Any] | None:
|
|
resource = self.repository.get_resource_for_user(resource_id, user_id)
|
|
if resource is None:
|
|
return None
|
|
return self._resource_detail(resource)
|
|
|
|
def delete_resource(self, resource_id: str, user_id: str) -> dict[str, Any] | None:
|
|
before = self.repository.get_resource_for_user(resource_id, user_id)
|
|
if before is None:
|
|
return None
|
|
resource = self.repository.soft_delete_resource(resource_id, user_id)
|
|
self._cleanup_resource_file(before)
|
|
return self._resource_summary(resource)
|
|
|
|
def _cleanup_resource_file(self, resource: dict[str, Any]) -> None:
|
|
uri = str(resource.get("uri") or "")
|
|
if not uri.startswith("file://"):
|
|
return
|
|
parsed = urlparse(uri)
|
|
if parsed.scheme != "file":
|
|
return
|
|
try:
|
|
path = Path(unquote(parsed.path)).resolve()
|
|
storage_root = self.config.storage_dir.resolve()
|
|
except OSError:
|
|
return
|
|
if not path.is_relative_to(storage_root):
|
|
return
|
|
try:
|
|
path.unlink(missing_ok=True)
|
|
except OSError:
|
|
return
|
|
_remove_empty_parents(path.parent, stop_at=storage_root)
|
|
|
|
async def search_memories(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
query: str,
|
|
conversation_id: str | None,
|
|
scope: list[str],
|
|
top_k: int,
|
|
app_id: str,
|
|
project_id: str,
|
|
) -> dict[str, Any]:
|
|
results: list[dict[str, Any]] = []
|
|
session_resource_map: dict[str, dict[str, Any]] = {}
|
|
|
|
if "current_chat" in scope and conversation_id:
|
|
payload = self._search_payload(
|
|
user_id=user_id,
|
|
query=query,
|
|
top_k=top_k,
|
|
app_id=app_id,
|
|
project_id=project_id,
|
|
filters={"session_id": f"chat:{conversation_id}"},
|
|
)
|
|
results.extend(
|
|
self._extract_results(
|
|
await self.backend_client.search_memory(payload),
|
|
source_scope="current_chat",
|
|
session_resource_map=session_resource_map,
|
|
user_id=user_id,
|
|
)
|
|
)
|
|
|
|
if "resources" in scope:
|
|
resources = self.repository.list_extracted_resources(
|
|
user_id,
|
|
app_id,
|
|
project_id,
|
|
)
|
|
session_resource_map.update({item["session_id"]: item for item in resources})
|
|
session_ids = [item["session_id"] for item in resources]
|
|
for batch in _chunks(session_ids, self.config.resource_search_batch_size):
|
|
payload = self._search_payload(
|
|
user_id=user_id,
|
|
query=query,
|
|
top_k=top_k,
|
|
app_id=app_id,
|
|
project_id=project_id,
|
|
filters={"session_id": {"in": batch}},
|
|
)
|
|
results.extend(
|
|
self._extract_results(
|
|
await self.backend_client.search_memory(payload),
|
|
source_scope="resources",
|
|
session_resource_map=session_resource_map,
|
|
user_id=user_id,
|
|
)
|
|
)
|
|
|
|
if "all_user_memory" in scope:
|
|
payload = self._search_payload(
|
|
user_id=user_id,
|
|
query=query,
|
|
top_k=top_k,
|
|
app_id=app_id,
|
|
project_id=project_id,
|
|
filters=None,
|
|
)
|
|
results.extend(
|
|
self._extract_results(
|
|
await self.backend_client.search_memory(payload),
|
|
source_scope="all_user_memory",
|
|
session_resource_map=session_resource_map,
|
|
user_id=user_id,
|
|
)
|
|
)
|
|
|
|
filtered = self._apply_tombstones(user_id, results)
|
|
overridden = self._apply_overrides(user_id, filtered)
|
|
return {"results": overridden}
|
|
|
|
async def add_memory(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
app_id: str,
|
|
project_id: str,
|
|
messages: list[dict[str, Any]],
|
|
) -> dict[str, Any]:
|
|
payload = {
|
|
"session_id": session_id,
|
|
"app_id": app_id,
|
|
"project_id": project_id,
|
|
"messages": messages,
|
|
}
|
|
return {
|
|
"session_id": session_id,
|
|
"backend": await self.backend_client.add_memory(payload),
|
|
}
|
|
|
|
async def flush_memory(
|
|
self,
|
|
*,
|
|
session_id: str,
|
|
app_id: str,
|
|
project_id: str,
|
|
) -> dict[str, Any]:
|
|
return {
|
|
"session_id": session_id,
|
|
"backend": await self.backend_client.flush_memory(
|
|
session_id,
|
|
app_id,
|
|
project_id,
|
|
),
|
|
}
|
|
|
|
def _search_payload(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
query: str,
|
|
top_k: int,
|
|
app_id: str,
|
|
project_id: str,
|
|
filters: dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
payload: dict[str, Any] = {
|
|
"user_id": user_id,
|
|
"query": query,
|
|
"top_k": top_k,
|
|
"app_id": app_id,
|
|
"project_id": project_id,
|
|
}
|
|
if filters is not None:
|
|
payload["filters"] = filters
|
|
return payload
|
|
|
|
def _extract_results(
|
|
self,
|
|
response: dict[str, Any],
|
|
*,
|
|
source_scope: str,
|
|
session_resource_map: dict[str, dict[str, Any]],
|
|
user_id: str,
|
|
) -> list[dict[str, Any]]:
|
|
data = response.get("data", {})
|
|
raw_items: list[dict[str, Any]] = []
|
|
for key in (
|
|
"episodes",
|
|
"profiles",
|
|
"agent_cases",
|
|
"agent_skills",
|
|
"unprocessed_messages",
|
|
):
|
|
raw_items.extend(data.get(key, []) or [])
|
|
|
|
normalized = []
|
|
for raw in raw_items:
|
|
session_id = raw.get("session_id")
|
|
resource = session_resource_map.get(session_id)
|
|
if resource is None and isinstance(session_id, str):
|
|
resource = self.repository.get_resource_by_session_for_user(
|
|
session_id,
|
|
user_id,
|
|
)
|
|
normalized.append(
|
|
{
|
|
"id": raw.get("id"),
|
|
"session_id": session_id,
|
|
"text": _display_text(raw),
|
|
"score": raw.get("score"),
|
|
"source_scope": source_scope,
|
|
"resource_id": resource["id"] if resource else None,
|
|
"resource_uri": (
|
|
public_resource_uri(user_id, resource["id"]) if resource else None
|
|
),
|
|
"raw": raw,
|
|
}
|
|
)
|
|
return normalized
|
|
|
|
def _apply_tombstones(
|
|
self,
|
|
user_id: str,
|
|
results: list[dict[str, Any]],
|
|
) -> list[dict[str, Any]]:
|
|
tombstones = self.repository.get_tombstones(user_id)
|
|
memory_ids = {item["memory_id"] for item in tombstones if item["memory_id"]}
|
|
session_ids = {item["session_id"] for item in tombstones if item["session_id"]}
|
|
return [
|
|
item
|
|
for item in results
|
|
if item.get("id") not in memory_ids
|
|
and item.get("session_id") not in session_ids
|
|
]
|
|
|
|
def _apply_overrides(
|
|
self,
|
|
user_id: str,
|
|
results: list[dict[str, Any]],
|
|
) -> list[dict[str, Any]]:
|
|
overrides = {
|
|
item["memory_id"]: item
|
|
for item in self.repository.get_active_overrides(user_id)
|
|
if item["memory_id"]
|
|
}
|
|
for result in results:
|
|
override = overrides.get(result.get("id"))
|
|
if override:
|
|
result["text"] = override["override_text"]
|
|
result["override_id"] = override["id"]
|
|
return results
|
|
|
|
def upsert_override(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
memory_id: str,
|
|
session_id: str | None,
|
|
override_text: str,
|
|
) -> dict[str, Any]:
|
|
override = self.repository.upsert_override(
|
|
user_id,
|
|
memory_id,
|
|
session_id,
|
|
override_text,
|
|
)
|
|
return {"memory_id": memory_id, "override_id": override["id"], "status": "active"}
|
|
|
|
def assert_memory_session_owned(self, user_id: str, session_id: str) -> None:
|
|
if session_id == f"memory_edit:{user_id}":
|
|
return
|
|
if session_id.startswith("resource:"):
|
|
resource = self.repository.get_resource_by_session_for_user(
|
|
session_id,
|
|
user_id,
|
|
)
|
|
if resource is not None:
|
|
return
|
|
raise PermissionError("memory session does not belong to user")
|
|
|
|
def delete_memory(
|
|
self,
|
|
*,
|
|
user_id: str,
|
|
memory_id: str,
|
|
session_id: str | None,
|
|
reason: str | None,
|
|
) -> dict[str, Any]:
|
|
tombstone = self.repository.add_tombstone(
|
|
user_id,
|
|
memory_id,
|
|
session_id,
|
|
reason,
|
|
)
|
|
return {"memory_id": memory_id, "tombstone_id": tombstone["id"], "status": "deleted"}
|
|
|
|
def _resource_summary(self, resource: dict[str, Any]) -> dict[str, Any]:
|
|
return {
|
|
"resource_id": resource["id"],
|
|
"session_id": resource["session_id"],
|
|
"uri": public_resource_uri(resource["user_id"], resource["id"]),
|
|
"status": resource["status"],
|
|
}
|
|
|
|
def _resource_detail(self, resource: dict[str, Any]) -> dict[str, Any]:
|
|
return {
|
|
"resource_id": resource["id"],
|
|
"user_id": resource["user_id"],
|
|
"filename": resource["original_filename"],
|
|
"content_type": resource["content_type"],
|
|
"mime_type": resource["mime_type"],
|
|
"uri": public_resource_uri(resource["user_id"], resource["id"]),
|
|
"session_id": resource["session_id"],
|
|
"status": resource["status"],
|
|
"title": resource["title"],
|
|
"description": resource["description"],
|
|
"created_at": resource["created_at"],
|
|
"updated_at": resource["updated_at"],
|
|
}
|
|
|
|
|
|
def _chunks(items: list[str], size: int) -> list[list[str]]:
|
|
if not items:
|
|
return []
|
|
return [items[index : index + size] for index in range(0, len(items), size)]
|
|
|
|
|
|
def _display_text(raw: dict[str, Any]) -> str:
|
|
for key in (
|
|
"episode",
|
|
"summary",
|
|
"content",
|
|
"profile_data",
|
|
"task_intent",
|
|
"approach",
|
|
"key_insight",
|
|
"name",
|
|
"description",
|
|
):
|
|
value = raw.get(key)
|
|
if value is None:
|
|
continue
|
|
if isinstance(value, str):
|
|
return value
|
|
return str(value)
|
|
return ""
|