Fix: optimize save ckpt function

This commit is contained in:
刘鑫
2026-01-16 16:22:34 +08:00
parent e8dd956fc2
commit 79e75f259e

View File

@ -603,7 +603,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
import shutil import shutil
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
tag = "latest" if step == 0 else f"step_{step:07d}" tag = f"step_{step:07d}"
folder = save_dir / tag folder = save_dir / tag
folder.mkdir(parents=True, exist_ok=True) folder.mkdir(parents=True, exist_ok=True)
@ -649,28 +649,14 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
torch.save(optimizer.state_dict(), folder / "optimizer.pth") torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth") torch.save(scheduler.state_dict(), folder / "scheduler.pth")
# Update (or create) a `latest` symlink pointing to the most recent checkpoint folder # Update (or create) a `latest` folder by copying the most recent checkpoint
latest_link = save_dir / "latest" latest_link = save_dir / "latest"
try: try:
if latest_link.exists() or latest_link.is_symlink(): if latest_link.exists():
# remove existing link or directory shutil.rmtree(latest_link)
if latest_link.is_dir() and not latest_link.is_symlink(): shutil.copytree(folder, latest_link)
shutil.rmtree(latest_link)
else:
latest_link.unlink()
# Create a symlink pointing to the new folder
os.symlink(str(folder), str(latest_link))
except Exception: except Exception:
# If symlink creation fails (e.g., on Windows or permission issues), fall back to copying print(f"Warning: failed to update latest checkpoint at {latest_link}", file=sys.stderr)
try:
if latest_link.exists():
if latest_link.is_dir():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
shutil.copytree(folder, latest_link)
except Exception:
print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
if __name__ == "__main__": if __name__ == "__main__":