feat: extend memory search and attachment mapping

This commit is contained in:
2026-06-15 17:25:44 +08:00
parent 15462a95cb
commit e5cd87789f
9 changed files with 1194 additions and 54 deletions

View File

@ -8,14 +8,19 @@ from typing import Any, Literal
from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from starlette.responses import Response
from .config import GatewayConfig
from .db import init_db
from .backend_client import BackendClient
from .repository import MemoryRepository
from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge
from .service import (
InvalidAttachment,
MemoryGatewayService,
UnsupportedContentType,
UploadTooLarge,
)
API_LOGGER = logging.getLogger("memory_gateway.api")
@ -34,15 +39,28 @@ SENSITIVE_FIELD_NAMES = {
class SearchMemoriesRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
agent_id: str | None = Field(default=None, min_length=1)
conversation_id: str | None = None
query: str = Field(min_length=1)
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
default_factory=lambda: ["current_chat", "resources"]
)
top_k: int = Field(default=8, ge=1, le=100)
method: Literal["keyword", "vector", "hybrid", "agentic"] = "hybrid"
top_k: int = 8
radius: float | None = Field(default=None, ge=0, le=1)
include_profile: bool = True
enable_llm_rerank: bool = True
filters: dict[str, Any] | None = None
app_id: str = "default"
project_id: str = "default"
@field_validator("top_k")
@classmethod
def validate_top_k(cls, value: int) -> int:
if value != -1 and not 1 <= value <= 100:
raise ValueError("top_k must be -1 or in 1..100")
return value
class AddMemoryMessage(BaseModel):
sender_id: str = Field(min_length=1)
@ -367,10 +385,16 @@ def create_app(
require_user(request.user_id, request.user_key)
return await service.search_memories(
user_id=request.user_id,
agent_id=request.agent_id,
query=request.query,
conversation_id=request.conversation_id,
scope=request.scope,
method=request.method,
top_k=request.top_k,
radius=request.radius,
include_profile=request.include_profile,
enable_llm_rerank=request.enable_llm_rerank,
filters=request.filters,
app_id=request.app_id,
project_id=request.project_id,
)
@ -380,12 +404,18 @@ def create_app(
request: AddMemoryRequest,
) -> dict[str, Any]:
require_user(request.user_id, request.user_key)
return await service.add_memory(
session_id=request.session_id,
app_id=request.app_id,
project_id=request.project_id,
messages=[message.model_dump() for message in request.messages],
)
try:
return await service.add_memory(
user_id=request.user_id,
session_id=request.session_id,
app_id=request.app_id,
project_id=request.project_id,
messages=[message.model_dump() for message in request.messages],
)
except UploadTooLarge as exc:
raise HTTPException(status_code=413, detail=str(exc)) from exc
except InvalidAttachment as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
@router.post("/memories/flush")
async def flush_memory(

View File

@ -43,6 +43,62 @@ ON user_resources (session_id);
CREATE INDEX IF NOT EXISTS idx_user_resources_user_id
ON user_resources (user_id);
CREATE TABLE IF NOT EXISTS memory_attachments (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
app_id TEXT NOT NULL DEFAULT 'default',
project_id TEXT NOT NULL DEFAULT 'default',
session_id TEXT NOT NULL,
resource_id TEXT,
content_type TEXT NOT NULL,
name TEXT NOT NULL,
internal_uri TEXT NOT NULL,
source TEXT NOT NULL,
sha256 TEXT,
created_at TIMESTAMP NOT NULL,
deleted_at TIMESTAMP
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_memory_attachments_unique_uri
ON memory_attachments (user_id, session_id, internal_uri);
CREATE INDEX IF NOT EXISTS idx_memory_attachments_user_session
ON memory_attachments (user_id, session_id, deleted_at);
CREATE INDEX IF NOT EXISTS idx_memory_attachments_resource
ON memory_attachments (resource_id, deleted_at);
INSERT OR IGNORE INTO memory_attachments (
id,
user_id,
app_id,
project_id,
session_id,
resource_id,
content_type,
name,
internal_uri,
source,
sha256,
created_at,
deleted_at
)
SELECT
'a_resource_' || id,
user_id,
app_id,
project_id,
session_id,
id,
content_type,
COALESCE(original_filename, id),
uri,
'resource_upload',
sha256,
created_at,
deleted_at
FROM user_resources;
CREATE TABLE IF NOT EXISTS memory_tombstones (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,

View File

@ -96,9 +96,13 @@ class MemoryRepository:
now = utc_now()
where = "id = ? AND deleted_at IS NULL"
params: tuple[Any, ...] = (now, now, resource_id)
attachment_where = "resource_id = ? AND deleted_at IS NULL"
attachment_params: tuple[Any, ...] = (now, resource_id)
if user_id is not None:
where += " AND user_id = ?"
params = (now, now, resource_id, user_id)
attachment_where += " AND user_id = ?"
attachment_params = (now, resource_id, user_id)
with connect(self.db_path) as conn:
conn.execute(
f"""
@ -108,6 +112,14 @@ class MemoryRepository:
""",
params,
)
conn.execute(
f"""
UPDATE memory_attachments
SET deleted_at = ?
WHERE {attachment_where}
""",
attachment_params,
)
conn.commit()
return self.get_resource(resource_id)
@ -215,6 +227,62 @@ class MemoryRepository:
).fetchall()
return [dict(row) for row in rows]
def create_attachment(self, **values: Any) -> dict[str, Any]:
attachment_id = str(values.get("id") or f"a_{uuid.uuid4().hex}")
payload = {
"id": attachment_id,
"created_at": utc_now(),
"deleted_at": None,
**values,
}
with connect(self.db_path) as conn:
conn.execute(
"""
INSERT OR IGNORE INTO memory_attachments (
id, user_id, app_id, project_id, session_id, resource_id,
content_type, name, internal_uri, source, sha256,
created_at, deleted_at
) VALUES (
:id, :user_id, :app_id, :project_id, :session_id, :resource_id,
:content_type, :name, :internal_uri, :source, :sha256,
:created_at, :deleted_at
)
""",
payload,
)
row = conn.execute(
"""
SELECT * FROM memory_attachments
WHERE user_id = ? AND session_id = ? AND internal_uri = ?
""",
(
payload["user_id"],
payload["session_id"],
payload["internal_uri"],
),
).fetchone()
conn.commit()
attachment = _row_to_dict(row)
if attachment is None:
raise RuntimeError("created attachment could not be read back")
return attachment
def list_attachments_for_session(
self,
user_id: str,
session_id: str,
) -> list[dict[str, Any]]:
with connect(self.db_path) as conn:
rows = conn.execute(
"""
SELECT * FROM memory_attachments
WHERE user_id = ? AND session_id = ? AND deleted_at IS NULL
ORDER BY created_at ASC, id ASC
""",
(user_id, session_id),
).fetchall()
return [dict(row) for row in rows]
def add_tombstone(
self,
user_id: str,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import base64
import binascii
import hashlib
import mimetypes
import secrets
@ -63,6 +64,10 @@ class UnsupportedContentType(ValueError):
pass
class InvalidAttachment(ValueError):
pass
def _copy_upload(
file: UploadFile,
destination: Path,
@ -180,6 +185,7 @@ class MemoryGatewayService:
)
if existing is not None:
shutil.rmtree(stored_path.parent, ignore_errors=True)
self._register_resource_attachment(existing)
return self._resource_summary(existing)
internal_uri = stored_path.resolve().as_uri()
@ -202,6 +208,7 @@ class MemoryGatewayService:
status="ingesting",
error_message=None,
)
self._register_resource_attachment(resource)
try:
await self._retry_backend_call(
@ -346,10 +353,16 @@ class MemoryGatewayService:
self,
*,
user_id: str,
agent_id: str | None,
query: str,
conversation_id: str | None,
scope: list[str],
method: str,
top_k: int,
radius: float | None,
include_profile: bool,
enable_llm_rerank: bool,
filters: dict[str, Any] | None,
app_id: str,
project_id: str,
) -> dict[str, Any]:
@ -359,11 +372,19 @@ class MemoryGatewayService:
if "current_chat" in scope and conversation_id:
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters={"session_id": f"chat:{conversation_id}"},
filters=_combine_filters(
filters,
{"session_id": f"chat:{conversation_id}"},
),
)
results.extend(
self._extract_results(
@ -385,11 +406,19 @@ class MemoryGatewayService:
for batch in _chunks(session_ids, self.config.resource_search_batch_size):
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters={"session_id": {"in": batch}},
filters=_combine_filters(
filters,
{"session_id": {"in": batch}},
),
)
results.extend(
self._extract_results(
@ -403,11 +432,16 @@ class MemoryGatewayService:
if "all_user_memory" in scope:
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters=None,
filters=filters,
)
results.extend(
self._extract_results(
@ -425,21 +459,126 @@ class MemoryGatewayService:
async def add_memory(
self,
*,
user_id: str,
session_id: str,
app_id: str,
project_id: str,
messages: list[dict[str, Any]],
) -> dict[str, Any]:
attachments, generated_paths = self._prepare_memory_attachments(
user_id=user_id,
session_id=session_id,
app_id=app_id,
project_id=project_id,
messages=messages,
)
payload = {
"session_id": session_id,
"app_id": app_id,
"project_id": project_id,
"messages": messages,
}
return {
"session_id": session_id,
"backend": await self.backend_client.add_memory(payload),
}
try:
backend = await self.backend_client.add_memory(payload)
for attachment in attachments:
self.repository.create_attachment(**attachment)
except Exception:
for path in generated_paths:
path.unlink(missing_ok=True)
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
raise
return {"session_id": session_id, "backend": backend}
def _register_resource_attachment(self, resource: dict[str, Any]) -> None:
self.repository.create_attachment(
user_id=resource["user_id"],
app_id=resource["app_id"],
project_id=resource["project_id"],
session_id=resource["session_id"],
resource_id=resource["id"],
content_type=resource["content_type"],
name=resource["original_filename"] or resource["id"],
internal_uri=resource["uri"],
source="resource_upload",
sha256=resource["sha256"],
)
def _prepare_memory_attachments(
self,
*,
user_id: str,
session_id: str,
app_id: str,
project_id: str,
messages: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[Path]]:
attachments: list[dict[str, Any]] = []
generated_paths: list[Path] = []
try:
for message in messages:
content = message.get("content")
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict):
continue
uri = item.get("uri")
encoded = item.get("base64")
if not uri and not encoded:
continue
attachment_id = f"a_{uuid.uuid4().hex}"
name = _attachment_name(item, str(uri) if uri else None)
sha256: str | None = None
if uri:
internal_uri = str(uri)
source = "memory_add_uri"
else:
try:
data = base64.b64decode(str(encoded), validate=True)
except (binascii.Error, ValueError) as exc:
raise InvalidAttachment(
f"invalid base64 attachment: {name}"
) from exc
if len(data) > self.config.max_upload_bytes:
raise UploadTooLarge(
f"attachment exceeds max size of "
f"{self.config.max_upload_bytes} bytes"
)
sha256 = hashlib.sha256(data).hexdigest()
path = (
self.config.storage_dir
/ user_id
/ "memory_attachments"
/ sha256
/ name
)
if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(data)
generated_paths.append(path)
internal_uri = path.resolve().as_uri()
source = "memory_add_base64"
attachments.append(
{
"id": attachment_id,
"user_id": user_id,
"app_id": app_id,
"project_id": project_id,
"session_id": session_id,
"resource_id": None,
"content_type": str(item.get("type") or "doc"),
"name": name,
"internal_uri": internal_uri,
"source": source,
"sha256": sha256,
}
)
except Exception:
for path in generated_paths:
path.unlink(missing_ok=True)
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
raise
return attachments, generated_paths
async def flush_memory(
self,
@ -461,19 +600,29 @@ class MemoryGatewayService:
self,
*,
user_id: str,
agent_id: str | None,
query: str,
method: str,
top_k: int,
radius: float | None,
include_profile: bool,
enable_llm_rerank: bool,
app_id: str,
project_id: str,
filters: dict[str, Any] | None,
) -> dict[str, Any]:
payload: dict[str, Any] = {
"user_id": user_id,
"query": query,
"method": method,
"top_k": top_k,
"include_profile": include_profile,
"enable_llm_rerank": enable_llm_rerank,
"app_id": app_id,
"project_id": project_id,
}
payload["agent_id" if agent_id else "user_id"] = agent_id or user_id
if radius is not None:
payload["radius"] = radius
if filters is not None:
payload["filters"] = filters
return payload
@ -487,18 +636,22 @@ class MemoryGatewayService:
user_id: str,
) -> list[dict[str, Any]]:
data = response.get("data", {})
raw_items: list[dict[str, Any]] = []
for key in (
"episodes",
"profiles",
"agent_cases",
"agent_skills",
"unprocessed_messages",
):
raw_items.extend(data.get(key, []) or [])
raw_items: list[tuple[str, dict[str, Any]]] = []
memory_types = {
"episodes": "episode",
"profiles": "profile",
"agent_cases": "agent_case",
"agent_skills": "agent_skill",
"unprocessed_messages": "unprocessed_message",
}
for key, memory_type in memory_types.items():
raw_items.extend(
(memory_type, item) for item in (data.get(key, []) or [])
)
normalized = []
for raw in raw_items:
attachment_cache: dict[str, list[dict[str, Any]]] = {}
for memory_type, raw in raw_items:
session_id = raw.get("session_id")
resource = session_resource_map.get(session_id)
if resource is None and isinstance(session_id, str):
@ -506,9 +659,21 @@ class MemoryGatewayService:
session_id,
user_id,
)
attachments: list[dict[str, Any]] = []
if isinstance(session_id, str):
if session_id not in attachment_cache:
attachment_cache[session_id] = (
self.repository.list_attachments_for_session(
user_id,
session_id,
)
)
session_attachments = attachment_cache[session_id]
attachments = _matching_attachments(raw, session_attachments)
normalized.append(
{
"id": raw.get("id"),
"memory_type": memory_type,
"session_id": session_id,
"text": _display_text(raw),
"score": raw.get("score"),
@ -517,6 +682,7 @@ class MemoryGatewayService:
"resource_uri": (
public_resource_uri(user_id, resource["id"]) if resource else None
),
"attachments": attachments,
"raw": raw,
}
)
@ -623,6 +789,72 @@ class MemoryGatewayService:
}
def _combine_filters(
custom_filters: dict[str, Any] | None,
scope_filters: dict[str, Any] | None,
) -> dict[str, Any] | None:
if custom_filters is None:
return scope_filters
if scope_filters is None:
return custom_filters
return {"AND": [custom_filters, scope_filters]}
def _attachment_name(item: dict[str, Any], uri: str | None) -> str:
if item.get("name"):
return _safe_filename(str(item["name"]))
if uri:
parsed = urlparse(uri)
uri_name = Path(unquote(parsed.path)).name
if uri_name:
return _safe_filename(uri_name)
extension = str(item.get("ext") or "bin").lstrip(".") or "bin"
return f"attachment.{extension}"
def _matching_attachments(
raw: dict[str, Any],
attachments: list[dict[str, Any]],
) -> list[dict[str, Any]]:
strings = [value.casefold() for value in _raw_string_values(raw)]
matched: list[dict[str, Any]] = []
seen_uris: set[str] = set()
for attachment in attachments:
name = str(attachment["name"])
internal_uri = str(attachment["internal_uri"])
if internal_uri in seen_uris:
continue
if not any(name.casefold() in value for value in strings):
continue
seen_uris.add(internal_uri)
matched.append(
{
"type": attachment["content_type"],
"name": name,
"internal_uri": internal_uri,
}
)
return matched
def _raw_string_values(value: Any, key: str | None = None) -> list[str]:
if key is not None and key.casefold() == "base64":
return []
if isinstance(value, str):
return [value]
if isinstance(value, dict):
strings: list[str] = []
for item_key, item_value in value.items():
strings.extend(_raw_string_values(item_value, str(item_key)))
return strings
if isinstance(value, list):
strings = []
for item in value:
strings.extend(_raw_string_values(item))
return strings
return []
def _chunks(items: list[str], size: int) -> list[list[str]]:
if not items:
return []