From 3a2f5c425227d84856d4ba78930b757773a6096e Mon Sep 17 00:00:00 2001 From: 0Xiao0 <511201264@qq.com> Date: Thu, 14 May 2026 10:16:08 +0800 Subject: [PATCH] feat: memory recall fuction --- custom_agent.py | 95 +++++++++++++++++++++++++++++---- memory.py | 137 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+), 9 deletions(-) create mode 100644 memory.py diff --git a/custom_agent.py b/custom_agent.py index 920d128..5161f07 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -1,21 +1,27 @@ import logging import os +from collections.abc import AsyncIterable from pathlib import Path -from typing import Optional from dotenv import load_dotenv +from memory import MemoryRecallClient from asr import BlackboxSTT from livekit.agents import ( Agent, AgentServer, AgentSession, + ChatContext, + ChatMessage, + FlushSentinel, JobContext, JobProcess, MetricsCollectedEvent, + ModelSettings, RecordingOptions, TurnHandlingOptions, cli, + llm, metrics, room_io, stt, @@ -32,17 +38,62 @@ AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "") class CustomAgent(Agent): - def __init__(self) -> None: - super().__init__( - instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant." - "Keep your responses concise and friendly." - "You are interacting with the user via a local ASR and LLM pipeline.", - ) + def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None: + super().__init__(instructions="") + self._memory_client = memory_client async def on_enter(self) -> None: # self.session.generate_reply(instructions="greet the user and introduce yourself") pass + async def llm_node( + self, + chat_ctx: ChatContext, + tools: list[llm.Tool], + model_settings: ModelSettings, + ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: + memory_context = await self._recall_room_memory(chat_ctx) + if memory_context: + chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) + + return Agent.default.llm_node(self, chat_ctx, tools, model_settings) + + async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: + if self._memory_client is None: + return "" + + user_query = _latest_user_text(chat_ctx) + if not user_query: + return "" + + try: + return await self._memory_client.recall(user_query) + except Exception: + logger.exception("Unexpected memory recall failure") + return "" + + +def _latest_user_text(chat_ctx: ChatContext) -> str: + for item in reversed(chat_ctx.items): + if isinstance(item, ChatMessage) and item.role == "user": + return (item.text_content or "").strip() + return "" + + +def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext: + chat_ctx = chat_ctx.copy() + for index in range(len(chat_ctx.items) - 1, -1, -1): + item = chat_ctx.items[index] + if isinstance(item, ChatMessage) and item.role == "user": + user_msg = item.model_copy(deep=True) + user_msg.content = [memory_context] + chat_ctx.items[index] = user_msg + return chat_ctx + + chat_ctx.items.append(ChatMessage(role="user", content=[memory_context])) + return chat_ctx + + server = AgentServer() @@ -79,6 +130,10 @@ async def entrypoint(ctx: JobContext) -> None: TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000) TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1) OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE) + MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip() + MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 10.0) + MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 8000) + MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None blackbox_stt = BlackboxSTT( url=ASR_URL, @@ -147,8 +202,19 @@ async def entrypoint(ctx: JobContext) -> None: def _on_metrics_collected(ev: MetricsCollectedEvent) -> None: metrics.log_metrics(ev.metrics) + memory_client = ( + MemoryRecallClient( + url=MEMORY_URL, + timeout=MEMORY_TIMEOUT, + max_chars=MEMORY_MAX_CHARS, + api_key=MEMORY_API_KEY, + ) + if MEMORY_URL + else None + ) + await session.start( - agent=CustomAgent(), + agent=CustomAgent(memory_client=memory_client), room=ctx.room, room_options=room_io.RoomOptions( audio_output=room_io.AudioOutputOptions( @@ -207,7 +273,7 @@ def _tts_params_from_env(model_name: str) -> dict[str, str]: return params -def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None: +def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None: if value: params[key] = value @@ -223,6 +289,17 @@ def _env_int(name: str, default: int) -> int: return default +def _env_float(name: str, default: float) -> float: + value = os.getenv(name) + if not value: + return default + try: + return float(value) + except ValueError: + logger.warning("Invalid float for %s=%r, using %s", name, value, default) + return default + + def _env_bool(name: str, default: bool) -> bool: value = os.getenv(name) if value is None: diff --git a/memory.py b/memory.py new file mode 100644 index 0000000..b47ead1 --- /dev/null +++ b/memory.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any + +import aiohttp + +from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils + +logger = logging.getLogger("memory-recall") + + +class MemoryRecallClient: + def __init__( + self, + *, + url: str, + timeout: float = 5.0, + max_chars: int = 2000, + api_key: str | None = None, + http_session: aiohttp.ClientSession | None = None, + ) -> None: + self._url = url + self._timeout = timeout + self._max_chars = max_chars + self._api_key = api_key + self._http_session = http_session + self._cached_payload: Any | None = None + + def _ensure_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + self._http_session = utils.http_context.http_session() + return self._http_session + + async def recall(self, query: str) -> str: + query = query.strip() + if not query: + return "" + + headers = {} + if self._api_key: + headers["Authorization"] = f"Bearer {self._api_key}" + + try: + async with self._ensure_session().get( + self._url, + headers=headers, + timeout=aiohttp.ClientTimeout(total=self._timeout), + ) as resp: + if resp.status != 200: + error_text = await resp.text() + raise APIStatusError( + message=f"Memory recall error: {error_text}", + status_code=resp.status, + request_id=None, + body=error_text, + ) + + try: + data = await resp.json() + except aiohttp.ContentTypeError: + data = await resp.text() + + self._cached_payload = data + return self._format_memory(data, query) + except asyncio.TimeoutError: + logger.warning( + "Memory recall timed out after %.1fs, using cached room graph", self._timeout + ) + return self._format_cached_memory(query) + except aiohttp.ClientError as e: + logger.warning("Memory recall connection error: %s, using cached room graph", e) + return self._format_cached_memory(query) + except (APIConnectionError, APIStatusError, APITimeoutError) as e: + logger.warning("Memory recall failed: %s, using cached room graph", e) + return self._format_cached_memory(query) + + def _format_memory(self, data: Any, query: str) -> str: + memory = _format_room_graph_memory(data, query) + if len(memory) > self._max_chars: + memory = memory[: self._max_chars].rstrip() + return memory + + def _format_cached_memory(self, query: str) -> str: + if self._cached_payload is None: + return "" + return self._format_memory(self._cached_payload, query) + + +def _format_room_graph_memory(payload: Any, query: str) -> str: + if not isinstance(payload, dict): + logger.warning("Unsupported room graph response: %s", payload) + return "" + objects = payload.get("objects", []) + relations = payload.get("relations", []) + summary = payload.get("summary", "") + usage_hint = payload.get("usage_hint", "") + + if not objects and not relations and not summary: + return "" + + objects_text = json.dumps(objects, ensure_ascii=False, indent=2) + relations_text = json.dumps(relations, ensure_ascii=False, indent=2) + + prompt = f""" +你是一个物品定位助手。 + +我的房间内有以下物品信息: + +{objects_text} + +这些物品之间的空间关系如下: + +{relations_text} + +房间概览如下: + +{summary} + +现在我要找的目标物品是:{query} + +请根据上面的 objects、relations 和 summary,告诉我它在哪里。 + +回答要求: +1. 只说明它和其他物品的位置关系。 +2. 不要编造不存在的关系。 +3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。 +4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。” +5. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。 +""".strip() + + if usage_hint: + prompt += f"\n\n接口使用提示:\n{usage_hint}" + + return prompt