fix: fix voice interupt

This commit is contained in:
0Xiao0
2026-06-12 11:17:12 +08:00
parent 78b9138c17
commit 820dc44053
8 changed files with 537 additions and 48 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import base64
import json
import logging
@ -9,12 +10,13 @@ from dataclasses import dataclass
from pathlib import Path
from beaver_llm import BeaverLLM
from beaver_terminal_client import BeaverTerminalError
from dotenv import load_dotenv
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
from memory import MemoryRecallClient
from tts import BlackboxTTS
from asr import BlackboxSTT
from asr import BlackboxSTT, BoundedStreamAdapter
from livekit.agents import (
Agent,
AgentServer,
@ -32,7 +34,6 @@ from livekit.agents import (
llm,
metrics,
room_io,
stt,
)
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
from livekit.plugins import openai, silero
@ -61,7 +62,7 @@ GENERAL_INSTRUCTIONS = """
EMOTION_INSTRUCTIONS = """
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。
情绪标签之后直接输出给用户的正常回复,不要解释标签。
""".strip()
@ -70,6 +71,7 @@ GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice"
VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto"
INTERRUPT_TOPIC = "lk.interrupt"
VISION_FRAME_TOPIC = "vision.frame"
DEFAULT_AGENT_PROFILE = "normal"
@ -163,14 +165,27 @@ AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
DEFAULT_EMOTION = "neutral"
EMOTION_LABELS = {
"neutral",
"happy",
"sad",
"angry",
"confident",
"confused",
"cool",
"crying",
"delicious",
"embarrassed",
"funny",
"happy",
"kissy",
"laughing",
"loving",
"neutral",
"relaxed",
"sad",
"shocked",
"silly",
"sleepy",
"surprised",
"fearful",
"calm",
"concerned",
"thinking",
"winking",
}
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
@ -380,6 +395,7 @@ class CustomAgent(Agent):
yield chunk
return _stream()
async def _observe_emotion_prefix(
self, chunk: llm.ChatChunk | str | FlushSentinel
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
@ -635,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
return chat_ctx
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
def _with_vision_as_latest_user_message(
chat_ctx: ChatContext, vision_frame: VisionFrame
) -> ChatContext:
chat_ctx = chat_ctx.copy()
image_content = llm.ImageContent(
image=vision_frame.image_data_url,
@ -742,6 +760,7 @@ async def entrypoint(ctx: JobContext) -> None:
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0)
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
@ -749,7 +768,9 @@ async def entrypoint(ctx: JobContext) -> None:
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode))
INPUT_MODE = _normalize_input_mode(
os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode)
)
if LLM_PROVIDER not in {
"openai",
"openai-compatible",
@ -782,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None:
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
beaver_warmup_text: str | None = None
blackbox_stt = BlackboxSTT(
url=ASR_URL,
@ -792,14 +814,27 @@ async def entrypoint(ctx: JobContext) -> None:
itn=os.getenv("CUSTOM_ASR_ITN"),
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
)
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
stt_stream = BoundedStreamAdapter(
stt=blackbox_stt,
vad=ctx.proc.userdata["vad"],
max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None,
)
turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel()
allow_interruptions = _env_bool(
"CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR",
ASR_MAX_SPEECH_DURATION <= 0,
)
if LLM_PROVIDER == "beaver":
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
if not beaver_url:
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
raise RuntimeError(
f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}"
)
beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
beaver_peer_id = (
_first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
)
beaver_device_name = (
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
or "livekit-custom-agent"
@ -811,18 +846,15 @@ async def entrypoint(ctx: JobContext) -> None:
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
)
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text)
text_llm = base_llm
vision_llm = base_llm
logger.info(
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s",
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s",
beaver_url,
beaver_peer_id,
beaver_device_name,
ctx.room.name,
base_llm.session_id,
bool(beaver_warmup_text and beaver_warmup_text.strip()),
len(warmup_reply) if warmup_reply is not None else 0,
)
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
@ -914,8 +946,9 @@ async def entrypoint(ctx: JobContext) -> None:
# 4. Silero VAD
vad=ctx.proc.userdata["vad"],
turn_handling=TurnHandlingOptions(
turn_detection=MultilingualModel(),
turn_detection=turn_detection,
interruption={
"enabled": allow_interruptions,
"resume_false_interruption": True,
"false_interruption_timeout": 1.0,
},
@ -946,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None:
@ctx.room.on("data_received")
def _on_data_received(data_packet) -> None:
packet_topic = getattr(data_packet, "topic", None)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
def _interrupt_done(fut: asyncio.Future[None]) -> None:
try:
fut.result()
except Exception:
logger.exception("Bridge interrupt failed")
def _handle_interrupt(payload: dict[str, object] | None = None) -> None:
reason = None if payload is None else payload.get("reason")
logger.info(
"Received bridge interrupt: topic=%s reason=%s",
packet_topic or "<default>",
reason if isinstance(reason, str) and reason else "<unspecified>",
)
try:
interrupt_fut = session.interrupt(force=True)
except RuntimeError:
logger.exception("Bridge interrupt received before AgentSession was running")
return
interrupt_fut.add_done_callback(_interrupt_done)
if packet_topic == INTERRUPT_TOPIC:
payload: dict[str, object] | None = None
try:
decoded = json.loads(data_packet.data.decode("utf-8"))
if isinstance(decoded, dict):
payload = decoded
except Exception:
logger.exception("Failed to decode interrupt payload")
_handle_interrupt(payload)
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
return
try:
payload = json.loads(data_packet.data.decode("utf-8"))
except Exception:
logger.exception("Failed to decode vision frame payload")
logger.exception("Failed to decode data payload")
return
if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC:
_handle_interrupt(payload)
return
if INPUT_MODE == VOICE_INPUT_MODE:
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
return
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
@ -1013,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None:
),
record=_recording_options_from_env(),
)
if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM):
_start_beaver_background_warmup(
ctx=ctx,
beaver_llm=base_llm,
warmup_text=beaver_warmup_text,
)
def _start_beaver_background_warmup(
*,
ctx: JobContext,
beaver_llm: BeaverLLM,
warmup_text: str | None,
) -> None:
async def _warmup() -> None:
try:
warmup_reply = await beaver_llm.connect(warmup_text=warmup_text)
except BeaverTerminalError:
logger.warning(
"Beaver background handshake failed; will retry on first user turn",
exc_info=True,
)
return
except Exception:
logger.exception("Unexpected Beaver background handshake failure")
return
logger.info(
"Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s",
ctx.room.name,
beaver_llm.session_id,
bool(warmup_text and warmup_text.strip()),
len(warmup_reply) if warmup_reply is not None else 0,
)
warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup")
async def _cancel_warmup() -> None:
if warmup_task.done():
return
warmup_task.cancel()
try:
await warmup_task
except asyncio.CancelledError:
pass
ctx.add_shutdown_callback(_cancel_warmup)
def _tts_params_from_env(model_name: str) -> dict[str, str]: