Print all log messages to stderr instead of stdout

This commit is contained in:
刘鑫
2026-01-12 15:30:45 +08:00
parent db75a7269b
commit e8dd956fc2
7 changed files with 100 additions and 96 deletions

View File

@ -24,6 +24,7 @@ Note: The script reads base_model path and lora_config from lora_config.json
import argparse
import json
import sys
from pathlib import Path
import soundfile as sf
@ -124,13 +125,13 @@ def main():
lora_cfg_dict = lora_info.get("lora_config", {})
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
print(f"Loaded config from: {lora_config_path}")
print(f" Base model: {pretrained_path}")
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
print(f" Base model: {pretrained_path}", file=sys.stderr)
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
# 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
print(f" LoRA weights: {ckpt_dir}")
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
print(f" LoRA weights: {ckpt_dir}", file=sys.stderr)
model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path,
load_denoiser=False,
@ -145,10 +146,10 @@ def main():
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[2/2] Starting synthesis tests...")
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
# === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...")
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
@ -161,10 +162,10 @@ def main():
)
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
model.set_lora_enabled(False)
audio_np = model.generate(
text=args.text,
@ -178,10 +179,10 @@ def main():
)
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
model.set_lora_enabled(True)
audio_np = model.generate(
text=args.text,
@ -195,10 +196,10 @@ def main():
)
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
model.unload_lora()
audio_np = model.generate(
text=args.text,
@ -212,12 +213,12 @@ def main():
)
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
# === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters")
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
@ -230,14 +231,14 @@ def main():
)
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
print(f"\n[Done] All tests completed!")
print(f" - with_lora: {lora_output}")
print(f" - lora_disabled: {disabled_output}")
print(f" - lora_reenabled: {reenabled_output}")
print(f" - lora_reset: {reset_output}")
print(f" - lora_reloaded: {reload_output}")
print(f"\n[Done] All tests completed!", file=sys.stderr)
print(f" - with_lora: {lora_output}", file=sys.stderr)
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
print(f" - lora_reset: {reset_output}", file=sys.stderr)
print(f" - lora_reloaded: {reload_output}", file=sys.stderr)
if __name__ == "__main__":