feat: Update Model Weights, use VoxCPM1.5 and model parameters, Supports streaming

This commit is contained in:
vera
2026-01-21 17:52:57 +08:00
parent 721c53fe87
commit 2b0c569b7a
4 changed files with 225 additions and 54 deletions

View File

@ -7,7 +7,7 @@ on:
# branches: [ main ]
env:
VERSION: 0.0.1
VERSION: 0.0.2
REGISTRY: https://harbor.bwgdi.com
REGISTRY_NAME: harbor.bwgdi.com
IMAGE_NAME: voxcpmtts

View File

@ -6,10 +6,17 @@ https://github.com/BoardWare-Genius/VoxCPM
| Version | Date | Summary |
|---------|------------|---------------------------------|
| 0.0.2 | 2026-01-21 | Supports streaming |
| 0.0.1 | 2026-01-20 | Initial version |
### 🔄 Version Details
#### 🆕 0.0.2 *2026-01-21*
-**Core Features**
- Update Model Weights, use VoxCPM1.5 and model parameters
- Supports streaming
#### 🆕 0.0.1 *2026-01-20*
-**Core Features**
@ -20,12 +27,14 @@ https://github.com/BoardWare-Genius/VoxCPM
# Start
```bash
docker pull harbor.bwgdi.com/library/voxcpmtts:0.0.1
docker pull harbor.bwgdi.com/library/voxcpmtts:0.0.2
docker run -d --restart always -p 5001:5000 --gpus all --mount type=bind,source=/Workspace/NAS11/model,target=/models harbor.bwgdi.com/library/voxcpmtts:0.0.1
docker run -d --restart always -p 5001:5000 --gpus all --mount type=bind,source=/Workspace/NAS11/model/Voice/VoxCPM,target=/models harbor.bwgdi.com/library/voxcpmtts:0.0.2
```
# Usage
## Non-streaming
```bash
curl --location 'http://localhost:5001/generate_tts' \
--form 'text="你好,这是一段测试文本"' \
@ -34,5 +43,23 @@ curl --location 'http://localhost:5001/generate_tts' \
--form 'inference_timesteps="10"' \
--form 'do_normalize="true"' \
--form 'denoise="true"' \
--form 'prompt_wav=@"/assets/2play16k_2.wav"'
--form 'retry_badcase="true"' \
--form 'retry_badcase_max_times="3"' \
--form 'retry_badcase_ratio_threshold="6.0"' \
--form 'prompt_wav=@"/assets/2food16k_2.wav"'
```
## Streaming
```bash
curl --location 'http://localhost:5001/generate_tts_streaming' \
--form 'text="你好,这是一段测试文本"' \
--form 'prompt_text="这是提示文本"' \
--form 'cfg_value="2.0"' \
--form 'inference_timesteps="10"' \
--form 'do_normalize="true"' \
--form 'denoise="true"' \
--form 'retry_badcase="true"' \
--form 'retry_badcase_max_times="3"' \
--form 'retry_badcase_ratio_threshold="6.0"' \
--form 'prompt_wav=@"/Workspace/NAS11/model/Voice/assets/2food16k_2.wav"'
```

View File

