model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. """ Model / state_dict utils
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import fnmatch
  5. from copy import deepcopy
  6. import torch
  7. from torchvision.ops.misc import FrozenBatchNorm2d
  8. from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\
  9. freeze_batch_norm_2d, unfreeze_batch_norm_2d
  10. from .model_ema import ModelEma
  11. def unwrap_model(model):
  12. if isinstance(model, ModelEma):
  13. return unwrap_model(model.ema)
  14. else:
  15. if hasattr(model, 'module'):
  16. return unwrap_model(model.module)
  17. elif hasattr(model, '_orig_mod'):
  18. return unwrap_model(model._orig_mod)
  19. else:
  20. return model
  21. def get_state_dict(model, unwrap_fn=unwrap_model):
  22. return unwrap_fn(model).state_dict()
  23. def avg_sq_ch_mean(model, input, output):
  24. """ calculate average channel square mean of output activations
  25. """
  26. return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
  27. def avg_ch_var(model, input, output):
  28. """ calculate average channel variance of output activations
  29. """
  30. return torch.mean(output.var(axis=[0, 2, 3])).item()
  31. def avg_ch_var_residual(model, input, output):
  32. """ calculate average channel variance of output activations
  33. """
  34. return torch.mean(output.var(axis=[0, 2, 3])).item()
  35. class ActivationStatsHook:
  36. """Iterates through each of `model`'s modules and matches modules using unix pattern
  37. matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is
  38. a match.
  39. Arguments:
  40. model (nn.Module): model from which we will extract the activation stats
  41. hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string
  42. matching with the name of model's modules.
  43. hook_fns (List[Callable]): List of hook functions to be registered at every
  44. module in `layer_names`.
  45. Inspiration from https://docs.fast.ai/callback.hook.html.
  46. Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example
  47. on how to plot Signal Propagation Plots using `ActivationStatsHook`.
  48. """
  49. def __init__(self, model, hook_fn_locs, hook_fns):
  50. self.model = model
  51. self.hook_fn_locs = hook_fn_locs
  52. self.hook_fns = hook_fns
  53. if len(hook_fn_locs) != len(hook_fns):
  54. raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
  55. their lengths are different.")
  56. self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
  57. for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
  58. self.register_hook(hook_fn_loc, hook_fn)
  59. def _create_hook(self, hook_fn):
  60. def append_activation_stats(module, input, output):
  61. out = hook_fn(module, input, output)
  62. self.stats[hook_fn.__name__].append(out)
  63. return append_activation_stats
  64. def register_hook(self, hook_fn_loc, hook_fn):
  65. for name, module in self.model.named_modules():
  66. if not fnmatch.fnmatch(name, hook_fn_loc):
  67. continue
  68. module.register_forward_hook(self._create_hook(hook_fn))
  69. def extract_spp_stats(
  70. model,
  71. hook_fn_locs,
  72. hook_fns,
  73. input_shape=[8, 3, 224, 224]):
  74. """Extract average square channel mean and variance of activations during
  75. forward pass to plot Signal Propagation Plots (SPP).
  76. Paper: https://arxiv.org/abs/2101.08692
  77. Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
  78. """
  79. x = torch.normal(0., 1., input_shape)
  80. hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
  81. _ = model(x)
  82. return hook.stats
  83. def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
  84. """
  85. Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
  86. done in place.
  87. Args:
  88. root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
  89. submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
  90. named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
  91. means that the whole root module will be (un)frozen. Defaults to []
  92. include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers.
  93. Defaults to `True`.
  94. mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`.
  95. """
  96. assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
  97. if isinstance(root_module, (
  98. torch.nn.modules.batchnorm.BatchNorm2d,
  99. torch.nn.modules.batchnorm.SyncBatchNorm,
  100. BatchNormAct2d,
  101. SyncBatchNormAct,
  102. )):
  103. # Raise assertion here because we can't convert it in place
  104. raise AssertionError(
  105. "You have provided a batch norm layer as the `root module`. Please use "
  106. "`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.")
  107. if isinstance(submodules, str):
  108. submodules = [submodules]
  109. named_modules = submodules
  110. submodules = [root_module.get_submodule(m) for m in submodules]
  111. if not len(submodules):
  112. named_modules, submodules = list(zip(*root_module.named_children()))
  113. for n, m in zip(named_modules, submodules):
  114. # (Un)freeze parameters
  115. for p in m.parameters():
  116. p.requires_grad = False if mode == 'freeze' else True
  117. if include_bn_running_stats:
  118. # Helper to add submodule specified as a named_module
  119. def _add_submodule(module, name, submodule):
  120. split = name.rsplit('.', 1)
  121. if len(split) > 1:
  122. module.get_submodule(split[0]).add_module(split[1], submodule)
  123. else:
  124. module.add_module(name, submodule)
  125. # Freeze batch norm
  126. if mode == 'freeze':
  127. res = freeze_batch_norm_2d(m)
  128. # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
  129. # convert it in place, but will return the converted result. In this case `res` holds the converted
  130. # result and we may try to re-assign the named module
  131. if isinstance(m, (
  132. torch.nn.modules.batchnorm.BatchNorm2d,
  133. torch.nn.modules.batchnorm.SyncBatchNorm,
  134. BatchNormAct2d,
  135. SyncBatchNormAct,
  136. )):
  137. _add_submodule(root_module, n, res)
  138. # Unfreeze batch norm
  139. else:
  140. res = unfreeze_batch_norm_2d(m)
  141. # Ditto. See note above in mode == 'freeze' branch
  142. if isinstance(m, (FrozenBatchNorm2d, FrozenBatchNormAct2d)):
  143. _add_submodule(root_module, n, res)
  144. def freeze(root_module, submodules=[], include_bn_running_stats=True):
  145. """
  146. Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
  147. Args:
  148. root_module (nn.Module): Root module relative to which `submodules` are referenced.
  149. submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
  150. named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
  151. means that the whole root module will be frozen. Defaults to `[]`.
  152. include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and
  153. `SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning,
  154. it's good practice to freeze batch norm stats. And note that these are different to the affine parameters
  155. which are just normal PyTorch parameters. Defaults to `True`.
  156. Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`.
  157. Examples::
  158. >>> model = timm.create_model('resnet18')
  159. >>> # Freeze up to and including layer2
  160. >>> submodules = [n for n, _ in model.named_children()]
  161. >>> print(submodules)
  162. ['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc']
  163. >>> freeze(model, submodules[:submodules.index('layer2') + 1])
  164. >>> # Check for yourself that it works as expected
  165. >>> print(model.layer2[0].conv1.weight.requires_grad)
  166. False
  167. >>> print(model.layer3[0].conv1.weight.requires_grad)
  168. True
  169. >>> # Unfreeze
  170. >>> unfreeze(model)
  171. """
  172. _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze")
  173. def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
  174. """
  175. Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
  176. Args:
  177. root_module (nn.Module): Root module relative to which `submodules` are referenced.
  178. submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided
  179. as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty
  180. list means that the whole root module will be unfrozen. Defaults to `[]`.
  181. include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers.
  182. These will be converted to `BatchNorm2d` in place. Defaults to `True`.
  183. See example in docstring for `freeze`.
  184. """
  185. _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
  186. def reparameterize_model(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
  187. if not inplace:
  188. model = deepcopy(model)
  189. def _fuse(m):
  190. for child_name, child in m.named_children():
  191. if hasattr(child, 'fuse'):
  192. setattr(m, child_name, child.fuse())
  193. elif hasattr(child, "reparameterize"):
  194. child.reparameterize()
  195. elif hasattr(child, "switch_to_deploy"):
  196. child.switch_to_deploy()
  197. _fuse(child)
  198. _fuse(model)
  199. return model