feat: supported different models

This commit is contained in:
0Xiao0
2026-05-11 11:22:01 +08:00
parent ac81d4a9eb
commit 409c7c9de0
6 changed files with 558 additions and 228 deletions

154
asr.py Normal file
View File

@ -0,0 +1,154 @@
import asyncio
import logging
from typing import Any, Optional, Union
import aiohttp
from livekit import rtc
from livekit.agents import (
NOT_GIVEN,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
LanguageCode,
NotGivenOr,
stt,
utils,
)
from livekit.agents.utils import is_given
logger = logging.getLogger("blackbox-asr")
class BlackboxSTT(stt.STT):
def __init__(
self,
url: str,
*,
model_name: str = "sensevoice",
language: Optional[str] = "auto",
output_language: str = "zh",
hotwords: Optional[str] = None,
itn: Optional[Union[bool, str]] = None,
chunk_mode: Optional[Union[bool, str]] = None,
timeout: float = 30.0,
http_session: Optional[aiohttp.ClientSession] = None,
) -> None:
super().__init__(
capabilities=stt.STTCapabilities(
streaming=False,
interim_results=False,
diarization=False,
)
)
self._url = url
self._model_name = model_name
self._language = language
self._output_language = output_language
self._timeout = timeout
self._http_session = http_session
self._extra_fields: dict[str, str] = {}
if hotwords:
self._extra_fields["hotwords"] = hotwords
if itn is not None:
self._extra_fields["itn"] = _form_value(itn)
if chunk_mode is not None:
self._extra_fields["chunk_mode"] = _form_value(chunk_mode)
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "asr-blackbox"
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
form = aiohttp.FormData()
form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav")
form.add_field("model_name", self._model_name)
resolved_language = language if is_given(language) else self._language
if resolved_language:
form.add_field("language", resolved_language)
for key, value in self._extra_fields.items():
form.add_field(key, value)
try:
async with self._ensure_session().post(
self._url,
data=form,
timeout=aiohttp.ClientTimeout(
total=self._timeout,
sock_connect=conn_options.timeout,
),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"ASR blackbox error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
payload = await resp.json()
logger.info("ASR blackbox raw result: %s", payload)
text = _extract_asr_text(payload)
if not text:
raise APIConnectionError("ASR blackbox returned an empty transcript")
logger.info("ASR blackbox result: %s", text)
return stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[
stt.SpeechData(
text=text,
language=LanguageCode(self._output_language),
)
],
)
except asyncio.TimeoutError as e:
raise APITimeoutError("ASR blackbox request timed out") from e
except aiohttp.ClientError as e:
raise APIConnectionError(f"ASR blackbox connection error: {e}") from e
def _extract_asr_text(payload: dict[str, Any]) -> str:
text = payload.get("text")
if isinstance(text, str):
return text.strip()
result = payload.get("result")
if isinstance(result, list) and result:
first = result[0]
if isinstance(first, dict):
for key in ("clean_text", "text", "raw_text"):
value = first.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(first, str):
return first.strip()
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
def _form_value(value: Union[bool, str]) -> str:
if isinstance(value, bool):
return str(value).lower()
return value

View File

