initial commit
This commit is contained in:
118
tts_voxcpm.py
Normal file
118
tts_voxcpm.py
Normal file
@ -0,0 +1,118 @@
|
||||
import aiohttp
|
||||
import logging
|
||||
import os
|
||||
from livekit.agents import tts, utils, APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS
|
||||
|
||||
logger = logging.getLogger("voxcpm-tts")
|
||||
|
||||
class VoxCPMTTS(tts.TTS):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
model_name: str = "voxcpmtts",
|
||||
prompt_text: str = "澳门有乜嘢好食嘅",
|
||||
prompt_wav_path: str = "/home/verachen/Music/voice/2food16k_2.wav",
|
||||
cfg_value: str = "2.0",
|
||||
inference_timesteps: str = "10",
|
||||
do_normalize: str = "true",
|
||||
denoise: str = "true",
|
||||
retry_badcase: str = "true",
|
||||
retry_badcase_max_times: str = "3",
|
||||
retry_badcase_ratio_threshold: str = "6.0",
|
||||
sample_rate: int = 16000,
|
||||
):
|
||||
super().__init__(
|
||||
capabilities=tts.TTSCapabilities(streaming=False),
|
||||
sample_rate=sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
self._url = url
|
||||
self._opts = {
|
||||
"model_name": model_name,
|
||||
"streaming": "false",
|
||||
"prompt_text": prompt_text,
|
||||
"cfg_value": str(cfg_value),
|
||||
"inference_timesteps": str(inference_timesteps),
|
||||
"do_normalize": str(do_normalize),
|
||||
"denoise": str(denoise),
|
||||
"retry_badcase": str(retry_badcase),
|
||||
"retry_badcase_max_times": str(retry_badcase_max_times),
|
||||
"retry_badcase_ratio_threshold": str(retry_badcase_ratio_threshold),
|
||||
}
|
||||
self._prompt_wav_path = prompt_wav_path
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
return self._opts["model_name"]
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||
) -> tts.ChunkedStream:
|
||||
return VoxCPMStream(
|
||||
self, text, self._url, self._opts, self._prompt_wav_path, conn_options=conn_options
|
||||
)
|
||||
|
||||
class VoxCPMStream(tts.ChunkedStream):
|
||||
def __init__(
|
||||
self,
|
||||
tts: VoxCPMTTS,
|
||||
text: str,
|
||||
url: str,
|
||||
opts: dict,
|
||||
prompt_wav_path: str,
|
||||
conn_options: APIConnectOptions,
|
||||
):
|
||||
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
|
||||
self._url = url
|
||||
self._opts = opts
|
||||
self._prompt_wav_path = prompt_wav_path
|
||||
|
||||
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
||||
# Initialize emitter early to avoid "AudioEmitter isn't started" error on failure
|
||||
output_emitter.initialize(
|
||||
request_id="",
|
||||
sample_rate=self._tts.sample_rate,
|
||||
num_channels=self._tts.num_channels,
|
||||
mime_type="audio/wav",
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = aiohttp.FormData()
|
||||
data.add_field("text", self.input_text)
|
||||
for k, v in self._opts.items():
|
||||
data.add_field(k, v)
|
||||
|
||||
# Open the prompt wav file if it exists
|
||||
f = None
|
||||
if os.path.exists(self._prompt_wav_path):
|
||||
f = open(self._prompt_wav_path, "rb")
|
||||
data.add_field("prompt_wav", f, filename="prompt.wav", content_type="audio/wav")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Prompt wav file not found at {self._prompt_wav_path}, skipping prompt_wav field"
|
||||
)
|
||||
|
||||
try:
|
||||
# Set a reasonable timeout for synthesis
|
||||
async with session.post(
|
||||
self._url, data=data, timeout=aiohttp.ClientTimeout(total=60)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
err_text = await resp.text()
|
||||
logger.error(f"VoxCPM TTS error: {resp.status} {err_text}")
|
||||
return
|
||||
|
||||
# Read the entire audio data (since streaming=false)
|
||||
audio_data = await resp.read()
|
||||
|
||||
output_emitter.push(audio_data)
|
||||
output_emitter.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"VoxCPM TTS request failed: {e}")
|
||||
finally:
|
||||
if f:
|
||||
f.close()
|
||||
Reference in New Issue
Block a user