_interpolate.py 80 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312
  1. __all__ = ['interp1d', 'interp2d', 'lagrange', 'PPoly', 'BPoly', 'NdPPoly']
  2. from math import prod
  3. from types import GenericAlias
  4. import numpy as np
  5. from numpy import array, asarray, intp, poly1d, searchsorted
  6. import scipy.special as spec
  7. from scipy._lib._util import copy_if_needed
  8. from scipy.special import comb
  9. from scipy._lib._array_api import array_namespace, xp_capabilities
  10. from . import _fitpack_py
  11. from ._polyint import _Interpolator1D
  12. from . import _ppoly
  13. from ._interpnd import _ndim_coords_from_arrays
  14. from ._bsplines import make_interp_spline, BSpline
  15. def lagrange(x, w):
  16. r"""
  17. Return a Lagrange interpolating polynomial.
  18. Given two 1-D arrays `x` and `w,` returns the Lagrange interpolating
  19. polynomial through the points ``(x, w)``.
  20. Warning: This implementation is numerically unstable. Do not expect to
  21. be able to use more than about 20 points even if they are chosen optimally.
  22. Parameters
  23. ----------
  24. x : array_like
  25. `x` represents the x-coordinates of a set of datapoints.
  26. w : array_like
  27. `w` represents the y-coordinates of a set of datapoints, i.e., f(`x`).
  28. Returns
  29. -------
  30. lagrange : `numpy.poly1d` instance
  31. The Lagrange interpolating polynomial.
  32. Notes
  33. -----
  34. The name of this function refers to the fact that the returned object represents
  35. a Lagrange polynomial, the unique polynomial of lowest degree that interpolates
  36. a given set of data [1]_. It computes the polynomial using Newton's divided
  37. differences formula [2]_; that is, it works with Newton basis polynomials rather
  38. than Lagrange basis polynomials. For numerical calculations, the barycentric form
  39. of Lagrange interpolation (`scipy.interpolate.BarycentricInterpolator`) is
  40. typically more appropriate.
  41. References
  42. ----------
  43. .. [1] Lagrange polynomial. *Wikipedia*.
  44. https://en.wikipedia.org/wiki/Lagrange_polynomial
  45. .. [2] Newton polynomial. *Wikipedia*.
  46. https://en.wikipedia.org/wiki/Newton_polynomial
  47. Examples
  48. --------
  49. Interpolate :math:`f(x) = x^3` by 3 points.
  50. >>> import numpy as np
  51. >>> from scipy.interpolate import lagrange
  52. >>> x = np.array([0, 1, 2])
  53. >>> y = x**3
  54. >>> poly = lagrange(x, y)
  55. Since there are only 3 points, the Lagrange polynomial has degree 2. Explicitly,
  56. it is given by
  57. .. math::
  58. \begin{aligned}
  59. L(x) &= 1\times \frac{x (x - 2)}{-1} + 8\times \frac{x (x-1)}{2} \\
  60. &= x (-2 + 3x)
  61. \end{aligned}
  62. >>> from numpy.polynomial.polynomial import Polynomial
  63. >>> Polynomial(poly.coef[::-1]).coef
  64. array([ 0., -2., 3.])
  65. >>> import matplotlib.pyplot as plt
  66. >>> x_new = np.arange(0, 2.1, 0.1)
  67. >>> plt.scatter(x, y, label='data')
  68. >>> plt.plot(x_new, Polynomial(poly.coef[::-1])(x_new), label='Polynomial')
  69. >>> plt.plot(x_new, 3*x_new**2 - 2*x_new + 0*x_new,
  70. ... label=r"$3 x^2 - 2 x$", linestyle='-.')
  71. >>> plt.legend()
  72. >>> plt.show()
  73. """
  74. M = len(x)
  75. p = poly1d(0.0)
  76. for j in range(M):
  77. pt = poly1d(w[j])
  78. for k in range(M):
  79. if k == j:
  80. continue
  81. fac = x[j]-x[k]
  82. pt *= poly1d([1.0, -x[k]])/fac
  83. p += pt
  84. return p
  85. # !! Need to find argument for keeping initialize. If it isn't
  86. # !! found, get rid of it!
  87. err_mesg = """\
  88. `interp2d` has been removed in SciPy 1.14.0.
  89. For legacy code, nearly bug-for-bug compatible replacements are
  90. `RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
  91. scattered 2D data.
  92. In new code, for regular grids use `RegularGridInterpolator` instead.
  93. For scattered data, prefer `LinearNDInterpolator` or
  94. `CloughTocher2DInterpolator`.
  95. For more details see
  96. https://scipy.github.io/devdocs/tutorial/interpolate/interp_transition_guide.html
  97. """
  98. class interp2d:
  99. """
  100. interp2d(x, y, z, kind='linear', copy=True, bounds_error=False,
  101. fill_value=None)
  102. Class for 2D interpolation (deprecated and removed)
  103. .. versionremoved:: 1.14.0
  104. `interp2d` has been removed in SciPy 1.14.0.
  105. For legacy code, nearly bug-for-bug compatible replacements are
  106. `RectBivariateSpline` on regular grids, and `bisplrep`/`bisplev` for
  107. scattered 2D data.
  108. In new code, for regular grids use `RegularGridInterpolator` instead.
  109. For scattered data, prefer `LinearNDInterpolator` or
  110. `CloughTocher2DInterpolator`.
  111. For more details see :ref:`interp-transition-guide`.
  112. """
  113. def __init__(self, x, y, z, kind='linear', copy=True, bounds_error=False,
  114. fill_value=None):
  115. raise NotImplementedError(err_mesg)
  116. def _check_broadcast_up_to(arr_from, shape_to, name):
  117. """Helper to check that arr_from broadcasts up to shape_to"""
  118. shape_from = arr_from.shape
  119. if len(shape_to) >= len(shape_from):
  120. for t, f in zip(shape_to[::-1], shape_from[::-1]):
  121. if f != 1 and f != t:
  122. break
  123. else: # all checks pass, do the upcasting that we need later
  124. if arr_from.size != 1 and arr_from.shape != shape_to:
  125. arr_from = np.ones(shape_to, arr_from.dtype) * arr_from
  126. return arr_from.ravel()
  127. # at least one check failed
  128. raise ValueError(f'{name} argument must be able to broadcast up '
  129. f'to shape {shape_to} but had shape {shape_from}')
  130. def _do_extrapolate(fill_value):
  131. """Helper to check if fill_value == "extrapolate" without warnings"""
  132. return (isinstance(fill_value, str) and
  133. fill_value == 'extrapolate')
  134. @xp_capabilities(out_of_scope=True)
  135. class interp1d(_Interpolator1D):
  136. """
  137. Interpolate a 1-D function (legacy).
  138. .. legacy:: class
  139. For a guide to the intended replacements for `interp1d` see
  140. :ref:`tutorial-interpolate_1Dsection`.
  141. `x` and `y` are arrays of values used to approximate some function f:
  142. ``y = f(x)``. This class returns a function whose call method uses
  143. interpolation to find the value of new points.
  144. Parameters
  145. ----------
  146. x : (npoints, ) array_like
  147. A 1-D array of real values.
  148. y : (..., npoints, ...) array_like
  149. An N-D array of real values. The length of `y` along the interpolation
  150. axis must be equal to the length of `x`. Use the ``axis`` parameter
  151. to select correct axis. Unlike other interpolators, the default
  152. interpolation axis is the last axis of `y`.
  153. kind : str or int, optional
  154. Specifies the kind of interpolation as a string or as an integer
  155. specifying the order of the spline interpolator to use.
  156. The string has to be one of 'linear', 'nearest', 'nearest-up', 'zero',
  157. 'slinear', 'quadratic', 'cubic', 'previous', or 'next'. 'zero',
  158. 'slinear', 'quadratic' and 'cubic' refer to a spline interpolation of
  159. zeroth, first, second or third order; 'previous' and 'next' simply
  160. return the previous or next value of the point; 'nearest-up' and
  161. 'nearest' differ when interpolating half-integers (e.g. 0.5, 1.5)
  162. in that 'nearest-up' rounds up and 'nearest' rounds down. Default
  163. is 'linear'.
  164. axis : int, optional
  165. Axis in the ``y`` array corresponding to the x-coordinate values. Unlike
  166. other interpolators, defaults to ``axis=-1``.
  167. copy : bool, optional
  168. If ``True``, the class makes internal copies of x and y. If ``False``,
  169. references to ``x`` and ``y`` are used if possible. The default is to copy.
  170. bounds_error : bool, optional
  171. If True, a ValueError is raised any time interpolation is attempted on
  172. a value outside of the range of x (where extrapolation is
  173. necessary). If False, out of bounds values are assigned `fill_value`.
  174. By default, an error is raised unless ``fill_value="extrapolate"``.
  175. fill_value : array-like or (array-like, array_like) or "extrapolate", optional
  176. - if a ndarray (or float), this value will be used to fill in for
  177. requested points outside of the data range. If not provided, then
  178. the default is NaN. The array-like must broadcast properly to the
  179. dimensions of the non-interpolation axes.
  180. - If a two-element tuple, then the first element is used as a
  181. fill value for ``x_new < x[0]`` and the second element is used for
  182. ``x_new > x[-1]``. Anything that is not a 2-element tuple (e.g.,
  183. list or ndarray, regardless of shape) is taken to be a single
  184. array-like argument meant to be used for both bounds as
  185. ``below, above = fill_value, fill_value``. Using a two-element tuple
  186. or ndarray requires ``bounds_error=False``.
  187. .. versionadded:: 0.17.0
  188. - If "extrapolate", then points outside the data range will be
  189. extrapolated.
  190. .. versionadded:: 0.17.0
  191. assume_sorted : bool, optional
  192. If False, values of `x` can be in any order and they are sorted first.
  193. If True, `x` has to be an array of monotonically increasing values.
  194. Attributes
  195. ----------
  196. fill_value
  197. Methods
  198. -------
  199. __call__
  200. See Also
  201. --------
  202. splrep, splev
  203. Spline interpolation/smoothing based on FITPACK.
  204. UnivariateSpline : An object-oriented wrapper of the FITPACK routines.
  205. interp2d : 2-D interpolation
  206. Notes
  207. -----
  208. Calling `interp1d` with NaNs present in input values results in
  209. undefined behaviour.
  210. Input values `x` and `y` must be convertible to `float` values like
  211. `int` or `float`.
  212. If the values in `x` are not unique, the resulting behavior is
  213. undefined and specific to the choice of `kind`, i.e., changing
  214. `kind` will change the behavior for duplicates.
  215. Examples
  216. --------
  217. >>> import numpy as np
  218. >>> import matplotlib.pyplot as plt
  219. >>> from scipy import interpolate
  220. >>> x = np.arange(0, 10)
  221. >>> y = np.exp(-x/3.0)
  222. >>> f = interpolate.interp1d(x, y)
  223. >>> xnew = np.arange(0, 9, 0.1)
  224. >>> ynew = f(xnew) # use interpolation function returned by `interp1d`
  225. >>> plt.plot(x, y, 'o', xnew, ynew, '-')
  226. >>> plt.show()
  227. """
  228. def __init__(self, x, y, kind='linear', axis=-1,
  229. copy=True, bounds_error=None, fill_value=np.nan,
  230. assume_sorted=False):
  231. """ Initialize a 1-D linear interpolation class."""
  232. _Interpolator1D.__init__(self, x, y, axis=axis)
  233. self.bounds_error = bounds_error # used by fill_value setter
  234. # `copy` keyword semantics changed in NumPy 2.0, once that is
  235. # the minimum version this can use `copy=None`.
  236. self.copy = copy
  237. if not copy:
  238. self.copy = copy_if_needed
  239. if kind in ['zero', 'slinear', 'quadratic', 'cubic']:
  240. order = {'zero': 0, 'slinear': 1,
  241. 'quadratic': 2, 'cubic': 3}[kind]
  242. kind = 'spline'
  243. elif isinstance(kind, int):
  244. order = kind
  245. kind = 'spline'
  246. elif kind not in ('linear', 'nearest', 'nearest-up', 'previous',
  247. 'next'):
  248. raise NotImplementedError(f"{kind} is unsupported: Use fitpack "
  249. "routines for other types.")
  250. x = array(x, copy=self.copy)
  251. y = array(y, copy=self.copy)
  252. if not assume_sorted:
  253. ind = np.argsort(x, kind="mergesort")
  254. x = x[ind]
  255. y = np.take(y, ind, axis=axis)
  256. if x.ndim != 1:
  257. raise ValueError("the x array must have exactly one dimension.")
  258. if y.ndim == 0:
  259. raise ValueError("the y array must have at least one dimension.")
  260. # Force-cast y to a floating-point type, if it's not yet one
  261. if not issubclass(y.dtype.type, np.inexact):
  262. y = y.astype(np.float64)
  263. # Backward compatibility
  264. self.axis = axis % y.ndim
  265. # Interpolation goes internally along the first axis
  266. self.y = y
  267. self._y = self._reshape_yi(self.y)
  268. self.x = x
  269. del y, x # clean up namespace to prevent misuse; use attributes
  270. self._kind = kind
  271. # Adjust to interpolation kind; store reference to *unbound*
  272. # interpolation methods, in order to avoid circular references to self
  273. # stored in the bound instance methods, and therefore delayed garbage
  274. # collection. See: https://docs.python.org/reference/datamodel.html
  275. if kind in ('linear', 'nearest', 'nearest-up', 'previous', 'next'):
  276. # Make a "view" of the y array that is rotated to the interpolation
  277. # axis.
  278. minval = 1
  279. if kind == 'nearest':
  280. # Do division before addition to prevent possible integer
  281. # overflow
  282. self._side = 'left'
  283. self.x_bds = self.x / 2.0
  284. self.x_bds = self.x_bds[1:] + self.x_bds[:-1]
  285. self._call = self.__class__._call_nearest
  286. elif kind == 'nearest-up':
  287. # Do division before addition to prevent possible integer
  288. # overflow
  289. self._side = 'right'
  290. self.x_bds = self.x / 2.0
  291. self.x_bds = self.x_bds[1:] + self.x_bds[:-1]
  292. self._call = self.__class__._call_nearest
  293. elif kind == 'previous':
  294. # Side for np.searchsorted and index for clipping
  295. self._side = 'left'
  296. self._ind = 0
  297. # Move x by one floating point value to the left
  298. self._x_shift = np.nextafter(self.x, -np.inf)
  299. self._call = self.__class__._call_previousnext
  300. if _do_extrapolate(fill_value):
  301. self._check_and_update_bounds_error_for_extrapolation()
  302. # assume y is sorted by x ascending order here.
  303. fill_value = (np.nan, np.take(self.y, -1, axis))
  304. elif kind == 'next':
  305. self._side = 'right'
  306. self._ind = 1
  307. # Move x by one floating point value to the right
  308. self._x_shift = np.nextafter(self.x, np.inf)
  309. self._call = self.__class__._call_previousnext
  310. if _do_extrapolate(fill_value):
  311. self._check_and_update_bounds_error_for_extrapolation()
  312. # assume y is sorted by x ascending order here.
  313. fill_value = (np.take(self.y, 0, axis), np.nan)
  314. else:
  315. # Check if we can delegate to numpy.interp (2x-10x faster).
  316. np_dtypes = (np.dtype(np.float64), np.dtype(int))
  317. cond = self.x.dtype in np_dtypes and self.y.dtype in np_dtypes
  318. cond = cond and self.y.ndim == 1
  319. cond = cond and not _do_extrapolate(fill_value)
  320. if cond:
  321. self._call = self.__class__._call_linear_np
  322. else:
  323. self._call = self.__class__._call_linear
  324. else:
  325. minval = order + 1
  326. rewrite_nan = False
  327. xx, yy = self.x, self._y
  328. if order > 1:
  329. # Quadratic or cubic spline. If input contains even a single
  330. # nan, then the output is all nans. We cannot just feed data
  331. # with nans to make_interp_spline because it calls LAPACK.
  332. # So, we make up a bogus x and y with no nans and use it
  333. # to get the correct shape of the output, which we then fill
  334. # with nans.
  335. # For slinear or zero order spline, we just pass nans through.
  336. mask = np.isnan(self.x)
  337. if mask.any():
  338. sx = self.x[~mask]
  339. if sx.size == 0:
  340. raise ValueError("`x` array is all-nan")
  341. xx = np.linspace(np.nanmin(self.x),
  342. np.nanmax(self.x),
  343. len(self.x))
  344. rewrite_nan = True
  345. if np.isnan(self._y).any():
  346. yy = np.ones_like(self._y)
  347. rewrite_nan = True
  348. self._spline = make_interp_spline(xx, yy, k=order,
  349. check_finite=False)
  350. if rewrite_nan:
  351. self._call = self.__class__._call_nan_spline
  352. else:
  353. self._call = self.__class__._call_spline
  354. if len(self.x) < minval:
  355. raise ValueError(f"x and y arrays must have at least {minval} entries")
  356. self.fill_value = fill_value # calls the setter, can modify bounds_err
  357. @property
  358. def fill_value(self):
  359. """The fill value."""
  360. # backwards compat: mimic a public attribute
  361. return self._fill_value_orig
  362. @fill_value.setter
  363. def fill_value(self, fill_value):
  364. # extrapolation only works for nearest neighbor and linear methods
  365. if _do_extrapolate(fill_value):
  366. self._check_and_update_bounds_error_for_extrapolation()
  367. self._extrapolate = True
  368. else:
  369. broadcast_shape = (self.y.shape[:self.axis] +
  370. self.y.shape[self.axis + 1:])
  371. if len(broadcast_shape) == 0:
  372. broadcast_shape = (1,)
  373. # it's either a pair (_below_range, _above_range) or a single value
  374. # for both above and below range
  375. if isinstance(fill_value, tuple) and len(fill_value) == 2:
  376. below_above = [np.asarray(fill_value[0]),
  377. np.asarray(fill_value[1])]
  378. names = ('fill_value (below)', 'fill_value (above)')
  379. for ii in range(2):
  380. below_above[ii] = _check_broadcast_up_to(
  381. below_above[ii], broadcast_shape, names[ii])
  382. else:
  383. fill_value = np.asarray(fill_value)
  384. below_above = [_check_broadcast_up_to(
  385. fill_value, broadcast_shape, 'fill_value')] * 2
  386. self._fill_value_below, self._fill_value_above = below_above
  387. self._extrapolate = False
  388. if self.bounds_error is None:
  389. self.bounds_error = True
  390. # backwards compat: fill_value was a public attr; make it writeable
  391. self._fill_value_orig = fill_value
  392. def _check_and_update_bounds_error_for_extrapolation(self):
  393. if self.bounds_error:
  394. raise ValueError("Cannot extrapolate and raise "
  395. "at the same time.")
  396. self.bounds_error = False
  397. def _call_linear_np(self, x_new):
  398. # Note that out-of-bounds values are taken care of in self._evaluate
  399. return np.interp(x_new, self.x, self.y)
  400. def _call_linear(self, x_new):
  401. # 2. Find where in the original data, the values to interpolate
  402. # would be inserted.
  403. # Note: If x_new[n] == x[m], then m is returned by searchsorted.
  404. x_new_indices = searchsorted(self.x, x_new)
  405. # 3. Clip x_new_indices so that they are within the range of
  406. # self.x indices and at least 1. Removes mis-interpolation
  407. # of x_new[n] = x[0]
  408. x_new_indices = x_new_indices.clip(1, len(self.x)-1).astype(int)
  409. # 4. Calculate the slope of regions that each x_new value falls in.
  410. lo = x_new_indices - 1
  411. hi = x_new_indices
  412. x_lo = self.x[lo]
  413. x_hi = self.x[hi]
  414. y_lo = self._y[lo]
  415. y_hi = self._y[hi]
  416. # Note that the following two expressions rely on the specifics of the
  417. # broadcasting semantics.
  418. slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  419. # 5. Calculate the actual value for each entry in x_new.
  420. y_new = slope*(x_new - x_lo)[:, None] + y_lo
  421. return y_new
  422. def _call_nearest(self, x_new):
  423. """ Find nearest neighbor interpolated y_new = f(x_new)."""
  424. # 2. Find where in the averaged data the values to interpolate
  425. # would be inserted.
  426. # Note: use side='left' (right) to searchsorted() to define the
  427. # halfway point to be nearest to the left (right) neighbor
  428. x_new_indices = searchsorted(self.x_bds, x_new, side=self._side)
  429. # 3. Clip x_new_indices so that they are within the range of x indices.
  430. x_new_indices = x_new_indices.clip(0, len(self.x)-1).astype(intp)
  431. # 4. Calculate the actual value for each entry in x_new.
  432. y_new = self._y[x_new_indices]
  433. return y_new
  434. def _call_previousnext(self, x_new):
  435. """Use previous/next neighbor of x_new, y_new = f(x_new)."""
  436. # 1. Get index of left/right value
  437. x_new_indices = searchsorted(self._x_shift, x_new, side=self._side)
  438. # 2. Clip x_new_indices so that they are within the range of x indices.
  439. x_new_indices = x_new_indices.clip(1-self._ind,
  440. len(self.x)-self._ind).astype(intp)
  441. # 3. Calculate the actual value for each entry in x_new.
  442. y_new = self._y[x_new_indices+self._ind-1]
  443. return y_new
  444. def _call_spline(self, x_new):
  445. return self._spline(x_new)
  446. def _call_nan_spline(self, x_new):
  447. out = self._spline(x_new)
  448. out[...] = np.nan
  449. return out
  450. def _evaluate(self, x_new):
  451. # 1. Handle values in x_new that are outside of x. Throw error,
  452. # or return a list of mask array indicating the outofbounds values.
  453. # The behavior is set by the bounds_error variable.
  454. x_new = asarray(x_new)
  455. y_new = self._call(self, x_new)
  456. if not self._extrapolate:
  457. below_bounds, above_bounds = self._check_bounds(x_new)
  458. if len(y_new) > 0:
  459. # Note fill_value must be broadcast up to the proper size
  460. # and flattened to work here
  461. y_new[below_bounds] = self._fill_value_below
  462. y_new[above_bounds] = self._fill_value_above
  463. return y_new
  464. def _check_bounds(self, x_new):
  465. """Check the inputs for being in the bounds of the interpolated data.
  466. Parameters
  467. ----------
  468. x_new : array
  469. Returns
  470. -------
  471. out_of_bounds : bool array
  472. The mask on x_new of values that are out of the bounds.
  473. """
  474. # If self.bounds_error is True, we raise an error if any x_new values
  475. # fall outside the range of x. Otherwise, we return an array indicating
  476. # which values are outside the boundary region.
  477. below_bounds = x_new < self.x[0]
  478. above_bounds = x_new > self.x[-1]
  479. if self.bounds_error and below_bounds.any():
  480. below_bounds_value = x_new[np.argmax(below_bounds)]
  481. raise ValueError(f"A value ({below_bounds_value}) in x_new is below "
  482. f"the interpolation range's minimum value ({self.x[0]}).")
  483. if self.bounds_error and above_bounds.any():
  484. above_bounds_value = x_new[np.argmax(above_bounds)]
  485. raise ValueError(f"A value ({above_bounds_value}) in x_new is above "
  486. f"the interpolation range's maximum value ({self.x[-1]}).")
  487. # !! Should we emit a warning if some values are out of bounds?
  488. # !! matlab does not.
  489. return below_bounds, above_bounds
  490. class _PPolyBase:
  491. """Base class for piecewise polynomials."""
  492. __slots__ = ('_c', '_x', 'extrapolate', 'axis', '_asarray')
  493. # generic type compatibility with scipy-stubs
  494. __class_getitem__ = classmethod(GenericAlias)
  495. def __init__(self, c, x, extrapolate=None, axis=0):
  496. self._asarray = array_namespace(c, x).asarray
  497. self._c = np.asarray(c)
  498. self._x = np.ascontiguousarray(x, dtype=np.float64)
  499. if extrapolate is None:
  500. extrapolate = True
  501. elif extrapolate != 'periodic':
  502. extrapolate = bool(extrapolate)
  503. self.extrapolate = extrapolate
  504. if self._c.ndim < 2:
  505. raise ValueError("Coefficients array must be at least "
  506. "2-dimensional.")
  507. if not (0 <= axis < self._c.ndim - 1):
  508. raise ValueError(f"axis={axis} must be between 0 and {self._c.ndim-1}")
  509. self.axis = axis
  510. if axis != 0:
  511. # move the interpolation axis to be the first one in self.c
  512. # More specifically, the target shape for self.c is (k, m, ...),
  513. # and axis !=0 means that we have c.shape (..., k, m, ...)
  514. # ^
  515. # axis
  516. # So we roll two of them.
  517. self._c = np.moveaxis(self._c, axis+1, 0)
  518. self._c = np.moveaxis(self._c, axis+1, 0)
  519. if self._x.ndim != 1:
  520. raise ValueError("x must be 1-dimensional")
  521. if self._x.size < 2:
  522. raise ValueError("at least 2 breakpoints are needed")
  523. if self._c.ndim < 2:
  524. raise ValueError("c must have at least 2 dimensions")
  525. if self._c.shape[0] == 0:
  526. raise ValueError("polynomial must be at least of order 0")
  527. if self._c.shape[1] != self._x.size-1:
  528. raise ValueError("number of coefficients != len(x)-1")
  529. dx = np.diff(self._x)
  530. if not (np.all(dx >= 0) or np.all(dx <= 0)):
  531. raise ValueError("`x` must be strictly increasing or decreasing.")
  532. dtype = self._get_dtype(self._c.dtype)
  533. self._c = np.ascontiguousarray(self._c, dtype=dtype)
  534. def _get_dtype(self, dtype):
  535. if np.issubdtype(dtype, np.complexfloating) \
  536. or np.issubdtype(self._c.dtype, np.complexfloating):
  537. return np.complex128
  538. else:
  539. return np.float64
  540. @property
  541. def x(self):
  542. return self._asarray(self._x)
  543. @x.setter
  544. def x(self, xval):
  545. self._x = np.asarray(xval)
  546. @property
  547. def c(self):
  548. return self._asarray(self._c)
  549. @c.setter
  550. def c(self, cval):
  551. self._c = np.asarray(cval)
  552. @classmethod
  553. def construct_fast(cls, c, x, extrapolate=None, axis=0):
  554. """
  555. Construct the piecewise polynomial without making checks.
  556. Takes the same parameters as the constructor. Input arguments
  557. ``c`` and ``x`` must be arrays of the correct shape and type. The
  558. ``c`` array can only be of dtypes float and complex, and ``x``
  559. array must have dtype float.
  560. """
  561. self = object.__new__(cls)
  562. self._c = np.asarray(c)
  563. self._x = np.asarray(x)
  564. self.axis = axis
  565. if extrapolate is None:
  566. extrapolate = True
  567. self.extrapolate = extrapolate
  568. self._asarray = array_namespace(c, x).asarray
  569. return self
  570. def _ensure_c_contiguous(self):
  571. """
  572. c and x may be modified by the user. The Cython code expects
  573. that they are C contiguous.
  574. """
  575. if not self._x.flags.c_contiguous:
  576. self._x = self._x.copy()
  577. if not self._c.flags.c_contiguous:
  578. self._c = self._c.copy()
  579. def extend(self, c, x):
  580. """
  581. Add additional breakpoints and coefficients to the polynomial.
  582. Parameters
  583. ----------
  584. c : ndarray, size (k, m, ...)
  585. Additional coefficients for polynomials in intervals. Note that
  586. the first additional interval will be formed using one of the
  587. ``self.x`` end points.
  588. x : ndarray, size (m,)
  589. Additional breakpoints. Must be sorted in the same order as
  590. ``self.x`` and either to the right or to the left of the current
  591. breakpoints.
  592. Notes
  593. -----
  594. This method is not thread safe and must not be executed concurrently
  595. with other methods available in this class. Doing so may cause
  596. unexpected errors or numerical output mismatches.
  597. """
  598. c = np.asarray(c)
  599. x = np.asarray(x)
  600. if c.ndim < 2:
  601. raise ValueError("invalid dimensions for c")
  602. if x.ndim != 1:
  603. raise ValueError("invalid dimensions for x")
  604. if x.shape[0] != c.shape[1]:
  605. raise ValueError(f"Shapes of x {x.shape} and c {c.shape} are incompatible")
  606. if c.shape[2:] != self._c.shape[2:] or c.ndim != self._c.ndim:
  607. raise ValueError(
  608. f"Shapes of c {c.shape} and self._c {self._c.shape} are incompatible"
  609. )
  610. if c.size == 0:
  611. return
  612. dx = np.diff(x)
  613. if not (np.all(dx >= 0) or np.all(dx <= 0)):
  614. raise ValueError("`x` is not sorted.")
  615. if self._x[-1] >= self._x[0]:
  616. if not x[-1] >= x[0]:
  617. raise ValueError("`x` is in the different order "
  618. "than `self.x`.")
  619. if x[0] >= self._x[-1]:
  620. action = 'append'
  621. elif x[-1] <= self._x[0]:
  622. action = 'prepend'
  623. else:
  624. raise ValueError("`x` is neither on the left or on the right "
  625. "from `self.x`.")
  626. else:
  627. if not x[-1] <= x[0]:
  628. raise ValueError("`x` is in the different order "
  629. "than `self.x`.")
  630. if x[0] <= self._x[-1]:
  631. action = 'append'
  632. elif x[-1] >= self._x[0]:
  633. action = 'prepend'
  634. else:
  635. raise ValueError("`x` is neither on the left or on the right "
  636. "from `self.x`.")
  637. dtype = self._get_dtype(c.dtype)
  638. k2 = max(c.shape[0], self._c.shape[0])
  639. c2 = np.zeros((k2, self._c.shape[1] + c.shape[1]) + self._c.shape[2:],
  640. dtype=dtype)
  641. if action == 'append':
  642. c2[k2-self._c.shape[0]:, :self._c.shape[1]] = self._c
  643. c2[k2-c.shape[0]:, self._c.shape[1]:] = c
  644. self._x = np.r_[self._x, x]
  645. elif action == 'prepend':
  646. c2[k2-self._c.shape[0]:, :c.shape[1]] = c
  647. c2[k2-c.shape[0]:, c.shape[1]:] = self._c
  648. self._x = np.r_[x, self._x]
  649. self._c = c2
  650. def __call__(self, x, nu=0, extrapolate=None):
  651. """
  652. Evaluate the piecewise polynomial or its derivative.
  653. Parameters
  654. ----------
  655. x : array_like
  656. Points to evaluate the interpolant at.
  657. nu : int, optional
  658. Order of derivative to evaluate. Must be non-negative.
  659. extrapolate : {bool, 'periodic', None}, optional
  660. If bool, determines whether to extrapolate to out-of-bounds points
  661. based on first and last intervals, or to return NaNs.
  662. If 'periodic', periodic extrapolation is used.
  663. If None (default), use `self.extrapolate`.
  664. Returns
  665. -------
  666. y : array_like
  667. Interpolated values. Shape is determined by replacing
  668. the interpolation axis in the original array with the shape of x.
  669. Notes
  670. -----
  671. Derivatives are evaluated piecewise for each polynomial
  672. segment, even if the polynomial is not differentiable at the
  673. breakpoints. The polynomial intervals are considered half-open,
  674. ``[a, b)``, except for the last interval which is closed
  675. ``[a, b]``.
  676. """
  677. if extrapolate is None:
  678. extrapolate = self.extrapolate
  679. x = np.asarray(x)
  680. x_shape, x_ndim = x.shape, x.ndim
  681. x = np.ascontiguousarray(x.ravel(), dtype=np.float64)
  682. # With periodic extrapolation we map x to the segment
  683. # [self.x[0], self.x[-1]].
  684. if extrapolate == 'periodic':
  685. x = self._x[0] + (x - self._x[0]) % (self._x[-1] - self._x[0])
  686. extrapolate = False
  687. out = np.empty((len(x), prod(self._c.shape[2:])), dtype=self._c.dtype)
  688. self._ensure_c_contiguous()
  689. self._evaluate(x, nu, extrapolate, out)
  690. out = out.reshape(x_shape + self._c.shape[2:])
  691. if self.axis != 0:
  692. # transpose to move the calculated values to the interpolation axis
  693. l = list(range(out.ndim))
  694. l = l[x_ndim:x_ndim+self.axis] + l[:x_ndim] + l[x_ndim+self.axis:]
  695. out = out.transpose(l)
  696. return self._asarray(out)
  697. @xp_capabilities(
  698. cpu_only=True, jax_jit=False,
  699. skip_backends=[
  700. ("dask.array",
  701. "https://github.com/data-apis/array-api-extra/issues/488")
  702. ]
  703. )
  704. class PPoly(_PPolyBase):
  705. """Piecewise polynomial in the power basis.
  706. The polynomial between ``x[i]`` and ``x[i + 1]`` is written in the
  707. local power basis::
  708. S = sum(c[m, i] * (xp - x[i])**(k-m) for m in range(k+1))
  709. where ``k`` is the degree of the polynomial.
  710. Parameters
  711. ----------
  712. c : ndarray, shape (k+1, m, ...)
  713. Polynomial coefficients, degree `k` and `m` intervals.
  714. x : ndarray, shape (m+1,)
  715. Polynomial breakpoints. Must be sorted in either increasing or
  716. decreasing order.
  717. extrapolate : bool or 'periodic', optional
  718. If bool, determines whether to extrapolate to out-of-bounds points
  719. based on first and last intervals, or to return NaNs. If 'periodic',
  720. periodic extrapolation is used. Default is True.
  721. axis : int, optional
  722. Interpolation axis. Default is zero.
  723. Attributes
  724. ----------
  725. x : ndarray
  726. Breakpoints.
  727. c : ndarray
  728. Coefficients of the polynomials. They are reshaped
  729. to a 3-D array with the last dimension representing
  730. the trailing dimensions of the original coefficient array.
  731. axis : int
  732. Interpolation axis.
  733. Methods
  734. -------
  735. __call__
  736. derivative
  737. antiderivative
  738. integrate
  739. solve
  740. roots
  741. extend
  742. from_spline
  743. from_bernstein_basis
  744. construct_fast
  745. See also
  746. --------
  747. BPoly : piecewise polynomials in the Bernstein basis
  748. Notes
  749. -----
  750. High-order polynomials in the power basis can be numerically
  751. unstable. Precision problems can start to appear for orders
  752. larger than 20-30.
  753. """
  754. def _evaluate(self, x, nu, extrapolate, out):
  755. _ppoly.evaluate(self._c.reshape(self._c.shape[0], self._c.shape[1], -1),
  756. self._x, x, nu, bool(extrapolate), out)
  757. def derivative(self, nu=1):
  758. """
  759. Construct a new piecewise polynomial representing the derivative.
  760. Parameters
  761. ----------
  762. nu : int, optional
  763. Order of derivative to evaluate. Default is 1, i.e., compute the
  764. first derivative. If negative, the antiderivative is returned.
  765. Returns
  766. -------
  767. pp : PPoly
  768. Piecewise polynomial of order k2 = k - n representing the derivative
  769. of this polynomial.
  770. Notes
  771. -----
  772. Derivatives are evaluated piecewise for each polynomial
  773. segment, even if the polynomial is not differentiable at the
  774. breakpoints. The polynomial intervals are considered half-open,
  775. ``[a, b)``, except for the last interval which is closed
  776. ``[a, b]``.
  777. """
  778. if nu < 0:
  779. return self.antiderivative(-nu)
  780. # reduce order
  781. if nu == 0:
  782. c2 = self._c.copy()
  783. else:
  784. c2 = self._c[:-nu, :].copy()
  785. if c2.shape[0] == 0:
  786. # derivative of order 0 is zero
  787. c2 = np.zeros((1,) + c2.shape[1:], dtype=c2.dtype)
  788. # multiply by the correct rising factorials
  789. factor = spec.poch(np.arange(c2.shape[0], 0, -1), nu)
  790. c2 *= factor[(slice(None),) + (None,)*(c2.ndim-1)]
  791. # construct a compatible polynomial
  792. c2 = self._asarray(c2)
  793. return self.construct_fast(c2, self.x, self.extrapolate, self.axis)
  794. def antiderivative(self, nu=1):
  795. """
  796. Construct a new piecewise polynomial representing the antiderivative.
  797. Antiderivative is also the indefinite integral of the function,
  798. and derivative is its inverse operation.
  799. Parameters
  800. ----------
  801. nu : int, optional
  802. Order of antiderivative to evaluate. Default is 1, i.e., compute
  803. the first integral. If negative, the derivative is returned.
  804. Returns
  805. -------
  806. pp : PPoly
  807. Piecewise polynomial of order k2 = k + n representing
  808. the antiderivative of this polynomial.
  809. Notes
  810. -----
  811. The antiderivative returned by this function is continuous and
  812. continuously differentiable to order n-1, up to floating point
  813. rounding error.
  814. If antiderivative is computed and ``self.extrapolate='periodic'``,
  815. it will be set to False for the returned instance. This is done because
  816. the antiderivative is no longer periodic and its correct evaluation
  817. outside of the initially given x interval is difficult.
  818. """
  819. if nu <= 0:
  820. return self.derivative(-nu)
  821. c = np.zeros((self._c.shape[0] + nu, self._c.shape[1]) + self._c.shape[2:],
  822. dtype=self._c.dtype)
  823. c[:-nu] = self._c
  824. # divide by the correct rising factorials
  825. factor = spec.poch(np.arange(self._c.shape[0], 0, -1), nu)
  826. c[:-nu] /= factor[(slice(None),) + (None,)*(c.ndim-1)]
  827. # fix continuity of added degrees of freedom
  828. self._ensure_c_contiguous()
  829. _ppoly.fix_continuity(c.reshape(c.shape[0], c.shape[1], -1),
  830. self._x, nu - 1)
  831. if self.extrapolate == 'periodic':
  832. extrapolate = False
  833. else:
  834. extrapolate = self.extrapolate
  835. # construct a compatible polynomial
  836. c = self._asarray(c)
  837. return self.construct_fast(c, self.x, extrapolate, self.axis)
  838. def integrate(self, a, b, extrapolate=None):
  839. """
  840. Compute a definite integral over a piecewise polynomial.
  841. Parameters
  842. ----------
  843. a : float
  844. Lower integration bound
  845. b : float
  846. Upper integration bound
  847. extrapolate : {bool, 'periodic', None}, optional
  848. If bool, determines whether to extrapolate to out-of-bounds points
  849. based on first and last intervals, or to return NaNs.
  850. If 'periodic', periodic extrapolation is used.
  851. If None (default), use `self.extrapolate`.
  852. Returns
  853. -------
  854. ig : array_like
  855. Definite integral of the piecewise polynomial over [a, b]
  856. """
  857. if extrapolate is None:
  858. extrapolate = self.extrapolate
  859. # Swap integration bounds if needed
  860. sign = 1
  861. if b < a:
  862. a, b = b, a
  863. sign = -1
  864. range_int = np.empty((prod(self._c.shape[2:]),), dtype=self._c.dtype)
  865. self._ensure_c_contiguous()
  866. # Compute the integral.
  867. if extrapolate == 'periodic':
  868. # Split the integral into the part over period (can be several
  869. # of them) and the remaining part.
  870. xs, xe = self._x[0], self._x[-1]
  871. period = xe - xs
  872. interval = b - a
  873. n_periods, left = divmod(interval, period)
  874. if n_periods > 0:
  875. _ppoly.integrate(
  876. self._c.reshape(self._c.shape[0], self._c.shape[1], -1),
  877. self._x, xs, xe, False, out=range_int)
  878. range_int *= n_periods
  879. else:
  880. range_int.fill(0)
  881. # Map a to [xs, xe], b is always a + left.
  882. a = xs + (a - xs) % period
  883. b = a + left
  884. # If b <= xe then we need to integrate over [a, b], otherwise
  885. # over [a, xe] and from xs to what is remained.
  886. remainder_int = np.empty_like(range_int)
  887. if b <= xe:
  888. _ppoly.integrate(
  889. self._c.reshape(self._c.shape[0], self._c.shape[1], -1),
  890. self._x, a, b, False, out=remainder_int)
  891. range_int += remainder_int
  892. else:
  893. _ppoly.integrate(
  894. self._c.reshape(self._c.shape[0], self._c.shape[1], -1),
  895. self._x, a, xe, False, out=remainder_int)
  896. range_int += remainder_int
  897. _ppoly.integrate(
  898. self._c.reshape(self._c.shape[0], self._c.shape[1], -1),
  899. self._x, xs, xs + left + a - xe, False, out=remainder_int)
  900. range_int += remainder_int
  901. else:
  902. _ppoly.integrate(
  903. self._c.reshape(self._c.shape[0], self._c.shape[1], -1),
  904. self._x, a, b, bool(extrapolate), out=range_int)
  905. # Return
  906. range_int *= sign
  907. return self._asarray(range_int.reshape(self._c.shape[2:]))
  908. def solve(self, y=0., discontinuity=True, extrapolate=None):
  909. """
  910. Find real solutions of the equation ``pp(x) == y``.
  911. Parameters
  912. ----------
  913. y : float, optional
  914. Right-hand side. Default is zero.
  915. discontinuity : bool, optional
  916. Whether to report sign changes across discontinuities at
  917. breakpoints as roots.
  918. extrapolate : {bool, 'periodic', None}, optional
  919. If bool, determines whether to return roots from the polynomial
  920. extrapolated based on first and last intervals, 'periodic' works
  921. the same as False. If None (default), use `self.extrapolate`.
  922. Returns
  923. -------
  924. roots : ndarray
  925. Roots of the polynomial(s).
  926. If the PPoly object describes multiple polynomials, the
  927. return value is an object array whose each element is an
  928. ndarray containing the roots.
  929. Notes
  930. -----
  931. This routine works only on real-valued polynomials.
  932. If the piecewise polynomial contains sections that are
  933. identically zero, the root list will contain the start point
  934. of the corresponding interval, followed by a ``nan`` value.
  935. If the polynomial is discontinuous across a breakpoint, and
  936. there is a sign change across the breakpoint, this is reported
  937. if the `discont` parameter is True.
  938. Examples
  939. --------
  940. Finding roots of ``[x**2 - 1, (x - 1)**2]`` defined on intervals
  941. ``[-2, 1], [1, 2]``:
  942. >>> import numpy as np
  943. >>> from scipy.interpolate import PPoly
  944. >>> pp = PPoly(np.array([[1, -4, 3], [1, 0, 0]]).T, [-2, 1, 2])
  945. >>> pp.solve()
  946. array([-1., 1.])
  947. """
  948. if extrapolate is None:
  949. extrapolate = self.extrapolate
  950. self._ensure_c_contiguous()
  951. if np.issubdtype(self._c.dtype, np.complexfloating):
  952. raise ValueError("Root finding is only for "
  953. "real-valued polynomials")
  954. y = float(y)
  955. r = _ppoly.real_roots(self._c.reshape(self._c.shape[0], self._c.shape[1], -1),
  956. self._x, y, bool(discontinuity),
  957. bool(extrapolate))
  958. if self._c.ndim == 2:
  959. return r[0]
  960. else:
  961. r2 = np.empty(prod(self._c.shape[2:]), dtype=object)
  962. # this for-loop is equivalent to ``r2[...] = r``, but that's broken
  963. # in NumPy 1.6.0
  964. for ii, root in enumerate(r):
  965. r2[ii] = root
  966. return r2.reshape(self._c.shape[2:])
  967. def roots(self, discontinuity=True, extrapolate=None):
  968. """
  969. Find real roots of the piecewise polynomial.
  970. Parameters
  971. ----------
  972. discontinuity : bool, optional
  973. Whether to report sign changes across discontinuities at
  974. breakpoints as roots.
  975. extrapolate : {bool, 'periodic', None}, optional
  976. If bool, determines whether to return roots from the polynomial
  977. extrapolated based on first and last intervals, 'periodic' works
  978. the same as False. If None (default), use `self.extrapolate`.
  979. Returns
  980. -------
  981. roots : ndarray
  982. Roots of the polynomial(s).
  983. If the PPoly object describes multiple polynomials, the
  984. return value is an object array whose each element is an
  985. ndarray containing the roots.
  986. See Also
  987. --------
  988. PPoly.solve
  989. """
  990. return self.solve(0, discontinuity, extrapolate)
  991. @classmethod
  992. def from_spline(cls, tck, extrapolate=None):
  993. """
  994. Construct a piecewise polynomial from a spline
  995. Parameters
  996. ----------
  997. tck
  998. A spline, as returned by `splrep` or a BSpline object.
  999. extrapolate : bool or 'periodic', optional
  1000. If bool, determines whether to extrapolate to out-of-bounds points
  1001. based on first and last intervals, or to return NaNs.
  1002. If 'periodic', periodic extrapolation is used. Default is True.
  1003. Examples
  1004. --------
  1005. Construct an interpolating spline and convert it to a `PPoly` instance
  1006. >>> import numpy as np
  1007. >>> from scipy.interpolate import splrep, PPoly
  1008. >>> x = np.linspace(0, 1, 11)
  1009. >>> y = np.sin(2*np.pi*x)
  1010. >>> tck = splrep(x, y, s=0)
  1011. >>> p = PPoly.from_spline(tck)
  1012. >>> isinstance(p, PPoly)
  1013. True
  1014. Note that this function only supports 1D splines out of the box.
  1015. If the ``tck`` object represents a parametric spline (e.g. constructed
  1016. by `splprep` or a `BSpline` with ``c.ndim > 1``), you will need to loop
  1017. over the dimensions manually.
  1018. >>> from scipy.interpolate import splprep, splev
  1019. >>> t = np.linspace(0, 1, 11)
  1020. >>> x = np.sin(2*np.pi*t)
  1021. >>> y = np.cos(2*np.pi*t)
  1022. >>> (t, c, k), u = splprep([x, y], s=0)
  1023. Note that ``c`` is a list of two arrays of length 11.
  1024. >>> unew = np.arange(0, 1.01, 0.01)
  1025. >>> out = splev(unew, (t, c, k))
  1026. To convert this spline to the power basis, we convert each
  1027. component of the list of b-spline coefficients, ``c``, into the
  1028. corresponding cubic polynomial.
  1029. >>> polys = [PPoly.from_spline((t, cj, k)) for cj in c]
  1030. >>> polys[0].c.shape
  1031. (4, 14)
  1032. Note that the coefficients of the polynomials `polys` are in the
  1033. power basis and their dimensions reflect just that: here 4 is the order
  1034. (degree+1), and 14 is the number of intervals---which is nothing but
  1035. the length of the knot array of the original `tck` minus one.
  1036. Optionally, we can stack the components into a single `PPoly` along
  1037. the third dimension:
  1038. >>> cc = np.dstack([p.c for p in polys]) # has shape = (4, 14, 2)
  1039. >>> poly = PPoly(cc, polys[0].x)
  1040. >>> np.allclose(poly(unew).T, # note the transpose to match `splev`
  1041. ... out, atol=1e-15)
  1042. True
  1043. """
  1044. if isinstance(tck, BSpline):
  1045. t, c, k = tck._t, tck._c, tck.k
  1046. _asarray = tck._asarray
  1047. if extrapolate is None:
  1048. extrapolate = tck.extrapolate
  1049. else:
  1050. t, c, k = tck
  1051. _asarray = np.asarray
  1052. cvals = np.empty((k + 1, len(t)-1), dtype=c.dtype)
  1053. for m in range(k, -1, -1):
  1054. y = _fitpack_py.splev(t[:-1], (t, c, k), der=m)
  1055. cvals[k - m, :] = y / spec.gamma(m+1)
  1056. return cls.construct_fast(_asarray(cvals), _asarray(t), extrapolate)
  1057. @classmethod
  1058. def from_bernstein_basis(cls, bp, extrapolate=None):
  1059. """
  1060. Construct a piecewise polynomial in the power basis
  1061. from a polynomial in Bernstein basis.
  1062. Parameters
  1063. ----------
  1064. bp : BPoly
  1065. A Bernstein basis polynomial, as created by BPoly
  1066. extrapolate : bool or 'periodic', optional
  1067. If bool, determines whether to extrapolate to out-of-bounds points
  1068. based on first and last intervals, or to return NaNs.
  1069. If 'periodic', periodic extrapolation is used. Default is True.
  1070. """
  1071. if not isinstance(bp, BPoly):
  1072. raise TypeError(f".from_bernstein_basis only accepts BPoly instances. "
  1073. f"Got {type(bp)} instead.")
  1074. dx = np.diff(bp._x)
  1075. k = bp._c.shape[0] - 1 # polynomial order
  1076. rest = (None,)*(bp.c.ndim-2)
  1077. c = np.zeros_like(bp._c)
  1078. for a in range(k+1):
  1079. factor = (-1)**a * comb(k, a) * bp._c[a]
  1080. for s in range(a, k+1):
  1081. val = comb(k-a, s-a) * (-1)**s
  1082. c[k-s] += factor * val / dx[(slice(None),)+rest]**s
  1083. if extrapolate is None:
  1084. extrapolate = bp.extrapolate
  1085. return cls.construct_fast(bp._asarray(c), bp.x, extrapolate, bp.axis)
  1086. @xp_capabilities(
  1087. cpu_only=True, jax_jit=False,
  1088. skip_backends=[
  1089. ("dask.array",
  1090. "https://github.com/data-apis/array-api-extra/issues/488")
  1091. ]
  1092. )
  1093. class BPoly(_PPolyBase):
  1094. """Piecewise polynomial in the Bernstein basis.
  1095. The polynomial between ``x[i]`` and ``x[i + 1]`` is written in the
  1096. Bernstein polynomial basis::
  1097. S = sum(c[a, i] * b(a, k; x) for a in range(k+1)),
  1098. where ``k`` is the degree of the polynomial, and::
  1099. b(a, k; x) = binom(k, a) * t**a * (1 - t)**(k - a),
  1100. with ``t = (x - x[i]) / (x[i+1] - x[i])`` and ``binom`` is the binomial
  1101. coefficient.
  1102. Parameters
  1103. ----------
  1104. c : ndarray, shape (k, m, ...)
  1105. Polynomial coefficients, order `k` and `m` intervals
  1106. x : ndarray, shape (m+1,)
  1107. Polynomial breakpoints. Must be sorted in either increasing or
  1108. decreasing order.
  1109. extrapolate : bool, optional
  1110. If bool, determines whether to extrapolate to out-of-bounds points
  1111. based on first and last intervals, or to return NaNs. If 'periodic',
  1112. periodic extrapolation is used. Default is True.
  1113. axis : int, optional
  1114. Interpolation axis. Default is zero.
  1115. Attributes
  1116. ----------
  1117. x : ndarray
  1118. Breakpoints.
  1119. c : ndarray
  1120. Coefficients of the polynomials. They are reshaped
  1121. to a 3-D array with the last dimension representing
  1122. the trailing dimensions of the original coefficient array.
  1123. axis : int
  1124. Interpolation axis.
  1125. Methods
  1126. -------
  1127. __call__
  1128. extend
  1129. derivative
  1130. antiderivative
  1131. integrate
  1132. construct_fast
  1133. from_power_basis
  1134. from_derivatives
  1135. See also
  1136. --------
  1137. PPoly : piecewise polynomials in the power basis
  1138. Notes
  1139. -----
  1140. Properties of Bernstein polynomials are well documented in the literature,
  1141. see for example [1]_ [2]_ [3]_.
  1142. References
  1143. ----------
  1144. .. [1] https://en.wikipedia.org/wiki/Bernstein_polynomial
  1145. .. [2] Kenneth I. Joy, Bernstein polynomials,
  1146. http://www.idav.ucdavis.edu/education/CAGDNotes/Bernstein-Polynomials.pdf
  1147. .. [3] E. H. Doha, A. H. Bhrawy, and M. A. Saker, Boundary Value Problems,
  1148. vol 2011, article ID 829546, :doi:`10.1155/2011/829543`.
  1149. Examples
  1150. --------
  1151. >>> from scipy.interpolate import BPoly
  1152. >>> x = [0, 1]
  1153. >>> c = [[1], [2], [3]]
  1154. >>> bp = BPoly(c, x)
  1155. This creates a 2nd order polynomial
  1156. .. math::
  1157. B(x) = 1 \\times b_{0, 2}(x) + 2 \\times b_{1, 2}(x) + 3
  1158. \\times b_{2, 2}(x) \\\\
  1159. = 1 \\times (1-x)^2 + 2 \\times 2 x (1 - x) + 3 \\times x^2
  1160. """ # noqa: E501
  1161. def _evaluate(self, x, nu, extrapolate, out):
  1162. _ppoly.evaluate_bernstein(
  1163. self._c.reshape(self._c.shape[0], self._c.shape[1], -1),
  1164. self._x, x, nu, bool(extrapolate), out)
  1165. def derivative(self, nu=1):
  1166. """
  1167. Construct a new piecewise polynomial representing the derivative.
  1168. Parameters
  1169. ----------
  1170. nu : int, optional
  1171. Order of derivative to evaluate. Default is 1, i.e., compute the
  1172. first derivative. If negative, the antiderivative is returned.
  1173. Returns
  1174. -------
  1175. bp : BPoly
  1176. Piecewise polynomial of order k - nu representing the derivative of
  1177. this polynomial.
  1178. """
  1179. if nu < 0:
  1180. return self.antiderivative(-nu)
  1181. if nu > 1:
  1182. bp = self
  1183. for k in range(nu):
  1184. bp = bp.derivative()
  1185. return bp
  1186. # reduce order
  1187. if nu == 0:
  1188. c2 = self._c.copy()
  1189. else:
  1190. # For a polynomial
  1191. # B(x) = \sum_{a=0}^{k} c_a b_{a, k}(x),
  1192. # we use the fact that
  1193. # b'_{a, k} = k ( b_{a-1, k-1} - b_{a, k-1} ),
  1194. # which leads to
  1195. # B'(x) = \sum_{a=0}^{k-1} (c_{a+1} - c_a) b_{a, k-1}
  1196. #
  1197. # finally, for an interval [y, y + dy] with dy != 1,
  1198. # we need to correct for an extra power of dy
  1199. rest = (None,)*(self._c.ndim-2)
  1200. k = self._c.shape[0] - 1
  1201. dx = np.diff(self._x)[(None, slice(None))+rest]
  1202. c2 = k * np.diff(self._c, axis=0) / dx
  1203. if c2.shape[0] == 0:
  1204. # derivative of order 0 is zero
  1205. c2 = np.zeros((1,) + c2.shape[1:], dtype=c2.dtype)
  1206. # construct a compatible polynomial
  1207. c2 = self._asarray(c2)
  1208. return self.construct_fast(c2, self.x, self.extrapolate, self.axis)
  1209. def antiderivative(self, nu=1):
  1210. """
  1211. Construct a new piecewise polynomial representing the antiderivative.
  1212. Parameters
  1213. ----------
  1214. nu : int, optional
  1215. Order of antiderivative to evaluate. Default is 1, i.e., compute
  1216. the first integral. If negative, the derivative is returned.
  1217. Returns
  1218. -------
  1219. bp : BPoly
  1220. Piecewise polynomial of order k + nu representing the
  1221. antiderivative of this polynomial.
  1222. Notes
  1223. -----
  1224. If antiderivative is computed and ``self.extrapolate='periodic'``,
  1225. it will be set to False for the returned instance. This is done because
  1226. the antiderivative is no longer periodic and its correct evaluation
  1227. outside of the initially given x interval is difficult.
  1228. """
  1229. if nu <= 0:
  1230. return self.derivative(-nu)
  1231. if nu > 1:
  1232. bp = self
  1233. for k in range(nu):
  1234. bp = bp.antiderivative()
  1235. return bp
  1236. # Construct the indefinite integrals on individual intervals
  1237. c, x = self._c, self._x
  1238. k = c.shape[0]
  1239. c2 = np.zeros((k+1,) + c.shape[1:], dtype=c.dtype)
  1240. c2[1:, ...] = np.cumsum(c, axis=0) / k
  1241. delta = x[1:] - x[:-1]
  1242. c2 *= delta[(None, slice(None)) + (None,)*(c.ndim-2)]
  1243. # Now fix continuity: on the very first interval, take the integration
  1244. # constant to be zero; on an interval [x_j, x_{j+1}) with j>0,
  1245. # the integration constant is then equal to the jump of the `bp` at x_j.
  1246. # The latter is given by the coefficient of B_{n+1, n+1}
  1247. # *on the previous interval* (other B. polynomials are zero at the
  1248. # breakpoint). Finally, use the fact that BPs form a partition of unity.
  1249. c2[:,1:] += np.cumsum(c2[k, :], axis=0)[:-1]
  1250. if self.extrapolate == 'periodic':
  1251. extrapolate = False
  1252. else:
  1253. extrapolate = self.extrapolate
  1254. c2 = self._asarray(c2)
  1255. return self.construct_fast(c2, self.x, extrapolate, axis=self.axis)
  1256. def integrate(self, a, b, extrapolate=None):
  1257. """
  1258. Compute a definite integral over a piecewise polynomial.
  1259. Parameters
  1260. ----------
  1261. a : float
  1262. Lower integration bound
  1263. b : float
  1264. Upper integration bound
  1265. extrapolate : {bool, 'periodic', None}, optional
  1266. Whether to extrapolate to out-of-bounds points based on first
  1267. and last intervals, or to return NaNs. If 'periodic', periodic
  1268. extrapolation is used. If None (default), use `self.extrapolate`.
  1269. Returns
  1270. -------
  1271. array_like
  1272. Definite integral of the piecewise polynomial over [a, b]
  1273. """
  1274. # XXX: can probably use instead the fact that
  1275. # \int_0^{1} B_{j, n}(x) \dx = 1/(n+1)
  1276. ib = self.antiderivative()
  1277. if extrapolate is None:
  1278. extrapolate = self.extrapolate
  1279. # ib.extrapolate shouldn't be 'periodic', it is converted to
  1280. # False for 'periodic. in antiderivative() call.
  1281. if extrapolate != 'periodic':
  1282. ib.extrapolate = extrapolate
  1283. if extrapolate == 'periodic':
  1284. # Split the integral into the part over period (can be several
  1285. # of them) and the remaining part.
  1286. # For simplicity and clarity convert to a <= b case.
  1287. if a <= b:
  1288. sign = 1
  1289. else:
  1290. a, b = b, a
  1291. sign = -1
  1292. xs, xe = self._x[0], self._x[-1]
  1293. period = xe - xs
  1294. interval = b - a
  1295. n_periods, left = divmod(interval, period)
  1296. res = n_periods * (ib(xe) - ib(xs))
  1297. # Map a and b to [xs, xe].
  1298. a = xs + (a - xs) % period
  1299. b = a + left
  1300. # If b <= xe then we need to integrate over [a, b], otherwise
  1301. # over [a, xe] and from xs to what is remained.
  1302. if b <= xe:
  1303. res += ib(b) - ib(a)
  1304. else:
  1305. res += ib(xe) - ib(a) + ib(xs + left + a - xe) - ib(xs)
  1306. return self._asarray(sign * res)
  1307. else:
  1308. return ib(b) - ib(a)
  1309. def extend(self, c, x):
  1310. k = max(self._c.shape[0], c.shape[0])
  1311. self._c = self._raise_degree(self._c, k - self._c.shape[0])
  1312. c = self._raise_degree(c, k - c.shape[0])
  1313. return _PPolyBase.extend(self, c, x)
  1314. extend.__doc__ = _PPolyBase.extend.__doc__
  1315. @classmethod
  1316. def from_power_basis(cls, pp, extrapolate=None):
  1317. """
  1318. Construct a piecewise polynomial in Bernstein basis
  1319. from a power basis polynomial.
  1320. Parameters
  1321. ----------
  1322. pp : PPoly
  1323. A piecewise polynomial in the power basis
  1324. extrapolate : bool or 'periodic', optional
  1325. If bool, determines whether to extrapolate to out-of-bounds points
  1326. based on first and last intervals, or to return NaNs.
  1327. If 'periodic', periodic extrapolation is used. Default is True.
  1328. """
  1329. if not isinstance(pp, PPoly):
  1330. raise TypeError(f".from_power_basis only accepts PPoly instances. "
  1331. f"Got {type(pp)} instead.")
  1332. dx = np.diff(pp._x)
  1333. k = pp._c.shape[0] - 1 # polynomial order
  1334. rest = (None,)*(pp._c.ndim-2)
  1335. c = np.zeros_like(pp._c)
  1336. for a in range(k+1):
  1337. factor = pp._c[a] / comb(k, k-a) * dx[(slice(None),) + rest]**(k-a)
  1338. for j in range(k-a, k+1):
  1339. c[j] += factor * comb(j, k-a)
  1340. if extrapolate is None:
  1341. extrapolate = pp.extrapolate
  1342. return cls.construct_fast(pp._asarray(c), pp.x, extrapolate, pp.axis)
  1343. @classmethod
  1344. def from_derivatives(cls, xi, yi, orders=None, extrapolate=None):
  1345. """Construct a piecewise polynomial in the Bernstein basis,
  1346. compatible with the specified values and derivatives at breakpoints.
  1347. Parameters
  1348. ----------
  1349. xi : array_like
  1350. sorted 1-D array of x-coordinates
  1351. yi : array_like or list of array_likes
  1352. ``yi[i][j]`` is the ``j``\\ th derivative known at ``xi[i]``
  1353. orders : None or int or array_like of ints. Default: None.
  1354. Specifies the degree of local polynomials. If not None, some
  1355. derivatives are ignored.
  1356. extrapolate : bool or 'periodic', optional
  1357. If bool, determines whether to extrapolate to out-of-bounds points
  1358. based on first and last intervals, or to return NaNs.
  1359. If 'periodic', periodic extrapolation is used. Default is True.
  1360. Notes
  1361. -----
  1362. If ``k`` derivatives are specified at a breakpoint ``x``, the
  1363. constructed polynomial is exactly ``k`` times continuously
  1364. differentiable at ``x``, unless the ``order`` is provided explicitly.
  1365. In the latter case, the smoothness of the polynomial at
  1366. the breakpoint is controlled by the ``order``.
  1367. Deduces the number of derivatives to match at each end
  1368. from ``order`` and the number of derivatives available. If
  1369. possible it uses the same number of derivatives from
  1370. each end; if the number is odd it tries to take the
  1371. extra one from y2. In any case if not enough derivatives
  1372. are available at one end or another it draws enough to
  1373. make up the total from the other end.
  1374. If the order is too high and not enough derivatives are available,
  1375. an exception is raised.
  1376. Examples
  1377. --------
  1378. >>> from scipy.interpolate import BPoly
  1379. >>> BPoly.from_derivatives([0, 1], [[1, 2], [3, 4]])
  1380. Creates a polynomial `f(x)` of degree 3, defined on ``[0, 1]``
  1381. such that `f(0) = 1, df/dx(0) = 2, f(1) = 3, df/dx(1) = 4`
  1382. >>> BPoly.from_derivatives([0, 1, 2], [[0, 1], [0], [2]])
  1383. Creates a piecewise polynomial `f(x)`, such that
  1384. `f(0) = f(1) = 0`, `f(2) = 2`, and `df/dx(0) = 1`.
  1385. Based on the number of derivatives provided, the order of the
  1386. local polynomials is 2 on ``[0, 1]`` and 1 on ``[1, 2]``.
  1387. Notice that no restriction is imposed on the derivatives at
  1388. ``x = 1`` and ``x = 2``.
  1389. Indeed, the explicit form of the polynomial is::
  1390. f(x) = | x * (1 - x), 0 <= x < 1
  1391. | 2 * (x - 1), 1 <= x <= 2
  1392. So that f'(1-0) = -1 and f'(1+0) = 2
  1393. """
  1394. xi = np.asarray(xi)
  1395. if len(xi) != len(yi):
  1396. raise ValueError("xi and yi need to have the same length")
  1397. if np.any(xi[1:] - xi[:1] <= 0):
  1398. raise ValueError("x coordinates are not in increasing order")
  1399. # number of intervals
  1400. m = len(xi) - 1
  1401. # global poly order is k-1, local orders are <=k and can vary
  1402. try:
  1403. k = max(len(yi[i]) + len(yi[i+1]) for i in range(m))
  1404. except TypeError as e:
  1405. raise ValueError(
  1406. "Using a 1-D array for y? Please .reshape(-1, 1)."
  1407. ) from e
  1408. if orders is None:
  1409. orders = [None] * m
  1410. else:
  1411. if isinstance(orders, int | np.integer):
  1412. orders = [orders] * m
  1413. k = max(k, max(orders))
  1414. if any(o <= 0 for o in orders):
  1415. raise ValueError("Orders must be positive.")
  1416. c = []
  1417. for i in range(m):
  1418. y1, y2 = yi[i], yi[i+1]
  1419. if orders[i] is None:
  1420. n1, n2 = len(y1), len(y2)
  1421. else:
  1422. n = orders[i]+1
  1423. n1 = min(n//2, len(y1))
  1424. n2 = min(n - n1, len(y2))
  1425. n1 = min(n - n2, len(y2))
  1426. if n1 + n2 != n:
  1427. mesg = (
  1428. f"Point {xi[i]} has {len(y1)} derivatives, point {xi[i+1]} has "
  1429. f"{len(y2)} derivatives, but order {orders[i]} requested"
  1430. )
  1431. raise ValueError(mesg)
  1432. if not (n1 <= len(y1) and n2 <= len(y2)):
  1433. raise ValueError("`order` input incompatible with"
  1434. " length y1 or y2.")
  1435. b = BPoly._construct_from_derivatives(xi[i], xi[i+1],
  1436. y1[:n1], y2[:n2])
  1437. if len(b) < k:
  1438. b = BPoly._raise_degree(b, k - len(b))
  1439. c.append(b)
  1440. c = np.asarray(c)
  1441. return cls(c.swapaxes(0, 1), xi, extrapolate)
  1442. @staticmethod
  1443. def _construct_from_derivatives(xa, xb, ya, yb):
  1444. r"""Compute the coefficients of a polynomial in the Bernstein basis
  1445. given the values and derivatives at the edges.
  1446. Return the coefficients of a polynomial in the Bernstein basis
  1447. defined on ``[xa, xb]`` and having the values and derivatives at the
  1448. endpoints `xa` and `xb` as specified by `ya` and `yb`.
  1449. The polynomial constructed is of the minimal possible degree, i.e.,
  1450. if the lengths of `ya` and `yb` are `na` and `nb`, the degree
  1451. of the polynomial is ``na + nb - 1``.
  1452. Parameters
  1453. ----------
  1454. xa : float
  1455. Left-hand end point of the interval
  1456. xb : float
  1457. Right-hand end point of the interval
  1458. ya : array_like
  1459. Derivatives at `xa`. ``ya[0]`` is the value of the function, and
  1460. ``ya[i]`` for ``i > 0`` is the value of the ``i``\ th derivative.
  1461. yb : array_like
  1462. Derivatives at `xb`.
  1463. Returns
  1464. -------
  1465. array
  1466. coefficient array of a polynomial having specified derivatives
  1467. Notes
  1468. -----
  1469. This uses several facts from life of Bernstein basis functions.
  1470. First of all,
  1471. .. math:: b'_{a, n} = n (b_{a-1, n-1} - b_{a, n-1})
  1472. If B(x) is a linear combination of the form
  1473. .. math:: B(x) = \sum_{a=0}^{n} c_a b_{a, n},
  1474. then :math: B'(x) = n \sum_{a=0}^{n-1} (c_{a+1} - c_{a}) b_{a, n-1}.
  1475. Iterating the latter one, one finds for the q-th derivative
  1476. .. math:: B^{q}(x) = n!/(n-q)! \sum_{a=0}^{n-q} Q_a b_{a, n-q},
  1477. with
  1478. .. math:: Q_a = \sum_{j=0}^{q} (-)^{j+q} comb(q, j) c_{j+a}
  1479. This way, only `a=0` contributes to :math: `B^{q}(x = xa)`, and
  1480. `c_q` are found one by one by iterating `q = 0, ..., na`.
  1481. At ``x = xb`` it's the same with ``a = n - q``.
  1482. """
  1483. ya, yb = np.asarray(ya), np.asarray(yb)
  1484. if ya.shape[1:] != yb.shape[1:]:
  1485. raise ValueError(
  1486. f"Shapes of ya {ya.shape} and yb {yb.shape} are incompatible"
  1487. )
  1488. dta, dtb = ya.dtype, yb.dtype
  1489. if (np.issubdtype(dta, np.complexfloating) or
  1490. np.issubdtype(dtb, np.complexfloating)):
  1491. dt = np.complex128
  1492. else:
  1493. dt = np.float64
  1494. na, nb = len(ya), len(yb)
  1495. n = na + nb
  1496. c = np.empty((na+nb,) + ya.shape[1:], dtype=dt)
  1497. # compute coefficients of a polynomial degree na+nb-1
  1498. # walk left-to-right
  1499. for q in range(0, na):
  1500. c[q] = ya[q] / spec.poch(n - q, q) * (xb - xa)**q
  1501. for j in range(0, q):
  1502. c[q] -= (-1)**(j+q) * comb(q, j) * c[j]
  1503. # now walk right-to-left
  1504. for q in range(0, nb):
  1505. c[-q-1] = yb[q] / spec.poch(n - q, q) * (-1)**q * (xb - xa)**q
  1506. for j in range(0, q):
  1507. c[-q-1] -= (-1)**(j+1) * comb(q, j+1) * c[-q+j]
  1508. return c
  1509. @staticmethod
  1510. def _raise_degree(c, d):
  1511. r"""Raise a degree of a polynomial in the Bernstein basis.
  1512. Given the coefficients of a polynomial degree `k`, return (the
  1513. coefficients of) the equivalent polynomial of degree `k+d`.
  1514. Parameters
  1515. ----------
  1516. c : array_like
  1517. coefficient array, 1-D
  1518. d : integer
  1519. Returns
  1520. -------
  1521. array
  1522. coefficient array, 1-D array of length `c.shape[0] + d`
  1523. Notes
  1524. -----
  1525. This uses the fact that a Bernstein polynomial `b_{a, k}` can be
  1526. identically represented as a linear combination of polynomials of
  1527. a higher degree `k+d`:
  1528. .. math:: b_{a, k} = comb(k, a) \sum_{j=0}^{d} b_{a+j, k+d} \
  1529. comb(d, j) / comb(k+d, a+j)
  1530. """
  1531. if d == 0:
  1532. return c
  1533. k = c.shape[0] - 1
  1534. out = np.zeros((c.shape[0] + d,) + c.shape[1:], dtype=c.dtype)
  1535. for a in range(c.shape[0]):
  1536. f = c[a] * comb(k, a)
  1537. for j in range(d+1):
  1538. out[a+j] += f * comb(d, j) / comb(k+d, a+j)
  1539. return out
  1540. class NdPPoly:
  1541. """
  1542. Piecewise tensor product polynomial
  1543. The value at point ``xp = (x', y', z', ...)`` is evaluated by first
  1544. computing the interval indices `i` such that::
  1545. x[0][i[0]] <= x' < x[0][i[0]+1]
  1546. x[1][i[1]] <= y' < x[1][i[1]+1]
  1547. ...
  1548. and then computing::
  1549. S = sum(c[k0-m0-1,...,kn-mn-1,i[0],...,i[n]]
  1550. * (xp[0] - x[0][i[0]])**m0
  1551. * ...
  1552. * (xp[n] - x[n][i[n]])**mn
  1553. for m0 in range(k[0]+1)
  1554. ...
  1555. for mn in range(k[n]+1))
  1556. where ``k[j]`` is the degree of the polynomial in dimension j. This
  1557. representation is the piecewise multivariate power basis.
  1558. Parameters
  1559. ----------
  1560. c : ndarray, shape (k0, ..., kn, m0, ..., mn, ...)
  1561. Polynomial coefficients, with polynomial order `kj` and
  1562. `mj+1` intervals for each dimension `j`.
  1563. x : ndim-tuple of ndarrays, shapes (mj+1,)
  1564. Polynomial breakpoints for each dimension. These must be
  1565. sorted in increasing order.
  1566. extrapolate : bool, optional
  1567. Whether to extrapolate to out-of-bounds points based on first
  1568. and last intervals, or to return NaNs. Default: True.
  1569. Attributes
  1570. ----------
  1571. x : tuple of ndarrays
  1572. Breakpoints.
  1573. c : ndarray
  1574. Coefficients of the polynomials.
  1575. Methods
  1576. -------
  1577. __call__
  1578. derivative
  1579. antiderivative
  1580. integrate
  1581. integrate_1d
  1582. construct_fast
  1583. See also
  1584. --------
  1585. PPoly : piecewise polynomials in 1D
  1586. Notes
  1587. -----
  1588. High-order polynomials in the power basis can be numerically
  1589. unstable.
  1590. """
  1591. def __init__(self, c, x, extrapolate=None):
  1592. self.x = tuple(np.ascontiguousarray(v, dtype=np.float64) for v in x)
  1593. self.c = np.asarray(c)
  1594. if extrapolate is None:
  1595. extrapolate = True
  1596. self.extrapolate = bool(extrapolate)
  1597. ndim = len(self.x)
  1598. if any(v.ndim != 1 for v in self.x):
  1599. raise ValueError("x arrays must all be 1-dimensional")
  1600. if any(v.size < 2 for v in self.x):
  1601. raise ValueError("x arrays must all contain at least 2 points")
  1602. if c.ndim < 2*ndim:
  1603. raise ValueError("c must have at least 2*len(x) dimensions")
  1604. if any(np.any(v[1:] - v[:-1] < 0) for v in self.x):
  1605. raise ValueError("x-coordinates are not in increasing order")
  1606. if any(a != b.size - 1 for a, b in zip(c.shape[ndim:2*ndim], self.x)):
  1607. raise ValueError("x and c do not agree on the number of intervals")
  1608. dtype = self._get_dtype(self.c.dtype)
  1609. self.c = np.ascontiguousarray(self.c, dtype=dtype)
  1610. @classmethod
  1611. def construct_fast(cls, c, x, extrapolate=None):
  1612. """
  1613. Construct the piecewise polynomial without making checks.
  1614. Takes the same parameters as the constructor. Input arguments
  1615. ``c`` and ``x`` must be arrays of the correct shape and type. The
  1616. ``c`` array can only be of dtypes float and complex, and ``x``
  1617. array must have dtype float.
  1618. """
  1619. self = object.__new__(cls)
  1620. self.c = c
  1621. self.x = x
  1622. if extrapolate is None:
  1623. extrapolate = True
  1624. self.extrapolate = extrapolate
  1625. return self
  1626. def _get_dtype(self, dtype):
  1627. if np.issubdtype(dtype, np.complexfloating) \
  1628. or np.issubdtype(self.c.dtype, np.complexfloating):
  1629. return np.complex128
  1630. else:
  1631. return np.float64
  1632. def _ensure_c_contiguous(self):
  1633. if not self.c.flags.c_contiguous:
  1634. self.c = self.c.copy()
  1635. if not isinstance(self.x, tuple):
  1636. self.x = tuple(self.x)
  1637. def __call__(self, x, nu=None, extrapolate=None):
  1638. """
  1639. Evaluate the piecewise polynomial or its derivative
  1640. Parameters
  1641. ----------
  1642. x : array-like
  1643. Points to evaluate the interpolant at.
  1644. nu : tuple, optional
  1645. Orders of derivatives to evaluate. Each must be non-negative.
  1646. extrapolate : bool, optional
  1647. Whether to extrapolate to out-of-bounds points based on first
  1648. and last intervals, or to return NaNs.
  1649. Returns
  1650. -------
  1651. y : array-like
  1652. Interpolated values. Shape is determined by replacing
  1653. the interpolation axis in the original array with the shape of x.
  1654. Notes
  1655. -----
  1656. Derivatives are evaluated piecewise for each polynomial
  1657. segment, even if the polynomial is not differentiable at the
  1658. breakpoints. The polynomial intervals are considered half-open,
  1659. ``[a, b)``, except for the last interval which is closed
  1660. ``[a, b]``.
  1661. """
  1662. if extrapolate is None:
  1663. extrapolate = self.extrapolate
  1664. else:
  1665. extrapolate = bool(extrapolate)
  1666. ndim = len(self.x)
  1667. x = _ndim_coords_from_arrays(x)
  1668. x_shape = x.shape
  1669. x = np.ascontiguousarray(x.reshape(-1, x.shape[-1]), dtype=np.float64)
  1670. if nu is None:
  1671. nu = np.zeros((ndim,), dtype=np.intc)
  1672. else:
  1673. nu = np.asarray(nu, dtype=np.intc)
  1674. if nu.ndim != 1 or nu.shape[0] != ndim:
  1675. raise ValueError("invalid number of derivative orders nu")
  1676. dim1 = prod(self.c.shape[:ndim])
  1677. dim2 = prod(self.c.shape[ndim:2*ndim])
  1678. dim3 = prod(self.c.shape[2*ndim:])
  1679. ks = np.array(self.c.shape[:ndim], dtype=np.intc)
  1680. out = np.empty((x.shape[0], dim3), dtype=self.c.dtype)
  1681. self._ensure_c_contiguous()
  1682. _ppoly.evaluate_nd(self.c.reshape(dim1, dim2, dim3),
  1683. self.x,
  1684. ks,
  1685. x,
  1686. nu,
  1687. bool(extrapolate),
  1688. out)
  1689. return out.reshape(x_shape[:-1] + self.c.shape[2*ndim:])
  1690. def _derivative_inplace(self, nu, axis):
  1691. """
  1692. Compute 1-D derivative along a selected dimension in-place
  1693. May result to non-contiguous c array.
  1694. """
  1695. if nu < 0:
  1696. return self._antiderivative_inplace(-nu, axis)
  1697. ndim = len(self.x)
  1698. axis = axis % ndim
  1699. # reduce order
  1700. if nu == 0:
  1701. # noop
  1702. return
  1703. else:
  1704. sl = [slice(None)]*ndim
  1705. sl[axis] = slice(None, -nu, None)
  1706. c2 = self.c[tuple(sl)]
  1707. if c2.shape[axis] == 0:
  1708. # derivative of order 0 is zero
  1709. shp = list(c2.shape)
  1710. shp[axis] = 1
  1711. c2 = np.zeros(shp, dtype=c2.dtype)
  1712. # multiply by the correct rising factorials
  1713. factor = spec.poch(np.arange(c2.shape[axis], 0, -1), nu)
  1714. sl = [None]*c2.ndim
  1715. sl[axis] = slice(None)
  1716. c2 *= factor[tuple(sl)]
  1717. self.c = c2
  1718. def _antiderivative_inplace(self, nu, axis):
  1719. """
  1720. Compute 1-D antiderivative along a selected dimension
  1721. May result to non-contiguous c array.
  1722. """
  1723. if nu <= 0:
  1724. return self._derivative_inplace(-nu, axis)
  1725. ndim = len(self.x)
  1726. axis = axis % ndim
  1727. perm = list(range(ndim))
  1728. perm[0], perm[axis] = perm[axis], perm[0]
  1729. perm = perm + list(range(ndim, self.c.ndim))
  1730. c = self.c.transpose(perm)
  1731. c2 = np.zeros((c.shape[0] + nu,) + c.shape[1:],
  1732. dtype=c.dtype)
  1733. c2[:-nu] = c
  1734. # divide by the correct rising factorials
  1735. factor = spec.poch(np.arange(c.shape[0], 0, -1), nu)
  1736. c2[:-nu] /= factor[(slice(None),) + (None,)*(c.ndim-1)]
  1737. # fix continuity of added degrees of freedom
  1738. perm2 = list(range(c2.ndim))
  1739. perm2[1], perm2[ndim+axis] = perm2[ndim+axis], perm2[1]
  1740. c2 = c2.transpose(perm2)
  1741. c2 = c2.copy()
  1742. _ppoly.fix_continuity(c2.reshape(c2.shape[0], c2.shape[1], -1),
  1743. self.x[axis], nu-1)
  1744. c2 = c2.transpose(perm2)
  1745. c2 = c2.transpose(perm)
  1746. # Done
  1747. self.c = c2
  1748. def derivative(self, nu):
  1749. """
  1750. Construct a new piecewise polynomial representing the derivative.
  1751. Parameters
  1752. ----------
  1753. nu : ndim-tuple of int
  1754. Order of derivatives to evaluate for each dimension.
  1755. If negative, the antiderivative is returned.
  1756. Returns
  1757. -------
  1758. pp : NdPPoly
  1759. Piecewise polynomial of orders (k[0] - nu[0], ..., k[n] - nu[n])
  1760. representing the derivative of this polynomial.
  1761. Notes
  1762. -----
  1763. Derivatives are evaluated piecewise for each polynomial
  1764. segment, even if the polynomial is not differentiable at the
  1765. breakpoints. The polynomial intervals in each dimension are
  1766. considered half-open, ``[a, b)``, except for the last interval
  1767. which is closed ``[a, b]``.
  1768. """
  1769. p = self.construct_fast(self.c.copy(), self.x, self.extrapolate)
  1770. for axis, n in enumerate(nu):
  1771. p._derivative_inplace(n, axis)
  1772. p._ensure_c_contiguous()
  1773. return p
  1774. def antiderivative(self, nu):
  1775. """
  1776. Construct a new piecewise polynomial representing the antiderivative.
  1777. Antiderivative is also the indefinite integral of the function,
  1778. and derivative is its inverse operation.
  1779. Parameters
  1780. ----------
  1781. nu : ndim-tuple of int
  1782. Order of derivatives to evaluate for each dimension.
  1783. If negative, the derivative is returned.
  1784. Returns
  1785. -------
  1786. pp : PPoly
  1787. Piecewise polynomial of order k2 = k + n representing
  1788. the antiderivative of this polynomial.
  1789. Notes
  1790. -----
  1791. The antiderivative returned by this function is continuous and
  1792. continuously differentiable to order n-1, up to floating point
  1793. rounding error.
  1794. """
  1795. p = self.construct_fast(self.c.copy(), self.x, self.extrapolate)
  1796. for axis, n in enumerate(nu):
  1797. p._antiderivative_inplace(n, axis)
  1798. p._ensure_c_contiguous()
  1799. return p
  1800. def integrate_1d(self, a, b, axis, extrapolate=None):
  1801. r"""
  1802. Compute NdPPoly representation for one dimensional definite integral
  1803. The result is a piecewise polynomial representing the integral:
  1804. .. math::
  1805. p(y, z, ...) = \int_a^b dx\, p(x, y, z, ...)
  1806. where the dimension integrated over is specified with the
  1807. `axis` parameter.
  1808. Parameters
  1809. ----------
  1810. a, b : float
  1811. Lower and upper bound for integration.
  1812. axis : int
  1813. Dimension over which to compute the 1-D integrals
  1814. extrapolate : bool, optional
  1815. Whether to extrapolate to out-of-bounds points based on first
  1816. and last intervals, or to return NaNs.
  1817. Returns
  1818. -------
  1819. ig : NdPPoly or array-like
  1820. Definite integral of the piecewise polynomial over [a, b].
  1821. If the polynomial was 1D, an array is returned,
  1822. otherwise, an NdPPoly object.
  1823. """
  1824. if extrapolate is None:
  1825. extrapolate = self.extrapolate
  1826. else:
  1827. extrapolate = bool(extrapolate)
  1828. ndim = len(self.x)
  1829. axis = int(axis) % ndim
  1830. # reuse 1-D integration routines
  1831. c = self.c
  1832. swap = list(range(c.ndim))
  1833. swap.insert(0, swap[axis])
  1834. del swap[axis + 1]
  1835. swap.insert(1, swap[ndim + axis])
  1836. del swap[ndim + axis + 1]
  1837. c = c.transpose(swap)
  1838. p = PPoly.construct_fast(c.reshape(c.shape[0], c.shape[1], -1),
  1839. self.x[axis],
  1840. extrapolate=extrapolate)
  1841. out = p.integrate(a, b, extrapolate=extrapolate)
  1842. # Construct result
  1843. if ndim == 1:
  1844. return out.reshape(c.shape[2:])
  1845. else:
  1846. c = out.reshape(c.shape[2:])
  1847. x = self.x[:axis] + self.x[axis+1:]
  1848. return self.construct_fast(c, x, extrapolate=extrapolate)
  1849. def integrate(self, ranges, extrapolate=None):
  1850. """
  1851. Compute a definite integral over a piecewise polynomial.
  1852. Parameters
  1853. ----------
  1854. ranges : ndim-tuple of 2-tuples float
  1855. Sequence of lower and upper bounds for each dimension,
  1856. ``[(a[0], b[0]), ..., (a[ndim-1], b[ndim-1])]``
  1857. extrapolate : bool, optional
  1858. Whether to extrapolate to out-of-bounds points based on first
  1859. and last intervals, or to return NaNs.
  1860. Returns
  1861. -------
  1862. ig : array_like
  1863. Definite integral of the piecewise polynomial over
  1864. [a[0], b[0]] x ... x [a[ndim-1], b[ndim-1]]
  1865. """
  1866. ndim = len(self.x)
  1867. if extrapolate is None:
  1868. extrapolate = self.extrapolate
  1869. else:
  1870. extrapolate = bool(extrapolate)
  1871. if not hasattr(ranges, '__len__') or len(ranges) != ndim:
  1872. raise ValueError("Range not a sequence of correct length")
  1873. self._ensure_c_contiguous()
  1874. # Reuse 1D integration routine
  1875. c = self.c
  1876. for n, (a, b) in enumerate(ranges):
  1877. swap = list(range(c.ndim))
  1878. swap.insert(1, swap[ndim - n])
  1879. del swap[ndim - n + 1]
  1880. c = c.transpose(swap)
  1881. p = PPoly.construct_fast(c, self.x[n], extrapolate=extrapolate)
  1882. out = p.integrate(a, b, extrapolate=extrapolate)
  1883. c = out.reshape(c.shape[2:])
  1884. return c