feat: support vlm chat

This commit is contained in:
0Xiao0
2026-05-25 17:17:38 +08:00
parent 2064db15dc
commit e097323176
3 changed files with 282 additions and 9 deletions

View File

@ -23,6 +23,10 @@ CUSTOM_LLM_BASE_URL=http://localhost/v1
CUSTOM_LLM_MODEL=Qwen-VL CUSTOM_LLM_MODEL=Qwen-VL
CUSTOM_LLM_API_KEY= CUSTOM_LLM_API_KEY=
CUSTOM_LLM_VERIFY_SSL=false CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_SAVE_MODEL_IMAGES=false
# CUSTOM_TEXT_LLM_MODEL=
# CUSTOM_VISION_LLM_MODEL=
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com # CUSTOM_LLM_BASE_URL=https://api.deepseek.com
# CUSTOM_LLM_MODEL=deepseek-v4-flash # CUSTOM_LLM_MODEL=deepseek-v4-flash
@ -31,7 +35,7 @@ CUSTOM_LLM_VERIFY_SSL=false
# TTS blackbox # TTS blackbox
CUSTOM_TTS_URL=http://localhost:5000/tts-blackbox CUSTOM_TTS_URL=http://localhost:5050/tts-blackbox
CUSTOM_TTS_MODEL=voxcpmtts CUSTOM_TTS_MODEL=voxcpmtts
# CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav # CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav
CUSTOM_TTS_STREAMING=true CUSTOM_TTS_STREAMING=true

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
__pycache__/
.env
model_images/

View File