@ -1,78 +1,35 @@
import logging import logging
import os import os
import aiohttp from pathlib import Path
from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from livekit import rtc
from asr import BlackboxSTT
from livekit.agents import ( from livekit.agents import (
Agent, Agent,
AgentServer, AgentServer,
AgentSession, AgentSession,
APIConnectOptions,
JobContext, JobContext,
JobProcess, JobProcess,
LanguageCode,
MetricsCollectedEvent, MetricsCollectedEvent,
NOT_GIVEN, RecordingOptions,
NotGivenOr,
TurnHandlingOptions, TurnHandlingOptions,
cli, cli,
metrics, metrics,
room_io, room_io,
stt, stt,
text_transforms,
utils,
) )
from livekit.plugins import silero, openai from livekit.plugins import openai, silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel from livekit.plugins.turn_detector.multilingual import MultilingualModel
from tts import BlackboxTTS
logger = logging.getLogger("custom-agent") logger = logging.getLogger("custom-agent")
load_dotenv() CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
class SenseVoiceSTT(stt.STT):
def __init__(self, url: str):
super().__init__(capabilities=stt.STTCapabilities(streaming=False, interim_results=False, diarization=False))
self._url = url
@property
def model(self) -> str:
return "sensevoice"
async def _recognize_impl(
self,
buffer: utils.AudioBuffer,
*,
language: NotGivenOr[str] = NOT_GIVEN,
conn_options: APIConnectOptions,
) -> stt.SpeechEvent:
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field('audio', audio_data, filename='audio.wav', content_type='audio/wav')
data.add_field('model_name', 'sensevoice')
lang = language if language is not NOT_GIVEN else 'auto'
data.add_field('language', lang)
try:
async with session.post(self._url, data=data, timeout=30) as resp:
if resp.status != 200:
raise Exception(f"ASR server returned status {resp.status}")
result = await resp.json()
if not result.get("result"):
return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT)
text = result["result"][0].get("clean_text", "")
logger.info(f"SenseVoice ASR Result: {text}")
return stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[stt.SpeechData(text=text, language=LanguageCode("zh"))],
)
except Exception as e:
logger.error(f"SenseVoice ASR error: {e}")
raise
class CustomAgent(Agent): class CustomAgent(Agent):
def __init__(self) -> None: def __init__(self) -> None:
@ -83,63 +40,86 @@ class CustomAgent(Agent):
) )
async def on_enter(self) -> None: async def on_enter(self) -> None:
self.session.generate_reply(instructions="greet the user and introduce yourself") # self.session.generate_reply(instructions="greet the user and introduce yourself")
pass
server = AgentServer() server = AgentServer()
def prewarm(proc: JobProcess) -> None: def prewarm(proc: JobProcess) -> None:
# Load Silero VAD as requested # Load Silero VAD as requested
proc.userdata["vad"] = silero.VAD.load() proc.userdata["vad"] = silero.VAD.load()
server.setup_fnc = prewarm server.setup_fnc = prewarm
@server.rtc_session(agent_name="my-agent")
@server.rtc_session(agent_name=AGENT_NAME)
async def entrypoint(ctx: JobContext) -> None: async def entrypoint(ctx: JobContext) -> None:
ctx.log_context_fields = { ctx.log_context_fields = {
"room": ctx.room.name, "room": ctx.room.name,
} }
# Configuration for custom local endpoints # Configuration for custom local endpoints. These can be set in your .env file.
# These can be set in your .env file
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox") ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1") MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max") MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY")
VOXCPM_URL = os.getenv("VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox") if not MINIMAX_API_KEY:
PROMPT_WAV = os.getenv("VOXCPM_PROMPT_WAV", "/assets/2food16k_2.wav") raise RuntimeError(f"MINIMAX_API_KEY is not set in {CUSTOM_ENV_PATH}")
# Initialize SenseVoice STT and wrap with StreamAdapter TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
sensevoice_stt = SenseVoiceSTT(url=ASR_URL) "VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
stt_stream = stt.StreamAdapter(stt=sensevoice_stt, vad=ctx.proc.userdata["vad"]) )
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
blackbox_stt = BlackboxSTT(
url=ASR_URL,
model_name=ASR_MODEL,
language=ASR_LANGUAGE,
output_language=ASR_OUTPUT_LANGUAGE,
hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"),
itn=os.getenv("CUSTOM_ASR_ITN"),
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
)
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
import httpx import httpx
from openai import AsyncClient as OpenAIAsyncClient from openai import AsyncClient as OpenAIAsyncClient
# Create a custom HTTP client that disables SSL verification # Create a custom HTTP client that disables SSL verification
http_client = httpx.AsyncClient(verify=False) http_client = httpx.AsyncClient(verify=False)
# Create the OpenAI AsyncClient with the custom HTTP client # Create the OpenAI AsyncClient with the custom HTTP client
openai_client = OpenAIAsyncClient( openai_client = OpenAIAsyncClient(
api_key="sk-orez64WkG1NkfksB5j_hGA", api_key=MINIMAX_API_KEY,
base_url=MINIMAX_BASE_URL, base_url=MINIMAX_BASE_URL,
http_client=http_client, http_client=http_client,
) )
from tts_voxcpm import VoxCPMTTS
session: AgentSession = AgentSession( session: AgentSession = AgentSession(
# 1. Custom SenseVoice ASR (STT) with StreamAdapter # 1. Custom ASR blackbox with StreamAdapter
stt=stt_stream, stt=stt_stream,
# 2. Minimax LLM - Using OpenAI plugin with local base_url # 2. Minimax LLM - Using OpenAI plugin with local base_url
llm=openai.LLM( llm=openai.LLM(
model=MINIMAX_MODEL, model=MINIMAX_MODEL,
client=openai_client, client=openai_client,
), ),
# 3. VoxCPM TTS - Custom implementation for blackbox API # 3. TTS blackbox
tts=VoxCPMTTS( tts=BlackboxTTS(
url=VOXCPM_URL, url=TTS_URL,
prompt_wav_path=PROMPT_WAV, model_name=TTS_MODEL,
params=_tts_params_from_env(TTS_MODEL),
prompt_wav_path=os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV"),
sample_rate=TTS_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
), ),
# 4. Silero VAD # 4. Silero VAD
vad=ctx.proc.userdata["vad"], vad=ctx.proc.userdata["vad"],
@ -150,7 +130,7 @@ async def entrypoint(ctx: JobContext) -> None:
"false_interruption_timeout": 1.0, "false_interruption_timeout": 1.0,
}, },
), ),
preemptive_generation=True, preemptive_generation=False,
aec_warmup_duration=3.0, aec_warmup_duration=3.0,
tts_text_transforms=[ tts_text_transforms=[
"filter_emoji", "filter_emoji",
@ -165,7 +145,102 @@ async def entrypoint(ctx: JobContext) -> None:
await session.start( await session.start(
agent=CustomAgent(), agent=CustomAgent(),
room=ctx.room, room=ctx.room,
room_options=room_io.RoomOptions(
audio_output=room_io.AudioOutputOptions(
sample_rate=OUTPUT_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
),
),
record=_recording_options_from_env(),
) )
def _tts_params_from_env(model_name: str) -> dict[str, str]:
params: dict[str, str] = {}
model_name = model_name.lower()
if model_name == "voxcpmtts":
params.update(
{
"streaming": os.getenv("CUSTOM_TTS_STREAMING", "false"),
"prompt_text": os.getenv(
"CUSTOM_TTS_PROMPT_TEXT",
os.getenv("VOXCPM_PROMPT_TEXT", "澳门有乜嘢好食嘅"),
),
"cfg_value": os.getenv("VOXCPM_CFG_VALUE", "2.0"),
"inference_timesteps": os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "10"),
"do_normalize": os.getenv("VOXCPM_DO_NORMALIZE", "true"),
"denoise": os.getenv("VOXCPM_DENOISE", "true"),
"retry_badcase": os.getenv("VOXCPM_RETRY_BADCASE", "true"),
"retry_badcase_max_times": os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES", "3"),
"retry_badcase_ratio_threshold": os.getenv(
"VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD", "6.0"
),
}
)
elif model_name == "melotts":
params["speed"] = os.getenv("CUSTOM_TTS_SPEED", "1.0")
elif model_name == "cosyvoicetts":
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
elif model_name == "sovitstts":
params.update(
{
"text_lang": os.getenv("CUSTOM_TTS_TEXT_LANG", "zh"),
"prompt_lang": os.getenv("CUSTOM_TTS_PROMPT_LANG", "zh"),
"text_split_method": os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD", "cut0"),
"batch_size": os.getenv("CUSTOM_TTS_BATCH_SIZE", "1"),
"media_type": os.getenv("CUSTOM_TTS_MEDIA_TYPE", "wav"),
"streaming_mode": os.getenv("CUSTOM_TTS_STREAMING", "false"),
}
)
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
return params
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
if value:
params[key] = value
def _env_int(name: str, default: int) -> int:
value = os.getenv(name)
if not value:
return default
try:
return int(value)
except ValueError:
logger.warning("Invalid integer for %s=%r, using %s", name, value, default)
return default
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
logger.warning("Invalid boolean for %s=%r, using %s", name, value, default)
return default
def _recording_options_from_env() -> RecordingOptions:
return RecordingOptions(
audio=_env_bool("CUSTOM_RECORD_AUDIO", False),
traces=_env_bool("CUSTOM_RECORD_TRACES", False),
logs=_env_bool("CUSTOM_RECORD_LOGS", False),
transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False),
)
if __name__ == "__main__": if __name__ == "__main__":
cli.run_app(server) cli.run_app(server)

