perf: improve speed
This commit is contained in:
24
tts.py
24
tts.py
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import wave
|
||||
from collections.abc import Mapping
|
||||
from io import BytesIO
|
||||
@ -88,6 +89,7 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
||||
self._tts: BlackboxTTS = tts
|
||||
|
||||
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
||||
started_at = time.perf_counter()
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
form.add_field("text", self.input_text)
|
||||
form.add_field("model_name", self._tts._model_name)
|
||||
@ -131,6 +133,9 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
||||
content_type = resp.headers.get("Content-Type", "audio/wav")
|
||||
logged_wav_format = False
|
||||
wav_header_probe = bytearray()
|
||||
first_audio_at: float | None = None
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
output_emitter.initialize(
|
||||
request_id=utils.shortuuid(),
|
||||
sample_rate=self._tts.sample_rate,
|
||||
@ -140,6 +145,16 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
||||
|
||||
async for data, _ in resp.content.iter_chunks():
|
||||
if data:
|
||||
chunk_count += 1
|
||||
total_bytes += len(data)
|
||||
if first_audio_at is None:
|
||||
first_audio_at = time.perf_counter()
|
||||
logger.info(
|
||||
"TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)",
|
||||
first_audio_at - started_at,
|
||||
len(self.input_text),
|
||||
len(data),
|
||||
)
|
||||
if not logged_wav_format:
|
||||
wav_header_probe.extend(data)
|
||||
logged_wav_format = _log_wav_format(
|
||||
@ -156,6 +171,15 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
||||
logged_wav_format = True
|
||||
output_emitter.push(data)
|
||||
output_emitter.flush()
|
||||
finished_at = time.perf_counter()
|
||||
logger.info(
|
||||
"TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)",
|
||||
finished_at - started_at,
|
||||
(first_audio_at - started_at) if first_audio_at else -1.0,
|
||||
chunk_count,
|
||||
total_bytes,
|
||||
len(self.input_text),
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise APITimeoutError("TTS blackbox request timed out") from e
|
||||
except aiohttp.ClientError as e:
|
||||
|
||||
Reference in New Issue
Block a user