fix: fix voice interupt
This commit is contained in:
162
custom_agent.py
162
custom_agent.py
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -9,12 +10,13 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from beaver_llm import BeaverLLM
|
||||
from beaver_terminal_client import BeaverTerminalError
|
||||
from dotenv import load_dotenv
|
||||
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
|
||||
from memory import MemoryRecallClient
|
||||
from tts import BlackboxTTS
|
||||
|
||||
from asr import BlackboxSTT
|
||||
from asr import BlackboxSTT, BoundedStreamAdapter
|
||||
from livekit.agents import (
|
||||
Agent,
|
||||
AgentServer,
|
||||
@ -32,7 +34,6 @@ from livekit.agents import (
|
||||
llm,
|
||||
metrics,
|
||||
room_io,
|
||||
stt,
|
||||
)
|
||||
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
|
||||
from livekit.plugins import openai, silero
|
||||
@ -61,7 +62,7 @@ GENERAL_INSTRUCTIONS = """
|
||||
|
||||
EMOTION_INSTRUCTIONS = """
|
||||
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
|
||||
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
|
||||
emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。
|
||||
情绪标签之后直接输出给用户的正常回复,不要解释标签。
|
||||
""".strip()
|
||||
|
||||
@ -70,6 +71,7 @@ GENERAL_MODE = "general"
|
||||
VOICE_INPUT_MODE = "voice"
|
||||
VISION_VOICE_INPUT_MODE = "vision_voice"
|
||||
AUTO_INPUT_MODE = "auto"
|
||||
INTERRUPT_TOPIC = "lk.interrupt"
|
||||
VISION_FRAME_TOPIC = "vision.frame"
|
||||
DEFAULT_AGENT_PROFILE = "normal"
|
||||
|
||||
@ -163,14 +165,27 @@ AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
|
||||
|
||||
DEFAULT_EMOTION = "neutral"
|
||||
EMOTION_LABELS = {
|
||||
"neutral",
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"confident",
|
||||
"confused",
|
||||
"cool",
|
||||
"crying",
|
||||
"delicious",
|
||||
"embarrassed",
|
||||
"funny",
|
||||
"happy",
|
||||
"kissy",
|
||||
"laughing",
|
||||
"loving",
|
||||
"neutral",
|
||||
"relaxed",
|
||||
"sad",
|
||||
"shocked",
|
||||
"silly",
|
||||
"sleepy",
|
||||
"surprised",
|
||||
"fearful",
|
||||
"calm",
|
||||
"concerned",
|
||||
"thinking",
|
||||
"winking",
|
||||
}
|
||||
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
|
||||
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
|
||||
@ -380,6 +395,7 @@ class CustomAgent(Agent):
|
||||
yield chunk
|
||||
|
||||
return _stream()
|
||||
|
||||
async def _observe_emotion_prefix(
|
||||
self, chunk: llm.ChatChunk | str | FlushSentinel
|
||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||
@ -635,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
|
||||
return chat_ctx
|
||||
|
||||
|
||||
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
|
||||
def _with_vision_as_latest_user_message(
|
||||
chat_ctx: ChatContext, vision_frame: VisionFrame
|
||||
) -> ChatContext:
|
||||
chat_ctx = chat_ctx.copy()
|
||||
image_content = llm.ImageContent(
|
||||
image=vision_frame.image_data_url,
|
||||
@ -742,6 +760,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
|
||||
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
||||
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
||||
ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0)
|
||||
|
||||
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
||||
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
||||
@ -749,7 +768,9 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
|
||||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
||||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode))
|
||||
INPUT_MODE = _normalize_input_mode(
|
||||
os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode)
|
||||
)
|
||||
if LLM_PROVIDER not in {
|
||||
"openai",
|
||||
"openai-compatible",
|
||||
@ -782,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
||||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
||||
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
||||
beaver_warmup_text: str | None = None
|
||||
|
||||
blackbox_stt = BlackboxSTT(
|
||||
url=ASR_URL,
|
||||
@ -792,14 +814,27 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
itn=os.getenv("CUSTOM_ASR_ITN"),
|
||||
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
|
||||
)
|
||||
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
||||
stt_stream = BoundedStreamAdapter(
|
||||
stt=blackbox_stt,
|
||||
vad=ctx.proc.userdata["vad"],
|
||||
max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None,
|
||||
)
|
||||
turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel()
|
||||
allow_interruptions = _env_bool(
|
||||
"CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR",
|
||||
ASR_MAX_SPEECH_DURATION <= 0,
|
||||
)
|
||||
|
||||
if LLM_PROVIDER == "beaver":
|
||||
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
|
||||
if not beaver_url:
|
||||
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
|
||||
raise RuntimeError(
|
||||
f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}"
|
||||
)
|
||||
|
||||
beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
|
||||
beaver_peer_id = (
|
||||
_first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
|
||||
)
|
||||
beaver_device_name = (
|
||||
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
|
||||
or "livekit-custom-agent"
|
||||
@ -811,18 +846,15 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
|
||||
)
|
||||
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
|
||||
warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text)
|
||||
text_llm = base_llm
|
||||
vision_llm = base_llm
|
||||
logger.info(
|
||||
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s",
|
||||
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s",
|
||||
beaver_url,
|
||||
beaver_peer_id,
|
||||
beaver_device_name,
|
||||
ctx.room.name,
|
||||
base_llm.session_id,
|
||||
bool(beaver_warmup_text and beaver_warmup_text.strip()),
|
||||
len(warmup_reply) if warmup_reply is not None else 0,
|
||||
)
|
||||
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
|
||||
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
|
||||
@ -914,8 +946,9 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
# 4. Silero VAD
|
||||
vad=ctx.proc.userdata["vad"],
|
||||
turn_handling=TurnHandlingOptions(
|
||||
turn_detection=MultilingualModel(),
|
||||
turn_detection=turn_detection,
|
||||
interruption={
|
||||
"enabled": allow_interruptions,
|
||||
"resume_false_interruption": True,
|
||||
"false_interruption_timeout": 1.0,
|
||||
},
|
||||
@ -946,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
@ctx.room.on("data_received")
|
||||
def _on_data_received(data_packet) -> None:
|
||||
packet_topic = getattr(data_packet, "topic", None)
|
||||
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
|
||||
|
||||
def _interrupt_done(fut: asyncio.Future[None]) -> None:
|
||||
try:
|
||||
fut.result()
|
||||
except Exception:
|
||||
logger.exception("Bridge interrupt failed")
|
||||
|
||||
def _handle_interrupt(payload: dict[str, object] | None = None) -> None:
|
||||
reason = None if payload is None else payload.get("reason")
|
||||
logger.info(
|
||||
"Received bridge interrupt: topic=%s reason=%s",
|
||||
packet_topic or "<default>",
|
||||
reason if isinstance(reason, str) and reason else "<unspecified>",
|
||||
)
|
||||
try:
|
||||
interrupt_fut = session.interrupt(force=True)
|
||||
except RuntimeError:
|
||||
logger.exception("Bridge interrupt received before AgentSession was running")
|
||||
return
|
||||
interrupt_fut.add_done_callback(_interrupt_done)
|
||||
|
||||
if packet_topic == INTERRUPT_TOPIC:
|
||||
payload: dict[str, object] | None = None
|
||||
try:
|
||||
decoded = json.loads(data_packet.data.decode("utf-8"))
|
||||
if isinstance(decoded, dict):
|
||||
payload = decoded
|
||||
except Exception:
|
||||
logger.exception("Failed to decode interrupt payload")
|
||||
_handle_interrupt(payload)
|
||||
return
|
||||
|
||||
if INPUT_MODE == VOICE_INPUT_MODE:
|
||||
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
|
||||
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = json.loads(data_packet.data.decode("utf-8"))
|
||||
except Exception:
|
||||
logger.exception("Failed to decode vision frame payload")
|
||||
logger.exception("Failed to decode data payload")
|
||||
return
|
||||
|
||||
if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC:
|
||||
_handle_interrupt(payload)
|
||||
return
|
||||
|
||||
if INPUT_MODE == VOICE_INPUT_MODE:
|
||||
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
|
||||
return
|
||||
|
||||
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
|
||||
@ -1013,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
),
|
||||
record=_recording_options_from_env(),
|
||||
)
|
||||
if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM):
|
||||
_start_beaver_background_warmup(
|
||||
ctx=ctx,
|
||||
beaver_llm=base_llm,
|
||||
warmup_text=beaver_warmup_text,
|
||||
)
|
||||
|
||||
|
||||
def _start_beaver_background_warmup(
|
||||
*,
|
||||
ctx: JobContext,
|
||||
beaver_llm: BeaverLLM,
|
||||
warmup_text: str | None,
|
||||
) -> None:
|
||||
async def _warmup() -> None:
|
||||
try:
|
||||
warmup_reply = await beaver_llm.connect(warmup_text=warmup_text)
|
||||
except BeaverTerminalError:
|
||||
logger.warning(
|
||||
"Beaver background handshake failed; will retry on first user turn",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception("Unexpected Beaver background handshake failure")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s",
|
||||
ctx.room.name,
|
||||
beaver_llm.session_id,
|
||||
bool(warmup_text and warmup_text.strip()),
|
||||
len(warmup_reply) if warmup_reply is not None else 0,
|
||||
)
|
||||
|
||||
warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup")
|
||||
|
||||
async def _cancel_warmup() -> None:
|
||||
if warmup_task.done():
|
||||
return
|
||||
warmup_task.cancel()
|
||||
try:
|
||||
await warmup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
ctx.add_shutdown_callback(_cancel_warmup)
|
||||
|
||||
|
||||
def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
||||
|
||||
Reference in New Issue
Block a user