Print all log messages to stderr instead of stdout
This commit is contained in:
15
app.py
15
app.py
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
@ -16,7 +17,7 @@ import voxcpm
|
||||
class VoxCPMDemo:
|
||||
def __init__(self) -> None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🚀 Running on device: {self.device}")
|
||||
print(f"🚀 Running on device: {self.device}", file=sys.stderr)
|
||||
|
||||
# ASR model for prompt text recognition
|
||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
||||
@ -49,10 +50,10 @@ class VoxCPMDemo:
|
||||
try:
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
|
||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
except Exception as e:
|
||||
print(f"Warning: HF download failed: {e}. Falling back to 'data'.")
|
||||
print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
|
||||
return "models"
|
||||
return target_dir
|
||||
return "models"
|
||||
@ -60,11 +61,11 @@ class VoxCPMDemo:
|
||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||
if self.voxcpm_model is not None:
|
||||
return self.voxcpm_model
|
||||
print("Model not loaded, initializing...")
|
||||
print("Model not loaded, initializing...", file=sys.stderr)
|
||||
model_dir = self._resolve_model_dir()
|
||||
print(f"Using model dir: {model_dir}")
|
||||
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||
print("Model loaded successfully.")
|
||||
print("Model loaded successfully.", file=sys.stderr)
|
||||
return self.voxcpm_model
|
||||
|
||||
# ---------- Functional endpoints ----------
|
||||
@ -98,7 +99,7 @@ class VoxCPMDemo:
|
||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||
prompt_text = prompt_text_input if prompt_text_input else None
|
||||
|
||||
print(f"Generating audio for text: '{text[:60]}...'")
|
||||
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
||||
wav = current_model.generate(
|
||||
text=text,
|
||||
prompt_text=prompt_text,
|
||||
|
||||
@ -104,7 +104,7 @@ def get_timestamp_str():
|
||||
def get_or_load_asr_model():
|
||||
global asr_model
|
||||
if asr_model is None:
|
||||
print("Loading ASR model (SenseVoiceSmall)...")
|
||||
print("Loading ASR model (SenseVoiceSmall)...", file=sys.stderr)
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
asr_model = AutoModel(
|
||||
model="iic/SenseVoiceSmall",
|
||||
@ -123,7 +123,7 @@ def recognize_audio(audio_path):
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
return text
|
||||
except Exception as e:
|
||||
print(f"ASR Error: {e}")
|
||||
print(f"ASR Error: {e}", file=sys.stderr)
|
||||
return ""
|
||||
|
||||
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||
@ -181,7 +181,7 @@ def load_lora_config_from_checkpoint(lora_path):
|
||||
if lora_cfg_dict:
|
||||
return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load lora_config.json: {e}")
|
||||
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
|
||||
return None, None
|
||||
|
||||
def get_default_lora_config():
|
||||
@ -197,7 +197,7 @@ def get_default_lora_config():
|
||||
|
||||
def load_model(pretrained_path, lora_path=None):
|
||||
global current_model
|
||||
print(f"Loading model from {pretrained_path}...")
|
||||
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
|
||||
|
||||
lora_config = None
|
||||
lora_weights_path = None
|
||||
@ -209,11 +209,11 @@ def load_model(pretrained_path, lora_path=None):
|
||||
# Try to load LoRA config from lora_config.json
|
||||
lora_config, _ = load_lora_config_from_checkpoint(full_lora_path)
|
||||
if lora_config:
|
||||
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json")
|
||||
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json", file=sys.stderr)
|
||||
else:
|
||||
# Fallback to default config for old checkpoints
|
||||
lora_config = get_default_lora_config()
|
||||
print("Using default LoRA config (lora_config.json not found)")
|
||||
print("Using default LoRA config (lora_config.json not found)", file=sys.stderr)
|
||||
|
||||
# Always init with a default LoRA config to allow hot-swapping later
|
||||
if lora_config is None:
|
||||
@ -251,36 +251,36 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
# 优先使用保存的 base_model 路径
|
||||
if os.path.exists(saved_base_model):
|
||||
base_model_path = saved_base_model
|
||||
print(f"Using base model from LoRA config: {base_model_path}")
|
||||
print(f"Using base model from LoRA config: {base_model_path}", file=sys.stderr)
|
||||
else:
|
||||
print(f"Warning: Saved base_model path not found: {saved_base_model}")
|
||||
print(f"Falling back to default: {base_model_path}")
|
||||
print(f"Warning: Saved base_model path not found: {saved_base_model}", file=sys.stderr)
|
||||
print(f"Falling back to default: {base_model_path}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read base_model from LoRA config: {e}")
|
||||
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||
|
||||
# 加载模型
|
||||
try:
|
||||
print(f"Loading base model: {base_model_path}")
|
||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||
status_msg = load_model(base_model_path)
|
||||
if lora_selection and lora_selection != "None":
|
||||
print(f"Model loaded for LoRA: {lora_selection}")
|
||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
||||
print(error_msg)
|
||||
print(error_msg, file=sys.stderr)
|
||||
return None, error_msg
|
||||
|
||||
# Handle LoRA hot-swapping
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
print(f"Hot-loading LoRA: {full_lora_path}")
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
try:
|
||||
current_model.load_lora(full_lora_path)
|
||||
current_model.set_lora_enabled(True)
|
||||
except Exception as e:
|
||||
print(f"Error loading LoRA: {e}")
|
||||
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||
return None, f"Error loading LoRA: {e}"
|
||||
else:
|
||||
print("Disabling LoRA")
|
||||
print("Disabling LoRA", file=sys.stderr)
|
||||
current_model.set_lora_enabled(False)
|
||||
|
||||
if seed != -1:
|
||||
@ -297,11 +297,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
|
||||
# 如果没有提供参考文本,尝试自动识别
|
||||
if not prompt_text or not prompt_text.strip():
|
||||
print("参考音频已提供但缺少文本,自动识别中...")
|
||||
print("参考音频已提供但缺少文本,自动识别中...", file=sys.stderr)
|
||||
try:
|
||||
final_prompt_text = recognize_audio(prompt_wav)
|
||||
if final_prompt_text:
|
||||
print(f"自动识别文本: {final_prompt_text}")
|
||||
print(f"自动识别文本: {final_prompt_text}", file=sys.stderr)
|
||||
else:
|
||||
return None, "错误:无法识别参考音频内容,请手动填写参考文本"
|
||||
except Exception as e:
|
||||
@ -1114,12 +1114,12 @@ with gr.Blocks(
|
||||
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
|
||||
|
||||
# 输出调试信息
|
||||
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点")
|
||||
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点", file=sys.stderr)
|
||||
for ckpt_path, base_model in checkpoints_with_info:
|
||||
if base_model:
|
||||
print(f" - {ckpt_path} (Base Model: {base_model})")
|
||||
print(f" - {ckpt_path} (Base Model: {base_model})", file=sys.stderr)
|
||||
else:
|
||||
print(f" - {ckpt_path}")
|
||||
print(f" - {ckpt_path}", file=sys.stderr)
|
||||
|
||||
return gr.update(choices=choices, value="None")
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ With voice cloning:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
@ -92,7 +93,7 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load model from checkpoint directory (no denoiser)
|
||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}", file=sys.stderr)
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=args.ckpt_dir,
|
||||
load_denoiser=False,
|
||||
@ -103,10 +104,10 @@ def main():
|
||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||
prompt_text = args.prompt_text if args.prompt_text else None
|
||||
|
||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
||||
print(f"[FT Inference] Synthesizing: text='{args.text}'", file=sys.stderr)
|
||||
if prompt_wav_path:
|
||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}", file=sys.stderr)
|
||||
print(f"[FT Inference] Reference text: {prompt_text}", file=sys.stderr)
|
||||
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@ -124,7 +125,7 @@ def main():
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -24,6 +24,7 @@ Note: The script reads base_model path and lora_config from lora_config.json
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
@ -124,13 +125,13 @@ def main():
|
||||
lora_cfg_dict = lora_info.get("lora_config", {})
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
print(f"Loaded config from: {lora_config_path}")
|
||||
print(f" Base model: {pretrained_path}")
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
|
||||
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
||||
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f" LoRA weights: {ckpt_dir}")
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
||||
print(f" LoRA weights: {ckpt_dir}", file=sys.stderr)
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=pretrained_path,
|
||||
load_denoiser=False,
|
||||
@ -145,10 +146,10 @@ def main():
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[2/2] Starting synthesis tests...")
|
||||
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
||||
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
@ -161,10 +162,10 @@ def main():
|
||||
)
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||
model.set_lora_enabled(False)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@ -178,10 +179,10 @@ def main():
|
||||
)
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||
model.set_lora_enabled(True)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@ -195,10 +196,10 @@ def main():
|
||||
)
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||
model.unload_lora()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@ -212,12 +213,12 @@ def main():
|
||||
)
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||
loaded, skipped = model.load_lora(ckpt_dir)
|
||||
print(f" Reloaded {len(loaded)} parameters")
|
||||
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
@ -230,14 +231,14 @@ def main():
|
||||
)
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
print(f"\n[Done] All tests completed!")
|
||||
print(f" - with_lora: {lora_output}")
|
||||
print(f" - lora_disabled: {disabled_output}")
|
||||
print(f" - lora_reenabled: {reenabled_output}")
|
||||
print(f" - lora_reset: {reset_output}")
|
||||
print(f" - lora_reloaded: {reload_output}")
|
||||
print(f"\n[Done] All tests completed!", file=sys.stderr)
|
||||
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
||||
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
||||
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
||||
print(f" - lora_reset: {reset_output}", file=sys.stderr)
|
||||
print(f" - lora_reloaded: {reload_output}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -24,7 +24,7 @@ try:
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
print("Warning: safetensors not available, will use pytorch format")
|
||||
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
@ -170,7 +170,7 @@ def train(
|
||||
# Only print param info on rank 0 to avoid cluttered output
|
||||
if accelerator.rank == 0:
|
||||
for name, param in model.named_parameters():
|
||||
print(name, param.requires_grad)
|
||||
print(name, param.requires_grad, file=sys.stderr)
|
||||
|
||||
optimizer = AdamW(
|
||||
(p for p in model.parameters() if p.requires_grad),
|
||||
@ -210,12 +210,12 @@ def train(
|
||||
cur_step = int(_resume.get("step", start_step))
|
||||
except Exception:
|
||||
cur_step = start_step
|
||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
|
||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
|
||||
try:
|
||||
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
||||
print("Checkpoint saved. Exiting.")
|
||||
print("Checkpoint saved. Exiting.", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error saving checkpoint on signal: {e}")
|
||||
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
@ -553,7 +553,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
|
||||
# Load only lora weights
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}")
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||
else:
|
||||
# Full finetune: load model.safetensors or pytorch_model.bin
|
||||
model_path = latest_folder / "model.safetensors"
|
||||
@ -569,26 +569,26 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
print(f"Loaded model weights from {model_path}")
|
||||
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = latest_folder / "optimizer.pth"
|
||||
if optimizer_path.exists():
|
||||
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||
print(f"Loaded optimizer state from {optimizer_path}")
|
||||
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = latest_folder / "scheduler.pth"
|
||||
if scheduler_path.exists():
|
||||
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||
print(f"Loaded scheduler state from {scheduler_path}")
|
||||
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||
|
||||
# Try to infer step from checkpoint folders
|
||||
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
||||
if step_folders:
|
||||
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
||||
resume_step = max(steps)
|
||||
print(f"Resuming from step {resume_step}")
|
||||
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||
return resume_step
|
||||
|
||||
return 0
|
||||
@ -670,7 +670,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
latest_link.unlink()
|
||||
shutil.copytree(folder, latest_link)
|
||||
except Exception:
|
||||
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
|
||||
print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -45,7 +45,7 @@ def load_model(args) -> VoxCPM:
|
||||
|
||||
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
|
||||
"""
|
||||
print("Loading VoxCPM model...")
|
||||
print("Loading VoxCPM model...", file=sys.stderr)
|
||||
|
||||
# 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||
@ -66,7 +66,7 @@ def load_model(args) -> VoxCPM:
|
||||
dropout=getattr(args, "lora_dropout", 0.0),
|
||||
)
|
||||
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
|
||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}", file=sys.stderr)
|
||||
|
||||
# Load from local path if provided
|
||||
if getattr(args, "model_path", None):
|
||||
@ -78,10 +78,10 @@ def load_model(args) -> VoxCPM:
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (local).")
|
||||
print("Model loaded (local).", file=sys.stderr)
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"Failed to load model (local): {e}")
|
||||
print(f"Failed to load model (local): {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Otherwise, try from_pretrained (Hub); exit on failure
|
||||
@ -95,10 +95,10 @@ def load_model(args) -> VoxCPM:
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (from_pretrained).")
|
||||
print("Model loaded (from_pretrained).", file=sys.stderr)
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"Failed to load model (from_pretrained): {e}")
|
||||
print(f"Failed to load model (from_pretrained): {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@ -106,15 +106,15 @@ def cmd_clone(args):
|
||||
"""Voice cloning command."""
|
||||
# Validate inputs
|
||||
if not args.text:
|
||||
print("Error: Please provide text to synthesize (--text)")
|
||||
print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not args.prompt_audio:
|
||||
print("Error: Voice cloning requires a reference audio (--prompt-audio)")
|
||||
print("Error: Voice cloning requires a reference audio (--prompt-audio)", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not args.prompt_text:
|
||||
print("Error: Voice cloning requires a reference text (--prompt-text)")
|
||||
print("Error: Voice cloning requires a reference text (--prompt-text)", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Validate files
|
||||
@ -125,9 +125,9 @@ def cmd_clone(args):
|
||||
model = load_model(args)
|
||||
|
||||
# Generate audio
|
||||
print(f"Synthesizing text: {args.text}")
|
||||
print(f"Reference audio: {prompt_audio_path}")
|
||||
print(f"Reference text: {args.prompt_text}")
|
||||
print(f"Synthesizing text: {args.text}", file=sys.stderr)
|
||||
print(f"Reference audio: {prompt_audio_path}", file=sys.stderr)
|
||||
print(f"Reference text: {args.prompt_text}", file=sys.stderr)
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
@ -141,25 +141,25 @@ def cmd_clone(args):
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
print(f"Saved audio to: {output_path}", file=sys.stderr)
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
print(f"Duration: {duration:.2f}s", file=sys.stderr)
|
||||
|
||||
|
||||
def cmd_synthesize(args):
|
||||
"""Direct TTS synthesis command."""
|
||||
# Validate inputs
|
||||
if not args.text:
|
||||
print("Error: Please provide text to synthesize (--text)")
|
||||
print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
# Validate output path
|
||||
output_path = validate_output_path(args.output)
|
||||
# Load model
|
||||
model = load_model(args)
|
||||
# Generate audio
|
||||
print(f"Synthesizing text: {args.text}")
|
||||
print(f"Synthesizing text: {args.text}", file=sys.stderr)
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
@ -173,11 +173,11 @@ def cmd_synthesize(args):
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
print(f"Saved audio to: {output_path}", file=sys.stderr)
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
print(f"Duration: {duration:.2f}s", file=sys.stderr)
|
||||
|
||||
|
||||
def cmd_batch(args):
|
||||
@ -191,12 +191,12 @@ def cmd_batch(args):
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
texts = [line.strip() for line in f if line.strip()]
|
||||
except Exception as e:
|
||||
print(f"Failed to read input file: {e}")
|
||||
print(f"Failed to read input file: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not texts:
|
||||
print("Error: Input file is empty or contains no valid lines")
|
||||
print("Error: Input file is empty or contains no valid lines", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
print(f"Found {len(texts)} lines to process")
|
||||
print(f"Found {len(texts)} lines to process", file=sys.stderr)
|
||||
|
||||
model = load_model(args)
|
||||
prompt_audio_path = None
|
||||
@ -205,7 +205,7 @@ def cmd_batch(args):
|
||||
|
||||
success_count = 0
|
||||
for i, text in enumerate(texts, 1):
|
||||
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...")
|
||||
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...", file=sys.stderr)
|
||||
|
||||
try:
|
||||
audio_array = model.generate(
|
||||
@ -221,14 +221,14 @@ def cmd_batch(args):
|
||||
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
||||
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f" Saved: {output_file} ({duration:.2f}s)")
|
||||
print(f" Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" Failed: {e}")
|
||||
print(f" Failed: {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded")
|
||||
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
|
||||
|
||||
def _build_unified_parser():
|
||||
"""Build unified argument parser (no subcommands, route by args)."""
|
||||
@ -296,14 +296,14 @@ def main():
|
||||
# Routing: prefer batch → single (clone/direct)
|
||||
if args.input:
|
||||
if not args.output_dir:
|
||||
print("Error: Batch mode requires --output-dir")
|
||||
print("Error: Batch mode requires --output-dir", file=sys.stderr)
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
return cmd_batch(args)
|
||||
|
||||
# Single-sample mode
|
||||
if not args.text or not args.output:
|
||||
print("Error: Single-sample mode requires --text and --output")
|
||||
print("Error: Single-sample mode requires --text and --output", file=sys.stderr)
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
@ -316,7 +316,7 @@ def main():
|
||||
args.prompt_text = f.read()
|
||||
|
||||
if not args.prompt_audio or not args.prompt_text:
|
||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return cmd_clone(args)
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
@ -36,7 +37,7 @@ class TrainingTracker:
|
||||
# ------------------------------------------------------------------ #
|
||||
def print(self, message: str):
|
||||
if self.rank == 0:
|
||||
print(message, flush=True)
|
||||
print(message, flush=True, file=sys.stderr)
|
||||
if self.log_file:
|
||||
with self.log_file.open("a", encoding="utf-8") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
Reference in New Issue
Block a user