linalg.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. import functools
  4. import math
  5. from typing import TYPE_CHECKING
  6. import torch
  7. from . import _dtypes_impl, _util
  8. from ._normalizations import ArrayLike, KeepDims, normalizer
  9. if TYPE_CHECKING:
  10. from collections.abc import Sequence
  11. class LinAlgError(Exception):
  12. pass
  13. def _atleast_float_1(a):
  14. if not (a.dtype.is_floating_point or a.dtype.is_complex):
  15. a = a.to(_dtypes_impl.default_dtypes().float_dtype)
  16. return a
  17. def _atleast_float_2(a, b):
  18. dtyp = _dtypes_impl.result_type_impl(a, b)
  19. if not (dtyp.is_floating_point or dtyp.is_complex):
  20. dtyp = _dtypes_impl.default_dtypes().float_dtype
  21. a = _util.cast_if_needed(a, dtyp)
  22. b = _util.cast_if_needed(b, dtyp)
  23. return a, b
  24. def linalg_errors(func):
  25. @functools.wraps(func)
  26. def wrapped(*args, **kwds):
  27. try:
  28. return func(*args, **kwds)
  29. except torch._C._LinAlgError as e:
  30. raise LinAlgError(*e.args) # noqa: B904
  31. return wrapped
  32. # ### Matrix and vector products ###
  33. @normalizer
  34. @linalg_errors
  35. def matrix_power(a: ArrayLike, n):
  36. a = _atleast_float_1(a)
  37. return torch.linalg.matrix_power(a, n)
  38. @normalizer
  39. @linalg_errors
  40. def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
  41. return torch.linalg.multi_dot(inputs)
  42. # ### Solving equations and inverting matrices ###
  43. @normalizer
  44. @linalg_errors
  45. def solve(a: ArrayLike, b: ArrayLike):
  46. a, b = _atleast_float_2(a, b)
  47. return torch.linalg.solve(a, b)
  48. @normalizer
  49. @linalg_errors
  50. def lstsq(a: ArrayLike, b: ArrayLike, rcond=None):
  51. a, b = _atleast_float_2(a, b)
  52. # NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
  53. # on CUDA, only `gels` is available though, so use it instead
  54. driver = "gels" if a.is_cuda or b.is_cuda else "gelsd"
  55. return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
  56. @normalizer
  57. @linalg_errors
  58. def inv(a: ArrayLike):
  59. a = _atleast_float_1(a)
  60. result = torch.linalg.inv(a)
  61. return result
  62. @normalizer
  63. @linalg_errors
  64. def pinv(a: ArrayLike, rcond=1e-15, hermitian=False):
  65. a = _atleast_float_1(a)
  66. return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian)
  67. @normalizer
  68. @linalg_errors
  69. def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None):
  70. a, b = _atleast_float_2(a, b)
  71. return torch.linalg.tensorsolve(a, b, dims=axes)
  72. @normalizer
  73. @linalg_errors
  74. def tensorinv(a: ArrayLike, ind=2):
  75. a = _atleast_float_1(a)
  76. return torch.linalg.tensorinv(a, ind=ind)
  77. # ### Norms and other numbers ###
  78. @normalizer
  79. @linalg_errors
  80. def det(a: ArrayLike):
  81. a = _atleast_float_1(a)
  82. return torch.linalg.det(a)
  83. @normalizer
  84. @linalg_errors
  85. def slogdet(a: ArrayLike):
  86. a = _atleast_float_1(a)
  87. return torch.linalg.slogdet(a)
  88. @normalizer
  89. @linalg_errors
  90. def cond(x: ArrayLike, p=None):
  91. x = _atleast_float_1(x)
  92. # check if empty
  93. # cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
  94. if x.numel() == 0 and math.prod(x.shape[-2:]) == 0:
  95. raise LinAlgError("cond is not defined on empty arrays")
  96. result = torch.linalg.cond(x, p=p)
  97. # Convert nans to infs (numpy does it in a data-dependent way, depending on
  98. # whether the input array has nans or not)
  99. # XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
  100. return torch.where(torch.isnan(result), float("inf"), result)
  101. @normalizer
  102. @linalg_errors
  103. def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
  104. a = _atleast_float_1(a)
  105. if a.ndim < 2:
  106. return int((a != 0).any())
  107. if tol is None:
  108. # follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
  109. atol = 0
  110. rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps
  111. else:
  112. atol, rtol = tol, 0
  113. return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian)
  114. @normalizer
  115. @linalg_errors
  116. def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False):
  117. x = _atleast_float_1(x)
  118. return torch.linalg.norm(x, ord=ord, dim=axis)
  119. # ### Decompositions ###
  120. @normalizer
  121. @linalg_errors
  122. def cholesky(a: ArrayLike):
  123. a = _atleast_float_1(a)
  124. return torch.linalg.cholesky(a)
  125. @normalizer
  126. @linalg_errors
  127. def qr(a: ArrayLike, mode="reduced"):
  128. a = _atleast_float_1(a)
  129. result = torch.linalg.qr(a, mode=mode)
  130. if mode == "r":
  131. # match NumPy
  132. result = result.R
  133. return result
  134. @normalizer
  135. @linalg_errors
  136. def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False):
  137. a = _atleast_float_1(a)
  138. if not compute_uv:
  139. return torch.linalg.svdvals(a)
  140. # NB: ignore the hermitian= argument (no pytorch equivalent)
  141. result = torch.linalg.svd(a, full_matrices=full_matrices)
  142. return result
  143. # ### Eigenvalues and eigenvectors ###
  144. @normalizer
  145. @linalg_errors
  146. def eig(a: ArrayLike):
  147. a = _atleast_float_1(a)
  148. w, vt = torch.linalg.eig(a)
  149. if not a.is_complex() and w.is_complex() and (w.imag == 0).all():
  150. w = w.real
  151. vt = vt.real
  152. return w, vt
  153. @normalizer
  154. @linalg_errors
  155. def eigh(a: ArrayLike, UPLO="L"):
  156. a = _atleast_float_1(a)
  157. return torch.linalg.eigh(a, UPLO=UPLO)
  158. @normalizer
  159. @linalg_errors
  160. def eigvals(a: ArrayLike):
  161. a = _atleast_float_1(a)
  162. result = torch.linalg.eigvals(a)
  163. if not a.is_complex() and result.is_complex() and (result.imag == 0).all():
  164. result = result.real
  165. return result
  166. @normalizer
  167. @linalg_errors
  168. def eigvalsh(a: ArrayLike, UPLO="L"):
  169. a = _atleast_float_1(a)
  170. return torch.linalg.eigvalsh(a, UPLO=UPLO)