activations.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import math
  16. from collections import OrderedDict
  17. import torch
  18. from torch import Tensor, nn
  19. from .integrations.hub_kernels import use_kernel_forward_from_hub
  20. from .utils import logging
  21. from .utils.import_utils import is_torchdynamo_compiling
  22. logger = logging.get_logger(__name__)
  23. @use_kernel_forward_from_hub("GeluTanh")
  24. class GELUTanh(nn.Module):
  25. """
  26. A fast C implementation of the tanh approximation of the GeLU activation function. See
  27. https://huggingface.co/papers/1606.08415.
  28. This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
  29. match due to rounding errors.
  30. """
  31. def __init__(self, use_gelu_tanh_python: bool = False):
  32. super().__init__()
  33. if use_gelu_tanh_python:
  34. self.act = self._gelu_tanh_python
  35. else:
  36. self.act = functools.partial(nn.functional.gelu, approximate="tanh")
  37. def _gelu_tanh_python(self, input: Tensor) -> Tensor:
  38. return input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
  39. def forward(self, input: Tensor) -> Tensor:
  40. return self.act(input)
  41. # Added for compatibility with autoawq which is archived now and imports PytorchGELUTanh from activations.py
  42. PytorchGELUTanh = GELUTanh
  43. @use_kernel_forward_from_hub("NewGELU")
  44. class NewGELUActivation(nn.Module):
  45. """
  46. Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
  47. the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
  48. """
  49. def forward(self, input: Tensor) -> Tensor:
  50. return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
  51. @use_kernel_forward_from_hub("GeLU")
  52. class GELUActivation(nn.Module):
  53. """
  54. Original Implementation of the GELU activation function in Google BERT repo when initially created. For
  55. information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
  56. torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
  57. Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
  58. """
  59. def __init__(self, use_gelu_python: bool = False):
  60. super().__init__()
  61. if use_gelu_python:
  62. self.act = self._gelu_python
  63. else:
  64. self.act = nn.functional.gelu
  65. def _gelu_python(self, input: Tensor) -> Tensor:
  66. return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
  67. def forward(self, input: Tensor) -> Tensor:
  68. return self.act(input)
  69. @use_kernel_forward_from_hub("SiLU")
  70. class SiLUActivation(nn.Module):
  71. """
  72. See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
  73. Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
  74. Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
  75. Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
  76. later.
  77. """
  78. def forward(self, input: Tensor) -> Tensor:
  79. return nn.functional.silu(input)
  80. @use_kernel_forward_from_hub("FastGELU")
  81. class FastGELUActivation(nn.Module):
  82. """
  83. Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
  84. """
  85. def forward(self, input: Tensor) -> Tensor:
  86. return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
  87. @use_kernel_forward_from_hub("QuickGELU")
  88. class QuickGELUActivation(nn.Module):
  89. """
  90. Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
  91. """
  92. def forward(self, input: Tensor) -> Tensor:
  93. return input * torch.sigmoid(1.702 * input)
  94. class ClippedGELUActivation(nn.Module):
  95. """
  96. Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
  97. it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
  98. https://huggingface.co/papers/2004.09602.
  99. Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
  100. initially created.
  101. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
  102. torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://huggingface.co/papers/1606.08415
  103. """
  104. def __init__(self, min: float, max: float):
  105. if min > max:
  106. raise ValueError(f"min should be < max (got min: {min}, max: {max})")
  107. super().__init__()
  108. self.min = min
  109. self.max = max
  110. def forward(self, x: Tensor) -> Tensor:
  111. return torch.clip(gelu(x), self.min, self.max)
  112. class AccurateGELUActivation(nn.Module):
  113. """
  114. Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
  115. https://github.com/hendrycks/GELUs
  116. Implemented along with MEGA (Moving Average Equipped Gated Attention)
  117. """
  118. def __init__(self):
  119. super().__init__()
  120. self.precomputed_constant = math.sqrt(2 / math.pi)
  121. def forward(self, input: Tensor) -> Tensor:
  122. return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
  123. class MishActivation(nn.Module):
  124. """
  125. See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also
  126. visit the official repository for the paper: https://github.com/digantamisra98/Mish
  127. """
  128. def __init__(self):
  129. super().__init__()
  130. self.act = nn.functional.mish
  131. def _mish_python(self, input: Tensor) -> Tensor:
  132. return input * torch.tanh(nn.functional.softplus(input))
  133. def forward(self, input: Tensor) -> Tensor:
  134. return self.act(input)
  135. class LinearActivation(nn.Module):
  136. """
  137. Applies the linear activation function, i.e. forwarding input directly to output.
  138. """
  139. def forward(self, input: Tensor) -> Tensor:
  140. return input
  141. class LaplaceActivation(nn.Module):
  142. """
  143. Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
  144. https://huggingface.co/papers/2209.10655
  145. Inspired by squared relu, but with bounded range and gradient for better stability
  146. """
  147. def forward(self, input, mu=0.707107, sigma=0.282095):
  148. input = (input - mu).div(sigma * math.sqrt(2.0))
  149. return 0.5 * (1.0 + torch.erf(input))
  150. class ReLUSquaredActivation(nn.Module):
  151. """
  152. Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668
  153. """
  154. def forward(self, input):
  155. relu_applied = nn.functional.relu(input)
  156. squared = torch.square(relu_applied)
  157. return squared
  158. class ClassInstantier(OrderedDict):
  159. def __getitem__(self, key):
  160. content = super().__getitem__(key)
  161. cls, kwargs = content if isinstance(content, tuple) else (content, {})
  162. return cls(**kwargs)
  163. class XIELUActivation(nn.Module):
  164. """
  165. Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
  166. If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA
  167. Otherwise, we emit a single warning and use xIELU Python
  168. """
  169. def __init__(
  170. self,
  171. alpha_p_init=0.8,
  172. alpha_n_init=0.8,
  173. beta=0.5,
  174. eps=-1e-6,
  175. dtype=torch.bfloat16,
  176. with_vector_loads=False,
  177. ):
  178. super().__init__()
  179. self.alpha_p = nn.Parameter(torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0))
  180. self.alpha_n = nn.Parameter(
  181. torch.log(torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype))).unsqueeze(0)
  182. )
  183. self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
  184. self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
  185. self.with_vector_loads = with_vector_loads
  186. # Temporary until xIELU CUDA fully implemented
  187. self._beta_scalar = float(beta)
  188. self._eps_scalar = float(eps)
  189. self._xielu_cuda_obj = None
  190. try:
  191. import xielu.ops # noqa: F401
  192. self._xielu_cuda_obj = torch.classes.xielu.XIELU()
  193. msg = "Using experimental xIELU CUDA."
  194. try:
  195. from torch.compiler import allow_in_graph
  196. self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
  197. msg += " Enabled torch._dynamo for xIELU CUDA."
  198. except Exception as err:
  199. msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance."
  200. self._xielu_cuda_fn = self._xielu_cuda
  201. logger.warning_once(msg)
  202. except Exception as err:
  203. logger.warning_once(
  204. f"CUDA-fused xIELU not available ({err}) – falling back to a Python version.\n"
  205. "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`"
  206. )
  207. def _xielu_python(self, x: Tensor) -> Tensor:
  208. alpha_p = nn.functional.softplus(self.alpha_p)
  209. alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
  210. return torch.where(
  211. x > 0,
  212. alpha_p * x * x + self.beta * x,
  213. (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
  214. )
  215. def _xielu_cuda(self, x: Tensor) -> Tensor:
  216. """Firewall function to prevent torch.compile from seeing .item() calls"""
  217. original_shape = x.shape
  218. # CUDA kernel expects 3D tensors, reshape if needed
  219. while x.dim() < 3:
  220. x = x.unsqueeze(0)
  221. if x.dim() > 3:
  222. x = x.view(-1, 1, x.size(-1))
  223. if original_shape != x.shape:
  224. logger.warning_once(
  225. "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
  226. original_shape,
  227. x.shape,
  228. )
  229. result = self._xielu_cuda_obj.forward(
  230. x,
  231. self.alpha_p.to(x.dtype),
  232. self.alpha_n.to(x.dtype),
  233. # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
  234. self._beta_scalar,
  235. self._eps_scalar,
  236. self.with_vector_loads,
  237. )
  238. return result.view(original_shape)
  239. def forward(self, input: Tensor) -> Tensor:
  240. if self._xielu_cuda_obj is not None and input.is_cuda:
  241. if not is_torchdynamo_compiling():
  242. return self._xielu_cuda_fn(input)
  243. else:
  244. logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.")
  245. return self._xielu_python(input)
  246. ACT2CLS = {
  247. "gelu": GELUActivation,
  248. "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
  249. "gelu_fast": FastGELUActivation,
  250. "gelu_new": NewGELUActivation,
  251. "gelu_python": (GELUActivation, {"use_gelu_python": True}),
  252. "gelu_pytorch_tanh": GELUTanh,
  253. "gelu_python_tanh": (GELUTanh, {"use_gelu_tanh_python": True}),
  254. "gelu_accurate": AccurateGELUActivation,
  255. "hardswish": nn.Hardswish,
  256. "laplace": LaplaceActivation,
  257. "leaky_relu": nn.LeakyReLU,
  258. "linear": LinearActivation,
  259. "mish": MishActivation,
  260. "quick_gelu": QuickGELUActivation,
  261. "relu": nn.ReLU,
  262. "relu2": ReLUSquaredActivation,
  263. "relu6": nn.ReLU6,
  264. "sigmoid": nn.Sigmoid,
  265. "silu": SiLUActivation,
  266. "swish": nn.SiLU,
  267. "tanh": nn.Tanh,
  268. "prelu": nn.PReLU,
  269. "xielu": XIELUActivation,
  270. }
  271. ACT2FN = ClassInstantier(ACT2CLS)
  272. def get_activation(activation_string):
  273. if activation_string in ACT2FN:
  274. return ACT2FN[activation_string]
  275. else:
  276. raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
  277. # For backwards compatibility with: from activations import gelu_python
  278. gelu_python = get_activation("gelu_python")
  279. gelu_new = get_activation("gelu_new")
  280. gelu = get_activation("gelu")
  281. gelu_fast = get_activation("gelu_fast")
  282. quick_gelu = get_activation("quick_gelu")
  283. silu = get_activation("silu")
  284. mish = get_activation("mish")
  285. linear_act = get_activation("linear")