init.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754
  1. """This file contains utilities for initializing neural network parameters."""
  2. import math
  3. import warnings
  4. from collections.abc import Callable
  5. from typing import Literal, TypeVar
  6. from typing_extensions import ParamSpec
  7. import torch
  8. from torch import Tensor
  9. __all__ = [
  10. "calculate_gain",
  11. "uniform_",
  12. "normal_",
  13. "trunc_normal_",
  14. "constant_",
  15. "ones_",
  16. "zeros_",
  17. "eye_",
  18. "dirac_",
  19. "xavier_uniform_",
  20. "xavier_normal_",
  21. "kaiming_uniform_",
  22. "kaiming_normal_",
  23. "orthogonal_",
  24. "sparse_",
  25. # Deprecated aliases (for backward compatibility)
  26. "uniform",
  27. "normal",
  28. "constant",
  29. "eye",
  30. "dirac",
  31. "xavier_uniform",
  32. "xavier_normal",
  33. "kaiming_uniform",
  34. "kaiming_normal",
  35. "orthogonal",
  36. "sparse",
  37. ]
  38. _R = TypeVar("_R")
  39. _P = ParamSpec("_P")
  40. _NonlinearityType = Literal[
  41. "linear",
  42. "conv1d",
  43. "conv2d",
  44. "conv3d",
  45. "conv_transpose1d",
  46. "conv_transpose2d",
  47. "conv_transpose3d",
  48. "sigmoid",
  49. "tanh",
  50. "relu",
  51. "leaky_relu",
  52. "selu",
  53. ]
  54. _FanMode = Literal["fan_in", "fan_out"]
  55. # These no_grad_* functions are necessary as wrappers around the parts of these
  56. # functions that use `with torch.no_grad()`. The JIT doesn't support context
  57. # managers, so these need to be implemented as builtins. Using these wrappers
  58. # lets us keep those builtins small and reusable.
  59. def _no_grad_uniform_(
  60. tensor: Tensor, a: float, b: float, generator: torch.Generator | None = None
  61. ) -> Tensor:
  62. with torch.no_grad():
  63. return tensor.uniform_(a, b, generator=generator)
  64. def _no_grad_normal_(
  65. tensor: Tensor,
  66. mean: float,
  67. std: float,
  68. generator: torch.Generator | None = None,
  69. ) -> Tensor:
  70. with torch.no_grad():
  71. return tensor.normal_(mean, std, generator=generator)
  72. def _no_grad_trunc_normal_(
  73. tensor: Tensor,
  74. mean: float,
  75. std: float,
  76. a: float,
  77. b: float,
  78. generator: torch.Generator | None = None,
  79. ) -> Tensor:
  80. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  81. def norm_cdf(x: float) -> float:
  82. # Computes standard normal cumulative distribution function
  83. return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
  84. if (mean < a - 2 * std) or (mean > b + 2 * std):
  85. warnings.warn(
  86. "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  87. "The distribution of values may be incorrect.",
  88. stacklevel=2,
  89. )
  90. with torch.no_grad():
  91. # Values are generated by using a truncated uniform distribution and
  92. # then using the inverse CDF for the normal distribution.
  93. # Get upper and lower cdf values
  94. l = norm_cdf((a - mean) / std)
  95. u = norm_cdf((b - mean) / std)
  96. # Uniformly fill tensor with values from [l, u], then translate to
  97. # [2l-1, 2u-1].
  98. tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
  99. # Use inverse cdf transform for normal distribution to get truncated
  100. # standard normal
  101. tensor.erfinv_()
  102. # Transform to proper mean, std
  103. tensor.mul_(std * math.sqrt(2.0))
  104. tensor.add_(mean)
  105. # Clamp to ensure it's in the proper range
  106. tensor.clamp_(min=a, max=b)
  107. return tensor
  108. def _no_grad_fill_(tensor: Tensor, val: float) -> Tensor:
  109. with torch.no_grad():
  110. return tensor.fill_(val)
  111. def _no_grad_zero_(tensor: Tensor) -> Tensor:
  112. with torch.no_grad():
  113. return tensor.zero_()
  114. def calculate_gain(
  115. nonlinearity: _NonlinearityType, param: int | float | None = None
  116. ) -> float:
  117. r"""Return the recommended gain value for the given nonlinearity function.
  118. The values are as follows:
  119. ================= ====================================================
  120. nonlinearity gain
  121. ================= ====================================================
  122. Linear / Identity :math:`1`
  123. Conv{1,2,3}D :math:`1`
  124. Sigmoid :math:`1`
  125. Tanh :math:`\frac{5}{3}`
  126. ReLU :math:`\sqrt{2}`
  127. Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
  128. SELU :math:`\frac{3}{4}`
  129. ================= ====================================================
  130. .. warning::
  131. In order to implement `Self-Normalizing Neural Networks`_ ,
  132. you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
  133. This gives the initial weights a variance of ``1 / N``,
  134. which is necessary to induce a stable fixed point in the forward pass.
  135. In contrast, the default gain for ``SELU`` sacrifices the normalization
  136. effect for more stable gradient flow in rectangular layers.
  137. Args:
  138. nonlinearity: the non-linear function (`nn.functional` name)
  139. param: optional parameter for the non-linear function
  140. Examples:
  141. >>> gain = nn.init.calculate_gain(
  142. ... "leaky_relu", 0.2
  143. ... ) # leaky_relu with negative_slope=0.2
  144. .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
  145. """
  146. linear_fns = [
  147. "linear",
  148. "conv1d",
  149. "conv2d",
  150. "conv3d",
  151. "conv_transpose1d",
  152. "conv_transpose2d",
  153. "conv_transpose3d",
  154. ]
  155. if nonlinearity in linear_fns or nonlinearity == "sigmoid":
  156. return 1
  157. elif nonlinearity == "tanh":
  158. return 5.0 / 3
  159. elif nonlinearity == "relu":
  160. return math.sqrt(2.0)
  161. elif nonlinearity == "leaky_relu":
  162. if param is None:
  163. negative_slope = 0.01
  164. elif (
  165. not isinstance(param, bool)
  166. and isinstance(param, int)
  167. or isinstance(param, float)
  168. ):
  169. # True/False are instances of int, hence check above
  170. negative_slope = param
  171. else:
  172. raise ValueError(f"negative_slope {param} not a valid number")
  173. return math.sqrt(2.0 / (1 + negative_slope**2))
  174. elif nonlinearity == "selu":
  175. return (
  176. 3.0 / 4
  177. ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
  178. else:
  179. raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
  180. def uniform_(
  181. tensor: Tensor,
  182. a: float = 0.0,
  183. b: float = 1.0,
  184. generator: torch.Generator | None = None,
  185. ) -> Tensor:
  186. r"""Fill the input Tensor with values drawn from the uniform distribution.
  187. :math:`\mathcal{U}(a, b)`.
  188. Args:
  189. tensor: an n-dimensional `torch.Tensor`
  190. a: the lower bound of the uniform distribution
  191. b: the upper bound of the uniform distribution
  192. generator: the torch Generator to sample from (default: None)
  193. Examples:
  194. >>> w = torch.empty(3, 5)
  195. >>> nn.init.uniform_(w)
  196. """
  197. if torch.overrides.has_torch_function_variadic(tensor):
  198. return torch.overrides.handle_torch_function(
  199. uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
  200. )
  201. return _no_grad_uniform_(tensor, a, b, generator)
  202. def normal_(
  203. tensor: Tensor,
  204. mean: float = 0.0,
  205. std: float = 1.0,
  206. generator: torch.Generator | None = None,
  207. ) -> Tensor:
  208. r"""Fill the input Tensor with values drawn from the normal distribution.
  209. :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
  210. Args:
  211. tensor: an n-dimensional `torch.Tensor`
  212. mean: the mean of the normal distribution
  213. std: the standard deviation of the normal distribution
  214. generator: the torch Generator to sample from (default: None)
  215. Examples:
  216. >>> w = torch.empty(3, 5)
  217. >>> nn.init.normal_(w)
  218. """
  219. if torch.overrides.has_torch_function_variadic(tensor):
  220. return torch.overrides.handle_torch_function(
  221. normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
  222. )
  223. return _no_grad_normal_(tensor, mean, std, generator)
  224. def trunc_normal_(
  225. tensor: Tensor,
  226. mean: float = 0.0,
  227. std: float = 1.0,
  228. a: float = -2.0,
  229. b: float = 2.0,
  230. generator: torch.Generator | None = None,
  231. ) -> Tensor:
  232. r"""Fill the input Tensor with values drawn from a truncated normal distribution.
  233. The values are effectively drawn from the
  234. normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
  235. with values outside :math:`[a, b]` redrawn until they are within
  236. the bounds. The method used for generating the random values works
  237. best when :math:`a \leq \text{mean} \leq b`.
  238. Args:
  239. tensor: an n-dimensional `torch.Tensor`
  240. mean: the mean of the normal distribution
  241. std: the standard deviation of the normal distribution
  242. a: the minimum cutoff value
  243. b: the maximum cutoff value
  244. generator: the torch Generator to sample from (default: None)
  245. Examples:
  246. >>> w = torch.empty(3, 5)
  247. >>> nn.init.trunc_normal_(w)
  248. """
  249. return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
  250. def constant_(tensor: Tensor, val: float) -> Tensor:
  251. r"""Fill the input Tensor with the value :math:`\text{val}`.
  252. Args:
  253. tensor: an n-dimensional `torch.Tensor`
  254. val: the value to fill the tensor with
  255. Examples:
  256. >>> w = torch.empty(3, 5)
  257. >>> nn.init.constant_(w, 0.3)
  258. """
  259. if torch.overrides.has_torch_function_variadic(tensor):
  260. return torch.overrides.handle_torch_function(
  261. constant_, (tensor,), tensor=tensor, val=val
  262. )
  263. return _no_grad_fill_(tensor, val)
  264. def ones_(tensor: Tensor) -> Tensor:
  265. r"""Fill the input Tensor with the scalar value `1`.
  266. Args:
  267. tensor: an n-dimensional `torch.Tensor`
  268. Examples:
  269. >>> w = torch.empty(3, 5)
  270. >>> nn.init.ones_(w)
  271. """
  272. return _no_grad_fill_(tensor, 1.0)
  273. def zeros_(tensor: Tensor) -> Tensor:
  274. r"""Fill the input Tensor with the scalar value `0`.
  275. Args:
  276. tensor: an n-dimensional `torch.Tensor`
  277. Examples:
  278. >>> w = torch.empty(3, 5)
  279. >>> nn.init.zeros_(w)
  280. """
  281. return _no_grad_zero_(tensor)
  282. def eye_(tensor: Tensor) -> Tensor:
  283. r"""Fill the 2-dimensional input `Tensor` with the identity matrix.
  284. Preserves the identity of the inputs in `Linear` layers, where as
  285. many inputs are preserved as possible.
  286. Args:
  287. tensor: a 2-dimensional `torch.Tensor`
  288. Examples:
  289. >>> w = torch.empty(3, 5)
  290. >>> nn.init.eye_(w)
  291. """
  292. if tensor.ndimension() != 2:
  293. raise ValueError("Only tensors with 2 dimensions are supported")
  294. with torch.no_grad():
  295. torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
  296. return tensor
  297. def dirac_(tensor: Tensor, groups: int = 1) -> Tensor:
  298. r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function.
  299. Preserves the identity of the inputs in `Convolutional`
  300. layers, where as many input channels are preserved as possible. In case
  301. of groups>1, each group of channels preserves identity
  302. Args:
  303. tensor: a {3, 4, 5}-dimensional `torch.Tensor`
  304. groups (int, optional): number of groups in the conv layer (default: 1)
  305. Examples:
  306. >>> w = torch.empty(3, 16, 5, 5)
  307. >>> nn.init.dirac_(w)
  308. >>> w = torch.empty(3, 24, 5, 5)
  309. >>> nn.init.dirac_(w, 3)
  310. """
  311. dimensions = tensor.ndimension()
  312. if dimensions not in [3, 4, 5]:
  313. raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
  314. sizes = tensor.size()
  315. if sizes[0] % groups != 0:
  316. raise ValueError("dim 0 must be divisible by groups")
  317. out_chans_per_grp = sizes[0] // groups
  318. min_dim = min(out_chans_per_grp, sizes[1])
  319. with torch.no_grad():
  320. tensor.zero_()
  321. for g in range(groups):
  322. for d in range(min_dim):
  323. if dimensions == 3: # Temporal convolution
  324. tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
  325. elif dimensions == 4: # Spatial convolution
  326. tensor[
  327. g * out_chans_per_grp + d,
  328. d,
  329. tensor.size(2) // 2,
  330. tensor.size(3) // 2,
  331. ] = 1
  332. else: # Volumetric convolution
  333. tensor[
  334. g * out_chans_per_grp + d,
  335. d,
  336. tensor.size(2) // 2,
  337. tensor.size(3) // 2,
  338. tensor.size(4) // 2,
  339. ] = 1
  340. return tensor
  341. def _calculate_fan_in_and_fan_out(tensor: Tensor) -> tuple[int, int]:
  342. dimensions = tensor.dim()
  343. if dimensions < 2:
  344. raise ValueError(
  345. "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
  346. )
  347. num_input_fmaps = tensor.size(1)
  348. num_output_fmaps = tensor.size(0)
  349. receptive_field_size = 1
  350. if tensor.dim() > 2:
  351. # math.prod is not always available, accumulate the product manually
  352. # we could use functools.reduce but that is not supported by TorchScript
  353. for s in tensor.shape[2:]:
  354. receptive_field_size *= s
  355. fan_in = num_input_fmaps * receptive_field_size
  356. fan_out = num_output_fmaps * receptive_field_size
  357. return fan_in, fan_out
  358. def xavier_uniform_(
  359. tensor: Tensor,
  360. gain: float = 1.0,
  361. generator: torch.Generator | None = None,
  362. ) -> Tensor:
  363. r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
  364. The method is described in `Understanding the difficulty of training
  365. deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
  366. The resulting tensor will have values sampled from
  367. :math:`\mathcal{U}(-a, a)` where
  368. .. math::
  369. a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
  370. Also known as Glorot initialization.
  371. Args:
  372. tensor: an n-dimensional `torch.Tensor`
  373. gain: an optional scaling factor
  374. generator: the torch Generator to sample from (default: None)
  375. Examples:
  376. >>> w = torch.empty(3, 5)
  377. >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu"))
  378. """
  379. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  380. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  381. a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  382. return _no_grad_uniform_(tensor, -a, a, generator)
  383. def xavier_normal_(
  384. tensor: Tensor,
  385. gain: float = 1.0,
  386. generator: torch.Generator | None = None,
  387. ) -> Tensor:
  388. r"""Fill the input `Tensor` with values using a Xavier normal distribution.
  389. The method is described in `Understanding the difficulty of training deep feedforward
  390. neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
  391. will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
  392. .. math::
  393. \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
  394. Also known as Glorot initialization.
  395. Args:
  396. tensor: an n-dimensional `torch.Tensor`
  397. gain: an optional scaling factor
  398. generator: the torch Generator to sample from (default: None)
  399. Examples:
  400. >>> w = torch.empty(3, 5)
  401. >>> nn.init.xavier_normal_(w)
  402. """
  403. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  404. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  405. return _no_grad_normal_(tensor, 0.0, std, generator)
  406. def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
  407. # pyrefly: ignore [bad-assignment]
  408. mode = mode.lower()
  409. valid_modes = ["fan_in", "fan_out"]
  410. if mode not in valid_modes:
  411. raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
  412. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  413. return fan_in if mode == "fan_in" else fan_out
  414. def kaiming_uniform_(
  415. tensor: Tensor,
  416. a: float = 0,
  417. mode: _FanMode = "fan_in",
  418. nonlinearity: _NonlinearityType = "leaky_relu",
  419. generator: torch.Generator | None = None,
  420. ) -> Tensor:
  421. r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
  422. The method is described in `Delving deep into rectifiers: Surpassing
  423. human-level performance on ImageNet classification` - He, K. et al. (2015).
  424. The resulting tensor will have values sampled from
  425. :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
  426. .. math::
  427. \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
  428. Also known as He initialization.
  429. Args:
  430. tensor: an n-dimensional `torch.Tensor`
  431. a: the negative slope of the rectifier used after this layer (only
  432. used with ``'leaky_relu'``)
  433. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  434. preserves the magnitude of the variance of the weights in the
  435. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  436. backwards pass.
  437. nonlinearity: the non-linear function (`nn.functional` name),
  438. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  439. generator: the torch Generator to sample from (default: None)
  440. Examples:
  441. >>> w = torch.empty(3, 5)
  442. >>> nn.init.kaiming_uniform_(w, mode="fan_in", nonlinearity="relu")
  443. Note:
  444. Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
  445. that the weight matrix is used in a transposed manner,
  446. (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
  447. This is important for correct initialization.
  448. If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
  449. pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``.
  450. """
  451. if torch.overrides.has_torch_function_variadic(tensor):
  452. return torch.overrides.handle_torch_function(
  453. kaiming_uniform_,
  454. (tensor,),
  455. tensor=tensor,
  456. a=a,
  457. mode=mode,
  458. nonlinearity=nonlinearity,
  459. generator=generator,
  460. )
  461. if 0 in tensor.shape:
  462. warnings.warn("Initializing zero-element tensors is a no-op", stacklevel=2)
  463. return tensor
  464. fan = _calculate_correct_fan(tensor, mode)
  465. gain = calculate_gain(nonlinearity, a)
  466. std = gain / math.sqrt(fan)
  467. bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  468. with torch.no_grad():
  469. return tensor.uniform_(-bound, bound, generator=generator)
  470. def kaiming_normal_(
  471. tensor: Tensor,
  472. a: float = 0,
  473. mode: _FanMode = "fan_in",
  474. nonlinearity: _NonlinearityType = "leaky_relu",
  475. generator: torch.Generator | None = None,
  476. ) -> Tensor:
  477. r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
  478. The method is described in `Delving deep into rectifiers: Surpassing
  479. human-level performance on ImageNet classification` - He, K. et al. (2015).
  480. The resulting tensor will have values sampled from
  481. :math:`\mathcal{N}(0, \text{std}^2)` where
  482. .. math::
  483. \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
  484. Also known as He initialization.
  485. Args:
  486. tensor: an n-dimensional `torch.Tensor`
  487. a: the negative slope of the rectifier used after this layer (only
  488. used with ``'leaky_relu'``)
  489. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  490. preserves the magnitude of the variance of the weights in the
  491. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  492. backwards pass.
  493. nonlinearity: the non-linear function (`nn.functional` name),
  494. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  495. generator: the torch Generator to sample from (default: None)
  496. Examples:
  497. >>> w = torch.empty(3, 5)
  498. >>> nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu")
  499. Note:
  500. Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
  501. that the weight matrix is used in a transposed manner,
  502. (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
  503. This is important for correct initialization.
  504. If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
  505. pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``.
  506. """
  507. if 0 in tensor.shape:
  508. warnings.warn("Initializing zero-element tensors is a no-op", stacklevel=2)
  509. return tensor
  510. fan = _calculate_correct_fan(tensor, mode)
  511. gain = calculate_gain(nonlinearity, a)
  512. std = gain / math.sqrt(fan)
  513. with torch.no_grad():
  514. return tensor.normal_(0, std, generator=generator)
  515. def orthogonal_(
  516. tensor: Tensor,
  517. gain: float = 1,
  518. generator: torch.Generator | None = None,
  519. ) -> Tensor:
  520. r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
  521. Described in `Exact solutions to the nonlinear dynamics of learning in deep
  522. linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
  523. at least 2 dimensions, and for tensors with more than 2 dimensions the
  524. trailing dimensions are flattened.
  525. Args:
  526. tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
  527. gain: optional scaling factor
  528. generator: the torch Generator to sample from (default: None)
  529. Examples:
  530. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  531. >>> w = torch.empty(3, 5)
  532. >>> nn.init.orthogonal_(w)
  533. """
  534. if tensor.ndimension() < 2:
  535. raise ValueError("Only tensors with 2 or more dimensions are supported")
  536. if tensor.numel() == 0:
  537. # no-op
  538. return tensor
  539. rows = tensor.size(0)
  540. cols = tensor.numel() // rows
  541. flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator)
  542. if rows < cols:
  543. flattened.t_()
  544. # Compute the qr factorization
  545. q, r = torch.linalg.qr(flattened)
  546. # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
  547. d = torch.diag(r, 0)
  548. ph = d.sign()
  549. q *= ph
  550. if rows < cols:
  551. q.t_()
  552. with torch.no_grad():
  553. tensor.view_as(q).copy_(q)
  554. tensor.mul_(gain)
  555. return tensor
  556. def sparse_(
  557. tensor: Tensor,
  558. sparsity: float,
  559. std: float = 0.01,
  560. generator: torch.Generator | None = None,
  561. ) -> Tensor:
  562. r"""Fill the 2D input `Tensor` as a sparse matrix.
  563. The non-zero elements will be drawn from the normal distribution
  564. :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
  565. Hessian-free optimization` - Martens, J. (2010).
  566. Args:
  567. tensor: an n-dimensional `torch.Tensor`
  568. sparsity: The fraction of elements in each column to be set to zero
  569. std: the standard deviation of the normal distribution used to generate
  570. the non-zero values
  571. generator: the torch Generator to sample from (default: None)
  572. Examples:
  573. >>> w = torch.empty(3, 5)
  574. >>> nn.init.sparse_(w, sparsity=0.1)
  575. """
  576. if tensor.ndimension() != 2:
  577. raise ValueError("Only tensors with 2 dimensions are supported")
  578. rows, cols = tensor.shape
  579. num_zeros = math.ceil(sparsity * rows)
  580. with torch.no_grad():
  581. tensor.normal_(0, std, generator=generator)
  582. for col_idx in range(cols):
  583. row_indices = torch.randperm(rows)
  584. zero_indices = row_indices[:num_zeros]
  585. tensor[zero_indices, col_idx] = 0
  586. return tensor
  587. # for backward compatibility
  588. def _make_deprecate(meth: Callable[_P, _R]) -> Callable[_P, _R]:
  589. new_name = meth.__name__
  590. old_name = new_name[:-1]
  591. def deprecated_init(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  592. warnings.warn(
  593. f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.",
  594. FutureWarning,
  595. stacklevel=2,
  596. )
  597. return meth(*args, **kwargs)
  598. deprecated_init.__doc__ = rf"""
  599. {old_name}(...)
  600. .. warning::
  601. This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
  602. See :func:`~torch.nn.init.{new_name}` for details."""
  603. deprecated_init.__name__ = old_name
  604. return deprecated_init
  605. uniform = _make_deprecate(uniform_)
  606. normal = _make_deprecate(normal_)
  607. constant = _make_deprecate(constant_)
  608. eye = _make_deprecate(eye_)
  609. dirac = _make_deprecate(dirac_)
  610. xavier_uniform = _make_deprecate(xavier_uniform_)
  611. xavier_normal = _make_deprecate(xavier_normal_)
  612. kaiming_uniform = _make_deprecate(kaiming_uniform_)
  613. kaiming_normal = _make_deprecate(kaiming_normal_)
  614. orthogonal = _make_deprecate(orthogonal_)
  615. sparse = _make_deprecate(sparse_)