| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754 |
- """This file contains utilities for initializing neural network parameters."""
- import math
- import warnings
- from collections.abc import Callable
- from typing import Literal, TypeVar
- from typing_extensions import ParamSpec
- import torch
- from torch import Tensor
- __all__ = [
- "calculate_gain",
- "uniform_",
- "normal_",
- "trunc_normal_",
- "constant_",
- "ones_",
- "zeros_",
- "eye_",
- "dirac_",
- "xavier_uniform_",
- "xavier_normal_",
- "kaiming_uniform_",
- "kaiming_normal_",
- "orthogonal_",
- "sparse_",
- # Deprecated aliases (for backward compatibility)
- "uniform",
- "normal",
- "constant",
- "eye",
- "dirac",
- "xavier_uniform",
- "xavier_normal",
- "kaiming_uniform",
- "kaiming_normal",
- "orthogonal",
- "sparse",
- ]
- _R = TypeVar("_R")
- _P = ParamSpec("_P")
- _NonlinearityType = Literal[
- "linear",
- "conv1d",
- "conv2d",
- "conv3d",
- "conv_transpose1d",
- "conv_transpose2d",
- "conv_transpose3d",
- "sigmoid",
- "tanh",
- "relu",
- "leaky_relu",
- "selu",
- ]
- _FanMode = Literal["fan_in", "fan_out"]
- # These no_grad_* functions are necessary as wrappers around the parts of these
- # functions that use `with torch.no_grad()`. The JIT doesn't support context
- # managers, so these need to be implemented as builtins. Using these wrappers
- # lets us keep those builtins small and reusable.
- def _no_grad_uniform_(
- tensor: Tensor, a: float, b: float, generator: torch.Generator | None = None
- ) -> Tensor:
- with torch.no_grad():
- return tensor.uniform_(a, b, generator=generator)
- def _no_grad_normal_(
- tensor: Tensor,
- mean: float,
- std: float,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- with torch.no_grad():
- return tensor.normal_(mean, std, generator=generator)
- def _no_grad_trunc_normal_(
- tensor: Tensor,
- mean: float,
- std: float,
- a: float,
- b: float,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
- def norm_cdf(x: float) -> float:
- # Computes standard normal cumulative distribution function
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
- if (mean < a - 2 * std) or (mean > b + 2 * std):
- warnings.warn(
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
- "The distribution of values may be incorrect.",
- stacklevel=2,
- )
- with torch.no_grad():
- # Values are generated by using a truncated uniform distribution and
- # then using the inverse CDF for the normal distribution.
- # Get upper and lower cdf values
- l = norm_cdf((a - mean) / std)
- u = norm_cdf((b - mean) / std)
- # Uniformly fill tensor with values from [l, u], then translate to
- # [2l-1, 2u-1].
- tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
- # Use inverse cdf transform for normal distribution to get truncated
- # standard normal
- tensor.erfinv_()
- # Transform to proper mean, std
- tensor.mul_(std * math.sqrt(2.0))
- tensor.add_(mean)
- # Clamp to ensure it's in the proper range
- tensor.clamp_(min=a, max=b)
- return tensor
- def _no_grad_fill_(tensor: Tensor, val: float) -> Tensor:
- with torch.no_grad():
- return tensor.fill_(val)
- def _no_grad_zero_(tensor: Tensor) -> Tensor:
- with torch.no_grad():
- return tensor.zero_()
- def calculate_gain(
- nonlinearity: _NonlinearityType, param: int | float | None = None
- ) -> float:
- r"""Return the recommended gain value for the given nonlinearity function.
- The values are as follows:
- ================= ====================================================
- nonlinearity gain
- ================= ====================================================
- Linear / Identity :math:`1`
- Conv{1,2,3}D :math:`1`
- Sigmoid :math:`1`
- Tanh :math:`\frac{5}{3}`
- ReLU :math:`\sqrt{2}`
- Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
- SELU :math:`\frac{3}{4}`
- ================= ====================================================
- .. warning::
- In order to implement `Self-Normalizing Neural Networks`_ ,
- you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
- This gives the initial weights a variance of ``1 / N``,
- which is necessary to induce a stable fixed point in the forward pass.
- In contrast, the default gain for ``SELU`` sacrifices the normalization
- effect for more stable gradient flow in rectangular layers.
- Args:
- nonlinearity: the non-linear function (`nn.functional` name)
- param: optional parameter for the non-linear function
- Examples:
- >>> gain = nn.init.calculate_gain(
- ... "leaky_relu", 0.2
- ... ) # leaky_relu with negative_slope=0.2
- .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
- """
- linear_fns = [
- "linear",
- "conv1d",
- "conv2d",
- "conv3d",
- "conv_transpose1d",
- "conv_transpose2d",
- "conv_transpose3d",
- ]
- if nonlinearity in linear_fns or nonlinearity == "sigmoid":
- return 1
- elif nonlinearity == "tanh":
- return 5.0 / 3
- elif nonlinearity == "relu":
- return math.sqrt(2.0)
- elif nonlinearity == "leaky_relu":
- if param is None:
- negative_slope = 0.01
- elif (
- not isinstance(param, bool)
- and isinstance(param, int)
- or isinstance(param, float)
- ):
- # True/False are instances of int, hence check above
- negative_slope = param
- else:
- raise ValueError(f"negative_slope {param} not a valid number")
- return math.sqrt(2.0 / (1 + negative_slope**2))
- elif nonlinearity == "selu":
- return (
- 3.0 / 4
- ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
- else:
- raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
- def uniform_(
- tensor: Tensor,
- a: float = 0.0,
- b: float = 1.0,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the input Tensor with values drawn from the uniform distribution.
- :math:`\mathcal{U}(a, b)`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- a: the lower bound of the uniform distribution
- b: the upper bound of the uniform distribution
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.uniform_(w)
- """
- if torch.overrides.has_torch_function_variadic(tensor):
- return torch.overrides.handle_torch_function(
- uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
- )
- return _no_grad_uniform_(tensor, a, b, generator)
- def normal_(
- tensor: Tensor,
- mean: float = 0.0,
- std: float = 1.0,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the input Tensor with values drawn from the normal distribution.
- :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.normal_(w)
- """
- if torch.overrides.has_torch_function_variadic(tensor):
- return torch.overrides.handle_torch_function(
- normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
- )
- return _no_grad_normal_(tensor, mean, std, generator)
- def trunc_normal_(
- tensor: Tensor,
- mean: float = 0.0,
- std: float = 1.0,
- a: float = -2.0,
- b: float = 2.0,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the input Tensor with values drawn from a truncated normal distribution.
- The values are effectively drawn from the
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
- with values outside :math:`[a, b]` redrawn until they are within
- the bounds. The method used for generating the random values works
- best when :math:`a \leq \text{mean} \leq b`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- a: the minimum cutoff value
- b: the maximum cutoff value
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.trunc_normal_(w)
- """
- return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
- def constant_(tensor: Tensor, val: float) -> Tensor:
- r"""Fill the input Tensor with the value :math:`\text{val}`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- val: the value to fill the tensor with
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.constant_(w, 0.3)
- """
- if torch.overrides.has_torch_function_variadic(tensor):
- return torch.overrides.handle_torch_function(
- constant_, (tensor,), tensor=tensor, val=val
- )
- return _no_grad_fill_(tensor, val)
- def ones_(tensor: Tensor) -> Tensor:
- r"""Fill the input Tensor with the scalar value `1`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.ones_(w)
- """
- return _no_grad_fill_(tensor, 1.0)
- def zeros_(tensor: Tensor) -> Tensor:
- r"""Fill the input Tensor with the scalar value `0`.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.zeros_(w)
- """
- return _no_grad_zero_(tensor)
- def eye_(tensor: Tensor) -> Tensor:
- r"""Fill the 2-dimensional input `Tensor` with the identity matrix.
- Preserves the identity of the inputs in `Linear` layers, where as
- many inputs are preserved as possible.
- Args:
- tensor: a 2-dimensional `torch.Tensor`
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.eye_(w)
- """
- if tensor.ndimension() != 2:
- raise ValueError("Only tensors with 2 dimensions are supported")
- with torch.no_grad():
- torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
- return tensor
- def dirac_(tensor: Tensor, groups: int = 1) -> Tensor:
- r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function.
- Preserves the identity of the inputs in `Convolutional`
- layers, where as many input channels are preserved as possible. In case
- of groups>1, each group of channels preserves identity
- Args:
- tensor: a {3, 4, 5}-dimensional `torch.Tensor`
- groups (int, optional): number of groups in the conv layer (default: 1)
- Examples:
- >>> w = torch.empty(3, 16, 5, 5)
- >>> nn.init.dirac_(w)
- >>> w = torch.empty(3, 24, 5, 5)
- >>> nn.init.dirac_(w, 3)
- """
- dimensions = tensor.ndimension()
- if dimensions not in [3, 4, 5]:
- raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
- sizes = tensor.size()
- if sizes[0] % groups != 0:
- raise ValueError("dim 0 must be divisible by groups")
- out_chans_per_grp = sizes[0] // groups
- min_dim = min(out_chans_per_grp, sizes[1])
- with torch.no_grad():
- tensor.zero_()
- for g in range(groups):
- for d in range(min_dim):
- if dimensions == 3: # Temporal convolution
- tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
- elif dimensions == 4: # Spatial convolution
- tensor[
- g * out_chans_per_grp + d,
- d,
- tensor.size(2) // 2,
- tensor.size(3) // 2,
- ] = 1
- else: # Volumetric convolution
- tensor[
- g * out_chans_per_grp + d,
- d,
- tensor.size(2) // 2,
- tensor.size(3) // 2,
- tensor.size(4) // 2,
- ] = 1
- return tensor
- def _calculate_fan_in_and_fan_out(tensor: Tensor) -> tuple[int, int]:
- dimensions = tensor.dim()
- if dimensions < 2:
- raise ValueError(
- "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
- )
- num_input_fmaps = tensor.size(1)
- num_output_fmaps = tensor.size(0)
- receptive_field_size = 1
- if tensor.dim() > 2:
- # math.prod is not always available, accumulate the product manually
- # we could use functools.reduce but that is not supported by TorchScript
- for s in tensor.shape[2:]:
- receptive_field_size *= s
- fan_in = num_input_fmaps * receptive_field_size
- fan_out = num_output_fmaps * receptive_field_size
- return fan_in, fan_out
- def xavier_uniform_(
- tensor: Tensor,
- gain: float = 1.0,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
- The method is described in `Understanding the difficulty of training
- deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
- The resulting tensor will have values sampled from
- :math:`\mathcal{U}(-a, a)` where
- .. math::
- a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
- Also known as Glorot initialization.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- gain: an optional scaling factor
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu"))
- """
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
- std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
- a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
- return _no_grad_uniform_(tensor, -a, a, generator)
- def xavier_normal_(
- tensor: Tensor,
- gain: float = 1.0,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the input `Tensor` with values using a Xavier normal distribution.
- The method is described in `Understanding the difficulty of training deep feedforward
- neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
- will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
- .. math::
- \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
- Also known as Glorot initialization.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- gain: an optional scaling factor
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.xavier_normal_(w)
- """
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
- std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
- return _no_grad_normal_(tensor, 0.0, std, generator)
- def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
- # pyrefly: ignore [bad-assignment]
- mode = mode.lower()
- valid_modes = ["fan_in", "fan_out"]
- if mode not in valid_modes:
- raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
- return fan_in if mode == "fan_in" else fan_out
- def kaiming_uniform_(
- tensor: Tensor,
- a: float = 0,
- mode: _FanMode = "fan_in",
- nonlinearity: _NonlinearityType = "leaky_relu",
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
- The method is described in `Delving deep into rectifiers: Surpassing
- human-level performance on ImageNet classification` - He, K. et al. (2015).
- The resulting tensor will have values sampled from
- :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
- .. math::
- \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
- Also known as He initialization.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- a: the negative slope of the rectifier used after this layer (only
- used with ``'leaky_relu'``)
- mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
- preserves the magnitude of the variance of the weights in the
- forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
- backwards pass.
- nonlinearity: the non-linear function (`nn.functional` name),
- recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.kaiming_uniform_(w, mode="fan_in", nonlinearity="relu")
- Note:
- Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
- that the weight matrix is used in a transposed manner,
- (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
- This is important for correct initialization.
- If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
- pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``.
- """
- if torch.overrides.has_torch_function_variadic(tensor):
- return torch.overrides.handle_torch_function(
- kaiming_uniform_,
- (tensor,),
- tensor=tensor,
- a=a,
- mode=mode,
- nonlinearity=nonlinearity,
- generator=generator,
- )
- if 0 in tensor.shape:
- warnings.warn("Initializing zero-element tensors is a no-op", stacklevel=2)
- return tensor
- fan = _calculate_correct_fan(tensor, mode)
- gain = calculate_gain(nonlinearity, a)
- std = gain / math.sqrt(fan)
- bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
- with torch.no_grad():
- return tensor.uniform_(-bound, bound, generator=generator)
- def kaiming_normal_(
- tensor: Tensor,
- a: float = 0,
- mode: _FanMode = "fan_in",
- nonlinearity: _NonlinearityType = "leaky_relu",
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
- The method is described in `Delving deep into rectifiers: Surpassing
- human-level performance on ImageNet classification` - He, K. et al. (2015).
- The resulting tensor will have values sampled from
- :math:`\mathcal{N}(0, \text{std}^2)` where
- .. math::
- \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
- Also known as He initialization.
- Args:
- tensor: an n-dimensional `torch.Tensor`
- a: the negative slope of the rectifier used after this layer (only
- used with ``'leaky_relu'``)
- mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
- preserves the magnitude of the variance of the weights in the
- forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
- backwards pass.
- nonlinearity: the non-linear function (`nn.functional` name),
- recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu")
- Note:
- Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
- that the weight matrix is used in a transposed manner,
- (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
- This is important for correct initialization.
- If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
- pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``.
- """
- if 0 in tensor.shape:
- warnings.warn("Initializing zero-element tensors is a no-op", stacklevel=2)
- return tensor
- fan = _calculate_correct_fan(tensor, mode)
- gain = calculate_gain(nonlinearity, a)
- std = gain / math.sqrt(fan)
- with torch.no_grad():
- return tensor.normal_(0, std, generator=generator)
- def orthogonal_(
- tensor: Tensor,
- gain: float = 1,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
- Described in `Exact solutions to the nonlinear dynamics of learning in deep
- linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
- at least 2 dimensions, and for tensors with more than 2 dimensions the
- trailing dimensions are flattened.
- Args:
- tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
- gain: optional scaling factor
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
- >>> w = torch.empty(3, 5)
- >>> nn.init.orthogonal_(w)
- """
- if tensor.ndimension() < 2:
- raise ValueError("Only tensors with 2 or more dimensions are supported")
- if tensor.numel() == 0:
- # no-op
- return tensor
- rows = tensor.size(0)
- cols = tensor.numel() // rows
- flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator)
- if rows < cols:
- flattened.t_()
- # Compute the qr factorization
- q, r = torch.linalg.qr(flattened)
- # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
- d = torch.diag(r, 0)
- ph = d.sign()
- q *= ph
- if rows < cols:
- q.t_()
- with torch.no_grad():
- tensor.view_as(q).copy_(q)
- tensor.mul_(gain)
- return tensor
- def sparse_(
- tensor: Tensor,
- sparsity: float,
- std: float = 0.01,
- generator: torch.Generator | None = None,
- ) -> Tensor:
- r"""Fill the 2D input `Tensor` as a sparse matrix.
- The non-zero elements will be drawn from the normal distribution
- :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
- Hessian-free optimization` - Martens, J. (2010).
- Args:
- tensor: an n-dimensional `torch.Tensor`
- sparsity: The fraction of elements in each column to be set to zero
- std: the standard deviation of the normal distribution used to generate
- the non-zero values
- generator: the torch Generator to sample from (default: None)
- Examples:
- >>> w = torch.empty(3, 5)
- >>> nn.init.sparse_(w, sparsity=0.1)
- """
- if tensor.ndimension() != 2:
- raise ValueError("Only tensors with 2 dimensions are supported")
- rows, cols = tensor.shape
- num_zeros = math.ceil(sparsity * rows)
- with torch.no_grad():
- tensor.normal_(0, std, generator=generator)
- for col_idx in range(cols):
- row_indices = torch.randperm(rows)
- zero_indices = row_indices[:num_zeros]
- tensor[zero_indices, col_idx] = 0
- return tensor
- # for backward compatibility
- def _make_deprecate(meth: Callable[_P, _R]) -> Callable[_P, _R]:
- new_name = meth.__name__
- old_name = new_name[:-1]
- def deprecated_init(*args: _P.args, **kwargs: _P.kwargs) -> _R:
- warnings.warn(
- f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.",
- FutureWarning,
- stacklevel=2,
- )
- return meth(*args, **kwargs)
- deprecated_init.__doc__ = rf"""
- {old_name}(...)
- .. warning::
- This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
- See :func:`~torch.nn.init.{new_name}` for details."""
- deprecated_init.__name__ = old_name
- return deprecated_init
- uniform = _make_deprecate(uniform_)
- normal = _make_deprecate(normal_)
- constant = _make_deprecate(constant_)
- eye = _make_deprecate(eye_)
- dirac = _make_deprecate(dirac_)
- xavier_uniform = _make_deprecate(xavier_uniform_)
- xavier_normal = _make_deprecate(xavier_normal_)
- kaiming_uniform = _make_deprecate(kaiming_uniform_)
- kaiming_normal = _make_deprecate(kaiming_normal_)
- orthogonal = _make_deprecate(orthogonal_)
- sparse = _make_deprecate(sparse_)
|