Files
livekit_agents/memory.py
2026-05-15 10:44:31 +08:00

293 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import asyncio
import json
import logging
import re
from typing import Any
import aiohttp
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
logger = logging.getLogger("memory-recall")
_LOCATION_STOPWORDS = {
"哪里",
"在哪",
"在哪里",
"哪儿",
"位置",
"什么地方",
"帮我找",
"帮我寻找",
"找一下",
"",
"请问",
"",
"",
"",
}
class MemoryRecallClient:
def __init__(
self,
*,
url: str,
timeout: float = 5.0,
max_chars: int = 2000,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
self._url = url
self._timeout = timeout
self._max_chars = max_chars
self._api_key = api_key
self._http_session = http_session
self._cached_payload: Any | None = None
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
async def recall(self, query: str) -> str:
query = query.strip()
if not query:
return ""
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
try:
async with self._ensure_session().get(
self._url,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self._timeout),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"Memory recall error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
try:
data = await resp.json()
except aiohttp.ContentTypeError:
data = await resp.text()
self._cached_payload = data
return self._format_memory(data, query)
except asyncio.TimeoutError:
logger.warning(
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
)
return self._format_cached_memory(query)
except aiohttp.ClientError as e:
logger.warning("Memory recall connection error: %s, using cached room graph", e)
return self._format_cached_memory(query)
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
logger.warning("Memory recall failed: %s, using cached room graph", e)
return self._format_cached_memory(query)
def _format_memory(self, data: Any, query: str) -> str:
memory = _format_room_graph_memory(data, query)
if len(memory) > self._max_chars:
memory = memory[: self._max_chars].rstrip()
return memory
def _format_cached_memory(self, query: str) -> str:
if self._cached_payload is None:
return ""
return self._format_memory(self._cached_payload, query)
def _format_room_graph_memory(payload: Any, query: str) -> str:
if not isinstance(payload, dict):
logger.warning("Unsupported room graph response: %s", payload)
return ""
objects = payload.get("objects", [])
relations = payload.get("relations", [])
summary = payload.get("summary", "")
if not objects and not relations and not summary:
return ""
query_terms = _query_terms(query)
relevant_objects, relevant_relations = _relevant_room_graph(
objects=objects,
relations=relations,
query_terms=query_terms,
)
objects_text = json.dumps(
relevant_objects or _compact_items(objects, limit=12),
ensure_ascii=False,
separators=(",", ":"),
)
relations_text = json.dumps(
relevant_relations or _compact_items(relations, limit=24),
ensure_ascii=False,
separators=(",", ":"),
)
prompt = f"""
你是一个物品定位助手。
目标物品:{query}
相关物品:{objects_text}
相关空间关系:{relations_text}
房间概览:{summary}
回答要求:
1. 只说明它和其他物品的位置关系。
2. 不要编造不存在的关系。
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
""".strip()
logger.info(
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
query_terms,
len(relevant_objects),
len(objects) if isinstance(objects, list) else 0,
len(relevant_relations),
len(relations) if isinstance(relations, list) else 0,
len(prompt),
)
return prompt
def _query_terms(query: str) -> list[str]:
normalized = re.sub(r"[\s?。!,、,.!]", "", query)
for word in _LOCATION_STOPWORDS:
normalized = normalized.replace(word, "")
terms = [normalized] if normalized else []
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
if token not in _LOCATION_STOPWORDS and token not in terms:
terms.append(token)
return terms[:4]
def _relevant_room_graph(
*,
objects: Any,
relations: Any,
query_terms: list[str],
) -> tuple[list[Any], list[Any]]:
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
return [], []
matched_ids: set[str] = set()
matched_objects: list[Any] = []
object_by_id: dict[str, Any] = {}
for obj in objects:
obj_id = _object_id(obj)
if obj_id:
object_by_id[obj_id] = obj
obj_text = _compact_text(obj)
if any(term and term in obj_text for term in query_terms):
matched_objects.append(obj)
if obj_id:
matched_ids.add(obj_id)
relevant_relations: list[Any] = []
related_ids: set[str] = set(matched_ids)
for relation in relations:
relation_text = _compact_text(relation)
relation_ids = _ids_in_value(relation)
if (
any(term and term in relation_text for term in query_terms)
or bool(matched_ids.intersection(relation_ids))
):
relevant_relations.append(relation)
related_ids.update(relation_ids)
relevant_objects = list(matched_objects)
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
for obj_id in related_ids:
obj = object_by_id.get(obj_id)
key = _object_key(obj)
if obj is not None and key not in seen_object_keys:
relevant_objects.append(obj)
seen_object_keys.add(key)
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
def _compact_items(items: Any, *, limit: int) -> list[Any]:
if not isinstance(items, list):
return []
return [_compact_item(item) for item in items[:limit]]
def _compact_item(item: Any) -> Any:
if not isinstance(item, dict):
return item
preferred_keys = (
"id",
"name",
"label",
"class",
"category",
"type",
"text",
"source",
"target",
"subject",
"object",
"relation",
"predicate",
"description",
)
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
return compact or item
def _object_id(obj: Any) -> str | None:
if not isinstance(obj, dict):
return None
for key in ("id", "object_id", "uuid", "name", "label"):
value = obj.get(key)
if isinstance(value, (str, int)):
return str(value)
return None
def _object_key(obj: Any) -> str:
return _object_id(obj) or _compact_text(obj)
def _ids_in_value(value: Any) -> set[str]:
ids: set[str] = set()
if isinstance(value, dict):
for key, item in value.items():
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
if isinstance(item, (str, int)):
ids.add(str(item))
elif isinstance(item, dict):
obj_id = _object_id(item)
if obj_id:
ids.add(obj_id)
ids.update(_ids_in_value(item))
elif isinstance(value, list):
for item in value:
ids.update(_ids_in_value(item))
return ids
def _compact_text(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))