feat: Update Model Weights, use VoxCPM1.5 and model parameters, Supports streaming
This commit is contained in:
2
.github/workflows/ci-cd.yaml
vendored
2
.github/workflows/ci-cd.yaml
vendored
@ -7,7 +7,7 @@ on:
|
||||
# branches: [ main ]
|
||||
|
||||
env:
|
||||
VERSION: 0.0.1
|
||||
VERSION: 0.0.2
|
||||
REGISTRY: https://harbor.bwgdi.com
|
||||
REGISTRY_NAME: harbor.bwgdi.com
|
||||
IMAGE_NAME: voxcpmtts
|
||||
|
||||
33
README_BW.md
33
README_BW.md
@ -6,10 +6,17 @@ https://github.com/BoardWare-Genius/VoxCPM
|
||||
|
||||
| Version | Date | Summary |
|
||||
|---------|------------|---------------------------------|
|
||||
| 0.0.2 | 2026-01-21 | Supports streaming |
|
||||
| 0.0.1 | 2026-01-20 | Initial version |
|
||||
|
||||
### 🔄 Version Details
|
||||
|
||||
#### 🆕 0.0.2 – *2026-01-21*
|
||||
|
||||
- ✅ **Core Features**
|
||||
- Update Model Weights, use VoxCPM1.5 and model parameters
|
||||
- Supports streaming
|
||||
|
||||
#### 🆕 0.0.1 – *2026-01-20*
|
||||
|
||||
- ✅ **Core Features**
|
||||
@ -20,12 +27,14 @@ https://github.com/BoardWare-Genius/VoxCPM
|
||||
|
||||
# Start
|
||||
```bash
|
||||
docker pull harbor.bwgdi.com/library/voxcpmtts:0.0.1
|
||||
docker pull harbor.bwgdi.com/library/voxcpmtts:0.0.2
|
||||
|
||||
docker run -d --restart always -p 5001:5000 --gpus all --mount type=bind,source=/Workspace/NAS11/model,target=/models harbor.bwgdi.com/library/voxcpmtts:0.0.1
|
||||
docker run -d --restart always -p 5001:5000 --gpus all --mount type=bind,source=/Workspace/NAS11/model/Voice/VoxCPM,target=/models harbor.bwgdi.com/library/voxcpmtts:0.0.2
|
||||
```
|
||||
|
||||
# Usage
|
||||
|
||||
## Non-streaming
|
||||
```bash
|
||||
curl --location 'http://localhost:5001/generate_tts' \
|
||||
--form 'text="你好,这是一段测试文本"' \
|
||||
@ -34,5 +43,23 @@ curl --location 'http://localhost:5001/generate_tts' \
|
||||
--form 'inference_timesteps="10"' \
|
||||
--form 'do_normalize="true"' \
|
||||
--form 'denoise="true"' \
|
||||
--form 'prompt_wav=@"/assets/2play16k_2.wav"'
|
||||
--form 'retry_badcase="true"' \
|
||||
--form 'retry_badcase_max_times="3"' \
|
||||
--form 'retry_badcase_ratio_threshold="6.0"' \
|
||||
--form 'prompt_wav=@"/assets/2food16k_2.wav"'
|
||||
```
|
||||
|
||||
## Streaming
|
||||
```bash
|
||||
curl --location 'http://localhost:5001/generate_tts_streaming' \
|
||||
--form 'text="你好,这是一段测试文本"' \
|
||||
--form 'prompt_text="这是提示文本"' \
|
||||
--form 'cfg_value="2.0"' \
|
||||
--form 'inference_timesteps="10"' \
|
||||
--form 'do_normalize="true"' \
|
||||
--form 'denoise="true"' \
|
||||
--form 'retry_badcase="true"' \
|
||||
--form 'retry_badcase_max_times="3"' \
|
||||
--form 'retry_badcase_ratio_threshold="6.0"' \
|
||||
--form 'prompt_wav=@"/Workspace/NAS11/model/Voice/assets/2food16k_2.wav"'
|
||||
```
|
||||
|
||||
@ -2,6 +2,9 @@ import os
|
||||
import torch
|
||||
import tempfile
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import wave
|
||||
from io import BytesIO
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI, Form, UploadFile, BackgroundTasks
|
||||
@ -13,7 +16,7 @@ import uvicorn
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "/models/Voice/VoxCPM/VoxCPM1.5/"
|
||||
os.environ["HF_REPO_ID"] = "/models/VoxCPM1.5/"
|
||||
|
||||
|
||||
# ========== 模型类 ==========
|
||||
@ -22,7 +25,7 @@ class VoxCPMDemo:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🚀 Running on device: {self.device}")
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.default_local_model_dir = "/models/Voice/VoxCPM/VoxCPM1.5/"
|
||||
self.default_local_model_dir = "/models/VoxCPM1.5/"
|
||||
|
||||
def _resolve_model_dir(self) -> str:
|
||||
if os.path.isdir(self.default_local_model_dir):
|
||||
@ -46,6 +49,9 @@ class VoxCPMDemo:
|
||||
inference_timesteps: int = 10,
|
||||
do_normalize: bool = True,
|
||||
denoise: bool = True,
|
||||
retry_badcase: bool = True,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
) -> str:
|
||||
model = self.get_or_load_voxcpm()
|
||||
wav = model.generate(
|
||||
@ -56,54 +62,86 @@ class VoxCPMDemo:
|
||||
inference_timesteps=int(inference_timesteps),
|
||||
normalize=do_normalize,
|
||||
denoise=denoise,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
)
|
||||
tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
sf.write(tmp_wav.name, wav, 44100)
|
||||
sf.write(tmp_wav.name, wav, model.tts_model.sample_rate)
|
||||
torch.cuda.empty_cache()
|
||||
return tmp_wav.name
|
||||
|
||||
def tts_generate_streaming(
|
||||
self,
|
||||
text: str,
|
||||
prompt_wav_path: Optional[str] = None,
|
||||
prompt_text: Optional[str] = None,
|
||||
cfg_value: float = 2.0,
|
||||
inference_timesteps: int = 10,
|
||||
do_normalize: bool = True,
|
||||
denoise: bool = True,
|
||||
retry_badcase: bool = True,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
):
|
||||
"""Generates audio and yields it as a stream of WAV chunks."""
|
||||
model = self.get_or_load_voxcpm()
|
||||
|
||||
# 1. Yield a WAV header first.
|
||||
# The size fields will be 0, which is standard for streaming.
|
||||
SAMPLE_RATE = model.tts_model.sample_rate
|
||||
CHANNELS = 1
|
||||
SAMPLE_WIDTH = 2 # 16-bit
|
||||
|
||||
header_buf = BytesIO()
|
||||
with wave.open(header_buf, "wb") as wf:
|
||||
wf.setnchannels(CHANNELS)
|
||||
wf.setsampwidth(SAMPLE_WIDTH)
|
||||
wf.setframerate(SAMPLE_RATE)
|
||||
|
||||
yield header_buf.getvalue()
|
||||
|
||||
# 2. Generate and yield audio chunks.
|
||||
# NOTE: We assume a `generate_stream` method exists on the model that yields audio chunks.
|
||||
# You may need to change `generate_stream` to the actual method name in your version of voxcpm.
|
||||
try:
|
||||
stream = model.generate_streaming(
|
||||
text=text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
cfg_value=float(cfg_value),
|
||||
inference_timesteps=int(inference_timesteps),
|
||||
normalize=do_normalize,
|
||||
denoise=denoise,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
)
|
||||
for chunk_np in stream: # Assuming it yields numpy arrays
|
||||
# Ensure audio is in 16-bit PCM format for streaming
|
||||
if chunk_np.dtype in [np.float32, np.float64]:
|
||||
chunk_np = (chunk_np * 32767).astype(np.int16)
|
||||
yield chunk_np.tobytes()
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# ========== FastAPI ==========
|
||||
app = FastAPI(title="VoxCPM API", version="1.0.0")
|
||||
demo = VoxCPMDemo()
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=2) # CPU线程池,执行GPU任务
|
||||
gpu_queue = asyncio.Queue() # GPU并发队列
|
||||
MAX_GPU_CONCURRENT = 1 # 单GPU同时最多1个任务,可以调整
|
||||
# --- Concurrency Control ---
|
||||
# Use a semaphore to limit concurrent GPU tasks.
|
||||
MAX_GPU_CONCURRENT = int(os.environ.get("MAX_GPU_CONCURRENT", "1"))
|
||||
gpu_semaphore = asyncio.Semaphore(MAX_GPU_CONCURRENT)
|
||||
|
||||
# Use a thread pool for running blocking (CPU/GPU-bound) code.
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
# GPU队列消费者协程
|
||||
async def gpu_worker():
|
||||
while True:
|
||||
future, func, args, kwargs = await gpu_queue.get()
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
# 在线程池中执行GPU任务
|
||||
result = await loop.run_in_executor(executor, func, *args, **kwargs)
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
gpu_queue.task_done()
|
||||
|
||||
|
||||
# 启动GPU消费者协程
|
||||
async def start_gpu_workers():
|
||||
for _ in range(MAX_GPU_CONCURRENT):
|
||||
asyncio.create_task(gpu_worker())
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
await start_gpu_workers()
|
||||
|
||||
|
||||
# 封装GPU队列任务
|
||||
async def submit_to_gpu(func, *args, **kwargs):
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
await gpu_queue.put((future, func, args, kwargs))
|
||||
return await future
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event():
|
||||
print("Shutting down thread pool executor...")
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
|
||||
# ---------- TTS API ----------
|
||||
@ -116,6 +154,9 @@ async def generate_tts(
|
||||
inference_timesteps: int = Form(10),
|
||||
do_normalize: bool = Form(True),
|
||||
denoise: bool = Form(True),
|
||||
retry_badcase: bool = Form(True),
|
||||
retry_badcase_max_times: int = Form(3),
|
||||
retry_badcase_ratio_threshold: float = Form(6.0),
|
||||
prompt_wav: Optional[UploadFile] = None,
|
||||
):
|
||||
try:
|
||||
@ -126,17 +167,23 @@ async def generate_tts(
|
||||
prompt_path = tmp.name
|
||||
background_tasks.add_task(os.remove, tmp.name)
|
||||
|
||||
# 提交到GPU队列
|
||||
output_path = await submit_to_gpu(
|
||||
demo.tts_generate,
|
||||
text,
|
||||
prompt_path,
|
||||
prompt_text,
|
||||
cfg_value,
|
||||
inference_timesteps,
|
||||
do_normalize,
|
||||
denoise
|
||||
)
|
||||
# Submit to GPU via semaphore and executor
|
||||
loop = asyncio.get_running_loop()
|
||||
async with gpu_semaphore:
|
||||
output_path = await loop.run_in_executor(
|
||||
executor,
|
||||
demo.tts_generate,
|
||||
text,
|
||||
prompt_path,
|
||||
prompt_text,
|
||||
cfg_value,
|
||||
inference_timesteps,
|
||||
do_normalize,
|
||||
denoise,
|
||||
retry_badcase,
|
||||
retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold,
|
||||
)
|
||||
|
||||
# 后台删除生成的文件
|
||||
background_tasks.add_task(os.remove, output_path)
|
||||
@ -151,9 +198,72 @@ async def generate_tts(
|
||||
return JSONResponse({"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@app.post("/generate_tts_streaming")
|
||||
async def generate_tts_streaming(
|
||||
background_tasks: BackgroundTasks,
|
||||
text: str = Form(...),
|
||||
prompt_text: Optional[str] = Form(None),
|
||||
cfg_value: float = Form(2.0),
|
||||
inference_timesteps: int = Form(10),
|
||||
do_normalize: bool = Form(True),
|
||||
denoise: bool = Form(True),
|
||||
retry_badcase: bool = Form(True),
|
||||
retry_badcase_max_times: int = Form(3),
|
||||
retry_badcase_ratio_threshold: float = Form(6.0),
|
||||
prompt_wav: Optional[UploadFile] = None,
|
||||
):
|
||||
try:
|
||||
prompt_path = None
|
||||
if prompt_wav:
|
||||
# Save uploaded file to a temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
tmp.write(await prompt_wav.read())
|
||||
prompt_path = tmp.name
|
||||
# Ensure the temp file is deleted after the request is finished
|
||||
background_tasks.add_task(os.remove, prompt_path)
|
||||
|
||||
async def stream_generator():
|
||||
# This async generator consumes from a queue populated by a sync generator in a thread.
|
||||
q = asyncio.Queue()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def producer():
|
||||
# This runs in the executor thread and produces chunks.
|
||||
try:
|
||||
# This is a sync generator
|
||||
for chunk in demo.tts_generate_streaming(
|
||||
text, prompt_path, prompt_text, cfg_value,
|
||||
inference_timesteps, do_normalize, denoise,
|
||||
retry_badcase, retry_badcase_max_times, retry_badcase_ratio_threshold
|
||||
):
|
||||
loop.call_soon_threadsafe(q.put_nowait, chunk)
|
||||
except Exception as e:
|
||||
# Put the exception in the queue to be re-raised in the consumer
|
||||
loop.call_soon_threadsafe(q.put_nowait, e)
|
||||
finally:
|
||||
# Signal the end of the stream
|
||||
loop.call_soon_threadsafe(q.put_nowait, None)
|
||||
|
||||
# Acquire the GPU semaphore before starting the producer thread.
|
||||
async with gpu_semaphore:
|
||||
loop.run_in_executor(executor, producer)
|
||||
while True:
|
||||
chunk = await q.get()
|
||||
if chunk is None:
|
||||
break
|
||||
if isinstance(chunk, Exception):
|
||||
raise chunk
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="audio/wav")
|
||||
|
||||
except Exception as e:
|
||||
return JSONResponse({"error": str(e)}, status_code=500)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "VoxCPM API running 🚀", "endpoints": ["/generate_tts"]}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=8)
|
||||
uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=4)
|
||||
34
test_streaming.py
Normal file
34
test_streaming.py
Normal file
@ -0,0 +1,34 @@
|
||||
import asyncio
|
||||
import httpx
|
||||
import time
|
||||
|
||||
async def test_streaming():
|
||||
url = "http://localhost:8880/generate_tts_streaming"
|
||||
data = {
|
||||
"text": "你好,这是一段流式输出的测试音频。",
|
||||
"cfg_value": "2.0",
|
||||
"inference_timesteps": "10",
|
||||
"do_normalize": "True",
|
||||
"denoise": "True"
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
first_byte_received = False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# 模拟文件上传(如果有 prompt_wav)
|
||||
# files = {'prompt_wav': open('test.wav', 'rb')}
|
||||
|
||||
async with client.stream("POST", url, data=data) as response:
|
||||
print(f"状态码: {response.status_code}")
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
if not first_byte_received:
|
||||
ttfb = time.time() - start_time
|
||||
print(f"🚀 首包到达 (TTFB): {ttfb:.4f} 秒")
|
||||
first_byte_received = True
|
||||
|
||||
print(f"收到数据块: {len(chunk)} 字节")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_streaming())
|
||||
Reference in New Issue
Block a user