diff --git a/src/voxcpm/cli.py b/src/voxcpm/cli.py index 71007b0..232deb4 100644 --- a/src/voxcpm/cli.py +++ b/src/voxcpm/cli.py @@ -3,30 +3,22 @@ VoxCPM Command Line Interface Unified CLI for voice cloning, direct TTS synthesis, and batch processing. - -Usage examples: - # Direct synthesis (single sample) - voxcpm --text "Hello world" --output output.wav - - # Voice cloning (with reference audio and text) - voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output output.wav --denoise - - # Batch processing (each line in the file is one sample) - voxcpm --input texts.txt --output-dir ./outputs/ """ import argparse import os import sys from pathlib import Path -from typing import Optional, List import soundfile as sf from voxcpm.core import VoxCPM +# ----------------------------- +# Validators +# ----------------------------- + def validate_file_exists(file_path: str, file_type: str = "file") -> Path: - """Validate that a file exists.""" path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"{file_type} '{file_path}' does not exist") @@ -34,47 +26,68 @@ def validate_file_exists(file_path: str, file_type: str = "file") -> Path: def validate_output_path(output_path: str) -> Path: - """Validate the output path and create parent directories if needed.""" path = Path(output_path) path.parent.mkdir(parents=True, exist_ok=True) return path -def load_model(args) -> VoxCPM: - """Load VoxCPM model. +def validate_ranges(args, parser): + """Validate numeric argument ranges.""" + if not (0.1 <= args.cfg_value <= 10.0): + parser.error("--cfg-value must be between 0.1 and 10.0") - Prefer --model-path if provided; otherwise use from_pretrained (Hub). - """ + if not (1 <= args.inference_timesteps <= 100): + parser.error("--inference-timesteps must be between 1 and 100") + + if args.lora_r <= 0: + parser.error("--lora-r must be a positive integer") + + if args.lora_alpha <= 0: + parser.error("--lora-alpha must be a positive integer") + + if not (0.0 <= args.lora_dropout <= 1.0): + parser.error("--lora-dropout must be between 0.0 and 1.0") + + +# ----------------------------- +# Model loading +# ----------------------------- + +def load_model(args) -> VoxCPM: print("Loading VoxCPM model...", file=sys.stderr) - # 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认 zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get( "ZIPENHANCER_MODEL_PATH", None ) - # Build LoRA config if lora_path is provided + # Build LoRA config if provided lora_config = None lora_weights_path = getattr(args, "lora_path", None) if lora_weights_path: from voxcpm.model.voxcpm import LoRAConfig - lora_config = LoRAConfig( - enable_lm=getattr(args, "lora_enable_lm", True), - enable_dit=getattr(args, "lora_enable_dit", True), - enable_proj=getattr(args, "lora_enable_proj", False), - r=getattr(args, "lora_r", 32), - alpha=getattr(args, "lora_alpha", 16), - dropout=getattr(args, "lora_dropout", 0.0), - ) - print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, " - f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}", file=sys.stderr) - # Load from local path if provided - if getattr(args, "model_path", None): + lora_config = LoRAConfig( + enable_lm=not args.lora_disable_lm, + enable_dit=not args.lora_disable_dit, + enable_proj=args.lora_enable_proj, + r=args.lora_r, + alpha=args.lora_alpha, + dropout=args.lora_dropout, + ) + + print( + f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, " + f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}", + file=sys.stderr, + ) + + # Load local model if specified + if args.model_path: try: model = VoxCPM( voxcpm_model_path=args.model_path, zipenhancer_model_path=zipenhancer_path, - enable_denoiser=not getattr(args, "no_denoiser", False), + enable_denoiser=not args.no_denoiser, lora_config=lora_config, lora_weights_path=lora_weights_path, ) @@ -84,14 +97,14 @@ def load_model(args) -> VoxCPM: print(f"Failed to load model (local): {e}", file=sys.stderr) sys.exit(1) - # Otherwise, try from_pretrained (Hub); exit on failure + # Load from Hugging Face Hub try: model = VoxCPM.from_pretrained( - hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"), - load_denoiser=not getattr(args, "no_denoiser", False), + hf_model_id=args.hf_model_id, + load_denoiser=not args.no_denoiser, zipenhancer_model_id=zipenhancer_path, - cache_dir=getattr(args, "cache_dir", None), - local_files_only=getattr(args, "local_files_only", False), + cache_dir=args.cache_dir, + local_files_only=args.local_files_only, lora_config=lora_config, lora_weights_path=lora_weights_path, ) @@ -102,33 +115,22 @@ def load_model(args) -> VoxCPM: sys.exit(1) +# ----------------------------- +# Commands +# ----------------------------- + def cmd_clone(args): - """Voice cloning command.""" - # Validate inputs if not args.text: - print("Error: Please provide text to synthesize (--text)", file=sys.stderr) - sys.exit(1) - - if not args.prompt_audio: - print("Error: Voice cloning requires a reference audio (--prompt-audio)", file=sys.stderr) - sys.exit(1) - - if not args.prompt_text: - print("Error: Voice cloning requires a reference text (--prompt-text)", file=sys.stderr) - sys.exit(1) - - # Validate files + sys.exit("Error: Please provide --text for synthesis") + + if not args.prompt_audio or not args.prompt_text: + sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text") + prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file") output_path = validate_output_path(args.output) - - # Load model + model = load_model(args) - - # Generate audio - print(f"Synthesizing text: {args.text}", file=sys.stderr) - print(f"Reference audio: {prompt_audio_path}", file=sys.stderr) - print(f"Reference text: {args.prompt_text}", file=sys.stderr) - + audio_array = model.generate( text=args.text, prompt_wav_path=str(prompt_audio_path), @@ -136,31 +138,22 @@ def cmd_clone(args): cfg_value=args.cfg_value, inference_timesteps=args.inference_timesteps, normalize=args.normalize, - denoise=args.denoise + denoise=args.denoise, ) - - # Save audio + sf.write(str(output_path), audio_array, model.tts_model.sample_rate) - print(f"Saved audio to: {output_path}", file=sys.stderr) - - # Stats + duration = len(audio_array) / model.tts_model.sample_rate - print(f"Duration: {duration:.2f}s", file=sys.stderr) + print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr) def cmd_synthesize(args): - """Direct TTS synthesis command.""" - # Validate inputs if not args.text: - print("Error: Please provide text to synthesize (--text)", file=sys.stderr) - sys.exit(1) - # Validate output path + sys.exit("Error: Please provide --text for synthesis") + output_path = validate_output_path(args.output) - # Load model model = load_model(args) - # Generate audio - print(f"Synthesizing text: {args.text}", file=sys.stderr) - + audio_array = model.generate( text=args.text, prompt_wav_path=None, @@ -168,45 +161,35 @@ def cmd_synthesize(args): cfg_value=args.cfg_value, inference_timesteps=args.inference_timesteps, normalize=args.normalize, - denoise=False # 无参考音频时不需要降噪 + denoise=False, ) - - # Save audio + sf.write(str(output_path), audio_array, model.tts_model.sample_rate) - print(f"Saved audio to: {output_path}", file=sys.stderr) - - # Stats + duration = len(audio_array) / model.tts_model.sample_rate - print(f"Duration: {duration:.2f}s", file=sys.stderr) + print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr) def cmd_batch(args): - """Batch synthesis command.""" - # Validate input file input_file = validate_file_exists(args.input, "input file") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - - try: - with open(input_file, 'r', encoding='utf-8') as f: - texts = [line.strip() for line in f if line.strip()] - except Exception as e: - print(f"Failed to read input file: {e}", file=sys.stderr) - sys.exit(1) + + with open(input_file, "r", encoding="utf-8") as f: + texts = [line.strip() for line in f if line.strip()] + if not texts: - print("Error: Input file is empty or contains no valid lines", file=sys.stderr) - sys.exit(1) - print(f"Found {len(texts)} lines to process", file=sys.stderr) - + sys.exit("Error: Input file is empty") + model = load_model(args) + prompt_audio_path = None if args.prompt_audio: prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file")) - + success_count = 0 + for i, text in enumerate(texts, 1): - print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...", file=sys.stderr) - try: audio_array = model.generate( text=text, @@ -215,112 +198,109 @@ def cmd_batch(args): cfg_value=args.cfg_value, inference_timesteps=args.inference_timesteps, normalize=args.normalize, - denoise=args.denoise and prompt_audio_path is not None + denoise=args.denoise and prompt_audio_path is not None, ) + output_file = output_dir / f"output_{i:03d}.wav" sf.write(str(output_file), audio_array, model.tts_model.sample_rate) - + duration = len(audio_array) / model.tts_model.sample_rate - print(f" Saved: {output_file} ({duration:.2f}s)", file=sys.stderr) + print(f"Saved: {output_file} ({duration:.2f}s)", file=sys.stderr) success_count += 1 - + except Exception as e: - print(f" Failed: {e}", file=sys.stderr) - continue - + print(f"Failed on line {i}: {e}", file=sys.stderr) + print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr) + +# ----------------------------- +# Parser +# ----------------------------- + def _build_unified_parser(): - """Build unified argument parser (no subcommands, route by args).""" parser = argparse.ArgumentParser( - description="VoxCPM CLI (single parser) - voice cloning, direct TTS, and batch processing", + description="VoxCPM CLI - voice cloning, direct TTS, and batch processing", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - # Direct synthesis (single sample) voxcpm --text "Hello world" --output out.wav - - # Voice cloning (reference audio + text) - voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output out.wav --denoise - - # Batch processing + voxcpm --text "Hello" --prompt-audio ref.wav --prompt-text "hi" --output out.wav --denoise voxcpm --input texts.txt --output-dir ./outs - - # Select model (from Hub) - voxcpm --text "Hello" --output out.wav --hf-model-id openbmb/VoxCPM-0.5B - """ + """, ) - # Task selection (automatic routing by presence of args) - parser.add_argument("--input", "-i", help="Input text file (one line per sample)") - parser.add_argument("--output-dir", "-od", help="Output directory (for batch mode)") - parser.add_argument("--text", "-t", help="Text to synthesize (single-sample mode)") - parser.add_argument("--output", "-o", help="Output audio file path (single-sample mode)") + # Mode selection + parser.add_argument("--input", "-i", help="Input text file (batch mode only)") + parser.add_argument("--output-dir", "-od", help="Output directory (batch mode only)") + parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)") + parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)") - # Prompt audio (for voice cloning) - parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path") + # Prompt + parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)") parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio") - parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio") - parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)") + parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement") # Generation parameters - parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG guidance scale (default: 2.0)") - parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (default: 10)") + parser.add_argument("--cfg-value", type=float, default=2.0, + help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)") + parser.add_argument("--inference-timesteps", type=int, default=10, + help="Inference steps (int, 1–100, default: 10)") parser.add_argument("--normalize", action="store_true", help="Enable text normalization") - # Model loading parameters - parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)") - parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)") + # Model loading + parser.add_argument("--model-path", type=str, help="Local VoxCPM model path") + parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", + help="Hugging Face repo id (default: openbmb/VoxCPM1.5)") parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads") - parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)") + parser.add_argument("--local-files-only", action="store_true", help="Disable network access") parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading") - parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)") + parser.add_argument("--zipenhancer-path", type=str, + help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)") - # LoRA parameters - parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)") - parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)") - parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)") - parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)") - parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)") - parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)") - parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)") + # LoRA + parser.add_argument("--lora-path", type=str, help="Path to LoRA weights") + parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)") + parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)") + parser.add_argument("--lora-dropout", type=float, default=0.0, + help="LoRA dropout rate (0.0–1.0, default: 0.0)") + parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers") + parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers") + parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers") return parser +# ----------------------------- +# Entrypoint +# ----------------------------- + def main(): - """Unified CLI entrypoint: route by provided arguments.""" parser = _build_unified_parser() args = parser.parse_args() - # Routing: prefer batch → single (clone/direct) + # Validate ranges + validate_ranges(args, parser) + + # Mode conflict checks + if args.input and args.text: + parser.error("Use either batch mode (--input) or single mode (--text), not both.") + + # Batch mode if args.input: if not args.output_dir: - print("Error: Batch mode requires --output-dir", file=sys.stderr) - parser.print_help() - sys.exit(1) + parser.error("Batch mode requires --output-dir") return cmd_batch(args) - # Single-sample mode + # Single mode if not args.text or not args.output: - print("Error: Single-sample mode requires --text and --output", file=sys.stderr) - parser.print_help() - sys.exit(1) + parser.error("Single-sample mode requires --text and --output") - # If prompt audio+text provided → voice cloning + # Clone mode if args.prompt_audio or args.prompt_text: - if not args.prompt_text and args.prompt_file: - assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible." - - with open(args.prompt_file, 'r', encoding='utf-8') as f: - args.prompt_text = f.read() - - if not args.prompt_audio or not args.prompt_text: - print("Error: Voice cloning requires both --prompt-audio and --prompt-text", file=sys.stderr) - sys.exit(1) return cmd_clone(args) - # Otherwise → direct synthesis + # Direct synthesis return cmd_synthesize(args)