@ -2,6 +2,9 @@ import os
import torch
import tempfile
import soundfile as sf
import numpy as np
import wave
from io import BytesIO
import asyncio
from typing import Optional
from fastapi import FastAPI, Form, UploadFile, BackgroundTasks
@ -13,7 +16,7 @@ import uvicorn
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "/models/Voice/VoxCPM/VoxCPM1.5/"
os.environ["HF_REPO_ID"] = "/models/VoxCPM1.5/"
# ========== 模型类 ==========
@ -22,7 +25,7 @@ class VoxCPMDemo:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {self.device}")
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "/models/Voice/VoxCPM/VoxCPM1.5/"
self.default_local_model_dir = "/models/VoxCPM1.5/"
def _resolve_model_dir(self) -> str:
if os.path.isdir(self.default_local_model_dir):
@ -46,6 +49,9 @@ class VoxCPMDemo:
inference_timesteps: int = 10,
do_normalize: bool = True,
denoise: bool = True,
retry_badcase: bool = True,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
) -> str:
model = self.get_or_load_voxcpm()
wav = model.generate(
@ -56,54 +62,86 @@ class VoxCPMDemo:
inference_timesteps=int(inference_timesteps),
normalize=do_normalize,
denoise=denoise,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
)
tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(tmp_wav.name, wav, 44100)
sf.write(tmp_wav.name, wav, model.tts_model.sample_rate)
torch.cuda.empty_cache()
return tmp_wav.name
def tts_generate_streaming(
self,
text: str,
prompt_wav_path: Optional[str] = None,
prompt_text: Optional[str] = None,
cfg_value: float = 2.0,
inference_timesteps: int = 10,
do_normalize: bool = True,
denoise: bool = True,
retry_badcase: bool = True,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
):
"""Generates audio and yields it as a stream of WAV chunks."""
model = self.get_or_load_voxcpm()
# 1. Yield a WAV header first.
# The size fields will be 0, which is standard for streaming.
SAMPLE_RATE = model.tts_model.sample_rate
CHANNELS = 1
SAMPLE_WIDTH = 2 # 16-bit
header_buf = BytesIO()
with wave.open(header_buf, "wb") as wf:
wf.setnchannels(CHANNELS)
wf.setsampwidth(SAMPLE_WIDTH)
wf.setframerate(SAMPLE_RATE)
yield header_buf.getvalue()
# 2. Generate and yield audio chunks.
# NOTE: We assume a `generate_stream` method exists on the model that yields audio chunks.
# You may need to change `generate_stream` to the actual method name in your version of voxcpm.
try:
stream = model.generate_streaming(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
cfg_value=float(cfg_value),
inference_timesteps=int(inference_timesteps),
normalize=do_normalize,
denoise=denoise,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
)
for chunk_np in stream: # Assuming it yields numpy arrays
# Ensure audio is in 16-bit PCM format for streaming
if chunk_np.dtype in [np.float32, np.float64]:
chunk_np = (chunk_np * 32767).astype(np.int16)
yield chunk_np.tobytes()
finally:
torch.cuda.empty_cache()
# ========== FastAPI ==========
app = FastAPI(title="VoxCPM API", version="1.0.0")
demo = VoxCPMDemo()
executor = ThreadPoolExecutor(max_workers=2) # CPU线程池执行GPU任务
gpu_queue = asyncio.Queue() # GPU并发队列
MAX_GPU_CONCURRENT = 1 # 单GPU同时最多1个任务可以调整
# --- Concurrency Control ---
# Use a semaphore to limit concurrent GPU tasks.
MAX_GPU_CONCURRENT = int(os.environ.get("MAX_GPU_CONCURRENT", "1"))
gpu_semaphore = asyncio.Semaphore(MAX_GPU_CONCURRENT)
# Use a thread pool for running blocking (CPU/GPU-bound) code.
executor = ThreadPoolExecutor(max_workers=2)
# GPU队列消费者协程
async def gpu_worker():
while True:
future, func, args, kwargs = await gpu_queue.get()
loop = asyncio.get_running_loop()
try:
# 在线程池中执行GPU任务
result = await loop.run_in_executor(executor, func, *args, **kwargs)
future.set_result(result)
except Exception as e:
future.set_exception(e)
finally:
gpu_queue.task_done()
# 启动GPU消费者协程
async def start_gpu_workers():
for _ in range(MAX_GPU_CONCURRENT):
asyncio.create_task(gpu_worker())
@app.on_event("startup")
async def startup_event():
await start_gpu_workers()
# 封装GPU队列任务
async def submit_to_gpu(func, *args, **kwargs):
loop = asyncio.get_running_loop()
future = loop.create_future()
await gpu_queue.put((future, func, args, kwargs))
return await future
@app.on_event("shutdown")
def shutdown_event():
print("Shutting down thread pool executor...")
executor.shutdown(wait=True)
# ---------- TTS API ----------
@ -116,6 +154,9 @@ async def generate_tts(
inference_timesteps: int = Form(10),
do_normalize: bool = Form(True),
denoise: bool = Form(True),
retry_badcase: bool = Form(True),
retry_badcase_max_times: int = Form(3),
retry_badcase_ratio_threshold: float = Form(6.0),
prompt_wav: Optional[UploadFile] = None,
):
try:
@ -126,17 +167,23 @@ async def generate_tts(
prompt_path = tmp.name
background_tasks.add_task(os.remove, tmp.name)
# 提交到GPU队列
output_path = await submit_to_gpu(
demo.tts_generate,
text,
prompt_path,
prompt_text,
cfg_value,
inference_timesteps,
do_normalize,
denoise
)
# Submit to GPU via semaphore and executor
loop = asyncio.get_running_loop()
async with gpu_semaphore:
output_path = await loop.run_in_executor(
executor,
demo.tts_generate,
text,
prompt_path,
prompt_text,
cfg_value,
inference_timesteps,
do_normalize,
denoise,
retry_badcase,
retry_badcase_max_times,
retry_badcase_ratio_threshold,
)
# 后台删除生成的文件
background_tasks.add_task(os.remove, output_path)
@ -151,9 +198,72 @@ async def generate_tts(
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/generate_tts_streaming")
async def generate_tts_streaming(
background_tasks: BackgroundTasks,
text: str = Form(...),
prompt_text: Optional[str] = Form(None),
cfg_value: float = Form(2.0),
inference_timesteps: int = Form(10),
do_normalize: bool = Form(True),
denoise: bool = Form(True),
retry_badcase: bool = Form(True),
retry_badcase_max_times: int = Form(3),
retry_badcase_ratio_threshold: float = Form(6.0),
prompt_wav: Optional[UploadFile] = None,
):
try:
prompt_path = None
if prompt_wav:
# Save uploaded file to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(await prompt_wav.read())
prompt_path = tmp.name
# Ensure the temp file is deleted after the request is finished
background_tasks.add_task(os.remove, prompt_path)
async def stream_generator():
# This async generator consumes from a queue populated by a sync generator in a thread.
q = asyncio.Queue()
loop = asyncio.get_running_loop()
def producer():
# This runs in the executor thread and produces chunks.
try:
# This is a sync generator
for chunk in demo.tts_generate_streaming(
text, prompt_path, prompt_text, cfg_value,
inference_timesteps, do_normalize, denoise,
retry_badcase, retry_badcase_max_times, retry_badcase_ratio_threshold
):
loop.call_soon_threadsafe(q.put_nowait, chunk)
except Exception as e:
# Put the exception in the queue to be re-raised in the consumer
loop.call_soon_threadsafe(q.put_nowait, e)
finally:
# Signal the end of the stream
loop.call_soon_threadsafe(q.put_nowait, None)
# Acquire the GPU semaphore before starting the producer thread.
async with gpu_semaphore:
loop.run_in_executor(executor, producer)
while True:
chunk = await q.get()
if chunk is None:
break
if isinstance(chunk, Exception):
raise chunk
yield chunk
return StreamingResponse(stream_generator(), media_type="audio/wav")
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.get("/")
async def root():
return {"message": "VoxCPM API running 🚀", "endpoints": ["/generate_tts"]}
if __name__ == "__main__":
uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=8)
uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=4)

34
test_streaming.py Normal file
View File

@ -0,0 +1,34 @@
import asyncio
import httpx
import time
async def test_streaming():
url = "http://localhost:8880/generate_tts_streaming"
data = {
"text": "你好,这是一段流式输出的测试音频。",
"cfg_value": "2.0",
"inference_timesteps": "10",
"do_normalize": "True",
"denoise": "True"
}
start_time = time.time()
first_byte_received = False
async with httpx.AsyncClient() as client:
# 模拟文件上传(如果有 prompt_wav
# files = {'prompt_wav': open('test.wav', 'rb')}
async with client.stream("POST", url, data=data) as response:
print(f"状态码: {response.status_code}")
async for chunk in response.aiter_bytes():
if not first_byte_received:
ttfb = time.time() - start_time
print(f"🚀 首包到达 (TTFB): {ttfb:.4f}")
first_byte_received = True
print(f"收到数据块: {len(chunk)} 字节")
if __name__ == "__main__":
asyncio.run(test_streaming())