View File

@ -1,53 +1,55 @@
import asyncio import asyncio
import logging import logging
import wave import wave
from custom_agent import SenseVoiceSTT
from asr import BlackboxSTT
from livekit import rtc from livekit import rtc
from livekit.agents import utils
# 设置日志级别以查看输出 # 设置日志级别以查看输出
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("test-asr") logger = logging.getLogger("test-asr")
async def test(): async def test():
# 替换为你本地的一个音频文件路径 # 替换为你本地的一个音频文件路径
audio_path = "/home/verachen/Music/voice/2food.wav" audio_path = "/home/verachen/Music/voice/2food.wav"
# 初始化 ASR # 初始化 ASR
stt = SenseVoiceSTT(url="http://10.6.80.21:5003/asr-blackbox") stt = BlackboxSTT(url="http://10.6.80.21:5003/asr-blackbox", model_name="sensevoice")
print(f"Testing ASR connectivity with file: {audio_path}") print(f"Testing ASR connectivity with file: {audio_path}")
try: try:
# 读取音频文件 # 读取音频文件
with wave.open(audio_path, 'rb') as wf: with wave.open(audio_path, "rb") as wf:
frames = wf.readframes(wf.getnframes()) frames = wf.readframes(wf.getnframes())
# 简单构造一个 AudioBuffer (假设是单声道 16kHz) # 简单构造一个 AudioBuffer (假设是单声道 16kHz)
# 实际上 SenseVoiceSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes() # 实际上 BlackboxSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes()
# 所以我们需要传递一个包含 AudioFrame 的 list # 所以我们需要传递一个包含 AudioFrame 的 list
# 这里我们模拟一个 Frame # 这里我们模拟一个 Frame
frame = rtc.AudioFrame( frame = rtc.AudioFrame(
data=frames, data=frames,
sample_rate=wf.getframerate(), sample_rate=wf.getframerate(),
num_channels=wf.getnchannels(), num_channels=wf.getnchannels(),
samples_per_channel=wf.getnframes() samples_per_channel=wf.getnframes(),
) )
# 调用 recognize # 调用 recognize
result = await stt.recognize(buffer=[frame]) result = await stt.recognize(buffer=[frame])
if result.alternatives: if result.alternatives:
print(f"\n--- ASR Result ---") print("\n--- ASR Result ---")
print(f"Text: {result.alternatives[0].text}") print(f"Text: {result.alternatives[0].text}")
print(f"------------------\n") print("------------------\n")
else: else:
print("ASR returned no text.") print("ASR returned no text.")
except FileNotFoundError: except FileNotFoundError:
print(f"Error: Audio file not found at {audio_path}") print(f"Error: Audio file not found at {audio_path}")
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(test()) asyncio.run(test())

