_helpers.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. """ Model creation / weight loading / state_dict helpers
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import argparse
  5. import logging
  6. import os
  7. import pickle
  8. from typing import Any, Callable, Dict, Optional, Union
  9. import torch
  10. try:
  11. import safetensors.torch
  12. _has_safetensors = True
  13. except ImportError:
  14. _has_safetensors = False
  15. _logger = logging.getLogger(__name__)
  16. __all__ = [
  17. 'clean_state_dict',
  18. 'load_checkpoint',
  19. 'load_state_dict',
  20. 'remap_state_dict',
  21. 'resume_checkpoint',
  22. ]
  23. def _checkpoint_unsafe_globals(checkpoint_path: str) -> str:
  24. if not hasattr(torch.serialization, 'get_unsafe_globals_in_checkpoint'):
  25. return ''
  26. try:
  27. unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(str(checkpoint_path))
  28. except Exception:
  29. unsafe_globals = []
  30. return f" Unsupported globals: {', '.join(unsafe_globals)}." if unsafe_globals else ''
  31. def _torch_load(
  32. checkpoint_path: str,
  33. map_location: Union[str, torch.device] = 'cpu',
  34. weights_only: bool = True,
  35. ):
  36. use_safe_globals = weights_only and hasattr(torch.serialization, 'safe_globals')
  37. try:
  38. if use_safe_globals:
  39. # Compatibility: timm training checkpoints often include argparse.Namespace in `args`.
  40. with torch.serialization.safe_globals([argparse.Namespace]):
  41. return torch.load(checkpoint_path, map_location=map_location, weights_only=weights_only)
  42. return torch.load(checkpoint_path, map_location=map_location, weights_only=weights_only)
  43. except TypeError as e:
  44. if not weights_only:
  45. return torch.load(checkpoint_path, map_location=map_location)
  46. raise RuntimeError(
  47. f"weights_only=True is not supported by this PyTorch build (torch=={torch.__version__}). "
  48. "No automatic unsafe pickle fallback is performed. "
  49. "Upgrade PyTorch, or explicitly set weights_only=False only for trusted local checkpoints."
  50. ) from e
  51. except pickle.UnpicklingError as e:
  52. if not weights_only:
  53. raise
  54. raise RuntimeError(
  55. "weights_only=True blocked loading this checkpoint because it requires non-allowlisted pickle globals."
  56. f"{_checkpoint_unsafe_globals(checkpoint_path)} "
  57. "No automatic unsafe pickle fallback is performed. "
  58. "If this checkpoint is trusted, retry with weights_only=False."
  59. ) from e
  60. def _remove_prefix(text: str, prefix: str) -> str:
  61. # FIXME replace with 3.9 stdlib fn when min at 3.9
  62. if text.startswith(prefix):
  63. return text[len(prefix):]
  64. return text
  65. def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
  66. # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
  67. cleaned_state_dict = {}
  68. to_remove = (
  69. 'module.', # DDP wrapper
  70. '_orig_mod.', # torchcompile dynamo wrapper
  71. )
  72. for k, v in state_dict.items():
  73. for r in to_remove:
  74. k = _remove_prefix(k, r)
  75. cleaned_state_dict[k] = v
  76. return cleaned_state_dict
  77. def load_state_dict(
  78. checkpoint_path: str,
  79. use_ema: bool = True,
  80. device: Union[str, torch.device] = 'cpu',
  81. weights_only: bool = True,
  82. ) -> Dict[str, Any]:
  83. """Load state dictionary from checkpoint file.
  84. Args:
  85. checkpoint_path: Path to checkpoint file.
  86. use_ema: Whether to use EMA weights if available.
  87. device: Device to load checkpoint to.
  88. weights_only: Whether to load only weights (torch.load parameter).
  89. Returns:
  90. State dictionary loaded from checkpoint.
  91. """
  92. if checkpoint_path and os.path.isfile(checkpoint_path):
  93. # Check if safetensors or not and load weights accordingly
  94. if str(checkpoint_path).endswith(".safetensors"):
  95. assert _has_safetensors, "`pip install safetensors` to use .safetensors"
  96. checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
  97. else:
  98. checkpoint = _torch_load(checkpoint_path, map_location=device, weights_only=weights_only)
  99. state_dict_key = ''
  100. if isinstance(checkpoint, dict):
  101. if use_ema and checkpoint.get('state_dict_ema', None) is not None:
  102. state_dict_key = 'state_dict_ema'
  103. elif use_ema and checkpoint.get('model_ema', None) is not None:
  104. state_dict_key = 'model_ema'
  105. elif 'state_dict' in checkpoint:
  106. state_dict_key = 'state_dict'
  107. elif 'model' in checkpoint:
  108. state_dict_key = 'model'
  109. state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
  110. _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
  111. return state_dict
  112. else:
  113. _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
  114. raise FileNotFoundError()
  115. def load_checkpoint(
  116. model: torch.nn.Module,
  117. checkpoint_path: str,
  118. use_ema: bool = True,
  119. device: Union[str, torch.device] = 'cpu',
  120. strict: bool = True,
  121. remap: bool = False,
  122. filter_fn: Optional[Callable] = None,
  123. weights_only: bool = True,
  124. ) -> Any:
  125. """Load checkpoint into model.
  126. Args:
  127. model: Model to load checkpoint into.
  128. checkpoint_path: Path to checkpoint file.
  129. use_ema: Whether to use EMA weights if available.
  130. device: Device to load checkpoint to.
  131. strict: Whether to strictly enforce state_dict keys match.
  132. remap: Whether to remap state dict keys by order.
  133. filter_fn: Optional function to filter state dict.
  134. weights_only: Whether to load only weights (torch.load parameter).
  135. Returns:
  136. Incompatible keys from model.load_state_dict().
  137. """
  138. if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
  139. # numpy checkpoint, try to load via model specific load_pretrained fn
  140. if hasattr(model, 'load_pretrained'):
  141. model.load_pretrained(checkpoint_path)
  142. else:
  143. raise NotImplementedError('Model cannot load numpy checkpoint')
  144. return
  145. state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
  146. if remap:
  147. state_dict = remap_state_dict(state_dict, model)
  148. elif filter_fn:
  149. state_dict = filter_fn(state_dict, model)
  150. incompatible_keys = model.load_state_dict(state_dict, strict=strict)
  151. return incompatible_keys
  152. def remap_state_dict(
  153. state_dict: Dict[str, Any],
  154. model: torch.nn.Module,
  155. allow_reshape: bool = True
  156. ) -> Dict[str, Any]:
  157. """Remap checkpoint by iterating over state dicts in order (ignoring original keys).
  158. This assumes models (and originating state dict) were created with params registered in same order.
  159. Args:
  160. state_dict: State dict to remap.
  161. model: Model whose state dict keys to use.
  162. allow_reshape: Whether to allow reshaping tensors to match.
  163. Returns:
  164. Remapped state dictionary.
  165. """
  166. out_dict = {}
  167. for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
  168. assert va.numel() == vb.numel(), f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
  169. if va.shape != vb.shape:
  170. if allow_reshape:
  171. vb = vb.reshape(va.shape)
  172. else:
  173. assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
  174. out_dict[ka] = vb
  175. return out_dict
  176. def resume_checkpoint(
  177. model: torch.nn.Module,
  178. checkpoint_path: str,
  179. optimizer: Optional[torch.optim.Optimizer] = None,
  180. loss_scaler: Optional[Any] = None,
  181. log_info: bool = True,
  182. weights_only: bool = True,
  183. ) -> Optional[int]:
  184. """Resume training from checkpoint.
  185. Args:
  186. model: Model to load checkpoint into.
  187. checkpoint_path: Path to checkpoint file.
  188. optimizer: Optional optimizer to restore state.
  189. loss_scaler: Optional AMP loss scaler to restore state.
  190. log_info: Whether to log loading info.
  191. weights_only: Whether to load only weights via torch.load.
  192. Returns:
  193. Resume epoch number if available, else None.
  194. """
  195. resume_epoch = None
  196. if os.path.isfile(checkpoint_path):
  197. checkpoint = _torch_load(checkpoint_path, map_location='cpu', weights_only=weights_only)
  198. if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
  199. if log_info:
  200. _logger.info('Restoring model state from checkpoint...')
  201. state_dict = clean_state_dict(checkpoint['state_dict'])
  202. model.load_state_dict(state_dict)
  203. if optimizer is not None and 'optimizer' in checkpoint:
  204. if log_info:
  205. _logger.info('Restoring optimizer state from checkpoint...')
  206. optimizer.load_state_dict(checkpoint['optimizer'])
  207. if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
  208. if log_info:
  209. _logger.info('Restoring AMP loss scaler state from checkpoint...')
  210. loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
  211. if 'epoch' in checkpoint:
  212. resume_epoch = checkpoint['epoch']
  213. if 'version' in checkpoint and checkpoint['version'] > 1:
  214. resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
  215. if log_info:
  216. _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
  217. else:
  218. model.load_state_dict(checkpoint)
  219. if log_info:
  220. _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
  221. return resume_epoch
  222. else:
  223. _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
  224. raise FileNotFoundError()