import logging import os import time from collections.abc import AsyncIterable from pathlib import Path from dotenv import load_dotenv from memory import MemoryRecallClient from asr import BlackboxSTT from livekit.agents import ( Agent, AgentServer, AgentSession, ChatContext, ChatMessage, FlushSentinel, JobContext, JobProcess, MetricsCollectedEvent, ModelSettings, RecordingOptions, TurnHandlingOptions, cli, llm, metrics, room_io, stt, ) from livekit.plugins import openai, silero from livekit.plugins.turn_detector.multilingual import MultilingualModel from tts import BlackboxTTS logger = logging.getLogger("custom-agent") CUSTOM_ENV_PATH = Path(__file__).with_name(".env") load_dotenv(dotenv_path=CUSTOM_ENV_PATH) AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "") ROOM_LOCATOR_INSTRUCTIONS = """ 你是一个房间物品定位助手。 当用户询问房间内某个物品的位置时: - 只用一句中文回答 - 描述目标物品和其他物品的相对位置关系 - 不要使用 Markdown、emoji、列表、标题、坐标区域标签 - 不要解释推理过程 如果用户的问题与房间物品定位无关,则正常回答用户问题。 """.strip() class CustomAgent(Agent): def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None: super().__init__(instructions=ROOM_LOCATOR_INSTRUCTIONS) self._memory_client = memory_client async def on_enter(self) -> None: # self.session.generate_reply(instructions="greet the user and introduce yourself") pass async def llm_node( self, chat_ctx: ChatContext, tools: list[llm.Tool], model_settings: ModelSettings, ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: llm_node_started_at = time.perf_counter() memory_context = await self._recall_room_memory(chat_ctx) if memory_context: chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings) if not hasattr(llm_result, "__aiter__"): elapsed = time.perf_counter() - llm_node_started_at logger.info("LLM node completed without streaming in %.3fs", elapsed) return llm_result async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: first_chunk_at: float | None = None chunk_count = 0 try: async for chunk in llm_result: chunk_count += 1 if first_chunk_at is None: first_chunk_at = time.perf_counter() logger.info( "LLM first chunk after %.3fs", first_chunk_at - llm_node_started_at, ) yield chunk finally: finished_at = time.perf_counter() logger.info( "LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)", finished_at - llm_node_started_at, (first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0, chunk_count, ) return _instrumented_stream() async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: if self._memory_client is None: return "" user_query = _latest_user_text(chat_ctx) if not user_query: return "" started_at = time.perf_counter() try: recalled = await self._memory_client.recall(user_query) elapsed = time.perf_counter() - started_at logger.info( "Memory recall completed in %.3fs (query_len=%s, memory_len=%s)", elapsed, len(user_query), len(recalled), ) return recalled except Exception: logger.exception( "Unexpected memory recall failure after %.3fs", time.perf_counter() - started_at, ) return "" def _latest_user_text(chat_ctx: ChatContext) -> str: for item in reversed(chat_ctx.items): if isinstance(item, ChatMessage) and item.role == "user": return (item.text_content or "").strip() return "" def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext: chat_ctx = chat_ctx.copy() for index in range(len(chat_ctx.items) - 1, -1, -1): item = chat_ctx.items[index] if isinstance(item, ChatMessage) and item.role == "user": user_msg = item.model_copy(deep=True) user_msg.content = [memory_context] chat_ctx.items[index] = user_msg return chat_ctx chat_ctx.items.append(ChatMessage(role="user", content=[memory_context])) return chat_ctx server = AgentServer() def prewarm(proc: JobProcess) -> None: # Load Silero VAD as requested proc.userdata["vad"] = silero.VAD.load() server.setup_fnc = prewarm @server.rtc_session(agent_name=AGENT_NAME) async def entrypoint(ctx: JobContext) -> None: ctx.log_context_fields = { "room": ctx.room.name, } # Configuration for custom local endpoints. These can be set in your .env file. ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox") ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice") ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto") ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh") LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL") LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max") LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY") if not LLM_API_KEY: raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv( "VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox" ) TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts") TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000) TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1) OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE) MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip() MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0) MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000) MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None blackbox_stt = BlackboxSTT( url=ASR_URL, model_name=ASR_MODEL, language=ASR_LANGUAGE, output_language=ASR_OUTPUT_LANGUAGE, hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"), itn=os.getenv("CUSTOM_ASR_ITN"), chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"), ) stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"]) import httpx from openai import AsyncClient as OpenAIAsyncClient # OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL. http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False)) if LLM_BASE_URL: openai_client = OpenAIAsyncClient( api_key=LLM_API_KEY, base_url=LLM_BASE_URL, http_client=http_client, ) else: openai_client = OpenAIAsyncClient( api_key=LLM_API_KEY, http_client=http_client, ) session: AgentSession = AgentSession( # 1. Custom ASR blackbox with StreamAdapter stt=stt_stream, # 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI. llm=openai.LLM( model=LLM_MODEL, client=openai_client, ), # 3. TTS blackbox tts=BlackboxTTS( url=TTS_URL, model_name=TTS_MODEL, params=_tts_params_from_env(TTS_MODEL), prompt_wav_path=_tts_prompt_wav_from_env(TTS_MODEL), sample_rate=TTS_SAMPLE_RATE, num_channels=TTS_NUM_CHANNELS, ), # 4. Silero VAD vad=ctx.proc.userdata["vad"], turn_handling=TurnHandlingOptions( turn_detection=MultilingualModel(), interruption={ "resume_false_interruption": True, "false_interruption_timeout": 1.0, }, ), preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True), aec_warmup_duration=3.0, tts_text_transforms=[ "filter_emoji", "filter_markdown", ], ) @session.on("metrics_collected") def _on_metrics_collected(ev: MetricsCollectedEvent) -> None: metrics.log_metrics(ev.metrics) @session.on("conversation_item_added") def _on_conversation_item_added(event) -> None: item = getattr(event, "item", None) if not isinstance(item, ChatMessage): return if item.role == "user" and item.metrics: logger.info("User turn metrics: %s", item.metrics) elif item.role == "assistant" and item.metrics: logger.info("Assistant turn metrics: %s", item.metrics) memory_client = ( MemoryRecallClient( url=MEMORY_URL, timeout=MEMORY_TIMEOUT, max_chars=MEMORY_MAX_CHARS, api_key=MEMORY_API_KEY, ) if MEMORY_URL else None ) await session.start( agent=CustomAgent(memory_client=memory_client), room=ctx.room, room_options=room_io.RoomOptions( audio_output=room_io.AudioOutputOptions( sample_rate=OUTPUT_SAMPLE_RATE, num_channels=TTS_NUM_CHANNELS, ), ), record=_recording_options_from_env(), ) def _tts_params_from_env(model_name: str) -> dict[str, str]: params: dict[str, str] = {} model_name = model_name.lower() if model_name == "voxcpmtts": _set_if_present(params, "streaming", os.getenv("CUSTOM_TTS_STREAMING")) _set_if_present( params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT") or os.getenv("VOXCPM_PROMPT_TEXT"), ) _set_if_present(params, "cfg_value", os.getenv("VOXCPM_CFG_VALUE")) _set_if_present(params, "inference_timesteps", os.getenv("VOXCPM_INFERENCE_TIMESTEPS")) _set_if_present(params, "do_normalize", os.getenv("VOXCPM_DO_NORMALIZE")) _set_if_present(params, "denoise", os.getenv("VOXCPM_DENOISE")) _set_if_present(params, "retry_badcase", os.getenv("VOXCPM_RETRY_BADCASE")) _set_if_present( params, "retry_badcase_max_times", os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES"), ) _set_if_present( params, "retry_badcase_ratio_threshold", os.getenv("VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD"), ) elif model_name == "melotts": _set_if_present(params, "speed", os.getenv("CUSTOM_TTS_SPEED")) elif model_name == "cosyvoicetts": _set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID")) _set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE")) _set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT")) _set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT")) elif model_name == "sovitstts": _set_if_present(params, "text_lang", os.getenv("CUSTOM_TTS_TEXT_LANG")) _set_if_present(params, "prompt_lang", os.getenv("CUSTOM_TTS_PROMPT_LANG")) _set_if_present(params, "text_split_method", os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD")) _set_if_present(params, "batch_size", os.getenv("CUSTOM_TTS_BATCH_SIZE")) _set_if_present(params, "media_type", os.getenv("CUSTOM_TTS_MEDIA_TYPE")) _set_if_present(params, "streaming_mode", os.getenv("CUSTOM_TTS_STREAMING")) _set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH")) _set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT")) return params def _tts_prompt_wav_from_env(model_name: str) -> str | None: if model_name.lower() != "voxcpmtts": return None return os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV") or None def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None: if value: params[key] = value def _env_int(name: str, default: int) -> int: value = os.getenv(name) if not value: return default try: return int(value) except ValueError: logger.warning("Invalid integer for %s=%r, using %s", name, value, default) return default def _env_float(name: str, default: float) -> float: value = os.getenv(name) if not value: return default try: return float(value) except ValueError: logger.warning("Invalid float for %s=%r, using %s", name, value, default) return default def _env_bool(name: str, default: bool) -> bool: value = os.getenv(name) if value is None: return default normalized = value.strip().lower() if normalized in {"1", "true", "yes", "on"}: return True if normalized in {"0", "false", "no", "off"}: return False logger.warning("Invalid boolean for %s=%r, using %s", name, value, default) return default def _recording_options_from_env() -> RecordingOptions: return RecordingOptions( audio=_env_bool("CUSTOM_RECORD_AUDIO", False), traces=_env_bool("CUSTOM_RECORD_TRACES", False), logs=_env_bool("CUSTOM_RECORD_LOGS", False), transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False), ) if __name__ == "__main__": cli.run_app(server)