Fix: optimize save ckpt function
This commit is contained in:
@ -603,7 +603,7 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
tag = "latest" if step == 0 else f"step_{step:07d}"
|
tag = f"step_{step:07d}"
|
||||||
folder = save_dir / tag
|
folder = save_dir / tag
|
||||||
folder.mkdir(parents=True, exist_ok=True)
|
folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@ -649,28 +649,14 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||||
|
|
||||||
# Update (or create) a `latest` symlink pointing to the most recent checkpoint folder
|
# Update (or create) a `latest` folder by copying the most recent checkpoint
|
||||||
latest_link = save_dir / "latest"
|
latest_link = save_dir / "latest"
|
||||||
try:
|
try:
|
||||||
if latest_link.exists() or latest_link.is_symlink():
|
if latest_link.exists():
|
||||||
# remove existing link or directory
|
shutil.rmtree(latest_link)
|
||||||
if latest_link.is_dir() and not latest_link.is_symlink():
|
shutil.copytree(folder, latest_link)
|
||||||
shutil.rmtree(latest_link)
|
|
||||||
else:
|
|
||||||
latest_link.unlink()
|
|
||||||
# Create a symlink pointing to the new folder
|
|
||||||
os.symlink(str(folder), str(latest_link))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# If symlink creation fails (e.g., on Windows or permission issues), fall back to copying
|
print(f"Warning: failed to update latest checkpoint at {latest_link}", file=sys.stderr)
|
||||||
try:
|
|
||||||
if latest_link.exists():
|
|
||||||
if latest_link.is_dir():
|
|
||||||
shutil.rmtree(latest_link)
|
|
||||||
else:
|
|
||||||
latest_link.unlink()
|
|
||||||
shutil.copytree(folder, latest_link)
|
|
||||||
except Exception:
|
|
||||||
print(f"Warning: failed to update latest checkpoint link at {latest_link}", file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user