| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- """Implement various linear algebra algorithms for low rank matrices."""
- __all__ = ["svd_lowrank", "pca_lowrank"]
- import torch
- from torch import _linalg_utils as _utils, Tensor
- from torch.overrides import handle_torch_function, has_torch_function
- def get_approximate_basis(
- A: Tensor,
- q: int,
- niter: int | None = 2,
- M: Tensor | None = None,
- ) -> Tensor:
- """Return tensor :math:`Q` with :math:`q` orthonormal columns such
- that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
- specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
- approximates :math:`A - M`. without instantiating any tensors
- of the size of :math:`A` or :math:`M`.
- .. note:: The implementation is based on the Algorithm 4.4 from
- Halko et al., 2009.
- .. note:: For an adequate approximation of a k-rank matrix
- :math:`A`, where k is not known in advance but could be
- estimated, the number of :math:`Q` columns, q, can be
- chosen according to the following criteria: in general,
- :math:`k <= q <= min(2*k, m, n)`. For large low-rank
- matrices, take :math:`q = k + 5..10`. If k is
- relatively small compared to :math:`min(m, n)`, choosing
- :math:`q = k + 0..2` may be sufficient.
- .. note:: To obtain repeatable results, reset the seed for the
- pseudorandom number generator
- Args::
- A (Tensor): the input tensor of size :math:`(*, m, n)`
- q (int): the dimension of subspace spanned by :math:`Q`
- columns.
- niter (int, optional): the number of subspace iterations to
- conduct; ``niter`` must be a
- nonnegative integer. In most cases, the
- default value 2 is more than enough.
- M (Tensor, optional): the input tensor's mean of size
- :math:`(*, m, n)`.
- References::
- - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
- structure with randomness: probabilistic algorithms for
- constructing approximate matrix decompositions,
- arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
- `arXiv <http://arxiv.org/abs/0909.4061>`_).
- """
- niter = 2 if niter is None else niter
- dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype
- matmul = _utils.matmul
- R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device)
- # The following code could be made faster using torch.geqrf + torch.ormqr
- # but geqrf is not differentiable
- X = matmul(A, R)
- if M is not None:
- X = X - matmul(M, R)
- Q = torch.linalg.qr(X).Q
- for _ in range(niter):
- X = matmul(A.mH, Q)
- if M is not None:
- X = X - matmul(M.mH, Q)
- Q = torch.linalg.qr(X).Q
- X = matmul(A, Q)
- if M is not None:
- X = X - matmul(M, Q)
- Q = torch.linalg.qr(X).Q
- return Q
- def svd_lowrank(
- A: Tensor,
- q: int | None = 6,
- niter: int | None = 2,
- M: Tensor | None = None,
- ) -> tuple[Tensor, Tensor, Tensor]:
- r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
- batches of matrices, or a sparse matrix :math:`A` such that
- :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
- SVD is computed for the matrix :math:`A - M`.
- .. note:: The implementation is based on the Algorithm 5.1 from
- Halko et al., 2009.
- .. note:: For an adequate approximation of a k-rank matrix
- :math:`A`, where k is not known in advance but could be
- estimated, the number of :math:`Q` columns, q, can be
- chosen according to the following criteria: in general,
- :math:`k <= q <= min(2*k, m, n)`. For large low-rank
- matrices, take :math:`q = k + 5..10`. If k is
- relatively small compared to :math:`min(m, n)`, choosing
- :math:`q = k + 0..2` may be sufficient.
- .. note:: This is a randomized method. To obtain repeatable results,
- set the seed for the pseudorandom number generator
- .. note:: In general, use the full-rank SVD implementation
- :func:`torch.linalg.svd` for dense matrices due to its 10x
- higher performance characteristics. The low-rank SVD
- will be useful for huge sparse matrices that
- :func:`torch.linalg.svd` cannot handle.
- Args::
- A (Tensor): the input tensor of size :math:`(*, m, n)`
- q (int, optional): a slightly overestimated rank of A.
- niter (int, optional): the number of subspace iterations to
- conduct; niter must be a nonnegative
- integer, and defaults to 2
- M (Tensor, optional): the input tensor's mean of size
- :math:`(*, m, n)`, which will be broadcasted
- to the size of A in this function.
- References::
- - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
- structure with randomness: probabilistic algorithms for
- constructing approximate matrix decompositions,
- arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
- `arXiv <https://arxiv.org/abs/0909.4061>`_).
- """
- if not torch.jit.is_scripting():
- tensor_ops = (A, M)
- if not set(map(type, tensor_ops)).issubset(
- (torch.Tensor, type(None))
- ) and has_torch_function(tensor_ops):
- return handle_torch_function(
- svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
- )
- return _svd_lowrank(A, q=q, niter=niter, M=M)
- def _svd_lowrank(
- A: Tensor,
- q: int | None = 6,
- niter: int | None = 2,
- M: Tensor | None = None,
- ) -> tuple[Tensor, Tensor, Tensor]:
- # Algorithm 5.1 in Halko et al., 2009
- q = 6 if q is None else q
- m, n = A.shape[-2:]
- matmul = _utils.matmul
- if M is not None:
- M = M.broadcast_to(A.size())
- # Assume that A is tall
- if m < n:
- A = A.mH
- if M is not None:
- M = M.mH
- Q = get_approximate_basis(A, q, niter=niter, M=M)
- B = matmul(Q.mH, A)
- if M is not None:
- B = B - matmul(Q.mH, M)
- U, S, Vh = torch.linalg.svd(B, full_matrices=False)
- V = Vh.mH
- U = Q.matmul(U)
- if m < n:
- U, V = V, U
- return U, S, V
- def pca_lowrank(
- A: Tensor,
- q: int | None = None,
- center: bool = True,
- niter: int = 2,
- ) -> tuple[Tensor, Tensor, Tensor]:
- r"""Performs linear Principal Component Analysis (PCA) on a low-rank
- matrix, batches of such matrices, or sparse matrix.
- This function returns a namedtuple ``(U, S, V)`` which is the
- nearly optimal approximation of a singular value decomposition of
- a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`
- .. note:: The relation of ``(U, S, V)`` to PCA is as follows:
- - :math:`A` is a data matrix with ``m`` samples and
- ``n`` features
- - the :math:`V` columns represent the principal directions
- - :math:`S ** 2 / (m - 1)` contains the eigenvalues of
- :math:`A^T A / (m - 1)` which is the covariance of
- ``A`` when ``center=True`` is provided.
- - ``matmul(A, V[:, :k])`` projects data to the first k
- principal components
- .. note:: Different from the standard SVD, the size of returned
- matrices depend on the specified rank and q
- values as follows:
- - :math:`U` is m x q matrix
- - :math:`S` is q-vector
- - :math:`V` is n x q matrix
- .. note:: To obtain repeatable results, reset the seed for the
- pseudorandom number generator
- Args:
- A (Tensor): the input tensor of size :math:`(*, m, n)`
- q (int, optional): a slightly overestimated rank of
- :math:`A`. By default, ``q = min(6, m,
- n)``.
- center (bool, optional): if True, center the input tensor,
- otherwise, assume that the input is
- centered.
- niter (int, optional): the number of subspace iterations to
- conduct; niter must be a nonnegative
- integer, and defaults to 2.
- References::
- - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
- structure with randomness: probabilistic algorithms for
- constructing approximate matrix decompositions,
- arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
- `arXiv <http://arxiv.org/abs/0909.4061>`_).
- """
- if not torch.jit.is_scripting():
- if type(A) is not torch.Tensor and has_torch_function((A,)):
- return handle_torch_function(
- pca_lowrank, (A,), A, q=q, center=center, niter=niter
- )
- (m, n) = A.shape[-2:]
- if q is None:
- q = min(6, m, n)
- elif not (q >= 0 and q <= min(m, n)):
- raise ValueError(
- f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}"
- )
- if not (niter >= 0):
- raise ValueError(f"niter(={niter}) must be non-negative integer")
- dtype = _utils.get_floating_dtype(A)
- if not center:
- return _svd_lowrank(A, q, niter=niter, M=None)
- if _utils.is_sparse(A):
- if len(A.shape) != 2:
- raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
- c = torch.sparse.sum(A, dim=(-2,)) / m
- # reshape c
- column_indices = c.indices()[0]
- indices = torch.zeros(
- 2,
- len(column_indices),
- dtype=column_indices.dtype,
- device=column_indices.device,
- )
- indices[0] = column_indices
- C_t = torch.sparse_coo_tensor(
- indices, c.values(), (n, 1), dtype=dtype, device=A.device
- )
- ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
- M = torch.sparse.mm(C_t, ones_m1_t).mT
- return _svd_lowrank(A, q, niter=niter, M=M)
- else:
- C = A.mean(dim=(-2,), keepdim=True)
- return _svd_lowrank(A - C, q, niter=niter, M=None)
|