Compare commits

..

11 Commits

Author SHA1 Message Date
41b5d321f6 fix: ci-cd
All checks were successful
Build container / build-docker (push) Successful in 29m27s
2026-02-09 19:23:11 +08:00
49367d03a8 feat: Optimized configuration 2026-02-09 18:28:49 +08:00
cd4584ebae feat: Optimized configuration 2026-02-09 18:22:06 +08:00
76c8bdbcfc docs: docker gpu use 2026-01-21 17:55:36 +08:00
2b0c569b7a feat: Update Model Weights, use VoxCPM1.5 and model parameters, Supports streaming 2026-01-21 17:52:57 +08:00
721c53fe87 first commit 2026-01-20 17:37:43 +08:00
e8dd956fc2 Print all log messages to stderr instead of stdout 2026-01-12 15:30:45 +08:00
db75a7269b Merge pull request #141 from vytskalt/main
Print debug messages to stderr instead of stdout
2026-01-12 15:06:51 +08:00
f2e203d5e2 print debug messages to stderr instead of stdout 2026-01-09 20:05:52 +02:00
6ecc00a5d3 Merge pull request #139 from lrjerryli/main
Add torchcodec to dependencies
2026-01-04 16:08:09 +08:00
8cfd9d155a Add torchcodec to dependencies
ImportError: TorchCodec is required for load_with_torchcodec.
2026-01-02 21:00:23 +08:00
18 changed files with 965 additions and 112 deletions

33
.github/workflows/ci-cd.yaml vendored Normal file
View File

