Print all log messages to stderr instead of stdout

This commit is contained in:
刘鑫
2026-01-12 15:30:45 +08:00
parent db75a7269b
commit e8dd956fc2
7 changed files with 100 additions and 96 deletions

View File

@ -24,7 +24,7 @@ try:
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format")
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
from voxcpm.model import VoxCPMModel
from voxcpm.model.voxcpm import LoRAConfig
@ -170,7 +170,7 @@ def train(
# Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0:
for name, param in model.named_parameters():
print(name, param.requires_grad)
print(name, param.requires_grad, file=sys.stderr)
optimizer = AdamW(
(p for p in model.parameters() if p.requires_grad),
@ -210,12 +210,12 @@ def train(
cur_step = int(_resume.get("step", start_step))
except Exception:
cur_step = start_step
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...", file=sys.stderr)
try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
print("Checkpoint saved. Exiting.")
print("Checkpoint saved. Exiting.", file=sys.stderr)
except Exception as e:
print(f"Error saving checkpoint on signal: {e}")
print(f"Error saving checkpoint on signal: {e}", file=sys.stderr)
os._exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
@ -553,7 +553,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
# Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded LoRA weights from {lora_weights_path}")
print(f"Loaded LoRA weights from {lora_weights_path}", file=sys.stderr)
else:
# Full finetune: load model.safetensors or pytorch_model.bin
model_path = latest_folder / "model.safetensors"
@ -569,26 +569,26 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded model weights from {model_path}")
print(f"Loaded model weights from {model_path}", file=sys.stderr)
# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
print(f"Loaded optimizer state from {optimizer_path}")
print(f"Loaded optimizer state from {optimizer_path}", file=sys.stderr)
# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
print(f"Loaded scheduler state from {scheduler_path}")
print(f"Loaded scheduler state from {scheduler_path}", file=sys.stderr)
# Try to infer step from checkpoint folders
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps)
print(f"Resuming from step {resume_step}")
print(f"Resuming from step {resume_step}", file=sys.stderr)
return resume_step
return 0
@ -670,7 +670,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
latest_link.unlink()
shutil.copytree(folder, latest_link)
except Exception:
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
if __name__ == "__main__":