_rbfinterp_xp.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """
  2. 'Generic' Array API backend for RBF interpolation.
  3. The general logic is this: `_rbfinterp.py` implements the user API and calls
  4. into either `_rbfinterp_np` (the "numpy backend"), or `_rbfinterp_xp` (the
  5. "generic backend".
  6. The numpy backend offloads performance-critical computations to the
  7. pythran-compiled `_rbfinterp_pythran` extension. This way, the call chain is
  8. _rbfinterp.py <-- _rbfinterp_np.py <-- _rbfinterp_pythran.py
  9. The "generic" backend here is a drop-in replacement of the API of
  10. `_rbfinterp_np.py` for use in `_rbfinterp.py` with non-numpy arrays.
  11. The implementation closely follows `_rbfinterp_np + _rbfinterp_pythran`, with
  12. the following differences:
  13. - We used vectorized code not explicit loops in `_build_system` and
  14. `_build_evaluation_coefficients`; this is more torch/jax friendly;
  15. - RBF kernels are also "vectorized" and not scalar: they receive an
  16. array of norms not a single norm;
  17. - RBF kernels accept an extra xp= argument;
  18. In general, we would prefer less code duplication. The main blocker ATM is
  19. that pythran cannot compile functions with an xp= argument where xp is numpy.
  20. """
  21. from numpy.linalg import LinAlgError
  22. from ._rbfinterp_common import _monomial_powers_impl
  23. def _monomial_powers(ndim, degree, xp):
  24. out = _monomial_powers_impl(ndim, degree)
  25. out = xp.asarray(out)
  26. if out.shape[0] == 0:
  27. out = xp.reshape(out, (0, ndim))
  28. return out
  29. def _build_and_solve_system(y, d, smoothing, kernel, epsilon, powers, xp):
  30. """Build and solve the RBF interpolation system of equations.
  31. Parameters
  32. ----------
  33. y : (P, N) float ndarray
  34. Data point coordinates.
  35. d : (P, S) float ndarray
  36. Data values at `y`.
  37. smoothing : (P,) float ndarray
  38. Smoothing parameter for each data point.
  39. kernel : str
  40. Name of the RBF.
  41. epsilon : float
  42. Shape parameter.
  43. powers : (R, N) int ndarray
  44. The exponents for each monomial in the polynomial.
  45. Returns
  46. -------
  47. coeffs : (P + R, S) float ndarray
  48. Coefficients for each RBF and monomial.
  49. shift : (N,) float ndarray
  50. Domain shift used to create the polynomial matrix.
  51. scale : (N,) float ndarray
  52. Domain scaling used to create the polynomial matrix.
  53. """
  54. lhs, rhs, shift, scale = _build_system(
  55. y, d, smoothing, kernel, epsilon, powers, xp
  56. )
  57. try:
  58. coeffs = xp.linalg.solve(lhs, rhs)
  59. except Exception:
  60. # Best-effort attempt to emit a helpful message.
  61. # `_rbfinterp_np` backend gives better diagnostics; it is hard to
  62. # match it in a backend-agnostic way: e.g. jax emits no error at all,
  63. # and instead returns an array of nans for a singular `lhs`.
  64. msg = "Singular matrix"
  65. nmonos = powers.shape[0]
  66. if nmonos > 0:
  67. pmat = polynomial_matrix((y - shift)/scale, powers, xp=xp)
  68. rank = xp.linalg.matrix_rank(pmat)
  69. if rank < nmonos:
  70. msg = (
  71. "Singular matrix. The matrix of monomials evaluated at "
  72. "the data point coordinates does not have full column "
  73. f"rank ({rank}/{nmonos})."
  74. )
  75. raise LinAlgError(msg)
  76. return shift, scale, coeffs
  77. def linear(r, xp):
  78. return -r
  79. def thin_plate_spline(r, xp):
  80. # NB: changed w.r.t. pythran, vectorized
  81. return xp.where(r == 0, 0, r**2 * xp.log(r))
  82. def cubic(r, xp):
  83. return r**3
  84. def quintic(r, xp):
  85. return -r**5
  86. def multiquadric(r, xp):
  87. return -xp.sqrt(r**2 + 1)
  88. def inverse_multiquadric(r, xp):
  89. return 1.0 / xp.sqrt(r**2 + 1.0)
  90. def inverse_quadratic(r, xp):
  91. return 1.0 / (r**2 + 1.0)
  92. def gaussian(r, xp):
  93. return xp.exp(-r**2)
  94. NAME_TO_FUNC = {
  95. "linear": linear,
  96. "thin_plate_spline": thin_plate_spline,
  97. "cubic": cubic,
  98. "quintic": quintic,
  99. "multiquadric": multiquadric,
  100. "inverse_multiquadric": inverse_multiquadric,
  101. "inverse_quadratic": inverse_quadratic,
  102. "gaussian": gaussian
  103. }
  104. def kernel_matrix(x, kernel_func, xp):
  105. """Evaluate RBFs, with centers at `x`, at `x`."""
  106. return kernel_func(
  107. xp.linalg.vector_norm(x[None, :, :] - x[:, None, :], axis=-1), xp
  108. )
  109. def polynomial_matrix(x, powers, xp):
  110. """Evaluate monomials, with exponents from `powers`, at `x`."""
  111. return xp.prod(x[:, None, :] ** powers, axis=-1)
  112. def _build_system(y, d, smoothing, kernel, epsilon, powers, xp):
  113. """Build the system used to solve for the RBF interpolant coefficients.
  114. Parameters
  115. ----------
  116. y : (P, N) float ndarray
  117. Data point coordinates.
  118. d : (P, S) float ndarray
  119. Data values at `y`.
  120. smoothing : (P,) float ndarray
  121. Smoothing parameter for each data point.
  122. kernel : str
  123. Name of the RBF.
  124. epsilon : float
  125. Shape parameter.
  126. powers : (R, N) int ndarray
  127. The exponents for each monomial in the polynomial.
  128. Returns
  129. -------
  130. lhs : (P + R, P + R) float ndarray
  131. Left-hand side matrix.
  132. rhs : (P + R, S) float ndarray
  133. Right-hand side matrix.
  134. shift : (N,) float ndarray
  135. Domain shift used to create the polynomial matrix.
  136. scale : (N,) float ndarray
  137. Domain scaling used to create the polynomial matrix.
  138. """
  139. s = d.shape[1]
  140. r = powers.shape[0]
  141. kernel_func = NAME_TO_FUNC[kernel]
  142. # Shift and scale the polynomial domain to be between -1 and 1
  143. mins = xp.min(y, axis=0)
  144. maxs = xp.max(y, axis=0)
  145. shift = (maxs + mins)/2
  146. scale = (maxs - mins)/2
  147. # The scale may be zero if there is a single point or all the points have
  148. # the same value for some dimension. Avoid division by zero by replacing
  149. # zeros with ones.
  150. scale = xp.where(scale == 0.0, 1.0, scale)
  151. yeps = y*epsilon
  152. yhat = (y - shift)/scale
  153. out_kernels = kernel_matrix(yeps, kernel_func, xp)
  154. out_poly = polynomial_matrix(yhat, powers, xp)
  155. lhs = xp.concat(
  156. [
  157. xp.concat((out_kernels, out_poly), axis=1),
  158. xp.concat((out_poly.T, xp.zeros((r, r))), axis=1)
  159. ]
  160. , axis=0) + xp.diag(xp.concat([smoothing, xp.zeros(r)]))
  161. rhs = xp.concat([d, xp.zeros((r, s))], axis=0)
  162. return lhs, rhs, shift, scale
  163. def _build_evaluation_coefficients(
  164. x, y, kernel, epsilon, powers, shift, scale, xp
  165. ):
  166. """Construct the coefficients needed to evaluate
  167. the RBF.
  168. Parameters
  169. ----------
  170. x : (Q, N) float ndarray
  171. Evaluation point coordinates.
  172. y : (P, N) float ndarray
  173. Data point coordinates.
  174. kernel : str
  175. Name of the RBF.
  176. epsilon : float
  177. Shape parameter.
  178. powers : (R, N) int ndarray
  179. The exponents for each monomial in the polynomial.
  180. shift : (N,) float ndarray
  181. Shifts the polynomial domain for numerical stability.
  182. scale : (N,) float ndarray
  183. Scales the polynomial domain for numerical stability.
  184. Returns
  185. -------
  186. (Q, P + R) float ndarray
  187. """
  188. kernel_func = NAME_TO_FUNC[kernel]
  189. yeps = y*epsilon
  190. xeps = x*epsilon
  191. xhat = (x - shift)/scale
  192. # NB: changed w.r.t. pythran
  193. vec = xp.concat(
  194. [
  195. kernel_func(
  196. xp.linalg.vector_norm(
  197. xeps[:, None, :] - yeps[None, :, :], axis=-1
  198. ), xp
  199. ),
  200. xp.prod(xhat[:, None, :] ** powers, axis=-1)
  201. ], axis=-1
  202. )
  203. return vec
  204. def compute_interpolation(x, y, kernel, epsilon, powers, shift, scale, coeffs, xp):
  205. vec = _build_evaluation_coefficients(
  206. x, y, kernel, epsilon, powers, shift, scale, xp
  207. )
  208. return vec @ coeffs