From 409c7c9de05caa4505443679466050d034c61b48 Mon Sep 17 00:00:00 2001 From: 0Xiao0 <511201264@qq.com> Date: Mon, 11 May 2026 11:22:01 +0800 Subject: [PATCH] feat: supported different models --- asr.py | 154 +++++++++++++++++++++++++++++++++ custom_agent.py | 225 ++++++++++++++++++++++++++++++++---------------- test_asr.py | 34 ++++---- test_voxcpm.py | 54 ++++++++---- tts.py | 201 ++++++++++++++++++++++++++++++++++++++++++ tts_voxcpm.py | 118 ------------------------- 6 files changed, 558 insertions(+), 228 deletions(-) create mode 100644 asr.py create mode 100644 tts.py delete mode 100644 tts_voxcpm.py diff --git a/asr.py b/asr.py new file mode 100644 index 0000000..4052321 --- /dev/null +++ b/asr.py @@ -0,0 +1,154 @@ +import asyncio +import logging +from typing import Any, Optional, Union + +import aiohttp + +from livekit import rtc +from livekit.agents import ( + NOT_GIVEN, + APIConnectionError, + APIConnectOptions, + APIStatusError, + APITimeoutError, + LanguageCode, + NotGivenOr, + stt, + utils, +) +from livekit.agents.utils import is_given + +logger = logging.getLogger("blackbox-asr") + + +class BlackboxSTT(stt.STT): + def __init__( + self, + url: str, + *, + model_name: str = "sensevoice", + language: Optional[str] = "auto", + output_language: str = "zh", + hotwords: Optional[str] = None, + itn: Optional[Union[bool, str]] = None, + chunk_mode: Optional[Union[bool, str]] = None, + timeout: float = 30.0, + http_session: Optional[aiohttp.ClientSession] = None, + ) -> None: + super().__init__( + capabilities=stt.STTCapabilities( + streaming=False, + interim_results=False, + diarization=False, + ) + ) + self._url = url + self._model_name = model_name + self._language = language + self._output_language = output_language + self._timeout = timeout + self._http_session = http_session + self._extra_fields: dict[str, str] = {} + + if hotwords: + self._extra_fields["hotwords"] = hotwords + if itn is not None: + self._extra_fields["itn"] = _form_value(itn) + if chunk_mode is not None: + self._extra_fields["chunk_mode"] = _form_value(chunk_mode) + + @property + def model(self) -> str: + return self._model_name + + @property + def provider(self) -> str: + return "asr-blackbox" + + def _ensure_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + self._http_session = utils.http_context.http_session() + return self._http_session + + async def _recognize_impl( + self, + buffer: utils.AudioBuffer, + *, + language: NotGivenOr[str] = NOT_GIVEN, + conn_options: APIConnectOptions, + ) -> stt.SpeechEvent: + audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes() + + form = aiohttp.FormData() + form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav") + form.add_field("model_name", self._model_name) + + resolved_language = language if is_given(language) else self._language + if resolved_language: + form.add_field("language", resolved_language) + for key, value in self._extra_fields.items(): + form.add_field(key, value) + + try: + async with self._ensure_session().post( + self._url, + data=form, + timeout=aiohttp.ClientTimeout( + total=self._timeout, + sock_connect=conn_options.timeout, + ), + ) as resp: + if resp.status != 200: + error_text = await resp.text() + raise APIStatusError( + message=f"ASR blackbox error: {error_text}", + status_code=resp.status, + request_id=None, + body=error_text, + ) + + payload = await resp.json() + logger.info("ASR blackbox raw result: %s", payload) + text = _extract_asr_text(payload) + if not text: + raise APIConnectionError("ASR blackbox returned an empty transcript") + + logger.info("ASR blackbox result: %s", text) + return stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[ + stt.SpeechData( + text=text, + language=LanguageCode(self._output_language), + ) + ], + ) + except asyncio.TimeoutError as e: + raise APITimeoutError("ASR blackbox request timed out") from e + except aiohttp.ClientError as e: + raise APIConnectionError(f"ASR blackbox connection error: {e}") from e + + +def _extract_asr_text(payload: dict[str, Any]) -> str: + text = payload.get("text") + if isinstance(text, str): + return text.strip() + + result = payload.get("result") + if isinstance(result, list) and result: + first = result[0] + if isinstance(first, dict): + for key in ("clean_text", "text", "raw_text"): + value = first.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(first, str): + return first.strip() + + raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}") + + +def _form_value(value: Union[bool, str]) -> str: + if isinstance(value, bool): + return str(value).lower() + return value diff --git a/custom_agent.py b/custom_agent.py index eef6927..75dea47 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -1,78 +1,35 @@ import logging import os -import aiohttp +from pathlib import Path +from typing import Optional + from dotenv import load_dotenv -from livekit import rtc + +from asr import BlackboxSTT from livekit.agents import ( Agent, AgentServer, AgentSession, - APIConnectOptions, JobContext, JobProcess, - LanguageCode, MetricsCollectedEvent, - NOT_GIVEN, - NotGivenOr, + RecordingOptions, TurnHandlingOptions, cli, metrics, room_io, stt, - text_transforms, - utils, ) -from livekit.plugins import silero, openai +from livekit.plugins import openai, silero from livekit.plugins.turn_detector.multilingual import MultilingualModel +from tts import BlackboxTTS logger = logging.getLogger("custom-agent") -load_dotenv() +CUSTOM_ENV_PATH = Path(__file__).with_name(".env") +load_dotenv(dotenv_path=CUSTOM_ENV_PATH) +AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "") -class SenseVoiceSTT(stt.STT): - def __init__(self, url: str): - super().__init__(capabilities=stt.STTCapabilities(streaming=False, interim_results=False, diarization=False)) - self._url = url - - @property - def model(self) -> str: - return "sensevoice" - - async def _recognize_impl( - self, - buffer: utils.AudioBuffer, - *, - language: NotGivenOr[str] = NOT_GIVEN, - conn_options: APIConnectOptions, - ) -> stt.SpeechEvent: - audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes() - - async with aiohttp.ClientSession() as session: - data = aiohttp.FormData() - data.add_field('audio', audio_data, filename='audio.wav', content_type='audio/wav') - data.add_field('model_name', 'sensevoice') - - lang = language if language is not NOT_GIVEN else 'auto' - data.add_field('language', lang) - - try: - async with session.post(self._url, data=data, timeout=30) as resp: - if resp.status != 200: - raise Exception(f"ASR server returned status {resp.status}") - - result = await resp.json() - if not result.get("result"): - return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT) - - text = result["result"][0].get("clean_text", "") - logger.info(f"SenseVoice ASR Result: {text}") - return stt.SpeechEvent( - type=stt.SpeechEventType.FINAL_TRANSCRIPT, - alternatives=[stt.SpeechData(text=text, language=LanguageCode("zh"))], - ) - except Exception as e: - logger.error(f"SenseVoice ASR error: {e}") - raise class CustomAgent(Agent): def __init__(self) -> None: @@ -83,63 +40,86 @@ class CustomAgent(Agent): ) async def on_enter(self) -> None: - self.session.generate_reply(instructions="greet the user and introduce yourself") + # self.session.generate_reply(instructions="greet the user and introduce yourself") + pass server = AgentServer() + def prewarm(proc: JobProcess) -> None: # Load Silero VAD as requested proc.userdata["vad"] = silero.VAD.load() + server.setup_fnc = prewarm -@server.rtc_session(agent_name="my-agent") + +@server.rtc_session(agent_name=AGENT_NAME) async def entrypoint(ctx: JobContext) -> None: ctx.log_context_fields = { "room": ctx.room.name, } - # Configuration for custom local endpoints - # These can be set in your .env file + # Configuration for custom local endpoints. These can be set in your .env file. ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox") - + ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice") + ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto") + ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh") + MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1") MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max") - - VOXCPM_URL = os.getenv("VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox") - PROMPT_WAV = os.getenv("VOXCPM_PROMPT_WAV", "/assets/2food16k_2.wav") + MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY") + if not MINIMAX_API_KEY: + raise RuntimeError(f"MINIMAX_API_KEY is not set in {CUSTOM_ENV_PATH}") - # Initialize SenseVoice STT and wrap with StreamAdapter - sensevoice_stt = SenseVoiceSTT(url=ASR_URL) - stt_stream = stt.StreamAdapter(stt=sensevoice_stt, vad=ctx.proc.userdata["vad"]) + TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv( + "VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox" + ) + TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts") + TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000) + TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1) + OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE) + + blackbox_stt = BlackboxSTT( + url=ASR_URL, + model_name=ASR_MODEL, + language=ASR_LANGUAGE, + output_language=ASR_OUTPUT_LANGUAGE, + hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"), + itn=os.getenv("CUSTOM_ASR_ITN"), + chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"), + ) + stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"]) import httpx from openai import AsyncClient as OpenAIAsyncClient # Create a custom HTTP client that disables SSL verification http_client = httpx.AsyncClient(verify=False) - + # Create the OpenAI AsyncClient with the custom HTTP client openai_client = OpenAIAsyncClient( - api_key="sk-orez64WkG1NkfksB5j_hGA", + api_key=MINIMAX_API_KEY, base_url=MINIMAX_BASE_URL, http_client=http_client, ) - from tts_voxcpm import VoxCPMTTS - session: AgentSession = AgentSession( - # 1. Custom SenseVoice ASR (STT) with StreamAdapter + # 1. Custom ASR blackbox with StreamAdapter stt=stt_stream, # 2. Minimax LLM - Using OpenAI plugin with local base_url llm=openai.LLM( model=MINIMAX_MODEL, client=openai_client, ), - # 3. VoxCPM TTS - Custom implementation for blackbox API - tts=VoxCPMTTS( - url=VOXCPM_URL, - prompt_wav_path=PROMPT_WAV, + # 3. TTS blackbox + tts=BlackboxTTS( + url=TTS_URL, + model_name=TTS_MODEL, + params=_tts_params_from_env(TTS_MODEL), + prompt_wav_path=os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV"), + sample_rate=TTS_SAMPLE_RATE, + num_channels=TTS_NUM_CHANNELS, ), # 4. Silero VAD vad=ctx.proc.userdata["vad"], @@ -150,7 +130,7 @@ async def entrypoint(ctx: JobContext) -> None: "false_interruption_timeout": 1.0, }, ), - preemptive_generation=True, + preemptive_generation=False, aec_warmup_duration=3.0, tts_text_transforms=[ "filter_emoji", @@ -165,7 +145,102 @@ async def entrypoint(ctx: JobContext) -> None: await session.start( agent=CustomAgent(), room=ctx.room, + room_options=room_io.RoomOptions( + audio_output=room_io.AudioOutputOptions( + sample_rate=OUTPUT_SAMPLE_RATE, + num_channels=TTS_NUM_CHANNELS, + ), + ), + record=_recording_options_from_env(), ) + +def _tts_params_from_env(model_name: str) -> dict[str, str]: + params: dict[str, str] = {} + model_name = model_name.lower() + + if model_name == "voxcpmtts": + params.update( + { + "streaming": os.getenv("CUSTOM_TTS_STREAMING", "false"), + "prompt_text": os.getenv( + "CUSTOM_TTS_PROMPT_TEXT", + os.getenv("VOXCPM_PROMPT_TEXT", "澳门有乜嘢好食嘅"), + ), + "cfg_value": os.getenv("VOXCPM_CFG_VALUE", "2.0"), + "inference_timesteps": os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "10"), + "do_normalize": os.getenv("VOXCPM_DO_NORMALIZE", "true"), + "denoise": os.getenv("VOXCPM_DENOISE", "true"), + "retry_badcase": os.getenv("VOXCPM_RETRY_BADCASE", "true"), + "retry_badcase_max_times": os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES", "3"), + "retry_badcase_ratio_threshold": os.getenv( + "VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD", "6.0" + ), + } + ) + elif model_name == "melotts": + params["speed"] = os.getenv("CUSTOM_TTS_SPEED", "1.0") + elif model_name == "cosyvoicetts": + _set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID")) + _set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE")) + _set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT")) + _set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT")) + elif model_name == "sovitstts": + params.update( + { + "text_lang": os.getenv("CUSTOM_TTS_TEXT_LANG", "zh"), + "prompt_lang": os.getenv("CUSTOM_TTS_PROMPT_LANG", "zh"), + "text_split_method": os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD", "cut0"), + "batch_size": os.getenv("CUSTOM_TTS_BATCH_SIZE", "1"), + "media_type": os.getenv("CUSTOM_TTS_MEDIA_TYPE", "wav"), + "streaming_mode": os.getenv("CUSTOM_TTS_STREAMING", "false"), + } + ) + _set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH")) + _set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT")) + + return params + + +def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None: + if value: + params[key] = value + + +def _env_int(name: str, default: int) -> int: + value = os.getenv(name) + if not value: + return default + try: + return int(value) + except ValueError: + logger.warning("Invalid integer for %s=%r, using %s", name, value, default) + return default + + +def _env_bool(name: str, default: bool) -> bool: + value = os.getenv(name) + if value is None: + return default + + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + + logger.warning("Invalid boolean for %s=%r, using %s", name, value, default) + return default + + +def _recording_options_from_env() -> RecordingOptions: + return RecordingOptions( + audio=_env_bool("CUSTOM_RECORD_AUDIO", False), + traces=_env_bool("CUSTOM_RECORD_TRACES", False), + logs=_env_bool("CUSTOM_RECORD_LOGS", False), + transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False), + ) + + if __name__ == "__main__": cli.run_app(server) diff --git a/test_asr.py b/test_asr.py index 64d41c1..2af0b62 100644 --- a/test_asr.py +++ b/test_asr.py @@ -1,53 +1,55 @@ import asyncio import logging import wave -from custom_agent import SenseVoiceSTT + +from asr import BlackboxSTT from livekit import rtc -from livekit.agents import utils # 设置日志级别以查看输出 logging.basicConfig(level=logging.INFO) logger = logging.getLogger("test-asr") + async def test(): # 替换为你本地的一个音频文件路径 - audio_path = "/home/verachen/Music/voice/2food.wav" - + audio_path = "/home/verachen/Music/voice/2food.wav" + # 初始化 ASR - stt = SenseVoiceSTT(url="http://10.6.80.21:5003/asr-blackbox") - + stt = BlackboxSTT(url="http://10.6.80.21:5003/asr-blackbox", model_name="sensevoice") + print(f"Testing ASR connectivity with file: {audio_path}") - + try: # 读取音频文件 - with wave.open(audio_path, 'rb') as wf: + with wave.open(audio_path, "rb") as wf: frames = wf.readframes(wf.getnframes()) # 简单构造一个 AudioBuffer (假设是单声道 16kHz) - # 实际上 SenseVoiceSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes() + # 实际上 BlackboxSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes() # 所以我们需要传递一个包含 AudioFrame 的 list - + # 这里我们模拟一个 Frame frame = rtc.AudioFrame( data=frames, sample_rate=wf.getframerate(), num_channels=wf.getnchannels(), - samples_per_channel=wf.getnframes() + samples_per_channel=wf.getnframes(), ) - + # 调用 recognize result = await stt.recognize(buffer=[frame]) - + if result.alternatives: - print(f"\n--- ASR Result ---") + print("\n--- ASR Result ---") print(f"Text: {result.alternatives[0].text}") - print(f"------------------\n") + print("------------------\n") else: print("ASR returned no text.") - + except FileNotFoundError: print(f"Error: Audio file not found at {audio_path}") except Exception as e: print(f"An error occurred: {e}") + if __name__ == "__main__": asyncio.run(test()) diff --git a/test_voxcpm.py b/test_voxcpm.py index f172bc4..59e58c7 100644 --- a/test_voxcpm.py +++ b/test_voxcpm.py @@ -1,50 +1,66 @@ import asyncio -import os import logging -from tts_voxcpm import VoxCPMTTS -from livekit.agents import tts +import os + +from tts import BlackboxTTS logging.basicConfig(level=logging.INFO) + async def test_tts(): # Use the URL from the user's curl command url = "http://10.6.80.21:5002/tts-blackbox" - + # Check if we have a real wav file to test with # In the earlier find_by_name, we found tests/change-sophie.wav - prompt_wav = "/home/verachen/Music/voice/2food.wav" + prompt_wav = "/home/verachen/Music/voice/2food.wav" if not os.path.exists(prompt_wav): - prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl + prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl - print(f"Testing VoxCPMTTS with URL: {url}") + print(f"Testing BlackboxTTS with URL: {url}") print(f"Using prompt wav: {prompt_wav}") - - vox_tts = VoxCPMTTS( + + blackbox_tts = BlackboxTTS( url=url, - prompt_wav_path=prompt_wav + model_name="voxcpmtts", + prompt_wav_path=prompt_wav, + params={ + "streaming": "false", + "prompt_text": "澳门有乜嘢好食嘅", + "cfg_value": "2.0", + "inference_timesteps": "10", + "do_normalize": "true", + "denoise": "true", + "retry_badcase": "true", + "retry_badcase_max_times": "3", + "retry_badcase_ratio_threshold": "6.0", + }, ) - + text = "你好,这是一段测试文本" print(f"Synthesizing text: {text}") - + try: - stream = vox_tts.synthesize(text) + stream = blackbox_tts.synthesize(text) audio_frame = await stream.collect() - - print(f"Successfully synthesized audio!") - print(f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?") + + print("Successfully synthesized audio!") + print( + f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?" + ) # Actually AudioFrame has duration or samples print(f"Samples: {len(audio_frame.data) // 2}") - + # Save to file for manual check if possible with open("test_output.wav", "wb") as f: - # This won't be a valid WAV yet if it's just raw PCM, + # This won't be a valid WAV yet if it's just raw PCM, # but if collect() returns combined frames, we can use to_wav_bytes() f.write(audio_frame.to_wav_bytes()) print("Saved output to test_output.wav") - + except Exception as e: print(f"TTS test failed: {e}") + if __name__ == "__main__": asyncio.run(test_tts()) diff --git a/tts.py b/tts.py new file mode 100644 index 0000000..b374f03 --- /dev/null +++ b/tts.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import asyncio +import logging +import os +import wave +from collections.abc import Mapping +from io import BytesIO +from typing import Optional + +import aiohttp + +from livekit.agents import ( + DEFAULT_API_CONNECT_OPTIONS, + APIConnectionError, + APIConnectOptions, + APIStatusError, + APITimeoutError, + tts, + utils, +) + +logger = logging.getLogger("blackbox-tts") + + +class BlackboxTTS(tts.TTS): + def __init__( + self, + *, + url: str, + model_name: str = "voxcpmtts", + params: Optional[Mapping[str, object]] = None, + prompt_wav_path: Optional[str] = None, + prompt_wav_field: str = "prompt_wav", + sample_rate: int = 16000, + num_channels: int = 1, + timeout: float = 60.0, + http_session: Optional[aiohttp.ClientSession] = None, + ) -> None: + super().__init__( + capabilities=tts.TTSCapabilities(streaming=False), + sample_rate=sample_rate, + num_channels=num_channels, + ) + self._url = url + self._model_name = model_name + self._params = {key: _form_value(value) for key, value in (params or {}).items()} + self._prompt_wav_path = prompt_wav_path + self._prompt_wav_field = prompt_wav_field + self._timeout = timeout + self._http_session = http_session + + @property + def model(self) -> str: + return self._model_name + + @property + def provider(self) -> str: + return "tts-blackbox" + + def _ensure_session(self) -> aiohttp.ClientSession: + if self._http_session is None: + self._http_session = utils.http_context.http_session() + return self._http_session + + def synthesize( + self, + text: str, + *, + conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, + ) -> tts.ChunkedStream: + return BlackboxTTSStream( + tts=self, + input_text=text, + conn_options=conn_options, + ) + + +class BlackboxTTSStream(tts.ChunkedStream): + def __init__( + self, + *, + tts: BlackboxTTS, + input_text: str, + conn_options: APIConnectOptions, + ) -> None: + super().__init__(tts=tts, input_text=input_text, conn_options=conn_options) + self._tts: BlackboxTTS = tts + + async def _run(self, output_emitter: tts.AudioEmitter) -> None: + form = aiohttp.FormData(default_to_multipart=True) + form.add_field("text", self.input_text) + form.add_field("model_name", self._tts._model_name) + for key, value in self._tts._params.items(): + form.add_field(key, value) + + prompt_file = None + if self._tts._prompt_wav_path: + if os.path.exists(self._tts._prompt_wav_path): + prompt_file = open(self._tts._prompt_wav_path, "rb") + form.add_field( + self._tts._prompt_wav_field, + prompt_file, + filename=os.path.basename(self._tts._prompt_wav_path), + content_type="audio/wav", + ) + else: + logger.warning( + "Prompt wav file not found at %s, skipping prompt_wav field", + self._tts._prompt_wav_path, + ) + + try: + async with self._tts._ensure_session().post( + self._tts._url, + data=form, + timeout=aiohttp.ClientTimeout( + total=self._tts._timeout, + sock_connect=self._conn_options.timeout, + ), + ) as resp: + if resp.status != 200: + error_text = await resp.text() + raise APIStatusError( + message=f"TTS blackbox error: {error_text}", + status_code=resp.status, + request_id=None, + body=error_text, + ) + + content_type = resp.headers.get("Content-Type", "audio/wav") + logged_wav_format = False + wav_header_probe = bytearray() + output_emitter.initialize( + request_id=utils.shortuuid(), + sample_rate=self._tts.sample_rate, + num_channels=self._tts.num_channels, + mime_type=content_type, + ) + + async for data, _ in resp.content.iter_chunks(): + if data: + if not logged_wav_format: + wav_header_probe.extend(data) + logged_wav_format = _log_wav_format( + bytes(wav_header_probe), + requested_sample_rate=self._tts.sample_rate, + requested_channels=self._tts.num_channels, + content_type=content_type, + ) + if not logged_wav_format and len(wav_header_probe) > 4096: + logger.info( + "TTS blackbox WAV format probe incomplete after %s bytes", + len(wav_header_probe), + ) + logged_wav_format = True + output_emitter.push(data) + output_emitter.flush() + except asyncio.TimeoutError as e: + raise APITimeoutError("TTS blackbox request timed out") from e + except aiohttp.ClientError as e: + raise APIConnectionError(f"TTS blackbox connection error: {e}") from e + finally: + if prompt_file is not None: + prompt_file.close() + + +def _form_value(value: object) -> str: + if isinstance(value, bool): + return str(value).lower() + return str(value) + + +def _log_wav_format( + data: bytes, + *, + requested_sample_rate: int, + requested_channels: int, + content_type: str, +) -> bool: + if not content_type.lower().startswith("audio/wav"): + logger.info("TTS blackbox returned content-type=%s", content_type) + return True + + try: + with wave.open(BytesIO(data), "rb") as wav: + sample_rate = wav.getframerate() + channels = wav.getnchannels() + sample_width = wav.getsampwidth() + except (EOFError, wave.Error): + return False + + logger.info( + "TTS blackbox WAV format: %sHz, %sch, %s-bit; output target: %sHz, %sch", + sample_rate, + channels, + sample_width * 8, + requested_sample_rate, + requested_channels, + ) + return True diff --git a/tts_voxcpm.py b/tts_voxcpm.py deleted file mode 100644 index 25291ac..0000000 --- a/tts_voxcpm.py +++ /dev/null @@ -1,118 +0,0 @@ -import aiohttp -import logging -import os -from livekit.agents import tts, utils, APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS - -logger = logging.getLogger("voxcpm-tts") - -class VoxCPMTTS(tts.TTS): - def __init__( - self, - *, - url: str, - model_name: str = "voxcpmtts", - prompt_text: str = "澳门有乜嘢好食嘅", - prompt_wav_path: str = "/home/verachen/Music/voice/2food16k_2.wav", - cfg_value: str = "2.0", - inference_timesteps: str = "10", - do_normalize: str = "true", - denoise: str = "true", - retry_badcase: str = "true", - retry_badcase_max_times: str = "3", - retry_badcase_ratio_threshold: str = "6.0", - sample_rate: int = 16000, - ): - super().__init__( - capabilities=tts.TTSCapabilities(streaming=False), - sample_rate=sample_rate, - num_channels=1, - ) - self._url = url - self._opts = { - "model_name": model_name, - "streaming": "false", - "prompt_text": prompt_text, - "cfg_value": str(cfg_value), - "inference_timesteps": str(inference_timesteps), - "do_normalize": str(do_normalize), - "denoise": str(denoise), - "retry_badcase": str(retry_badcase), - "retry_badcase_max_times": str(retry_badcase_max_times), - "retry_badcase_ratio_threshold": str(retry_badcase_ratio_threshold), - } - self._prompt_wav_path = prompt_wav_path - - @property - def model(self) -> str: - return self._opts["model_name"] - - def synthesize( - self, - text: str, - *, - conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, - ) -> tts.ChunkedStream: - return VoxCPMStream( - self, text, self._url, self._opts, self._prompt_wav_path, conn_options=conn_options - ) - -class VoxCPMStream(tts.ChunkedStream): - def __init__( - self, - tts: VoxCPMTTS, - text: str, - url: str, - opts: dict, - prompt_wav_path: str, - conn_options: APIConnectOptions, - ): - super().__init__(tts=tts, input_text=text, conn_options=conn_options) - self._url = url - self._opts = opts - self._prompt_wav_path = prompt_wav_path - - async def _run(self, output_emitter: tts.AudioEmitter) -> None: - # Initialize emitter early to avoid "AudioEmitter isn't started" error on failure - output_emitter.initialize( - request_id="", - sample_rate=self._tts.sample_rate, - num_channels=self._tts.num_channels, - mime_type="audio/wav", - ) - - async with aiohttp.ClientSession() as session: - data = aiohttp.FormData() - data.add_field("text", self.input_text) - for k, v in self._opts.items(): - data.add_field(k, v) - - # Open the prompt wav file if it exists - f = None - if os.path.exists(self._prompt_wav_path): - f = open(self._prompt_wav_path, "rb") - data.add_field("prompt_wav", f, filename="prompt.wav", content_type="audio/wav") - else: - logger.warning( - f"Prompt wav file not found at {self._prompt_wav_path}, skipping prompt_wav field" - ) - - try: - # Set a reasonable timeout for synthesis - async with session.post( - self._url, data=data, timeout=aiohttp.ClientTimeout(total=60) - ) as resp: - if resp.status != 200: - err_text = await resp.text() - logger.error(f"VoxCPM TTS error: {resp.status} {err_text}") - return - - # Read the entire audio data (since streaming=false) - audio_data = await resp.read() - - output_emitter.push(audio_data) - output_emitter.flush() - except Exception as e: - logger.error(f"VoxCPM TTS request failed: {e}") - finally: - if f: - f.close()