from __future__ import annotations import asyncio import json import logging import re from typing import Any import aiohttp from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils logger = logging.getLogger("memory-recall") _LOCATION_STOPWORDS = { "哪里", "在哪", "在哪里", "哪儿", "位置", "什么地方", "帮我找", "帮我寻找", "找一下", "找", "请问", "请", "吗", "呢", } class MemoryRecallClient: def __init__( self, *, url: str, timeout: float = 5.0, max_chars: int = 2000, api_key: str | None = None, http_session: aiohttp.ClientSession | None = None, ) -> None: self._url = url self._timeout = timeout self._max_chars = max_chars self._api_key = api_key self._http_session = http_session self._cached_payload: Any | None = None def _ensure_session(self) -> aiohttp.ClientSession: if self._http_session is None: self._http_session = utils.http_context.http_session() return self._http_session async def recall(self, query: str) -> str: query = query.strip() if not query: return "" headers = {} if self._api_key: headers["Authorization"] = f"Bearer {self._api_key}" try: async with self._ensure_session().get( self._url, headers=headers, timeout=aiohttp.ClientTimeout(total=self._timeout), ) as resp: if resp.status != 200: error_text = await resp.text() raise APIStatusError( message=f"Memory recall error: {error_text}", status_code=resp.status, request_id=None, body=error_text, ) try: data = await resp.json() except aiohttp.ContentTypeError: data = await resp.text() self._cached_payload = data return self._format_memory(data, query) except asyncio.TimeoutError: logger.warning( "Memory recall timed out after %.1fs, using cached room graph", self._timeout ) return self._format_cached_memory(query) except aiohttp.ClientError as e: logger.warning("Memory recall connection error: %s, using cached room graph", e) return self._format_cached_memory(query) except (APIConnectionError, APIStatusError, APITimeoutError) as e: logger.warning("Memory recall failed: %s, using cached room graph", e) return self._format_cached_memory(query) def _format_memory(self, data: Any, query: str) -> str: memory = _format_room_graph_memory(data, query) if len(memory) > self._max_chars: memory = memory[: self._max_chars].rstrip() return memory def _format_cached_memory(self, query: str) -> str: if self._cached_payload is None: return "" return self._format_memory(self._cached_payload, query) def _format_room_graph_memory(payload: Any, query: str) -> str: if not isinstance(payload, dict): logger.warning("Unsupported room graph response: %s", payload) return "" objects = payload.get("objects", []) relations = payload.get("relations", []) summary = payload.get("summary", "") if not objects and not relations and not summary: return "" query_terms = _query_terms(query) relevant_objects, relevant_relations = _relevant_room_graph( objects=objects, relations=relations, query_terms=query_terms, ) objects_text = json.dumps( relevant_objects or _compact_items(objects, limit=12), ensure_ascii=False, separators=(",", ":"), ) relations_text = json.dumps( relevant_relations or _compact_items(relations, limit=24), ensure_ascii=False, separators=(",", ":"), ) prompt = f""" 你是一个物品定位助手。 目标物品:{query} 相关物品:{objects_text} 相关空间关系:{relations_text} 房间概览:{summary} 回答要求: 1. 只说明它和其他物品的位置关系。 2. 不要编造不存在的关系。 3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。 4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。” 5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。 6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。 7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。 """.strip() logger.info( "Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s", query_terms, len(relevant_objects), len(objects) if isinstance(objects, list) else 0, len(relevant_relations), len(relations) if isinstance(relations, list) else 0, len(prompt), ) return prompt def _query_terms(query: str) -> list[str]: normalized = re.sub(r"[\s??。!,、,.!]", "", query) for word in _LOCATION_STOPWORDS: normalized = normalized.replace(word, "") terms = [normalized] if normalized else [] for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query): if token not in _LOCATION_STOPWORDS and token not in terms: terms.append(token) return terms[:4] def _relevant_room_graph( *, objects: Any, relations: Any, query_terms: list[str], ) -> tuple[list[Any], list[Any]]: if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms: return [], [] matched_ids: set[str] = set() matched_objects: list[Any] = [] object_by_id: dict[str, Any] = {} for obj in objects: obj_id = _object_id(obj) if obj_id: object_by_id[obj_id] = obj obj_text = _compact_text(obj) if any(term and term in obj_text for term in query_terms): matched_objects.append(obj) if obj_id: matched_ids.add(obj_id) relevant_relations: list[Any] = [] related_ids: set[str] = set(matched_ids) for relation in relations: relation_text = _compact_text(relation) relation_ids = _ids_in_value(relation) if ( any(term and term in relation_text for term in query_terms) or bool(matched_ids.intersection(relation_ids)) ): relevant_relations.append(relation) related_ids.update(relation_ids) relevant_objects = list(matched_objects) seen_object_keys = {_object_key(obj) for obj in relevant_objects} for obj_id in related_ids: obj = object_by_id.get(obj_id) key = _object_key(obj) if obj is not None and key not in seen_object_keys: relevant_objects.append(obj) seen_object_keys.add(key) return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32) def _compact_items(items: Any, *, limit: int) -> list[Any]: if not isinstance(items, list): return [] return [_compact_item(item) for item in items[:limit]] def _compact_item(item: Any) -> Any: if not isinstance(item, dict): return item preferred_keys = ( "id", "name", "label", "class", "category", "type", "text", "source", "target", "subject", "object", "relation", "predicate", "description", ) compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")} return compact or item def _object_id(obj: Any) -> str | None: if not isinstance(obj, dict): return None for key in ("id", "object_id", "uuid", "name", "label"): value = obj.get(key) if isinstance(value, (str, int)): return str(value) return None def _object_key(obj: Any) -> str: return _object_id(obj) or _compact_text(obj) def _ids_in_value(value: Any) -> set[str]: ids: set[str] = set() if isinstance(value, dict): for key, item in value.items(): if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}: if isinstance(item, (str, int)): ids.add(str(item)) elif isinstance(item, dict): obj_id = _object_id(item) if obj_id: ids.add(obj_id) ids.update(_ids_in_value(item)) elif isinstance(value, list): for item in value: ids.update(_ids_in_value(item)) return ids def _compact_text(value: Any) -> str: return json.dumps(value, ensure_ascii=False, separators=(",", ":"))