_rbfinterp.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. """Module for RBF interpolation."""
  2. import warnings
  3. from types import GenericAlias
  4. import numpy as np
  5. from scipy.spatial import KDTree
  6. from . import _rbfinterp_np
  7. from . import _rbfinterp_xp
  8. from scipy._lib._array_api import (
  9. _asarray, array_namespace, xp_size, is_numpy, xp_capabilities
  10. )
  11. import scipy._lib.array_api_extra as xpx
  12. __all__ = ["RBFInterpolator"]
  13. # These RBFs are implemented.
  14. _AVAILABLE = {
  15. "linear",
  16. "thin_plate_spline",
  17. "cubic",
  18. "quintic",
  19. "multiquadric",
  20. "inverse_multiquadric",
  21. "inverse_quadratic",
  22. "gaussian"
  23. }
  24. # The shape parameter does not need to be specified when using these RBFs.
  25. _SCALE_INVARIANT = {"linear", "thin_plate_spline", "cubic", "quintic"}
  26. # For RBFs that are conditionally positive definite of order m, the interpolant
  27. # should include polynomial terms with degree >= m - 1. Define the minimum
  28. # degrees here. These values are from Chapter 8 of Fasshauer's "Meshfree
  29. # Approximation Methods with MATLAB". The RBFs that are not in this dictionary
  30. # are positive definite and do not need polynomial terms.
  31. _NAME_TO_MIN_DEGREE = {
  32. "multiquadric": 0,
  33. "linear": 0,
  34. "thin_plate_spline": 1,
  35. "cubic": 1,
  36. "quintic": 2
  37. }
  38. def _get_backend(xp):
  39. if is_numpy(xp):
  40. return _rbfinterp_np
  41. return _rbfinterp_xp
  42. extra_note="""Only the default ``neighbors=None`` is Array API compatible.
  43. If a non-default value of ``neighbors`` is given, the behavior is NumPy -only.
  44. """
  45. @xp_capabilities(
  46. skip_backends=[
  47. ("dask.array", "linalg.lu is broken; array_api_extra#488"),
  48. ("array_api_strict", "array-api#977, diag, view")
  49. ],
  50. extra_note=extra_note
  51. )
  52. class RBFInterpolator:
  53. """Radial basis function interpolator in N ≥ 1 dimensions.
  54. Parameters
  55. ----------
  56. y : (npoints, ndims) array_like
  57. 2-D array of data point coordinates.
  58. d : (npoints, ...) array_like
  59. N-D array of data values at `y`. The length of `d` along the first
  60. axis must be equal to the length of `y`. Unlike some interpolators, the
  61. interpolation axis cannot be changed.
  62. neighbors : int, optional
  63. If specified, the value of the interpolant at each evaluation point
  64. will be computed using only this many nearest data points. All the data
  65. points are used by default.
  66. smoothing : float or (npoints, ) array_like, optional
  67. Smoothing parameter. The interpolant perfectly fits the data when this
  68. is set to 0. For large values, the interpolant approaches a least
  69. squares fit of a polynomial with the specified degree. Default is 0.
  70. kernel : str, optional
  71. Type of RBF. This should be one of
  72. - 'linear' : ``-r``
  73. - 'thin_plate_spline' : ``r**2 * log(r)``
  74. - 'cubic' : ``r**3``
  75. - 'quintic' : ``-r**5``
  76. - 'multiquadric' : ``-sqrt(1 + r**2)``
  77. - 'inverse_multiquadric' : ``1/sqrt(1 + r**2)``
  78. - 'inverse_quadratic' : ``1/(1 + r**2)``
  79. - 'gaussian' : ``exp(-r**2)``
  80. Default is 'thin_plate_spline'.
  81. epsilon : float, optional
  82. Shape parameter that scales the input to the RBF. If `kernel` is
  83. 'linear', 'thin_plate_spline', 'cubic', or 'quintic', this defaults to
  84. 1 and can be ignored because it has the same effect as scaling the
  85. smoothing parameter. Otherwise, this must be specified.
  86. degree : int, optional
  87. Degree of the added polynomial. For some RBFs the interpolant may not
  88. be well-posed if the polynomial degree is too small. Those RBFs and
  89. their corresponding minimum degrees are
  90. - 'multiquadric' : 0
  91. - 'linear' : 0
  92. - 'thin_plate_spline' : 1
  93. - 'cubic' : 1
  94. - 'quintic' : 2
  95. The default value is the minimum degree for `kernel` or 0 if there is
  96. no minimum degree. Set this to -1 for no added polynomial.
  97. Notes
  98. -----
  99. An RBF is a scalar valued function in N-dimensional space whose value at
  100. :math:`x` can be expressed in terms of :math:`r=||x - c||`, where :math:`c`
  101. is the center of the RBF.
  102. An RBF interpolant for the vector of data values :math:`d`, which are from
  103. locations :math:`y`, is a linear combination of RBFs centered at :math:`y`
  104. plus a polynomial with a specified degree. The RBF interpolant is written
  105. as
  106. .. math::
  107. f(x) = K(x, y) a + P(x) b,
  108. where :math:`K(x, y)` is a matrix of RBFs with centers at :math:`y`
  109. evaluated at the points :math:`x`, and :math:`P(x)` is a matrix of
  110. monomials, which span polynomials with the specified degree, evaluated at
  111. :math:`x`. The coefficients :math:`a` and :math:`b` are the solution to the
  112. linear equations
  113. .. math::
  114. (K(x, y) + \\lambda I) a + P(y) b = d
  115. and
  116. .. math::
  117. P(y)^T a = 0,
  118. where :math:`\\lambda` is a non-negative smoothing parameter that controls
  119. how well we want to fit the data. The data are fit exactly when the
  120. smoothing parameter is 0.
  121. The above system is uniquely solvable if the following requirements are
  122. met:
  123. - :math:`P(y)` must have full column rank. :math:`P(y)` always has full
  124. column rank when `degree` is -1 or 0. When `degree` is 1,
  125. :math:`P(y)` has full column rank if the data point locations are not
  126. all collinear (N=2), coplanar (N=3), etc.
  127. - If `kernel` is 'multiquadric', 'linear', 'thin_plate_spline',
  128. 'cubic', or 'quintic', then `degree` must not be lower than the
  129. minimum value listed above.
  130. - If `smoothing` is 0, then each data point location must be distinct.
  131. When using an RBF that is not scale invariant ('multiquadric',
  132. 'inverse_multiquadric', 'inverse_quadratic', or 'gaussian'), an appropriate
  133. shape parameter must be chosen (e.g., through cross validation). Smaller
  134. values for the shape parameter correspond to wider RBFs. The problem can
  135. become ill-conditioned or singular when the shape parameter is too small.
  136. The memory required to solve for the RBF interpolation coefficients
  137. increases quadratically with the number of data points, which can become
  138. impractical when interpolating more than about a thousand data points.
  139. To overcome memory limitations for large interpolation problems, the
  140. `neighbors` argument can be specified to compute an RBF interpolant for
  141. each evaluation point using only the nearest data points.
  142. .. versionadded:: 1.7.0
  143. See Also
  144. --------
  145. NearestNDInterpolator
  146. LinearNDInterpolator
  147. CloughTocher2DInterpolator
  148. References
  149. ----------
  150. .. [1] Fasshauer, G., 2007. Meshfree Approximation Methods with Matlab.
  151. World Scientific Publishing Co.
  152. .. [2] http://amadeus.math.iit.edu/~fass/603_ch3.pdf
  153. .. [3] Wahba, G., 1990. Spline Models for Observational Data. SIAM.
  154. .. [4] http://pages.stat.wisc.edu/~wahba/stat860public/lect/lect8/lect8.pdf
  155. Examples
  156. --------
  157. Demonstrate interpolating scattered data to a grid in 2-D.
  158. >>> import numpy as np
  159. >>> import matplotlib.pyplot as plt
  160. >>> from scipy.interpolate import RBFInterpolator
  161. >>> from scipy.stats.qmc import Halton
  162. >>> rng = np.random.default_rng()
  163. >>> xobs = 2*Halton(2, seed=rng).random(100) - 1
  164. >>> yobs = np.sum(xobs, axis=1)*np.exp(-6*np.sum(xobs**2, axis=1))
  165. >>> x1 = np.linspace(-1, 1, 50)
  166. >>> xgrid = np.asarray(np.meshgrid(x1, x1, indexing='ij'))
  167. >>> xflat = xgrid.reshape(2, -1).T # make it a 2-D array
  168. >>> yflat = RBFInterpolator(xobs, yobs)(xflat)
  169. >>> ygrid = yflat.reshape(50, 50)
  170. >>> fig, ax = plt.subplots()
  171. >>> ax.pcolormesh(*xgrid, ygrid, vmin=-0.25, vmax=0.25, shading='gouraud')
  172. >>> p = ax.scatter(*xobs.T, c=yobs, s=50, ec='k', vmin=-0.25, vmax=0.25)
  173. >>> fig.colorbar(p)
  174. >>> plt.show()
  175. """
  176. # generic type compatibility with scipy-stubs
  177. __class_getitem__ = classmethod(GenericAlias)
  178. def __init__(self, y, d,
  179. neighbors=None,
  180. smoothing=0.0,
  181. kernel="thin_plate_spline",
  182. epsilon=None,
  183. degree=None):
  184. xp = array_namespace(y, d, smoothing)
  185. _backend = _get_backend(xp)
  186. if neighbors is not None:
  187. if not is_numpy(xp):
  188. raise NotImplementedError(
  189. "neighbors not None is numpy-only because it relies on KDTree"
  190. )
  191. y = _asarray(y, dtype=xp.float64, order="C", xp=xp)
  192. if y.ndim != 2:
  193. raise ValueError("`y` must be a 2-dimensional array.")
  194. ny, ndim = y.shape
  195. d = xp.asarray(d)
  196. if xp.isdtype(d.dtype, 'complex floating'):
  197. d_dtype = xp.complex128
  198. else:
  199. d_dtype = xp.float64
  200. d = _asarray(d, dtype=d_dtype, order="C", xp=xp)
  201. if d.shape[0] != ny:
  202. raise ValueError(
  203. f"Expected the first axis of `d` to have length {ny}."
  204. )
  205. d_shape = d.shape[1:]
  206. d = xp.reshape(d, (ny, -1))
  207. # If `d` is complex, convert it to a float array with twice as many
  208. # columns. Otherwise, the LHS matrix would need to be converted to
  209. # complex and take up 2x more memory than necessary.
  210. d = d.view(float) # NB not Array API compliant (and jax copies)
  211. if isinstance(smoothing, int | float) or smoothing.shape == ():
  212. smoothing = xp.full(ny, smoothing, dtype=xp.float64)
  213. else:
  214. smoothing = _asarray(smoothing, dtype=float, order="C", xp=xp)
  215. if smoothing.shape != (ny,):
  216. raise ValueError(
  217. "Expected `smoothing` to be a scalar or have shape "
  218. f"({ny},)."
  219. )
  220. kernel = kernel.lower()
  221. if kernel not in _AVAILABLE:
  222. raise ValueError(f"`kernel` must be one of {_AVAILABLE}.")
  223. if epsilon is None:
  224. if kernel in _SCALE_INVARIANT:
  225. epsilon = 1.0
  226. else:
  227. raise ValueError(
  228. "`epsilon` must be specified if `kernel` is not one of "
  229. f"{_SCALE_INVARIANT}."
  230. )
  231. else:
  232. epsilon = float(epsilon)
  233. min_degree = _NAME_TO_MIN_DEGREE.get(kernel, -1)
  234. if degree is None:
  235. degree = max(min_degree, 0)
  236. else:
  237. degree = int(degree)
  238. if degree < -1:
  239. raise ValueError("`degree` must be at least -1.")
  240. elif -1 < degree < min_degree:
  241. warnings.warn(
  242. f"`degree` should not be below {min_degree} except -1 "
  243. f"when `kernel` is '{kernel}'."
  244. f"The interpolant may not be uniquely "
  245. f"solvable, and the smoothing parameter may have an "
  246. f"unintuitive effect.",
  247. UserWarning, stacklevel=2
  248. )
  249. if neighbors is None:
  250. nobs = ny
  251. else:
  252. # Make sure the number of nearest neighbors used for interpolation
  253. # does not exceed the number of observations.
  254. neighbors = int(min(neighbors, ny))
  255. nobs = neighbors
  256. powers = _backend._monomial_powers(ndim, degree, xp)
  257. # The polynomial matrix must have full column rank in order for the
  258. # interpolant to be well-posed, which is not possible if there are
  259. # fewer observations than monomials.
  260. if powers.shape[0] > nobs:
  261. raise ValueError(
  262. f"At least {powers.shape[0]} data points are required when "
  263. f"`degree` is {degree} and the number of dimensions is {ndim}."
  264. )
  265. if neighbors is None:
  266. shift, scale, coeffs = _backend._build_and_solve_system(
  267. y, d, smoothing, kernel, epsilon, powers,
  268. xp
  269. )
  270. # Make these attributes private since they do not always exist.
  271. self._shift = shift
  272. self._scale = scale
  273. self._coeffs = coeffs
  274. else:
  275. self._tree = KDTree(y)
  276. self.y = y
  277. self.d = d
  278. self.d_shape = d_shape
  279. self.d_dtype = d_dtype
  280. self.neighbors = neighbors
  281. self.smoothing = smoothing
  282. self.kernel = kernel
  283. self.epsilon = epsilon
  284. self.powers = powers
  285. self._xp = xp
  286. def __setstate__(self, state):
  287. tpl1, tpl2 = state
  288. (self.y, self.d, self.d_shape, self.d_dtype, self.neighbors,
  289. self.smoothing, self.kernel, self.epsilon, self.powers) = tpl1
  290. if self.neighbors is None:
  291. self._shift, self._scale, self._coeffs = tpl2
  292. else:
  293. self._tree, = tpl2
  294. self._xp = array_namespace(self.y, self.d, self.smoothing)
  295. def __getstate__(self):
  296. tpl = (self.y, self.d, self.d_shape, self.d_dtype, self.neighbors,
  297. self.smoothing, self.kernel, self.epsilon, self.powers
  298. )
  299. if self.neighbors is None:
  300. tpl2 = (self._shift, self._scale, self._coeffs)
  301. else:
  302. tpl2 = (self._tree,)
  303. return (tpl, tpl2)
  304. def _chunk_evaluator(
  305. self,
  306. x,
  307. y,
  308. shift,
  309. scale,
  310. coeffs,
  311. memory_budget=1000000
  312. ):
  313. """
  314. Evaluate the interpolation while controlling memory consumption.
  315. We chunk the input if we need more memory than specified.
  316. Parameters
  317. ----------
  318. x : (Q, N) float ndarray
  319. array of points on which to evaluate
  320. y: (P, N) float ndarray
  321. array of points on which we know function values
  322. shift: (N, ) ndarray
  323. Domain shift used to create the polynomial matrix.
  324. scale : (N,) float ndarray
  325. Domain scaling used to create the polynomial matrix.
  326. coeffs: (P+R, S) float ndarray
  327. Coefficients in front of basis functions
  328. memory_budget: int
  329. Total amount of memory (in units of sizeof(float)) we wish
  330. to devote for storing the array of coefficients for
  331. interpolated points. If we need more memory than that, we
  332. chunk the input.
  333. Returns
  334. -------
  335. (Q, S) float ndarray
  336. Interpolated array
  337. """
  338. _backend = _get_backend(self._xp)
  339. nx, ndim = x.shape
  340. if self.neighbors is None:
  341. nnei = y.shape[0]
  342. else:
  343. nnei = self.neighbors
  344. # in each chunk we consume the same space we already occupy
  345. chunksize = memory_budget // (self.powers.shape[0] + nnei) + 1
  346. if chunksize <= nx:
  347. out = self._xp.empty((nx, self.d.shape[1]), dtype=self._xp.float64)
  348. for i in range(0, nx, chunksize):
  349. chunk = _backend.compute_interpolation(
  350. x[i:i + chunksize, :],
  351. y,
  352. self.kernel,
  353. self.epsilon,
  354. self.powers,
  355. shift,
  356. scale,
  357. coeffs,
  358. self._xp
  359. )
  360. out = xpx.at(out, (slice(i, i + chunksize), slice(None,))).set(chunk)
  361. else:
  362. out = _backend.compute_interpolation(
  363. x,
  364. y,
  365. self.kernel,
  366. self.epsilon,
  367. self.powers,
  368. shift,
  369. scale,
  370. coeffs,
  371. self._xp
  372. )
  373. return out
  374. def __call__(self, x):
  375. """Evaluate the interpolant at `x`.
  376. Parameters
  377. ----------
  378. x : (npts, ndim) array_like
  379. Evaluation point coordinates.
  380. Returns
  381. -------
  382. ndarray, shape (npts, )
  383. Values of the interpolant at `x`.
  384. """
  385. x = _asarray(x, dtype=self._xp.float64, order="C", xp=self._xp)
  386. if x.ndim != 2:
  387. raise ValueError("`x` must be a 2-dimensional array.")
  388. nx, ndim = x.shape
  389. if ndim != self.y.shape[1]:
  390. raise ValueError("Expected the second axis of `x` to have length "
  391. f"{self.y.shape[1]}.")
  392. # Our memory budget for storing RBF coefficients is
  393. # based on how many floats in memory we already occupy
  394. # If this number is below 1e6 we just use 1e6
  395. # This memory budget is used to decide how we chunk
  396. # the inputs
  397. memory_budget = max(xp_size(x) + xp_size(self.y) + xp_size(self.d), 1_000_000)
  398. if self.neighbors is None:
  399. out = self._chunk_evaluator(
  400. x,
  401. self.y,
  402. self._shift,
  403. self._scale,
  404. self._coeffs,
  405. memory_budget=memory_budget)
  406. else:
  407. # XXX: this relies on KDTree, hence is numpy-only until KDTree is converted
  408. _build_and_solve_system = _get_backend(np)._build_and_solve_system
  409. # Get the indices of the k nearest observation points to each
  410. # evaluation point.
  411. _, yindices = self._tree.query(x, self.neighbors)
  412. if self.neighbors == 1:
  413. # `KDTree` squeezes the output when neighbors=1.
  414. yindices = yindices[:, None]
  415. # Multiple evaluation points may have the same neighborhood of
  416. # observation points. Make the neighborhoods unique so that we only
  417. # compute the interpolation coefficients once for each
  418. # neighborhood.
  419. yindices = np.sort(yindices, axis=1)
  420. yindices, inv = np.unique(yindices, return_inverse=True, axis=0)
  421. inv = np.reshape(inv, (-1,)) # flatten, we need 1-D indices
  422. # `inv` tells us which neighborhood will be used by each evaluation
  423. # point. Now we find which evaluation points will be using each
  424. # neighborhood.
  425. xindices = [[] for _ in range(len(yindices))]
  426. for i, j in enumerate(inv):
  427. xindices[j].append(i)
  428. out = np.empty((nx, self.d.shape[1]), dtype=float)
  429. for xidx, yidx in zip(xindices, yindices):
  430. # `yidx` are the indices of the observations in this
  431. # neighborhood. `xidx` are the indices of the evaluation points
  432. # that are using this neighborhood.
  433. xnbr = x[xidx]
  434. ynbr = self.y[yidx]
  435. dnbr = self.d[yidx]
  436. snbr = self.smoothing[yidx]
  437. shift, scale, coeffs = _build_and_solve_system(
  438. ynbr,
  439. dnbr,
  440. snbr,
  441. self.kernel,
  442. self.epsilon,
  443. self.powers,
  444. np
  445. )
  446. out[xidx] = self._chunk_evaluator(
  447. xnbr,
  448. ynbr,
  449. shift,
  450. scale,
  451. coeffs,
  452. memory_budget=memory_budget)
  453. out = out.view(self.d_dtype) # NB not Array API compliant (and jax copies)
  454. out = self._xp.reshape(out, (nx, ) + self.d_shape)
  455. return out