Compare commits
11 Commits
6499215204
...
41b5d321f6
| Author | SHA1 | Date | |
|---|---|---|---|
| 41b5d321f6 | |||
| 49367d03a8 | |||
| cd4584ebae | |||
| 76c8bdbcfc | |||
| 2b0c569b7a | |||
| 721c53fe87 | |||
| e8dd956fc2 | |||
| db75a7269b | |||
| f2e203d5e2 | |||
| 6ecc00a5d3 | |||
| 8cfd9d155a |
33
.github/workflows/ci-cd.yaml
vendored
Normal file
33
.github/workflows/ci-cd.yaml
vendored
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
name: Build container
|
||||||
|
env:
|
||||||
|
VERSION: 0.0.3
|
||||||
|
REGISTRY: https://harbor.bwgdi.com
|
||||||
|
REGISTRY_NAME: harbor.bwgdi.com
|
||||||
|
REGISTRY_PATH: library
|
||||||
|
DOCKER_NAME: voxcpmtts
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
workflow_dispatch:
|
||||||
|
jobs:
|
||||||
|
build-docker:
|
||||||
|
runs-on: builder-ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ secrets.BWGDI_NAME }}
|
||||||
|
password: ${{ secrets.BWGDI_TOKEN }}
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v2
|
||||||
|
- name: Build and push
|
||||||
|
uses: docker/build-push-action@v4
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: ${{ env.REGISTRY_NAME }}/${{ env.REGISTRY_PATH }}/${{ env.DOCKER_NAME }}:${{ env.VERSION }}
|
||||||
42
.gitignore
vendored
42
.gitignore
vendored
@ -1,4 +1,42 @@
|
|||||||
launch.json
|
launch.json
|
||||||
__pycache__
|
|
||||||
voxcpm.egg-info
|
voxcpm.egg-info
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
coverage.xml
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
log/*.log
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# IDE settings
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# OS generated files
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Generated files
|
||||||
|
*.wav
|
||||||
|
*.pdf
|
||||||
|
|
||||||
|
*.lock
|
||||||
17
Dockerfile
Normal file
17
Dockerfile
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
FROM python:3.10.12-slim
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
ffmpeg \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
# Create app directory
|
||||||
|
WORKDIR /app
|
||||||
|
COPY api_concurrent.py requirements.txt ./
|
||||||
|
RUN pip install -r requirements.txt
|
||||||
|
ENV VOXCPM_MODEL_ID="/models/VoxCPM1.5/" \
|
||||||
|
VOXCPM_CPU_WORKERS="2" \
|
||||||
|
VOXCPM_UVICORN_WORKERS="1" \
|
||||||
|
MAX_GPU_CONCURRENT="1"
|
||||||
|
|
||||||
|
EXPOSE 5000
|
||||||
|
CMD [ "python", "./api_concurrent.py" ]
|
||||||
|
|
||||||
83
README_BW.md
Normal file
83
README_BW.md
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
# VoxCPM-TTS
|
||||||
|
|
||||||
|
https://github.com/BoardWare-Genius/VoxCPM
|
||||||
|
|
||||||
|
## 📦 VoxCPM-TTS Version History
|
||||||
|
|
||||||
|
| Version | Date | Summary |
|
||||||
|
|---------|------------|---------------------------------|
|
||||||
|
| 0.0.3 | 2026-02-09 | Optimized configuration & Model support |
|
||||||
|
| 0.0.2 | 2026-01-21 | Supports streaming |
|
||||||
|
| 0.0.1 | 2026-01-20 | Initial version |
|
||||||
|
|
||||||
|
### 🔄 Version Details
|
||||||
|
|
||||||
|
#### 🆕 0.0.3 – *2026-02-09*
|
||||||
|
- ✅ **Configuration & Deployment**
|
||||||
|
- Supports configuring model path via `VOXCPM_MODEL_ID`
|
||||||
|
- Supports configuring CPU workers via `VOXCPM_CPU_WORKERS`
|
||||||
|
- Supports configuring Uvicorn workers via `VOXCPM_UVICORN_WORKERS`
|
||||||
|
|
||||||
|
#### 🆕 0.0.2 – *2026-01-21*
|
||||||
|
|
||||||
|
- ✅ **Core Features**
|
||||||
|
- Update Model Weights, use VoxCPM1.5 and model parameters
|
||||||
|
- Supports streaming
|
||||||
|
|
||||||
|
#### 🆕 0.0.1 – *2026-01-20*
|
||||||
|
|
||||||
|
- ✅ **Core Features**
|
||||||
|
- Initial VoxCPM-TTS
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|
||||||
|
# Start
|
||||||
|
```bash
|
||||||
|
docker pull harbor.bwgdi.com/library/voxcpmtts:0.0.3
|
||||||
|
|
||||||
|
# Run with custom configuration
|
||||||
|
# -e VOXCPM_MODEL_ID: Path to the model directory inside container
|
||||||
|
# -e VOXCPM_CPU_WORKERS: Number of threads for CPU-bound tasks
|
||||||
|
# -e VOXCPM_UVICORN_WORKERS: Number of uvicorn workers
|
||||||
|
# -e MAX_GPU_CONCURRENT: Max concurrent GPU tasks
|
||||||
|
docker run -d --restart always -p 5001:5000 --gpus '"device=0"' \
|
||||||
|
-e VOXCPM_MODEL_ID="/models/VoxCPM1.5/" \
|
||||||
|
-e VOXCPM_CPU_WORKERS="2" \
|
||||||
|
-e VOXCPM_UVICORN_WORKERS="1" \
|
||||||
|
-e MAX_GPU_CONCURRENT="1" \
|
||||||
|
--mount type=bind,source=/Workspace/NAS11/model/Voice/VoxCPM,target=/models \
|
||||||
|
harbor.bwgdi.com/library/voxcpmtts:0.0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
|
||||||
|
## Non-streaming
|
||||||
|
```bash
|
||||||
|
curl --location 'http://localhost:5001/generate_tts' \
|
||||||
|
--form 'text="你好,这是一段测试文本"' \
|
||||||
|
--form 'prompt_text="这是提示文本"' \
|
||||||
|
--form 'cfg_value="2.0"' \
|
||||||
|
--form 'inference_timesteps="10"' \
|
||||||
|
--form 'do_normalize="true"' \
|
||||||
|
--form 'denoise="true"' \
|
||||||
|
--form 'retry_badcase="true"' \
|
||||||
|
--form 'retry_badcase_max_times="3"' \
|
||||||
|
--form 'retry_badcase_ratio_threshold="6.0"' \
|
||||||
|
--form 'prompt_wav=@"/assets/2food16k_2.wav"'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Streaming
|
||||||
|
```bash
|
||||||
|
curl --location 'http://localhost:5001/generate_tts_streaming' \
|
||||||
|
--form 'text="你好,这是一段测试文本"' \
|
||||||
|
--form 'prompt_text="这是提示文本"' \
|
||||||
|
--form 'cfg_value="2.0"' \
|
||||||
|
--form 'inference_timesteps="10"' \
|
||||||
|
--form 'do_normalize="true"' \
|
||||||
|
--form 'denoise="true"' \
|
||||||
|
--form 'retry_badcase="true"' \
|
||||||
|
--form 'retry_badcase_max_times="3"' \
|
||||||
|
--form 'retry_badcase_ratio_threshold="6.0"' \
|
||||||
|
--form 'prompt_wav=@"/Workspace/NAS11/model/Voice/assets/2food16k_2.wav"'
|
||||||
|
```
|
||||||
272
api_concurrent.py
Normal file
272
api_concurrent.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import tempfile
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
import wave
|
||||||
|
from io import BytesIO
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import FastAPI, Form, UploadFile, BackgroundTasks
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import voxcpm
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||||
|
os.environ["HF_REPO_ID"] = "/models/VoxCPM1.5/"
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 模型类 ==========
|
||||||
|
class VoxCPMDemo:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"🚀 Running on device: {self.device}")
|
||||||
|
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||||
|
self.default_local_model_dir = os.environ.get("VOXCPM_MODEL_ID", "/models/VoxCPM1.5/")
|
||||||
|
|
||||||
|
def _resolve_model_dir(self) -> str:
|
||||||
|
if os.path.isdir(self.default_local_model_dir):
|
||||||
|
return self.default_local_model_dir
|
||||||
|
return "models"
|
||||||
|
|
||||||
|
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||||
|
if self.voxcpm_model is None:
|
||||||
|
print("🔄 Loading VoxCPM model...")
|
||||||
|
model_dir = self._resolve_model_dir()
|
||||||
|
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||||
|
print("✅ VoxCPM model loaded.")
|
||||||
|
return self.voxcpm_model
|
||||||
|
|
||||||
|
def tts_generate(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
prompt_wav_path: Optional[str] = None,
|
||||||
|
prompt_text: Optional[str] = None,
|
||||||
|
cfg_value: float = 2.0,
|
||||||
|
inference_timesteps: int = 10,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
denoise: bool = True,
|
||||||
|
retry_badcase: bool = True,
|
||||||
|
retry_badcase_max_times: int = 3,
|
||||||
|
retry_badcase_ratio_threshold: float = 6.0,
|
||||||
|
) -> str:
|
||||||
|
model = self.get_or_load_voxcpm()
|
||||||
|
wav = model.generate(
|
||||||
|
text=text,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
prompt_wav_path=prompt_wav_path,
|
||||||
|
cfg_value=float(cfg_value),
|
||||||
|
inference_timesteps=int(inference_timesteps),
|
||||||
|
normalize=do_normalize,
|
||||||
|
denoise=denoise,
|
||||||
|
retry_badcase=retry_badcase,
|
||||||
|
retry_badcase_max_times=retry_badcase_max_times,
|
||||||
|
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||||
|
)
|
||||||
|
tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||||
|
sf.write(tmp_wav.name, wav, model.tts_model.sample_rate)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return tmp_wav.name
|
||||||
|
|
||||||
|
def tts_generate_streaming(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
prompt_wav_path: Optional[str] = None,
|
||||||
|
prompt_text: Optional[str] = None,
|
||||||
|
cfg_value: float = 2.0,
|
||||||
|
inference_timesteps: int = 10,
|
||||||
|
do_normalize: bool = True,
|
||||||
|
denoise: bool = True,
|
||||||
|
retry_badcase: bool = True,
|
||||||
|
retry_badcase_max_times: int = 3,
|
||||||
|
retry_badcase_ratio_threshold: float = 6.0,
|
||||||
|
):
|
||||||
|
"""Generates audio and yields it as a stream of WAV chunks."""
|
||||||
|
model = self.get_or_load_voxcpm()
|
||||||
|
|
||||||
|
# 1. Yield a WAV header first.
|
||||||
|
# The size fields will be 0, which is standard for streaming.
|
||||||
|
SAMPLE_RATE = model.tts_model.sample_rate
|
||||||
|
CHANNELS = 1
|
||||||
|
SAMPLE_WIDTH = 2 # 16-bit
|
||||||
|
|
||||||
|
header_buf = BytesIO()
|
||||||
|
with wave.open(header_buf, "wb") as wf:
|
||||||
|
wf.setnchannels(CHANNELS)
|
||||||
|
wf.setsampwidth(SAMPLE_WIDTH)
|
||||||
|
wf.setframerate(SAMPLE_RATE)
|
||||||
|
|
||||||
|
yield header_buf.getvalue()
|
||||||
|
|
||||||
|
# 2. Generate and yield audio chunks.
|
||||||
|
# NOTE: We assume a `generate_stream` method exists on the model that yields audio chunks.
|
||||||
|
# You may need to change `generate_stream` to the actual method name in your version of voxcpm.
|
||||||
|
try:
|
||||||
|
stream = model.generate_streaming(
|
||||||
|
text=text,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
prompt_wav_path=prompt_wav_path,
|
||||||
|
cfg_value=float(cfg_value),
|
||||||
|
inference_timesteps=int(inference_timesteps),
|
||||||
|
normalize=do_normalize,
|
||||||
|
denoise=denoise,
|
||||||
|
retry_badcase=retry_badcase,
|
||||||
|
retry_badcase_max_times=retry_badcase_max_times,
|
||||||
|
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||||
|
)
|
||||||
|
for chunk_np in stream: # Assuming it yields numpy arrays
|
||||||
|
# Ensure audio is in 16-bit PCM format for streaming
|
||||||
|
if chunk_np.dtype in [np.float32, np.float64]:
|
||||||
|
chunk_np = (chunk_np * 32767).astype(np.int16)
|
||||||
|
yield chunk_np.tobytes()
|
||||||
|
finally:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== FastAPI ==========
|
||||||
|
app = FastAPI(title="VoxCPM API", version="1.0.0")
|
||||||
|
demo = VoxCPMDemo()
|
||||||
|
|
||||||
|
# --- Concurrency Control ---
|
||||||
|
# Use a semaphore to limit concurrent GPU tasks.
|
||||||
|
MAX_GPU_CONCURRENT = int(os.environ.get("MAX_GPU_CONCURRENT", "1"))
|
||||||
|
gpu_semaphore = asyncio.Semaphore(MAX_GPU_CONCURRENT)
|
||||||
|
|
||||||
|
# Use a thread pool for running blocking (CPU/GPU-bound) code.
|
||||||
|
MAX_CPU_WORKERS = int(os.environ.get("VOXCPM_CPU_WORKERS", "2"))
|
||||||
|
executor = ThreadPoolExecutor(max_workers=MAX_CPU_WORKERS)
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
def shutdown_event():
|
||||||
|
print("Shutting down thread pool executor...")
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- TTS API ----------
|
||||||
|
@app.post("/generate_tts")
|
||||||
|
async def generate_tts(
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
text: str = Form(...),
|
||||||
|
prompt_text: Optional[str] = Form(None),
|
||||||
|
cfg_value: float = Form(2.0),
|
||||||
|
inference_timesteps: int = Form(10),
|
||||||
|
do_normalize: bool = Form(True),
|
||||||
|
denoise: bool = Form(True),
|
||||||
|
retry_badcase: bool = Form(True),
|
||||||
|
retry_badcase_max_times: int = Form(3),
|
||||||
|
retry_badcase_ratio_threshold: float = Form(6.0),
|
||||||
|
prompt_wav: Optional[UploadFile] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt_path = None
|
||||||
|
if prompt_wav:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||||
|
tmp.write(await prompt_wav.read())
|
||||||
|
prompt_path = tmp.name
|
||||||
|
background_tasks.add_task(os.remove, tmp.name)
|
||||||
|
|
||||||
|
# Submit to GPU via semaphore and executor
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
async with gpu_semaphore:
|
||||||
|
output_path = await loop.run_in_executor(
|
||||||
|
executor,
|
||||||
|
demo.tts_generate,
|
||||||
|
text,
|
||||||
|
prompt_path,
|
||||||
|
prompt_text,
|
||||||
|
cfg_value,
|
||||||
|
inference_timesteps,
|
||||||
|
do_normalize,
|
||||||
|
denoise,
|
||||||
|
retry_badcase,
|
||||||
|
retry_badcase_max_times,
|
||||||
|
retry_badcase_ratio_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 后台删除生成的文件
|
||||||
|
background_tasks.add_task(os.remove, output_path)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
open(output_path, "rb"),
|
||||||
|
media_type="audio/wav",
|
||||||
|
headers={"Content-Disposition": 'attachment; filename="output.wav"'}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate_tts_streaming")
|
||||||
|
async def generate_tts_streaming(
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
text: str = Form(...),
|
||||||
|
prompt_text: Optional[str] = Form(None),
|
||||||
|
cfg_value: float = Form(2.0),
|
||||||
|
inference_timesteps: int = Form(10),
|
||||||
|
do_normalize: bool = Form(True),
|
||||||
|
denoise: bool = Form(True),
|
||||||
|
retry_badcase: bool = Form(True),
|
||||||
|
retry_badcase_max_times: int = Form(3),
|
||||||
|
retry_badcase_ratio_threshold: float = Form(6.0),
|
||||||
|
prompt_wav: Optional[UploadFile] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
prompt_path = None
|
||||||
|
if prompt_wav:
|
||||||
|
# Save uploaded file to a temporary file
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||||
|
tmp.write(await prompt_wav.read())
|
||||||
|
prompt_path = tmp.name
|
||||||
|
# Ensure the temp file is deleted after the request is finished
|
||||||
|
background_tasks.add_task(os.remove, prompt_path)
|
||||||
|
|
||||||
|
async def stream_generator():
|
||||||
|
# This async generator consumes from a queue populated by a sync generator in a thread.
|
||||||
|
q = asyncio.Queue()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def producer():
|
||||||
|
# This runs in the executor thread and produces chunks.
|
||||||
|
try:
|
||||||
|
# This is a sync generator
|
||||||
|
for chunk in demo.tts_generate_streaming(
|
||||||
|
text, prompt_path, prompt_text, cfg_value,
|
||||||
|
inference_timesteps, do_normalize, denoise,
|
||||||
|
retry_badcase, retry_badcase_max_times, retry_badcase_ratio_threshold
|
||||||
|
):
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, chunk)
|
||||||
|
except Exception as e:
|
||||||
|
# Put the exception in the queue to be re-raised in the consumer
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, e)
|
||||||
|
finally:
|
||||||
|
# Signal the end of the stream
|
||||||
|
loop.call_soon_threadsafe(q.put_nowait, None)
|
||||||
|
|
||||||
|
# Acquire the GPU semaphore before starting the producer thread.
|
||||||
|
async with gpu_semaphore:
|
||||||
|
loop.run_in_executor(executor, producer)
|
||||||
|
while True:
|
||||||
|
chunk = await q.get()
|
||||||
|
if chunk is None:
|
||||||
|
break
|
||||||
|
if isinstance(chunk, Exception):
|
||||||
|
raise chunk
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(stream_generator(), media_type="audio/wav")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse({"error": str(e)}, status_code=500)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "VoxCPM API running 🚀", "endpoints": ["/generate_tts"]}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn_workers = int(os.environ.get("VOXCPM_UVICORN_WORKERS", "1"))
|
||||||
|
uvicorn.run("api_concurrent:app", host="0.0.0.0", port=5000, workers=uvicorn_workers)
|
||||||
15
app.py
15
app.py
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -16,7 +17,7 @@ import voxcpm
|
|||||||
class VoxCPMDemo:
|
class VoxCPMDemo:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print(f"🚀 Running on device: {self.device}")
|
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
|
||||||
|
|
||||||
# ASR model for prompt text recognition
|
# ASR model for prompt text recognition
|
||||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
self.asr_model_id = "iic/SenseVoiceSmall"
|
||||||
@ -49,10 +50,10 @@ class VoxCPMDemo:
|
|||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download # type: ignore
|
from huggingface_hub import snapshot_download # type: ignore
|
||||||
os.makedirs(target_dir, exist_ok=True)
|
os.makedirs(target_dir, exist_ok=True)
|
||||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
|
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: HF download failed: {e}. Falling back to 'data'.")
|
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
|
||||||
return "models"
|
return "models"
|
||||||
return target_dir
|
return target_dir
|
||||||
return "models"
|
return "models"
|
||||||
@ -60,11 +61,11 @@ class VoxCPMDemo:
|
|||||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||||
if self.voxcpm_model is not None:
|
if self.voxcpm_model is not None:
|
||||||
return self.voxcpm_model
|
return self.voxcpm_model
|
||||||
print("Model not loaded, initializing...")
|
print("Model not loaded, initializing...", file=sys.stderr)
|
||||||
model_dir = self._resolve_model_dir()
|
model_dir = self._resolve_model_dir()
|
||||||
print(f"Using model dir: {model_dir}")
|
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||||
print("Model loaded successfully.")
|
print("Model loaded successfully.", file=sys.stderr)
|
||||||
return self.voxcpm_model
|
return self.voxcpm_model
|
||||||
|
|
||||||
# ---------- Functional endpoints ----------
|
# ---------- Functional endpoints ----------
|
||||||
@ -98,7 +99,7 @@ class VoxCPMDemo:
|
|||||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||||
prompt_text = prompt_text_input if prompt_text_input else None
|
prompt_text = prompt_text_input if prompt_text_input else None
|
||||||
|
|
||||||
print(f"Generating audio for text: '{text[:60]}...'")
|
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
||||||
wav = current_model.generate(
|
wav = current_model.generate(
|
||||||
text=text,
|
text=text,
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
|
|||||||
@ -104,7 +104,7 @@ def get_timestamp_str():
|
|||||||
def get_or_load_asr_model():
|
def get_or_load_asr_model():
|
||||||
global asr_model
|
global asr_model
|
||||||
if asr_model is None:
|
if asr_model is None:
|
||||||
print("Loading ASR model (SenseVoiceSmall)...")
|
print("Loading ASR model (SenseVoiceSmall)...", file=sys.stderr)
|
||||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
asr_model = AutoModel(
|
asr_model = AutoModel(
|
||||||
model="iic/SenseVoiceSmall",
|
model="iic/SenseVoiceSmall",
|
||||||
@ -123,7 +123,7 @@ def recognize_audio(audio_path):
|
|||||||
text = res[0]["text"].split('|>')[-1]
|
text = res[0]["text"].split('|>')[-1]
|
||||||
return text
|
return text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ASR Error: {e}")
|
print(f"ASR Error: {e}", file=sys.stderr)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||||
@ -181,7 +181,7 @@ def load_lora_config_from_checkpoint(lora_path):
|
|||||||
if lora_cfg_dict:
|
if lora_cfg_dict:
|
||||||
return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model")
|
return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load lora_config.json: {e}")
|
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def get_default_lora_config():
|
def get_default_lora_config():
|
||||||
@ -197,7 +197,7 @@ def get_default_lora_config():
|
|||||||
|
|
||||||
def load_model(pretrained_path, lora_path=None):
|
def load_model(pretrained_path, lora_path=None):
|
||||||
global current_model
|
global current_model
|
||||||
print(f"Loading model from {pretrained_path}...")
|
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
|
||||||
|
|
||||||
lora_config = None
|
lora_config = None
|
||||||
lora_weights_path = None
|
lora_weights_path = None
|
||||||
@ -209,11 +209,11 @@ def load_model(pretrained_path, lora_path=None):
|
|||||||
# Try to load LoRA config from lora_config.json
|
# Try to load LoRA config from lora_config.json
|
||||||
lora_config, _ = load_lora_config_from_checkpoint(full_lora_path)
|
lora_config, _ = load_lora_config_from_checkpoint(full_lora_path)
|
||||||
if lora_config:
|
if lora_config:
|
||||||
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json")
|
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
# Fallback to default config for old checkpoints
|
# Fallback to default config for old checkpoints
|
||||||
lora_config = get_default_lora_config()
|
lora_config = get_default_lora_config()
|
||||||
print("Using default LoRA config (lora_config.json not found)")
|
print("Using default LoRA config (lora_config.json not found)", file=sys.stderr)
|
||||||
|
|
||||||
# Always init with a default LoRA config to allow hot-swapping later
|
# Always init with a default LoRA config to allow hot-swapping later
|
||||||
if lora_config is None:
|
if lora_config is None:
|
||||||
@ -251,36 +251,36 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
|||||||
# 优先使用保存的 base_model 路径
|
# 优先使用保存的 base_model 路径
|
||||||
if os.path.exists(saved_base_model):
|
if os.path.exists(saved_base_model):
|
||||||
base_model_path = saved_base_model
|
base_model_path = saved_base_model
|
||||||
print(f"Using base model from LoRA config: {base_model_path}")
|
print(f"Using base model from LoRA config: {base_model_path}", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
print(f"Warning: Saved base_model path not found: {saved_base_model}")
|
print(f"Warning: Saved base_model path not found: {saved_base_model}", file=sys.stderr)
|
||||||
print(f"Falling back to default: {base_model_path}")
|
print(f"Falling back to default: {base_model_path}", file=sys.stderr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to read base_model from LoRA config: {e}")
|
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
try:
|
try:
|
||||||
print(f"Loading base model: {base_model_path}")
|
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||||
status_msg = load_model(base_model_path)
|
status_msg = load_model(base_model_path)
|
||||||
if lora_selection and lora_selection != "None":
|
if lora_selection and lora_selection != "None":
|
||||||
print(f"Model loaded for LoRA: {lora_selection}")
|
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
||||||
print(error_msg)
|
print(error_msg, file=sys.stderr)
|
||||||
return None, error_msg
|
return None, error_msg
|
||||||
|
|
||||||
# Handle LoRA hot-swapping
|
# Handle LoRA hot-swapping
|
||||||
if lora_selection and lora_selection != "None":
|
if lora_selection and lora_selection != "None":
|
||||||
full_lora_path = os.path.join("lora", lora_selection)
|
full_lora_path = os.path.join("lora", lora_selection)
|
||||||
print(f"Hot-loading LoRA: {full_lora_path}")
|
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
current_model.load_lora(full_lora_path)
|
current_model.load_lora(full_lora_path)
|
||||||
current_model.set_lora_enabled(True)
|
current_model.set_lora_enabled(True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading LoRA: {e}")
|
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||||
return None, f"Error loading LoRA: {e}"
|
return None, f"Error loading LoRA: {e}"
|
||||||
else:
|
else:
|
||||||
print("Disabling LoRA")
|
print("Disabling LoRA", file=sys.stderr)
|
||||||
current_model.set_lora_enabled(False)
|
current_model.set_lora_enabled(False)
|
||||||
|
|
||||||
if seed != -1:
|
if seed != -1:
|
||||||
@ -297,11 +297,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
|||||||
|
|
||||||
# 如果没有提供参考文本,尝试自动识别
|
# 如果没有提供参考文本,尝试自动识别
|
||||||
if not prompt_text or not prompt_text.strip():
|
if not prompt_text or not prompt_text.strip():
|
||||||
print("参考音频已提供但缺少文本,自动识别中...")
|
print("参考音频已提供但缺少文本,自动识别中...", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
final_prompt_text = recognize_audio(prompt_wav)
|
final_prompt_text = recognize_audio(prompt_wav)
|
||||||
if final_prompt_text:
|
if final_prompt_text:
|
||||||
print(f"自动识别文本: {final_prompt_text}")
|
print(f"自动识别文本: {final_prompt_text}", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
return None, "错误:无法识别参考音频内容,请手动填写参考文本"
|
return None, "错误:无法识别参考音频内容,请手动填写参考文本"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1114,12 +1114,12 @@ with gr.Blocks(
|
|||||||
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
|
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
|
||||||
|
|
||||||
# 输出调试信息
|
# 输出调试信息
|
||||||
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点")
|
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点", file=sys.stderr)
|
||||||
for ckpt_path, base_model in checkpoints_with_info:
|
for ckpt_path, base_model in checkpoints_with_info:
|
||||||
if base_model:
|
if base_model:
|
||||||
print(f" - {ckpt_path} (Base Model: {base_model})")
|
print(f" - {ckpt_path} (Base Model: {base_model})", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
print(f" - {ckpt_path}")
|
print(f" - {ckpt_path}", file=sys.stderr)
|
||||||
|
|
||||||
return gr.update(choices=choices, value="None")
|
return gr.update(choices=choices, value="None")
|
||||||
|
|
||||||
|
|||||||
@ -27,6 +27,7 @@ requires-python = ">=3.10"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"torch>=2.5.0",
|
"torch>=2.5.0",
|
||||||
"torchaudio>=2.5.0",
|
"torchaudio>=2.5.0",
|
||||||
|
"torchcodec",
|
||||||
"transformers>=4.36.2",
|
"transformers>=4.36.2",
|
||||||
"einops",
|
"einops",
|
||||||
"gradio<6",
|
"gradio<6",
|
||||||
@ -46,8 +47,7 @@ dependencies = [
|
|||||||
"funasr",
|
"funasr",
|
||||||
"spaces",
|
"spaces",
|
||||||
"argbind",
|
"argbind",
|
||||||
"safetensors"
|
"safetensors",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
soundfile
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
voxcpm
|
||||||
|
torchcodec
|
||||||
@ -23,6 +23,7 @@ With voice cloning:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
@ -92,7 +93,7 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Load model from checkpoint directory (no denoiser)
|
# Load model from checkpoint directory (no denoiser)
|
||||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
print(f"[FT Inference] Loading model: {args.ckpt_dir}", file=sys.stderr)
|
||||||
model = VoxCPM.from_pretrained(
|
model = VoxCPM.from_pretrained(
|
||||||
hf_model_id=args.ckpt_dir,
|
hf_model_id=args.ckpt_dir,
|
||||||
load_denoiser=False,
|
load_denoiser=False,
|
||||||
@ -103,10 +104,10 @@ def main():
|
|||||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||||
prompt_text = args.prompt_text if args.prompt_text else None
|
prompt_text = args.prompt_text if args.prompt_text else None
|
||||||
|
|
||||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
print(f"[FT Inference] Synthesizing: text='{args.text}'", file=sys.stderr)
|
||||||
if prompt_wav_path:
|
if prompt_wav_path:
|
||||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
print(f"[FT Inference] Using reference audio: {prompt_wav_path}", file=sys.stderr)
|
||||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
print(f"[FT Inference] Reference text: {prompt_text}", file=sys.stderr)
|
||||||
|
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -124,7 +125,7 @@ def main():
|
|||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||||
|
|
||||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -24,6 +24,7 @@ Note: The script reads base_model path and lora_config from lora_config.json
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
@ -124,13 +125,13 @@ def main():
|
|||||||
lora_cfg_dict = lora_info.get("lora_config", {})
|
lora_cfg_dict = lora_info.get("lora_config", {})
|
||||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||||
|
|
||||||
print(f"Loaded config from: {lora_config_path}")
|
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
||||||
print(f" Base model: {pretrained_path}")
|
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
||||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
|
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
|
||||||
|
|
||||||
# 3. Load model with LoRA (no denoiser)
|
# 3. Load model with LoRA (no denoiser)
|
||||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
|
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
||||||
print(f" LoRA weights: {ckpt_dir}")
|
print(f" LoRA weights: {ckpt_dir}", file=sys.stderr)
|
||||||
model = VoxCPM.from_pretrained(
|
model = VoxCPM.from_pretrained(
|
||||||
hf_model_id=pretrained_path,
|
hf_model_id=pretrained_path,
|
||||||
load_denoiser=False,
|
load_denoiser=False,
|
||||||
@ -145,10 +146,10 @@ def main():
|
|||||||
out_path = Path(args.output)
|
out_path = Path(args.output)
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"\n[2/2] Starting synthesis tests...")
|
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 1: With LoRA ===
|
# === Test 1: With LoRA ===
|
||||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
@ -161,10 +162,10 @@ def main():
|
|||||||
)
|
)
|
||||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||||
model.set_lora_enabled(False)
|
model.set_lora_enabled(False)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -178,10 +179,10 @@ def main():
|
|||||||
)
|
)
|
||||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 3: Re-enable LoRA ===
|
# === Test 3: Re-enable LoRA ===
|
||||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||||
model.set_lora_enabled(True)
|
model.set_lora_enabled(True)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -195,10 +196,10 @@ def main():
|
|||||||
)
|
)
|
||||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
|
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||||
model.unload_lora()
|
model.unload_lora()
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -212,12 +213,12 @@ def main():
|
|||||||
)
|
)
|
||||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||||
loaded, skipped = model.load_lora(ckpt_dir)
|
loaded, skipped = model.load_lora(ckpt_dir)
|
||||||
print(f" Reloaded {len(loaded)} parameters")
|
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
@ -230,14 +231,14 @@ def main():
|
|||||||
)
|
)
|
||||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
print(f"\n[Done] All tests completed!")
|
print(f"\n[Done] All tests completed!", file=sys.stderr)
|
||||||
print(f" - with_lora: {lora_output}")
|
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
||||||
print(f" - lora_disabled: {disabled_output}")
|
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
||||||
print(f" - lora_reenabled: {reenabled_output}")
|
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
||||||
print(f" - lora_reset: {reset_output}")
|
print(f" - lora_reset: {reset_output}", file=sys.stderr)
|
||||||
print(f" - lora_reloaded: {reload_output}")
|
print(f" - lora_reloaded: {reload_output}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -24,7 +24,7 @@ try:
|
|||||||
SAFETENSORS_AVAILABLE = True
|
SAFETENSORS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
SAFETENSORS_AVAILABLE = False
|
SAFETENSORS_AVAILABLE = False
|
||||||
print("Warning: safetensors not available, will use pytorch format")
|
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
||||||
|
|
||||||
from voxcpm.model import VoxCPMModel
|
from voxcpm.model import VoxCPMModel
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
@ -170,7 +170,7 @@ def train(
|
|||||||
# Only print param info on rank 0 to avoid cluttered output
|
# Only print param info on rank 0 to avoid cluttered output
|
||||||
if accelerator.rank == 0:
|
if accelerator.rank == 0:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
print(name, param.requires_grad)
|
print(name, param.requires_grad, file=sys.stderr)
|
||||||
|
|
||||||
optimizer = AdamW(
|
optimizer = AdamW(
|
||||||
(p for p in model.parameters() if p.requires_grad),
|
(p for p in model.parameters() if p.requires_grad),
|
||||||
@ -210,12 +210,12 @@ def train(
|
|||||||
cur_step = int(_resume.get("step", start_step))
|
cur_step = int(_resume.get("step", start_step))
|
||||||
except Exception:
|
except Exception:
|
||||||
cur_step = start_step
|
cur_step = start_step
|
||||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
|
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
||||||
print("Checkpoint saved. Exiting.")
|
print("Checkpoint saved. Exiting.", file=sys.stderr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error saving checkpoint on signal: {e}")
|
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, _signal_handler)
|
signal.signal(signal.SIGTERM, _signal_handler)
|
||||||
@ -553,7 +553,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
|||||||
|
|
||||||
# Load only lora weights
|
# Load only lora weights
|
||||||
unwrapped.load_state_dict(state_dict, strict=False)
|
unwrapped.load_state_dict(state_dict, strict=False)
|
||||||
print(f"Loaded LoRA weights from {lora_weights_path}")
|
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
# Full finetune: load model.safetensors or pytorch_model.bin
|
# Full finetune: load model.safetensors or pytorch_model.bin
|
||||||
model_path = latest_folder / "model.safetensors"
|
model_path = latest_folder / "model.safetensors"
|
||||||
@ -569,26 +569,26 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
|||||||
state_dict = ckpt.get("state_dict", ckpt)
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
unwrapped.load_state_dict(state_dict, strict=False)
|
unwrapped.load_state_dict(state_dict, strict=False)
|
||||||
print(f"Loaded model weights from {model_path}")
|
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||||
|
|
||||||
# Load optimizer state
|
# Load optimizer state
|
||||||
optimizer_path = latest_folder / "optimizer.pth"
|
optimizer_path = latest_folder / "optimizer.pth"
|
||||||
if optimizer_path.exists():
|
if optimizer_path.exists():
|
||||||
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||||
print(f"Loaded optimizer state from {optimizer_path}")
|
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||||
|
|
||||||
# Load scheduler state
|
# Load scheduler state
|
||||||
scheduler_path = latest_folder / "scheduler.pth"
|
scheduler_path = latest_folder / "scheduler.pth"
|
||||||
if scheduler_path.exists():
|
if scheduler_path.exists():
|
||||||
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||||
print(f"Loaded scheduler state from {scheduler_path}")
|
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||||
|
|
||||||
# Try to infer step from checkpoint folders
|
# Try to infer step from checkpoint folders
|
||||||
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
||||||
if step_folders:
|
if step_folders:
|
||||||
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
||||||
resume_step = max(steps)
|
resume_step = max(steps)
|
||||||
print(f"Resuming from step {resume_step}")
|
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||||
return resume_step
|
return resume_step
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
@ -670,7 +670,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
latest_link.unlink()
|
latest_link.unlink()
|
||||||
shutil.copytree(folder, latest_link)
|
shutil.copytree(folder, latest_link)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
|
print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -45,7 +45,7 @@ def load_model(args) -> VoxCPM:
|
|||||||
|
|
||||||
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
|
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
|
||||||
"""
|
"""
|
||||||
print("Loading VoxCPM model...")
|
print("Loading VoxCPM model...", file=sys.stderr)
|
||||||
|
|
||||||
# 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认
|
# 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认
|
||||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||||
@ -66,7 +66,7 @@ def load_model(args) -> VoxCPM:
|
|||||||
dropout=getattr(args, "lora_dropout", 0.0),
|
dropout=getattr(args, "lora_dropout", 0.0),
|
||||||
)
|
)
|
||||||
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
|
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}", file=sys.stderr)
|
||||||
|
|
||||||
# Load from local path if provided
|
# Load from local path if provided
|
||||||
if getattr(args, "model_path", None):
|
if getattr(args, "model_path", None):
|
||||||
@ -78,10 +78,10 @@ def load_model(args) -> VoxCPM:
|
|||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
print("Model loaded (local).")
|
print("Model loaded (local).", file=sys.stderr)
|
||||||
return model
|
return model
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to load model (local): {e}")
|
print(f"Failed to load model (local): {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Otherwise, try from_pretrained (Hub); exit on failure
|
# Otherwise, try from_pretrained (Hub); exit on failure
|
||||||
@ -95,10 +95,10 @@ def load_model(args) -> VoxCPM:
|
|||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
print("Model loaded (from_pretrained).")
|
print("Model loaded (from_pretrained).", file=sys.stderr)
|
||||||
return model
|
return model
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to load model (from_pretrained): {e}")
|
print(f"Failed to load model (from_pretrained): {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
@ -106,15 +106,15 @@ def cmd_clone(args):
|
|||||||
"""Voice cloning command."""
|
"""Voice cloning command."""
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
if not args.text:
|
if not args.text:
|
||||||
print("Error: Please provide text to synthesize (--text)")
|
print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if not args.prompt_audio:
|
if not args.prompt_audio:
|
||||||
print("Error: Voice cloning requires a reference audio (--prompt-audio)")
|
print("Error: Voice cloning requires a reference audio (--prompt-audio)", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if not args.prompt_text:
|
if not args.prompt_text:
|
||||||
print("Error: Voice cloning requires a reference text (--prompt-text)")
|
print("Error: Voice cloning requires a reference text (--prompt-text)", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Validate files
|
# Validate files
|
||||||
@ -125,9 +125,9 @@ def cmd_clone(args):
|
|||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio
|
||||||
print(f"Synthesizing text: {args.text}")
|
print(f"Synthesizing text: {args.text}", file=sys.stderr)
|
||||||
print(f"Reference audio: {prompt_audio_path}")
|
print(f"Reference audio: {prompt_audio_path}", file=sys.stderr)
|
||||||
print(f"Reference text: {args.prompt_text}")
|
print(f"Reference text: {args.prompt_text}", file=sys.stderr)
|
||||||
|
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -141,25 +141,25 @@ def cmd_clone(args):
|
|||||||
|
|
||||||
# Save audio
|
# Save audio
|
||||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||||
print(f"Saved audio to: {output_path}")
|
print(f"Saved audio to: {output_path}", file=sys.stderr)
|
||||||
|
|
||||||
# Stats
|
# Stats
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f"Duration: {duration:.2f}s")
|
print(f"Duration: {duration:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def cmd_synthesize(args):
|
def cmd_synthesize(args):
|
||||||
"""Direct TTS synthesis command."""
|
"""Direct TTS synthesis command."""
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
if not args.text:
|
if not args.text:
|
||||||
print("Error: Please provide text to synthesize (--text)")
|
print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
# Validate output path
|
# Validate output path
|
||||||
output_path = validate_output_path(args.output)
|
output_path = validate_output_path(args.output)
|
||||||
# Load model
|
# Load model
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
# Generate audio
|
# Generate audio
|
||||||
print(f"Synthesizing text: {args.text}")
|
print(f"Synthesizing text: {args.text}", file=sys.stderr)
|
||||||
|
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -173,11 +173,11 @@ def cmd_synthesize(args):
|
|||||||
|
|
||||||
# Save audio
|
# Save audio
|
||||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||||
print(f"Saved audio to: {output_path}")
|
print(f"Saved audio to: {output_path}", file=sys.stderr)
|
||||||
|
|
||||||
# Stats
|
# Stats
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f"Duration: {duration:.2f}s")
|
print(f"Duration: {duration:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def cmd_batch(args):
|
def cmd_batch(args):
|
||||||
@ -191,12 +191,12 @@ def cmd_batch(args):
|
|||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
texts = [line.strip() for line in f if line.strip()]
|
texts = [line.strip() for line in f if line.strip()]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to read input file: {e}")
|
print(f"Failed to read input file: {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if not texts:
|
if not texts:
|
||||||
print("Error: Input file is empty or contains no valid lines")
|
print("Error: Input file is empty or contains no valid lines", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
print(f"Found {len(texts)} lines to process")
|
print(f"Found {len(texts)} lines to process", file=sys.stderr)
|
||||||
|
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
prompt_audio_path = None
|
prompt_audio_path = None
|
||||||
@ -205,7 +205,7 @@ def cmd_batch(args):
|
|||||||
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
for i, text in enumerate(texts, 1):
|
for i, text in enumerate(texts, 1):
|
||||||
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...")
|
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...", file=sys.stderr)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
@ -221,14 +221,14 @@ def cmd_batch(args):
|
|||||||
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
||||||
|
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f" Saved: {output_file} ({duration:.2f}s)")
|
print(f" Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Failed: {e}")
|
print(f" Failed: {e}", file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded")
|
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
|
||||||
|
|
||||||
def _build_unified_parser():
|
def _build_unified_parser():
|
||||||
"""Build unified argument parser (no subcommands, route by args)."""
|
"""Build unified argument parser (no subcommands, route by args)."""
|
||||||
@ -296,14 +296,14 @@ def main():
|
|||||||
# Routing: prefer batch → single (clone/direct)
|
# Routing: prefer batch → single (clone/direct)
|
||||||
if args.input:
|
if args.input:
|
||||||
if not args.output_dir:
|
if not args.output_dir:
|
||||||
print("Error: Batch mode requires --output-dir")
|
print("Error: Batch mode requires --output-dir", file=sys.stderr)
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
return cmd_batch(args)
|
return cmd_batch(args)
|
||||||
|
|
||||||
# Single-sample mode
|
# Single-sample mode
|
||||||
if not args.text or not args.output:
|
if not args.text or not args.output:
|
||||||
print("Error: Single-sample mode requires --text and --output")
|
print("Error: Single-sample mode requires --text and --output", file=sys.stderr)
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -316,7 +316,7 @@ def main():
|
|||||||
args.prompt_text = f.read()
|
args.prompt_text = f.read()
|
||||||
|
|
||||||
if not args.prompt_audio or not args.prompt_text:
|
if not args.prompt_audio or not args.prompt_text:
|
||||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
print("Error: Voice cloning requires both --prompt-audio and --prompt-text", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
return cmd_clone(args)
|
return cmd_clone(args)
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -30,7 +31,7 @@ class VoxCPM:
|
|||||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||||
"""
|
"""
|
||||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
|
||||||
|
|
||||||
# If lora_weights_path is provided but no lora_config, create a default one
|
# If lora_weights_path is provided but no lora_config, create a default one
|
||||||
if lora_weights_path is not None and lora_config is None:
|
if lora_weights_path is not None and lora_config is None:
|
||||||
@ -39,15 +40,15 @@ class VoxCPM:
|
|||||||
enable_dit=True,
|
enable_dit=True,
|
||||||
enable_proj=False,
|
enable_proj=False,
|
||||||
)
|
)
|
||||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
|
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
||||||
|
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||||
|
|
||||||
# Load LoRA weights if path is provided
|
# Load LoRA weights if path is provided
|
||||||
if lora_weights_path is not None:
|
if lora_weights_path is not None:
|
||||||
print(f"Loading LoRA weights from: {lora_weights_path}")
|
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
|
||||||
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
|
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
||||||
|
|
||||||
self.text_normalizer = None
|
self.text_normalizer = None
|
||||||
if enable_denoiser and zipenhancer_model_path is not None:
|
if enable_denoiser and zipenhancer_model_path is not None:
|
||||||
@ -56,7 +57,7 @@ class VoxCPM:
|
|||||||
else:
|
else:
|
||||||
self.denoiser = None
|
self.denoiser = None
|
||||||
if optimize:
|
if optimize:
|
||||||
print("Warm up VoxCPMModel...")
|
print("Warm up VoxCPMModel...", file=sys.stderr)
|
||||||
self.tts_model.generate(
|
self.tts_model.generate(
|
||||||
target_text="Hello, this is the first test sentence.",
|
target_text="Hello, this is the first test sentence.",
|
||||||
max_len=10,
|
max_len=10,
|
||||||
@ -278,4 +279,4 @@ class VoxCPM:
|
|||||||
@property
|
@property
|
||||||
def lora_enabled(self) -> bool:
|
def lora_enabled(self) -> bool:
|
||||||
"""Check if LoRA is currently configured."""
|
"""Check if LoRA is currently configured."""
|
||||||
return self.tts_model.lora_config is not None
|
return self.tts_model.lora_config is not None
|
||||||
|
|||||||
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from typing import Tuple, Union, Generator, List, Optional
|
from typing import Tuple, Union, Generator, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -120,7 +121,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.device = "mps"
|
self.device = "mps"
|
||||||
else:
|
else:
|
||||||
self.device = "cpu"
|
self.device = "cpu"
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
self.base_lm = MiniCPMModel(config.lm_config)
|
self.base_lm = MiniCPMModel(config.lm_config)
|
||||||
@ -228,7 +229,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: torch.compile disabled - {e}")
|
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -459,7 +460,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||||
retry_badcase_times += 1
|
retry_badcase_times += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -683,7 +684,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||||
retry_badcase_times += 1
|
retry_badcase_times += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -868,10 +869,10 @@ class VoxCPMModel(nn.Module):
|
|||||||
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
||||||
|
|
||||||
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||||
print(f"Loading model from safetensors: {safetensors_path}")
|
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
|
||||||
model_state_dict = load_file(safetensors_path)
|
model_state_dict = load_file(safetensors_path)
|
||||||
elif os.path.exists(pytorch_model_path):
|
elif os.path.exists(pytorch_model_path):
|
||||||
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
|
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
pytorch_model_path,
|
pytorch_model_path,
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
@ -36,7 +37,7 @@ class TrainingTracker:
|
|||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
def print(self, message: str):
|
def print(self, message: str):
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
print(message, flush=True)
|
print(message, flush=True, file=sys.stderr)
|
||||||
if self.log_file:
|
if self.log_file:
|
||||||
with self.log_file.open("a", encoding="utf-8") as f:
|
with self.log_file.open("a", encoding="utf-8") as f:
|
||||||
f.write(message + "\n")
|
f.write(message + "\n")
|
||||||
|
|||||||
365
test_concurrent.py
Normal file
365
test_concurrent.py
Normal file
@ -0,0 +1,365 @@
|
|||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import time
|
||||||
|
import statistics
|
||||||
|
import csv
|
||||||
|
from typing import List, Dict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# ---------------- 基本配置 ----------------
|
||||||
|
BASE_URL = "http://localhost:8880/generate_tts"
|
||||||
|
|
||||||
|
REQUEST_TEMPLATE = {
|
||||||
|
"text": "哇,你这个新买的T-shirt好cool啊!是哪个brand的?周末我们去新开的mall里那家Starbucks喝杯coffee吧?我听说他们的new season限定款Latte很OK。"
|
||||||
|
}
|
||||||
|
|
||||||
|
# REQUEST_TEMPLATE = {
|
||||||
|
# "text": "澳门在哪里啊",
|
||||||
|
# "cfg_value": "2.0",
|
||||||
|
# "inference_timesteps": "10",
|
||||||
|
# "normalize": "true",
|
||||||
|
# "denoise": "true",
|
||||||
|
# "prompt_text": "澳门有乜嘢好食嘅?"
|
||||||
|
# }
|
||||||
|
|
||||||
|
# # 音频路径(确保文件存在)
|
||||||
|
# PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
|
||||||
|
|
||||||
|
# headers 一般不必指定 multipart,会自动设置
|
||||||
|
DEFAULT_HEADERS = {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- 请求逻辑 ----------------
|
||||||
|
async def tts_request(session: aiohttp.ClientSession, request_id: int) -> float:
|
||||||
|
"""执行一次请求,返回耗时(秒)"""
|
||||||
|
start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
for k, v in REQUEST_TEMPLATE.items():
|
||||||
|
form.add_field(k, v)
|
||||||
|
# form.add_field(
|
||||||
|
# "prompt_wav",
|
||||||
|
# open(PROMPT_WAV_PATH, "rb"),
|
||||||
|
# filename="2food.wav",
|
||||||
|
# content_type="audio/wav"
|
||||||
|
# )
|
||||||
|
|
||||||
|
async with session.post(BASE_URL, data=form, headers=DEFAULT_HEADERS) as resp:
|
||||||
|
await resp.read()
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"HTTP {resp.status}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[请求 {request_id}] 出错: {e}")
|
||||||
|
return -1
|
||||||
|
return time.perf_counter() - start
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- 并发测试核心 ----------------
|
||||||
|
async def benchmark(concurrency: int, total_requests: int) -> Dict:
|
||||||
|
"""在指定并发下发起多次请求,统计性能指标"""
|
||||||
|
timings = []
|
||||||
|
errors = 0
|
||||||
|
|
||||||
|
conn = aiohttp.TCPConnector(limit=0, force_close=False)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=None)
|
||||||
|
async with aiohttp.ClientSession(connector=conn, timeout=timeout) as session:
|
||||||
|
sem = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
async def worker(i):
|
||||||
|
nonlocal errors
|
||||||
|
async with sem:
|
||||||
|
t = await tts_request(session, i)
|
||||||
|
if t > 0:
|
||||||
|
timings.append(t)
|
||||||
|
else:
|
||||||
|
errors += 1
|
||||||
|
|
||||||
|
tasks = [asyncio.create_task(worker(i)) for i in range(total_requests)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
succ = len(timings)
|
||||||
|
total_time = sum(timings)
|
||||||
|
result = {
|
||||||
|
"concurrency": concurrency,
|
||||||
|
"total_requests": total_requests,
|
||||||
|
"successes": succ,
|
||||||
|
"errors": errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
if succ > 0:
|
||||||
|
result.update({
|
||||||
|
"min": min(timings),
|
||||||
|
"max": max(timings),
|
||||||
|
"avg": statistics.mean(timings),
|
||||||
|
"median": statistics.median(timings),
|
||||||
|
"qps": succ / total_time if total_time > 0 else 0
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- CSV 写入 ----------------
|
||||||
|
def write_to_csv(filename: str, data: List[Dict]):
|
||||||
|
"""把测试结果写入 CSV 文件"""
|
||||||
|
fieldnames = ["concurrency", "total_requests", "successes", "errors",
|
||||||
|
"min", "max", "avg", "median", "qps"]
|
||||||
|
with open(filename, "w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for row in data:
|
||||||
|
writer.writerow(row)
|
||||||
|
print(f"\n✅ 测试结果已保存到: {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------- 主流程 ----------------
|
||||||
|
async def run_tests(concurrency_list: List[int], total_requests: dict):
|
||||||
|
print("开始压测接口:", BASE_URL)
|
||||||
|
results = []
|
||||||
|
for c in concurrency_list:
|
||||||
|
print(f"\n=== 并发数: {c} ===")
|
||||||
|
res = await benchmark(c, total_requests[c])
|
||||||
|
results.append(res)
|
||||||
|
|
||||||
|
print(f"请求总数: {res['total_requests']} | 成功: {res['successes']} | 失败: {res['errors']}")
|
||||||
|
if "avg" in res:
|
||||||
|
print(f"耗时 (s) → min={res['min']:.3f}, avg={res['avg']:.3f}, median={res['median']:.3f}, max={res['max']:.3f}")
|
||||||
|
print(f"近似 QPS: {res['qps']:.2f}")
|
||||||
|
else:
|
||||||
|
print("所有请求均失败")
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
|
||||||
|
csv_filename = f"tts_benchmark_{timestamp}.csv"
|
||||||
|
write_to_csv(csv_filename, results)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 你可以根据机器能力调整这两个参数
|
||||||
|
concurrency_list = [1, 5, 10, 20, 50, 100, 150, 200] # 并发测试范围
|
||||||
|
total_requests = {
|
||||||
|
1: 50,
|
||||||
|
5: 100,
|
||||||
|
10: 200,
|
||||||
|
20: 400,
|
||||||
|
50: 1000,
|
||||||
|
100: 2000,
|
||||||
|
150: 3000,
|
||||||
|
200: 4000,
|
||||||
|
300: 6000,
|
||||||
|
500: 10000
|
||||||
|
} # 每个并发等级下的请求数
|
||||||
|
asyncio.run(run_tests(concurrency_list, total_requests))
|
||||||
|
|
||||||
|
|
||||||
|
# import asyncio
|
||||||
|
# import aiohttp
|
||||||
|
# import time
|
||||||
|
# import statistics
|
||||||
|
# from typing import List, Dict
|
||||||
|
|
||||||
|
# # 新的接口 URL
|
||||||
|
# BASE_URL = "http://127.0.0.1:8880/generate_tts"
|
||||||
|
|
||||||
|
# # # 固定参数(不包括音频)
|
||||||
|
# # FORM_PARAMS = {
|
||||||
|
# # "text": "澳门在哪里啊",
|
||||||
|
# # "cfg_value": "2.0",
|
||||||
|
# # "inference_timesteps": "10",
|
||||||
|
# # "normalize": "true",
|
||||||
|
# # "denoise": "true",
|
||||||
|
# # "prompt_text": "澳门有乜嘢好食嘅?"
|
||||||
|
# # }
|
||||||
|
# FORM_PARAMS = {
|
||||||
|
# "text": "澳门在哪里啊"
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
# # 音频路径(确保文件存在)
|
||||||
|
# # PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
|
||||||
|
|
||||||
|
# # headers 一般不必指定 multipart,会自动设置
|
||||||
|
# DEFAULT_HEADERS = {}
|
||||||
|
|
||||||
|
|
||||||
|
# async def tts_request(session: aiohttp.ClientSession) -> float:
|
||||||
|
# """
|
||||||
|
# 发起一次 multipart/form-data 格式的 TTS 请求,返回耗时。
|
||||||
|
# """
|
||||||
|
# start = time.perf_counter()
|
||||||
|
# form = aiohttp.FormData()
|
||||||
|
# for k, v in FORM_PARAMS.items():
|
||||||
|
# form.add_field(k, v)
|
||||||
|
# # form.add_field(
|
||||||
|
# # "prompt_wav",
|
||||||
|
# # open(PROMPT_WAV_PATH, "rb"),
|
||||||
|
# # filename="2food.wav",
|
||||||
|
# # content_type="audio/wav"
|
||||||
|
# # )
|
||||||
|
|
||||||
|
# async with session.post(BASE_URL, data=form, headers=DEFAULT_HEADERS) as resp:
|
||||||
|
# data = await resp.read()
|
||||||
|
# if resp.status != 200:
|
||||||
|
# raise RuntimeError(f"HTTP {resp.status}, body: {data[:200]!r}")
|
||||||
|
|
||||||
|
# elapsed = time.perf_counter() - start
|
||||||
|
# return elapsed
|
||||||
|
|
||||||
|
|
||||||
|
# async def benchmark(concurrency: int, total_requests: int) -> Dict:
|
||||||
|
# timings: List[float] = []
|
||||||
|
# errors = 0
|
||||||
|
|
||||||
|
# conn = aiohttp.TCPConnector(limit=0)
|
||||||
|
# timeout = aiohttp.ClientTimeout(total=None)
|
||||||
|
|
||||||
|
# async with aiohttp.ClientSession(connector=conn, timeout=timeout) as session:
|
||||||
|
# sem = asyncio.Semaphore(concurrency)
|
||||||
|
|
||||||
|
# async def worker():
|
||||||
|
# nonlocal errors
|
||||||
|
# async with sem:
|
||||||
|
# try:
|
||||||
|
# t = await tts_request(session)
|
||||||
|
# timings.append(t)
|
||||||
|
# except Exception as e:
|
||||||
|
# errors += 1
|
||||||
|
# print("请求出错:", e)
|
||||||
|
|
||||||
|
# tasks = [asyncio.create_task(worker()) for _ in range(total_requests)]
|
||||||
|
# await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# succ = len(timings)
|
||||||
|
# total_time = sum(timings) if timings else 0.0
|
||||||
|
|
||||||
|
# result = {
|
||||||
|
# "concurrency": concurrency,
|
||||||
|
# "total_requests": total_requests,
|
||||||
|
# "successes": succ,
|
||||||
|
# "errors": errors,
|
||||||
|
# }
|
||||||
|
# if succ > 0:
|
||||||
|
# result.update({
|
||||||
|
# "min": min(timings),
|
||||||
|
# "max": max(timings),
|
||||||
|
# "avg": statistics.mean(timings),
|
||||||
|
# "median": statistics.median(timings),
|
||||||
|
# "qps": succ / total_time if total_time > 0 else None
|
||||||
|
# })
|
||||||
|
# return result
|
||||||
|
|
||||||
|
|
||||||
|
# async def run_tests(concurrency_list: List[int], total_requests: int):
|
||||||
|
# print("开始并发测试,目标接口:", BASE_URL)
|
||||||
|
# for c in concurrency_list:
|
||||||
|
# print(f"\n--- 并发 = {c} ---")
|
||||||
|
# res = await benchmark(c, total_requests)
|
||||||
|
# print("总请求:", res["total_requests"])
|
||||||
|
# print("成功:", res["successes"], "失败:", res["errors"])
|
||||||
|
# if "avg" in res:
|
||||||
|
# print(f"耗时 min / avg / median / max = "
|
||||||
|
# f"{res['min']:.3f} / {res['avg']:.3f} / {res['median']:.3f} / {res['max']:.3f} 秒")
|
||||||
|
# print(f"近似 QPS = {res['qps']:.2f}")
|
||||||
|
# else:
|
||||||
|
# print("所有请求都失败了")
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# # 你可以调整这些参数
|
||||||
|
# concurrency_list = [200] # 要试的并发数列表
|
||||||
|
# total_requests = 100 # 每个并发等级下总请求数
|
||||||
|
|
||||||
|
# asyncio.run(run_tests(concurrency_list, total_requests))
|
||||||
|
|
||||||
|
|
||||||
|
# import multiprocessing as mp
|
||||||
|
# import requests
|
||||||
|
# import time
|
||||||
|
# import statistics
|
||||||
|
# from typing import Dict, List
|
||||||
|
|
||||||
|
# BASE_URL = "http://127.0.0.1:8880/generate_tts"
|
||||||
|
# FORM_PARAMS = {
|
||||||
|
# "text": "澳门在哪里啊",
|
||||||
|
# "cfg_value": "2.0",
|
||||||
|
# "inference_timesteps": "10",
|
||||||
|
# "normalize": "true",
|
||||||
|
# "denoise": "true",
|
||||||
|
# "prompt_text": "澳门有乜嘢好食嘅?"
|
||||||
|
# }
|
||||||
|
# PROMPT_WAV_PATH = "/home/verachen/Music/voice/2food.wav"
|
||||||
|
|
||||||
|
# def single_request() -> float:
|
||||||
|
# """用 requests 同步发一次 multipart/form-data 请求,返回耗时(秒)或抛异常。"""
|
||||||
|
# files = {
|
||||||
|
# "prompt_wav": ("2food.wav", open(PROMPT_WAV_PATH, "rb"), "audio/wav")
|
||||||
|
# }
|
||||||
|
# data = FORM_PARAMS.copy()
|
||||||
|
# start = time.perf_counter()
|
||||||
|
# resp = requests.post(BASE_URL, data=data, files=files, timeout=60)
|
||||||
|
# elapsed = time.perf_counter() - start
|
||||||
|
# if resp.status_code != 200:
|
||||||
|
# raise RuntimeError(f"HTTP {resp.status_code}, body: {resp.text[:200]!r}")
|
||||||
|
# return elapsed
|
||||||
|
|
||||||
|
# def worker_task(num_requests: int, return_list: mp.Manager().list, err_list: mp.Manager().list):
|
||||||
|
# """子进程做 num_requests 次请求,将各次耗时记录到 return_list(共享 list),错误次数记录到 err_list。"""
|
||||||
|
# for _ in range(num_requests):
|
||||||
|
# try:
|
||||||
|
# t = single_request()
|
||||||
|
# return_list.append(t)
|
||||||
|
# except Exception as e:
|
||||||
|
# err_list.append(str(e))
|
||||||
|
|
||||||
|
# def run_multiproc(concurrency: int, total_requests: int) -> Dict:
|
||||||
|
# """
|
||||||
|
# 用多个进程模拟并发:
|
||||||
|
# - 每个进程发 total_requests/concurrency 次请求(向下取整或略分配)
|
||||||
|
# - 或者简单地让每个进程跑 total_requests 次(更激进)
|
||||||
|
# """
|
||||||
|
# manager = mp.Manager()
|
||||||
|
# times = manager.list()
|
||||||
|
# errs = manager.list()
|
||||||
|
|
||||||
|
# procs = []
|
||||||
|
# # 任务分配:每个子进程跑一部分请求
|
||||||
|
# per = total_requests // concurrency
|
||||||
|
# if per < 1:
|
||||||
|
# per = 1
|
||||||
|
|
||||||
|
# for i in range(concurrency):
|
||||||
|
# p = mp.Process(target=worker_task, args=(per, times, errs))
|
||||||
|
# p.start()
|
||||||
|
# procs.append(p)
|
||||||
|
|
||||||
|
# for p in procs:
|
||||||
|
# p.join()
|
||||||
|
|
||||||
|
# timings = list(times)
|
||||||
|
# errors = list(errs)
|
||||||
|
# succ = len(timings)
|
||||||
|
# total_time = sum(timings) if timings else 0.0
|
||||||
|
|
||||||
|
# ret = {
|
||||||
|
# "concurrency": concurrency,
|
||||||
|
# "total_requests": total_requests,
|
||||||
|
# "successes": succ,
|
||||||
|
# "errors": len(errors),
|
||||||
|
# }
|
||||||
|
# if succ > 0:
|
||||||
|
# ret.update({
|
||||||
|
# "min": min(timings),
|
||||||
|
# "max": max(timings),
|
||||||
|
# "avg": statistics.mean(timings),
|
||||||
|
# "median": statistics.median(timings),
|
||||||
|
# "qps": succ / total_time if total_time > 0 else None
|
||||||
|
# })
|
||||||
|
# return ret
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# concurrency = 4
|
||||||
|
# total_requests = 40
|
||||||
|
# print("开始多进程并发测试")
|
||||||
|
# result = run_multiproc(concurrency, total_requests)
|
||||||
|
# print(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
34
test_streaming.py
Normal file
34
test_streaming.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
import time
|
||||||
|
|
||||||
|
async def test_streaming():
|
||||||
|
url = "http://localhost:8880/generate_tts_streaming"
|
||||||
|
data = {
|
||||||
|
"text": "你好,这是一段流式输出的测试音频。",
|
||||||
|
"cfg_value": "2.0",
|
||||||
|
"inference_timesteps": "10",
|
||||||
|
"do_normalize": "True",
|
||||||
|
"denoise": "True"
|
||||||
|
}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
first_byte_received = False
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# 模拟文件上传(如果有 prompt_wav)
|
||||||
|
# files = {'prompt_wav': open('test.wav', 'rb')}
|
||||||
|
|
||||||
|
async with client.stream("POST", url, data=data) as response:
|
||||||
|
print(f"状态码: {response.status_code}")
|
||||||
|
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
if not first_byte_received:
|
||||||
|
ttfb = time.time() - start_time
|
||||||
|
print(f"🚀 首包到达 (TTFB): {ttfb:.4f} 秒")
|
||||||
|
first_byte_received = True
|
||||||
|
|
||||||
|
print(f"收到数据块: {len(chunk)} 字节")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_streaming())
|
||||||
Reference in New Issue
Block a user