Merge pull request #161 from s3ldc/cli-arg-validation
Improve CLI argument validation and help text
This commit is contained in:
@ -3,30 +3,22 @@
|
|||||||
VoxCPM Command Line Interface
|
VoxCPM Command Line Interface
|
||||||
|
|
||||||
Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
|
Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
|
||||||
|
|
||||||
Usage examples:
|
|
||||||
# Direct synthesis (single sample)
|
|
||||||
voxcpm --text "Hello world" --output output.wav
|
|
||||||
|
|
||||||
# Voice cloning (with reference audio and text)
|
|
||||||
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output output.wav --denoise
|
|
||||||
|
|
||||||
# Batch processing (each line in the file is one sample)
|
|
||||||
voxcpm --input texts.txt --output-dir ./outputs/
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from voxcpm.core import VoxCPM
|
from voxcpm.core import VoxCPM
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Validators
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
||||||
"""Validate that a file exists."""
|
|
||||||
path = Path(file_path)
|
path = Path(file_path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
|
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
|
||||||
@ -34,47 +26,68 @@ def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
|||||||
|
|
||||||
|
|
||||||
def validate_output_path(output_path: str) -> Path:
|
def validate_output_path(output_path: str) -> Path:
|
||||||
"""Validate the output path and create parent directories if needed."""
|
|
||||||
path = Path(output_path)
|
path = Path(output_path)
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def load_model(args) -> VoxCPM:
|
def validate_ranges(args, parser):
|
||||||
"""Load VoxCPM model.
|
"""Validate numeric argument ranges."""
|
||||||
|
if not (0.1 <= args.cfg_value <= 10.0):
|
||||||
|
parser.error("--cfg-value must be between 0.1 and 10.0")
|
||||||
|
|
||||||
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
|
if not (1 <= args.inference_timesteps <= 100):
|
||||||
"""
|
parser.error("--inference-timesteps must be between 1 and 100")
|
||||||
|
|
||||||
|
if args.lora_r <= 0:
|
||||||
|
parser.error("--lora-r must be a positive integer")
|
||||||
|
|
||||||
|
if args.lora_alpha <= 0:
|
||||||
|
parser.error("--lora-alpha must be a positive integer")
|
||||||
|
|
||||||
|
if not (0.0 <= args.lora_dropout <= 1.0):
|
||||||
|
parser.error("--lora-dropout must be between 0.0 and 1.0")
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Model loading
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
def load_model(args) -> VoxCPM:
|
||||||
print("Loading VoxCPM model...", file=sys.stderr)
|
print("Loading VoxCPM model...", file=sys.stderr)
|
||||||
|
|
||||||
# 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认
|
|
||||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||||
"ZIPENHANCER_MODEL_PATH", None
|
"ZIPENHANCER_MODEL_PATH", None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build LoRA config if lora_path is provided
|
# Build LoRA config if provided
|
||||||
lora_config = None
|
lora_config = None
|
||||||
lora_weights_path = getattr(args, "lora_path", None)
|
lora_weights_path = getattr(args, "lora_path", None)
|
||||||
if lora_weights_path:
|
if lora_weights_path:
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
lora_config = LoRAConfig(
|
|
||||||
enable_lm=getattr(args, "lora_enable_lm", True),
|
|
||||||
enable_dit=getattr(args, "lora_enable_dit", True),
|
|
||||||
enable_proj=getattr(args, "lora_enable_proj", False),
|
|
||||||
r=getattr(args, "lora_r", 32),
|
|
||||||
alpha=getattr(args, "lora_alpha", 16),
|
|
||||||
dropout=getattr(args, "lora_dropout", 0.0),
|
|
||||||
)
|
|
||||||
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
|
||||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}", file=sys.stderr)
|
|
||||||
|
|
||||||
# Load from local path if provided
|
lora_config = LoRAConfig(
|
||||||
if getattr(args, "model_path", None):
|
enable_lm=not args.lora_disable_lm,
|
||||||
|
enable_dit=not args.lora_disable_dit,
|
||||||
|
enable_proj=args.lora_enable_proj,
|
||||||
|
r=args.lora_r,
|
||||||
|
alpha=args.lora_alpha,
|
||||||
|
dropout=args.lora_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||||
|
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load local model if specified
|
||||||
|
if args.model_path:
|
||||||
try:
|
try:
|
||||||
model = VoxCPM(
|
model = VoxCPM(
|
||||||
voxcpm_model_path=args.model_path,
|
voxcpm_model_path=args.model_path,
|
||||||
zipenhancer_model_path=zipenhancer_path,
|
zipenhancer_model_path=zipenhancer_path,
|
||||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
enable_denoiser=not args.no_denoiser,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
@ -84,14 +97,14 @@ def load_model(args) -> VoxCPM:
|
|||||||
print(f"Failed to load model (local): {e}", file=sys.stderr)
|
print(f"Failed to load model (local): {e}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Otherwise, try from_pretrained (Hub); exit on failure
|
# Load from Hugging Face Hub
|
||||||
try:
|
try:
|
||||||
model = VoxCPM.from_pretrained(
|
model = VoxCPM.from_pretrained(
|
||||||
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"),
|
hf_model_id=args.hf_model_id,
|
||||||
load_denoiser=not getattr(args, "no_denoiser", False),
|
load_denoiser=not args.no_denoiser,
|
||||||
zipenhancer_model_id=zipenhancer_path,
|
zipenhancer_model_id=zipenhancer_path,
|
||||||
cache_dir=getattr(args, "cache_dir", None),
|
cache_dir=args.cache_dir,
|
||||||
local_files_only=getattr(args, "local_files_only", False),
|
local_files_only=args.local_files_only,
|
||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
lora_weights_path=lora_weights_path,
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
@ -102,33 +115,22 @@ def load_model(args) -> VoxCPM:
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Commands
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def cmd_clone(args):
|
def cmd_clone(args):
|
||||||
"""Voice cloning command."""
|
|
||||||
# Validate inputs
|
|
||||||
if not args.text:
|
if not args.text:
|
||||||
print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
|
sys.exit("Error: Please provide --text for synthesis")
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not args.prompt_audio:
|
if not args.prompt_audio or not args.prompt_text:
|
||||||
print("Error: Voice cloning requires a reference audio (--prompt-audio)", file=sys.stderr)
|
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not args.prompt_text:
|
|
||||||
print("Error: Voice cloning requires a reference text (--prompt-text)", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Validate files
|
|
||||||
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
||||||
output_path = validate_output_path(args.output)
|
output_path = validate_output_path(args.output)
|
||||||
|
|
||||||
# Load model
|
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
print(f"Synthesizing text: {args.text}", file=sys.stderr)
|
|
||||||
print(f"Reference audio: {prompt_audio_path}", file=sys.stderr)
|
|
||||||
print(f"Reference text: {args.prompt_text}", file=sys.stderr)
|
|
||||||
|
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
prompt_wav_path=str(prompt_audio_path),
|
prompt_wav_path=str(prompt_audio_path),
|
||||||
@ -136,30 +138,21 @@ def cmd_clone(args):
|
|||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=args.denoise
|
denoise=args.denoise,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save audio
|
|
||||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||||
print(f"Saved audio to: {output_path}", file=sys.stderr)
|
|
||||||
|
|
||||||
# Stats
|
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f"Duration: {duration:.2f}s", file=sys.stderr)
|
print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def cmd_synthesize(args):
|
def cmd_synthesize(args):
|
||||||
"""Direct TTS synthesis command."""
|
|
||||||
# Validate inputs
|
|
||||||
if not args.text:
|
if not args.text:
|
||||||
print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
|
sys.exit("Error: Please provide --text for synthesis")
|
||||||
sys.exit(1)
|
|
||||||
# Validate output path
|
|
||||||
output_path = validate_output_path(args.output)
|
output_path = validate_output_path(args.output)
|
||||||
# Load model
|
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
# Generate audio
|
|
||||||
print(f"Synthesizing text: {args.text}", file=sys.stderr)
|
|
||||||
|
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@ -168,45 +161,35 @@ def cmd_synthesize(args):
|
|||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=False # 无参考音频时不需要降噪
|
denoise=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save audio
|
|
||||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||||
print(f"Saved audio to: {output_path}", file=sys.stderr)
|
|
||||||
|
|
||||||
# Stats
|
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f"Duration: {duration:.2f}s", file=sys.stderr)
|
print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def cmd_batch(args):
|
def cmd_batch(args):
|
||||||
"""Batch synthesis command."""
|
|
||||||
# Validate input file
|
|
||||||
input_file = validate_file_exists(args.input, "input file")
|
input_file = validate_file_exists(args.input, "input file")
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
try:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
|
||||||
texts = [line.strip() for line in f if line.strip()]
|
texts = [line.strip() for line in f if line.strip()]
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to read input file: {e}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
if not texts:
|
if not texts:
|
||||||
print("Error: Input file is empty or contains no valid lines", file=sys.stderr)
|
sys.exit("Error: Input file is empty")
|
||||||
sys.exit(1)
|
|
||||||
print(f"Found {len(texts)} lines to process", file=sys.stderr)
|
|
||||||
|
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
|
|
||||||
prompt_audio_path = None
|
prompt_audio_path = None
|
||||||
if args.prompt_audio:
|
if args.prompt_audio:
|
||||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
||||||
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
for i, text in enumerate(texts, 1):
|
|
||||||
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...", file=sys.stderr)
|
|
||||||
|
|
||||||
|
for i, text in enumerate(texts, 1):
|
||||||
try:
|
try:
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=text,
|
text=text,
|
||||||
@ -215,112 +198,109 @@ def cmd_batch(args):
|
|||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=args.denoise and prompt_audio_path is not None
|
denoise=args.denoise and prompt_audio_path is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_file = output_dir / f"output_{i:03d}.wav"
|
output_file = output_dir / f"output_{i:03d}.wav"
|
||||||
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
||||||
|
|
||||||
duration = len(audio_array) / model.tts_model.sample_rate
|
duration = len(audio_array) / model.tts_model.sample_rate
|
||||||
print(f" Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
|
print(f"Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" Failed: {e}", file=sys.stderr)
|
print(f"Failed on line {i}: {e}", file=sys.stderr)
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
|
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Parser
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def _build_unified_parser():
|
def _build_unified_parser():
|
||||||
"""Build unified argument parser (no subcommands, route by args)."""
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="VoxCPM CLI (single parser) - voice cloning, direct TTS, and batch processing",
|
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
# Direct synthesis (single sample)
|
|
||||||
voxcpm --text "Hello world" --output out.wav
|
voxcpm --text "Hello world" --output out.wav
|
||||||
|
voxcpm --text "Hello" --prompt-audio ref.wav --prompt-text "hi" --output out.wav --denoise
|
||||||
# Voice cloning (reference audio + text)
|
|
||||||
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output out.wav --denoise
|
|
||||||
|
|
||||||
# Batch processing
|
|
||||||
voxcpm --input texts.txt --output-dir ./outs
|
voxcpm --input texts.txt --output-dir ./outs
|
||||||
|
""",
|
||||||
# Select model (from Hub)
|
|
||||||
voxcpm --text "Hello" --output out.wav --hf-model-id openbmb/VoxCPM-0.5B
|
|
||||||
"""
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Task selection (automatic routing by presence of args)
|
# Mode selection
|
||||||
parser.add_argument("--input", "-i", help="Input text file (one line per sample)")
|
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
||||||
parser.add_argument("--output-dir", "-od", help="Output directory (for batch mode)")
|
parser.add_argument("--output-dir", "-od", help="Output directory (batch mode only)")
|
||||||
parser.add_argument("--text", "-t", help="Text to synthesize (single-sample mode)")
|
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
|
||||||
parser.add_argument("--output", "-o", help="Output audio file path (single-sample mode)")
|
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
|
||||||
|
|
||||||
# Prompt audio (for voice cloning)
|
# Prompt
|
||||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
|
||||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||||
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
|
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
|
||||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
|
||||||
|
|
||||||
# Generation parameters
|
# Generation parameters
|
||||||
parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG guidance scale (default: 2.0)")
|
parser.add_argument("--cfg-value", type=float, default=2.0,
|
||||||
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (default: 10)")
|
help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)")
|
||||||
|
parser.add_argument("--inference-timesteps", type=int, default=10,
|
||||||
|
help="Inference steps (int, 1–100, default: 10)")
|
||||||
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
||||||
|
|
||||||
# Model loading parameters
|
# Model loading
|
||||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
|
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
|
||||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)")
|
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
|
||||||
|
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
|
||||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||||
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
|
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
|
||||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
parser.add_argument("--zipenhancer-path", type=str,
|
||||||
|
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
|
||||||
|
|
||||||
# LoRA parameters
|
# LoRA
|
||||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
|
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
|
||||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
|
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
|
||||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
|
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
|
||||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
|
parser.add_argument("--lora-dropout", type=float, default=0.0,
|
||||||
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
|
help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||||
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
|
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
|
||||||
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
|
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
|
||||||
|
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Entrypoint
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Unified CLI entrypoint: route by provided arguments."""
|
|
||||||
parser = _build_unified_parser()
|
parser = _build_unified_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Routing: prefer batch → single (clone/direct)
|
# Validate ranges
|
||||||
|
validate_ranges(args, parser)
|
||||||
|
|
||||||
|
# Mode conflict checks
|
||||||
|
if args.input and args.text:
|
||||||
|
parser.error("Use either batch mode (--input) or single mode (--text), not both.")
|
||||||
|
|
||||||
|
# Batch mode
|
||||||
if args.input:
|
if args.input:
|
||||||
if not args.output_dir:
|
if not args.output_dir:
|
||||||
print("Error: Batch mode requires --output-dir", file=sys.stderr)
|
parser.error("Batch mode requires --output-dir")
|
||||||
parser.print_help()
|
|
||||||
sys.exit(1)
|
|
||||||
return cmd_batch(args)
|
return cmd_batch(args)
|
||||||
|
|
||||||
# Single-sample mode
|
# Single mode
|
||||||
if not args.text or not args.output:
|
if not args.text or not args.output:
|
||||||
print("Error: Single-sample mode requires --text and --output", file=sys.stderr)
|
parser.error("Single-sample mode requires --text and --output")
|
||||||
parser.print_help()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# If prompt audio+text provided → voice cloning
|
# Clone mode
|
||||||
if args.prompt_audio or args.prompt_text:
|
if args.prompt_audio or args.prompt_text:
|
||||||
if not args.prompt_text and args.prompt_file:
|
|
||||||
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
|
|
||||||
|
|
||||||
with open(args.prompt_file, 'r', encoding='utf-8') as f:
|
|
||||||
args.prompt_text = f.read()
|
|
||||||
|
|
||||||
if not args.prompt_audio or not args.prompt_text:
|
|
||||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
return cmd_clone(args)
|
return cmd_clone(args)
|
||||||
|
|
||||||
# Otherwise → direct synthesis
|
# Direct synthesis
|
||||||
return cmd_synthesize(args)
|
return cmd_synthesize(args)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user