test_rbf.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Created by John Travers, Robert Hetland, 2007
  2. """ Test functions for rbf module """
  3. import numpy as np
  4. from scipy._lib._array_api import assert_array_almost_equal, assert_almost_equal
  5. from numpy import linspace, sin, cos, exp, allclose
  6. from scipy.interpolate._rbf import Rbf
  7. from scipy._lib._testutils import _run_concurrent_barrier
  8. FUNCTIONS = ('multiquadric', 'inverse multiquadric', 'gaussian',
  9. 'cubic', 'quintic', 'thin-plate', 'linear')
  10. def check_rbf1d_interpolation(function):
  11. # Check that the Rbf function interpolates through the nodes (1D)
  12. x = linspace(0,10,9)
  13. y = sin(x)
  14. rbf = Rbf(x, y, function=function)
  15. yi = rbf(x)
  16. assert_array_almost_equal(y, yi)
  17. assert_almost_equal(rbf(float(x[0])), y[0], check_0d=False)
  18. def check_rbf2d_interpolation(function):
  19. # Check that the Rbf function interpolates through the nodes (2D).
  20. rng = np.random.RandomState(1234)
  21. x = rng.rand(50,1)*4-2
  22. y = rng.rand(50,1)*4-2
  23. z = x*exp(-x**2-1j*y**2)
  24. rbf = Rbf(x, y, z, epsilon=2, function=function)
  25. zi = rbf(x, y)
  26. zi = zi.reshape(x.shape)
  27. assert_array_almost_equal(z, zi)
  28. def check_rbf3d_interpolation(function):
  29. # Check that the Rbf function interpolates through the nodes (3D).
  30. rng = np.random.RandomState(1234)
  31. x = rng.rand(50, 1)*4 - 2
  32. y = rng.rand(50, 1)*4 - 2
  33. z = rng.rand(50, 1)*4 - 2
  34. d = x*exp(-x**2 - y**2)
  35. rbf = Rbf(x, y, z, d, epsilon=2, function=function)
  36. di = rbf(x, y, z)
  37. di = di.reshape(x.shape)
  38. assert_array_almost_equal(di, d)
  39. def test_rbf_interpolation():
  40. for function in FUNCTIONS:
  41. check_rbf1d_interpolation(function)
  42. check_rbf2d_interpolation(function)
  43. check_rbf3d_interpolation(function)
  44. def check_2drbf1d_interpolation(function):
  45. # Check that the 2-D Rbf function interpolates through the nodes (1D)
  46. x = linspace(0, 10, 9)
  47. y0 = sin(x)
  48. y1 = cos(x)
  49. y = np.vstack([y0, y1]).T
  50. rbf = Rbf(x, y, function=function, mode='N-D')
  51. yi = rbf(x)
  52. assert_array_almost_equal(y, yi)
  53. assert_almost_equal(rbf(float(x[0])), y[0])
  54. def check_2drbf2d_interpolation(function):
  55. # Check that the 2-D Rbf function interpolates through the nodes (2D).
  56. rng = np.random.RandomState(1234)
  57. x = rng.rand(50, ) * 4 - 2
  58. y = rng.rand(50, ) * 4 - 2
  59. z0 = x * exp(-x ** 2 - 1j * y ** 2)
  60. z1 = y * exp(-y ** 2 - 1j * x ** 2)
  61. z = np.vstack([z0, z1]).T
  62. rbf = Rbf(x, y, z, epsilon=2, function=function, mode='N-D')
  63. zi = rbf(x, y)
  64. zi = zi.reshape(z.shape)
  65. assert_array_almost_equal(z, zi)
  66. def check_2drbf3d_interpolation(function):
  67. # Check that the 2-D Rbf function interpolates through the nodes (3D).
  68. rng = np.random.RandomState(1234)
  69. x = rng.rand(50, ) * 4 - 2
  70. y = rng.rand(50, ) * 4 - 2
  71. z = rng.rand(50, ) * 4 - 2
  72. d0 = x * exp(-x ** 2 - y ** 2)
  73. d1 = y * exp(-y ** 2 - x ** 2)
  74. d = np.vstack([d0, d1]).T
  75. rbf = Rbf(x, y, z, d, epsilon=2, function=function, mode='N-D')
  76. di = rbf(x, y, z)
  77. di = di.reshape(d.shape)
  78. assert_array_almost_equal(di, d)
  79. def test_2drbf_interpolation():
  80. for function in FUNCTIONS:
  81. check_2drbf1d_interpolation(function)
  82. check_2drbf2d_interpolation(function)
  83. check_2drbf3d_interpolation(function)
  84. def check_rbf1d_regularity(function, atol):
  85. # Check that the Rbf function approximates a smooth function well away
  86. # from the nodes.
  87. x = linspace(0, 10, 9)
  88. y = sin(x)
  89. rbf = Rbf(x, y, function=function)
  90. xi = linspace(0, 10, 100)
  91. yi = rbf(xi)
  92. msg = f"abs-diff: {abs(yi - sin(xi)).max():f}"
  93. assert allclose(yi, sin(xi), atol=atol), msg
  94. def test_rbf_regularity():
  95. tolerances = {
  96. 'multiquadric': 0.1,
  97. 'inverse multiquadric': 0.15,
  98. 'gaussian': 0.15,
  99. 'cubic': 0.15,
  100. 'quintic': 0.1,
  101. 'thin-plate': 0.1,
  102. 'linear': 0.2
  103. }
  104. for function in FUNCTIONS:
  105. check_rbf1d_regularity(function, tolerances.get(function, 1e-2))
  106. def check_2drbf1d_regularity(function, atol):
  107. # Check that the 2-D Rbf function approximates a smooth function well away
  108. # from the nodes.
  109. x = linspace(0, 10, 9)
  110. y0 = sin(x)
  111. y1 = cos(x)
  112. y = np.vstack([y0, y1]).T
  113. rbf = Rbf(x, y, function=function, mode='N-D')
  114. xi = linspace(0, 10, 100)
  115. yi = rbf(xi)
  116. msg = f"abs-diff: {abs(yi - np.vstack([sin(xi), cos(xi)]).T).max():f}"
  117. assert allclose(yi, np.vstack([sin(xi), cos(xi)]).T, atol=atol), msg
  118. def test_2drbf_regularity():
  119. tolerances = {
  120. 'multiquadric': 0.1,
  121. 'inverse multiquadric': 0.15,
  122. 'gaussian': 0.15,
  123. 'cubic': 0.15,
  124. 'quintic': 0.1,
  125. 'thin-plate': 0.15,
  126. 'linear': 0.2
  127. }
  128. for function in FUNCTIONS:
  129. check_2drbf1d_regularity(function, tolerances.get(function, 1e-2))
  130. def check_rbf1d_stability(function):
  131. # Check that the Rbf function with default epsilon is not subject
  132. # to overshoot. Regression for issue #4523.
  133. #
  134. # Generate some data (fixed random seed hence deterministic)
  135. rng = np.random.RandomState(1234)
  136. x = np.linspace(0, 10, 50)
  137. z = x + 4.0 * rng.randn(len(x))
  138. rbf = Rbf(x, z, function=function)
  139. xi = np.linspace(0, 10, 1000)
  140. yi = rbf(xi)
  141. # subtract the linear trend and make sure there no spikes
  142. assert np.abs(yi-xi).max() / np.abs(z-x).max() < 1.1
  143. def test_rbf_stability():
  144. for function in FUNCTIONS:
  145. check_rbf1d_stability(function)
  146. def test_default_construction():
  147. # Check that the Rbf class can be constructed with the default
  148. # multiquadric basis function. Regression test for ticket #1228.
  149. x = linspace(0,10,9)
  150. y = sin(x)
  151. rbf = Rbf(x, y)
  152. yi = rbf(x)
  153. assert_array_almost_equal(y, yi)
  154. def test_function_is_callable():
  155. # Check that the Rbf class can be constructed with function=callable.
  156. x = linspace(0,10,9)
  157. y = sin(x)
  158. def linfunc(x):
  159. return x
  160. rbf = Rbf(x, y, function=linfunc)
  161. yi = rbf(x)
  162. assert_array_almost_equal(y, yi)
  163. def test_two_arg_function_is_callable():
  164. # Check that the Rbf class can be constructed with a two argument
  165. # function=callable.
  166. def _func(self, r):
  167. return self.epsilon + r
  168. x = linspace(0,10,9)
  169. y = sin(x)
  170. rbf = Rbf(x, y, function=_func)
  171. yi = rbf(x)
  172. assert_array_almost_equal(y, yi)
  173. def test_rbf_epsilon_none():
  174. x = linspace(0, 10, 9)
  175. y = sin(x)
  176. Rbf(x, y, epsilon=None)
  177. def test_rbf_epsilon_none_collinear():
  178. # Check that collinear points in one dimension doesn't cause an error
  179. # due to epsilon = 0
  180. x = [1, 2, 3]
  181. y = [4, 4, 4]
  182. z = [5, 6, 7]
  183. rbf = Rbf(x, y, z, epsilon=None)
  184. assert rbf.epsilon > 0
  185. def test_rbf_concurrency():
  186. x = linspace(0, 10, 100)
  187. y0 = sin(x)
  188. y1 = cos(x)
  189. y = np.vstack([y0, y1]).T
  190. rbf = Rbf(x, y, mode='N-D')
  191. def worker_fn(_, interp, xp):
  192. interp(xp)
  193. _run_concurrent_barrier(10, worker_fn, rbf, x)