feat: memory recall fuction
This commit is contained in:
@ -1,21 +1,27 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from memory import MemoryRecallClient
|
||||||
|
|
||||||
from asr import BlackboxSTT
|
from asr import BlackboxSTT
|
||||||
from livekit.agents import (
|
from livekit.agents import (
|
||||||
Agent,
|
Agent,
|
||||||
AgentServer,
|
AgentServer,
|
||||||
AgentSession,
|
AgentSession,
|
||||||
|
ChatContext,
|
||||||
|
ChatMessage,
|
||||||
|
FlushSentinel,
|
||||||
JobContext,
|
JobContext,
|
||||||
JobProcess,
|
JobProcess,
|
||||||
MetricsCollectedEvent,
|
MetricsCollectedEvent,
|
||||||
|
ModelSettings,
|
||||||
RecordingOptions,
|
RecordingOptions,
|
||||||
TurnHandlingOptions,
|
TurnHandlingOptions,
|
||||||
cli,
|
cli,
|
||||||
|
llm,
|
||||||
metrics,
|
metrics,
|
||||||
room_io,
|
room_io,
|
||||||
stt,
|
stt,
|
||||||
@ -32,17 +38,62 @@ AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
|||||||
|
|
||||||
|
|
||||||
class CustomAgent(Agent):
|
class CustomAgent(Agent):
|
||||||
def __init__(self) -> None:
|
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
|
||||||
super().__init__(
|
super().__init__(instructions="")
|
||||||
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
|
self._memory_client = memory_client
|
||||||
"Keep your responses concise and friendly."
|
|
||||||
"You are interacting with the user via a local ASR and LLM pipeline.",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_enter(self) -> None:
|
async def on_enter(self) -> None:
|
||||||
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def llm_node(
|
||||||
|
self,
|
||||||
|
chat_ctx: ChatContext,
|
||||||
|
tools: list[llm.Tool],
|
||||||
|
model_settings: ModelSettings,
|
||||||
|
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||||
|
memory_context = await self._recall_room_memory(chat_ctx)
|
||||||
|
if memory_context:
|
||||||
|
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
||||||
|
|
||||||
|
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
||||||
|
|
||||||
|
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
||||||
|
if self._memory_client is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
user_query = _latest_user_text(chat_ctx)
|
||||||
|
if not user_query:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._memory_client.recall(user_query)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Unexpected memory recall failure")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _latest_user_text(chat_ctx: ChatContext) -> str:
|
||||||
|
for item in reversed(chat_ctx.items):
|
||||||
|
if isinstance(item, ChatMessage) and item.role == "user":
|
||||||
|
return (item.text_content or "").strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext:
|
||||||
|
chat_ctx = chat_ctx.copy()
|
||||||
|
for index in range(len(chat_ctx.items) - 1, -1, -1):
|
||||||
|
item = chat_ctx.items[index]
|
||||||
|
if isinstance(item, ChatMessage) and item.role == "user":
|
||||||
|
user_msg = item.model_copy(deep=True)
|
||||||
|
user_msg.content = [memory_context]
|
||||||
|
chat_ctx.items[index] = user_msg
|
||||||
|
return chat_ctx
|
||||||
|
|
||||||
|
chat_ctx.items.append(ChatMessage(role="user", content=[memory_context]))
|
||||||
|
return chat_ctx
|
||||||
|
|
||||||
|
|
||||||
server = AgentServer()
|
server = AgentServer()
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +130,10 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
||||||
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
||||||
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
||||||
|
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
|
||||||
|
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 10.0)
|
||||||
|
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 8000)
|
||||||
|
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
||||||
|
|
||||||
blackbox_stt = BlackboxSTT(
|
blackbox_stt = BlackboxSTT(
|
||||||
url=ASR_URL,
|
url=ASR_URL,
|
||||||
@ -147,8 +202,19 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
||||||
metrics.log_metrics(ev.metrics)
|
metrics.log_metrics(ev.metrics)
|
||||||
|
|
||||||
|
memory_client = (
|
||||||
|
MemoryRecallClient(
|
||||||
|
url=MEMORY_URL,
|
||||||
|
timeout=MEMORY_TIMEOUT,
|
||||||
|
max_chars=MEMORY_MAX_CHARS,
|
||||||
|
api_key=MEMORY_API_KEY,
|
||||||
|
)
|
||||||
|
if MEMORY_URL
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
await session.start(
|
await session.start(
|
||||||
agent=CustomAgent(),
|
agent=CustomAgent(memory_client=memory_client),
|
||||||
room=ctx.room,
|
room=ctx.room,
|
||||||
room_options=room_io.RoomOptions(
|
room_options=room_io.RoomOptions(
|
||||||
audio_output=room_io.AudioOutputOptions(
|
audio_output=room_io.AudioOutputOptions(
|
||||||
@ -207,7 +273,7 @@ def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
|
def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None:
|
||||||
if value:
|
if value:
|
||||||
params[key] = value
|
params[key] = value
|
||||||
|
|
||||||
@ -223,6 +289,17 @@ def _env_int(name: str, default: int) -> int:
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _env_float(name: str, default: float) -> float:
|
||||||
|
value = os.getenv(name)
|
||||||
|
if not value:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid float for %s=%r, using %s", name, value, default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
def _env_bool(name: str, default: bool) -> bool:
|
def _env_bool(name: str, default: bool) -> bool:
|
||||||
value = os.getenv(name)
|
value = os.getenv(name)
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|||||||
137
memory.py
Normal file
137
memory.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
|
||||||
|
|
||||||
|
logger = logging.getLogger("memory-recall")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRecallClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
timeout: float = 5.0,
|
||||||
|
max_chars: int = 2000,
|
||||||
|
api_key: str | None = None,
|
||||||
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._url = url
|
||||||
|
self._timeout = timeout
|
||||||
|
self._max_chars = max_chars
|
||||||
|
self._api_key = api_key
|
||||||
|
self._http_session = http_session
|
||||||
|
self._cached_payload: Any | None = None
|
||||||
|
|
||||||
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._http_session is None:
|
||||||
|
self._http_session = utils.http_context.http_session()
|
||||||
|
return self._http_session
|
||||||
|
|
||||||
|
async def recall(self, query: str) -> str:
|
||||||
|
query = query.strip()
|
||||||
|
if not query:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if self._api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._ensure_session().get(
|
||||||
|
self._url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self._timeout),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
raise APIStatusError(
|
||||||
|
message=f"Memory recall error: {error_text}",
|
||||||
|
status_code=resp.status,
|
||||||
|
request_id=None,
|
||||||
|
body=error_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await resp.json()
|
||||||
|
except aiohttp.ContentTypeError:
|
||||||
|
data = await resp.text()
|
||||||
|
|
||||||
|
self._cached_payload = data
|
||||||
|
return self._format_memory(data, query)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
|
||||||
|
)
|
||||||
|
return self._format_cached_memory(query)
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
logger.warning("Memory recall connection error: %s, using cached room graph", e)
|
||||||
|
return self._format_cached_memory(query)
|
||||||
|
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
|
||||||
|
logger.warning("Memory recall failed: %s, using cached room graph", e)
|
||||||
|
return self._format_cached_memory(query)
|
||||||
|
|
||||||
|
def _format_memory(self, data: Any, query: str) -> str:
|
||||||
|
memory = _format_room_graph_memory(data, query)
|
||||||
|
if len(memory) > self._max_chars:
|
||||||
|
memory = memory[: self._max_chars].rstrip()
|
||||||
|
return memory
|
||||||
|
|
||||||
|
def _format_cached_memory(self, query: str) -> str:
|
||||||
|
if self._cached_payload is None:
|
||||||
|
return ""
|
||||||
|
return self._format_memory(self._cached_payload, query)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_room_graph_memory(payload: Any, query: str) -> str:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
logger.warning("Unsupported room graph response: %s", payload)
|
||||||
|
return ""
|
||||||
|
objects = payload.get("objects", [])
|
||||||
|
relations = payload.get("relations", [])
|
||||||
|
summary = payload.get("summary", "")
|
||||||
|
usage_hint = payload.get("usage_hint", "")
|
||||||
|
|
||||||
|
if not objects and not relations and not summary:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
objects_text = json.dumps(objects, ensure_ascii=False, indent=2)
|
||||||
|
relations_text = json.dumps(relations, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你是一个物品定位助手。
|
||||||
|
|
||||||
|
我的房间内有以下物品信息:
|
||||||
|
|
||||||
|
{objects_text}
|
||||||
|
|
||||||
|
这些物品之间的空间关系如下:
|
||||||
|
|
||||||
|
{relations_text}
|
||||||
|
|
||||||
|
房间概览如下:
|
||||||
|
|
||||||
|
{summary}
|
||||||
|
|
||||||
|
现在我要找的目标物品是:{query}
|
||||||
|
|
||||||
|
请根据上面的 objects、relations 和 summary,告诉我它在哪里。
|
||||||
|
|
||||||
|
回答要求:
|
||||||
|
1. 只说明它和其他物品的位置关系。
|
||||||
|
2. 不要编造不存在的关系。
|
||||||
|
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
|
||||||
|
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
|
||||||
|
5. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
if usage_hint:
|
||||||
|
prompt += f"\n\n接口使用提示:\n{usage_hint}"
|
||||||
|
|
||||||
|
return prompt
|
||||||
Reference in New Issue
Block a user