import base64 import json import logging import os import time from collections.abc import AsyncIterable from dataclasses import dataclass from pathlib import Path from dotenv import load_dotenv from memory import MemoryRecallClient from tts import BlackboxTTS from asr import BlackboxSTT from livekit.agents import ( Agent, AgentServer, AgentSession, ChatContext, ChatMessage, FlushSentinel, JobContext, JobProcess, MetricsCollectedEvent, ModelSettings, RecordingOptions, TurnHandlingOptions, cli, llm, metrics, room_io, stt, ) from livekit.agents.voice.generation import update_instructions as update_chat_instructions from livekit.plugins import openai, silero from livekit.plugins.turn_detector.multilingual import MultilingualModel logger = logging.getLogger("custom-agent") CUSTOM_ENV_PATH = Path(__file__).with_name(".env") load_dotenv(dotenv_path=CUSTOM_ENV_PATH) AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "") ROOM_LOCATOR_INSTRUCTIONS = """ 你是一个房间物品定位助手。 当用户询问房间内某个物品的位置时: - 只用一句中文回答 - 描述目标物品和其他物品的相对位置关系 - 不要使用 Markdown、emoji、列表、标题、坐标区域标签 - 不要解释推理过程 如果用户的问题与房间物品定位无关,则正常回答用户问题。 """.strip() GENERAL_INSTRUCTIONS = """ 你是一个智能语音助手。 正常回答用户问题。 回答自然、简洁、准确。 """.strip() ROOM_LOCATOR_MODE = "room_locator" GENERAL_MODE = "general" VOICE_INPUT_MODE = "voice" VISION_VOICE_INPUT_MODE = "vision_voice" AUTO_INPUT_MODE = "auto" VISION_FRAME_TOPIC = "vision.frame" @dataclass class VisionFrame: image_data_url: str received_at: float mime_type: str saved_path: str | None = None class VisionFrameStore: def __init__(self, *, max_age_seconds: float) -> None: self._max_age_seconds = max_age_seconds self._latest_frame: VisionFrame | None = None def update(self, *, image: str, mime_type: str, saved_path: str | None = None) -> None: self._latest_frame = VisionFrame( image_data_url=f"data:{mime_type};base64,{image}", received_at=time.monotonic(), mime_type=mime_type, saved_path=saved_path, ) def consume_fresh(self) -> VisionFrame | None: frame = self._latest_frame if frame is None: return None age = time.monotonic() - frame.received_at self._latest_frame = None if age > self._max_age_seconds: logger.info("Dropping stale vision frame: age=%.3fs", age) return None return frame class CustomAgent(Agent): def __init__( self, *, memory_client: MemoryRecallClient | None = None, vision_store: VisionFrameStore | None = None, input_mode: str = AUTO_INPUT_MODE, text_llm: llm.LLM | None = None, vision_llm: llm.LLM | None = None, model_image_save_dir: Path | None = None, ) -> None: super().__init__(instructions=GENERAL_INSTRUCTIONS) self._memory_client = memory_client self._vision_store = vision_store self._input_mode = input_mode self._text_llm = text_llm self._vision_llm = vision_llm self._model_image_save_dir = model_image_save_dir async def on_enter(self) -> None: # self.session.generate_reply(instructions="greet the user and introduce yourself") pass async def llm_node( self, chat_ctx: ChatContext, tools: list[llm.Tool], model_settings: ModelSettings, ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: llm_node_started_at = time.perf_counter() user_query = _latest_user_text(chat_ctx) mode = _select_mode(user_query) vision_frame = self._consume_vision_frame() logger.info( "Selected agent mode: %s input_mode=%s has_image=%s", mode, self._input_mode, vision_frame is not None, ) chat_ctx = chat_ctx.copy() update_chat_instructions( chat_ctx, instructions=ROOM_LOCATOR_INSTRUCTIONS if mode == ROOM_LOCATOR_MODE else GENERAL_INSTRUCTIONS, add_if_missing=True, ) if mode == ROOM_LOCATOR_MODE: memory_context = await self._recall_room_memory(chat_ctx) if memory_context: chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) if vision_frame is not None: self._save_model_vision_frame(vision_frame) chat_ctx = _with_vision_as_latest_user_message(chat_ctx, vision_frame) llm_result = self._run_selected_llm( chat_ctx, tools, model_settings, has_image=vision_frame is not None, ) if not hasattr(llm_result, "__aiter__"): elapsed = time.perf_counter() - llm_node_started_at logger.info("LLM node completed without streaming in %.3fs", elapsed) return llm_result async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: first_chunk_at: float | None = None chunk_count = 0 try: async for chunk in llm_result: chunk_count += 1 if first_chunk_at is None: first_chunk_at = time.perf_counter() logger.info( "LLM first chunk after %.3fs", first_chunk_at - llm_node_started_at, ) yield chunk finally: finished_at = time.perf_counter() logger.info( "LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)", finished_at - llm_node_started_at, (first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0, chunk_count, ) return _instrumented_stream() def _consume_vision_frame(self) -> VisionFrame | None: if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None: return None return self._vision_store.consume_fresh() def _save_model_vision_frame(self, vision_frame: VisionFrame) -> None: if self._model_image_save_dir is None: return try: _, b64_data = vision_frame.image_data_url.split(",", 1) image_bytes = base64.b64decode(b64_data, validate=True) except Exception: logger.exception("Failed to decode model vision frame for debug save") return extension = _image_extension_from_mime_type(vision_frame.mime_type) timestamp_ms = int(time.time() * 1000) path = self._model_image_save_dir / f"{timestamp_ms}_model_input{extension}" try: self._model_image_save_dir.mkdir(parents=True, exist_ok=True) path.write_bytes(image_bytes) except Exception: logger.exception("Failed to save model vision frame: path=%s", path) return logger.info( "Saved model vision frame: path=%s bytes=%s source_path=%s", path, len(image_bytes), vision_frame.saved_path, ) def _run_selected_llm( self, chat_ctx: ChatContext, tools: list[llm.Tool], model_settings: ModelSettings, *, has_image: bool, ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: selected_llm = self._vision_llm if has_image else self._text_llm if selected_llm is None: return Agent.default.llm_node(self, chat_ctx, tools, model_settings) activity = self._get_activity_or_raise() tool_choice = model_settings.tool_choice conn_options = activity.session.conn_options.llm_conn_options async def _stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: async with selected_llm.chat( chat_ctx=chat_ctx, tools=tools, tool_choice=tool_choice, conn_options=conn_options, ) as stream: async for chunk in stream: yield chunk return _stream() async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: if self._memory_client is None: return "" user_query = _latest_user_text(chat_ctx) if not user_query: return "" started_at = time.perf_counter() try: recalled = await self._memory_client.recall(user_query) elapsed = time.perf_counter() - started_at logger.info( "Memory recall completed in %.3fs (query_len=%s, memory_len=%s)", elapsed, len(user_query), len(recalled), ) return recalled except Exception: logger.exception( "Unexpected memory recall failure after %.3fs", time.perf_counter() - started_at, ) return "" def _select_mode(user_query: str) -> str: normalized = _normalize_text(user_query) if not normalized: return GENERAL_MODE if _is_room_locator_query(normalized): return ROOM_LOCATOR_MODE return GENERAL_MODE def _is_room_locator_query(normalized_text: str) -> bool: room_context_hints = ( "房间", "屋里", "屋子", "室内", "客厅", "卧室", "书房", "厨房", "餐厅", "沙发", "桌", "椅", "床", "门", "窗", "柜", "电视", "空调", "书架", "灯", "冰箱", "茶几", "电脑", "包", "瓶", "相机", "植物", ) spatial_hints = ( "在哪里", "在哪", "位置", "方位", "旁边", "左边", "右边", "前面", "后面", "上面", "下面", "附近", "对面", "靠近", "挨着", "隔着", ) software_hints = ( "python", "代码", "函数", "class", "bug", "日志", "logging", "api", "server", "agent", "prompt", "模型", "数据库", "git", "uv", "ruff", "mypy", ) if any(hint in normalized_text for hint in software_hints): return False has_spatial_hint = any(hint in normalized_text for hint in spatial_hints) has_room_context_hint = any(hint in normalized_text for hint in room_context_hints) if has_spatial_hint and has_room_context_hint: return True if has_spatial_hint and len(normalized_text) <= 12: return True return False def _normalize_text(text: str) -> str: return "".join(text.split()).lower() def _latest_user_text(chat_ctx: ChatContext) -> str: for item in reversed(chat_ctx.items): if isinstance(item, ChatMessage) and item.role == "user": return (item.text_content or "").strip() return "" def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext: chat_ctx = chat_ctx.copy() for index in range(len(chat_ctx.items) - 1, -1, -1): item = chat_ctx.items[index] if isinstance(item, ChatMessage) and item.role == "user": user_msg = item.model_copy(deep=True) user_msg.content = [memory_context] chat_ctx.items[index] = user_msg return chat_ctx chat_ctx.items.append(ChatMessage(role="user", content=[memory_context])) return chat_ctx def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext: chat_ctx = chat_ctx.copy() image_content = llm.ImageContent( image=vision_frame.image_data_url, mime_type=vision_frame.mime_type, inference_detail="auto", ) for index in range(len(chat_ctx.items) - 1, -1, -1): item = chat_ctx.items[index] if isinstance(item, ChatMessage) and item.role == "user": user_msg = item.model_copy(deep=True) content = list(user_msg.content) content.append(image_content) user_msg.content = content chat_ctx.items[index] = user_msg return chat_ctx chat_ctx.items.append(ChatMessage(role="user", content=[image_content])) return chat_ctx def _normalize_input_mode(value: str | None) -> str: if not value: return AUTO_INPUT_MODE normalized = value.strip().lower().replace("-", "_") aliases = { "image_voice": VISION_VOICE_INPUT_MODE, "image": VISION_VOICE_INPUT_MODE, "vision": VISION_VOICE_INPUT_MODE, "vision_voice": VISION_VOICE_INPUT_MODE, "voice_image": VISION_VOICE_INPUT_MODE, "audio": VOICE_INPUT_MODE, "voice": VOICE_INPUT_MODE, "auto": AUTO_INPUT_MODE, } mode = aliases.get(normalized) if mode is not None: return mode logger.warning("Invalid CUSTOM_AGENT_INPUT_MODE=%r, using %s", value, AUTO_INPUT_MODE) return AUTO_INPUT_MODE def _image_extension_from_mime_type(mime_type: str) -> str: normalized = mime_type.strip().lower() if normalized == "image/png": return ".png" if normalized == "image/webp": return ".webp" if normalized == "image/gif": return ".gif" return ".jpg" def _model_image_save_dir_from_env() -> Path | None: if not _env_bool("CUSTOM_SAVE_MODEL_IMAGES", True): return None configured = os.getenv("CUSTOM_MODEL_IMAGE_SAVE_DIR") if configured: return Path(configured).expanduser() return Path(__file__).with_name("model_images") server = AgentServer() def prewarm(proc: JobProcess) -> None: # Load Silero VAD as requested proc.userdata["vad"] = silero.VAD.load() server.setup_fnc = prewarm @server.rtc_session(agent_name=AGENT_NAME) async def entrypoint(ctx: JobContext) -> None: ctx.log_context_fields = { "room": ctx.room.name, } # Configuration for custom local endpoints. These can be set in your .env file. ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox") ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice") ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto") ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh") LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL") LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max") LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY") TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL) VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL) INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE")) if not LLM_API_KEY: raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv( "VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox" ) TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts") TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000) TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1) OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE) MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip() MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0) MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000) MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None blackbox_stt = BlackboxSTT( url=ASR_URL, model_name=ASR_MODEL, language=ASR_LANGUAGE, output_language=ASR_OUTPUT_LANGUAGE, hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"), itn=os.getenv("CUSTOM_ASR_ITN"), chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"), ) stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"]) import httpx from openai import AsyncClient as OpenAIAsyncClient # OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL. http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False)) if LLM_BASE_URL: openai_client = OpenAIAsyncClient( api_key=LLM_API_KEY, base_url=LLM_BASE_URL, http_client=http_client, ) else: openai_client = OpenAIAsyncClient( api_key=LLM_API_KEY, http_client=http_client, ) base_llm = openai.LLM( model=LLM_MODEL, client=openai_client, ) text_llm = ( openai.LLM(model=TEXT_LLM_MODEL, client=openai_client) if TEXT_LLM_MODEL != LLM_MODEL else base_llm ) vision_llm = ( openai.LLM(model=VISION_LLM_MODEL, client=openai_client) if VISION_LLM_MODEL != LLM_MODEL else base_llm ) vision_store = VisionFrameStore( max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0) ) session: AgentSession = AgentSession( # 1. Custom ASR blackbox with StreamAdapter stt=stt_stream, # 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI. llm=base_llm, # 3. TTS blackbox tts=BlackboxTTS( url=TTS_URL, model_name=TTS_MODEL, params=_tts_params_from_env(TTS_MODEL), prompt_wav_path=_tts_prompt_wav_from_env(TTS_MODEL), sample_rate=TTS_SAMPLE_RATE, num_channels=TTS_NUM_CHANNELS, ), # 4. Silero VAD vad=ctx.proc.userdata["vad"], turn_handling=TurnHandlingOptions( turn_detection=MultilingualModel(), interruption={ "resume_false_interruption": True, "false_interruption_timeout": 1.0, }, ), preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True), aec_warmup_duration=3.0, tts_text_transforms=[ "filter_emoji", "filter_markdown", ], ) @session.on("metrics_collected") def _on_metrics_collected(ev: MetricsCollectedEvent) -> None: metrics.log_metrics(ev.metrics) @session.on("conversation_item_added") def _on_conversation_item_added(event) -> None: item = getattr(event, "item", None) if not isinstance(item, ChatMessage): return if item.role == "user" and item.metrics: logger.info("User turn metrics: %s", item.metrics) elif item.role == "assistant" and item.metrics: logger.info("Assistant turn metrics: %s", item.metrics) @ctx.room.on("data_received") def _on_data_received(data_packet) -> None: packet_topic = getattr(data_packet, "topic", None) if packet_topic not in {None, "", VISION_FRAME_TOPIC}: return if INPUT_MODE == VOICE_INPUT_MODE: logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE) return try: payload = json.loads(data_packet.data.decode("utf-8")) except Exception: logger.exception("Failed to decode vision frame payload") return if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC: return image = payload.get("image") if not isinstance(image, str) or not image: logger.warning("Received vision frame without image data") return mime_type = payload.get("mime_type") if not isinstance(mime_type, str) or not mime_type: mime_type = "image/jpeg" saved_path = payload.get("saved_path") vision_store.update( image=image, mime_type=mime_type, saved_path=saved_path if isinstance(saved_path, str) else None, ) logger.info( "Cached vision frame: mime_type=%s image_chars=%s saved_path=%s", mime_type, len(image), saved_path, ) memory_client = ( MemoryRecallClient( url=MEMORY_URL, timeout=MEMORY_TIMEOUT, max_chars=MEMORY_MAX_CHARS, api_key=MEMORY_API_KEY, ) if MEMORY_URL else None ) await session.start( agent=CustomAgent( memory_client=memory_client, vision_store=vision_store, input_mode=INPUT_MODE, text_llm=text_llm, vision_llm=vision_llm, model_image_save_dir=_model_image_save_dir_from_env(), ), room=ctx.room, room_options=room_io.RoomOptions( audio_output=room_io.AudioOutputOptions( sample_rate=OUTPUT_SAMPLE_RATE, num_channels=TTS_NUM_CHANNELS, ), ), record=_recording_options_from_env(), ) def _tts_params_from_env(model_name: str) -> dict[str, str]: params: dict[str, str] = {} model_name = model_name.lower() if model_name == "voxcpmtts": _set_if_present(params, "streaming", os.getenv("CUSTOM_TTS_STREAMING")) _set_if_present( params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT") or os.getenv("VOXCPM_PROMPT_TEXT"), ) _set_if_present(params, "cfg_value", os.getenv("VOXCPM_CFG_VALUE")) _set_if_present(params, "inference_timesteps", os.getenv("VOXCPM_INFERENCE_TIMESTEPS")) _set_if_present(params, "do_normalize", os.getenv("VOXCPM_DO_NORMALIZE")) _set_if_present(params, "denoise", os.getenv("VOXCPM_DENOISE")) _set_if_present(params, "retry_badcase", os.getenv("VOXCPM_RETRY_BADCASE")) _set_if_present( params, "retry_badcase_max_times", os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES"), ) _set_if_present( params, "retry_badcase_ratio_threshold", os.getenv("VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD"), ) elif model_name == "melotts": _set_if_present(params, "speed", os.getenv("CUSTOM_TTS_SPEED")) elif model_name == "cosyvoicetts": _set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID")) _set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE")) _set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT")) _set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT")) elif model_name == "sovitstts": _set_if_present(params, "text_lang", os.getenv("CUSTOM_TTS_TEXT_LANG")) _set_if_present(params, "prompt_lang", os.getenv("CUSTOM_TTS_PROMPT_LANG")) _set_if_present(params, "text_split_method", os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD")) _set_if_present(params, "batch_size", os.getenv("CUSTOM_TTS_BATCH_SIZE")) _set_if_present(params, "media_type", os.getenv("CUSTOM_TTS_MEDIA_TYPE")) _set_if_present(params, "streaming_mode", os.getenv("CUSTOM_TTS_STREAMING")) _set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH")) _set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT")) return params def _tts_prompt_wav_from_env(model_name: str) -> str | None: if model_name.lower() != "voxcpmtts": return None return os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV") or None def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None: if value: params[key] = value def _env_int(name: str, default: int) -> int: value = os.getenv(name) if not value: return default try: return int(value) except ValueError: logger.warning("Invalid integer for %s=%r, using %s", name, value, default) return default def _env_float(name: str, default: float) -> float: value = os.getenv(name) if not value: return default try: return float(value) except ValueError: logger.warning("Invalid float for %s=%r, using %s", name, value, default) return default def _env_bool(name: str, default: bool) -> bool: value = os.getenv(name) if value is None: return default normalized = value.strip().lower() if normalized in {"1", "true", "yes", "on"}: return True if normalized in {"0", "false", "no", "off"}: return False logger.warning("Invalid boolean for %s=%r, using %s", name, value, default) return default def _recording_options_from_env() -> RecordingOptions: return RecordingOptions( audio=_env_bool("CUSTOM_RECORD_AUDIO", False), traces=_env_bool("CUSTOM_RECORD_TRACES", False), logs=_env_bool("CUSTOM_RECORD_LOGS", False), transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False), ) if __name__ == "__main__": cli.run_app(server)