Files
livekit_agents/tts.py
2026-05-11 11:22:01 +08:00

202 lines
6.5 KiB
Python

from __future__ import annotations
import asyncio
import logging
import os
import wave
from collections.abc import Mapping
from io import BytesIO
from typing import Optional
import aiohttp
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
logger = logging.getLogger("blackbox-tts")
class BlackboxTTS(tts.TTS):
def __init__(
self,
*,
url: str,
model_name: str = "voxcpmtts",
params: Optional[Mapping[str, object]] = None,
prompt_wav_path: Optional[str] = None,
prompt_wav_field: str = "prompt_wav",
sample_rate: int = 16000,
num_channels: int = 1,
timeout: float = 60.0,
http_session: Optional[aiohttp.ClientSession] = None,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(streaming=False),
sample_rate=sample_rate,
num_channels=num_channels,
)
self._url = url
self._model_name = model_name
self._params = {key: _form_value(value) for key, value in (params or {}).items()}
self._prompt_wav_path = prompt_wav_path
self._prompt_wav_field = prompt_wav_field
self._timeout = timeout
self._http_session = http_session
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "tts-blackbox"
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
def synthesize(
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> tts.ChunkedStream:
return BlackboxTTSStream(
tts=self,
input_text=text,
conn_options=conn_options,
)
class BlackboxTTSStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: BlackboxTTS,
input_text: str,
conn_options: APIConnectOptions,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._tts: BlackboxTTS = tts
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
form = aiohttp.FormData(default_to_multipart=True)
form.add_field("text", self.input_text)
form.add_field("model_name", self._tts._model_name)
for key, value in self._tts._params.items():
form.add_field(key, value)
prompt_file = None
if self._tts._prompt_wav_path:
if os.path.exists(self._tts._prompt_wav_path):
prompt_file = open(self._tts._prompt_wav_path, "rb")
form.add_field(
self._tts._prompt_wav_field,
prompt_file,
filename=os.path.basename(self._tts._prompt_wav_path),
content_type="audio/wav",
)
else:
logger.warning(
"Prompt wav file not found at %s, skipping prompt_wav field",
self._tts._prompt_wav_path,
)
try:
async with self._tts._ensure_session().post(
self._tts._url,
data=form,
timeout=aiohttp.ClientTimeout(
total=self._tts._timeout,
sock_connect=self._conn_options.timeout,
),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"TTS blackbox error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
content_type = resp.headers.get("Content-Type", "audio/wav")
logged_wav_format = False
wav_header_probe = bytearray()
output_emitter.initialize(
request_id=utils.shortuuid(),
sample_rate=self._tts.sample_rate,
num_channels=self._tts.num_channels,
mime_type=content_type,
)
async for data, _ in resp.content.iter_chunks():
if data:
if not logged_wav_format:
wav_header_probe.extend(data)
logged_wav_format = _log_wav_format(
bytes(wav_header_probe),
requested_sample_rate=self._tts.sample_rate,
requested_channels=self._tts.num_channels,
content_type=content_type,
)
if not logged_wav_format and len(wav_header_probe) > 4096:
logger.info(
"TTS blackbox WAV format probe incomplete after %s bytes",
len(wav_header_probe),
)
logged_wav_format = True
output_emitter.push(data)
output_emitter.flush()
except asyncio.TimeoutError as e:
raise APITimeoutError("TTS blackbox request timed out") from e
except aiohttp.ClientError as e:
raise APIConnectionError(f"TTS blackbox connection error: {e}") from e
finally:
if prompt_file is not None:
prompt_file.close()
def _form_value(value: object) -> str:
if isinstance(value, bool):
return str(value).lower()
return str(value)
def _log_wav_format(
data: bytes,
*,
requested_sample_rate: int,
requested_channels: int,
content_type: str,
) -> bool:
if not content_type.lower().startswith("audio/wav"):
logger.info("TTS blackbox returned content-type=%s", content_type)
return True
try:
with wave.open(BytesIO(data), "rb") as wav:
sample_rate = wav.getframerate()
channels = wav.getnchannels()
sample_width = wav.getsampwidth()
except (EOFError, wave.Error):
return False
logger.info(
"TTS blackbox WAV format: %sHz, %sch, %s-bit; output target: %sHz, %sch",
sample_rate,
channels,
sample_width * 8,
requested_sample_rate,
requested_channels,
)
return True