test_ndgriddata.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import numpy as np
  2. from scipy._lib._array_api import (
  3. xp_assert_equal, xp_assert_close
  4. )
  5. import pytest
  6. from pytest import raises as assert_raises
  7. from scipy.interpolate import (griddata, NearestNDInterpolator,
  8. LinearNDInterpolator,
  9. CloughTocher2DInterpolator)
  10. from scipy._lib._testutils import _run_concurrent_barrier
  11. parametrize_interpolators = pytest.mark.parametrize(
  12. "interpolator", [NearestNDInterpolator, LinearNDInterpolator,
  13. CloughTocher2DInterpolator]
  14. )
  15. parametrize_methods = pytest.mark.parametrize(
  16. 'method',
  17. ('nearest', 'linear', 'cubic'),
  18. )
  19. parametrize_rescale = pytest.mark.parametrize(
  20. 'rescale',
  21. (True, False),
  22. )
  23. class TestGriddata:
  24. def test_fill_value(self):
  25. x = [(0,0), (0,1), (1,0)]
  26. y = [1, 2, 3]
  27. yi = griddata(x, y, [(1,1), (1,2), (0,0)], fill_value=-1)
  28. xp_assert_equal(yi, [-1., -1, 1])
  29. yi = griddata(x, y, [(1,1), (1,2), (0,0)])
  30. xp_assert_equal(yi, [np.nan, np.nan, 1])
  31. @parametrize_methods
  32. @parametrize_rescale
  33. def test_alternative_call(self, method, rescale):
  34. x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
  35. dtype=np.float64)
  36. y = (np.arange(x.shape[0], dtype=np.float64)[:,None]
  37. + np.array([0,1])[None,:])
  38. msg = repr((method, rescale))
  39. yi = griddata((x[:,0], x[:,1]), y, (x[:,0], x[:,1]), method=method,
  40. rescale=rescale)
  41. xp_assert_close(y, yi, atol=1e-14, err_msg=msg)
  42. @parametrize_methods
  43. @parametrize_rescale
  44. def test_multivalue_2d(self, method, rescale):
  45. x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
  46. dtype=np.float64)
  47. y = (np.arange(x.shape[0], dtype=np.float64)[:,None]
  48. + np.array([0,1])[None,:])
  49. msg = repr((method, rescale))
  50. yi = griddata(x, y, x, method=method, rescale=rescale)
  51. xp_assert_close(y, yi, atol=1e-14, err_msg=msg)
  52. @parametrize_methods
  53. @parametrize_rescale
  54. def test_multipoint_2d(self, method, rescale):
  55. x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
  56. dtype=np.float64)
  57. y = np.arange(x.shape[0], dtype=np.float64)
  58. xi = x[:,None,:] + np.array([0,0,0])[None,:,None]
  59. msg = repr((method, rescale))
  60. yi = griddata(x, y, xi, method=method, rescale=rescale)
  61. assert yi.shape == (5, 3), msg
  62. xp_assert_close(yi, np.tile(y[:,None], (1, 3)),
  63. atol=1e-14, err_msg=msg)
  64. @parametrize_methods
  65. @parametrize_rescale
  66. def test_complex_2d(self, method, rescale):
  67. x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
  68. dtype=np.float64)
  69. y = np.arange(x.shape[0], dtype=np.float64)
  70. y = y - 2j*y[::-1]
  71. xi = x[:,None,:] + np.array([0,0,0])[None,:,None]
  72. msg = repr((method, rescale))
  73. yi = griddata(x, y, xi, method=method, rescale=rescale)
  74. assert yi.shape == (5, 3)
  75. xp_assert_close(yi, np.tile(y[:,None], (1, 3)),
  76. atol=1e-14, err_msg=msg)
  77. @parametrize_methods
  78. def test_1d(self, method):
  79. x = np.array([1, 2.5, 3, 4.5, 5, 6])
  80. y = np.array([1, 2, 0, 3.9, 2, 1])
  81. xp_assert_close(griddata(x, y, x, method=method), y,
  82. err_msg=method, atol=1e-14)
  83. xp_assert_close(griddata(x.reshape(6, 1), y, x, method=method), y,
  84. err_msg=method, atol=1e-14)
  85. xp_assert_close(griddata((x,), y, (x,), method=method), y,
  86. err_msg=method, atol=1e-14)
  87. def test_1d_borders(self):
  88. # Test for nearest neighbor case with xi outside
  89. # the range of the values.
  90. x = np.array([1, 2.5, 3, 4.5, 5, 6])
  91. y = np.array([1, 2, 0, 3.9, 2, 1])
  92. xi = np.array([0.9, 6.5])
  93. yi_should = np.array([1.0, 1.0])
  94. method = 'nearest'
  95. xp_assert_close(griddata(x, y, xi,
  96. method=method), yi_should,
  97. err_msg=method,
  98. atol=1e-14)
  99. xp_assert_close(griddata(x.reshape(6, 1), y, xi,
  100. method=method), yi_should,
  101. err_msg=method,
  102. atol=1e-14)
  103. xp_assert_close(griddata((x, ), y, (xi, ),
  104. method=method), yi_should,
  105. err_msg=method,
  106. atol=1e-14)
  107. @parametrize_methods
  108. def test_1d_unsorted(self, method):
  109. x = np.array([2.5, 1, 4.5, 5, 6, 3])
  110. y = np.array([1, 2, 0, 3.9, 2, 1])
  111. xp_assert_close(griddata(x, y, x, method=method), y,
  112. err_msg=method, atol=1e-10)
  113. xp_assert_close(griddata(x.reshape(6, 1), y, x, method=method), y,
  114. err_msg=method, atol=1e-10)
  115. xp_assert_close(griddata((x,), y, (x,), method=method), y,
  116. err_msg=method, atol=1e-10)
  117. @parametrize_methods
  118. def test_square_rescale_manual(self, method):
  119. points = np.array([(0,0), (0,100), (10,100), (10,0), (1, 5)], dtype=np.float64)
  120. points_rescaled = np.array([(0,0), (0,1), (1,1), (1,0), (0.1, 0.05)],
  121. dtype=np.float64)
  122. values = np.array([1., 2., -3., 5., 9.], dtype=np.float64)
  123. xx, yy = np.broadcast_arrays(np.linspace(0, 10, 14)[:,None],
  124. np.linspace(0, 100, 14)[None,:])
  125. xx = xx.ravel()
  126. yy = yy.ravel()
  127. xi = np.array([xx, yy]).T.copy()
  128. msg = method
  129. zi = griddata(points_rescaled, values, xi/np.array([10, 100.]),
  130. method=method)
  131. zi_rescaled = griddata(points, values, xi, method=method,
  132. rescale=True)
  133. xp_assert_close(zi, zi_rescaled, err_msg=msg,
  134. atol=1e-12)
  135. @parametrize_methods
  136. def test_xi_1d(self, method):
  137. # Check that 1-D xi is interpreted as a coordinate
  138. x = np.array([(0,0), (-0.5,-0.5), (-0.5,0.5), (0.5, 0.5), (0.25, 0.3)],
  139. dtype=np.float64)
  140. y = np.arange(x.shape[0], dtype=np.float64)
  141. y = y - 2j*y[::-1]
  142. xi = np.array([0.5, 0.5])
  143. p1 = griddata(x, y, xi, method=method)
  144. p2 = griddata(x, y, xi[None,:], method=method)
  145. xp_assert_close(p1, p2, err_msg=method)
  146. xi1 = np.array([0.5])
  147. xi3 = np.array([0.5, 0.5, 0.5])
  148. assert_raises(ValueError, griddata, x, y, xi1,
  149. method=method)
  150. assert_raises(ValueError, griddata, x, y, xi3,
  151. method=method)
  152. class TestNearestNDInterpolator:
  153. def test_nearest_options(self):
  154. # smoke test that NearestNDInterpolator accept cKDTree options
  155. npts, nd = 4, 3
  156. x = np.arange(npts*nd).reshape((npts, nd))
  157. y = np.arange(npts)
  158. nndi = NearestNDInterpolator(x, y)
  159. opts = {'balanced_tree': False, 'compact_nodes': False}
  160. nndi_o = NearestNDInterpolator(x, y, tree_options=opts)
  161. xp_assert_close(nndi(x), nndi_o(x), atol=1e-14)
  162. def test_nearest_list_argument(self):
  163. nd = np.array([[0, 0, 0, 0, 1, 0, 1],
  164. [0, 0, 0, 0, 0, 1, 1],
  165. [0, 0, 0, 0, 1, 1, 2]])
  166. d = nd[:, 3:]
  167. # z is np.array
  168. NI = NearestNDInterpolator((d[0], d[1]), d[2])
  169. xp_assert_equal(NI([0.1, 0.9], [0.1, 0.9]), [0.0, 2.0])
  170. # z is list
  171. NI = NearestNDInterpolator((d[0], d[1]), list(d[2]))
  172. xp_assert_equal(NI([0.1, 0.9], [0.1, 0.9]), [0.0, 2.0])
  173. def test_nearest_query_options(self):
  174. nd = np.array([[0, 0.5, 0, 1],
  175. [0, 0, 0.5, 1],
  176. [0, 1, 1, 2]])
  177. delta = 0.1
  178. query_points = [0 + delta, 1 + delta], [0 + delta, 1 + delta]
  179. # case 1 - query max_dist is smaller than
  180. # the query points' nearest distance to nd.
  181. NI = NearestNDInterpolator((nd[0], nd[1]), nd[2])
  182. distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) - 1e-7
  183. xp_assert_equal(NI(query_points, distance_upper_bound=distance_upper_bound),
  184. [np.nan, np.nan])
  185. # case 2 - query p is inf, will return [0, 2]
  186. distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) - 1e-7
  187. p = np.inf
  188. xp_assert_equal(
  189. NI(query_points, distance_upper_bound=distance_upper_bound, p=p),
  190. [0.0, 2.0]
  191. )
  192. # case 3 - query max_dist is larger, so should return non np.nan
  193. distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) + 1e-7
  194. xp_assert_equal(
  195. NI(query_points, distance_upper_bound=distance_upper_bound),
  196. [0.0, 2.0]
  197. )
  198. def test_nearest_query_valid_inputs(self):
  199. nd = np.array([[0, 1, 0, 1],
  200. [0, 0, 1, 1],
  201. [0, 1, 1, 2]])
  202. NI = NearestNDInterpolator((nd[0], nd[1]), nd[2])
  203. with assert_raises(TypeError):
  204. NI([0.5, 0.5], query_options="not a dictionary")
  205. def test_concurrency(self):
  206. npts, nd = 50, 3
  207. x = np.arange(npts * nd).reshape((npts, nd))
  208. y = np.arange(npts)
  209. nndi = NearestNDInterpolator(x, y)
  210. def worker_fn(_, spl):
  211. spl(x)
  212. _run_concurrent_barrier(10, worker_fn, nndi)
  213. class TestNDInterpolators:
  214. @parametrize_interpolators
  215. def test_broadcastable_input(self, interpolator):
  216. # input data
  217. rng = np.random.RandomState(0)
  218. x = rng.random(10)
  219. y = rng.random(10)
  220. z = np.hypot(x, y)
  221. # x-y grid for interpolation
  222. X = np.linspace(min(x), max(x))
  223. Y = np.linspace(min(y), max(y))
  224. X, Y = np.meshgrid(X, Y)
  225. XY = np.vstack((X.ravel(), Y.ravel())).T
  226. interp = interpolator(list(zip(x, y)), z)
  227. # single array input
  228. interp_points0 = interp(XY)
  229. # tuple input
  230. interp_points1 = interp((X, Y))
  231. interp_points2 = interp((X, 0.0))
  232. # broadcastable input
  233. interp_points3 = interp(X, Y)
  234. interp_points4 = interp(X, 0.0)
  235. assert (interp_points0.size ==
  236. interp_points1.size ==
  237. interp_points2.size ==
  238. interp_points3.size ==
  239. interp_points4.size)
  240. @parametrize_interpolators
  241. def test_read_only(self, interpolator):
  242. # input data
  243. rng = np.random.RandomState(0)
  244. xy = rng.random((10, 2))
  245. x, y = xy[:, 0], xy[:, 1]
  246. z = np.hypot(x, y)
  247. # interpolation points
  248. XY = rng.random((50, 2))
  249. xy.setflags(write=False)
  250. z.setflags(write=False)
  251. XY.setflags(write=False)
  252. interp = interpolator(xy, z)
  253. interp(XY)