feat: supported different models
This commit is contained in:
201
tts.py
Normal file
201
tts.py
Normal file
@ -0,0 +1,201 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import wave
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from livekit.agents import (
|
||||
DEFAULT_API_CONNECT_OPTIONS,
|
||||
APIConnectionError,
|
||||
APIConnectOptions,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
tts,
|
||||
utils,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("blackbox-tts")
|
||||
|
||||
|
||||
class BlackboxTTS(tts.TTS):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
model_name: str = "voxcpmtts",
|
||||
params: Optional[Mapping[str, object]] = None,
|
||||
prompt_wav_path: Optional[str] = None,
|
||||
prompt_wav_field: str = "prompt_wav",
|
||||
sample_rate: int = 16000,
|
||||
num_channels: int = 1,
|
||||
timeout: float = 60.0,
|
||||
http_session: Optional[aiohttp.ClientSession] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
capabilities=tts.TTSCapabilities(streaming=False),
|
||||
sample_rate=sample_rate,
|
||||
num_channels=num_channels,
|
||||
)
|
||||
self._url = url
|
||||
self._model_name = model_name
|
||||
self._params = {key: _form_value(value) for key, value in (params or {}).items()}
|
||||
self._prompt_wav_path = prompt_wav_path
|
||||
self._prompt_wav_field = prompt_wav_field
|
||||
self._timeout = timeout
|
||||
self._http_session = http_session
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return "tts-blackbox"
|
||||
|
||||
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||
if self._http_session is None:
|
||||
self._http_session = utils.http_context.http_session()
|
||||
return self._http_session
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||
) -> tts.ChunkedStream:
|
||||
return BlackboxTTSStream(
|
||||
tts=self,
|
||||
input_text=text,
|
||||
conn_options=conn_options,
|
||||
)
|
||||
|
||||
|
||||
class BlackboxTTSStream(tts.ChunkedStream):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tts: BlackboxTTS,
|
||||
input_text: str,
|
||||
conn_options: APIConnectOptions,
|
||||
) -> None:
|
||||
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
|
||||
self._tts: BlackboxTTS = tts
|
||||
|
||||
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
form.add_field("text", self.input_text)
|
||||
form.add_field("model_name", self._tts._model_name)
|
||||
for key, value in self._tts._params.items():
|
||||
form.add_field(key, value)
|
||||
|
||||
prompt_file = None
|
||||
if self._tts._prompt_wav_path:
|
||||
if os.path.exists(self._tts._prompt_wav_path):
|
||||
prompt_file = open(self._tts._prompt_wav_path, "rb")
|
||||
form.add_field(
|
||||
self._tts._prompt_wav_field,
|
||||
prompt_file,
|
||||
filename=os.path.basename(self._tts._prompt_wav_path),
|
||||
content_type="audio/wav",
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Prompt wav file not found at %s, skipping prompt_wav field",
|
||||
self._tts._prompt_wav_path,
|
||||
)
|
||||
|
||||
try:
|
||||
async with self._tts._ensure_session().post(
|
||||
self._tts._url,
|
||||
data=form,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self._tts._timeout,
|
||||
sock_connect=self._conn_options.timeout,
|
||||
),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
raise APIStatusError(
|
||||
message=f"TTS blackbox error: {error_text}",
|
||||
status_code=resp.status,
|
||||
request_id=None,
|
||||
body=error_text,
|
||||
)
|
||||
|
||||
content_type = resp.headers.get("Content-Type", "audio/wav")
|
||||
logged_wav_format = False
|
||||
wav_header_probe = bytearray()
|
||||
output_emitter.initialize(
|
||||
request_id=utils.shortuuid(),
|
||||
sample_rate=self._tts.sample_rate,
|
||||
num_channels=self._tts.num_channels,
|
||||
mime_type=content_type,
|
||||
)
|
||||
|
||||
async for data, _ in resp.content.iter_chunks():
|
||||
if data:
|
||||
if not logged_wav_format:
|
||||
wav_header_probe.extend(data)
|
||||
logged_wav_format = _log_wav_format(
|
||||
bytes(wav_header_probe),
|
||||
requested_sample_rate=self._tts.sample_rate,
|
||||
requested_channels=self._tts.num_channels,
|
||||
content_type=content_type,
|
||||
)
|
||||
if not logged_wav_format and len(wav_header_probe) > 4096:
|
||||
logger.info(
|
||||
"TTS blackbox WAV format probe incomplete after %s bytes",
|
||||
len(wav_header_probe),
|
||||
)
|
||||
logged_wav_format = True
|
||||
output_emitter.push(data)
|
||||
output_emitter.flush()
|
||||
except asyncio.TimeoutError as e:
|
||||
raise APITimeoutError("TTS blackbox request timed out") from e
|
||||
except aiohttp.ClientError as e:
|
||||
raise APIConnectionError(f"TTS blackbox connection error: {e}") from e
|
||||
finally:
|
||||
if prompt_file is not None:
|
||||
prompt_file.close()
|
||||
|
||||
|
||||
def _form_value(value: object) -> str:
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
return str(value)
|
||||
|
||||
|
||||
def _log_wav_format(
|
||||
data: bytes,
|
||||
*,
|
||||
requested_sample_rate: int,
|
||||
requested_channels: int,
|
||||
content_type: str,
|
||||
) -> bool:
|
||||
if not content_type.lower().startswith("audio/wav"):
|
||||
logger.info("TTS blackbox returned content-type=%s", content_type)
|
||||
return True
|
||||
|
||||
try:
|
||||
with wave.open(BytesIO(data), "rb") as wav:
|
||||
sample_rate = wav.getframerate()
|
||||
channels = wav.getnchannels()
|
||||
sample_width = wav.getsampwidth()
|
||||
except (EOFError, wave.Error):
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
"TTS blackbox WAV format: %sHz, %sch, %s-bit; output target: %sHz, %sch",
|
||||
sample_rate,
|
||||
channels,
|
||||
sample_width * 8,
|
||||
requested_sample_rate,
|
||||
requested_channels,
|
||||
)
|
||||
return True
|
||||
Reference in New Issue
Block a user