_rbfinterp_np.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import numpy as np
  2. from numpy.linalg import LinAlgError
  3. from scipy.linalg.lapack import dgesv # type: ignore[attr-defined]
  4. from ._rbfinterp_common import _monomial_powers_impl
  5. from ._rbfinterp_pythran import (
  6. _build_system as _pythran_build_system,
  7. _build_evaluation_coefficients as _pythran_build_evaluation_coefficients,
  8. _polynomial_matrix as _pythran_polynomial_matrix
  9. )
  10. # trampolines for pythran-compiled functions to drop the `xp` argument
  11. def _build_evaluation_coefficients(
  12. x, y, kernel, epsilon, powers, shift, scale, xp
  13. ):
  14. return _pythran_build_evaluation_coefficients(
  15. x, y, kernel, epsilon, powers, shift, scale
  16. )
  17. def polynomial_matrix(x, powers, xp):
  18. return _pythran_polynomial_matrix(x, powers)
  19. def _monomial_powers(ndim, degree, xp):
  20. out = _monomial_powers_impl(ndim, degree)
  21. out = np.asarray(out, dtype=np.int64)
  22. if len(out) == 0:
  23. out = out.reshape(0, ndim)
  24. return out
  25. def _build_system(y, d, smoothing, kernel, epsilon, powers, xp):
  26. return _pythran_build_system(y, d, smoothing, kernel, epsilon, powers)
  27. def _build_and_solve_system(y, d, smoothing, kernel, epsilon, powers, xp):
  28. """Build and solve the RBF interpolation system of equations.
  29. Parameters
  30. ----------
  31. y : (P, N) float ndarray
  32. Data point coordinates.
  33. d : (P, S) float ndarray
  34. Data values at `y`.
  35. smoothing : (P,) float ndarray
  36. Smoothing parameter for each data point.
  37. kernel : str
  38. Name of the RBF.
  39. epsilon : float
  40. Shape parameter.
  41. powers : (R, N) int ndarray
  42. The exponents for each monomial in the polynomial.
  43. Returns
  44. -------
  45. coeffs : (P + R, S) float ndarray
  46. Coefficients for each RBF and monomial.
  47. shift : (N,) float ndarray
  48. Domain shift used to create the polynomial matrix.
  49. scale : (N,) float ndarray
  50. Domain scaling used to create the polynomial matrix.
  51. """
  52. lhs, rhs, shift, scale = _build_system(
  53. y, d, smoothing, kernel, epsilon, powers, xp
  54. )
  55. _, _, coeffs, info = dgesv(lhs, rhs, overwrite_a=True, overwrite_b=True)
  56. if info < 0:
  57. raise ValueError(f"The {-info}-th argument had an illegal value.")
  58. elif info > 0:
  59. msg = "Singular matrix."
  60. nmonos = powers.shape[0]
  61. if nmonos > 0:
  62. pmat = polynomial_matrix((y - shift)/scale, powers, xp)
  63. rank = np.linalg.matrix_rank(pmat)
  64. if rank < nmonos:
  65. msg = (
  66. "Singular matrix. The matrix of monomials evaluated at "
  67. "the data point coordinates does not have full column "
  68. f"rank ({rank}/{nmonos})."
  69. )
  70. raise LinAlgError(msg)
  71. return shift, scale, coeffs
  72. def compute_interpolation(x, y, kernel, epsilon, powers, shift, scale, coeffs, xp):
  73. vec = _build_evaluation_coefficients(
  74. x, y, kernel, epsilon, powers, shift, scale, xp
  75. )
  76. return vec @ coeffs