_polyutils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. """Partial replacements for numpy polynomial routines, with Array API compatibility.
  2. This module contains both "old-style", np.poly1d, routines from the main numpy
  3. namespace, and "new-style", np.polynomial.polynomial, routines.
  4. To distinguish the two sets, the "new-style" routine names start with `npp_`
  5. """
  6. import warnings
  7. import scipy._lib.array_api_extra as xpx
  8. from scipy._lib._array_api import (
  9. xp_promote, xp_default_dtype, xp_size, xp_device, is_numpy
  10. )
  11. try:
  12. from numpy.exceptions import RankWarning
  13. except ImportError:
  14. # numpy 1.x
  15. from numpy import RankWarning
  16. def _sort_cmplx(arr, xp):
  17. # xp.sort is undefined for complex dtypes. Here we only need some
  18. # consistent way to sort a complex array, including equal magnitude elements.
  19. arr = xp.asarray(arr)
  20. if xp.isdtype(arr.dtype, 'complex floating'):
  21. sorter = abs(arr) + xp.real(arr) + xp.imag(arr)**3
  22. else:
  23. sorter = arr
  24. idxs = xp.argsort(sorter)
  25. return arr[idxs]
  26. def polyroots(coef, *, xp):
  27. """numpy.roots, best-effor replacement
  28. """
  29. if coef.shape[0] < 2:
  30. return xp.asarray([], dtype=coef.dtype)
  31. root_func = getattr(xp, 'roots', None)
  32. if root_func:
  33. # NB: cupy.roots is broken in CuPy 13.x, but CuPy is handled via delegation
  34. # so we never hit this code path with xp being cupy
  35. return root_func(coef)
  36. # companion matrix
  37. n = coef.shape[0]
  38. a = xp.eye(n - 1, n - 1, k=-1, dtype=coef.dtype)
  39. a[:, -1] = -xp.flip(coef[1:]) / coef[0]
  40. # non-symmetric eigenvalue problem is not in the spec but is available on e.g. torch
  41. if hasattr(xp.linalg, 'eigvals'):
  42. return xp.linalg.eigvals(a)
  43. else:
  44. import numpy as np
  45. return xp.asarray(np.linalg.eigvals(np.asarray(a)))
  46. # https://github.com/numpy/numpy/blob/v2.1.0/numpy/lib/_function_base_impl.py#L1874-L1925
  47. def _trim_zeros(filt, trim='fb'):
  48. first = 0
  49. trim = trim.upper()
  50. if 'F' in trim:
  51. for i in filt:
  52. if i != 0.:
  53. break
  54. else:
  55. first = first + 1
  56. last = filt.shape[0]
  57. if 'B' in trim:
  58. for i in filt[::-1]:
  59. if i != 0.:
  60. break
  61. else:
  62. last = last - 1
  63. return filt[first:last]
  64. # For numpy arrays, use scipy.linalg.lstsq;
  65. # For other backends,
  66. # - use xp.linalg.lstsq, if available (cupy, torch, jax.numpy);
  67. # - otherwise manually compute pseudoinverse via SVD factorization
  68. def _lstsq(a, b, xp=None, rcond=None):
  69. a, b = xp_promote(a, b, force_floating=True, xp=xp)
  70. if rcond is None:
  71. rcond = xp.finfo(a.dtype).eps * max(a.shape[-1], a.shape[-2])
  72. if is_numpy(xp):
  73. from scipy.linalg import lstsq as s_lstsq
  74. return s_lstsq(a, b, cond=rcond)
  75. elif lstsq_func := getattr(xp.linalg, "lstsq", None):
  76. # cupy, torch, jax.numpy all have xp.linalg.lstsq
  77. return lstsq_func(a, b, rcond=rcond)
  78. else:
  79. # unknown array library: LSQ solve via pseudoinverse
  80. u, s, vt = xp.linalg.svd(a, full_matrices=False)
  81. sing_val_mask = s > rcond
  82. s = xpx.apply_where(sing_val_mask, (s,), lambda x: 1. / x, fill_value=0.)
  83. sigma = xp.eye(s.shape[0]) * s # == np.diag(s)
  84. x = vt.T @ sigma @ u.T @ b
  85. rank = xp.count_nonzero(sing_val_mask)
  86. # XXX actually compute residuals, when there's a use case
  87. residuals = xp.asarray([])
  88. return x, residuals, rank, s
  89. # ### Old-style routines ###
  90. # https://github.com/numpy/numpy/blob/v2.2.0/numpy/lib/_polynomial_impl.py#L1232
  91. def _poly1d(c_or_r, *, xp):
  92. """ Constructor of np.poly1d object from an array of coefficients (r=False)
  93. """
  94. c_or_r = xpx.atleast_nd(c_or_r, ndim=1, xp=xp)
  95. if c_or_r.ndim > 1:
  96. raise ValueError("Polynomial must be 1d only.")
  97. c_or_r = _trim_zeros(c_or_r, trim='f')
  98. if c_or_r.shape[0] == 0:
  99. c_or_r = xp.asarray([0], dtype=c_or_r.dtype)
  100. return c_or_r
  101. # https://github.com/numpy/numpy/blob/v2.2.0/numpy/lib/_polynomial_impl.py#L702-L779
  102. def polyval(p, x, *, xp):
  103. """ Old-style polynomial, `np.polyval`
  104. """
  105. p = xp.asarray(p)
  106. x = xp.asarray(x)
  107. y = xp.zeros_like(x)
  108. # NB: cannot do `for pv in p` since array API iteration
  109. # is only defined for 1D arrays.
  110. for j in range(p.shape[0]):
  111. y = y * x + p[j, ...]
  112. return y
  113. # https://github.com/numpy/numpy/blob/v2.2.0/numpy/lib/_polynomial_impl.py#L34-L157
  114. def poly(seq_of_zeros, *, xp):
  115. # Only reproduce the 1D variant of np.poly
  116. seq_of_zeros = xp.asarray(seq_of_zeros)
  117. seq_of_zeros = xpx.atleast_nd(seq_of_zeros, ndim=1, xp=xp)
  118. if seq_of_zeros.shape[0] == 0:
  119. return xp.asarray(1.0, dtype=xp.real(seq_of_zeros).dtype)
  120. # prefer np.convolve etc, if available
  121. convolve_func = getattr(xp, 'convolve', None)
  122. if convolve_func is None:
  123. from scipy.signal import convolve as convolve_func
  124. dt = seq_of_zeros.dtype
  125. a = xp.ones((1,), dtype=dt)
  126. one = xp.ones_like(seq_of_zeros[0])
  127. for zero in seq_of_zeros:
  128. a = convolve_func(a, xp.stack((one, -zero)), mode='full')
  129. if xp.isdtype(a.dtype, 'complex floating'):
  130. # if complex roots are all complex conjugates, the roots are real.
  131. roots = xp.asarray(seq_of_zeros, dtype=xp.complex128)
  132. if xp.all(xp.sort(xp.imag(roots)) == xp.sort(xp.imag(xp.conj(roots)))):
  133. a = xp.asarray(xp.real(a), copy=True)
  134. return a
  135. # https://github.com/numpy/numpy/blob/v2.2.0/numpy/lib/_polynomial_impl.py#L912
  136. def polymul(a1, a2, *, xp):
  137. a1, a2 = _poly1d(a1, xp=xp), _poly1d(a2, xp=xp)
  138. # prefer np.convolve etc, if available
  139. convolve_func = getattr(xp, 'convolve', None)
  140. if convolve_func is None:
  141. from scipy.signal import convolve as convolve_func
  142. val = convolve_func(a1, a2)
  143. return val
  144. # https://github.com/numpy/numpy/blob/v2.3.3/numpy/lib/_polynomial_impl.py#L459
  145. def polyfit(x, y, deg, *, xp, rcond=None):
  146. # only reproduce the variant with full=False, w=None, cov=False
  147. order = int(deg) + 1
  148. x = xp.asarray(x)
  149. y = xp.asarray(y)
  150. x, y = xp_promote(x, y, force_floating=True, xp=xp)
  151. # check arguments.
  152. if deg < 0:
  153. raise ValueError("expected deg >= 0")
  154. if x.ndim != 1:
  155. raise TypeError("expected 1D vector for x")
  156. if xp_size(x) == 0:
  157. raise TypeError("expected non-empty vector for x")
  158. if y.ndim < 1 or y.ndim > 2:
  159. raise TypeError("expected 1D or 2D array for y")
  160. if x.shape[0] != y.shape[0]:
  161. raise TypeError("expected x and y to have same length")
  162. # set rcond
  163. if rcond is None:
  164. rcond = x.shape[0] * xp.finfo(x.dtype).eps
  165. # set up least squares equation for powers of x: lhs = vander(x, order)
  166. powers = xp.flip(xp.arange(order, dtype=x.dtype, device=xp_device(x)))
  167. lhs = x[:, None] ** powers[None, :]
  168. # scale lhs to improve condition number and solve
  169. scale = xp.sqrt(xp.sum(lhs * lhs, axis=0))
  170. lhs /= scale
  171. c, _, rank, _ = _lstsq(lhs, y, rcond=rcond, xp=xp)
  172. c = (c.T / scale).T # broadcast scale coefficients
  173. # warn on rank reduction, which indicates an ill conditioned matrix
  174. if rank != order:
  175. msg = "Polyfit may be poorly conditioned"
  176. warnings.warn(msg, RankWarning, stacklevel=2)
  177. return c
  178. # ### New-style routines ###
  179. # https://github.com/numpy/numpy/blob/v2.2.0/numpy/polynomial/polynomial.py#L663
  180. def npp_polyval(x, c, *, xp, tensor=True):
  181. if xp.isdtype(c.dtype, 'integral'):
  182. c = xp.astype(c, xp_default_dtype(xp))
  183. c = xpx.atleast_nd(c, ndim=1, xp=xp)
  184. if isinstance(x, tuple | list):
  185. x = xp.asarray(x)
  186. if tensor:
  187. c = xp.reshape(c, (c.shape + (1,)*x.ndim))
  188. c0, _ = xp_promote(c[-1, ...], x, broadcast=True, xp=xp)
  189. for i in range(2, c.shape[0] + 1):
  190. c0 = c[-i, ...] + c0*x
  191. return c0
  192. # https://github.com/numpy/numpy/blob/v2.2.0/numpy/polynomial/polynomial.py#L758-L842
  193. def npp_polyvalfromroots(x, r, *, xp, tensor=True):
  194. r = xpx.atleast_nd(r, ndim=1, xp=xp)
  195. # if r.dtype.char in '?bBhHiIlLqQpP':
  196. # r = r.astype(np.double)
  197. if isinstance(x, tuple | list):
  198. x = xp.asarray(x)
  199. if tensor:
  200. r = xp.reshape(r, r.shape + (1,) * x.ndim)
  201. elif x.ndim >= r.ndim:
  202. raise ValueError("x.ndim must be < r.ndim when tensor == False")
  203. return xp.prod(x - r, axis=0)