feat: memory recall fuction
This commit is contained in:
@ -1,21 +1,27 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterable
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from memory import MemoryRecallClient
|
||||
|
||||
from asr import BlackboxSTT
|
||||
from livekit.agents import (
|
||||
Agent,
|
||||
AgentServer,
|
||||
AgentSession,
|
||||
ChatContext,
|
||||
ChatMessage,
|
||||
FlushSentinel,
|
||||
JobContext,
|
||||
JobProcess,
|
||||
MetricsCollectedEvent,
|
||||
ModelSettings,
|
||||
RecordingOptions,
|
||||
TurnHandlingOptions,
|
||||
cli,
|
||||
llm,
|
||||
metrics,
|
||||
room_io,
|
||||
stt,
|
||||
@ -32,17 +38,62 @@ AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
||||
|
||||
|
||||
class CustomAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
|
||||
"Keep your responses concise and friendly."
|
||||
"You are interacting with the user via a local ASR and LLM pipeline.",
|
||||
)
|
||||
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
|
||||
super().__init__(instructions="")
|
||||
self._memory_client = memory_client
|
||||
|
||||
async def on_enter(self) -> None:
|
||||
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||||
pass
|
||||
|
||||
async def llm_node(
|
||||
self,
|
||||
chat_ctx: ChatContext,
|
||||
tools: list[llm.Tool],
|
||||
model_settings: ModelSettings,
|
||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||
memory_context = await self._recall_room_memory(chat_ctx)
|
||||
if memory_context:
|
||||
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
||||
|
||||
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
||||
|
||||
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
||||
if self._memory_client is None:
|
||||
return ""
|
||||
|
||||
user_query = _latest_user_text(chat_ctx)
|
||||
if not user_query:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return await self._memory_client.recall(user_query)
|
||||
except Exception:
|
||||
logger.exception("Unexpected memory recall failure")
|
||||
return ""
|
||||
|
||||
|
||||
def _latest_user_text(chat_ctx: ChatContext) -> str:
|
||||
for item in reversed(chat_ctx.items):
|
||||
if isinstance(item, ChatMessage) and item.role == "user":
|
||||
return (item.text_content or "").strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext:
|
||||
chat_ctx = chat_ctx.copy()
|
||||
for index in range(len(chat_ctx.items) - 1, -1, -1):
|
||||
item = chat_ctx.items[index]
|
||||
if isinstance(item, ChatMessage) and item.role == "user":
|
||||
user_msg = item.model_copy(deep=True)
|
||||
user_msg.content = [memory_context]
|
||||
chat_ctx.items[index] = user_msg
|
||||
return chat_ctx
|
||||
|
||||
chat_ctx.items.append(ChatMessage(role="user", content=[memory_context]))
|
||||
return chat_ctx
|
||||
|
||||
|
||||
server = AgentServer()
|
||||
|
||||
|
||||
@ -79,6 +130,10 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
||||
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
||||
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
||||
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
|
||||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 10.0)
|
||||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 8000)
|
||||
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
||||
|
||||
blackbox_stt = BlackboxSTT(
|
||||
url=ASR_URL,
|
||||
@ -147,8 +202,19 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
||||
metrics.log_metrics(ev.metrics)
|
||||
|
||||
memory_client = (
|
||||
MemoryRecallClient(
|
||||
url=MEMORY_URL,
|
||||
timeout=MEMORY_TIMEOUT,
|
||||
max_chars=MEMORY_MAX_CHARS,
|
||||
api_key=MEMORY_API_KEY,
|
||||
)
|
||||
if MEMORY_URL
|
||||
else None
|
||||
)
|
||||
|
||||
await session.start(
|
||||
agent=CustomAgent(),
|
||||
agent=CustomAgent(memory_client=memory_client),
|
||||
room=ctx.room,
|
||||
room_options=room_io.RoomOptions(
|
||||
audio_output=room_io.AudioOutputOptions(
|
||||
@ -207,7 +273,7 @@ def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
||||
return params
|
||||
|
||||
|
||||
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
|
||||
def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None:
|
||||
if value:
|
||||
params[key] = value
|
||||
|
||||
@ -223,6 +289,17 @@ def _env_int(name: str, default: int) -> int:
|
||||
return default
|
||||
|
||||
|
||||
def _env_float(name: str, default: float) -> float:
|
||||
value = os.getenv(name)
|
||||
if not value:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
logger.warning("Invalid float for %s=%r, using %s", name, value, default)
|
||||
return default
|
||||
|
||||
|
||||
def _env_bool(name: str, default: bool) -> bool:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
|
||||
Reference in New Issue
Block a user