cli: improve argument validation and help text for VoxCPM CLI

This commit is contained in:
Biriy
2026-01-20 14:33:58 +05:30
parent e72fb42c38
commit 8f3a91cac8

View File

@ -3,30 +3,22 @@
VoxCPM Command Line Interface VoxCPM Command Line Interface
Unified CLI for voice cloning, direct TTS synthesis, and batch processing. Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
Usage examples:
# Direct synthesis (single sample)
voxcpm --text "Hello world" --output output.wav
# Voice cloning (with reference audio and text)
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output output.wav --denoise
# Batch processing (each line in the file is one sample)
voxcpm --input texts.txt --output-dir ./outputs/
""" """
import argparse import argparse
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional, List
import soundfile as sf import soundfile as sf
from voxcpm.core import VoxCPM from voxcpm.core import VoxCPM
# -----------------------------
# Validators
# -----------------------------
def validate_file_exists(file_path: str, file_type: str = "file") -> Path: def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
"""Validate that a file exists."""
path = Path(file_path) path = Path(file_path)
if not path.exists(): if not path.exists():
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist") raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
@ -34,47 +26,68 @@ def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
def validate_output_path(output_path: str) -> Path: def validate_output_path(output_path: str) -> Path:
"""Validate the output path and create parent directories if needed."""
path = Path(output_path) path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
return path return path
def load_model(args) -> VoxCPM: def validate_ranges(args, parser):
"""Load VoxCPM model. """Validate numeric argument ranges."""
if not (0.1 <= args.cfg_value <= 10.0):
parser.error("--cfg-value must be between 0.1 and 10.0")
Prefer --model-path if provided; otherwise use from_pretrained (Hub). if not (1 <= args.inference_timesteps <= 100):
""" parser.error("--inference-timesteps must be between 1 and 100")
if args.lora_r <= 0:
parser.error("--lora-r must be a positive integer")
if args.lora_alpha <= 0:
parser.error("--lora-alpha must be a positive integer")
if not (0.0 <= args.lora_dropout <= 1.0):
parser.error("--lora-dropout must be between 0.0 and 1.0")
# -----------------------------
# Model loading
# -----------------------------
def load_model(args) -> VoxCPM:
print("Loading VoxCPM model...", file=sys.stderr) print("Loading VoxCPM model...", file=sys.stderr)
# 兼容旧参数ZIPENHANCER_MODEL_PATH 环境变量作为默认
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get( zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
"ZIPENHANCER_MODEL_PATH", None "ZIPENHANCER_MODEL_PATH", None
) )
# Build LoRA config if lora_path is provided # Build LoRA config if provided
lora_config = None lora_config = None
lora_weights_path = getattr(args, "lora_path", None) lora_weights_path = getattr(args, "lora_path", None)
if lora_weights_path: if lora_weights_path:
from voxcpm.model.voxcpm import LoRAConfig from voxcpm.model.voxcpm import LoRAConfig
lora_config = LoRAConfig(
enable_lm=getattr(args, "lora_enable_lm", True),
enable_dit=getattr(args, "lora_enable_dit", True),
enable_proj=getattr(args, "lora_enable_proj", False),
r=getattr(args, "lora_r", 32),
alpha=getattr(args, "lora_alpha", 16),
dropout=getattr(args, "lora_dropout", 0.0),
)
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}", file=sys.stderr)
# Load from local path if provided lora_config = LoRAConfig(
if getattr(args, "model_path", None): enable_lm=not args.lora_disable_lm,
enable_dit=not args.lora_disable_dit,
enable_proj=args.lora_enable_proj,
r=args.lora_r,
alpha=args.lora_alpha,
dropout=args.lora_dropout,
)
print(
f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}",
file=sys.stderr,
)
# Load local model if specified
if args.model_path:
try: try:
model = VoxCPM( model = VoxCPM(
voxcpm_model_path=args.model_path, voxcpm_model_path=args.model_path,
zipenhancer_model_path=zipenhancer_path, zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not getattr(args, "no_denoiser", False), enable_denoiser=not args.no_denoiser,
lora_config=lora_config, lora_config=lora_config,
lora_weights_path=lora_weights_path, lora_weights_path=lora_weights_path,
) )
@ -84,14 +97,14 @@ def load_model(args) -> VoxCPM:
print(f"Failed to load model (local): {e}", file=sys.stderr) print(f"Failed to load model (local): {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
# Otherwise, try from_pretrained (Hub); exit on failure # Load from Hugging Face Hub
try: try:
model = VoxCPM.from_pretrained( model = VoxCPM.from_pretrained(
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"), hf_model_id=args.hf_model_id,
load_denoiser=not getattr(args, "no_denoiser", False), load_denoiser=not args.no_denoiser,
zipenhancer_model_id=zipenhancer_path, zipenhancer_model_id=zipenhancer_path,
cache_dir=getattr(args, "cache_dir", None), cache_dir=args.cache_dir,
local_files_only=getattr(args, "local_files_only", False), local_files_only=args.local_files_only,
lora_config=lora_config, lora_config=lora_config,
lora_weights_path=lora_weights_path, lora_weights_path=lora_weights_path,
) )
@ -102,33 +115,22 @@ def load_model(args) -> VoxCPM:
sys.exit(1) sys.exit(1)
# -----------------------------
# Commands
# -----------------------------
def cmd_clone(args): def cmd_clone(args):
"""Voice cloning command."""
# Validate inputs
if not args.text: if not args.text:
print("Error: Please provide text to synthesize (--text)", file=sys.stderr) sys.exit("Error: Please provide --text for synthesis")
sys.exit(1)
if not args.prompt_audio or not args.prompt_text:
if not args.prompt_audio: sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
print("Error: Voice cloning requires a reference audio (--prompt-audio)", file=sys.stderr)
sys.exit(1)
if not args.prompt_text:
print("Error: Voice cloning requires a reference text (--prompt-text)", file=sys.stderr)
sys.exit(1)
# Validate files
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file") prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
output_path = validate_output_path(args.output) output_path = validate_output_path(args.output)
# Load model
model = load_model(args) model = load_model(args)
# Generate audio
print(f"Synthesizing text: {args.text}", file=sys.stderr)
print(f"Reference audio: {prompt_audio_path}", file=sys.stderr)
print(f"Reference text: {args.prompt_text}", file=sys.stderr)
audio_array = model.generate( audio_array = model.generate(
text=args.text, text=args.text,
prompt_wav_path=str(prompt_audio_path), prompt_wav_path=str(prompt_audio_path),
@ -136,31 +138,22 @@ def cmd_clone(args):
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
normalize=args.normalize, normalize=args.normalize,
denoise=args.denoise denoise=args.denoise,
) )
# Save audio
sf.write(str(output_path), audio_array, model.tts_model.sample_rate) sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}", file=sys.stderr)
# Stats
duration = len(audio_array) / model.tts_model.sample_rate duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s", file=sys.stderr) print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr)
def cmd_synthesize(args): def cmd_synthesize(args):
"""Direct TTS synthesis command."""
# Validate inputs
if not args.text: if not args.text:
print("Error: Please provide text to synthesize (--text)", file=sys.stderr) sys.exit("Error: Please provide --text for synthesis")
sys.exit(1)
# Validate output path
output_path = validate_output_path(args.output) output_path = validate_output_path(args.output)
# Load model
model = load_model(args) model = load_model(args)
# Generate audio
print(f"Synthesizing text: {args.text}", file=sys.stderr)
audio_array = model.generate( audio_array = model.generate(
text=args.text, text=args.text,
prompt_wav_path=None, prompt_wav_path=None,
@ -168,45 +161,35 @@ def cmd_synthesize(args):
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
normalize=args.normalize, normalize=args.normalize,
denoise=False # 无参考音频时不需要降噪 denoise=False,
) )
# Save audio
sf.write(str(output_path), audio_array, model.tts_model.sample_rate) sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}", file=sys.stderr)
# Stats
duration = len(audio_array) / model.tts_model.sample_rate duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s", file=sys.stderr) print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr)
def cmd_batch(args): def cmd_batch(args):
"""Batch synthesis command."""
# Validate input file
input_file = validate_file_exists(args.input, "input file") input_file = validate_file_exists(args.input, "input file")
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
try: with open(input_file, "r", encoding="utf-8") as f:
with open(input_file, 'r', encoding='utf-8') as f: texts = [line.strip() for line in f if line.strip()]
texts = [line.strip() for line in f if line.strip()]
except Exception as e:
print(f"Failed to read input file: {e}", file=sys.stderr)
sys.exit(1)
if not texts: if not texts:
print("Error: Input file is empty or contains no valid lines", file=sys.stderr) sys.exit("Error: Input file is empty")
sys.exit(1)
print(f"Found {len(texts)} lines to process", file=sys.stderr)
model = load_model(args) model = load_model(args)
prompt_audio_path = None prompt_audio_path = None
if args.prompt_audio: if args.prompt_audio:
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file")) prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
success_count = 0 success_count = 0
for i, text in enumerate(texts, 1): for i, text in enumerate(texts, 1):
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...", file=sys.stderr)
try: try:
audio_array = model.generate( audio_array = model.generate(
text=text, text=text,
@ -215,112 +198,109 @@ def cmd_batch(args):
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
normalize=args.normalize, normalize=args.normalize,
denoise=args.denoise and prompt_audio_path is not None denoise=args.denoise and prompt_audio_path is not None,
) )
output_file = output_dir / f"output_{i:03d}.wav" output_file = output_dir / f"output_{i:03d}.wav"
sf.write(str(output_file), audio_array, model.tts_model.sample_rate) sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
duration = len(audio_array) / model.tts_model.sample_rate duration = len(audio_array) / model.tts_model.sample_rate
print(f" Saved: {output_file} ({duration:.2f}s)", file=sys.stderr) print(f"Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
success_count += 1 success_count += 1
except Exception as e: except Exception as e:
print(f" Failed: {e}", file=sys.stderr) print(f"Failed on line {i}: {e}", file=sys.stderr)
continue
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr) print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
# -----------------------------
# Parser
# -----------------------------
def _build_unified_parser(): def _build_unified_parser():
"""Build unified argument parser (no subcommands, route by args)."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="VoxCPM CLI (single parser) - voice cloning, direct TTS, and batch processing", description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=""" epilog="""
Examples: Examples:
# Direct synthesis (single sample)
voxcpm --text "Hello world" --output out.wav voxcpm --text "Hello world" --output out.wav
voxcpm --text "Hello" --prompt-audio ref.wav --prompt-text "hi" --output out.wav --denoise
# Voice cloning (reference audio + text)
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output out.wav --denoise
# Batch processing
voxcpm --input texts.txt --output-dir ./outs voxcpm --input texts.txt --output-dir ./outs
""",
# Select model (from Hub)
voxcpm --text "Hello" --output out.wav --hf-model-id openbmb/VoxCPM-0.5B
"""
) )
# Task selection (automatic routing by presence of args) # Mode selection
parser.add_argument("--input", "-i", help="Input text file (one line per sample)") parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
parser.add_argument("--output-dir", "-od", help="Output directory (for batch mode)") parser.add_argument("--output-dir", "-od", help="Output directory (batch mode only)")
parser.add_argument("--text", "-t", help="Text to synthesize (single-sample mode)") parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
parser.add_argument("--output", "-o", help="Output audio file path (single-sample mode)") parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
# Prompt audio (for voice cloning) # Prompt
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path") parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio") parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio") parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
# Generation parameters # Generation parameters
parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG guidance scale (default: 2.0)") parser.add_argument("--cfg-value", type=float, default=2.0,
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (default: 10)") help="CFG guidance scale (float, recommended 0.55.0, default: 2.0)")
parser.add_argument("--inference-timesteps", type=int, default=10,
help="Inference steps (int, 1100, default: 10)")
parser.add_argument("--normalize", action="store_true", help="Enable text normalization") parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
# Model loading parameters # Model loading
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)") parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)") parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads") parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)") parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading") parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)") parser.add_argument("--zipenhancer-path", type=str,
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
# LoRA parameters # LoRA
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)") parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)") parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)") parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)") parser.add_argument("--lora-dropout", type=float, default=0.0,
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)") help="LoRA dropout rate (0.01.0, default: 0.0)")
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)") parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)") parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
return parser return parser
# -----------------------------
# Entrypoint
# -----------------------------
def main(): def main():
"""Unified CLI entrypoint: route by provided arguments."""
parser = _build_unified_parser() parser = _build_unified_parser()
args = parser.parse_args() args = parser.parse_args()
# Routing: prefer batch → single (clone/direct) # Validate ranges
validate_ranges(args, parser)
# Mode conflict checks
if args.input and args.text:
parser.error("Use either batch mode (--input) or single mode (--text), not both.")
# Batch mode
if args.input: if args.input:
if not args.output_dir: if not args.output_dir:
print("Error: Batch mode requires --output-dir", file=sys.stderr) parser.error("Batch mode requires --output-dir")
parser.print_help()
sys.exit(1)
return cmd_batch(args) return cmd_batch(args)
# Single-sample mode # Single mode
if not args.text or not args.output: if not args.text or not args.output:
print("Error: Single-sample mode requires --text and --output", file=sys.stderr) parser.error("Single-sample mode requires --text and --output")
parser.print_help()
sys.exit(1)
# If prompt audio+text provided → voice cloning # Clone mode
if args.prompt_audio or args.prompt_text: if args.prompt_audio or args.prompt_text:
if not args.prompt_text and args.prompt_file:
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
with open(args.prompt_file, 'r', encoding='utf-8') as f:
args.prompt_text = f.read()
if not args.prompt_audio or not args.prompt_text:
print("Error: Voice cloning requires both --prompt-audio and --prompt-text", file=sys.stderr)
sys.exit(1)
return cmd_clone(args) return cmd_clone(args)
# Otherwise → direct synthesis # Direct synthesis
return cmd_synthesize(args) return cmd_synthesize(args)