@ -1,7 +1,10 @@
import base64
import json
import logging import logging
import os import os
import time import time
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
@ -56,12 +59,65 @@ GENERAL_INSTRUCTIONS = """
ROOM_LOCATOR_MODE = "room_locator" ROOM_LOCATOR_MODE = "room_locator"
GENERAL_MODE = "general" GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice"
VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto"
VISION_FRAME_TOPIC = "vision.frame"
@dataclass
class VisionFrame:
image_data_url: str
received_at: float
mime_type: str
saved_path: str | None = None
class VisionFrameStore:
def __init__(self, *, max_age_seconds: float) -> None:
self._max_age_seconds = max_age_seconds
self._latest_frame: VisionFrame | None = None
def update(self, *, image: str, mime_type: str, saved_path: str | None = None) -> None:
self._latest_frame = VisionFrame(
image_data_url=f"data:{mime_type};base64,{image}",
received_at=time.monotonic(),
mime_type=mime_type,
saved_path=saved_path,
)
def consume_fresh(self) -> VisionFrame | None:
frame = self._latest_frame
if frame is None:
return None
age = time.monotonic() - frame.received_at
self._latest_frame = None
if age > self._max_age_seconds:
logger.info("Dropping stale vision frame: age=%.3fs", age)
return None
return frame
class CustomAgent(Agent): class CustomAgent(Agent):
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None: def __init__(
self,
*,
memory_client: MemoryRecallClient | None = None,
vision_store: VisionFrameStore | None = None,
input_mode: str = AUTO_INPUT_MODE,
text_llm: llm.LLM | None = None,
vision_llm: llm.LLM | None = None,
model_image_save_dir: Path | None = None,
) -> None:
super().__init__(instructions=GENERAL_INSTRUCTIONS) super().__init__(instructions=GENERAL_INSTRUCTIONS)
self._memory_client = memory_client self._memory_client = memory_client
self._vision_store = vision_store
self._input_mode = input_mode
self._text_llm = text_llm
self._vision_llm = vision_llm
self._model_image_save_dir = model_image_save_dir
async def on_enter(self) -> None: async def on_enter(self) -> None:
# self.session.generate_reply(instructions="greet the user and introduce yourself") # self.session.generate_reply(instructions="greet the user and introduce yourself")
@ -77,7 +133,13 @@ class CustomAgent(Agent):
user_query = _latest_user_text(chat_ctx) user_query = _latest_user_text(chat_ctx)
mode = _select_mode(user_query) mode = _select_mode(user_query)
logger.info("Selected agent mode: %s", mode) vision_frame = self._consume_vision_frame()
logger.info(
"Selected agent mode: %s input_mode=%s has_image=%s",
mode,
self._input_mode,
vision_frame is not None,
)
chat_ctx = chat_ctx.copy() chat_ctx = chat_ctx.copy()
update_chat_instructions( update_chat_instructions(
@ -93,7 +155,16 @@ class CustomAgent(Agent):
if memory_context: if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings) if vision_frame is not None:
self._save_model_vision_frame(vision_frame)
chat_ctx = _with_vision_as_latest_user_message(chat_ctx, vision_frame)
llm_result = self._run_selected_llm(
chat_ctx,
tools,
model_settings,
has_image=vision_frame is not None,
)
if not hasattr(llm_result, "__aiter__"): if not hasattr(llm_result, "__aiter__"):
elapsed = time.perf_counter() - llm_node_started_at elapsed = time.perf_counter() - llm_node_started_at
logger.info("LLM node completed without streaming in %.3fs", elapsed) logger.info("LLM node completed without streaming in %.3fs", elapsed)
@ -123,6 +194,68 @@ class CustomAgent(Agent):
return _instrumented_stream() return _instrumented_stream()
def _consume_vision_frame(self) -> VisionFrame | None:
if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None:
return None
return self._vision_store.consume_fresh()
def _save_model_vision_frame(self, vision_frame: VisionFrame) -> None:
if self._model_image_save_dir is None:
return
try:
_, b64_data = vision_frame.image_data_url.split(",", 1)
image_bytes = base64.b64decode(b64_data, validate=True)
except Exception:
logger.exception("Failed to decode model vision frame for debug save")
return
extension = _image_extension_from_mime_type(vision_frame.mime_type)
timestamp_ms = int(time.time() * 1000)
path = self._model_image_save_dir / f"{timestamp_ms}_model_input{extension}"
try:
self._model_image_save_dir.mkdir(parents=True, exist_ok=True)
path.write_bytes(image_bytes)
except Exception:
logger.exception("Failed to save model vision frame: path=%s", path)
return
logger.info(
"Saved model vision frame: path=%s bytes=%s source_path=%s",
path,
len(image_bytes),
vision_frame.saved_path,
)
def _run_selected_llm(
self,
chat_ctx: ChatContext,
tools: list[llm.Tool],
model_settings: ModelSettings,
*,
has_image: bool,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
selected_llm = self._vision_llm if has_image else self._text_llm
if selected_llm is None:
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
activity = self._get_activity_or_raise()
tool_choice = model_settings.tool_choice
conn_options = activity.session.conn_options.llm_conn_options
async def _stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
async with selected_llm.chat(
chat_ctx=chat_ctx,
tools=tools,
tool_choice=tool_choice,
conn_options=conn_options,
) as stream:
async for chunk in stream:
yield chunk
return _stream()
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
if self._memory_client is None: if self._memory_client is None:
return "" return ""
@ -269,6 +402,73 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
return chat_ctx return chat_ctx
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
chat_ctx = chat_ctx.copy()
image_content = llm.ImageContent(
image=vision_frame.image_data_url,
mime_type=vision_frame.mime_type,
inference_detail="auto",
)
for index in range(len(chat_ctx.items) - 1, -1, -1):
item = chat_ctx.items[index]
if isinstance(item, ChatMessage) and item.role == "user":
user_msg = item.model_copy(deep=True)
content = list(user_msg.content)
content.append(image_content)
user_msg.content = content
chat_ctx.items[index] = user_msg
return chat_ctx
chat_ctx.items.append(ChatMessage(role="user", content=[image_content]))
return chat_ctx
def _normalize_input_mode(value: str | None) -> str:
if not value:
return AUTO_INPUT_MODE
normalized = value.strip().lower().replace("-", "_")
aliases = {
"image_voice": VISION_VOICE_INPUT_MODE,
"image": VISION_VOICE_INPUT_MODE,
"vision": VISION_VOICE_INPUT_MODE,
"vision_voice": VISION_VOICE_INPUT_MODE,
"voice_image": VISION_VOICE_INPUT_MODE,
"audio": VOICE_INPUT_MODE,
"voice": VOICE_INPUT_MODE,
"auto": AUTO_INPUT_MODE,
}
mode = aliases.get(normalized)
if mode is not None:
return mode
logger.warning("Invalid CUSTOM_AGENT_INPUT_MODE=%r, using %s", value, AUTO_INPUT_MODE)
return AUTO_INPUT_MODE
def _image_extension_from_mime_type(mime_type: str) -> str:
normalized = mime_type.strip().lower()
if normalized == "image/png":
return ".png"
if normalized == "image/webp":
return ".webp"
if normalized == "image/gif":
return ".gif"
return ".jpg"
def _model_image_save_dir_from_env() -> Path | None:
if not _env_bool("CUSTOM_SAVE_MODEL_IMAGES", True):
return None
configured = os.getenv("CUSTOM_MODEL_IMAGE_SAVE_DIR")
if configured:
return Path(configured).expanduser()
return Path(__file__).with_name("model_images")
server = AgentServer() server = AgentServer()
@ -295,6 +495,9 @@ async def entrypoint(ctx: JobContext) -> None:
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL") LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max") LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY") LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
if not LLM_API_KEY: if not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
@ -339,14 +542,29 @@ async def entrypoint(ctx: JobContext) -> None:
http_client=http_client, http_client=http_client,
) )
base_llm = openai.LLM(
model=LLM_MODEL,
client=openai_client,
)
text_llm = (
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
if TEXT_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_llm = (
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
if VISION_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_store = VisionFrameStore(
max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0)
)
session: AgentSession = AgentSession( session: AgentSession = AgentSession(
# 1. Custom ASR blackbox with StreamAdapter # 1. Custom ASR blackbox with StreamAdapter
stt=stt_stream, stt=stt_stream,
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI. # 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
llm=openai.LLM( llm=base_llm,
model=LLM_MODEL,
client=openai_client,
),
# 3. TTS blackbox # 3. TTS blackbox
tts=BlackboxTTS( tts=BlackboxTTS(
url=TTS_URL, url=TTS_URL,
@ -388,6 +606,47 @@ async def entrypoint(ctx: JobContext) -> None:
elif item.role == "assistant" and item.metrics: elif item.role == "assistant" and item.metrics:
logger.info("Assistant turn metrics: %s", item.metrics) logger.info("Assistant turn metrics: %s", item.metrics)
@ctx.room.on("data_received")
def _on_data_received(data_packet) -> None:
packet_topic = getattr(data_packet, "topic", None)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
return
try:
payload = json.loads(data_packet.data.decode("utf-8"))
except Exception:
logger.exception("Failed to decode vision frame payload")
return
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
return
image = payload.get("image")
if not isinstance(image, str) or not image:
logger.warning("Received vision frame without image data")
return
mime_type = payload.get("mime_type")
if not isinstance(mime_type, str) or not mime_type:
mime_type = "image/jpeg"
saved_path = payload.get("saved_path")
vision_store.update(
image=image,
mime_type=mime_type,
saved_path=saved_path if isinstance(saved_path, str) else None,
)
logger.info(
"Cached vision frame: mime_type=%s image_chars=%s saved_path=%s",
mime_type,
len(image),
saved_path,
)
memory_client = ( memory_client = (
MemoryRecallClient( MemoryRecallClient(
url=MEMORY_URL, url=MEMORY_URL,
@ -400,7 +659,14 @@ async def entrypoint(ctx: JobContext) -> None:
) )
await session.start( await session.start(
agent=CustomAgent(memory_client=memory_client), agent=CustomAgent(
memory_client=memory_client,
vision_store=vision_store,
input_mode=INPUT_MODE,
text_llm=text_llm,
vision_llm=vision_llm,
model_image_save_dir=_model_image_save_dir_from_env(),
),
room=ctx.room, room=ctx.room,
room_options=room_io.RoomOptions( room_options=room_io.RoomOptions(
audio_output=room_io.AudioOutputOptions( audio_output=room_io.AudioOutputOptions(