_lowrank.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. """Implement various linear algebra algorithms for low rank matrices."""
  2. __all__ = ["svd_lowrank", "pca_lowrank"]
  3. import torch
  4. from torch import _linalg_utils as _utils, Tensor
  5. from torch.overrides import handle_torch_function, has_torch_function
  6. def get_approximate_basis(
  7. A: Tensor,
  8. q: int,
  9. niter: int | None = 2,
  10. M: Tensor | None = None,
  11. ) -> Tensor:
  12. """Return tensor :math:`Q` with :math:`q` orthonormal columns such
  13. that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
  14. specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
  15. approximates :math:`A - M`. without instantiating any tensors
  16. of the size of :math:`A` or :math:`M`.
  17. .. note:: The implementation is based on the Algorithm 4.4 from
  18. Halko et al., 2009.
  19. .. note:: For an adequate approximation of a k-rank matrix
  20. :math:`A`, where k is not known in advance but could be
  21. estimated, the number of :math:`Q` columns, q, can be
  22. chosen according to the following criteria: in general,
  23. :math:`k <= q <= min(2*k, m, n)`. For large low-rank
  24. matrices, take :math:`q = k + 5..10`. If k is
  25. relatively small compared to :math:`min(m, n)`, choosing
  26. :math:`q = k + 0..2` may be sufficient.
  27. .. note:: To obtain repeatable results, reset the seed for the
  28. pseudorandom number generator
  29. Args::
  30. A (Tensor): the input tensor of size :math:`(*, m, n)`
  31. q (int): the dimension of subspace spanned by :math:`Q`
  32. columns.
  33. niter (int, optional): the number of subspace iterations to
  34. conduct; ``niter`` must be a
  35. nonnegative integer. In most cases, the
  36. default value 2 is more than enough.
  37. M (Tensor, optional): the input tensor's mean of size
  38. :math:`(*, m, n)`.
  39. References::
  40. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  41. structure with randomness: probabilistic algorithms for
  42. constructing approximate matrix decompositions,
  43. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  44. `arXiv <http://arxiv.org/abs/0909.4061>`_).
  45. """
  46. niter = 2 if niter is None else niter
  47. dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype
  48. matmul = _utils.matmul
  49. R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device)
  50. # The following code could be made faster using torch.geqrf + torch.ormqr
  51. # but geqrf is not differentiable
  52. X = matmul(A, R)
  53. if M is not None:
  54. X = X - matmul(M, R)
  55. Q = torch.linalg.qr(X).Q
  56. for _ in range(niter):
  57. X = matmul(A.mH, Q)
  58. if M is not None:
  59. X = X - matmul(M.mH, Q)
  60. Q = torch.linalg.qr(X).Q
  61. X = matmul(A, Q)
  62. if M is not None:
  63. X = X - matmul(M, Q)
  64. Q = torch.linalg.qr(X).Q
  65. return Q
  66. def svd_lowrank(
  67. A: Tensor,
  68. q: int | None = 6,
  69. niter: int | None = 2,
  70. M: Tensor | None = None,
  71. ) -> tuple[Tensor, Tensor, Tensor]:
  72. r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
  73. batches of matrices, or a sparse matrix :math:`A` such that
  74. :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then
  75. SVD is computed for the matrix :math:`A - M`.
  76. .. note:: The implementation is based on the Algorithm 5.1 from
  77. Halko et al., 2009.
  78. .. note:: For an adequate approximation of a k-rank matrix
  79. :math:`A`, where k is not known in advance but could be
  80. estimated, the number of :math:`Q` columns, q, can be
  81. chosen according to the following criteria: in general,
  82. :math:`k <= q <= min(2*k, m, n)`. For large low-rank
  83. matrices, take :math:`q = k + 5..10`. If k is
  84. relatively small compared to :math:`min(m, n)`, choosing
  85. :math:`q = k + 0..2` may be sufficient.
  86. .. note:: This is a randomized method. To obtain repeatable results,
  87. set the seed for the pseudorandom number generator
  88. .. note:: In general, use the full-rank SVD implementation
  89. :func:`torch.linalg.svd` for dense matrices due to its 10x
  90. higher performance characteristics. The low-rank SVD
  91. will be useful for huge sparse matrices that
  92. :func:`torch.linalg.svd` cannot handle.
  93. Args::
  94. A (Tensor): the input tensor of size :math:`(*, m, n)`
  95. q (int, optional): a slightly overestimated rank of A.
  96. niter (int, optional): the number of subspace iterations to
  97. conduct; niter must be a nonnegative
  98. integer, and defaults to 2
  99. M (Tensor, optional): the input tensor's mean of size
  100. :math:`(*, m, n)`, which will be broadcasted
  101. to the size of A in this function.
  102. References::
  103. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  104. structure with randomness: probabilistic algorithms for
  105. constructing approximate matrix decompositions,
  106. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  107. `arXiv <https://arxiv.org/abs/0909.4061>`_).
  108. """
  109. if not torch.jit.is_scripting():
  110. tensor_ops = (A, M)
  111. if not set(map(type, tensor_ops)).issubset(
  112. (torch.Tensor, type(None))
  113. ) and has_torch_function(tensor_ops):
  114. return handle_torch_function(
  115. svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
  116. )
  117. return _svd_lowrank(A, q=q, niter=niter, M=M)
  118. def _svd_lowrank(
  119. A: Tensor,
  120. q: int | None = 6,
  121. niter: int | None = 2,
  122. M: Tensor | None = None,
  123. ) -> tuple[Tensor, Tensor, Tensor]:
  124. # Algorithm 5.1 in Halko et al., 2009
  125. q = 6 if q is None else q
  126. m, n = A.shape[-2:]
  127. matmul = _utils.matmul
  128. if M is not None:
  129. M = M.broadcast_to(A.size())
  130. # Assume that A is tall
  131. if m < n:
  132. A = A.mH
  133. if M is not None:
  134. M = M.mH
  135. Q = get_approximate_basis(A, q, niter=niter, M=M)
  136. B = matmul(Q.mH, A)
  137. if M is not None:
  138. B = B - matmul(Q.mH, M)
  139. U, S, Vh = torch.linalg.svd(B, full_matrices=False)
  140. V = Vh.mH
  141. U = Q.matmul(U)
  142. if m < n:
  143. U, V = V, U
  144. return U, S, V
  145. def pca_lowrank(
  146. A: Tensor,
  147. q: int | None = None,
  148. center: bool = True,
  149. niter: int = 2,
  150. ) -> tuple[Tensor, Tensor, Tensor]:
  151. r"""Performs linear Principal Component Analysis (PCA) on a low-rank
  152. matrix, batches of such matrices, or sparse matrix.
  153. This function returns a namedtuple ``(U, S, V)`` which is the
  154. nearly optimal approximation of a singular value decomposition of
  155. a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`
  156. .. note:: The relation of ``(U, S, V)`` to PCA is as follows:
  157. - :math:`A` is a data matrix with ``m`` samples and
  158. ``n`` features
  159. - the :math:`V` columns represent the principal directions
  160. - :math:`S ** 2 / (m - 1)` contains the eigenvalues of
  161. :math:`A^T A / (m - 1)` which is the covariance of
  162. ``A`` when ``center=True`` is provided.
  163. - ``matmul(A, V[:, :k])`` projects data to the first k
  164. principal components
  165. .. note:: Different from the standard SVD, the size of returned
  166. matrices depend on the specified rank and q
  167. values as follows:
  168. - :math:`U` is m x q matrix
  169. - :math:`S` is q-vector
  170. - :math:`V` is n x q matrix
  171. .. note:: To obtain repeatable results, reset the seed for the
  172. pseudorandom number generator
  173. Args:
  174. A (Tensor): the input tensor of size :math:`(*, m, n)`
  175. q (int, optional): a slightly overestimated rank of
  176. :math:`A`. By default, ``q = min(6, m,
  177. n)``.
  178. center (bool, optional): if True, center the input tensor,
  179. otherwise, assume that the input is
  180. centered.
  181. niter (int, optional): the number of subspace iterations to
  182. conduct; niter must be a nonnegative
  183. integer, and defaults to 2.
  184. References::
  185. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  186. structure with randomness: probabilistic algorithms for
  187. constructing approximate matrix decompositions,
  188. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  189. `arXiv <http://arxiv.org/abs/0909.4061>`_).
  190. """
  191. if not torch.jit.is_scripting():
  192. if type(A) is not torch.Tensor and has_torch_function((A,)):
  193. return handle_torch_function(
  194. pca_lowrank, (A,), A, q=q, center=center, niter=niter
  195. )
  196. (m, n) = A.shape[-2:]
  197. if q is None:
  198. q = min(6, m, n)
  199. elif not (q >= 0 and q <= min(m, n)):
  200. raise ValueError(
  201. f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}"
  202. )
  203. if not (niter >= 0):
  204. raise ValueError(f"niter(={niter}) must be non-negative integer")
  205. dtype = _utils.get_floating_dtype(A)
  206. if not center:
  207. return _svd_lowrank(A, q, niter=niter, M=None)
  208. if _utils.is_sparse(A):
  209. if len(A.shape) != 2:
  210. raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
  211. c = torch.sparse.sum(A, dim=(-2,)) / m
  212. # reshape c
  213. column_indices = c.indices()[0]
  214. indices = torch.zeros(
  215. 2,
  216. len(column_indices),
  217. dtype=column_indices.dtype,
  218. device=column_indices.device,
  219. )
  220. indices[0] = column_indices
  221. C_t = torch.sparse_coo_tensor(
  222. indices, c.values(), (n, 1), dtype=dtype, device=A.device
  223. )
  224. ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
  225. M = torch.sparse.mm(C_t, ones_m1_t).mT
  226. return _svd_lowrank(A, q, niter=niter, M=M)
  227. else:
  228. C = A.mean(dim=(-2,), keepdim=True)
  229. return _svd_lowrank(A - C, q, niter=niter, M=None)