feat: support vlm chat

This commit is contained in:
0Xiao0
2026-05-25 17:17:38 +08:00
parent 2064db15dc
commit e097323176
3 changed files with 282 additions and 9 deletions

View File

@ -23,6 +23,10 @@ CUSTOM_LLM_BASE_URL=http://localhost/v1
CUSTOM_LLM_MODEL=Qwen-VL
CUSTOM_LLM_API_KEY=
CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_SAVE_MODEL_IMAGES=false
# CUSTOM_TEXT_LLM_MODEL=
# CUSTOM_VISION_LLM_MODEL=
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com
# CUSTOM_LLM_MODEL=deepseek-v4-flash
@ -31,7 +35,7 @@ CUSTOM_LLM_VERIFY_SSL=false
# TTS blackbox
CUSTOM_TTS_URL=http://localhost:5000/tts-blackbox
CUSTOM_TTS_URL=http://localhost:5050/tts-blackbox
CUSTOM_TTS_MODEL=voxcpmtts
# CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav
CUSTOM_TTS_STREAMING=true

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
__pycache__/
.env
model_images/

View File

@ -1,7 +1,10 @@
import base64
import json
import logging
import os
import time
from collections.abc import AsyncIterable
from dataclasses import dataclass
from pathlib import Path
from dotenv import load_dotenv
@ -56,12 +59,65 @@ GENERAL_INSTRUCTIONS = """
ROOM_LOCATOR_MODE = "room_locator"
GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice"
VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto"
VISION_FRAME_TOPIC = "vision.frame"
@dataclass
class VisionFrame:
image_data_url: str
received_at: float
mime_type: str
saved_path: str | None = None
class VisionFrameStore:
def __init__(self, *, max_age_seconds: float) -> None:
self._max_age_seconds = max_age_seconds
self._latest_frame: VisionFrame | None = None
def update(self, *, image: str, mime_type: str, saved_path: str | None = None) -> None:
self._latest_frame = VisionFrame(
image_data_url=f"data:{mime_type};base64,{image}",
received_at=time.monotonic(),
mime_type=mime_type,
saved_path=saved_path,
)
def consume_fresh(self) -> VisionFrame | None:
frame = self._latest_frame
if frame is None:
return None
age = time.monotonic() - frame.received_at
self._latest_frame = None
if age > self._max_age_seconds:
logger.info("Dropping stale vision frame: age=%.3fs", age)
return None
return frame
class CustomAgent(Agent):
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
def __init__(
self,
*,
memory_client: MemoryRecallClient | None = None,
vision_store: VisionFrameStore | None = None,
input_mode: str = AUTO_INPUT_MODE,
text_llm: llm.LLM | None = None,
vision_llm: llm.LLM | None = None,
model_image_save_dir: Path | None = None,
) -> None:
super().__init__(instructions=GENERAL_INSTRUCTIONS)
self._memory_client = memory_client
self._vision_store = vision_store
self._input_mode = input_mode
self._text_llm = text_llm
self._vision_llm = vision_llm
self._model_image_save_dir = model_image_save_dir
async def on_enter(self) -> None:
# self.session.generate_reply(instructions="greet the user and introduce yourself")
@ -77,7 +133,13 @@ class CustomAgent(Agent):
user_query = _latest_user_text(chat_ctx)
mode = _select_mode(user_query)
logger.info("Selected agent mode: %s", mode)
vision_frame = self._consume_vision_frame()
logger.info(
"Selected agent mode: %s input_mode=%s has_image=%s",
mode,
self._input_mode,
vision_frame is not None,
)
chat_ctx = chat_ctx.copy()
update_chat_instructions(
@ -93,7 +155,16 @@ class CustomAgent(Agent):
if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
if vision_frame is not None:
self._save_model_vision_frame(vision_frame)
chat_ctx = _with_vision_as_latest_user_message(chat_ctx, vision_frame)
llm_result = self._run_selected_llm(
chat_ctx,
tools,
model_settings,
has_image=vision_frame is not None,
)
if not hasattr(llm_result, "__aiter__"):
elapsed = time.perf_counter() - llm_node_started_at
logger.info("LLM node completed without streaming in %.3fs", elapsed)
@ -123,6 +194,68 @@ class CustomAgent(Agent):
return _instrumented_stream()
def _consume_vision_frame(self) -> VisionFrame | None:
if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None:
return None
return self._vision_store.consume_fresh()
def _save_model_vision_frame(self, vision_frame: VisionFrame) -> None:
if self._model_image_save_dir is None:
return
try:
_, b64_data = vision_frame.image_data_url.split(",", 1)
image_bytes = base64.b64decode(b64_data, validate=True)
except Exception:
logger.exception("Failed to decode model vision frame for debug save")
return
extension = _image_extension_from_mime_type(vision_frame.mime_type)
timestamp_ms = int(time.time() * 1000)
path = self._model_image_save_dir / f"{timestamp_ms}_model_input{extension}"
try:
self._model_image_save_dir.mkdir(parents=True, exist_ok=True)
path.write_bytes(image_bytes)
except Exception:
logger.exception("Failed to save model vision frame: path=%s", path)
return
logger.info(
"Saved model vision frame: path=%s bytes=%s source_path=%s",
path,
len(image_bytes),
vision_frame.saved_path,
)
def _run_selected_llm(
self,
chat_ctx: ChatContext,
tools: list[llm.Tool],
model_settings: ModelSettings,
*,
has_image: bool,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
selected_llm = self._vision_llm if has_image else self._text_llm
if selected_llm is None:
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
activity = self._get_activity_or_raise()
tool_choice = model_settings.tool_choice
conn_options = activity.session.conn_options.llm_conn_options
async def _stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
async with selected_llm.chat(
chat_ctx=chat_ctx,
tools=tools,
tool_choice=tool_choice,
conn_options=conn_options,
) as stream:
async for chunk in stream:
yield chunk
return _stream()
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
if self._memory_client is None:
return ""
@ -269,6 +402,73 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
return chat_ctx
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
chat_ctx = chat_ctx.copy()
image_content = llm.ImageContent(
image=vision_frame.image_data_url,
mime_type=vision_frame.mime_type,
inference_detail="auto",
)
for index in range(len(chat_ctx.items) - 1, -1, -1):
item = chat_ctx.items[index]
if isinstance(item, ChatMessage) and item.role == "user":
user_msg = item.model_copy(deep=True)
content = list(user_msg.content)
content.append(image_content)
user_msg.content = content
chat_ctx.items[index] = user_msg
return chat_ctx
chat_ctx.items.append(ChatMessage(role="user", content=[image_content]))
return chat_ctx
def _normalize_input_mode(value: str | None) -> str:
if not value:
return AUTO_INPUT_MODE
normalized = value.strip().lower().replace("-", "_")
aliases = {
"image_voice": VISION_VOICE_INPUT_MODE,
"image": VISION_VOICE_INPUT_MODE,
"vision": VISION_VOICE_INPUT_MODE,
"vision_voice": VISION_VOICE_INPUT_MODE,
"voice_image": VISION_VOICE_INPUT_MODE,
"audio": VOICE_INPUT_MODE,
"voice": VOICE_INPUT_MODE,
"auto": AUTO_INPUT_MODE,
}
mode = aliases.get(normalized)
if mode is not None:
return mode
logger.warning("Invalid CUSTOM_AGENT_INPUT_MODE=%r, using %s", value, AUTO_INPUT_MODE)
return AUTO_INPUT_MODE
def _image_extension_from_mime_type(mime_type: str) -> str:
normalized = mime_type.strip().lower()
if normalized == "image/png":
return ".png"
if normalized == "image/webp":
return ".webp"
if normalized == "image/gif":
return ".gif"
return ".jpg"
def _model_image_save_dir_from_env() -> Path | None:
if not _env_bool("CUSTOM_SAVE_MODEL_IMAGES", True):
return None
configured = os.getenv("CUSTOM_MODEL_IMAGE_SAVE_DIR")
if configured:
return Path(configured).expanduser()
return Path(__file__).with_name("model_images")
server = AgentServer()
@ -295,6 +495,9 @@ async def entrypoint(ctx: JobContext) -> None:
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
if not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
@ -339,14 +542,29 @@ async def entrypoint(ctx: JobContext) -> None:
http_client=http_client,
)
base_llm = openai.LLM(
model=LLM_MODEL,
client=openai_client,
)
text_llm = (
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
if TEXT_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_llm = (
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
if VISION_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_store = VisionFrameStore(
max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0)
)
session: AgentSession = AgentSession(
# 1. Custom ASR blackbox with StreamAdapter
stt=stt_stream,
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
llm=openai.LLM(
model=LLM_MODEL,
client=openai_client,
),
llm=base_llm,
# 3. TTS blackbox
tts=BlackboxTTS(
url=TTS_URL,
@ -388,6 +606,47 @@ async def entrypoint(ctx: JobContext) -> None:
elif item.role == "assistant" and item.metrics:
logger.info("Assistant turn metrics: %s", item.metrics)
@ctx.room.on("data_received")
def _on_data_received(data_packet) -> None:
packet_topic = getattr(data_packet, "topic", None)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
return
try:
payload = json.loads(data_packet.data.decode("utf-8"))
except Exception:
logger.exception("Failed to decode vision frame payload")
return
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
return
image = payload.get("image")
if not isinstance(image, str) or not image:
logger.warning("Received vision frame without image data")
return
mime_type = payload.get("mime_type")
if not isinstance(mime_type, str) or not mime_type:
mime_type = "image/jpeg"
saved_path = payload.get("saved_path")
vision_store.update(
image=image,
mime_type=mime_type,
saved_path=saved_path if isinstance(saved_path, str) else None,
)
logger.info(
"Cached vision frame: mime_type=%s image_chars=%s saved_path=%s",
mime_type,
len(image),
saved_path,
)
memory_client = (
MemoryRecallClient(
url=MEMORY_URL,
@ -400,7 +659,14 @@ async def entrypoint(ctx: JobContext) -> None:
)
await session.start(
agent=CustomAgent(memory_client=memory_client),
agent=CustomAgent(
memory_client=memory_client,
vision_store=vision_store,
input_mode=INPUT_MODE,
text_llm=text_llm,
vision_llm=vision_llm,
model_image_save_dir=_model_image_save_dir_from_env(),
),
room=ctx.room,
room_options=room_io.RoomOptions(
audio_output=room_io.AudioOutputOptions(