feat: extend memory search and attachment mapping
This commit is contained in:
268
core/service.py
268
core/service.py
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import secrets
|
||||
@ -63,6 +64,10 @@ class UnsupportedContentType(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidAttachment(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _copy_upload(
|
||||
file: UploadFile,
|
||||
destination: Path,
|
||||
@ -180,6 +185,7 @@ class MemoryGatewayService:
|
||||
)
|
||||
if existing is not None:
|
||||
shutil.rmtree(stored_path.parent, ignore_errors=True)
|
||||
self._register_resource_attachment(existing)
|
||||
return self._resource_summary(existing)
|
||||
|
||||
internal_uri = stored_path.resolve().as_uri()
|
||||
@ -202,6 +208,7 @@ class MemoryGatewayService:
|
||||
status="ingesting",
|
||||
error_message=None,
|
||||
)
|
||||
self._register_resource_attachment(resource)
|
||||
|
||||
try:
|
||||
await self._retry_backend_call(
|
||||
@ -346,10 +353,16 @@ class MemoryGatewayService:
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
agent_id: str | None,
|
||||
query: str,
|
||||
conversation_id: str | None,
|
||||
scope: list[str],
|
||||
method: str,
|
||||
top_k: int,
|
||||
radius: float | None,
|
||||
include_profile: bool,
|
||||
enable_llm_rerank: bool,
|
||||
filters: dict[str, Any] | None,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
) -> dict[str, Any]:
|
||||
@ -359,11 +372,19 @@ class MemoryGatewayService:
|
||||
if "current_chat" in scope and conversation_id:
|
||||
payload = self._search_payload(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
query=query,
|
||||
method=method,
|
||||
top_k=top_k,
|
||||
radius=radius,
|
||||
include_profile=include_profile,
|
||||
enable_llm_rerank=enable_llm_rerank,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
filters={"session_id": f"chat:{conversation_id}"},
|
||||
filters=_combine_filters(
|
||||
filters,
|
||||
{"session_id": f"chat:{conversation_id}"},
|
||||
),
|
||||
)
|
||||
results.extend(
|
||||
self._extract_results(
|
||||
@ -385,11 +406,19 @@ class MemoryGatewayService:
|
||||
for batch in _chunks(session_ids, self.config.resource_search_batch_size):
|
||||
payload = self._search_payload(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
query=query,
|
||||
method=method,
|
||||
top_k=top_k,
|
||||
radius=radius,
|
||||
include_profile=include_profile,
|
||||
enable_llm_rerank=enable_llm_rerank,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
filters={"session_id": {"in": batch}},
|
||||
filters=_combine_filters(
|
||||
filters,
|
||||
{"session_id": {"in": batch}},
|
||||
),
|
||||
)
|
||||
results.extend(
|
||||
self._extract_results(
|
||||
@ -403,11 +432,16 @@ class MemoryGatewayService:
|
||||
if "all_user_memory" in scope:
|
||||
payload = self._search_payload(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
query=query,
|
||||
method=method,
|
||||
top_k=top_k,
|
||||
radius=radius,
|
||||
include_profile=include_profile,
|
||||
enable_llm_rerank=enable_llm_rerank,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
filters=None,
|
||||
filters=filters,
|
||||
)
|
||||
results.extend(
|
||||
self._extract_results(
|
||||
@ -425,21 +459,126 @@ class MemoryGatewayService:
|
||||
async def add_memory(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
attachments, generated_paths = self._prepare_memory_attachments(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
app_id=app_id,
|
||||
project_id=project_id,
|
||||
messages=messages,
|
||||
)
|
||||
payload = {
|
||||
"session_id": session_id,
|
||||
"app_id": app_id,
|
||||
"project_id": project_id,
|
||||
"messages": messages,
|
||||
}
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"backend": await self.backend_client.add_memory(payload),
|
||||
}
|
||||
try:
|
||||
backend = await self.backend_client.add_memory(payload)
|
||||
for attachment in attachments:
|
||||
self.repository.create_attachment(**attachment)
|
||||
except Exception:
|
||||
for path in generated_paths:
|
||||
path.unlink(missing_ok=True)
|
||||
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
|
||||
raise
|
||||
return {"session_id": session_id, "backend": backend}
|
||||
|
||||
def _register_resource_attachment(self, resource: dict[str, Any]) -> None:
|
||||
self.repository.create_attachment(
|
||||
user_id=resource["user_id"],
|
||||
app_id=resource["app_id"],
|
||||
project_id=resource["project_id"],
|
||||
session_id=resource["session_id"],
|
||||
resource_id=resource["id"],
|
||||
content_type=resource["content_type"],
|
||||
name=resource["original_filename"] or resource["id"],
|
||||
internal_uri=resource["uri"],
|
||||
source="resource_upload",
|
||||
sha256=resource["sha256"],
|
||||
)
|
||||
|
||||
def _prepare_memory_attachments(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[Path]]:
|
||||
attachments: list[dict[str, Any]] = []
|
||||
generated_paths: list[Path] = []
|
||||
try:
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
uri = item.get("uri")
|
||||
encoded = item.get("base64")
|
||||
if not uri and not encoded:
|
||||
continue
|
||||
attachment_id = f"a_{uuid.uuid4().hex}"
|
||||
name = _attachment_name(item, str(uri) if uri else None)
|
||||
sha256: str | None = None
|
||||
if uri:
|
||||
internal_uri = str(uri)
|
||||
source = "memory_add_uri"
|
||||
else:
|
||||
try:
|
||||
data = base64.b64decode(str(encoded), validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
raise InvalidAttachment(
|
||||
f"invalid base64 attachment: {name}"
|
||||
) from exc
|
||||
if len(data) > self.config.max_upload_bytes:
|
||||
raise UploadTooLarge(
|
||||
f"attachment exceeds max size of "
|
||||
f"{self.config.max_upload_bytes} bytes"
|
||||
)
|
||||
sha256 = hashlib.sha256(data).hexdigest()
|
||||
path = (
|
||||
self.config.storage_dir
|
||||
/ user_id
|
||||
/ "memory_attachments"
|
||||
/ sha256
|
||||
/ name
|
||||
)
|
||||
if not path.exists():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
generated_paths.append(path)
|
||||
internal_uri = path.resolve().as_uri()
|
||||
source = "memory_add_base64"
|
||||
attachments.append(
|
||||
{
|
||||
"id": attachment_id,
|
||||
"user_id": user_id,
|
||||
"app_id": app_id,
|
||||
"project_id": project_id,
|
||||
"session_id": session_id,
|
||||
"resource_id": None,
|
||||
"content_type": str(item.get("type") or "doc"),
|
||||
"name": name,
|
||||
"internal_uri": internal_uri,
|
||||
"source": source,
|
||||
"sha256": sha256,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
for path in generated_paths:
|
||||
path.unlink(missing_ok=True)
|
||||
_remove_empty_parents(path.parent, stop_at=self.config.storage_dir)
|
||||
raise
|
||||
return attachments, generated_paths
|
||||
|
||||
async def flush_memory(
|
||||
self,
|
||||
@ -461,19 +600,29 @@ class MemoryGatewayService:
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
agent_id: str | None,
|
||||
query: str,
|
||||
method: str,
|
||||
top_k: int,
|
||||
radius: float | None,
|
||||
include_profile: bool,
|
||||
enable_llm_rerank: bool,
|
||||
app_id: str,
|
||||
project_id: str,
|
||||
filters: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"query": query,
|
||||
"method": method,
|
||||
"top_k": top_k,
|
||||
"include_profile": include_profile,
|
||||
"enable_llm_rerank": enable_llm_rerank,
|
||||
"app_id": app_id,
|
||||
"project_id": project_id,
|
||||
}
|
||||
payload["agent_id" if agent_id else "user_id"] = agent_id or user_id
|
||||
if radius is not None:
|
||||
payload["radius"] = radius
|
||||
if filters is not None:
|
||||
payload["filters"] = filters
|
||||
return payload
|
||||
@ -487,18 +636,22 @@ class MemoryGatewayService:
|
||||
user_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
data = response.get("data", {})
|
||||
raw_items: list[dict[str, Any]] = []
|
||||
for key in (
|
||||
"episodes",
|
||||
"profiles",
|
||||
"agent_cases",
|
||||
"agent_skills",
|
||||
"unprocessed_messages",
|
||||
):
|
||||
raw_items.extend(data.get(key, []) or [])
|
||||
raw_items: list[tuple[str, dict[str, Any]]] = []
|
||||
memory_types = {
|
||||
"episodes": "episode",
|
||||
"profiles": "profile",
|
||||
"agent_cases": "agent_case",
|
||||
"agent_skills": "agent_skill",
|
||||
"unprocessed_messages": "unprocessed_message",
|
||||
}
|
||||
for key, memory_type in memory_types.items():
|
||||
raw_items.extend(
|
||||
(memory_type, item) for item in (data.get(key, []) or [])
|
||||
)
|
||||
|
||||
normalized = []
|
||||
for raw in raw_items:
|
||||
attachment_cache: dict[str, list[dict[str, Any]]] = {}
|
||||
for memory_type, raw in raw_items:
|
||||
session_id = raw.get("session_id")
|
||||
resource = session_resource_map.get(session_id)
|
||||
if resource is None and isinstance(session_id, str):
|
||||
@ -506,9 +659,21 @@ class MemoryGatewayService:
|
||||
session_id,
|
||||
user_id,
|
||||
)
|
||||
attachments: list[dict[str, Any]] = []
|
||||
if isinstance(session_id, str):
|
||||
if session_id not in attachment_cache:
|
||||
attachment_cache[session_id] = (
|
||||
self.repository.list_attachments_for_session(
|
||||
user_id,
|
||||
session_id,
|
||||
)
|
||||
)
|
||||
session_attachments = attachment_cache[session_id]
|
||||
attachments = _matching_attachments(raw, session_attachments)
|
||||
normalized.append(
|
||||
{
|
||||
"id": raw.get("id"),
|
||||
"memory_type": memory_type,
|
||||
"session_id": session_id,
|
||||
"text": _display_text(raw),
|
||||
"score": raw.get("score"),
|
||||
@ -517,6 +682,7 @@ class MemoryGatewayService:
|
||||
"resource_uri": (
|
||||
public_resource_uri(user_id, resource["id"]) if resource else None
|
||||
),
|
||||
"attachments": attachments,
|
||||
"raw": raw,
|
||||
}
|
||||
)
|
||||
@ -623,6 +789,72 @@ class MemoryGatewayService:
|
||||
}
|
||||
|
||||
|
||||
def _combine_filters(
|
||||
custom_filters: dict[str, Any] | None,
|
||||
scope_filters: dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
if custom_filters is None:
|
||||
return scope_filters
|
||||
if scope_filters is None:
|
||||
return custom_filters
|
||||
return {"AND": [custom_filters, scope_filters]}
|
||||
|
||||
|
||||
def _attachment_name(item: dict[str, Any], uri: str | None) -> str:
|
||||
if item.get("name"):
|
||||
return _safe_filename(str(item["name"]))
|
||||
if uri:
|
||||
parsed = urlparse(uri)
|
||||
uri_name = Path(unquote(parsed.path)).name
|
||||
if uri_name:
|
||||
return _safe_filename(uri_name)
|
||||
extension = str(item.get("ext") or "bin").lstrip(".") or "bin"
|
||||
return f"attachment.{extension}"
|
||||
|
||||
|
||||
def _matching_attachments(
|
||||
raw: dict[str, Any],
|
||||
attachments: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
strings = [value.casefold() for value in _raw_string_values(raw)]
|
||||
matched: list[dict[str, Any]] = []
|
||||
seen_uris: set[str] = set()
|
||||
for attachment in attachments:
|
||||
name = str(attachment["name"])
|
||||
internal_uri = str(attachment["internal_uri"])
|
||||
if internal_uri in seen_uris:
|
||||
continue
|
||||
if not any(name.casefold() in value for value in strings):
|
||||
continue
|
||||
seen_uris.add(internal_uri)
|
||||
matched.append(
|
||||
{
|
||||
"type": attachment["content_type"],
|
||||
"name": name,
|
||||
"internal_uri": internal_uri,
|
||||
}
|
||||
)
|
||||
return matched
|
||||
|
||||
|
||||
def _raw_string_values(value: Any, key: str | None = None) -> list[str]:
|
||||
if key is not None and key.casefold() == "base64":
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
if isinstance(value, dict):
|
||||
strings: list[str] = []
|
||||
for item_key, item_value in value.items():
|
||||
strings.extend(_raw_string_values(item_value, str(item_key)))
|
||||
return strings
|
||||
if isinstance(value, list):
|
||||
strings = []
|
||||
for item in value:
|
||||
strings.extend(_raw_string_values(item))
|
||||
return strings
|
||||
return []
|
||||
|
||||
|
||||
def _chunks(items: list[str], size: int) -> list[list[str]]:
|
||||
if not items:
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user