checkpoint_saver.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """ Checkpoint Saver
  2. Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
  3. Hacked together by / Copyright 2020 Ross Wightman
  4. """
  5. import glob
  6. import logging
  7. import operator
  8. import os
  9. import shutil
  10. import torch
  11. from .model import unwrap_model, get_state_dict
  12. _logger = logging.getLogger(__name__)
  13. class CheckpointSaver:
  14. def __init__(
  15. self,
  16. model,
  17. optimizer,
  18. args=None,
  19. model_ema=None,
  20. amp_scaler=None,
  21. checkpoint_prefix='checkpoint',
  22. recovery_prefix='recovery',
  23. checkpoint_dir='',
  24. recovery_dir='',
  25. decreasing=False,
  26. max_history=10,
  27. unwrap_fn=unwrap_model
  28. ):
  29. # objects to save state_dicts of
  30. self.model = model
  31. self.optimizer = optimizer
  32. self.args = args
  33. self.model_ema = model_ema
  34. self.amp_scaler = amp_scaler
  35. # state
  36. self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
  37. self.best_epoch = None
  38. self.best_metric = None
  39. self.curr_recovery_file = ''
  40. self.prev_recovery_file = ''
  41. self.can_hardlink = True
  42. # config
  43. self.checkpoint_dir = checkpoint_dir
  44. self.recovery_dir = recovery_dir
  45. self.save_prefix = checkpoint_prefix
  46. self.recovery_prefix = recovery_prefix
  47. self.extension = '.pth.tar'
  48. self.decreasing = decreasing # a lower metric is better if True
  49. self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
  50. self.max_history = max_history
  51. self.unwrap_fn = unwrap_fn
  52. assert self.max_history >= 1
  53. def _replace(self, src, dst):
  54. if self.can_hardlink:
  55. try:
  56. if os.path.exists(dst):
  57. os.unlink(dst) # required for Windows support.
  58. except (OSError, NotImplementedError) as e:
  59. self.can_hardlink = False
  60. os.replace(src, dst)
  61. def _duplicate(self, src, dst):
  62. if self.can_hardlink:
  63. try:
  64. if os.path.exists(dst):
  65. # for Windows
  66. os.unlink(dst)
  67. os.link(src, dst)
  68. return
  69. except (OSError, NotImplementedError) as e:
  70. self.can_hardlink = False
  71. shutil.copy2(src, dst)
  72. def _save(self, save_path, epoch, metric=None):
  73. save_state = {
  74. 'epoch': epoch,
  75. 'arch': type(self.model).__name__.lower(),
  76. 'state_dict': get_state_dict(self.model, self.unwrap_fn),
  77. 'optimizer': self.optimizer.state_dict(),
  78. 'version': 2, # version < 2 increments epoch before save
  79. }
  80. if self.args is not None:
  81. save_state['arch'] = self.args.model
  82. save_state['args'] = self.args
  83. if self.amp_scaler is not None:
  84. save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
  85. if self.model_ema is not None:
  86. save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
  87. if metric is not None:
  88. save_state['metric'] = metric
  89. torch.save(save_state, save_path)
  90. def _cleanup_checkpoints(self, trim=0):
  91. trim = min(len(self.checkpoint_files), trim)
  92. delete_index = self.max_history - trim
  93. if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
  94. return
  95. to_delete = self.checkpoint_files[delete_index:]
  96. for d in to_delete:
  97. try:
  98. _logger.debug("Cleaning checkpoint: {}".format(d))
  99. os.remove(d[0])
  100. except Exception as e:
  101. _logger.error("Exception '{}' while deleting checkpoint".format(e))
  102. self.checkpoint_files = self.checkpoint_files[:delete_index]
  103. def save_checkpoint(self, epoch, metric=None):
  104. assert epoch >= 0
  105. tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
  106. last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
  107. self._save(tmp_save_path, epoch, metric)
  108. self._replace(tmp_save_path, last_save_path)
  109. worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
  110. if (
  111. len(self.checkpoint_files) < self.max_history
  112. or metric is None
  113. or self.cmp(metric, worst_file[1])
  114. ):
  115. if len(self.checkpoint_files) >= self.max_history:
  116. self._cleanup_checkpoints(1)
  117. filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
  118. save_path = os.path.join(self.checkpoint_dir, filename)
  119. self._duplicate(last_save_path, save_path)
  120. self.checkpoint_files.append((save_path, metric))
  121. self.checkpoint_files = sorted(
  122. self.checkpoint_files,
  123. key=lambda x: x[1],
  124. reverse=not self.decreasing # sort in descending order if a lower metric is not better
  125. )
  126. checkpoints_str = "Current checkpoints:\n"
  127. for c in self.checkpoint_files:
  128. checkpoints_str += ' {}\n'.format(c)
  129. _logger.info(checkpoints_str)
  130. if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
  131. self.best_epoch = epoch
  132. self.best_metric = metric
  133. best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
  134. self._duplicate(last_save_path, best_save_path)
  135. return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
  136. def save_recovery(self, epoch, batch_idx=0):
  137. assert epoch >= 0
  138. tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
  139. self._save(tmp_save_path, epoch)
  140. filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
  141. save_path = os.path.join(self.recovery_dir, filename)
  142. self._replace(tmp_save_path, save_path)
  143. if os.path.exists(self.prev_recovery_file):
  144. try:
  145. _logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
  146. os.remove(self.prev_recovery_file)
  147. except Exception as e:
  148. _logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
  149. self.prev_recovery_file = self.curr_recovery_file
  150. self.curr_recovery_file = save_path
  151. def find_recovery(self):
  152. recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
  153. files = glob.glob(recovery_path + '*' + self.extension)
  154. files = sorted(files)
  155. return files[0] if len(files) else ''