_manipulate.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import collections.abc
  2. import math
  3. import re
  4. from collections import defaultdict
  5. from itertools import chain
  6. from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union
  7. import torch
  8. import torch.utils.checkpoint
  9. from torch import nn as nn
  10. from torch import Tensor
  11. from timm.layers import use_reentrant_ckpt
  12. __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
  13. 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq', 'checkpoint',
  14. 'reinit_non_persistent_buffers']
  15. def model_parameters(model: nn.Module, exclude_head: bool = False):
  16. if exclude_head:
  17. # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
  18. return [p for p in model.parameters()][:-2]
  19. else:
  20. return model.parameters()
  21. def named_apply(
  22. fn: Callable,
  23. module: nn.Module, name='',
  24. depth_first: bool = True,
  25. include_root: bool = False,
  26. ) -> nn.Module:
  27. if not depth_first and include_root:
  28. fn(module=module, name=name)
  29. for child_name, child_module in module.named_children():
  30. child_name = '.'.join((name, child_name)) if name else child_name
  31. named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
  32. if depth_first and include_root:
  33. fn(module=module, name=name)
  34. return module
  35. def named_modules(
  36. module: nn.Module,
  37. name: str = '',
  38. depth_first: bool = True,
  39. include_root: bool = False,
  40. ):
  41. if not depth_first and include_root:
  42. yield name, module
  43. for child_name, child_module in module.named_children():
  44. child_name = '.'.join((name, child_name)) if name else child_name
  45. yield from named_modules(
  46. module=child_module, name=child_name, depth_first=depth_first, include_root=True)
  47. if depth_first and include_root:
  48. yield name, module
  49. def named_modules_with_params(
  50. module: nn.Module,
  51. name: str = '',
  52. depth_first: bool = True,
  53. include_root: bool = False,
  54. ):
  55. if module._parameters and not depth_first and include_root:
  56. yield name, module
  57. for child_name, child_module in module.named_children():
  58. child_name = '.'.join((name, child_name)) if name else child_name
  59. yield from named_modules_with_params(
  60. module=child_module, name=child_name, depth_first=depth_first, include_root=True)
  61. if module._parameters and depth_first and include_root:
  62. yield name, module
  63. MATCH_PREV_GROUP = (99999,)
  64. def group_with_matcher(
  65. named_objects: Iterator[Tuple[str, Any]],
  66. group_matcher: Union[Dict, Callable],
  67. return_values: bool = False,
  68. reverse: bool = False
  69. ):
  70. if isinstance(group_matcher, dict):
  71. # dictionary matcher contains a dict of raw-string regex expr that must be compiled
  72. compiled = []
  73. for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
  74. if mspec is None:
  75. continue
  76. # map all matching specifications into 3-tuple (compiled re, prefix, suffix)
  77. if isinstance(mspec, (tuple, list)):
  78. # multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
  79. for sspec in mspec:
  80. compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
  81. else:
  82. compiled += [(re.compile(mspec), (group_ordinal,), None)]
  83. group_matcher = compiled
  84. def _get_grouping(name):
  85. if isinstance(group_matcher, (list, tuple)):
  86. for match_fn, prefix, suffix in group_matcher:
  87. r = match_fn.match(name)
  88. if r:
  89. parts = (prefix, r.groups(), suffix)
  90. # map all tuple elem to int for numeric sort, filter out None entries
  91. return tuple(map(float, chain.from_iterable(filter(None, parts))))
  92. return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
  93. else:
  94. ord = group_matcher(name)
  95. if not isinstance(ord, collections.abc.Iterable):
  96. return ord,
  97. return tuple(ord)
  98. # map layers into groups via ordinals (ints or tuples of ints) from matcher
  99. grouping = defaultdict(list)
  100. for k, v in named_objects:
  101. grouping[_get_grouping(k)].append(v if return_values else k)
  102. # remap to integers
  103. layer_id_to_param = defaultdict(list)
  104. lid = -1
  105. for k in sorted(filter(lambda x: x is not None, grouping.keys())):
  106. if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
  107. lid += 1
  108. layer_id_to_param[lid].extend(grouping[k])
  109. if reverse:
  110. assert not return_values, "reverse mapping only sensible for name output"
  111. # output reverse mapping
  112. param_to_layer_id = {}
  113. for lid, lm in layer_id_to_param.items():
  114. for n in lm:
  115. param_to_layer_id[n] = lid
  116. return param_to_layer_id
  117. return layer_id_to_param
  118. def group_parameters(
  119. module: nn.Module,
  120. group_matcher,
  121. return_values: bool = False,
  122. reverse: bool = False,
  123. ):
  124. return group_with_matcher(
  125. module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)
  126. def group_modules(
  127. module: nn.Module,
  128. group_matcher,
  129. return_values: bool = False,
  130. reverse: bool = False,
  131. ):
  132. return group_with_matcher(
  133. named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)
  134. def flatten_modules(
  135. named_modules: Iterator[Tuple[str, nn.Module]],
  136. depth: int = 1,
  137. prefix: Union[str, Tuple[str, ...]] = '',
  138. module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
  139. ):
  140. prefix_is_tuple = isinstance(prefix, tuple)
  141. if isinstance(module_types, str):
  142. if module_types == 'container':
  143. module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
  144. else:
  145. module_types = (nn.Sequential,)
  146. for name, module in named_modules:
  147. if depth and isinstance(module, module_types):
  148. yield from flatten_modules(
  149. module.named_children(),
  150. depth - 1,
  151. prefix=(name,) if prefix_is_tuple else name,
  152. module_types=module_types,
  153. )
  154. else:
  155. if prefix_is_tuple:
  156. name = prefix + (name,)
  157. yield name, module
  158. else:
  159. if prefix:
  160. name = '.'.join([prefix, name])
  161. yield name, module
  162. def checkpoint(
  163. function,
  164. *args,
  165. use_reentrant: Optional[bool] = None,
  166. **kwargs,
  167. ):
  168. """ checkpoint wrapper fn
  169. A thin wrapper around torch.utils.checkpoint.checkpoint to default
  170. use_reentrant to False
  171. """
  172. if use_reentrant is None:
  173. use_reentrant = use_reentrant_ckpt()
  174. return torch.utils.checkpoint.checkpoint(
  175. function,
  176. *args,
  177. use_reentrant=use_reentrant,
  178. **kwargs,
  179. )
  180. def checkpoint_seq(
  181. functions,
  182. x,
  183. every: int = 1,
  184. flatten: bool = False,
  185. skip_last: bool = False,
  186. use_reentrant: Optional[bool] = None,
  187. ):
  188. r"""A helper function for checkpointing sequential models.
  189. Sequential models execute a list of modules/functions in order
  190. (sequentially). Therefore, we can divide such a sequence into segments
  191. and checkpoint each segment. All segments except run in :func:`torch.no_grad`
  192. manner, i.e., not storing the intermediate activations. The inputs of each
  193. checkpointed segment will be saved for re-running the segment in the backward pass.
  194. See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
  195. .. warning::
  196. Checkpointing currently only supports :func:`torch.autograd.backward`
  197. and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
  198. is not supported.
  199. .. warning:
  200. At least one of the inputs needs to have :code:`requires_grad=True` if
  201. grads are needed for model inputs, otherwise the checkpointed part of the
  202. model won't have gradients.
  203. Args:
  204. functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
  205. x: A Tensor that is input to :attr:`functions`
  206. every: checkpoint every-n functions (default: 1)
  207. flatten: flatten nn.Sequential of nn.Sequentials
  208. skip_last: skip checkpointing the last function in the sequence if True
  209. use_reentrant: Use re-entrant checkpointing
  210. Returns:
  211. Output of running :attr:`functions` sequentially on :attr:`*inputs`
  212. Example:
  213. >>> model = nn.Sequential(...)
  214. >>> input_var = checkpoint_seq(model, input_var, every=2)
  215. """
  216. if use_reentrant is None:
  217. use_reentrant = use_reentrant_ckpt()
  218. def run_function(start, end, functions):
  219. def forward(_x):
  220. for j in range(start, end + 1):
  221. _x = functions[j](_x)
  222. return _x
  223. return forward
  224. if isinstance(functions, torch.nn.Sequential):
  225. functions = functions.children()
  226. if flatten:
  227. functions = chain.from_iterable(functions)
  228. if not isinstance(functions, (tuple, list)):
  229. functions = tuple(functions)
  230. num_checkpointed = len(functions)
  231. if skip_last:
  232. num_checkpointed -= 1
  233. end = -1
  234. for start in range(0, num_checkpointed, every):
  235. end = min(start + every - 1, num_checkpointed - 1)
  236. x = torch.utils.checkpoint.checkpoint(
  237. run_function(start, end, functions),
  238. x,
  239. use_reentrant=use_reentrant,
  240. )
  241. if skip_last:
  242. return run_function(end + 1, len(functions) - 1, functions)(x)
  243. return x
  244. def adapt_input_conv(in_chans: int, conv_weight: Tensor) -> Tensor:
  245. conv_type = conv_weight.dtype
  246. conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
  247. O, I, J, K = conv_weight.shape
  248. if in_chans == 1:
  249. if I > 3:
  250. assert conv_weight.shape[1] % 3 == 0
  251. # For models with space2depth stems
  252. conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
  253. conv_weight = conv_weight.sum(dim=2, keepdim=False)
  254. else:
  255. conv_weight = conv_weight.sum(dim=1, keepdim=True)
  256. elif in_chans != 3:
  257. if I != 3:
  258. raise NotImplementedError('Weight format not supported by conversion.')
  259. else:
  260. # NOTE this strategy should be better than random init, but there could be other combinations of
  261. # the original RGB input layer weights that'd work better for specific cases.
  262. repeat = int(math.ceil(in_chans / 3))
  263. conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
  264. conv_weight *= (3 / float(in_chans))
  265. conv_weight = conv_weight.to(conv_type)
  266. return conv_weight
  267. def reinit_non_persistent_buffers(model: nn.Module) -> List[str]:
  268. """Walk model and call init_non_persistent_buffers() on modules that have it.
  269. This reinitializes computed buffers (like RoPE frequencies, attention bias indices)
  270. that are marked as non-persistent and thus not saved in checkpoints. These buffers
  271. are typically computed from module configuration and need to be reinitialized after
  272. loading a checkpoint.
  273. Args:
  274. model: Model to reinitialize buffers for
  275. Returns:
  276. List of module names that were reinitialized
  277. Example:
  278. >>> model = create_model('vit_base', pretrained=True)
  279. >>> # After loading checkpoint or moving to new device
  280. >>> reinitialized = reinit_non_persistent_buffers(model)
  281. >>> print(f"Reinitialized {len(reinitialized)} modules")
  282. """
  283. reinitialized = []
  284. for name, module in model.named_modules():
  285. if hasattr(module, 'init_non_persistent_buffers'):
  286. module.init_non_persistent_buffers()
  287. reinitialized.append(name if name else '(root)')
  288. return reinitialized