_cubic.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029
  1. """Interpolation algorithms using piecewise cubic polynomials."""
  2. from typing import Literal
  3. import numpy as np
  4. from scipy.linalg import solve, solve_banded
  5. from scipy._lib._array_api import array_namespace, xp_size, xp_capabilities
  6. from scipy._lib.array_api_compat import numpy as np_compat
  7. from . import PPoly
  8. from ._polyint import _isscalar
  9. __all__ = ["CubicHermiteSpline", "PchipInterpolator", "pchip_interpolate",
  10. "Akima1DInterpolator", "CubicSpline"]
  11. def prepare_input(x, y, axis, dydx=None, xp=None):
  12. """Prepare input for cubic spline interpolators.
  13. All data are converted to numpy arrays and checked for correctness.
  14. Axes equal to `axis` of arrays `y` and `dydx` are moved to be the 0th
  15. axis. The value of `axis` is converted to lie in
  16. [0, number of dimensions of `y`).
  17. """
  18. x, y = map(xp.asarray, (x, y))
  19. if xp.isdtype(x.dtype, "complex floating"):
  20. raise ValueError("`x` must contain real values.")
  21. x = xp.astype(x, xp.float64)
  22. if xp.isdtype(y.dtype, "complex floating"):
  23. dtype = xp.complex128
  24. else:
  25. dtype = xp.float64
  26. if dydx is not None:
  27. dydx = xp.asarray(dydx)
  28. if y.shape != dydx.shape:
  29. raise ValueError("The shapes of `y` and `dydx` must be identical.")
  30. if xp.isdtype(dydx.dtype, "complex floating"):
  31. dtype = xp.complex128
  32. dydx = xp.astype(dydx, dtype, copy=False)
  33. y = xp.astype(y, dtype, copy=False)
  34. axis = axis % y.ndim
  35. if x.ndim != 1:
  36. raise ValueError("`x` must be 1-dimensional.")
  37. if x.shape[0] < 2:
  38. raise ValueError("`x` must contain at least 2 elements.")
  39. if x.shape[0] != y.shape[axis]:
  40. raise ValueError(f"The length of `y` along `axis`={axis} doesn't "
  41. "match the length of `x`")
  42. if not xp.all(xp.isfinite(x)):
  43. raise ValueError("`x` must contain only finite values.")
  44. if not xp.all(xp.isfinite(y)):
  45. raise ValueError("`y` must contain only finite values.")
  46. if dydx is not None and not xp.all(xp.isfinite(dydx)):
  47. raise ValueError("`dydx` must contain only finite values.")
  48. dx = xp.diff(x)
  49. if xp.any(dx <= 0):
  50. raise ValueError("`x` must be strictly increasing sequence.")
  51. y = xp.moveaxis(y, axis, 0)
  52. if dydx is not None:
  53. dydx = xp.moveaxis(dydx, axis, 0)
  54. return x, dx, y, axis, dydx
  55. @xp_capabilities(
  56. cpu_only=True, jax_jit=False,
  57. skip_backends=[
  58. ("dask.array",
  59. "https://github.com/data-apis/array-api-extra/issues/488")
  60. ]
  61. )
  62. class CubicHermiteSpline(PPoly):
  63. """Piecewise cubic interpolator to fit values and first derivatives (C1 smooth).
  64. The result is represented as a `PPoly` instance.
  65. Parameters
  66. ----------
  67. x : array_like, shape (n,)
  68. 1-D array containing values of the independent variable.
  69. Values must be real, finite and in strictly increasing order.
  70. y : array_like
  71. Array containing values of the dependent variable. It can have
  72. arbitrary number of dimensions, but the length along ``axis``
  73. (see below) must match the length of ``x``. Values must be finite.
  74. dydx : array_like
  75. Array containing derivatives of the dependent variable. It can have
  76. arbitrary number of dimensions, but the length along ``axis``
  77. (see below) must match the length of ``x``. Values must be finite.
  78. axis : int, optional
  79. Axis along which `y` is assumed to be varying. Meaning that for
  80. ``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``.
  81. Default is 0.
  82. extrapolate : {bool, 'periodic', None}, optional
  83. If bool, determines whether to extrapolate to out-of-bounds points
  84. based on first and last intervals, or to return NaNs. If 'periodic',
  85. periodic extrapolation is used. If None (default), it is set to True.
  86. Attributes
  87. ----------
  88. x : ndarray, shape (n,)
  89. Breakpoints. The same ``x`` which was passed to the constructor.
  90. c : ndarray, shape (4, n-1, ...)
  91. Coefficients of the polynomials on each segment. The trailing
  92. dimensions match the dimensions of `y`, excluding ``axis``.
  93. For example, if `y` is 1-D, then ``c[k, i]`` is a coefficient for
  94. ``(x-x[i])**(3-k)`` on the segment between ``x[i]`` and ``x[i+1]``.
  95. axis : int
  96. Interpolation axis. The same axis which was passed to the
  97. constructor.
  98. Methods
  99. -------
  100. __call__
  101. derivative
  102. antiderivative
  103. integrate
  104. solve
  105. roots
  106. See Also
  107. --------
  108. Akima1DInterpolator : Akima 1D interpolator.
  109. PchipInterpolator : PCHIP 1-D monotonic cubic interpolator.
  110. CubicSpline : Cubic spline data interpolator.
  111. PPoly : Piecewise polynomial in terms of coefficients and breakpoints
  112. Notes
  113. -----
  114. If you want to create a higher-order spline matching higher-order
  115. derivatives, use `BPoly.from_derivatives`.
  116. References
  117. ----------
  118. .. [1] `Cubic Hermite spline
  119. <https://en.wikipedia.org/wiki/Cubic_Hermite_spline>`_
  120. on Wikipedia.
  121. """
  122. def __init__(self, x, y, dydx, axis=0, extrapolate=None):
  123. xp = array_namespace(x, y, dydx)
  124. if extrapolate is None:
  125. extrapolate = True
  126. x, dx, y, axis, dydx = prepare_input(x, y, axis, dydx, xp=xp)
  127. dxr = xp.reshape(dx, (dx.shape[0], ) + (1, ) * (y.ndim - 1))
  128. slope = xp.diff(y, axis=0) / dxr
  129. t = (dydx[:-1, ...] + dydx[1:, ...] - 2 * slope) / dxr
  130. c = xp.stack((
  131. t / dxr,
  132. (slope - dydx[:-1, ...]) / dxr - t,
  133. dydx[:-1, ...],
  134. y[:-1, ...]
  135. ))
  136. super().__init__(c, x, extrapolate=extrapolate)
  137. self.axis = axis
  138. @xp_capabilities(
  139. cpu_only=True, jax_jit=False,
  140. skip_backends=[
  141. ("dask.array",
  142. "https://github.com/data-apis/array-api-extra/issues/488")
  143. ]
  144. )
  145. class PchipInterpolator(CubicHermiteSpline):
  146. r"""PCHIP shape-preserving interpolator (C1 smooth).
  147. ``x`` and ``y`` are arrays of values used to approximate some function f,
  148. with ``y = f(x)``. The interpolant uses monotonic cubic splines
  149. to find the value of new points. (PCHIP stands for Piecewise Cubic
  150. Hermite Interpolating Polynomial).
  151. Parameters
  152. ----------
  153. x : ndarray, shape (npoints, )
  154. A 1-D array of monotonically increasing real values. ``x`` cannot
  155. include duplicate values (otherwise f is overspecified)
  156. y : ndarray, shape (..., npoints, ...)
  157. An N-D array of real values. ``y``'s length along the interpolation
  158. axis must be equal to the length of ``x``. Use the ``axis``
  159. parameter to select the interpolation axis.
  160. axis : int, optional
  161. Axis in the ``y`` array corresponding to the x-coordinate values. Defaults
  162. to ``axis=0``.
  163. extrapolate : bool, optional
  164. Whether to extrapolate to out-of-bounds points based on first
  165. and last intervals, or to return NaNs.
  166. Methods
  167. -------
  168. __call__
  169. derivative
  170. antiderivative
  171. integrate
  172. solve
  173. roots
  174. See Also
  175. --------
  176. CubicHermiteSpline : Piecewise-cubic interpolator.
  177. Akima1DInterpolator : Akima 1D interpolator.
  178. CubicSpline : Cubic spline data interpolator.
  179. PPoly : Piecewise polynomial in terms of coefficients and breakpoints.
  180. Notes
  181. -----
  182. The interpolator preserves monotonicity in the interpolation data and does
  183. not overshoot if the data is not smooth.
  184. The first derivatives are guaranteed to be continuous, but the second
  185. derivatives may jump at :math:`x_k`.
  186. Determines the derivatives at the points :math:`x_k`, :math:`f'_k`,
  187. by using PCHIP algorithm [1]_.
  188. Let :math:`h_k = x_{k+1} - x_k`, and :math:`d_k = (y_{k+1} - y_k) / h_k`
  189. are the slopes at internal points :math:`x_k`.
  190. If the signs of :math:`d_k` and :math:`d_{k-1}` are different or either of
  191. them equals zero, then :math:`f'_k = 0`. Otherwise, it is given by the
  192. weighted harmonic mean
  193. .. math::
  194. \frac{w_1 + w_2}{f'_k} = \frac{w_1}{d_{k-1}} + \frac{w_2}{d_k}
  195. where :math:`w_1 = 2 h_k + h_{k-1}` and :math:`w_2 = h_k + 2 h_{k-1}`.
  196. The end slopes are set using a one-sided scheme [2]_.
  197. References
  198. ----------
  199. .. [1] F. N. Fritsch and J. Butland,
  200. A method for constructing local
  201. monotone piecewise cubic interpolants,
  202. SIAM J. Sci. Comput., 5(2), 300-304 (1984).
  203. :doi:`10.1137/0905021`.
  204. .. [2] C. Moler, Numerical Computing with Matlab, 2004.
  205. :doi:`10.1137/1.9780898717952`
  206. """
  207. # PchipInterpolator is not generic in scipy-stubs
  208. __class_getitem__ = None
  209. def __init__(self, x, y, axis=0, extrapolate=None):
  210. xp = array_namespace(x, y)
  211. x, _, y, axis, _ = prepare_input(x, y, axis, xp=xp)
  212. if xp.isdtype(y.dtype, "complex floating"):
  213. msg = ("`PchipInterpolator` only works with real values for `y`. "
  214. "If you are trying to use the real components of the passed array, "
  215. "use `np.real` on the array before passing to `PchipInterpolator`.")
  216. raise ValueError(msg)
  217. xv = xp.reshape(x, (x.shape[0],) + (1,)*(y.ndim-1))
  218. dk = self._find_derivatives(xv, y, xp=xp)
  219. super().__init__(x, y, dk, axis=0, extrapolate=extrapolate)
  220. self.axis = axis
  221. @staticmethod
  222. def _edge_case(h0, h1, m0, m1, xp):
  223. # one-sided three-point estimate for the derivative
  224. d = ((2*h0 + h1)*m0 - h0*m1) / (h0 + h1)
  225. # try to preserve shape
  226. mask = xp.sign(d) != xp.sign(m0)
  227. mask2 = (xp.sign(m0) != xp.sign(m1)) & (xp.abs(d) > 3.*xp.abs(m0))
  228. mmm = (~mask) & mask2
  229. d[mask] = 0.
  230. d[mmm] = 3.*m0[mmm]
  231. return d
  232. @staticmethod
  233. def _find_derivatives(x, y, xp):
  234. # Determine the derivatives at the points y_k, d_k, by using
  235. # PCHIP algorithm is:
  236. # We choose the derivatives at the point x_k by
  237. # Let m_k be the slope of the kth segment (between k and k+1)
  238. # If m_k=0 or m_{k-1}=0 or sgn(m_k) != sgn(m_{k-1}) then d_k == 0
  239. # else use weighted harmonic mean:
  240. # w_1 = 2h_k + h_{k-1}, w_2 = h_k + 2h_{k-1}
  241. # 1/d_k = 1/(w_1 + w_2)*(w_1 / m_k + w_2 / m_{k-1})
  242. # where h_k is the spacing between x_k and x_{k+1}
  243. y_shape = y.shape
  244. if y.ndim == 1:
  245. # So that _edge_case doesn't end up assigning to scalars
  246. x = x[:, None]
  247. y = y[:, None]
  248. hk = x[1:] - x[:-1]
  249. mk = (y[1:] - y[:-1]) / hk
  250. if y.shape[0] == 2:
  251. # edge case: only have two points, use linear interpolation
  252. dk = xp.zeros_like(y)
  253. dk[0] = mk
  254. dk[1] = mk
  255. return xp.reshape(dk, y_shape)
  256. smk = xp.sign(mk)
  257. condition = (smk[1:] != smk[:-1]) | (mk[1:] == 0) | (mk[:-1] == 0)
  258. w1 = 2*hk[1:] + hk[:-1]
  259. w2 = hk[1:] + 2*hk[:-1]
  260. # values where division by zero occurs will be excluded
  261. # by 'condition' afterwards
  262. with np.errstate(divide='ignore', invalid='ignore'):
  263. whmean = (w1/mk[:-1] + w2/mk[1:]) / (w1 + w2)
  264. dk = np.zeros_like(y)
  265. dk[1:-1][condition] = 0.0
  266. dk[1:-1][~condition] = 1.0 / whmean[~condition]
  267. # special case endpoints, as suggested in
  268. # Cleve Moler, Numerical Computing with MATLAB, Chap 3.6 (pchiptx.m)
  269. dk[0] = PchipInterpolator._edge_case(hk[0], hk[1], mk[0], mk[1], xp=xp)
  270. dk[-1] = PchipInterpolator._edge_case(hk[-1], hk[-2], mk[-1], mk[-2], xp=xp)
  271. return xp.reshape(dk, y_shape)
  272. def pchip_interpolate(xi, yi, x, der=0, axis=0):
  273. """
  274. Convenience function for pchip interpolation.
  275. xi and yi are arrays of values used to approximate some function f,
  276. with ``yi = f(xi)``. The interpolant uses monotonic cubic splines
  277. to find the value of new points x and the derivatives there.
  278. See `scipy.interpolate.PchipInterpolator` for details.
  279. Parameters
  280. ----------
  281. xi : array_like
  282. A sorted list of x-coordinates, of length N.
  283. yi : array_like
  284. A 1-D array of real values. `yi`'s length along the interpolation
  285. axis must be equal to the length of `xi`. If N-D array, use axis
  286. parameter to select correct axis.
  287. x : scalar or array_like
  288. Of length M.
  289. der : int or list, optional
  290. Derivatives to extract. The 0th derivative can be included to
  291. return the function value.
  292. axis : int, optional
  293. Axis in the yi array corresponding to the x-coordinate values.
  294. Returns
  295. -------
  296. y : scalar or array_like
  297. The result, of length R or length M or M by R.
  298. See Also
  299. --------
  300. PchipInterpolator : PCHIP 1-D monotonic cubic interpolator.
  301. Examples
  302. --------
  303. We can interpolate 2D observed data using pchip interpolation:
  304. >>> import numpy as np
  305. >>> import matplotlib.pyplot as plt
  306. >>> from scipy.interpolate import pchip_interpolate
  307. >>> x_observed = np.linspace(0.0, 10.0, 11)
  308. >>> y_observed = np.sin(x_observed)
  309. >>> x = np.linspace(min(x_observed), max(x_observed), num=100)
  310. >>> y = pchip_interpolate(x_observed, y_observed, x)
  311. >>> plt.plot(x_observed, y_observed, "o", label="observation")
  312. >>> plt.plot(x, y, label="pchip interpolation")
  313. >>> plt.legend()
  314. >>> plt.show()
  315. """
  316. P = PchipInterpolator(xi, yi, axis=axis)
  317. if der == 0:
  318. return P(x)
  319. elif _isscalar(der):
  320. return P.derivative(der)(x)
  321. else:
  322. return [P.derivative(nu)(x) for nu in der]
  323. @xp_capabilities(cpu_only=True, xfail_backends=[
  324. ("dask.array", "lacks nd fancy indexing"),
  325. ("jax.numpy", "immutable arrays"),
  326. ("array_api_strict", "fancy indexing __setitem__"),
  327. ])
  328. class Akima1DInterpolator(CubicHermiteSpline):
  329. r"""Akima "visually pleasing" interpolator (C1 smooth).
  330. Fit piecewise cubic polynomials, given vectors x and y. The interpolation
  331. method by Akima uses a continuously differentiable sub-spline built from
  332. piecewise cubic polynomials. The resultant curve passes through the given
  333. data points and will appear smooth and natural.
  334. Parameters
  335. ----------
  336. x : ndarray, shape (npoints, )
  337. 1-D array of monotonically increasing real values.
  338. y : ndarray, shape (..., npoints, ...)
  339. N-D array of real values. The length of ``y`` along the interpolation axis
  340. must be equal to the length of ``x``. Use the ``axis`` parameter to
  341. select the interpolation axis.
  342. axis : int, optional
  343. Axis in the ``y`` array corresponding to the x-coordinate values. Defaults
  344. to ``axis=0``.
  345. method : {'akima', 'makima'}, optional
  346. If ``"makima"``, use the modified Akima interpolation [2]_.
  347. Defaults to ``"akima"``, use the Akima interpolation [1]_.
  348. .. versionadded:: 1.13.0
  349. extrapolate : {bool, None}, optional
  350. If bool, determines whether to extrapolate to out-of-bounds points
  351. based on first and last intervals, or to return NaNs. If None,
  352. ``extrapolate`` is set to False.
  353. Methods
  354. -------
  355. __call__
  356. derivative
  357. antiderivative
  358. integrate
  359. solve
  360. roots
  361. See Also
  362. --------
  363. PchipInterpolator : PCHIP 1-D monotonic cubic interpolator.
  364. CubicSpline : Cubic spline data interpolator.
  365. PPoly : Piecewise polynomial in terms of coefficients and breakpoints
  366. Notes
  367. -----
  368. .. versionadded:: 0.14
  369. Use only for precise data, as the fitted curve passes through the given
  370. points exactly. This routine is useful for plotting a pleasingly smooth
  371. curve through a few given points for purposes of plotting.
  372. Let :math:`\delta_i = (y_{i+1} - y_i) / (x_{i+1} - x_i)` be the slopes of
  373. the interval :math:`\left[x_i, x_{i+1}\right)`. Akima's derivative at
  374. :math:`x_i` is defined as:
  375. .. math::
  376. d_i = \frac{w_1}{w_1 + w_2}\delta_{i-1} + \frac{w_2}{w_1 + w_2}\delta_i
  377. In the Akima interpolation [1]_ (``method="akima"``), the weights are:
  378. .. math::
  379. \begin{aligned}
  380. w_1 &= |\delta_{i+1} - \delta_i| \\
  381. w_2 &= |\delta_{i-1} - \delta_{i-2}|
  382. \end{aligned}
  383. In the modified Akima interpolation [2]_ (``method="makima"``),
  384. to eliminate overshoot and avoid edge cases of both numerator and
  385. denominator being equal to 0, the weights are modified as follows:
  386. .. math::
  387. \begin{align*}
  388. w_1 &= |\delta_{i+1} - \delta_i| + |\delta_{i+1} + \delta_i| / 2 \\
  389. w_2 &= |\delta_{i-1} - \delta_{i-2}| + |\delta_{i-1} + \delta_{i-2}| / 2
  390. \end{align*}
  391. Examples
  392. --------
  393. Comparison of ``method="akima"`` and ``method="makima"``:
  394. >>> import numpy as np
  395. >>> from scipy.interpolate import Akima1DInterpolator
  396. >>> import matplotlib.pyplot as plt
  397. >>> x = np.linspace(1, 7, 7)
  398. >>> y = np.array([-1, -1, -1, 0, 1, 1, 1])
  399. >>> xs = np.linspace(min(x), max(x), num=100)
  400. >>> y_akima = Akima1DInterpolator(x, y, method="akima")(xs)
  401. >>> y_makima = Akima1DInterpolator(x, y, method="makima")(xs)
  402. >>> fig, ax = plt.subplots()
  403. >>> ax.plot(x, y, "o", label="data")
  404. >>> ax.plot(xs, y_akima, label="akima")
  405. >>> ax.plot(xs, y_makima, label="makima")
  406. >>> ax.legend()
  407. >>> fig.show()
  408. The overshoot that occurred in ``"akima"`` has been avoided in ``"makima"``.
  409. References
  410. ----------
  411. .. [1] A new method of interpolation and smooth curve fitting based
  412. on local procedures. Hiroshi Akima, J. ACM, October 1970, 17(4),
  413. 589-602. :doi:`10.1145/321607.321609`
  414. .. [2] Makima Piecewise Cubic Interpolation. Cleve Moler and Cosmin Ionita, 2019.
  415. https://blogs.mathworks.com/cleve/2019/04/29/makima-piecewise-cubic-interpolation/
  416. """
  417. # PchipInterpolator is not generic in scipy-stubs
  418. __class_getitem__ = None
  419. def __init__(self, x, y, axis=0, *, method: Literal["akima", "makima"]="akima",
  420. extrapolate:bool | None = None):
  421. if method not in {"akima", "makima"}:
  422. raise NotImplementedError(f"`method`={method} is unsupported.")
  423. # Original implementation in MATLAB by N. Shamsundar (BSD licensed), see
  424. # https://www.mathworks.com/matlabcentral/fileexchange/1814-akima-interpolation
  425. xp = array_namespace(x, y)
  426. x, dx, y, axis, _ = prepare_input(x, y, axis, xp=xp)
  427. if xp.isdtype(y.dtype, "complex floating"):
  428. msg = ("`Akima1DInterpolator` only works with real values for `y`. "
  429. "If you are trying to use the real components of the passed array, "
  430. "use `np.real` on the array before passing to "
  431. "`Akima1DInterpolator`.")
  432. raise ValueError(msg)
  433. # Akima extrapolation historically False; parent class defaults to True.
  434. extrapolate = False if extrapolate is None else extrapolate
  435. if y.shape[0] == 2:
  436. # edge case: only have two points, use linear interpolation
  437. xv = xp.reshape(x, (x.shape[0],) + (1,)*(y.ndim-1))
  438. hk = xv[1:, ...] - xv[:-1, ...]
  439. mk = (y[1:, ...] - y[:-1, ...]) / hk
  440. t = xp.zeros_like(y)
  441. t[...] = mk
  442. else:
  443. # determine slopes between breakpoints
  444. m = xp.empty((x.shape[0] + 3, ) + y.shape[1:])
  445. dx = dx[(slice(None), ) + (None, ) * (y.ndim - 1)]
  446. m[2:-2, ...] = xp.diff(y, axis=0) / dx
  447. # add two additional points on the left ...
  448. m[1, ...] = 2. * m[2, ...] - m[3, ...]
  449. m[0, ...] = 2. * m[1, ...] - m[2, ...]
  450. # ... and on the right
  451. m[-2, ...] = 2. * m[-3, ...] - m[-4, ...]
  452. m[-1, ...] = 2. * m[-2, ...] - m[-3, ...]
  453. # if m1 == m2 != m3 == m4, the slope at the breakpoint is not
  454. # defined. This is the fill value:
  455. t = .5 * (m[3:, ...] + m[:-3, ...])
  456. # get the denominator of the slope t
  457. dm = xp.abs(xp.diff(m, axis=0))
  458. if method == "makima":
  459. pm = xp.abs(m[1:, ...] + m[:-1, ...])
  460. f1 = dm[2:, ...] + 0.5 * pm[2:, ...]
  461. f2 = dm[:-2, ...] + 0.5 * pm[:-2, ...]
  462. else:
  463. f1 = dm[2:, ...]
  464. f2 = dm[:-2, ...]
  465. # makima is more numerically stable for small f12,
  466. # so a finite cutoff should not improve any behavior
  467. # however, akima has a qualitative discontinuity near f12=0
  468. # a finite cutoff moves it, but cannot remove it.
  469. # the cutoff break_mult could be made a keyword argument
  470. # method='akima' also benefits from a check for m2=m3
  471. break_mult = 1.e-9
  472. f12 = f1 + f2
  473. # These are the mask of where the slope at breakpoint is defined:
  474. mmax = xp.max(f12) if xp_size(f12) > 0 else -xp.inf
  475. ind = xp.nonzero(f12 > break_mult * mmax)
  476. x_ind, y_ind = ind[0], ind[1:]
  477. # Set the slope at breakpoint
  478. t[ind] = m[(x_ind + 1,) + y_ind] + (
  479. (f2[ind] / f12[ind])
  480. * (m[(x_ind + 2,) + y_ind] - m[(x_ind + 1,) + y_ind])
  481. )
  482. super().__init__(x, y, t, axis=0, extrapolate=extrapolate)
  483. self.axis = axis
  484. def extend(self, c, x, right=True):
  485. raise NotImplementedError("Extending a 1-D Akima interpolator is not "
  486. "yet implemented")
  487. # These are inherited from PPoly, but they do not produce an Akima
  488. # interpolator. Hence stub them out.
  489. @classmethod
  490. def from_spline(cls, tck, extrapolate=None):
  491. raise NotImplementedError("This method does not make sense for "
  492. "an Akima interpolator.")
  493. @classmethod
  494. def from_bernstein_basis(cls, bp, extrapolate=None):
  495. raise NotImplementedError("This method does not make sense for "
  496. "an Akima interpolator.")
  497. @xp_capabilities(
  498. cpu_only=True, jax_jit=False,
  499. skip_backends=[
  500. ("dask.array",
  501. "https://github.com/data-apis/array-api-extra/issues/488")
  502. ]
  503. )
  504. class CubicSpline(CubicHermiteSpline):
  505. """Piecewise cubic interpolator to fit values (C2 smooth).
  506. Interpolate data with a piecewise cubic polynomial which is twice
  507. continuously differentiable [1]_. The result is represented as a `PPoly`
  508. instance with breakpoints matching the given data.
  509. Parameters
  510. ----------
  511. x : array_like, shape (n,)
  512. 1-D array containing values of the independent variable.
  513. Values must be real, finite and in strictly increasing order.
  514. y : array_like
  515. Array containing values of the dependent variable. It can have
  516. arbitrary number of dimensions, but the length along ``axis``
  517. (see below) must match the length of ``x``. Values must be finite.
  518. axis : int, optional
  519. Axis along which `y` is assumed to be varying. Meaning that for
  520. ``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``.
  521. Default is 0.
  522. bc_type : string or 2-tuple, optional
  523. Boundary condition type. Two additional equations, given by the
  524. boundary conditions, are required to determine all coefficients of
  525. polynomials on each segment [2]_.
  526. If `bc_type` is a string, then the specified condition will be applied
  527. at both ends of a spline. Available conditions are:
  528. * 'not-a-knot' (default): The first and second segment at a curve end
  529. are the same polynomial. It is a good default when there is no
  530. information on boundary conditions.
  531. * 'periodic': The interpolated functions is assumed to be periodic
  532. of period ``x[-1] - x[0]``. The first and last value of `y` must be
  533. identical: ``y[0] == y[-1]``. This boundary condition will result in
  534. ``y'[0] == y'[-1]`` and ``y''[0] == y''[-1]``.
  535. * 'clamped': The first derivative at curves ends are zero. Assuming
  536. a 1D `y`, ``bc_type=((1, 0.0), (1, 0.0))`` is the same condition.
  537. * 'natural': The second derivative at curve ends are zero. Assuming
  538. a 1D `y`, ``bc_type=((2, 0.0), (2, 0.0))`` is the same condition.
  539. If `bc_type` is a 2-tuple, the first and the second value will be
  540. applied at the curve start and end respectively. The tuple values can
  541. be one of the previously mentioned strings (except 'periodic') or a
  542. tuple ``(order, deriv_values)`` allowing to specify arbitrary
  543. derivatives at curve ends:
  544. * `order`: the derivative order, 1 or 2.
  545. * `deriv_value`: array_like containing derivative values, shape must
  546. be the same as `y`, excluding ``axis`` dimension. For example, if
  547. `y` is 1-D, then `deriv_value` must be a scalar. If `y` is 3-D with
  548. the shape (n0, n1, n2) and axis=2, then `deriv_value` must be 2-D
  549. and have the shape (n0, n1).
  550. extrapolate : {bool, 'periodic', None}, optional
  551. If bool, determines whether to extrapolate to out-of-bounds points
  552. based on first and last intervals, or to return NaNs. If 'periodic',
  553. periodic extrapolation is used. If None (default), ``extrapolate`` is
  554. set to 'periodic' for ``bc_type='periodic'`` and to True otherwise.
  555. Attributes
  556. ----------
  557. x : ndarray, shape (n,)
  558. Breakpoints. The same ``x`` which was passed to the constructor.
  559. c : ndarray, shape (4, n-1, ...)
  560. Coefficients of the polynomials on each segment. The trailing
  561. dimensions match the dimensions of `y`, excluding ``axis``.
  562. For example, if `y` is 1-d, then ``c[k, i]`` is a coefficient for
  563. ``(x-x[i])**(3-k)`` on the segment between ``x[i]`` and ``x[i+1]``.
  564. axis : int
  565. Interpolation axis. The same axis which was passed to the
  566. constructor.
  567. Methods
  568. -------
  569. __call__
  570. derivative
  571. antiderivative
  572. integrate
  573. solve
  574. roots
  575. See Also
  576. --------
  577. Akima1DInterpolator : Akima 1D interpolator.
  578. PchipInterpolator : PCHIP 1-D monotonic cubic interpolator.
  579. PPoly : Piecewise polynomial in terms of coefficients and breakpoints.
  580. Notes
  581. -----
  582. Parameters `bc_type` and ``extrapolate`` work independently, i.e. the
  583. former controls only construction of a spline, and the latter only
  584. evaluation.
  585. When a boundary condition is 'not-a-knot' and n = 2, it is replaced by
  586. a condition that the first derivative is equal to the linear interpolant
  587. slope. When both boundary conditions are 'not-a-knot' and n = 3, the
  588. solution is sought as a parabola passing through given points.
  589. When 'not-a-knot' boundary conditions is applied to both ends, the
  590. resulting spline will be the same as returned by `splrep` (with ``s=0``)
  591. and `InterpolatedUnivariateSpline`, but these two methods use a
  592. representation in B-spline basis.
  593. .. versionadded:: 0.18.0
  594. Examples
  595. --------
  596. In this example the cubic spline is used to interpolate a sampled sinusoid.
  597. You can see that the spline continuity property holds for the first and
  598. second derivatives and violates only for the third derivative.
  599. >>> import numpy as np
  600. >>> from scipy.interpolate import CubicSpline
  601. >>> import matplotlib.pyplot as plt
  602. >>> x = np.arange(10)
  603. >>> y = np.sin(x)
  604. >>> cs = CubicSpline(x, y)
  605. >>> xs = np.arange(-0.5, 9.6, 0.1)
  606. >>> fig, ax = plt.subplots(figsize=(6.5, 4))
  607. >>> ax.plot(x, y, 'o', label='data')
  608. >>> ax.plot(xs, np.sin(xs), label='true')
  609. >>> ax.plot(xs, cs(xs), label="S")
  610. >>> ax.plot(xs, cs(xs, 1), label="S'")
  611. >>> ax.plot(xs, cs(xs, 2), label="S''")
  612. >>> ax.plot(xs, cs(xs, 3), label="S'''")
  613. >>> ax.set_xlim(-0.5, 9.5)
  614. >>> ax.legend(loc='lower left', ncol=2)
  615. >>> plt.show()
  616. In the second example, the unit circle is interpolated with a spline. A
  617. periodic boundary condition is used. You can see that the first derivative
  618. values, ds/dx=0, ds/dy=1 at the periodic point (1, 0) are correctly
  619. computed. Note that a circle cannot be exactly represented by a cubic
  620. spline. To increase precision, more breakpoints would be required.
  621. >>> theta = 2 * np.pi * np.linspace(0, 1, 5)
  622. >>> y = np.c_[np.cos(theta), np.sin(theta)]
  623. >>> cs = CubicSpline(theta, y, bc_type='periodic')
  624. >>> print("ds/dx={:.1f} ds/dy={:.1f}".format(cs(0, 1)[0], cs(0, 1)[1]))
  625. ds/dx=0.0 ds/dy=1.0
  626. >>> xs = 2 * np.pi * np.linspace(0, 1, 100)
  627. >>> fig, ax = plt.subplots(figsize=(6.5, 4))
  628. >>> ax.plot(y[:, 0], y[:, 1], 'o', label='data')
  629. >>> ax.plot(np.cos(xs), np.sin(xs), label='true')
  630. >>> ax.plot(cs(xs)[:, 0], cs(xs)[:, 1], label='spline')
  631. >>> ax.axes.set_aspect('equal')
  632. >>> ax.legend(loc='center')
  633. >>> plt.show()
  634. The third example is the interpolation of a polynomial y = x**3 on the
  635. interval 0 <= x<= 1. A cubic spline can represent this function exactly.
  636. To achieve that we need to specify values and first derivatives at
  637. endpoints of the interval. Note that y' = 3 * x**2 and thus y'(0) = 0 and
  638. y'(1) = 3.
  639. >>> cs = CubicSpline([0, 1], [0, 1], bc_type=((1, 0), (1, 3)))
  640. >>> x = np.linspace(0, 1)
  641. >>> np.allclose(x**3, cs(x))
  642. True
  643. References
  644. ----------
  645. .. [1] `Cubic Spline Interpolation
  646. <https://en.wikiversity.org/wiki/Cubic_Spline_Interpolation>`_
  647. on Wikiversity.
  648. .. [2] Carl de Boor, "A Practical Guide to Splines", Springer-Verlag, 1978.
  649. """
  650. def __init__(self, x, y, axis=0, bc_type='not-a-knot', extrapolate=None):
  651. xp = array_namespace(x, y)
  652. x, dx, y, axis, _ = prepare_input(x, y, axis, xp=np_compat)
  653. n = len(x)
  654. bc, y = self._validate_bc(bc_type, y, y.shape[1:], axis)
  655. if extrapolate is None:
  656. if bc[0] == 'periodic':
  657. extrapolate = 'periodic'
  658. else:
  659. extrapolate = True
  660. if y.size == 0:
  661. # bail out early for zero-sized arrays
  662. s = np.zeros_like(y)
  663. else:
  664. dxr = dx.reshape([dx.shape[0]] + [1] * (y.ndim - 1))
  665. slope = np.diff(y, axis=0) / dxr
  666. # If bc is 'not-a-knot' this change is just a convention.
  667. # If bc is 'periodic' then we already checked that y[0] == y[-1],
  668. # and the spline is just a constant, we handle this case in the
  669. # same way by setting the first derivatives to slope, which is 0.
  670. if n == 2:
  671. if bc[0] in ['not-a-knot', 'periodic']:
  672. bc[0] = (1, slope[0])
  673. if bc[1] in ['not-a-knot', 'periodic']:
  674. bc[1] = (1, slope[0])
  675. # This is a special case, when both conditions are 'not-a-knot'
  676. # and n == 3. In this case 'not-a-knot' can't be handled regularly
  677. # as the both conditions are identical. We handle this case by
  678. # constructing a parabola passing through given points.
  679. if n == 3 and bc[0] == 'not-a-knot' and bc[1] == 'not-a-knot':
  680. A = np.zeros((3, 3)) # This is a standard matrix.
  681. b = np.empty((3,) + y.shape[1:], dtype=y.dtype)
  682. A[0, 0] = 1
  683. A[0, 1] = 1
  684. A[1, 0] = dx[1]
  685. A[1, 1] = 2 * (dx[0] + dx[1])
  686. A[1, 2] = dx[0]
  687. A[2, 1] = 1
  688. A[2, 2] = 1
  689. b[0] = 2 * slope[0]
  690. b[1] = 3 * (dxr[0] * slope[1] + dxr[1] * slope[0])
  691. b[2] = 2 * slope[1]
  692. m = b.shape[0]
  693. s = solve(A, b.reshape(m, -1), overwrite_a=True, overwrite_b=True,
  694. check_finite=False).reshape(b.shape)
  695. elif n == 3 and bc[0] == 'periodic':
  696. # In case when number of points is 3 we compute the derivatives
  697. # manually
  698. t = (slope / dxr).sum(0) / (1. / dxr).sum(0)
  699. s = np.broadcast_to(t, (n,) + y.shape[1:])
  700. else:
  701. # Find derivative values at each x[i] by solving a tridiagonal
  702. # system.
  703. A = np.zeros((3, n)) # This is a banded matrix representation.
  704. b = np.empty((n,) + y.shape[1:], dtype=y.dtype)
  705. # Filling the system for i=1..n-2
  706. # (x[i-1] - x[i]) * s[i-1] +\
  707. # 2 * ((x[i] - x[i-1]) + (x[i+1] - x[i])) * s[i] +\
  708. # (x[i] - x[i-1]) * s[i+1] =\
  709. # 3 * ((x[i+1] - x[i])*(y[i] - y[i-1])/(x[i] - x[i-1]) +\
  710. # (x[i] - x[i-1])*(y[i+1] - y[i])/(x[i+1] - x[i]))
  711. A[1, 1:-1] = 2 * (dx[:-1] + dx[1:]) # The diagonal
  712. A[0, 2:] = dx[:-1] # The upper diagonal
  713. A[-1, :-2] = dx[1:] # The lower diagonal
  714. b[1:-1] = 3 * (dxr[1:] * slope[:-1] + dxr[:-1] * slope[1:])
  715. bc_start, bc_end = bc
  716. if bc_start == 'periodic':
  717. # Due to the periodicity, and because y[-1] = y[0], the
  718. # linear system has (n-1) unknowns/equations instead of n:
  719. A = A[:, 0:-1]
  720. A[1, 0] = 2 * (dx[-1] + dx[0])
  721. A[0, 1] = dx[-1]
  722. b = b[:-1]
  723. # Also, due to the periodicity, the system is not tri-diagonal.
  724. # We need to compute a "condensed" matrix of shape (n-2, n-2).
  725. # See https://web.archive.org/web/20151220180652/http://www.cfm.brown.edu/people/gk/chap6/node14.html
  726. # for more explanations.
  727. # The condensed matrix is obtained by removing the last column
  728. # and last row of the (n-1, n-1) system matrix. The removed
  729. # values are saved in scalar variables with the (n-1, n-1)
  730. # system matrix indices forming their names:
  731. a_m1_0 = dx[-2] # lower left corner value: A[-1, 0]
  732. a_m1_m2 = dx[-1]
  733. a_m1_m1 = 2 * (dx[-1] + dx[-2])
  734. a_m2_m1 = dx[-3]
  735. a_0_m1 = dx[0]
  736. b[0] = 3 * (dxr[0] * slope[-1] + dxr[-1] * slope[0])
  737. b[-1] = 3 * (dxr[-1] * slope[-2] + dxr[-2] * slope[-1])
  738. Ac = A[:, :-1]
  739. b1 = b[:-1]
  740. b2 = np.zeros_like(b1)
  741. b2[0] = -a_0_m1
  742. b2[-1] = -a_m2_m1
  743. # s1 and s2 are the solutions of (n-2, n-2) system
  744. m = b1.shape[0]
  745. s1 = solve_banded((1, 1), Ac, b1.reshape(m, -1), overwrite_ab=False,
  746. overwrite_b=False, check_finite=False)
  747. s1 = s1.reshape(b1.shape)
  748. m = b2.shape[0]
  749. s2 = solve_banded((1, 1), Ac, b2.reshape(m, -1), overwrite_ab=False,
  750. overwrite_b=False, check_finite=False)
  751. s2 = s2.reshape(b2.shape)
  752. # computing the s[n-2] solution:
  753. s_m1 = ((b[-1] - a_m1_0 * s1[0] - a_m1_m2 * s1[-1]) /
  754. (a_m1_m1 + a_m1_0 * s2[0] + a_m1_m2 * s2[-1]))
  755. # s is the solution of the (n, n) system:
  756. s = np.empty((n,) + y.shape[1:], dtype=y.dtype)
  757. s[:-2] = s1 + s_m1 * s2
  758. s[-2] = s_m1
  759. s[-1] = s[0]
  760. else:
  761. if bc_start == 'not-a-knot':
  762. A[1, 0] = dx[1]
  763. A[0, 1] = x[2] - x[0]
  764. d = x[2] - x[0]
  765. b[0] = ((dxr[0] + 2*d) * dxr[1] * slope[0] +
  766. dxr[0]**2 * slope[1]) / d
  767. elif bc_start[0] == 1:
  768. A[1, 0] = 1
  769. A[0, 1] = 0
  770. b[0] = bc_start[1]
  771. elif bc_start[0] == 2:
  772. A[1, 0] = 2 * dx[0]
  773. A[0, 1] = dx[0]
  774. b[0] = -0.5 * bc_start[1] * dx[0]**2 + 3 * (y[1] - y[0])
  775. if bc_end == 'not-a-knot':
  776. A[1, -1] = dx[-2]
  777. A[-1, -2] = x[-1] - x[-3]
  778. d = x[-1] - x[-3]
  779. b[-1] = ((dxr[-1]**2*slope[-2] +
  780. (2*d + dxr[-1])*dxr[-2]*slope[-1]) / d)
  781. elif bc_end[0] == 1:
  782. A[1, -1] = 1
  783. A[-1, -2] = 0
  784. b[-1] = bc_end[1]
  785. elif bc_end[0] == 2:
  786. A[1, -1] = 2 * dx[-1]
  787. A[-1, -2] = dx[-1]
  788. b[-1] = 0.5 * bc_end[1] * dx[-1]**2 + 3 * (y[-1] - y[-2])
  789. m = b.shape[0]
  790. s = solve_banded((1, 1), A, b.reshape(m, -1), overwrite_ab=True,
  791. overwrite_b=True, check_finite=False)
  792. s = s.reshape(b.shape)
  793. x, y, s = map(xp.asarray, (x, y, s))
  794. super().__init__(x, y, s, axis=0, extrapolate=extrapolate)
  795. self.axis = axis
  796. @staticmethod
  797. def _validate_bc(bc_type, y, expected_deriv_shape, axis):
  798. """Validate and prepare boundary conditions.
  799. Returns
  800. -------
  801. validated_bc : 2-tuple
  802. Boundary conditions for a curve start and end.
  803. y : ndarray
  804. y casted to complex dtype if one of the boundary conditions has
  805. complex dtype.
  806. """
  807. if isinstance(bc_type, str):
  808. if bc_type == 'periodic':
  809. if not np.allclose(y[0], y[-1], rtol=1e-15, atol=1e-15):
  810. raise ValueError(
  811. f"The first and last `y` point along axis {axis} must "
  812. "be identical (within machine precision) when "
  813. "bc_type='periodic'.")
  814. bc_type = (bc_type, bc_type)
  815. else:
  816. if len(bc_type) != 2:
  817. raise ValueError("`bc_type` must contain 2 elements to "
  818. "specify start and end conditions.")
  819. if 'periodic' in bc_type:
  820. raise ValueError("'periodic' `bc_type` is defined for both "
  821. "curve ends and cannot be used with other "
  822. "boundary conditions.")
  823. validated_bc = []
  824. for bc in bc_type:
  825. if isinstance(bc, str):
  826. if bc == 'clamped':
  827. validated_bc.append((1, np.zeros(expected_deriv_shape)))
  828. elif bc == 'natural':
  829. validated_bc.append((2, np.zeros(expected_deriv_shape)))
  830. elif bc in ['not-a-knot', 'periodic']:
  831. validated_bc.append(bc)
  832. else:
  833. raise ValueError(f"bc_type={bc} is not allowed.")
  834. else:
  835. try:
  836. deriv_order, deriv_value = bc
  837. except Exception as e:
  838. raise ValueError(
  839. "A specified derivative value must be "
  840. "given in the form (order, value)."
  841. ) from e
  842. if deriv_order not in [1, 2]:
  843. raise ValueError("The specified derivative order must "
  844. "be 1 or 2.")
  845. deriv_value = np.asarray(deriv_value)
  846. if deriv_value.shape != expected_deriv_shape:
  847. raise ValueError(
  848. f"`deriv_value` shape {deriv_value.shape} is not "
  849. f"the expected one {expected_deriv_shape}."
  850. )
  851. if np.issubdtype(deriv_value.dtype, np.complexfloating):
  852. y = y.astype(complex, copy=False)
  853. validated_bc.append((deriv_order, deriv_value))
  854. return validated_bc, y