_ndgriddata.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. """
  2. Convenience interface to N-D interpolation
  3. .. versionadded:: 0.9
  4. """
  5. import numpy as np
  6. from ._interpnd import (LinearNDInterpolator, NDInterpolatorBase,
  7. CloughTocher2DInterpolator, _ndim_coords_from_arrays)
  8. from scipy.spatial import cKDTree
  9. __all__ = ['griddata', 'NearestNDInterpolator', 'LinearNDInterpolator',
  10. 'CloughTocher2DInterpolator']
  11. #------------------------------------------------------------------------------
  12. # Nearest-neighbor interpolation
  13. #------------------------------------------------------------------------------
  14. class NearestNDInterpolator(NDInterpolatorBase):
  15. """Nearest-neighbor interpolator in N > 1 dimensions.
  16. Methods
  17. -------
  18. __call__
  19. Parameters
  20. ----------
  21. x : (npoints, ndims) 2-D ndarray of floats
  22. Data point coordinates.
  23. y : (npoints, ...) N-D ndarray of float or complex
  24. Data values. The length of `y` along the first axis must be equal to
  25. the length of `x`.
  26. rescale : boolean, optional
  27. Rescale points to unit cube before performing interpolation.
  28. This is useful if some of the input dimensions have
  29. incommensurable units and differ by many orders of magnitude.
  30. .. versionadded:: 0.14.0
  31. tree_options : dict, optional
  32. Options passed to the underlying ``cKDTree``.
  33. .. versionadded:: 0.17.0
  34. See Also
  35. --------
  36. griddata :
  37. Interpolate unstructured D-D data.
  38. LinearNDInterpolator :
  39. Piecewise linear interpolator in N dimensions.
  40. CloughTocher2DInterpolator :
  41. Piecewise cubic, C1 smooth, curvature-minimizing interpolator in 2D.
  42. interpn : Interpolation on a regular grid or rectilinear grid.
  43. RegularGridInterpolator : Interpolator on a regular or rectilinear grid
  44. in arbitrary dimensions (`interpn` wraps this
  45. class).
  46. Notes
  47. -----
  48. Uses ``scipy.spatial.cKDTree``
  49. .. note:: For data on a regular grid use `interpn` instead.
  50. Examples
  51. --------
  52. We can interpolate values on a 2D plane:
  53. >>> from scipy.interpolate import NearestNDInterpolator
  54. >>> import numpy as np
  55. >>> import matplotlib.pyplot as plt
  56. >>> rng = np.random.default_rng()
  57. >>> x = rng.random(10) - 0.5
  58. >>> y = rng.random(10) - 0.5
  59. >>> z = np.hypot(x, y)
  60. >>> X = np.linspace(min(x), max(x))
  61. >>> Y = np.linspace(min(y), max(y))
  62. >>> X, Y = np.meshgrid(X, Y) # 2D grid for interpolation
  63. >>> interp = NearestNDInterpolator(list(zip(x, y)), z)
  64. >>> Z = interp(X, Y)
  65. >>> plt.pcolormesh(X, Y, Z, shading='auto')
  66. >>> plt.plot(x, y, "ok", label="input point")
  67. >>> plt.legend()
  68. >>> plt.colorbar()
  69. >>> plt.axis("equal")
  70. >>> plt.show()
  71. """
  72. def __init__(self, x, y, rescale=False, tree_options=None):
  73. NDInterpolatorBase.__init__(self, x, y, rescale=rescale,
  74. need_contiguous=False,
  75. need_values=False)
  76. if tree_options is None:
  77. tree_options = dict()
  78. self.tree = cKDTree(self.points, **tree_options)
  79. self.values = np.asarray(y)
  80. def __call__(self, *args, **query_options):
  81. """
  82. Evaluate interpolator at given points.
  83. Parameters
  84. ----------
  85. x1, x2, ... xn : array-like of float
  86. Points where to interpolate data at.
  87. x1, x2, ... xn can be array-like of float with broadcastable shape.
  88. or x1 can be array-like of float with shape ``(..., ndim)``
  89. **query_options
  90. This allows ``eps``, ``p``, ``distance_upper_bound``, and ``workers``
  91. being passed to the cKDTree's query function to be explicitly set.
  92. See `scipy.spatial.cKDTree.query` for an overview of the different options.
  93. .. versionadded:: 1.12.0
  94. """
  95. # For the sake of enabling subclassing, NDInterpolatorBase._set_xi performs
  96. # some operations which are not required by NearestNDInterpolator.__call__,
  97. # hence here we operate on xi directly, without calling a parent class function.
  98. xi = _ndim_coords_from_arrays(args, ndim=self.points.shape[1])
  99. xi = self._check_call_shape(xi)
  100. xi = self._scale_x(xi)
  101. # We need to handle two important cases:
  102. # (1) the case where xi has trailing dimensions (..., ndim), and
  103. # (2) the case where y has trailing dimensions
  104. # We will first flatten xi to deal with case (1),
  105. # do the computation in flattened array while retaining y's dimensionality,
  106. # and then reshape the interpolated values back to match xi's shape.
  107. # Flatten xi for the query
  108. xi_flat = xi.reshape(-1, xi.shape[-1])
  109. original_shape = xi.shape
  110. flattened_shape = xi_flat.shape
  111. # if distance_upper_bound is set to not be infinite,
  112. # then we need to consider the case where cKDtree
  113. # does not find any points within distance_upper_bound to return.
  114. # It marks those points as having infinte distance, which is what will be used
  115. # below to mask the array and return only the points that were deemed
  116. # to have a close enough neighbor to return something useful.
  117. dist, i = self.tree.query(xi_flat, **query_options)
  118. valid_mask = np.isfinite(dist)
  119. # create a holder interp_values array and fill with nans.
  120. if self.values.ndim > 1:
  121. interp_shape = flattened_shape[:-1] + self.values.shape[1:]
  122. else:
  123. interp_shape = flattened_shape[:-1]
  124. if np.issubdtype(self.values.dtype, np.complexfloating):
  125. interp_values = np.full(interp_shape, np.nan, dtype=self.values.dtype)
  126. else:
  127. interp_values = np.full(interp_shape, np.nan)
  128. interp_values[valid_mask] = self.values[i[valid_mask], ...]
  129. if self.values.ndim > 1:
  130. new_shape = original_shape[:-1] + self.values.shape[1:]
  131. else:
  132. new_shape = original_shape[:-1]
  133. interp_values = interp_values.reshape(new_shape)
  134. return interp_values
  135. #------------------------------------------------------------------------------
  136. # Convenience interface function
  137. #------------------------------------------------------------------------------
  138. def griddata(points, values, xi, method='linear', fill_value=np.nan,
  139. rescale=False):
  140. """
  141. Convenience function for interpolating unstructured data in multiple dimensions.
  142. Parameters
  143. ----------
  144. points : 2-D ndarray of floats with shape (n, D), or length D tuple of 1-D ndarrays with shape (n,).
  145. Data point coordinates.
  146. values : ndarray of float or complex, shape (n,)
  147. Data values.
  148. xi : 2-D ndarray of floats with shape (m, D), or length D tuple of ndarrays broadcastable to the same shape.
  149. Points at which to interpolate data.
  150. method : {'linear', 'nearest', 'cubic'}, optional
  151. Method of interpolation. One of
  152. ``nearest``
  153. return the value at the data point closest to
  154. the point of interpolation. See `NearestNDInterpolator` for
  155. more details.
  156. ``linear``
  157. tessellate the input point set to N-D
  158. simplices, and interpolate linearly on each simplex. See
  159. `LinearNDInterpolator` for more details.
  160. ``cubic`` (1-D)
  161. return the value determined from a cubic
  162. spline.
  163. ``cubic`` (2-D)
  164. return the value determined from a
  165. piecewise cubic, continuously differentiable (C1), and
  166. approximately curvature-minimizing polynomial surface. See
  167. `CloughTocher2DInterpolator` for more details.
  168. fill_value : float, optional
  169. Value used to fill in for requested points outside of the
  170. convex hull of the input points. If not provided, then the
  171. default is ``nan``. This option has no effect for the
  172. 'nearest' method.
  173. rescale : bool, optional
  174. Rescale points to unit cube before performing interpolation.
  175. This is useful if some of the input dimensions have
  176. incommensurable units and differ by many orders of magnitude.
  177. .. versionadded:: 0.14.0
  178. Returns
  179. -------
  180. ndarray
  181. Array of interpolated values.
  182. See Also
  183. --------
  184. LinearNDInterpolator :
  185. Piecewise linear interpolator in N dimensions.
  186. NearestNDInterpolator :
  187. Nearest-neighbor interpolator in N dimensions.
  188. CloughTocher2DInterpolator :
  189. Piecewise cubic, C1 smooth, curvature-minimizing interpolator in 2D.
  190. interpn : Interpolation on a regular grid or rectilinear grid.
  191. RegularGridInterpolator : Interpolator on a regular or rectilinear grid
  192. in arbitrary dimensions (`interpn` wraps this
  193. class).
  194. Notes
  195. -----
  196. .. versionadded:: 0.9
  197. .. note:: For data on a regular grid use `interpn` instead.
  198. Examples
  199. --------
  200. Suppose we want to interpolate the 2-D function
  201. >>> import numpy as np
  202. >>> def func(x, y):
  203. ... return x*(1-x)*np.cos(4*np.pi*x) * np.sin(4*np.pi*y**2)**2
  204. on a grid in [0, 1]x[0, 1]
  205. >>> grid_x, grid_y = np.mgrid[0:1:100j, 0:1:200j]
  206. but we only know its values at 1000 data points:
  207. >>> rng = np.random.default_rng()
  208. >>> points = rng.random((1000, 2))
  209. >>> values = func(points[:,0], points[:,1])
  210. This can be done with `griddata` -- below we try out all of the
  211. interpolation methods:
  212. >>> from scipy.interpolate import griddata
  213. >>> grid_z0 = griddata(points, values, (grid_x, grid_y), method='nearest')
  214. >>> grid_z1 = griddata(points, values, (grid_x, grid_y), method='linear')
  215. >>> grid_z2 = griddata(points, values, (grid_x, grid_y), method='cubic')
  216. One can see that the exact result is reproduced by all of the
  217. methods to some degree, but for this smooth function the piecewise
  218. cubic interpolant gives the best results:
  219. >>> import matplotlib.pyplot as plt
  220. >>> plt.subplot(221)
  221. >>> plt.imshow(func(grid_x, grid_y).T, extent=(0,1,0,1), origin='lower')
  222. >>> plt.plot(points[:,0], points[:,1], 'k.', ms=1)
  223. >>> plt.title('Original')
  224. >>> plt.subplot(222)
  225. >>> plt.imshow(grid_z0.T, extent=(0,1,0,1), origin='lower')
  226. >>> plt.title('Nearest')
  227. >>> plt.subplot(223)
  228. >>> plt.imshow(grid_z1.T, extent=(0,1,0,1), origin='lower')
  229. >>> plt.title('Linear')
  230. >>> plt.subplot(224)
  231. >>> plt.imshow(grid_z2.T, extent=(0,1,0,1), origin='lower')
  232. >>> plt.title('Cubic')
  233. >>> plt.gcf().set_size_inches(6, 6)
  234. >>> plt.show()
  235. """ # numpy/numpydoc#87 # noqa: E501
  236. points = _ndim_coords_from_arrays(points)
  237. if points.ndim < 2:
  238. ndim = points.ndim
  239. else:
  240. ndim = points.shape[-1]
  241. if ndim == 1 and method in ('nearest', 'linear', 'cubic'):
  242. from ._interpolate import interp1d
  243. points = points.ravel()
  244. if isinstance(xi, tuple):
  245. if len(xi) != 1:
  246. raise ValueError("invalid number of dimensions in xi")
  247. xi, = xi
  248. # Sort points/values together, necessary as input for interp1d
  249. idx = np.argsort(points)
  250. points = points[idx]
  251. values = values[idx]
  252. if method == 'nearest':
  253. fill_value = 'extrapolate'
  254. ip = interp1d(points, values, kind=method, axis=0, bounds_error=False,
  255. fill_value=fill_value)
  256. return ip(xi)
  257. elif method == 'nearest':
  258. ip = NearestNDInterpolator(points, values, rescale=rescale)
  259. return ip(xi)
  260. elif method == 'linear':
  261. ip = LinearNDInterpolator(points, values, fill_value=fill_value,
  262. rescale=rescale)
  263. return ip(xi)
  264. elif method == 'cubic' and ndim == 2:
  265. ip = CloughTocher2DInterpolator(points, values, fill_value=fill_value,
  266. rescale=rescale)
  267. return ip(xi)
  268. else:
  269. raise ValueError(
  270. f"Unknown interpolation method {method!r} for {ndim} dimensional data"
  271. )