feat: supported different models
This commit is contained in:
154
asr.py
Normal file
154
asr.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from livekit import rtc
|
||||||
|
from livekit.agents import (
|
||||||
|
NOT_GIVEN,
|
||||||
|
APIConnectionError,
|
||||||
|
APIConnectOptions,
|
||||||
|
APIStatusError,
|
||||||
|
APITimeoutError,
|
||||||
|
LanguageCode,
|
||||||
|
NotGivenOr,
|
||||||
|
stt,
|
||||||
|
utils,
|
||||||
|
)
|
||||||
|
from livekit.agents.utils import is_given
|
||||||
|
|
||||||
|
logger = logging.getLogger("blackbox-asr")
|
||||||
|
|
||||||
|
|
||||||
|
class BlackboxSTT(stt.STT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
model_name: str = "sensevoice",
|
||||||
|
language: Optional[str] = "auto",
|
||||||
|
output_language: str = "zh",
|
||||||
|
hotwords: Optional[str] = None,
|
||||||
|
itn: Optional[Union[bool, str]] = None,
|
||||||
|
chunk_mode: Optional[Union[bool, str]] = None,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
http_session: Optional[aiohttp.ClientSession] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
capabilities=stt.STTCapabilities(
|
||||||
|
streaming=False,
|
||||||
|
interim_results=False,
|
||||||
|
diarization=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._url = url
|
||||||
|
self._model_name = model_name
|
||||||
|
self._language = language
|
||||||
|
self._output_language = output_language
|
||||||
|
self._timeout = timeout
|
||||||
|
self._http_session = http_session
|
||||||
|
self._extra_fields: dict[str, str] = {}
|
||||||
|
|
||||||
|
if hotwords:
|
||||||
|
self._extra_fields["hotwords"] = hotwords
|
||||||
|
if itn is not None:
|
||||||
|
self._extra_fields["itn"] = _form_value(itn)
|
||||||
|
if chunk_mode is not None:
|
||||||
|
self._extra_fields["chunk_mode"] = _form_value(chunk_mode)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> str:
|
||||||
|
return "asr-blackbox"
|
||||||
|
|
||||||
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._http_session is None:
|
||||||
|
self._http_session = utils.http_context.http_session()
|
||||||
|
return self._http_session
|
||||||
|
|
||||||
|
async def _recognize_impl(
|
||||||
|
self,
|
||||||
|
buffer: utils.AudioBuffer,
|
||||||
|
*,
|
||||||
|
language: NotGivenOr[str] = NOT_GIVEN,
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
) -> stt.SpeechEvent:
|
||||||
|
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
|
||||||
|
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
form.add_field("audio", audio_data, filename="audio.wav", content_type="audio/wav")
|
||||||
|
form.add_field("model_name", self._model_name)
|
||||||
|
|
||||||
|
resolved_language = language if is_given(language) else self._language
|
||||||
|
if resolved_language:
|
||||||
|
form.add_field("language", resolved_language)
|
||||||
|
for key, value in self._extra_fields.items():
|
||||||
|
form.add_field(key, value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._ensure_session().post(
|
||||||
|
self._url,
|
||||||
|
data=form,
|
||||||
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
total=self._timeout,
|
||||||
|
sock_connect=conn_options.timeout,
|
||||||
|
),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
raise APIStatusError(
|
||||||
|
message=f"ASR blackbox error: {error_text}",
|
||||||
|
status_code=resp.status,
|
||||||
|
request_id=None,
|
||||||
|
body=error_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = await resp.json()
|
||||||
|
logger.info("ASR blackbox raw result: %s", payload)
|
||||||
|
text = _extract_asr_text(payload)
|
||||||
|
if not text:
|
||||||
|
raise APIConnectionError("ASR blackbox returned an empty transcript")
|
||||||
|
|
||||||
|
logger.info("ASR blackbox result: %s", text)
|
||||||
|
return stt.SpeechEvent(
|
||||||
|
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
||||||
|
alternatives=[
|
||||||
|
stt.SpeechData(
|
||||||
|
text=text,
|
||||||
|
language=LanguageCode(self._output_language),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
raise APITimeoutError("ASR blackbox request timed out") from e
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise APIConnectionError(f"ASR blackbox connection error: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_asr_text(payload: dict[str, Any]) -> str:
|
||||||
|
text = payload.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
result = payload.get("result")
|
||||||
|
if isinstance(result, list) and result:
|
||||||
|
first = result[0]
|
||||||
|
if isinstance(first, dict):
|
||||||
|
for key in ("clean_text", "text", "raw_text"):
|
||||||
|
value = first.get(key)
|
||||||
|
if isinstance(value, str) and value.strip():
|
||||||
|
return value.strip()
|
||||||
|
if isinstance(first, str):
|
||||||
|
return first.strip()
|
||||||
|
|
||||||
|
raise APIConnectionError(f"Unsupported ASR blackbox response: {payload}")
|
||||||
|
|
||||||
|
|
||||||
|
def _form_value(value: Union[bool, str]) -> str:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return str(value).lower()
|
||||||
|
return value
|
||||||
225
custom_agent.py
225
custom_agent.py
@ -1,78 +1,35 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import aiohttp
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from livekit import rtc
|
|
||||||
|
from asr import BlackboxSTT
|
||||||
from livekit.agents import (
|
from livekit.agents import (
|
||||||
Agent,
|
Agent,
|
||||||
AgentServer,
|
AgentServer,
|
||||||
AgentSession,
|
AgentSession,
|
||||||
APIConnectOptions,
|
|
||||||
JobContext,
|
JobContext,
|
||||||
JobProcess,
|
JobProcess,
|
||||||
LanguageCode,
|
|
||||||
MetricsCollectedEvent,
|
MetricsCollectedEvent,
|
||||||
NOT_GIVEN,
|
RecordingOptions,
|
||||||
NotGivenOr,
|
|
||||||
TurnHandlingOptions,
|
TurnHandlingOptions,
|
||||||
cli,
|
cli,
|
||||||
metrics,
|
metrics,
|
||||||
room_io,
|
room_io,
|
||||||
stt,
|
stt,
|
||||||
text_transforms,
|
|
||||||
utils,
|
|
||||||
)
|
)
|
||||||
from livekit.plugins import silero, openai
|
from livekit.plugins import openai, silero
|
||||||
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
||||||
|
from tts import BlackboxTTS
|
||||||
|
|
||||||
logger = logging.getLogger("custom-agent")
|
logger = logging.getLogger("custom-agent")
|
||||||
|
|
||||||
load_dotenv()
|
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||||
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||||
|
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
||||||
|
|
||||||
class SenseVoiceSTT(stt.STT):
|
|
||||||
def __init__(self, url: str):
|
|
||||||
super().__init__(capabilities=stt.STTCapabilities(streaming=False, interim_results=False, diarization=False))
|
|
||||||
self._url = url
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model(self) -> str:
|
|
||||||
return "sensevoice"
|
|
||||||
|
|
||||||
async def _recognize_impl(
|
|
||||||
self,
|
|
||||||
buffer: utils.AudioBuffer,
|
|
||||||
*,
|
|
||||||
language: NotGivenOr[str] = NOT_GIVEN,
|
|
||||||
conn_options: APIConnectOptions,
|
|
||||||
) -> stt.SpeechEvent:
|
|
||||||
audio_data = rtc.combine_audio_frames(buffer).to_wav_bytes()
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
data = aiohttp.FormData()
|
|
||||||
data.add_field('audio', audio_data, filename='audio.wav', content_type='audio/wav')
|
|
||||||
data.add_field('model_name', 'sensevoice')
|
|
||||||
|
|
||||||
lang = language if language is not NOT_GIVEN else 'auto'
|
|
||||||
data.add_field('language', lang)
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with session.post(self._url, data=data, timeout=30) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
raise Exception(f"ASR server returned status {resp.status}")
|
|
||||||
|
|
||||||
result = await resp.json()
|
|
||||||
if not result.get("result"):
|
|
||||||
return stt.SpeechEvent(type=stt.SpeechEventType.FINAL_TRANSCRIPT)
|
|
||||||
|
|
||||||
text = result["result"][0].get("clean_text", "")
|
|
||||||
logger.info(f"SenseVoice ASR Result: {text}")
|
|
||||||
return stt.SpeechEvent(
|
|
||||||
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
|
|
||||||
alternatives=[stt.SpeechData(text=text, language=LanguageCode("zh"))],
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"SenseVoice ASR error: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
class CustomAgent(Agent):
|
class CustomAgent(Agent):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -83,63 +40,86 @@ class CustomAgent(Agent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def on_enter(self) -> None:
|
async def on_enter(self) -> None:
|
||||||
self.session.generate_reply(instructions="greet the user and introduce yourself")
|
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||||||
|
pass
|
||||||
|
|
||||||
server = AgentServer()
|
server = AgentServer()
|
||||||
|
|
||||||
|
|
||||||
def prewarm(proc: JobProcess) -> None:
|
def prewarm(proc: JobProcess) -> None:
|
||||||
# Load Silero VAD as requested
|
# Load Silero VAD as requested
|
||||||
proc.userdata["vad"] = silero.VAD.load()
|
proc.userdata["vad"] = silero.VAD.load()
|
||||||
|
|
||||||
|
|
||||||
server.setup_fnc = prewarm
|
server.setup_fnc = prewarm
|
||||||
|
|
||||||
@server.rtc_session(agent_name="my-agent")
|
|
||||||
|
@server.rtc_session(agent_name=AGENT_NAME)
|
||||||
async def entrypoint(ctx: JobContext) -> None:
|
async def entrypoint(ctx: JobContext) -> None:
|
||||||
ctx.log_context_fields = {
|
ctx.log_context_fields = {
|
||||||
"room": ctx.room.name,
|
"room": ctx.room.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Configuration for custom local endpoints
|
# Configuration for custom local endpoints. These can be set in your .env file.
|
||||||
# These can be set in your .env file
|
|
||||||
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
|
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
|
||||||
|
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
|
||||||
|
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
||||||
|
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
||||||
|
|
||||||
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
|
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
|
||||||
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
|
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
|
||||||
|
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY")
|
||||||
VOXCPM_URL = os.getenv("VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox")
|
if not MINIMAX_API_KEY:
|
||||||
PROMPT_WAV = os.getenv("VOXCPM_PROMPT_WAV", "/assets/2food16k_2.wav")
|
raise RuntimeError(f"MINIMAX_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||||
|
|
||||||
# Initialize SenseVoice STT and wrap with StreamAdapter
|
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
||||||
sensevoice_stt = SenseVoiceSTT(url=ASR_URL)
|
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
|
||||||
stt_stream = stt.StreamAdapter(stt=sensevoice_stt, vad=ctx.proc.userdata["vad"])
|
)
|
||||||
|
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
||||||
|
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
||||||
|
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
||||||
|
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
||||||
|
|
||||||
|
blackbox_stt = BlackboxSTT(
|
||||||
|
url=ASR_URL,
|
||||||
|
model_name=ASR_MODEL,
|
||||||
|
language=ASR_LANGUAGE,
|
||||||
|
output_language=ASR_OUTPUT_LANGUAGE,
|
||||||
|
hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"),
|
||||||
|
itn=os.getenv("CUSTOM_ASR_ITN"),
|
||||||
|
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
|
||||||
|
)
|
||||||
|
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import AsyncClient as OpenAIAsyncClient
|
from openai import AsyncClient as OpenAIAsyncClient
|
||||||
|
|
||||||
# Create a custom HTTP client that disables SSL verification
|
# Create a custom HTTP client that disables SSL verification
|
||||||
http_client = httpx.AsyncClient(verify=False)
|
http_client = httpx.AsyncClient(verify=False)
|
||||||
|
|
||||||
# Create the OpenAI AsyncClient with the custom HTTP client
|
# Create the OpenAI AsyncClient with the custom HTTP client
|
||||||
openai_client = OpenAIAsyncClient(
|
openai_client = OpenAIAsyncClient(
|
||||||
api_key="sk-orez64WkG1NkfksB5j_hGA",
|
api_key=MINIMAX_API_KEY,
|
||||||
base_url=MINIMAX_BASE_URL,
|
base_url=MINIMAX_BASE_URL,
|
||||||
http_client=http_client,
|
http_client=http_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tts_voxcpm import VoxCPMTTS
|
|
||||||
|
|
||||||
session: AgentSession = AgentSession(
|
session: AgentSession = AgentSession(
|
||||||
# 1. Custom SenseVoice ASR (STT) with StreamAdapter
|
# 1. Custom ASR blackbox with StreamAdapter
|
||||||
stt=stt_stream,
|
stt=stt_stream,
|
||||||
# 2. Minimax LLM - Using OpenAI plugin with local base_url
|
# 2. Minimax LLM - Using OpenAI plugin with local base_url
|
||||||
llm=openai.LLM(
|
llm=openai.LLM(
|
||||||
model=MINIMAX_MODEL,
|
model=MINIMAX_MODEL,
|
||||||
client=openai_client,
|
client=openai_client,
|
||||||
),
|
),
|
||||||
# 3. VoxCPM TTS - Custom implementation for blackbox API
|
# 3. TTS blackbox
|
||||||
tts=VoxCPMTTS(
|
tts=BlackboxTTS(
|
||||||
url=VOXCPM_URL,
|
url=TTS_URL,
|
||||||
prompt_wav_path=PROMPT_WAV,
|
model_name=TTS_MODEL,
|
||||||
|
params=_tts_params_from_env(TTS_MODEL),
|
||||||
|
prompt_wav_path=os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV"),
|
||||||
|
sample_rate=TTS_SAMPLE_RATE,
|
||||||
|
num_channels=TTS_NUM_CHANNELS,
|
||||||
),
|
),
|
||||||
# 4. Silero VAD
|
# 4. Silero VAD
|
||||||
vad=ctx.proc.userdata["vad"],
|
vad=ctx.proc.userdata["vad"],
|
||||||
@ -150,7 +130,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
"false_interruption_timeout": 1.0,
|
"false_interruption_timeout": 1.0,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
preemptive_generation=True,
|
preemptive_generation=False,
|
||||||
aec_warmup_duration=3.0,
|
aec_warmup_duration=3.0,
|
||||||
tts_text_transforms=[
|
tts_text_transforms=[
|
||||||
"filter_emoji",
|
"filter_emoji",
|
||||||
@ -165,7 +145,102 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
await session.start(
|
await session.start(
|
||||||
agent=CustomAgent(),
|
agent=CustomAgent(),
|
||||||
room=ctx.room,
|
room=ctx.room,
|
||||||
|
room_options=room_io.RoomOptions(
|
||||||
|
audio_output=room_io.AudioOutputOptions(
|
||||||
|
sample_rate=OUTPUT_SAMPLE_RATE,
|
||||||
|
num_channels=TTS_NUM_CHANNELS,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
record=_recording_options_from_env(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
||||||
|
params: dict[str, str] = {}
|
||||||
|
model_name = model_name.lower()
|
||||||
|
|
||||||
|
if model_name == "voxcpmtts":
|
||||||
|
params.update(
|
||||||
|
{
|
||||||
|
"streaming": os.getenv("CUSTOM_TTS_STREAMING", "false"),
|
||||||
|
"prompt_text": os.getenv(
|
||||||
|
"CUSTOM_TTS_PROMPT_TEXT",
|
||||||
|
os.getenv("VOXCPM_PROMPT_TEXT", "澳门有乜嘢好食嘅"),
|
||||||
|
),
|
||||||
|
"cfg_value": os.getenv("VOXCPM_CFG_VALUE", "2.0"),
|
||||||
|
"inference_timesteps": os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "10"),
|
||||||
|
"do_normalize": os.getenv("VOXCPM_DO_NORMALIZE", "true"),
|
||||||
|
"denoise": os.getenv("VOXCPM_DENOISE", "true"),
|
||||||
|
"retry_badcase": os.getenv("VOXCPM_RETRY_BADCASE", "true"),
|
||||||
|
"retry_badcase_max_times": os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES", "3"),
|
||||||
|
"retry_badcase_ratio_threshold": os.getenv(
|
||||||
|
"VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD", "6.0"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif model_name == "melotts":
|
||||||
|
params["speed"] = os.getenv("CUSTOM_TTS_SPEED", "1.0")
|
||||||
|
elif model_name == "cosyvoicetts":
|
||||||
|
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
|
||||||
|
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
|
||||||
|
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
||||||
|
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
|
||||||
|
elif model_name == "sovitstts":
|
||||||
|
params.update(
|
||||||
|
{
|
||||||
|
"text_lang": os.getenv("CUSTOM_TTS_TEXT_LANG", "zh"),
|
||||||
|
"prompt_lang": os.getenv("CUSTOM_TTS_PROMPT_LANG", "zh"),
|
||||||
|
"text_split_method": os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD", "cut0"),
|
||||||
|
"batch_size": os.getenv("CUSTOM_TTS_BATCH_SIZE", "1"),
|
||||||
|
"media_type": os.getenv("CUSTOM_TTS_MEDIA_TYPE", "wav"),
|
||||||
|
"streaming_mode": os.getenv("CUSTOM_TTS_STREAMING", "false"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
|
||||||
|
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
|
||||||
|
if value:
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _env_int(name: str, default: int) -> int:
|
||||||
|
value = os.getenv(name)
|
||||||
|
if not value:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid integer for %s=%r, using %s", name, value, default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _env_bool(name: str, default: bool) -> bool:
|
||||||
|
value = os.getenv(name)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
normalized = value.strip().lower()
|
||||||
|
if normalized in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if normalized in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.warning("Invalid boolean for %s=%r, using %s", name, value, default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _recording_options_from_env() -> RecordingOptions:
|
||||||
|
return RecordingOptions(
|
||||||
|
audio=_env_bool("CUSTOM_RECORD_AUDIO", False),
|
||||||
|
traces=_env_bool("CUSTOM_RECORD_TRACES", False),
|
||||||
|
logs=_env_bool("CUSTOM_RECORD_LOGS", False),
|
||||||
|
transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli.run_app(server)
|
cli.run_app(server)
|
||||||
|
|||||||
34
test_asr.py
34
test_asr.py
@ -1,53 +1,55 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import wave
|
import wave
|
||||||
from custom_agent import SenseVoiceSTT
|
|
||||||
|
from asr import BlackboxSTT
|
||||||
from livekit import rtc
|
from livekit import rtc
|
||||||
from livekit.agents import utils
|
|
||||||
|
|
||||||
# 设置日志级别以查看输出
|
# 设置日志级别以查看输出
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger("test-asr")
|
logger = logging.getLogger("test-asr")
|
||||||
|
|
||||||
|
|
||||||
async def test():
|
async def test():
|
||||||
# 替换为你本地的一个音频文件路径
|
# 替换为你本地的一个音频文件路径
|
||||||
audio_path = "/home/verachen/Music/voice/2food.wav"
|
audio_path = "/home/verachen/Music/voice/2food.wav"
|
||||||
|
|
||||||
# 初始化 ASR
|
# 初始化 ASR
|
||||||
stt = SenseVoiceSTT(url="http://10.6.80.21:5003/asr-blackbox")
|
stt = BlackboxSTT(url="http://10.6.80.21:5003/asr-blackbox", model_name="sensevoice")
|
||||||
|
|
||||||
print(f"Testing ASR connectivity with file: {audio_path}")
|
print(f"Testing ASR connectivity with file: {audio_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 读取音频文件
|
# 读取音频文件
|
||||||
with wave.open(audio_path, 'rb') as wf:
|
with wave.open(audio_path, "rb") as wf:
|
||||||
frames = wf.readframes(wf.getnframes())
|
frames = wf.readframes(wf.getnframes())
|
||||||
# 简单构造一个 AudioBuffer (假设是单声道 16kHz)
|
# 简单构造一个 AudioBuffer (假设是单声道 16kHz)
|
||||||
# 实际上 SenseVoiceSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes()
|
# 实际上 BlackboxSTT._recognize_impl 会用 combine_audio_frames(buffer).to_wav_bytes()
|
||||||
# 所以我们需要传递一个包含 AudioFrame 的 list
|
# 所以我们需要传递一个包含 AudioFrame 的 list
|
||||||
|
|
||||||
# 这里我们模拟一个 Frame
|
# 这里我们模拟一个 Frame
|
||||||
frame = rtc.AudioFrame(
|
frame = rtc.AudioFrame(
|
||||||
data=frames,
|
data=frames,
|
||||||
sample_rate=wf.getframerate(),
|
sample_rate=wf.getframerate(),
|
||||||
num_channels=wf.getnchannels(),
|
num_channels=wf.getnchannels(),
|
||||||
samples_per_channel=wf.getnframes()
|
samples_per_channel=wf.getnframes(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用 recognize
|
# 调用 recognize
|
||||||
result = await stt.recognize(buffer=[frame])
|
result = await stt.recognize(buffer=[frame])
|
||||||
|
|
||||||
if result.alternatives:
|
if result.alternatives:
|
||||||
print(f"\n--- ASR Result ---")
|
print("\n--- ASR Result ---")
|
||||||
print(f"Text: {result.alternatives[0].text}")
|
print(f"Text: {result.alternatives[0].text}")
|
||||||
print(f"------------------\n")
|
print("------------------\n")
|
||||||
else:
|
else:
|
||||||
print("ASR returned no text.")
|
print("ASR returned no text.")
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print(f"Error: Audio file not found at {audio_path}")
|
print(f"Error: Audio file not found at {audio_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred: {e}")
|
print(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(test())
|
asyncio.run(test())
|
||||||
|
|||||||
@ -1,50 +1,66 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
from tts_voxcpm import VoxCPMTTS
|
import os
|
||||||
from livekit.agents import tts
|
|
||||||
|
from tts import BlackboxTTS
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
async def test_tts():
|
async def test_tts():
|
||||||
# Use the URL from the user's curl command
|
# Use the URL from the user's curl command
|
||||||
url = "http://10.6.80.21:5002/tts-blackbox"
|
url = "http://10.6.80.21:5002/tts-blackbox"
|
||||||
|
|
||||||
# Check if we have a real wav file to test with
|
# Check if we have a real wav file to test with
|
||||||
# In the earlier find_by_name, we found tests/change-sophie.wav
|
# In the earlier find_by_name, we found tests/change-sophie.wav
|
||||||
prompt_wav = "/home/verachen/Music/voice/2food.wav"
|
prompt_wav = "/home/verachen/Music/voice/2food.wav"
|
||||||
if not os.path.exists(prompt_wav):
|
if not os.path.exists(prompt_wav):
|
||||||
prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl
|
prompt_wav = "/home/verachen/Music/voice/2food.wav" # fallback to the one in curl
|
||||||
|
|
||||||
print(f"Testing VoxCPMTTS with URL: {url}")
|
print(f"Testing BlackboxTTS with URL: {url}")
|
||||||
print(f"Using prompt wav: {prompt_wav}")
|
print(f"Using prompt wav: {prompt_wav}")
|
||||||
|
|
||||||
vox_tts = VoxCPMTTS(
|
blackbox_tts = BlackboxTTS(
|
||||||
url=url,
|
url=url,
|
||||||
prompt_wav_path=prompt_wav
|
model_name="voxcpmtts",
|
||||||
|
prompt_wav_path=prompt_wav,
|
||||||
|
params={
|
||||||
|
"streaming": "false",
|
||||||
|
"prompt_text": "澳门有乜嘢好食嘅",
|
||||||
|
"cfg_value": "2.0",
|
||||||
|
"inference_timesteps": "10",
|
||||||
|
"do_normalize": "true",
|
||||||
|
"denoise": "true",
|
||||||
|
"retry_badcase": "true",
|
||||||
|
"retry_badcase_max_times": "3",
|
||||||
|
"retry_badcase_ratio_threshold": "6.0",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
text = "你好,这是一段测试文本"
|
text = "你好,这是一段测试文本"
|
||||||
print(f"Synthesizing text: {text}")
|
print(f"Synthesizing text: {text}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream = vox_tts.synthesize(text)
|
stream = blackbox_tts.synthesize(text)
|
||||||
audio_frame = await stream.collect()
|
audio_frame = await stream.collect()
|
||||||
|
|
||||||
print(f"Successfully synthesized audio!")
|
print("Successfully synthesized audio!")
|
||||||
print(f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?")
|
print(
|
||||||
|
f"Audio duration: {audio_frame.sample_rate * len(audio_frame.data) / (audio_frame.num_channels * 2)} samples?"
|
||||||
|
)
|
||||||
# Actually AudioFrame has duration or samples
|
# Actually AudioFrame has duration or samples
|
||||||
print(f"Samples: {len(audio_frame.data) // 2}")
|
print(f"Samples: {len(audio_frame.data) // 2}")
|
||||||
|
|
||||||
# Save to file for manual check if possible
|
# Save to file for manual check if possible
|
||||||
with open("test_output.wav", "wb") as f:
|
with open("test_output.wav", "wb") as f:
|
||||||
# This won't be a valid WAV yet if it's just raw PCM,
|
# This won't be a valid WAV yet if it's just raw PCM,
|
||||||
# but if collect() returns combined frames, we can use to_wav_bytes()
|
# but if collect() returns combined frames, we can use to_wav_bytes()
|
||||||
f.write(audio_frame.to_wav_bytes())
|
f.write(audio_frame.to_wav_bytes())
|
||||||
print("Saved output to test_output.wav")
|
print("Saved output to test_output.wav")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"TTS test failed: {e}")
|
print(f"TTS test failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(test_tts())
|
asyncio.run(test_tts())
|
||||||
|
|||||||
201
tts.py
Normal file
201
tts.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import wave
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from livekit.agents import (
|
||||||
|
DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
APIConnectionError,
|
||||||
|
APIConnectOptions,
|
||||||
|
APIStatusError,
|
||||||
|
APITimeoutError,
|
||||||
|
tts,
|
||||||
|
utils,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("blackbox-tts")
|
||||||
|
|
||||||
|
|
||||||
|
class BlackboxTTS(tts.TTS):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
model_name: str = "voxcpmtts",
|
||||||
|
params: Optional[Mapping[str, object]] = None,
|
||||||
|
prompt_wav_path: Optional[str] = None,
|
||||||
|
prompt_wav_field: str = "prompt_wav",
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
num_channels: int = 1,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
http_session: Optional[aiohttp.ClientSession] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
capabilities=tts.TTSCapabilities(streaming=False),
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
num_channels=num_channels,
|
||||||
|
)
|
||||||
|
self._url = url
|
||||||
|
self._model_name = model_name
|
||||||
|
self._params = {key: _form_value(value) for key, value in (params or {}).items()}
|
||||||
|
self._prompt_wav_path = prompt_wav_path
|
||||||
|
self._prompt_wav_field = prompt_wav_field
|
||||||
|
self._timeout = timeout
|
||||||
|
self._http_session = http_session
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self) -> str:
|
||||||
|
return "tts-blackbox"
|
||||||
|
|
||||||
|
def _ensure_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._http_session is None:
|
||||||
|
self._http_session = utils.http_context.http_session()
|
||||||
|
return self._http_session
|
||||||
|
|
||||||
|
def synthesize(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||||||
|
) -> tts.ChunkedStream:
|
||||||
|
return BlackboxTTSStream(
|
||||||
|
tts=self,
|
||||||
|
input_text=text,
|
||||||
|
conn_options=conn_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BlackboxTTSStream(tts.ChunkedStream):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tts: BlackboxTTS,
|
||||||
|
input_text: str,
|
||||||
|
conn_options: APIConnectOptions,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
|
||||||
|
self._tts: BlackboxTTS = tts
|
||||||
|
|
||||||
|
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
||||||
|
form = aiohttp.FormData(default_to_multipart=True)
|
||||||
|
form.add_field("text", self.input_text)
|
||||||
|
form.add_field("model_name", self._tts._model_name)
|
||||||
|
for key, value in self._tts._params.items():
|
||||||
|
form.add_field(key, value)
|
||||||
|
|
||||||
|
prompt_file = None
|
||||||
|
if self._tts._prompt_wav_path:
|
||||||
|
if os.path.exists(self._tts._prompt_wav_path):
|
||||||
|
prompt_file = open(self._tts._prompt_wav_path, "rb")
|
||||||
|
form.add_field(
|
||||||
|
self._tts._prompt_wav_field,
|
||||||
|
prompt_file,
|
||||||
|
filename=os.path.basename(self._tts._prompt_wav_path),
|
||||||
|
content_type="audio/wav",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Prompt wav file not found at %s, skipping prompt_wav field",
|
||||||
|
self._tts._prompt_wav_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._tts._ensure_session().post(
|
||||||
|
self._tts._url,
|
||||||
|
data=form,
|
||||||
|
timeout=aiohttp.ClientTimeout(
|
||||||
|
total=self._tts._timeout,
|
||||||
|
sock_connect=self._conn_options.timeout,
|
||||||
|
),
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
raise APIStatusError(
|
||||||
|
message=f"TTS blackbox error: {error_text}",
|
||||||
|
status_code=resp.status,
|
||||||
|
request_id=None,
|
||||||
|
body=error_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
content_type = resp.headers.get("Content-Type", "audio/wav")
|
||||||
|
logged_wav_format = False
|
||||||
|
wav_header_probe = bytearray()
|
||||||
|
output_emitter.initialize(
|
||||||
|
request_id=utils.shortuuid(),
|
||||||
|
sample_rate=self._tts.sample_rate,
|
||||||
|
num_channels=self._tts.num_channels,
|
||||||
|
mime_type=content_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for data, _ in resp.content.iter_chunks():
|
||||||
|
if data:
|
||||||
|
if not logged_wav_format:
|
||||||
|
wav_header_probe.extend(data)
|
||||||
|
logged_wav_format = _log_wav_format(
|
||||||
|
bytes(wav_header_probe),
|
||||||
|
requested_sample_rate=self._tts.sample_rate,
|
||||||
|
requested_channels=self._tts.num_channels,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
if not logged_wav_format and len(wav_header_probe) > 4096:
|
||||||
|
logger.info(
|
||||||
|
"TTS blackbox WAV format probe incomplete after %s bytes",
|
||||||
|
len(wav_header_probe),
|
||||||
|
)
|
||||||
|
logged_wav_format = True
|
||||||
|
output_emitter.push(data)
|
||||||
|
output_emitter.flush()
|
||||||
|
except asyncio.TimeoutError as e:
|
||||||
|
raise APITimeoutError("TTS blackbox request timed out") from e
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise APIConnectionError(f"TTS blackbox connection error: {e}") from e
|
||||||
|
finally:
|
||||||
|
if prompt_file is not None:
|
||||||
|
prompt_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _form_value(value: object) -> str:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return str(value).lower()
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_wav_format(
|
||||||
|
data: bytes,
|
||||||
|
*,
|
||||||
|
requested_sample_rate: int,
|
||||||
|
requested_channels: int,
|
||||||
|
content_type: str,
|
||||||
|
) -> bool:
|
||||||
|
if not content_type.lower().startswith("audio/wav"):
|
||||||
|
logger.info("TTS blackbox returned content-type=%s", content_type)
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
with wave.open(BytesIO(data), "rb") as wav:
|
||||||
|
sample_rate = wav.getframerate()
|
||||||
|
channels = wav.getnchannels()
|
||||||
|
sample_width = wav.getsampwidth()
|
||||||
|
except (EOFError, wave.Error):
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"TTS blackbox WAV format: %sHz, %sch, %s-bit; output target: %sHz, %sch",
|
||||||
|
sample_rate,
|
||||||
|
channels,
|
||||||
|
sample_width * 8,
|
||||||
|
requested_sample_rate,
|
||||||
|
requested_channels,
|
||||||
|
)
|
||||||
|
return True
|
||||||
118
tts_voxcpm.py
118
tts_voxcpm.py
@ -1,118 +0,0 @@
|
|||||||
import aiohttp
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from livekit.agents import tts, utils, APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS
|
|
||||||
|
|
||||||
logger = logging.getLogger("voxcpm-tts")
|
|
||||||
|
|
||||||
class VoxCPMTTS(tts.TTS):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
url: str,
|
|
||||||
model_name: str = "voxcpmtts",
|
|
||||||
prompt_text: str = "澳门有乜嘢好食嘅",
|
|
||||||
prompt_wav_path: str = "/home/verachen/Music/voice/2food16k_2.wav",
|
|
||||||
cfg_value: str = "2.0",
|
|
||||||
inference_timesteps: str = "10",
|
|
||||||
do_normalize: str = "true",
|
|
||||||
denoise: str = "true",
|
|
||||||
retry_badcase: str = "true",
|
|
||||||
retry_badcase_max_times: str = "3",
|
|
||||||
retry_badcase_ratio_threshold: str = "6.0",
|
|
||||||
sample_rate: int = 16000,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
capabilities=tts.TTSCapabilities(streaming=False),
|
|
||||||
sample_rate=sample_rate,
|
|
||||||
num_channels=1,
|
|
||||||
)
|
|
||||||
self._url = url
|
|
||||||
self._opts = {
|
|
||||||
"model_name": model_name,
|
|
||||||
"streaming": "false",
|
|
||||||
"prompt_text": prompt_text,
|
|
||||||
"cfg_value": str(cfg_value),
|
|
||||||
"inference_timesteps": str(inference_timesteps),
|
|
||||||
"do_normalize": str(do_normalize),
|
|
||||||
"denoise": str(denoise),
|
|
||||||
"retry_badcase": str(retry_badcase),
|
|
||||||
"retry_badcase_max_times": str(retry_badcase_max_times),
|
|
||||||
"retry_badcase_ratio_threshold": str(retry_badcase_ratio_threshold),
|
|
||||||
}
|
|
||||||
self._prompt_wav_path = prompt_wav_path
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model(self) -> str:
|
|
||||||
return self._opts["model_name"]
|
|
||||||
|
|
||||||
def synthesize(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
*,
|
|
||||||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
|
||||||
) -> tts.ChunkedStream:
|
|
||||||
return VoxCPMStream(
|
|
||||||
self, text, self._url, self._opts, self._prompt_wav_path, conn_options=conn_options
|
|
||||||
)
|
|
||||||
|
|
||||||
class VoxCPMStream(tts.ChunkedStream):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tts: VoxCPMTTS,
|
|
||||||
text: str,
|
|
||||||
url: str,
|
|
||||||
opts: dict,
|
|
||||||
prompt_wav_path: str,
|
|
||||||
conn_options: APIConnectOptions,
|
|
||||||
):
|
|
||||||
super().__init__(tts=tts, input_text=text, conn_options=conn_options)
|
|
||||||
self._url = url
|
|
||||||
self._opts = opts
|
|
||||||
self._prompt_wav_path = prompt_wav_path
|
|
||||||
|
|
||||||
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
|
||||||
# Initialize emitter early to avoid "AudioEmitter isn't started" error on failure
|
|
||||||
output_emitter.initialize(
|
|
||||||
request_id="",
|
|
||||||
sample_rate=self._tts.sample_rate,
|
|
||||||
num_channels=self._tts.num_channels,
|
|
||||||
mime_type="audio/wav",
|
|
||||||
)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
data = aiohttp.FormData()
|
|
||||||
data.add_field("text", self.input_text)
|
|
||||||
for k, v in self._opts.items():
|
|
||||||
data.add_field(k, v)
|
|
||||||
|
|
||||||
# Open the prompt wav file if it exists
|
|
||||||
f = None
|
|
||||||
if os.path.exists(self._prompt_wav_path):
|
|
||||||
f = open(self._prompt_wav_path, "rb")
|
|
||||||
data.add_field("prompt_wav", f, filename="prompt.wav", content_type="audio/wav")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Prompt wav file not found at {self._prompt_wav_path}, skipping prompt_wav field"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Set a reasonable timeout for synthesis
|
|
||||||
async with session.post(
|
|
||||||
self._url, data=data, timeout=aiohttp.ClientTimeout(total=60)
|
|
||||||
) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
err_text = await resp.text()
|
|
||||||
logger.error(f"VoxCPM TTS error: {resp.status} {err_text}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Read the entire audio data (since streaming=false)
|
|
||||||
audio_data = await resp.read()
|
|
||||||
|
|
||||||
output_emitter.push(audio_data)
|
|
||||||
output_emitter.flush()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"VoxCPM TTS request failed: {e}")
|
|
||||||
finally:
|
|
||||||
if f:
|
|
||||||
f.close()
|
|
||||||
Reference in New Issue
Block a user