activation.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import Tensor
  6. from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
  7. from torch.nn.parameter import Parameter
  8. from .linear import NonDynamicallyQuantizableLinear
  9. from .module import Module
  10. __all__ = [
  11. "Threshold",
  12. "ReLU",
  13. "RReLU",
  14. "Hardtanh",
  15. "ReLU6",
  16. "Sigmoid",
  17. "Hardsigmoid",
  18. "Tanh",
  19. "SiLU",
  20. "Mish",
  21. "Hardswish",
  22. "ELU",
  23. "CELU",
  24. "SELU",
  25. "GLU",
  26. "GELU",
  27. "Hardshrink",
  28. "LeakyReLU",
  29. "LogSigmoid",
  30. "Softplus",
  31. "Softshrink",
  32. "MultiheadAttention",
  33. "PReLU",
  34. "Softsign",
  35. "Tanhshrink",
  36. "Softmin",
  37. "Softmax",
  38. "Softmax2d",
  39. "LogSoftmax",
  40. ]
  41. class Threshold(Module):
  42. r"""Thresholds each element of the input Tensor.
  43. Threshold is defined as:
  44. .. math::
  45. y =
  46. \begin{cases}
  47. x, &\text{ if } x > \text{threshold} \\
  48. \text{value}, &\text{ otherwise }
  49. \end{cases}
  50. Args:
  51. threshold: The value to threshold at
  52. value: The value to replace with
  53. inplace: can optionally do the operation in-place. Default: ``False``
  54. Shape:
  55. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  56. - Output: :math:`(*)`, same shape as the input.
  57. .. image:: ../scripts/activation_images/Threshold.png
  58. Examples::
  59. >>> m = nn.Threshold(0, 0.5)
  60. >>> input = torch.arange(-3, 3)
  61. >>> output = m(input)
  62. """
  63. __constants__ = ["threshold", "value", "inplace"]
  64. threshold: float
  65. value: float
  66. inplace: bool
  67. def __init__(self, threshold: float, value: float, inplace: bool = False) -> None:
  68. super().__init__()
  69. self.threshold = threshold
  70. self.value = value
  71. self.inplace = inplace
  72. # TODO: check in THNN (if inplace == True, then assert value <= threshold)
  73. def forward(self, input: Tensor) -> Tensor:
  74. """
  75. Runs the forward pass.
  76. """
  77. return F.threshold(input, self.threshold, self.value, self.inplace)
  78. def extra_repr(self) -> str:
  79. """
  80. Return the extra representation of the module.
  81. """
  82. inplace_str = ", inplace=True" if self.inplace else ""
  83. return f"threshold={self.threshold}, value={self.value}{inplace_str}"
  84. class ReLU(Module):
  85. r"""Applies the rectified linear unit function element-wise.
  86. :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
  87. Args:
  88. inplace: can optionally do the operation in-place. Default: ``False``
  89. Shape:
  90. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  91. - Output: :math:`(*)`, same shape as the input.
  92. .. image:: ../scripts/activation_images/ReLU.png
  93. Examples::
  94. >>> m = nn.ReLU()
  95. >>> input = torch.randn(2)
  96. >>> output = m(input)
  97. An implementation of CReLU - https://arxiv.org/abs/1603.05201
  98. >>> m = nn.ReLU()
  99. >>> input = torch.randn(2).unsqueeze(0)
  100. >>> output = torch.cat((m(input), m(-input)))
  101. """
  102. __constants__ = ["inplace"]
  103. inplace: bool
  104. def __init__(self, inplace: bool = False) -> None:
  105. super().__init__()
  106. self.inplace = inplace
  107. def forward(self, input: Tensor) -> Tensor:
  108. """
  109. Runs the forward pass.
  110. """
  111. return F.relu(input, inplace=self.inplace)
  112. def extra_repr(self) -> str:
  113. """
  114. Return the extra representation of the module.
  115. """
  116. inplace_str = "inplace=True" if self.inplace else ""
  117. return inplace_str
  118. class RReLU(Module):
  119. r"""Applies the randomized leaky rectified linear unit function, element-wise.
  120. Method described in the paper:
  121. `Empirical Evaluation of Rectified Activations in Convolutional Network <https://arxiv.org/abs/1505.00853>`_.
  122. The function is defined as:
  123. .. math::
  124. \text{RReLU}(x) =
  125. \begin{cases}
  126. x & \text{if } x \geq 0 \\
  127. ax & \text{ otherwise }
  128. \end{cases}
  129. where :math:`a` is randomly sampled from uniform distribution
  130. :math:`\mathcal{U}(\text{lower}, \text{upper})` during training while during
  131. evaluation :math:`a` is fixed with :math:`a = \frac{\text{lower} + \text{upper}}{2}`.
  132. Args:
  133. lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
  134. upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
  135. inplace: can optionally do the operation in-place. Default: ``False``
  136. Shape:
  137. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  138. - Output: :math:`(*)`, same shape as the input.
  139. .. image:: ../scripts/activation_images/RReLU.png
  140. Examples::
  141. >>> m = nn.RReLU(0.1, 0.3)
  142. >>> input = torch.randn(2)
  143. >>> output = m(input)
  144. """
  145. __constants__ = ["lower", "upper", "inplace"]
  146. lower: float
  147. upper: float
  148. inplace: bool
  149. def __init__(
  150. self, lower: float = 1.0 / 8, upper: float = 1.0 / 3, inplace: bool = False
  151. ) -> None:
  152. super().__init__()
  153. self.lower = lower
  154. self.upper = upper
  155. self.inplace = inplace
  156. def forward(self, input: Tensor) -> Tensor:
  157. """
  158. Runs the forward pass.
  159. """
  160. return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
  161. def extra_repr(self) -> str:
  162. """
  163. Return the extra representation of the module.
  164. """
  165. inplace_str = ", inplace=True" if self.inplace else ""
  166. return f"lower={self.lower}, upper={self.upper}{inplace_str}"
  167. class Hardtanh(Module):
  168. r"""Applies the HardTanh function element-wise.
  169. HardTanh is defined as:
  170. .. math::
  171. \text{HardTanh}(x) = \begin{cases}
  172. \text{max\_val} & \text{ if } x > \text{ max\_val } \\
  173. \text{min\_val} & \text{ if } x < \text{ min\_val } \\
  174. x & \text{ otherwise } \\
  175. \end{cases}
  176. Args:
  177. min_val: minimum value of the linear region range. Default: -1
  178. max_val: maximum value of the linear region range. Default: 1
  179. inplace: can optionally do the operation in-place. Default: ``False``
  180. Keyword arguments :attr:`min_value` and :attr:`max_value`
  181. have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
  182. Shape:
  183. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  184. - Output: :math:`(*)`, same shape as the input.
  185. .. image:: ../scripts/activation_images/Hardtanh.png
  186. Examples::
  187. >>> m = nn.Hardtanh(-2, 2)
  188. >>> input = torch.randn(2)
  189. >>> output = m(input)
  190. """
  191. __constants__ = ["min_val", "max_val", "inplace"]
  192. min_val: float
  193. max_val: float
  194. inplace: bool
  195. def __init__(
  196. self,
  197. min_val: float = -1.0,
  198. max_val: float = 1.0,
  199. inplace: bool = False,
  200. min_value: float | None = None,
  201. max_value: float | None = None,
  202. ) -> None:
  203. super().__init__()
  204. if min_value is not None:
  205. warnings.warn(
  206. "keyword argument `min_value` is deprecated and rename to `min_val`",
  207. FutureWarning,
  208. stacklevel=2,
  209. )
  210. min_val = min_value
  211. if max_value is not None:
  212. warnings.warn(
  213. "keyword argument `max_value` is deprecated and rename to `max_val`",
  214. FutureWarning,
  215. stacklevel=2,
  216. )
  217. max_val = max_value
  218. self.min_val = min_val
  219. self.max_val = max_val
  220. self.inplace = inplace
  221. if self.max_val <= self.min_val:
  222. raise AssertionError(
  223. f"max_val ({self.max_val}) must be greater than min_val ({self.min_val})"
  224. )
  225. def forward(self, input: Tensor) -> Tensor:
  226. """
  227. Runs the forward pass.
  228. """
  229. return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
  230. def extra_repr(self) -> str:
  231. """
  232. Return the extra representation of the module.
  233. """
  234. inplace_str = ", inplace=True" if self.inplace else ""
  235. return f"min_val={self.min_val}, max_val={self.max_val}{inplace_str}"
  236. class ReLU6(Hardtanh):
  237. r"""Applies the ReLU6 function element-wise.
  238. .. math::
  239. \text{ReLU6}(x) = \min(\max(0,x), 6)
  240. Args:
  241. inplace: can optionally do the operation in-place. Default: ``False``
  242. Shape:
  243. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  244. - Output: :math:`(*)`, same shape as the input.
  245. .. image:: ../scripts/activation_images/ReLU6.png
  246. Examples::
  247. >>> m = nn.ReLU6()
  248. >>> input = torch.randn(2)
  249. >>> output = m(input)
  250. """
  251. def __init__(self, inplace: bool = False) -> None:
  252. super().__init__(0.0, 6.0, inplace)
  253. def extra_repr(self) -> str:
  254. """
  255. Return the extra representation of the module.
  256. """
  257. inplace_str = "inplace=True" if self.inplace else ""
  258. return inplace_str
  259. class Sigmoid(Module):
  260. r"""Applies the Sigmoid function element-wise.
  261. .. math::
  262. \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
  263. Shape:
  264. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  265. - Output: :math:`(*)`, same shape as the input.
  266. .. image:: ../scripts/activation_images/Sigmoid.png
  267. Examples::
  268. >>> m = nn.Sigmoid()
  269. >>> input = torch.randn(2)
  270. >>> output = m(input)
  271. """
  272. def forward(self, input: Tensor) -> Tensor:
  273. """
  274. Runs the forward pass.
  275. """
  276. return torch.sigmoid(input)
  277. class Hardsigmoid(Module):
  278. r"""Applies the Hardsigmoid function element-wise.
  279. Hardsigmoid is defined as:
  280. .. math::
  281. \text{Hardsigmoid}(x) = \begin{cases}
  282. 0 & \text{if~} x \le -3, \\
  283. 1 & \text{if~} x \ge +3, \\
  284. x / 6 + 1 / 2 & \text{otherwise}
  285. \end{cases}
  286. Args:
  287. inplace: can optionally do the operation in-place. Default: ``False``
  288. Shape:
  289. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  290. - Output: :math:`(*)`, same shape as the input.
  291. .. image:: ../scripts/activation_images/Hardsigmoid.png
  292. Examples::
  293. >>> m = nn.Hardsigmoid()
  294. >>> input = torch.randn(2)
  295. >>> output = m(input)
  296. """
  297. __constants__ = ["inplace"]
  298. inplace: bool
  299. def __init__(self, inplace: bool = False) -> None:
  300. super().__init__()
  301. self.inplace = inplace
  302. def forward(self, input: Tensor) -> Tensor:
  303. """
  304. Runs the forward pass.
  305. """
  306. return F.hardsigmoid(input, self.inplace)
  307. class Tanh(Module):
  308. r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
  309. Tanh is defined as:
  310. .. math::
  311. \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
  312. Shape:
  313. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  314. - Output: :math:`(*)`, same shape as the input.
  315. .. image:: ../scripts/activation_images/Tanh.png
  316. Examples::
  317. >>> m = nn.Tanh()
  318. >>> input = torch.randn(2)
  319. >>> output = m(input)
  320. """
  321. def forward(self, input: Tensor) -> Tensor:
  322. """
  323. Runs the forward pass.
  324. """
  325. return torch.tanh(input)
  326. class SiLU(Module):
  327. r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
  328. The SiLU function is also known as the swish function.
  329. .. math::
  330. \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
  331. .. note::
  332. See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
  333. where the SiLU (Sigmoid Linear Unit) was originally coined, and see
  334. `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
  335. in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
  336. a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
  337. where the SiLU was experimented with later.
  338. Shape:
  339. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  340. - Output: :math:`(*)`, same shape as the input.
  341. .. image:: ../scripts/activation_images/SiLU.png
  342. Examples::
  343. >>> m = nn.SiLU()
  344. >>> input = torch.randn(2)
  345. >>> output = m(input)
  346. """
  347. __constants__ = ["inplace"]
  348. inplace: bool
  349. def __init__(self, inplace: bool = False) -> None:
  350. super().__init__()
  351. self.inplace = inplace
  352. def forward(self, input: Tensor) -> Tensor:
  353. """
  354. Runs the forward pass.
  355. """
  356. return F.silu(input, inplace=self.inplace)
  357. def extra_repr(self) -> str:
  358. """
  359. Return the extra representation of the module.
  360. """
  361. inplace_str = "inplace=True" if self.inplace else ""
  362. return inplace_str
  363. class Mish(Module):
  364. r"""Applies the Mish function, element-wise.
  365. Mish: A Self Regularized Non-Monotonic Neural Activation Function.
  366. .. math::
  367. \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
  368. .. note::
  369. See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
  370. Shape:
  371. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  372. - Output: :math:`(*)`, same shape as the input.
  373. .. image:: ../scripts/activation_images/Mish.png
  374. Examples::
  375. >>> m = nn.Mish()
  376. >>> input = torch.randn(2)
  377. >>> output = m(input)
  378. """
  379. __constants__ = ["inplace"]
  380. inplace: bool
  381. def __init__(self, inplace: bool = False) -> None:
  382. super().__init__()
  383. self.inplace = inplace
  384. def forward(self, input: Tensor) -> Tensor:
  385. """
  386. Runs the forward pass.
  387. """
  388. return F.mish(input, inplace=self.inplace)
  389. def extra_repr(self) -> str:
  390. """
  391. Return the extra representation of the module.
  392. """
  393. inplace_str = "inplace=True" if self.inplace else ""
  394. return inplace_str
  395. class Hardswish(Module):
  396. r"""Applies the Hardswish function, element-wise.
  397. Method described in the paper: `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
  398. Hardswish is defined as:
  399. .. math::
  400. \text{Hardswish}(x) = \begin{cases}
  401. 0 & \text{if~} x \le -3, \\
  402. x & \text{if~} x \ge +3, \\
  403. x \cdot (x + 3) /6 & \text{otherwise}
  404. \end{cases}
  405. Args:
  406. inplace: can optionally do the operation in-place. Default: ``False``
  407. Shape:
  408. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  409. - Output: :math:`(*)`, same shape as the input.
  410. .. image:: ../scripts/activation_images/Hardswish.png
  411. Examples::
  412. >>> m = nn.Hardswish()
  413. >>> input = torch.randn(2)
  414. >>> output = m(input)
  415. """
  416. __constants__ = ["inplace"]
  417. inplace: bool
  418. def __init__(self, inplace: bool = False) -> None:
  419. super().__init__()
  420. self.inplace = inplace
  421. def forward(self, input: Tensor) -> Tensor:
  422. """
  423. Runs the forward pass.
  424. """
  425. return F.hardswish(input, self.inplace)
  426. class ELU(Module):
  427. r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
  428. Method described in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
  429. Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
  430. ELU is defined as:
  431. .. math::
  432. \text{ELU}(x) = \begin{cases}
  433. x, & \text{ if } x > 0\\
  434. \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
  435. \end{cases}
  436. Args:
  437. alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
  438. inplace: can optionally do the operation in-place. Default: ``False``
  439. Shape:
  440. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  441. - Output: :math:`(*)`, same shape as the input.
  442. .. image:: ../scripts/activation_images/ELU.png
  443. Examples::
  444. >>> m = nn.ELU()
  445. >>> input = torch.randn(2)
  446. >>> output = m(input)
  447. """
  448. __constants__ = ["alpha", "inplace"]
  449. alpha: float
  450. inplace: bool
  451. def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None:
  452. super().__init__()
  453. self.alpha = alpha
  454. self.inplace = inplace
  455. def forward(self, input: Tensor) -> Tensor:
  456. """
  457. Runs the forward pass.
  458. """
  459. return F.elu(input, self.alpha, self.inplace)
  460. def extra_repr(self) -> str:
  461. """
  462. Return the extra representation of the module.
  463. """
  464. inplace_str = ", inplace=True" if self.inplace else ""
  465. return f"alpha={self.alpha}{inplace_str}"
  466. class CELU(Module):
  467. r"""Applies the CELU function element-wise.
  468. .. math::
  469. \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
  470. More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
  471. Args:
  472. alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
  473. inplace: can optionally do the operation in-place. Default: ``False``
  474. Shape:
  475. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  476. - Output: :math:`(*)`, same shape as the input.
  477. .. image:: ../scripts/activation_images/CELU.png
  478. Examples::
  479. >>> m = nn.CELU()
  480. >>> input = torch.randn(2)
  481. >>> output = m(input)
  482. .. _`Continuously Differentiable Exponential Linear Units`:
  483. https://arxiv.org/abs/1704.07483
  484. """
  485. __constants__ = ["alpha", "inplace"]
  486. alpha: float
  487. inplace: bool
  488. def __init__(self, alpha: float = 1.0, inplace: bool = False) -> None:
  489. super().__init__()
  490. self.alpha = alpha
  491. self.inplace = inplace
  492. def forward(self, input: Tensor) -> Tensor:
  493. """
  494. Runs the forward pass.
  495. """
  496. return F.celu(input, self.alpha, self.inplace)
  497. def extra_repr(self) -> str:
  498. """
  499. Return the extra representation of the module.
  500. """
  501. inplace_str = ", inplace=True" if self.inplace else ""
  502. return f"alpha={self.alpha}{inplace_str}"
  503. class SELU(Module):
  504. r"""Applies the SELU function element-wise.
  505. .. math::
  506. \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
  507. with :math:`\alpha = 1.6732632423543772848170429916717` and
  508. :math:`\text{scale} = 1.0507009873554804934193349852946`.
  509. .. warning::
  510. When using ``kaiming_normal`` or ``kaiming_normal_`` for initialisation,
  511. ``nonlinearity='linear'`` should be used instead of ``nonlinearity='selu'``
  512. in order to get `Self-Normalizing Neural Networks`_.
  513. See :func:`torch.nn.init.calculate_gain` for more information.
  514. More details can be found in the paper `Self-Normalizing Neural Networks`_ .
  515. Args:
  516. inplace (bool, optional): can optionally do the operation in-place. Default: ``False``
  517. Shape:
  518. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  519. - Output: :math:`(*)`, same shape as the input.
  520. .. image:: ../scripts/activation_images/SELU.png
  521. Examples::
  522. >>> m = nn.SELU()
  523. >>> input = torch.randn(2)
  524. >>> output = m(input)
  525. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  526. """
  527. __constants__ = ["inplace"]
  528. inplace: bool
  529. def __init__(self, inplace: bool = False) -> None:
  530. super().__init__()
  531. self.inplace = inplace
  532. def forward(self, input: Tensor) -> Tensor:
  533. """
  534. Runs the forward pass.
  535. """
  536. return F.selu(input, self.inplace)
  537. def extra_repr(self) -> str:
  538. """
  539. Return the extra representation of the module.
  540. """
  541. inplace_str = "inplace=True" if self.inplace else ""
  542. return inplace_str
  543. class GLU(Module):
  544. r"""Applies the gated linear unit function.
  545. :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
  546. of the input matrices and :math:`b` is the second half.
  547. Args:
  548. dim (int): the dimension on which to split the input. Default: -1
  549. Shape:
  550. - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
  551. dimensions
  552. - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
  553. Examples::
  554. >>> m = nn.GLU()
  555. >>> input = torch.randn(4, 2)
  556. >>> output = m(input)
  557. """
  558. __constants__ = ["dim"]
  559. dim: int
  560. def __init__(self, dim: int = -1) -> None:
  561. super().__init__()
  562. self.dim = dim
  563. def forward(self, input: Tensor) -> Tensor:
  564. """
  565. Runs the forward pass.
  566. """
  567. return F.glu(input, self.dim)
  568. def extra_repr(self) -> str:
  569. """
  570. Return the extra representation of the module.
  571. """
  572. return f"dim={self.dim}"
  573. class GELU(Module):
  574. r"""Applies the Gaussian Error Linear Units function.
  575. .. math:: \text{GELU}(x) = x * \Phi(x)
  576. where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
  577. When the approximate argument is 'tanh', Gelu is estimated with:
  578. .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt{2 / \pi} * (x + 0.044715 * x^3)))
  579. Args:
  580. approximate (str, optional): the gelu approximation algorithm to use:
  581. ``'none'`` | ``'tanh'``. Default: ``'none'``
  582. Shape:
  583. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  584. - Output: :math:`(*)`, same shape as the input.
  585. .. image:: ../scripts/activation_images/GELU.png
  586. Examples::
  587. >>> m = nn.GELU()
  588. >>> input = torch.randn(2)
  589. >>> output = m(input)
  590. """
  591. __constants__ = ["approximate"]
  592. approximate: str
  593. def __init__(self, approximate: str = "none") -> None:
  594. super().__init__()
  595. self.approximate = approximate
  596. def forward(self, input: Tensor) -> Tensor:
  597. """
  598. Runs the forward pass.
  599. """
  600. return F.gelu(input, approximate=self.approximate)
  601. def extra_repr(self) -> str:
  602. """
  603. Return the extra representation of the module.
  604. """
  605. return f"approximate={repr(self.approximate)}"
  606. class Hardshrink(Module):
  607. r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
  608. Hardshrink is defined as:
  609. .. math::
  610. \text{HardShrink}(x) =
  611. \begin{cases}
  612. x, & \text{ if } x > \lambda \\
  613. x, & \text{ if } x < -\lambda \\
  614. 0, & \text{ otherwise }
  615. \end{cases}
  616. Args:
  617. lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
  618. Shape:
  619. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  620. - Output: :math:`(*)`, same shape as the input.
  621. .. image:: ../scripts/activation_images/Hardshrink.png
  622. Examples::
  623. >>> m = nn.Hardshrink()
  624. >>> input = torch.randn(2)
  625. >>> output = m(input)
  626. """
  627. __constants__ = ["lambd"]
  628. lambd: float
  629. def __init__(self, lambd: float = 0.5) -> None:
  630. super().__init__()
  631. self.lambd = lambd
  632. def forward(self, input: Tensor) -> Tensor:
  633. """
  634. Run forward pass.
  635. """
  636. return F.hardshrink(input, self.lambd)
  637. def extra_repr(self) -> str:
  638. """
  639. Return the extra representation of the module.
  640. """
  641. return f"{self.lambd}"
  642. class LeakyReLU(Module):
  643. r"""Applies the LeakyReLU function element-wise.
  644. .. math::
  645. \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
  646. or
  647. .. math::
  648. \text{LeakyReLU}(x) =
  649. \begin{cases}
  650. x, & \text{ if } x \geq 0 \\
  651. \text{negative\_slope} \times x, & \text{ otherwise }
  652. \end{cases}
  653. Args:
  654. negative_slope: Controls the angle of the negative slope (which is used for
  655. negative input values). Default: 1e-2
  656. inplace: can optionally do the operation in-place. Default: ``False``
  657. Shape:
  658. - Input: :math:`(*)` where `*` means, any number of additional
  659. dimensions
  660. - Output: :math:`(*)`, same shape as the input
  661. .. image:: ../scripts/activation_images/LeakyReLU.png
  662. Examples::
  663. >>> m = nn.LeakyReLU(0.1)
  664. >>> input = torch.randn(2)
  665. >>> output = m(input)
  666. """
  667. __constants__ = ["inplace", "negative_slope"]
  668. inplace: bool
  669. negative_slope: float
  670. def __init__(self, negative_slope: float = 1e-2, inplace: bool = False) -> None:
  671. super().__init__()
  672. self.negative_slope = negative_slope
  673. self.inplace = inplace
  674. def forward(self, input: Tensor) -> Tensor:
  675. """
  676. Run forward pass.
  677. """
  678. return F.leaky_relu(input, self.negative_slope, self.inplace)
  679. def extra_repr(self) -> str:
  680. """
  681. Return the extra representation of the module.
  682. """
  683. inplace_str = ", inplace=True" if self.inplace else ""
  684. return f"negative_slope={self.negative_slope}{inplace_str}"
  685. class LogSigmoid(Module):
  686. r"""Applies the Logsigmoid function element-wise.
  687. .. math::
  688. \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
  689. Shape:
  690. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  691. - Output: :math:`(*)`, same shape as the input.
  692. .. image:: ../scripts/activation_images/LogSigmoid.png
  693. Examples::
  694. >>> m = nn.LogSigmoid()
  695. >>> input = torch.randn(2)
  696. >>> output = m(input)
  697. """
  698. def forward(self, input: Tensor) -> Tensor:
  699. """
  700. Run forward pass.
  701. """
  702. return F.logsigmoid(input)
  703. class Softplus(Module):
  704. r"""Applies the Softplus function element-wise.
  705. .. math::
  706. \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
  707. SoftPlus is a smooth approximation to the ReLU function and can be used
  708. to constrain the output of a machine to always be positive.
  709. For numerical stability the implementation reverts to the linear function
  710. when :math:`input \times \beta > threshold`.
  711. Args:
  712. beta: the :math:`\beta` value for the Softplus formulation. Default: 1
  713. threshold: values above this revert to a linear function. Default: 20
  714. Shape:
  715. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  716. - Output: :math:`(*)`, same shape as the input.
  717. .. image:: ../scripts/activation_images/Softplus.png
  718. Examples::
  719. >>> m = nn.Softplus()
  720. >>> input = torch.randn(2)
  721. >>> output = m(input)
  722. """
  723. __constants__ = ["beta", "threshold"]
  724. beta: float
  725. threshold: float
  726. def __init__(self, beta: float = 1.0, threshold: float = 20.0) -> None:
  727. super().__init__()
  728. self.beta = beta
  729. self.threshold = threshold
  730. def forward(self, input: Tensor) -> Tensor:
  731. """
  732. Run forward pass.
  733. """
  734. return F.softplus(input, self.beta, self.threshold)
  735. def extra_repr(self) -> str:
  736. """
  737. Return the extra representation of the module.
  738. """
  739. return f"beta={self.beta}, threshold={self.threshold}"
  740. class Softshrink(Module):
  741. r"""Applies the soft shrinkage function element-wise.
  742. .. math::
  743. \text{SoftShrinkage}(x) =
  744. \begin{cases}
  745. x - \lambda, & \text{ if } x > \lambda \\
  746. x + \lambda, & \text{ if } x < -\lambda \\
  747. 0, & \text{ otherwise }
  748. \end{cases}
  749. Args:
  750. lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
  751. Shape:
  752. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  753. - Output: :math:`(*)`, same shape as the input.
  754. .. image:: ../scripts/activation_images/Softshrink.png
  755. Examples::
  756. >>> m = nn.Softshrink()
  757. >>> input = torch.randn(2)
  758. >>> output = m(input)
  759. """
  760. __constants__ = ["lambd"]
  761. lambd: float
  762. def __init__(self, lambd: float = 0.5) -> None:
  763. super().__init__()
  764. self.lambd = lambd
  765. def forward(self, input: Tensor) -> Tensor:
  766. """
  767. Run forward pass.
  768. """
  769. return F.softshrink(input, self.lambd)
  770. def extra_repr(self) -> str:
  771. """
  772. Return the extra representation of the module.
  773. """
  774. return str(self.lambd)
  775. def _check_arg_device(x: torch.Tensor | None) -> bool:
  776. if x is not None:
  777. return x.device.type in [
  778. "cpu",
  779. "cuda",
  780. torch.utils.backend_registration._privateuse1_backend_name,
  781. ]
  782. return True
  783. def _arg_requires_grad(x: torch.Tensor | None) -> bool:
  784. if x is not None:
  785. return x.requires_grad
  786. return False
  787. def _is_make_fx_tracing():
  788. if not torch.jit.is_scripting():
  789. torch_dispatch_mode_stack = (
  790. torch.utils._python_dispatch._get_current_dispatch_mode_stack()
  791. )
  792. # this can be triggered when dynamo inlining the module too.
  793. return (
  794. any(
  795. type(x) is torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode
  796. for x in torch_dispatch_mode_stack
  797. )
  798. or torch.compiler.is_exporting()
  799. )
  800. else:
  801. return False
  802. class MultiheadAttention(Module):
  803. r"""Allows the model to jointly attend to information from different representation subspaces.
  804. This MultiheadAttention layer implements the original architecture described
  805. in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
  806. intent of this layer is as a reference implementation for foundational understanding
  807. and thus it contains only limited features relative to newer architectures.
  808. Given the fast pace of innovation in transformer-like architectures, we recommend
  809. exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
  810. to build efficient layers from building blocks in core or using higher
  811. level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
  812. Multi-Head Attention is defined as:
  813. .. math::
  814. \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O
  815. where :math:`\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
  816. ``nn.MultiheadAttention`` will use the optimized implementations of
  817. ``scaled_dot_product_attention()`` when possible.
  818. In addition to support for the new ``scaled_dot_product_attention()``
  819. function, for speeding up Inference, MHA will use
  820. fastpath inference with support for Nested Tensors, iff:
  821. - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor).
  822. - inputs are batched (3D) with ``batch_first==True``
  823. - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
  824. - training is disabled (using ``.eval()``)
  825. - ``add_bias_kv`` is ``False``
  826. - ``add_zero_attn`` is ``False``
  827. - ``kdim`` and ``vdim`` are equal to ``embed_dim``
  828. - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
  829. nor ``attn_mask`` is passed
  830. - autocast is disabled
  831. If the optimized inference fastpath implementation is in use, a
  832. `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
  833. ``query``/``key``/``value`` to represent padding more efficiently than using a
  834. padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
  835. will be returned, and an additional speedup proportional to the fraction of the input
  836. that is padding can be expected.
  837. Args:
  838. embed_dim: Total dimension of the model.
  839. num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
  840. across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
  841. dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
  842. bias: If specified, adds bias to input / output projection layers. Default: ``True``.
  843. add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
  844. add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
  845. Default: ``False``.
  846. kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
  847. vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
  848. batch_first: If ``True``, then the input and output tensors are provided
  849. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  850. Examples::
  851. >>> # xdoctest: +SKIP
  852. >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
  853. >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
  854. .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
  855. https://arxiv.org/abs/2205.14135
  856. """
  857. __constants__ = ["batch_first"]
  858. bias_k: torch.Tensor | None
  859. bias_v: torch.Tensor | None
  860. def __init__(
  861. self,
  862. embed_dim,
  863. num_heads,
  864. dropout=0.0,
  865. bias=True,
  866. add_bias_kv=False,
  867. add_zero_attn=False,
  868. kdim=None,
  869. vdim=None,
  870. batch_first=False,
  871. device=None,
  872. dtype=None,
  873. ) -> None:
  874. if embed_dim <= 0 or num_heads <= 0:
  875. raise ValueError(
  876. f"embed_dim and num_heads must be greater than 0,"
  877. f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
  878. )
  879. factory_kwargs = {"device": device, "dtype": dtype}
  880. super().__init__()
  881. self.embed_dim = embed_dim
  882. self.kdim = kdim if kdim is not None else embed_dim
  883. self.vdim = vdim if vdim is not None else embed_dim
  884. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  885. self.num_heads = num_heads
  886. self.dropout = dropout
  887. self.batch_first = batch_first
  888. self.head_dim = embed_dim // num_heads
  889. if self.head_dim * num_heads != self.embed_dim:
  890. raise AssertionError("embed_dim must be divisible by num_heads")
  891. if not self._qkv_same_embed_dim:
  892. self.q_proj_weight = Parameter(
  893. torch.empty((embed_dim, embed_dim), **factory_kwargs)
  894. )
  895. self.k_proj_weight = Parameter(
  896. torch.empty((embed_dim, self.kdim), **factory_kwargs)
  897. )
  898. self.v_proj_weight = Parameter(
  899. torch.empty((embed_dim, self.vdim), **factory_kwargs)
  900. )
  901. self.register_parameter("in_proj_weight", None)
  902. else:
  903. self.in_proj_weight = Parameter(
  904. torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
  905. )
  906. self.register_parameter("q_proj_weight", None)
  907. self.register_parameter("k_proj_weight", None)
  908. self.register_parameter("v_proj_weight", None)
  909. if bias:
  910. self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
  911. else:
  912. self.register_parameter("in_proj_bias", None)
  913. self.out_proj = NonDynamicallyQuantizableLinear(
  914. embed_dim, embed_dim, bias=bias, **factory_kwargs
  915. )
  916. if add_bias_kv:
  917. self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  918. self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
  919. else:
  920. self.bias_k = self.bias_v = None
  921. self.add_zero_attn = add_zero_attn
  922. self._reset_parameters()
  923. def _reset_parameters(self) -> None:
  924. if self._qkv_same_embed_dim:
  925. xavier_uniform_(self.in_proj_weight)
  926. else:
  927. xavier_uniform_(self.q_proj_weight)
  928. xavier_uniform_(self.k_proj_weight)
  929. xavier_uniform_(self.v_proj_weight)
  930. if self.in_proj_bias is not None:
  931. constant_(self.in_proj_bias, 0.0)
  932. constant_(self.out_proj.bias, 0.0)
  933. if self.bias_k is not None:
  934. xavier_normal_(self.bias_k)
  935. if self.bias_v is not None:
  936. xavier_normal_(self.bias_v)
  937. def __setstate__(self, state):
  938. # Support loading old MultiheadAttention checkpoints generated by v1.1.0
  939. if "_qkv_same_embed_dim" not in state:
  940. state["_qkv_same_embed_dim"] = True
  941. super().__setstate__(state)
  942. def forward(
  943. self,
  944. query: Tensor,
  945. key: Tensor,
  946. value: Tensor,
  947. key_padding_mask: Tensor | None = None,
  948. need_weights: bool = True,
  949. attn_mask: Tensor | None = None,
  950. average_attn_weights: bool = True,
  951. is_causal: bool = False,
  952. ) -> tuple[Tensor, Tensor | None]:
  953. r"""Compute attention outputs using query, key, and value embeddings.
  954. Supports optional parameters for padding, masks and attention weights.
  955. Args:
  956. query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
  957. or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
  958. :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
  959. Queries are compared against key-value pairs to produce the output.
  960. See "Attention Is All You Need" for more details.
  961. key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
  962. or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
  963. :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
  964. See "Attention Is All You Need" for more details.
  965. value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
  966. ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
  967. sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
  968. See "Attention Is All You Need" for more details.
  969. key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
  970. to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
  971. Binary and float masks are supported.
  972. For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
  973. the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
  974. need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
  975. Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``
  976. and achieve the best performance for MHA.
  977. Default: ``True``.
  978. attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
  979. :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
  980. :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
  981. broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
  982. Binary and float masks are supported. For a binary mask, a ``True`` value indicates that the
  983. corresponding position is not allowed to attend. For a float mask, the mask values will be added to
  984. the attention weight.
  985. If both attn_mask and key_padding_mask are supplied, their types should match.
  986. average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
  987. heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
  988. effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
  989. is_causal: If specified, applies a causal mask as attention mask.
  990. Default: ``False``.
  991. Warning:
  992. ``is_causal`` provides a hint that ``attn_mask`` is the
  993. causal mask. Providing incorrect hints can result in
  994. incorrect execution, including forward and backward
  995. compatibility.
  996. Outputs:
  997. - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
  998. :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
  999. where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
  1000. embedding dimension ``embed_dim``.
  1001. - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
  1002. returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
  1003. :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
  1004. :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
  1005. head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
  1006. .. note::
  1007. `batch_first` argument is ignored for unbatched inputs.
  1008. """ # noqa: B950
  1009. why_not_fast_path = ""
  1010. if (
  1011. (attn_mask is not None and torch.is_floating_point(attn_mask))
  1012. or (key_padding_mask is not None)
  1013. and torch.is_floating_point(key_padding_mask)
  1014. ):
  1015. why_not_fast_path = "floating-point masks are not supported for fast path."
  1016. is_batched = query.dim() == 3
  1017. key_padding_mask = F._canonical_mask(
  1018. mask=key_padding_mask,
  1019. mask_name="key_padding_mask",
  1020. other_type=F._none_or_dtype(attn_mask),
  1021. other_name="attn_mask",
  1022. target_type=query.dtype,
  1023. )
  1024. attn_mask = F._canonical_mask(
  1025. mask=attn_mask,
  1026. mask_name="attn_mask",
  1027. other_type=None,
  1028. other_name="",
  1029. target_type=query.dtype,
  1030. check_other=False,
  1031. )
  1032. is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
  1033. if not is_fastpath_enabled:
  1034. why_not_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
  1035. elif not is_batched:
  1036. why_not_fast_path = (
  1037. f"input not batched; expected query.dim() of 3 but got {query.dim()}"
  1038. )
  1039. elif query is not key or key is not value:
  1040. # When lifting this restriction, don't forget to either
  1041. # enforce that the dtypes all match or test cases where
  1042. # they don't!
  1043. why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
  1044. elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
  1045. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
  1046. elif self.in_proj_weight is None:
  1047. why_not_fast_path = "in_proj_weight was None"
  1048. elif query.dtype != self.in_proj_weight.dtype:
  1049. # this case will fail anyway, but at least they'll get a useful error message.
  1050. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
  1051. elif self.training:
  1052. why_not_fast_path = "training is enabled"
  1053. elif (self.num_heads % 2) != 0:
  1054. why_not_fast_path = "self.num_heads is not even"
  1055. elif not self.batch_first:
  1056. why_not_fast_path = "batch_first was not True"
  1057. elif self.bias_k is not None:
  1058. why_not_fast_path = "self.bias_k was not None"
  1059. elif self.bias_v is not None:
  1060. why_not_fast_path = "self.bias_v was not None"
  1061. elif self.add_zero_attn:
  1062. why_not_fast_path = "add_zero_attn was enabled"
  1063. elif not self._qkv_same_embed_dim:
  1064. why_not_fast_path = "_qkv_same_embed_dim was not True"
  1065. elif query.is_nested and (
  1066. key_padding_mask is not None or attn_mask is not None
  1067. ):
  1068. why_not_fast_path = (
  1069. "supplying both src_key_padding_mask and src_mask at the same time \
  1070. is not supported with NestedTensor input"
  1071. )
  1072. elif torch.is_autocast_enabled():
  1073. why_not_fast_path = "autocast is enabled"
  1074. if not why_not_fast_path:
  1075. tensor_args = (
  1076. query,
  1077. key,
  1078. value,
  1079. self.in_proj_weight,
  1080. self.in_proj_bias,
  1081. self.out_proj.weight,
  1082. self.out_proj.bias,
  1083. )
  1084. # We have to use list comprehensions below because TorchScript does not support
  1085. # generator expressions.
  1086. if torch.overrides.has_torch_function(tensor_args):
  1087. why_not_fast_path = "some Tensor argument has_torch_function"
  1088. elif _is_make_fx_tracing():
  1089. why_not_fast_path = "we are running make_fx tracing"
  1090. elif not all(_check_arg_device(x) for x in tensor_args):
  1091. why_not_fast_path = (
  1092. "some Tensor argument's device is neither one of "
  1093. f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
  1094. )
  1095. elif torch.is_grad_enabled() and any(
  1096. _arg_requires_grad(x) for x in tensor_args
  1097. ):
  1098. why_not_fast_path = (
  1099. "grad is enabled and at least one of query or the "
  1100. "input/output projection weights or biases requires_grad"
  1101. )
  1102. if not why_not_fast_path:
  1103. merged_mask, mask_type = self.merge_masks(
  1104. attn_mask, key_padding_mask, query
  1105. )
  1106. if self.in_proj_bias is not None and self.in_proj_weight is not None:
  1107. return torch._native_multi_head_attention(
  1108. query,
  1109. key,
  1110. value,
  1111. self.embed_dim,
  1112. self.num_heads,
  1113. self.in_proj_weight,
  1114. self.in_proj_bias,
  1115. self.out_proj.weight,
  1116. self.out_proj.bias,
  1117. merged_mask,
  1118. need_weights,
  1119. average_attn_weights,
  1120. mask_type,
  1121. )
  1122. any_nested = query.is_nested or key.is_nested or value.is_nested
  1123. if any_nested:
  1124. raise AssertionError(
  1125. "MultiheadAttention does not support NestedTensor outside of its fast path. "
  1126. + f"The fast path was not hit because {why_not_fast_path}"
  1127. )
  1128. if self.batch_first and is_batched:
  1129. # make sure that the transpose op does not affect the "is" property
  1130. if key is value:
  1131. if query is key:
  1132. query = key = value = query.transpose(1, 0)
  1133. else:
  1134. query, key = (x.transpose(1, 0) for x in (query, key))
  1135. value = key
  1136. else:
  1137. query, key, value = (x.transpose(1, 0) for x in (query, key, value))
  1138. if not self._qkv_same_embed_dim:
  1139. attn_output, attn_output_weights = F.multi_head_attention_forward(
  1140. query,
  1141. key,
  1142. value,
  1143. self.embed_dim,
  1144. self.num_heads,
  1145. self.in_proj_weight,
  1146. self.in_proj_bias,
  1147. self.bias_k,
  1148. self.bias_v,
  1149. self.add_zero_attn,
  1150. self.dropout,
  1151. self.out_proj.weight,
  1152. self.out_proj.bias,
  1153. training=self.training,
  1154. key_padding_mask=key_padding_mask,
  1155. need_weights=need_weights,
  1156. attn_mask=attn_mask,
  1157. use_separate_proj_weight=True,
  1158. q_proj_weight=self.q_proj_weight,
  1159. k_proj_weight=self.k_proj_weight,
  1160. v_proj_weight=self.v_proj_weight,
  1161. average_attn_weights=average_attn_weights,
  1162. is_causal=is_causal,
  1163. )
  1164. else:
  1165. attn_output, attn_output_weights = F.multi_head_attention_forward(
  1166. query,
  1167. key,
  1168. value,
  1169. self.embed_dim,
  1170. self.num_heads,
  1171. self.in_proj_weight,
  1172. self.in_proj_bias,
  1173. self.bias_k,
  1174. self.bias_v,
  1175. self.add_zero_attn,
  1176. self.dropout,
  1177. self.out_proj.weight,
  1178. self.out_proj.bias,
  1179. training=self.training,
  1180. key_padding_mask=key_padding_mask,
  1181. need_weights=need_weights,
  1182. attn_mask=attn_mask,
  1183. average_attn_weights=average_attn_weights,
  1184. is_causal=is_causal,
  1185. )
  1186. if self.batch_first and is_batched:
  1187. return attn_output.transpose(1, 0), attn_output_weights
  1188. else:
  1189. return attn_output, attn_output_weights
  1190. def merge_masks(
  1191. self,
  1192. attn_mask: Tensor | None,
  1193. key_padding_mask: Tensor | None,
  1194. query: Tensor,
  1195. ) -> tuple[Tensor | None, int | None]:
  1196. r"""Determine mask type and combine masks if necessary.
  1197. If only one mask is provided, that mask
  1198. and the corresponding mask type will be returned. If both masks are provided, they will be both
  1199. expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
  1200. and mask type 2 will be returned
  1201. Args:
  1202. attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
  1203. key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
  1204. query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
  1205. Returns:
  1206. merged_mask: merged mask
  1207. mask_type: merged mask type (0, 1, or 2)
  1208. """
  1209. mask_type: int | None = None
  1210. merged_mask: Tensor | None = None
  1211. if key_padding_mask is not None:
  1212. mask_type = 1
  1213. merged_mask = key_padding_mask
  1214. if attn_mask is not None:
  1215. # In this branch query can't be a nested tensor, so it has a shape
  1216. batch_size, seq_len, _ = query.shape
  1217. mask_type = 2
  1218. # Always expands attn_mask to 4D
  1219. if attn_mask.dim() == 3:
  1220. attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
  1221. else: # attn_mask.dim() == 2:
  1222. attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(
  1223. batch_size, self.num_heads, -1, -1
  1224. )
  1225. merged_mask = attn_mask_expanded
  1226. if key_padding_mask is not None:
  1227. key_padding_mask_expanded = key_padding_mask.view(
  1228. batch_size, 1, 1, seq_len
  1229. ).expand(-1, self.num_heads, -1, -1)
  1230. merged_mask = attn_mask_expanded + key_padding_mask_expanded
  1231. # no attn_mask and no key_padding_mask, returns None, None
  1232. return merged_mask, mask_type
  1233. class PReLU(Module):
  1234. r"""Applies the element-wise PReLU function.
  1235. .. math::
  1236. \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
  1237. or
  1238. .. math::
  1239. \text{PReLU}(x) =
  1240. \begin{cases}
  1241. x, & \text{ if } x \ge 0 \\
  1242. ax, & \text{ otherwise }
  1243. \end{cases}
  1244. Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
  1245. parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
  1246. a separate :math:`a` is used for each input channel.
  1247. .. note::
  1248. weight decay should not be used when learning :math:`a` for good performance.
  1249. .. note::
  1250. Channel dim is the 2nd dim of input. When input has dims < 2, then there is
  1251. no channel dim and the number of channels = 1.
  1252. Args:
  1253. num_parameters (int): number of :math:`a` to learn.
  1254. Although it takes an int as input, there is only two values are legitimate:
  1255. 1, or the number of channels at input. Default: 1
  1256. init (float): the initial value of :math:`a`. Default: 0.25
  1257. Shape:
  1258. - Input: :math:`( *)` where `*` means, any number of additional
  1259. dimensions.
  1260. - Output: :math:`(*)`, same shape as the input.
  1261. Attributes:
  1262. weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
  1263. .. image:: ../scripts/activation_images/PReLU.png
  1264. Examples::
  1265. >>> m = nn.PReLU()
  1266. >>> input = torch.randn(2)
  1267. >>> output = m(input)
  1268. """
  1269. __constants__ = ["num_parameters"]
  1270. num_parameters: int
  1271. def __init__(
  1272. self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None
  1273. ) -> None:
  1274. factory_kwargs = {"device": device, "dtype": dtype}
  1275. self.num_parameters = num_parameters
  1276. super().__init__()
  1277. self.init = init
  1278. self.weight = Parameter(torch.empty(num_parameters, **factory_kwargs))
  1279. self.reset_parameters()
  1280. def reset_parameters(self) -> None:
  1281. """
  1282. Resets parameters based on their initialization used in ``__init__``.
  1283. """
  1284. torch.nn.init.constant_(self.weight, self.init)
  1285. def forward(self, input: Tensor) -> Tensor:
  1286. """
  1287. Runs the forward pass.
  1288. """
  1289. return F.prelu(input, self.weight)
  1290. def extra_repr(self) -> str:
  1291. """
  1292. Return the extra representation of the module.
  1293. """
  1294. return f"num_parameters={self.num_parameters}"
  1295. class Softsign(Module):
  1296. r"""Applies the element-wise Softsign function.
  1297. .. math::
  1298. \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
  1299. Shape:
  1300. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  1301. - Output: :math:`(*)`, same shape as the input.
  1302. .. image:: ../scripts/activation_images/Softsign.png
  1303. Examples::
  1304. >>> m = nn.Softsign()
  1305. >>> input = torch.randn(2)
  1306. >>> output = m(input)
  1307. """
  1308. def forward(self, input: Tensor) -> Tensor:
  1309. """
  1310. Runs the forward pass.
  1311. """
  1312. return F.softsign(input)
  1313. class Tanhshrink(Module):
  1314. r"""Applies the element-wise Tanhshrink function.
  1315. .. math::
  1316. \text{Tanhshrink}(x) = x - \tanh(x)
  1317. Shape:
  1318. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  1319. - Output: :math:`(*)`, same shape as the input.
  1320. .. image:: ../scripts/activation_images/Tanhshrink.png
  1321. Examples::
  1322. >>> m = nn.Tanhshrink()
  1323. >>> input = torch.randn(2)
  1324. >>> output = m(input)
  1325. """
  1326. def forward(self, input: Tensor) -> Tensor:
  1327. """
  1328. Runs the forward pass.
  1329. """
  1330. return F.tanhshrink(input)
  1331. class Softmin(Module):
  1332. r"""Applies the Softmin function to an n-dimensional input Tensor.
  1333. Rescales them so that the elements of the n-dimensional output Tensor
  1334. lie in the range `[0, 1]` and sum to 1.
  1335. Softmin is defined as:
  1336. .. math::
  1337. \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
  1338. Shape:
  1339. - Input: :math:`(*)` where `*` means, any number of additional
  1340. dimensions
  1341. - Output: :math:`(*)`, same shape as the input
  1342. Args:
  1343. dim (int): A dimension along which Softmin will be computed (so every slice
  1344. along dim will sum to 1).
  1345. Returns:
  1346. a Tensor of the same dimension and shape as the input, with
  1347. values in the range [0, 1]
  1348. Examples::
  1349. >>> m = nn.Softmin(dim=1)
  1350. >>> input = torch.randn(2, 3)
  1351. >>> output = m(input)
  1352. """
  1353. __constants__ = ["dim"]
  1354. dim: int | None
  1355. def __init__(self, dim: int | None = None) -> None:
  1356. super().__init__()
  1357. self.dim = dim
  1358. def __setstate__(self, state):
  1359. super().__setstate__(state)
  1360. if not hasattr(self, "dim"):
  1361. self.dim = None
  1362. def forward(self, input: Tensor) -> Tensor:
  1363. """
  1364. Runs the forward pass.
  1365. """
  1366. return F.softmin(input, self.dim, _stacklevel=5)
  1367. def extra_repr(self) -> str:
  1368. """
  1369. Return the extra representation of the module.
  1370. """
  1371. return f"dim={self.dim}"
  1372. class Softmax(Module):
  1373. r"""Applies the Softmax function to an n-dimensional input Tensor.
  1374. Rescales them so that the elements of the n-dimensional output Tensor
  1375. lie in the range [0,1] and sum to 1.
  1376. Softmax is defined as:
  1377. .. math::
  1378. \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
  1379. When the input Tensor is a sparse tensor then the unspecified
  1380. values are treated as ``-inf``.
  1381. Shape:
  1382. - Input: :math:`(*)` where `*` means, any number of additional
  1383. dimensions
  1384. - Output: :math:`(*)`, same shape as the input
  1385. Returns:
  1386. a Tensor of the same dimension and shape as the input with
  1387. values in the range [0, 1]
  1388. Args:
  1389. dim (int): A dimension along which Softmax will be computed (so every slice
  1390. along dim will sum to 1).
  1391. .. note::
  1392. This module doesn't work directly with NLLLoss,
  1393. which expects the Log to be computed between the Softmax and itself.
  1394. Use `LogSoftmax` instead (it's faster and has better numerical properties).
  1395. Examples::
  1396. >>> m = nn.Softmax(dim=1)
  1397. >>> input = torch.randn(2, 3)
  1398. >>> output = m(input)
  1399. """
  1400. __constants__ = ["dim"]
  1401. dim: int | None
  1402. def __init__(self, dim: int | None = None) -> None:
  1403. super().__init__()
  1404. self.dim = dim
  1405. def __setstate__(self, state):
  1406. super().__setstate__(state)
  1407. if not hasattr(self, "dim"):
  1408. self.dim = None
  1409. def forward(self, input: Tensor) -> Tensor:
  1410. """
  1411. Runs the forward pass.
  1412. """
  1413. return F.softmax(input, self.dim, _stacklevel=5)
  1414. def extra_repr(self) -> str:
  1415. """
  1416. Return the extra representation of the module.
  1417. """
  1418. return f"dim={self.dim}"
  1419. class Softmax2d(Module):
  1420. r"""Applies SoftMax over features to each spatial location.
  1421. When given an image of ``Channels x Height x Width``, it will
  1422. apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
  1423. Shape:
  1424. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
  1425. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  1426. Returns:
  1427. a Tensor of the same dimension and shape as the input with
  1428. values in the range [0, 1]
  1429. Examples::
  1430. >>> m = nn.Softmax2d()
  1431. >>> # you softmax over the 2nd dimension
  1432. >>> input = torch.randn(2, 3, 12, 13)
  1433. >>> output = m(input)
  1434. """
  1435. def forward(self, input: Tensor) -> Tensor:
  1436. """
  1437. Runs the forward pass.
  1438. """
  1439. if input.dim() not in (3, 4):
  1440. raise ValueError(
  1441. f"Softmax2d: expected input to be 3D or 4D, got {input.dim()}D instead"
  1442. )
  1443. return F.softmax(input, -3, _stacklevel=5)
  1444. class LogSoftmax(Module):
  1445. r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
  1446. The LogSoftmax formulation can be simplified as:
  1447. .. math::
  1448. \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
  1449. Shape:
  1450. - Input: :math:`(*)` where `*` means, any number of additional
  1451. dimensions
  1452. - Output: :math:`(*)`, same shape as the input
  1453. Args:
  1454. dim (int): A dimension along which LogSoftmax will be computed.
  1455. Returns:
  1456. a Tensor of the same dimension and shape as the input with
  1457. values in the range [-inf, 0)
  1458. Examples::
  1459. >>> m = nn.LogSoftmax(dim=1)
  1460. >>> input = torch.randn(2, 3)
  1461. >>> output = m(input)
  1462. """
  1463. __constants__ = ["dim"]
  1464. dim: int | None
  1465. def __init__(self, dim: int | None = None) -> None:
  1466. super().__init__()
  1467. self.dim = dim
  1468. def __setstate__(self, state):
  1469. super().__setstate__(state)
  1470. if not hasattr(self, "dim"):
  1471. self.dim = None
  1472. def forward(self, input: Tensor) -> Tensor:
  1473. """
  1474. Runs the forward pass.
  1475. """
  1476. return F.log_softmax(input, self.dim, _stacklevel=5)
  1477. def extra_repr(self) -> str:
  1478. """
  1479. Return the extra representation of the module.
  1480. """
  1481. return f"dim={self.dim}"