Fix: optimize save ckpt function
This commit is contained in:
@ -603,7 +603,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
import shutil
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = "latest" if step == 0 else f"step_{step:07d}"
|
||||
tag = f"step_{step:07d}"
|
||||
folder = save_dir / tag
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -649,28 +649,14 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||
|
||||
# Update (or create) a `latest` symlink pointing to the most recent checkpoint folder
|
||||
# Update (or create) a `latest` folder by copying the most recent checkpoint
|
||||
latest_link = save_dir / "latest"
|
||||
try:
|
||||
if latest_link.exists() or latest_link.is_symlink():
|
||||
# remove existing link or directory
|
||||
if latest_link.is_dir() and not latest_link.is_symlink():
|
||||
shutil.rmtree(latest_link)
|
||||
else:
|
||||
latest_link.unlink()
|
||||
# Create a symlink pointing to the new folder
|
||||
os.symlink(str(folder), str(latest_link))
|
||||
if latest_link.exists():
|
||||
shutil.rmtree(latest_link)
|
||||
shutil.copytree(folder, latest_link)
|
||||
except Exception:
|
||||
# If symlink creation fails (e.g., on Windows or permission issues), fall back to copying
|
||||
try:
|
||||
if latest_link.exists():
|
||||
if latest_link.is_dir():
|
||||
shutil.rmtree(latest_link)
|
||||
else:
|
||||
latest_link.unlink()
|
||||
shutil.copytree(folder, latest_link)
|
||||
except Exception:
|
||||
print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
|
||||
print(f"Warning: failed to update latest checkpoint at {latest_link}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user