feat: extend memory search and attachment mapping

This commit is contained in:
2026-06-15 17:25:44 +08:00
parent 15462a95cb
commit e5cd87789f
9 changed files with 1194 additions and 54 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import base64
import binascii
import hashlib
import mimetypes
import secrets
@ -63,6 +64,10 @@ class UnsupportedContentType(ValueError):
pass
class InvalidAttachment(ValueError):
pass
def _copy_upload(
file: UploadFile,
destination: Path,
@ -180,6 +185,7 @@ class MemoryGatewayService:
)
if existing is not None:
shutil.rmtree(stored_path.parent, ignore_errors=True)
self._register_resource_attachment(existing)
return self._resource_summary(existing)
internal_uri = stored_path.resolve().as_uri()
@ -202,6 +208,7 @@ class MemoryGatewayService:
status="ingesting",
error_message=None,
)
self._register_resource_attachment(resource)
try:
await self._retry_backend_call(
@ -346,10 +353,16 @@ class MemoryGatewayService:
self,
*,
user_id: str,
agent_id: str | None,
query: str,
conversation_id: str | None,
scope: list[str],
method: str,
top_k: int,
radius: float | None,
include_profile: bool,
enable_llm_rerank: bool,
filters: dict[str, Any] | None,
app_id: str,
project_id: str,
) -> dict[str, Any]:
@ -359,11 +372,19 @@ class MemoryGatewayService:
if "current_chat" in scope and conversation_id:
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters={"session_id": f"chat:{conversation_id}"},
filters=_combine_filters(
filters,
{"session_id": f"chat:{conversation_id}"},
),
)
results.extend(
self._extract_results(
@ -385,11 +406,19 @@ class MemoryGatewayService:
for batch in _chunks(session_ids, self.config.resource_search_batch_size):
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters={"session_id": {"in": batch}},
filters=_combine_filters(
filters,
{"session_id": {"in": batch}},
),
)
results.extend(
self._extract_results(
@ -403,11 +432,16 @@ class MemoryGatewayService:
if "all_user_memory" in scope:
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
query=query,
method=method,
top_k=top_k,
radius=radius,
include_profile=include_profile,
enable_llm_rerank=enable_llm_rerank,
app_id=app_id,
project_id=project_id,
filters=None,
filters=filters,
)
results.extend(
self._extract_results(
@ -425,21 +459,126 @@ class MemoryGatewayService:
async def add_memory(
self,
*,
user_id: str,
session_id: str,
app_id: str,
project_id: str,
messages: list[dict[str, Any]],
) -> dict[str, Any]:
attachments, generated_paths = self._prepare_memory_attachments(
user_id=user_id,
session_id=session_id,
app_id=app_id,
project_id=project_id,
messages=messages,
)
payload = {
"session_id": session_id,
"app_id": app_id,
"project_id": project_id,
"messages": messages,
}
return {
"session_id": session_id,
"backend": await self.backend_client.add_memory(payload),
}
try:
backend = await self.backend_client.add_memory(payload)
for attachment in attachments:
self.repository.create_attachment(**attachment)
except Exception:
for path in generated_paths:
path.unlink(missing_ok=True)
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
raise
return {"session_id": session_id, "backend": backend}
def _register_resource_attachment(self, resource: dict[str, Any]) -> None:
self.repository.create_attachment(
user_id=resource["user_id"],
app_id=resource["app_id"],
project_id=resource["project_id"],
session_id=resource["session_id"],
resource_id=resource["id"],
content_type=resource["content_type"],
name=resource["original_filename"] or resource["id"],
internal_uri=resource["uri"],
source="resource_upload",
sha256=resource["sha256"],
)
def _prepare_memory_attachments(
self,
*,
user_id: str,
session_id: str,
app_id: str,
project_id: str,
messages: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], list[Path]]:
attachments: list[dict[str, Any]] = []
generated_paths: list[Path] = []
try:
for message in messages:
content = message.get("content")
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict):
continue
uri = item.get("uri")
encoded = item.get("base64")
if not uri and not encoded:
continue
attachment_id = f"a_{uuid.uuid4().hex}"
name = _attachment_name(item, str(uri) if uri else None)
sha256: str | None = None
if uri:
internal_uri = str(uri)
source = "memory_add_uri"
else:
try:
data = base64.b64decode(str(encoded), validate=True)
except (binascii.Error, ValueError) as exc:
raise InvalidAttachment(
f"invalid base64 attachment: {name}"
) from exc
if len(data) > self.config.max_upload_bytes:
raise UploadTooLarge(
f"attachment exceeds max size of "
f"{self.config.max_upload_bytes} bytes"
)
sha256 = hashlib.sha256(data).hexdigest()
path = (
self.config.storage_dir
/ user_id
/ "memory_attachments"
/ sha256
/ name
)
if not path.exists():
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(data)
generated_paths.append(path)
internal_uri = path.resolve().as_uri()
source = "memory_add_base64"
attachments.append(
{
"id": attachment_id,
"user_id": user_id,
"app_id": app_id,
"project_id": project_id,
"session_id": session_id,
"resource_id": None,
"content_type": str(item.get("type") or "doc"),
"name": name,
"internal_uri": internal_uri,
"source": source,
"sha256": sha256,
}
)
except Exception:
for path in generated_paths:
path.unlink(missing_ok=True)
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
raise
return attachments, generated_paths
async def flush_memory(
self,
@ -461,19 +600,29 @@ class MemoryGatewayService:
self,
*,
user_id: str,
agent_id: str | None,
query: str,
method: str,
top_k: int,
radius: float | None,
include_profile: bool,
enable_llm_rerank: bool,
app_id: str,
project_id: str,
filters: dict[str, Any] | None,
) -> dict[str, Any]:
payload: dict[str, Any] = {
"user_id": user_id,
"query": query,
"method": method,
"top_k": top_k,
"include_profile": include_profile,
"enable_llm_rerank": enable_llm_rerank,
"app_id": app_id,
"project_id": project_id,
}
payload["agent_id" if agent_id else "user_id"] = agent_id or user_id
if radius is not None:
payload["radius"] = radius
if filters is not None:
payload["filters"] = filters
return payload
@ -487,18 +636,22 @@ class MemoryGatewayService:
user_id: str,
) -> list[dict[str, Any]]:
data = response.get("data", {})
raw_items: list[dict[str, Any]] = []
for key in (
"episodes",
"profiles",
"agent_cases",
"agent_skills",
"unprocessed_messages",
):
raw_items.extend(data.get(key, []) or [])
raw_items: list[tuple[str, dict[str, Any]]] = []
memory_types = {
"episodes": "episode",
"profiles": "profile",
"agent_cases": "agent_case",
"agent_skills": "agent_skill",
"unprocessed_messages": "unprocessed_message",
}
for key, memory_type in memory_types.items():
raw_items.extend(
(memory_type, item) for item in (data.get(key, []) or [])
)
normalized = []
for raw in raw_items:
attachment_cache: dict[str, list[dict[str, Any]]] = {}
for memory_type, raw in raw_items:
session_id = raw.get("session_id")
resource = session_resource_map.get(session_id)
if resource is None and isinstance(session_id, str):
@ -506,9 +659,21 @@ class MemoryGatewayService:
session_id,
user_id,
)
attachments: list[dict[str, Any]] = []
if isinstance(session_id, str):
if session_id not in attachment_cache:
attachment_cache[session_id] = (
self.repository.list_attachments_for_session(
user_id,
session_id,
)
)
session_attachments = attachment_cache[session_id]
attachments = _matching_attachments(raw, session_attachments)
normalized.append(
{
"id": raw.get("id"),
"memory_type": memory_type,
"session_id": session_id,
"text": _display_text(raw),
"score": raw.get("score"),
@ -517,6 +682,7 @@ class MemoryGatewayService:
"resource_uri": (
public_resource_uri(user_id, resource["id"]) if resource else None
),
"attachments": attachments,
"raw": raw,
}
)
@ -623,6 +789,72 @@ class MemoryGatewayService:
}
def _combine_filters(
custom_filters: dict[str, Any] | None,
scope_filters: dict[str, Any] | None,
) -> dict[str, Any] | None:
if custom_filters is None:
return scope_filters
if scope_filters is None:
return custom_filters
return {"AND": [custom_filters, scope_filters]}
def _attachment_name(item: dict[str, Any], uri: str | None) -> str:
if item.get("name"):
return _safe_filename(str(item["name"]))
if uri:
parsed = urlparse(uri)
uri_name = Path(unquote(parsed.path)).name
if uri_name:
return _safe_filename(uri_name)
extension = str(item.get("ext") or "bin").lstrip(".") or "bin"
return f"attachment.{extension}"
def _matching_attachments(
raw: dict[str, Any],
attachments: list[dict[str, Any]],
) -> list[dict[str, Any]]:
strings = [value.casefold() for value in _raw_string_values(raw)]
matched: list[dict[str, Any]] = []
seen_uris: set[str] = set()
for attachment in attachments:
name = str(attachment["name"])
internal_uri = str(attachment["internal_uri"])
if internal_uri in seen_uris:
continue
if not any(name.casefold() in value for value in strings):
continue
seen_uris.add(internal_uri)
matched.append(
{
"type": attachment["content_type"],
"name": name,
"internal_uri": internal_uri,
}
)
return matched
def _raw_string_values(value: Any, key: str | None = None) -> list[str]:
if key is not None and key.casefold() == "base64":
return []
if isinstance(value, str):
return [value]
if isinstance(value, dict):
strings: list[str] = []
for item_key, item_value in value.items():
strings.extend(_raw_string_values(item_value, str(item_key)))
return strings
if isinstance(value, list):
strings = []
for item in value:
strings.extend(_raw_string_values(item))
return strings
return []
def _chunks(items: list[str], size: int) -> list[list[str]]:
if not items:
return []