_param_groups.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import fnmatch
  2. import logging
  3. from itertools import islice
  4. from typing import Collection, Optional
  5. from torch import nn as nn
  6. from timm.models import group_parameters
  7. _logger = logging.getLogger(__name__)
  8. def _matches_pattern(name: str, patterns: Collection[str]) -> bool:
  9. """Check if parameter name matches any pattern (supports wildcards)."""
  10. return any(fnmatch.fnmatch(name, pattern) for pattern in patterns)
  11. def param_groups_weight_decay(
  12. model: nn.Module,
  13. weight_decay: float = 1e-5,
  14. no_weight_decay_list: Collection[str] = (),
  15. fallback_list: Collection[str] = (),
  16. fallback_no_weight_decay: bool = False,
  17. ):
  18. # Merge no_weight_decay into fallback_list if requested
  19. if fallback_no_weight_decay:
  20. fallback_list = set(fallback_list) | set(no_weight_decay_list)
  21. decay = []
  22. decay_fallback = []
  23. no_decay = []
  24. no_decay_fallback = []
  25. for name, param in model.named_parameters():
  26. if not param.requires_grad:
  27. continue
  28. # Determine if this is a "fallback" parameter for fallback optimizer (if available)
  29. is_fallback = _matches_pattern(name, fallback_list)
  30. # Determine weight decay
  31. matches_pattern = _matches_pattern(name, no_weight_decay_list)
  32. if param.ndim <= 1 or name.endswith(".bias") or matches_pattern:
  33. # No weight decay
  34. if is_fallback:
  35. no_decay_fallback.append(param)
  36. else:
  37. no_decay.append(param)
  38. else:
  39. # With weight decay
  40. if is_fallback:
  41. decay_fallback.append(param)
  42. else:
  43. decay.append(param)
  44. groups = []
  45. if no_decay:
  46. groups.append({'params': no_decay, 'weight_decay': 0.})
  47. if decay:
  48. groups.append({'params': decay, 'weight_decay': weight_decay})
  49. if no_decay_fallback:
  50. groups.append({'params': no_decay_fallback, 'weight_decay': 0., 'use_fallback': True})
  51. if decay_fallback:
  52. groups.append({'params': decay_fallback, 'weight_decay': weight_decay, 'use_fallback': True})
  53. return groups
  54. def _group(it, size):
  55. it = iter(it)
  56. return iter(lambda: tuple(islice(it, size)), ())
  57. def auto_group_layers(model, layers_per_group=12, num_groups=None):
  58. def _in_head(n, hp):
  59. if not hp:
  60. return True
  61. elif isinstance(hp, (tuple, list)):
  62. return any([n.startswith(hpi) for hpi in hp])
  63. else:
  64. return n.startswith(hp)
  65. head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
  66. names_trunk = []
  67. names_head = []
  68. for n, _ in model.named_parameters():
  69. names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
  70. # group non-head layers
  71. num_trunk_layers = len(names_trunk)
  72. if num_groups is not None:
  73. layers_per_group = -(num_trunk_layers // -num_groups)
  74. names_trunk = list(_group(names_trunk, layers_per_group))
  75. num_trunk_groups = len(names_trunk)
  76. layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
  77. layer_map.update({n: num_trunk_groups for n in names_head})
  78. return layer_map
  79. _layer_map = auto_group_layers # backward compat
  80. def param_groups_layer_decay(
  81. model: nn.Module,
  82. weight_decay: float = 0.05,
  83. no_weight_decay_list: Collection[str] = (),
  84. fallback_list: Collection[str] = (),
  85. fallback_no_weight_decay: bool = False,
  86. weight_decay_exclude_1d: bool = True,
  87. layer_decay: float = .75,
  88. min_scale: float = 0.,
  89. no_opt_scale: Optional[float] = None,
  90. verbose: bool = False,
  91. ):
  92. """
  93. Parameter groups for layer-wise lr decay & weight decay
  94. Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
  95. """
  96. # Merge no_weight_decay into fallback_list if requested
  97. if fallback_no_weight_decay:
  98. fallback_list = set(fallback_list) | set(no_weight_decay_list)
  99. param_group_names = {} # NOTE for debugging
  100. param_groups = {}
  101. if hasattr(model, 'group_matcher'):
  102. # FIXME interface needs more work
  103. layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
  104. else:
  105. # fallback
  106. layer_map = auto_group_layers(model)
  107. num_layers = max(layer_map.values()) + 1
  108. layer_max = num_layers - 1
  109. layer_scales = list(max(min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers))
  110. for name, param in model.named_parameters():
  111. if not param.requires_grad:
  112. continue
  113. # Determine if this is a "fallback" parameter for fallback optimizer (if available)
  114. is_fallback = _matches_pattern(name, fallback_list)
  115. # Determine weight decay
  116. if (weight_decay_exclude_1d and param.ndim <= 1) or _matches_pattern(name, no_weight_decay_list):
  117. # no weight decay for 1D parameters and model specific ones
  118. g_decay = "no_decay"
  119. this_decay = 0.
  120. else:
  121. g_decay = "decay"
  122. this_decay = weight_decay
  123. layer_id = layer_map.get(name, layer_max)
  124. this_scale = layer_scales[layer_id]
  125. if no_opt_scale and this_scale < no_opt_scale:
  126. # if the calculated layer scale is below this, exclude from optimization
  127. param.requires_grad = False
  128. continue
  129. fallback_suffix = "_fallback" if is_fallback else ""
  130. group_name = "layer_%d_%s%s" % (layer_id, g_decay, fallback_suffix)
  131. if group_name not in param_groups:
  132. param_group_names[group_name] = {
  133. "lr_scale": this_scale,
  134. "weight_decay": this_decay,
  135. "use_fallback": is_fallback,
  136. "param_names": [],
  137. }
  138. param_groups[group_name] = {
  139. "lr_scale": this_scale,
  140. "weight_decay": this_decay,
  141. "params": [],
  142. }
  143. if is_fallback:
  144. param_groups[group_name]["use_fallback"] = True
  145. param_group_names[group_name]["param_names"].append(name)
  146. param_groups[group_name]["params"].append(param)
  147. if verbose:
  148. import json
  149. _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
  150. return list(param_groups.values())