_rgi.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812
  1. __all__ = ['RegularGridInterpolator', 'interpn']
  2. import itertools
  3. from types import GenericAlias
  4. import numpy as np
  5. import scipy.sparse.linalg as ssl
  6. from scipy._lib._array_api import array_namespace, xp_capabilities
  7. from scipy._lib.array_api_compat import is_array_api_obj
  8. from ._interpnd import _ndim_coords_from_arrays
  9. from ._cubic import PchipInterpolator
  10. from ._rgi_cython import evaluate_linear_2d, find_indices
  11. from ._bsplines import make_interp_spline
  12. from ._fitpack2 import RectBivariateSpline
  13. from ._ndbspline import make_ndbspl
  14. def _check_points(points):
  15. descending_dimensions = []
  16. grid = []
  17. for i, p in enumerate(points):
  18. # early make points float
  19. # see https://github.com/scipy/scipy/pull/17230
  20. p = np.asarray(p, dtype=float)
  21. if not np.all(p[1:] > p[:-1]):
  22. if np.all(p[1:] < p[:-1]):
  23. # input is descending, so make it ascending
  24. descending_dimensions.append(i)
  25. p = np.flip(p)
  26. else:
  27. raise ValueError(
  28. f"The points in dimension {i} must be strictly ascending or "
  29. f"descending"
  30. )
  31. # see https://github.com/scipy/scipy/issues/17716
  32. p = np.ascontiguousarray(p)
  33. grid.append(p)
  34. return tuple(grid), tuple(descending_dimensions)
  35. def _check_dimensionality(points, values):
  36. if len(points) > values.ndim:
  37. raise ValueError(
  38. f"There are {len(points)} point arrays, but values has "
  39. f"{values.ndim} dimensions"
  40. )
  41. for i, p in enumerate(points):
  42. if not np.asarray(p).ndim == 1:
  43. raise ValueError(f"The points in dimension {i} must be 1-dimensional")
  44. if not values.shape[i] == len(p):
  45. raise ValueError(
  46. f"There are {len(p)} points and {values.shape[i]} values in "
  47. f"dimension {i}"
  48. )
  49. @xp_capabilities(
  50. cpu_only=True, jax_jit=False,
  51. skip_backends=[
  52. ("dask.array",
  53. "https://github.com/data-apis/array-api-extra/issues/488")
  54. ]
  55. )
  56. class RegularGridInterpolator:
  57. """Interpolator of specified order on a rectilinear grid in N ≥ 1 dimensions.
  58. The data must be defined on a rectilinear grid; that is, a rectangular
  59. grid with even or uneven spacing. Linear, nearest-neighbor, spline
  60. interpolations are supported. After setting up the interpolator object,
  61. the interpolation method may be chosen at each evaluation.
  62. Parameters
  63. ----------
  64. points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, )
  65. The points defining the regular grid in n dimensions. The points in
  66. each dimension (i.e. every elements of the points tuple) must be
  67. strictly ascending or descending.
  68. values : array_like, shape (m1, ..., mn, ...)
  69. The data on the regular grid in n dimensions. Complex data is
  70. accepted.
  71. method : str, optional
  72. The method of interpolation to perform. Supported are "linear",
  73. "nearest", "slinear", "cubic", "quintic" and "pchip". This
  74. parameter will become the default for the object's ``__call__``
  75. method. Default is "linear".
  76. bounds_error : bool, optional
  77. If True, when interpolated values are requested outside of the
  78. domain of the input data, a ValueError is raised.
  79. If False, then `fill_value` is used.
  80. Default is True.
  81. fill_value : float or None, optional
  82. The value to use for points outside of the interpolation domain.
  83. If None, values outside the domain are extrapolated.
  84. Default is ``np.nan``.
  85. solver : callable, optional
  86. Only used for methods "slinear", "cubic" and "quintic".
  87. Sparse linear algebra solver for construction of the NdBSpline instance.
  88. Default is the iterative solver `scipy.sparse.linalg.gcrotmk`.
  89. .. versionadded:: 1.13
  90. solver_args: dict, optional
  91. Additional arguments to pass to `solver`, if any.
  92. .. versionadded:: 1.13
  93. Methods
  94. -------
  95. __call__
  96. Attributes
  97. ----------
  98. grid : tuple of ndarrays
  99. The points defining the regular grid in n dimensions.
  100. This tuple defines the full grid via
  101. ``np.meshgrid(*grid, indexing='ij')``
  102. values : ndarray
  103. Data values at the grid.
  104. method : str
  105. Interpolation method.
  106. fill_value : float or ``None``
  107. Use this value for out-of-bounds arguments to `__call__`.
  108. bounds_error : bool
  109. If ``True``, out-of-bounds argument raise a ``ValueError``.
  110. Notes
  111. -----
  112. Contrary to `LinearNDInterpolator` and `NearestNDInterpolator`, this class
  113. avoids expensive triangulation of the input data by taking advantage of the
  114. regular grid structure.
  115. In other words, this class assumes that the data is defined on a
  116. *rectilinear* grid.
  117. .. versionadded:: 0.14
  118. The 'slinear'(k=1), 'cubic'(k=3), and 'quintic'(k=5) methods are
  119. tensor-product spline interpolators, where `k` is the spline degree,
  120. If any dimension has fewer points than `k` + 1, an error will be raised.
  121. .. versionadded:: 1.9
  122. If the input data is such that dimensions have incommensurate
  123. units and differ by many orders of magnitude, the interpolant may have
  124. numerical artifacts. Consider rescaling the data before interpolating.
  125. **Choosing a solver for spline methods**
  126. Spline methods, "slinear", "cubic" and "quintic" involve solving a
  127. large sparse linear system at instantiation time. Depending on data,
  128. the default solver may or may not be adequate. When it is not, you may
  129. need to experiment with an optional `solver` argument, where you may
  130. choose between the direct solver (`scipy.sparse.linalg.spsolve`) or
  131. iterative solvers from `scipy.sparse.linalg`. You may need to supply
  132. additional parameters via the optional `solver_args` parameter (for instance,
  133. you may supply the starting value or target tolerance). See the
  134. `scipy.sparse.linalg` documentation for the full list of available options.
  135. Alternatively, you may instead use the legacy methods, "slinear_legacy",
  136. "cubic_legacy" and "quintic_legacy". These methods allow faster construction
  137. but evaluations will be much slower.
  138. **Rounding rule at half points with `nearest` method**
  139. The rounding rule with the `nearest` method at half points is rounding *down*.
  140. Examples
  141. --------
  142. **Evaluate a function on the points of a 3-D grid**
  143. As a first example, we evaluate a simple example function on the points of
  144. a 3-D grid:
  145. >>> from scipy.interpolate import RegularGridInterpolator
  146. >>> import numpy as np
  147. >>> def f(x, y, z):
  148. ... return 2 * x**3 + 3 * y**2 - z
  149. >>> x = np.linspace(1, 4, 11)
  150. >>> y = np.linspace(4, 7, 22)
  151. >>> z = np.linspace(7, 9, 33)
  152. >>> xg, yg ,zg = np.meshgrid(x, y, z, indexing='ij', sparse=True)
  153. >>> data = f(xg, yg, zg)
  154. ``data`` is now a 3-D array with ``data[i, j, k] = f(x[i], y[j], z[k])``.
  155. Next, define an interpolating function from this data:
  156. >>> interp = RegularGridInterpolator((x, y, z), data)
  157. Evaluate the interpolating function at the two points
  158. ``(x,y,z) = (2.1, 6.2, 8.3)`` and ``(3.3, 5.2, 7.1)``:
  159. >>> pts = np.array([[2.1, 6.2, 8.3],
  160. ... [3.3, 5.2, 7.1]])
  161. >>> interp(pts)
  162. array([ 125.80469388, 146.30069388])
  163. which is indeed a close approximation to
  164. >>> f(2.1, 6.2, 8.3), f(3.3, 5.2, 7.1)
  165. (125.54200000000002, 145.894)
  166. **Interpolate and extrapolate a 2D dataset**
  167. As a second example, we interpolate and extrapolate a 2D data set:
  168. >>> x, y = np.array([-2, 0, 4]), np.array([-2, 0, 2, 5])
  169. >>> def ff(x, y):
  170. ... return x**2 + y**2
  171. >>> xg, yg = np.meshgrid(x, y, indexing='ij')
  172. >>> data = ff(xg, yg)
  173. >>> interp = RegularGridInterpolator((x, y), data,
  174. ... bounds_error=False, fill_value=None)
  175. >>> import matplotlib.pyplot as plt
  176. >>> fig = plt.figure()
  177. >>> ax = fig.add_subplot(projection='3d')
  178. >>> ax.scatter(xg.ravel(), yg.ravel(), data.ravel(),
  179. ... s=60, c='k', label='data')
  180. Evaluate and plot the interpolator on a finer grid
  181. >>> xx = np.linspace(-4, 9, 31)
  182. >>> yy = np.linspace(-4, 9, 31)
  183. >>> X, Y = np.meshgrid(xx, yy, indexing='ij')
  184. >>> # interpolator
  185. >>> ax.plot_wireframe(X, Y, interp((X, Y)), rstride=3, cstride=3,
  186. ... alpha=0.4, color='m', label='linear interp')
  187. >>> # ground truth
  188. >>> ax.plot_wireframe(X, Y, ff(X, Y), rstride=3, cstride=3,
  189. ... alpha=0.4, label='ground truth')
  190. >>> plt.legend()
  191. >>> plt.show()
  192. Other examples are given
  193. :ref:`in the tutorial <tutorial-interpolate_regular_grid_interpolator>`.
  194. See Also
  195. --------
  196. NearestNDInterpolator : Nearest neighbor interpolator on *unstructured*
  197. data in N dimensions
  198. LinearNDInterpolator : Piecewise linear interpolator on *unstructured* data
  199. in N dimensions
  200. interpn : a convenience function which wraps `RegularGridInterpolator`
  201. scipy.ndimage.map_coordinates : interpolation on grids with equal spacing
  202. (suitable for e.g., N-D image resampling)
  203. References
  204. ----------
  205. .. [1] Python package *regulargrid* by Johannes Buchner, see
  206. https://pypi.python.org/pypi/regulargrid/
  207. .. [2] Wikipedia, "Trilinear interpolation",
  208. https://en.wikipedia.org/wiki/Trilinear_interpolation
  209. .. [3] Weiser, Alan, and Sergio E. Zarantonello. "A note on piecewise linear
  210. and multilinear table interpolation in many dimensions." MATH.
  211. COMPUT. 50.181 (1988): 189-196.
  212. https://www.ams.org/journals/mcom/1988-50-181/S0025-5718-1988-0917826-0/S0025-5718-1988-0917826-0.pdf
  213. :doi:`10.1090/S0025-5718-1988-0917826-0`
  214. """
  215. # this class is based on code originally programmed by Johannes Buchner,
  216. # see https://github.com/JohannesBuchner/regulargrid
  217. _SPLINE_DEGREE_MAP = {"slinear": 1, "cubic": 3, "quintic": 5, 'pchip': 3,
  218. "slinear_legacy": 1, "cubic_legacy": 3, "quintic_legacy": 5,}
  219. _SPLINE_METHODS_recursive = {"slinear_legacy", "cubic_legacy",
  220. "quintic_legacy", "pchip"}
  221. _SPLINE_METHODS_ndbspl = {"slinear", "cubic", "quintic"}
  222. _SPLINE_METHODS = list(_SPLINE_DEGREE_MAP.keys())
  223. _ALL_METHODS = ["linear", "nearest"] + _SPLINE_METHODS
  224. # generic type compatibility with scipy-stubs
  225. __class_getitem__ = classmethod(GenericAlias)
  226. def __init__(self, points, values, method="linear", bounds_error=True,
  227. fill_value=np.nan, *, solver=None, solver_args=None):
  228. if method not in self._ALL_METHODS:
  229. raise ValueError(f"Method '{method}' is not defined")
  230. elif method in self._SPLINE_METHODS:
  231. self._validate_grid_dimensions(points, method) # NB: uses np.atleast_1d
  232. try:
  233. xp = array_namespace(*points, values)
  234. except Exception as e:
  235. # either "duck-type" values or a user error?
  236. xp = array_namespace(*points) # still forbid mixed namespaces in `points`
  237. try:
  238. xp_v = array_namespace(values)
  239. except Exception:
  240. # "duck-type" values indeed, continue with `xp` as the namespace
  241. pass
  242. else:
  243. # both `points` and `values` are array API objects, check consistency
  244. if xp_v != xp:
  245. raise e
  246. self._asarray = xp.asarray
  247. self.method = method
  248. self._spline = None
  249. self.bounds_error = bounds_error
  250. self._grid, self._descending_dimensions = _check_points(points)
  251. self._values = self._check_values(values)
  252. self._check_dimensionality(self._grid, self._values)
  253. self.fill_value = self._check_fill_value(self._values, fill_value)
  254. if self._descending_dimensions:
  255. self._values = np.flip(values, axis=self._descending_dimensions)
  256. if self.method == "pchip" and np.iscomplexobj(self._values):
  257. msg = ("`PchipInterpolator` only works with real values. If you are trying "
  258. "to use the real components of the passed array, use `np.real` on "
  259. "the array before passing to `RegularGridInterpolator`.")
  260. raise ValueError(msg)
  261. if method in self._SPLINE_METHODS_ndbspl:
  262. if solver_args is None:
  263. solver_args = {}
  264. self._spline = self._construct_spline(method, solver, **solver_args)
  265. else:
  266. if solver is not None or solver_args:
  267. raise ValueError(
  268. f"{method =} does not accept the 'solver' argument. Got "
  269. f" {solver = } and with arguments {solver_args}."
  270. )
  271. def _construct_spline(self, method, solver=None, **solver_args):
  272. if solver is None:
  273. solver = ssl.gcrotmk
  274. spl = make_ndbspl(
  275. self._grid, self._values, self._SPLINE_DEGREE_MAP[method],
  276. solver=solver, **solver_args
  277. )
  278. return spl
  279. def _check_dimensionality(self, grid, values):
  280. _check_dimensionality(grid, values)
  281. def _check_points(self, points):
  282. return _check_points(points)
  283. def _check_values(self, values):
  284. if is_array_api_obj(values):
  285. values = np.asarray(values)
  286. if not hasattr(values, 'ndim'):
  287. # allow reasonable duck-typed values
  288. values = np.asarray(values)
  289. if hasattr(values, 'dtype') and hasattr(values, 'astype'):
  290. if not np.issubdtype(values.dtype, np.inexact):
  291. values = values.astype(float)
  292. return values
  293. def _check_fill_value(self, values, fill_value):
  294. if fill_value is not None:
  295. fill_value_dtype = np.asarray(fill_value).dtype
  296. if (hasattr(values, 'dtype') and not
  297. np.can_cast(fill_value_dtype, values.dtype,
  298. casting='same_kind')):
  299. raise ValueError("fill_value must be either 'None' or "
  300. "of a type compatible with values")
  301. return fill_value
  302. def __call__(self, xi, method=None, *, nu=None):
  303. """
  304. Interpolation at coordinates.
  305. Parameters
  306. ----------
  307. xi : ndarray of shape (..., ndim)
  308. The coordinates to evaluate the interpolator at.
  309. method : str, optional
  310. The method of interpolation to perform. Supported are "linear",
  311. "nearest", "slinear", "cubic", "quintic" and "pchip". Default is
  312. the method chosen when the interpolator was created.
  313. nu : sequence of ints, length ndim, optional
  314. If not None, the orders of the derivatives to evaluate.
  315. Each entry must be non-negative.
  316. Only allowed for methods "slinear", "cubic" and "quintic".
  317. .. versionadded:: 1.13
  318. Returns
  319. -------
  320. values_x : ndarray, shape xi.shape[:-1] + values.shape[ndim:]
  321. Interpolated values at `xi`. See notes for behaviour when
  322. ``xi.ndim == 1``.
  323. Notes
  324. -----
  325. In the case that ``xi.ndim == 1`` a new axis is inserted into
  326. the 0 position of the returned array, values_x, so its shape is
  327. instead ``(1,) + values.shape[ndim:]``.
  328. Examples
  329. --------
  330. Here we define a nearest-neighbor interpolator of a simple function
  331. >>> import numpy as np
  332. >>> x, y = np.array([0, 1, 2]), np.array([1, 3, 7])
  333. >>> def f(x, y):
  334. ... return x**2 + y**2
  335. >>> data = f(*np.meshgrid(x, y, indexing='ij', sparse=True))
  336. >>> from scipy.interpolate import RegularGridInterpolator
  337. >>> interp = RegularGridInterpolator((x, y), data, method='nearest')
  338. By construction, the interpolator uses the nearest-neighbor
  339. interpolation
  340. >>> interp([[1.5, 1.3], [0.3, 4.5]])
  341. array([2., 9.])
  342. We can however evaluate the linear interpolant by overriding the
  343. `method` parameter
  344. >>> interp([[1.5, 1.3], [0.3, 4.5]], method='linear')
  345. array([ 4.7, 24.3])
  346. """
  347. _spline = self._spline
  348. method = self.method if method is None else method
  349. is_method_changed = self.method != method
  350. if method not in self._ALL_METHODS:
  351. raise ValueError(f"Method '{method}' is not defined")
  352. if is_method_changed and method in self._SPLINE_METHODS_ndbspl:
  353. _spline = self._construct_spline(method)
  354. if nu is not None and method not in self._SPLINE_METHODS_ndbspl:
  355. raise ValueError(
  356. f"Can only compute derivatives for methods "
  357. f"{self._SPLINE_METHODS_ndbspl}, got {method =}."
  358. )
  359. xi, xi_shape, ndim, nans, out_of_bounds = self._prepare_xi(xi)
  360. if method == "linear":
  361. indices, norm_distances = self._find_indices(xi.T)
  362. if (ndim == 2 and hasattr(self._values, 'dtype') and
  363. self._values.ndim == 2 and self._values.flags.writeable and
  364. self._values.dtype in (np.float64, np.complex128) and
  365. self._values.dtype.byteorder == '='):
  366. # until cython supports const fused types, the fast path
  367. # cannot support non-writeable values
  368. # a fast path
  369. out = np.empty(indices.shape[1], dtype=self._values.dtype)
  370. result = evaluate_linear_2d(self._values,
  371. indices,
  372. norm_distances,
  373. self._grid,
  374. out)
  375. else:
  376. result = self._evaluate_linear(indices, norm_distances)
  377. elif method == "nearest":
  378. indices, norm_distances = self._find_indices(xi.T)
  379. result = self._evaluate_nearest(indices, norm_distances)
  380. elif method in self._SPLINE_METHODS:
  381. if is_method_changed:
  382. self._validate_grid_dimensions(self._grid, method)
  383. if method in self._SPLINE_METHODS_recursive:
  384. result = self._evaluate_spline(xi, method)
  385. else:
  386. result = _spline(xi, nu=nu)
  387. if not self.bounds_error and self.fill_value is not None:
  388. result[out_of_bounds] = self.fill_value
  389. # f(nan) = nan, if any
  390. if np.any(nans):
  391. result[nans] = np.nan
  392. return self._asarray(result.reshape(xi_shape[:-1] + self._values.shape[ndim:]))
  393. @property
  394. def grid(self):
  395. return tuple(self._asarray(p) for p in self._grid)
  396. @property
  397. def values(self):
  398. return self._asarray(self._values)
  399. def _prepare_xi(self, xi):
  400. ndim = len(self._grid)
  401. xi = _ndim_coords_from_arrays(xi, ndim=ndim)
  402. if xi.shape[-1] != ndim:
  403. raise ValueError("The requested sample points xi have dimension "
  404. f"{xi.shape[-1]} but this "
  405. f"RegularGridInterpolator has dimension {ndim}")
  406. xi_shape = xi.shape
  407. xi = xi.reshape(-1, xi_shape[-1])
  408. xi = np.asarray(xi, dtype=float)
  409. # find nans in input
  410. nans = np.any(np.isnan(xi), axis=-1)
  411. if self.bounds_error:
  412. for i, p in enumerate(xi.T):
  413. if not np.logical_and(np.all(self._grid[i][0] <= p),
  414. np.all(p <= self._grid[i][-1])):
  415. raise ValueError(
  416. f"One of the requested xi is out of bounds in dimension {i}"
  417. )
  418. out_of_bounds = None
  419. else:
  420. out_of_bounds = self._find_out_of_bounds(xi.T)
  421. return xi, xi_shape, ndim, nans, out_of_bounds
  422. def _evaluate_linear(self, indices, norm_distances):
  423. # slice for broadcasting over trailing dimensions in self._values
  424. vslice = (slice(None),) + (None,)*(self._values.ndim - len(indices))
  425. # Compute shifting up front before zipping everything together
  426. shift_norm_distances = [1 - yi for yi in norm_distances]
  427. shift_indices = [i + 1 for i in indices]
  428. # The formula for linear interpolation in 2d takes the form:
  429. # values = self._values[(i0, i1)] * (1 - y0) * (1 - y1) + \
  430. # self._values[(i0, i1 + 1)] * (1 - y0) * y1 + \
  431. # self._values[(i0 + 1, i1)] * y0 * (1 - y1) + \
  432. # self._values[(i0 + 1, i1 + 1)] * y0 * y1
  433. # We pair i with 1 - yi (zipped1) and i + 1 with yi (zipped2)
  434. zipped1 = zip(indices, shift_norm_distances)
  435. zipped2 = zip(shift_indices, norm_distances)
  436. # Take all products of zipped1 and zipped2 and iterate over them
  437. # to get the terms in the above formula. This corresponds to iterating
  438. # over the vertices of a hypercube.
  439. hypercube = itertools.product(*zip(zipped1, zipped2))
  440. value = np.array([0.])
  441. for h in hypercube:
  442. edge_indices, weights = zip(*h)
  443. weight = np.array([1.])
  444. for w in weights:
  445. weight = weight * w
  446. term = np.asarray(self._values[edge_indices]) * weight[vslice]
  447. value = value + term # cannot use += because broadcasting
  448. return value
  449. def _evaluate_nearest(self, indices, norm_distances):
  450. idx_res = [np.where(yi <= .5, i, i + 1)
  451. for i, yi in zip(indices, norm_distances)]
  452. return self._values[tuple(idx_res)]
  453. def _validate_grid_dimensions(self, points, method):
  454. k = self._SPLINE_DEGREE_MAP[method]
  455. for i, point in enumerate(points):
  456. ndim = len(np.atleast_1d(point))
  457. if ndim <= k:
  458. raise ValueError(f"There are {ndim} points in dimension {i},"
  459. f" but method {method} requires at least "
  460. f" {k+1} points per dimension.")
  461. def _evaluate_spline(self, xi, method):
  462. # ensure xi is 2D list of points to evaluate (`m` is the number of
  463. # points and `n` is the number of interpolation dimensions,
  464. # ``n == len(self._grid)``.)
  465. if xi.ndim == 1:
  466. xi = xi.reshape((1, xi.size))
  467. m, n = xi.shape
  468. # Reorder the axes: n-dimensional process iterates over the
  469. # interpolation axes from the last axis downwards: E.g. for a 4D grid
  470. # the order of axes is 3, 2, 1, 0. Each 1D interpolation works along
  471. # the 0th axis of its argument array (for 1D routine it's its ``y``
  472. # array). Thus permute the interpolation axes of `values` *and keep
  473. # trailing dimensions trailing*.
  474. axes = tuple(range(self._values.ndim))
  475. axx = axes[:n][::-1] + axes[n:]
  476. values = self._values.transpose(axx)
  477. if method == 'pchip':
  478. _eval_func = self._do_pchip
  479. else:
  480. _eval_func = self._do_spline_fit
  481. k = self._SPLINE_DEGREE_MAP[method]
  482. # Non-stationary procedure: difficult to vectorize this part entirely
  483. # into numpy-level operations. Unfortunately this requires explicit
  484. # looping over each point in xi.
  485. # can at least vectorize the first pass across all points in the
  486. # last variable of xi.
  487. last_dim = n - 1
  488. first_values = _eval_func(self._grid[last_dim],
  489. values,
  490. xi[:, last_dim],
  491. k)
  492. # the rest of the dimensions have to be on a per point-in-xi basis
  493. shape = (m, *self._values.shape[n:])
  494. result = np.empty(shape, dtype=self._values.dtype)
  495. for j in range(m):
  496. # Main process: Apply 1D interpolate in each dimension
  497. # sequentially, starting with the last dimension.
  498. # These are then "folded" into the next dimension in-place.
  499. folded_values = first_values[j, ...]
  500. for i in range(last_dim-1, -1, -1):
  501. # Interpolate for each 1D from the last dimensions.
  502. # This collapses each 1D sequence into a scalar.
  503. folded_values = _eval_func(self._grid[i],
  504. folded_values,
  505. xi[j, i],
  506. k)
  507. result[j, ...] = folded_values
  508. return result
  509. @staticmethod
  510. def _do_spline_fit(x, y, pt, k):
  511. local_interp = make_interp_spline(x, y, k=k, axis=0)
  512. values = local_interp(pt)
  513. return values
  514. @staticmethod
  515. def _do_pchip(x, y, pt, k):
  516. local_interp = PchipInterpolator(x, y, axis=0)
  517. values = local_interp(pt)
  518. return values
  519. def _find_indices(self, xi):
  520. return find_indices(self._grid, xi)
  521. def _find_out_of_bounds(self, xi):
  522. # check for out of bounds xi
  523. out_of_bounds = np.zeros((xi.shape[1]), dtype=bool)
  524. # iterate through dimensions
  525. for x, grid in zip(xi, self._grid):
  526. out_of_bounds += x < grid[0]
  527. out_of_bounds += x > grid[-1]
  528. return out_of_bounds
  529. def interpn(points, values, xi, method="linear", bounds_error=True,
  530. fill_value=np.nan):
  531. """
  532. Multidimensional interpolation on regular or rectilinear grids.
  533. Strictly speaking, not all regular grids are supported - this function
  534. works on *rectilinear* grids, that is, a rectangular grid with even or
  535. uneven spacing.
  536. Parameters
  537. ----------
  538. points : tuple of ndarray of float, with shapes (m1, ), ..., (mn, )
  539. The points defining the regular grid in n dimensions. The points in
  540. each dimension (i.e. every elements of the points tuple) must be
  541. strictly ascending or descending.
  542. values : array_like, shape (m1, ..., mn, ...)
  543. The data on the regular grid in n dimensions. Complex data is
  544. accepted.
  545. xi : ndarray of shape (..., ndim)
  546. The coordinates to sample the gridded data at
  547. method : str, optional
  548. The method of interpolation to perform. Supported are "linear",
  549. "nearest", "slinear", "cubic", "quintic", "pchip", and "splinef2d".
  550. "splinef2d" is only supported for 2-dimensional data.
  551. bounds_error : bool, optional
  552. If True, when interpolated values are requested outside of the
  553. domain of the input data, a ValueError is raised.
  554. If False, then `fill_value` is used.
  555. fill_value : number, optional
  556. If provided, the value to use for points outside of the
  557. interpolation domain. If None, values outside
  558. the domain are extrapolated. Extrapolation is not supported by method
  559. "splinef2d".
  560. Returns
  561. -------
  562. values_x : ndarray, shape xi.shape[:-1] + values.shape[ndim:]
  563. Interpolated values at `xi`. See notes for behaviour when
  564. ``xi.ndim == 1``.
  565. See Also
  566. --------
  567. NearestNDInterpolator : Nearest neighbor interpolation on unstructured
  568. data in N dimensions
  569. LinearNDInterpolator : Piecewise linear interpolant on unstructured data
  570. in N dimensions
  571. RegularGridInterpolator : interpolation on a regular or rectilinear grid
  572. in arbitrary dimensions (`interpn` wraps this
  573. class).
  574. RectBivariateSpline : Bivariate spline approximation over a rectangular mesh
  575. scipy.ndimage.map_coordinates : interpolation on grids with equal spacing
  576. (suitable for e.g., N-D image resampling)
  577. Notes
  578. -----
  579. .. versionadded:: 0.14
  580. In the case that ``xi.ndim == 1`` a new axis is inserted into
  581. the 0 position of the returned array, values_x, so its shape is
  582. instead ``(1,) + values.shape[ndim:]``.
  583. If the input data is such that input dimensions have incommensurate
  584. units and differ by many orders of magnitude, the interpolant may have
  585. numerical artifacts. Consider rescaling the data before interpolation.
  586. Examples
  587. --------
  588. Evaluate a simple example function on the points of a regular 3-D grid:
  589. >>> import numpy as np
  590. >>> from scipy.interpolate import interpn
  591. >>> def value_func_3d(x, y, z):
  592. ... return 2 * x + 3 * y - z
  593. >>> x = np.linspace(0, 4, 5)
  594. >>> y = np.linspace(0, 5, 6)
  595. >>> z = np.linspace(0, 6, 7)
  596. >>> points = (x, y, z)
  597. >>> values = value_func_3d(*np.meshgrid(*points, indexing='ij'))
  598. Evaluate the interpolating function at a point
  599. >>> point = np.array([2.21, 3.12, 1.15])
  600. >>> print(interpn(points, values, point))
  601. [12.63]
  602. Compare with value at point by function
  603. >>> value_func_3d(*point)
  604. 12.63 # up to rounding
  605. Since the function is linear, the interpolation is exact using linear method.
  606. """
  607. # sanity check 'method' kwarg
  608. if method not in ["linear", "nearest", "cubic", "quintic", "pchip",
  609. "splinef2d", "slinear",
  610. "slinear_legacy", "cubic_legacy", "quintic_legacy"]:
  611. raise ValueError("interpn only understands the methods 'linear', "
  612. "'nearest', 'slinear', 'cubic', 'quintic', 'pchip', "
  613. f"and 'splinef2d'. You provided {method}.")
  614. if not hasattr(values, 'ndim'):
  615. values = np.asarray(values)
  616. ndim = values.ndim
  617. if ndim > 2 and method == "splinef2d":
  618. raise ValueError("The method splinef2d can only be used for "
  619. "2-dimensional input data")
  620. if not bounds_error and fill_value is None and method == "splinef2d":
  621. raise ValueError("The method splinef2d does not support extrapolation.")
  622. # sanity check consistency of input dimensions
  623. if len(points) > ndim:
  624. raise ValueError(
  625. f"There are {len(points)} point arrays, but values has {ndim} dimensions"
  626. )
  627. if len(points) != ndim and method == 'splinef2d':
  628. raise ValueError("The method splinef2d can only be used for "
  629. "scalar data with one point per coordinate")
  630. grid, descending_dimensions = _check_points(points)
  631. _check_dimensionality(grid, values)
  632. # sanity check requested xi
  633. xi = _ndim_coords_from_arrays(xi, ndim=len(grid))
  634. if xi.shape[-1] != len(grid):
  635. raise ValueError(
  636. f"The requested sample points xi have dimension {xi.shape[-1]}, "
  637. f"but this RegularGridInterpolator has dimension {len(grid)}"
  638. )
  639. if bounds_error:
  640. for i, p in enumerate(xi.T):
  641. if not np.logical_and(np.all(grid[i][0] <= p),
  642. np.all(p <= grid[i][-1])):
  643. raise ValueError(
  644. f"One of the requested xi is out of bounds in dimension {i}"
  645. )
  646. # perform interpolation
  647. if method in RegularGridInterpolator._ALL_METHODS:
  648. interp = RegularGridInterpolator(points, values, method=method,
  649. bounds_error=bounds_error,
  650. fill_value=fill_value)
  651. return interp(xi)
  652. elif method == "splinef2d":
  653. xi_shape = xi.shape
  654. xi = xi.reshape(-1, xi.shape[-1])
  655. # RectBivariateSpline doesn't support fill_value; we need to wrap here
  656. idx_valid = np.all((grid[0][0] <= xi[:, 0], xi[:, 0] <= grid[0][-1],
  657. grid[1][0] <= xi[:, 1], xi[:, 1] <= grid[1][-1]),
  658. axis=0)
  659. result = np.empty_like(xi[:, 0])
  660. # make a copy of values for RectBivariateSpline
  661. interp = RectBivariateSpline(points[0], points[1], values[:])
  662. result[idx_valid] = interp.ev(xi[idx_valid, 0], xi[idx_valid, 1])
  663. result[np.logical_not(idx_valid)] = fill_value
  664. return result.reshape(xi_shape[:-1])
  665. else:
  666. raise ValueError(f"unknown {method = }")