Files
livekit_agents/custom_agent.py
2026-06-03 17:25:08 +08:00

1023 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import base64
import json
import logging
import os
import re
import time
from collections.abc import AsyncIterable
from dataclasses import dataclass
from pathlib import Path
from beaver_llm import BeaverLLM
from dotenv import load_dotenv
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
from memory import MemoryRecallClient
from tts import BlackboxTTS
from asr import BlackboxSTT
from livekit.agents import (
Agent,
AgentServer,
AgentSession,
ChatContext,
ChatMessage,
FlushSentinel,
JobContext,
JobProcess,
MetricsCollectedEvent,
ModelSettings,
RecordingOptions,
TurnHandlingOptions,
cli,
llm,
metrics,
room_io,
stt,
)
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
from livekit.plugins import openai, silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel
logger = logging.getLogger("custom-agent")
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
ROOM_LOCATOR_INSTRUCTIONS = """
你是一个房间物品定位助手。
当用户询问房间内某个物品的位置时:
- 只用一句中文回答
- 描述目标物品和其他物品的相对位置关系
- 不要使用 Markdown、emoji、列表、标题、坐标区域标签
- 不要解释推理过程
如果用户的问题与房间物品定位无关,则正常回答用户问题。
""".strip()
GENERAL_INSTRUCTIONS = """
你是一个智能语音助手。
正常回答用户问题。
回答自然、简洁、准确。
""".strip()
EMOTION_INSTRUCTIONS = """
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
情绪标签之后直接输出给用户的正常回复,不要解释标签。
""".strip()
ROOM_LOCATOR_MODE = "room_locator"
GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice"
VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto"
VISION_FRAME_TOPIC = "vision.frame"
DEFAULT_EMOTION = "neutral"
EMOTION_LABELS = {
"neutral",
"happy",
"sad",
"angry",
"surprised",
"fearful",
"calm",
"concerned",
}
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
TTS_EMOTION_LINE_RE = re.compile(
r"^\s*(?:emotion|情绪)\s*[:=]\s*[\w\u4e00-\u9fff-]{1,40}\s*[,,。.!\s-]*",
re.IGNORECASE,
)
MAX_EMOTION_PREFIX_CHARS = 80
@dataclass
class VisionFrame:
image_data_url: str
received_at: float
mime_type: str
saved_path: str | None = None
class VisionFrameStore:
def __init__(self, *, max_age_seconds: float) -> None:
self._max_age_seconds = max_age_seconds
self._latest_frame: VisionFrame | None = None
def update(self, *, image: str, mime_type: str, saved_path: str | None = None) -> None:
self._latest_frame = VisionFrame(
image_data_url=f"data:{mime_type};base64,{image}",
received_at=time.monotonic(),
mime_type=mime_type,
saved_path=saved_path,
)
def consume_fresh(self) -> VisionFrame | None:
frame = self._latest_frame
if frame is None:
return None
age = time.monotonic() - frame.received_at
self._latest_frame = None
if age > self._max_age_seconds:
logger.info("Dropping stale vision frame: age=%.3fs", age)
return None
return frame
class CustomAgent(Agent):
def __init__(
self,
*,
memory_client: MemoryRecallClient | None = None,
vision_store: VisionFrameStore | None = None,
input_mode: str = AUTO_INPUT_MODE,
text_llm: llm.LLM | None = None,
vision_llm: llm.LLM | None = None,
model_image_save_dir: Path | None = None,
) -> None:
super().__init__(instructions=_with_emotion_instructions(GENERAL_INSTRUCTIONS))
self._memory_client = memory_client
self._vision_store = vision_store
self._input_mode = input_mode
self._text_llm = text_llm
self._vision_llm = vision_llm
self._model_image_save_dir = model_image_save_dir
self.current_emotion = DEFAULT_EMOTION
self._emotion_prefix_buffer = ""
self._emotion_prefix_done = True
async def on_enter(self) -> None:
# self.session.generate_reply(instructions="greet the user and introduce yourself")
pass
async def llm_node(
self,
chat_ctx: ChatContext,
tools: list[llm.Tool],
model_settings: ModelSettings,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
llm_node_started_at = time.perf_counter()
user_query = _latest_user_text(chat_ctx)
mode = _select_mode(user_query)
vision_frame = self._consume_vision_frame()
logger.info(
"Selected agent mode: %s input_mode=%s has_image=%s",
mode,
self._input_mode,
vision_frame is not None,
)
chat_ctx = chat_ctx.copy()
update_chat_instructions(
chat_ctx,
instructions=_with_emotion_instructions(
ROOM_LOCATOR_INSTRUCTIONS if mode == ROOM_LOCATOR_MODE else GENERAL_INSTRUCTIONS
),
add_if_missing=True,
)
if mode == ROOM_LOCATOR_MODE:
memory_context = await self._recall_room_memory(chat_ctx)
if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
if vision_frame is not None:
self._save_model_vision_frame(vision_frame)
chat_ctx = _with_vision_as_latest_user_message(chat_ctx, vision_frame)
llm_result = self._run_selected_llm(
chat_ctx,
tools,
model_settings,
has_image=vision_frame is not None,
)
if not hasattr(llm_result, "__aiter__"):
elapsed = time.perf_counter() - llm_node_started_at
logger.info("LLM node completed without streaming in %.3fs", elapsed)
return llm_result
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
first_chunk_at: float | None = None
chunk_count = 0
self._emotion_prefix_buffer = ""
self._emotion_prefix_done = False
try:
async for chunk in llm_result:
chunk_count += 1
if first_chunk_at is None:
first_chunk_at = time.perf_counter()
logger.info(
"LLM first chunk after %.3fs",
first_chunk_at - llm_node_started_at,
)
async for output_chunk in self._observe_emotion_prefix(chunk):
yield output_chunk
finally:
finished_at = time.perf_counter()
logger.info(
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
finished_at - llm_node_started_at,
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
chunk_count,
)
return _instrumented_stream()
def tts_node(self, text: AsyncIterable[str], model_settings: ModelSettings):
return Agent.default.tts_node(self, _strip_emotion_for_tts(text), model_settings)
def _consume_vision_frame(self) -> VisionFrame | None:
if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None:
return None
return self._vision_store.consume_fresh()
def _save_model_vision_frame(self, vision_frame: VisionFrame) -> None:
if self._model_image_save_dir is None:
return
try:
_, b64_data = vision_frame.image_data_url.split(",", 1)
image_bytes = base64.b64decode(b64_data, validate=True)
except Exception:
logger.exception("Failed to decode model vision frame for debug save")
return
extension = _image_extension_from_mime_type(vision_frame.mime_type)
timestamp_ms = int(time.time() * 1000)
path = self._model_image_save_dir / f"{timestamp_ms}_model_input{extension}"
try:
self._model_image_save_dir.mkdir(parents=True, exist_ok=True)
path.write_bytes(image_bytes)
except Exception:
logger.exception("Failed to save model vision frame: path=%s", path)
return
logger.info(
"Saved model vision frame: path=%s bytes=%s source_path=%s",
path,
len(image_bytes),
vision_frame.saved_path,
)
def _run_selected_llm(
self,
chat_ctx: ChatContext,
tools: list[llm.Tool],
model_settings: ModelSettings,
*,
has_image: bool,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
selected_llm = self._vision_llm if has_image else self._text_llm
if selected_llm is None:
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
activity = self._get_activity_or_raise()
tool_choice = model_settings.tool_choice
conn_options = activity.session.conn_options.llm_conn_options
async def _stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
async with selected_llm.chat(
chat_ctx=chat_ctx,
tools=tools,
tool_choice=tool_choice,
conn_options=conn_options,
) as stream:
async for chunk in stream:
yield chunk
return _stream()
async def _observe_emotion_prefix(
self, chunk: llm.ChatChunk | str | FlushSentinel
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
if isinstance(chunk, str):
self._consume_emotion_prefix(chunk)
yield chunk
return
if isinstance(chunk, llm.ChatChunk) and chunk.delta and chunk.delta.content:
self._consume_emotion_prefix(chunk.delta.content)
yield chunk
return
yield chunk
def _consume_emotion_prefix(self, content: str) -> None:
if self._emotion_prefix_done:
return
self._emotion_prefix_buffer += content
match = EMOTION_PREFIX_RE.match(self._emotion_prefix_buffer)
if match:
emotion = match.group(1).lower()
if emotion not in EMOTION_LABELS:
logger.warning("LLM returned unsupported emotion=%s, using neutral", emotion)
emotion = DEFAULT_EMOTION
self.current_emotion = emotion
self._emotion_prefix_done = True
self._emotion_prefix_buffer = ""
logger.info("LLM emotion selected: %s", emotion)
return
candidate = self._emotion_prefix_buffer.lstrip().lower()
might_still_be_prefix = (
not candidate
or "<emotion=".startswith(candidate)
or (candidate.startswith("<emotion=") and ">" not in candidate)
)
if might_still_be_prefix and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
return
self._emotion_prefix_done = True
self._emotion_prefix_buffer = ""
logger.warning("LLM response did not start with an emotion prefix")
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
if self._memory_client is None:
return ""
user_query = _latest_user_text(chat_ctx)
if not user_query:
return ""
started_at = time.perf_counter()
try:
recalled = await self._memory_client.recall(user_query)
elapsed = time.perf_counter() - started_at
logger.info(
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
elapsed,
len(user_query),
len(recalled),
)
return recalled
except Exception:
logger.exception(
"Unexpected memory recall failure after %.3fs",
time.perf_counter() - started_at,
)
return ""
def _select_mode(user_query: str) -> str:
normalized = _normalize_text(user_query)
if not normalized:
return GENERAL_MODE
if _is_room_locator_query(normalized):
return ROOM_LOCATOR_MODE
return GENERAL_MODE
def _with_emotion_instructions(instructions: str) -> str:
return f"{instructions}\n\n{EMOTION_INSTRUCTIONS}"
async def _strip_emotion_for_tts(text: AsyncIterable[str]) -> AsyncIterable[str]:
prefix_buffer = ""
scanning_prefix = True
async for chunk in text:
if not chunk:
continue
if scanning_prefix:
prefix_buffer += chunk
cleaned, done = _strip_leading_tts_emotion(prefix_buffer)
if not done:
continue
scanning_prefix = False
prefix_buffer = ""
if cleaned:
yield _strip_inline_tts_emotion(cleaned)
continue
cleaned = _strip_inline_tts_emotion(chunk)
if cleaned:
yield cleaned
if scanning_prefix and prefix_buffer:
cleaned, _ = _strip_leading_tts_emotion(prefix_buffer, force=True)
cleaned = _strip_inline_tts_emotion(cleaned)
if cleaned:
yield cleaned
def _strip_leading_tts_emotion(text: str, *, force: bool = False) -> tuple[str, bool]:
match = TTS_EMOTION_MARKUP_RE.match(text)
if match:
return text[match.end() :], True
match = TTS_EMOTION_LINE_RE.match(text)
if match:
return text[match.end() :], True
candidate = text.lstrip().lower()
might_still_be_emotion = (
not candidate
or "<emotion=".startswith(candidate)
or (candidate.startswith("<emotion") and ">" not in candidate)
or "emotion".startswith(candidate)
or (candidate.startswith("emotion") and len(candidate) <= MAX_EMOTION_PREFIX_CHARS)
or "情绪".startswith(candidate)
)
if not force and might_still_be_emotion and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
return "", False
return text, True
def _strip_inline_tts_emotion(text: str) -> str:
return TTS_EMOTION_MARKUP_RE.sub("", text)
def _is_room_locator_query(normalized_text: str) -> bool:
room_context_hints = (
"房间",
"屋里",
"屋子",
"室内",
"客厅",
"卧室",
"书房",
"厨房",
"餐厅",
"沙发",
"",
"",
"",
"",
"",
"",
"电视",
"空调",
"书架",
"",
"冰箱",
"茶几",
"电脑",
"",
"",
"相机",
"植物",
)
spatial_hints = (
"在哪里",
"在哪",
"位置",
"方位",
"旁边",
"左边",
"右边",
"前面",
"后面",
"上面",
"下面",
"附近",
"对面",
"靠近",
"挨着",
"隔着",
)
software_hints = (
"python",
"代码",
"函数",
"class",
"bug",
"日志",
"logging",
"api",
"server",
"agent",
"prompt",
"模型",
"数据库",
"git",
"uv",
"ruff",
"mypy",
)
if any(hint in normalized_text for hint in software_hints):
return False
has_spatial_hint = any(hint in normalized_text for hint in spatial_hints)
has_room_context_hint = any(hint in normalized_text for hint in room_context_hints)
if has_spatial_hint and has_room_context_hint:
return True
if has_spatial_hint and len(normalized_text) <= 12:
return True
return False
def _normalize_text(text: str) -> str:
return "".join(text.split()).lower()
def _latest_user_text(chat_ctx: ChatContext) -> str:
for item in reversed(chat_ctx.items):
if isinstance(item, ChatMessage) and item.role == "user":
return (item.text_content or "").strip()
return ""
def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext:
chat_ctx = chat_ctx.copy()
for index in range(len(chat_ctx.items) - 1, -1, -1):
item = chat_ctx.items[index]
if isinstance(item, ChatMessage) and item.role == "user":
user_msg = item.model_copy(deep=True)
user_msg.content = [memory_context]
chat_ctx.items[index] = user_msg
return chat_ctx
chat_ctx.items.append(ChatMessage(role="user", content=[memory_context]))
return chat_ctx
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
chat_ctx = chat_ctx.copy()
image_content = llm.ImageContent(
image=vision_frame.image_data_url,
mime_type=vision_frame.mime_type,
inference_detail="auto",
)
for index in range(len(chat_ctx.items) - 1, -1, -1):
item = chat_ctx.items[index]
if isinstance(item, ChatMessage) and item.role == "user":
user_msg = item.model_copy(deep=True)
content = list(user_msg.content)
content.append(image_content)
user_msg.content = content
chat_ctx.items[index] = user_msg
return chat_ctx
chat_ctx.items.append(ChatMessage(role="user", content=[image_content]))
return chat_ctx
def _normalize_input_mode(value: str | None) -> str:
if not value:
return AUTO_INPUT_MODE
normalized = value.strip().lower().replace("-", "_")
aliases = {
"image_voice": VISION_VOICE_INPUT_MODE,
"image": VISION_VOICE_INPUT_MODE,
"vision": VISION_VOICE_INPUT_MODE,
"vision_voice": VISION_VOICE_INPUT_MODE,
"voice_image": VISION_VOICE_INPUT_MODE,
"audio": VOICE_INPUT_MODE,
"voice": VOICE_INPUT_MODE,
"auto": AUTO_INPUT_MODE,
}
mode = aliases.get(normalized)
if mode is not None:
return mode
logger.warning("Invalid CUSTOM_AGENT_INPUT_MODE=%r, using %s", value, AUTO_INPUT_MODE)
return AUTO_INPUT_MODE
def _image_extension_from_mime_type(mime_type: str) -> str:
normalized = mime_type.strip().lower()
if normalized == "image/png":
return ".png"
if normalized == "image/webp":
return ".webp"
if normalized == "image/gif":
return ".gif"
return ".jpg"
def _model_image_save_dir_from_env() -> Path | None:
if not _env_bool("CUSTOM_SAVE_MODEL_IMAGES", True):
return None
configured = os.getenv("CUSTOM_MODEL_IMAGE_SAVE_DIR")
if configured:
return Path(configured).expanduser()
return Path(__file__).with_name("model_images")
server = AgentServer()
def prewarm(proc: JobProcess) -> None:
# Load Silero VAD as requested
proc.userdata["vad"] = silero.VAD.load()
server.setup_fnc = prewarm
@server.rtc_session(agent_name=AGENT_NAME)
async def entrypoint(ctx: JobContext) -> None:
ctx.log_context_fields = {
"room": ctx.room.name,
}
# Configuration for custom local endpoints. These can be set in your .env file.
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", "openai").strip().lower()
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
if LLM_PROVIDER not in {
"openai",
"openai-compatible",
"hermes",
"hermes_gateway",
"openclaw",
"beaver",
}:
raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}")
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
logger.info(
"Using LLM provider=%s model=%s base_url=%s",
LLM_PROVIDER,
LLM_MODEL,
LLM_BASE_URL or "OpenAI default",
)
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
)
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
blackbox_stt = BlackboxSTT(
url=ASR_URL,
model_name=ASR_MODEL,
language=ASR_LANGUAGE,
output_language=ASR_OUTPUT_LANGUAGE,
hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"),
itn=os.getenv("CUSTOM_ASR_ITN"),
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
)
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
if LLM_PROVIDER == "beaver":
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
if not beaver_url:
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
beaver_device_name = (
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
or "livekit-custom-agent"
)
base_llm = BeaverLLM(
url=beaver_url,
peer_id=beaver_peer_id,
device_name=beaver_device_name,
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
)
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text)
text_llm = base_llm
vision_llm = base_llm
logger.info(
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s",
beaver_url,
beaver_peer_id,
beaver_device_name,
ctx.room.name,
base_llm.session_id,
bool(beaver_warmup_text and beaver_warmup_text.strip()),
len(warmup_reply) if warmup_reply is not None else 0,
)
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
if not gateway_url:
raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}")
hermes_agent_id = os.getenv("CUSTOM_HERMES_AGENT_ID") or None
hermes_session_mode = os.getenv("CUSTOM_HERMES_SESSION_MODE", "per_room").strip().lower()
if hermes_session_mode != "per_room":
raise RuntimeError("CUSTOM_HERMES_SESSION_MODE must be per_room")
hermes_token = (
os.getenv("CUSTOM_HERMES_API_KEY")
or os.getenv("CUSTOM_HERMES_TOKEN")
or LLM_API_KEY
or None
)
hermes_state = GatewaySessionState(
room_name=ctx.room.name,
agent_id=hermes_agent_id,
session_mode=hermes_session_mode,
)
base_llm = HermesGatewayLLM(
url=gateway_url,
token=hermes_token,
state=hermes_state,
agent_id=hermes_agent_id,
model_name=os.getenv("CUSTOM_HERMES_MODEL", "hermes-agent"),
request_timeout=_env_float("CUSTOM_HERMES_REQUEST_TIMEOUT", 30.0),
)
text_llm = base_llm
vision_llm = base_llm
logger.info(
"Using Hermes/OpenClaw gateway url=%s agent_id=%s session_key=%s",
gateway_url,
hermes_agent_id or "default",
hermes_state.session_key,
)
else:
import httpx
from openai import AsyncClient as OpenAIAsyncClient
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
if LLM_BASE_URL:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL,
http_client=http_client,
)
else:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
http_client=http_client,
)
base_llm = openai.LLM(
model=LLM_MODEL,
client=openai_client,
)
text_llm = (
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
if TEXT_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_llm = (
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
if VISION_LLM_MODEL != LLM_MODEL
else base_llm
)
vision_store = VisionFrameStore(
max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0)
)
session: AgentSession = AgentSession(
# 1. Custom ASR blackbox with StreamAdapter
stt=stt_stream,
# 2. LLM backend, OpenAI-compatible or Hermes/OpenClaw gateway.
llm=base_llm,
# 3. TTS blackbox
tts=BlackboxTTS(
url=TTS_URL,
model_name=TTS_MODEL,
params=_tts_params_from_env(TTS_MODEL),
prompt_wav_path=_tts_prompt_wav_from_env(TTS_MODEL),
sample_rate=TTS_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
),
# 4. Silero VAD
vad=ctx.proc.userdata["vad"],
turn_handling=TurnHandlingOptions(
turn_detection=MultilingualModel(),
interruption={
"resume_false_interruption": True,
"false_interruption_timeout": 1.0,
},
),
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", LLM_PROVIDER != "beaver"),
aec_warmup_duration=3.0,
tts_text_transforms=[
"filter_emoji",
"filter_markdown",
],
)
@session.on("metrics_collected")
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
metrics.log_metrics(ev.metrics)
@session.on("conversation_item_added")
def _on_conversation_item_added(event) -> None:
item = getattr(event, "item", None)
if not isinstance(item, ChatMessage):
return
if item.role == "user" and item.metrics:
logger.info("User turn metrics: %s", item.metrics)
elif item.role == "assistant" and item.metrics:
logger.info("Assistant turn metrics: %s", item.metrics)
@ctx.room.on("data_received")
def _on_data_received(data_packet) -> None:
packet_topic = getattr(data_packet, "topic", None)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
return
try:
payload = json.loads(data_packet.data.decode("utf-8"))
except Exception:
logger.exception("Failed to decode vision frame payload")
return
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
return
image = payload.get("image")
if not isinstance(image, str) or not image:
logger.warning("Received vision frame without image data")
return
mime_type = payload.get("mime_type")
if not isinstance(mime_type, str) or not mime_type:
mime_type = "image/jpeg"
saved_path = payload.get("saved_path")
vision_store.update(
image=image,
mime_type=mime_type,
saved_path=saved_path if isinstance(saved_path, str) else None,
)
logger.info(
"Cached vision frame: mime_type=%s image_chars=%s saved_path=%s",
mime_type,
len(image),
saved_path,
)
memory_client = (
MemoryRecallClient(
url=MEMORY_URL,
timeout=MEMORY_TIMEOUT,
max_chars=MEMORY_MAX_CHARS,
api_key=MEMORY_API_KEY,
)
if MEMORY_URL
else None
)
await session.start(
agent=CustomAgent(
memory_client=memory_client,
vision_store=vision_store,
input_mode=INPUT_MODE,
text_llm=text_llm,
vision_llm=vision_llm,
model_image_save_dir=_model_image_save_dir_from_env(),
),
room=ctx.room,
room_options=room_io.RoomOptions(
audio_output=room_io.AudioOutputOptions(
sample_rate=OUTPUT_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
),
),
record=_recording_options_from_env(),
)
def _tts_params_from_env(model_name: str) -> dict[str, str]:
params: dict[str, str] = {}
model_name = model_name.lower()
if model_name == "voxcpmtts":
_set_if_present(params, "streaming", os.getenv("CUSTOM_TTS_STREAMING"))
_set_if_present(
params,
"prompt_text",
os.getenv("CUSTOM_TTS_PROMPT_TEXT") or os.getenv("VOXCPM_PROMPT_TEXT"),
)
_set_if_present(params, "cfg_value", os.getenv("VOXCPM_CFG_VALUE"))
_set_if_present(params, "inference_timesteps", os.getenv("VOXCPM_INFERENCE_TIMESTEPS"))
_set_if_present(params, "do_normalize", os.getenv("VOXCPM_DO_NORMALIZE"))
_set_if_present(params, "denoise", os.getenv("VOXCPM_DENOISE"))
_set_if_present(params, "retry_badcase", os.getenv("VOXCPM_RETRY_BADCASE"))
_set_if_present(
params,
"retry_badcase_max_times",
os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES"),
)
_set_if_present(
params,
"retry_badcase_ratio_threshold",
os.getenv("VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD"),
)
elif model_name == "melotts":
_set_if_present(params, "speed", os.getenv("CUSTOM_TTS_SPEED"))
elif model_name == "cosyvoicetts":
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
elif model_name == "sovitstts":
_set_if_present(params, "text_lang", os.getenv("CUSTOM_TTS_TEXT_LANG"))
_set_if_present(params, "prompt_lang", os.getenv("CUSTOM_TTS_PROMPT_LANG"))
_set_if_present(params, "text_split_method", os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD"))
_set_if_present(params, "batch_size", os.getenv("CUSTOM_TTS_BATCH_SIZE"))
_set_if_present(params, "media_type", os.getenv("CUSTOM_TTS_MEDIA_TYPE"))
_set_if_present(params, "streaming_mode", os.getenv("CUSTOM_TTS_STREAMING"))
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
return params
def _tts_prompt_wav_from_env(model_name: str) -> str | None:
if model_name.lower() != "voxcpmtts":
return None
return os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV") or None
def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None:
if value:
params[key] = value
def _env_int(name: str, default: int) -> int:
value = os.getenv(name)
if not value:
return default
try:
return int(value)
except ValueError:
logger.warning("Invalid integer for %s=%r, using %s", name, value, default)
return default
def _env_float(name: str, default: float) -> float:
value = os.getenv(name)
if not value:
return default
try:
return float(value)
except ValueError:
logger.warning("Invalid float for %s=%r, using %s", name, value, default)
return default
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
logger.warning("Invalid boolean for %s=%r, using %s", name, value, default)
return default
def _first_env(*names: str) -> str | None:
for name in names:
value = os.getenv(name)
if value and value.strip():
return value.strip()
return None
def _recording_options_from_env() -> RecordingOptions:
return RecordingOptions(
audio=_env_bool("CUSTOM_RECORD_AUDIO", False),
traces=_env_bool("CUSTOM_RECORD_TRACES", False),
logs=_env_bool("CUSTOM_RECORD_LOGS", False),
transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False),
)
if __name__ == "__main__":
cli.run_app(server)