Files
livekit_agents/memory.py
2026-05-14 10:16:08 +08:00

138 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
import aiohttp
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
logger = logging.getLogger("memory-recall")
class MemoryRecallClient:
def __init__(
self,
*,
url: str,
timeout: float = 5.0,
max_chars: int = 2000,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
self._url = url
self._timeout = timeout
self._max_chars = max_chars
self._api_key = api_key
self._http_session = http_session
self._cached_payload: Any | None = None
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
async def recall(self, query: str) -> str:
query = query.strip()
if not query:
return ""
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
try:
async with self._ensure_session().get(
self._url,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self._timeout),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"Memory recall error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
try:
data = await resp.json()
except aiohttp.ContentTypeError:
data = await resp.text()
self._cached_payload = data
return self._format_memory(data, query)
except asyncio.TimeoutError:
logger.warning(
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
)
return self._format_cached_memory(query)
except aiohttp.ClientError as e:
logger.warning("Memory recall connection error: %s, using cached room graph", e)
return self._format_cached_memory(query)
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
logger.warning("Memory recall failed: %s, using cached room graph", e)
return self._format_cached_memory(query)
def _format_memory(self, data: Any, query: str) -> str:
memory = _format_room_graph_memory(data, query)
if len(memory) > self._max_chars:
memory = memory[: self._max_chars].rstrip()
return memory
def _format_cached_memory(self, query: str) -> str:
if self._cached_payload is None:
return ""
return self._format_memory(self._cached_payload, query)
def _format_room_graph_memory(payload: Any, query: str) -> str:
if not isinstance(payload, dict):
logger.warning("Unsupported room graph response: %s", payload)
return ""
objects = payload.get("objects", [])
relations = payload.get("relations", [])
summary = payload.get("summary", "")
usage_hint = payload.get("usage_hint", "")
if not objects and not relations and not summary:
return ""
objects_text = json.dumps(objects, ensure_ascii=False, indent=2)
relations_text = json.dumps(relations, ensure_ascii=False, indent=2)
prompt = f"""
你是一个物品定位助手。
我的房间内有以下物品信息:
{objects_text}
这些物品之间的空间关系如下:
{relations_text}
房间概览如下:
{summary}
现在我要找的目标物品是:{query}
请根据上面的 objects、relations 和 summary告诉我它在哪里。
回答要求:
1. 只说明它和其他物品的位置关系。
2. 不要编造不存在的关系。
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
5. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
""".strip()
if usage_hint:
prompt += f"\n\n接口使用提示:\n{usage_hint}"
return prompt