from __future__ import annotations import asyncio import json import logging from typing import Any import aiohttp from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils logger = logging.getLogger("memory-recall") class MemoryRecallClient: def __init__( self, *, url: str, timeout: float = 5.0, max_chars: int = 2000, api_key: str | None = None, http_session: aiohttp.ClientSession | None = None, ) -> None: self._url = url self._timeout = timeout self._max_chars = max_chars self._api_key = api_key self._http_session = http_session self._cached_payload: Any | None = None def _ensure_session(self) -> aiohttp.ClientSession: if self._http_session is None: self._http_session = utils.http_context.http_session() return self._http_session async def recall(self, query: str) -> str: query = query.strip() if not query: return "" headers = {} if self._api_key: headers["Authorization"] = f"Bearer {self._api_key}" try: async with self._ensure_session().get( self._url, headers=headers, timeout=aiohttp.ClientTimeout(total=self._timeout), ) as resp: if resp.status != 200: error_text = await resp.text() raise APIStatusError( message=f"Memory recall error: {error_text}", status_code=resp.status, request_id=None, body=error_text, ) try: data = await resp.json() except aiohttp.ContentTypeError: data = await resp.text() self._cached_payload = data return self._format_memory(data, query) except asyncio.TimeoutError: logger.warning( "Memory recall timed out after %.1fs, using cached room graph", self._timeout ) return self._format_cached_memory(query) except aiohttp.ClientError as e: logger.warning("Memory recall connection error: %s, using cached room graph", e) return self._format_cached_memory(query) except (APIConnectionError, APIStatusError, APITimeoutError) as e: logger.warning("Memory recall failed: %s, using cached room graph", e) return self._format_cached_memory(query) def _format_memory(self, data: Any, query: str) -> str: memory = _format_room_graph_memory(data, query) if len(memory) > self._max_chars: memory = memory[: self._max_chars].rstrip() return memory def _format_cached_memory(self, query: str) -> str: if self._cached_payload is None: return "" return self._format_memory(self._cached_payload, query) def _format_room_graph_memory(payload: Any, query: str) -> str: if not isinstance(payload, dict): logger.warning("Unsupported room graph response: %s", payload) return "" objects = payload.get("objects", []) relations = payload.get("relations", []) summary = payload.get("summary", "") if not objects and not relations and not summary: return "" objects_text = json.dumps(objects, ensure_ascii=False, indent=2) relations_text = json.dumps(relations, ensure_ascii=False, indent=2) prompt = f""" 你是一个物品定位助手。 我的房间内有以下物品信息: {objects_text} 这些物品之间的空间关系如下: {relations_text} 房间概览如下: {summary} 现在我要找的目标物品是:{query} 请根据上面的 objects、relations 和 summary,告诉我它在哪里。 回答要求: 1. 只说明它和其他物品的位置关系。 2. 不要编造不存在的关系。 3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。 4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。” 5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。 6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。 7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。 """.strip() return prompt