| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import os
- import torch
- from torch.nn.parallel.data_parallel import DataParallel
- from torch.nn.parallel.distributed import DistributedDataParallel
- from loguru import logger
- import gc
- import romatch
- class CheckPoint:
- def __init__(self, dir=None, name="tmp"):
- self.name = name
- self.dir = dir
- os.makedirs(self.dir, exist_ok=True)
- def save(
- self,
- model,
- optimizer,
- lr_scheduler,
- n,
- ):
- if romatch.RANK == 0:
- assert model is not None
- if isinstance(model, (DataParallel, DistributedDataParallel)):
- model = model.module
- states = {
- "model": model.state_dict(),
- "n": n,
- "optimizer": optimizer.state_dict(),
- "lr_scheduler": lr_scheduler.state_dict(),
- }
- torch.save(states, self.dir + self.name + f"_latest.pth")
- logger.info(f"Saved states {list(states.keys())}, at step {n}")
-
- def load(
- self,
- model,
- optimizer,
- lr_scheduler,
- n,
- ):
- if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0:
- states = torch.load(self.dir + self.name + f"_latest.pth")
- if "model" in states:
- model.load_state_dict(states["model"])
- if "n" in states:
- n = states["n"] if states["n"] else n
- if "optimizer" in states:
- try:
- optimizer.load_state_dict(states["optimizer"])
- except Exception as e:
- print(f"Failed to load states for optimizer, with error {e}")
- if "lr_scheduler" in states:
- lr_scheduler.load_state_dict(states["lr_scheduler"])
- print(f"Loaded states {list(states.keys())}, at step {n}")
- del states
- gc.collect()
- torch.cuda.empty_cache()
- return model, optimizer, lr_scheduler, n
|