@ -0,0 +1,33 @@
name: Build container
env:
VERSION: 0.0.3
REGISTRY: https://harbor.bwgdi.com
REGISTRY_NAME: harbor.bwgdi.com
REGISTRY_PATH: library
DOCKER_NAME: voxcpmtts
on:
push:
branches:
- main
workflow_dispatch:
jobs:
build-docker:
runs-on: builder-ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ secrets.BWGDI_NAME }}
password: ${{ secrets.BWGDI_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Build and push
uses: docker/build-push-action@v4
with:
context: .
file: ./Dockerfile
push: true
tags: ${{ env.REGISTRY_NAME }}/${{ env.REGISTRY_PATH }}/${{ env.DOCKER_NAME }}:${{ env.VERSION }}

40
.gitignore vendored
View File

@ -1,4 +1,42 @@
launch.json
__pycache__
voxcpm.egg-info
.DS_Store
# Python-generated files
__pycache__/
*.py[cod]
*$py.class
# Distribution / packaging
build/
dist/
wheels/
*.egg-info/
# Unit test / coverage reports
.pytest_cache/
.coverage
htmlcov/
coverage.xml
# Logs
*.log
log/*.log
# Virtual environments
.venv/
venv/
env/
# IDE settings
.vscode/
.idea/
# OS generated files
.DS_Store
# Generated files
*.wav
*.pdf
*.lock

17
Dockerfile Normal file
View File

@ -0,0 +1,17 @@
FROM python:3.10.12-slim
RUN apt-get update && apt-get install -y \
build-essential \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
# Create app directory
WORKDIR /app
COPY api_concurrent.py requirements.txt ./
RUN pip install -r requirements.txt
ENV VOXCPM_MODEL_ID="/models/VoxCPM1.5/" \
VOXCPM_CPU_WORKERS="2" \
VOXCPM_UVICORN_WORKERS="1" \
MAX_GPU_CONCURRENT="1"
EXPOSE 5000
CMD [ "python", "./api_concurrent.py" ]

83
README_BW.md Normal file
View File

@ -0,0 +1,83 @@
# VoxCPM-TTS
https://github.com/BoardWare-Genius/VoxCPM
## 📦 VoxCPM-TTS Version History
| Version | Date | Summary |
|---------|------------|---------------------------------|
| 0.0.3 | 2026-02-09 | Optimized configuration & Model support |
| 0.0.2 | 2026-01-21 | Supports streaming |
| 0.0.1 | 2026-01-20 | Initial version |
### 🔄 Version Details
#### 🆕 0.0.3 *2026-02-09*
-**Configuration & Deployment**
- Supports configuring model path via `VOXCPM_MODEL_ID`
- Supports configuring CPU workers via `VOXCPM_CPU_WORKERS`
- Supports configuring Uvicorn workers via `VOXCPM_UVICORN_WORKERS`
#### 🆕 0.0.2 *2026-01-21*
-**Core Features**
- Update Model Weights, use VoxCPM1.5 and model parameters
- Supports streaming
#### 🆕 0.0.1 *2026-01-20*
-**Core Features**
- Initial VoxCPM-TTS
---
# Start
```bash
docker pull harbor.bwgdi.com/library/voxcpmtts:0.0.3
# Run with custom configuration
# -e VOXCPM_MODEL_ID: Path to the model directory inside container
# -e VOXCPM_CPU_WORKERS: Number of threads for CPU-bound tasks
# -e VOXCPM_UVICORN_WORKERS: Number of uvicorn workers
# -e MAX_GPU_CONCURRENT: Max concurrent GPU tasks
docker run -d --restart always -p 5001:5000 --gpus '"device=0"' \
-e VOXCPM_MODEL_ID="/models/VoxCPM1.5/" \
-e VOXCPM_CPU_WORKERS="2" \
-e VOXCPM_UVICORN_WORKERS="1" \
-e MAX_GPU_CONCURRENT="1" \
--mount type=bind,source=/Workspace/NAS11/model/Voice/VoxCPM,target=/models \
harbor.bwgdi.com/library/voxcpmtts:0.0.3
```
# Usage
## Non-streaming
```bash
curl --location 'http://localhost:5001/generate_tts' \
--form 'text="你好,这是一段测试文本"' \
--form 'prompt_text="这是提示文本"' \
--form 'cfg_value="2.0"' \
--form 'inference_timesteps="10"' \
--form 'do_normalize="true"' \
--form 'denoise="true"' \
--form 'retry_badcase="true"' \
--form 'retry_badcase_max_times="3"' \
--form 'retry_badcase_ratio_threshold="6.0"' \
--form 'prompt_wav=@"/assets/2food16k_2.wav"'
```
## Streaming
```bash
curl --location 'http://localhost:5001/generate_tts_streaming' \
--form 'text="你好,这是一段测试文本"' \
--form 'prompt_text="这是提示文本"' \
--form 'cfg_value="2.0"' \
--form 'inference_timesteps="10"' \
--form 'do_normalize="true"' \
--form 'denoise="true"' \
--form 'retry_badcase="true"' \
--form 'retry_badcase_max_times="3"' \
--form 'retry_badcase_ratio_threshold="6.0"' \
--form 'prompt_wav=@"/Workspace/NAS11/model/Voice/assets/2food16k_2.wav"'
```

272
api_concurrent.py Normal file
View File

@ -0,0 +1,272 @@
import os
import torch
import tempfile
import soundfile as sf
import numpy as np
import wave
from io import BytesIO
import asyncio
from typing import Optional
from fastapi import FastAPI, Form, UploadFile, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
from concurrent.futures import ThreadPoolExecutor
import voxcpm
import uvicorn
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "/models/VoxCPM1.5/"
# ========== 模型类 ==========
class VoxCPMDemo:
def __init__(self) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {self.device}")
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = os.environ.get("VOXCPM_MODEL_ID", "/models/VoxCPM1.5/")
def _resolve_model_dir(self) -> str:
if os.path.isdir(self.default_local_model_dir):
return self.default_local_model_dir
return "models"
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is None:
print("🔄 Loading VoxCPM model...")
model_dir = self._resolve_model_dir()
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
print("✅ VoxCPM model loaded.")
return self.voxcpm_model
def tts_generate(
self,
text: str,
prompt_wav_path: Optional[str] = None,
prompt_text: Optional[str] = None,
cfg_value: float = 2.0,
inference_timesteps: int = 10,
do_normalize: bool = True,
denoise: bool = True,
retry_badcase: bool = True,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
) -> str:
model = self.get_or_load_voxcpm()
wav = model.generate(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
cfg_value=float(cfg_value),
inference_timesteps=int(inference_timesteps),
normalize=do_normalize,
denoise=denoise,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
)
tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
sf.write(tmp_wav.name, wav, model.tts_model.sample_rate)
torch.cuda.empty_cache()
return tmp_wav.name
def tts_generate_streaming(
self,
text: str,
prompt_wav_path: Optional[str] = None,
prompt_text: Optional[str] = None,
cfg_value: float = 2.0,
inference_timesteps: int = 10,
do_normalize: bool = True,
denoise: bool = True,
retry_badcase: bool = True,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
):
"""Generates audio and yields it as a stream of WAV chunks."""
model = self.get_or_load_voxcpm()
# 1. Yield a WAV header first.
# The size fields will be 0, which is standard for streaming.
SAMPLE_RATE = model.tts_model.sample_rate
CHANNELS = 1
SAMPLE_WIDTH = 2 # 16-bit
header_buf = BytesIO()
with wave.open(header_buf, "wb") as wf:
wf.setnchannels(CHANNELS)
wf.setsampwidth(SAMPLE_WIDTH)
wf.setframerate(SAMPLE_RATE)
yield header_buf.getvalue()
# 2. Generate and yield audio chunks.
# NOTE: We assume a `generate_stream` method exists on the model that yields audio chunks.
# You may need to change `generate_stream` to the actual method name in your version of voxcpm.
try:
stream = model.generate_streaming(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
cfg_value=float(cfg_value),
inference_timesteps=int(inference_timesteps),
normalize=do_normalize,
denoise=denoise,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
)
for chunk_np in stream: # Assuming it yields numpy arrays
# Ensure audio is in 16-bit PCM format for streaming
if chunk_np.dtype in [np.float32, np.float64]:
chunk_np = (chunk_np * 32767).astype(np.int16)
yield chunk_np.tobytes()
finally:
torch.cuda.empty_cache()
# ========== FastAPI ==========
app = FastAPI(title="VoxCPM API", version="1.0.0")
demo = VoxCPMDemo()
# --- Concurrency Control ---
# Use a semaphore to limit concurrent GPU tasks.
MAX_GPU_CONCURRENT = int(os.environ.get("MAX_GPU_CONCURRENT", "1"))
gpu_semaphore = asyncio.Semaphore(MAX_GPU_CONCURRENT)
# Use a thread pool for running blocking (CPU/GPU-bound) code.
MAX_CPU_WORKERS = int(os.environ.get("VOXCPM_CPU_WORKERS", "2"))
executor = ThreadPoolExecutor(max_workers=MAX_CPU_WORKERS)
@app.on_event("shutdown")
def shutdown_event():
print("Shutting down thread pool executor...")
executor.shutdown(wait=True)
# ---------- TTS API ----------
@app.post("/generate_tts")
async def generate_tts(
background_tasks: BackgroundTasks,
text: str = Form(...),
prompt_text: Optional[str] = Form(None),
cfg_value: float = Form(2.0),
inference_timesteps: int = Form(10),
do_normalize: bool = Form(True),
denoise: bool = Form(True),
retry_badcase: bool = Form(True),
retry_badcase_max_times: int = Form(3),
retry_badcase_ratio_threshold: float = Form(6.0),
prompt_wav: Optional[UploadFile] = None,
):
try:
prompt_path = None
if prompt_wav:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(await prompt_wav.read())
prompt_path = tmp.name
background_tasks.add_task(os.remove, tmp.name)
# Submit to GPU via semaphore and executor
loop = asyncio.get_running_loop()
async with gpu_semaphore:
output_path = await loop.run_in_executor(
executor,
demo.tts_generate,
text,
prompt_path,
prompt_text,
cfg_value,
inference_timesteps,
do_normalize,
denoise,
retry_badcase,
retry_badcase_max_times,
retry_badcase_ratio_threshold,
)
# 后台删除生成的文件
background_tasks.add_task(os.remove, output_path)
return StreamingResponse(
open(output_path, "rb"),
media_type="audio/wav",
headers={"Content-Disposition": 'attachment; filename="output.wav"'}
)
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.post("/generate_tts_streaming")
async def generate_tts_streaming(
background_tasks: BackgroundTasks,
text: str = Form(...),
prompt_text: Optional[str] = Form(None),
cfg_value: float = Form(2.0),
inference_timesteps: int = Form(10),
do_normalize: bool = Form(True),
denoise: bool = Form(True),
retry_badcase: bool = Form(True),
retry_badcase_max_times: int = Form(3),
retry_badcase_ratio_threshold: float = Form(6.0),
prompt_wav: Optional[UploadFile] = None,
):
try:
prompt_path = None
if prompt_wav:
# Save uploaded file to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(await prompt_wav.read())
prompt_path = tmp.name
# Ensure the temp file is deleted after the request is finished
background_tasks.add_task(os.remove, prompt_path)
async def stream_generator():
# This async generator consumes from a queue populated by a sync generator in a thread.
q = asyncio.Queue()
loop = asyncio.get_running_loop()
def producer():
# This runs in the executor thread and produces chunks.
try:
# This is a sync generator
for chunk in demo.tts_generate_streaming(
text, prompt_path, prompt_text, cfg_value,
inference_timesteps, do_normalize, denoise,
retry_badcase, retry_badcase_max_times, retry_badcase_ratio_threshold
):
loop.call_soon_threadsafe(q.put_nowait, chunk)
except Exception as e:
# Put the exception in the queue to be re-raised in the consumer
loop.call_soon_threadsafe(q.put_nowait, e)
finally:
# Signal the end of the stream
loop.call_soon_threadsafe(q.put_nowait, None)
# Acquire the GPU semaphore before starting the producer thread.
async with gpu_semaphore:
loop.run_in_executor(executor, producer)
while True:
chunk = await q.get()
if chunk is None:
break
if isinstance(chunk, Exception):
raise chunk
yield chunk
return StreamingResponse(stream_generator(), media_type="audio/wav")
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
@app.get("/")
async def root():
return {"message": "VoxCPM API running 🚀", "endpoints": ["/generate_tts"]}
if __name__ == "__main__":
uvicorn_workers = int(os.environ.get("VOXCPM_UVICORN_WORKERS", "1"))
uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=uvicorn_workers)

15
app.py
View File

@ -1,4 +1,5 @@
import os
import sys
import numpy as np
import torch
import gradio as gr
@ -16,7 +17,7 @@ import voxcpm
class VoxCPMDemo:
def __init__(self) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {self.device}")
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
# ASR model for prompt text recognition
self.asr_model_id = "iic/SenseVoiceSmall"
@ -49,10 +50,10 @@ class VoxCPMDemo:
try:
from huggingface_hub import snapshot_download # type: ignore
os.makedirs(target_dir, exist_ok=True)
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
except Exception as e:
print(f"Warning: HF download failed: {e}. Falling back to 'data'.")
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
return "models"
return target_dir
return "models"
@ -60,11 +61,11 @@ class VoxCPMDemo:
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is not None:
return self.voxcpm_model
print("Model not loaded, initializing...")
print("Model not loaded, initializing...", file=sys.stderr)
model_dir = self._resolve_model_dir()
print(f"Using model dir: {model_dir}")
print(f"Using model dir: {model_dir}", file=sys.stderr)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
print("Model loaded successfully.")
print("Model loaded successfully.", file=sys.stderr)
return self.voxcpm_model
# ---------- Functional endpoints ----------
@ -98,7 +99,7 @@ class VoxCPMDemo:
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
prompt_text = prompt_text_input if prompt_text_input else None
print(f"Generating audio for text: '{text[:60]}...'")
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
wav = current_model.generate(
text=text,
prompt_text=prompt_text,

View File

@ -104,7 +104,7 @@ def get_timestamp_str():
def get_or_load_asr_model():
global asr_model
if asr_model is None:
print("Loading ASR model (SenseVoiceSmall)...")
print("Loading ASR model (SenseVoiceSmall)...", file=sys.stderr)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
asr_model = AutoModel(
model="iic/SenseVoiceSmall",
@ -123,7 +123,7 @@ def recognize_audio(audio_path):
text = res[0]["text"].split('|>')[-1]
return text
except Exception as e:
print(f"ASR Error: {e}")
print(f"ASR Error: {e}", file=sys.stderr)
return ""
def scan_lora_checkpoints(root_dir="lora", with_info=False):
@ -181,7 +181,7 @@ def load_lora_config_from_checkpoint(lora_path):
if lora_cfg_dict:
return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model")
except Exception as e:
print(f"Warning: Failed to load lora_config.json: {e}")
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
return None, None
def get_default_lora_config():
@ -197,7 +197,7 @@ def get_default_lora_config():
def load_model(pretrained_path, lora_path=None):
global current_model
print(f"Loading model from {pretrained_path}...")
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
lora_config = None
lora_weights_path = None
@ -209,11 +209,11 @@ def load_model(pretrained_path, lora_path=None):
# Try to load LoRA config from lora_config.json
lora_config, _ = load_lora_config_from_checkpoint(full_lora_path)
if lora_config:
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json")
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json", file=sys.stderr)
else:
# Fallback to default config for old checkpoints
lora_config = get_default_lora_config()
print("Using default LoRA config (lora_config.json not found)")
print("Using default LoRA config (lora_config.json not found)", file=sys.stderr)
# Always init with a default LoRA config to allow hot-swapping later
if lora_config is None:
@ -251,36 +251,36 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
# 优先使用保存的 base_model 路径
if os.path.exists(saved_base_model):
base_model_path = saved_base_model
print(f"Using base model from LoRA config: {base_model_path}")
print(f"Using base model from LoRA config: {base_model_path}", file=sys.stderr)
else:
print(f"Warning: Saved base_model path not found: {saved_base_model}")
print(f"Falling back to default: {base_model_path}")
print(f"Warning: Saved base_model path not found: {saved_base_model}", file=sys.stderr)
print(f"Falling back to default: {base_model_path}", file=sys.stderr)
except Exception as e:
print(f"Warning: Failed to read base_model from LoRA config: {e}")
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
# 加载模型
try:
print(f"Loading base model: {base_model_path}")
print(f"Loading base model: {base_model_path}", file=sys.stderr)
status_msg = load_model(base_model_path)
if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}")
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
except Exception as e:
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
print(error_msg)
print(error_msg, file=sys.stderr)
return None, error_msg
# Handle LoRA hot-swapping
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
print(f"Hot-loading LoRA: {full_lora_path}")
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
try:
current_model.load_lora(full_lora_path)
current_model.set_lora_enabled(True)
except Exception as e:
print(f"Error loading LoRA: {e}")
print(f"Error loading LoRA: {e}", file=sys.stderr)
return None, f"Error loading LoRA: {e}"
else:
print("Disabling LoRA")
print("Disabling LoRA", file=sys.stderr)
current_model.set_lora_enabled(False)
if seed != -1:
@ -297,11 +297,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
# 如果没有提供参考文本,尝试自动识别
if not prompt_text or not prompt_text.strip():
print("参考音频已提供但缺少文本,自动识别中...")
print("参考音频已提供但缺少文本,自动识别中...", file=sys.stderr)
try:
final_prompt_text = recognize_audio(prompt_wav)
if final_prompt_text:
print(f"自动识别文本: {final_prompt_text}")
print(f"自动识别文本: {final_prompt_text}", file=sys.stderr)
else:
return None, "错误:无法识别参考音频内容,请手动填写参考文本"
except Exception as e:
@ -1114,12 +1114,12 @@ with gr.Blocks(
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
# 输出调试信息
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点")
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点", file=sys.stderr)
for ckpt_path, base_model in checkpoints_with_info:
if base_model:
print(f" - {ckpt_path} (Base Model: {base_model})")
print(f" - {ckpt_path} (Base Model: {base_model})", file=sys.stderr)
else:
print(f" - {ckpt_path}")
print(f" - {ckpt_path}", file=sys.stderr)
return gr.update(choices=choices, value="None")

View File

@ -27,6 +27,7 @@ requires-python = ">=3.10"
dependencies = [
"torch>=2.5.0",
"torchaudio>=2.5.0",
"torchcodec",
"transformers>=4.36.2",
"einops",
"gradio<6",
@ -46,8 +47,7 @@ dependencies = [
"funasr",
"spaces",
"argbind",
"safetensors"
"safetensors",
]
[project.optional-dependencies]

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
soundfile
fastapi
uvicorn
voxcpm
torchcodec

View File

@ -23,6 +23,7 @@ With voice cloning:
"""
import argparse
import sys
from pathlib import Path
import soundfile as sf
@ -92,7 +93,7 @@ def main():
args = parse_args()
# Load model from checkpoint directory (no denoiser)
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
print(f"[FT Inference] Loading model: {args.ckpt_dir}", file=sys.stderr)
model = VoxCPM.from_pretrained(
hf_model_id=args.ckpt_dir,
load_denoiser=False,
@ -103,10 +104,10 @@ def main():
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None
print(f"[FT Inference] Synthesizing: text='{args.text}'")
print(f"[FT Inference] Synthesizing: text='{args.text}'", file=sys.stderr)
if prompt_wav_path:
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
print(f"[FT Inference] Reference text: {prompt_text}")
print(f"[FT Inference] Using reference audio: {prompt_wav_path}", file=sys.stderr)
print(f"[FT Inference] Reference text: {prompt_text}", file=sys.stderr)
audio_np = model.generate(
text=args.text,
@ -124,7 +125,7 @@ def main():
out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
if __name__ == "__main__":

View File

@ -24,6 +24,7 @@ Note: The script reads base_model path and lora_config from lora_config.json
import argparse
import json
import sys
from pathlib import Path
import soundfile as sf
@ -124,13 +125,13 @@ def main():
lora_cfg_dict = lora_info.get("lora_config", {})
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
print(f"Loaded config from: {lora_config_path}")
print(f" Base model: {pretrained_path}")
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
print(f" Base model: {pretrained_path}", file=sys.stderr)
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
# 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
print(f" LoRA weights: {ckpt_dir}")
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
print(f" LoRA weights: {ckpt_dir}", file=sys.stderr)
model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path,
load_denoiser=False,
@ -145,10 +146,10 @@ def main():
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[2/2] Starting synthesis tests...")
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
# === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...")
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
@ -161,10 +162,10 @@ def main():
)
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
model.set_lora_enabled(False)
audio_np = model.generate(
text=args.text,
@ -178,10 +179,10 @@ def main():
)
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
model.set_lora_enabled(True)
audio_np = model.generate(
text=args.text,
@ -195,10 +196,10 @@ def main():
)
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
model.unload_lora()
audio_np = model.generate(
text=args.text,
@ -212,12 +213,12 @@ def main():
)
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters")
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
@ -230,14 +231,14 @@ def main():
)
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(f"\n[Done] All tests completed!")
print(f" - with_lora: {lora_output}")
print(f" - lora_disabled: {disabled_output}")
print(f" - lora_reenabled: {reenabled_output}")
print(f" - lora_reset: {reset_output}")
print(f" - lora_reloaded: {reload_output}")
print(f"\n[Done] All tests completed!", file=sys.stderr)
print(f" - with_lora: {lora_output}", file=sys.stderr)
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
print(f" - lora_reset: {reset_output}", file=sys.stderr)
print(f" - lora_reloaded: {reload_output}", file=sys.stderr)
if __name__ == "__main__":

View File

@ -24,7 +24,7 @@ try:
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format")
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
from voxcpm.model import VoxCPMModel
from voxcpm.model.voxcpm import LoRAConfig
@ -170,7 +170,7 @@ def train(
# Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0:
for name, param in model.named_parameters():
print(name, param.requires_grad)
print(name, param.requires_grad, file=sys.stderr)
optimizer = AdamW(
(p for p in model.parameters() if p.requires_grad),
@ -210,12 +210,12 @@ def train(
cur_step = int(_resume.get("step", start_step))
except Exception:
cur_step = start_step
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
print("Checkpoint saved. Exiting.")
print("Checkpoint saved. Exiting.", file=sys.stderr)
except Exception as e:
print(f"Error saving checkpoint on signal: {e}")
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
os._exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
@ -553,7 +553,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
# Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded LoRA weights from {lora_weights_path}")
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
else:
# Full finetune: load model.safetensors or pytorch_model.bin
model_path = latest_folder / "model.safetensors"
@ -569,26 +569,26 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded model weights from {model_path}")
print(f"Loaded model weights from {model_path}", file=sys.stderr)
# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
print(f"Loaded optimizer state from {optimizer_path}")
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
print(f"Loaded scheduler state from {scheduler_path}")
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
# Try to infer step from checkpoint folders
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps)
print(f"Resuming from step {resume_step}")
print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step
return 0
@ -670,7 +670,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
latest_link.unlink()
shutil.copytree(folder, latest_link)
except Exception:
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
if __name__ == "__main__":

View File

@ -45,7 +45,7 @@ def load_model(args) -> VoxCPM:
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
"""
print("Loading VoxCPM model...")
print("Loading VoxCPM model...", file=sys.stderr)
# 兼容旧参数ZIPENHANCER_MODEL_PATH 环境变量作为默认
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
@ -66,7 +66,7 @@ def load_model(args) -> VoxCPM:
dropout=getattr(args, "lora_dropout", 0.0),
)
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}", file=sys.stderr)
# Load from local path if provided
if getattr(args, "model_path", None):
@ -78,10 +78,10 @@ def load_model(args) -> VoxCPM:
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
print("Model loaded (local).")
print("Model loaded (local).", file=sys.stderr)
return model
except Exception as e:
print(f"Failed to load model (local): {e}")
print(f"Failed to load model (local): {e}", file=sys.stderr)
sys.exit(1)
# Otherwise, try from_pretrained (Hub); exit on failure
@ -95,10 +95,10 @@ def load_model(args) -> VoxCPM:
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
print("Model loaded (from_pretrained).")
print("Model loaded (from_pretrained).", file=sys.stderr)
return model
except Exception as e:
print(f"Failed to load model (from_pretrained): {e}")
print(f"Failed to load model (from_pretrained): {e}", file=sys.stderr)
sys.exit(1)
@ -106,15 +106,15 @@ def cmd_clone(args):
"""Voice cloning command."""
# Validate inputs
if not args.text:
print("Error: Please provide text to synthesize (--text)")
print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
sys.exit(1)
if not args.prompt_audio:
print("Error: Voice cloning requires a reference audio (--prompt-audio)")
print("Error: Voice cloning requires a reference audio (--prompt-audio)", file=sys.stderr)
sys.exit(1)
if not args.prompt_text:
print("Error: Voice cloning requires a reference text (--prompt-text)")
print("Error: Voice cloning requires a reference text (--prompt-text)", file=sys.stderr)
sys.exit(1)
# Validate files
@ -125,9 +125,9 @@ def cmd_clone(args):
model = load_model(args)
# Generate audio
print(f"Synthesizing text: {args.text}")
print(f"Reference audio: {prompt_audio_path}")
print(f"Reference text: {args.prompt_text}")
print(f"Synthesizing text: {args.text}", file=sys.stderr)
print(f"Reference audio: {prompt_audio_path}", file=sys.stderr)
print(f"Reference text: {args.prompt_text}", file=sys.stderr)
audio_array = model.generate(
text=args.text,
@ -141,25 +141,25 @@ def cmd_clone(args):
# Save audio
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}")
print(f"Saved audio to: {output_path}", file=sys.stderr)
# Stats
duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s")
print(f"Duration: {duration:.2f}s", file=sys.stderr)
def cmd_synthesize(args):
"""Direct TTS synthesis command."""
# Validate inputs
if not args.text:
print("Error: Please provide text to synthesize (--text)")
print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
sys.exit(1)
# Validate output path
output_path = validate_output_path(args.output)
# Load model
model = load_model(args)
# Generate audio
print(f"Synthesizing text: {args.text}")
print(f"Synthesizing text: {args.text}", file=sys.stderr)
audio_array = model.generate(
text=args.text,
@ -173,11 +173,11 @@ def cmd_synthesize(args):
# Save audio
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}")
print(f"Saved audio to: {output_path}", file=sys.stderr)
# Stats
duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s")
print(f"Duration: {duration:.2f}s", file=sys.stderr)
def cmd_batch(args):
@ -191,12 +191,12 @@ def cmd_batch(args):
with open(input_file, 'r', encoding='utf-8') as f:
texts = [line.strip() for line in f if line.strip()]
except Exception as e:
print(f"Failed to read input file: {e}")
print(f"Failed to read input file: {e}", file=sys.stderr)
sys.exit(1)
if not texts:
print("Error: Input file is empty or contains no valid lines")
print("Error: Input file is empty or contains no valid lines", file=sys.stderr)
sys.exit(1)
print(f"Found {len(texts)} lines to process")
print(f"Found {len(texts)} lines to process", file=sys.stderr)
model = load_model(args)
prompt_audio_path = None
@ -205,7 +205,7 @@ def cmd_batch(args):
success_count = 0
for i, text in enumerate(texts, 1):
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...")
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...", file=sys.stderr)
try:
audio_array = model.generate(
@ -221,14 +221,14 @@ def cmd_batch(args):
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
duration = len(audio_array) / model.tts_model.sample_rate
print(f" Saved: {output_file} ({duration:.2f}s)")
print(f" Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
success_count += 1
except Exception as e:
print(f" Failed: {e}")
print(f" Failed: {e}", file=sys.stderr)
continue
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded")
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
def _build_unified_parser():
"""Build unified argument parser (no subcommands, route by args)."""
@ -296,14 +296,14 @@ def main():
# Routing: prefer batch → single (clone/direct)
if args.input:
if not args.output_dir:
print("Error: Batch mode requires --output-dir")
print("Error: Batch mode requires --output-dir", file=sys.stderr)
parser.print_help()
sys.exit(1)
return cmd_batch(args)
# Single-sample mode
if not args.text or not args.output:
print("Error: Single-sample mode requires --text and --output")
print("Error: Single-sample mode requires --text and --output", file=sys.stderr)
parser.print_help()
sys.exit(1)
@ -316,7 +316,7 @@ def main():
args.prompt_text = f.read()
if not args.prompt_audio or not args.prompt_text:
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
print("Error: Voice cloning requires both --prompt-audio and --prompt-text", file=sys.stderr)
sys.exit(1)
return cmd_clone(args)

View File

@ -1,4 +1,5 @@
import os
import sys
import re
import tempfile
import numpy as np
@ -30,7 +31,7 @@ class VoxCPM:
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
# If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
@ -39,15 +40,15 @@ class VoxCPM:
enable_dit=True,
enable_proj=False,
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Load LoRA weights if path is provided
if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}")
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None:
@ -56,7 +57,7 @@ class VoxCPM:
else:
self.denoiser = None
if optimize:
print("Warm up VoxCPMModel...")
print("Warm up VoxCPMModel...", file=sys.stderr)
self.tts_model.generate(
target_text="Hello, this is the first test sentence.",
max_len=10,

View File

@ -19,6 +19,7 @@ limitations under the License.
"""
import os
import sys
from typing import Tuple, Union, Generator, List, Optional
import torch
@ -120,7 +121,7 @@ class VoxCPMModel(nn.Module):
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config)
@ -228,7 +229,7 @@ class VoxCPMModel(nn.Module):
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except Exception as e:
print(f"Warning: torch.compile disabled - {e}")
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
return self
def forward(
@ -459,7 +460,7 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
retry_badcase_times += 1
continue
else:
@ -683,7 +684,7 @@ class VoxCPMModel(nn.Module):
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
retry_badcase_times += 1
continue
else:
@ -868,10 +869,10 @@ class VoxCPMModel(nn.Module):
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}")
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
model_state_dict = load_file(safetensors_path)
elif os.path.exists(pytorch_model_path):
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
checkpoint = torch.load(
pytorch_model_path,
map_location="cpu",

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import contextlib
import sys
import time
from pathlib import Path
from typing import Dict, Optional
@ -36,7 +37,7 @@ class TrainingTracker:
# ------------------------------------------------------------------ #
def print(self, message: str):
if self.rank == 0:
print(message, flush=True)
print(message, flush=True, file=sys.stderr)
if self.log_file:
with self.log_file.open("a", encoding="utf-8") as f:
f.write(message + "\n")

365
test_concurrent.py Normal file
View File

@ -0,0 +1,365 @@
import asyncio
import aiohttp
import time
import statistics
import csv
from typing import List, Dict
from datetime import datetime
# ---------------- 基本配置 ----------------
BASE_URL = "http://localhost:8880/generate_tts"
REQUEST_TEMPLATE = {
"text": "你这个新买的T-shirt好cool啊是哪个brand的周末我们去新开的mall里那家Starbucks喝杯coffee吧我听说他们的new season限定款Latte很OK。"
}
# REQUEST_TEMPLATE = {
# "text": "澳门在哪里啊",
# "cfg_value": "2.0",
# "inference_timesteps": "10",
# "normalize": "true",
# "denoise": "true",
# "prompt_text": "澳门有乜嘢好食嘅?"
# }
# # 音频路径(确保文件存在)
# PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# headers 一般不必指定 multipart会自动设置
DEFAULT_HEADERS = {}
# ---------------- 请求逻辑 ----------------
async def tts_request(session: aiohttp.ClientSession, request_id: int) -> float:
"""执行一次请求,返回耗时(秒)"""
start = time.perf_counter()
try:
form = aiohttp.FormData()
for k, v in REQUEST_TEMPLATE.items():
form.add_field(k, v)
# form.add_field(
# "prompt_wav",
# open(PROMPT_WAV_PATH, "rb"),
# filename="2food.wav",
# content_type="audio/wav"
# )
async with session.post(BASE_URL, data=form, headers=DEFAULT_HEADERS) as resp:
await resp.read()
if resp.status != 200:
raise RuntimeError(f"HTTP {resp.status}")
except Exception as e:
print(f"[请求 {request_id}] 出错: {e}")
return -1
return time.perf_counter() - start
# ---------------- 并发测试核心 ----------------
async def benchmark(concurrency: int, total_requests: int) -> Dict:
"""在指定并发下发起多次请求,统计性能指标"""
timings = []
errors = 0
conn = aiohttp.TCPConnector(limit=0, force_close=False)
timeout = aiohttp.ClientTimeout(total=None)
async with aiohttp.ClientSession(connector=conn, timeout=timeout) as session:
sem = asyncio.Semaphore(concurrency)
async def worker(i):
nonlocal errors
async with sem:
t = await tts_request(session, i)
if t > 0:
timings.append(t)
else:
errors += 1
tasks = [asyncio.create_task(worker(i)) for i in range(total_requests)]
await asyncio.gather(*tasks)
succ = len(timings)
total_time = sum(timings)
result = {
"concurrency": concurrency,
"total_requests": total_requests,
"successes": succ,
"errors": errors,
}
if succ > 0:
result.update({
"min": min(timings),
"max": max(timings),
"avg": statistics.mean(timings),
"median": statistics.median(timings),
"qps": succ / total_time if total_time > 0 else 0
})
return result
# ---------------- CSV 写入 ----------------
def write_to_csv(filename: str, data: List[Dict]):
"""把测试结果写入 CSV 文件"""
fieldnames = ["concurrency", "total_requests", "successes", "errors",
"min", "max", "avg", "median", "qps"]
with open(filename, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for row in data:
writer.writerow(row)
print(f"\n✅ 测试结果已保存到: {filename}")
# ---------------- 主流程 ----------------
async def run_tests(concurrency_list: List[int], total_requests: dict):
print("开始压测接口:", BASE_URL)
results = []
for c in concurrency_list:
print(f"\n=== 并发数: {c} ===")
res = await benchmark(c, total_requests[c])
results.append(res)
print(f"请求总数: {res['total_requests']} | 成功: {res['successes']} | 失败: {res['errors']}")
if "avg" in res:
print(f"耗时 (s) → min={res['min']:.3f}, avg={res['avg']:.3f}, median={res['median']:.3f}, max={res['max']:.3f}")
print(f"近似 QPS: {res['qps']:.2f}")
else:
print("所有请求均失败")
# 保存结果
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
csv_filename = f"tts_benchmark_{timestamp}.csv"
write_to_csv(csv_filename, results)
if __name__ == "__main__":
# 你可以根据机器能力调整这两个参数
concurrency_list = [1, 5, 10, 20, 50, 100, 150, 200] # 并发测试范围
total_requests = {
1: 50,
5: 100,
10: 200,
20: 400,
50: 1000,
100: 2000,
150: 3000,
200: 4000,
300: 6000,
500: 10000
} # 每个并发等级下的请求数
asyncio.run(run_tests(concurrency_list, total_requests))
# import asyncio
# import aiohttp
# import time
# import statistics
# from typing import List, Dict
# # 新的接口 URL
# BASE_URL = "http://127.0.0.1:8880/generate_tts"
# # # 固定参数(不包括音频)
# # FORM_PARAMS = {
# # "text": "澳门在哪里啊",
# # "cfg_value": "2.0",
# # "inference_timesteps": "10",
# # "normalize": "true",
# # "denoise": "true",
# # "prompt_text": "澳门有乜嘢好食嘅?"
# # }
# FORM_PARAMS = {
# "text": "澳门在哪里啊"
# }
# # 音频路径(确保文件存在)
# # PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# # headers 一般不必指定 multipart会自动设置
# DEFAULT_HEADERS = {}
# async def tts_request(session: aiohttp.ClientSession) -> float:
# """
# 发起一次 multipart/form-data 格式的 TTS 请求,返回耗时。
# """
# start = time.perf_counter()
# form = aiohttp.FormData()
# for k, v in FORM_PARAMS.items():
# form.add_field(k, v)
# # form.add_field(
# # "prompt_wav",
# # open(PROMPT_WAV_PATH, "rb"),
# # filename="2food.wav",
# # content_type="audio/wav"
# # )
# async with session.post(BASE_URL, data=form, headers=DEFAULT_HEADERS) as resp:
# data = await resp.read()
# if resp.status != 200:
# raise RuntimeError(f"HTTP {resp.status}, body: {data[:200]!r}")
# elapsed = time.perf_counter() - start
# return elapsed
# async def benchmark(concurrency: int, total_requests: int) -> Dict:
# timings: List[float] = []
# errors = 0
# conn = aiohttp.TCPConnector(limit=0)
# timeout = aiohttp.ClientTimeout(total=None)
# async with aiohttp.ClientSession(connector=conn, timeout=timeout) as session:
# sem = asyncio.Semaphore(concurrency)
# async def worker():
# nonlocal errors
# async with sem:
# try:
# t = await tts_request(session)
# timings.append(t)
# except Exception as e:
# errors += 1
# print("请求出错:", e)
# tasks = [asyncio.create_task(worker()) for _ in range(total_requests)]
# await asyncio.gather(*tasks)
# succ = len(timings)
# total_time = sum(timings) if timings else 0.0
# result = {
# "concurrency": concurrency,
# "total_requests": total_requests,
# "successes": succ,
# "errors": errors,
# }
# if succ > 0:
# result.update({
# "min": min(timings),
# "max": max(timings),
# "avg": statistics.mean(timings),
# "median": statistics.median(timings),
# "qps": succ / total_time if total_time > 0 else None
# })
# return result
# async def run_tests(concurrency_list: List[int], total_requests: int):
# print("开始并发测试,目标接口:", BASE_URL)
# for c in concurrency_list:
# print(f"\n--- 并发 = {c} ---")
# res = await benchmark(c, total_requests)
# print("总请求:", res["total_requests"])
# print("成功:", res["successes"], "失败:", res["errors"])
# if "avg" in res:
# print(f"耗时 min / avg / median / max = "
# f"{res['min']:.3f} / {res['avg']:.3f} / {res['median']:.3f} / {res['max']:.3f} 秒")
# print(f"近似 QPS = {res['qps']:.2f}")
# else:
# print("所有请求都失败了")
# if __name__ == "__main__":
# # 你可以调整这些参数
# concurrency_list = [200] # 要试的并发数列表
# total_requests = 100 # 每个并发等级下总请求数
# asyncio.run(run_tests(concurrency_list, total_requests))
# import multiprocessing as mp
# import requests
# import time
# import statistics
# from typing import Dict, List
# BASE_URL = "http://127.0.0.1:8880/generate_tts"
# FORM_PARAMS = {
# "text": "澳门在哪里啊",
# "cfg_value": "2.0",
# "inference_timesteps": "10",
# "normalize": "true",
# "denoise": "true",
# "prompt_text": "澳门有乜嘢好食嘅?"
# }
# PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
# def single_request() -> float:
# """用 requests 同步发一次 multipart/form-data 请求,返回耗时(秒)或抛异常。"""
# files = {
# "prompt_wav": ("2food.wav", open(PROMPT_WAV_PATH, "rb"), "audio/wav")
# }
# data = FORM_PARAMS.copy()
# start = time.perf_counter()
# resp = requests.post(BASE_URL, data=data, files=files, timeout=60)
# elapsed = time.perf_counter() - start
# if resp.status_code != 200:
# raise RuntimeError(f"HTTP {resp.status_code}, body: {resp.text[:200]!r}")
# return elapsed
# def worker_task(num_requests: int, return_list: mp.Manager().list, err_list: mp.Manager().list):
# """子进程做 num_requests 次请求,将各次耗时记录到 return_list共享 list错误次数记录到 err_list。"""
# for _ in range(num_requests):
# try:
# t = single_request()
# return_list.append(t)
# except Exception as e:
# err_list.append(str(e))
# def run_multiproc(concurrency: int, total_requests: int) -> Dict:
# """
# 用多个进程模拟并发:
# - 每个进程发 total_requestsconcurrency 次请求(向下取整或略分配)
# - 或者简单地让每个进程跑 total_requests 次(更激进)
# """
# manager = mp.Manager()
# times = manager.list()
# errs = manager.list()
# procs = []
# # 任务分配:每个子进程跑一部分请求
# per = total_requests // concurrency
# if per < 1:
# per = 1
# for i in range(concurrency):
# p = mp.Process(target=worker_task, args=(per, times, errs))
# p.start()
# procs.append(p)
# for p in procs:
# p.join()
# timings = list(times)
# errors = list(errs)
# succ = len(timings)
# total_time = sum(timings) if timings else 0.0
# ret = {
# "concurrency": concurrency,
# "total_requests": total_requests,
# "successes": succ,
# "errors": len(errors),
# }
# if succ > 0:
# ret.update({
# "min": min(timings),
# "max": max(timings),
# "avg": statistics.mean(timings),
# "median": statistics.median(timings),
# "qps": succ / total_time if total_time > 0 else None
# })
# return ret
# if __name__ == "__main__":
# concurrency = 4
# total_requests = 40
# print("开始多进程并发测试")
# result = run_multiproc(concurrency, total_requests)
# print(result)

34
test_streaming.py Normal file
View File

@ -0,0 +1,34 @@
import asyncio
import httpx
import time
async def test_streaming():
url = "http://localhost:8880/generate_tts_streaming"
data = {
"text": "你好,这是一段流式输出的测试音频。",
"cfg_value": "2.0",
"inference_timesteps": "10",
"do_normalize": "True",
"denoise": "True"
}
start_time = time.time()
first_byte_received = False
async with httpx.AsyncClient() as client:
# 模拟文件上传(如果有 prompt_wav
# files = {'prompt_wav': open('test.wav', 'rb')}
async with client.stream("POST", url, data=data) as response:
print(f"状态码: {response.status_code}")
async for chunk in response.aiter_bytes():
if not first_byte_received:
ttfb = time.time() - start_time
print(f"🚀 首包到达 (TTFB): {ttfb:.4f}")
first_byte_received = True
print(f"收到数据块: {len(chunk)} 字节")
if __name__ == "__main__":
asyncio.run(test_streaming())