View File

@ -1,50 +1,66 @@
import asyncio import asyncio
import os
import logging import logging
from tts_voxcpm import VoxCPMTTS import os
from livekit.agents import tts
from tts import BlackboxTTS
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
async def test_tts(): async def test_tts():
# Use the URL from the user's curl command # Use the URL from the user's curl command
url = "http://10.6.80.21:5002/tts-blackbox" url = "http://10.6.80.21:5002/tts-blackbox"
# Check if we have a real wav file to test with # Check if we have a real wav file to test with
# In the earlier find_by_name, we found tests/change-sophie.wav # In the earlier find_by_name, we found tests/change-sophie.wav
prompt_wav = "/home/verachen/Music/voice/2food.wav" prompt_wav = "/home/verachen/Music/voice/2food.wav"
if not os.path.exists(prompt_wav): if not os.path.exists(prompt_wav):
prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl
print(f"Testing VoxCPMTTS with URL: {url}") print(f"Testing BlackboxTTS with URL: {url}")
print(f"Using prompt wav: {prompt_wav}") print(f"Using prompt wav: {prompt_wav}")
vox_tts = VoxCPMTTS( blackbox_tts = BlackboxTTS(
url=url, url=url,
prompt_wav_path=prompt_wav model_name="voxcpmtts",
prompt_wav_path=prompt_wav,
params={
"streaming": "false",
"prompt_text": "澳门有乜嘢好食嘅",
"cfg_value": "2.0",
"inference_timesteps": "10",
"do_normalize": "true",
"denoise": "true",
"retry_badcase": "true",
"retry_badcase_max_times": "3",
"retry_badcase_ratio_threshold": "6.0",
},
) )
text = "你好,这是一段测试文本" text = "你好,这是一段测试文本"
print(f"Synthesizing text: {text}") print(f"Synthesizing text: {text}")
try: try:
stream = vox_tts.synthesize(text) stream = blackbox_tts.synthesize(text)
audio_frame = await stream.collect() audio_frame = await stream.collect()
print(f"Successfully synthesized audio!") print("Successfully synthesized audio!")
print(f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?") print(
f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?"
)
# Actually AudioFrame has duration or samples # Actually AudioFrame has duration or samples
print(f"Samples: {len(audio_frame.data) // 2}") print(f"Samples: {len(audio_frame.data) // 2}")
# Save to file for manual check if possible # Save to file for manual check if possible
with open("test_output.wav", "wb") as f: with open("test_output.wav", "wb") as f:
# This won't be a valid WAV yet if it's just raw PCM, # This won't be a valid WAV yet if it's just raw PCM,
# but if collect() returns combined frames, we can use to_wav_bytes() # but if collect() returns combined frames, we can use to_wav_bytes()
f.write(audio_frame.to_wav_bytes()) f.write(audio_frame.to_wav_bytes())
print("Saved output to test_output.wav") print("Saved output to test_output.wav")
except Exception as e: except Exception as e:
print(f"TTS test failed: {e}") print(f"TTS test failed: {e}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(test_tts()) asyncio.run(test_tts())

201
tts.py Normal file
View File

@ -0,0 +1,201 @@
from __future__ import annotations
import asyncio
import logging
import os
import wave
from collections.abc import Mapping
from io import BytesIO
from typing import Optional
import aiohttp
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectionError,
APIConnectOptions,
APIStatusError,
APITimeoutError,
tts,
utils,
)
logger = logging.getLogger("blackbox-tts")
class BlackboxTTS(tts.TTS):
def __init__(
self,
*,
url: str,
model_name: str = "voxcpmtts",
params: Optional[Mapping[str, object]] = None,
prompt_wav_path: Optional[str] = None,
prompt_wav_field: str = "prompt_wav",
sample_rate: int = 16000,
num_channels: int = 1,
timeout: float = 60.0,
http_session: Optional[aiohttp.ClientSession] = None,
) -> None:
super().__init__(
capabilities=tts.TTSCapabilities(streaming=False),
sample_rate=sample_rate,
num_channels=num_channels,
)
self._url = url
self._model_name = model_name
self._params = {key: _form_value(value) for key, value in (params or {}).items()}
self._prompt_wav_path = prompt_wav_path
self._prompt_wav_field = prompt_wav_field
self._timeout = timeout
self._http_session = http_session
@property
def model(self) -> str:
return self._model_name
@property
def provider(self) -> str:
return "tts-blackbox"
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
def synthesize(
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> tts.ChunkedStream:
return BlackboxTTSStream(
tts=self,
input_text=text,
conn_options=conn_options,
)
class BlackboxTTSStream(tts.ChunkedStream):
def __init__(
self,
*,
tts: BlackboxTTS,
input_text: str,
conn_options: APIConnectOptions,
) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._tts: BlackboxTTS = tts
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
form = aiohttp.FormData(default_to_multipart=True)
form.add_field("text", self.input_text)
form.add_field("model_name", self._tts._model_name)
for key, value in self._tts._params.items():
form.add_field(key, value)
prompt_file = None
if self._tts._prompt_wav_path:
if os.path.exists(self._tts._prompt_wav_path):
prompt_file = open(self._tts._prompt_wav_path, "rb")
form.add_field(
self._tts._prompt_wav_field,
prompt_file,
filename=os.path.basename(self._tts._prompt_wav_path),
content_type="audio/wav",
)
else:
logger.warning(
"Prompt wav file not found at %s, skipping prompt_wav field",
self._tts._prompt_wav_path,
)
try:
async with self._tts._ensure_session().post(
self._tts._url,
data=form,
timeout=aiohttp.ClientTimeout(
total=self._tts._timeout,
sock_connect=self._conn_options.timeout,
),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"TTS blackbox error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
content_type = resp.headers.get("Content-Type", "audio/wav")
logged_wav_format = False
wav_header_probe = bytearray()
output_emitter.initialize(
request_id=utils.shortuuid(),
sample_rate=self._tts.sample_rate,
num_channels=self._tts.num_channels,
mime_type=content_type,
)
async for data, _ in resp.content.iter_chunks():
if data:
if not logged_wav_format:
wav_header_probe.extend(data)
logged_wav_format = _log_wav_format(
bytes(wav_header_probe),
requested_sample_rate=self._tts.sample_rate,
requested_channels=self._tts.num_channels,
content_type=content_type,
)
if not logged_wav_format and len(wav_header_probe) > 4096:
logger.info(
"TTS blackbox WAV format probe incomplete after %s bytes",
len(wav_header_probe),
)
logged_wav_format = True
output_emitter.push(data)
output_emitter.flush()
except asyncio.TimeoutError as e:
raise APITimeoutError("TTS blackbox request timed out") from e
except aiohttp.ClientError as e:
raise APIConnectionError(f"TTS blackbox connection error: {e}") from e
finally:
if prompt_file is not None:
prompt_file.close()
def _form_value(value: object) -> str:
if isinstance(value, bool):
return str(value).lower()
return str(value)
def _log_wav_format(
data: bytes,
*,
requested_sample_rate: int,
requested_channels: int,
content_type: str,
) -> bool:
if not content_type.lower().startswith("audio/wav"):
logger.info("TTS blackbox returned content-type=%s", content_type)
return True
try:
with wave.open(BytesIO(data), "rb") as wav:
sample_rate = wav.getframerate()
channels = wav.getnchannels()
sample_width = wav.getsampwidth()
except (EOFError, wave.Error):
return False
logger.info(
"TTS blackbox WAV format: %sHz, %sch, %s-bit; output target: %sHz, %sch",
sample_rate,
channels,
sample_width * 8,
requested_sample_rate,
requested_channels,
)
return True

View File

@ -1,118 +0,0 @@
import aiohttp
import logging
import os
from livekit.agents import tts, utils, APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS
logger = logging.getLogger("voxcpm-tts")
class VoxCPMTTS(tts.TTS):
def __init__(
self,
*,
url: str,
model_name: str = "voxcpmtts",
prompt_text: str = "澳门有乜嘢好食嘅",
prompt_wav_path: str = "/home/verachen/Music/voice/2food16k_2.wav",
cfg_value: str = "2.0",
inference_timesteps: str = "10",
do_normalize: str = "true",
denoise: str = "true",
retry_badcase: str = "true",
retry_badcase_max_times: str = "3",
retry_badcase_ratio_threshold: str = "6.0",
sample_rate: int = 16000,
):
super().__init__(
capabilities=tts.TTSCapabilities(streaming=False),
sample_rate=sample_rate,
num_channels=1,
)
self._url = url
self._opts = {
"model_name": model_name,
"streaming": "false",
"prompt_text": prompt_text,
"cfg_value": str(cfg_value),
"inference_timesteps": str(inference_timesteps),
"do_normalize": str(do_normalize),
"denoise": str(denoise),
"retry_badcase": str(retry_badcase),
"retry_badcase_max_times": str(retry_badcase_max_times),
"retry_badcase_ratio_threshold": str(retry_badcase_ratio_threshold),
}
self._prompt_wav_path = prompt_wav_path
@property
def model(self) -> str:
return self._opts["model_name"]
def synthesize(
self,
text: str,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> tts.ChunkedStream:
return VoxCPMStream(
self, text, self._url, self._opts, self._prompt_wav_path, conn_options=conn_options
)
class VoxCPMStream(tts.ChunkedStream):
def __init__(
self,
tts: VoxCPMTTS,
text: str,
url: str,
opts: dict,
prompt_wav_path: str,
conn_options: APIConnectOptions,
):
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
self._url = url
self._opts = opts
self._prompt_wav_path = prompt_wav_path
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
# Initialize emitter early to avoid "AudioEmitter isn't started" error on failure
output_emitter.initialize(
request_id="",
sample_rate=self._tts.sample_rate,
num_channels=self._tts.num_channels,
mime_type="audio/wav",
)
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field("text", self.input_text)
for k, v in self._opts.items():
data.add_field(k, v)
# Open the prompt wav file if it exists
f = None
if os.path.exists(self._prompt_wav_path):
f = open(self._prompt_wav_path, "rb")
data.add_field("prompt_wav", f, filename="prompt.wav", content_type="audio/wav")
else:
logger.warning(
f"Prompt wav file not found at {self._prompt_wav_path}, skipping prompt_wav field"
)
try:
# Set a reasonable timeout for synthesis
async with session.post(
self._url, data=data, timeout=aiohttp.ClientTimeout(total=60)
) as resp:
if resp.status != 200:
err_text = await resp.text()
logger.error(f"VoxCPM TTS error: {resp.status} {err_text}")
return
# Read the entire audio data (since streaming=false)
audio_data = await resp.read()
output_emitter.push(audio_data)
output_emitter.flush()
except Exception as e:
logger.error(f"VoxCPM TTS request failed: {e}")
finally:
if f:
f.close()