from __future__ import annotations import asyncio import logging import os import time import wave from collections.abc import Mapping from io import BytesIO from typing import Optional import aiohttp from livekit.agents import ( DEFAULT_API_CONNECT_OPTIONS, APIConnectionError, APIConnectOptions, APIStatusError, APITimeoutError, tts, utils, ) logger = logging.getLogger("blackbox-tts") class BlackboxTTS(tts.TTS): def __init__( self, *, url: str, model_name: str = "voxcpmtts", params: Optional[Mapping[str, object]] = None, prompt_wav_path: Optional[str] = None, prompt_wav_field: str = "prompt_wav", sample_rate: int = 16000, num_channels: int = 1, timeout: float = 60.0, http_session: Optional[aiohttp.ClientSession] = None, ) -> None: super().__init__( capabilities=tts.TTSCapabilities(streaming=False), sample_rate=sample_rate, num_channels=num_channels, ) self._url = url self._model_name = model_name self._params = {key: _form_value(value) for key, value in (params or {}).items()} self._prompt_wav_path = prompt_wav_path self._prompt_wav_field = prompt_wav_field self._timeout = timeout self._http_session = http_session @property def model(self) -> str: return self._model_name @property def provider(self) -> str: return "tts-blackbox" def _ensure_session(self) -> aiohttp.ClientSession: if self._http_session is None: self._http_session = utils.http_context.http_session() return self._http_session def synthesize( self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> tts.ChunkedStream: return BlackboxTTSStream( tts=self, input_text=text, conn_options=conn_options, ) class BlackboxTTSStream(tts.ChunkedStream): def __init__( self, *, tts: BlackboxTTS, input_text: str, conn_options: APIConnectOptions, ) -> None: super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) self._tts: BlackboxTTS = tts async def _run(self, output_emitter: tts.AudioEmitter) -> None: started_at = time.perf_counter() form = aiohttp.FormData(default_to_multipart=True) form.add_field("text", self.input_text) form.add_field("model_name", self._tts._model_name) for key, value in self._tts._params.items(): form.add_field(key, value) prompt_file = None if self._tts._prompt_wav_path: if os.path.exists(self._tts._prompt_wav_path): prompt_file = open(self._tts._prompt_wav_path, "rb") form.add_field( self._tts._prompt_wav_field, prompt_file, filename=os.path.basename(self._tts._prompt_wav_path), content_type="audio/wav", ) else: logger.warning( "Prompt wav file not found at %s, skipping prompt_wav field", self._tts._prompt_wav_path, ) try: async with self._tts._ensure_session().post( self._tts._url, data=form, timeout=aiohttp.ClientTimeout( total=self._tts._timeout, sock_connect=self._conn_options.timeout, ), ) as resp: if resp.status != 200: error_text = await resp.text() raise APIStatusError( message=f"TTS blackbox error: {error_text}", status_code=resp.status, request_id=None, body=error_text, ) content_type = resp.headers.get("Content-Type", "audio/wav") logged_wav_format = False wav_header_probe = bytearray() first_audio_at: float | None = None chunk_count = 0 total_bytes = 0 output_emitter.initialize( request_id=utils.shortuuid(), sample_rate=self._tts.sample_rate, num_channels=self._tts.num_channels, mime_type=content_type, ) async for data, _ in resp.content.iter_chunks(): if data: chunk_count += 1 total_bytes += len(data) if first_audio_at is None: first_audio_at = time.perf_counter() logger.info( "TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)", first_audio_at - started_at, len(self.input_text), len(data), ) if not logged_wav_format: wav_header_probe.extend(data) logged_wav_format = _log_wav_format( bytes(wav_header_probe), requested_sample_rate=self._tts.sample_rate, requested_channels=self._tts.num_channels, content_type=content_type, ) if not logged_wav_format and len(wav_header_probe) > 4096: logger.info( "TTS blackbox WAV format probe incomplete after %s bytes", len(wav_header_probe), ) logged_wav_format = True output_emitter.push(data) output_emitter.flush() finished_at = time.perf_counter() logger.info( "TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)", finished_at - started_at, (first_audio_at - started_at) if first_audio_at else -1.0, chunk_count, total_bytes, len(self.input_text), ) except asyncio.TimeoutError as e: raise APITimeoutError("TTS blackbox request timed out") from e except aiohttp.ClientError as e: raise APIConnectionError(f"TTS blackbox connection error: {e}") from e finally: if prompt_file is not None: prompt_file.close() def _form_value(value: object) -> str: if isinstance(value, bool): return str(value).lower() return str(value) def _log_wav_format( data: bytes, *, requested_sample_rate: int, requested_channels: int, content_type: str, ) -> bool: if not content_type.lower().startswith("audio/wav"): logger.info("TTS blackbox returned content-type=%s", content_type) return True try: with wave.open(BytesIO(data), "rb") as wav: sample_rate = wav.getframerate() channels = wav.getnchannels() sample_width = wav.getsampwidth() except (EOFError, wave.Error): return False logger.info( "TTS blackbox WAV format: %sHz, %sch, %s-bit; output target: %sHz, %sch", sample_rate, channels, sample_width * 8, requested_sample_rate, requested_channels, ) return True