293 lines
9.0 KiB
Python
293 lines
9.0 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Any
|
||
|
||
import aiohttp
|
||
|
||
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
|
||
|
||
logger = logging.getLogger("memory-recall")
|
||
|
||
_LOCATION_STOPWORDS = {
|
||
"哪里",
|
||
"在哪",
|
||
"在哪里",
|
||
"哪儿",
|
||
"位置",
|
||
"什么地方",
|
||
"帮我找",
|
||
"帮我寻找",
|
||
"找一下",
|
||
"找",
|
||
"请问",
|
||
"请",
|
||
"吗",
|
||
"呢",
|
||
}
|
||
|
||
|
||
class MemoryRecallClient:
|
||
def __init__(
|
||
self,
|
||
*,
|
||
url: str,
|
||
timeout: float = 5.0,
|
||
max_chars: int = 2000,
|
||
api_key: str | None = None,
|
||
http_session: aiohttp.ClientSession | None = None,
|
||
) -> None:
|
||
self._url = url
|
||
self._timeout = timeout
|
||
self._max_chars = max_chars
|
||
self._api_key = api_key
|
||
self._http_session = http_session
|
||
self._cached_payload: Any | None = None
|
||
|
||
def _ensure_session(self) -> aiohttp.ClientSession:
|
||
if self._http_session is None:
|
||
self._http_session = utils.http_context.http_session()
|
||
return self._http_session
|
||
|
||
async def recall(self, query: str) -> str:
|
||
query = query.strip()
|
||
if not query:
|
||
return ""
|
||
|
||
headers = {}
|
||
if self._api_key:
|
||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||
|
||
try:
|
||
async with self._ensure_session().get(
|
||
self._url,
|
||
headers=headers,
|
||
timeout=aiohttp.ClientTimeout(total=self._timeout),
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
raise APIStatusError(
|
||
message=f"Memory recall error: {error_text}",
|
||
status_code=resp.status,
|
||
request_id=None,
|
||
body=error_text,
|
||
)
|
||
|
||
try:
|
||
data = await resp.json()
|
||
except aiohttp.ContentTypeError:
|
||
data = await resp.text()
|
||
|
||
self._cached_payload = data
|
||
return self._format_memory(data, query)
|
||
except asyncio.TimeoutError:
|
||
logger.warning(
|
||
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
|
||
)
|
||
return self._format_cached_memory(query)
|
||
except aiohttp.ClientError as e:
|
||
logger.warning("Memory recall connection error: %s, using cached room graph", e)
|
||
return self._format_cached_memory(query)
|
||
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
|
||
logger.warning("Memory recall failed: %s, using cached room graph", e)
|
||
return self._format_cached_memory(query)
|
||
|
||
def _format_memory(self, data: Any, query: str) -> str:
|
||
memory = _format_room_graph_memory(data, query)
|
||
if len(memory) > self._max_chars:
|
||
memory = memory[: self._max_chars].rstrip()
|
||
return memory
|
||
|
||
def _format_cached_memory(self, query: str) -> str:
|
||
if self._cached_payload is None:
|
||
return ""
|
||
return self._format_memory(self._cached_payload, query)
|
||
|
||
|
||
def _format_room_graph_memory(payload: Any, query: str) -> str:
|
||
if not isinstance(payload, dict):
|
||
logger.warning("Unsupported room graph response: %s", payload)
|
||
return ""
|
||
objects = payload.get("objects", [])
|
||
relations = payload.get("relations", [])
|
||
summary = payload.get("summary", "")
|
||
|
||
if not objects and not relations and not summary:
|
||
return ""
|
||
|
||
query_terms = _query_terms(query)
|
||
relevant_objects, relevant_relations = _relevant_room_graph(
|
||
objects=objects,
|
||
relations=relations,
|
||
query_terms=query_terms,
|
||
)
|
||
|
||
objects_text = json.dumps(
|
||
relevant_objects or _compact_items(objects, limit=12),
|
||
ensure_ascii=False,
|
||
separators=(",", ":"),
|
||
)
|
||
relations_text = json.dumps(
|
||
relevant_relations or _compact_items(relations, limit=24),
|
||
ensure_ascii=False,
|
||
separators=(",", ":"),
|
||
)
|
||
|
||
prompt = f"""
|
||
你是一个物品定位助手。
|
||
|
||
目标物品:{query}
|
||
相关物品:{objects_text}
|
||
相关空间关系:{relations_text}
|
||
房间概览:{summary}
|
||
|
||
回答要求:
|
||
1. 只说明它和其他物品的位置关系。
|
||
2. 不要编造不存在的关系。
|
||
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
|
||
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
|
||
5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。
|
||
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
|
||
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
|
||
""".strip()
|
||
|
||
logger.info(
|
||
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
|
||
query_terms,
|
||
len(relevant_objects),
|
||
len(objects) if isinstance(objects, list) else 0,
|
||
len(relevant_relations),
|
||
len(relations) if isinstance(relations, list) else 0,
|
||
len(prompt),
|
||
)
|
||
return prompt
|
||
|
||
|
||
def _query_terms(query: str) -> list[str]:
|
||
normalized = re.sub(r"[\s??。!,、,.!]", "", query)
|
||
for word in _LOCATION_STOPWORDS:
|
||
normalized = normalized.replace(word, "")
|
||
|
||
terms = [normalized] if normalized else []
|
||
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
|
||
if token not in _LOCATION_STOPWORDS and token not in terms:
|
||
terms.append(token)
|
||
return terms[:4]
|
||
|
||
|
||
def _relevant_room_graph(
|
||
*,
|
||
objects: Any,
|
||
relations: Any,
|
||
query_terms: list[str],
|
||
) -> tuple[list[Any], list[Any]]:
|
||
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
|
||
return [], []
|
||
|
||
matched_ids: set[str] = set()
|
||
matched_objects: list[Any] = []
|
||
object_by_id: dict[str, Any] = {}
|
||
|
||
for obj in objects:
|
||
obj_id = _object_id(obj)
|
||
if obj_id:
|
||
object_by_id[obj_id] = obj
|
||
|
||
obj_text = _compact_text(obj)
|
||
if any(term and term in obj_text for term in query_terms):
|
||
matched_objects.append(obj)
|
||
if obj_id:
|
||
matched_ids.add(obj_id)
|
||
|
||
relevant_relations: list[Any] = []
|
||
related_ids: set[str] = set(matched_ids)
|
||
for relation in relations:
|
||
relation_text = _compact_text(relation)
|
||
relation_ids = _ids_in_value(relation)
|
||
if (
|
||
any(term and term in relation_text for term in query_terms)
|
||
or bool(matched_ids.intersection(relation_ids))
|
||
):
|
||
relevant_relations.append(relation)
|
||
related_ids.update(relation_ids)
|
||
|
||
relevant_objects = list(matched_objects)
|
||
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
|
||
for obj_id in related_ids:
|
||
obj = object_by_id.get(obj_id)
|
||
key = _object_key(obj)
|
||
if obj is not None and key not in seen_object_keys:
|
||
relevant_objects.append(obj)
|
||
seen_object_keys.add(key)
|
||
|
||
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
|
||
|
||
|
||
def _compact_items(items: Any, *, limit: int) -> list[Any]:
|
||
if not isinstance(items, list):
|
||
return []
|
||
return [_compact_item(item) for item in items[:limit]]
|
||
|
||
|
||
def _compact_item(item: Any) -> Any:
|
||
if not isinstance(item, dict):
|
||
return item
|
||
|
||
preferred_keys = (
|
||
"id",
|
||
"name",
|
||
"label",
|
||
"class",
|
||
"category",
|
||
"type",
|
||
"text",
|
||
"source",
|
||
"target",
|
||
"subject",
|
||
"object",
|
||
"relation",
|
||
"predicate",
|
||
"description",
|
||
)
|
||
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
|
||
return compact or item
|
||
|
||
|
||
def _object_id(obj: Any) -> str | None:
|
||
if not isinstance(obj, dict):
|
||
return None
|
||
for key in ("id", "object_id", "uuid", "name", "label"):
|
||
value = obj.get(key)
|
||
if isinstance(value, (str, int)):
|
||
return str(value)
|
||
return None
|
||
|
||
|
||
def _object_key(obj: Any) -> str:
|
||
return _object_id(obj) or _compact_text(obj)
|
||
|
||
|
||
def _ids_in_value(value: Any) -> set[str]:
|
||
ids: set[str] = set()
|
||
if isinstance(value, dict):
|
||
for key, item in value.items():
|
||
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
|
||
if isinstance(item, (str, int)):
|
||
ids.add(str(item))
|
||
elif isinstance(item, dict):
|
||
obj_id = _object_id(item)
|
||
if obj_id:
|
||
ids.add(obj_id)
|
||
ids.update(_ids_in_value(item))
|
||
elif isinstance(value, list):
|
||
for item in value:
|
||
ids.update(_ids_in_value(item))
|
||
return ids
|
||
|
||
|
||
def _compact_text(value: Any) -> str:
|
||
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))
|