activations_me.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. """ Activations (memory-efficient w/ custom autograd)
  2. A collection of activations fn and modules with a common interface so that they can
  3. easily be swapped. All have an `inplace` arg even if not used.
  4. These activations are not compatible with jit scripting or ONNX export of the model, please use
  5. basic versions of the activations.
  6. Hacked together by / Copyright 2020 Ross Wightman
  7. """
  8. import torch
  9. from torch import nn as nn
  10. from torch.nn import functional as F
  11. def swish_fwd(x):
  12. return x.mul(torch.sigmoid(x))
  13. def swish_bwd(x, grad_output):
  14. x_sigmoid = torch.sigmoid(x)
  15. return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
  16. class SwishAutoFn(torch.autograd.Function):
  17. """ optimised Swish w/ memory-efficient checkpoint
  18. Inspired by conversation btw Jeremy Howard & Adam Pazske
  19. https://twitter.com/jeremyphoward/status/1188251041835315200
  20. """
  21. @staticmethod
  22. def symbolic(g, x):
  23. return g.op("Mul", x, g.op("Sigmoid", x))
  24. @staticmethod
  25. def forward(ctx, x):
  26. ctx.save_for_backward(x)
  27. return swish_fwd(x)
  28. @staticmethod
  29. def backward(ctx, grad_output):
  30. x = ctx.saved_tensors[0]
  31. return swish_bwd(x, grad_output)
  32. def swish_me(x, inplace=False):
  33. return SwishAutoFn.apply(x)
  34. class SwishMe(nn.Module):
  35. def __init__(self, inplace: bool = False):
  36. super().__init__()
  37. def forward(self, x):
  38. return SwishAutoFn.apply(x)
  39. def mish_fwd(x):
  40. return x.mul(torch.tanh(F.softplus(x)))
  41. def mish_bwd(x, grad_output):
  42. x_sigmoid = torch.sigmoid(x)
  43. x_tanh_sp = F.softplus(x).tanh()
  44. return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
  45. class MishAutoFn(torch.autograd.Function):
  46. """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
  47. A memory efficient variant of Mish
  48. """
  49. @staticmethod
  50. def forward(ctx, x):
  51. ctx.save_for_backward(x)
  52. return mish_fwd(x)
  53. @staticmethod
  54. def backward(ctx, grad_output):
  55. x = ctx.saved_tensors[0]
  56. return mish_bwd(x, grad_output)
  57. def mish_me(x, inplace=False):
  58. return MishAutoFn.apply(x)
  59. class MishMe(nn.Module):
  60. def __init__(self, inplace: bool = False):
  61. super().__init__()
  62. def forward(self, x):
  63. return MishAutoFn.apply(x)
  64. def hard_sigmoid_fwd(x, inplace: bool = False):
  65. return (x + 3).clamp(min=0, max=6).div(6.)
  66. def hard_sigmoid_bwd(x, grad_output):
  67. m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
  68. return grad_output * m
  69. class HardSigmoidAutoFn(torch.autograd.Function):
  70. @staticmethod
  71. def forward(ctx, x):
  72. ctx.save_for_backward(x)
  73. return hard_sigmoid_fwd(x)
  74. @staticmethod
  75. def backward(ctx, grad_output):
  76. x = ctx.saved_tensors[0]
  77. return hard_sigmoid_bwd(x, grad_output)
  78. def hard_sigmoid_me(x, inplace: bool = False):
  79. return HardSigmoidAutoFn.apply(x)
  80. class HardSigmoidMe(nn.Module):
  81. def __init__(self, inplace: bool = False):
  82. super().__init__()
  83. def forward(self, x):
  84. return HardSigmoidAutoFn.apply(x)
  85. def hard_swish_fwd(x):
  86. return x * (x + 3).clamp(min=0, max=6).div(6.)
  87. def hard_swish_bwd(x, grad_output):
  88. m = torch.ones_like(x) * (x >= 3.)
  89. m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
  90. return grad_output * m
  91. class HardSwishAutoFn(torch.autograd.Function):
  92. """A memory efficient HardSwish activation"""
  93. @staticmethod
  94. def forward(ctx, x):
  95. ctx.save_for_backward(x)
  96. return hard_swish_fwd(x)
  97. @staticmethod
  98. def backward(ctx, grad_output):
  99. x = ctx.saved_tensors[0]
  100. return hard_swish_bwd(x, grad_output)
  101. @staticmethod
  102. def symbolic(g, self):
  103. input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float)))
  104. hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
  105. hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
  106. return g.op("Mul", self, hardtanh_)
  107. def hard_swish_me(x, inplace=False):
  108. return HardSwishAutoFn.apply(x)
  109. class HardSwishMe(nn.Module):
  110. def __init__(self, inplace: bool = False):
  111. super().__init__()
  112. def forward(self, x):
  113. return HardSwishAutoFn.apply(x)
  114. def hard_mish_fwd(x):
  115. return 0.5 * x * (x + 2).clamp(min=0, max=2)
  116. def hard_mish_bwd(x, grad_output):
  117. m = torch.ones_like(x) * (x >= -2.)
  118. m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
  119. return grad_output * m
  120. class HardMishAutoFn(torch.autograd.Function):
  121. """ A memory efficient variant of Hard Mish
  122. Experimental, based on notes by Mish author Diganta Misra at
  123. https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
  124. """
  125. @staticmethod
  126. def forward(ctx, x):
  127. ctx.save_for_backward(x)
  128. return hard_mish_fwd(x)
  129. @staticmethod
  130. def backward(ctx, grad_output):
  131. x = ctx.saved_tensors[0]
  132. return hard_mish_bwd(x, grad_output)
  133. def hard_mish_me(x, inplace: bool = False):
  134. return HardMishAutoFn.apply(x)
  135. class HardMishMe(nn.Module):
  136. def __init__(self, inplace: bool = False):
  137. super().__init__()
  138. def forward(self, x):
  139. return HardMishAutoFn.apply(x)