import asyncio import logging from typing import Any, Optional, Union import aiohttp from livekit import rtc from livekit.agents import ( NOT_GIVEN, APIConnectionError, APIConnectOptions, APIStatusError, APITimeoutError, LanguageCode, NotGivenOr, stt, utils, ) from livekit.agents.utils import is_given logger = logging.getLogger("blackbox-asr") class BlackboxSTT(stt.STT): def __init__( self, url: str, *, model_name: str = "sensevoice", language: Optional[str] = "auto", output_language: str = "zh", hotwords: Optional[str] = None, itn: Optional[Union[bool, str]] = None, chunk_mode: Optional[Union[bool, str]] = None, timeout: float = 30.0, http_session: Optional[aiohttp.ClientSession] = None, ) -> None: super().__init__( capabilities=stt.STTCapabilities( streaming=False, interim_results=False, diarization=False, ) ) self._url = url self._model_name = model_name self._language = language self._output_language = output_language self._timeout = timeout self._http_session = http_session self._extra_fields: dict[str, str] = {} if hotwords: self._extra_fields["hotwords"] = hotwords if itn is not None: self._extra_fields["itn"] = _form_value(itn) if chunk_mode is not None: self._extra_fields["chunk_mode"] = _form_value(chunk_mode) @property def model(self) -> str: return self._model_name @property def provider(self) -> str: return "asr-blackbox" def _ensure_session(self) -> aiohttp.ClientSession: if self._http_session is None: self._http_session = utils.http_context.http_session() return self._http_session async def _recognize_impl( self, buffer: utils.AudioBuffer, *, language: NotGivenOr[str] = NOT_GIVEN, conn_options: APIConnectOptions, ) -> stt.SpeechEvent: audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes() form = aiohttp.FormData() form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav") form.add_field("model_name", self._model_name) resolved_language = language if is_given(language) else self._language if resolved_language: form.add_field("language", resolved_language) for key, value in self._extra_fields.items(): form.add_field(key, value) try: async with self._ensure_session().post( self._url, data=form, timeout=aiohttp.ClientTimeout( total=self._timeout, sock_connect=conn_options.timeout, ), ) as resp: if resp.status != 200: error_text = await resp.text() raise APIStatusError( message=f"ASR blackbox error: {error_text}", status_code=resp.status, request_id=None, body=error_text, ) payload = await resp.json() logger.info("ASR blackbox raw result: %s", payload) text = _extract_asr_text(payload) if not text: raise APIConnectionError("ASR blackbox returned an empty transcript") logger.info("ASR blackbox result: %s", text) return stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, alternatives=[ stt.SpeechData( text=text, language=LanguageCode(self._output_language), ) ], ) except asyncio.TimeoutError as e: raise APITimeoutError("ASR blackbox request timed out") from e except aiohttp.ClientError as e: raise APIConnectionError(f"ASR blackbox connection error: {e}") from e def _extract_asr_text(payload: dict[str, Any]) -> str: text = payload.get("text") if isinstance(text, str): return text.strip() result = payload.get("result") if isinstance(result, list) and result: first = result[0] if isinstance(first, dict): for key in ("clean_text", "text", "raw_text"): value = first.get(key) if isinstance(value, str) and value.strip(): return value.strip() if isinstance(first, str): return first.strip() raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}") def _form_value(value: Union[bool, str]) -> str: if isinstance(value, bool): return str(value).lower() return value