fix: fix voice interupt
This commit is contained in:
@ -20,6 +20,10 @@ CUSTOM_ASR_OUTPUT_LANGUAGE=zh
|
||||
CUSTOM_ASR_HOTWORDS=
|
||||
CUSTOM_ASR_ITN=
|
||||
CUSTOM_ASR_CHUNK_MODE=
|
||||
# Force a user turn if VAD/ASR never reaches end-of-speech. Set 0 to disable.
|
||||
CUSTOM_ASR_MAX_SPEECH_DURATION=12
|
||||
# Keep false if forced ASR turns should reply even while input audio continues.
|
||||
CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR=false
|
||||
|
||||
# LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver.
|
||||
# Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override.
|
||||
|
||||
260
asr.py
260
asr.py
@ -1,11 +1,14 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional, Union
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from livekit import rtc
|
||||
from livekit.agents import (
|
||||
DEFAULT_API_CONNECT_OPTIONS,
|
||||
NOT_GIVEN,
|
||||
APIConnectionError,
|
||||
APIConnectOptions,
|
||||
@ -17,9 +20,15 @@ from livekit.agents import (
|
||||
utils,
|
||||
)
|
||||
from livekit.agents.utils import is_given
|
||||
from livekit.agents.vad import VAD, VADEventType
|
||||
|
||||
logger = logging.getLogger("blackbox-asr")
|
||||
|
||||
DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions(
|
||||
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
|
||||
)
|
||||
STT_CAPABILITIES = stt.STTCapabilities
|
||||
|
||||
|
||||
class BlackboxSTT(stt.STT):
|
||||
def __init__(
|
||||
@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT):
|
||||
url: str,
|
||||
*,
|
||||
model_name: str = "sensevoice",
|
||||
language: Optional[str] = "auto",
|
||||
language: str | None = "auto",
|
||||
output_language: str = "zh",
|
||||
hotwords: Optional[str] = None,
|
||||
itn: Optional[Union[bool, str]] = None,
|
||||
chunk_mode: Optional[Union[bool, str]] = None,
|
||||
hotwords: str | None = None,
|
||||
itn: bool | str | None = None,
|
||||
chunk_mode: bool | str | None = None,
|
||||
timeout: float = 30.0,
|
||||
http_session: Optional[aiohttp.ClientSession] = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
capabilities=stt.STTCapabilities(
|
||||
@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str:
|
||||
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
||||
|
||||
|
||||
def _form_value(value: Union[bool, str]) -> str:
|
||||
def _form_value(value: bool | str) -> str:
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
return value
|
||||
|
||||
|
||||
class BoundedStreamAdapter(stt.STT):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
stt: stt.STT,
|
||||
vad: VAD,
|
||||
max_speech_duration: float | None = 12.0,
|
||||
pre_speech_duration: float = 0.5,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
capabilities=STT_CAPABILITIES(
|
||||
streaming=True,
|
||||
interim_results=False,
|
||||
diarization=False,
|
||||
)
|
||||
)
|
||||
self._vad = vad
|
||||
self._stt = stt
|
||||
self._max_speech_duration = max_speech_duration
|
||||
self._pre_speech_duration = pre_speech_duration
|
||||
self._stt.on("metrics_collected", self._on_metrics_collected)
|
||||
|
||||
@property
|
||||
def wrapped_stt(self) -> stt.STT:
|
||||
return self._stt
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._stt.model
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return self._stt.provider
|
||||
|
||||
async def _recognize_impl(
|
||||
self,
|
||||
buffer: utils.AudioBuffer,
|
||||
*,
|
||||
language: NotGivenOr[str] = NOT_GIVEN,
|
||||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||
) -> stt.SpeechEvent:
|
||||
return await self._stt.recognize(
|
||||
buffer=buffer, language=language, conn_options=conn_options
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
*,
|
||||
language: NotGivenOr[str] = NOT_GIVEN,
|
||||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||
) -> stt.RecognizeStream:
|
||||
return _BoundedStreamAdapterWrapper(
|
||||
self,
|
||||
vad=self._vad,
|
||||
wrapped_stt=self._stt,
|
||||
language=language,
|
||||
conn_options=conn_options,
|
||||
max_speech_duration=self._max_speech_duration,
|
||||
pre_speech_duration=self._pre_speech_duration,
|
||||
)
|
||||
|
||||
def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.emit("metrics_collected", *args, **kwargs)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._stt.off("metrics_collected", self._on_metrics_collected)
|
||||
|
||||
|
||||
class _BoundedStreamAdapterWrapper(stt.RecognizeStream):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: BoundedStreamAdapter,
|
||||
*,
|
||||
vad: VAD,
|
||||
wrapped_stt: stt.STT,
|
||||
language: NotGivenOr[str],
|
||||
conn_options: APIConnectOptions,
|
||||
max_speech_duration: float | None,
|
||||
pre_speech_duration: float,
|
||||
) -> None:
|
||||
super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS)
|
||||
self._vad = vad
|
||||
self._wrapped_stt = wrapped_stt
|
||||
self._wrapped_stt_conn_options = conn_options
|
||||
self._language = language
|
||||
self._max_speech_duration = max_speech_duration
|
||||
self._pre_speech_duration = pre_speech_duration
|
||||
|
||||
async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None:
|
||||
async for _ in event_aiter:
|
||||
pass
|
||||
|
||||
async def _run(self) -> None:
|
||||
vad_stream = self._vad.stream()
|
||||
lock = asyncio.Lock()
|
||||
recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue()
|
||||
|
||||
speech_active = False
|
||||
segment_frames: list[rtc.AudioFrame] = []
|
||||
segment_duration = 0.0
|
||||
pre_roll_frames: deque[rtc.AudioFrame] = deque()
|
||||
pre_roll_duration = 0.0
|
||||
|
||||
def _frame_duration(frame: rtc.AudioFrame) -> float:
|
||||
return frame.samples_per_channel / frame.sample_rate
|
||||
|
||||
def _append_pre_roll(frame: rtc.AudioFrame) -> None:
|
||||
nonlocal pre_roll_duration
|
||||
pre_roll_frames.append(frame)
|
||||
pre_roll_duration += _frame_duration(frame)
|
||||
|
||||
while pre_roll_duration > self._pre_speech_duration and pre_roll_frames:
|
||||
pre_roll_duration -= _frame_duration(pre_roll_frames.popleft())
|
||||
|
||||
async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None:
|
||||
if not frames:
|
||||
return
|
||||
|
||||
if forced:
|
||||
logger.info(
|
||||
"Forcing ASR segment after %.2fs of continuous speech",
|
||||
self._max_speech_duration,
|
||||
)
|
||||
|
||||
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
|
||||
await recognize_queue.put(frames)
|
||||
|
||||
async def _recognize_worker() -> None:
|
||||
while True:
|
||||
frames = await recognize_queue.get()
|
||||
if frames is None:
|
||||
return
|
||||
|
||||
merged_frames = utils.merge_frames(frames)
|
||||
try:
|
||||
t_event = await self._wrapped_stt.recognize(
|
||||
buffer=merged_frames,
|
||||
language=self._language,
|
||||
conn_options=self._wrapped_stt_conn_options,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("ASR segment recognition failed")
|
||||
continue
|
||||
|
||||
if not t_event.alternatives or not t_event.alternatives[0].text:
|
||||
continue
|
||||
|
||||
self._event_ch.send_nowait(
|
||||
stt.SpeechEvent(
|
||||
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
||||
alternatives=[t_event.alternatives[0]],
|
||||
)
|
||||
)
|
||||
|
||||
async def _forward_input() -> None:
|
||||
nonlocal segment_duration, segment_frames, speech_active
|
||||
|
||||
async for input_frame in self._input_ch:
|
||||
if isinstance(input_frame, self._FlushSentinel):
|
||||
vad_stream.flush()
|
||||
continue
|
||||
|
||||
vad_stream.push_frame(input_frame)
|
||||
forced_frames: list[rtc.AudioFrame] = []
|
||||
|
||||
async with lock:
|
||||
if speech_active:
|
||||
segment_frames.append(input_frame)
|
||||
segment_duration += _frame_duration(input_frame)
|
||||
if (
|
||||
self._max_speech_duration is not None
|
||||
and segment_duration >= self._max_speech_duration
|
||||
):
|
||||
forced_frames = segment_frames
|
||||
segment_frames = []
|
||||
segment_duration = 0.0
|
||||
else:
|
||||
_append_pre_roll(input_frame)
|
||||
|
||||
if forced_frames:
|
||||
await _enqueue_segment(forced_frames, forced=True)
|
||||
|
||||
vad_stream.end_input()
|
||||
|
||||
final_frames: list[rtc.AudioFrame] = []
|
||||
async with lock:
|
||||
if speech_active and segment_frames:
|
||||
final_frames = segment_frames
|
||||
segment_frames = []
|
||||
segment_duration = 0.0
|
||||
speech_active = False
|
||||
|
||||
if final_frames:
|
||||
await _enqueue_segment(final_frames)
|
||||
|
||||
async def _recognize_from_vad() -> None:
|
||||
nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active
|
||||
|
||||
async for event in vad_stream:
|
||||
if event.type == VADEventType.START_OF_SPEECH:
|
||||
self._event_ch.send_nowait(
|
||||
stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH)
|
||||
)
|
||||
async with lock:
|
||||
if not speech_active:
|
||||
speech_active = True
|
||||
segment_frames = list(pre_roll_frames)
|
||||
segment_duration = sum(_frame_duration(f) for f in segment_frames)
|
||||
pre_roll_frames.clear()
|
||||
pre_roll_duration = 0.0
|
||||
continue
|
||||
|
||||
if event.type != VADEventType.END_OF_SPEECH:
|
||||
continue
|
||||
|
||||
async with lock:
|
||||
frames = segment_frames
|
||||
segment_frames = []
|
||||
segment_duration = 0.0
|
||||
speech_active = False
|
||||
|
||||
await _enqueue_segment(frames)
|
||||
|
||||
worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize")
|
||||
tasks = [
|
||||
asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"),
|
||||
asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"),
|
||||
]
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
await recognize_queue.put(None)
|
||||
await worker_task
|
||||
finally:
|
||||
await utils.aio.cancel_and_wait(*tasks, worker_task)
|
||||
await vad_stream.aclose()
|
||||
|
||||
@ -109,6 +109,7 @@ class BeaverLLMStream(llm.LLMStream):
|
||||
) -> None:
|
||||
super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options)
|
||||
self._beaver_llm = beaver_llm
|
||||
self._request_id = shortuuid("beaver_")
|
||||
|
||||
async def _run(self) -> None:
|
||||
user_text = latest_user_text(self.chat_ctx)
|
||||
@ -116,9 +117,12 @@ class BeaverLLMStream(llm.LLMStream):
|
||||
reply = await self._beaver_llm._client.send_text(user_text)
|
||||
|
||||
if reply:
|
||||
self._event_ch.send_nowait(
|
||||
llm.ChatChunk(
|
||||
id=shortuuid("beaver_"),
|
||||
delta=llm.ChoiceDelta(role="assistant", content=reply),
|
||||
)
|
||||
self._send_text_chunk(reply)
|
||||
|
||||
def _send_text_chunk(self, text: str) -> None:
|
||||
self._event_ch.send_nowait(
|
||||
llm.ChatChunk(
|
||||
id=self._request_id,
|
||||
delta=llm.ChoiceDelta(role="assistant", content=text),
|
||||
)
|
||||
)
|
||||
|
||||
@ -28,6 +28,7 @@ class BeaverTerminalConnectionClosed(BeaverTerminalError):
|
||||
@dataclass
|
||||
class MessageIdGenerator:
|
||||
peer_id: str
|
||||
nonce: str | None = None
|
||||
initial_counter: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@ -35,6 +36,8 @@ class MessageIdGenerator:
|
||||
|
||||
def next_id(self) -> str:
|
||||
self.counter += 1
|
||||
if self.nonce:
|
||||
return f"{self.peer_id}-{self.nonce}-{self.counter:06d}"
|
||||
return f"{self.peer_id}-{self.counter:06d}"
|
||||
|
||||
|
||||
@ -77,15 +80,22 @@ class BeaverTerminalClient:
|
||||
async def connect(self) -> None:
|
||||
await self._close_websocket()
|
||||
session = self._ensure_http_session()
|
||||
self._ws = await session.ws_connect(self._url)
|
||||
await self._send_json(
|
||||
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
||||
)
|
||||
frame = await self._receive_json()
|
||||
if frame.get("type") != "connected":
|
||||
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
||||
session_id = frame.get("session_id")
|
||||
self.session_id = session_id if isinstance(session_id, str) else None
|
||||
try:
|
||||
self._ws = await session.ws_connect(self._url)
|
||||
await self._send_json(
|
||||
build_connect_frame(peer_id=self._peer_id, device_name=self._device_name)
|
||||
)
|
||||
frame = await self._receive_json()
|
||||
if frame.get("type") != "connected":
|
||||
raise BeaverTerminalError(f"expected connected frame, received {frame!r}")
|
||||
session_id = frame.get("session_id")
|
||||
self.session_id = session_id if isinstance(session_id, str) else None
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc:
|
||||
await self._cleanup_failed_connection()
|
||||
raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc
|
||||
except Exception:
|
||||
await self._cleanup_failed_connection()
|
||||
raise
|
||||
|
||||
async def send_text(self, text: str) -> str:
|
||||
for attempt in range(2):
|
||||
@ -153,6 +163,13 @@ class BeaverTerminalClient:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def _cleanup_failed_connection(self) -> None:
|
||||
await self._close_websocket()
|
||||
self.session_id = None
|
||||
if self._owned_session and self._http_session is not None:
|
||||
await self._http_session.close()
|
||||
self._http_session = None
|
||||
|
||||
def _websocket_is_open(self) -> bool:
|
||||
return self._ws is not None and not self._ws.closed
|
||||
|
||||
|
||||
162
custom_agent.py
162
custom_agent.py
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@ -9,12 +10,13 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from beaver_llm import BeaverLLM
|
||||
from beaver_terminal_client import BeaverTerminalError
|
||||
from dotenv import load_dotenv
|
||||
from hermes_gateway import GatewaySessionState, HermesGatewayLLM
|
||||
from memory import MemoryRecallClient
|
||||
from tts import BlackboxTTS
|
||||
|
||||
from asr import BlackboxSTT
|
||||
from asr import BlackboxSTT, BoundedStreamAdapter
|
||||
from livekit.agents import (
|
||||
Agent,
|
||||
AgentServer,
|
||||
@ -32,7 +34,6 @@ from livekit.agents import (
|
||||
llm,
|
||||
metrics,
|
||||
room_io,
|
||||
stt,
|
||||
)
|
||||
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
|
||||
from livekit.plugins import openai, silero
|
||||
@ -61,7 +62,7 @@ GENERAL_INSTRUCTIONS = """
|
||||
|
||||
EMOTION_INSTRUCTIONS = """
|
||||
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
|
||||
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
|
||||
emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。
|
||||
情绪标签之后直接输出给用户的正常回复,不要解释标签。
|
||||
""".strip()
|
||||
|
||||
@ -70,6 +71,7 @@ GENERAL_MODE = "general"
|
||||
VOICE_INPUT_MODE = "voice"
|
||||
VISION_VOICE_INPUT_MODE = "vision_voice"
|
||||
AUTO_INPUT_MODE = "auto"
|
||||
INTERRUPT_TOPIC = "lk.interrupt"
|
||||
VISION_FRAME_TOPIC = "vision.frame"
|
||||
DEFAULT_AGENT_PROFILE = "normal"
|
||||
|
||||
@ -163,14 +165,27 @@ AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
|
||||
|
||||
DEFAULT_EMOTION = "neutral"
|
||||
EMOTION_LABELS = {
|
||||
"neutral",
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"confident",
|
||||
"confused",
|
||||
"cool",
|
||||
"crying",
|
||||
"delicious",
|
||||
"embarrassed",
|
||||
"funny",
|
||||
"happy",
|
||||
"kissy",
|
||||
"laughing",
|
||||
"loving",
|
||||
"neutral",
|
||||
"relaxed",
|
||||
"sad",
|
||||
"shocked",
|
||||
"silly",
|
||||
"sleepy",
|
||||
"surprised",
|
||||
"fearful",
|
||||
"calm",
|
||||
"concerned",
|
||||
"thinking",
|
||||
"winking",
|
||||
}
|
||||
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
|
||||
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
|
||||
@ -380,6 +395,7 @@ class CustomAgent(Agent):
|
||||
yield chunk
|
||||
|
||||
return _stream()
|
||||
|
||||
async def _observe_emotion_prefix(
|
||||
self, chunk: llm.ChatChunk | str | FlushSentinel
|
||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||
@ -635,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s
|
||||
return chat_ctx
|
||||
|
||||
|
||||
def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext:
|
||||
def _with_vision_as_latest_user_message(
|
||||
chat_ctx: ChatContext, vision_frame: VisionFrame
|
||||
) -> ChatContext:
|
||||
chat_ctx = chat_ctx.copy()
|
||||
image_content = llm.ImageContent(
|
||||
image=vision_frame.image_data_url,
|
||||
@ -742,6 +760,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
|
||||
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
||||
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
||||
ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0)
|
||||
|
||||
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
||||
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
||||
@ -749,7 +768,9 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
|
||||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
||||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode))
|
||||
INPUT_MODE = _normalize_input_mode(
|
||||
os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode)
|
||||
)
|
||||
if LLM_PROVIDER not in {
|
||||
"openai",
|
||||
"openai-compatible",
|
||||
@ -782,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
||||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
||||
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
||||
beaver_warmup_text: str | None = None
|
||||
|
||||
blackbox_stt = BlackboxSTT(
|
||||
url=ASR_URL,
|
||||
@ -792,14 +814,27 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
itn=os.getenv("CUSTOM_ASR_ITN"),
|
||||
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
|
||||
)
|
||||
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
||||
stt_stream = BoundedStreamAdapter(
|
||||
stt=blackbox_stt,
|
||||
vad=ctx.proc.userdata["vad"],
|
||||
max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None,
|
||||
)
|
||||
turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel()
|
||||
allow_interruptions = _env_bool(
|
||||
"CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR",
|
||||
ASR_MAX_SPEECH_DURATION <= 0,
|
||||
)
|
||||
|
||||
if LLM_PROVIDER == "beaver":
|
||||
beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL")
|
||||
if not beaver_url:
|
||||
raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}")
|
||||
raise RuntimeError(
|
||||
f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}"
|
||||
)
|
||||
|
||||
beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
|
||||
beaver_peer_id = (
|
||||
_first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}"
|
||||
)
|
||||
beaver_device_name = (
|
||||
_first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME")
|
||||
or "livekit-custom-agent"
|
||||
@ -811,18 +846,15 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"),
|
||||
)
|
||||
beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT")
|
||||
warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text)
|
||||
text_llm = base_llm
|
||||
vision_llm = base_llm
|
||||
logger.info(
|
||||
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s",
|
||||
"Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s",
|
||||
beaver_url,
|
||||
beaver_peer_id,
|
||||
beaver_device_name,
|
||||
ctx.room.name,
|
||||
base_llm.session_id,
|
||||
bool(beaver_warmup_text and beaver_warmup_text.strip()),
|
||||
len(warmup_reply) if warmup_reply is not None else 0,
|
||||
)
|
||||
elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}:
|
||||
gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip()
|
||||
@ -914,8 +946,9 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
# 4. Silero VAD
|
||||
vad=ctx.proc.userdata["vad"],
|
||||
turn_handling=TurnHandlingOptions(
|
||||
turn_detection=MultilingualModel(),
|
||||
turn_detection=turn_detection,
|
||||
interruption={
|
||||
"enabled": allow_interruptions,
|
||||
"resume_false_interruption": True,
|
||||
"false_interruption_timeout": 1.0,
|
||||
},
|
||||
@ -946,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
@ctx.room.on("data_received")
|
||||
def _on_data_received(data_packet) -> None:
|
||||
packet_topic = getattr(data_packet, "topic", None)
|
||||
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
|
||||
|
||||
def _interrupt_done(fut: asyncio.Future[None]) -> None:
|
||||
try:
|
||||
fut.result()
|
||||
except Exception:
|
||||
logger.exception("Bridge interrupt failed")
|
||||
|
||||
def _handle_interrupt(payload: dict[str, object] | None = None) -> None:
|
||||
reason = None if payload is None else payload.get("reason")
|
||||
logger.info(
|
||||
"Received bridge interrupt: topic=%s reason=%s",
|
||||
packet_topic or "<default>",
|
||||
reason if isinstance(reason, str) and reason else "<unspecified>",
|
||||
)
|
||||
try:
|
||||
interrupt_fut = session.interrupt(force=True)
|
||||
except RuntimeError:
|
||||
logger.exception("Bridge interrupt received before AgentSession was running")
|
||||
return
|
||||
interrupt_fut.add_done_callback(_interrupt_done)
|
||||
|
||||
if packet_topic == INTERRUPT_TOPIC:
|
||||
payload: dict[str, object] | None = None
|
||||
try:
|
||||
decoded = json.loads(data_packet.data.decode("utf-8"))
|
||||
if isinstance(decoded, dict):
|
||||
payload = decoded
|
||||
except Exception:
|
||||
logger.exception("Failed to decode interrupt payload")
|
||||
_handle_interrupt(payload)
|
||||
return
|
||||
|
||||
if INPUT_MODE == VOICE_INPUT_MODE:
|
||||
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
|
||||
if packet_topic not in {None, "", VISION_FRAME_TOPIC}:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = json.loads(data_packet.data.decode("utf-8"))
|
||||
except Exception:
|
||||
logger.exception("Failed to decode vision frame payload")
|
||||
logger.exception("Failed to decode data payload")
|
||||
return
|
||||
|
||||
if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC:
|
||||
_handle_interrupt(payload)
|
||||
return
|
||||
|
||||
if INPUT_MODE == VOICE_INPUT_MODE:
|
||||
logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE)
|
||||
return
|
||||
|
||||
if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC:
|
||||
@ -1013,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
),
|
||||
record=_recording_options_from_env(),
|
||||
)
|
||||
if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM):
|
||||
_start_beaver_background_warmup(
|
||||
ctx=ctx,
|
||||
beaver_llm=base_llm,
|
||||
warmup_text=beaver_warmup_text,
|
||||
)
|
||||
|
||||
|
||||
def _start_beaver_background_warmup(
|
||||
*,
|
||||
ctx: JobContext,
|
||||
beaver_llm: BeaverLLM,
|
||||
warmup_text: str | None,
|
||||
) -> None:
|
||||
async def _warmup() -> None:
|
||||
try:
|
||||
warmup_reply = await beaver_llm.connect(warmup_text=warmup_text)
|
||||
except BeaverTerminalError:
|
||||
logger.warning(
|
||||
"Beaver background handshake failed; will retry on first user turn",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
logger.exception("Unexpected Beaver background handshake failure")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s",
|
||||
ctx.room.name,
|
||||
beaver_llm.session_id,
|
||||
bool(warmup_text and warmup_text.strip()),
|
||||
len(warmup_reply) if warmup_reply is not None else 0,
|
||||
)
|
||||
|
||||
warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup")
|
||||
|
||||
async def _cancel_warmup() -> None:
|
||||
if warmup_task.done():
|
||||
return
|
||||
warmup_task.cancel()
|
||||
try:
|
||||
await warmup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
ctx.add_shutdown_callback(_cancel_warmup)
|
||||
|
||||
|
||||
def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
@ -96,3 +97,73 @@ async def test_beaver_llm_sends_latest_user_text_and_returns_reply(
|
||||
assert received[1]["message_id"].startswith("livekit-room-")
|
||||
assert received[1]["message_id"].endswith("-000001")
|
||||
assert received[1]["text"] == "hello beaver"
|
||||
|
||||
|
||||
async def test_beaver_llm_waits_for_slow_reply_without_placeholder(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
async for message in ws:
|
||||
assert message.type == aiohttp.WSMsgType.TEXT
|
||||
frame = json.loads(message.data)
|
||||
|
||||
if frame["type"] == "connect":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": "terminal-dev",
|
||||
"session_id": "terminal-dev:local:livekit-room",
|
||||
}
|
||||
)
|
||||
elif frame["type"] == "message":
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "ack",
|
||||
"message_id": frame["message_id"],
|
||||
"session_id": "terminal-dev:local:livekit-room",
|
||||
"accepted": True,
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": frame["message_id"],
|
||||
"run_id": "run-1",
|
||||
"text": "beaver reply",
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
)
|
||||
|
||||
return ws
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
beaver_llm = BeaverLLM(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="livekit-room",
|
||||
device_name="livekit-custom-agent",
|
||||
)
|
||||
ctx = ChatContext.empty()
|
||||
ctx.add_message(role="user", content="hello beaver")
|
||||
|
||||
try:
|
||||
chunks: list[str] = []
|
||||
async with beaver_llm.chat(chat_ctx=ctx) as stream:
|
||||
async for chunk in stream:
|
||||
if chunk.delta and chunk.delta.content:
|
||||
chunks.append(chunk.delta.content)
|
||||
finally:
|
||||
await beaver_llm.aclose()
|
||||
await runner.cleanup()
|
||||
|
||||
assert chunks == ["beaver reply"]
|
||||
|
||||
@ -14,6 +14,7 @@ if __name__ == "__main__":
|
||||
try:
|
||||
from custom.beaver_terminal_client import (
|
||||
BeaverTerminalClient,
|
||||
BeaverTerminalConnectionClosed,
|
||||
BeaverTerminalError,
|
||||
MessageIdGenerator,
|
||||
build_connect_frame,
|
||||
@ -22,6 +23,7 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
from beaver_terminal_client import (
|
||||
BeaverTerminalClient,
|
||||
BeaverTerminalConnectionClosed,
|
||||
BeaverTerminalError,
|
||||
MessageIdGenerator,
|
||||
build_connect_frame,
|
||||
@ -244,6 +246,36 @@ async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None:
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
async def test_client_cleans_up_owned_session_when_connect_fails(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
async def websocket_handler(request: web.Request) -> web.Response:
|
||||
return web.Response(status=200, text="not a websocket")
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port)
|
||||
await site.start()
|
||||
|
||||
client = BeaverTerminalClient(
|
||||
url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws",
|
||||
peer_id="device-001",
|
||||
device_name="desk-terminal",
|
||||
message_ids=MessageIdGenerator(peer_id="device-001"),
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(BeaverTerminalConnectionClosed, match="failed to connect"):
|
||||
await client.connect()
|
||||
assert client._http_session is None
|
||||
assert client._ws is None
|
||||
finally:
|
||||
await client.close()
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
async def test_client_treats_assistant_finish_reason_error_as_failed_turn(
|
||||
unused_tcp_port: int,
|
||||
) -> None:
|
||||
|
||||
7
tts.py
7
tts.py
@ -7,7 +7,6 @@ import time
|
||||
import wave
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
@ -30,13 +29,13 @@ class BlackboxTTS(tts.TTS):
|
||||
*,
|
||||
url: str,
|
||||
model_name: str = "voxcpmtts",
|
||||
params: Optional[Mapping[str, object]] = None,
|
||||
prompt_wav_path: Optional[str] = None,
|
||||
params: Mapping[str, object] | None = None,
|
||||
prompt_wav_path: str | None = None,
|
||||
prompt_wav_field: str = "prompt_wav",
|
||||
sample_rate: int = 16000,
|
||||
num_channels: int = 1,
|
||||
timeout: float = 60.0,
|
||||
http_session: Optional[aiohttp.ClientSession] = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
capabilities=tts.TTSCapabilities(streaming=False),
|
||||
|
||||
Reference in New Issue
Block a user