Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7aadc6c94e | |||
| 8f3a91cac8 | |||
| e72fb42c38 | |||
| 6dd63a534f | |||
| 79e75f259e | |||
| e8dd956fc2 | |||
| db75a7269b | |||
| f2e203d5e2 | |||
| 6ecc00a5d3 | |||
| 8cfd9d155a | |||
| 6499215204 | |||
| d57ac634f8 | |||
| de11c6a8cb | |||
| ee5f2567ac | |||
| b3a2d95fec | |||
| aabda60833 | |||
| a266c0a88d | |||
| 0779a93697 | |||
| a1f9d0c3b6 | |||
| aefba63f71 | |||
| 58717d7d82 | |||
| 1b0ff5693c | |||
| 762815a5b7 | |||
| 5b13a35ea6 | |||
| 3ba727a615 |
@ -44,13 +44,13 @@ Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses
|
|||||||
### 📦 Model Versions
|
### 📦 Model Versions
|
||||||
See [Release Notes](docs/release_note.md) for details
|
See [Release Notes](docs/release_note.md) for details
|
||||||
- **VoxCPM1.5** (Latest):
|
- **VoxCPM1.5** (Latest):
|
||||||
- Model Params: 750M
|
- Model Params: 800M
|
||||||
- Sampling rate of AudioVAE: 44100
|
- Sampling rate of AudioVAE: 44100
|
||||||
- Token rate in LM Backbone: 6.25Hz (patch-size=4)
|
- Token rate in LM Backbone: 6.25Hz (patch-size=4)
|
||||||
- RTF in a single NVIDIA-RTX 4090 GPU: ~0.15
|
- RTF in a single NVIDIA-RTX 4090 GPU: ~0.15
|
||||||
|
|
||||||
- **VoxCPM-0.5B** (Original):
|
- **VoxCPM-0.5B** (Original):
|
||||||
- Model Params: 600M
|
- Model Params: 640M
|
||||||
- Sampling rate of AudioVAE: 16000
|
- Sampling rate of AudioVAE: 16000
|
||||||
- Token rate in LM Backbone: 12.5Hz (patch-size=2)
|
- Token rate in LM Backbone: 12.5Hz (patch-size=2)
|
||||||
- RTF in a single NVIDIA-RTX 4090 GPU: 0.17
|
- RTF in a single NVIDIA-RTX 4090 GPU: 0.17
|
||||||
@ -210,6 +210,8 @@ We're excited to see the VoxCPM community growing! Here are some amazing project
|
|||||||
- **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU.
|
- **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU.
|
||||||
- **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference.
|
- **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference.
|
||||||
- **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server.
|
- **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server.
|
||||||
|
- **[PR: LoRA finetune web UI (by Ayin1412)](https://github.com/OpenBMB/VoxCPM/pull/100)**
|
||||||
|
- **[voxcpm_rs](https://github.com/madushan1000/voxcpm_rs)** A re-implementation of VoxCPM-0.5B in Rust.
|
||||||
|
|
||||||
*Note: The projects are not officially maintained by OpenBMB.*
|
*Note: The projects are not officially maintained by OpenBMB.*
|
||||||
|
|
||||||
|
|||||||
39
app.py
39
app.py
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -16,7 +17,7 @@ import voxcpm
|
|||||||
class VoxCPMDemo:
|
class VoxCPMDemo:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print(f"🚀 Running on device: {self.device}")
|
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
|
||||||
|
|
||||||
# ASR model for prompt text recognition
|
# ASR model for prompt text recognition
|
||||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
self.asr_model_id = "iic/SenseVoiceSmall"
|
||||||
@ -49,10 +50,10 @@ class VoxCPMDemo:
|
|||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download # type: ignore
|
from huggingface_hub import snapshot_download # type: ignore
|
||||||
os.makedirs(target_dir, exist_ok=True)
|
os.makedirs(target_dir, exist_ok=True)
|
||||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
|
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: HF download failed: {e}. Falling back to 'data'.")
|
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
|
||||||
return "models"
|
return "models"
|
||||||
return target_dir
|
return target_dir
|
||||||
return "models"
|
return "models"
|
||||||
@ -60,11 +61,11 @@ class VoxCPMDemo:
|
|||||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||||
if self.voxcpm_model is not None:
|
if self.voxcpm_model is not None:
|
||||||
return self.voxcpm_model
|
return self.voxcpm_model
|
||||||
print("Model not loaded, initializing...")
|
print("Model not loaded, initializing...", file=sys.stderr)
|
||||||
model_dir = self._resolve_model_dir()
|
model_dir = self._resolve_model_dir()
|
||||||
print(f"Using model dir: {model_dir}")
|
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||||
print("Model loaded successfully.")
|
print("Model loaded successfully.", file=sys.stderr)
|
||||||
return self.voxcpm_model
|
return self.voxcpm_model
|
||||||
|
|
||||||
# ---------- Functional endpoints ----------
|
# ---------- Functional endpoints ----------
|
||||||
@ -98,7 +99,7 @@ class VoxCPMDemo:
|
|||||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||||
prompt_text = prompt_text_input if prompt_text_input else None
|
prompt_text = prompt_text_input if prompt_text_input else None
|
||||||
|
|
||||||
print(f"Generating audio for text: '{text[:60]}...'")
|
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
||||||
wav = current_model.generate(
|
wav = current_model.generate(
|
||||||
text=text,
|
text=text,
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
@ -172,22 +173,22 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
### Prompt Speech Enhancement|参考语音降噪
|
### Prompt Speech Enhancement|参考语音降噪
|
||||||
- **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
|
- **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
|
||||||
**启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质。
|
**启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz,限制克隆上限。
|
||||||
- **Disable** to preserve the original audio's background atmosphere.
|
- **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
|
||||||
**禁用**:保留原始音频的背景环境声,如果想复刻相应声学环境。
|
**禁用**:保留原始音频的全部信息,包括背景环境声,最高支持44.1kHz的音频复刻。
|
||||||
|
|
||||||
### Text Normalization|文本正则化
|
### Text Normalization|文本正则化
|
||||||
- **Enable** to process general text with an external WeTextProcessing component.
|
- **Enable** to process general text with an external WeTextProcessing component.
|
||||||
**启用**:使用 WeTextProcessing 组件,可处理常见文本。
|
**启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。
|
||||||
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input ({HH AH0 L OW1}), try it!
|
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it!
|
||||||
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如 {da4}{jia1}好)和公式符号合成,尝试一下!
|
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3};英文转CMUDict:{HH AH0 L OW1})和公式符号合成,尝试一下!
|
||||||
|
|
||||||
### CFG Value|CFG 值
|
### CFG Value|CFG 值
|
||||||
- **Lower CFG** if the voice prompt sounds strained or expressive.
|
- **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input.
|
||||||
**调低**:如果提示语音听起来不自然或过于夸张。
|
**调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。
|
||||||
- **Higher CFG** for better adherence to the prompt speech style or input text.
|
- **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input.
|
||||||
**调高**:为更好地贴合提示音频的风格或输入文本。
|
**调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。
|
||||||
|
|
||||||
### Inference Timesteps|推理时间步
|
### Inference Timesteps|推理时间步
|
||||||
- **Lower** for faster synthesis speed.
|
- **Lower** for faster synthesis speed.
|
||||||
@ -267,7 +268,7 @@ def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error
|
|||||||
demo = VoxCPMDemo()
|
demo = VoxCPMDemo()
|
||||||
interface = create_demo_interface(demo)
|
interface = create_demo_interface(demo)
|
||||||
# Recommended to enable queue on Spaces for better throughput
|
# Recommended to enable queue on Spaces for better throughput
|
||||||
interface.queue(max_size=10).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -19,6 +19,8 @@ tensorboard: /path/to/logs/finetune_lora
|
|||||||
lambdas:
|
lambdas:
|
||||||
loss/diff: 1.0
|
loss/diff: 1.0
|
||||||
loss/stop: 1.0
|
loss/stop: 1.0
|
||||||
|
|
||||||
|
# LoRA configuration
|
||||||
lora:
|
lora:
|
||||||
enable_lm: true
|
enable_lm: true
|
||||||
enable_dit: true
|
enable_dit: true
|
||||||
@ -26,3 +28,9 @@ lora:
|
|||||||
r: 32
|
r: 32
|
||||||
alpha: 16
|
alpha: 16
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
|
|
||||||
|
# Distribution options (optional)
|
||||||
|
# - If distribute=false (default): save pretrained_path as base_model in lora_config.json
|
||||||
|
# - If distribute=true: save hf_model_id as base_model (hf_model_id is required)
|
||||||
|
# hf_model_id: "openbmb/VoxCPM1.5"
|
||||||
|
# distribute: true
|
||||||
|
|||||||
@ -19,6 +19,8 @@ tensorboard: /path/to/logs/finetune_lora
|
|||||||
lambdas:
|
lambdas:
|
||||||
loss/diff: 1.0
|
loss/diff: 1.0
|
||||||
loss/stop: 1.0
|
loss/stop: 1.0
|
||||||
|
|
||||||
|
# LoRA configuration
|
||||||
lora:
|
lora:
|
||||||
enable_lm: true
|
enable_lm: true
|
||||||
enable_dit: true
|
enable_dit: true
|
||||||
@ -26,3 +28,9 @@ lora:
|
|||||||
r: 32
|
r: 32
|
||||||
alpha: 16
|
alpha: 16
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
|
|
||||||
|
# Distribution options (optional)
|
||||||
|
# - If distribute=false (default): save pretrained_path as base_model in lora_config.json
|
||||||
|
# - If distribute=true: save hf_model_id as base_model (hf_model_id is required)
|
||||||
|
# hf_model_id: "openbmb/VoxCPM-0.5B"
|
||||||
|
# distribute: true
|
||||||
126
docs/finetune.md
126
docs/finetune.md
@ -19,6 +19,7 @@ LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
|
|||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Quick Start: WebUI](#quick-start-webui)
|
||||||
- [Data Preparation](#data-preparation)
|
- [Data Preparation](#data-preparation)
|
||||||
- [Full Fine-tuning](#full-fine-tuning)
|
- [Full Fine-tuning](#full-fine-tuning)
|
||||||
- [LoRA Fine-tuning](#lora-fine-tuning)
|
- [LoRA Fine-tuning](#lora-fine-tuning)
|
||||||
@ -28,6 +29,31 @@ LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Quick Start: WebUI
|
||||||
|
|
||||||
|
For users who prefer a graphical interface, we provide `lora_ft_webui.py` - a comprehensive WebUI for training and inference:
|
||||||
|
|
||||||
|
### Launch WebUI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lora_ft_webui.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Then open `http://localhost:7860` in your browser.
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
- **🚀 Training Tab**: Configure and start LoRA training with an intuitive interface
|
||||||
|
- Set training parameters (learning rate, batch size, LoRA rank, etc.)
|
||||||
|
- Monitor training progress in real-time
|
||||||
|
- Resume training from existing checkpoints
|
||||||
|
|
||||||
|
- **🎵 Inference Tab**: Generate audio with trained models
|
||||||
|
- Automatic base model loading from LoRA checkpoint config
|
||||||
|
- Voice cloning with automatic ASR (reference text recognition)
|
||||||
|
- Hot-swap between multiple LoRA models
|
||||||
|
- Zero-shot TTS without reference audio
|
||||||
|
|
||||||
## Data Preparation
|
## Data Preparation
|
||||||
|
|
||||||
Training data should be prepared as a JSONL manifest file, with one sample per line:
|
Training data should be prepared as a JSONL manifest file, with one sample per line:
|
||||||
@ -177,6 +203,10 @@ lora:
|
|||||||
# Target modules
|
# Target modules
|
||||||
target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||||
target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||||
|
|
||||||
|
# Distribution options (optional)
|
||||||
|
# hf_model_id: "openbmb/VoxCPM1.5" # HuggingFace ID
|
||||||
|
# distribute: true # If true, save hf_model_id in lora_config.json
|
||||||
```
|
```
|
||||||
|
|
||||||
### LoRA Parameters
|
### LoRA Parameters
|
||||||
@ -189,6 +219,15 @@ lora:
|
|||||||
| `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` |
|
| `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` |
|
||||||
| `target_modules_*` | Layer names to add LoRA | attention layers |
|
| `target_modules_*` | Layer names to add LoRA | attention layers |
|
||||||
|
|
||||||
|
### Distribution Options (Optional)
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `hf_model_id` | HuggingFace model ID (e.g., `openbmb/VoxCPM1.5`) | `""` |
|
||||||
|
| `distribute` | If `true`, save `hf_model_id` as `base_model` in checkpoint; otherwise save local `pretrained_path` | `false` |
|
||||||
|
|
||||||
|
> **Note**: If `distribute: true`, `hf_model_id` is required.
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -202,16 +241,37 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
|
|||||||
|
|
||||||
### Checkpoint Structure
|
### Checkpoint Structure
|
||||||
|
|
||||||
LoRA training saves only LoRA parameters:
|
LoRA training saves LoRA parameters and configuration:
|
||||||
|
|
||||||
```
|
```
|
||||||
checkpoints/finetune_lora/
|
checkpoints/finetune_lora/
|
||||||
└── step_0002000/
|
└── step_0002000/
|
||||||
├── lora_weights.safetensors # Only lora_A, lora_B parameters
|
├── lora_weights.safetensors # Only lora_A, lora_B parameters
|
||||||
|
├── lora_config.json # LoRA config + base model path
|
||||||
├── optimizer.pth
|
├── optimizer.pth
|
||||||
└── scheduler.pth
|
└── scheduler.pth
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The `lora_config.json` contains:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"base_model": "/path/to/VoxCPM1.5/",
|
||||||
|
"lora_config": {
|
||||||
|
"enable_lm": true,
|
||||||
|
"enable_dit": true,
|
||||||
|
"r": 32,
|
||||||
|
"alpha": 16,
|
||||||
|
...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The `base_model` field contains:
|
||||||
|
- Local path (default): when `distribute: false` or not set
|
||||||
|
- HuggingFace ID: when `distribute: true` (e.g., `"openbmb/VoxCPM1.5"`)
|
||||||
|
|
||||||
|
This allows loading LoRA checkpoints without the original training config file.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Inference
|
## Inference
|
||||||
@ -240,11 +300,10 @@ python scripts/test_voxcpm_ft_infer.py \
|
|||||||
|
|
||||||
### LoRA Inference
|
### LoRA Inference
|
||||||
|
|
||||||
LoRA inference requires the training config (for LoRA structure) and LoRA checkpoint:
|
LoRA inference only requires the checkpoint directory (base model path and LoRA config are read from `lora_config.json`):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/test_voxcpm_lora_infer.py \
|
python scripts/test_voxcpm_lora_infer.py \
|
||||||
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
|
|
||||||
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||||
--text "Hello, this is LoRA fine-tuned result." \
|
--text "Hello, this is LoRA fine-tuned result." \
|
||||||
--output lora_output.wav
|
--output lora_output.wav
|
||||||
@ -254,7 +313,6 @@ With voice cloning:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/test_voxcpm_lora_infer.py \
|
python scripts/test_voxcpm_lora_infer.py \
|
||||||
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
|
|
||||||
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||||
--text "This is voice cloning with LoRA." \
|
--text "This is voice cloning with LoRA." \
|
||||||
--prompt_audio /path/to/reference.wav \
|
--prompt_audio /path/to/reference.wav \
|
||||||
@ -262,6 +320,16 @@ python scripts/test_voxcpm_lora_infer.py \
|
|||||||
--output cloned_output.wav
|
--output cloned_output.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Override base model path (optional):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/test_voxcpm_lora_infer.py \
|
||||||
|
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||||
|
--base_model /path/to/another/VoxCPM1.5 \
|
||||||
|
--text "Use different base model." \
|
||||||
|
--output output.wav
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## LoRA Hot-swapping
|
## LoRA Hot-swapping
|
||||||
@ -315,20 +383,39 @@ print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
|||||||
lora_state = model.get_lora_state_dict()
|
lora_state = model.get_lora_state_dict()
|
||||||
```
|
```
|
||||||
|
|
||||||
### Simplified Usage (Auto LoRA Config)
|
### Simplified Usage (Load from lora_config.json)
|
||||||
|
|
||||||
If you only have LoRA weights and don't need custom config, just provide the path:
|
If your checkpoint contains `lora_config.json` (saved by the training script), you can load everything automatically:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import json
|
||||||
from voxcpm.core import VoxCPM
|
from voxcpm.core import VoxCPM
|
||||||
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
|
|
||||||
# Auto-create default LoRAConfig when only lora_weights_path is provided
|
# Load config from checkpoint
|
||||||
|
lora_ckpt_dir = "/path/to/checkpoints/finetune_lora/step_0002000"
|
||||||
|
with open(f"{lora_ckpt_dir}/lora_config.json") as f:
|
||||||
|
lora_info = json.load(f)
|
||||||
|
|
||||||
|
base_model = lora_info["base_model"]
|
||||||
|
lora_cfg = LoRAConfig(**lora_info["lora_config"])
|
||||||
|
|
||||||
|
# Load model with LoRA
|
||||||
model = VoxCPM.from_pretrained(
|
model = VoxCPM.from_pretrained(
|
||||||
hf_model_id="openbmb/VoxCPM1.5",
|
hf_model_id=base_model,
|
||||||
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
|
lora_config=lora_cfg,
|
||||||
|
lora_weights_path=lora_ckpt_dir,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Or use the test script directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/test_voxcpm_lora_infer.py \
|
||||||
|
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||||
|
--text "Hello world"
|
||||||
|
```
|
||||||
|
|
||||||
### Method Reference
|
### Method Reference
|
||||||
|
|
||||||
| Method | Description | torch.compile Compatible |
|
| Method | Description | torch.compile Compatible |
|
||||||
@ -343,34 +430,39 @@ model = VoxCPM.from_pretrained(
|
|||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
### 1. Out of Memory (OOM)
|
### 1. How Much Data is Needed for LoRA Fine-tuning to Converge to a Single Voice?
|
||||||
|
|
||||||
|
We have tested with 5 minutes and 10 minutes of data (all audio clips are 3-6s in length). In our experiments, both datasets converged to a single voice after 2000 training steps with default configurations. You can adjust the data amount and training configurations based on your available data and computational resources.
|
||||||
|
|
||||||
|
### 2. Out of Memory (OOM)
|
||||||
|
|
||||||
- Increase `grad_accum_steps` (gradient accumulation)
|
- Increase `grad_accum_steps` (gradient accumulation)
|
||||||
- Decrease `batch_size`
|
- Decrease `batch_size`
|
||||||
- Use LoRA fine-tuning instead of full fine-tuning
|
- Use LoRA fine-tuning instead of full fine-tuning
|
||||||
- Decrease `max_batch_tokens` to filter long samples
|
- Decrease `max_batch_tokens` to filter long samples
|
||||||
|
|
||||||
### 2. Poor LoRA Performance
|
### 3. Poor LoRA Performance
|
||||||
|
|
||||||
- Increase `r` (LoRA rank)
|
- Increase `r` (LoRA rank)
|
||||||
- Adjust `alpha` (try `alpha = r/2` or `alpha = r`)
|
- Adjust `alpha` (try `alpha = r/2` or `alpha = r`)
|
||||||
- Ensure `enable_dit: true` (required for voice cloning)
|
|
||||||
- Increase training steps
|
- Increase training steps
|
||||||
- Add more target modules
|
- Add more target modules
|
||||||
|
|
||||||
### 3. Training Not Converging
|
### 4. Training Not Converging
|
||||||
|
|
||||||
- Decrease `learning_rate`
|
- Decrease `learning_rate`
|
||||||
- Increase `warmup_steps`
|
- Increase `warmup_steps`
|
||||||
- Check data quality
|
- Check data quality
|
||||||
|
|
||||||
### 4. LoRA Not Taking Effect at Inference
|
### 5. LoRA Not Taking Effect at Inference
|
||||||
|
|
||||||
- Ensure inference config matches training config LoRA parameters
|
- Check that `lora_config.json` exists in the checkpoint directory
|
||||||
- Check `load_lora()` return value - `skipped_keys` should be empty
|
- Check `load_lora()` return value - `skipped_keys` should be empty
|
||||||
- Verify `set_lora_enabled(True)` is called
|
- Verify `set_lora_enabled(True)` is called
|
||||||
|
|
||||||
### 5. Checkpoint Loading Errors
|
### 6. Checkpoint Loading Errors
|
||||||
|
|
||||||
- Full fine-tuning: checkpoint directory should contain `model.safetensors` (or `pytorch_model.bin`), `config.json`, `audiovae.pth`
|
- Full fine-tuning: checkpoint directory should contain `model.safetensors` (or `pytorch_model.bin`), `config.json`, `audiovae.pth`
|
||||||
- LoRA: checkpoint directory should contain `lora_weights.safetensors` (or `lora_weights.ckpt`)
|
- LoRA: checkpoint directory should contain:
|
||||||
|
- `lora_weights.safetensors` (or `lora_weights.ckpt`) - LoRA weights
|
||||||
|
- `lora_config.json` - LoRA config and base model path
|
||||||
|
|||||||
@ -32,6 +32,9 @@ We reduced the token rate in LM backbone from 12.5Hz to 6.25Hz (LocEnc&LocDiT pa
|
|||||||
- 📈 Provides a foundation for longer audio generation
|
- 📈 Provides a foundation for longer audio generation
|
||||||
- 🏗️ Paves the way for training larger models in the future
|
- 🏗️ Paves the way for training larger models in the future
|
||||||
|
|
||||||
|
**Model Architecture Clarification**: The core architecture of VoxCPM1.5 remains unchanged from the technical report. The key modification is adjusting the patch size of the local modules (LocEnc & LocDiT) from 2 to 4, which reduces the LM processing rate from 12.5Hz to 6.25Hz. Since the local modules now need to handle longer contexts, we expanded their network depth, resulting in a slightly larger overall model parameter count.
|
||||||
|
|
||||||
|
**Generation Speed Clarification**: Although the model parameters have increased, VoxCPM1.5 only requires 6.25 tokens to generate 1 second of audio (compared to 12.5 tokens in the previous version). While the displayed generation speed (xx it/s) may appear slower, the actual Real-Time Factor (RTF = audio duration / processing time) shows no difference or may even be faster.
|
||||||
|
|
||||||
## 🔧 Fine-tuning Support
|
## 🔧 Fine-tuning Support
|
||||||
|
|
||||||
@ -82,7 +85,7 @@ We're continuously improving VoxCPM and working on exciting new features:
|
|||||||
|
|
||||||
### Q: Has the stability issue been resolved?
|
### Q: Has the stability issue been resolved?
|
||||||
|
|
||||||
**A:** We have made stability optimizations in VoxCPM1.5, including improvements to the training data and model architecture. Based on community feedback, we collected some stability issues such as:
|
**A:** We have made stability optimizations in VoxCPM1.5, including improvements to the inference code logic, training data, and model architecture. Based on community feedback, we collected some stability issues such as:
|
||||||
- Increased noise and reverberation
|
- Increased noise and reverberation
|
||||||
- Audio artifacts (e.g., howling/squealing)
|
- Audio artifacts (e.g., howling/squealing)
|
||||||
- Unstable speaking rate (speeding up)
|
- Unstable speaking rate (speeding up)
|
||||||
@ -90,7 +93,11 @@ We're continuously improving VoxCPM and working on exciting new features:
|
|||||||
- Noise artifacts at the beginning and end of audio
|
- Noise artifacts at the beginning and end of audio
|
||||||
- Synthesis issues with very short texts (e.g., "hello")
|
- Synthesis issues with very short texts (e.g., "hello")
|
||||||
|
|
||||||
While we have made improvements to these issues, they have not been completely resolved and may still occasionally occur, especially with very long or highly expressive inputs. We continue to work on further stability improvements in future versions.
|
**What we've improved:**
|
||||||
|
- By adjusting inference code logic and optimizing training data, we have largely fixed the beginning/ending artifacts.
|
||||||
|
- By reducing the LM processing rate (12.5Hz → 6.25Hz), we have improved stability on longer speech generation cases.
|
||||||
|
|
||||||
|
**What remains:** We acknowledge that long speech stability issues have not been completely resolved. Particularly for highly expressive or complex reference speech, error accumulation during autoregressive generation can still occur. We will continue to analyze and optimize this in future versions.
|
||||||
|
|
||||||
### Q: Does VoxCPM plan to support multilingual TTS?
|
### Q: Does VoxCPM plan to support multilingual TTS?
|
||||||
|
|
||||||
|
|||||||
@ -23,8 +23,10 @@ This is the secret sauce that gives your audio its unique sound.
|
|||||||
|
|
||||||
### 1. Cooking with a Prompt Speech (Following a Famous Recipe)
|
### 1. Cooking with a Prompt Speech (Following a Famous Recipe)
|
||||||
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
|
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
|
||||||
- **For a Clean, Studio-Quality Voice:**
|
- **For a Clean, Denoising Voice:**
|
||||||
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone.
|
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
|
||||||
|
- **For High-Quality Audio Cloning (Up to 44.1kHz):**
|
||||||
|
- ❌ Disable "Prompt Speech Enhancement" to preserve all original audio information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
|
||||||
|
|
||||||
### 2. Cooking au Naturel (Letting the Model Improvise)
|
### 2. Cooking au Naturel (Letting the Model Improvise)
|
||||||
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
|
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
|
||||||
|
|||||||
1254
lora_ft_webui.py
Normal file
1254
lora_ft_webui.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -27,6 +27,7 @@ requires-python = ">=3.10"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"torch>=2.5.0",
|
"torch>=2.5.0",
|
||||||
"torchaudio>=2.5.0",
|
"torchaudio>=2.5.0",
|
||||||
|
"torchcodec",
|
||||||
"transformers>=4.36.2",
|
"transformers>=4.36.2",
|
||||||
"einops",
|
"einops",
|
||||||
"gradio<6",
|
"gradio<6",
|
||||||
@ -41,6 +42,8 @@ dependencies = [
|
|||||||
"simplejson",
|
"simplejson",
|
||||||
"sortedcontainers",
|
"sortedcontainers",
|
||||||
"soundfile",
|
"soundfile",
|
||||||
|
"librosa",
|
||||||
|
"matplotlib",
|
||||||
"funasr",
|
"funasr",
|
||||||
"spaces",
|
"spaces",
|
||||||
"argbind",
|
"argbind",
|
||||||
|
|||||||
@ -23,6 +23,7 @@ With voice cloning:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
@ -92,7 +93,7 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Load model from checkpoint directory (no denoiser)
|
# Load model from checkpoint directory (no denoiser)
|
||||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
print(f"[FT Inference] Loading model: {args.ckpt_dir}", file=sys.stderr)
|
||||||
model = VoxCPM.from_pretrained(
|
model = VoxCPM.from_pretrained(
|
||||||
hf_model_id=args.ckpt_dir,
|
hf_model_id=args.ckpt_dir,
|
||||||
load_denoiser=False,
|
load_denoiser=False,
|
||||||
@ -103,10 +104,10 @@ def main():
|
|||||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||||
prompt_text = args.prompt_text if args.prompt_text else None
|
prompt_text = args.prompt_text if args.prompt_text else None
|
||||||
|
|
||||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
print(f"[FT Inference] Synthesizing: text='{args.text}'", file=sys.stderr)
|
||||||
if prompt_wav_path:
|
if prompt_wav_path:
|
||||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
print(f"[FT Inference] Using reference audio: {prompt_wav_path}", file=sys.stderr)
|
||||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
print(f"[FT Inference] Reference text: {prompt_text}", file=sys.stderr)
|
||||||
|
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -114,7 +115,7 @@ def main():
|
|||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
max_length=args.max_len,
|
max_len=args.max_len,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=False,
|
denoise=False,
|
||||||
)
|
)
|
||||||
@ -124,7 +125,7 @@ def main():
|
|||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||||
|
|
||||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -5,7 +5,6 @@ LoRA inference test script.
|
|||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
python scripts/test_voxcpm_lora_infer.py \
|
python scripts/test_voxcpm_lora_infer.py \
|
||||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
|
||||||
--lora_ckpt checkpoints/step_0002000 \
|
--lora_ckpt checkpoints/step_0002000 \
|
||||||
--text "Hello, this is LoRA finetuned result." \
|
--text "Hello, this is LoRA finetuned result." \
|
||||||
--output lora_test.wav
|
--output lora_test.wav
|
||||||
@ -13,37 +12,40 @@ Usage:
|
|||||||
With voice cloning:
|
With voice cloning:
|
||||||
|
|
||||||
python scripts/test_voxcpm_lora_infer.py \
|
python scripts/test_voxcpm_lora_infer.py \
|
||||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
|
||||||
--lora_ckpt checkpoints/step_0002000 \
|
--lora_ckpt checkpoints/step_0002000 \
|
||||||
--text "This is voice cloning result." \
|
--text "This is voice cloning result." \
|
||||||
--prompt_audio path/to/ref.wav \
|
--prompt_audio path/to/ref.wav \
|
||||||
--prompt_text "Reference audio transcript" \
|
--prompt_text "Reference audio transcript" \
|
||||||
--output lora_clone.wav
|
--output lora_clone.wav
|
||||||
|
|
||||||
|
Note: The script reads base_model path and lora_config from lora_config.json
|
||||||
|
in the checkpoint directory (saved automatically during training).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from voxcpm.core import VoxCPM
|
from voxcpm.core import VoxCPM
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
from voxcpm.training.config import load_yaml_config
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
|
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
|
||||||
parser.add_argument(
|
|
||||||
"--config_path",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Training YAML config path (contains pretrained_path and lora config)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora_ckpt",
|
"--lora_ckpt",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)",
|
help="LoRA checkpoint directory (contains lora_weights.safetensors and lora_config.json)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_model",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Optional: override base model path (default: read from lora_config.json)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--text",
|
"--text",
|
||||||
@ -98,26 +100,44 @@ def parse_args():
|
|||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# 1. Load YAML config
|
# 1. Check LoRA checkpoint directory
|
||||||
cfg = load_yaml_config(args.config_path)
|
ckpt_dir = Path(args.lora_ckpt)
|
||||||
pretrained_path = cfg["pretrained_path"]
|
if not ckpt_dir.exists():
|
||||||
lora_cfg_dict = cfg.get("lora", {}) or {}
|
|
||||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
|
||||||
|
|
||||||
# 2. Check LoRA checkpoint
|
|
||||||
ckpt_dir = args.lora_ckpt
|
|
||||||
if not Path(ckpt_dir).exists():
|
|
||||||
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
||||||
|
|
||||||
|
# 2. Load lora_config.json from checkpoint
|
||||||
|
lora_config_path = ckpt_dir / "lora_config.json"
|
||||||
|
if not lora_config_path.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"lora_config.json not found in {ckpt_dir}. "
|
||||||
|
"Make sure the checkpoint was saved with the updated training script."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(lora_config_path, "r", encoding="utf-8") as f:
|
||||||
|
lora_info = json.load(f)
|
||||||
|
|
||||||
|
# Get base model path (command line arg overrides config)
|
||||||
|
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
|
||||||
|
if not pretrained_path:
|
||||||
|
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
|
||||||
|
|
||||||
|
# Get LoRA config
|
||||||
|
lora_cfg_dict = lora_info.get("lora_config", {})
|
||||||
|
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||||
|
|
||||||
|
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
||||||
|
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
||||||
|
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
|
||||||
|
|
||||||
# 3. Load model with LoRA (no denoiser)
|
# 3. Load model with LoRA (no denoiser)
|
||||||
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
|
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
||||||
print(f" LoRA weights: {ckpt_dir}")
|
print(f" LoRA weights: {ckpt_dir}", file=sys.stderr)
|
||||||
model = VoxCPM.from_pretrained(
|
model = VoxCPM.from_pretrained(
|
||||||
hf_model_id=pretrained_path,
|
hf_model_id=pretrained_path,
|
||||||
load_denoiser=False,
|
load_denoiser=False,
|
||||||
optimize=True,
|
optimize=True,
|
||||||
lora_config=lora_cfg,
|
lora_config=lora_cfg,
|
||||||
lora_weights_path=ckpt_dir,
|
lora_weights_path=str(ckpt_dir),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Synthesize audio
|
# 4. Synthesize audio
|
||||||
@ -126,26 +146,26 @@ def main():
|
|||||||
out_path = Path(args.output)
|
out_path = Path(args.output)
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"\n[2/2] Starting synthesis tests...")
|
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 1: With LoRA ===
|
# === Test 1: With LoRA ===
|
||||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
max_length=args.max_len,
|
max_len=args.max_len,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=False,
|
denoise=False,
|
||||||
)
|
)
|
||||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||||
model.set_lora_enabled(False)
|
model.set_lora_enabled(False)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -153,16 +173,16 @@ def main():
|
|||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
max_length=args.max_len,
|
max_len=args.max_len,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=False,
|
denoise=False,
|
||||||
)
|
)
|
||||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 3: Re-enable LoRA ===
|
# === Test 3: Re-enable LoRA ===
|
||||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||||
model.set_lora_enabled(True)
|
model.set_lora_enabled(True)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -170,16 +190,16 @@ def main():
|
|||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
max_length=args.max_len,
|
max_len=args.max_len,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=False,
|
denoise=False,
|
||||||
)
|
)
|
||||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
|
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||||
model.unload_lora()
|
model.unload_lora()
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -187,38 +207,38 @@ def main():
|
|||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
max_length=args.max_len,
|
max_len=args.max_len,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=False,
|
denoise=False,
|
||||||
)
|
)
|
||||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||||
loaded, skipped = model.load_lora(str(ckpt_dir))
|
loaded, skipped = model.load_lora(ckpt_dir)
|
||||||
print(f" Reloaded {len(loaded)} parameters")
|
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
max_length=args.max_len,
|
max_len=args.max_len,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=False,
|
denoise=False,
|
||||||
)
|
)
|
||||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
print(f"\n[Done] All tests completed!")
|
print(f"\n[Done] All tests completed!", file=sys.stderr)
|
||||||
print(f" - with_lora: {lora_output}")
|
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
||||||
print(f" - lora_disabled: {disabled_output}")
|
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
||||||
print(f" - lora_reenabled: {reenabled_output}")
|
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
||||||
print(f" - lora_reset: {reset_output}")
|
print(f" - lora_reset: {reset_output}", file=sys.stderr)
|
||||||
print(f" - lora_reloaded: {reload_output}")
|
print(f" - lora_reloaded: {reload_output}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -14,13 +14,17 @@ import torch
|
|||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from transformers import get_cosine_schedule_with_warmup
|
from transformers import get_cosine_schedule_with_warmup
|
||||||
|
import signal
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
SAFETENSORS_AVAILABLE = True
|
SAFETENSORS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
SAFETENSORS_AVAILABLE = False
|
SAFETENSORS_AVAILABLE = False
|
||||||
print("Warning: safetensors not available, will use pytorch format")
|
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
||||||
|
|
||||||
from voxcpm.model import VoxCPMModel
|
from voxcpm.model import VoxCPMModel
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
@ -56,8 +60,16 @@ def train(
|
|||||||
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||||
lora: dict = None,
|
lora: dict = None,
|
||||||
config_path: str = "",
|
config_path: str = "",
|
||||||
|
# Distribution options (for LoRA checkpoints)
|
||||||
|
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||||
|
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||||
):
|
):
|
||||||
_ = config_path
|
_ = config_path
|
||||||
|
|
||||||
|
# Validate distribution options
|
||||||
|
if lora is not None and distribute and not hf_model_id:
|
||||||
|
raise ValueError("hf_model_id is required when distribute=True")
|
||||||
|
|
||||||
accelerator = Accelerator(amp=True)
|
accelerator = Accelerator(amp=True)
|
||||||
|
|
||||||
save_dir = Path(save_path)
|
save_dir = Path(save_path)
|
||||||
@ -87,7 +99,10 @@ def train(
|
|||||||
return {"text_ids": text_ids}
|
return {"text_ids": text_ids}
|
||||||
|
|
||||||
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
|
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
|
||||||
|
# Save original validation texts for audio generation display
|
||||||
|
val_texts = None
|
||||||
if val_ds is not None:
|
if val_ds is not None:
|
||||||
|
val_texts = list(val_ds["text"]) # Save original texts
|
||||||
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
|
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
|
||||||
|
|
||||||
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
|
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
|
||||||
@ -144,6 +159,8 @@ def train(
|
|||||||
dataset_cnt=dataset_cnt,
|
dataset_cnt=dataset_cnt,
|
||||||
device=accelerator.device,
|
device=accelerator.device,
|
||||||
)
|
)
|
||||||
|
# Save audio_vae for audio generation
|
||||||
|
audio_vae_for_gen = base_model.audio_vae
|
||||||
del base_model.audio_vae
|
del base_model.audio_vae
|
||||||
model = accelerator.prepare_model(base_model)
|
model = accelerator.prepare_model(base_model)
|
||||||
unwrapped_model = accelerator.unwrap(model)
|
unwrapped_model = accelerator.unwrap(model)
|
||||||
@ -153,7 +170,7 @@ def train(
|
|||||||
# Only print param info on rank 0 to avoid cluttered output
|
# Only print param info on rank 0 to avoid cluttered output
|
||||||
if accelerator.rank == 0:
|
if accelerator.rank == 0:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
print(name, param.requires_grad)
|
print(name, param.requires_grad, file=sys.stderr)
|
||||||
|
|
||||||
optimizer = AdamW(
|
optimizer = AdamW(
|
||||||
(p for p in model.parameters() if p.requires_grad),
|
(p for p in model.parameters() if p.requires_grad),
|
||||||
@ -171,6 +188,39 @@ def train(
|
|||||||
num_training_steps=total_training_steps,
|
num_training_steps=total_training_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Try to load checkpoint and resume training
|
||||||
|
start_step = 0
|
||||||
|
if accelerator.rank == 0:
|
||||||
|
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
|
||||||
|
# Broadcast start_step to all processes
|
||||||
|
if hasattr(accelerator, 'all_reduce'):
|
||||||
|
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
|
||||||
|
accelerator.all_reduce(start_step_tensor)
|
||||||
|
start_step = int(start_step_tensor.item())
|
||||||
|
|
||||||
|
if start_step > 0 and accelerator.rank == 0:
|
||||||
|
tracker.print(f"Resuming training from step {start_step}")
|
||||||
|
|
||||||
|
# Resume tracker for signal handler to read current step
|
||||||
|
resume = {"step": start_step}
|
||||||
|
|
||||||
|
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
||||||
|
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume):
|
||||||
|
try:
|
||||||
|
cur_step = int(_resume.get("step", start_step))
|
||||||
|
except Exception:
|
||||||
|
cur_step = start_step
|
||||||
|
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
|
||||||
|
try:
|
||||||
|
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
||||||
|
print("Checkpoint saved. Exiting.", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, _signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, _signal_handler)
|
||||||
|
|
||||||
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
|
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
|
||||||
grad_accum_steps = max(int(grad_accum_steps), 1)
|
grad_accum_steps = max(int(grad_accum_steps), 1)
|
||||||
data_epoch = 0
|
data_epoch = 0
|
||||||
@ -191,7 +241,9 @@ def train(
|
|||||||
return next(train_iter)
|
return next(train_iter)
|
||||||
|
|
||||||
with tracker.live():
|
with tracker.live():
|
||||||
for step in range(num_iters):
|
for step in range(start_step, num_iters):
|
||||||
|
# update resume step so signal handler can save current progress
|
||||||
|
resume["step"] = step
|
||||||
tracker.step = step
|
tracker.step = step
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
@ -242,7 +294,7 @@ def train(
|
|||||||
accelerator.update()
|
accelerator.update()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
if step % log_interval == 0:
|
if step % log_interval == 0 or step == num_iters - 1:
|
||||||
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
|
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
|
||||||
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
|
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
|
||||||
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
|
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
|
||||||
@ -251,21 +303,31 @@ def train(
|
|||||||
loss_values["grad_norm"] = float(grad_norm)
|
loss_values["grad_norm"] = float(grad_norm)
|
||||||
tracker.log_metrics(loss_values, split="train")
|
tracker.log_metrics(loss_values, split="train")
|
||||||
|
|
||||||
if val_loader is not None and step % valid_interval == 0 and step != 0:
|
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
|
||||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
|
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||||
|
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
|
||||||
|
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
|
||||||
|
valid_interval=valid_interval)
|
||||||
|
|
||||||
if step % save_interval == 0 and accelerator.rank == 0:
|
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
|
||||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path)
|
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
|
||||||
|
|
||||||
if accelerator.rank == 0:
|
if accelerator.rank == 0:
|
||||||
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path)
|
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path, hf_model_id, distribute)
|
||||||
if writer:
|
if writer:
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
|
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
||||||
|
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
|
||||||
|
val_texts=None, tokenizer=None, valid_interval=1000):
|
||||||
|
"""Validate and generate sample audio"""
|
||||||
|
import numpy as np
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
losses = []
|
total_losses = []
|
||||||
|
sub_losses = defaultdict(list) # Track individual sub-losses
|
||||||
num_batches = 0
|
num_batches = 0
|
||||||
max_val_batches = 10
|
max_val_batches = 10
|
||||||
|
|
||||||
@ -289,19 +351,250 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
|
|||||||
total = 0.0
|
total = 0.0
|
||||||
for key, value in outputs.items():
|
for key, value in outputs.items():
|
||||||
if key.startswith("loss/"):
|
if key.startswith("loss/"):
|
||||||
total += lambdas.get(key, 1.0) * value
|
weighted_loss = lambdas.get(key, 1.0) * value
|
||||||
losses.append(total.detach())
|
total += weighted_loss
|
||||||
|
sub_losses[key].append(value.detach())
|
||||||
|
total_losses.append(total.detach())
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
|
|
||||||
if losses:
|
if total_losses:
|
||||||
mean_loss = torch.stack(losses).mean()
|
# Compute mean total loss
|
||||||
# All-reduce validation loss across processes for global average
|
mean_total_loss = torch.stack(total_losses).mean()
|
||||||
accelerator.all_reduce(mean_loss)
|
accelerator.all_reduce(mean_total_loss)
|
||||||
tracker.log_metrics({"loss": mean_loss.item()}, split="val")
|
|
||||||
|
# Compute mean of each sub-loss
|
||||||
|
val_metrics = {"loss/total": mean_total_loss.item()}
|
||||||
|
for key, values in sub_losses.items():
|
||||||
|
mean_sub_loss = torch.stack(values).mean()
|
||||||
|
accelerator.all_reduce(mean_sub_loss)
|
||||||
|
val_metrics[key] = mean_sub_loss.item()
|
||||||
|
|
||||||
|
tracker.log_metrics(val_metrics, split="val")
|
||||||
|
|
||||||
|
# Generate sample audio for TensorBoard display
|
||||||
|
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
|
||||||
|
try:
|
||||||
|
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
|
||||||
|
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
|
||||||
|
tracker=tracker)
|
||||||
|
except Exception as e:
|
||||||
|
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
|
||||||
|
import traceback
|
||||||
|
import io
|
||||||
|
buf = io.StringIO()
|
||||||
|
traceback.print_exc(file=buf)
|
||||||
|
tracker.print(buf.getvalue())
|
||||||
|
else:
|
||||||
|
# Log why audio generation was skipped
|
||||||
|
missing = []
|
||||||
|
if writer is None:
|
||||||
|
missing.append("writer")
|
||||||
|
if val_ds is None:
|
||||||
|
missing.append("val_ds")
|
||||||
|
if audio_vae is None:
|
||||||
|
missing.append("audio_vae")
|
||||||
|
if missing and accelerator.rank == 0:
|
||||||
|
tracker.print(f"[Warning] Skip audio generation: missing {', '.join(missing)}")
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None):
|
def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
|
||||||
|
"""Compute Mel Spectrogram (dB) using librosa"""
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
audio_np = audio_np.flatten().astype(np.float32)
|
||||||
|
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
|
||||||
|
return librosa.power_to_db(mel, ref=np.max)
|
||||||
|
|
||||||
|
|
||||||
|
def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_np=None, ref_mel=None):
|
||||||
|
"""
|
||||||
|
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import librosa.display
|
||||||
|
|
||||||
|
fmax = sample_rate // 2
|
||||||
|
step_str = f" @ Step {step}" if step is not None else ""
|
||||||
|
|
||||||
|
if ref_audio_np is not None and ref_mel is not None:
|
||||||
|
# Comparison mode: reference vs generated
|
||||||
|
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
|
||||||
|
|
||||||
|
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
|
||||||
|
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
|
||||||
|
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
|
||||||
|
|
||||||
|
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
|
||||||
|
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
|
||||||
|
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
|
||||||
|
else:
|
||||||
|
# Single figure mode: show generated only
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 4))
|
||||||
|
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
|
||||||
|
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
|
||||||
|
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_audio(audio_np):
|
||||||
|
"""Normalize audio to [-0.9, 0.9]"""
|
||||||
|
import numpy as np
|
||||||
|
max_val = np.abs(audio_np).max()
|
||||||
|
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
|
||||||
|
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
|
||||||
|
tracker=None):
|
||||||
|
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
log = tracker.print if tracker else print
|
||||||
|
num_samples = min(2, len(val_ds))
|
||||||
|
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
||||||
|
|
||||||
|
unwrapped_model = accelerator.unwrap(model)
|
||||||
|
|
||||||
|
for i in range(num_samples):
|
||||||
|
sample = val_ds[i]
|
||||||
|
text = val_texts[i] if val_texts and i < len(val_texts) else "Hello, this is a test."
|
||||||
|
|
||||||
|
# Load reference audio
|
||||||
|
ref_audio_np = None
|
||||||
|
try:
|
||||||
|
if "audio" in sample and isinstance(sample["audio"], dict) and "array" in sample["audio"]:
|
||||||
|
ref_audio_np = np.array(sample["audio"]["array"], dtype=np.float32)
|
||||||
|
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
|
||||||
|
if ref_sr != sample_rate:
|
||||||
|
import torchaudio.functional as F
|
||||||
|
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||||
|
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
||||||
|
except Exception as e:
|
||||||
|
log(f"[Warning] Failed to load reference audio: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Inference setup
|
||||||
|
unwrapped_model.eval()
|
||||||
|
unwrapped_model.to(torch.bfloat16)
|
||||||
|
unwrapped_model.audio_vae = audio_vae.to(torch.float32)
|
||||||
|
|
||||||
|
log(f"[Audio] Generating sample {i} with text: '{text[:50]}...'")
|
||||||
|
with torch.no_grad():
|
||||||
|
generated = unwrapped_model.generate(target_text=text, inference_timesteps=10, cfg_value=2.0)
|
||||||
|
|
||||||
|
# Restore training setup
|
||||||
|
unwrapped_model.to(torch.float32)
|
||||||
|
unwrapped_model.audio_vae = None
|
||||||
|
|
||||||
|
if generated is None or len(generated) == 0:
|
||||||
|
log(f"[Warning] Generated audio is empty for sample {i}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process generated audio
|
||||||
|
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
|
||||||
|
gen_audio_np = normalize_audio(gen_audio_np)
|
||||||
|
|
||||||
|
tag = f"val_sample_{i}"
|
||||||
|
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
|
||||||
|
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
|
||||||
|
|
||||||
|
# Log reference audio
|
||||||
|
if ref_audio_np is not None:
|
||||||
|
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
|
||||||
|
|
||||||
|
# Generate mel spectrogram figure
|
||||||
|
try:
|
||||||
|
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
|
||||||
|
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
|
||||||
|
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref)
|
||||||
|
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
|
||||||
|
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
||||||
|
except Exception as e:
|
||||||
|
log(f"[Warning] Failed to create mel spectrogram: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||||
|
"""
|
||||||
|
Load the latest checkpoint if it exists.
|
||||||
|
Returns the step number to resume from, or 0 if no checkpoint found.
|
||||||
|
"""
|
||||||
|
latest_folder = save_dir / "latest"
|
||||||
|
if not latest_folder.exists():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
unwrapped = model.module if hasattr(model, "module") else model
|
||||||
|
lora_cfg = unwrapped.lora_config
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
if lora_cfg is not None:
|
||||||
|
# LoRA: load lora_weights
|
||||||
|
lora_weights_path = latest_folder / "lora_weights.safetensors"
|
||||||
|
if not lora_weights_path.exists():
|
||||||
|
lora_weights_path = latest_folder / "lora_weights.ckpt"
|
||||||
|
|
||||||
|
if lora_weights_path.exists():
|
||||||
|
if lora_weights_path.suffix == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
state_dict = load_file(str(lora_weights_path))
|
||||||
|
else:
|
||||||
|
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
||||||
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
|
# Load only lora weights
|
||||||
|
unwrapped.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
# Full finetune: load model.safetensors or pytorch_model.bin
|
||||||
|
model_path = latest_folder / "model.safetensors"
|
||||||
|
if not model_path.exists():
|
||||||
|
model_path = latest_folder / "pytorch_model.bin"
|
||||||
|
|
||||||
|
if model_path.exists():
|
||||||
|
if model_path.suffix == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
state_dict = load_file(str(model_path))
|
||||||
|
else:
|
||||||
|
ckpt = torch.load(model_path, map_location="cpu")
|
||||||
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
|
unwrapped.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Load optimizer state
|
||||||
|
optimizer_path = latest_folder / "optimizer.pth"
|
||||||
|
if optimizer_path.exists():
|
||||||
|
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||||
|
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Load scheduler state
|
||||||
|
scheduler_path = latest_folder / "scheduler.pth"
|
||||||
|
if scheduler_path.exists():
|
||||||
|
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||||
|
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Try to infer step from checkpoint folders
|
||||||
|
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
||||||
|
if step_folders:
|
||||||
|
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
||||||
|
resume_step = max(steps)
|
||||||
|
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||||
|
return resume_step
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
|
||||||
"""
|
"""
|
||||||
Save checkpoint with different strategies for full finetune vs LoRA:
|
Save checkpoint with different strategies for full finetune vs LoRA:
|
||||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||||
@ -310,7 +603,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
tag = "latest" if step == 0 else f"step_{step:07d}"
|
tag = f"step_{step:07d}"
|
||||||
folder = save_dir / tag
|
folder = save_dir / tag
|
||||||
folder.mkdir(parents=True, exist_ok=True)
|
folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@ -325,6 +618,17 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
save_file(state_dict, folder / "lora_weights.safetensors")
|
save_file(state_dict, folder / "lora_weights.safetensors")
|
||||||
else:
|
else:
|
||||||
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
|
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
|
||||||
|
|
||||||
|
# Save LoRA config and base model path to a separate JSON file
|
||||||
|
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
|
||||||
|
import json
|
||||||
|
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
|
||||||
|
lora_info = {
|
||||||
|
"base_model": base_model_to_save,
|
||||||
|
"lora_config": lora_cfg.model_dump() if hasattr(lora_cfg, "model_dump") else vars(lora_cfg),
|
||||||
|
}
|
||||||
|
with open(folder / "lora_config.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(lora_info, f, indent=2, ensure_ascii=False)
|
||||||
else:
|
else:
|
||||||
# Full finetune: save non-vae weights to model.safetensors
|
# Full finetune: save non-vae weights to model.safetensors
|
||||||
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
|
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
|
||||||
@ -345,6 +649,15 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||||
|
|
||||||
|
# Update (or create) a `latest` folder by copying the most recent checkpoint
|
||||||
|
latest_link = save_dir / "latest"
|
||||||
|
try:
|
||||||
|
if latest_link.exists():
|
||||||
|
shutil.rmtree(latest_link)
|
||||||
|
shutil.copytree(folder, latest_link)
|
||||||
|
except Exception:
|
||||||
|
print(f"Warning: failed to update latest checkpoint at {latest_link}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from voxcpm.training.config import load_yaml_config
|
from voxcpm.training.config import load_yaml_config
|
||||||
@ -359,4 +672,3 @@ if __name__ == "__main__":
|
|||||||
# Otherwise use command line args (parsed by argbind)
|
# Otherwise use command line args (parsed by argbind)
|
||||||
with argbind.scope(args):
|
with argbind.scope(args):
|
||||||
train()
|
train()
|
||||||
|
|
||||||
|
|||||||
@ -3,30 +3,22 @@
|
|||||||
VoxCPM Command Line Interface
|
VoxCPM Command Line Interface
|
||||||
|
|
||||||
Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
|
Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
|
||||||
|
|
||||||
Usage examples:
|
|
||||||
# Direct synthesis (single sample)
|
|
||||||
voxcpm --text "Hello world" --output output.wav
|
|
||||||
|
|
||||||
# Voice cloning (with reference audio and text)
|
|
||||||
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output output.wav --denoise
|
|
||||||
|
|
||||||
# Batch processing (each line in the file is one sample)
|
|
||||||
voxcpm --input texts.txt --output-dir ./outputs/
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from voxcpm.core import VoxCPM
|
from voxcpm.core import VoxCPM
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Validators
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
||||||
"""Validate that a file exists."""
|
|
||||||
path = Path(file_path)
|
path = Path(file_path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
|
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
|
||||||
@ -34,101 +26,111 @@ def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
|||||||
|
|
||||||
|
|
||||||
def validate_output_path(output_path: str) -> Path:
|
def validate_output_path(output_path: str) -> Path:
|
||||||
"""Validate the output path and create parent directories if needed."""
|
|
||||||
path = Path(output_path)
|
path = Path(output_path)
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def validate_ranges(args, parser):
|
||||||
|
"""Validate numeric argument ranges."""
|
||||||
|
if not (0.1 <= args.cfg_value <= 10.0):
|
||||||
|
parser.error("--cfg-value must be between 0.1 and 10.0")
|
||||||
|
|
||||||
|
if not (1 <= args.inference_timesteps <= 100):
|
||||||
|
parser.error("--inference-timesteps must be between 1 and 100")
|
||||||
|
|
||||||
|
if args.lora_r <= 0:
|
||||||
|
parser.error("--lora-r must be a positive integer")
|
||||||
|
|
||||||
|
if args.lora_alpha <= 0:
|
||||||
|
parser.error("--lora-alpha must be a positive integer")
|
||||||
|
|
||||||
|
if not (0.0 <= args.lora_dropout <= 1.0):
|
||||||
|
parser.error("--lora-dropout must be between 0.0 and 1.0")
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Model loading
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def load_model(args) -> VoxCPM:
|
def load_model(args) -> VoxCPM:
|
||||||
"""Load VoxCPM model.
|
print("Loading VoxCPM model...", file=sys.stderr)
|
||||||
|
|
||||||
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
|
|
||||||
"""
|
|
||||||
print("Loading VoxCPM model...")
|
|
||||||
|
|
||||||
# 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认
|
|
||||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||||
"ZIPENHANCER_MODEL_PATH", None
|
"ZIPENHANCER_MODEL_PATH", None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build LoRA config if lora_path is provided
|
# Build LoRA config if provided
|
||||||
lora_config = None
|
lora_config = None
|
||||||
lora_weights_path = getattr(args, "lora_path", None)
|
lora_weights_path = getattr(args, "lora_path", None)
|
||||||
if lora_weights_path:
|
if lora_weights_path:
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
lora_config = LoRAConfig(
|
|
||||||
enable_lm=getattr(args, "lora_enable_lm", True),
|
|
||||||
enable_dit=getattr(args, "lora_enable_dit", True),
|
|
||||||
enable_proj=getattr(args, "lora_enable_proj", False),
|
|
||||||
r=getattr(args, "lora_r", 32),
|
|
||||||
alpha=getattr(args, "lora_alpha", 16),
|
|
||||||
dropout=getattr(args, "lora_dropout", 0.0),
|
|
||||||
)
|
|
||||||
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
|
||||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
|
|
||||||
|
|
||||||
# Load from local path if provided
|
lora_config = LoRAConfig(
|
||||||
if getattr(args, "model_path", None):
|
enable_lm=not args.lora_disable_lm,
|
||||||
|
enable_dit=not args.lora_disable_dit,
|
||||||
|
enable_proj=args.lora_enable_proj,
|
||||||
|
r=args.lora_r,
|
||||||
|
alpha=args.lora_alpha,
|
||||||
|
dropout=args.lora_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||||
|
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load local model if specified
|
||||||
|
if args.model_path:
|
||||||
try:
|
try:
|
||||||
model = VoxCPM(
|
model = VoxCPM(
|
||||||
voxcpm_model_path=args.model_path,
|
voxcpm_model_path=args.model_path,
|
||||||
zipenhancer_model_path=zipenhancer_path,
|
zipenhancer_model_path=zipenhancer_path,
|
||||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
enable_denoiser=not args.no_denoiser,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
print("Model loaded (local).")
|
print("Model loaded (local).", file=sys.stderr)
|
||||||
return model
|
return model
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to load model (local): {e}")
|
print(f"Failed to load model (local): {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Otherwise, try from_pretrained (Hub); exit on failure
|
# Load from Hugging Face Hub
|
||||||
try:
|
try:
|
||||||
model = VoxCPM.from_pretrained(
|
model = VoxCPM.from_pretrained(
|
||||||
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"),
|
hf_model_id=args.hf_model_id,
|
||||||
load_denoiser=not getattr(args, "no_denoiser", False),
|
load_denoiser=not args.no_denoiser,
|
||||||
zipenhancer_model_id=zipenhancer_path,
|
zipenhancer_model_id=zipenhancer_path,
|
||||||
cache_dir=getattr(args, "cache_dir", None),
|
cache_dir=args.cache_dir,
|
||||||
local_files_only=getattr(args, "local_files_only", False),
|
local_files_only=args.local_files_only,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
print("Model loaded (from_pretrained).")
|
print("Model loaded (from_pretrained).", file=sys.stderr)
|
||||||
return model
|
return model
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to load model (from_pretrained): {e}")
|
print(f"Failed to load model (from_pretrained): {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Commands
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def cmd_clone(args):
|
def cmd_clone(args):
|
||||||
"""Voice cloning command."""
|
|
||||||
# Validate inputs
|
|
||||||
if not args.text:
|
if not args.text:
|
||||||
print("Error: Please provide text to synthesize (--text)")
|
sys.exit("Error: Please provide --text for synthesis")
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not args.prompt_audio:
|
if not args.prompt_audio or not args.prompt_text:
|
||||||
print("Error: Voice cloning requires a reference audio (--prompt-audio)")
|
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not args.prompt_text:
|
|
||||||
print("Error: Voice cloning requires a reference text (--prompt-text)")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Validate files
|
|
||||||
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
||||||
output_path = validate_output_path(args.output)
|
output_path = validate_output_path(args.output)
|
||||||
|
|
||||||
# Load model
|
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
print(f"Synthesizing text: {args.text}")
|
|
||||||
print(f"Reference audio: {prompt_audio_path}")
|
|
||||||
print(f"Reference text: {args.prompt_text}")
|
|
||||||
|
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
prompt_wav_path=str(prompt_audio_path),
|
prompt_wav_path=str(prompt_audio_path),
|
||||||
@ -136,30 +138,21 @@ def cmd_clone(args):
|
|||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=args.denoise
|
denoise=args.denoise,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save audio
|
|
||||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||||
print(f"Saved audio to: {output_path}")
|
|
||||||
|
|
||||||
# Stats
|
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f"Duration: {duration:.2f}s")
|
print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def cmd_synthesize(args):
|
def cmd_synthesize(args):
|
||||||
"""Direct TTS synthesis command."""
|
|
||||||
# Validate inputs
|
|
||||||
if not args.text:
|
if not args.text:
|
||||||
print("Error: Please provide text to synthesize (--text)")
|
sys.exit("Error: Please provide --text for synthesis")
|
||||||
sys.exit(1)
|
|
||||||
# Validate output path
|
|
||||||
output_path = validate_output_path(args.output)
|
output_path = validate_output_path(args.output)
|
||||||
# Load model
|
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
# Generate audio
|
|
||||||
print(f"Synthesizing text: {args.text}")
|
|
||||||
|
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -168,45 +161,35 @@ def cmd_synthesize(args):
|
|||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=False # 无参考音频时不需要降噪
|
denoise=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save audio
|
|
||||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||||
print(f"Saved audio to: {output_path}")
|
|
||||||
|
|
||||||
# Stats
|
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f"Duration: {duration:.2f}s")
|
print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def cmd_batch(args):
|
def cmd_batch(args):
|
||||||
"""Batch synthesis command."""
|
|
||||||
# Validate input file
|
|
||||||
input_file = validate_file_exists(args.input, "input file")
|
input_file = validate_file_exists(args.input, "input file")
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
try:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
|
||||||
texts = [line.strip() for line in f if line.strip()]
|
texts = [line.strip() for line in f if line.strip()]
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to read input file: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
if not texts:
|
if not texts:
|
||||||
print("Error: Input file is empty or contains no valid lines")
|
sys.exit("Error: Input file is empty")
|
||||||
sys.exit(1)
|
|
||||||
print(f"Found {len(texts)} lines to process")
|
|
||||||
|
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
|
|
||||||
prompt_audio_path = None
|
prompt_audio_path = None
|
||||||
if args.prompt_audio:
|
if args.prompt_audio:
|
||||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
||||||
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
for i, text in enumerate(texts, 1):
|
|
||||||
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...")
|
|
||||||
|
|
||||||
|
for i, text in enumerate(texts, 1):
|
||||||
try:
|
try:
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=text,
|
text=text,
|
||||||
@ -215,112 +198,109 @@ def cmd_batch(args):
|
|||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=args.denoise and prompt_audio_path is not None
|
denoise=args.denoise and prompt_audio_path is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_file = output_dir / f"output_{i:03d}.wav"
|
output_file = output_dir / f"output_{i:03d}.wav"
|
||||||
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
||||||
|
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f" Saved: {output_file} ({duration:.2f}s)")
|
print(f"Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Failed: {e}")
|
print(f"Failed on line {i}: {e}", file=sys.stderr)
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded")
|
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Parser
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def _build_unified_parser():
|
def _build_unified_parser():
|
||||||
"""Build unified argument parser (no subcommands, route by args)."""
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="VoxCPM CLI (single parser) - voice cloning, direct TTS, and batch processing",
|
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
# Direct synthesis (single sample)
|
|
||||||
voxcpm --text "Hello world" --output out.wav
|
voxcpm --text "Hello world" --output out.wav
|
||||||
|
voxcpm --text "Hello" --prompt-audio ref.wav --prompt-text "hi" --output out.wav --denoise
|
||||||
# Voice cloning (reference audio + text)
|
|
||||||
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output out.wav --denoise
|
|
||||||
|
|
||||||
# Batch processing
|
|
||||||
voxcpm --input texts.txt --output-dir ./outs
|
voxcpm --input texts.txt --output-dir ./outs
|
||||||
|
""",
|
||||||
# Select model (from Hub)
|
|
||||||
voxcpm --text "Hello" --output out.wav --hf-model-id openbmb/VoxCPM-0.5B
|
|
||||||
"""
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Task selection (automatic routing by presence of args)
|
# Mode selection
|
||||||
parser.add_argument("--input", "-i", help="Input text file (one line per sample)")
|
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
||||||
parser.add_argument("--output-dir", "-od", help="Output directory (for batch mode)")
|
parser.add_argument("--output-dir", "-od", help="Output directory (batch mode only)")
|
||||||
parser.add_argument("--text", "-t", help="Text to synthesize (single-sample mode)")
|
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
|
||||||
parser.add_argument("--output", "-o", help="Output audio file path (single-sample mode)")
|
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
|
||||||
|
|
||||||
# Prompt audio (for voice cloning)
|
# Prompt
|
||||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
|
||||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||||
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
|
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
|
||||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
|
||||||
|
|
||||||
# Generation parameters
|
# Generation parameters
|
||||||
parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG guidance scale (default: 2.0)")
|
parser.add_argument("--cfg-value", type=float, default=2.0,
|
||||||
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (default: 10)")
|
help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)")
|
||||||
|
parser.add_argument("--inference-timesteps", type=int, default=10,
|
||||||
|
help="Inference steps (int, 1–100, default: 10)")
|
||||||
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
||||||
|
|
||||||
# Model loading parameters
|
# Model loading
|
||||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
|
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
|
||||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)")
|
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
|
||||||
|
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
|
||||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||||
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
|
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
|
||||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
parser.add_argument("--zipenhancer-path", type=str,
|
||||||
|
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
|
||||||
|
|
||||||
# LoRA parameters
|
# LoRA
|
||||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
|
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
|
||||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
|
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
|
||||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
|
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
|
||||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
|
parser.add_argument("--lora-dropout", type=float, default=0.0,
|
||||||
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
|
help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||||
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
|
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
|
||||||
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
|
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
|
||||||
|
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Entrypoint
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Unified CLI entrypoint: route by provided arguments."""
|
|
||||||
parser = _build_unified_parser()
|
parser = _build_unified_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Routing: prefer batch → single (clone/direct)
|
# Validate ranges
|
||||||
|
validate_ranges(args, parser)
|
||||||
|
|
||||||
|
# Mode conflict checks
|
||||||
|
if args.input and args.text:
|
||||||
|
parser.error("Use either batch mode (--input) or single mode (--text), not both.")
|
||||||
|
|
||||||
|
# Batch mode
|
||||||
if args.input:
|
if args.input:
|
||||||
if not args.output_dir:
|
if not args.output_dir:
|
||||||
print("Error: Batch mode requires --output-dir")
|
parser.error("Batch mode requires --output-dir")
|
||||||
parser.print_help()
|
|
||||||
sys.exit(1)
|
|
||||||
return cmd_batch(args)
|
return cmd_batch(args)
|
||||||
|
|
||||||
# Single-sample mode
|
# Single mode
|
||||||
if not args.text or not args.output:
|
if not args.text or not args.output:
|
||||||
print("Error: Single-sample mode requires --text and --output")
|
parser.error("Single-sample mode requires --text and --output")
|
||||||
parser.print_help()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# If prompt audio+text provided → voice cloning
|
# Clone mode
|
||||||
if args.prompt_audio or args.prompt_text:
|
if args.prompt_audio or args.prompt_text:
|
||||||
if not args.prompt_text and args.prompt_file:
|
|
||||||
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
|
|
||||||
|
|
||||||
with open(args.prompt_file, 'r', encoding='utf-8') as f:
|
|
||||||
args.prompt_text = f.read()
|
|
||||||
|
|
||||||
if not args.prompt_audio or not args.prompt_text:
|
|
||||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
|
||||||
sys.exit(1)
|
|
||||||
return cmd_clone(args)
|
return cmd_clone(args)
|
||||||
|
|
||||||
# Otherwise → direct synthesis
|
# Direct synthesis
|
||||||
return cmd_synthesize(args)
|
return cmd_synthesize(args)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -30,7 +31,7 @@ class VoxCPM:
|
|||||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||||
"""
|
"""
|
||||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
|
||||||
|
|
||||||
# If lora_weights_path is provided but no lora_config, create a default one
|
# If lora_weights_path is provided but no lora_config, create a default one
|
||||||
if lora_weights_path is not None and lora_config is None:
|
if lora_weights_path is not None and lora_config is None:
|
||||||
@ -39,15 +40,15 @@ class VoxCPM:
|
|||||||
enable_dit=True,
|
enable_dit=True,
|
||||||
enable_proj=False,
|
enable_proj=False,
|
||||||
)
|
)
|
||||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
|
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
||||||
|
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||||
|
|
||||||
# Load LoRA weights if path is provided
|
# Load LoRA weights if path is provided
|
||||||
if lora_weights_path is not None:
|
if lora_weights_path is not None:
|
||||||
print(f"Loading LoRA weights from: {lora_weights_path}")
|
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
|
||||||
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
|
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
||||||
|
|
||||||
self.text_normalizer = None
|
self.text_normalizer = None
|
||||||
if enable_denoiser and zipenhancer_model_path is not None:
|
if enable_denoiser and zipenhancer_model_path is not None:
|
||||||
@ -55,7 +56,8 @@ class VoxCPM:
|
|||||||
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
||||||
else:
|
else:
|
||||||
self.denoiser = None
|
self.denoiser = None
|
||||||
print("Warm up VoxCPMModel...")
|
if optimize:
|
||||||
|
print("Warm up VoxCPMModel...", file=sys.stderr)
|
||||||
self.tts_model.generate(
|
self.tts_model.generate(
|
||||||
target_text="Hello, this is the first test sentence.",
|
target_text="Hello, this is the first test sentence.",
|
||||||
max_len=10,
|
max_len=10,
|
||||||
|
|||||||
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
from typing import Tuple, Union, Generator, List, Optional
|
from typing import Tuple, Union, Generator, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -120,7 +121,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.device = "mps"
|
self.device = "mps"
|
||||||
else:
|
else:
|
||||||
self.device = "cpu"
|
self.device = "cpu"
|
||||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
self.base_lm = MiniCPMModel(config.lm_config)
|
self.base_lm = MiniCPMModel(config.lm_config)
|
||||||
@ -228,7 +229,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: torch.compile disabled - {e}")
|
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -459,7 +460,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||||
retry_badcase_times += 1
|
retry_badcase_times += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -583,6 +584,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0,
|
retry_badcase_ratio_threshold: float = 6.0,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
|
streaming_prefix_len: int = 3,
|
||||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate audio using pre-built prompt cache.
|
Generate audio using pre-built prompt cache.
|
||||||
@ -598,6 +600,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase_max_times: Maximum retry attempts
|
retry_badcase_max_times: Maximum retry attempts
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||||
streaming: Whether to return a generator of audio chunks
|
streaming: Whether to return a generator of audio chunks
|
||||||
|
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generator of Tuple containing:
|
Generator of Tuple containing:
|
||||||
@ -664,6 +667,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
|
streaming_prefix_len=streaming_prefix_len,
|
||||||
)
|
)
|
||||||
if streaming:
|
if streaming:
|
||||||
patch_len = self.patch_size * self.chunk_size
|
patch_len = self.patch_size * self.chunk_size
|
||||||
@ -680,7 +684,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
||||||
retry_badcase_times += 1
|
retry_badcase_times += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -688,8 +692,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
if not streaming:
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
|
patch_len = self.patch_size * self.chunk_size
|
||||||
|
if audio_mask.sum().item() > 0:
|
||||||
|
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
||||||
|
else:
|
||||||
|
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||||
yield (
|
yield (
|
||||||
decode_audio,
|
decode_audio,
|
||||||
target_text_token,
|
target_text_token,
|
||||||
@ -754,6 +762,17 @@ class VoxCPMModel(nn.Module):
|
|||||||
pred_feat_seq = [] # b, t, p, d
|
pred_feat_seq = [] # b, t, p, d
|
||||||
curr_embed = None
|
curr_embed = None
|
||||||
|
|
||||||
|
# Prepare prompt context patches for streaming mode
|
||||||
|
# When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context
|
||||||
|
prompt_context_patches = []
|
||||||
|
audio_patch_count = int(feat_mask.sum().item())
|
||||||
|
if audio_patch_count > 0:
|
||||||
|
context_len = min(streaming_prefix_len - 1, audio_patch_count)
|
||||||
|
# Take the last context_len patches from prompt audio as initial context
|
||||||
|
# Split into list of [b, 1, p, d] tensors to match pred_feat_seq format
|
||||||
|
prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1))
|
||||||
|
pred_feat_seq = prompt_context_patches + pred_feat_seq
|
||||||
|
|
||||||
enc_outputs, kv_cache_tuple = self.base_lm(
|
enc_outputs, kv_cache_tuple = self.base_lm(
|
||||||
inputs_embeds=combined_embed,
|
inputs_embeds=combined_embed,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
@ -850,10 +869,10 @@ class VoxCPMModel(nn.Module):
|
|||||||
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
||||||
|
|
||||||
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||||
print(f"Loading model from safetensors: {safetensors_path}")
|
print(f"Loading model from safetensors: {safetensors_path}", file=sys.stderr)
|
||||||
model_state_dict = load_file(safetensors_path)
|
model_state_dict = load_file(safetensors_path)
|
||||||
elif os.path.exists(pytorch_model_path):
|
elif os.path.exists(pytorch_model_path):
|
||||||
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
|
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}", file=sys.stderr)
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
pytorch_model_path,
|
pytorch_model_path,
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
|
|||||||
@ -70,25 +70,28 @@ def compute_sample_lengths(
|
|||||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||||
t_seq = ceil(t_vae / patch_size)
|
t_seq = ceil(t_vae / patch_size)
|
||||||
- 序列总长约为: text_len + t_seq + 2
|
- 序列总长约为: text_len + t_seq + 2
|
||||||
|
|
||||||
|
Optimized: Use batch column access instead of iterating item by item.
|
||||||
"""
|
"""
|
||||||
lengths: List[int] = []
|
# Batch access columns - much faster than per-item access
|
||||||
|
text_ids_list = ds["text_ids"]
|
||||||
|
text_lens = [len(t) for t in text_ids_list]
|
||||||
|
|
||||||
has_duration = "duration" in ds.column_names
|
has_duration = "duration" in ds.column_names
|
||||||
|
|
||||||
for i in range(len(ds)):
|
|
||||||
item = ds[i]
|
|
||||||
text_len = len(item["text_ids"])
|
|
||||||
|
|
||||||
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
|
|
||||||
if has_duration:
|
if has_duration:
|
||||||
duration = float(item["duration"])
|
durations = ds["duration"]
|
||||||
else:
|
else:
|
||||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
# Fallback: need to compute from audio (slow, but unavoidable without duration column)
|
||||||
duration = len(audio["array"]) / float(audio["sampling_rate"])
|
durations = []
|
||||||
|
for i in range(len(ds)):
|
||||||
|
audio = ds[i][DEFAULT_AUDIO_COLUMN]
|
||||||
|
durations.append(len(audio["array"]) / float(audio["sampling_rate"]))
|
||||||
|
|
||||||
t_vae = math.ceil(duration * audio_vae_fps)
|
# Vectorized length computation
|
||||||
|
lengths = []
|
||||||
|
for text_len, duration in zip(text_lens, durations):
|
||||||
|
t_vae = math.ceil(float(duration) * audio_vae_fps)
|
||||||
t_seq = math.ceil(t_vae / patch_size)
|
t_seq = math.ceil(t_vae / patch_size)
|
||||||
|
|
||||||
total_len = text_len + t_seq + 2
|
total_len = text_len + t_seq + 2
|
||||||
lengths.append(total_len)
|
lengths.append(total_len)
|
||||||
|
|
||||||
@ -211,4 +214,3 @@ def build_dataloader(
|
|||||||
collate_fn=HFVoxCPMDataset.collate_fn,
|
collate_fn=HFVoxCPMDataset.collate_fn,
|
||||||
drop_last=drop_last,
|
drop_last=drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
@ -36,7 +37,7 @@ class TrainingTracker:
|
|||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
def print(self, message: str):
|
def print(self, message: str):
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
print(message, flush=True)
|
print(message, flush=True, file=sys.stderr)
|
||||||
if self.log_file:
|
if self.log_file:
|
||||||
with self.log_file.open("a", encoding="utf-8") as f:
|
with self.log_file.open("a", encoding="utf-8") as f:
|
||||||
f.write(message + "\n")
|
f.write(message + "\n")
|
||||||
|
|||||||
Reference in New Issue
Block a user