1132 lines
37 KiB
Python
1132 lines
37 KiB
Python
import base64
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import time
|
||
from collections.abc import AsyncIterable
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
from beaver_llm import BeaverLLM
|
||
from dotenv import load_dotenv
|
||
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
|
||
from memory import MemoryRecallClient
|
||
from tts import BlackboxTTS
|
||
|
||
from asr import BlackboxSTT
|
||
from livekit.agents import (
|
||
Agent,
|
||
AgentServer,
|
||
AgentSession,
|
||
ChatContext,
|
||
ChatMessage,
|
||
FlushSentinel,
|
||
JobContext,
|
||
JobProcess,
|
||
MetricsCollectedEvent,
|
||
ModelSettings,
|
||
RecordingOptions,
|
||
TurnHandlingOptions,
|
||
cli,
|
||
llm,
|
||
metrics,
|
||
room_io,
|
||
stt,
|
||
)
|
||
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
|
||
from livekit.plugins import openai, silero
|
||
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
||
|
||
logger = logging.getLogger("custom-agent")
|
||
|
||
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||
|
||
ROOM_LOCATOR_INSTRUCTIONS = """
|
||
你是一个房间物品定位助手。
|
||
当用户询问房间内某个物品的位置时:
|
||
- 只用一句中文回答
|
||
- 描述目标物品和其他物品的相对位置关系
|
||
- 不要使用 Markdown、emoji、列表、标题、坐标区域标签
|
||
- 不要解释推理过程
|
||
如果用户的问题与房间物品定位无关,则正常回答用户问题。
|
||
""".strip()
|
||
|
||
GENERAL_INSTRUCTIONS = """
|
||
你是一个智能语音助手。
|
||
正常回答用户问题。
|
||
回答自然、简洁、准确。
|
||
""".strip()
|
||
|
||
EMOTION_INSTRUCTIONS = """
|
||
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
|
||
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
|
||
情绪标签之后直接输出给用户的正常回复,不要解释标签。
|
||
""".strip()
|
||
|
||
ROOM_LOCATOR_MODE = "room_locator"
|
||
GENERAL_MODE = "general"
|
||
VOICE_INPUT_MODE = "voice"
|
||
VISION_VOICE_INPUT_MODE = "vision_voice"
|
||
AUTO_INPUT_MODE = "auto"
|
||
VISION_FRAME_TOPIC = "vision.frame"
|
||
DEFAULT_AGENT_PROFILE = "normal"
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class AgentProfile:
|
||
agent_name: str
|
||
llm_provider: str
|
||
input_mode: str
|
||
|
||
|
||
AGENT_PROFILES = {
|
||
"normal": AgentProfile(
|
||
agent_name="normal-agent",
|
||
llm_provider="openai-compatible",
|
||
input_mode=VOICE_INPUT_MODE,
|
||
),
|
||
"beaver": AgentProfile(
|
||
agent_name="beaver-agent",
|
||
llm_provider="beaver",
|
||
input_mode=VOICE_INPUT_MODE,
|
||
),
|
||
"vision-normal": AgentProfile(
|
||
agent_name="vision-normal-agent",
|
||
llm_provider="openai-compatible",
|
||
input_mode=VISION_VOICE_INPUT_MODE,
|
||
),
|
||
"vision-beaver": AgentProfile(
|
||
agent_name="vision-beaver-agent",
|
||
llm_provider="beaver",
|
||
input_mode=VISION_VOICE_INPUT_MODE,
|
||
),
|
||
}
|
||
AGENT_PROFILE_ALIASES = {
|
||
"default": "normal",
|
||
"openai": "normal",
|
||
"openai-compatible": "normal",
|
||
"llm": "normal",
|
||
"text": "normal",
|
||
"voice": "normal",
|
||
"vision": "vision-normal",
|
||
"vision-llm": "vision-normal",
|
||
"vision-openai": "vision-normal",
|
||
"vision-openai-compatible": "vision-normal",
|
||
}
|
||
|
||
|
||
def _normalize_agent_profile(value: str | None) -> str:
|
||
if not value or not value.strip():
|
||
return DEFAULT_AGENT_PROFILE
|
||
|
||
normalized = value.strip().lower().replace("_", "-")
|
||
profile = AGENT_PROFILE_ALIASES.get(normalized, normalized)
|
||
if profile in AGENT_PROFILES:
|
||
return profile
|
||
|
||
logger.warning(
|
||
"Invalid CUSTOM_AGENT_PROFILE=%r, using %s",
|
||
value,
|
||
DEFAULT_AGENT_PROFILE,
|
||
)
|
||
return DEFAULT_AGENT_PROFILE
|
||
|
||
|
||
def _agent_profile_from_name(agent_name: str | None) -> str | None:
|
||
if not agent_name or not agent_name.strip():
|
||
return None
|
||
|
||
normalized = agent_name.strip().lower().replace("_", "-")
|
||
for profile_name, profile in AGENT_PROFILES.items():
|
||
if normalized == profile.agent_name:
|
||
return profile_name
|
||
return None
|
||
|
||
|
||
def _selected_agent_profile_name() -> str:
|
||
configured_profile = os.getenv("CUSTOM_AGENT_PROFILE")
|
||
if configured_profile and configured_profile.strip():
|
||
return _normalize_agent_profile(configured_profile)
|
||
|
||
inferred_profile = _agent_profile_from_name(os.getenv("CUSTOM_AGENT_NAME"))
|
||
if inferred_profile is not None:
|
||
return inferred_profile
|
||
|
||
return DEFAULT_AGENT_PROFILE
|
||
|
||
|
||
AGENT_PROFILE_NAME = _selected_agent_profile_name()
|
||
AGENT_PROFILE = AGENT_PROFILES[AGENT_PROFILE_NAME]
|
||
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
|
||
|
||
DEFAULT_EMOTION = "neutral"
|
||
EMOTION_LABELS = {
|
||
"neutral",
|
||
"happy",
|
||
"sad",
|
||
"angry",
|
||
"surprised",
|
||
"fearful",
|
||
"calm",
|
||
"concerned",
|
||
}
|
||
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
|
||
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
|
||
TTS_EMOTION_LINE_RE = re.compile(
|
||
r"^\s*(?:emotion|情绪)\s*[::=]\s*[\w\u4e00-\u9fff-]{1,40}\s*[,,。.!!\s-]*",
|
||
re.IGNORECASE,
|
||
)
|
||
MAX_EMOTION_PREFIX_CHARS = 80
|
||
|
||
|
||
@dataclass
|
||
class VisionFrame:
|
||
image_data_url: str
|
||
received_at: float
|
||
mime_type: str
|
||
saved_path: str | None = None
|
||
|
||
|
||
class VisionFrameStore:
|
||
def __init__(self, *, max_age_seconds: float) -> None:
|
||
self._max_age_seconds = max_age_seconds
|
||
self._latest_frame: VisionFrame | None = None
|
||
|
||
def update(self, *, image: str, mime_type: str, saved_path: str | None = None) -> None:
|
||
self._latest_frame = VisionFrame(
|
||
image_data_url=f"data:{mime_type};base64,{image}",
|
||
received_at=time.monotonic(),
|
||
mime_type=mime_type,
|
||
saved_path=saved_path,
|
||
)
|
||
|
||
def consume_fresh(self) -> VisionFrame | None:
|
||
frame = self._latest_frame
|
||
if frame is None:
|
||
return None
|
||
|
||
age = time.monotonic() - frame.received_at
|
||
self._latest_frame = None
|
||
if age > self._max_age_seconds:
|
||
logger.info("Dropping stale vision frame: age=%.3fs", age)
|
||
return None
|
||
|
||
return frame
|
||
|
||
|
||
class CustomAgent(Agent):
|
||
def __init__(
|
||
self,
|
||
*,
|
||
memory_client: MemoryRecallClient | None = None,
|
||
vision_store: VisionFrameStore | None = None,
|
||
input_mode: str = AUTO_INPUT_MODE,
|
||
text_llm: llm.LLM | None = None,
|
||
vision_llm: llm.LLM | None = None,
|
||
model_image_save_dir: Path | None = None,
|
||
) -> None:
|
||
super().__init__(instructions=_with_emotion_instructions(GENERAL_INSTRUCTIONS))
|
||
self._memory_client = memory_client
|
||
self._vision_store = vision_store
|
||
self._input_mode = input_mode
|
||
self._text_llm = text_llm
|
||
self._vision_llm = vision_llm
|
||
self._model_image_save_dir = model_image_save_dir
|
||
self.current_emotion = DEFAULT_EMOTION
|
||
self._emotion_prefix_buffer = ""
|
||
self._emotion_prefix_done = True
|
||
|
||
async def on_enter(self) -> None:
|
||
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||
pass
|
||
|
||
async def llm_node(
|
||
self,
|
||
chat_ctx: ChatContext,
|
||
tools: list[llm.Tool],
|
||
model_settings: ModelSettings,
|
||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||
llm_node_started_at = time.perf_counter()
|
||
|
||
user_query = _latest_user_text(chat_ctx)
|
||
mode = _select_mode(user_query)
|
||
vision_frame = self._consume_vision_frame()
|
||
logger.info(
|
||
"Selected agent mode: %s input_mode=%s has_image=%s",
|
||
mode,
|
||
self._input_mode,
|
||
vision_frame is not None,
|
||
)
|
||
|
||
chat_ctx = chat_ctx.copy()
|
||
update_chat_instructions(
|
||
chat_ctx,
|
||
instructions=_with_emotion_instructions(
|
||
ROOM_LOCATOR_INSTRUCTIONS if mode == ROOM_LOCATOR_MODE else GENERAL_INSTRUCTIONS
|
||
),
|
||
add_if_missing=True,
|
||
)
|
||
|
||
if mode == ROOM_LOCATOR_MODE:
|
||
memory_context = await self._recall_room_memory(chat_ctx)
|
||
if memory_context:
|
||
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
||
|
||
if vision_frame is not None:
|
||
self._save_model_vision_frame(vision_frame)
|
||
chat_ctx = _with_vision_as_latest_user_message(chat_ctx, vision_frame)
|
||
|
||
llm_result = self._run_selected_llm(
|
||
chat_ctx,
|
||
tools,
|
||
model_settings,
|
||
has_image=vision_frame is not None,
|
||
)
|
||
if not hasattr(llm_result, "__aiter__"):
|
||
elapsed = time.perf_counter() - llm_node_started_at
|
||
logger.info("LLM node completed without streaming in %.3fs", elapsed)
|
||
return llm_result
|
||
|
||
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||
first_chunk_at: float | None = None
|
||
chunk_count = 0
|
||
self._emotion_prefix_buffer = ""
|
||
self._emotion_prefix_done = False
|
||
try:
|
||
async for chunk in llm_result:
|
||
chunk_count += 1
|
||
if first_chunk_at is None:
|
||
first_chunk_at = time.perf_counter()
|
||
logger.info(
|
||
"LLM first chunk after %.3fs",
|
||
first_chunk_at - llm_node_started_at,
|
||
)
|
||
async for output_chunk in self._observe_emotion_prefix(chunk):
|
||
yield output_chunk
|
||
finally:
|
||
finished_at = time.perf_counter()
|
||
logger.info(
|
||
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
|
||
finished_at - llm_node_started_at,
|
||
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
|
||
chunk_count,
|
||
)
|
||
|
||
return _instrumented_stream()
|
||
|
||
def tts_node(self, text: AsyncIterable[str], model_settings: ModelSettings):
|
||
return Agent.default.tts_node(self, _strip_emotion_for_tts(text), model_settings)
|
||
|
||
def _consume_vision_frame(self) -> VisionFrame | None:
|
||
if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None:
|
||
return None
|
||
return self._vision_store.consume_fresh()
|
||
|
||
def _save_model_vision_frame(self, vision_frame: VisionFrame) -> None:
|
||
if self._model_image_save_dir is None:
|
||
return
|
||
|
||
try:
|
||
_, b64_data = vision_frame.image_data_url.split(",", 1)
|
||
image_bytes = base64.b64decode(b64_data, validate=True)
|
||
except Exception:
|
||
logger.exception("Failed to decode model vision frame for debug save")
|
||
return
|
||
|
||
extension = _image_extension_from_mime_type(vision_frame.mime_type)
|
||
timestamp_ms = int(time.time() * 1000)
|
||
path = self._model_image_save_dir / f"{timestamp_ms}_model_input{extension}"
|
||
|
||
try:
|
||
self._model_image_save_dir.mkdir(parents=True, exist_ok=True)
|
||
path.write_bytes(image_bytes)
|
||
except Exception:
|
||
logger.exception("Failed to save model vision frame: path=%s", path)
|
||
return
|
||
|
||
logger.info(
|
||
"Saved model vision frame: path=%s bytes=%s source_path=%s",
|
||
path,
|
||
len(image_bytes),
|
||
vision_frame.saved_path,
|
||
)
|
||
|
||
def _run_selected_llm(
|
||
self,
|
||
chat_ctx: ChatContext,
|
||
tools: list[llm.Tool],
|
||
model_settings: ModelSettings,
|
||
*,
|
||
has_image: bool,
|
||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||
selected_llm = self._vision_llm if has_image else self._text_llm
|
||
if selected_llm is None:
|
||
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
||
|
||
activity = self._get_activity_or_raise()
|
||
tool_choice = model_settings.tool_choice
|
||
conn_options = activity.session.conn_options.llm_conn_options
|
||
|
||
async def _stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||
async with selected_llm.chat(
|
||
chat_ctx=chat_ctx,
|
||
tools=tools,
|
||
tool_choice=tool_choice,
|
||
conn_options=conn_options,
|
||
) as stream:
|
||
async for chunk in stream:
|
||
yield chunk
|
||
|
||
return _stream()
|
||
async def _observe_emotion_prefix(
|
||
self, chunk: llm.ChatChunk | str | FlushSentinel
|
||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||
if isinstance(chunk, str):
|
||
self._consume_emotion_prefix(chunk)
|
||
yield chunk
|
||
return
|
||
|
||
if isinstance(chunk, llm.ChatChunk) and chunk.delta and chunk.delta.content:
|
||
self._consume_emotion_prefix(chunk.delta.content)
|
||
yield chunk
|
||
return
|
||
|
||
yield chunk
|
||
|
||
def _consume_emotion_prefix(self, content: str) -> None:
|
||
if self._emotion_prefix_done:
|
||
return
|
||
|
||
self._emotion_prefix_buffer += content
|
||
match = EMOTION_PREFIX_RE.match(self._emotion_prefix_buffer)
|
||
if match:
|
||
emotion = match.group(1).lower()
|
||
if emotion not in EMOTION_LABELS:
|
||
logger.warning("LLM returned unsupported emotion=%s, using neutral", emotion)
|
||
emotion = DEFAULT_EMOTION
|
||
|
||
self.current_emotion = emotion
|
||
self._emotion_prefix_done = True
|
||
self._emotion_prefix_buffer = ""
|
||
logger.info("LLM emotion selected: %s", emotion)
|
||
return
|
||
|
||
candidate = self._emotion_prefix_buffer.lstrip().lower()
|
||
might_still_be_prefix = (
|
||
not candidate
|
||
or "<emotion=".startswith(candidate)
|
||
or (candidate.startswith("<emotion=") and ">" not in candidate)
|
||
)
|
||
if might_still_be_prefix and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
|
||
return
|
||
|
||
self._emotion_prefix_done = True
|
||
self._emotion_prefix_buffer = ""
|
||
logger.warning("LLM response did not start with an emotion prefix")
|
||
|
||
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
||
if self._memory_client is None:
|
||
return ""
|
||
|
||
user_query = _latest_user_text(chat_ctx)
|
||
if not user_query:
|
||
return ""
|
||
|
||
started_at = time.perf_counter()
|
||
try:
|
||
recalled = await self._memory_client.recall(user_query)
|
||
elapsed = time.perf_counter() - started_at
|
||
logger.info(
|
||
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
|
||
elapsed,
|
||
len(user_query),
|
||
len(recalled),
|
||
)
|
||
return recalled
|
||
except Exception:
|
||
logger.exception(
|
||
"Unexpected memory recall failure after %.3fs",
|
||
time.perf_counter() - started_at,
|
||
)
|
||
return ""
|
||
|
||
|
||
def _select_mode(user_query: str) -> str:
|
||
normalized = _normalize_text(user_query)
|
||
if not normalized:
|
||
return GENERAL_MODE
|
||
|
||
if _is_room_locator_query(normalized):
|
||
return ROOM_LOCATOR_MODE
|
||
|
||
return GENERAL_MODE
|
||
|
||
|
||
def _with_emotion_instructions(instructions: str) -> str:
|
||
return f"{instructions}\n\n{EMOTION_INSTRUCTIONS}"
|
||
|
||
|
||
async def _strip_emotion_for_tts(text: AsyncIterable[str]) -> AsyncIterable[str]:
|
||
prefix_buffer = ""
|
||
scanning_prefix = True
|
||
|
||
async for chunk in text:
|
||
if not chunk:
|
||
continue
|
||
|
||
if scanning_prefix:
|
||
prefix_buffer += chunk
|
||
cleaned, done = _strip_leading_tts_emotion(prefix_buffer)
|
||
if not done:
|
||
continue
|
||
|
||
scanning_prefix = False
|
||
prefix_buffer = ""
|
||
if cleaned:
|
||
yield _strip_inline_tts_emotion(cleaned)
|
||
continue
|
||
|
||
cleaned = _strip_inline_tts_emotion(chunk)
|
||
if cleaned:
|
||
yield cleaned
|
||
|
||
if scanning_prefix and prefix_buffer:
|
||
cleaned, _ = _strip_leading_tts_emotion(prefix_buffer, force=True)
|
||
cleaned = _strip_inline_tts_emotion(cleaned)
|
||
if cleaned:
|
||
yield cleaned
|
||
|
||
|
||
def _strip_leading_tts_emotion(text: str, *, force: bool = False) -> tuple[str, bool]:
|
||
match = TTS_EMOTION_MARKUP_RE.match(text)
|
||
if match:
|
||
return text[match.end() :], True
|
||
|
||
match = TTS_EMOTION_LINE_RE.match(text)
|
||
if match:
|
||
return text[match.end() :], True
|
||
|
||
candidate = text.lstrip().lower()
|
||
might_still_be_emotion = (
|
||
not candidate
|
||
or "<emotion=".startswith(candidate)
|
||
or (candidate.startswith("<emotion") and ">" not in candidate)
|
||
or "emotion".startswith(candidate)
|
||
or (candidate.startswith("emotion") and len(candidate) <= MAX_EMOTION_PREFIX_CHARS)
|
||
or "情绪".startswith(candidate)
|
||
)
|
||
if not force and might_still_be_emotion and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
|
||
return "", False
|
||
|
||
return text, True
|
||
|
||
|
||
def _strip_inline_tts_emotion(text: str) -> str:
|
||
return TTS_EMOTION_MARKUP_RE.sub("", text)
|
||
|
||
|
||
def _is_room_locator_query(normalized_text: str) -> bool:
|
||
room_context_hints = (
|
||
"房间",
|
||
"屋里",
|
||
"屋子",
|
||
"室内",
|
||
"客厅",
|
||
"卧室",
|
||
"书房",
|
||
"厨房",
|
||
"餐厅",
|
||
"沙发",
|
||
"桌",
|
||
"椅",
|
||
"床",
|
||
"门",
|
||
"窗",
|
||
"柜",
|
||
"电视",
|
||
"空调",
|
||
"书架",
|
||
"灯",
|
||
"冰箱",
|
||
"茶几",
|
||
"电脑",
|
||
"包",
|
||
"瓶",
|
||
"相机",
|
||
"植物",
|
||
)
|
||
spatial_hints = (
|
||
"在哪里",
|
||
"在哪",
|
||
"位置",
|
||
"方位",
|
||
"旁边",
|
||
"左边",
|
||
"右边",
|
||
"前面",
|
||
"后面",
|
||
"上面",
|
||
"下面",
|
||
"附近",
|
||
"对面",
|
||
"靠近",
|
||
"挨着",
|
||
"隔着",
|
||
)
|
||
software_hints = (
|
||
"python",
|
||
"代码",
|
||
"函数",
|
||
"class",
|
||
"bug",
|
||
"日志",
|
||
"logging",
|
||
"api",
|
||
"server",
|
||
"agent",
|
||
"prompt",
|
||
"模型",
|
||
"数据库",
|
||
"git",
|
||
"uv",
|
||
"ruff",
|
||
"mypy",
|
||
)
|
||
|
||
if any(hint in normalized_text for hint in software_hints):
|
||
return False
|
||
|
||
has_spatial_hint = any(hint in normalized_text for hint in spatial_hints)
|
||
has_room_context_hint = any(hint in normalized_text for hint in room_context_hints)
|
||
|
||
if has_spatial_hint and has_room_context_hint:
|
||
return True
|
||
|
||
if has_spatial_hint and len(normalized_text) <= 12:
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def _normalize_text(text: str) -> str:
|
||
return "".join(text.split()).lower()
|
||
|
||
|
||
def _latest_user_text(chat_ctx: ChatContext) -> str:
|
||
for item in reversed(chat_ctx.items):
|
||
if isinstance(item, ChatMessage) and item.role == "user":
|
||
return (item.text_content or "").strip()
|
||
return ""
|
||
|
||
|
||
def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext:
|
||
chat_ctx = chat_ctx.copy()
|
||
for index in range(len(chat_ctx.items) - 1, -1, -1):
|
||
item = chat_ctx.items[index]
|
||
if isinstance(item, ChatMessage) and item.role == "user":
|
||
user_msg = item.model_copy(deep=True)
|
||
user_msg.content = [memory_context]
|
||
chat_ctx.items[index] = user_msg
|
||
return chat_ctx
|
||
|
||
chat_ctx.items.append(ChatMessage(role="user", content=[memory_context]))
|
||
return chat_ctx
|
||
|
||
|
||
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
|
||
chat_ctx = chat_ctx.copy()
|
||
image_content = llm.ImageContent(
|
||
image=vision_frame.image_data_url,
|
||
mime_type=vision_frame.mime_type,
|
||
inference_detail="auto",
|
||
)
|
||
|
||
for index in range(len(chat_ctx.items) - 1, -1, -1):
|
||
item = chat_ctx.items[index]
|
||
if isinstance(item, ChatMessage) and item.role == "user":
|
||
user_msg = item.model_copy(deep=True)
|
||
content = list(user_msg.content)
|
||
content.append(image_content)
|
||
user_msg.content = content
|
||
chat_ctx.items[index] = user_msg
|
||
return chat_ctx
|
||
|
||
chat_ctx.items.append(ChatMessage(role="user", content=[image_content]))
|
||
return chat_ctx
|
||
|
||
|
||
def _normalize_input_mode(value: str | None) -> str:
|
||
if not value:
|
||
return AUTO_INPUT_MODE
|
||
|
||
normalized = value.strip().lower().replace("-", "_")
|
||
aliases = {
|
||
"image_voice": VISION_VOICE_INPUT_MODE,
|
||
"image": VISION_VOICE_INPUT_MODE,
|
||
"vision": VISION_VOICE_INPUT_MODE,
|
||
"vision_voice": VISION_VOICE_INPUT_MODE,
|
||
"voice_image": VISION_VOICE_INPUT_MODE,
|
||
"audio": VOICE_INPUT_MODE,
|
||
"voice": VOICE_INPUT_MODE,
|
||
"auto": AUTO_INPUT_MODE,
|
||
}
|
||
mode = aliases.get(normalized)
|
||
if mode is not None:
|
||
return mode
|
||
|
||
logger.warning("Invalid CUSTOM_AGENT_INPUT_MODE=%r, using %s", value, AUTO_INPUT_MODE)
|
||
return AUTO_INPUT_MODE
|
||
|
||
|
||
def _image_extension_from_mime_type(mime_type: str) -> str:
|
||
normalized = mime_type.strip().lower()
|
||
if normalized == "image/png":
|
||
return ".png"
|
||
if normalized == "image/webp":
|
||
return ".webp"
|
||
if normalized == "image/gif":
|
||
return ".gif"
|
||
return ".jpg"
|
||
|
||
|
||
def _model_image_save_dir_from_env() -> Path | None:
|
||
if not _env_bool("CUSTOM_SAVE_MODEL_IMAGES", True):
|
||
return None
|
||
|
||
configured = os.getenv("CUSTOM_MODEL_IMAGE_SAVE_DIR")
|
||
if configured:
|
||
return Path(configured).expanduser()
|
||
|
||
return Path(__file__).with_name("model_images")
|
||
|
||
|
||
def _agent_server_from_env() -> AgentServer:
|
||
configured_port = os.getenv("CUSTOM_AGENT_HTTP_PORT")
|
||
if configured_port is None:
|
||
return AgentServer()
|
||
|
||
try:
|
||
port = int(configured_port)
|
||
except ValueError:
|
||
logger.warning("Invalid integer for CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
|
||
port = 0
|
||
|
||
if port < 0 or port > 65535:
|
||
logger.warning("Invalid CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
|
||
port = 0
|
||
|
||
return AgentServer(port=port)
|
||
|
||
|
||
server = _agent_server_from_env()
|
||
|
||
|
||
def prewarm(proc: JobProcess) -> None:
|
||
# Load Silero VAD as requested
|
||
proc.userdata["vad"] = silero.VAD.load()
|
||
|
||
|
||
server.setup_fnc = prewarm
|
||
|
||
|
||
@server.rtc_session(agent_name=AGENT_NAME)
|
||
async def entrypoint(ctx: JobContext) -> None:
|
||
ctx.log_context_fields = {
|
||
"room": ctx.room.name,
|
||
}
|
||
|
||
# Configuration for custom local endpoints. These can be set in your .env file.
|
||
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
|
||
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
|
||
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
||
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
||
|
||
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
||
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
||
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
|
||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
|
||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode))
|
||
if LLM_PROVIDER not in {
|
||
"openai",
|
||
"openai-compatible",
|
||
"hermes",
|
||
"hermes_gateway",
|
||
"openclaw",
|
||
"beaver",
|
||
}:
|
||
raise RuntimeError(f"Unsupported CUSTOM_LLM_PROVIDER={LLM_PROVIDER!r}")
|
||
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
|
||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||
logger.info(
|
||
"Using agent profile=%s agent_name=%s input_mode=%s llm_provider=%s model=%s base_url=%s",
|
||
AGENT_PROFILE_NAME,
|
||
AGENT_NAME or "<automatic>",
|
||
INPUT_MODE,
|
||
LLM_PROVIDER,
|
||
LLM_MODEL,
|
||
LLM_BASE_URL or "OpenAI default",
|
||
)
|
||
|
||
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
||
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
|
||
)
|
||
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
||
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
||
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
||
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
||
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
|
||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
||
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
||
|
||
blackbox_stt = BlackboxSTT(
|
||
url=ASR_URL,
|
||
model_name=ASR_MODEL,
|
||
language=ASR_LANGUAGE,
|
||
output_language=ASR_OUTPUT_LANGUAGE,
|
||
hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"),
|
||
itn=os.getenv("CUSTOM_ASR_ITN"),
|
||
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
|
||
)
|
||
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
||
|
||
if LLM_PROVIDER == "beaver":
|
||
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
|
||
if not beaver_url:
|
||
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
|
||
|
||
beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
|
||
beaver_device_name = (
|
||
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
|
||
or "livekit-custom-agent"
|
||
)
|
||
base_llm = BeaverLLM(
|
||
url=beaver_url,
|
||
peer_id=beaver_peer_id,
|
||
device_name=beaver_device_name,
|
||
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
|
||
)
|
||
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
|
||
warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text)
|
||
text_llm = base_llm
|
||
vision_llm = base_llm
|
||
logger.info(
|
||
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s",
|
||
beaver_url,
|
||
beaver_peer_id,
|
||
beaver_device_name,
|
||
ctx.room.name,
|
||
base_llm.session_id,
|
||
bool(beaver_warmup_text and beaver_warmup_text.strip()),
|
||
len(warmup_reply) if warmup_reply is not None else 0,
|
||
)
|
||
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
|
||
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
|
||
if not gateway_url:
|
||
raise RuntimeError(f"CUSTOM_HERMES_GATEWAY_URL is not set in {CUSTOM_ENV_PATH}")
|
||
|
||
hermes_agent_id = os.getenv("CUSTOM_HERMES_AGENT_ID") or None
|
||
hermes_session_mode = os.getenv("CUSTOM_HERMES_SESSION_MODE", "per_room").strip().lower()
|
||
if hermes_session_mode != "per_room":
|
||
raise RuntimeError("CUSTOM_HERMES_SESSION_MODE must be per_room")
|
||
hermes_token = (
|
||
os.getenv("CUSTOM_HERMES_API_KEY")
|
||
or os.getenv("CUSTOM_HERMES_TOKEN")
|
||
or LLM_API_KEY
|
||
or None
|
||
)
|
||
hermes_state = GatewaySessionState(
|
||
room_name=ctx.room.name,
|
||
agent_id=hermes_agent_id,
|
||
session_mode=hermes_session_mode,
|
||
)
|
||
base_llm = HermesGatewayLLM(
|
||
url=gateway_url,
|
||
token=hermes_token,
|
||
state=hermes_state,
|
||
agent_id=hermes_agent_id,
|
||
model_name=os.getenv("CUSTOM_HERMES_MODEL", "hermes-agent"),
|
||
request_timeout=_env_float("CUSTOM_HERMES_REQUEST_TIMEOUT", 30.0),
|
||
)
|
||
text_llm = base_llm
|
||
vision_llm = base_llm
|
||
logger.info(
|
||
"Using Hermes/OpenClaw gateway url=%s agent_id=%s session_key=%s",
|
||
gateway_url,
|
||
hermes_agent_id or "default",
|
||
hermes_state.session_key,
|
||
)
|
||
else:
|
||
import httpx
|
||
from openai import AsyncClient as OpenAIAsyncClient
|
||
|
||
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
|
||
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
|
||
|
||
if LLM_BASE_URL:
|
||
openai_client = OpenAIAsyncClient(
|
||
api_key=LLM_API_KEY,
|
||
base_url=LLM_BASE_URL,
|
||
http_client=http_client,
|
||
)
|
||
else:
|
||
openai_client = OpenAIAsyncClient(
|
||
api_key=LLM_API_KEY,
|
||
http_client=http_client,
|
||
)
|
||
|
||
base_llm = openai.LLM(
|
||
model=LLM_MODEL,
|
||
client=openai_client,
|
||
)
|
||
text_llm = (
|
||
openai.LLM(model=TEXT_LLM_MODEL, client=openai_client)
|
||
if TEXT_LLM_MODEL != LLM_MODEL
|
||
else base_llm
|
||
)
|
||
vision_llm = (
|
||
openai.LLM(model=VISION_LLM_MODEL, client=openai_client)
|
||
if VISION_LLM_MODEL != LLM_MODEL
|
||
else base_llm
|
||
)
|
||
vision_store = VisionFrameStore(
|
||
max_age_seconds=_env_float("CUSTOM_VISION_FRAME_MAX_AGE_SECONDS", 8.0)
|
||
)
|
||
|
||
session: AgentSession = AgentSession(
|
||
# 1. Custom ASR blackbox with StreamAdapter
|
||
stt=stt_stream,
|
||
# 2. LLM backend, OpenAI-compatible or Hermes/OpenClaw gateway.
|
||
llm=base_llm,
|
||
# 3. TTS blackbox
|
||
tts=BlackboxTTS(
|
||
url=TTS_URL,
|
||
model_name=TTS_MODEL,
|
||
params=_tts_params_from_env(TTS_MODEL),
|
||
prompt_wav_path=_tts_prompt_wav_from_env(TTS_MODEL),
|
||
sample_rate=TTS_SAMPLE_RATE,
|
||
num_channels=TTS_NUM_CHANNELS,
|
||
),
|
||
# 4. Silero VAD
|
||
vad=ctx.proc.userdata["vad"],
|
||
turn_handling=TurnHandlingOptions(
|
||
turn_detection=MultilingualModel(),
|
||
interruption={
|
||
"resume_false_interruption": True,
|
||
"false_interruption_timeout": 1.0,
|
||
},
|
||
),
|
||
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", LLM_PROVIDER != "beaver"),
|
||
aec_warmup_duration=3.0,
|
||
tts_text_transforms=[
|
||
"filter_emoji",
|
||
"filter_markdown",
|
||
],
|
||
)
|
||
|
||
@session.on("metrics_collected")
|
||
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
||
metrics.log_metrics(ev.metrics)
|
||
|
||
@session.on("conversation_item_added")
|
||
def _on_conversation_item_added(event) -> None:
|
||
item = getattr(event, "item", None)
|
||
if not isinstance(item, ChatMessage):
|
||
return
|
||
|
||
if item.role == "user" and item.metrics:
|
||
logger.info("User turn metrics: %s", item.metrics)
|
||
elif item.role == "assistant" and item.metrics:
|
||
logger.info("Assistant turn metrics: %s", item.metrics)
|
||
|
||
@ctx.room.on("data_received")
|
||
def _on_data_received(data_packet) -> None:
|
||
packet_topic = getattr(data_packet, "topic", None)
|
||
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
|
||
return
|
||
|
||
if INPUT_MODE == VOICE_INPUT_MODE:
|
||
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
|
||
return
|
||
|
||
try:
|
||
payload = json.loads(data_packet.data.decode("utf-8"))
|
||
except Exception:
|
||
logger.exception("Failed to decode vision frame payload")
|
||
return
|
||
|
||
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
|
||
return
|
||
|
||
image = payload.get("image")
|
||
if not isinstance(image, str) or not image:
|
||
logger.warning("Received vision frame without image data")
|
||
return
|
||
|
||
mime_type = payload.get("mime_type")
|
||
if not isinstance(mime_type, str) or not mime_type:
|
||
mime_type = "image/jpeg"
|
||
|
||
saved_path = payload.get("saved_path")
|
||
vision_store.update(
|
||
image=image,
|
||
mime_type=mime_type,
|
||
saved_path=saved_path if isinstance(saved_path, str) else None,
|
||
)
|
||
logger.info(
|
||
"Cached vision frame: mime_type=%s image_chars=%s saved_path=%s",
|
||
mime_type,
|
||
len(image),
|
||
saved_path,
|
||
)
|
||
|
||
memory_client = (
|
||
MemoryRecallClient(
|
||
url=MEMORY_URL,
|
||
timeout=MEMORY_TIMEOUT,
|
||
max_chars=MEMORY_MAX_CHARS,
|
||
api_key=MEMORY_API_KEY,
|
||
)
|
||
if MEMORY_URL
|
||
else None
|
||
)
|
||
|
||
await session.start(
|
||
agent=CustomAgent(
|
||
memory_client=memory_client,
|
||
vision_store=vision_store,
|
||
input_mode=INPUT_MODE,
|
||
text_llm=text_llm,
|
||
vision_llm=vision_llm,
|
||
model_image_save_dir=_model_image_save_dir_from_env(),
|
||
),
|
||
room=ctx.room,
|
||
room_options=room_io.RoomOptions(
|
||
audio_output=room_io.AudioOutputOptions(
|
||
sample_rate=OUTPUT_SAMPLE_RATE,
|
||
num_channels=TTS_NUM_CHANNELS,
|
||
),
|
||
),
|
||
record=_recording_options_from_env(),
|
||
)
|
||
|
||
|
||
def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
||
params: dict[str, str] = {}
|
||
model_name = model_name.lower()
|
||
|
||
if model_name == "voxcpmtts":
|
||
_set_if_present(params, "streaming", os.getenv("CUSTOM_TTS_STREAMING"))
|
||
_set_if_present(
|
||
params,
|
||
"prompt_text",
|
||
os.getenv("CUSTOM_TTS_PROMPT_TEXT") or os.getenv("VOXCPM_PROMPT_TEXT"),
|
||
)
|
||
_set_if_present(params, "cfg_value", os.getenv("VOXCPM_CFG_VALUE"))
|
||
_set_if_present(params, "inference_timesteps", os.getenv("VOXCPM_INFERENCE_TIMESTEPS"))
|
||
_set_if_present(params, "do_normalize", os.getenv("VOXCPM_DO_NORMALIZE"))
|
||
_set_if_present(params, "denoise", os.getenv("VOXCPM_DENOISE"))
|
||
_set_if_present(params, "retry_badcase", os.getenv("VOXCPM_RETRY_BADCASE"))
|
||
_set_if_present(
|
||
params,
|
||
"retry_badcase_max_times",
|
||
os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES"),
|
||
)
|
||
_set_if_present(
|
||
params,
|
||
"retry_badcase_ratio_threshold",
|
||
os.getenv("VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD"),
|
||
)
|
||
elif model_name == "melotts":
|
||
_set_if_present(params, "speed", os.getenv("CUSTOM_TTS_SPEED"))
|
||
elif model_name == "cosyvoicetts":
|
||
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
|
||
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
|
||
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
||
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
|
||
elif model_name == "sovitstts":
|
||
_set_if_present(params, "text_lang", os.getenv("CUSTOM_TTS_TEXT_LANG"))
|
||
_set_if_present(params, "prompt_lang", os.getenv("CUSTOM_TTS_PROMPT_LANG"))
|
||
_set_if_present(params, "text_split_method", os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD"))
|
||
_set_if_present(params, "batch_size", os.getenv("CUSTOM_TTS_BATCH_SIZE"))
|
||
_set_if_present(params, "media_type", os.getenv("CUSTOM_TTS_MEDIA_TYPE"))
|
||
_set_if_present(params, "streaming_mode", os.getenv("CUSTOM_TTS_STREAMING"))
|
||
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
|
||
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
||
|
||
return params
|
||
|
||
|
||
def _tts_prompt_wav_from_env(model_name: str) -> str | None:
|
||
if model_name.lower() != "voxcpmtts":
|
||
return None
|
||
|
||
return os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV") or None
|
||
|
||
|
||
def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None:
|
||
if value:
|
||
params[key] = value
|
||
|
||
|
||
def _env_int(name: str, default: int) -> int:
|
||
value = os.getenv(name)
|
||
if not value:
|
||
return default
|
||
try:
|
||
return int(value)
|
||
except ValueError:
|
||
logger.warning("Invalid integer for %s=%r, using %s", name, value, default)
|
||
return default
|
||
|
||
|
||
def _env_float(name: str, default: float) -> float:
|
||
value = os.getenv(name)
|
||
if not value:
|
||
return default
|
||
try:
|
||
return float(value)
|
||
except ValueError:
|
||
logger.warning("Invalid float for %s=%r, using %s", name, value, default)
|
||
return default
|
||
|
||
|
||
def _env_bool(name: str, default: bool) -> bool:
|
||
value = os.getenv(name)
|
||
if value is None:
|
||
return default
|
||
|
||
normalized = value.strip().lower()
|
||
if normalized in {"1", "true", "yes", "on"}:
|
||
return True
|
||
if normalized in {"0", "false", "no", "off"}:
|
||
return False
|
||
|
||
logger.warning("Invalid boolean for %s=%r, using %s", name, value, default)
|
||
return default
|
||
|
||
|
||
def _first_env(*names: str) -> str | None:
|
||
for name in names:
|
||
value = os.getenv(name)
|
||
if value and value.strip():
|
||
return value.strip()
|
||
return None
|
||
|
||
|
||
def _recording_options_from_env() -> RecordingOptions:
|
||
return RecordingOptions(
|
||
audio=_env_bool("CUSTOM_RECORD_AUDIO", False),
|
||
traces=_env_bool("CUSTOM_RECORD_TRACES", False),
|
||
logs=_env_bool("CUSTOM_RECORD_LOGS", False),
|
||
transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False),
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
cli.run_app(server)
|