create_norm_act.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """ NormAct (Normalization + Activation Layer) Factory
  2. Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
  3. instances in models. Where these are used it will be possible to swap separate BN + act layers with
  4. combined modules like IABN or EvoNorms.
  5. Hacked together by / Copyright 2020 Ross Wightman
  6. """
  7. import types
  8. import functools
  9. from typing import Optional
  10. from .evo_norm import *
  11. from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
  12. from .norm_act import (
  13. BatchNormAct2d,
  14. GroupNormAct,
  15. GroupNorm1Act,
  16. LayerNormAct,
  17. LayerNormActFp32,
  18. LayerNormAct2d,
  19. LayerNormAct2dFp32,
  20. RmsNormAct,
  21. RmsNormActFp32,
  22. RmsNormAct2d,
  23. RmsNormAct2dFp32,
  24. )
  25. from .inplace_abn import InplaceAbn
  26. from .typing import LayerType
  27. _NORM_ACT_MAP = dict(
  28. batchnorm=BatchNormAct2d,
  29. batchnorm2d=BatchNormAct2d,
  30. groupnorm=GroupNormAct,
  31. groupnorm1=GroupNorm1Act,
  32. layernorm=LayerNormAct,
  33. layernorm2d=LayerNormAct2d,
  34. layernormfp32=LayerNormActFp32,
  35. layernorm2dfp32=LayerNormAct2dFp32,
  36. evonormb0=EvoNorm2dB0,
  37. evonormb1=EvoNorm2dB1,
  38. evonormb2=EvoNorm2dB2,
  39. evonorms0=EvoNorm2dS0,
  40. evonorms0a=EvoNorm2dS0a,
  41. evonorms1=EvoNorm2dS1,
  42. evonorms1a=EvoNorm2dS1a,
  43. evonorms2=EvoNorm2dS2,
  44. evonorms2a=EvoNorm2dS2a,
  45. frn=FilterResponseNormAct2d,
  46. frntlu=FilterResponseNormTlu2d,
  47. inplaceabn=InplaceAbn,
  48. iabn=InplaceAbn,
  49. rmsnorm=RmsNormAct,
  50. rmsnorm2d=RmsNormAct2d,
  51. rmsnormfp32=RmsNormActFp32,
  52. rmsnorm2dfp32=RmsNormAct2dFp32,
  53. )
  54. _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
  55. # Reverse map from base norm layer names to norm+act layer classes
  56. _NORM_TO_NORM_ACT_MAP = dict(
  57. batchnorm=BatchNormAct2d,
  58. batchnorm2d=BatchNormAct2d,
  59. groupnorm=GroupNormAct,
  60. groupnorm1=GroupNorm1Act,
  61. layernorm=LayerNormAct,
  62. layernorm2d=LayerNormAct2d,
  63. layernormfp32=LayerNormActFp32,
  64. layernorm2dfp32=LayerNormAct2dFp32,
  65. rmsnorm=RmsNormAct,
  66. rmsnorm2d=RmsNormAct2d,
  67. rmsnormfp32=RmsNormActFp32,
  68. rmsnorm2dfp32=RmsNormAct2dFp32,
  69. )
  70. # has act_layer arg to define act type
  71. _NORM_ACT_REQUIRES_ARG = {
  72. BatchNormAct2d,
  73. GroupNormAct,
  74. GroupNorm1Act,
  75. LayerNormAct,
  76. LayerNormAct2d,
  77. LayerNormActFp32,
  78. LayerNormAct2dFp32,
  79. FilterResponseNormAct2d,
  80. InplaceAbn,
  81. RmsNormAct,
  82. RmsNormAct2d,
  83. RmsNormActFp32,
  84. RmsNormAct2dFp32,
  85. }
  86. def create_norm_act_layer(
  87. layer_name: LayerType,
  88. num_features: int,
  89. act_layer: Optional[LayerType] = None,
  90. apply_act: bool = True,
  91. jit: bool = False,
  92. **kwargs,
  93. ):
  94. layer = get_norm_act_layer(layer_name, act_layer=act_layer)
  95. layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
  96. if jit:
  97. layer_instance = torch.jit.script(layer_instance)
  98. return layer_instance
  99. def get_norm_act_layer(
  100. norm_layer: LayerType,
  101. act_layer: Optional[LayerType] = None,
  102. ):
  103. if norm_layer is None:
  104. return None
  105. assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
  106. assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
  107. norm_act_kwargs = {}
  108. # unbind partial fn, so args can be rebound later
  109. if isinstance(norm_layer, functools.partial):
  110. norm_act_kwargs.update(norm_layer.keywords)
  111. norm_layer = norm_layer.func
  112. if isinstance(norm_layer, str):
  113. if not norm_layer:
  114. return None
  115. layer_name = norm_layer.replace('_', '').lower().split('-')[0]
  116. norm_act_layer = _NORM_ACT_MAP[layer_name]
  117. elif norm_layer in _NORM_ACT_TYPES:
  118. norm_act_layer = norm_layer
  119. elif isinstance(norm_layer, types.FunctionType):
  120. # if function type, must be a lambda/fn that creates a norm_act layer
  121. norm_act_layer = norm_layer
  122. else:
  123. # Use reverse map to find the corresponding norm+act layer
  124. type_name = norm_layer.__name__.lower()
  125. norm_act_layer = _NORM_TO_NORM_ACT_MAP.get(type_name, None)
  126. assert norm_act_layer is not None, f"No equivalent norm_act layer for {type_name}"
  127. if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
  128. # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
  129. # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
  130. norm_act_kwargs.setdefault('act_layer', act_layer)
  131. if norm_act_kwargs:
  132. norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
  133. return norm_act_layer