_ndbspline.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. import itertools
  2. import functools
  3. import operator
  4. import numpy as np
  5. from math import prod
  6. from types import GenericAlias
  7. from . import _dierckx # type: ignore[attr-defined]
  8. import scipy.sparse.linalg as ssl
  9. from scipy.sparse import csr_array
  10. from scipy._lib._array_api import array_namespace, xp_capabilities
  11. from ._bsplines import _not_a_knot, BSpline
  12. __all__ = ["NdBSpline"]
  13. def _get_dtype(dtype):
  14. """Return np.complex128 for complex dtypes, np.float64 otherwise."""
  15. if np.issubdtype(dtype, np.complexfloating):
  16. return np.complex128
  17. else:
  18. return np.float64
  19. @xp_capabilities(
  20. cpu_only=True, jax_jit=False,
  21. skip_backends=[
  22. ("dask.array",
  23. "https://github.com/data-apis/array-api-extra/issues/488")
  24. ]
  25. )
  26. class NdBSpline:
  27. """Tensor product spline object.
  28. The value at point ``xp = (x1, x2, ..., xN)`` is evaluated as a linear
  29. combination of products of one-dimensional b-splines in each of the ``N``
  30. dimensions::
  31. c[i1, i2, ..., iN] * B(x1; i1, t1) * B(x2; i2, t2) * ... * B(xN; iN, tN)
  32. Here ``B(x; i, t)`` is the ``i``-th b-spline defined by the knot vector
  33. ``t`` evaluated at ``x``.
  34. Parameters
  35. ----------
  36. t : tuple of 1D ndarrays
  37. knot vectors in directions 1, 2, ... N,
  38. ``len(t[i]) == n[i] + k + 1``
  39. c : ndarray, shape (n1, n2, ..., nN, ...)
  40. b-spline coefficients
  41. k : int or length-d tuple of integers
  42. spline degrees.
  43. A single integer is interpreted as having this degree for
  44. all dimensions.
  45. extrapolate : bool, optional
  46. Whether to extrapolate out-of-bounds inputs, or return `nan`.
  47. Default is to extrapolate.
  48. Attributes
  49. ----------
  50. t : tuple of ndarrays
  51. Knots vectors.
  52. c : ndarray
  53. Coefficients of the tensor-product spline.
  54. k : tuple of integers
  55. Degrees for each dimension.
  56. extrapolate : bool, optional
  57. Whether to extrapolate or return nans for out-of-bounds inputs.
  58. Defaults to true.
  59. Methods
  60. -------
  61. __call__
  62. derivative
  63. design_matrix
  64. See Also
  65. --------
  66. BSpline : a one-dimensional B-spline object
  67. NdPPoly : an N-dimensional piecewise tensor product polynomial
  68. """
  69. # generic type compatibility with scipy-stubs
  70. __class_getitem__ = classmethod(GenericAlias)
  71. def __init__(self, t, c, k, *, extrapolate=None):
  72. self._k, self._indices_k1d, (self._t, self._len_t) = _preprocess_inputs(k, t)
  73. self._asarray = array_namespace(c, *t).asarray
  74. if extrapolate is None:
  75. extrapolate = True
  76. self.extrapolate = bool(extrapolate)
  77. self._c = np.asarray(c)
  78. ndim = self._t.shape[0] # == len(self.t)
  79. if self._c.ndim < ndim:
  80. raise ValueError(f"Coefficients must be at least {ndim}-dimensional.")
  81. for d in range(ndim):
  82. td = self.t[d]
  83. kd = self.k[d]
  84. n = td.shape[0] - kd - 1
  85. if self._c.shape[d] != n:
  86. raise ValueError(f"Knots, coefficients and degree in dimension"
  87. f" {d} are inconsistent:"
  88. f" got {self._c.shape[d]} coefficients for"
  89. f" {len(td)} knots, need at least {n} for"
  90. f" k={k}.")
  91. dt = _get_dtype(self._c.dtype)
  92. self._c = np.ascontiguousarray(self._c, dtype=dt)
  93. @property
  94. def k(self):
  95. return tuple(self._k)
  96. @property
  97. def t(self):
  98. # repack the knots into a tuple
  99. return tuple(
  100. self._asarray(self._t[d, :self._len_t[d]]) for d in range(self._t.shape[0])
  101. )
  102. @property
  103. def c(self):
  104. return self._asarray(self._c)
  105. def __call__(self, xi, *, nu=None, extrapolate=None):
  106. """Evaluate the tensor product b-spline at ``xi``.
  107. Parameters
  108. ----------
  109. xi : array_like, shape(..., ndim)
  110. The coordinates to evaluate the interpolator at.
  111. This can be a list or tuple of ndim-dimensional points
  112. or an array with the shape (num_points, ndim).
  113. nu : sequence of length ``ndim``, optional
  114. Orders of derivatives to evaluate. Each must be non-negative.
  115. Defaults to the zeroth derivivative.
  116. extrapolate : bool, optional
  117. Whether to exrapolate based on first and last intervals in each
  118. dimension, or return `nan`. Default is to ``self.extrapolate``.
  119. Returns
  120. -------
  121. values : ndarray, shape ``xi.shape[:-1] + self.c.shape[ndim:]``
  122. Interpolated values at ``xi``
  123. """
  124. ndim = self._t.shape[0] # == len(self.t)
  125. if extrapolate is None:
  126. extrapolate = self.extrapolate
  127. extrapolate = bool(extrapolate)
  128. if nu is None:
  129. nu = np.zeros((ndim,), dtype=np.int64)
  130. else:
  131. nu = np.asarray(nu, dtype=np.int64)
  132. if nu.ndim != 1 or nu.shape[0] != ndim:
  133. raise ValueError(
  134. f"invalid number of derivative orders {nu = } for "
  135. f"ndim = {len(self.t)}.")
  136. if any(nu < 0):
  137. raise ValueError(f"derivatives must be positive, got {nu = }")
  138. # prepare xi : shape (..., m1, ..., md) -> (1, m1, ..., md)
  139. xi = np.asarray(xi, dtype=float)
  140. xi_shape = xi.shape
  141. xi = xi.reshape(-1, xi_shape[-1])
  142. xi = np.ascontiguousarray(xi)
  143. if xi_shape[-1] != ndim:
  144. raise ValueError(f"Shapes: xi.shape={xi_shape} and ndim={ndim}")
  145. # complex -> double
  146. was_complex = self._c.dtype.kind == 'c'
  147. cc = self._c
  148. if was_complex and self._c.ndim == ndim:
  149. # make sure that core dimensions are intact, and complex->float
  150. # size doubling only adds a trailing dimension
  151. cc = self._c[..., None]
  152. cc = cc.view(float)
  153. # prepare the coefficients: flatten the trailing dimensions
  154. c1 = cc.reshape(cc.shape[:ndim] + (-1,))
  155. c1r = c1.ravel()
  156. # replacement for np.ravel_multi_index for indexing of `c1`:
  157. _strides_c1 = np.asarray([s // c1.dtype.itemsize
  158. for s in c1.strides], dtype=np.int64)
  159. num_c_tr = c1.shape[-1] # # of trailing coefficients
  160. out = _dierckx.evaluate_ndbspline(xi,
  161. self._t,
  162. self._len_t,
  163. self._k,
  164. nu,
  165. extrapolate,
  166. c1r,
  167. num_c_tr,
  168. _strides_c1,
  169. self._indices_k1d,
  170. )
  171. out = out.view(self._c.dtype)
  172. out = out.reshape(xi_shape[:-1] + self._c.shape[ndim:])
  173. return self._asarray(out)
  174. @classmethod
  175. def design_matrix(cls, xvals, t, k, extrapolate=True):
  176. """Construct the design matrix as a CSR format sparse array.
  177. Parameters
  178. ----------
  179. xvals : ndarray, shape(npts, ndim)
  180. Data points. ``xvals[j, :]`` gives the ``j``-th data point as an
  181. ``ndim``-dimensional array.
  182. t : tuple of 1D ndarrays, length-ndim
  183. Knot vectors in directions 1, 2, ... ndim,
  184. k : int
  185. B-spline degree.
  186. extrapolate : bool, optional
  187. Whether to extrapolate out-of-bounds values of raise a `ValueError`
  188. Returns
  189. -------
  190. design_matrix : a CSR array
  191. Each row of the design matrix corresponds to a value in `xvals` and
  192. contains values of b-spline basis elements which are non-zero
  193. at this value.
  194. """
  195. xvals = np.asarray(xvals, dtype=float)
  196. ndim = xvals.shape[-1]
  197. if len(t) != ndim:
  198. raise ValueError(
  199. f"Data and knots are inconsistent: len(t) = {len(t)} for "
  200. f" {ndim = }."
  201. )
  202. # tabulate the flat indices for iterating over the (k+1)**ndim subarray
  203. k, _indices_k1d, (_t, len_t) = _preprocess_inputs(k, t)
  204. # Precompute the shape and strides of the 'coefficients array'.
  205. # This would have been the NdBSpline coefficients; in the present context
  206. # this is a helper to compute the indices into the colocation matrix.
  207. c_shape = tuple(len_t[d] - k[d] - 1 for d in range(ndim))
  208. # The strides of the coeffs array: the computation is equivalent to
  209. # >>> cstrides = [s // 8 for s in np.empty(c_shape).strides]
  210. cs = c_shape[1:] + (1,)
  211. cstrides = np.cumprod(cs[::-1], dtype=np.int64)[::-1].copy()
  212. # heavy lifting happens here
  213. data, indices, indptr = _dierckx._coloc_nd(xvals,
  214. _t, len_t, k, _indices_k1d, cstrides)
  215. return csr_array((data, indices, indptr))
  216. def _bspline_derivative_along_axis(self, c, t, k, axis, nu=1):
  217. # Move the selected axis to front
  218. c = np.moveaxis(c, axis, 0)
  219. n = c.shape[0]
  220. trailing_shape = c.shape[1:]
  221. c_flat = c.reshape(n, -1)
  222. new_c_list = []
  223. new_t = None
  224. for i in range(c_flat.shape[1]):
  225. if k >= nu:
  226. b = BSpline.construct_fast(t, c_flat[:, i], k)
  227. db = b.derivative(nu)
  228. # truncate coefficients to match new knot/degree size
  229. db.c = db.c[:len(db.t) - db.k - 1]
  230. else:
  231. db = BSpline.construct_fast(t, np.zeros(len(t) - 1), 0)
  232. if new_t is None:
  233. new_t = db.t
  234. new_c_list.append(db.c)
  235. new_c = np.stack(new_c_list, axis=1).reshape(
  236. (len(new_c_list[0]),) + trailing_shape)
  237. new_c = np.moveaxis(new_c, 0, axis)
  238. return new_c, new_t
  239. def derivative(self, nu):
  240. """
  241. Construct a new NdBSpline representing the partial derivative.
  242. Parameters
  243. ----------
  244. nu : array_like of shape (ndim,)
  245. Orders of the partial derivatives to compute along each dimension.
  246. Returns
  247. -------
  248. NdBSpline
  249. A new NdBSpline representing the partial derivative of the original spline.
  250. """
  251. nu_arr = np.asarray(nu, dtype=np.int64)
  252. ndim = len(self.t)
  253. if nu_arr.ndim != 1 or nu_arr.shape[0] != ndim:
  254. raise ValueError(
  255. f"invalid number of derivative orders {nu = } for "
  256. f"ndim = {len(self.t)}.")
  257. if any(nu_arr < 0):
  258. raise ValueError(f"derivative orders must be positive, got {nu = }")
  259. # extract t and c as numpy arrays
  260. t_new = [self._t[d, :self._len_t[d]] for d in range(self._t.shape[0])]
  261. k_new = list(self.k)
  262. c_new = self._c.copy()
  263. for axis, n in enumerate(nu_arr):
  264. if n == 0:
  265. continue
  266. c_new, t_new[axis] = self._bspline_derivative_along_axis(
  267. c_new, t_new[axis], k_new[axis], axis, nu=n
  268. )
  269. k_new[axis] = max(k_new[axis] - n, 0)
  270. return NdBSpline(tuple(self._asarray(t) for t in t_new),
  271. self._asarray(c_new),
  272. tuple(k_new),
  273. extrapolate=self.extrapolate
  274. )
  275. def _preprocess_inputs(k, t_tpl):
  276. """Helpers: validate and preprocess NdBSpline inputs.
  277. Parameters
  278. ----------
  279. k : int or tuple
  280. Spline orders
  281. t_tpl : tuple or array-likes
  282. Knots.
  283. """
  284. # 1. Make sure t_tpl is a tuple
  285. if not isinstance(t_tpl, tuple):
  286. raise ValueError(f"Expect `t` to be a tuple of array-likes. "
  287. f"Got {t_tpl} instead."
  288. )
  289. # 2. Make ``k`` a tuple of integers
  290. ndim = len(t_tpl)
  291. try:
  292. len(k)
  293. except TypeError:
  294. # make k a tuple
  295. k = (k,)*ndim
  296. k = np.asarray([operator.index(ki) for ki in k], dtype=np.int64)
  297. if len(k) != ndim:
  298. raise ValueError(f"len(t) = {len(t_tpl)} != {len(k) = }.")
  299. # 3. Validate inputs
  300. ndim = len(t_tpl)
  301. for d in range(ndim):
  302. td = np.asarray(t_tpl[d])
  303. kd = k[d]
  304. n = td.shape[0] - kd - 1
  305. if kd < 0:
  306. raise ValueError(f"Spline degree in dimension {d} cannot be"
  307. f" negative.")
  308. if td.ndim != 1:
  309. raise ValueError(f"Knot vector in dimension {d} must be"
  310. f" one-dimensional.")
  311. if n < kd + 1:
  312. raise ValueError(f"Need at least {2*kd + 2} knots for degree"
  313. f" {kd} in dimension {d}.")
  314. if (np.diff(td) < 0).any():
  315. raise ValueError(f"Knots in dimension {d} must be in a"
  316. f" non-decreasing order.")
  317. if len(np.unique(td[kd:n + 1])) < 2:
  318. raise ValueError(f"Need at least two internal knots in"
  319. f" dimension {d}.")
  320. if not np.isfinite(td).all():
  321. raise ValueError(f"Knots in dimension {d} should not have"
  322. f" nans or infs.")
  323. # 4. tabulate the flat indices for iterating over the (k+1)**ndim subarray
  324. # non-zero b-spline elements
  325. shape = tuple(kd + 1 for kd in k)
  326. indices = np.unravel_index(np.arange(prod(shape)), shape)
  327. _indices_k1d = np.asarray(indices, dtype=np.int64).T.copy()
  328. # 5. pack the knots into a single array:
  329. # ([1, 2, 3, 4], [5, 6], (7, 8, 9)) -->
  330. # array([[1, 2, 3, 4],
  331. # [5, 6, nan, nan],
  332. # [7, 8, 9, nan]])
  333. t_tpl = [np.asarray(t) for t in t_tpl]
  334. ndim = len(t_tpl)
  335. len_t = [len(ti) for ti in t_tpl]
  336. _t = np.empty((ndim, max(len_t)), dtype=float)
  337. _t.fill(np.nan)
  338. for d in range(ndim):
  339. _t[d, :len(t_tpl[d])] = t_tpl[d]
  340. len_t = np.asarray(len_t, dtype=np.int64)
  341. return k, _indices_k1d, (_t, len_t)
  342. def _iter_solve(a, b, solver=ssl.gcrotmk, **solver_args):
  343. # work around iterative solvers not accepting multiple r.h.s.
  344. # also work around a.dtype == float64 and b.dtype == complex128
  345. # cf https://github.com/scipy/scipy/issues/19644
  346. if np.issubdtype(b.dtype, np.complexfloating):
  347. real = _iter_solve(a, b.real, solver, **solver_args)
  348. imag = _iter_solve(a, b.imag, solver, **solver_args)
  349. return real + 1j*imag
  350. if b.ndim == 2 and b.shape[1] !=1:
  351. res = np.empty_like(b)
  352. for j in range(b.shape[1]):
  353. res[:, j], info = solver(a, b[:, j], **solver_args)
  354. if info != 0:
  355. raise ValueError(f"{solver = } returns {info =} for column {j}.")
  356. return res
  357. else:
  358. res, info = solver(a, b, **solver_args)
  359. if info != 0:
  360. raise ValueError(f"{solver = } returns {info = }.")
  361. return res
  362. def make_ndbspl(points, values, k=3, *, solver=ssl.gcrotmk, **solver_args):
  363. """Construct an interpolating NdBspline.
  364. Parameters
  365. ----------
  366. points : tuple of ndarrays of float, with shapes (m1,), ... (mN,)
  367. The points defining the regular grid in N dimensions. The points in
  368. each dimension (i.e. every element of the `points` tuple) must be
  369. strictly ascending or descending.
  370. values : ndarray of float, shape (m1, ..., mN, ...)
  371. The data on the regular grid in n dimensions.
  372. k : int, optional
  373. The spline degree. Must be odd. Default is cubic, k=3
  374. solver : a `scipy.sparse.linalg` solver (iterative or direct), optional.
  375. An iterative solver from `scipy.sparse.linalg` or a direct one,
  376. `sparse.sparse.linalg.spsolve`.
  377. Used to solve the sparse linear system
  378. ``design_matrix @ coefficients = rhs`` for the coefficients.
  379. Default is `scipy.sparse.linalg.gcrotmk`
  380. solver_args : dict, optional
  381. Additional arguments for the solver. The call signature is
  382. ``solver(csr_array, rhs_vector, **solver_args)``
  383. Returns
  384. -------
  385. spl : NdBSpline object
  386. Notes
  387. -----
  388. Boundary conditions are not-a-knot in all dimensions.
  389. """
  390. ndim = len(points)
  391. xi_shape = tuple(len(x) for x in points)
  392. try:
  393. len(k)
  394. except TypeError:
  395. # make k a tuple
  396. k = (k,)*ndim
  397. for d, point in enumerate(points):
  398. numpts = len(np.atleast_1d(point))
  399. if numpts <= k[d]:
  400. raise ValueError(f"There are {numpts} points in dimension {d},"
  401. f" but order {k[d]} requires at least "
  402. f" {k[d]+1} points per dimension.")
  403. t = tuple(_not_a_knot(np.asarray(points[d], dtype=float), k[d])
  404. for d in range(ndim))
  405. xvals = np.asarray([xv for xv in itertools.product(*points)], dtype=float)
  406. # construct the colocation matrix
  407. matr = NdBSpline.design_matrix(xvals, t, k)
  408. # Remove zeros from the sparse matrix
  409. # If k=1, then solve() doesn't take long enough for this to help
  410. if k[0] >= 3:
  411. matr.eliminate_zeros()
  412. # Solve for the coefficients given `values`.
  413. # Trailing dimensions: first ndim dimensions are data, the rest are batch
  414. # dimensions, so stack `values` into a 2D array for `spsolve` to undestand.
  415. v_shape = values.shape
  416. vals_shape = (prod(v_shape[:ndim]), prod(v_shape[ndim:]))
  417. vals = values.reshape(vals_shape)
  418. if solver != ssl.spsolve:
  419. solver = functools.partial(_iter_solve, solver=solver)
  420. if "atol" not in solver_args:
  421. # avoid a DeprecationWarning, grumble grumble
  422. solver_args["atol"] = 1e-6
  423. coef = solver(matr, vals, **solver_args)
  424. coef = coef.reshape(xi_shape + v_shape[ndim:])
  425. return NdBSpline(t, coef, k)