From e0973231766a1dcd748d6ae633f031dffc7b4480 Mon Sep 17 00:00:00 2001 From: 0Xiao0 <511201264@qq.com> Date: Mon, 25 May 2026 17:17:38 +0800 Subject: [PATCH] feat: support vlm chat --- .env => .env.example | 6 +- .gitignore | 3 + custom_agent.py | 282 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 282 insertions(+), 9 deletions(-) rename .env => .env.example (92%) create mode 100644 .gitignore diff --git a/.env b/.env.example similarity index 92% rename from .env rename to .env.example index 76b2b95..8f023d6 100644 --- a/.env +++ b/.env.example @@ -23,6 +23,10 @@ CUSTOM_LLM_BASE_URL=http://localhost/v1 CUSTOM_LLM_MODEL=Qwen-VL CUSTOM_LLM_API_KEY= CUSTOM_LLM_VERIFY_SSL=false +CUSTOM_SAVE_MODEL_IMAGES=false + +# CUSTOM_TEXT_LLM_MODEL= +# CUSTOM_VISION_LLM_MODEL= # CUSTOM_LLM_BASE_URL=https://api.deepseek.com # CUSTOM_LLM_MODEL=deepseek-v4-flash @@ -31,7 +35,7 @@ CUSTOM_LLM_VERIFY_SSL=false # TTS blackbox -CUSTOM_TTS_URL=http://localhost:5000/tts-blackbox +CUSTOM_TTS_URL=http://localhost:5050/tts-blackbox CUSTOM_TTS_MODEL=voxcpmtts # CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav CUSTOM_TTS_STREAMING=true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ef52c5f --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +.env +model_images/ \ No newline at end of file diff --git a/custom_agent.py b/custom_agent.py index 16a5728..8f274fd 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -1,7 +1,10 @@ +import base64 +import json import logging import os import time from collections.abc import AsyncIterable +from dataclasses import dataclass from pathlib import Path from dotenv import load_dotenv @@ -56,12 +59,65 @@ GENERAL_INSTRUCTIONS = """ ROOM_LOCATOR_MODE = "room_locator" GENERAL_MODE = "general" +VOICE_INPUT_MODE = "voice" +VISION_VOICE_INPUT_MODE = "vision_voice" +AUTO_INPUT_MODE = "auto" +VISION_FRAME_TOPIC = "vision.frame" + + +@dataclass +class VisionFrame: + image_data_url: str + received_at: float + mime_type: str + saved_path: str | None = None + + +class VisionFrameStore: + def __init__(self, *, max_age_seconds: float) -> None: + self._max_age_seconds = max_age_seconds + self._latest_frame: VisionFrame | None = None + + def update(self, *, image: str, mime_type: str, saved_path: str | None = None) -> None: + self._latest_frame = VisionFrame( + image_data_url=f"data:{mime_type};base64,{image}", + received_at=time.monotonic(), + mime_type=mime_type, + saved_path=saved_path, + ) + + def consume_fresh(self) -> VisionFrame | None: + frame = self._latest_frame + if frame is None: + return None + + age = time.monotonic() - frame.received_at + self._latest_frame = None + if age > self._max_age_seconds: + logger.info("Dropping stale vision frame: age=%.3fs", age) + return None + + return frame class CustomAgent(Agent): - def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None: + def __init__( + self, + *, + memory_client: MemoryRecallClient | None = None, + vision_store: VisionFrameStore | None = None, + input_mode: str = AUTO_INPUT_MODE, + text_llm: llm.LLM | None = None, + vision_llm: llm.LLM | None = None, + model_image_save_dir: Path | None = None, + ) -> None: super().__init__(instructions=GENERAL_INSTRUCTIONS) self._memory_client = memory_client + self._vision_store = vision_store + self._input_mode = input_mode + self._text_llm = text_llm + self._vision_llm = vision_llm + self._model_image_save_dir = model_image_save_dir async def on_enter(self) -> None: # self.session.generate_reply(instructions="greet the user and introduce yourself") @@ -77,7 +133,13 @@ class CustomAgent(Agent): user_query = _latest_user_text(chat_ctx) mode = _select_mode(user_query) - logger.info("Selected agent mode: %s", mode) + vision_frame = self._consume_vision_frame() + logger.info( + "Selected agent mode: %s input_mode=%s has_image=%s", + mode, + self._input_mode, + vision_frame is not None, + ) chat_ctx = chat_ctx.copy() update_chat_instructions( @@ -93,7 +155,16 @@ class CustomAgent(Agent): if memory_context: chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) - llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings) + if vision_frame is not None: + self._save_model_vision_frame(vision_frame) + chat_ctx = _with_vision_as_latest_user_message(chat_ctx, vision_frame) + + llm_result = self._run_selected_llm( + chat_ctx, + tools, + model_settings, + has_image=vision_frame is not None, + ) if not hasattr(llm_result, "__aiter__"): elapsed = time.perf_counter() - llm_node_started_at logger.info("LLM node completed without streaming in %.3fs", elapsed) @@ -123,6 +194,68 @@ class CustomAgent(Agent): return _instrumented_stream() + def _consume_vision_frame(self) -> VisionFrame | None: + if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None: + return None + return self._vision_store.consume_fresh() + + def _save_model_vision_frame(self, vision_frame: VisionFrame) -> None: + if self._model_image_save_dir is None: + return + + try: + _, b64_data = vision_frame.image_data_url.split(",", 1) + image_bytes = base64.b64decode(b64_data, validate=True) + except Exception: + logger.exception("Failed to decode model vision frame for debug save") + return + + extension = _image_extension_from_mime_type(vision_frame.mime_type) + timestamp_ms = int(time.time() * 1000) + path = self._model_image_save_dir / f"{timestamp_ms}_model_input{extension}" + + try: + self._model_image_save_dir.mkdir(parents=True, exist_ok=True) + path.write_bytes(image_bytes) + except Exception: + logger.exception("Failed to save model vision frame: path=%s", path) + return + + logger.info( + "Saved model vision frame: path=%s bytes=%s source_path=%s", + path, + len(image_bytes), + vision_frame.saved_path, + ) + + def _run_selected_llm( + self, + chat_ctx: ChatContext, + tools: list[llm.Tool], + model_settings: ModelSettings, + *, + has_image: bool, + ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: + selected_llm = self._vision_llm if has_image else self._text_llm + if selected_llm is None: + return Agent.default.llm_node(self, chat_ctx, tools, model_settings) + + activity = self._get_activity_or_raise() + tool_choice = model_settings.tool_choice + conn_options = activity.session.conn_options.llm_conn_options + + async def _stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: + async with selected_llm.chat( + chat_ctx=chat_ctx, + tools=tools, + tool_choice=tool_choice, + conn_options=conn_options, + ) as stream: + async for chunk in stream: + yield chunk + + return _stream() + async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: if self._memory_client is None: return "" @@ -269,6 +402,73 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s return chat_ctx +def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext: + chat_ctx = chat_ctx.copy() + image_content = llm.ImageContent( + image=vision_frame.image_data_url, + mime_type=vision_frame.mime_type, + inference_detail="auto", + ) + + for index in range(len(chat_ctx.items) - 1, -1, -1): + item = chat_ctx.items[index] + if isinstance(item, ChatMessage) and item.role == "user": + user_msg = item.model_copy(deep=True) + content = list(user_msg.content) + content.append(image_content) + user_msg.content = content + chat_ctx.items[index] = user_msg + return chat_ctx + + chat_ctx.items.append(ChatMessage(role="user", content=[image_content])) + return chat_ctx + + +def _normalize_input_mode(value: str | None) -> str: + if not value: + return AUTO_INPUT_MODE + + normalized = value.strip().lower().replace("-", "_") + aliases = { + "image_voice": VISION_VOICE_INPUT_MODE, + "image": VISION_VOICE_INPUT_MODE, + "vision": VISION_VOICE_INPUT_MODE, + "vision_voice": VISION_VOICE_INPUT_MODE, + "voice_image": VISION_VOICE_INPUT_MODE, + "audio": VOICE_INPUT_MODE, + "voice": VOICE_INPUT_MODE, + "auto": AUTO_INPUT_MODE, + } + mode = aliases.get(normalized) + if mode is not None: + return mode + + logger.warning("Invalid CUSTOM_AGENT_INPUT_MODE=%r, using %s", value, AUTO_INPUT_MODE) + return AUTO_INPUT_MODE + + +def _image_extension_from_mime_type(mime_type: str) -> str: + normalized = mime_type.strip().lower() + if normalized == "image/png": + return ".png" + if normalized == "image/webp": + return ".webp" + if normalized == "image/gif": + return ".gif" + return ".jpg" + + +def _model_image_save_dir_from_env() -> Path | None: + if not _env_bool("CUSTOM_SAVE_MODEL_IMAGES", True): + return None + + configured = os.getenv("CUSTOM_MODEL_IMAGE_SAVE_DIR") + if configured: + return Path(configured).expanduser() + + return Path(__file__).with_name("model_images") + + server = AgentServer() @@ -295,6 +495,9 @@ async def entrypoint(ctx: JobContext) -> None: LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL") LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max") LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY") + TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL) + VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL) + INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE")) if not LLM_API_KEY: raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") @@ -339,14 +542,29 @@ async def entrypoint(ctx: JobContext) -> None: http_client=http_client, ) + base_llm = openai.LLM( + model=LLM_MODEL, + client=openai_client, + ) + text_llm = ( + openai.LLM(model=TEXT_LLM_MODEL, client=openai_client) + if TEXT_LLM_MODEL != LLM_MODEL + else base_llm + ) + vision_llm = ( + openai.LLM(model=VISION_LLM_MODEL, client=openai_client) + if VISION_LLM_MODEL != LLM_MODEL + else base_llm + ) + vision_store = VisionFrameStore( + max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0) + ) + session: AgentSession = AgentSession( # 1. Custom ASR blackbox with StreamAdapter stt=stt_stream, # 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI. - llm=openai.LLM( - model=LLM_MODEL, - client=openai_client, - ), + llm=base_llm, # 3. TTS blackbox tts=BlackboxTTS( url=TTS_URL, @@ -388,6 +606,47 @@ async def entrypoint(ctx: JobContext) -> None: elif item.role == "assistant" and item.metrics: logger.info("Assistant turn metrics: %s", item.metrics) + @ctx.room.on("data_received") + def _on_data_received(data_packet) -> None: + packet_topic = getattr(data_packet, "topic", None) + if packet_topic not in {None, "", VISION_FRAME_TOPIC}: + return + + if INPUT_MODE == VOICE_INPUT_MODE: + logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE) + return + + try: + payload = json.loads(data_packet.data.decode("utf-8")) + except Exception: + logger.exception("Failed to decode vision frame payload") + return + + if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC: + return + + image = payload.get("image") + if not isinstance(image, str) or not image: + logger.warning("Received vision frame without image data") + return + + mime_type = payload.get("mime_type") + if not isinstance(mime_type, str) or not mime_type: + mime_type = "image/jpeg" + + saved_path = payload.get("saved_path") + vision_store.update( + image=image, + mime_type=mime_type, + saved_path=saved_path if isinstance(saved_path, str) else None, + ) + logger.info( + "Cached vision frame: mime_type=%s image_chars=%s saved_path=%s", + mime_type, + len(image), + saved_path, + ) + memory_client = ( MemoryRecallClient( url=MEMORY_URL, @@ -400,7 +659,14 @@ async def entrypoint(ctx: JobContext) -> None: ) await session.start( - agent=CustomAgent(memory_client=memory_client), + agent=CustomAgent( + memory_client=memory_client, + vision_store=vision_store, + input_mode=INPUT_MODE, + text_llm=text_llm, + vision_llm=vision_llm, + model_image_save_dir=_model_image_save_dir_from_env(), + ), room=ctx.room, room_options=room_io.RoomOptions( audio_output=room_io.AudioOutputOptions(