import asyncio import logging from collections import deque from collections.abc import AsyncIterable from typing import Any import aiohttp from livekit import rtc from livekit.agents import ( DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, APIConnectionError, APIConnectOptions, APIStatusError, APITimeoutError, LanguageCode, NotGivenOr, stt, utils, ) from livekit.agents.utils import is_given from livekit.agents.vad import VAD, VADEventType logger = logging.getLogger("blackbox-asr") DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS = APIConnectOptions( max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout ) STT_CAPABILITIES = stt.STTCapabilities class BlackboxSTT(stt.STT): def __init__( self, url: str, *, model_name: str = "sensevoice", language: str | None = "auto", output_language: str = "zh", hotwords: str | None = None, itn: bool | str | None = None, chunk_mode: bool | str | None = None, timeout: float = 30.0, http_session: aiohttp.ClientSession | None = None, ) -> None: super().__init__( capabilities=stt.STTCapabilities( streaming=False, interim_results=False, diarization=False, ) ) self._url = url self._model_name = model_name self._language = language self._output_language = output_language self._timeout = timeout self._http_session = http_session self._extra_fields: dict[str, str] = {} if hotwords: self._extra_fields["hotwords"] = hotwords if itn is not None: self._extra_fields["itn"] = _form_value(itn) if chunk_mode is not None: self._extra_fields["chunk_mode"] = _form_value(chunk_mode) @property def model(self) -> str: return self._model_name @property def provider(self) -> str: return "asr-blackbox" def _ensure_session(self) -> aiohttp.ClientSession: if self._http_session is None: self._http_session = utils.http_context.http_session() return self._http_session async def _recognize_impl( self, buffer: utils.AudioBuffer, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions, ) -> stt.SpeechEvent: audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes() form = aiohttp.FormData() form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav") form.add_field("model_name", self._model_name) resolved_language = language if is_given(language) else self._language if resolved_language: form.add_field("language", resolved_language) for key, value in self._extra_fields.items(): form.add_field(key, value) try: async with self._ensure_session().post( self._url, data=form, timeout=aiohttp.ClientTimeout( total=self._timeout, sock_connect=conn_options.timeout, ), ) as resp: if resp.status != 200: error_text = await resp.text() raise APIStatusError( message=f"ASR blackbox error: {error_text}", status_code=resp.status, request_id=None, body=error_text, ) payload = await resp.json() logger.info("ASR blackbox raw result: %s", payload) text = _extract_asr_text(payload) if not text: raise APIConnectionError("ASR blackbox returned an empty transcript") logger.info("ASR blackbox result: %s", text) return stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[ stt.SpeechData( text=text, language=LanguageCode(self._output_language), ) ], ) except asyncio.TimeoutError as e: raise APITimeoutError("ASR blackbox request timed out") from e except aiohttp.ClientError as e: raise APIConnectionError(f"ASR blackbox connection error: {e}") from e def _extract_asr_text(payload: dict[str, Any]) -> str: text = payload.get("text") if isinstance(text, str): return text.strip() result = payload.get("result") if isinstance(result, list) and result: first = result[0] if isinstance(first, dict): for key in ("clean_text", "text", "raw_text"): value = first.get(key) if isinstance(value, str) and value.strip(): return value.strip() if isinstance(first, str): return first.strip() raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}") def _form_value(value: bool | str) -> str: if isinstance(value, bool): return str(value).lower() return value class BoundedStreamAdapter(stt.STT): def __init__( self, *, stt: stt.STT, vad: VAD, max_speech_duration: float | None = 12.0, pre_speech_duration: float = 0.5, ) -> None: super().__init__( capabilities=STT_CAPABILITIES( streaming=True, interim_results=False, diarization=False, ) ) self._vad = vad self._stt = stt self._max_speech_duration = max_speech_duration self._pre_speech_duration = pre_speech_duration self._stt.on("metrics_collected", self._on_metrics_collected) @property def wrapped_stt(self) -> stt.STT: return self._stt @property def model(self) -> str: return self._stt.model @property def provider(self) -> str: return self._stt.provider async def _recognize_impl( self, buffer: utils.AudioBuffer, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> stt.SpeechEvent: return await self._stt.recognize( buffer=buffer, language=language, conn_options=conn_options ) def stream( self, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> stt.RecognizeStream: return _BoundedStreamAdapterWrapper( self, vad=self._vad, wrapped_stt=self._stt, language=language, conn_options=conn_options, max_speech_duration=self._max_speech_duration, pre_speech_duration=self._pre_speech_duration, ) def _on_metrics_collected(self, *args: Any, **kwargs: Any) -> None: self.emit("metrics_collected", *args, **kwargs) async def aclose(self) -> None: self._stt.off("metrics_collected", self._on_metrics_collected) class _BoundedStreamAdapterWrapper(stt.RecognizeStream): def __init__( self, adapter: BoundedStreamAdapter, *, vad: VAD, wrapped_stt: stt.STT, language: NotGivenOr[str], conn_options: APIConnectOptions, max_speech_duration: float | None, pre_speech_duration: float, ) -> None: super().__init__(stt=adapter, conn_options=DEFAULT_STREAM_ADAPTER_API_CONNECT_OPTIONS) self._vad = vad self._wrapped_stt = wrapped_stt self._wrapped_stt_conn_options = conn_options self._language = language self._max_speech_duration = max_speech_duration self._pre_speech_duration = pre_speech_duration async def _metrics_monitor_task(self, event_aiter: AsyncIterable[stt.SpeechEvent]) -> None: async for _ in event_aiter: pass async def _run(self) -> None: vad_stream = self._vad.stream() lock = asyncio.Lock() recognize_queue: asyncio.Queue[list[rtc.AudioFrame] | None] = asyncio.Queue() speech_active = False segment_frames: list[rtc.AudioFrame] = [] segment_duration = 0.0 pre_roll_frames: deque[rtc.AudioFrame] = deque() pre_roll_duration = 0.0 def _frame_duration(frame: rtc.AudioFrame) -> float: return frame.samples_per_channel / frame.sample_rate def _append_pre_roll(frame: rtc.AudioFrame) -> None: nonlocal pre_roll_duration pre_roll_frames.append(frame) pre_roll_duration += _frame_duration(frame) while pre_roll_duration > self._pre_speech_duration and pre_roll_frames: pre_roll_duration -= _frame_duration(pre_roll_frames.popleft()) async def _enqueue_segment(frames: list[rtc.AudioFrame], *, forced: bool = False) -> None: if not frames: return if forced: logger.info( "Forcing ASR segment after %.2fs of continuous speech", self._max_speech_duration, ) self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)) await recognize_queue.put(frames) async def _recognize_worker() -> None: while True: frames = await recognize_queue.get() if frames is None: return merged_frames = utils.merge_frames(frames) try: t_event = await self._wrapped_stt.recognize( buffer=merged_frames, language=self._language, conn_options=self._wrapped_stt_conn_options, ) except Exception: logger.exception("ASR segment recognition failed") continue if not t_event.alternatives or not t_event.alternatives[0].text: continue self._event_ch.send_nowait( stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[t_event.alternatives[0]], ) ) async def _forward_input() -> None: nonlocal segment_duration, segment_frames, speech_active async for input_frame in self._input_ch: if isinstance(input_frame, self._FlushSentinel): vad_stream.flush() continue vad_stream.push_frame(input_frame) forced_frames: list[rtc.AudioFrame] = [] async with lock: if speech_active: segment_frames.append(input_frame) segment_duration += _frame_duration(input_frame) if ( self._max_speech_duration is not None and segment_duration >= self._max_speech_duration ): forced_frames = segment_frames segment_frames = [] segment_duration = 0.0 else: _append_pre_roll(input_frame) if forced_frames: await _enqueue_segment(forced_frames, forced=True) vad_stream.end_input() final_frames: list[rtc.AudioFrame] = [] async with lock: if speech_active and segment_frames: final_frames = segment_frames segment_frames = [] segment_duration = 0.0 speech_active = False if final_frames: await _enqueue_segment(final_frames) async def _recognize_from_vad() -> None: nonlocal pre_roll_duration, segment_duration, segment_frames, speech_active async for event in vad_stream: if event.type == VADEventType.START_OF_SPEECH: self._event_ch.send_nowait( stt.SpeechEvent(stt.SpeechEventType.START_OF_SPEECH) ) async with lock: if not speech_active: speech_active = True segment_frames = list(pre_roll_frames) segment_duration = sum(_frame_duration(f) for f in segment_frames) pre_roll_frames.clear() pre_roll_duration = 0.0 continue if event.type != VADEventType.END_OF_SPEECH: continue async with lock: frames = segment_frames segment_frames = [] segment_duration = 0.0 speech_active = False await _enqueue_segment(frames) worker_task = asyncio.create_task(_recognize_worker(), name="bounded_asr_recognize") tasks = [ asyncio.create_task(_forward_input(), name="bounded_asr_forward_input"), asyncio.create_task(_recognize_from_vad(), name="bounded_asr_vad"), ] try: await asyncio.gather(*tasks) await recognize_queue.put(None) await worker_task finally: await utils.aio.cancel_and_wait(*tasks, worker_task) await vad_stream.aclose()