diff --git a/.env.example b/.env.example index 82d9251..001ecc9 100644 --- a/.env.example +++ b/.env.example @@ -20,6 +20,10 @@ CUSTOM_ASR_OUTPUT_LANGUAGE=zh CUSTOM_ASR_HOTWORDS= CUSTOM_ASR_ITN= CUSTOM_ASR_CHUNK_MODE= +# Force a user turn if VAD/ASR never reaches end-of-speech. Set 0 to disable. +CUSTOM_ASR_MAX_SPEECH_DURATION=12 +# Keep false if forced ASR turns should reply even while input audio continues. +CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR=false # LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver. # Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override. diff --git a/asr.py b/asr.py index 4052321..b8c6d76 100644 --- a/asr.py +++ b/asr.py @@ -1,11 +1,14 @@ import asyncio import logging -from typing import Any, Optional, Union +from collections import deque +from collections.abc import AsyncIterable +from typing import Any import aiohttp from livekit import rtc from livekit.agents import ( + DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, APIConnectionError, APIConnectOptions, @@ -17,9 +20,15 @@ from livekit.agents import ( utils, ) from livekit.agents.utils import is_given +from livekit.agents.vad import VAD, VADEventType logger = logging.getLogger("blackbox-asr") +DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions( + max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout +) +STT_CAPABILITIES = stt.STTCapabilities + class BlackboxSTT(stt.STT): def __init__( @@ -27,13 +36,13 @@ class BlackboxSTT(stt.STT): url: str, *, model_name: str = "sensevoice", - language: Optional[str] = "auto", + language: str | None = "auto", output_language: str = "zh", - hotwords: Optional[str] = None, - itn: Optional[Union[bool, str]] = None, - chunk_mode: Optional[Union[bool, str]] = None, + hotwords: str | None = None, + itn: bool | str | None = None, + chunk_mode: bool | str | None = None, timeout: float = 30.0, - http_session: Optional[aiohttp.ClientSession] = None, + http_session: aiohttp.ClientSession | None = None, ) -> None: super().__init__( capabilities=stt.STTCapabilities( @@ -148,7 +157,244 @@ def _extract_asr_text(payload: dict[str, Any]) -> str: raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}") -def _form_value(value: Union[bool, str]) -> str: +def _form_value(value: bool | str) -> str: if isinstance(value, bool): return str(value).lower() return value + + +class BoundedStreamAdapter(stt.STT): + def __init__( + self, + *, + stt: stt.STT, + vad: VAD, + max_speech_duration: float | None = 12.0, + pre_speech_duration: float = 0.5, + ) -> None: + super().__init__( + capabilities=STT_CAPABILITIES( + streaming=True, + interim_results=False, + diarization=False, + ) + ) + self._vad = vad + self._stt = stt + self._max_speech_duration = max_speech_duration + self._pre_speech_duration = pre_speech_duration + self._stt.on("metrics_collected", self._on_metrics_collected) + + @property + def wrapped_stt(self) -> stt.STT: + return self._stt + + @property + def model(self) -> str: + return self._stt.model + + @property + def provider(self) -> str: + return self._stt.provider + + async def _recognize_impl( + self, + buffer: utils.AudioBuffer, + *, + language: NotGivenOr[str] = NOT_GIVEN, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> stt.SpeechEvent: + return await self._stt.recognize( + buffer=buffer, language=language, conn_options=conn_options + ) + + def stream( + self, + *, + language: NotGivenOr[str] = NOT_GIVEN, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> stt.RecognizeStream: + return _BoundedStreamAdapterWrapper( + self, + vad=self._vad, + wrapped_stt=self._stt, + language=language, + conn_options=conn_options, + max_speech_duration=self._max_speech_duration, + pre_speech_duration=self._pre_speech_duration, + ) + + def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None: + self.emit("metrics_collected", *args, **kwargs) + + async def aclose(self) -> None: + self._stt.off("metrics_collected", self._on_metrics_collected) + + +class _BoundedStreamAdapterWrapper(stt.RecognizeStream): + def __init__( + self, + adapter: BoundedStreamAdapter, + *, + vad: VAD, + wrapped_stt: stt.STT, + language: NotGivenOr[str], + conn_options: APIConnectOptions, + max_speech_duration: float | None, + pre_speech_duration: float, + ) -> None: + super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS) + self._vad = vad + self._wrapped_stt = wrapped_stt + self._wrapped_stt_conn_options = conn_options + self._language = language + self._max_speech_duration = max_speech_duration + self._pre_speech_duration = pre_speech_duration + + async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None: + async for _ in event_aiter: + pass + + async def _run(self) -> None: + vad_stream = self._vad.stream() + lock = asyncio.Lock() + recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue() + + speech_active = False + segment_frames: list[rtc.AudioFrame] = [] + segment_duration = 0.0 + pre_roll_frames: deque[rtc.AudioFrame] = deque() + pre_roll_duration = 0.0 + + def _frame_duration(frame: rtc.AudioFrame) -> float: + return frame.samples_per_channel / frame.sample_rate + + def _append_pre_roll(frame: rtc.AudioFrame) -> None: + nonlocal pre_roll_duration + pre_roll_frames.append(frame) + pre_roll_duration += _frame_duration(frame) + + while pre_roll_duration > self._pre_speech_duration and pre_roll_frames: + pre_roll_duration -= _frame_duration(pre_roll_frames.popleft()) + + async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None: + if not frames: + return + + if forced: + logger.info( + "Forcing ASR segment after %.2fs of continuous speech", + self._max_speech_duration, + ) + + self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)) + await recognize_queue.put(frames) + + async def _recognize_worker() -> None: + while True: + frames = await recognize_queue.get() + if frames is None: + return + + merged_frames = utils.merge_frames(frames) + try: + t_event = await self._wrapped_stt.recognize( + buffer=merged_frames, + language=self._language, + conn_options=self._wrapped_stt_conn_options, + ) + except Exception: + logger.exception("ASR segment recognition failed") + continue + + if not t_event.alternatives or not t_event.alternatives[0].text: + continue + + self._event_ch.send_nowait( + stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[t_event.alternatives[0]], + ) + ) + + async def _forward_input() -> None: + nonlocal segment_duration, segment_frames, speech_active + + async for input_frame in self._input_ch: + if isinstance(input_frame, self._FlushSentinel): + vad_stream.flush() + continue + + vad_stream.push_frame(input_frame) + forced_frames: list[rtc.AudioFrame] = [] + + async with lock: + if speech_active: + segment_frames.append(input_frame) + segment_duration += _frame_duration(input_frame) + if ( + self._max_speech_duration is not None + and segment_duration >= self._max_speech_duration + ): + forced_frames = segment_frames + segment_frames = [] + segment_duration = 0.0 + else: + _append_pre_roll(input_frame) + + if forced_frames: + await _enqueue_segment(forced_frames, forced=True) + + vad_stream.end_input() + + final_frames: list[rtc.AudioFrame] = [] + async with lock: + if speech_active and segment_frames: + final_frames = segment_frames + segment_frames = [] + segment_duration = 0.0 + speech_active = False + + if final_frames: + await _enqueue_segment(final_frames) + + async def _recognize_from_vad() -> None: + nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active + + async for event in vad_stream: + if event.type == VADEventType.START_OF_SPEECH: + self._event_ch.send_nowait( + stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH) + ) + async with lock: + if not speech_active: + speech_active = True + segment_frames = list(pre_roll_frames) + segment_duration = sum(_frame_duration(f) for f in segment_frames) + pre_roll_frames.clear() + pre_roll_duration = 0.0 + continue + + if event.type != VADEventType.END_OF_SPEECH: + continue + + async with lock: + frames = segment_frames + segment_frames = [] + segment_duration = 0.0 + speech_active = False + + await _enqueue_segment(frames) + + worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize") + tasks = [ + asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"), + asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"), + ] + try: + await asyncio.gather(*tasks) + await recognize_queue.put(None) + await worker_task + finally: + await utils.aio.cancel_and_wait(*tasks, worker_task) + await vad_stream.aclose() diff --git a/beaver_llm.py b/beaver_llm.py index 0017aa4..1050f3e 100644 --- a/beaver_llm.py +++ b/beaver_llm.py @@ -109,6 +109,7 @@ class BeaverLLMStream(llm.LLMStream): ) -> None: super().__init__(beaver_llm, chat_ctx=chat_ctx, tools=tools, conn_options=conn_options) self._beaver_llm = beaver_llm + self._request_id = shortuuid("beaver_") async def _run(self) -> None: user_text = latest_user_text(self.chat_ctx) @@ -116,9 +117,12 @@ class BeaverLLMStream(llm.LLMStream): reply = await self._beaver_llm._client.send_text(user_text) if reply: - self._event_ch.send_nowait( - llm.ChatChunk( - id=shortuuid("beaver_"), - delta=llm.ChoiceDelta(role="assistant", content=reply), - ) + self._send_text_chunk(reply) + + def _send_text_chunk(self, text: str) -> None: + self._event_ch.send_nowait( + llm.ChatChunk( + id=self._request_id, + delta=llm.ChoiceDelta(role="assistant", content=text), ) + ) diff --git a/beaver_terminal_client.py b/beaver_terminal_client.py index dc14eb7..b209da6 100644 --- a/beaver_terminal_client.py +++ b/beaver_terminal_client.py @@ -28,6 +28,7 @@ class BeaverTerminalConnectionClosed(BeaverTerminalError): @dataclass class MessageIdGenerator: peer_id: str + nonce: str | None = None initial_counter: int = 0 def __post_init__(self) -> None: @@ -35,6 +36,8 @@ class MessageIdGenerator: def next_id(self) -> str: self.counter += 1 + if self.nonce: + return f"{self.peer_id}-{self.nonce}-{self.counter:06d}" return f"{self.peer_id}-{self.counter:06d}" @@ -77,15 +80,22 @@ class BeaverTerminalClient: async def connect(self) -> None: await self._close_websocket() session = self._ensure_http_session() - self._ws = await session.ws_connect(self._url) - await self._send_json( - build_connect_frame(peer_id=self._peer_id, device_name=self._device_name) - ) - frame = await self._receive_json() - if frame.get("type") != "connected": - raise BeaverTerminalError(f"expected connected frame, received {frame!r}") - session_id = frame.get("session_id") - self.session_id = session_id if isinstance(session_id, str) else None + try: + self._ws = await session.ws_connect(self._url) + await self._send_json( + build_connect_frame(peer_id=self._peer_id, device_name=self._device_name) + ) + frame = await self._receive_json() + if frame.get("type") != "connected": + raise BeaverTerminalError(f"expected connected frame, received {frame!r}") + session_id = frame.get("session_id") + self.session_id = session_id if isinstance(session_id, str) else None + except (aiohttp.ClientError, asyncio.TimeoutError, BeaverTerminalConnectionClosed) as exc: + await self._cleanup_failed_connection() + raise BeaverTerminalConnectionClosed("failed to connect to Beaver websocket") from exc + except Exception: + await self._cleanup_failed_connection() + raise async def send_text(self, text: str) -> str: for attempt in range(2): @@ -153,6 +163,13 @@ class BeaverTerminalClient: await self._ws.close() self._ws = None + async def _cleanup_failed_connection(self) -> None: + await self._close_websocket() + self.session_id = None + if self._owned_session and self._http_session is not None: + await self._http_session.close() + self._http_session = None + def _websocket_is_open(self) -> bool: return self._ws is not None and not self._ws.closed diff --git a/custom_agent.py b/custom_agent.py index 93a4f35..840efe1 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -1,3 +1,4 @@ +import asyncio import base64 import json import logging @@ -9,12 +10,13 @@ from dataclasses import dataclass from pathlib import Path from beaver_llm import BeaverLLM +from beaver_terminal_client import BeaverTerminalError from dotenv import load_dotenv from hermes_gateway import GatewaySessionState, HermesGatewayLLM from memory import MemoryRecallClient from tts import BlackboxTTS -from asr import BlackboxSTT +from asr import BlackboxSTT, BoundedStreamAdapter from livekit.agents import ( Agent, AgentServer, @@ -32,7 +34,6 @@ from livekit.agents import ( llm, metrics, room_io, - stt, ) from livekit.agents.voice.generation import update_instructions as update_chat_instructions from livekit.plugins import openai, silero @@ -61,7 +62,7 @@ GENERAL_INSTRUCTIONS = """ EMOTION_INSTRUCTIONS = """ 每次回复必须先输出一个情绪标签,格式严格为: -emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。 +emotion 只能从 angry、confident、confused、cool、crying、delicious、embarrassed、funny、happy、kissy、laughing、loving、neutral、relaxed、sad、shocked、silly、sleepy、surprised、thinking、winking 中选择。 情绪标签之后直接输出给用户的正常回复,不要解释标签。 """.strip() @@ -70,6 +71,7 @@ GENERAL_MODE = "general" VOICE_INPUT_MODE = "voice" VISION_VOICE_INPUT_MODE = "vision_voice" AUTO_INPUT_MODE = "auto" +INTERRUPT_TOPIC = "lk.interrupt" VISION_FRAME_TOPIC = "vision.frame" DEFAULT_AGENT_PROFILE = "normal" @@ -163,14 +165,27 @@ AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name DEFAULT_EMOTION = "neutral" EMOTION_LABELS = { - "neutral", - "happy", - "sad", "angry", + "confident", + "confused", + "cool", + "crying", + "delicious", + "embarrassed", + "funny", + "happy", + "kissy", + "laughing", + "loving", + "neutral", + "relaxed", + "sad", + "shocked", + "silly", + "sleepy", "surprised", - "fearful", - "calm", - "concerned", + "thinking", + "winking", } EMOTION_PREFIX_RE = re.compile(r"^\s*\s*", re.IGNORECASE) TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE) @@ -380,6 +395,7 @@ class CustomAgent(Agent): yield chunk return _stream() + async def _observe_emotion_prefix( self, chunk: llm.ChatChunk | str | FlushSentinel ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: @@ -635,7 +651,9 @@ def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: s return chat_ctx -def _with_vision_as_latest_user_message(chat_ctx: ChatContext, vision_frame: VisionFrame) -> ChatContext: +def _with_vision_as_latest_user_message( + chat_ctx: ChatContext, vision_frame: VisionFrame +) -> ChatContext: chat_ctx = chat_ctx.copy() image_content = llm.ImageContent( image=vision_frame.image_data_url, @@ -742,6 +760,7 @@ async def entrypoint(ctx: JobContext) -> None: ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice") ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto") ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh") + ASR_MAX_SPEECH_DURATION = _env_float("CUSTOM_ASR_MAX_SPEECH_DURATION", 12.0) LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL") LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max") @@ -749,7 +768,9 @@ async def entrypoint(ctx: JobContext) -> None: LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower() TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL) VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL) - INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode)) + INPUT_MODE = _normalize_input_mode( + os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode) + ) if LLM_PROVIDER not in { "openai", "openai-compatible", @@ -782,6 +803,7 @@ async def entrypoint(ctx: JobContext) -> None: MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0) MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000) MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None + beaver_warmup_text: str | None = None blackbox_stt = BlackboxSTT( url=ASR_URL, @@ -792,14 +814,27 @@ async def entrypoint(ctx: JobContext) -> None: itn=os.getenv("CUSTOM_ASR_ITN"), chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"), ) - stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"]) + stt_stream = BoundedStreamAdapter( + stt=blackbox_stt, + vad=ctx.proc.userdata["vad"], + max_speech_duration=ASR_MAX_SPEECH_DURATION if ASR_MAX_SPEECH_DURATION > 0 else None, + ) + turn_detection = "stt" if ASR_MAX_SPEECH_DURATION > 0 else MultilingualModel() + allow_interruptions = _env_bool( + "CUSTOM_ALLOW_INTERRUPTION_DURING_FORCED_ASR", + ASR_MAX_SPEECH_DURATION <= 0, + ) if LLM_PROVIDER == "beaver": beaver_url = _first_env("CUSTOM_BEAVER_WS_URL", "BEAVER_WS_URL") if not beaver_url: - raise RuntimeError(f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}") + raise RuntimeError( + f"CUSTOM_BEAVER_WS_URL or BEAVER_WS_URL is not set in {CUSTOM_ENV_PATH}" + ) - beaver_peer_id = _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}" + beaver_peer_id = ( + _first_env("CUSTOM_BEAVER_PEER_ID", "BEAVER_PEER_ID") or f"livekit-{ctx.room.name}" + ) beaver_device_name = ( _first_env("CUSTOM_BEAVER_DEVICE_NAME", "BEAVER_DEVICE_NAME", "TERMINAL_DEVICE_NAME") or "livekit-custom-agent" @@ -811,18 +846,15 @@ async def entrypoint(ctx: JobContext) -> None: model_name=os.getenv("CUSTOM_BEAVER_MODEL", "beaver-terminal"), ) beaver_warmup_text = os.getenv("CUSTOM_BEAVER_WARMUP_TEXT") - warmup_reply = await base_llm.connect(warmup_text=beaver_warmup_text) text_llm = base_llm vision_llm = base_llm logger.info( - "Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s session_id=%s warmup=%s warmup_reply_len=%s", + "Using Beaver gateway url=%s peer_id=%s device_name=%s room=%s warmup=%s", beaver_url, beaver_peer_id, beaver_device_name, ctx.room.name, - base_llm.session_id, bool(beaver_warmup_text and beaver_warmup_text.strip()), - len(warmup_reply) if warmup_reply is not None else 0, ) elif LLM_PROVIDER in {"hermes", "hermes_gateway", "openclaw"}: gateway_url = os.getenv("CUSTOM_HERMES_GATEWAY_URL", "").strip() @@ -914,8 +946,9 @@ async def entrypoint(ctx: JobContext) -> None: # 4. Silero VAD vad=ctx.proc.userdata["vad"], turn_handling=TurnHandlingOptions( - turn_detection=MultilingualModel(), + turn_detection=turn_detection, interruption={ + "enabled": allow_interruptions, "resume_false_interruption": True, "false_interruption_timeout": 1.0, }, @@ -946,17 +979,53 @@ async def entrypoint(ctx: JobContext) -> None: @ctx.room.on("data_received") def _on_data_received(data_packet) -> None: packet_topic = getattr(data_packet, "topic", None) - if packet_topic not in {None, "", VISION_FRAME_TOPIC}: + + def _interrupt_done(fut: asyncio.Future[None]) -> None: + try: + fut.result() + except Exception: + logger.exception("Bridge interrupt failed") + + def _handle_interrupt(payload: dict[str, object] | None = None) -> None: + reason = None if payload is None else payload.get("reason") + logger.info( + "Received bridge interrupt: topic=%s reason=%s", + packet_topic or "", + reason if isinstance(reason, str) and reason else "", + ) + try: + interrupt_fut = session.interrupt(force=True) + except RuntimeError: + logger.exception("Bridge interrupt received before AgentSession was running") + return + interrupt_fut.add_done_callback(_interrupt_done) + + if packet_topic == INTERRUPT_TOPIC: + payload: dict[str, object] | None = None + try: + decoded = json.loads(data_packet.data.decode("utf-8")) + if isinstance(decoded, dict): + payload = decoded + except Exception: + logger.exception("Failed to decode interrupt payload") + _handle_interrupt(payload) return - if INPUT_MODE == VOICE_INPUT_MODE: - logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE) + if packet_topic not in {None, "", VISION_FRAME_TOPIC}: return try: payload = json.loads(data_packet.data.decode("utf-8")) except Exception: - logger.exception("Failed to decode vision frame payload") + logger.exception("Failed to decode data payload") + return + + if payload.get("type") == "interrupt" or payload.get("topic") == INTERRUPT_TOPIC: + _handle_interrupt(payload) + return + + if INPUT_MODE == VOICE_INPUT_MODE: + logger.info("Ignoring vision frame because CUSTOM_AGENT_INPUT_MODE=%s", INPUT_MODE) return if payload.get("type") != "vision_frame" and payload.get("topic") != VISION_FRAME_TOPIC: @@ -1013,6 +1082,53 @@ async def entrypoint(ctx: JobContext) -> None: ), record=_recording_options_from_env(), ) + if LLM_PROVIDER == "beaver" and isinstance(base_llm, BeaverLLM): + _start_beaver_background_warmup( + ctx=ctx, + beaver_llm=base_llm, + warmup_text=beaver_warmup_text, + ) + + +def _start_beaver_background_warmup( + *, + ctx: JobContext, + beaver_llm: BeaverLLM, + warmup_text: str | None, +) -> None: + async def _warmup() -> None: + try: + warmup_reply = await beaver_llm.connect(warmup_text=warmup_text) + except BeaverTerminalError: + logger.warning( + "Beaver background handshake failed; will retry on first user turn", + exc_info=True, + ) + return + except Exception: + logger.exception("Unexpected Beaver background handshake failure") + return + + logger.info( + "Beaver background handshake completed room=%s session_id=%s warmup=%s warmup_reply_len=%s", + ctx.room.name, + beaver_llm.session_id, + bool(warmup_text and warmup_text.strip()), + len(warmup_reply) if warmup_reply is not None else 0, + ) + + warmup_task = asyncio.create_task(_warmup(), name="beaver_background_warmup") + + async def _cancel_warmup() -> None: + if warmup_task.done(): + return + warmup_task.cancel() + try: + await warmup_task + except asyncio.CancelledError: + pass + + ctx.add_shutdown_callback(_cancel_warmup) def _tts_params_from_env(model_name: str) -> dict[str, str]: diff --git a/test_beaver_llm.py b/test_beaver_llm.py index e4a2b07..fe6f6f9 100644 --- a/test_beaver_llm.py +++ b/test_beaver_llm.py @@ -1,3 +1,4 @@ +import asyncio import json import aiohttp @@ -96,3 +97,73 @@ async def test_beaver_llm_sends_latest_user_text_and_returns_reply( assert received[1]["message_id"].startswith("livekit-room-") assert received[1]["message_id"].endswith("-000001") assert received[1]["text"] == "hello beaver" + + +async def test_beaver_llm_waits_for_slow_reply_without_placeholder( + unused_tcp_port: int, +) -> None: + async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for message in ws: + assert message.type == aiohttp.WSMsgType.TEXT + frame = json.loads(message.data) + + if frame["type"] == "connect": + await ws.send_json( + { + "type": "connected", + "channel_id": "terminal-dev", + "session_id": "terminal-dev:local:livekit-room", + } + ) + elif frame["type"] == "message": + await ws.send_json( + { + "type": "ack", + "message_id": frame["message_id"], + "session_id": "terminal-dev:local:livekit-room", + "accepted": True, + } + ) + await asyncio.sleep(0.05) + await ws.send_json( + { + "type": "message", + "role": "assistant", + "message_id": frame["message_id"], + "run_id": "run-1", + "text": "beaver reply", + "finish_reason": "stop", + } + ) + + return ws + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + beaver_llm = BeaverLLM( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="livekit-room", + device_name="livekit-custom-agent", + ) + ctx = ChatContext.empty() + ctx.add_message(role="user", content="hello beaver") + + try: + chunks: list[str] = [] + async with beaver_llm.chat(chat_ctx=ctx) as stream: + async for chunk in stream: + if chunk.delta and chunk.delta.content: + chunks.append(chunk.delta.content) + finally: + await beaver_llm.aclose() + await runner.cleanup() + + assert chunks == ["beaver reply"] diff --git a/test_beaver_terminal_client.py b/test_beaver_terminal_client.py index 9276ec1..1f8a755 100644 --- a/test_beaver_terminal_client.py +++ b/test_beaver_terminal_client.py @@ -14,6 +14,7 @@ if __name__ == "__main__": try: from custom.beaver_terminal_client import ( BeaverTerminalClient, + BeaverTerminalConnectionClosed, BeaverTerminalError, MessageIdGenerator, build_connect_frame, @@ -22,6 +23,7 @@ try: except ModuleNotFoundError: from beaver_terminal_client import ( BeaverTerminalClient, + BeaverTerminalConnectionClosed, BeaverTerminalError, MessageIdGenerator, build_connect_frame, @@ -244,6 +246,36 @@ async def test_client_raises_on_error_frames(unused_tcp_port: int) -> None: await runner.cleanup() +async def test_client_cleans_up_owned_session_when_connect_fails( + unused_tcp_port: int, +) -> None: + async def websocket_handler(request: web.Request) -> web.Response: + return web.Response(status=200, text="not a websocket") + + app = web.Application() + app.router.add_get("/api/channels/terminal-dev/ws", websocket_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", unused_tcp_port) + await site.start() + + client = BeaverTerminalClient( + url=f"http://127.0.0.1:{unused_tcp_port}/api/channels/terminal-dev/ws", + peer_id="device-001", + device_name="desk-terminal", + message_ids=MessageIdGenerator(peer_id="device-001"), + ) + + try: + with pytest.raises(BeaverTerminalConnectionClosed, match="failed to connect"): + await client.connect() + assert client._http_session is None + assert client._ws is None + finally: + await client.close() + await runner.cleanup() + + async def test_client_treats_assistant_finish_reason_error_as_failed_turn( unused_tcp_port: int, ) -> None: diff --git a/tts.py b/tts.py index a3cbc58..1afaa2b 100644 --- a/tts.py +++ b/tts.py @@ -7,7 +7,6 @@ import time import wave from collections.abc import Mapping from io import BytesIO -from typing import Optional import aiohttp @@ -30,13 +29,13 @@ class BlackboxTTS(tts.TTS): *, url: str, model_name: str = "voxcpmtts", - params: Optional[Mapping[str, object]] = None, - prompt_wav_path: Optional[str] = None, + params: Mapping[str, object] | None = None, + prompt_wav_path: str | None = None, prompt_wav_field: str = "prompt_wav", sample_rate: int = 16000, num_channels: int = 1, timeout: float = 60.0, - http_session: Optional[aiohttp.ClientSession] = None, + http_session: aiohttp.ClientSession | None = None, ) -> None: super().__init__( capabilities=tts.TTSCapabilities(streaming=False),