From 2b0c569b7a75f32a6ba05e687fea63fb9b7413b7 Mon Sep 17 00:00:00 2001 From: vera <511201264@qq.com> Date: Wed, 21 Jan 2026 17:52:57 +0800 Subject: [PATCH] feat: Update Model Weights, use VoxCPM1.5 and model parameters, Supports streaming --- .github/workflows/ci-cd.yaml | 2 +- README_BW.md | 33 +++++- api_concurrent.py | 210 ++++++++++++++++++++++++++--------- test_streaming.py | 34 ++++++ 4 files changed, 225 insertions(+), 54 deletions(-) create mode 100644 test_streaming.py diff --git a/.github/workflows/ci-cd.yaml b/.github/workflows/ci-cd.yaml index 7dfdde0..b4b739f 100644 --- a/.github/workflows/ci-cd.yaml +++ b/.github/workflows/ci-cd.yaml @@ -7,7 +7,7 @@ on: # branches: [ main ] env: - VERSION: 0.0.1 + VERSION: 0.0.2 REGISTRY: https://harbor.bwgdi.com REGISTRY_NAME: harbor.bwgdi.com IMAGE_NAME: voxcpmtts diff --git a/README_BW.md b/README_BW.md index 8aee73d..c761e71 100644 --- a/README_BW.md +++ b/README_BW.md @@ -6,10 +6,17 @@ https://github.com/BoardWare-Genius/VoxCPM | Version | Date | Summary | |---------|------------|---------------------------------| +| 0.0.2 | 2026-01-21 | Supports streaming | | 0.0.1 | 2026-01-20 | Initial version | ### 🔄 Version Details +#### 🆕 0.0.2 – *2026-01-21* + +- ✅ **Core Features** + - Update Model Weights, use VoxCPM1.5 and model parameters + - Supports streaming + #### 🆕 0.0.1 – *2026-01-20* - ✅ **Core Features** @@ -20,12 +27,14 @@ https://github.com/BoardWare-Genius/VoxCPM # Start ```bash -docker pull harbor.bwgdi.com/library/voxcpmtts:0.0.1 +docker pull harbor.bwgdi.com/library/voxcpmtts:0.0.2 -docker run -d --restart always -p 5001:5000 --gpus all --mount type=bind,source=/Workspace/NAS11/model,target=/models harbor.bwgdi.com/library/voxcpmtts:0.0.1 +docker run -d --restart always -p 5001:5000 --gpus all --mount type=bind,source=/Workspace/NAS11/model/Voice/VoxCPM,target=/models harbor.bwgdi.com/library/voxcpmtts:0.0.2 ``` # Usage + +## Non-streaming ```bash curl --location 'http://localhost:5001/generate_tts' \ --form 'text="你好,这是一段测试文本"' \ @@ -34,5 +43,23 @@ curl --location 'http://localhost:5001/generate_tts' \ --form 'inference_timesteps="10"' \ --form 'do_normalize="true"' \ --form 'denoise="true"' \ ---form 'prompt_wav=@"/assets/2play16k_2.wav"' +--form 'retry_badcase="true"' \ +--form 'retry_badcase_max_times="3"' \ +--form 'retry_badcase_ratio_threshold="6.0"' \ +--form 'prompt_wav=@"/assets/2food16k_2.wav"' +``` + +## Streaming +```bash +curl --location 'http://localhost:5001/generate_tts_streaming' \ +--form 'text="你好,这是一段测试文本"' \ +--form 'prompt_text="这是提示文本"' \ +--form 'cfg_value="2.0"' \ +--form 'inference_timesteps="10"' \ +--form 'do_normalize="true"' \ +--form 'denoise="true"' \ +--form 'retry_badcase="true"' \ +--form 'retry_badcase_max_times="3"' \ +--form 'retry_badcase_ratio_threshold="6.0"' \ +--form 'prompt_wav=@"/Workspace/NAS11/model/Voice/assets/2food16k_2.wav"' ``` diff --git a/api_concurrent.py b/api_concurrent.py index 3908b87..3443277 100644 --- a/api_concurrent.py +++ b/api_concurrent.py @@ -2,6 +2,9 @@ import os import torch import tempfile import soundfile as sf +import numpy as np +import wave +from io import BytesIO import asyncio from typing import Optional from fastapi import FastAPI, Form, UploadFile, BackgroundTasks @@ -13,7 +16,7 @@ import uvicorn os.environ["TOKENIZERS_PARALLELISM"] = "false" if os.environ.get("HF_REPO_ID", "").strip() == "": - os.environ["HF_REPO_ID"] = "/models/Voice/VoxCPM/VoxCPM1.5/" + os.environ["HF_REPO_ID"] = "/models/VoxCPM1.5/" # ========== 模型类 ========== @@ -22,7 +25,7 @@ class VoxCPMDemo: self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Running on device: {self.device}") self.voxcpm_model: Optional[voxcpm.VoxCPM] = None - self.default_local_model_dir = "/models/Voice/VoxCPM/VoxCPM1.5/" + self.default_local_model_dir = "/models/VoxCPM1.5/" def _resolve_model_dir(self) -> str: if os.path.isdir(self.default_local_model_dir): @@ -46,6 +49,9 @@ class VoxCPMDemo: inference_timesteps: int = 10, do_normalize: bool = True, denoise: bool = True, + retry_badcase: bool = True, + retry_badcase_max_times: int = 3, + retry_badcase_ratio_threshold: float = 6.0, ) -> str: model = self.get_or_load_voxcpm() wav = model.generate( @@ -56,54 +62,86 @@ class VoxCPMDemo: inference_timesteps=int(inference_timesteps), normalize=do_normalize, denoise=denoise, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, ) tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") - sf.write(tmp_wav.name, wav, 44100) + sf.write(tmp_wav.name, wav, model.tts_model.sample_rate) torch.cuda.empty_cache() return tmp_wav.name + def tts_generate_streaming( + self, + text: str, + prompt_wav_path: Optional[str] = None, + prompt_text: Optional[str] = None, + cfg_value: float = 2.0, + inference_timesteps: int = 10, + do_normalize: bool = True, + denoise: bool = True, + retry_badcase: bool = True, + retry_badcase_max_times: int = 3, + retry_badcase_ratio_threshold: float = 6.0, + ): + """Generates audio and yields it as a stream of WAV chunks.""" + model = self.get_or_load_voxcpm() + + # 1. Yield a WAV header first. + # The size fields will be 0, which is standard for streaming. + SAMPLE_RATE = model.tts_model.sample_rate + CHANNELS = 1 + SAMPLE_WIDTH = 2 # 16-bit + + header_buf = BytesIO() + with wave.open(header_buf, "wb") as wf: + wf.setnchannels(CHANNELS) + wf.setsampwidth(SAMPLE_WIDTH) + wf.setframerate(SAMPLE_RATE) + + yield header_buf.getvalue() + + # 2. Generate and yield audio chunks. + # NOTE: We assume a `generate_stream` method exists on the model that yields audio chunks. + # You may need to change `generate_stream` to the actual method name in your version of voxcpm. + try: + stream = model.generate_streaming( + text=text, + prompt_text=prompt_text, + prompt_wav_path=prompt_wav_path, + cfg_value=float(cfg_value), + inference_timesteps=int(inference_timesteps), + normalize=do_normalize, + denoise=denoise, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + ) + for chunk_np in stream: # Assuming it yields numpy arrays + # Ensure audio is in 16-bit PCM format for streaming + if chunk_np.dtype in [np.float32, np.float64]: + chunk_np = (chunk_np * 32767).astype(np.int16) + yield chunk_np.tobytes() + finally: + torch.cuda.empty_cache() + # ========== FastAPI ========== app = FastAPI(title="VoxCPM API", version="1.0.0") demo = VoxCPMDemo() -executor = ThreadPoolExecutor(max_workers=2) # CPU线程池,执行GPU任务 -gpu_queue = asyncio.Queue() # GPU并发队列 -MAX_GPU_CONCURRENT = 1 # 单GPU同时最多1个任务,可以调整 +# --- Concurrency Control --- +# Use a semaphore to limit concurrent GPU tasks. +MAX_GPU_CONCURRENT = int(os.environ.get("MAX_GPU_CONCURRENT", "1")) +gpu_semaphore = asyncio.Semaphore(MAX_GPU_CONCURRENT) +# Use a thread pool for running blocking (CPU/GPU-bound) code. +executor = ThreadPoolExecutor(max_workers=2) -# GPU队列消费者协程 -async def gpu_worker(): - while True: - future, func, args, kwargs = await gpu_queue.get() - loop = asyncio.get_running_loop() - try: - # 在线程池中执行GPU任务 - result = await loop.run_in_executor(executor, func, *args, **kwargs) - future.set_result(result) - except Exception as e: - future.set_exception(e) - finally: - gpu_queue.task_done() - - -# 启动GPU消费者协程 -async def start_gpu_workers(): - for _ in range(MAX_GPU_CONCURRENT): - asyncio.create_task(gpu_worker()) - - -@app.on_event("startup") -async def startup_event(): - await start_gpu_workers() - - -# 封装GPU队列任务 -async def submit_to_gpu(func, *args, **kwargs): - loop = asyncio.get_running_loop() - future = loop.create_future() - await gpu_queue.put((future, func, args, kwargs)) - return await future +@app.on_event("shutdown") +def shutdown_event(): + print("Shutting down thread pool executor...") + executor.shutdown(wait=True) # ---------- TTS API ---------- @@ -116,6 +154,9 @@ async def generate_tts( inference_timesteps: int = Form(10), do_normalize: bool = Form(True), denoise: bool = Form(True), + retry_badcase: bool = Form(True), + retry_badcase_max_times: int = Form(3), + retry_badcase_ratio_threshold: float = Form(6.0), prompt_wav: Optional[UploadFile] = None, ): try: @@ -126,17 +167,23 @@ async def generate_tts( prompt_path = tmp.name background_tasks.add_task(os.remove, tmp.name) - # 提交到GPU队列 - output_path = await submit_to_gpu( - demo.tts_generate, - text, - prompt_path, - prompt_text, - cfg_value, - inference_timesteps, - do_normalize, - denoise - ) + # Submit to GPU via semaphore and executor + loop = asyncio.get_running_loop() + async with gpu_semaphore: + output_path = await loop.run_in_executor( + executor, + demo.tts_generate, + text, + prompt_path, + prompt_text, + cfg_value, + inference_timesteps, + do_normalize, + denoise, + retry_badcase, + retry_badcase_max_times, + retry_badcase_ratio_threshold, + ) # 后台删除生成的文件 background_tasks.add_task(os.remove, output_path) @@ -151,9 +198,72 @@ async def generate_tts( return JSONResponse({"error": str(e)}, status_code=500) +@app.post("/generate_tts_streaming") +async def generate_tts_streaming( + background_tasks: BackgroundTasks, + text: str = Form(...), + prompt_text: Optional[str] = Form(None), + cfg_value: float = Form(2.0), + inference_timesteps: int = Form(10), + do_normalize: bool = Form(True), + denoise: bool = Form(True), + retry_badcase: bool = Form(True), + retry_badcase_max_times: int = Form(3), + retry_badcase_ratio_threshold: float = Form(6.0), + prompt_wav: Optional[UploadFile] = None, +): + try: + prompt_path = None + if prompt_wav: + # Save uploaded file to a temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: + tmp.write(await prompt_wav.read()) + prompt_path = tmp.name + # Ensure the temp file is deleted after the request is finished + background_tasks.add_task(os.remove, prompt_path) + + async def stream_generator(): + # This async generator consumes from a queue populated by a sync generator in a thread. + q = asyncio.Queue() + loop = asyncio.get_running_loop() + + def producer(): + # This runs in the executor thread and produces chunks. + try: + # This is a sync generator + for chunk in demo.tts_generate_streaming( + text, prompt_path, prompt_text, cfg_value, + inference_timesteps, do_normalize, denoise, + retry_badcase, retry_badcase_max_times, retry_badcase_ratio_threshold + ): + loop.call_soon_threadsafe(q.put_nowait, chunk) + except Exception as e: + # Put the exception in the queue to be re-raised in the consumer + loop.call_soon_threadsafe(q.put_nowait, e) + finally: + # Signal the end of the stream + loop.call_soon_threadsafe(q.put_nowait, None) + + # Acquire the GPU semaphore before starting the producer thread. + async with gpu_semaphore: + loop.run_in_executor(executor, producer) + while True: + chunk = await q.get() + if chunk is None: + break + if isinstance(chunk, Exception): + raise chunk + yield chunk + + return StreamingResponse(stream_generator(), media_type="audio/wav") + + except Exception as e: + return JSONResponse({"error": str(e)}, status_code=500) + + @app.get("/") async def root(): return {"message": "VoxCPM API running 🚀", "endpoints": ["/generate_tts"]} if __name__ == "__main__": - uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=8) \ No newline at end of file + uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=4) \ No newline at end of file diff --git a/test_streaming.py b/test_streaming.py new file mode 100644 index 0000000..f10166c --- /dev/null +++ b/test_streaming.py @@ -0,0 +1,34 @@ +import asyncio +import httpx +import time + +async def test_streaming(): + url = "http://localhost:8880/generate_tts_streaming" + data = { + "text": "你好,这是一段流式输出的测试音频。", + "cfg_value": "2.0", + "inference_timesteps": "10", + "do_normalize": "True", + "denoise": "True" + } + + start_time = time.time() + first_byte_received = False + + async with httpx.AsyncClient() as client: + # 模拟文件上传(如果有 prompt_wav) + # files = {'prompt_wav': open('test.wav', 'rb')} + + async with client.stream("POST", url, data=data) as response: + print(f"状态码: {response.status_code}") + + async for chunk in response.aiter_bytes(): + if not first_byte_received: + ttfb = time.time() - start_time + print(f"🚀 首包到达 (TTFB): {ttfb:.4f} 秒") + first_byte_received = True + + print(f"收到数据块: {len(chunk)} 字节") + +if __name__ == "__main__": + asyncio.run(test_streaming()) \ No newline at end of file