Print all log messages to stderr instead of stdout
This commit is contained in:
@ -23,6 +23,7 @@ With voice cloning:
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
@ -92,7 +93,7 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load model from checkpoint directory (no denoiser)
|
||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}", file=sys.stderr)
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=args.ckpt_dir,
|
||||
load_denoiser=False,
|
||||
@ -103,10 +104,10 @@ def main():
|
||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||
prompt_text = args.prompt_text if args.prompt_text else None
|
||||
|
||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
||||
print(f"[FT Inference] Synthesizing: text='{args.text}'", file=sys.stderr)
|
||||
if prompt_wav_path:
|
||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}", file=sys.stderr)
|
||||
print(f"[FT Inference] Reference text: {prompt_text}", file=sys.stderr)
|
||||
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@ -124,7 +125,7 @@ def main():
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -24,6 +24,7 @@ Note: The script reads base_model path and lora_config from lora_config.json
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
@ -124,13 +125,13 @@ def main():
|
||||
lora_cfg_dict = lora_info.get("lora_config", {})
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
print(f"Loaded config from: {lora_config_path}")
|
||||
print(f" Base model: {pretrained_path}")
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
|
||||
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
||||
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f" LoRA weights: {ckpt_dir}")
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
||||
print(f" LoRA weights: {ckpt_dir}", file=sys.stderr)
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=pretrained_path,
|
||||
load_denoiser=False,
|
||||
@ -145,10 +146,10 @@ def main():
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[2/2] Starting synthesis tests...")
|
||||
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
||||
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
@ -161,10 +162,10 @@ def main():
|
||||
)
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||
model.set_lora_enabled(False)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@ -178,10 +179,10 @@ def main():
|
||||
)
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||
model.set_lora_enabled(True)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@ -195,10 +196,10 @@ def main():
|
||||
)
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||
model.unload_lora()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@ -212,12 +213,12 @@ def main():
|
||||
)
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||
loaded, skipped = model.load_lora(ckpt_dir)
|
||||
print(f" Reloaded {len(loaded)} parameters")
|
||||
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
@ -230,14 +231,14 @@ def main():
|
||||
)
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
|
||||
print(f"\n[Done] All tests completed!")
|
||||
print(f" - with_lora: {lora_output}")
|
||||
print(f" - lora_disabled: {disabled_output}")
|
||||
print(f" - lora_reenabled: {reenabled_output}")
|
||||
print(f" - lora_reset: {reset_output}")
|
||||
print(f" - lora_reloaded: {reload_output}")
|
||||
print(f"\n[Done] All tests completed!", file=sys.stderr)
|
||||
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
||||
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
||||
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
||||
print(f" - lora_reset: {reset_output}", file=sys.stderr)
|
||||
print(f" - lora_reloaded: {reload_output}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -24,7 +24,7 @@ try:
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
print("Warning: safetensors not available, will use pytorch format")
|
||||
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
@ -170,7 +170,7 @@ def train(
|
||||
# Only print param info on rank 0 to avoid cluttered output
|
||||
if accelerator.rank == 0:
|
||||
for name, param in model.named_parameters():
|
||||
print(name, param.requires_grad)
|
||||
print(name, param.requires_grad, file=sys.stderr)
|
||||
|
||||
optimizer = AdamW(
|
||||
(p for p in model.parameters() if p.requires_grad),
|
||||
@ -210,12 +210,12 @@ def train(
|
||||
cur_step = int(_resume.get("step", start_step))
|
||||
except Exception:
|
||||
cur_step = start_step
|
||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
|
||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
|
||||
try:
|
||||
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
||||
print("Checkpoint saved. Exiting.")
|
||||
print("Checkpoint saved. Exiting.", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error saving checkpoint on signal: {e}")
|
||||
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
@ -553,7 +553,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
|
||||
# Load only lora weights
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}")
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
|
||||
else:
|
||||
# Full finetune: load model.safetensors or pytorch_model.bin
|
||||
model_path = latest_folder / "model.safetensors"
|
||||
@ -569,26 +569,26 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
print(f"Loaded model weights from {model_path}")
|
||||
print(f"Loaded model weights from {model_path}", file=sys.stderr)
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = latest_folder / "optimizer.pth"
|
||||
if optimizer_path.exists():
|
||||
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||
print(f"Loaded optimizer state from {optimizer_path}")
|
||||
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = latest_folder / "scheduler.pth"
|
||||
if scheduler_path.exists():
|
||||
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||
print(f"Loaded scheduler state from {scheduler_path}")
|
||||
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
|
||||
|
||||
# Try to infer step from checkpoint folders
|
||||
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
||||
if step_folders:
|
||||
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
||||
resume_step = max(steps)
|
||||
print(f"Resuming from step {resume_step}")
|
||||
print(f"Resuming from step {resume_step}", file=sys.stderr)
|
||||
return resume_step
|
||||
|
||||
return 0
|
||||
@ -670,7 +670,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
latest_link.unlink()
|
||||
shutil.copytree(folder, latest_link)
|
||||
except Exception:
|
||||
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
|
||||
print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user