138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from typing import Any
|
||
|
||
import aiohttp
|
||
|
||
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
|
||
|
||
logger = logging.getLogger("memory-recall")
|
||
|
||
|
||
class MemoryRecallClient:
|
||
def __init__(
|
||
self,
|
||
*,
|
||
url: str,
|
||
timeout: float = 5.0,
|
||
max_chars: int = 2000,
|
||
api_key: str | None = None,
|
||
http_session: aiohttp.ClientSession | None = None,
|
||
) -> None:
|
||
self._url = url
|
||
self._timeout = timeout
|
||
self._max_chars = max_chars
|
||
self._api_key = api_key
|
||
self._http_session = http_session
|
||
self._cached_payload: Any | None = None
|
||
|
||
def _ensure_session(self) -> aiohttp.ClientSession:
|
||
if self._http_session is None:
|
||
self._http_session = utils.http_context.http_session()
|
||
return self._http_session
|
||
|
||
async def recall(self, query: str) -> str:
|
||
query = query.strip()
|
||
if not query:
|
||
return ""
|
||
|
||
headers = {}
|
||
if self._api_key:
|
||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||
|
||
try:
|
||
async with self._ensure_session().get(
|
||
self._url,
|
||
headers=headers,
|
||
timeout=aiohttp.ClientTimeout(total=self._timeout),
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
raise APIStatusError(
|
||
message=f"Memory recall error: {error_text}",
|
||
status_code=resp.status,
|
||
request_id=None,
|
||
body=error_text,
|
||
)
|
||
|
||
try:
|
||
data = await resp.json()
|
||
except aiohttp.ContentTypeError:
|
||
data = await resp.text()
|
||
|
||
self._cached_payload = data
|
||
return self._format_memory(data, query)
|
||
except asyncio.TimeoutError:
|
||
logger.warning(
|
||
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
|
||
)
|
||
return self._format_cached_memory(query)
|
||
except aiohttp.ClientError as e:
|
||
logger.warning("Memory recall connection error: %s, using cached room graph", e)
|
||
return self._format_cached_memory(query)
|
||
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
|
||
logger.warning("Memory recall failed: %s, using cached room graph", e)
|
||
return self._format_cached_memory(query)
|
||
|
||
def _format_memory(self, data: Any, query: str) -> str:
|
||
memory = _format_room_graph_memory(data, query)
|
||
if len(memory) > self._max_chars:
|
||
memory = memory[: self._max_chars].rstrip()
|
||
return memory
|
||
|
||
def _format_cached_memory(self, query: str) -> str:
|
||
if self._cached_payload is None:
|
||
return ""
|
||
return self._format_memory(self._cached_payload, query)
|
||
|
||
|
||
def _format_room_graph_memory(payload: Any, query: str) -> str:
|
||
if not isinstance(payload, dict):
|
||
logger.warning("Unsupported room graph response: %s", payload)
|
||
return ""
|
||
objects = payload.get("objects", [])
|
||
relations = payload.get("relations", [])
|
||
summary = payload.get("summary", "")
|
||
usage_hint = payload.get("usage_hint", "")
|
||
|
||
if not objects and not relations and not summary:
|
||
return ""
|
||
|
||
objects_text = json.dumps(objects, ensure_ascii=False, indent=2)
|
||
relations_text = json.dumps(relations, ensure_ascii=False, indent=2)
|
||
|
||
prompt = f"""
|
||
你是一个物品定位助手。
|
||
|
||
我的房间内有以下物品信息:
|
||
|
||
{objects_text}
|
||
|
||
这些物品之间的空间关系如下:
|
||
|
||
{relations_text}
|
||
|
||
房间概览如下:
|
||
|
||
{summary}
|
||
|
||
现在我要找的目标物品是:{query}
|
||
|
||
请根据上面的 objects、relations 和 summary,告诉我它在哪里。
|
||
|
||
回答要求:
|
||
1. 只说明它和其他物品的位置关系。
|
||
2. 不要编造不存在的关系。
|
||
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
|
||
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
|
||
5. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
|
||
""".strip()
|
||
|
||
if usage_hint:
|
||
prompt += f"\n\n接口使用提示:\n{usage_hint}"
|
||
|
||
return prompt
|