checkpoint.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import torch
  3. from torch.nn.parallel.data_parallel import DataParallel
  4. from torch.nn.parallel.distributed import DistributedDataParallel
  5. from loguru import logger
  6. import gc
  7. import romatch
  8. class CheckPoint:
  9. def __init__(self, dir=None, name="tmp"):
  10. self.name = name
  11. self.dir = dir
  12. os.makedirs(self.dir, exist_ok=True)
  13. def save(
  14. self,
  15. model,
  16. optimizer,
  17. lr_scheduler,
  18. n,
  19. ):
  20. if romatch.RANK == 0:
  21. assert model is not None
  22. if isinstance(model, (DataParallel, DistributedDataParallel)):
  23. model = model.module
  24. states = {
  25. "model": model.state_dict(),
  26. "n": n,
  27. "optimizer": optimizer.state_dict(),
  28. "lr_scheduler": lr_scheduler.state_dict(),
  29. }
  30. torch.save(states, self.dir + self.name + f"_latest.pth")
  31. logger.info(f"Saved states {list(states.keys())}, at step {n}")
  32. def load(
  33. self,
  34. model,
  35. optimizer,
  36. lr_scheduler,
  37. n,
  38. ):
  39. if os.path.exists(self.dir + self.name + f"_latest.pth") and romatch.RANK == 0:
  40. states = torch.load(self.dir + self.name + f"_latest.pth")
  41. if "model" in states:
  42. model.load_state_dict(states["model"])
  43. if "n" in states:
  44. n = states["n"] if states["n"] else n
  45. if "optimizer" in states:
  46. try:
  47. optimizer.load_state_dict(states["optimizer"])
  48. except Exception as e:
  49. print(f"Failed to load states for optimizer, with error {e}")
  50. if "lr_scheduler" in states:
  51. lr_scheduler.load_state_dict(states["lr_scheduler"])
  52. print(f"Loaded states {list(states.keys())}, at step {n}")
  53. del states
  54. gc.collect()
  55. torch.cuda.empty_cache()
  56. return model, optimizer, lr_scheduler, n