test_interpolate.py 100 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692
  1. from scipy._lib._array_api import (
  2. xp_assert_equal, xp_assert_close, assert_almost_equal, assert_array_almost_equal,
  3. make_xp_test_case
  4. )
  5. from pytest import raises as assert_raises
  6. import pytest
  7. from numpy import mgrid, pi, sin, poly1d
  8. import numpy as np
  9. from scipy.interpolate import (interp1d, interp2d, lagrange, PPoly, BPoly,
  10. splrep, splev, splantider, splint, sproot, Akima1DInterpolator,
  11. NdPPoly, BSpline, PchipInterpolator)
  12. from scipy.special import poch, gamma
  13. from scipy.interpolate import _ppoly
  14. from scipy._lib._gcutils import assert_deallocated, IS_PYPY
  15. from scipy._lib._testutils import _run_concurrent_barrier
  16. from scipy.integrate import nquad
  17. from scipy.special import binom
  18. skip_xp_backends = pytest.mark.skip_xp_backends
  19. xfail_xp_backends = pytest.mark.xfail_xp_backends
  20. class TestInterp2D:
  21. def test_interp2d(self):
  22. y, x = mgrid[0:2:20j, 0:pi:21j]
  23. z = sin(x+0.5*y)
  24. with assert_raises(NotImplementedError):
  25. interp2d(x, y, z)
  26. class TestInterp1D:
  27. def setup_method(self):
  28. self.x5 = np.arange(5.)
  29. self.x10 = np.arange(10.)
  30. self.y10 = np.arange(10.)
  31. self.x25 = self.x10.reshape((2,5))
  32. self.x2 = np.arange(2.)
  33. self.y2 = np.arange(2.)
  34. self.x1 = np.array([0.])
  35. self.y1 = np.array([0.])
  36. self.y210 = np.arange(20.).reshape((2, 10))
  37. self.y102 = np.arange(20.).reshape((10, 2))
  38. self.y225 = np.arange(20.).reshape((2, 2, 5))
  39. self.y25 = np.arange(10.).reshape((2, 5))
  40. self.y235 = np.arange(30.).reshape((2, 3, 5))
  41. self.y325 = np.arange(30.).reshape((3, 2, 5))
  42. # Edge updated test matrix 1
  43. # array([[ 30, 1, 2, 3, 4, 5, 6, 7, 8, -30],
  44. # [ 30, 11, 12, 13, 14, 15, 16, 17, 18, -30]])
  45. self.y210_edge_updated = np.arange(20.).reshape((2, 10))
  46. self.y210_edge_updated[:, 0] = 30
  47. self.y210_edge_updated[:, -1] = -30
  48. # Edge updated test matrix 2
  49. # array([[ 30, 30],
  50. # [ 2, 3],
  51. # [ 4, 5],
  52. # [ 6, 7],
  53. # [ 8, 9],
  54. # [ 10, 11],
  55. # [ 12, 13],
  56. # [ 14, 15],
  57. # [ 16, 17],
  58. # [-30, -30]])
  59. self.y102_edge_updated = np.arange(20.).reshape((10, 2))
  60. self.y102_edge_updated[0, :] = 30
  61. self.y102_edge_updated[-1, :] = -30
  62. self.fill_value = -100.0
  63. def test_validation(self):
  64. # Make sure that appropriate exceptions are raised when invalid values
  65. # are given to the constructor.
  66. # These should all work.
  67. for kind in ('nearest', 'nearest-up', 'zero', 'linear', 'slinear',
  68. 'quadratic', 'cubic', 'previous', 'next'):
  69. interp1d(self.x10, self.y10, kind=kind)
  70. interp1d(self.x10, self.y10, kind=kind, fill_value="extrapolate")
  71. interp1d(self.x10, self.y10, kind='linear', fill_value=(-1, 1))
  72. interp1d(self.x10, self.y10, kind='linear',
  73. fill_value=np.array([-1]))
  74. interp1d(self.x10, self.y10, kind='linear',
  75. fill_value=(-1,))
  76. interp1d(self.x10, self.y10, kind='linear',
  77. fill_value=-1)
  78. interp1d(self.x10, self.y10, kind='linear',
  79. fill_value=(-1, -1))
  80. interp1d(self.x10, self.y10, kind=0)
  81. interp1d(self.x10, self.y10, kind=1)
  82. interp1d(self.x10, self.y10, kind=2)
  83. interp1d(self.x10, self.y10, kind=3)
  84. interp1d(self.x10, self.y210, kind='linear', axis=-1,
  85. fill_value=(-1, -1))
  86. interp1d(self.x2, self.y210, kind='linear', axis=0,
  87. fill_value=np.ones(10))
  88. interp1d(self.x2, self.y210, kind='linear', axis=0,
  89. fill_value=(np.ones(10), np.ones(10)))
  90. interp1d(self.x2, self.y210, kind='linear', axis=0,
  91. fill_value=(np.ones(10), -1))
  92. # x array must be 1D.
  93. assert_raises(ValueError, interp1d, self.x25, self.y10)
  94. # y array cannot be a scalar.
  95. assert_raises(ValueError, interp1d, self.x10, np.array(0))
  96. # Check for x and y arrays having the same length.
  97. assert_raises(ValueError, interp1d, self.x10, self.y2)
  98. assert_raises(ValueError, interp1d, self.x2, self.y10)
  99. assert_raises(ValueError, interp1d, self.x10, self.y102)
  100. interp1d(self.x10, self.y210)
  101. interp1d(self.x10, self.y102, axis=0)
  102. # Check for x and y having at least 1 element.
  103. assert_raises(ValueError, interp1d, self.x1, self.y10)
  104. assert_raises(ValueError, interp1d, self.x10, self.y1)
  105. # Bad fill values
  106. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  107. fill_value=(-1, -1, -1)) # doesn't broadcast
  108. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  109. fill_value=[-1, -1, -1]) # doesn't broadcast
  110. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  111. fill_value=np.array((-1, -1, -1))) # doesn't broadcast
  112. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  113. fill_value=[[-1]]) # doesn't broadcast
  114. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  115. fill_value=[-1, -1]) # doesn't broadcast
  116. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  117. fill_value=np.array([])) # doesn't broadcast
  118. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  119. fill_value=()) # doesn't broadcast
  120. assert_raises(ValueError, interp1d, self.x2, self.y210, kind='linear',
  121. axis=0, fill_value=[-1, -1]) # doesn't broadcast
  122. assert_raises(ValueError, interp1d, self.x2, self.y210, kind='linear',
  123. axis=0, fill_value=(0., [-1, -1])) # above doesn't bc
  124. def test_init(self):
  125. # Check that the attributes are initialized appropriately by the
  126. # constructor.
  127. assert interp1d(self.x10, self.y10).copy
  128. assert not interp1d(self.x10, self.y10, copy=False).copy
  129. assert interp1d(self.x10, self.y10).bounds_error
  130. assert not interp1d(self.x10, self.y10, bounds_error=False).bounds_error
  131. assert np.isnan(interp1d(self.x10, self.y10).fill_value)
  132. assert interp1d(self.x10, self.y10, fill_value=3.0).fill_value == 3.0
  133. assert (interp1d(self.x10, self.y10, fill_value=(1.0, 2.0)).fill_value ==
  134. (1.0, 2.0)
  135. )
  136. assert interp1d(self.x10, self.y10).axis == 0
  137. assert interp1d(self.x10, self.y210).axis == 1
  138. assert interp1d(self.x10, self.y102, axis=0).axis == 0
  139. xp_assert_equal(interp1d(self.x10, self.y10).x, self.x10)
  140. xp_assert_equal(interp1d(self.x10, self.y10).y, self.y10)
  141. xp_assert_equal(interp1d(self.x10, self.y210).y, self.y210)
  142. def test_assume_sorted(self):
  143. # Check for unsorted arrays
  144. interp10 = interp1d(self.x10, self.y10)
  145. interp10_unsorted = interp1d(self.x10[::-1], self.y10[::-1])
  146. assert_array_almost_equal(interp10_unsorted(self.x10), self.y10)
  147. assert_array_almost_equal(interp10_unsorted(1.2), np.array(1.2))
  148. assert_array_almost_equal(interp10_unsorted([2.4, 5.6, 6.0]),
  149. interp10([2.4, 5.6, 6.0]))
  150. # Check assume_sorted keyword (defaults to False)
  151. interp10_assume_kw = interp1d(self.x10[::-1], self.y10[::-1],
  152. assume_sorted=False)
  153. assert_array_almost_equal(interp10_assume_kw(self.x10), self.y10)
  154. interp10_assume_kw2 = interp1d(self.x10[::-1], self.y10[::-1],
  155. assume_sorted=True)
  156. # Should raise an error for unsorted input if assume_sorted=True
  157. assert_raises(ValueError, interp10_assume_kw2, self.x10)
  158. # Check that if y is a 2-D array, things are still consistent
  159. interp10_y_2d = interp1d(self.x10, self.y210)
  160. interp10_y_2d_unsorted = interp1d(self.x10[::-1], self.y210[:, ::-1])
  161. assert_array_almost_equal(interp10_y_2d(self.x10),
  162. interp10_y_2d_unsorted(self.x10))
  163. def test_linear(self):
  164. for kind in ['linear', 'slinear']:
  165. self._check_linear(kind)
  166. def _check_linear(self, kind):
  167. # Check the actual implementation of linear interpolation.
  168. interp10 = interp1d(self.x10, self.y10, kind=kind)
  169. assert_array_almost_equal(interp10(self.x10), self.y10)
  170. assert_array_almost_equal(interp10(1.2), np.array(1.2))
  171. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  172. np.array([2.4, 5.6, 6.0]))
  173. # test fill_value="extrapolate"
  174. extrapolator = interp1d(self.x10, self.y10, kind=kind,
  175. fill_value='extrapolate')
  176. xp_assert_close(extrapolator([-1., 0, 9, 11]),
  177. np.asarray([-1.0, 0, 9, 11]), rtol=1e-14)
  178. opts = dict(kind=kind,
  179. fill_value='extrapolate',
  180. bounds_error=True)
  181. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  182. def test_linear_dtypes(self):
  183. # regression test for gh-5898, where 1D linear interpolation has been
  184. # delegated to numpy.interp for all float dtypes, and the latter was
  185. # not handling e.g. np.float128.
  186. for dtyp in [np.float16,
  187. np.float32,
  188. np.float64,
  189. np.longdouble]:
  190. x = np.arange(8, dtype=dtyp)
  191. y = x
  192. yp = interp1d(x, y, kind='linear')(x)
  193. assert yp.dtype == dtyp
  194. xp_assert_close(yp, y, atol=1e-15)
  195. # regression test for gh-14531, where 1D linear interpolation has been
  196. # has been extended to delegate to numpy.interp for integer dtypes
  197. x = [0, 1, 2]
  198. y = [np.nan, 0, 1]
  199. yp = interp1d(x, y)(x)
  200. xp_assert_close(yp, y, atol=1e-15)
  201. def test_slinear_dtypes(self):
  202. # regression test for gh-7273: 1D slinear interpolation fails with
  203. # float32 inputs
  204. dt_r = [np.float16, np.float32, np.float64]
  205. dt_rc = dt_r + [np.complex64, np.complex128]
  206. spline_kinds = ['slinear', 'zero', 'quadratic', 'cubic']
  207. for dtx in dt_r:
  208. x = np.arange(0, 10, dtype=dtx)
  209. for dty in dt_rc:
  210. y = np.exp(-x/3.0).astype(dty)
  211. for dtn in dt_r:
  212. xnew = x.astype(dtn)
  213. for kind in spline_kinds:
  214. f = interp1d(x, y, kind=kind, bounds_error=False)
  215. xp_assert_close(f(xnew), y, atol=1e-7,
  216. check_dtype=False,
  217. err_msg=f"{dtx}, {dty} {dtn}")
  218. def test_cubic(self):
  219. # Check the actual implementation of spline interpolation.
  220. interp10 = interp1d(self.x10, self.y10, kind='cubic')
  221. assert_array_almost_equal(interp10(self.x10), self.y10)
  222. assert_array_almost_equal(interp10(1.2), np.array(1.2))
  223. assert_array_almost_equal(interp10(1.5), np.array(1.5))
  224. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  225. np.array([2.4, 5.6, 6.0]),)
  226. def test_nearest(self):
  227. # Check the actual implementation of nearest-neighbour interpolation.
  228. # Nearest asserts that half-integer case (1.5) rounds down to 1
  229. interp10 = interp1d(self.x10, self.y10, kind='nearest')
  230. assert_array_almost_equal(interp10(self.x10), self.y10)
  231. assert_array_almost_equal(interp10(1.2), np.array(1.))
  232. assert_array_almost_equal(interp10(1.5), np.array(1.))
  233. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  234. np.array([2., 6., 6.]),)
  235. # test fill_value="extrapolate"
  236. extrapolator = interp1d(self.x10, self.y10, kind='nearest',
  237. fill_value='extrapolate')
  238. xp_assert_close(extrapolator([-1., 0, 9, 11]),
  239. [0.0, 0, 9, 9], rtol=1e-14)
  240. opts = dict(kind='nearest',
  241. fill_value='extrapolate',
  242. bounds_error=True)
  243. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  244. def test_nearest_up(self):
  245. # Check the actual implementation of nearest-neighbour interpolation.
  246. # Nearest-up asserts that half-integer case (1.5) rounds up to 2
  247. interp10 = interp1d(self.x10, self.y10, kind='nearest-up')
  248. assert_array_almost_equal(interp10(self.x10), self.y10)
  249. assert_array_almost_equal(interp10(1.2), np.array(1.))
  250. assert_array_almost_equal(interp10(1.5), np.array(2.))
  251. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  252. np.array([2., 6., 6.]),)
  253. # test fill_value="extrapolate"
  254. extrapolator = interp1d(self.x10, self.y10, kind='nearest-up',
  255. fill_value='extrapolate')
  256. xp_assert_close(extrapolator([-1., 0, 9, 11]),
  257. [0.0, 0, 9, 9], rtol=1e-14)
  258. opts = dict(kind='nearest-up',
  259. fill_value='extrapolate',
  260. bounds_error=True)
  261. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  262. def test_previous(self):
  263. # Check the actual implementation of previous interpolation.
  264. interp10 = interp1d(self.x10, self.y10, kind='previous')
  265. assert_array_almost_equal(interp10(self.x10), self.y10)
  266. assert_array_almost_equal(interp10(1.2), np.array(1.))
  267. assert_array_almost_equal(interp10(1.5), np.array(1.))
  268. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  269. np.array([2., 5., 6.]),)
  270. # test fill_value="extrapolate"
  271. extrapolator = interp1d(self.x10, self.y10, kind='previous',
  272. fill_value='extrapolate')
  273. xp_assert_close(extrapolator([-1., 0, 9, 11]),
  274. [np.nan, 0, 9, 9], rtol=1e-14)
  275. # Tests for gh-9591
  276. interpolator1D = interp1d(self.x10, self.y10, kind="previous",
  277. fill_value='extrapolate')
  278. xp_assert_close(interpolator1D([-1, -2, 5, 8, 12, 25]),
  279. [np.nan, np.nan, 5, 8, 9, 9])
  280. interpolator2D = interp1d(self.x10, self.y210, kind="previous",
  281. fill_value='extrapolate')
  282. xp_assert_close(interpolator2D([-1, -2, 5, 8, 12, 25]),
  283. [[np.nan, np.nan, 5, 8, 9, 9],
  284. [np.nan, np.nan, 15, 18, 19, 19]])
  285. interpolator2DAxis0 = interp1d(self.x10, self.y102, kind="previous",
  286. axis=0, fill_value='extrapolate')
  287. xp_assert_close(interpolator2DAxis0([-2, 5, 12]),
  288. [[np.nan, np.nan],
  289. [10, 11],
  290. [18, 19]])
  291. opts = dict(kind='previous',
  292. fill_value='extrapolate',
  293. bounds_error=True)
  294. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  295. # Tests for gh-16813
  296. interpolator1D = interp1d([0, 1, 2],
  297. [0, 1, -1], kind="previous",
  298. fill_value='extrapolate',
  299. assume_sorted=True)
  300. xp_assert_close(interpolator1D([-2, -1, 0, 1, 2, 3, 5]),
  301. [np.nan, np.nan, 0, 1, -1, -1, -1])
  302. interpolator1D = interp1d([2, 0, 1], # x is not ascending
  303. [-1, 0, 1], kind="previous",
  304. fill_value='extrapolate',
  305. assume_sorted=False)
  306. xp_assert_close(interpolator1D([-2, -1, 0, 1, 2, 3, 5]),
  307. [np.nan, np.nan, 0, 1, -1, -1, -1])
  308. interpolator2D = interp1d(self.x10, self.y210_edge_updated,
  309. kind="previous",
  310. fill_value='extrapolate')
  311. xp_assert_close(interpolator2D([-1, -2, 5, 8, 12, 25]),
  312. [[np.nan, np.nan, 5, 8, -30, -30],
  313. [np.nan, np.nan, 15, 18, -30, -30]])
  314. interpolator2DAxis0 = interp1d(self.x10, self.y102_edge_updated,
  315. kind="previous",
  316. axis=0, fill_value='extrapolate')
  317. xp_assert_close(interpolator2DAxis0([-2, 5, 12]),
  318. [[np.nan, np.nan],
  319. [10, 11],
  320. [-30, -30]])
  321. def test_next(self):
  322. # Check the actual implementation of next interpolation.
  323. interp10 = interp1d(self.x10, self.y10, kind='next')
  324. assert_array_almost_equal(interp10(self.x10), self.y10)
  325. assert_array_almost_equal(interp10(1.2), np.array(2.))
  326. assert_array_almost_equal(interp10(1.5), np.array(2.))
  327. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  328. np.array([3., 6., 6.]),)
  329. # test fill_value="extrapolate"
  330. extrapolator = interp1d(self.x10, self.y10, kind='next',
  331. fill_value='extrapolate')
  332. xp_assert_close(extrapolator([-1., 0, 9, 11]),
  333. [0, 0, 9, np.nan], rtol=1e-14)
  334. # Tests for gh-9591
  335. interpolator1D = interp1d(self.x10, self.y10, kind="next",
  336. fill_value='extrapolate')
  337. xp_assert_close(interpolator1D([-1, -2, 5, 8, 12, 25]),
  338. [0, 0, 5, 8, np.nan, np.nan])
  339. interpolator2D = interp1d(self.x10, self.y210, kind="next",
  340. fill_value='extrapolate')
  341. xp_assert_close(interpolator2D([-1, -2, 5, 8, 12, 25]),
  342. [[0, 0, 5, 8, np.nan, np.nan],
  343. [10, 10, 15, 18, np.nan, np.nan]])
  344. interpolator2DAxis0 = interp1d(self.x10, self.y102, kind="next",
  345. axis=0, fill_value='extrapolate')
  346. xp_assert_close(interpolator2DAxis0([-2, 5, 12]),
  347. [[0, 1],
  348. [10, 11],
  349. [np.nan, np.nan]])
  350. opts = dict(kind='next',
  351. fill_value='extrapolate',
  352. bounds_error=True)
  353. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  354. # Tests for gh-16813
  355. interpolator1D = interp1d([0, 1, 2],
  356. [0, 1, -1], kind="next",
  357. fill_value='extrapolate',
  358. assume_sorted=True)
  359. xp_assert_close(interpolator1D([-2, -1, 0, 1, 2, 3, 5]),
  360. [0, 0, 0, 1, -1, np.nan, np.nan])
  361. interpolator1D = interp1d([2, 0, 1], # x is not ascending
  362. [-1, 0, 1], kind="next",
  363. fill_value='extrapolate',
  364. assume_sorted=False)
  365. xp_assert_close(interpolator1D([-2, -1, 0, 1, 2, 3, 5]),
  366. [0, 0, 0, 1, -1, np.nan, np.nan])
  367. interpolator2D = interp1d(self.x10, self.y210_edge_updated,
  368. kind="next",
  369. fill_value='extrapolate')
  370. xp_assert_close(interpolator2D([-1, -2, 5, 8, 12, 25]),
  371. [[30, 30, 5, 8, np.nan, np.nan],
  372. [30, 30, 15, 18, np.nan, np.nan]])
  373. interpolator2DAxis0 = interp1d(self.x10, self.y102_edge_updated,
  374. kind="next",
  375. axis=0, fill_value='extrapolate')
  376. xp_assert_close(interpolator2DAxis0([-2, 5, 12]),
  377. [[30, 30],
  378. [10, 11],
  379. [np.nan, np.nan]])
  380. def test_zero(self):
  381. # Check the actual implementation of zero-order spline interpolation.
  382. interp10 = interp1d(self.x10, self.y10, kind='zero')
  383. assert_array_almost_equal(interp10(self.x10), self.y10)
  384. assert_array_almost_equal(interp10(1.2), np.array(1.))
  385. assert_array_almost_equal(interp10(1.5), np.array(1.))
  386. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  387. np.array([2., 5., 6.]))
  388. def bounds_check_helper(self, interpolant, test_array, fail_value):
  389. # Asserts that a ValueError is raised and that the error message
  390. # contains the value causing this exception.
  391. assert_raises(ValueError, interpolant, test_array)
  392. try:
  393. interpolant(test_array)
  394. except ValueError as err:
  395. assert (f"{fail_value}" in str(err))
  396. def _bounds_check(self, kind='linear'):
  397. # Test that our handling of out-of-bounds input is correct.
  398. extrap10 = interp1d(self.x10, self.y10, fill_value=self.fill_value,
  399. bounds_error=False, kind=kind)
  400. xp_assert_equal(extrap10(11.2), np.array(self.fill_value))
  401. xp_assert_equal(extrap10(-3.4), np.array(self.fill_value))
  402. xp_assert_equal(extrap10([[[11.2], [-3.4], [12.6], [19.3]]]),
  403. np.array(self.fill_value), check_shape=False)
  404. xp_assert_equal(extrap10._check_bounds(
  405. np.array([-1.0, 0.0, 5.0, 9.0, 11.0])),
  406. np.array([[True, False, False, False, False],
  407. [False, False, False, False, True]]))
  408. raises_bounds_error = interp1d(self.x10, self.y10, bounds_error=True,
  409. kind=kind)
  410. self.bounds_check_helper(raises_bounds_error, -1.0, -1.0)
  411. self.bounds_check_helper(raises_bounds_error, 11.0, 11.0)
  412. self.bounds_check_helper(raises_bounds_error, [0.0, -1.0, 0.0], -1.0)
  413. self.bounds_check_helper(raises_bounds_error, [0.0, 1.0, 21.0], 21.0)
  414. raises_bounds_error([0.0, 5.0, 9.0])
  415. def _bounds_check_int_nan_fill(self, kind='linear'):
  416. x = np.arange(10).astype(int)
  417. y = np.arange(10).astype(int)
  418. c = interp1d(x, y, kind=kind, fill_value=np.nan, bounds_error=False)
  419. yi = c(x - 1)
  420. assert np.isnan(yi[0])
  421. assert_array_almost_equal(yi, np.r_[np.nan, y[:-1]])
  422. def test_bounds(self):
  423. for kind in ('linear', 'cubic', 'nearest', 'previous', 'next',
  424. 'slinear', 'zero', 'quadratic'):
  425. self._bounds_check(kind)
  426. self._bounds_check_int_nan_fill(kind)
  427. def _check_fill_value(self, kind):
  428. interp = interp1d(self.x10, self.y10, kind=kind,
  429. fill_value=(-100, 100), bounds_error=False)
  430. assert_array_almost_equal(interp(10), np.asarray(100.))
  431. assert_array_almost_equal(interp(-10), np.asarray(-100.))
  432. assert_array_almost_equal(interp([-10, 10]), [-100, 100])
  433. # Proper broadcasting:
  434. # interp along axis of length 5
  435. # other dim=(2, 3), (3, 2), (2, 2), or (2,)
  436. # one singleton fill_value (works for all)
  437. for y in (self.y235, self.y325, self.y225, self.y25):
  438. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  439. fill_value=100, bounds_error=False)
  440. assert_array_almost_equal(interp(10), np.asarray(100.))
  441. assert_array_almost_equal(interp(-10), np.asarray(100.))
  442. assert_array_almost_equal(interp([-10, 10]), np.asarray(100.))
  443. # singleton lower, singleton upper
  444. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  445. fill_value=(-100, 100), bounds_error=False)
  446. assert_array_almost_equal(interp(10), np.asarray(100.))
  447. assert_array_almost_equal(interp(-10), np.asarray(-100.))
  448. if y.ndim == 3:
  449. result = [[[-100, 100]] * y.shape[1]] * y.shape[0]
  450. else:
  451. result = [[-100, 100]] * y.shape[0]
  452. assert_array_almost_equal(interp([-10, 10]), result)
  453. # one broadcastable (3,) fill_value
  454. fill_value = [100, 200, 300]
  455. for y in (self.y325, self.y225):
  456. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  457. axis=-1, fill_value=fill_value, bounds_error=False)
  458. interp = interp1d(self.x5, self.y235, kind=kind, axis=-1,
  459. fill_value=fill_value, bounds_error=False)
  460. assert_array_almost_equal(interp(10), [[100, 200, 300]] * 2)
  461. assert_array_almost_equal(interp(-10), [[100, 200, 300]] * 2)
  462. assert_array_almost_equal(interp([-10, 10]), [[[100, 100],
  463. [200, 200],
  464. [300, 300]]] * 2)
  465. # one broadcastable (2,) fill_value
  466. fill_value = [100, 200]
  467. assert_raises(ValueError, interp1d, self.x5, self.y235, kind=kind,
  468. axis=-1, fill_value=fill_value, bounds_error=False)
  469. for y in (self.y225, self.y325, self.y25):
  470. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  471. fill_value=fill_value, bounds_error=False)
  472. result = [100, 200]
  473. if y.ndim == 3:
  474. result = [result] * y.shape[0]
  475. assert_array_almost_equal(interp(10), result)
  476. assert_array_almost_equal(interp(-10), result)
  477. result = [[100, 100], [200, 200]]
  478. if y.ndim == 3:
  479. result = [result] * y.shape[0]
  480. assert_array_almost_equal(interp([-10, 10]), result)
  481. # broadcastable (3,) lower, singleton upper
  482. fill_value = (np.array([-100, -200, -300]), 100)
  483. for y in (self.y325, self.y225):
  484. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  485. axis=-1, fill_value=fill_value, bounds_error=False)
  486. interp = interp1d(self.x5, self.y235, kind=kind, axis=-1,
  487. fill_value=fill_value, bounds_error=False)
  488. assert_array_almost_equal(interp(10), np.asarray(100.))
  489. assert_array_almost_equal(interp(-10), [[-100, -200, -300]] * 2)
  490. assert_array_almost_equal(interp([-10, 10]), [[[-100, 100],
  491. [-200, 100],
  492. [-300, 100]]] * 2)
  493. # broadcastable (2,) lower, singleton upper
  494. fill_value = (np.array([-100, -200]), 100)
  495. assert_raises(ValueError, interp1d, self.x5, self.y235, kind=kind,
  496. axis=-1, fill_value=fill_value, bounds_error=False)
  497. for y in (self.y225, self.y325, self.y25):
  498. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  499. fill_value=fill_value, bounds_error=False)
  500. assert_array_almost_equal(interp(10), np.asarray(100))
  501. result = [-100, -200]
  502. if y.ndim == 3:
  503. result = [result] * y.shape[0]
  504. assert_array_almost_equal(interp(-10), result)
  505. result = [[-100, 100], [-200, 100]]
  506. if y.ndim == 3:
  507. result = [result] * y.shape[0]
  508. assert_array_almost_equal(interp([-10, 10]), result)
  509. # broadcastable (3,) lower, broadcastable (3,) upper
  510. fill_value = ([-100, -200, -300], [100, 200, 300])
  511. for y in (self.y325, self.y225):
  512. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  513. axis=-1, fill_value=fill_value, bounds_error=False)
  514. for ii in range(2): # check ndarray as well as list here
  515. if ii == 1:
  516. fill_value = tuple(np.array(f) for f in fill_value)
  517. interp = interp1d(self.x5, self.y235, kind=kind, axis=-1,
  518. fill_value=fill_value, bounds_error=False)
  519. assert_array_almost_equal(interp(10), [[100, 200, 300]] * 2)
  520. assert_array_almost_equal(interp(-10), [[-100, -200, -300]] * 2)
  521. assert_array_almost_equal(interp([-10, 10]), [[[-100, 100],
  522. [-200, 200],
  523. [-300, 300]]] * 2)
  524. # broadcastable (2,) lower, broadcastable (2,) upper
  525. fill_value = ([-100, -200], [100, 200])
  526. assert_raises(ValueError, interp1d, self.x5, self.y235, kind=kind,
  527. axis=-1, fill_value=fill_value, bounds_error=False)
  528. for y in (self.y325, self.y225, self.y25):
  529. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  530. fill_value=fill_value, bounds_error=False)
  531. result = [100, 200]
  532. if y.ndim == 3:
  533. result = [result] * y.shape[0]
  534. assert_array_almost_equal(interp(10), result)
  535. result = [-100, -200]
  536. if y.ndim == 3:
  537. result = [result] * y.shape[0]
  538. assert_array_almost_equal(interp(-10), result)
  539. result = [[-100, 100], [-200, 200]]
  540. if y.ndim == 3:
  541. result = [result] * y.shape[0]
  542. assert_array_almost_equal(interp([-10, 10]), result)
  543. # one broadcastable (2, 2) array-like
  544. fill_value = [[100, 200], [1000, 2000]]
  545. for y in (self.y235, self.y325, self.y25):
  546. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  547. axis=-1, fill_value=fill_value, bounds_error=False)
  548. for ii in range(2):
  549. if ii == 1:
  550. fill_value = np.array(fill_value)
  551. interp = interp1d(self.x5, self.y225, kind=kind, axis=-1,
  552. fill_value=fill_value, bounds_error=False)
  553. assert_array_almost_equal(interp(10), [[100, 200], [1000, 2000]])
  554. assert_array_almost_equal(interp(-10), [[100, 200], [1000, 2000]])
  555. assert_array_almost_equal(interp([-10, 10]), [[[100, 100],
  556. [200, 200]],
  557. [[1000, 1000],
  558. [2000, 2000]]])
  559. # broadcastable (2, 2) lower, broadcastable (2, 2) upper
  560. fill_value = ([[-100, -200], [-1000, -2000]],
  561. [[100, 200], [1000, 2000]])
  562. for y in (self.y235, self.y325, self.y25):
  563. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  564. axis=-1, fill_value=fill_value, bounds_error=False)
  565. for ii in range(2):
  566. if ii == 1:
  567. fill_value = (np.array(fill_value[0]), np.array(fill_value[1]))
  568. interp = interp1d(self.x5, self.y225, kind=kind, axis=-1,
  569. fill_value=fill_value, bounds_error=False)
  570. assert_array_almost_equal(interp(10), [[100, 200], [1000, 2000]])
  571. assert_array_almost_equal(interp(-10), [[-100, -200],
  572. [-1000, -2000]])
  573. assert_array_almost_equal(interp([-10, 10]), [[[-100, 100],
  574. [-200, 200]],
  575. [[-1000, 1000],
  576. [-2000, 2000]]])
  577. def test_fill_value(self):
  578. # test that two-element fill value works
  579. for kind in ('linear', 'nearest', 'cubic', 'slinear', 'quadratic',
  580. 'zero', 'previous', 'next'):
  581. self._check_fill_value(kind)
  582. def test_fill_value_writeable(self):
  583. # backwards compat: fill_value is a public writeable attribute
  584. interp = interp1d(self.x10, self.y10, fill_value=123.0)
  585. assert interp.fill_value == 123.0
  586. interp.fill_value = 321.0
  587. assert interp.fill_value == 321.0
  588. def _nd_check_interp(self, kind='linear'):
  589. # Check the behavior when the inputs and outputs are multidimensional.
  590. # Multidimensional input.
  591. interp10 = interp1d(self.x10, self.y10, kind=kind)
  592. assert_array_almost_equal(interp10(np.array([[3., 5.], [2., 7.]])),
  593. np.array([[3., 5.], [2., 7.]]))
  594. # Scalar input -> 0-dim scalar array output
  595. assert isinstance(interp10(1.2), np.ndarray)
  596. assert interp10(1.2).shape == ()
  597. # Multidimensional outputs.
  598. interp210 = interp1d(self.x10, self.y210, kind=kind)
  599. assert_array_almost_equal(interp210(1.), np.array([1., 11.]))
  600. assert_array_almost_equal(interp210(np.array([1., 2.])),
  601. np.array([[1., 2.], [11., 12.]]))
  602. interp102 = interp1d(self.x10, self.y102, axis=0, kind=kind)
  603. assert_array_almost_equal(interp102(1.), np.array([2.0, 3.0]))
  604. assert_array_almost_equal(interp102(np.array([1., 3.])),
  605. np.array([[2., 3.], [6., 7.]]))
  606. # Both at the same time!
  607. x_new = np.array([[3., 5.], [2., 7.]])
  608. assert_array_almost_equal(interp210(x_new),
  609. np.array([[[3., 5.], [2., 7.]],
  610. [[13., 15.], [12., 17.]]]))
  611. assert_array_almost_equal(interp102(x_new),
  612. np.array([[[6., 7.], [10., 11.]],
  613. [[4., 5.], [14., 15.]]]))
  614. def _nd_check_shape(self, kind='linear'):
  615. # Check large N-D output shape
  616. a = [4, 5, 6, 7]
  617. y = np.arange(np.prod(a)).reshape(*a)
  618. for n, s in enumerate(a):
  619. x = np.arange(s)
  620. z = interp1d(x, y, axis=n, kind=kind)
  621. assert_array_almost_equal(z(x), y, err_msg=kind)
  622. x2 = np.arange(2*3*1).reshape((2,3,1)) / 12.
  623. b = list(a)
  624. b[n:n+1] = [2, 3, 1]
  625. assert z(x2).shape == tuple(b), kind
  626. def test_nd(self):
  627. for kind in ('linear', 'cubic', 'slinear', 'quadratic', 'nearest',
  628. 'zero', 'previous', 'next'):
  629. self._nd_check_interp(kind)
  630. self._nd_check_shape(kind)
  631. def _check_complex(self, dtype=np.complex128, kind='linear'):
  632. x = np.array([1, 2.5, 3, 3.1, 4, 6.4, 7.9, 8.0, 9.5, 10])
  633. y = x * x ** (1 + 2j)
  634. y = y.astype(dtype)
  635. # simple test
  636. c = interp1d(x, y, kind=kind)
  637. assert_array_almost_equal(y[:-1], c(x)[:-1])
  638. # check against interpolating real+imag separately
  639. xi = np.linspace(1, 10, 31)
  640. cr = interp1d(x, y.real, kind=kind)
  641. ci = interp1d(x, y.imag, kind=kind)
  642. assert_array_almost_equal(c(xi).real, cr(xi))
  643. assert_array_almost_equal(c(xi).imag, ci(xi))
  644. def test_complex(self):
  645. for kind in ('linear', 'nearest', 'cubic', 'slinear', 'quadratic',
  646. 'zero', 'previous', 'next'):
  647. self._check_complex(np.complex64, kind)
  648. self._check_complex(np.complex128, kind)
  649. @pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
  650. def test_circular_refs(self):
  651. # Test interp1d can be automatically garbage collected
  652. x = np.linspace(0, 1)
  653. y = np.linspace(0, 1)
  654. # Confirm interp can be released from memory after use
  655. with assert_deallocated(interp1d, x, y) as interp:
  656. interp([0.1, 0.2])
  657. del interp
  658. def test_overflow_nearest(self):
  659. # Test that the x range doesn't overflow when given integers as input
  660. for kind in ('nearest', 'previous', 'next'):
  661. x = np.array([0, 50, 127], dtype=np.int8)
  662. ii = interp1d(x, x, kind=kind)
  663. assert_array_almost_equal(ii(x), x)
  664. def test_local_nans(self):
  665. # check that for local interpolation kinds (slinear, zero) a single nan
  666. # only affects its local neighborhood
  667. x = np.arange(10).astype(float)
  668. y = x.copy()
  669. y[6] = np.nan
  670. for kind in ('zero', 'slinear'):
  671. ir = interp1d(x, y, kind=kind)
  672. vals = ir([4.9, 7.0])
  673. assert np.isfinite(vals).all()
  674. def test_spline_nans(self):
  675. # Backwards compat: a single nan makes the whole spline interpolation
  676. # return nans in an array of the correct shape. And it doesn't raise,
  677. # just quiet nans because of backcompat.
  678. x = np.arange(8).astype(float)
  679. y = x.copy()
  680. yn = y.copy()
  681. yn[3] = np.nan
  682. for kind in ['quadratic', 'cubic']:
  683. ir = interp1d(x, y, kind=kind)
  684. irn = interp1d(x, yn, kind=kind)
  685. for xnew in (6, [1, 6], [[1, 6], [3, 5]]):
  686. xnew = np.asarray(xnew)
  687. out, outn = ir(x), irn(x)
  688. assert np.isnan(outn).all()
  689. assert out.shape == outn.shape
  690. def test_all_nans(self):
  691. # regression test for gh-11637: interp1d core dumps with all-nan `x`
  692. x = np.ones(10) * np.nan
  693. y = np.arange(10)
  694. with assert_raises(ValueError):
  695. interp1d(x, y, kind='cubic')
  696. def test_read_only(self):
  697. x = np.arange(0, 10)
  698. y = np.exp(-x / 3.0)
  699. xnew = np.arange(0, 9, 0.1)
  700. # Check both read-only and not read-only:
  701. for xnew_writeable in (True, False):
  702. xnew.flags.writeable = xnew_writeable
  703. x.flags.writeable = False
  704. for kind in ('linear', 'nearest', 'zero', 'slinear', 'quadratic',
  705. 'cubic'):
  706. f = interp1d(x, y, kind=kind)
  707. vals = f(xnew)
  708. assert np.isfinite(vals).all()
  709. @pytest.mark.parametrize(
  710. "kind", ("linear", "nearest", "nearest-up", "previous", "next")
  711. )
  712. def test_single_value(self, kind):
  713. # https://github.com/scipy/scipy/issues/4043
  714. f = interp1d([1.5], [6], kind=kind, bounds_error=False,
  715. fill_value=(2, 10))
  716. xp_assert_equal(f([1, 1.5, 2]), np.asarray([2.0, 6, 10]))
  717. # check still error if bounds_error=True
  718. f = interp1d([1.5], [6], kind=kind, bounds_error=True)
  719. with assert_raises(ValueError, match="x_new is above"):
  720. f(2.0)
  721. class TestLagrange:
  722. def test_lagrange(self):
  723. p = poly1d([5,2,1,4,3])
  724. xs = np.arange(len(p.coeffs))
  725. ys = p(xs)
  726. pl = lagrange(xs,ys)
  727. assert_array_almost_equal(p.coeffs,pl.coeffs)
  728. @make_xp_test_case(Akima1DInterpolator)
  729. class TestAkima1DInterpolator:
  730. def test_eval(self, xp):
  731. x = xp.arange(0., 11., dtype=xp.float64)
  732. y = xp.asarray(
  733. [0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.], dtype=xp.float64
  734. )
  735. ak = Akima1DInterpolator(x, y)
  736. xi = xp.asarray([0., 0.5, 1., 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2,
  737. 8.6, 9.9, 10.], dtype=xp.float64)
  738. yi = xp.asarray([0., 1.375, 2., 1.5, 1.953125, 2.484375,
  739. 4.1363636363636366866103344, 5.9803623910336236590978842,
  740. 5.5067291516462386624652936, 5.2031367459745245795943447,
  741. 4.1796554159017080820603951, 3.4110386597938129327189927,
  742. 3.], dtype=xp.float64)
  743. xp_assert_close(ak(xi), yi)
  744. def test_eval_mod(self, xp):
  745. # Reference values generated with the following MATLAB code:
  746. # format longG
  747. # x = 0:10; y = [0. 2. 1. 3. 2. 6. 5.5 5.5 2.7 5.1 3.];
  748. # xi = [0. 0.5 1. 1.5 2.5 3.5 4.5 5.1 6.5 7.2 8.6 9.9 10.];
  749. # makima(x, y, xi)
  750. x = xp.arange(0., 11., dtype=xp.float64)
  751. y = xp.asarray(
  752. [0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.], dtype=xp.float64
  753. )
  754. ak = Akima1DInterpolator(x, y, method="makima")
  755. xi = xp.asarray([0., 0.5, 1., 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2,
  756. 8.6, 9.9, 10.], dtype=xp.float64)
  757. yi = xp.asarray([
  758. 0.0, 1.34471153846154, 2.0, 1.44375, 1.94375, 2.51939102564103,
  759. 4.10366931918656, 5.98501550899192, 5.51756330960439, 5.1757231914014,
  760. 4.12326636931311, 3.32931513157895, 3.0], dtype=xp.float64)
  761. xp_assert_close(ak(xi), yi)
  762. def test_eval_2d(self, xp):
  763. x = xp.arange(0., 11., dtype=xp.float64)
  764. y = xp.asarray(
  765. [0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.], dtype=xp.float64
  766. )
  767. y = xp.stack((y, 2. * y), axis=1)
  768. ak = Akima1DInterpolator(x, y)
  769. xi = xp.asarray([0., 0.5, 1., 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2,
  770. 8.6, 9.9, 10.], dtype=xp.float64)
  771. yi = xp.asarray([0., 1.375, 2., 1.5, 1.953125, 2.484375,
  772. 4.1363636363636366866103344,
  773. 5.9803623910336236590978842,
  774. 5.5067291516462386624652936,
  775. 5.2031367459745245795943447,
  776. 4.1796554159017080820603951,
  777. 3.4110386597938129327189927, 3.], dtype=xp.float64)
  778. yi = xp.stack((yi, 2. * yi), axis=1)
  779. xp_assert_close(ak(xi), yi)
  780. def test_eval_3d(self):
  781. x = np.arange(0., 11.)
  782. y_ = np.array([0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.])
  783. y = np.empty((11, 2, 2))
  784. y[:, 0, 0] = y_
  785. y[:, 1, 0] = 2. * y_
  786. y[:, 0, 1] = 3. * y_
  787. y[:, 1, 1] = 4. * y_
  788. ak = Akima1DInterpolator(x, y)
  789. xi = np.array([0., 0.5, 1., 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2,
  790. 8.6, 9.9, 10.])
  791. yi = np.empty((13, 2, 2))
  792. yi_ = np.array([0., 1.375, 2., 1.5, 1.953125, 2.484375,
  793. 4.1363636363636366866103344,
  794. 5.9803623910336236590978842,
  795. 5.5067291516462386624652936,
  796. 5.2031367459745245795943447,
  797. 4.1796554159017080820603951,
  798. 3.4110386597938129327189927, 3.])
  799. yi[:, 0, 0] = yi_
  800. yi[:, 1, 0] = 2. * yi_
  801. yi[:, 0, 1] = 3. * yi_
  802. yi[:, 1, 1] = 4. * yi_
  803. xp_assert_close(ak(xi), yi)
  804. def test_linear_interpolant_edge_case_1d(self, xp):
  805. x = xp.asarray([0.0, 1.0], dtype=xp.float64)
  806. y = xp.asarray([0.5, 1.0])
  807. akima = Akima1DInterpolator(x, y, axis=0, extrapolate=None)
  808. xp_assert_close(akima(0.45), xp.asarray(0.725, dtype=xp.float64))
  809. def test_linear_interpolant_edge_case_2d(self, xp):
  810. x = xp.asarray([0., 1.])
  811. y = xp.stack((x, 2. * x, 3. * x, 4. * x), axis=1)
  812. ak = Akima1DInterpolator(x, y)
  813. xi = xp.asarray([0.5, 1.])
  814. yi = xp.asarray([[0.5, 1., 1.5, 2.],
  815. [1., 2., 3., 4.]], dtype=xp.float64
  816. )
  817. xp_assert_close(ak(xi), yi)
  818. ak = Akima1DInterpolator(x, y.T, axis=1)
  819. xp_assert_close(ak(xi), yi.T)
  820. def test_linear_interpolant_edge_case_3d(self):
  821. x = np.arange(0., 2.)
  822. y_ = np.array([0., 1.])
  823. y = np.empty((2, 2, 2))
  824. y[:, 0, 0] = y_
  825. y[:, 1, 0] = 2. * y_
  826. y[:, 0, 1] = 3. * y_
  827. y[:, 1, 1] = 4. * y_
  828. ak = Akima1DInterpolator(x, y)
  829. yi_ = np.array([0.5, 1.])
  830. yi = np.empty((2, 2, 2))
  831. yi[:, 0, 0] = yi_
  832. yi[:, 1, 0] = 2. * yi_
  833. yi[:, 0, 1] = 3. * yi_
  834. yi[:, 1, 1] = 4. * yi_
  835. xi = yi_
  836. xp_assert_close(ak(xi), yi)
  837. ak = Akima1DInterpolator(x, y.transpose(1, 0, 2), axis=1)
  838. xp_assert_close(ak(xi), yi.transpose(1, 0, 2))
  839. ak = Akima1DInterpolator(x, y.transpose(2, 1, 0), axis=2)
  840. xp_assert_close(ak(xi), yi.transpose(2, 1, 0))
  841. def test_degenerate_case_multidimensional(self, xp):
  842. # This test is for issue #5683.
  843. x = xp.asarray([0, 1, 2], dtype=xp.float64)
  844. y = xp.stack((x, x**2)).T
  845. ak = Akima1DInterpolator(x, y)
  846. x_eval = xp.asarray([0.5, 1.5], dtype=xp.float64)
  847. y_eval = ak(x_eval)
  848. xp_assert_close(y_eval, xp.stack((x_eval, x_eval**2)).T)
  849. def test_extend(self):
  850. x = np.arange(0., 11.)
  851. y = np.array([0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.])
  852. ak = Akima1DInterpolator(x, y)
  853. match = "Extending a 1-D Akima interpolator is not yet implemented"
  854. with pytest.raises(NotImplementedError, match=match):
  855. ak.extend(None, None)
  856. def test_mod_invalid_method(self):
  857. x = np.arange(0., 11.)
  858. y = np.array([0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.])
  859. match = "`method`=invalid is unsupported."
  860. with pytest.raises(NotImplementedError, match=match):
  861. Akima1DInterpolator(x, y, method="invalid") # type: ignore
  862. def test_extrapolate_attr(self):
  863. #
  864. x = np.linspace(-5, 5, 11)
  865. y = x**2
  866. x_ext = np.linspace(-10, 10, 17)
  867. y_ext = x_ext**2
  868. # Testing all extrapolate cases.
  869. ak_true = Akima1DInterpolator(x, y, extrapolate=True)
  870. ak_false = Akima1DInterpolator(x, y, extrapolate=False)
  871. ak_none = Akima1DInterpolator(x, y, extrapolate=None)
  872. # None should default to False; extrapolated points are NaN.
  873. xp_assert_close(ak_false(x_ext), ak_none(x_ext), atol=1e-15)
  874. xp_assert_equal(ak_false(x_ext)[0:4], np.full(4, np.nan))
  875. xp_assert_equal(ak_false(x_ext)[-4:-1], np.full(3, np.nan))
  876. # Extrapolation on call and attribute should be equal.
  877. xp_assert_close(ak_false(x_ext, extrapolate=True), ak_true(x_ext), atol=1e-15)
  878. # Testing extrapoation to actual function.
  879. xp_assert_close(y_ext, ak_true(x_ext), atol=1e-15)
  880. def test_no_overflow(self):
  881. # check a large jump does not cause a float overflow
  882. x = np.arange(1, 10)
  883. y = 1.e6*np.sqrt(np.finfo(float).max)*np.heaviside(x-4, 0.5)
  884. ak1 = Akima1DInterpolator(x, y, method='makima')
  885. ak2 = Akima1DInterpolator(x, y, method='akima')
  886. y_eval1 = ak1(x)
  887. y_eval2 = ak2(x)
  888. assert np.isfinite(y_eval1).all()
  889. assert np.isfinite(y_eval2).all()
  890. @pytest.mark.parametrize("method", [Akima1DInterpolator, PchipInterpolator])
  891. def test_complex(method):
  892. # Complex-valued data deprecated
  893. x = np.arange(0., 11.)
  894. y = np.array([0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.])
  895. y = y - 2j*y
  896. msg = "real values"
  897. with pytest.raises(ValueError, match=msg):
  898. method(x, y)
  899. def test_concurrency(self):
  900. # Check that no segfaults appear with concurrent access to Akima1D
  901. x = np.linspace(-5, 5, 11)
  902. y = x**2
  903. x_ext = np.linspace(-10, 10, 17)
  904. ak = Akima1DInterpolator(x, y, extrapolate=True)
  905. def worker_fn(_, ak, x_ext):
  906. ak(x_ext)
  907. _run_concurrent_barrier(10, worker_fn, ak, x_ext)
  908. @make_xp_test_case(PPoly, BPoly)
  909. class TestPPolyCommon:
  910. # test basic functionality for PPoly and BPoly
  911. def test_sort_check(self, xp):
  912. c = xp.asarray([[1, 4], [2, 5], [3, 6]])
  913. x = xp.asarray([0, 1, 0.5])
  914. assert_raises(ValueError, PPoly, c, x)
  915. assert_raises(ValueError, BPoly, c, x)
  916. def test_ctor_c(self):
  917. # wrong shape: `c` must be at least 2D
  918. with assert_raises(ValueError):
  919. PPoly([1, 2], [0, 1])
  920. def test_extend(self, xp):
  921. # Test adding new points to the piecewise polynomial
  922. np.random.seed(1234)
  923. order = 3
  924. x = np.unique(np.r_[0, 10 * np.random.rand(30), 10])
  925. c = 2*np.random.rand(order+1, len(x)-1, 2, 3) - 1
  926. c, x = xp.asarray(c), xp.asarray(x)
  927. for cls in (PPoly, BPoly):
  928. pp = cls(c[:, :9, ...], x[:10])
  929. pp.extend(c[:, 9:, ...], x[10:])
  930. pp2 = cls(c[:, 10:, ...], x[10:])
  931. pp2.extend(c[:, :10, ...], x[:10])
  932. pp3 = cls(c, x)
  933. xp_assert_equal(pp.c, pp3.c)
  934. xp_assert_equal(pp.x, pp3.x)
  935. xp_assert_equal(pp2.c, pp3.c)
  936. xp_assert_equal(pp2.x, pp3.x)
  937. def test_extend_diff_orders(self, xp):
  938. # Test extending polynomial with different order one
  939. np.random.seed(1234)
  940. x = xp.linspace(0, 1, 6)
  941. c = xp.asarray(np.random.rand(2, 5))
  942. x2 = xp.linspace(1, 2, 6)
  943. c2 = xp.asarray(np.random.rand(4, 5))
  944. for cls in (PPoly, BPoly):
  945. pp1 = cls(c, x)
  946. pp2 = cls(c2, x2)
  947. pp_comb = cls(c, x)
  948. pp_comb.extend(c2, x2[1:])
  949. # NB. doesn't match to pp1 at the endpoint, because pp1 is not
  950. # continuous with pp2 as we took random coefs.
  951. xi1 = xp.linspace(0, 1, 300, endpoint=False)
  952. xi2 = xp.linspace(1, 2, 300)
  953. xp_assert_close(pp1(xi1), pp_comb(xi1))
  954. xp_assert_close(pp2(xi2), pp_comb(xi2))
  955. def test_extend_descending(self, xp):
  956. np.random.seed(0)
  957. order = 3
  958. x = np.sort(np.random.uniform(0, 10, 20))
  959. c = np.random.rand(order + 1, x.shape[0] - 1, 2, 3)
  960. c, x = xp.asarray(c), xp.asarray(x)
  961. for cls in (PPoly, BPoly):
  962. p = cls(c, x)
  963. p1 = cls(c[:, :9, ...], x[:10])
  964. p1.extend(c[:, 9:, ...], x[10:])
  965. p2 = cls(c[:, 10:, ...], x[10:])
  966. p2.extend(c[:, :10, ...], x[:10])
  967. xp_assert_equal(p1.c, p.c)
  968. xp_assert_equal(p1.x, p.x)
  969. xp_assert_equal(p2.c, p.c)
  970. xp_assert_equal(p2.x, p.x)
  971. def test_shape(self):
  972. np.random.seed(1234)
  973. c = np.random.rand(8, 12, 5, 6, 7)
  974. x = np.sort(np.random.rand(13))
  975. xp = np.random.rand(3, 4)
  976. for cls in (PPoly, BPoly):
  977. p = cls(c, x)
  978. assert p(xp).shape == (3, 4, 5, 6, 7)
  979. # 'scalars'
  980. for cls in (PPoly, BPoly):
  981. p = cls(c[..., 0, 0, 0], x)
  982. assert np.shape(p(0.5)) == ()
  983. assert np.shape(p(np.array(0.5))) == ()
  984. assert_raises(ValueError, p, np.array([[0.1, 0.2], [0.4]], dtype=object))
  985. def test_concurrency(self, xp):
  986. # Check that no segfaults appear with concurrent access to BPoly, PPoly
  987. c = np.random.rand(8, 12, 5, 6, 7)
  988. x = np.sort(np.random.rand(13))
  989. xpp = np.random.rand(3, 4)
  990. c, x, xpp = map(xp.asarray, (c, x, xpp))
  991. for cls in (PPoly, BPoly):
  992. interp = cls(c, x)
  993. def worker_fn(_, interp, xpp):
  994. interp(xpp)
  995. _run_concurrent_barrier(10, worker_fn, interp, xpp)
  996. def test_complex_coef(self):
  997. np.random.seed(12345)
  998. x = np.sort(np.random.random(13))
  999. c = np.random.random((8, 12)) * (1. + 0.3j)
  1000. c_re, c_im = c.real, c.imag
  1001. xp = np.random.random(5)
  1002. for cls in (PPoly, BPoly):
  1003. p, p_re, p_im = cls(c, x), cls(c_re, x), cls(c_im, x)
  1004. for nu in [0, 1, 2]:
  1005. xp_assert_close(p(xp, nu).real, p_re(xp, nu))
  1006. xp_assert_close(p(xp, nu).imag, p_im(xp, nu))
  1007. def test_axis(self, xp):
  1008. np.random.seed(12345)
  1009. c = np.random.rand(3, 4, 5, 6, 7, 8)
  1010. c_s = c.shape
  1011. xpp = np.random.random((1, 2))
  1012. c, xpp = xp.asarray(c), xp.asarray(xpp)
  1013. for axis in (0, 1, 2, 3):
  1014. m = c.shape[axis+1]
  1015. x = xp.asarray(np.sort(np.random.rand(m+1)))
  1016. for cls in (PPoly, BPoly):
  1017. p = cls(c, x, axis=axis)
  1018. assert p.c.shape == c_s[axis:axis+2] + c_s[:axis] + c_s[axis+2:]
  1019. res = p(xpp)
  1020. targ_shape = c_s[:axis] + xpp.shape + c_s[2+axis:]
  1021. assert res.shape == targ_shape
  1022. # deriv/antideriv does not drop the axis
  1023. for p1 in [cls(c, x, axis=axis).derivative(),
  1024. cls(c, x, axis=axis).derivative(2),
  1025. cls(c, x, axis=axis).antiderivative(),
  1026. cls(c, x, axis=axis).antiderivative(2)]:
  1027. assert p1.axis == p.axis
  1028. # c array needs two axes for the coefficients and intervals, so
  1029. # 0 <= axis < c.ndim-1; raise otherwise
  1030. for axis in (-1, 4, 5, 6):
  1031. for cls in (BPoly, PPoly):
  1032. assert_raises(ValueError, cls, **dict(c=c, x=x, axis=axis))
  1033. class TestPolySubclassing:
  1034. class P(PPoly):
  1035. pass
  1036. class B(BPoly):
  1037. pass
  1038. def _make_polynomials(self):
  1039. np.random.seed(1234)
  1040. x = np.sort(np.random.random(3))
  1041. c = np.random.random((4, 2))
  1042. return self.P(c, x), self.B(c, x)
  1043. def test_derivative(self):
  1044. pp, bp = self._make_polynomials()
  1045. for p in (pp, bp):
  1046. pd = p.derivative()
  1047. assert p.__class__ == pd.__class__
  1048. ppa = pp.antiderivative()
  1049. assert pp.__class__ == ppa.__class__
  1050. def test_from_spline(self):
  1051. np.random.seed(1234)
  1052. x = np.sort(np.r_[0, np.random.rand(11), 1])
  1053. y = np.random.rand(len(x))
  1054. spl = splrep(x, y, s=0)
  1055. pp = self.P.from_spline(spl)
  1056. assert pp.__class__ == self.P
  1057. def test_conversions(self):
  1058. pp, bp = self._make_polynomials()
  1059. pp1 = self.P.from_bernstein_basis(bp)
  1060. assert pp1.__class__ == self.P
  1061. bp1 = self.B.from_power_basis(pp)
  1062. assert bp1.__class__ == self.B
  1063. def test_from_derivatives(self):
  1064. x = [0, 1, 2]
  1065. y = [[1], [2], [3]]
  1066. bp = self.B.from_derivatives(x, y)
  1067. assert bp.__class__ == self.B
  1068. @make_xp_test_case(PPoly)
  1069. class TestPPoly:
  1070. def test_simple(self, xp):
  1071. c = xp.asarray([[1, 4], [2, 5], [3, 6]])
  1072. x = xp.asarray([0, 0.5, 1])
  1073. p = PPoly(c, x)
  1074. xp_assert_close(p(0.3), xp.asarray(1*0.3**2 + 2*0.3 + 3, dtype=xp.float64))
  1075. xp_assert_close(
  1076. p(0.7), xp.asarray(4*(0.7-0.5)**2 + 5*(0.7-0.5) + 6, dtype=xp.float64)
  1077. )
  1078. def test_periodic(self, xp):
  1079. c = xp.asarray([[1, 4], [2, 5], [3, 6]])
  1080. x = xp.asarray([0, 0.5, 1])
  1081. p = PPoly(c, x, extrapolate='periodic')
  1082. xp_assert_close(p(1.3),
  1083. xp.asarray(1 * 0.3 ** 2 + 2 * 0.3 + 3, dtype=xp.float64))
  1084. xp_assert_close(
  1085. p(-0.3),
  1086. xp.asarray(4 * (0.7 - 0.5) ** 2 + 5 * (0.7 - 0.5) + 6, dtype=xp.float64)
  1087. )
  1088. xp_assert_close(p(1.3, 1), xp.asarray(2 * 0.3 + 2, dtype=xp.float64))
  1089. xp_assert_close(p(-0.3, 1), xp.asarray(8 * (0.7 - 0.5) + 5, dtype=xp.float64))
  1090. def test_read_only(self):
  1091. c = np.array([[1, 4], [2, 5], [3, 6]])
  1092. x = np.array([0, 0.5, 1])
  1093. xnew = np.array([0, 0.1, 0.2])
  1094. PPoly(c, x, extrapolate='periodic')
  1095. for writeable in (True, False):
  1096. x.flags.writeable = writeable
  1097. c.flags.writeable = writeable
  1098. f = PPoly(c, x)
  1099. vals = f(xnew)
  1100. assert np.isfinite(vals).all()
  1101. def test_descending(self):
  1102. def binom_matrix(power):
  1103. n = np.arange(power + 1).reshape(-1, 1)
  1104. k = np.arange(power + 1)
  1105. B = binom(n, k)
  1106. return B[::-1, ::-1]
  1107. rng = np.random.RandomState(0)
  1108. power = 3
  1109. for m in [10, 20, 30]:
  1110. x = np.sort(rng.uniform(0, 10, m + 1))
  1111. ca = rng.uniform(-2, 2, size=(power + 1, m))
  1112. h = np.diff(x)
  1113. h_powers = h[None, :] ** np.arange(power + 1)[::-1, None]
  1114. B = binom_matrix(power)
  1115. cap = ca * h_powers
  1116. cdp = np.dot(B.T, cap)
  1117. cd = cdp / h_powers
  1118. pa = PPoly(ca, x, extrapolate=True)
  1119. pd = PPoly(cd[:, ::-1], x[::-1], extrapolate=True)
  1120. x_test = rng.uniform(-10, 20, 100)
  1121. xp_assert_close(pa(x_test), pd(x_test), rtol=1e-13)
  1122. xp_assert_close(pa(x_test, 1), pd(x_test, 1), rtol=1e-13)
  1123. pa_d = pa.derivative()
  1124. pd_d = pd.derivative()
  1125. xp_assert_close(pa_d(x_test), pd_d(x_test), rtol=1e-13)
  1126. # Antiderivatives won't be equal because fixing continuity is
  1127. # done in the reverse order, but surely the differences should be
  1128. # equal.
  1129. pa_i = pa.antiderivative()
  1130. pd_i = pd.antiderivative()
  1131. for a, b in rng.uniform(-10, 20, (5, 2)):
  1132. int_a = pa.integrate(a, b)
  1133. int_d = pd.integrate(a, b)
  1134. xp_assert_close(int_a, int_d, rtol=1e-13)
  1135. xp_assert_close(pa_i(b) - pa_i(a), pd_i(b) - pd_i(a),
  1136. rtol=1e-13)
  1137. roots_d = pd.roots()
  1138. roots_a = pa.roots()
  1139. xp_assert_close(roots_a, np.sort(roots_d), rtol=1e-12)
  1140. def test_multi_shape(self, xp):
  1141. c = np.random.rand(6, 2, 1, 2, 3)
  1142. x = np.array([0, 0.5, 1])
  1143. p = PPoly(c, x)
  1144. assert p.x.shape == x.shape
  1145. assert p.c.shape == c.shape
  1146. assert p(0.3).shape == c.shape[2:]
  1147. assert p(np.random.rand(5, 6)).shape == (5, 6) + c.shape[2:]
  1148. dp = p.derivative()
  1149. assert dp.c.shape == (5, 2, 1, 2, 3)
  1150. ip = p.antiderivative()
  1151. assert ip.c.shape == (7, 2, 1, 2, 3)
  1152. def test_construct_fast(self):
  1153. np.random.seed(1234)
  1154. c = np.array([[1, 4], [2, 5], [3, 6]], dtype=float)
  1155. x = np.array([0, 0.5, 1])
  1156. p = PPoly.construct_fast(c, x)
  1157. xp_assert_close(p(0.3), np.asarray(1*0.3**2 + 2*0.3 + 3))
  1158. xp_assert_close(p(0.7), np.asarray(4*(0.7-0.5)**2 + 5*(0.7-0.5) + 6))
  1159. def test_vs_alternative_implementations(self):
  1160. rng = np.random.RandomState(1234)
  1161. c = rng.rand(3, 12, 22)
  1162. x = np.sort(np.r_[0, rng.rand(11), 1])
  1163. p = PPoly(c, x)
  1164. xp = np.r_[0.3, 0.5, 0.33, 0.6]
  1165. expected = _ppoly_eval_1(c, x, xp)
  1166. xp_assert_close(p(xp), expected)
  1167. expected = _ppoly_eval_2(c[:,:,0], x, xp)
  1168. xp_assert_close(p(xp)[:, 0], expected)
  1169. def test_from_spline(self):
  1170. rng = np.random.RandomState(1234)
  1171. x = np.sort(np.r_[0, rng.rand(11), 1])
  1172. y = rng.rand(len(x))
  1173. spl = splrep(x, y, s=0)
  1174. pp = PPoly.from_spline(spl)
  1175. xi = np.linspace(0, 1, 200)
  1176. xp_assert_close(pp(xi), splev(xi, spl))
  1177. # make sure .from_spline accepts BSpline objects
  1178. b = BSpline(*spl)
  1179. ppp = PPoly.from_spline(b)
  1180. xp_assert_close(ppp(xi), b(xi))
  1181. # BSpline's extrapolate attribute propagates unless overridden
  1182. t, c, k = spl
  1183. for extrap in (None, True, False):
  1184. b = BSpline(t, c, k, extrapolate=extrap)
  1185. p = PPoly.from_spline(b)
  1186. assert p.extrapolate == b.extrapolate
  1187. def test_from_spline_2(self, xp):
  1188. # BSpline namespace propagates to PPoly
  1189. rng = np.random.RandomState(1234)
  1190. x = np.sort(np.r_[0, rng.rand(11), 1])
  1191. y = rng.rand(len(x))
  1192. t, c, k = splrep(x, y, s=0)
  1193. spl = BSpline(xp.asarray(t), xp.asarray(c), k)
  1194. pp = PPoly.from_spline(spl)
  1195. xi = xp.linspace(0, 1, 11)
  1196. xp_assert_close(pp(xi), spl(xi))
  1197. def test_derivative_simple(self, xp):
  1198. np.random.seed(1234)
  1199. c = xp.asarray([[4, 3, 2, 1]]).T
  1200. dc = xp.asarray([[3*4, 2*3, 2]]).T
  1201. ddc = xp.asarray([[2*3*4, 1*2*3]]).T
  1202. x = xp.asarray([0, 1])
  1203. pp = PPoly(c, x)
  1204. dpp = PPoly(dc, x)
  1205. ddpp = PPoly(ddc, x)
  1206. xp_assert_close(pp.derivative().c, dpp.c)
  1207. xp_assert_close(pp.derivative(2).c, ddpp.c)
  1208. def test_derivative_eval(self):
  1209. rng = np.random.RandomState(1234)
  1210. x = np.sort(np.r_[0, rng.rand(11), 1])
  1211. y = rng.rand(len(x))
  1212. spl = splrep(x, y, s=0)
  1213. pp = PPoly.from_spline(spl)
  1214. xi = np.linspace(0, 1, 200)
  1215. for dx in range(0, 3):
  1216. xp_assert_close(pp(xi, dx), splev(xi, spl, dx))
  1217. def test_derivative(self):
  1218. rng = np.random.RandomState(1234)
  1219. x = np.sort(np.r_[0, rng.rand(11), 1])
  1220. y = rng.rand(len(x))
  1221. spl = splrep(x, y, s=0, k=5)
  1222. pp = PPoly.from_spline(spl)
  1223. xi = np.linspace(0, 1, 200)
  1224. for dx in range(0, 10):
  1225. xp_assert_close(pp(xi, dx), pp.derivative(dx)(xi), err_msg=f"dx={dx}")
  1226. def test_antiderivative_of_constant(self):
  1227. # https://github.com/scipy/scipy/issues/4216
  1228. p = PPoly([[1.]], [0, 1])
  1229. xp_assert_equal(p.antiderivative().c, PPoly([[1], [0]], [0, 1]).c)
  1230. xp_assert_equal(p.antiderivative().x, PPoly([[1], [0]], [0, 1]).x)
  1231. def test_antiderivative_regression_4355(self):
  1232. # https://github.com/scipy/scipy/issues/4355
  1233. p = PPoly([[1., 0.5]], [0, 1, 2])
  1234. q = p.antiderivative()
  1235. xp_assert_equal(q.c, [[1, 0.5], [0, 1]])
  1236. xp_assert_equal(q.x, [0.0, 1, 2])
  1237. xp_assert_close(p.integrate(0, 2), np.asarray(1.5))
  1238. xp_assert_close(np.asarray(q(2) - q(0)),
  1239. np.asarray(1.5))
  1240. def test_antiderivative_simple(self, xp):
  1241. # [ p1(x) = 3*x**2 + 2*x + 1,
  1242. # p2(x) = 1.6875]
  1243. c = xp.asarray([[3, 2, 1], [0, 0, 1.6875]], dtype=xp.float64).T
  1244. # [ pp1(x) = x**3 + x**2 + x,
  1245. # pp2(x) = 1.6875*(x - 0.25) + pp1(0.25)]
  1246. ic = xp.asarray([[1, 1, 1, 0], [0, 0, 1.6875, 0.328125]], dtype=xp.float64).T
  1247. # [ ppp1(x) = (1/4)*x**4 + (1/3)*x**3 + (1/2)*x**2,
  1248. # ppp2(x) = (1.6875/2)*(x - 0.25)**2 + pp1(0.25)*x + ppp1(0.25)]
  1249. iic = xp.asarray([[1/4, 1/3, 1/2, 0, 0],
  1250. [0, 0, 1.6875/2, 0.328125, 0.037434895833333336]],
  1251. dtype=xp.float64
  1252. ).T
  1253. x = xp.asarray([0, 0.25, 1], dtype=xp.float64)
  1254. pp = PPoly(c, x)
  1255. ipp = pp.antiderivative()
  1256. iipp = pp.antiderivative(2)
  1257. iipp2 = ipp.antiderivative()
  1258. xp_assert_close(ipp.x, x)
  1259. xp_assert_close(ipp.c.T, ic.T)
  1260. xp_assert_close(iipp.c.T, iic.T)
  1261. xp_assert_close(iipp2.c.T, iic.T)
  1262. def test_antiderivative_vs_derivative(self):
  1263. rng = np.random.RandomState(1234)
  1264. x = np.linspace(0, 1, 30)**2
  1265. y = rng.rand(len(x))
  1266. spl = splrep(x, y, s=0, k=5)
  1267. pp = PPoly.from_spline(spl)
  1268. for dx in range(0, 10):
  1269. ipp = pp.antiderivative(dx)
  1270. # check that derivative is inverse op
  1271. pp2 = ipp.derivative(dx)
  1272. xp_assert_close(pp.c, pp2.c)
  1273. # check continuity
  1274. for k in range(dx):
  1275. pp2 = ipp.derivative(k)
  1276. r = 1e-13
  1277. endpoint = r*pp2.x[:-1] + (1 - r)*pp2.x[1:]
  1278. xp_assert_close(
  1279. pp2(pp2.x[1:]), pp2(endpoint), rtol=1e-7, err_msg=f"dx={dx} k={k}"
  1280. )
  1281. def test_antiderivative_vs_spline(self):
  1282. rng = np.random.RandomState(1234)
  1283. x = np.sort(np.r_[0, rng.rand(11), 1])
  1284. y = rng.rand(len(x))
  1285. spl = splrep(x, y, s=0, k=5)
  1286. pp = PPoly.from_spline(spl)
  1287. for dx in range(0, 10):
  1288. pp2 = pp.antiderivative(dx)
  1289. spl2 = splantider(spl, dx)
  1290. xi = np.linspace(0, 1, 200)
  1291. xp_assert_close(pp2(xi), splev(xi, spl2),
  1292. rtol=1e-7)
  1293. def test_antiderivative_continuity(self):
  1294. c = np.array([[2, 1, 2, 2], [2, 1, 3, 3]]).T
  1295. x = np.array([0, 0.5, 1])
  1296. p = PPoly(c, x)
  1297. ip = p.antiderivative()
  1298. # check continuity
  1299. xp_assert_close(ip(0.5 - 1e-9), ip(0.5 + 1e-9), rtol=1e-8)
  1300. # check that only lowest order coefficients were changed
  1301. p2 = ip.derivative()
  1302. xp_assert_close(p2.c, p.c)
  1303. def test_integrate(self):
  1304. rng = np.random.RandomState(1234)
  1305. x = np.sort(np.r_[0, rng.rand(11), 1])
  1306. y = rng.rand(len(x))
  1307. spl = splrep(x, y, s=0, k=5)
  1308. pp = PPoly.from_spline(spl)
  1309. a, b = 0.3, 0.9
  1310. ig = pp.integrate(a, b)
  1311. ipp = pp.antiderivative()
  1312. xp_assert_close(ig, ipp(b) - ipp(a), check_0d=False)
  1313. xp_assert_close(ig, splint(a, b, spl), check_0d=False)
  1314. a, b = -0.3, 0.9
  1315. ig = pp.integrate(a, b, extrapolate=True)
  1316. xp_assert_close(ig, ipp(b) - ipp(a), check_0d=False)
  1317. assert np.isnan(pp.integrate(a, b, extrapolate=False)).all()
  1318. def test_integrate_readonly(self):
  1319. x = np.array([1, 2, 4])
  1320. c = np.array([[0., 0.], [-1., -1.], [2., -0.], [1., 2.]])
  1321. for writeable in (True, False):
  1322. x.flags.writeable = writeable
  1323. P = PPoly(c, x)
  1324. vals = P.integrate(1, 4)
  1325. assert np.isfinite(vals).all()
  1326. def test_integrate_periodic(self):
  1327. x = np.array([1, 2, 4])
  1328. c = np.array([[0., 0.], [-1., -1.], [2., -0.], [1., 2.]])
  1329. P = PPoly(c, x, extrapolate='periodic')
  1330. I = P.antiderivative()
  1331. period_int = np.asarray(I(4) - I(1))
  1332. xp_assert_close(P.integrate(1, 4), period_int)
  1333. xp_assert_close(P.integrate(-10, -7), period_int)
  1334. xp_assert_close(P.integrate(-10, -4), np.asarray(2 * period_int))
  1335. xp_assert_close(P.integrate(1.5, 2.5),
  1336. np.asarray(I(2.5) - I(1.5)))
  1337. xp_assert_close(P.integrate(3.5, 5),
  1338. np.asarray(I(2) - I(1) + I(4) - I(3.5)))
  1339. xp_assert_close(P.integrate(3.5 + 12, 5 + 12),
  1340. np.asarray(I(2) - I(1) + I(4) - I(3.5)))
  1341. xp_assert_close(P.integrate(3.5, 5 + 12),
  1342. np.asarray(I(2) - I(1) + I(4) - I(3.5) + 4 * period_int))
  1343. xp_assert_close(P.integrate(0, -1),
  1344. np.asarray(I(2) - I(3)))
  1345. xp_assert_close(P.integrate(-9, -10),
  1346. np.asarray(I(2) - I(3)))
  1347. xp_assert_close(P.integrate(0, -10),
  1348. np.asarray(I(2) - I(3) - 3 * period_int))
  1349. def test_roots(self):
  1350. x = np.linspace(0, 1, 31)**2
  1351. y = np.sin(30*x)
  1352. spl = splrep(x, y, s=0, k=3)
  1353. pp = PPoly.from_spline(spl)
  1354. r = pp.roots()
  1355. r = r[(r >= 0 - 1e-15) & (r <= 1 + 1e-15)]
  1356. xp_assert_close(r, sproot(spl), atol=1e-15)
  1357. def test_roots_idzero(self):
  1358. # Roots for piecewise polynomials with identically zero
  1359. # sections.
  1360. c = np.array([[-1, 0.25], [0, 0], [-1, 0.25]]).T
  1361. x = np.array([0, 0.4, 0.6, 1.0])
  1362. pp = PPoly(c, x)
  1363. xp_assert_equal(pp.roots(),
  1364. [0.25, 0.4, np.nan, 0.6 + 0.25])
  1365. # ditto for p.solve(const) with sections identically equal const
  1366. const = 2.
  1367. c1 = c.copy()
  1368. c1[1, :] += const
  1369. pp1 = PPoly(c1, x)
  1370. xp_assert_equal(pp1.solve(const),
  1371. [0.25, 0.4, np.nan, 0.6 + 0.25])
  1372. def test_roots_all_zero(self):
  1373. # test the code path for the polynomial being identically zero everywhere
  1374. c = [[0], [0]]
  1375. x = [0, 1]
  1376. p = PPoly(c, x)
  1377. xp_assert_equal(p.roots(), [0, np.nan])
  1378. xp_assert_equal(p.solve(0), [0, np.nan])
  1379. xp_assert_equal(p.solve(1), [])
  1380. c = [[0, 0], [0, 0]]
  1381. x = [0, 1, 2]
  1382. p = PPoly(c, x)
  1383. xp_assert_equal(p.roots(), [0, np.nan, 1, np.nan])
  1384. xp_assert_equal(p.solve(0), [0, np.nan, 1, np.nan])
  1385. xp_assert_equal(p.solve(1), [])
  1386. def test_roots_repeated(self):
  1387. # Check roots repeated in multiple sections are reported only
  1388. # once.
  1389. # [(x + 1)**2 - 1, -x**2] ; x == 0 is a repeated root
  1390. c = np.array([[1, 0, -1], [-1, 0, 0]]).T
  1391. x = np.array([-1, 0, 1])
  1392. pp = PPoly(c, x)
  1393. xp_assert_equal(pp.roots(), np.asarray([-2.0, 0.0]))
  1394. xp_assert_equal(pp.roots(extrapolate=False), np.asarray([0.0]))
  1395. def test_roots_discont(self):
  1396. # Check that a discontinuity across zero is reported as root
  1397. c = np.array([[1], [-1]]).T
  1398. x = np.array([0, 0.5, 1])
  1399. pp = PPoly(c, x)
  1400. xp_assert_equal(pp.roots(), np.asarray([0.5]))
  1401. xp_assert_equal(pp.roots(discontinuity=False), np.asarray([]))
  1402. # ditto for a discontinuity across y:
  1403. xp_assert_equal(pp.solve(0.5), np.asarray([0.5]))
  1404. xp_assert_equal(pp.solve(0.5, discontinuity=False), np.asarray([]))
  1405. xp_assert_equal(pp.solve(1.5), np.asarray([]))
  1406. xp_assert_equal(pp.solve(1.5, discontinuity=False), np.asarray([]))
  1407. def test_roots_random(self):
  1408. # Check high-order polynomials with random coefficients
  1409. rng = np.random.RandomState(1234)
  1410. num = 0
  1411. for extrapolate in (True, False):
  1412. for order in range(0, 20):
  1413. x = np.unique(np.r_[0, 10 * rng.rand(30), 10])
  1414. c = 2*rng.rand(order+1, len(x)-1, 2, 3) - 1
  1415. pp = PPoly(c, x)
  1416. for y in [0, rng.random()]:
  1417. r = pp.solve(y, discontinuity=False, extrapolate=extrapolate)
  1418. for i in range(2):
  1419. for j in range(3):
  1420. rr = r[i,j]
  1421. if rr.size > 0:
  1422. # Check that the reported roots indeed are roots
  1423. num += rr.size
  1424. val = pp(rr, extrapolate=extrapolate)[:,i,j]
  1425. cmpval = pp(rr, nu=1,
  1426. extrapolate=extrapolate)[:,i,j]
  1427. msg = f"({extrapolate!r}) r = {repr(rr)}"
  1428. xp_assert_close((val-y) / cmpval, np.asarray(0.0),
  1429. atol=1e-7,
  1430. err_msg=msg, check_shape=False)
  1431. # Check that we checked a number of roots
  1432. assert num > 100, repr(num)
  1433. def test_roots_croots(self):
  1434. # Test the complex root finding algorithm
  1435. rng = np.random.RandomState(1234)
  1436. for k in range(1, 15):
  1437. c = rng.rand(k, 1, 130)
  1438. if k == 3:
  1439. # add a case with zero discriminant
  1440. c[:,0,0] = 1, 2, 1
  1441. for y in [0, rng.random()]:
  1442. w = np.empty(c.shape, dtype=complex)
  1443. _ppoly._croots_poly1(c, w, y)
  1444. if k == 1:
  1445. assert np.isnan(w).all()
  1446. continue
  1447. res = -y
  1448. cres = 0
  1449. for i in range(k):
  1450. res += c[i,None] * w**(k-1-i)
  1451. cres += abs(c[i,None] * w**(k-1-i))
  1452. with np.errstate(invalid='ignore'):
  1453. res /= cres
  1454. res = res.ravel()
  1455. res = res[~np.isnan(res)]
  1456. xp_assert_close(res, np.zeros_like(res), atol=1e-10)
  1457. def test_extrapolate_attr(self):
  1458. # [ 1 - x**2 ]
  1459. c = np.array([[-1, 0, 1]]).T
  1460. x = np.array([0, 1])
  1461. for extrapolate in [True, False, None]:
  1462. pp = PPoly(c, x, extrapolate=extrapolate)
  1463. pp_d = pp.derivative()
  1464. pp_i = pp.antiderivative()
  1465. if extrapolate is False:
  1466. assert np.isnan(pp([-0.1, 1.1])).all()
  1467. assert np.isnan(pp_i([-0.1, 1.1])).all()
  1468. assert np.isnan(pp_d([-0.1, 1.1])).all()
  1469. assert pp.roots() == [1]
  1470. else:
  1471. xp_assert_close(pp([-0.1, 1.1]), [1-0.1**2, 1-1.1**2])
  1472. assert not np.isnan(pp_i([-0.1, 1.1])).any()
  1473. assert not np.isnan(pp_d([-0.1, 1.1])).any()
  1474. xp_assert_close(pp.roots(), np.asarray([1.0, -1.0]))
  1475. @make_xp_test_case(BPoly)
  1476. class TestBPoly:
  1477. def test_simple(self, xp):
  1478. x = xp.asarray([0, 1])
  1479. c = xp.asarray([[3]])
  1480. bp = BPoly(c, x)
  1481. xp_assert_close(bp(0.1), xp.asarray(3., dtype=xp.float64))
  1482. def test_simple2(self, xp):
  1483. x = xp.asarray([0, 1])
  1484. c = xp.asarray([[3], [1]])
  1485. bp = BPoly(c, x) # 3*(1-x) + 1*x
  1486. xp_assert_close(bp(0.1), xp.asarray(3*0.9 + 1.*0.1, dtype=xp.float64))
  1487. def test_simple3(self, xp):
  1488. x = xp.asarray([0, 1])
  1489. c = xp.asarray([[3], [1], [4]])
  1490. bp = BPoly(c, x) # 3 * (1-x)**2 + 2 * x (1-x) + 4 * x**2
  1491. xp_assert_close(
  1492. bp(0.2),
  1493. xp.asarray(3 * 0.8*0.8 + 1 * 2*0.2*0.8 + 4 * 0.2*0.2, dtype=xp.float64)
  1494. )
  1495. def test_simple4(self, xp):
  1496. x = xp.asarray([0, 1])
  1497. c = xp.asarray([[1], [1], [1], [2]])
  1498. bp = BPoly(c, x)
  1499. xp_assert_close(bp(0.3),
  1500. xp.asarray( 0.7**3 +
  1501. 3 * 0.7**2 * 0.3 +
  1502. 3 * 0.7 * 0.3**2 +
  1503. 2 * 0.3**3, dtype=xp.float64)
  1504. )
  1505. def test_simple5(self, xp):
  1506. x = xp.asarray([0, 1])
  1507. c = xp.asarray([[1], [1], [8], [2], [1]])
  1508. bp = BPoly(c, x)
  1509. xp_assert_close(bp(0.3),
  1510. xp.asarray( 0.7**4 +
  1511. 4 * 0.7**3 * 0.3 +
  1512. 8 * 6 * 0.7**2 * 0.3**2 +
  1513. 2 * 4 * 0.7 * 0.3**3 +
  1514. 0.3**4, dtype=xp.float64)
  1515. )
  1516. def test_periodic(self, xp):
  1517. x = xp.asarray([0, 1, 3])
  1518. c = xp.asarray([[3, 0], [0, 0], [0, 2]])
  1519. # [3*(1-x)**2, 2*((x-1)/2)**2]
  1520. bp = BPoly(c, x, extrapolate='periodic')
  1521. xp_assert_close(bp(3.4), xp.asarray(3 * 0.6**2, dtype=xp.float64))
  1522. xp_assert_close(bp(-1.3), xp.asarray(2 * (0.7/2)**2, dtype=xp.float64))
  1523. xp_assert_close(bp(3.4, 1), xp.asarray(-6 * 0.6, dtype=xp.float64))
  1524. xp_assert_close(bp(-1.3, 1), xp.asarray(2 * (0.7/2), dtype=xp.float64))
  1525. def test_descending(self):
  1526. rng = np.random.RandomState(0)
  1527. power = 3
  1528. for m in [10, 20, 30]:
  1529. x = np.sort(rng.uniform(0, 10, m + 1))
  1530. ca = rng.uniform(-0.1, 0.1, size=(power + 1, m))
  1531. # We need only to flip coefficients to get it right!
  1532. cd = ca[::-1].copy()
  1533. pa = BPoly(ca, x, extrapolate=True)
  1534. pd = BPoly(cd[:, ::-1], x[::-1], extrapolate=True)
  1535. x_test = rng.uniform(-10, 20, 100)
  1536. xp_assert_close(pa(x_test), pd(x_test), rtol=1e-13)
  1537. xp_assert_close(pa(x_test, 1), pd(x_test, 1), rtol=1e-13)
  1538. pa_d = pa.derivative()
  1539. pd_d = pd.derivative()
  1540. xp_assert_close(pa_d(x_test), pd_d(x_test), rtol=1e-13)
  1541. # Antiderivatives won't be equal because fixing continuity is
  1542. # done in the reverse order, but surely the differences should be
  1543. # equal.
  1544. pa_i = pa.antiderivative()
  1545. pd_i = pd.antiderivative()
  1546. for a, b in rng.uniform(-10, 20, (5, 2)):
  1547. int_a = pa.integrate(a, b)
  1548. int_d = pd.integrate(a, b)
  1549. xp_assert_close(int_a, int_d, rtol=1e-12)
  1550. xp_assert_close(pa_i(b) - pa_i(a), pd_i(b) - pd_i(a),
  1551. rtol=1e-12)
  1552. def test_multi_shape(self):
  1553. rng = np.random.RandomState(1234)
  1554. c = rng.rand(6, 2, 1, 2, 3)
  1555. x = np.array([0, 0.5, 1])
  1556. p = BPoly(c, x)
  1557. assert p.x.shape == x.shape
  1558. assert p.c.shape == c.shape
  1559. assert p(0.3).shape == c.shape[2:]
  1560. assert p(rng.rand(5, 6)).shape == (5, 6) + c.shape[2:]
  1561. dp = p.derivative()
  1562. assert dp.c.shape == (5, 2, 1, 2, 3)
  1563. def test_interval_length(self, xp):
  1564. x = xp.asarray([0, 2])
  1565. c = xp.asarray([[3], [1], [4]])
  1566. bp = BPoly(c, x)
  1567. xval = 0.1
  1568. s = xval / 2 # s = (x - xa) / (xb - xa)
  1569. xp_assert_close(
  1570. bp(xval),
  1571. xp.asarray(3 * (1-s)*(1-s) + 1 * 2*s*(1-s) + 4 * s*s, dtype=xp.float64)
  1572. )
  1573. def test_two_intervals(self, xp):
  1574. x = xp.asarray([0, 1, 3])
  1575. c = xp.asarray([[3, 0], [0, 0], [0, 2]])
  1576. bp = BPoly(c, x) # [3*(1-x)**2, 2*((x-1)/2)**2]
  1577. xp_assert_close(bp(0.4), xp.asarray(3 * 0.6*0.6, dtype=xp.float64))
  1578. xp_assert_close(bp(1.7), xp.asarray(2 * (0.7/2)**2, dtype=xp.float64))
  1579. def test_extrapolate_attr(self):
  1580. x = [0, 2]
  1581. c = [[3], [1], [4]]
  1582. bp = BPoly(c, x)
  1583. for extrapolate in (True, False, None):
  1584. bp = BPoly(c, x, extrapolate=extrapolate)
  1585. bp_d = bp.derivative()
  1586. if extrapolate is False:
  1587. assert np.isnan(bp([-0.1, 2.1])).all()
  1588. assert np.isnan(bp_d([-0.1, 2.1])).all()
  1589. else:
  1590. assert not np.isnan(bp([-0.1, 2.1])).any()
  1591. assert not np.isnan(bp_d([-0.1, 2.1])).any()
  1592. @make_xp_test_case(BPoly)
  1593. class TestBPolyCalculus:
  1594. def test_derivative(self, xp):
  1595. x = xp.asarray([0, 1, 3])
  1596. c = xp.asarray([[3, 0], [0, 0], [0, 2]])
  1597. bp = BPoly(c, x) # [3*(1-x)**2, 2*((x-1)/2)**2]
  1598. bp_der = bp.derivative()
  1599. xp_assert_close(bp_der(0.4), xp.asarray(-6*(0.6), dtype=xp.float64))
  1600. xp_assert_close(bp_der(1.7), xp.asarray(0.7, dtype=xp.float64))
  1601. # derivatives in-place
  1602. xp_assert_close(xp.stack([bp(0.4, nu) for nu in [1, 2, 3]]),
  1603. xp.asarray([-6*(1-0.4), 6., 0.], dtype=xp.float64)
  1604. )
  1605. xp_assert_close(xp.stack([bp(1.7, nu) for nu in [1, 2, 3]]),
  1606. xp.asarray([0.7, 1., 0], dtype=xp.float64)
  1607. )
  1608. def test_derivative_ppoly(self, xp):
  1609. # make sure it's consistent w/ power basis
  1610. rng = np.random.RandomState(1234)
  1611. m, k = 5, 8 # number of intervals, order
  1612. x = np.sort(rng.random(m))
  1613. c = rng.random((k, m-1))
  1614. c, x = xp.asarray(c), xp.asarray(x)
  1615. bp = BPoly(c, x)
  1616. pp = PPoly.from_bernstein_basis(bp)
  1617. for d in range(k):
  1618. bp = bp.derivative()
  1619. pp = pp.derivative()
  1620. xpp = xp.linspace(x[0], x[-1], 21)
  1621. xp_assert_close(bp(xpp), pp(xpp))
  1622. def test_deriv_inplace(self):
  1623. rng = np.random.RandomState(1234)
  1624. m, k = 5, 8 # number of intervals, order
  1625. x = np.sort(rng.random(m))
  1626. c = rng.random((k, m-1))
  1627. # test both real and complex coefficients
  1628. for cc in [c.copy(), c*(1. + 2.j)]:
  1629. bp = BPoly(cc, x)
  1630. xpp = np.linspace(x[0], x[-1], 21)
  1631. for i in range(k):
  1632. xp_assert_close(bp(xpp, i), bp.derivative(i)(xpp))
  1633. def test_antiderivative_simple(self, xp):
  1634. # f(x) = x for x \in [0, 1),
  1635. # (x-1)/2 for x \in [1, 3]
  1636. #
  1637. # antiderivative is then
  1638. # F(x) = x**2 / 2 for x \in [0, 1),
  1639. # 0.5*x*(x/2 - 1) + A for x \in [1, 3]
  1640. # where A = 3/4 for continuity at x = 1.
  1641. x = xp.asarray([0, 1, 3])
  1642. c = xp.asarray([[0, 0], [1, 1]])
  1643. bp = BPoly(c, x)
  1644. bi = bp.antiderivative()
  1645. xx = xp.linspace(0, 3, 11, dtype=xp.float64)
  1646. xp_assert_close(bi(xx),
  1647. xp.where(xx < 1, xx**2 / 2.,
  1648. 0.5 * xx * (xx/2. - 1) + 3./4),
  1649. atol=1e-12, rtol=1e-12)
  1650. def test_der_antider(self):
  1651. rng = np.random.RandomState(1234)
  1652. x = np.sort(rng.random(11))
  1653. c = rng.random((4, 10, 2, 3))
  1654. bp = BPoly(c, x)
  1655. xx = np.linspace(x[0], x[-1], 100)
  1656. xp_assert_close(bp.antiderivative().derivative()(xx),
  1657. bp(xx), atol=1e-12, rtol=1e-12)
  1658. def test_antider_ppoly(self):
  1659. rng = np.random.RandomState(1234)
  1660. x = np.sort(rng.random(11))
  1661. c = rng.random((4, 10, 2, 3))
  1662. bp = BPoly(c, x)
  1663. pp = PPoly.from_bernstein_basis(bp)
  1664. xx = np.linspace(x[0], x[-1], 10)
  1665. xp_assert_close(bp.antiderivative(2)(xx),
  1666. pp.antiderivative(2)(xx), atol=1e-12, rtol=1e-12)
  1667. def test_antider_continuous(self):
  1668. rng = np.random.RandomState(1234)
  1669. x = np.sort(rng.random(11))
  1670. c = rng.random((4, 10))
  1671. bp = BPoly(c, x).antiderivative()
  1672. xx = bp.x[1:-1]
  1673. xp_assert_close(bp(xx - 1e-14),
  1674. bp(xx + 1e-14), atol=1e-12, rtol=1e-12)
  1675. def test_integrate(self, xp):
  1676. rng = np.random.RandomState(1234)
  1677. x = np.sort(rng.random(11))
  1678. c = rng.random((4, 10))
  1679. x, c = xp.asarray(x), xp.asarray(c)
  1680. bp = BPoly(c, x)
  1681. pp = PPoly.from_bernstein_basis(bp)
  1682. xp_assert_close(bp.integrate(0, 1),
  1683. pp.integrate(0, 1), atol=1e-12, rtol=1e-12, check_0d=False)
  1684. def test_integrate_extrap(self):
  1685. c = [[1]]
  1686. x = [0, 1]
  1687. b = BPoly(c, x)
  1688. # default is extrapolate=True
  1689. xp_assert_close(b.integrate(0, 2), np.asarray(2.),
  1690. atol=1e-14, check_0d=False)
  1691. # .integrate argument overrides self.extrapolate
  1692. b1 = BPoly(c, x, extrapolate=False)
  1693. assert np.isnan(b1.integrate(0, 2))
  1694. xp_assert_close(b1.integrate(0, 2, extrapolate=True),
  1695. np.asarray(2.), atol=1e-14, check_0d=False)
  1696. def test_integrate_periodic(self, xp):
  1697. x = xp.asarray([1, 2, 4])
  1698. c = xp.asarray([[0., 0.], [-1., -1.], [2., -0.], [1., 2.]])
  1699. P = BPoly.from_power_basis(PPoly(c, x), extrapolate='periodic')
  1700. I = P.antiderivative()
  1701. period_int = xp.asarray(I(4) - I(1))
  1702. xp_assert_close(P.integrate(1, 4), period_int) #, check_0d=False)
  1703. xp_assert_close(P.integrate(-10, -7), period_int)
  1704. xp_assert_close(P.integrate(-10, -4), xp.asarray(2 * period_int))
  1705. xp_assert_close(P.integrate(1.5, 2.5), xp.asarray(I(2.5) - I(1.5)))
  1706. xp_assert_close(P.integrate(3.5, 5), xp.asarray(I(2) - I(1) + I(4) - I(3.5)))
  1707. xp_assert_close(P.integrate(3.5 + 12, 5 + 12),
  1708. xp.asarray(I(2) - I(1) + I(4) - I(3.5)))
  1709. xp_assert_close(P.integrate(3.5, 5 + 12),
  1710. xp.asarray(I(2) - I(1) + I(4) - I(3.5) + 4 * period_int))
  1711. xp_assert_close(P.integrate(0, -1), xp.asarray(I(2) - I(3)))
  1712. xp_assert_close(P.integrate(-9, -10), xp.asarray(I(2) - I(3)))
  1713. xp_assert_close(P.integrate(0, -10), xp.asarray(I(2) - I(3) - 3 * period_int))
  1714. def test_antider_neg(self, xp):
  1715. # .derivative(-nu) ==> .andiderivative(nu) and vice versa
  1716. c = xp.asarray([[1]])
  1717. x = xp.asarray([0, 1])
  1718. b = BPoly(c, x)
  1719. xx = xp.linspace(0, 1, 21)
  1720. xp_assert_close(b.derivative(-1)(xx), b.antiderivative()(xx),
  1721. atol=1e-12, rtol=1e-12)
  1722. xp_assert_close(b.derivative(1)(xx), b.antiderivative(-1)(xx),
  1723. atol=1e-12, rtol=1e-12)
  1724. @make_xp_test_case(BPoly, PPoly)
  1725. class TestPolyConversions:
  1726. def test_bp_from_pp(self, xp):
  1727. x = xp.asarray([0, 1, 3])
  1728. c = xp.asarray([[3, 2], [1, 8], [4, 3]])
  1729. pp = PPoly(c, x)
  1730. bp = BPoly.from_power_basis(pp)
  1731. pp1 = PPoly.from_bernstein_basis(bp)
  1732. xv = xp.asarray([0.1, 1.4])
  1733. xp_assert_close(pp(xv), bp(xv))
  1734. xp_assert_close(pp(xv), pp1(xv))
  1735. def test_bp_from_pp_random(self):
  1736. rng = np.random.RandomState(1234)
  1737. m, k = 5, 8 # number of intervals, order
  1738. x = np.sort(rng.random(m))
  1739. c = rng.random((k, m-1))
  1740. pp = PPoly(c, x)
  1741. bp = BPoly.from_power_basis(pp)
  1742. pp1 = PPoly.from_bernstein_basis(bp)
  1743. xv = np.linspace(x[0], x[-1], 21)
  1744. xp_assert_close(pp(xv), bp(xv))
  1745. xp_assert_close(pp(xv), pp1(xv))
  1746. def test_pp_from_bp(self, xp):
  1747. x = xp.asarray([0, 1, 3])
  1748. c = xp.asarray([[3, 3], [1, 1], [4, 2]])
  1749. bp = BPoly(c, x)
  1750. pp = PPoly.from_bernstein_basis(bp)
  1751. bp1 = BPoly.from_power_basis(pp)
  1752. xv = xp.asarray([0.1, 1.4])
  1753. xp_assert_close(bp(xv), pp(xv))
  1754. xp_assert_close(bp(xv), bp1(xv))
  1755. def test_broken_conversions(self):
  1756. # regression test for gh-10597: from_power_basis only accepts PPoly etc.
  1757. x = [0, 1, 3]
  1758. c = [[3, 3], [1, 1], [4, 2]]
  1759. pp = PPoly(c, x)
  1760. with assert_raises(TypeError):
  1761. PPoly.from_bernstein_basis(pp)
  1762. bp = BPoly(c, x)
  1763. with assert_raises(TypeError):
  1764. BPoly.from_power_basis(bp)
  1765. class TestBPolyFromDerivatives:
  1766. def test_make_poly_1(self):
  1767. c1 = BPoly._construct_from_derivatives(0, 1, [2], [3])
  1768. xp_assert_close(c1, [2., 3.])
  1769. def test_make_poly_2(self):
  1770. c1 = BPoly._construct_from_derivatives(0, 1, [1, 0], [1])
  1771. xp_assert_close(c1, [1., 1., 1.])
  1772. # f'(0) = 3
  1773. c2 = BPoly._construct_from_derivatives(0, 1, [2, 3], [1])
  1774. xp_assert_close(c2, [2., 7./2, 1.])
  1775. # f'(1) = 3
  1776. c3 = BPoly._construct_from_derivatives(0, 1, [2], [1, 3])
  1777. xp_assert_close(c3, [2., -0.5, 1.])
  1778. def test_make_poly_3(self):
  1779. # f'(0)=2, f''(0)=3
  1780. c1 = BPoly._construct_from_derivatives(0, 1, [1, 2, 3], [4])
  1781. xp_assert_close(c1, [1., 5./3, 17./6, 4.])
  1782. # f'(1)=2, f''(1)=3
  1783. c2 = BPoly._construct_from_derivatives(0, 1, [1], [4, 2, 3])
  1784. xp_assert_close(c2, [1., 19./6, 10./3, 4.])
  1785. # f'(0)=2, f'(1)=3
  1786. c3 = BPoly._construct_from_derivatives(0, 1, [1, 2], [4, 3])
  1787. xp_assert_close(c3, [1., 5./3, 3., 4.])
  1788. def test_make_poly_12(self):
  1789. rng = np.random.RandomState(12345)
  1790. ya = np.r_[0, rng.random(5)]
  1791. yb = np.r_[0, rng.random(5)]
  1792. c = BPoly._construct_from_derivatives(0, 1, ya, yb)
  1793. pp = BPoly(c[:, None], [0, 1])
  1794. for j in range(6):
  1795. xp_assert_close(pp(0.), ya[j], check_0d=False)
  1796. xp_assert_close(pp(1.), yb[j], check_0d=False)
  1797. pp = pp.derivative()
  1798. def test_raise_degree(self):
  1799. rng = np.random.RandomState(12345)
  1800. x = [0, 1]
  1801. k, d = 8, 5
  1802. c = rng.random((k, 1, 2, 3, 4))
  1803. bp = BPoly(c, x)
  1804. c1 = BPoly._raise_degree(c, d)
  1805. bp1 = BPoly(c1, x)
  1806. xp = np.linspace(0, 1, 11)
  1807. xp_assert_close(bp(xp), bp1(xp))
  1808. def test_xi_yi(self):
  1809. assert_raises(ValueError, BPoly.from_derivatives, [0, 1], [0])
  1810. def test_coords_order(self):
  1811. xi = [0, 0, 1]
  1812. yi = [[0], [0], [0]]
  1813. assert_raises(ValueError, BPoly.from_derivatives, xi, yi)
  1814. def test_zeros(self):
  1815. xi = [0, 1, 2, 3]
  1816. yi = [[0, 0], [0], [0, 0], [0, 0]] # NB: will have to raise the degree
  1817. pp = BPoly.from_derivatives(xi, yi)
  1818. assert pp.c.shape == (4, 3)
  1819. ppd = pp.derivative()
  1820. for xp in [0., 0.1, 1., 1.1, 1.9, 2., 2.5]:
  1821. xp_assert_close(pp(xp), np.asarray(0.0))
  1822. xp_assert_close(ppd(xp), np.asarray(0.0))
  1823. def _make_random_mk(self, m, k):
  1824. # k derivatives at each breakpoint
  1825. rng = np.random.RandomState(1234)
  1826. xi = np.asarray([1. * j**2 for j in range(m+1)])
  1827. yi = [rng.random(k) for j in range(m+1)]
  1828. return xi, yi
  1829. def test_random_12(self):
  1830. m, k = 5, 12
  1831. xi, yi = self._make_random_mk(m, k)
  1832. pp = BPoly.from_derivatives(xi, yi)
  1833. for order in range(k//2):
  1834. xp_assert_close(pp(xi), [yy[order] for yy in yi])
  1835. pp = pp.derivative()
  1836. def test_order_zero(self):
  1837. m, k = 5, 12
  1838. xi, yi = self._make_random_mk(m, k)
  1839. assert_raises(ValueError, BPoly.from_derivatives,
  1840. **dict(xi=xi, yi=yi, orders=0))
  1841. def test_orders_too_high(self):
  1842. m, k = 5, 12
  1843. xi, yi = self._make_random_mk(m, k)
  1844. BPoly.from_derivatives(xi, yi, orders=2*k-1) # this is still ok
  1845. assert_raises(ValueError, BPoly.from_derivatives, # but this is not
  1846. **dict(xi=xi, yi=yi, orders=2*k))
  1847. def test_orders_global(self):
  1848. m, k = 5, 12
  1849. xi, yi = self._make_random_mk(m, k)
  1850. # ok, this is confusing. Local polynomials will be of the order 5
  1851. # which means that up to the 2nd derivatives will be used at each point
  1852. order = 5
  1853. pp = BPoly.from_derivatives(xi, yi, orders=order)
  1854. for j in range(order//2+1):
  1855. xp_assert_close(pp(xi[1:-1] - 1e-12), pp(xi[1:-1] + 1e-12))
  1856. pp = pp.derivative()
  1857. assert not np.allclose(pp(xi[1:-1] - 1e-12), pp(xi[1:-1] + 1e-12))
  1858. # now repeat with `order` being even: on each interval, it uses
  1859. # order//2 'derivatives' @ the right-hand endpoint and
  1860. # order//2+1 @ 'derivatives' the left-hand endpoint
  1861. order = 6
  1862. pp = BPoly.from_derivatives(xi, yi, orders=order)
  1863. for j in range(order//2):
  1864. xp_assert_close(pp(xi[1:-1] - 1e-12), pp(xi[1:-1] + 1e-12))
  1865. pp = pp.derivative()
  1866. assert not np.allclose(pp(xi[1:-1] - 1e-12), pp(xi[1:-1] + 1e-12))
  1867. def test_orders_local(self):
  1868. m, k = 7, 12
  1869. xi, yi = self._make_random_mk(m, k)
  1870. orders = [o + 1 for o in range(m)]
  1871. for i, x in enumerate(xi[1:-1]):
  1872. pp = BPoly.from_derivatives(xi, yi, orders=orders)
  1873. for j in range(orders[i] // 2 + 1):
  1874. xp_assert_close(pp(x - 1e-12), pp(x + 1e-12))
  1875. pp = pp.derivative()
  1876. assert not np.allclose(pp(x - 1e-12), pp(x + 1e-12))
  1877. def test_yi_trailing_dims(self):
  1878. rng = np.random.RandomState(1234)
  1879. m, k = 7, 5
  1880. xi = np.sort(rng.random(m+1))
  1881. yi = rng.random((m+1, k, 6, 7, 8))
  1882. pp = BPoly.from_derivatives(xi, yi)
  1883. assert pp.c.shape == (2*k, m, 6, 7, 8)
  1884. def test_gh_5430(self):
  1885. # At least one of these raises an error unless gh-5430 is
  1886. # fixed. In py2k an int is implemented using a C long, so
  1887. # which one fails depends on your system. In py3k there is only
  1888. # one arbitrary precision integer type, so both should fail.
  1889. orders = np.int32(1)
  1890. p = BPoly.from_derivatives([0, 1], [[0], [0]], orders=orders)
  1891. assert_almost_equal(p(0), np.asarray(0))
  1892. orders = np.int64(1)
  1893. p = BPoly.from_derivatives([0, 1], [[0], [0]], orders=orders)
  1894. assert_almost_equal(p(0), np.asarray(0))
  1895. orders = 1
  1896. # This worked before; make sure it still works
  1897. p = BPoly.from_derivatives([0, 1], [[0], [0]], orders=orders)
  1898. assert_almost_equal(p(0), np.asarray(0))
  1899. orders = 1
  1900. class TestNdPPoly:
  1901. def test_simple_1d(self):
  1902. rng = np.random.RandomState(1234)
  1903. c = rng.rand(4, 5)
  1904. x = np.linspace(0, 1, 5+1)
  1905. xi = rng.rand(200)
  1906. p = NdPPoly(c, (x,))
  1907. v1 = p((xi,))
  1908. v2 = _ppoly_eval_1(c[:,:,None], x, xi).ravel()
  1909. xp_assert_close(v1, v2)
  1910. def test_simple_2d(self):
  1911. rng = np.random.RandomState(1234)
  1912. c = rng.rand(4, 5, 6, 7)
  1913. x = np.linspace(0, 1, 6+1)
  1914. y = np.linspace(0, 1, 7+1)**2
  1915. xi = rng.rand(200)
  1916. yi = rng.rand(200)
  1917. v1 = np.empty([len(xi), 1], dtype=c.dtype)
  1918. v1.fill(np.nan)
  1919. _ppoly.evaluate_nd(c.reshape(4*5, 6*7, 1),
  1920. (x, y),
  1921. np.array([4, 5], dtype=np.intc),
  1922. np.c_[xi, yi],
  1923. np.array([0, 0], dtype=np.intc),
  1924. 1,
  1925. v1)
  1926. v1 = v1.ravel()
  1927. v2 = _ppoly2d_eval(c, (x, y), xi, yi)
  1928. xp_assert_close(v1, v2)
  1929. p = NdPPoly(c, (x, y))
  1930. for nu in (None, (0, 0), (0, 1), (1, 0), (2, 3), (9, 2)):
  1931. v1 = p(np.c_[xi, yi], nu=nu)
  1932. v2 = _ppoly2d_eval(c, (x, y), xi, yi, nu=nu)
  1933. xp_assert_close(v1, v2, err_msg=repr(nu))
  1934. def test_simple_3d(self):
  1935. rng = np.random.RandomState(1234)
  1936. c = rng.rand(4, 5, 6, 7, 8, 9)
  1937. x = np.linspace(0, 1, 7+1)
  1938. y = np.linspace(0, 1, 8+1)**2
  1939. z = np.linspace(0, 1, 9+1)**3
  1940. xi = rng.rand(40)
  1941. yi = rng.rand(40)
  1942. zi = rng.rand(40)
  1943. p = NdPPoly(c, (x, y, z))
  1944. for nu in (None, (0, 0, 0), (0, 1, 0), (1, 0, 0), (2, 3, 0),
  1945. (6, 0, 2)):
  1946. v1 = p((xi, yi, zi), nu=nu)
  1947. v2 = _ppoly3d_eval(c, (x, y, z), xi, yi, zi, nu=nu)
  1948. xp_assert_close(v1, v2, err_msg=repr(nu))
  1949. def test_simple_4d(self):
  1950. rng = np.random.RandomState(1234)
  1951. c = rng.rand(4, 5, 6, 7, 8, 9, 10, 11)
  1952. x = np.linspace(0, 1, 8+1)
  1953. y = np.linspace(0, 1, 9+1)**2
  1954. z = np.linspace(0, 1, 10+1)**3
  1955. u = np.linspace(0, 1, 11+1)**4
  1956. xi = rng.rand(20)
  1957. yi = rng.rand(20)
  1958. zi = rng.rand(20)
  1959. ui = rng.rand(20)
  1960. p = NdPPoly(c, (x, y, z, u))
  1961. v1 = p((xi, yi, zi, ui))
  1962. v2 = _ppoly4d_eval(c, (x, y, z, u), xi, yi, zi, ui)
  1963. xp_assert_close(v1, v2)
  1964. def test_deriv_1d(self):
  1965. rng = np.random.RandomState(1234)
  1966. c = rng.rand(4, 5)
  1967. x = np.linspace(0, 1, 5+1)
  1968. p = NdPPoly(c, (x,))
  1969. # derivative
  1970. dp = p.derivative(nu=[1])
  1971. p1 = PPoly(c, x)
  1972. dp1 = p1.derivative()
  1973. xp_assert_close(dp.c, dp1.c)
  1974. # antiderivative
  1975. dp = p.antiderivative(nu=[2])
  1976. p1 = PPoly(c, x)
  1977. dp1 = p1.antiderivative(2)
  1978. xp_assert_close(dp.c, dp1.c)
  1979. def test_deriv_3d(self):
  1980. rng = np.random.RandomState(1234)
  1981. c = rng.rand(4, 5, 6, 7, 8, 9)
  1982. x = np.linspace(0, 1, 7+1)
  1983. y = np.linspace(0, 1, 8+1)**2
  1984. z = np.linspace(0, 1, 9+1)**3
  1985. p = NdPPoly(c, (x, y, z))
  1986. # differentiate vs x
  1987. p1 = PPoly(c.transpose(0, 3, 1, 2, 4, 5), x)
  1988. dp = p.derivative(nu=[2])
  1989. dp1 = p1.derivative(2)
  1990. xp_assert_close(dp.c,
  1991. dp1.c.transpose(0, 2, 3, 1, 4, 5))
  1992. # antidifferentiate vs y
  1993. p1 = PPoly(c.transpose(1, 4, 0, 2, 3, 5), y)
  1994. dp = p.antiderivative(nu=[0, 1, 0])
  1995. dp1 = p1.antiderivative(1)
  1996. xp_assert_close(dp.c,
  1997. dp1.c.transpose(2, 0, 3, 4, 1, 5))
  1998. # differentiate vs z
  1999. p1 = PPoly(c.transpose(2, 5, 0, 1, 3, 4), z)
  2000. dp = p.derivative(nu=[0, 0, 3])
  2001. dp1 = p1.derivative(3)
  2002. xp_assert_close(dp.c,
  2003. dp1.c.transpose(2, 3, 0, 4, 5, 1))
  2004. def test_deriv_3d_simple(self):
  2005. # Integrate to obtain function x y**2 z**4 / (2! 4!)
  2006. rng = np.random.RandomState(1234)
  2007. c = np.ones((1, 1, 1, 3, 4, 5))
  2008. x = np.linspace(0, 1, 3+1)**1
  2009. y = np.linspace(0, 1, 4+1)**2
  2010. z = np.linspace(0, 1, 5+1)**3
  2011. p = NdPPoly(c, (x, y, z))
  2012. ip = p.antiderivative((1, 0, 4))
  2013. ip = ip.antiderivative((0, 2, 0))
  2014. xi = rng.rand(20)
  2015. yi = rng.rand(20)
  2016. zi = rng.rand(20)
  2017. xp_assert_close(ip((xi, yi, zi)),
  2018. xi * yi**2 * zi**4 / (gamma(3)*gamma(5)))
  2019. def test_integrate_2d(self):
  2020. rng = np.random.RandomState(1234)
  2021. c = rng.rand(4, 5, 16, 17)
  2022. x = np.linspace(0, 1, 16+1)**1
  2023. y = np.linspace(0, 1, 17+1)**2
  2024. # make continuously differentiable so that nquad() has an
  2025. # easier time
  2026. c = c.transpose(0, 2, 1, 3)
  2027. cx = c.reshape(c.shape[0], c.shape[1], -1).copy()
  2028. _ppoly.fix_continuity(cx, x, 2)
  2029. c = cx.reshape(c.shape)
  2030. c = c.transpose(0, 2, 1, 3)
  2031. c = c.transpose(1, 3, 0, 2)
  2032. cx = c.reshape(c.shape[0], c.shape[1], -1).copy()
  2033. _ppoly.fix_continuity(cx, y, 2)
  2034. c = cx.reshape(c.shape)
  2035. c = c.transpose(2, 0, 3, 1).copy()
  2036. # Check integration
  2037. p = NdPPoly(c, (x, y))
  2038. for ranges in [[(0, 1), (0, 1)],
  2039. [(0, 0.5), (0, 1)],
  2040. [(0, 1), (0, 0.5)],
  2041. [(0.3, 0.7), (0.6, 0.2)]]:
  2042. ig = p.integrate(ranges)
  2043. ig2, err2 = nquad(lambda x, y: p((x, y)), ranges,
  2044. opts=[dict(epsrel=1e-5, epsabs=1e-5)]*2)
  2045. xp_assert_close(ig, ig2, rtol=1e-5, atol=1e-5, check_0d=False,
  2046. err_msg=repr(ranges))
  2047. def test_integrate_1d(self):
  2048. rng = np.random.RandomState(1234)
  2049. c = rng.rand(4, 5, 6, 16, 17, 18)
  2050. x = np.linspace(0, 1, 16+1)**1
  2051. y = np.linspace(0, 1, 17+1)**2
  2052. z = np.linspace(0, 1, 18+1)**3
  2053. # Check 1-D integration
  2054. p = NdPPoly(c, (x, y, z))
  2055. u = rng.rand(200)
  2056. v = rng.rand(200)
  2057. a, b = 0.2, 0.7
  2058. px = p.integrate_1d(a, b, axis=0)
  2059. pax = p.antiderivative((1, 0, 0))
  2060. xp_assert_close(px((u, v)), pax((b, u, v)) - pax((a, u, v)))
  2061. py = p.integrate_1d(a, b, axis=1)
  2062. pay = p.antiderivative((0, 1, 0))
  2063. xp_assert_close(py((u, v)), pay((u, b, v)) - pay((u, a, v)))
  2064. pz = p.integrate_1d(a, b, axis=2)
  2065. paz = p.antiderivative((0, 0, 1))
  2066. xp_assert_close(pz((u, v)), paz((u, v, b)) - paz((u, v, a)))
  2067. def test_concurrency(self):
  2068. rng = np.random.default_rng(12345)
  2069. c = rng.uniform(size=(4, 5, 6, 7, 8, 9))
  2070. x = np.linspace(0, 1, 7+1)
  2071. y = np.linspace(0, 1, 8+1)**2
  2072. z = np.linspace(0, 1, 9+1)**3
  2073. p = NdPPoly(c, (x, y, z))
  2074. def worker_fn(_, spl):
  2075. xi = rng.uniform(size=40)
  2076. yi = rng.uniform(size=40)
  2077. zi = rng.uniform(size=40)
  2078. spl((xi, yi, zi))
  2079. _run_concurrent_barrier(10, worker_fn, p)
  2080. def _ppoly_eval_1(c, x, xps):
  2081. """Evaluate piecewise polynomial manually"""
  2082. out = np.zeros((len(xps), c.shape[2]))
  2083. for i, xp in enumerate(xps):
  2084. if xp < 0 or xp > 1:
  2085. out[i,:] = np.nan
  2086. continue
  2087. j = np.searchsorted(x, xp) - 1
  2088. d = xp - x[j]
  2089. assert x[j] <= xp < x[j+1]
  2090. r = sum(c[k,j] * d**(c.shape[0]-k-1)
  2091. for k in range(c.shape[0]))
  2092. out[i,:] = r
  2093. return out
  2094. def _ppoly_eval_2(coeffs, breaks, xnew, fill=np.nan):
  2095. """Evaluate piecewise polynomial manually (another way)"""
  2096. a = breaks[0]
  2097. b = breaks[-1]
  2098. K = coeffs.shape[0]
  2099. saveshape = np.shape(xnew)
  2100. xnew = np.ravel(xnew)
  2101. res = np.empty_like(xnew)
  2102. mask = (xnew >= a) & (xnew <= b)
  2103. res[~mask] = fill
  2104. xx = xnew.compress(mask)
  2105. indxs = np.searchsorted(breaks, xx)-1
  2106. indxs = indxs.clip(0, len(breaks))
  2107. pp = coeffs
  2108. diff = xx - breaks.take(indxs)
  2109. V = np.vander(diff, N=K)
  2110. values = np.array([np.dot(V[k, :], pp[:, indxs[k]]) for k in range(len(xx))])
  2111. res[mask] = values
  2112. res = res.reshape(saveshape)
  2113. return res
  2114. def _dpow(x, y, n):
  2115. """
  2116. d^n (x**y) / dx^n
  2117. """
  2118. if n < 0:
  2119. raise ValueError("invalid derivative order")
  2120. elif n > y:
  2121. return 0
  2122. else:
  2123. return poch(y - n + 1, n) * x**(y - n)
  2124. def _ppoly2d_eval(c, xs, xnew, ynew, nu=None):
  2125. """
  2126. Straightforward evaluation of 2-D piecewise polynomial
  2127. """
  2128. if nu is None:
  2129. nu = (0, 0)
  2130. out = np.empty((len(xnew),), dtype=c.dtype)
  2131. nx, ny = c.shape[:2]
  2132. for jout, (x, y) in enumerate(zip(xnew, ynew)):
  2133. if not ((xs[0][0] <= x <= xs[0][-1]) and
  2134. (xs[1][0] <= y <= xs[1][-1])):
  2135. out[jout] = np.nan
  2136. continue
  2137. j1 = np.searchsorted(xs[0], x) - 1
  2138. j2 = np.searchsorted(xs[1], y) - 1
  2139. s1 = x - xs[0][j1]
  2140. s2 = y - xs[1][j2]
  2141. val = 0
  2142. for k1 in range(c.shape[0]):
  2143. for k2 in range(c.shape[1]):
  2144. val += (c[nx-k1-1,ny-k2-1,j1,j2]
  2145. * _dpow(s1, k1, nu[0])
  2146. * _dpow(s2, k2, nu[1]))
  2147. out[jout] = val
  2148. return out
  2149. def _ppoly3d_eval(c, xs, xnew, ynew, znew, nu=None):
  2150. """
  2151. Straightforward evaluation of 3-D piecewise polynomial
  2152. """
  2153. if nu is None:
  2154. nu = (0, 0, 0)
  2155. out = np.empty((len(xnew),), dtype=c.dtype)
  2156. nx, ny, nz = c.shape[:3]
  2157. for jout, (x, y, z) in enumerate(zip(xnew, ynew, znew)):
  2158. if not ((xs[0][0] <= x <= xs[0][-1]) and
  2159. (xs[1][0] <= y <= xs[1][-1]) and
  2160. (xs[2][0] <= z <= xs[2][-1])):
  2161. out[jout] = np.nan
  2162. continue
  2163. j1 = np.searchsorted(xs[0], x) - 1
  2164. j2 = np.searchsorted(xs[1], y) - 1
  2165. j3 = np.searchsorted(xs[2], z) - 1
  2166. s1 = x - xs[0][j1]
  2167. s2 = y - xs[1][j2]
  2168. s3 = z - xs[2][j3]
  2169. val = 0
  2170. for k1 in range(c.shape[0]):
  2171. for k2 in range(c.shape[1]):
  2172. for k3 in range(c.shape[2]):
  2173. val += (c[nx-k1-1,ny-k2-1,nz-k3-1,j1,j2,j3]
  2174. * _dpow(s1, k1, nu[0])
  2175. * _dpow(s2, k2, nu[1])
  2176. * _dpow(s3, k3, nu[2]))
  2177. out[jout] = val
  2178. return out
  2179. def _ppoly4d_eval(c, xs, xnew, ynew, znew, unew, nu=None):
  2180. """
  2181. Straightforward evaluation of 4-D piecewise polynomial
  2182. """
  2183. if nu is None:
  2184. nu = (0, 0, 0, 0)
  2185. out = np.empty((len(xnew),), dtype=c.dtype)
  2186. mx, my, mz, mu = c.shape[:4]
  2187. for jout, (x, y, z, u) in enumerate(zip(xnew, ynew, znew, unew)):
  2188. if not ((xs[0][0] <= x <= xs[0][-1]) and
  2189. (xs[1][0] <= y <= xs[1][-1]) and
  2190. (xs[2][0] <= z <= xs[2][-1]) and
  2191. (xs[3][0] <= u <= xs[3][-1])):
  2192. out[jout] = np.nan
  2193. continue
  2194. j1 = np.searchsorted(xs[0], x) - 1
  2195. j2 = np.searchsorted(xs[1], y) - 1
  2196. j3 = np.searchsorted(xs[2], z) - 1
  2197. j4 = np.searchsorted(xs[3], u) - 1
  2198. s1 = x - xs[0][j1]
  2199. s2 = y - xs[1][j2]
  2200. s3 = z - xs[2][j3]
  2201. s4 = u - xs[3][j4]
  2202. val = 0
  2203. for k1 in range(c.shape[0]):
  2204. for k2 in range(c.shape[1]):
  2205. for k3 in range(c.shape[2]):
  2206. for k4 in range(c.shape[3]):
  2207. val += (c[mx-k1-1,my-k2-1,mz-k3-1,mu-k4-1,j1,j2,j3,j4]
  2208. * _dpow(s1, k1, nu[0])
  2209. * _dpow(s2, k2, nu[1])
  2210. * _dpow(s3, k3, nu[2])
  2211. * _dpow(s4, k4, nu[3]))
  2212. out[jout] = val
  2213. return out