create_act.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. """ Activation Factory
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. from typing import Callable, Optional, Type, Union
  5. from .activations import *
  6. from .activations_me import *
  7. from .config import is_exportable, is_scriptable
  8. from .typing import LayerType
  9. # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
  10. # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
  11. # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
  12. _has_silu = 'silu' in dir(torch.nn.functional)
  13. _has_hardswish = 'hardswish' in dir(torch.nn.functional)
  14. _has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
  15. _has_mish = 'mish' in dir(torch.nn.functional)
  16. _ACT_FN_DEFAULT = dict(
  17. silu=F.silu if _has_silu else swish,
  18. swish=F.silu if _has_silu else swish,
  19. mish=F.mish if _has_mish else mish,
  20. relu=F.relu,
  21. relu6=F.relu6,
  22. leaky_relu=F.leaky_relu,
  23. elu=F.elu,
  24. celu=F.celu,
  25. selu=F.selu,
  26. gelu=gelu,
  27. gelu_tanh=gelu_tanh,
  28. quick_gelu=quick_gelu,
  29. sigmoid=sigmoid,
  30. tanh=tanh,
  31. hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
  32. hard_swish=F.hardswish if _has_hardswish else hard_swish,
  33. hard_mish=hard_mish,
  34. )
  35. _ACT_FN_ME = dict(
  36. silu=F.silu if _has_silu else swish_me,
  37. swish=F.silu if _has_silu else swish_me,
  38. mish=F.mish if _has_mish else mish_me,
  39. hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
  40. hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
  41. hard_mish=hard_mish_me,
  42. )
  43. _ACT_FNS = (_ACT_FN_ME, _ACT_FN_DEFAULT)
  44. for a in _ACT_FNS:
  45. a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
  46. a.setdefault('hardswish', a.get('hard_swish'))
  47. _ACT_LAYER_DEFAULT = dict(
  48. silu=nn.SiLU if _has_silu else Swish,
  49. swish=nn.SiLU if _has_silu else Swish,
  50. mish=nn.Mish if _has_mish else Mish,
  51. relu=nn.ReLU,
  52. relu6=nn.ReLU6,
  53. leaky_relu=nn.LeakyReLU,
  54. elu=nn.ELU,
  55. prelu=PReLU,
  56. celu=nn.CELU,
  57. selu=nn.SELU,
  58. gelu=GELU,
  59. gelu_tanh=GELUTanh,
  60. quick_gelu=QuickGELU,
  61. sigmoid=Sigmoid,
  62. tanh=Tanh,
  63. hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
  64. hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
  65. hard_mish=HardMish,
  66. identity=nn.Identity,
  67. )
  68. _ACT_LAYER_ME = dict(
  69. silu=nn.SiLU if _has_silu else SwishMe,
  70. swish=nn.SiLU if _has_silu else SwishMe,
  71. mish=nn.Mish if _has_mish else MishMe,
  72. hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
  73. hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
  74. hard_mish=HardMishMe,
  75. )
  76. _ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_DEFAULT)
  77. for a in _ACT_LAYERS:
  78. a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
  79. a.setdefault('hardswish', a.get('hard_swish'))
  80. def get_act_fn(name: Optional[LayerType] = 'relu'):
  81. """ Activation Function Factory
  82. Fetching activation fns by name with this function allows export or torch script friendly
  83. functions to be returned dynamically based on current config.
  84. """
  85. if not name:
  86. return None
  87. if isinstance(name, Callable):
  88. return name
  89. name = name.lower()
  90. if not (is_exportable() or is_scriptable()):
  91. # If not exporting or scripting the model, first look for a memory-efficient version with
  92. # custom autograd, then fallback
  93. if name in _ACT_FN_ME:
  94. return _ACT_FN_ME[name]
  95. return _ACT_FN_DEFAULT[name]
  96. def get_act_layer(name: Optional[LayerType] = 'relu'):
  97. """ Activation Layer Factory
  98. Fetching activation layers by name with this function allows export or torch script friendly
  99. functions to be returned dynamically based on current config.
  100. """
  101. if name is None:
  102. return None
  103. if not isinstance(name, str):
  104. # callable, module, etc
  105. return name
  106. if not name:
  107. return None
  108. name = name.lower()
  109. if not (is_exportable() or is_scriptable()):
  110. if name in _ACT_LAYER_ME:
  111. return _ACT_LAYER_ME[name]
  112. return _ACT_LAYER_DEFAULT[name]
  113. def create_act_layer(
  114. name: Optional[LayerType],
  115. inplace: Optional[bool] = None,
  116. **kwargs
  117. ):
  118. act_layer = get_act_layer(name)
  119. if act_layer is None:
  120. return None
  121. if inplace is None:
  122. return act_layer(**kwargs)
  123. try:
  124. return act_layer(inplace=inplace, **kwargs)
  125. except TypeError:
  126. # recover if act layer doesn't have inplace arg
  127. return act_layer(**kwargs)