Print all log messages to stderr instead of stdout

This commit is contained in:
刘鑫
2026-01-12 15:30:45 +08:00
parent db75a7269b
commit e8dd956fc2
7 changed files with 100 additions and 96 deletions

15
app.py
View File

@ -1,4 +1,5 @@
import os import os
import sys
import numpy as np import numpy as np
import torch import torch
import gradio as gr import gradio as gr
@ -16,7 +17,7 @@ import voxcpm
class VoxCPMDemo: class VoxCPMDemo:
def __init__(self) -> None: def __init__(self) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {self.device}") print(f"🚀 Running on device: {self.device}", file=sys.stderr)
# ASR model for prompt text recognition # ASR model for prompt text recognition
self.asr_model_id = "iic/SenseVoiceSmall" self.asr_model_id = "iic/SenseVoiceSmall"
@ -49,10 +50,10 @@ class VoxCPMDemo:
try: try:
from huggingface_hub import snapshot_download # type: ignore from huggingface_hub import snapshot_download # type: ignore
os.makedirs(target_dir, exist_ok=True) os.makedirs(target_dir, exist_ok=True)
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...") print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False) snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
except Exception as e: except Exception as e:
print(f"Warning: HF download failed: {e}. Falling back to 'data'.") print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr)
return "models" return "models"
return target_dir return target_dir
return "models" return "models"
@ -60,11 +61,11 @@ class VoxCPMDemo:
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM: def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is not None: if self.voxcpm_model is not None:
return self.voxcpm_model return self.voxcpm_model
print("Model not loaded, initializing...") print("Model not loaded, initializing...", file=sys.stderr)
model_dir = self._resolve_model_dir() model_dir = self._resolve_model_dir()
print(f"Using model dir: {model_dir}") print(f"Using model dir: {model_dir}", file=sys.stderr)
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir) self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
print("Model loaded successfully.") print("Model loaded successfully.", file=sys.stderr)
return self.voxcpm_model return self.voxcpm_model
# ---------- Functional endpoints ---------- # ---------- Functional endpoints ----------
@ -98,7 +99,7 @@ class VoxCPMDemo:
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
prompt_text = prompt_text_input if prompt_text_input else None prompt_text = prompt_text_input if prompt_text_input else None
print(f"Generating audio for text: '{text[:60]}...'") print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
wav = current_model.generate( wav = current_model.generate(
text=text, text=text,
prompt_text=prompt_text, prompt_text=prompt_text,

View File

@ -104,7 +104,7 @@ def get_timestamp_str():
def get_or_load_asr_model(): def get_or_load_asr_model():
global asr_model global asr_model
if asr_model is None: if asr_model is None:
print("Loading ASR model (SenseVoiceSmall)...") print("Loading ASR model (SenseVoiceSmall)...", file=sys.stderr)
device = "cuda:0" if torch.cuda.is_available() else "cpu" device = "cuda:0" if torch.cuda.is_available() else "cpu"
asr_model = AutoModel( asr_model = AutoModel(
model="iic/SenseVoiceSmall", model="iic/SenseVoiceSmall",
@ -123,7 +123,7 @@ def recognize_audio(audio_path):
text = res[0]["text"].split('|>')[-1] text = res[0]["text"].split('|>')[-1]
return text return text
except Exception as e: except Exception as e:
print(f"ASR Error: {e}") print(f"ASR Error: {e}", file=sys.stderr)
return "" return ""
def scan_lora_checkpoints(root_dir="lora", with_info=False): def scan_lora_checkpoints(root_dir="lora", with_info=False):
@ -181,7 +181,7 @@ def load_lora_config_from_checkpoint(lora_path):
if lora_cfg_dict: if lora_cfg_dict:
return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model") return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model")
except Exception as e: except Exception as e:
print(f"Warning: Failed to load lora_config.json: {e}") print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
return None, None return None, None
def get_default_lora_config(): def get_default_lora_config():
@ -197,7 +197,7 @@ def get_default_lora_config():
def load_model(pretrained_path, lora_path=None): def load_model(pretrained_path, lora_path=None):
global current_model global current_model
print(f"Loading model from {pretrained_path}...") print(f"Loading model from {pretrained_path}...", file=sys.stderr)
lora_config = None lora_config = None
lora_weights_path = None lora_weights_path = None
@ -209,11 +209,11 @@ def load_model(pretrained_path, lora_path=None):
# Try to load LoRA config from lora_config.json # Try to load LoRA config from lora_config.json
lora_config, _ = load_lora_config_from_checkpoint(full_lora_path) lora_config, _ = load_lora_config_from_checkpoint(full_lora_path)
if lora_config: if lora_config:
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json") print(f"Loaded LoRA config from {full_lora_path}/lora_config.json", file=sys.stderr)
else: else:
# Fallback to default config for old checkpoints # Fallback to default config for old checkpoints
lora_config = get_default_lora_config() lora_config = get_default_lora_config()
print("Using default LoRA config (lora_config.json not found)") print("Using default LoRA config (lora_config.json not found)", file=sys.stderr)
# Always init with a default LoRA config to allow hot-swapping later # Always init with a default LoRA config to allow hot-swapping later
if lora_config is None: if lora_config is None:
@ -251,36 +251,36 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
# 优先使用保存的 base_model 路径 # 优先使用保存的 base_model 路径
if os.path.exists(saved_base_model): if os.path.exists(saved_base_model):
base_model_path = saved_base_model base_model_path = saved_base_model
print(f"Using base model from LoRA config: {base_model_path}") print(f"Using base model from LoRA config: {base_model_path}", file=sys.stderr)
else: else:
print(f"Warning: Saved base_model path not found: {saved_base_model}") print(f"Warning: Saved base_model path not found: {saved_base_model}", file=sys.stderr)
print(f"Falling back to default: {base_model_path}") print(f"Falling back to default: {base_model_path}", file=sys.stderr)
except Exception as e: except Exception as e:
print(f"Warning: Failed to read base_model from LoRA config: {e}") print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
# 加载模型 # 加载模型
try: try:
print(f"Loading base model: {base_model_path}") print(f"Loading base model: {base_model_path}", file=sys.stderr)
status_msg = load_model(base_model_path) status_msg = load_model(base_model_path)
if lora_selection and lora_selection != "None": if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}") print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
except Exception as e: except Exception as e:
error_msg = f"Failed to load model from {base_model_path}: {str(e)}" error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
print(error_msg) print(error_msg, file=sys.stderr)
return None, error_msg return None, error_msg
# Handle LoRA hot-swapping # Handle LoRA hot-swapping
if lora_selection and lora_selection != "None": if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection) full_lora_path = os.path.join("lora", lora_selection)
print(f"Hot-loading LoRA: {full_lora_path}") print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
try: try:
current_model.load_lora(full_lora_path) current_model.load_lora(full_lora_path)
current_model.set_lora_enabled(True) current_model.set_lora_enabled(True)
except Exception as e: except Exception as e:
print(f"Error loading LoRA: {e}") print(f"Error loading LoRA: {e}", file=sys.stderr)
return None, f"Error loading LoRA: {e}" return None, f"Error loading LoRA: {e}"
else: else:
print("Disabling LoRA") print("Disabling LoRA", file=sys.stderr)
current_model.set_lora_enabled(False) current_model.set_lora_enabled(False)
if seed != -1: if seed != -1:
@ -297,11 +297,11 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
# 如果没有提供参考文本,尝试自动识别 # 如果没有提供参考文本,尝试自动识别
if not prompt_text or not prompt_text.strip(): if not prompt_text or not prompt_text.strip():
print("参考音频已提供但缺少文本,自动识别中...") print("参考音频已提供但缺少文本,自动识别中...", file=sys.stderr)
try: try:
final_prompt_text = recognize_audio(prompt_wav) final_prompt_text = recognize_audio(prompt_wav)
if final_prompt_text: if final_prompt_text:
print(f"自动识别文本: {final_prompt_text}") print(f"自动识别文本: {final_prompt_text}", file=sys.stderr)
else: else:
return None, "错误:无法识别参考音频内容,请手动填写参考文本" return None, "错误:无法识别参考音频内容,请手动填写参考文本"
except Exception as e: except Exception as e:
@ -1114,12 +1114,12 @@ with gr.Blocks(
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info] choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
# 输出调试信息 # 输出调试信息
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点") print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点", file=sys.stderr)
for ckpt_path, base_model in checkpoints_with_info: for ckpt_path, base_model in checkpoints_with_info:
if base_model: if base_model:
print(f" - {ckpt_path} (Base Model: {base_model})") print(f" - {ckpt_path} (Base Model: {base_model})", file=sys.stderr)
else: else:
print(f" - {ckpt_path}") print(f" - {ckpt_path}", file=sys.stderr)
return gr.update(choices=choices, value="None") return gr.update(choices=choices, value="None")

