202 lines
6.5 KiB
Python
202 lines
6.5 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import wave
|
|
from collections.abc import Mapping
|
|
from io import BytesIO
|
|
from typing import Optional
|
|
|
|
import aiohttp
|
|
|
|
from livekit.agents import (
|
|
DEFAULT_API_CONNECT_OPTIONS,
|
|
APIConnectionError,
|
|
APIConnectOptions,
|
|
APIStatusError,
|
|
APITimeoutError,
|
|
tts,
|
|
utils,
|
|
)
|
|
|
|
logger = logging.getLogger("blackbox-tts")
|
|
|
|
|
|
class BlackboxTTS(tts.TTS):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
url: str,
|
|
model_name: str = "voxcpmtts",
|
|
params: Optional[Mapping[str, object]] = None,
|
|
prompt_wav_path: Optional[str] = None,
|
|
prompt_wav_field: str = "prompt_wav",
|
|
sample_rate: int = 16000,
|
|
num_channels: int = 1,
|
|
timeout: float = 60.0,
|
|
http_session: Optional[aiohttp.ClientSession] = None,
|
|
) -> None:
|
|
super().__init__(
|
|
capabilities=tts.TTSCapabilities(streaming=False),
|
|
sample_rate=sample_rate,
|
|
num_channels=num_channels,
|
|
)
|
|
self._url = url
|
|
self._model_name = model_name
|
|
self._params = {key: _form_value(value) for key, value in (params or {}).items()}
|
|
self._prompt_wav_path = prompt_wav_path
|
|
self._prompt_wav_field = prompt_wav_field
|
|
self._timeout = timeout
|
|
self._http_session = http_session
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
return self._model_name
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
return "tts-blackbox"
|
|
|
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
|
if self._http_session is None:
|
|
self._http_session = utils.http_context.http_session()
|
|
return self._http_session
|
|
|
|
def synthesize(
|
|
self,
|
|
text: str,
|
|
*,
|
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
) -> tts.ChunkedStream:
|
|
return BlackboxTTSStream(
|
|
tts=self,
|
|
input_text=text,
|
|
conn_options=conn_options,
|
|
)
|
|
|
|
|
|
class BlackboxTTSStream(tts.ChunkedStream):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
tts: BlackboxTTS,
|
|
input_text: str,
|
|
conn_options: APIConnectOptions,
|
|
) -> None:
|
|
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
|
|
self._tts: BlackboxTTS = tts
|
|
|
|
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
|
form = aiohttp.FormData(default_to_multipart=True)
|
|
form.add_field("text", self.input_text)
|
|
form.add_field("model_name", self._tts._model_name)
|
|
for key, value in self._tts._params.items():
|
|
form.add_field(key, value)
|
|
|
|
prompt_file = None
|
|
if self._tts._prompt_wav_path:
|
|
if os.path.exists(self._tts._prompt_wav_path):
|
|
prompt_file = open(self._tts._prompt_wav_path, "rb")
|
|
form.add_field(
|
|
self._tts._prompt_wav_field,
|
|
prompt_file,
|
|
filename=os.path.basename(self._tts._prompt_wav_path),
|
|
content_type="audio/wav",
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Prompt wav file not found at %s, skipping prompt_wav field",
|
|
self._tts._prompt_wav_path,
|
|
)
|
|
|
|
try:
|
|
async with self._tts._ensure_session().post(
|
|
self._tts._url,
|
|
data=form,
|
|
timeout=aiohttp.ClientTimeout(
|
|
total=self._tts._timeout,
|
|
sock_connect=self._conn_options.timeout,
|
|
),
|
|
) as resp:
|
|
if resp.status != 200:
|
|
error_text = await resp.text()
|
|
raise APIStatusError(
|
|
message=f"TTS blackbox error: {error_text}",
|
|
status_code=resp.status,
|
|
request_id=None,
|
|
body=error_text,
|
|
)
|
|
|
|
content_type = resp.headers.get("Content-Type", "audio/wav")
|
|
logged_wav_format = False
|
|
wav_header_probe = bytearray()
|
|
output_emitter.initialize(
|
|
request_id=utils.shortuuid(),
|
|
sample_rate=self._tts.sample_rate,
|
|
num_channels=self._tts.num_channels,
|
|
mime_type=content_type,
|
|
)
|
|
|
|
async for data, _ in resp.content.iter_chunks():
|
|
if data:
|
|
if not logged_wav_format:
|
|
wav_header_probe.extend(data)
|
|
logged_wav_format = _log_wav_format(
|
|
bytes(wav_header_probe),
|
|
requested_sample_rate=self._tts.sample_rate,
|
|
requested_channels=self._tts.num_channels,
|
|
content_type=content_type,
|
|
)
|
|
if not logged_wav_format and len(wav_header_probe) > 4096:
|
|
logger.info(
|
|
"TTS blackbox WAV format probe incomplete after %s bytes",
|
|
len(wav_header_probe),
|
|
)
|
|
logged_wav_format = True
|
|
output_emitter.push(data)
|
|
output_emitter.flush()
|
|
except asyncio.TimeoutError as e:
|
|
raise APITimeoutError("TTS blackbox request timed out") from e
|
|
except aiohttp.ClientError as e:
|
|
raise APIConnectionError(f"TTS blackbox connection error: {e}") from e
|
|
finally:
|
|
if prompt_file is not None:
|
|
prompt_file.close()
|
|
|
|
|
|
def _form_value(value: object) -> str:
|
|
if isinstance(value, bool):
|
|
return str(value).lower()
|
|
return str(value)
|
|
|
|
|
|
def _log_wav_format(
|
|
data: bytes,
|
|
*,
|
|
requested_sample_rate: int,
|
|
requested_channels: int,
|
|
content_type: str,
|
|
) -> bool:
|
|
if not content_type.lower().startswith("audio/wav"):
|
|
logger.info("TTS blackbox returned content-type=%s", content_type)
|
|
return True
|
|
|
|
try:
|
|
with wave.open(BytesIO(data), "rb") as wav:
|
|
sample_rate = wav.getframerate()
|
|
channels = wav.getnchannels()
|
|
sample_width = wav.getsampwidth()
|
|
except (EOFError, wave.Error):
|
|
return False
|
|
|
|
logger.info(
|
|
"TTS blackbox WAV format: %sHz, %sch, %s-bit; output target: %sHz, %sch",
|
|
sample_rate,
|
|
channels,
|
|
sample_width * 8,
|
|
requested_sample_rate,
|
|
requested_channels,
|
|
)
|
|
return True
|