linear.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Any
  4. import torch
  5. from torch import Tensor
  6. from torch.nn import functional as F, init
  7. from torch.nn.parameter import Parameter, UninitializedParameter
  8. from .lazy import LazyModuleMixin
  9. from .module import Module
  10. __all__ = [
  11. "Bilinear",
  12. "Identity",
  13. "LazyLinear",
  14. "Linear",
  15. ]
  16. class Identity(Module):
  17. r"""A placeholder identity operator that is argument-insensitive.
  18. Args:
  19. args: any argument (unused)
  20. kwargs: any keyword argument (unused)
  21. Shape:
  22. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  23. - Output: :math:`(*)`, same shape as the input.
  24. Examples::
  25. >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
  26. >>> input = torch.randn(128, 20)
  27. >>> output = m(input)
  28. >>> print(output.size())
  29. torch.Size([128, 20])
  30. """
  31. def __init__(self, *args: Any, **kwargs: Any) -> None:
  32. super().__init__()
  33. def forward(self, input: Tensor) -> Tensor:
  34. """
  35. Runs the forward pass.
  36. """
  37. return input
  38. class Linear(Module):
  39. r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`.
  40. This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
  41. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  42. Args:
  43. in_features: size of each input sample
  44. out_features: size of each output sample
  45. bias: If set to ``False``, the layer will not learn an additive bias.
  46. Default: ``True``
  47. Shape:
  48. - Input: :math:`(*, H_\text{in})` where :math:`*` means any number of
  49. dimensions including none and :math:`H_\text{in} = \text{in\_features}`.
  50. - Output: :math:`(*, H_\text{out})` where all but the last dimension
  51. are the same shape as the input and :math:`H_\text{out} = \text{out\_features}`.
  52. Attributes:
  53. weight: the learnable weights of the module of shape
  54. :math:`(\text{out\_features}, \text{in\_features})`. The values are
  55. initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  56. :math:`k = \frac{1}{\text{in\_features}}`
  57. bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
  58. If :attr:`bias` is ``True``, the values are initialized from
  59. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  60. :math:`k = \frac{1}{\text{in\_features}}`
  61. Examples::
  62. >>> m = nn.Linear(20, 30)
  63. >>> input = torch.randn(128, 20)
  64. >>> output = m(input)
  65. >>> print(output.size())
  66. torch.Size([128, 30])
  67. """
  68. __constants__ = ["in_features", "out_features"]
  69. in_features: int
  70. out_features: int
  71. weight: Tensor
  72. def __init__(
  73. self,
  74. in_features: int,
  75. out_features: int,
  76. bias: bool = True,
  77. device=None,
  78. dtype=None,
  79. ) -> None:
  80. factory_kwargs = {"device": device, "dtype": dtype}
  81. super().__init__()
  82. self.in_features = in_features
  83. self.out_features = out_features
  84. self.weight = Parameter(
  85. torch.empty((out_features, in_features), **factory_kwargs)
  86. )
  87. if bias:
  88. self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
  89. else:
  90. self.register_parameter("bias", None)
  91. self.reset_parameters()
  92. def reset_parameters(self) -> None:
  93. """
  94. Resets parameters based on their initialization used in ``__init__``.
  95. """
  96. # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
  97. # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
  98. # https://github.com/pytorch/pytorch/issues/57109
  99. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  100. if self.bias is not None:
  101. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  102. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  103. init.uniform_(self.bias, -bound, bound)
  104. def forward(self, input: Tensor) -> Tensor:
  105. """
  106. Runs the forward pass.
  107. """
  108. return F.linear(input, self.weight, self.bias)
  109. def extra_repr(self) -> str:
  110. """
  111. Return the extra representation of the module.
  112. """
  113. return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
  114. # This class exists solely to avoid triggering an obscure error when scripting
  115. # an improperly quantized attention layer. See this issue for details:
  116. # https://github.com/pytorch/pytorch/issues/58969
  117. # TODO: fail fast on quantization API usage error, then remove this class
  118. # and replace uses of it with plain Linear
  119. class NonDynamicallyQuantizableLinear(Linear):
  120. def __init__(
  121. self,
  122. in_features: int,
  123. out_features: int,
  124. bias: bool = True,
  125. device=None,
  126. dtype=None,
  127. ) -> None:
  128. super().__init__(
  129. in_features, out_features, bias=bias, device=device, dtype=dtype
  130. )
  131. class Bilinear(Module):
  132. r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`.
  133. Args:
  134. in1_features: size of each first input sample, must be > 0
  135. in2_features: size of each second input sample, must be > 0
  136. out_features: size of each output sample, must be > 0
  137. bias: If set to ``False``, the layer will not learn an additive bias.
  138. Default: ``True``
  139. Shape:
  140. - Input1: :math:`(*, H_\text{in1})` where :math:`H_\text{in1}=\text{in1\_features}` and
  141. :math:`*` means any number of additional dimensions including none. All but the last dimension
  142. of the inputs should be the same.
  143. - Input2: :math:`(*, H_\text{in2})` where :math:`H_\text{in2}=\text{in2\_features}`.
  144. - Output: :math:`(*, H_\text{out})` where :math:`H_\text{out}=\text{out\_features}`
  145. and all but the last dimension are the same shape as the input.
  146. Attributes:
  147. weight: the learnable weights of the module of shape
  148. :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`.
  149. The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  150. :math:`k = \frac{1}{\text{in1\_features}}`
  151. bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
  152. If :attr:`bias` is ``True``, the values are initialized from
  153. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  154. :math:`k = \frac{1}{\text{in1\_features}}`
  155. Examples::
  156. >>> m = nn.Bilinear(20, 30, 40)
  157. >>> input1 = torch.randn(128, 20)
  158. >>> input2 = torch.randn(128, 30)
  159. >>> output = m(input1, input2)
  160. >>> print(output.size())
  161. torch.Size([128, 40])
  162. """
  163. __constants__ = ["in1_features", "in2_features", "out_features"]
  164. in1_features: int
  165. in2_features: int
  166. out_features: int
  167. weight: Tensor
  168. def __init__(
  169. self,
  170. in1_features: int,
  171. in2_features: int,
  172. out_features: int,
  173. bias: bool = True,
  174. device=None,
  175. dtype=None,
  176. ) -> None:
  177. factory_kwargs = {"device": device, "dtype": dtype}
  178. super().__init__()
  179. self.in1_features = in1_features
  180. self.in2_features = in2_features
  181. self.out_features = out_features
  182. self.weight = Parameter(
  183. torch.empty((out_features, in1_features, in2_features), **factory_kwargs)
  184. )
  185. if bias:
  186. self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
  187. else:
  188. self.register_parameter("bias", None)
  189. self.reset_parameters()
  190. def reset_parameters(self) -> None:
  191. """
  192. Resets parameters based on their initialization used in ``__init__``.
  193. """
  194. if self.in1_features <= 0:
  195. raise ValueError(
  196. f"in1_features must be > 0, but got (in1_features={self.in1_features})"
  197. )
  198. bound = 1 / math.sqrt(self.weight.size(1))
  199. init.uniform_(self.weight, -bound, bound)
  200. if self.bias is not None:
  201. init.uniform_(self.bias, -bound, bound)
  202. def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
  203. """
  204. Runs the forward pass.
  205. """
  206. return F.bilinear(input1, input2, self.weight, self.bias)
  207. def extra_repr(self) -> str:
  208. """
  209. Return the extra representation of the module.
  210. """
  211. return (
  212. f"in1_features={self.in1_features}, in2_features={self.in2_features}, "
  213. f"out_features={self.out_features}, bias={self.bias is not None}"
  214. )
  215. class LazyLinear(LazyModuleMixin, Linear):
  216. r"""A :class:`torch.nn.Linear` module where `in_features` is inferred.
  217. In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter`
  218. class. They will be initialized after the first call to ``forward`` is done and the
  219. module will become a regular :class:`torch.nn.Linear` module. The ``in_features`` argument
  220. of the :class:`Linear` is inferred from the ``input.shape[-1]``.
  221. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  222. on lazy modules and their limitations.
  223. Args:
  224. out_features: size of each output sample
  225. bias: If set to ``False``, the layer will not learn an additive bias.
  226. Default: ``True``
  227. Attributes:
  228. weight: the learnable weights of the module of shape
  229. :math:`(\text{out\_features}, \text{in\_features})`. The values are
  230. initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  231. :math:`k = \frac{1}{\text{in\_features}}`
  232. bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
  233. If :attr:`bias` is ``True``, the values are initialized from
  234. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  235. :math:`k = \frac{1}{\text{in\_features}}`
  236. """
  237. cls_to_become = Linear # type: ignore[assignment]
  238. # pyrefly: ignore [bad-override]
  239. weight: UninitializedParameter
  240. bias: UninitializedParameter # type: ignore[assignment]
  241. def __init__(
  242. self, out_features: int, bias: bool = True, device=None, dtype=None
  243. ) -> None:
  244. factory_kwargs = {"device": device, "dtype": dtype}
  245. # bias is hardcoded to False to avoid creating tensor
  246. # that will soon be overwritten.
  247. # pyrefly: ignore [bad-argument-type]
  248. super().__init__(0, 0, False)
  249. # pyrefly: ignore [bad-argument-type, unexpected-keyword]
  250. self.weight = UninitializedParameter(**factory_kwargs)
  251. self.out_features = out_features
  252. if bias:
  253. # pyrefly: ignore [bad-argument-type, unexpected-keyword]
  254. self.bias = UninitializedParameter(**factory_kwargs)
  255. def reset_parameters(self) -> None:
  256. """
  257. Resets parameters based on their initialization used in ``__init__``.
  258. """
  259. # pyrefly: ignore [bad-argument-type]
  260. if not self.has_uninitialized_params() and self.in_features != 0:
  261. super().reset_parameters()
  262. def initialize_parameters(self, input) -> None: # type: ignore[override]
  263. """
  264. Infers ``in_features`` based on ``input`` and initializes parameters.
  265. """
  266. # pyrefly: ignore [bad-argument-type]
  267. if self.has_uninitialized_params():
  268. with torch.no_grad():
  269. self.in_features = input.shape[-1]
  270. self.weight.materialize((self.out_features, self.in_features))
  271. if self.bias is not None:
  272. self.bias.materialize((self.out_features,))
  273. self.reset_parameters()
  274. if self.in_features == 0:
  275. if input.shape[-1] != self.weight.shape[-1]:
  276. raise AssertionError(
  277. f"The in_features inferred from input: {input.shape[-1]} "
  278. f"is not equal to in_features from self.weight: "
  279. f"{self.weight.shape[-1]}"
  280. )
  281. self.in_features = input.shape[-1]
  282. # TODO: PartialLinear - maybe in sparse?