View File

@ -23,6 +23,7 @@ With voice cloning:
""" """
import argparse import argparse
import sys
from pathlib import Path from pathlib import Path
import soundfile as sf import soundfile as sf
@ -92,7 +93,7 @@ def main():
args = parse_args() args = parse_args()
# Load model from checkpoint directory (no denoiser) # Load model from checkpoint directory (no denoiser)
print(f"[FT Inference] Loading model: {args.ckpt_dir}") print(f"[FT Inference] Loading model: {args.ckpt_dir}", file=sys.stderr)
model = VoxCPM.from_pretrained( model = VoxCPM.from_pretrained(
hf_model_id=args.ckpt_dir, hf_model_id=args.ckpt_dir,
load_denoiser=False, load_denoiser=False,
@ -103,10 +104,10 @@ def main():
prompt_wav_path = args.prompt_audio if args.prompt_audio else None prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None prompt_text = args.prompt_text if args.prompt_text else None
print(f"[FT Inference] Synthesizing: text='{args.text}'") print(f"[FT Inference] Synthesizing: text='{args.text}'", file=sys.stderr)
if prompt_wav_path: if prompt_wav_path:
print(f"[FT Inference] Using reference audio: {prompt_wav_path}") print(f"[FT Inference] Using reference audio: {prompt_wav_path}", file=sys.stderr)
print(f"[FT Inference] Reference text: {prompt_text}") print(f"[FT Inference] Reference text: {prompt_text}", file=sys.stderr)
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
@ -124,7 +125,7 @@ def main():
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.tts_model.sample_rate) sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,6 +24,7 @@ Note: The script reads base_model path and lora_config from lora_config.json
import argparse import argparse
import json import json
import sys
from pathlib import Path from pathlib import Path
import soundfile as sf import soundfile as sf
@ -124,13 +125,13 @@ def main():
lora_cfg_dict = lora_info.get("lora_config", {}) lora_cfg_dict = lora_info.get("lora_config", {})
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
print(f"Loaded config from: {lora_config_path}") print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
print(f" Base model: {pretrained_path}") print(f" Base model: {pretrained_path}", file=sys.stderr)
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None") print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
# 3. Load model with LoRA (no denoiser) # 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}") print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
print(f" LoRA weights: {ckpt_dir}") print(f" LoRA weights: {ckpt_dir}", file=sys.stderr)
model = VoxCPM.from_pretrained( model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path, hf_model_id=pretrained_path,
load_denoiser=False, load_denoiser=False,
@ -145,10 +146,10 @@ def main():
out_path = Path(args.output) out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[2/2] Starting synthesis tests...") print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
# === Test 1: With LoRA === # === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...") print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
@ -161,10 +162,10 @@ def main():
) )
lora_output = out_path.with_stem(out_path.stem + "_with_lora") lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate) sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 2: Disable LoRA (via set_lora_enabled) === # === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...") print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
model.set_lora_enabled(False) model.set_lora_enabled(False)
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
@ -178,10 +179,10 @@ def main():
) )
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled") disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate) sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 3: Re-enable LoRA === # === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...") print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
model.set_lora_enabled(True) model.set_lora_enabled(True)
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
@ -195,10 +196,10 @@ def main():
) )
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled") reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate) sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 4: Unload LoRA (reset_lora_weights) === # === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (unload_lora)...") print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
model.unload_lora() model.unload_lora()
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
@ -212,12 +213,12 @@ def main():
) )
reset_output = out_path.with_stem(out_path.stem + "_lora_reset") reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate) sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 5: Hot-reload LoRA (load_lora) === # === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...") print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
loaded, skipped = model.load_lora(ckpt_dir) loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters") print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
@ -230,14 +231,14 @@ def main():
) )
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded") reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate) sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(f"\n[Done] All tests completed!") print(f"\n[Done] All tests completed!", file=sys.stderr)
print(f" - with_lora: {lora_output}") print(f" - with_lora: {lora_output}", file=sys.stderr)
print(f" - lora_disabled: {disabled_output}") print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
print(f" - lora_reenabled: {reenabled_output}") print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
print(f" - lora_reset: {reset_output}") print(f" - lora_reset: {reset_output}", file=sys.stderr)
print(f" - lora_reloaded: {reload_output}") print(f" - lora_reloaded: {reload_output}", file=sys.stderr)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,7 +24,7 @@ try:
SAFETENSORS_AVAILABLE = True SAFETENSORS_AVAILABLE = True
except ImportError: except ImportError:
SAFETENSORS_AVAILABLE = False SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format") print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
from voxcpm.model import VoxCPMModel from voxcpm.model import VoxCPMModel
from voxcpm.model.voxcpm import LoRAConfig from voxcpm.model.voxcpm import LoRAConfig
@ -170,7 +170,7 @@ def train(
# Only print param info on rank 0 to avoid cluttered output # Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0: if accelerator.rank == 0:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
print(name, param.requires_grad) print(name, param.requires_grad, file=sys.stderr)
optimizer = AdamW( optimizer = AdamW(
(p for p in model.parameters() if p.requires_grad), (p for p in model.parameters() if p.requires_grad),
@ -210,12 +210,12 @@ def train(
cur_step = int(_resume.get("step", start_step)) cur_step = int(_resume.get("step", start_step))
except Exception: except Exception:
cur_step = start_step cur_step = start_step
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...") print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
try: try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist) save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
print("Checkpoint saved. Exiting.") print("Checkpoint saved. Exiting.", file=sys.stderr)
except Exception as e: except Exception as e:
print(f"Error saving checkpoint on signal: {e}") print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
os._exit(0) os._exit(0)
signal.signal(signal.SIGTERM, _signal_handler) signal.signal(signal.SIGTERM, _signal_handler)
@ -553,7 +553,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
# Load only lora weights # Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False) unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded LoRA weights from {lora_weights_path}") print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
else: else:
# Full finetune: load model.safetensors or pytorch_model.bin # Full finetune: load model.safetensors or pytorch_model.bin
model_path = latest_folder / "model.safetensors" model_path = latest_folder / "model.safetensors"
@ -569,26 +569,26 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
state_dict = ckpt.get("state_dict", ckpt) state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False) unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded model weights from {model_path}") print(f"Loaded model weights from {model_path}", file=sys.stderr)
# Load optimizer state # Load optimizer state
optimizer_path = latest_folder / "optimizer.pth" optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists(): if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu")) optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
print(f"Loaded optimizer state from {optimizer_path}") print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
# Load scheduler state # Load scheduler state
scheduler_path = latest_folder / "scheduler.pth" scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists(): if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu")) scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
print(f"Loaded scheduler state from {scheduler_path}") print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
# Try to infer step from checkpoint folders # Try to infer step from checkpoint folders
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")] step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders: if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders] steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps) resume_step = max(steps)
print(f"Resuming from step {resume_step}") print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step return resume_step
return 0 return 0
@ -670,7 +670,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
latest_link.unlink() latest_link.unlink()
shutil.copytree(folder, latest_link) shutil.copytree(folder, latest_link)
except Exception: except Exception:
print(f"Warning: failed to update latest checkpoint link at {latest_link}") print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -45,7 +45,7 @@ def load_model(args) -> VoxCPM:
Prefer --model-path if provided; otherwise use from_pretrained (Hub). Prefer --model-path if provided; otherwise use from_pretrained (Hub).
""" """
print("Loading VoxCPM model...") print("Loading VoxCPM model...", file=sys.stderr)
# 兼容旧参数ZIPENHANCER_MODEL_PATH 环境变量作为默认 # 兼容旧参数ZIPENHANCER_MODEL_PATH 环境变量作为默认
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get( zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
@ -66,7 +66,7 @@ def load_model(args) -> VoxCPM:
dropout=getattr(args, "lora_dropout", 0.0), dropout=getattr(args, "lora_dropout", 0.0),
) )
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, " print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}") f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}", file=sys.stderr)
# Load from local path if provided # Load from local path if provided
if getattr(args, "model_path", None): if getattr(args, "model_path", None):
@ -78,10 +78,10 @@ def load_model(args) -> VoxCPM:
lora_config=lora_config, lora_config=lora_config,
lora_weights_path=lora_weights_path, lora_weights_path=lora_weights_path,
) )
print("Model loaded (local).") print("Model loaded (local).", file=sys.stderr)
return model return model
except Exception as e: except Exception as e:
print(f"Failed to load model (local): {e}") print(f"Failed to load model (local): {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
# Otherwise, try from_pretrained (Hub); exit on failure # Otherwise, try from_pretrained (Hub); exit on failure
@ -95,10 +95,10 @@ def load_model(args) -> VoxCPM:
lora_config=lora_config, lora_config=lora_config,
lora_weights_path=lora_weights_path, lora_weights_path=lora_weights_path,
) )
print("Model loaded (from_pretrained).") print("Model loaded (from_pretrained).", file=sys.stderr)
return model return model
except Exception as e: except Exception as e:
print(f"Failed to load model (from_pretrained): {e}") print(f"Failed to load model (from_pretrained): {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
@ -106,15 +106,15 @@ def cmd_clone(args):
"""Voice cloning command.""" """Voice cloning command."""
# Validate inputs # Validate inputs
if not args.text: if not args.text:
print("Error: Please provide text to synthesize (--text)") print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
sys.exit(1) sys.exit(1)
if not args.prompt_audio: if not args.prompt_audio:
print("Error: Voice cloning requires a reference audio (--prompt-audio)") print("Error: Voice cloning requires a reference audio (--prompt-audio)", file=sys.stderr)
sys.exit(1) sys.exit(1)
if not args.prompt_text: if not args.prompt_text:
print("Error: Voice cloning requires a reference text (--prompt-text)") print("Error: Voice cloning requires a reference text (--prompt-text)", file=sys.stderr)
sys.exit(1) sys.exit(1)
# Validate files # Validate files
@ -125,9 +125,9 @@ def cmd_clone(args):
model = load_model(args) model = load_model(args)
# Generate audio # Generate audio
print(f"Synthesizing text: {args.text}") print(f"Synthesizing text: {args.text}", file=sys.stderr)
print(f"Reference audio: {prompt_audio_path}") print(f"Reference audio: {prompt_audio_path}", file=sys.stderr)
print(f"Reference text: {args.prompt_text}") print(f"Reference text: {args.prompt_text}", file=sys.stderr)
audio_array = model.generate( audio_array = model.generate(
text=args.text, text=args.text,
@ -141,25 +141,25 @@ def cmd_clone(args):
# Save audio # Save audio
sf.write(str(output_path), audio_array, model.tts_model.sample_rate) sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}") print(f"Saved audio to: {output_path}", file=sys.stderr)
# Stats # Stats
duration = len(audio_array) / model.tts_model.sample_rate duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s") print(f"Duration: {duration:.2f}s", file=sys.stderr)
def cmd_synthesize(args): def cmd_synthesize(args):
"""Direct TTS synthesis command.""" """Direct TTS synthesis command."""
# Validate inputs # Validate inputs
if not args.text: if not args.text:
print("Error: Please provide text to synthesize (--text)") print("Error: Please provide text to synthesize (--text)", file=sys.stderr)
sys.exit(1) sys.exit(1)
# Validate output path # Validate output path
output_path = validate_output_path(args.output) output_path = validate_output_path(args.output)
# Load model # Load model
model = load_model(args) model = load_model(args)
# Generate audio # Generate audio
print(f"Synthesizing text: {args.text}") print(f"Synthesizing text: {args.text}", file=sys.stderr)
audio_array = model.generate( audio_array = model.generate(
text=args.text, text=args.text,
@ -173,11 +173,11 @@ def cmd_synthesize(args):
# Save audio # Save audio
sf.write(str(output_path), audio_array, model.tts_model.sample_rate) sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}") print(f"Saved audio to: {output_path}", file=sys.stderr)
# Stats # Stats
duration = len(audio_array) / model.tts_model.sample_rate duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s") print(f"Duration: {duration:.2f}s", file=sys.stderr)
def cmd_batch(args): def cmd_batch(args):
@ -191,12 +191,12 @@ def cmd_batch(args):
with open(input_file, 'r', encoding='utf-8') as f: with open(input_file, 'r', encoding='utf-8') as f:
texts = [line.strip() for line in f if line.strip()] texts = [line.strip() for line in f if line.strip()]
except Exception as e: except Exception as e:
print(f"Failed to read input file: {e}") print(f"Failed to read input file: {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
if not texts: if not texts:
print("Error: Input file is empty or contains no valid lines") print("Error: Input file is empty or contains no valid lines", file=sys.stderr)
sys.exit(1) sys.exit(1)
print(f"Found {len(texts)} lines to process") print(f"Found {len(texts)} lines to process", file=sys.stderr)
model = load_model(args) model = load_model(args)
prompt_audio_path = None prompt_audio_path = None
@ -205,7 +205,7 @@ def cmd_batch(args):
success_count = 0 success_count = 0
for i, text in enumerate(texts, 1): for i, text in enumerate(texts, 1):
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...") print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...", file=sys.stderr)
try: try:
audio_array = model.generate( audio_array = model.generate(
@ -221,14 +221,14 @@ def cmd_batch(args):
sf.write(str(output_file), audio_array, model.tts_model.sample_rate) sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
duration = len(audio_array) / model.tts_model.sample_rate duration = len(audio_array) / model.tts_model.sample_rate
print(f" Saved: {output_file} ({duration:.2f}s)") print(f" Saved: {output_file} ({duration:.2f}s)", file=sys.stderr)
success_count += 1 success_count += 1
except Exception as e: except Exception as e:
print(f" Failed: {e}") print(f" Failed: {e}", file=sys.stderr)
continue continue
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded") print(f"\nBatch finished: {success_count}/{len(texts)} succeeded", file=sys.stderr)
def _build_unified_parser(): def _build_unified_parser():
"""Build unified argument parser (no subcommands, route by args).""" """Build unified argument parser (no subcommands, route by args)."""
@ -296,14 +296,14 @@ def main():
# Routing: prefer batch → single (clone/direct) # Routing: prefer batch → single (clone/direct)
if args.input: if args.input:
if not args.output_dir: if not args.output_dir:
print("Error: Batch mode requires --output-dir") print("Error: Batch mode requires --output-dir", file=sys.stderr)
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
return cmd_batch(args) return cmd_batch(args)
# Single-sample mode # Single-sample mode
if not args.text or not args.output: if not args.text or not args.output:
print("Error: Single-sample mode requires --text and --output") print("Error: Single-sample mode requires --text and --output", file=sys.stderr)
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
@ -316,7 +316,7 @@ def main():
args.prompt_text = f.read() args.prompt_text = f.read()
if not args.prompt_audio or not args.prompt_text: if not args.prompt_audio or not args.prompt_text:
print("Error: Voice cloning requires both --prompt-audio and --prompt-text") print("Error: Voice cloning requires both --prompt-audio and --prompt-text", file=sys.stderr)
sys.exit(1) sys.exit(1)
return cmd_clone(args) return cmd_clone(args)

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional
@ -36,7 +37,7 @@ class TrainingTracker:
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def print(self, message: str): def print(self, message: str):
if self.rank == 0: if self.rank == 0:
print(message, flush=True) print(message, flush=True, file=sys.stderr)
if self.log_file: if self.log_file:
with self.log_file.open("a", encoding="utf-8") as f: with self.log_file.open("a", encoding="utf-8") as f:
f.write(message + "\n") f.write(message + "\n")