test_signaltools.py 192 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766
  1. import sys
  2. import math
  3. import warnings
  4. from concurrent.futures import ThreadPoolExecutor, as_completed
  5. from itertools import product
  6. from math import gcd
  7. import pytest
  8. from pytest import raises as assert_raises
  9. import numpy as np
  10. from numpy.exceptions import ComplexWarning
  11. from scipy import fft as sp_fft
  12. from scipy.ndimage import correlate1d
  13. from scipy.optimize import fmin, linear_sum_assignment
  14. from scipy import signal
  15. from scipy.signal import (
  16. correlate, correlate2d, correlation_lags, convolve, convolve2d,
  17. fftconvolve, oaconvolve, choose_conv_method, envelope,
  18. hilbert, hilbert2, lfilter, lfilter_zi, filtfilt, butter, zpk2tf, zpk2sos,
  19. invres, invresz, vectorstrength, lfiltic, tf2sos, sosfilt, sosfiltfilt,
  20. sosfilt_zi, tf2zpk, BadCoefficients, detrend, unique_roots, residue,
  21. residuez)
  22. from scipy.signal.windows import hann
  23. from scipy.signal._signaltools import _filtfilt_gust, _compute_factors, _group_poles
  24. from scipy.signal._upfirdn import _upfirdn_modes
  25. from scipy._lib import _testutils
  26. from scipy._lib._array_api import (
  27. xp_assert_close, xp_assert_equal, is_numpy, is_torch, is_jax, is_cupy,
  28. assert_array_almost_equal, assert_almost_equal,
  29. xp_copy, xp_size, xp_default_dtype, array_namespace, make_xp_test_case,
  30. make_xp_pytest_param, SCIPY_DEVICE, _xp_copy_to_numpy
  31. )
  32. skip_xp_backends = pytest.mark.skip_xp_backends
  33. xfail_xp_backends = pytest.mark.xfail_xp_backends
  34. lazy_xp_modules = [signal]
  35. @make_xp_test_case(convolve)
  36. class TestConvolve:
  37. @skip_xp_backends("jax.numpy",
  38. reason="jax returns floats; scipy returns ints; cf gh-6076")
  39. def test_basic(self, xp):
  40. a = xp.asarray([3, 4, 5, 6, 5, 4])
  41. b = xp.asarray([1, 2, 3])
  42. c = convolve(a, b)
  43. xp_assert_equal(c, xp.asarray([3, 10, 22, 28, 32, 32, 23, 12]))
  44. @skip_xp_backends("jax.numpy",
  45. reason="jax returns floats; scipy returns ints; cf gh-6076")
  46. def test_same(self, xp):
  47. a = xp.asarray([3, 4, 5])
  48. b = xp.asarray([1, 2, 3, 4])
  49. c = convolve(a, b, mode="same")
  50. xp_assert_equal(c, xp.asarray([10, 22, 34]))
  51. @skip_xp_backends("jax.numpy",
  52. reason="jax returns floats; scipy returns ints; cf gh-6076")
  53. def test_same_eq(self, xp):
  54. a = xp.asarray([3, 4, 5])
  55. b = xp.asarray([1, 2, 3])
  56. c = convolve(a, b, mode="same")
  57. xp_assert_equal(c, xp.asarray([10, 22, 22]))
  58. def test_complex(self, xp):
  59. x = xp.asarray([1 + 1j, 2 + 1j, 3 + 1j])
  60. y = xp.asarray([1 + 1j, 2 + 1j])
  61. z = convolve(x, y)
  62. xp_assert_equal(z, xp.asarray([2j, 2 + 6j, 5 + 8j, 5 + 5j]))
  63. @xfail_xp_backends("jax.numpy", reason="wrong output dtype")
  64. def test_zero_rank(self, xp):
  65. a = xp.asarray(1289)
  66. b = xp.asarray(4567)
  67. c = convolve(a, b)
  68. xp_assert_equal(c, a * b)
  69. @skip_xp_backends(np_only=True, reason="pure python")
  70. def test_zero_rank_python_scalars(self, xp):
  71. a = 1289
  72. b = 4567
  73. c = convolve(a, b)
  74. assert c == a * b
  75. @xfail_xp_backends("jax.numpy", reason="disagreement between methods")
  76. def test_broadcastable(self, xp):
  77. a = xp.reshape(xp.arange(27), (3, 3, 3))
  78. b = xp.arange(3)
  79. for i in range(3):
  80. b_shape = [1]*3
  81. b_shape[i] = 3
  82. x = convolve(a, xp.reshape(b, tuple(b_shape)), method='direct')
  83. y = convolve(a, xp.reshape(b, tuple(b_shape)), method='fft')
  84. xp_assert_close(x, y, atol=1e-14)
  85. @xfail_xp_backends("jax.numpy", reason="wrong output dtype")
  86. def test_single_element(self, xp):
  87. a = xp.asarray([4967])
  88. b = xp.asarray([3920])
  89. c = convolve(a, b)
  90. xp_assert_equal(c, a * b)
  91. @skip_xp_backends("jax.numpy",)
  92. @skip_xp_backends("cupy")
  93. def test_2d_arrays(self, xp):
  94. a = xp.asarray([[1, 2, 3], [3, 4, 5]])
  95. b = xp.asarray([[2, 3, 4], [4, 5, 6]])
  96. c = convolve(a, b)
  97. d = xp.asarray([[2, 7, 16, 17, 12],
  98. [10, 30, 62, 58, 38],
  99. [12, 31, 58, 49, 30]])
  100. xp_assert_equal(c, d)
  101. @skip_xp_backends("torch")
  102. @skip_xp_backends("cupy")
  103. def test_input_swapping(self, xp):
  104. small = xp.reshape(xp.arange(8), (2, 2, 2))
  105. big = 1j * xp.reshape(xp.arange(27, dtype=xp.complex128), (3, 3, 3))
  106. big += xp.reshape(xp.arange(27, dtype=xp.complex128)[::-1], (3, 3, 3))
  107. out_array = xp.asarray(
  108. [[[0 + 0j, 26 + 0j, 25 + 1j, 24 + 2j],
  109. [52 + 0j, 151 + 5j, 145 + 11j, 93 + 11j],
  110. [46 + 6j, 133 + 23j, 127 + 29j, 81 + 23j],
  111. [40 + 12j, 98 + 32j, 93 + 37j, 54 + 24j]],
  112. [[104 + 0j, 247 + 13j, 237 + 23j, 135 + 21j],
  113. [282 + 30j, 632 + 96j, 604 + 124j, 330 + 86j],
  114. [246 + 66j, 548 + 180j, 520 + 208j, 282 + 134j],
  115. [142 + 66j, 307 + 161j, 289 + 179j, 153 + 107j]],
  116. [[68 + 36j, 157 + 103j, 147 + 113j, 81 + 75j],
  117. [174 + 138j, 380 + 348j, 352 + 376j, 186 + 230j],
  118. [138 + 174j, 296 + 432j, 268 + 460j, 138 + 278j],
  119. [70 + 138j, 145 + 323j, 127 + 341j, 63 + 197j]],
  120. [[32 + 72j, 68 + 166j, 59 + 175j, 30 + 100j],
  121. [68 + 192j, 139 + 433j, 117 + 455j, 57 + 255j],
  122. [38 + 222j, 73 + 499j, 51 + 521j, 21 + 291j],
  123. [12 + 144j, 20 + 318j, 7 + 331j, 0 + 182j]]])
  124. xp_assert_equal(convolve(small, big, 'full'), out_array)
  125. xp_assert_equal(convolve(big, small, 'full'), out_array)
  126. xp_assert_equal(convolve(small, big, 'same'),
  127. out_array[1:3, 1:3, 1:3])
  128. xp_assert_equal(convolve(big, small, 'same'),
  129. out_array[0:3, 0:3, 0:3])
  130. xp_assert_equal(convolve(small, big, 'valid'),
  131. out_array[1:3, 1:3, 1:3])
  132. xp_assert_equal(convolve(big, small, 'valid'),
  133. out_array[1:3, 1:3, 1:3])
  134. def test_invalid_params(self, xp):
  135. a = xp.asarray([3, 4, 5])
  136. b = xp.asarray([1, 2, 3])
  137. assert_raises(ValueError, convolve, a, b, mode='spam')
  138. assert_raises(ValueError, convolve, a, b, mode='eggs', method='fft')
  139. assert_raises(ValueError, convolve, a, b, mode='ham', method='direct')
  140. assert_raises(ValueError, convolve, a, b, mode='full', method='bacon')
  141. assert_raises(ValueError, convolve, a, b, mode='same', method='bacon')
  142. @skip_xp_backends("jax.numpy", reason="dtypes do not match")
  143. def test_valid_mode2(self, xp):
  144. # See gh-5897
  145. a = xp.asarray([1, 2, 3, 6, 5, 3])
  146. b = xp.asarray([2, 3, 4, 5, 3, 4, 2, 2, 1])
  147. expected = xp.asarray([70, 78, 73, 65])
  148. out = convolve(a, b, 'valid')
  149. xp_assert_equal(out, expected)
  150. out = convolve(b, a, 'valid')
  151. xp_assert_equal(out, expected)
  152. a = xp.asarray([1 + 5j, 2 - 1j, 3 + 0j])
  153. b = xp.asarray([2 - 3j, 1 + 0j])
  154. expected = xp.asarray([2 - 3j, 8 - 10j])
  155. out = convolve(a, b, 'valid')
  156. xp_assert_equal(out, expected)
  157. out = convolve(b, a, 'valid')
  158. xp_assert_equal(out, expected)
  159. @skip_xp_backends("jax.numpy", reason="dtypes do not match")
  160. def test_same_mode(self, xp):
  161. a = xp.asarray([1, 2, 3, 3, 1, 2])
  162. b = xp.asarray([1, 4, 3, 4, 5, 6, 7, 4, 3, 2, 1, 1, 3])
  163. c = convolve(a, b, 'same')
  164. d = xp.asarray([57, 61, 63, 57, 45, 36])
  165. xp_assert_equal(c, d)
  166. @skip_xp_backends("cupy", reason="different exception")
  167. def test_invalid_shapes(self, xp):
  168. # By "invalid," we mean that no one
  169. # array has dimensions that are all at
  170. # least as large as the corresponding
  171. # dimensions of the other array. This
  172. # setup should throw a ValueError.
  173. a = xp.reshape(xp.arange(1, 7), (2, 3))
  174. b = xp.reshape(xp.arange(-6, 0), (3, 2))
  175. assert_raises(ValueError, convolve, *(a, b), **{'mode': 'valid'})
  176. assert_raises(ValueError, convolve, *(b, a), **{'mode': 'valid'})
  177. @skip_xp_backends(np_only=True, reason="TODO: convert this test")
  178. def test_convolve_method(self, xp, n=100):
  179. # this types data structure was manually encoded instead of
  180. # using custom filters on the soon-to-be-removed np.sctypes
  181. types = {'uint16', 'uint64', 'int64', 'int32',
  182. 'complex128', 'float64', 'float16',
  183. 'complex64', 'float32', 'int16',
  184. 'uint8', 'uint32', 'int8', 'bool'}
  185. args = [(t1, t2, mode) for t1 in types for t2 in types
  186. for mode in ['valid', 'full', 'same']]
  187. # These are random arrays, which means test is much stronger than
  188. # convolving testing by convolving two np.ones arrays
  189. rng = np.random.RandomState(42)
  190. array_types = {'i': rng.choice([0, 1], size=n),
  191. 'f': rng.randn(n)}
  192. array_types['b'] = array_types['u'] = array_types['i']
  193. array_types['c'] = array_types['f'] + 0.5j*array_types['f']
  194. for t1, t2, mode in args:
  195. x1 = array_types[np.dtype(t1).kind].astype(t1)
  196. x2 = array_types[np.dtype(t2).kind].astype(t2)
  197. results = {key: convolve(x1, x2, method=key, mode=mode)
  198. for key in ['fft', 'direct']}
  199. assert results['fft'].dtype == results['direct'].dtype
  200. if 'bool' in t1 and 'bool' in t2:
  201. assert choose_conv_method(x1, x2) == 'direct'
  202. continue
  203. # Found by experiment. Found approx smallest value for (rtol, atol)
  204. # threshold to have tests pass.
  205. if any([t in {'complex64', 'float32'} for t in [t1, t2]]):
  206. kwargs = {'rtol': 1.0e-4, 'atol': 1e-6}
  207. elif 'float16' in [t1, t2]:
  208. # atol is default for np.allclose
  209. kwargs = {'rtol': 1e-3, 'atol': 1e-3}
  210. else:
  211. # defaults for np.allclose (different from assert_allclose)
  212. kwargs = {'rtol': 1e-5, 'atol': 1e-8}
  213. xp_assert_close(results['fft'], results['direct'], **kwargs)
  214. @skip_xp_backends("jax.numpy", reason="dtypes do not match")
  215. def test_convolve_method_large_input(self, xp):
  216. # This is really a test that convolving two large integers goes to the
  217. # direct method even if they're in the fft method.
  218. for n in [10, 20, 50, 51, 52, 53, 54, 60, 62]:
  219. z = xp.asarray([2**n], dtype=xp.int64)
  220. fft = convolve(z, z, method='fft')
  221. direct = convolve(z, z, method='direct')
  222. # this is the case when integer precision gets to us
  223. # issue #6076 has more detail, hopefully more tests after resolved
  224. # # XXX: revisit check_dtype under np 2.0: 32bit linux & windows
  225. if n < 50:
  226. val = xp.asarray([2**(2*n)])
  227. xp_assert_equal(fft, direct)
  228. xp_assert_equal(fft, val, check_dtype=False)
  229. xp_assert_equal(direct, val, check_dtype=False)
  230. @skip_xp_backends(np_only=True)
  231. def test_mismatched_dims(self, xp):
  232. # Input arrays should have the same number of dimensions
  233. assert_raises(ValueError, convolve, [1], 2, method='direct')
  234. assert_raises(ValueError, convolve, 1, [2], method='direct')
  235. assert_raises(ValueError, convolve, [1], 2, method='fft')
  236. assert_raises(ValueError, convolve, 1, [2], method='fft')
  237. assert_raises(ValueError, convolve, [1], [[2]])
  238. assert_raises(ValueError, convolve, [3], 2)
  239. @make_xp_test_case(convolve2d)
  240. class TestConvolve2d:
  241. @skip_xp_backends("jax.numpy", reason="dtypes do not match")
  242. def test_2d_arrays(self, xp):
  243. a = xp.asarray([[1, 2, 3], [3, 4, 5]])
  244. b = xp.asarray([[2, 3, 4], [4, 5, 6]])
  245. d = xp.asarray([[2, 7, 16, 17, 12],
  246. [10, 30, 62, 58, 38],
  247. [12, 31, 58, 49, 30]])
  248. e = convolve2d(a, b)
  249. xp_assert_equal(e, d)
  250. @skip_xp_backends("jax.numpy", reason="dtypes do not match")
  251. def test_valid_mode(self, xp):
  252. e = xp.asarray([[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]])
  253. f = xp.asarray([[1, 2, 3], [3, 4, 5]])
  254. h = xp.asarray([[62, 80, 98, 116, 134]])
  255. g = convolve2d(e, f, 'valid')
  256. xp_assert_equal(g, h)
  257. # See gh-5897
  258. g = convolve2d(f, e, 'valid')
  259. xp_assert_equal(g, h)
  260. @skip_xp_backends("torch", reason="dtypes do not match")
  261. def test_valid_mode_complx(self, xp):
  262. e = xp.asarray([[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]])
  263. f = xp.asarray([[1, 2, 3], [3, 4, 5]], dtype=xp.complex128) + 1j
  264. h = xp.asarray([[62.+24.j, 80.+30.j, 98.+36.j, 116.+42.j, 134.+48.j]])
  265. g = convolve2d(e, f, 'valid')
  266. xp_assert_close(g, h)
  267. # See gh-5897
  268. g = convolve2d(f, e, 'valid')
  269. xp_assert_equal(g, h)
  270. @skip_xp_backends("jax.numpy", reason="jax only allows fillvalue=0")
  271. def test_fillvalue(self, xp):
  272. a = xp.asarray([[1, 2, 3], [3, 4, 5]])
  273. b = xp.asarray([[2, 3, 4], [4, 5, 6]])
  274. fillval = 1
  275. c = convolve2d(a, b, 'full', 'fill', fillval)
  276. d = xp.asarray([[24, 26, 31, 34, 32],
  277. [28, 40, 62, 64, 52],
  278. [32, 46, 67, 62, 48]])
  279. xp_assert_equal(c, d)
  280. def test_fillvalue_errors(self, xp):
  281. msg = "could not cast `fillvalue` directly to the output "
  282. with warnings.catch_warnings():
  283. warnings.filterwarnings("ignore", "Casting complex values", ComplexWarning)
  284. with assert_raises(ValueError, match=msg):
  285. convolve2d([[1]], [[1, 2]], fillvalue=1j)
  286. msg = "`fillvalue` must be scalar or an array with "
  287. with assert_raises(ValueError, match=msg):
  288. convolve2d([[1]], [[1, 2]], fillvalue=[1, 2])
  289. def test_fillvalue_empty(self, xp):
  290. # Check that fillvalue being empty raises an error:
  291. assert_raises(ValueError, convolve2d, [[1]], [[1, 2]],
  292. fillvalue=[])
  293. @skip_xp_backends("jax.numpy", reason="jax only supports boundary='fill'")
  294. def test_wrap_boundary(self, xp):
  295. a = xp.asarray([[1, 2, 3], [3, 4, 5]])
  296. b = xp.asarray([[2, 3, 4], [4, 5, 6]])
  297. c = convolve2d(a, b, 'full', 'wrap')
  298. d = xp.asarray([[80, 80, 74, 80, 80],
  299. [68, 68, 62, 68, 68],
  300. [80, 80, 74, 80, 80]])
  301. xp_assert_equal(c, d)
  302. @skip_xp_backends("jax.numpy", reason="jax only supports boundary='fill'")
  303. def test_sym_boundary(self, xp):
  304. a = xp.asarray([[1, 2, 3], [3, 4, 5]])
  305. b = xp.asarray([[2, 3, 4], [4, 5, 6]])
  306. c = convolve2d(a, b, 'full', 'symm')
  307. d = xp.asarray([[34, 30, 44, 62, 66],
  308. [52, 48, 62, 80, 84],
  309. [82, 78, 92, 110, 114]])
  310. xp_assert_equal(c, d)
  311. @skip_xp_backends("jax.numpy", reason="jax only supports boundary='fill'")
  312. @pytest.mark.parametrize('func', [convolve2d, correlate2d])
  313. @pytest.mark.parametrize('boundary, expected',
  314. [('symm', [[37.0, 42.0, 44.0, 45.0]]),
  315. ('wrap', [[43.0, 44.0, 42.0, 39.0]])])
  316. def test_same_with_boundary(self, func, boundary, expected, xp):
  317. # Test boundary='symm' and boundary='wrap' with a "long" kernel.
  318. # The size of the kernel requires that the values in the "image"
  319. # be extended more than once to handle the requested boundary method.
  320. # This is a regression test for gh-8684 and gh-8814.
  321. image = xp.asarray([[2.0, -1.0, 3.0, 4.0]])
  322. kernel = xp.ones((1, 21))
  323. result = func(image, kernel, mode='same', boundary=boundary)
  324. # The expected results were calculated "by hand". Because the
  325. # kernel is all ones, the same result is expected for convolve2d
  326. # and correlate2d.
  327. xp_assert_equal(result, xp.asarray(expected))
  328. @skip_xp_backends("jax.numpy", reason="jax only supports boundary='fill'")
  329. def test_boundary_extension_same(self, xp):
  330. # Regression test for gh-12686.
  331. # Use ndimage.convolve with appropriate arguments to create the
  332. # expected result.
  333. import scipy.ndimage as ndi
  334. a = xp.reshape(xp.arange(1, 10*3+1, dtype=xp.float64), (10, 3))
  335. b = xp.reshape(xp.arange(1, 10*10+1, dtype=xp.float64), (10, 10))
  336. c = convolve2d(a, b, mode='same', boundary='wrap')
  337. xp_assert_equal(c, ndi.convolve(a, b, mode='wrap', origin=(-1, -1)))
  338. @skip_xp_backends("jax.numpy", reason="jax only supports boundary='fill'")
  339. def test_boundary_extension_full(self, xp):
  340. # Regression test for gh-12686.
  341. # Use ndimage.convolve with appropriate arguments to create the
  342. # expected result.
  343. import scipy.ndimage as ndi
  344. a = xp.reshape(xp.arange(1, 3*3+1, dtype=xp.float64), (3, 3))
  345. b = xp.reshape(xp.arange(1, 6*6+1, dtype=xp.float64), (6, 6))
  346. c = convolve2d(a, b, mode='full', boundary='wrap')
  347. a_np = np.arange(1, 3*3 +1, dtype=float).reshape(3, 3)
  348. apad_np = np.pad(a_np, ((3, 3), (3, 3)), 'wrap')
  349. apad = xp.asarray(apad_np)
  350. xp_assert_equal(c, xp.asarray(ndi.convolve(apad, b, mode='wrap')[:-1, :-1]))
  351. def test_invalid_shapes(self, xp):
  352. # By "invalid," we mean that no one
  353. # array has dimensions that are all at
  354. # least as large as the corresponding
  355. # dimensions of the other array. This
  356. # setup should throw a ValueError.
  357. a = xp.reshape(xp.arange(1, 7), (2, 3))
  358. b = xp.reshape(xp.arange(-6, 0), (3, 2))
  359. assert_raises(ValueError, convolve2d, *(a, b), **{'mode': 'valid'})
  360. assert_raises(ValueError, convolve2d, *(b, a), **{'mode': 'valid'})
  361. @skip_xp_backends("jax.numpy",
  362. reason="jax returns floats; scipy returns ints; cf gh-6076")
  363. def test_same_mode(self, xp):
  364. e = xp.asarray([[1, 2, 3], [3, 4, 5]])
  365. f = xp.asarray([[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]])
  366. g = convolve2d(e, f, 'same')
  367. h = xp.asarray([[22, 28, 34],
  368. [80, 98, 116]])
  369. xp_assert_equal(g, h)
  370. @skip_xp_backends("jax.numpy",
  371. reason="jax returns floats; scipy returns ints; cf gh-6076")
  372. def test_valid_mode2(self, xp):
  373. # See gh-5897
  374. e = xp.asarray([[1, 2, 3], [3, 4, 5]])
  375. f = xp.asarray([[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]])
  376. expected = xp.asarray([[62, 80, 98, 116, 134]])
  377. out = convolve2d(e, f, 'valid')
  378. xp_assert_equal(out, expected)
  379. out = convolve2d(f, e, 'valid')
  380. xp_assert_equal(out, expected)
  381. e = xp.asarray([[1 + 1j, 2 - 3j], [3 + 1j, 4 + 0j]])
  382. f = xp.asarray([[2 - 1j, 3 + 2j, 4 + 0j], [4 - 0j, 5 + 1j, 6 - 3j]])
  383. expected = xp.asarray([[27 - 1j, 46. + 2j]])
  384. out = convolve2d(e, f, 'valid')
  385. xp_assert_equal(out, expected)
  386. # See gh-5897
  387. out = convolve2d(f, e, 'valid')
  388. xp_assert_equal(out, expected)
  389. @skip_xp_backends("torch",
  390. reason="only integer tensors of a single element can be converted"
  391. )
  392. def test_consistency_convolve_funcs(self, xp):
  393. # Compare np.convolve, signal.convolve, signal.convolve2d
  394. a = xp.arange(5)
  395. b = xp.asarray([3.2, 1.4, 3])
  396. a_np = _xp_copy_to_numpy(a)
  397. b_np = _xp_copy_to_numpy(b)
  398. for mode in ['full', 'valid', 'same']:
  399. xp_assert_close(
  400. xp.asarray(np.convolve(a_np, b_np, mode=mode)),
  401. signal.convolve(a, b, mode=mode)
  402. )
  403. xp_assert_close(
  404. xp.squeeze(
  405. signal.convolve2d(a[None, :], b[None, :], mode=mode),
  406. axis=0
  407. ),
  408. signal.convolve(a, b, mode=mode)
  409. )
  410. def test_invalid_dims(self, xp):
  411. assert_raises(ValueError, convolve2d, 3, 4)
  412. assert_raises(ValueError, convolve2d, [3], [4])
  413. assert_raises(ValueError, convolve2d, [[[3]]], [[[4]]])
  414. @pytest.mark.slow
  415. @pytest.mark.xfail_on_32bit("Can't create large array for test")
  416. @skip_xp_backends(np_only=True, reason="stride_tricks")
  417. def test_large_array(self, xp):
  418. # Test indexing doesn't overflow an int (gh-10761)
  419. n = 2**31 // (1000 * xp.int64().itemsize)
  420. _testutils.check_free_memory(2 * n * 1001 * np.int64().itemsize / 1e6)
  421. # Create a chequered pattern of 1s and 0s
  422. a = xp.zeros(1001 * n, dtype=xp.int64)
  423. a[::2] = 1
  424. a = np.lib.stride_tricks.as_strided(a, shape=(n, 1000), strides=(8008, 8))
  425. count = signal.convolve2d(a, [[1, 1]])
  426. fails = np.where(count > 1)
  427. assert fails[0].size == 0
  428. @make_xp_test_case(fftconvolve)
  429. class TestFFTConvolve:
  430. @skip_xp_backends("torch", reason="dtypes do not match")
  431. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  432. def test_real(self, axes, xp):
  433. a = xp.asarray([1, 2, 3])
  434. expected = xp.asarray([1, 4, 10, 12, 9.])
  435. if axes == '':
  436. out = fftconvolve(a, a)
  437. else:
  438. if isinstance(axes, list):
  439. axes = tuple(axes)
  440. out = fftconvolve(a, a, axes=axes)
  441. xp_assert_close(out, expected, atol=1.5e-6)
  442. @skip_xp_backends("torch", reason="dtypes do not match")
  443. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  444. def test_real_axes(self, axes, xp):
  445. a = xp.asarray([1, 2, 3])
  446. expected = xp.asarray([1, 4, 10, 12, 9.])
  447. a = xp.asarray(np.tile(a, [2, 1]))
  448. expected = xp.asarray(np.tile(expected, [2, 1]))
  449. if isinstance(axes, list):
  450. axes = tuple(axes)
  451. out = fftconvolve(a, a, axes=axes)
  452. xp_assert_close(out, expected, atol=1.5e-6)
  453. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  454. def test_complex(self, axes, xp):
  455. a = xp.asarray([1 + 1j, 2 + 2j, 3 + 3j])
  456. expected = xp.asarray([0 + 2j, 0 + 8j, 0 + 20j, 0 + 24j, 0 + 18j])
  457. if axes == '':
  458. out = fftconvolve(a, a)
  459. else:
  460. if isinstance(axes, list):
  461. axes = tuple(axes)
  462. out = fftconvolve(a, a, axes=axes)
  463. xp_assert_close(out, expected, atol=1.5e-6)
  464. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  465. def test_complex_axes(self, axes, xp):
  466. a = xp.asarray([1 + 1j, 2 + 2j, 3 + 3j])
  467. expected = xp.asarray([0 + 2j, 0 + 8j, 0 + 20j, 0 + 24j, 0 + 18j])
  468. a = xp.asarray(np.tile(a, [2, 1]))
  469. expected = xp.asarray(np.tile(expected, [2, 1]))
  470. if isinstance(axes, list):
  471. axes = tuple(axes)
  472. out = fftconvolve(a, a, axes=axes)
  473. xp_assert_close(out, expected, atol=1.5e-6)
  474. @pytest.mark.parametrize('axes', ['',
  475. None,
  476. [0, 1],
  477. [1, 0],
  478. [0, -1],
  479. [-1, 0],
  480. [-2, 1],
  481. [1, -2],
  482. [-2, -1],
  483. [-1, -2]])
  484. def test_2d_real_same(self, axes, xp):
  485. a = xp.asarray([[1.0, 2, 3],
  486. [4, 5, 6]])
  487. expected = xp.asarray([[1.0, 4, 10, 12, 9],
  488. [8, 26, 56, 54, 36],
  489. [16, 40, 73, 60, 36]])
  490. if axes == '':
  491. out = fftconvolve(a, a)
  492. else:
  493. if isinstance(axes, list):
  494. axes = tuple(axes)
  495. out = fftconvolve(a, a, axes=axes)
  496. xp_assert_close(out, expected)
  497. @pytest.mark.parametrize('axes', [[1, 2],
  498. [2, 1],
  499. [1, -1],
  500. [-1, 1],
  501. [-2, 2],
  502. [2, -2],
  503. [-2, -1],
  504. [-1, -2]])
  505. def test_2d_real_same_axes(self, axes, xp):
  506. a = xp.asarray([[1, 2, 3],
  507. [4, 5, 6]])
  508. expected = xp.asarray([[1, 4, 10, 12, 9],
  509. [8, 26, 56, 54, 36],
  510. [16, 40, 73, 60, 36]])
  511. a = xp.asarray(np.tile(a, [2, 1, 1]))
  512. expected = xp.asarray(np.tile(expected, [2, 1, 1]))
  513. if isinstance(axes, list):
  514. axes = tuple(axes)
  515. out = fftconvolve(a, a, axes=axes)
  516. xp_assert_close(out, expected, atol=1.5e-6, check_dtype=False)
  517. @pytest.mark.parametrize('axes', ['',
  518. None,
  519. [0, 1],
  520. [1, 0],
  521. [0, -1],
  522. [-1, 0],
  523. [-2, 1],
  524. [1, -2],
  525. [-2, -1],
  526. [-1, -2]])
  527. def test_2d_complex_same(self, axes, xp):
  528. a = xp.asarray([[1 + 2j, 3 + 4j, 5 + 6j],
  529. [2 + 1j, 4 + 3j, 6 + 5j]])
  530. expected = xp.asarray([
  531. [-3 + 4j, -10 + 20j, -21 + 56j, -18 + 76j, -11 + 60j],
  532. [10j, 44j, 118j, 156j, 122j],
  533. [3 + 4j, 10 + 20j, 21 + 56j, 18 + 76j, 11 + 60j]
  534. ])
  535. if axes == '':
  536. out = fftconvolve(a, a)
  537. else:
  538. if isinstance(axes, list):
  539. axes = tuple(axes)
  540. out = fftconvolve(a, a, axes=axes)
  541. xp_assert_close(out, expected, atol=1.5e-6)
  542. @pytest.mark.parametrize('axes', [[1, 2],
  543. [2, 1],
  544. [1, -1],
  545. [-1, 1],
  546. [-2, 2],
  547. [2, -2],
  548. [-2, -1],
  549. [-1, -2]])
  550. def test_2d_complex_same_axes(self, axes, xp):
  551. a = xp.asarray([[1 + 2j, 3 + 4j, 5 + 6j],
  552. [2 + 1j, 4 + 3j, 6 + 5j]])
  553. expected = xp.asarray([
  554. [-3 + 4j, -10 + 20j, -21 + 56j, -18 + 76j, -11 + 60j],
  555. [10j, 44j, 118j, 156j, 122j],
  556. [3 + 4j, 10 + 20j, 21 + 56j, 18 + 76j, 11 + 60j]
  557. ])
  558. a = xp.asarray(np.tile(a, [2, 1, 1]))
  559. expected = xp.asarray(np.tile(expected, [2, 1, 1]))
  560. if isinstance(axes, list):
  561. axes = tuple(axes)
  562. out = fftconvolve(a, a, axes=axes)
  563. xp_assert_close(out, expected, atol=1.5e-6)
  564. @skip_xp_backends("torch", reason="dtypes do not match")
  565. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  566. def test_real_same_mode(self, axes, xp):
  567. a = xp.asarray([1, 2, 3])
  568. b = xp.asarray([3, 3, 5, 6, 8, 7, 9, 0, 1])
  569. expected_1 = xp.asarray([35., 41., 47.])
  570. expected_2 = xp.asarray([9., 20., 25., 35., 41., 47., 39., 28., 2.])
  571. if axes == '':
  572. out = fftconvolve(a, b, 'same')
  573. else:
  574. if isinstance(axes, list):
  575. axes = tuple(axes)
  576. out = fftconvolve(a, b, 'same', axes=axes)
  577. xp_assert_close(out, expected_1)
  578. if axes == '':
  579. out = fftconvolve(b, a, 'same')
  580. else:
  581. if isinstance(axes, list):
  582. axes = tuple(axes)
  583. out = fftconvolve(b, a, 'same', axes=axes)
  584. xp_assert_close(out, expected_2, atol=1.5e-6)
  585. @skip_xp_backends("torch", reason="dtypes do not match")
  586. @pytest.mark.parametrize('axes', [1, -1, [1], [-1]])
  587. def test_real_same_mode_axes(self, axes, xp):
  588. a = xp.asarray([1, 2, 3])
  589. b = xp.asarray([3, 3, 5, 6, 8, 7, 9, 0, 1])
  590. expected_1 = xp.asarray([35., 41., 47.])
  591. expected_2 = xp.asarray([9., 20., 25., 35., 41., 47., 39., 28., 2.])
  592. a = xp.asarray(np.tile(a, [2, 1]))
  593. b = xp.asarray(np.tile(b, [2, 1]))
  594. expected_1 = xp.asarray(np.tile(expected_1, [2, 1]))
  595. expected_2 = xp.asarray(np.tile(expected_2, [2, 1]))
  596. if isinstance(axes, list):
  597. axes = tuple(axes)
  598. out = fftconvolve(a, b, 'same', axes=axes)
  599. xp_assert_close(out, expected_1, atol=1.5e-6)
  600. out = fftconvolve(b, a, 'same', axes=axes)
  601. xp_assert_close(out, expected_2, atol=1.5e-6)
  602. @skip_xp_backends("torch", reason="dtypes do not match")
  603. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  604. def test_valid_mode_real(self, axes, xp):
  605. # See gh-5897
  606. a = xp.asarray([3, 2, 1])
  607. b = xp.asarray([3, 3, 5, 6, 8, 7, 9, 0, 1])
  608. expected = xp.asarray([24., 31., 41., 43., 49., 25., 12.])
  609. if axes == '':
  610. out = fftconvolve(a, b, 'valid')
  611. else:
  612. if isinstance(axes, list):
  613. axes = tuple(axes)
  614. out = fftconvolve(a, b, 'valid', axes=axes)
  615. xp_assert_close(out, expected, atol=1.5e-6)
  616. if axes == '':
  617. out = fftconvolve(b, a, 'valid')
  618. else:
  619. if isinstance(axes, list):
  620. axes = tuple(axes)
  621. out = fftconvolve(b, a, 'valid', axes=axes)
  622. xp_assert_close(out, expected, atol=1.5e-6)
  623. @skip_xp_backends("torch", reason="dtypes do not match")
  624. @pytest.mark.parametrize('axes', [1, [1]])
  625. def test_valid_mode_real_axes(self, axes, xp):
  626. # See gh-5897
  627. a = xp.asarray([3, 2, 1])
  628. b = xp.asarray([3, 3, 5, 6, 8, 7, 9, 0, 1])
  629. expected = xp.asarray([24., 31., 41., 43., 49., 25., 12.])
  630. a = xp.asarray(np.tile(a, [2, 1]))
  631. b = xp.asarray(np.tile(b, [2, 1]))
  632. expected = xp.asarray(np.tile(expected, [2, 1]))
  633. if isinstance(axes, list):
  634. axes = tuple(axes)
  635. out = fftconvolve(a, b, 'valid', axes=axes)
  636. xp_assert_close(out, expected, atol=1.5e-6)
  637. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  638. def test_valid_mode_complex(self, axes, xp):
  639. a = xp.asarray([3 - 1j, 2 + 7j, 1 + 0j])
  640. b = xp.asarray([3 + 2j, 3 - 3j, 5 + 0j, 6 - 1j, 8 + 0j])
  641. expected = xp.asarray([45. + 12.j, 30. + 23.j, 48 + 32.j])
  642. if axes == '':
  643. out = fftconvolve(a, b, 'valid')
  644. else:
  645. if isinstance(axes, list):
  646. axes = tuple(axes)
  647. out = fftconvolve(a, b, 'valid', axes=axes)
  648. xp_assert_close(out, expected, atol=1.5e-6)
  649. if axes == '':
  650. out = fftconvolve(b, a, 'valid')
  651. else:
  652. if isinstance(axes, list):
  653. axes = tuple(axes)
  654. out = fftconvolve(b, a, 'valid', axes=axes)
  655. xp_assert_close(out, expected, atol=1.5e-6)
  656. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  657. def test_valid_mode_complex_axes(self, axes, xp):
  658. a = xp.asarray([3 - 1j, 2 + 7j, 1 + 0j])
  659. b = xp.asarray([3 + 2j, 3 - 3j, 5 + 0j, 6 - 1j, 8 + 0j])
  660. expected = xp.asarray([45. + 12.j, 30. + 23.j, 48 + 32.j])
  661. a = xp.asarray(np.tile(a, [2, 1]))
  662. b = xp.asarray(np.tile(b, [2, 1]))
  663. expected = xp.asarray(np.tile(expected, [2, 1]))
  664. if isinstance(axes, list):
  665. axes = tuple(axes)
  666. out = fftconvolve(a, b, 'valid', axes=axes)
  667. xp_assert_close(out, expected, atol=1.5e-6)
  668. out = fftconvolve(b, a, 'valid', axes=axes)
  669. xp_assert_close(out, expected, atol=1.5e-6)
  670. @skip_xp_backends("jax.numpy", reason="mapped axes must have same shape")
  671. @skip_xp_backends("torch", reason="dtypes do not match")
  672. def test_valid_mode_ignore_nonaxes(self, xp):
  673. # See gh-5897
  674. a = xp.asarray([3, 2, 1])
  675. b = xp.asarray([3, 3, 5, 6, 8, 7, 9, 0, 1])
  676. expected = xp.asarray([24., 31., 41., 43., 49., 25., 12.])
  677. a = xp.asarray(np.tile(a, [2, 1]))
  678. b = xp.asarray(np.tile(b, [1, 1]))
  679. expected = xp.asarray(np.tile(expected, [2, 1]))
  680. out = fftconvolve(a, b, 'valid', axes=1)
  681. xp_assert_close(out, expected, atol=1.5e-6)
  682. @xfail_xp_backends("cupy", reason="dtypes do not match")
  683. @xfail_xp_backends("jax.numpy", reason="assorted error messages")
  684. @pytest.mark.parametrize("a,b", [([], []), ([5, 6], []), ([], [7])])
  685. def test_empty(self, a, b, xp):
  686. # Regression test for #1745: crashes with 0-length input.
  687. xp_assert_equal(
  688. fftconvolve(xp.asarray(a), xp.asarray(b)),
  689. xp.asarray([]),
  690. )
  691. @skip_xp_backends("jax.numpy", reason="jnp.pad: pad_width with nd=0")
  692. def test_zero_rank(self, xp):
  693. a = xp.asarray(4967)
  694. b = xp.asarray(3920)
  695. out = fftconvolve(a, b)
  696. xp_assert_equal(out, a * b)
  697. def test_single_element(self, xp):
  698. a = xp.asarray([4967])
  699. b = xp.asarray([3920])
  700. out = fftconvolve(a, b)
  701. xp_assert_equal(out,
  702. xp.asarray(a * b, dtype=out.dtype))
  703. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  704. def test_random_data(self, axes, xp):
  705. rng = np.random.default_rng(1234)
  706. a_np = np.random.rand(1233) + 1j * rng.random(1233)
  707. b_np = np.random.rand(1321) + 1j * rng.random(1321)
  708. expected = xp.asarray(np.convolve(a_np, b_np, 'full'))
  709. a = xp.asarray(a_np)
  710. b = xp.asarray(b_np)
  711. if axes == '':
  712. out = fftconvolve(a, b, 'full')
  713. else:
  714. if isinstance(axes, list):
  715. axes = tuple(axes)
  716. out = fftconvolve(a, b, 'full', axes=axes)
  717. xp_assert_close(out, expected, rtol=1e-10)
  718. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  719. def test_random_data_axes(self, axes, xp):
  720. rng = np.random.default_rng(1234)
  721. a_np = np.random.rand(1233) + 1j * rng.random(1233)
  722. b_np = np.random.rand(1321) + 1j * rng.random(1321)
  723. expected = np.convolve(a_np, b_np, 'full')
  724. a_np = np.tile(a_np, [2, 1])
  725. b_np = np.tile(b_np, [2, 1])
  726. expected = xp.asarray(np.tile(expected, [2, 1]))
  727. a = xp.asarray(a_np)
  728. b = xp.asarray(b_np)
  729. if isinstance(axes, list):
  730. axes = tuple(axes)
  731. out = fftconvolve(a, b, 'full', axes=axes)
  732. xp_assert_close(out, expected, rtol=1e-10)
  733. @xfail_xp_backends(np_only=True, reason="TODO: swapaxes")
  734. @pytest.mark.parametrize('axes', [[1, 4],
  735. [4, 1],
  736. [1, -1],
  737. [-1, 1],
  738. [-4, 4],
  739. [4, -4],
  740. [-4, -1],
  741. [-1, -4]])
  742. def test_random_data_multidim_axes(self, axes, xp):
  743. a_shape, b_shape = (123, 22), (132, 11)
  744. rng = np.random.default_rng(1234)
  745. a = xp.asarray(np.random.rand(*a_shape) + 1j * rng.random(a_shape))
  746. b = xp.asarray(np.random.rand(*b_shape) + 1j * rng.random(b_shape))
  747. expected = convolve2d(a, b, 'full')
  748. a = a[:, :, None, None, None]
  749. b = b[:, :, None, None, None]
  750. expected = expected[:, :, None, None, None]
  751. a = xp.moveaxis(a.swapaxes(0, 2), 1, 4)
  752. b = xp.moveaxis(b.swapaxes(0, 2), 1, 4)
  753. expected = xp.moveaxis(expected.swapaxes(0, 2), 1, 4)
  754. # use 1 for dimension 2 in a and 3 in b to test broadcasting
  755. a = xp.asarray(np.tile(a, [2, 1, 3, 1, 1]))
  756. b = xp.asarray(np.tile(b, [2, 1, 1, 4, 1]))
  757. expected = xp.asarray(np.tile(expected, [2, 1, 3, 4, 1]))
  758. out = fftconvolve(a, b, 'full', axes=axes)
  759. xp_assert_close(out, expected, rtol=1e-10, atol=1e-10)
  760. @pytest.mark.slow
  761. @pytest.mark.parametrize(
  762. 'n',
  763. list(range(1, 100)) +
  764. list(range(1000, 1500)) +
  765. np.random.RandomState(1234).randint(1001, 10000, 5).tolist())
  766. def test_many_sizes(self, n, xp):
  767. a_np = np.random.rand(n) + 1j * np.random.rand(n)
  768. b_np = np.random.rand(n) + 1j * np.random.rand(n)
  769. expected = xp.asarray(np.convolve(a_np, b_np, 'full'))
  770. a = xp.asarray(a_np)
  771. b = xp.asarray(b_np)
  772. out = fftconvolve(a, b, 'full')
  773. xp_assert_close(out, expected, atol=1e-10)
  774. out = fftconvolve(a, b, 'full', axes=(0,))
  775. xp_assert_close(out, expected, atol=1e-10)
  776. @skip_xp_backends(np_only=True)
  777. def test_fft_nan(self, xp):
  778. n = 1000
  779. rng = np.random.default_rng(43876432987)
  780. sig_nan = xp.asarray(rng.standard_normal(n))
  781. for val in [np.nan, np.inf]:
  782. sig_nan[100] = val
  783. coeffs = xp.asarray(signal.firwin(200, 0.2))
  784. msg = "Use of fft convolution.*|invalid value encountered.*"
  785. with pytest.warns(RuntimeWarning, match=msg):
  786. signal.convolve(sig_nan, coeffs, mode='same', method='fft')
  787. def fftconvolve_err(*args, **kwargs):
  788. raise RuntimeError('Fell back to fftconvolve')
  789. def gen_oa_shapes(sizes):
  790. return [(a, b) for a, b in product(sizes, repeat=2)
  791. if abs(a - b) > 3]
  792. def gen_oa_shapes_2d(sizes):
  793. shapes0 = gen_oa_shapes(sizes)
  794. shapes1 = gen_oa_shapes(sizes)
  795. shapes = [ishapes0+ishapes1 for ishapes0, ishapes1 in
  796. zip(shapes0, shapes1)]
  797. modes = ['full', 'valid', 'same']
  798. return [ishapes+(imode,) for ishapes, imode in product(shapes, modes)
  799. if imode != 'valid' or
  800. (ishapes[0] > ishapes[1] and ishapes[2] > ishapes[3]) or
  801. (ishapes[0] < ishapes[1] and ishapes[2] < ishapes[3])]
  802. def gen_oa_shapes_eq(sizes):
  803. return [(a, b) for a, b in product(sizes, repeat=2)
  804. if a >= b]
  805. @make_xp_test_case(oaconvolve)
  806. class TestOAConvolve:
  807. @pytest.mark.slow()
  808. @pytest.mark.parametrize('shape_a_0, shape_b_0',
  809. gen_oa_shapes_eq(list(range(1, 100, 1)) +
  810. list(range(100, 1000, 23)))
  811. )
  812. def test_real_manylens(self, shape_a_0, shape_b_0, xp):
  813. a = np.random.rand(shape_a_0)
  814. b = np.random.rand(shape_b_0)
  815. expected = xp.asarray(fftconvolve(a, b))
  816. a = xp.asarray(a)
  817. b = xp.asarray(b)
  818. out = oaconvolve(a, b)
  819. assert_array_almost_equal(out, expected)
  820. @pytest.mark.parametrize('shape_a_0, shape_b_0',
  821. gen_oa_shapes([50, 47, 6, 4, 1]))
  822. @pytest.mark.parametrize('is_complex', [True, False])
  823. @pytest.mark.parametrize('mode', ['full', 'valid', 'same'])
  824. def test_1d_noaxes(self, shape_a_0, shape_b_0,
  825. is_complex, mode, monkeypatch, xp):
  826. a = np.random.rand(shape_a_0)
  827. b = np.random.rand(shape_b_0)
  828. if is_complex:
  829. a = a + 1j*np.random.rand(shape_a_0)
  830. b = b + 1j*np.random.rand(shape_b_0)
  831. expected = xp.asarray(fftconvolve(a, b, mode=mode))
  832. a = xp.asarray(a)
  833. b = xp.asarray(b)
  834. monkeypatch.setattr(signal._signaltools, 'fftconvolve',
  835. fftconvolve_err)
  836. out = oaconvolve(a, b, mode=mode)
  837. assert_array_almost_equal(out, expected)
  838. @pytest.mark.parametrize('axes', [0, 1])
  839. @pytest.mark.parametrize('shape_a_0, shape_b_0',
  840. gen_oa_shapes([50, 47, 6, 4]))
  841. @pytest.mark.parametrize('shape_a_extra', [1, 3])
  842. @pytest.mark.parametrize('shape_b_extra', [1, 3])
  843. @pytest.mark.parametrize('is_complex', [True, False])
  844. @pytest.mark.parametrize('mode', ['full', 'valid', 'same'])
  845. def test_1d_axes(self, axes, shape_a_0, shape_b_0,
  846. shape_a_extra, shape_b_extra,
  847. is_complex, mode, monkeypatch, xp):
  848. ax_a = [shape_a_extra]*2
  849. ax_b = [shape_b_extra]*2
  850. ax_a[axes] = shape_a_0
  851. ax_b[axes] = shape_b_0
  852. a = np.random.rand(*ax_a)
  853. b = np.random.rand(*ax_b)
  854. if is_complex:
  855. a = a + 1j*np.random.rand(*ax_a)
  856. b = b + 1j*np.random.rand(*ax_b)
  857. expected = xp.asarray(fftconvolve(a, b, mode=mode, axes=axes))
  858. a = xp.asarray(a)
  859. b = xp.asarray(b)
  860. monkeypatch.setattr(signal._signaltools, 'fftconvolve',
  861. fftconvolve_err)
  862. out = oaconvolve(a, b, mode=mode, axes=axes)
  863. assert_array_almost_equal(out, expected)
  864. @pytest.mark.parametrize('shape_a_0, shape_b_0, '
  865. 'shape_a_1, shape_b_1, mode',
  866. gen_oa_shapes_2d([50, 47, 6, 4]))
  867. @pytest.mark.parametrize('is_complex', [True, False])
  868. def test_2d_noaxes(self, shape_a_0, shape_b_0,
  869. shape_a_1, shape_b_1, mode,
  870. is_complex, monkeypatch, xp):
  871. a = np.random.rand(shape_a_0, shape_a_1)
  872. b = np.random.rand(shape_b_0, shape_b_1)
  873. if is_complex:
  874. a = a + 1j*np.random.rand(shape_a_0, shape_a_1)
  875. b = b + 1j*np.random.rand(shape_b_0, shape_b_1)
  876. expected = xp.asarray(fftconvolve(a, b, mode=mode))
  877. a = xp.asarray(a)
  878. b = xp.asarray(b)
  879. monkeypatch.setattr(signal._signaltools, 'fftconvolve',
  880. fftconvolve_err)
  881. out = oaconvolve(a, b, mode=mode)
  882. assert_array_almost_equal(out, expected)
  883. @pytest.mark.parametrize('axes', [[0, 1], [0, 2], [1, 2]])
  884. @pytest.mark.parametrize('shape_a_0, shape_b_0, '
  885. 'shape_a_1, shape_b_1, mode',
  886. gen_oa_shapes_2d([50, 47, 6, 4]))
  887. @pytest.mark.parametrize('shape_a_extra', [1, 3])
  888. @pytest.mark.parametrize('shape_b_extra', [1, 3])
  889. @pytest.mark.parametrize('is_complex', [True, False])
  890. def test_2d_axes(self, axes, shape_a_0, shape_b_0,
  891. shape_a_1, shape_b_1, mode,
  892. shape_a_extra, shape_b_extra,
  893. is_complex, monkeypatch, xp):
  894. ax_a = [shape_a_extra]*3
  895. ax_b = [shape_b_extra]*3
  896. ax_a[axes[0]] = shape_a_0
  897. ax_b[axes[0]] = shape_b_0
  898. ax_a[axes[1]] = shape_a_1
  899. ax_b[axes[1]] = shape_b_1
  900. a = np.random.rand(*ax_a)
  901. b = np.random.rand(*ax_b)
  902. if is_complex:
  903. a = a + 1j*np.random.rand(*ax_a)
  904. b = b + 1j*np.random.rand(*ax_b)
  905. axes = tuple(axes) # XXX for CuPy
  906. expected = xp.asarray(fftconvolve(a, b, mode=mode, axes=axes))
  907. a = xp.asarray(a)
  908. b = xp.asarray(b)
  909. monkeypatch.setattr(signal._signaltools, 'fftconvolve',
  910. fftconvolve_err)
  911. out = oaconvolve(a, b, mode=mode, axes=axes)
  912. assert_array_almost_equal(out, expected)
  913. @xfail_xp_backends("torch", reason="ValueError: Target length must be positive")
  914. @pytest.mark.parametrize("a,b", [([], []), ([5, 6], []), ([], [7])])
  915. def test_empty(self, a, b, xp):
  916. # Regression test for #1745: crashes with 0-length input.
  917. xp_assert_equal(
  918. oaconvolve(xp.asarray(a), xp.asarray(b)),
  919. xp.asarray([]), check_dtype=False
  920. )
  921. def test_zero_rank(self, xp):
  922. a = xp.asarray(4967)
  923. b = xp.asarray(3920)
  924. out = oaconvolve(a, b)
  925. xp_assert_equal(out, a * b)
  926. def test_single_element(self, xp):
  927. a = xp.asarray([4967])
  928. b = xp.asarray([3920])
  929. out = oaconvolve(a, b)
  930. xp_assert_equal(out, a * b)
  931. @skip_xp_backends(np_only=True, reason="assertions may differ on backends")
  932. @pytest.mark.parametrize('convapproach',
  933. [make_xp_pytest_param(fftconvolve),
  934. make_xp_pytest_param(oaconvolve)])
  935. class TestAllFreqConvolves:
  936. def test_invalid_shapes(self, convapproach, xp):
  937. a = np.arange(1, 7).reshape((2, 3))
  938. b = np.arange(-6, 0).reshape((3, 2))
  939. with assert_raises(ValueError,
  940. match="For 'valid' mode, one must be at least "
  941. "as large as the other in every dimension"):
  942. convapproach(a, b, mode='valid')
  943. def test_invalid_shapes_axes(self, convapproach, xp):
  944. a = np.zeros([5, 6, 2, 1])
  945. b = np.zeros([5, 6, 3, 1])
  946. with assert_raises(ValueError,
  947. match=r"incompatible shapes for in1 and in2:"
  948. r" \(5L?, 6L?, 2L?, 1L?\) and"
  949. r" \(5L?, 6L?, 3L?, 1L?\)"):
  950. convapproach(a, b, axes=[0, 1])
  951. @pytest.mark.parametrize('a,b',
  952. [([1], 2),
  953. (1, [2]),
  954. ([3], [[2]])])
  955. def test_mismatched_dims(self, a, b, convapproach, xp):
  956. with assert_raises(ValueError,
  957. match="in1 and in2 should have the same"
  958. " dimensionality"):
  959. convapproach(a, b)
  960. def test_invalid_flags(self, convapproach, xp):
  961. with assert_raises(ValueError,
  962. match="acceptable mode flags are 'valid',"
  963. " 'same', or 'full'"):
  964. convapproach([1], [2], mode='chips')
  965. with assert_raises(ValueError,
  966. match="when provided, axes cannot be empty"):
  967. convapproach([1], [2], axes=[])
  968. with assert_raises(ValueError, match="axes must be a scalar or "
  969. "iterable of integers"):
  970. convapproach([1], [2], axes=[[1, 2], [3, 4]])
  971. with assert_raises(ValueError, match="axes must be a scalar or "
  972. "iterable of integers"):
  973. convapproach([1], [2], axes=[1., 2., 3., 4.])
  974. with assert_raises(ValueError,
  975. match="axes exceeds dimensionality of input"):
  976. convapproach([1], [2], axes=[1])
  977. with assert_raises(ValueError,
  978. match="axes exceeds dimensionality of input"):
  979. convapproach([1], [2], axes=[-2])
  980. with assert_raises(ValueError,
  981. match="all axes must be unique"):
  982. convapproach([1], [2], axes=[0, 0])
  983. @skip_xp_backends(np_only=True, reason="assertions may differ on backends")
  984. @pytest.mark.filterwarnings('ignore::DeprecationWarning')
  985. @pytest.mark.parametrize('dtype', [np.longdouble, np.clongdouble])
  986. @make_xp_test_case(convolve, fftconvolve)
  987. def test_convolve_longdtype_input(dtype, xp):
  988. x = np.random.random((27, 27)).astype(dtype)
  989. y = np.random.random((4, 4)).astype(dtype)
  990. if np.iscomplexobj(dtype()):
  991. x += .1j
  992. y -= .1j
  993. res = fftconvolve(x, y)
  994. xp_assert_close(res, convolve(x, y, method='direct'))
  995. assert res.dtype == dtype
  996. class TestMedFilt:
  997. IN = [[50, 50, 50, 50, 50, 92, 18, 27, 65, 46],
  998. [50, 50, 50, 50, 50, 0, 72, 77, 68, 66],
  999. [50, 50, 50, 50, 50, 46, 47, 19, 64, 77],
  1000. [50, 50, 50, 50, 50, 42, 15, 29, 95, 35],
  1001. [50, 50, 50, 50, 50, 46, 34, 9, 21, 66],
  1002. [70, 97, 28, 68, 78, 77, 61, 58, 71, 42],
  1003. [64, 53, 44, 29, 68, 32, 19, 68, 24, 84],
  1004. [3, 33, 53, 67, 1, 78, 74, 55, 12, 83],
  1005. [7, 11, 46, 70, 60, 47, 24, 43, 61, 26],
  1006. [32, 61, 88, 7, 39, 4, 92, 64, 45, 61]]
  1007. OUT = [[0, 50, 50, 50, 42, 15, 15, 18, 27, 0],
  1008. [0, 50, 50, 50, 50, 42, 19, 21, 29, 0],
  1009. [50, 50, 50, 50, 50, 47, 34, 34, 46, 35],
  1010. [50, 50, 50, 50, 50, 50, 42, 47, 64, 42],
  1011. [50, 50, 50, 50, 50, 50, 46, 55, 64, 35],
  1012. [33, 50, 50, 50, 50, 47, 46, 43, 55, 26],
  1013. [32, 50, 50, 50, 50, 47, 46, 45, 55, 26],
  1014. [7, 46, 50, 50, 47, 46, 46, 43, 45, 21],
  1015. [0, 32, 33, 39, 32, 32, 43, 43, 43, 0],
  1016. [0, 7, 11, 7, 4, 4, 19, 19, 24, 0]]
  1017. KERNEL_SIZE = [7,3]
  1018. @make_xp_test_case(signal.medfilt, signal.medfilt2d)
  1019. def test_basic(self, xp):
  1020. in_ = xp.asarray(self.IN)
  1021. out_ = xp.asarray(self.OUT)
  1022. kernel_size = xp.asarray(self.KERNEL_SIZE)
  1023. d = signal.medfilt(in_, kernel_size)
  1024. e = signal.medfilt2d(xp.asarray(in_, dtype=xp.float64), kernel_size)
  1025. xp_assert_equal(d, out_)
  1026. xp_assert_equal(d, e, check_dtype=False)
  1027. @pytest.mark.parametrize('dtype', ["uint8", "int8", "uint16", "int16",
  1028. "uint32", "int32", "uint64", "int64",
  1029. "float32", "float64"])
  1030. @make_xp_test_case(signal.medfilt, signal.medfilt2d)
  1031. def test_types(self, dtype, xp):
  1032. # volume input and output types match
  1033. if is_torch(xp) and dtype in ["uint16", "uint32", "uint64"]:
  1034. pytest.skip("torch does not support unisigned ints")
  1035. dtype = getattr(xp, dtype)
  1036. in_typed = xp.asarray(self.IN, dtype=dtype)
  1037. assert signal.medfilt(in_typed).dtype == dtype
  1038. assert signal.medfilt2d(in_typed).dtype == dtype
  1039. @skip_xp_backends(np_only=True, reason="assertions may differ")
  1040. @pytest.mark.parametrize('dtype', [np.bool_, np.complex64, np.complex128,
  1041. np.clongdouble, np.float16,
  1042. "float96", "float128"])
  1043. @make_xp_test_case(signal.medfilt, signal.medfilt2d)
  1044. def test_invalid_dtypes(self, dtype, xp):
  1045. # We can only test this on platforms that support a native type of float96 or
  1046. # float128; comparing to np.longdouble allows us to filter out non-native types
  1047. if (dtype in ["float96", "float128"]
  1048. and np.finfo(np.longdouble).dtype != dtype):
  1049. pytest.skip(f"Platform does not support {dtype}")
  1050. in_typed = np.array(self.IN, dtype=dtype)
  1051. with pytest.raises(ValueError, match="not supported"):
  1052. signal.medfilt(in_typed)
  1053. with pytest.raises(ValueError, match="not supported"):
  1054. signal.medfilt2d(in_typed)
  1055. @skip_xp_backends(np_only=True, reason="object arrays")
  1056. @make_xp_test_case(signal.medfilt)
  1057. def test_none(self, xp):
  1058. # gh-1651, trac #1124. Ensure this does not segfault.
  1059. with assert_raises((ValueError, TypeError)):
  1060. signal.medfilt(None)
  1061. @skip_xp_backends(np_only=True, reason="strides are only writeable in NumPy")
  1062. @make_xp_test_case(signal.medfilt)
  1063. def test_odd_strides(self, xp):
  1064. # Avoid a regression with possible contiguous
  1065. # numpy arrays that have odd strides. The stride value below gets
  1066. # us into wrong memory if used (but it does not need to be used)
  1067. dummy = xp.arange(10, dtype=xp.float64)
  1068. a = dummy[5:6]
  1069. a = np.lib.stride_tricks.as_strided(a, strides=(16,))
  1070. xp_assert_close(signal.medfilt(a, 1), xp.asarray([5.]))
  1071. @skip_xp_backends(
  1072. "jax.numpy",
  1073. reason="chunk assignment does not work on jax immutable arrays"
  1074. )
  1075. @pytest.mark.parametrize("dtype", ["uint8", "float32", "float64"])
  1076. @make_xp_test_case(signal.medfilt2d)
  1077. def test_medfilt2d_parallel(self, dtype, xp):
  1078. dtype = getattr(xp, dtype)
  1079. in_typed = xp.asarray(self.IN, dtype=dtype)
  1080. expected = xp.asarray(self.OUT, dtype=dtype)
  1081. # This is used to simplify the indexing calculations.
  1082. assert in_typed.shape == expected.shape
  1083. # We'll do the calculation in four chunks. M1 and N1 are the dimensions
  1084. # of the first output chunk. We have to extend the input by half the
  1085. # kernel size to be able to calculate the full output chunk.
  1086. M1 = expected.shape[0] // 2
  1087. N1 = expected.shape[1] // 2
  1088. offM = self.KERNEL_SIZE[0] // 2 + 1
  1089. offN = self.KERNEL_SIZE[1] // 2 + 1
  1090. def apply(chunk):
  1091. # in = slice of in_typed to use.
  1092. # sel = slice of output to crop it to the correct region.
  1093. # out = slice of output array to store in.
  1094. M, N = chunk
  1095. if M == 0:
  1096. Min = slice(0, M1 + offM)
  1097. Msel = slice(0, -offM)
  1098. Mout = slice(0, M1)
  1099. else:
  1100. Min = slice(M1 - offM, None)
  1101. Msel = slice(offM, None)
  1102. Mout = slice(M1, None)
  1103. if N == 0:
  1104. Nin = slice(0, N1 + offN)
  1105. Nsel = slice(0, -offN)
  1106. Nout = slice(0, N1)
  1107. else:
  1108. Nin = slice(N1 - offN, None)
  1109. Nsel = slice(offN, None)
  1110. Nout = slice(N1, None)
  1111. # Do the calculation, but do not write to the output in the threads.
  1112. chunk_data = in_typed[Min, Nin]
  1113. med = signal.medfilt2d(chunk_data, self.KERNEL_SIZE)
  1114. return med[Msel, Nsel], Mout, Nout
  1115. # Give each chunk to a different thread.
  1116. output = xp.zeros_like(expected)
  1117. with ThreadPoolExecutor(max_workers=4) as pool:
  1118. chunks = {(0, 0), (0, 1), (1, 0), (1, 1)}
  1119. futures = {pool.submit(apply, chunk) for chunk in chunks}
  1120. # Store each result in the output as it arrives.
  1121. for future in as_completed(futures):
  1122. data, Mslice, Nslice = future.result()
  1123. output[Mslice, Nslice] = data
  1124. xp_assert_equal(output, expected)
  1125. @make_xp_test_case(signal.wiener)
  1126. class TestWiener:
  1127. @skip_xp_backends("cupy", reason="XXX: can_cast in cupy <= 13.2")
  1128. def test_basic(self, xp):
  1129. g = xp.asarray([[5, 6, 4, 3],
  1130. [3, 5, 6, 2],
  1131. [2, 3, 5, 6],
  1132. [1, 6, 9, 7]], dtype=xp.float64)
  1133. h = xp.asarray([[2.16374269, 3.2222222222, 2.8888888889, 1.6666666667],
  1134. [2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],
  1135. [2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],
  1136. [1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
  1137. assert_array_almost_equal(signal.wiener(g), h, decimal=6)
  1138. assert_array_almost_equal(signal.wiener(g, mysize=3), h, decimal=6)
  1139. padtype_options = ["mean", "median", "minimum", "maximum", "line"]
  1140. padtype_options += _upfirdn_modes
  1141. class TestResample:
  1142. @make_xp_test_case(signal.resample, signal.resample_poly)
  1143. @xfail_xp_backends("cupy", reason="does not raise with non-int upsampling factor")
  1144. def test_basic(self, xp):
  1145. # Some basic tests
  1146. # Regression test for issue #3603.
  1147. # window.shape must equal to sig.shape[0]
  1148. sig = xp.arange(128, dtype=xp.float64)
  1149. num = 256
  1150. win = signal.get_window(('kaiser', 8.0), 160, xp=xp)
  1151. assert_raises(ValueError, signal.resample, sig, num, window=win)
  1152. assert_raises(ValueError, signal.resample, sig, num, domain='INVALID')
  1153. # Other degenerate conditions
  1154. assert_raises(ValueError, signal.resample_poly, sig, 'yo', 1)
  1155. assert_raises(ValueError, signal.resample_poly, sig, 1, 0)
  1156. assert_raises(ValueError, signal.resample_poly, sig, 1.3, 2)
  1157. assert_raises(ValueError, signal.resample_poly, sig, 2, 1.3)
  1158. assert_raises(ValueError, signal.resample_poly, sig, 2, 1, padtype='')
  1159. assert_raises(ValueError, signal.resample_poly, sig, 2, 1,
  1160. padtype='mean', cval=10)
  1161. assert_raises(ValueError, signal.resample_poly, sig, 2, 1, window=xp.eye(2))
  1162. # test for issue #6505 - should not modify window.shape when axis ≠ 0
  1163. sig2 = xp.tile(xp.arange(160, dtype=xp.float64), (2, 1))
  1164. signal.resample(sig2, num, axis=-1, window=win)
  1165. assert win.shape == (160,)
  1166. # Ensure coverage for parameter cval=None and cval != None:
  1167. x_ref = signal.resample_poly(sig, 2, 1)
  1168. x0 = signal.resample_poly(sig, 2, 1, padtype='constant')
  1169. x1 = signal.resample_poly(sig, 2, 1, padtype='constant', cval=0)
  1170. xp_assert_equal(x1, x_ref)
  1171. xp_assert_equal(x0, x_ref)
  1172. @pytest.mark.parametrize('window', (None, 'hamming'))
  1173. @pytest.mark.parametrize('N', (20, 19))
  1174. @pytest.mark.parametrize('num', (100, 101, 10, 11))
  1175. @make_xp_test_case(signal.resample)
  1176. def test_rfft(self, N, num, window, xp):
  1177. # Make sure the speed up using rfft gives the same result as the normal
  1178. # way using fft
  1179. dt_r = xp_default_dtype(xp)
  1180. dt_c = xp.complex64 if dt_r == xp.float32 else xp.complex128
  1181. x = xp.linspace(0, 10, N, endpoint=False)
  1182. y = xp.cos(-x**2/6.0)
  1183. desired = signal.resample(xp.astype(y, dt_c), num, window=window)
  1184. xp_assert_close(signal.resample(y, num, window=window),
  1185. xp.real(desired))
  1186. y = xp.stack([xp.cos(-x**2/6.0), xp.sin(-x**2/6.0)])
  1187. y_complex = xp.astype(y, dt_c)
  1188. resampled = signal.resample(y_complex, num, axis=1, window=window)
  1189. atol = 1e-9 if dt_r == xp.float64 else 3e-7
  1190. xp_assert_close(
  1191. signal.resample(y, num, axis=1, window=window),
  1192. xp.real(resampled),
  1193. atol=atol)
  1194. @make_xp_test_case(signal.resample)
  1195. def test_input_domain(self, xp):
  1196. # Test if both input domain modes produce the same results.
  1197. tsig = xp.astype(xp.arange(256), xp.complex128)
  1198. fsig = sp_fft.fft(tsig)
  1199. num = 256
  1200. xp_assert_close(
  1201. signal.resample(fsig, num, domain='freq'),
  1202. signal.resample(tsig, num, domain='time'),
  1203. atol=1e-9)
  1204. @pytest.mark.parametrize('nx', (1, 2, 3, 5, 8))
  1205. @pytest.mark.parametrize('ny', (1, 2, 3, 5, 8))
  1206. @pytest.mark.parametrize('dtype', ('float64', 'complex128'))
  1207. @make_xp_test_case(signal.resample)
  1208. def test_dc(self, nx, ny, dtype, xp):
  1209. dtype = getattr(xp, dtype)
  1210. x = xp.asarray([1] * nx, dtype=dtype)
  1211. y = signal.resample(x, ny)
  1212. xp_assert_close(y, xp.asarray([1] * ny, dtype=y.dtype))
  1213. @skip_xp_backends("cupy", reason="padtype not supported by upfirdn")
  1214. @pytest.mark.parametrize('padtype', padtype_options)
  1215. @make_xp_test_case(signal.resample_poly)
  1216. def test_mutable_window(self, padtype, xp):
  1217. # Test that a mutable window is not modified
  1218. impulse = xp.zeros(3)
  1219. window = xp.asarray(np.random.RandomState(0).randn(2))
  1220. window_orig = xp.asarray(window, copy=True)
  1221. signal.resample_poly(impulse, 5, 1, window=window, padtype=padtype)
  1222. xp_assert_equal(window, window_orig)
  1223. @skip_xp_backends("cupy", reason="padtype not supported by upfirdn")
  1224. @make_xp_test_case(signal.resample_poly)
  1225. @pytest.mark.parametrize('padtype', padtype_options)
  1226. def test_output_float32(self, padtype, xp):
  1227. # Test that float32 inputs yield a float32 output
  1228. x = xp.arange(10, dtype=xp.float32)
  1229. h = xp.asarray([1, 1, 1], dtype=xp.float32)
  1230. y = signal.resample_poly(x, 1, 2, window=h, padtype=padtype)
  1231. assert y.dtype == xp.float32
  1232. @pytest.mark.parametrize('padtype', padtype_options)
  1233. @pytest.mark.parametrize('dtype', ['float32', 'float64'])
  1234. @skip_xp_backends("cupy", reason="padtype not supported by upfirdn")
  1235. @make_xp_test_case(signal.resample_poly)
  1236. def test_output_match_dtype(self, padtype, dtype, xp):
  1237. # Test that the dtype of x is preserved per issue #14733
  1238. dtype = getattr(xp, dtype)
  1239. x = xp.arange(10, dtype=dtype)
  1240. y = signal.resample_poly(x, 1, 2, padtype=padtype)
  1241. assert y.dtype == x.dtype
  1242. @skip_xp_backends("cupy", reason="padtype not supported by upfirdn")
  1243. @pytest.mark.parametrize(
  1244. "method, ext, padtype",
  1245. [("fft", False, None)]
  1246. + list(
  1247. product(
  1248. ["polyphase"], [False, True], padtype_options,
  1249. )
  1250. ),
  1251. )
  1252. @make_xp_test_case(signal.resample, signal.resample_poly)
  1253. def test_resample_methods(self, method, ext, padtype, xp):
  1254. # Test resampling of sinusoids and random noise (1-sec)
  1255. rate = 100
  1256. rates_to = [49, 50, 51, 99, 100, 101, 199, 200, 201]
  1257. # Sinusoids, windowed to avoid edge artifacts
  1258. t = xp.arange(rate, dtype=xp.float64) / float(rate)
  1259. freqs = xp.asarray((1., 10., 40.))[:, xp.newaxis]
  1260. x = xp.sin(2 * xp.pi * freqs * t) * hann(rate, xp=xp)
  1261. for rate_to in rates_to:
  1262. t_to = xp.arange(rate_to, dtype=xp.float64) / float(rate_to)
  1263. y_tos = xp.sin(2 * xp.pi * freqs * t_to) * hann(rate_to, xp=xp)
  1264. if method == 'fft':
  1265. y_resamps = signal.resample(x, rate_to, axis=-1)
  1266. else:
  1267. if ext and rate_to != rate:
  1268. # Match default window design
  1269. g = gcd(rate_to, rate)
  1270. up = rate_to // g
  1271. down = rate // g
  1272. max_rate = max(up, down)
  1273. f_c = 1. / max_rate
  1274. half_len = 10 * max_rate
  1275. window = signal.firwin(2 * half_len + 1, f_c,
  1276. window=('kaiser', 5.0))
  1277. window = xp.asarray(window)
  1278. polyargs = {'window': window, 'padtype': padtype}
  1279. else:
  1280. polyargs = {'padtype': padtype}
  1281. y_resamps = signal.resample_poly(x, rate_to, rate, axis=-1,
  1282. **polyargs)
  1283. for i in range(y_tos.shape[0]):
  1284. y_to = y_tos[i, :]
  1285. y_resamp = y_resamps[i, :]
  1286. freq = float(freqs[i, 0])
  1287. if freq >= 0.5 * rate_to:
  1288. #y_to.fill(0.) # mostly low-passed away
  1289. y_to = xp.zeros_like(y_to) # mostly low-passed away
  1290. if padtype in ['minimum', 'maximum']:
  1291. xp_assert_close(y_resamp, y_to, atol=3e-1)
  1292. else:
  1293. xp_assert_close(y_resamp, y_to, atol=1e-3)
  1294. else:
  1295. assert y_to.shape == y_resamp.shape
  1296. corr = np.corrcoef(y_to, y_resamp)[0, 1]
  1297. assert corr > 0.99, (corr, rate, rate_to)
  1298. # Random data
  1299. rng = np.random.RandomState(0)
  1300. x = hann(rate) * np.cumsum(rng.randn(rate)) # low-pass, wind
  1301. x = xp.asarray(x)
  1302. for rate_to in rates_to:
  1303. # random data
  1304. t_to = xp.arange(rate_to, dtype=xp.float64) / float(rate_to)
  1305. y_to = np.interp(t_to, t, x)
  1306. if method == 'fft':
  1307. y_resamp = signal.resample(x, rate_to)
  1308. else:
  1309. y_resamp = signal.resample_poly(x, rate_to, rate,
  1310. padtype=padtype)
  1311. assert y_to.shape == y_resamp.shape
  1312. corr = xp.asarray(np.corrcoef(y_to, y_resamp)[0, 1])
  1313. assert corr > 0.99, corr
  1314. # More tests of fft method (Master 0.18.1 fails these)
  1315. if method == 'fft':
  1316. x1 = xp.asarray([1.+0.j, 0.+0.j])
  1317. y1_test = signal.resample(x1, 4)
  1318. # upsampling a complex array
  1319. y1_true = xp.asarray([1.+0.j, 0.5+0.j, 0.+0.j, 0.5+0.j])
  1320. xp_assert_close(y1_test, y1_true, atol=1e-12)
  1321. x2 = xp.asarray([1., 0.5, 0., 0.5])
  1322. y2_test = signal.resample(x2, 2) # downsampling a real array
  1323. y2_true = xp.asarray([1., 0.])
  1324. xp_assert_close(y2_test, y2_true, atol=1e-12)
  1325. @pytest.mark.parametrize("n_in", (8, 9))
  1326. @pytest.mark.parametrize("n_out", (3, 4))
  1327. @make_xp_test_case(signal.resample)
  1328. def test_resample_win_func(self, n_in, n_out):
  1329. """Test callable window function. """
  1330. x_in = np.ones(n_in)
  1331. def win(freqs):
  1332. """Scale input by 1/2"""
  1333. return 0.5 * np.ones_like(freqs)
  1334. y0 = signal.resample(x_in, n_out)
  1335. y1 = signal.resample(x_in, n_out, window=win)
  1336. xp_assert_close(2*y1, y0, atol=1e-12)
  1337. @pytest.mark.parametrize("n_in", (6, 12))
  1338. @pytest.mark.parametrize("n_out", (3, 4))
  1339. @make_xp_test_case(signal.resample)
  1340. def test__resample_param_t(self, n_in, n_out):
  1341. """Verify behavior for parameter `t`.
  1342. Note that only `t[0]` and `t[1]` are utilized.
  1343. """
  1344. t0, dt = 10, 2
  1345. x_in = np.ones(n_in)
  1346. y0 = signal.resample(x_in, n_out)
  1347. y1, t1 = signal.resample(x_in, n_out, t=[t0, t0+dt])
  1348. t_ref = 10 + np.arange(len(y0)) * dt * n_in / n_out
  1349. xp_assert_equal(y1, y0) # no influence of `t`
  1350. xp_assert_close(t1, t_ref, atol=1e-12)
  1351. @pytest.mark.parametrize("n1", (2, 3, 7, 8))
  1352. @pytest.mark.parametrize("n0", (2, 3, 7, 8))
  1353. @make_xp_test_case(signal.resample)
  1354. def test_resample_nyquist(self, n0, n1):
  1355. """Test behavior at Nyquist frequency to ensure issue #14569 is fixed. """
  1356. f_ny = min(n0, n1) // 2
  1357. tt = (np.arange(n_) / n_ for n_ in (n0, n1))
  1358. x0, x1 = (np.cos(2 * np.pi * f_ny * t_) for t_ in tt)
  1359. y1_r = signal.resample(x0, n1)
  1360. y1_c = signal.resample(x0 + 0j, n1)
  1361. xp_assert_close(y1_r, x1, atol=1e-12)
  1362. xp_assert_close(y1_c.real, x1, atol=1e-12)
  1363. @pytest.mark.parametrize('down_factor', [2, 11, 79])
  1364. @pytest.mark.parametrize("dtype", [int, np.float32, np.complex64, float, complex])
  1365. @make_xp_test_case(signal.resample_poly)
  1366. def test_poly_vs_filtfilt(self, down_factor, dtype, xp):
  1367. # Check that up=1.0 gives same answer as filtfilt + slicing
  1368. random_state = np.random.RandomState(17)
  1369. size = 10000
  1370. x = random_state.randn(size).astype(dtype)
  1371. if dtype in (np.complex64, np.complex128):
  1372. x += 1j * random_state.randn(size)
  1373. # resample_poly assumes zeros outside of signl, whereas filtfilt
  1374. # can only constant-pad. Make them equivalent:
  1375. x[0] = 0
  1376. x[-1] = 0
  1377. h = signal.firwin(31, 1. / down_factor, window='hamming')
  1378. yf = filtfilt(h, 1.0, x, padtype='constant')[::down_factor]
  1379. # Need to pass convolved version of filter to resample_poly,
  1380. # since filtfilt does forward and backward, but resample_poly
  1381. # only goes forward
  1382. hc = convolve(h, np.flip(h))
  1383. # Use yf.copy() to avoid negative strides, which are unsupported
  1384. # in torch.
  1385. x, hc, yf = map(xp.asarray, (x, hc, yf.copy()))
  1386. y = signal.resample_poly(x, 1, down_factor, window=hc)
  1387. xp_assert_close(yf, y, atol=3e-7, rtol=6e-7)
  1388. @make_xp_test_case(signal.resample_poly)
  1389. def test_correlate1d(self, xp):
  1390. for down in [2, 4]:
  1391. for nx in range(1, 40, down):
  1392. for nweights in (32, 33):
  1393. x = np.random.random((nx,))
  1394. weights = np.random.random((nweights,))
  1395. y_g = correlate1d(x, np.flip(weights), mode='constant')
  1396. x, weights, y_g = map(xp.asarray, (x, weights, y_g))
  1397. y_s = signal.resample_poly(
  1398. x, up=1, down=down, window=weights)
  1399. xp_assert_close(y_g[::down], y_s)
  1400. @make_xp_test_case(signal.resample_poly)
  1401. @pytest.mark.parametrize('dtype', ['int32', 'float32'])
  1402. @skip_xp_backends("cupy", reason="padtype not supported by upfirdn")
  1403. def test_gh_15620(self, dtype, xp):
  1404. dtype = getattr(xp, dtype)
  1405. data = xp.asarray([0, 1, 2, 3, 2, 1, 0], dtype=dtype)
  1406. actual = signal.resample_poly(data,
  1407. up=2,
  1408. down=1,
  1409. padtype='smooth')
  1410. assert np.count_nonzero(actual) > 0
  1411. @make_xp_test_case(signal.cspline1d_eval)
  1412. class TestCSpline1DEval:
  1413. def test_basic(self, xp):
  1414. y = np.asarray([1, 2, 3, 4, 3, 2, 1, 2, 3.0])
  1415. x = np.arange(y.shape[0])
  1416. dx = x[1] - x[0]
  1417. cj = xp.asarray(signal.cspline1d(y))
  1418. x2 = xp.arange(len(y) * 10.0) / 10.0
  1419. y2 = signal.cspline1d_eval(cj, x2, dx=dx, x0=x[0])
  1420. # make sure interpolated values are on knot points
  1421. assert_array_almost_equal(y2[::10], xp.asarray(y), decimal=5)
  1422. def test_complex(self, xp):
  1423. # create some smoothly varying complex signal to interpolate
  1424. x = np.arange(2.0)
  1425. y = np.zeros(x.shape, dtype=np.complex64)
  1426. T = 10.0
  1427. f = 1.0 / T
  1428. y = np.exp(2.0J * np.pi * f * x)
  1429. # get the cspline transform
  1430. cy = xp.asarray(signal.cspline1d(y))
  1431. # determine new test x value and interpolate
  1432. xnew = xp.asarray([0.5])
  1433. ynew = signal.cspline1d_eval(cy, xnew)
  1434. assert ynew.dtype == xp.asarray(y).dtype
  1435. @make_xp_test_case(signal.order_filter)
  1436. class TestOrderFilt:
  1437. def test_basic(self, xp):
  1438. actual = signal.order_filter(xp.asarray([1, 2, 3]), xp.asarray([1, 0, 1]), 1)
  1439. expect = xp.asarray([2, 3, 2])
  1440. xp_assert_equal(actual, expect)
  1441. def test_doc_example(self, xp):
  1442. x = xp.reshape(xp.arange(25, dtype=xp_default_dtype(xp)), (5, 5))
  1443. domain = xp.eye(3, dtype=xp_default_dtype(xp))
  1444. # minimum of elements 1,3,9 (zero-padded) on phone pad
  1445. # 7,5,3 on numpad
  1446. expected = xp.asarray(
  1447. [[0., 0., 0., 0., 0.],
  1448. [0., 0., 1., 2., 0.],
  1449. [0., 5., 6., 7., 0.],
  1450. [0., 10., 11., 12., 0.],
  1451. [0., 0., 0., 0., 0.]],
  1452. dtype=xp_default_dtype(xp)
  1453. )
  1454. xp_assert_close(signal.order_filter(x, domain, 0), expected)
  1455. # maximum of elements 1,3,9 (zero-padded) on phone pad
  1456. # 7,5,3 on numpad
  1457. expected = xp.asarray(
  1458. [[6., 7., 8., 9., 4.],
  1459. [11., 12., 13., 14., 9.],
  1460. [16., 17., 18., 19., 14.],
  1461. [21., 22., 23., 24., 19.],
  1462. [20., 21., 22., 23., 24.]],
  1463. )
  1464. xp_assert_close(signal.order_filter(x, domain, 2), expected)
  1465. # and, just to complete the set, median of zero-padded elements
  1466. expected = xp.asarray(
  1467. [[0, 1, 2, 3, 0],
  1468. [5, 6, 7, 8, 3],
  1469. [10, 11, 12, 13, 8],
  1470. [15, 16, 17, 18, 13],
  1471. [0, 15, 16, 17, 18]],
  1472. dtype=xp_default_dtype(xp)
  1473. )
  1474. xp_assert_close(signal.order_filter(x, domain, 1), expected)
  1475. @xfail_xp_backends('dask.array', reason='repeat requires an axis')
  1476. @xfail_xp_backends('torch', reason='array-api-compat#292')
  1477. @make_xp_test_case(signal.medfilt)
  1478. def test_medfilt_order_filter(self, xp):
  1479. x = xp.reshape(xp.arange(25), (5, 5))
  1480. # median of zero-padded elements 1,5,9 on phone pad
  1481. # 7,5,3 on numpad
  1482. expected = xp.asarray(
  1483. [[0, 1, 2, 3, 0],
  1484. [1, 6, 7, 8, 4],
  1485. [6, 11, 12, 13, 9],
  1486. [11, 16, 17, 18, 14],
  1487. [0, 16, 17, 18, 0]],
  1488. )
  1489. xp_assert_close(signal.medfilt(x, 3), expected)
  1490. xp_assert_close(
  1491. signal.order_filter(x, xp.ones((3, 3)), 4),
  1492. expected
  1493. )
  1494. def test_order_filter_asymmetric(self, xp):
  1495. x = xp.reshape(xp.arange(25), (5, 5))
  1496. domain = xp.asarray(
  1497. [[1, 1, 0],
  1498. [0, 1, 0],
  1499. [0, 0, 0]],
  1500. )
  1501. expected = xp.asarray(
  1502. [[0, 0, 0, 0, 0],
  1503. [0, 0, 1, 2, 3],
  1504. [0, 5, 6, 7, 8],
  1505. [0, 10, 11, 12, 13],
  1506. [0, 15, 16, 17, 18]]
  1507. )
  1508. xp_assert_close(signal.order_filter(x, domain, 0), expected)
  1509. expected = xp.asarray(
  1510. [[0, 0, 0, 0, 0],
  1511. [0, 1, 2, 3, 4],
  1512. [5, 6, 7, 8, 9],
  1513. [10, 11, 12, 13, 14],
  1514. [15, 16, 17, 18, 19]]
  1515. )
  1516. xp_assert_close(signal.order_filter(x, domain, 1), expected)
  1517. @make_xp_test_case(lfilter)
  1518. class _TestLinearFilter:
  1519. def generate(self, shape, xp):
  1520. prodshape = shape if isinstance(shape, int) else math.prod(shape)
  1521. x = xp.linspace(0, prodshape - 1, prodshape)
  1522. if not isinstance(shape, int):
  1523. x = xp.reshape(x, shape)
  1524. return self.convert_dtype(x, xp)
  1525. def convert_dtype(self, arr, xp):
  1526. if self.dtype == np.dtype('O'):
  1527. arr = np.asarray(arr)
  1528. out = np.empty(arr.shape, self.dtype)
  1529. iter = np.nditer([arr, out], ['refs_ok','zerosize_ok'],
  1530. [['readonly'],['writeonly']])
  1531. for x, y in iter:
  1532. y[...] = self.type(x[()])
  1533. return out
  1534. else:
  1535. dtype = (getattr(xp, self.dtype)
  1536. if isinstance(self.dtype, str)
  1537. else self.dtype)
  1538. return xp.asarray(arr, dtype=dtype)
  1539. @skip_xp_backends('cupy', reason='XXX https://github.com/scipy/scipy/issues/23539')
  1540. def test_invalid_params(self, xp):
  1541. """Verify all exceptions are raised. """
  1542. b, a, x = xp.asarray([1]), xp.asarray([2]), xp.asarray([3, 4])
  1543. with pytest.raises(ValueError, match="^Parameter b is not"):
  1544. lfilter(xp.eye(2), a, x) # b not one-dimensional
  1545. with pytest.raises(ValueError, match="^Parameter b is not"):
  1546. lfilter(xp.asarray([]), a, x) # b empty
  1547. with pytest.raises(ValueError, match="^Parameter a is not"):
  1548. lfilter(b, xp.eye(2), x) # a not one-dimensional
  1549. with pytest.raises(ValueError, match="^Parameter a is not"):
  1550. lfilter(b, xp.asarray([]), x) # a empty
  1551. with pytest.raises(NotImplementedError, match="^Parameter's dtypes produced "):
  1552. b, a, x = (xp.astype(v_, xp.uint64, copy=False) for v_ in (b, a, x))
  1553. lfilter(b, a, x) # fails with uint64 dtype
  1554. def test_rank_1_IIR(self, xp):
  1555. x = self.generate((6,), xp)
  1556. b = self.convert_dtype([1, -1], xp)
  1557. a = self.convert_dtype([0.5, -0.5], xp)
  1558. y_r = self.convert_dtype([0, 2, 4, 6, 8, 10.], xp)
  1559. assert_array_almost_equal(lfilter(b, a, x), y_r)
  1560. def test_rank_1_FIR(self, xp):
  1561. x = self.generate((6,), xp)
  1562. b = self.convert_dtype([1, 1], xp)
  1563. a = self.convert_dtype([1], xp)
  1564. y_r = self.convert_dtype([0, 1, 3, 5, 7, 9.], xp)
  1565. assert_array_almost_equal(lfilter(b, a, x), y_r)
  1566. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1567. def test_rank_1_IIR_init_cond(self, xp):
  1568. x = self.generate((6,), xp)
  1569. b = self.convert_dtype([1, 0, -1], xp)
  1570. a = self.convert_dtype([0.5, -0.5], xp)
  1571. zi = self.convert_dtype([1, 2], xp)
  1572. y_r = self.convert_dtype([1, 5, 9, 13, 17, 21], xp)
  1573. zf_r = self.convert_dtype([13, -10], xp)
  1574. y, zf = lfilter(b, a, x, zi=zi)
  1575. assert_array_almost_equal(y, y_r)
  1576. assert_array_almost_equal(zf, zf_r)
  1577. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1578. def test_rank_1_FIR_init_cond(self, xp):
  1579. x = self.generate((6,), xp)
  1580. b = self.convert_dtype([1, 1, 1], xp)
  1581. a = self.convert_dtype([1], xp)
  1582. zi = self.convert_dtype([1, 1], xp)
  1583. y_r = self.convert_dtype([1, 2, 3, 6, 9, 12.], xp)
  1584. zf_r = self.convert_dtype([9, 5], xp)
  1585. y, zf = lfilter(b, a, x, zi=zi)
  1586. assert_array_almost_equal(y, y_r)
  1587. assert_array_almost_equal(zf, zf_r)
  1588. def test_rank_2_IIR_axis_0(self, xp):
  1589. x = self.generate((4, 3), xp)
  1590. b = self.convert_dtype([1, -1], xp)
  1591. a = self.convert_dtype([0.5, 0.5], xp)
  1592. y_r2_a0 = self.convert_dtype([[0, 2, 4], [6, 4, 2], [0, 2, 4],
  1593. [6, 4, 2]], xp)
  1594. y = lfilter(b, a, x, axis=0)
  1595. assert_array_almost_equal(y_r2_a0, y)
  1596. def test_rank_2_IIR_axis_1(self, xp):
  1597. x = self.generate((4, 3), xp)
  1598. b = self.convert_dtype([1, -1], xp)
  1599. a = self.convert_dtype([0.5, 0.5], xp)
  1600. y_r2_a1 = self.convert_dtype([[0, 2, 0], [6, -4, 6], [12, -10, 12],
  1601. [18, -16, 18]], xp)
  1602. y = lfilter(b, a, x, axis=1)
  1603. assert_array_almost_equal(y_r2_a1, y)
  1604. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1605. def test_rank_2_IIR_axis_0_init_cond(self, xp):
  1606. x = self.generate((4, 3), xp)
  1607. b = self.convert_dtype([1, -1], xp)
  1608. a = self.convert_dtype([0.5, 0.5], xp)
  1609. zi = self.convert_dtype(np.ones((4,1)), xp)
  1610. y_r2_a0_1 = self.convert_dtype([[1, 1, 1], [7, -5, 7], [13, -11, 13],
  1611. [19, -17, 19]], xp)
  1612. zf_r = self.convert_dtype([-5, -17, -29, -41], xp)[:, np.newaxis]
  1613. y, zf = lfilter(b, a, x, axis=1, zi=zi)
  1614. assert_array_almost_equal(y_r2_a0_1, y)
  1615. assert_array_almost_equal(zf, zf_r)
  1616. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1617. def test_rank_2_IIR_axis_1_init_cond(self, xp):
  1618. x = self.generate((4, 3), xp)
  1619. b = self.convert_dtype([1, -1], xp)
  1620. a = self.convert_dtype([0.5, 0.5], xp)
  1621. zi = self.convert_dtype(np.ones((1, 3)), xp)
  1622. y_r2_a0_0 = self.convert_dtype([[1, 3, 5], [5, 3, 1],
  1623. [1, 3, 5], [5, 3, 1]], xp)
  1624. zf_r = self.convert_dtype([[-23, -23, -23]], xp)
  1625. y, zf = lfilter(b, a, x, axis=0, zi=zi)
  1626. assert_array_almost_equal(y_r2_a0_0, y)
  1627. assert_array_almost_equal(zf, zf_r)
  1628. def test_rank_3_IIR(self, xp):
  1629. x = self.generate((4, 3, 2), xp)
  1630. b = self.convert_dtype([1, -1], xp)
  1631. a = self.convert_dtype([0.5, 0.5], xp)
  1632. a_np, b_np, x_np = map(_xp_copy_to_numpy, (a, b, x))
  1633. for axis in range(x.ndim):
  1634. y = lfilter(b, a, x, axis)
  1635. y_r = np.apply_along_axis(lambda w: lfilter(b_np, a_np, w), axis, x_np)
  1636. assert_array_almost_equal(y, xp.asarray(y_r))
  1637. @xfail_xp_backends("cupy", reason="inaccurate")
  1638. def test_rank_3_IIR_init_cond(self, xp):
  1639. x = self.generate((4, 3, 2), xp)
  1640. b = self.convert_dtype([1, -1], xp)
  1641. a = self.convert_dtype([0.5, 0.5], xp)
  1642. for axis in range(x.ndim):
  1643. zi_shape = list(x.shape)
  1644. zi_shape[axis] = 1
  1645. zi = self.convert_dtype(xp.ones(zi_shape), xp)
  1646. zi1 = self.convert_dtype([1], xp)
  1647. y, zf = lfilter(b, a, x, axis, zi)
  1648. b_np, a_np, zi1_np = map(_xp_copy_to_numpy, (b, a, zi1))
  1649. def lf0(w):
  1650. return lfilter(b_np, a_np, w, zi=zi1_np)[0]
  1651. def lf1(w):
  1652. return lfilter(b_np, a_np, w, zi=zi1_np)[1]
  1653. y_r = np.apply_along_axis(lf0, axis, _xp_copy_to_numpy(x))
  1654. zf_r = np.apply_along_axis(lf1, axis, _xp_copy_to_numpy(x))
  1655. assert_array_almost_equal(y, xp.asarray(y_r))
  1656. assert_array_almost_equal(zf, xp.asarray(zf_r))
  1657. def test_rank_3_FIR(self, xp):
  1658. x = self.generate((4, 3, 2), xp)
  1659. b = self.convert_dtype([1, 0, -1], xp)
  1660. a = self.convert_dtype([1], xp)
  1661. a_np, b_np, x_np = map(_xp_copy_to_numpy, (a, b, x))
  1662. for axis in range(x.ndim):
  1663. y = lfilter(b, a, x, axis)
  1664. y_r = np.apply_along_axis(lambda w: lfilter(b_np, a_np, w), axis, x_np)
  1665. assert_array_almost_equal(y, xp.asarray(y_r))
  1666. @xfail_xp_backends("cupy", reason="inaccurate")
  1667. def test_rank_3_FIR_init_cond(self, xp):
  1668. x = self.generate((4, 3, 2), xp)
  1669. b = self.convert_dtype([1, 0, -1], xp)
  1670. a = self.convert_dtype([1], xp)
  1671. x_np, b_np, a_np = map(_xp_copy_to_numpy, (x, b, a))
  1672. for axis in range(x.ndim):
  1673. zi_shape = list(x.shape)
  1674. zi_shape[axis] = 2
  1675. zi = self.convert_dtype(xp.ones(zi_shape), xp)
  1676. zi1 = self.convert_dtype([1, 1], xp)
  1677. zi1_np = _xp_copy_to_numpy(zi1)
  1678. y, zf = lfilter(b, a, x, axis, zi)
  1679. b_np, a_np, zi1_np = map(_xp_copy_to_numpy, (b, a, zi1))
  1680. def lf0(w):
  1681. return lfilter(b_np, a_np, w, zi=zi1_np)[0]
  1682. def lf1(w):
  1683. return lfilter(b_np, a_np, w, zi=zi1_np)[1]
  1684. y_r = np.apply_along_axis(lf0, axis, x_np)
  1685. zf_r = np.apply_along_axis(lf1, axis, x_np)
  1686. assert_array_almost_equal(y, xp.asarray(y_r))
  1687. assert_array_almost_equal(zf, xp.asarray(zf_r))
  1688. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1689. def test_zi_pseudobroadcast(self, xp):
  1690. x = self.generate((4, 5, 20), xp)
  1691. b, a = signal.butter(8, 0.2, output='ba')
  1692. b = self.convert_dtype(b, xp)
  1693. a = self.convert_dtype(a, xp)
  1694. zi_size = b.shape[0] - 1
  1695. # lfilter requires x.ndim == zi.ndim exactly. However, zi can have
  1696. # length 1 dimensions.
  1697. zi_full = self.convert_dtype(xp.ones((4, 5, zi_size)), xp)
  1698. zi_sing = self.convert_dtype(xp.ones((1, 1, zi_size)), xp)
  1699. y_full, zf_full = lfilter(b, a, x, zi=zi_full)
  1700. y_sing, zf_sing = lfilter(b, a, x, zi=zi_sing)
  1701. assert_array_almost_equal(y_sing, y_full)
  1702. assert_array_almost_equal(zf_full, zf_sing)
  1703. # lfilter does not prepend ones
  1704. assert_raises(ValueError, lfilter, b, a, x, -1, xp.ones(zi_size))
  1705. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1706. def test_scalar_a(self, xp):
  1707. # a can be a scalar.
  1708. x = self.generate(6, xp)
  1709. b = self.convert_dtype([1, 0, -1], xp)
  1710. a = self.convert_dtype([1], xp)
  1711. y_r = self.convert_dtype([0, 1, 2, 2, 2, 2], xp)
  1712. y = lfilter(b, a[0], x)
  1713. assert_array_almost_equal(y, y_r)
  1714. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1715. def test_zi_some_singleton_dims(self, xp):
  1716. # lfilter doesn't really broadcast (no prepending of 1's). But does
  1717. # do singleton expansion if x and zi have the same ndim. This was
  1718. # broken only if a subset of the axes were singletons (gh-4681).
  1719. x = self.convert_dtype(xp.zeros((3, 2, 5), dtype=xp.int64), xp)
  1720. b = self.convert_dtype(xp.ones(5, dtype=xp.int64), xp)
  1721. a = self.convert_dtype(xp.asarray([1, 0, 0]), xp)
  1722. zi = np.ones((3, 1, 4), dtype=np.int64)
  1723. zi[1, :, :] *= 2
  1724. zi[2, :, :] *= 3
  1725. zi = xp.asarray(zi)
  1726. zi = self.convert_dtype(zi, xp)
  1727. zf_expected = self.convert_dtype(xp.zeros((3, 2, 4), dtype=xp.int64), xp)
  1728. y_expected = np.zeros((3, 2, 5), dtype=np.int64)
  1729. y_expected[:, :, :4] = [[[1]], [[2]], [[3]]]
  1730. y_expected = xp.asarray(y_expected)
  1731. y_expected = self.convert_dtype(y_expected, xp)
  1732. # IIR
  1733. y_iir, zf_iir = lfilter(b, a, x, -1, zi)
  1734. assert_array_almost_equal(y_iir, y_expected)
  1735. assert_array_almost_equal(zf_iir, zf_expected)
  1736. # FIR
  1737. y_fir, zf_fir = lfilter(b, a[0], x, -1, zi)
  1738. assert_array_almost_equal(y_fir, y_expected)
  1739. assert_array_almost_equal(zf_fir, zf_expected)
  1740. def base_bad_size_zi(self, b, a, x, axis, zi, xp):
  1741. b = self.convert_dtype(b, xp)
  1742. a = self.convert_dtype(a, xp)
  1743. x = self.convert_dtype(x, xp)
  1744. zi = self.convert_dtype(zi, xp)
  1745. assert_raises(ValueError, lfilter, b, a, x, axis, zi)
  1746. @skip_xp_backends('cupy', reason='cupy does not raise')
  1747. def test_bad_size_zi(self, xp):
  1748. # rank 1
  1749. x1 = xp.arange(6)
  1750. self.base_bad_size_zi([1], [1], x1, -1, [1], xp)
  1751. self.base_bad_size_zi([1, 1], [1], x1, -1, [0, 1], xp)
  1752. self.base_bad_size_zi([1, 1], [1], x1, -1, [[0]], xp)
  1753. self.base_bad_size_zi([1, 1], [1], x1, -1, [0, 1, 2], xp)
  1754. self.base_bad_size_zi([1, 1, 1], [1], x1, -1, [[0]], xp)
  1755. self.base_bad_size_zi([1, 1, 1], [1], x1, -1, [0, 1, 2], xp)
  1756. self.base_bad_size_zi([1], [1, 1], x1, -1, [0, 1], xp)
  1757. self.base_bad_size_zi([1], [1, 1], x1, -1, [[0]], xp)
  1758. self.base_bad_size_zi([1], [1, 1], x1, -1, [0, 1, 2], xp)
  1759. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0], xp)
  1760. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [[0], [1]], xp)
  1761. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0, 1, 2], xp)
  1762. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0, 1, 2, 3], xp)
  1763. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0], xp)
  1764. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [[0], [1]], xp)
  1765. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0, 1, 2], xp)
  1766. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0, 1, 2, 3], xp)
  1767. # rank 2
  1768. x2 = np.arange(12).reshape((4,3))
  1769. x2 = xp.asarray(x2)
  1770. # for axis=0 zi.shape should == (max(len(a),len(b))-1, 3)
  1771. self.base_bad_size_zi([1], [1], x2, 0, [0], xp)
  1772. # for each of these there are 5 cases tested (in this order):
  1773. # 1. not deep enough, right # elements
  1774. # 2. too deep, right # elements
  1775. # 3. right depth, right # elements, transposed
  1776. # 4. right depth, too few elements
  1777. # 5. right depth, too many elements
  1778. self.base_bad_size_zi([1, 1], [1], x2, 0, [0, 1, 2], xp)
  1779. self.base_bad_size_zi([1, 1], [1], x2, 0, [[[0, 1, 2]]], xp)
  1780. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0], [1], [2]], xp)
  1781. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0, 1]], xp)
  1782. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0, 1, 2, 3]], xp)
  1783. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [0, 1, 2, 3, 4, 5], xp)
  1784. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[[0, 1, 2], [3, 4, 5]]], xp)
  1785. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0, 1], [2, 3], [4, 5]], xp)
  1786. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0, 1], [2, 3]], xp)
  1787. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0, 1, 2, 3], [4, 5, 6, 7]], xp)
  1788. self.base_bad_size_zi([1], [1, 1], x2, 0, [0, 1, 2], xp)
  1789. self.base_bad_size_zi([1], [1, 1], x2, 0, [[[0, 1, 2]]], xp)
  1790. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0], [1], [2]], xp)
  1791. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0, 1]], xp)
  1792. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0, 1, 2, 3]], xp)
  1793. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [0, 1, 2, 3, 4, 5], xp)
  1794. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[[0, 1, 2], [3, 4, 5]]], xp)
  1795. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0, 1], [2, 3], [4, 5]], xp)
  1796. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0, 1], [2, 3]], xp)
  1797. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0, 1, 2, 3], [4, 5, 6, 7]], xp)
  1798. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [0, 1, 2, 3, 4, 5], xp)
  1799. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[[0, 1, 2], [3, 4, 5]]], xp)
  1800. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[0, 1], [2, 3], [4, 5]], xp)
  1801. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[0, 1], [2, 3]], xp)
  1802. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0,
  1803. [[0, 1, 2, 3], [4, 5, 6, 7]], xp)
  1804. # for axis=1 zi.shape should == (4, max(len(a),len(b))-1)
  1805. self.base_bad_size_zi([1], [1], x2, 1, [0], xp)
  1806. self.base_bad_size_zi([1, 1], [1], x2, 1, [0, 1, 2, 3], xp)
  1807. self.base_bad_size_zi([1, 1], [1], x2, 1, [[[0], [1], [2], [3]]], xp)
  1808. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0, 1, 2, 3]], xp)
  1809. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0], [1], [2]], xp)
  1810. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0], [1], [2], [3], [4]], xp)
  1811. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [0, 1, 2, 3, 4, 5, 6, 7], xp)
  1812. self.base_bad_size_zi([1, 1, 1], [1], x2, 1,
  1813. [[[0, 1], [2, 3], [4, 5], [6, 7]]], xp)
  1814. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[0, 1, 2, 3], [4, 5, 6, 7]], xp)
  1815. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[0, 1], [2, 3], [4, 5]], xp)
  1816. self.base_bad_size_zi([1, 1, 1], [1], x2, 1,
  1817. [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], xp)
  1818. self.base_bad_size_zi([1], [1, 1], x2, 1, [0, 1, 2, 3], xp)
  1819. self.base_bad_size_zi([1], [1, 1], x2, 1, [[[0], [1], [2], [3]]], xp)
  1820. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0, 1, 2, 3]], xp)
  1821. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0], [1], [2]], xp)
  1822. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0], [1], [2], [3], [4]], xp)
  1823. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [0, 1, 2, 3, 4, 5, 6, 7], xp)
  1824. self.base_bad_size_zi([1], [1, 1, 1], x2, 1,
  1825. [[[0, 1], [2, 3], [4, 5], [6, 7]]], xp)
  1826. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0, 1, 2, 3], [4, 5, 6, 7]], xp)
  1827. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0, 1], [2, 3], [4, 5]], xp)
  1828. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0, 1],
  1829. [2, 3], [4, 5], [6, 7], [8, 9]], xp)
  1830. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [0, 1, 2, 3, 4, 5, 6, 7], xp)
  1831. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1,
  1832. [[[0, 1], [2, 3], [4, 5], [6, 7]]], xp)
  1833. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1,
  1834. [[0, 1, 2, 3], [4, 5, 6, 7]], xp)
  1835. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[0, 1], [2, 3], [4, 5]], xp)
  1836. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1,
  1837. [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], xp)
  1838. def test_empty_zi(self, xp):
  1839. # Regression test for #880: empty array for zi crashes.
  1840. x = self.generate((5,), xp)
  1841. a = self.convert_dtype([1], xp)
  1842. b = self.convert_dtype([1], xp)
  1843. zi = self.convert_dtype([], xp)
  1844. y, zf = lfilter(b, a, x, zi=zi)
  1845. assert_array_almost_equal(y, x)
  1846. assert zf.dtype == (getattr(xp, self.dtype)
  1847. if isinstance(self.dtype, str)
  1848. else self.dtype)
  1849. assert xp_size(zf) == 0
  1850. @skip_xp_backends('jax.numpy', reason='jax does not support inplace ops')
  1851. @pytest.mark.parametrize('a', (1, [1], [1, .5, 1.5], 2, [2], [2, 1, 3]),
  1852. ids=str)
  1853. @make_xp_test_case(lfiltic)
  1854. def test_lfiltic(self, a, xp):
  1855. # Test for #22470: lfiltic does not handle `a[0] != 1`
  1856. # and, more in general, test that lfiltic behaves consistently with lfilter
  1857. if is_cupy(xp) and isinstance(a, int | float):
  1858. pytest.skip('cupy does not supoprt scalar filter coefficients')
  1859. x = self.generate(6, xp) # arbitrary input
  1860. b = self.convert_dtype([.5, 1., .2], xp) # arbitrary b
  1861. a = self.convert_dtype(a, xp)
  1862. N = xp_size(a) - 1
  1863. M = xp_size(b) - 1
  1864. K = M + N if is_cupy(xp) else max(N, M)
  1865. # compute reference initial conditions as final conditions of lfilter
  1866. y1, zi_1 = lfilter(b, a, x, zi=self.generate(K, xp))
  1867. # copute initial conditions from lfiltic
  1868. zi_2 = lfiltic(b, a, xp.flip(y1), xp.flip(x))
  1869. # compare lfiltic's output with reference
  1870. assert_array_almost_equal(zi_1, zi_2)
  1871. @make_xp_test_case(lfiltic)
  1872. def test_lfiltic_bad_coeffs(xp):
  1873. # Test for invalid filter coefficients (wrong shape or zero `a[0]`)
  1874. assert_raises(ValueError, lfiltic, [1, 2], [], [0, 0], [0, 1])
  1875. assert_raises(ValueError, lfiltic, [1, 2], [0, 2], [0, 0], [0, 1])
  1876. assert_raises(ValueError, lfiltic, [1, 2], [[1], [2]], [0, 0], [0, 1])
  1877. assert_raises(ValueError, lfiltic, [[1], [2]], [1], [0, 0], [0, 1])
  1878. @skip_xp_backends(
  1879. 'array_api_strict', reason='int64 and float64 cannot be promoted together'
  1880. )
  1881. @skip_xp_backends('jax.numpy', reason='jax dtype defaults differ')
  1882. @make_xp_test_case(lfiltic)
  1883. def test_lfiltic_bad_zi(self, xp):
  1884. # Regression test for #3699: bad initial conditions
  1885. a = self.convert_dtype([1], xp)
  1886. b = self.convert_dtype([1], xp)
  1887. # "y" sets the datatype of zi, so it truncates if int
  1888. zi = lfiltic(b, a, xp.asarray([1., 0]))
  1889. zi_1 = lfiltic(b, a, xp.asarray([1.0, 0]))
  1890. zi_2 = lfiltic(b, a, xp.asarray([True, False]))
  1891. xp_assert_equal(zi, zi_1)
  1892. check_dtype_arg = {} if self.dtype == object else {'check_dtype': False}
  1893. xp_assert_equal(zi, zi_2, **check_dtype_arg)
  1894. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1895. def test_short_x_FIR(self, xp):
  1896. # regression test for #5116
  1897. # x shorter than b, with non None zi fails
  1898. a = self.convert_dtype([1], xp)
  1899. b = self.convert_dtype([1, 0, -1], xp)
  1900. zi = self.convert_dtype([2, 7], xp)
  1901. x = self.convert_dtype([72], xp)
  1902. ye = self.convert_dtype([74], xp)
  1903. zfe = self.convert_dtype([7, -72], xp)
  1904. y, zf = lfilter(b, a, x, zi=zi)
  1905. assert_array_almost_equal(y, ye)
  1906. assert_array_almost_equal(zf, zfe)
  1907. @skip_xp_backends('cupy', reason='XXX https://github.com/cupy/cupy/pull/8677')
  1908. def test_short_x_IIR(self, xp):
  1909. # regression test for #5116
  1910. # x shorter than b, with non None zi fails
  1911. a = self.convert_dtype([1, 1], xp)
  1912. b = self.convert_dtype([1, 0, -1], xp)
  1913. zi = self.convert_dtype([2, 7], xp)
  1914. x = self.convert_dtype([72], xp)
  1915. ye = self.convert_dtype([74], xp)
  1916. zfe = self.convert_dtype([-67, -72], xp)
  1917. y, zf = lfilter(b, a, x, zi=zi)
  1918. assert_array_almost_equal(y, ye)
  1919. assert_array_almost_equal(zf, zfe)
  1920. def test_do_not_modify_a_b_IIR(self, xp):
  1921. x = self.generate((6,), xp)
  1922. b = self.convert_dtype([1, -1], xp)
  1923. b0 = xp_copy(b, xp=xp)
  1924. a = self.convert_dtype([0.5, -0.5], xp)
  1925. a0 = xp_copy(a, xp=xp)
  1926. y_r = self.convert_dtype([0, 2, 4, 6, 8, 10.], xp)
  1927. y_f = lfilter(b, a, x)
  1928. assert_array_almost_equal(y_f, y_r)
  1929. xp_assert_equal(b, b0)
  1930. xp_assert_equal(a, a0)
  1931. def test_do_not_modify_a_b_FIR(self, xp):
  1932. x = self.generate((6,), xp)
  1933. b = self.convert_dtype([1, 0, 1], xp)
  1934. b0 = xp_copy(b, xp=xp)
  1935. a = self.convert_dtype([2], xp)
  1936. a0 = xp_copy(a, xp=xp)
  1937. y_r = self.convert_dtype([0, 0.5, 1, 2, 3, 4.], xp)
  1938. y_f = lfilter(b, a, x)
  1939. assert_array_almost_equal(y_f, y_r)
  1940. xp_assert_equal(b, b0)
  1941. xp_assert_equal(a, a0)
  1942. @skip_xp_backends(np_only=True)
  1943. @pytest.mark.parametrize("a", [1.0, [1.0], np.array(1.0)])
  1944. @pytest.mark.parametrize("b", [1.0, [1.0], np.array(1.0)])
  1945. def test_scalar_input(self, a, b, xp):
  1946. data = np.random.randn(10)
  1947. data = xp.asarray(data)
  1948. xp_assert_close(
  1949. lfilter(xp.asarray([1.0]), xp.asarray([1.0]), data),
  1950. lfilter(b, a, data)
  1951. )
  1952. class TestLinearFilterFloat32(_TestLinearFilter):
  1953. dtype = 'float32'
  1954. class TestLinearFilterFloat64(_TestLinearFilter):
  1955. dtype = 'float64'
  1956. @skip_xp_backends(np_only=True)
  1957. class TestLinearFilterFloatExtended(_TestLinearFilter):
  1958. dtype = np.dtype('g')
  1959. class TestLinearFilterComplex64(_TestLinearFilter):
  1960. dtype = 'complex64'
  1961. class TestLinearFilterComplex128(_TestLinearFilter):
  1962. dtype = 'complex128'
  1963. @skip_xp_backends(np_only=True)
  1964. class TestLinearFilterComplexExtended(_TestLinearFilter):
  1965. dtype = np.dtype('G')
  1966. @make_xp_test_case(lfilter)
  1967. def test_lfilter_bad_object(xp):
  1968. # lfilter: object arrays with non-numeric objects raise TypeError.
  1969. # Regression test for ticket #1452.
  1970. if hasattr(sys, 'abiflags') and 'd' in sys.abiflags:
  1971. pytest.skip('test is flaky when run with python3-dbg')
  1972. assert_raises(TypeError, lfilter, [1.0], [1.0], [1.0, None, 2.0])
  1973. assert_raises(TypeError, lfilter, [1.0], [None], [1.0, 2.0, 3.0])
  1974. assert_raises(TypeError, lfilter, [None], [1.0], [1.0, 2.0, 3.0])
  1975. @make_xp_test_case(lfilter)
  1976. def test_lfilter_notimplemented_input(xp):
  1977. # Should not crash, gh-7991
  1978. assert_raises(NotImplementedError, lfilter, [2,3], [4,5], [1,2,3,4,5])
  1979. @pytest.mark.parametrize('dt', ["uint8", "int8", "uint16", "int16",
  1980. "uint32", "int32",
  1981. "uint64", "int64",
  1982. "float32", "float64",
  1983. ])
  1984. @xfail_xp_backends("jax.numpy", reason="fails all around")
  1985. @make_xp_test_case(correlate)
  1986. class TestCorrelateReal:
  1987. def _setup_rank1(self, dt, xp):
  1988. a = xp.linspace(0, 3, 4, dtype=dt)
  1989. b = xp.linspace(1, 2, 2, dtype=dt)
  1990. y_r = xp.asarray([0, 2, 5, 8, 3], dtype=dt)
  1991. return a, b, y_r
  1992. def equal_tolerance(self, res_dt):
  1993. # default value of keyword
  1994. decimal = 6
  1995. try:
  1996. dt_info = np.finfo(res_dt)
  1997. if hasattr(dt_info, 'resolution'):
  1998. decimal = int(-0.5*np.log10(dt_info.resolution))
  1999. except Exception:
  2000. pass
  2001. return decimal
  2002. def equal_tolerance_fft(self, res_dt):
  2003. # FFT implementations convert longdouble arguments down to
  2004. # double so don't expect better precision, see gh-9520
  2005. if res_dt == np.longdouble:
  2006. return self.equal_tolerance(np.float64)
  2007. else:
  2008. return self.equal_tolerance(res_dt)
  2009. @skip_xp_backends(np_only=True, reason="order='F'")
  2010. def test_method(self, dt, xp):
  2011. dt = getattr(xp, dt)
  2012. a, b, y_r = self._setup_rank3(dt, xp)
  2013. y_fft = correlate(a, b, method='fft')
  2014. y_direct = correlate(a, b, method='direct')
  2015. assert_array_almost_equal(y_r, y_fft,
  2016. decimal=self.equal_tolerance_fft(y_fft.dtype),)
  2017. assert_array_almost_equal(y_r, y_direct,
  2018. decimal=self.equal_tolerance(y_direct.dtype),)
  2019. assert y_fft.dtype == dt
  2020. assert y_direct.dtype == dt
  2021. def test_rank1_valid(self, dt, xp):
  2022. if is_torch(xp) and dt in ["uint16", "uint32", "uint64"]:
  2023. pytest.skip("torch does not support unsigned ints")
  2024. dt = getattr(xp, dt) if isinstance(dt, str) else dt
  2025. a, b, y_r = self._setup_rank1(dt, xp)
  2026. y = correlate(a, b, 'valid')
  2027. assert_array_almost_equal(y, y_r[1:4])
  2028. assert y.dtype == dt
  2029. # See gh-5897
  2030. y = correlate(b, a, 'valid')
  2031. assert_array_almost_equal(y, xp.flip(y_r[1:4]))
  2032. assert y.dtype == dt
  2033. def test_rank1_same(self, dt, xp):
  2034. if is_torch(xp) and dt in ["uint16", "uint32", "uint64"]:
  2035. pytest.skip("torch does not support unsigned ints")
  2036. dt = getattr(xp, dt) if isinstance(dt, str) else dt
  2037. a, b, y_r = self._setup_rank1(dt, xp)
  2038. y = correlate(a, b, 'same')
  2039. assert_array_almost_equal(y, y_r[:-1])
  2040. assert y.dtype == dt
  2041. def test_rank1_full(self, dt, xp):
  2042. if is_torch(xp) and dt in ["uint16", "uint32", "uint64"]:
  2043. pytest.skip("torch does not support unsigned ints")
  2044. dt = getattr(xp, dt) if isinstance(dt, str) else dt
  2045. a, b, y_r = self._setup_rank1(dt, xp)
  2046. y = correlate(a, b, 'full')
  2047. assert_array_almost_equal(y, y_r)
  2048. assert y.dtype == dt
  2049. def _setup_rank3(self, dt, xp):
  2050. a = np.linspace(0, 39, 40).reshape((2, 4, 5), order='F').astype(
  2051. dt)
  2052. b = np.linspace(0, 23, 24).reshape((2, 3, 4), order='F').astype(
  2053. dt)
  2054. y_r = np.array([[[0., 184., 504., 912., 1360., 888., 472., 160.],
  2055. [46., 432., 1062., 1840., 2672., 1698., 864., 266.],
  2056. [134., 736., 1662., 2768., 3920., 2418., 1168., 314.],
  2057. [260., 952., 1932., 3056., 4208., 2580., 1240., 332.],
  2058. [202., 664., 1290., 1984., 2688., 1590., 712., 150.],
  2059. [114., 344., 642., 960., 1280., 726., 296., 38.]],
  2060. [[23., 400., 1035., 1832., 2696., 1737., 904., 293.],
  2061. [134., 920., 2166., 3680., 5280., 3306., 1640., 474.],
  2062. [325., 1544., 3369., 5512., 7720., 4683., 2192., 535.],
  2063. [571., 1964., 3891., 6064., 8272., 4989., 2324., 565.],
  2064. [434., 1360., 2586., 3920., 5264., 3054., 1312., 230.],
  2065. [241., 700., 1281., 1888., 2496., 1383., 532., 39.]],
  2066. [[22., 214., 528., 916., 1332., 846., 430., 132.],
  2067. [86., 484., 1098., 1832., 2600., 1602., 772., 206.],
  2068. [188., 802., 1698., 2732., 3788., 2256., 1018., 218.],
  2069. [308., 1006., 1950., 2996., 4052., 2400., 1078., 230.],
  2070. [230., 692., 1290., 1928., 2568., 1458., 596., 78.],
  2071. [126., 354., 636., 924., 1212., 654., 234., 0.]]],
  2072. dtype=np.float64).astype(dt)
  2073. return a, b, y_r
  2074. @skip_xp_backends(np_only=True, reason="order='F'")
  2075. def test_rank3_valid(self, dt, xp):
  2076. dt = getattr(xp, dt) if isinstance(dt, str) else dt
  2077. a, b, y_r = self._setup_rank3(dt, xp)
  2078. y = correlate(a, b, "valid")
  2079. assert_array_almost_equal(y, y_r[1:2, 2:4, 3:5])
  2080. assert y.dtype == dt
  2081. # See gh-5897
  2082. y = correlate(b, a, "valid")
  2083. assert_array_almost_equal(y, y_r[1:2, 2:4, 3:5][::-1, ::-1, ::-1])
  2084. assert y.dtype == dt
  2085. @skip_xp_backends(np_only=True, reason="order='F'")
  2086. def test_rank3_same(self, dt, xp):
  2087. dt = getattr(xp, dt) if isinstance(dt, str) else dt
  2088. a, b, y_r = self._setup_rank3(dt, xp)
  2089. y = correlate(a, b, "same")
  2090. xp_assert_close(y, y_r[0:-1, 1:-1, 1:-2])
  2091. assert y.dtype == dt
  2092. @skip_xp_backends(np_only=True, reason="order='F'")
  2093. def test_rank3_all(self, dt, xp):
  2094. dt = getattr(xp, dt) if isinstance(dt, str) else dt
  2095. a, b, y_r = self._setup_rank3(dt, xp)
  2096. y = correlate(a, b)
  2097. xp_assert_close(y, y_r)
  2098. assert y.dtype == dt
  2099. @make_xp_test_case(correlate)
  2100. class TestCorrelate:
  2101. # Tests that don't depend on dtype
  2102. @skip_xp_backends(np_only=True)
  2103. def test_invalid_shapes(self, xp):
  2104. # By "invalid," we mean that no one
  2105. # array has dimensions that are all at
  2106. # least as large as the corresponding
  2107. # dimensions of the other array. This
  2108. # setup should throw a ValueError.
  2109. a = np.arange(1, 7).reshape((2, 3))
  2110. b = np.arange(-6, 0).reshape((3, 2))
  2111. assert_raises(ValueError, correlate, *(a, b), **{'mode': 'valid'})
  2112. assert_raises(ValueError, correlate, *(b, a), **{'mode': 'valid'})
  2113. @skip_xp_backends(np_only=True)
  2114. def test_invalid_params(self, xp):
  2115. a = [3, 4, 5]
  2116. b = [1, 2, 3]
  2117. assert_raises(ValueError, correlate, a, b, mode='spam')
  2118. assert_raises(ValueError, correlate, a, b, mode='eggs', method='fft')
  2119. assert_raises(ValueError, correlate, a, b, mode='ham', method='direct')
  2120. assert_raises(ValueError, correlate, a, b, mode='full', method='bacon')
  2121. assert_raises(ValueError, correlate, a, b, mode='same', method='bacon')
  2122. @skip_xp_backends(np_only=True)
  2123. def test_mismatched_dims(self, xp):
  2124. # Input arrays should have the same number of dimensions
  2125. assert_raises(ValueError, correlate, [1], 2, method='direct')
  2126. assert_raises(ValueError, correlate, 1, [2], method='direct')
  2127. assert_raises(ValueError, correlate, [1], 2, method='fft')
  2128. assert_raises(ValueError, correlate, 1, [2], method='fft')
  2129. assert_raises(ValueError, correlate, [1], [[2]])
  2130. assert_raises(ValueError, correlate, [3], 2)
  2131. @skip_xp_backends(cpu_only=True, exceptions=['cupy'])
  2132. @skip_xp_backends("jax.numpy", reason="dtype differs")
  2133. def test_numpy_fastpath(self, xp):
  2134. a = xp.asarray([1, 2, 3])
  2135. b = xp.asarray([4, 5])
  2136. xp_assert_close(correlate(a, b, mode='same'), xp.asarray([5, 14, 23]))
  2137. a = xp.asarray([1, 2, 3])
  2138. b = xp.asarray([4, 5, 6])
  2139. xp_assert_close(correlate(a, b, mode='same'), xp.asarray([17, 32, 23]))
  2140. xp_assert_close(correlate(a, b, mode='full'), xp.asarray([6, 17, 32, 23, 12]))
  2141. xp_assert_close(correlate(a, b, mode='valid'), xp.asarray([32]))
  2142. @make_xp_test_case(correlation_lags)
  2143. @pytest.mark.parametrize("mode", ["valid", "same", "full"])
  2144. @pytest.mark.parametrize("behind", [True, False])
  2145. @pytest.mark.parametrize("input_size", [100, 101, 1000, 1001,
  2146. pytest.param(10000, marks=[pytest.mark.slow]),
  2147. pytest.param(10001, marks=[pytest.mark.slow])]
  2148. )
  2149. def test_correlation_lags(mode, behind, input_size, xp):
  2150. # generate random data
  2151. rng = np.random.RandomState(0)
  2152. in1 = rng.standard_normal(input_size)
  2153. offset = int(input_size/10)
  2154. # generate offset version of array to correlate with
  2155. if behind:
  2156. # y is behind x
  2157. in2 = np.concatenate([rng.standard_normal(offset), in1])
  2158. expected = -offset
  2159. else:
  2160. # y is ahead of x
  2161. in2 = in1[offset:]
  2162. expected = offset
  2163. # cross correlate, returning lag information
  2164. correlation = correlate(in1, in2, mode=mode)
  2165. lags = correlation_lags(in1.size, in2.size, mode=mode)
  2166. # identify the peak
  2167. lag_index = np.argmax(correlation)
  2168. # Check as expected
  2169. xp_assert_equal(lags[lag_index], expected)
  2170. # Correlation and lags shape should match
  2171. assert lags.shape == correlation.shape
  2172. @make_xp_test_case(correlation_lags)
  2173. def test_correlation_lags_invalid_mode(xp):
  2174. with pytest.raises(ValueError, match="Mode asdfgh is invalid"):
  2175. correlation_lags(100, 100, mode="asdfgh")
  2176. @make_xp_test_case(correlate)
  2177. @pytest.mark.parametrize('dt_name', ['complex64', 'complex128'])
  2178. class TestCorrelateComplex:
  2179. # The decimal precision to be used for comparing results.
  2180. # This value will be passed as the 'decimal' keyword argument of
  2181. # assert_array_almost_equal().
  2182. # Since correlate may chose to use FFT method which converts
  2183. # longdoubles to doubles internally don't expect better precision
  2184. # for longdouble than for double (see gh-9520).
  2185. def decimal(self, dt, xp):
  2186. if is_numpy(xp) and dt == np.clongdouble:
  2187. dt = np.cdouble
  2188. # emulate np.finfo(dt).precision for complex64 and complex128
  2189. prec = {64: 15, 32: 6}[xp.finfo(dt).bits]
  2190. return int(2 * prec / 3)
  2191. def _setup_rank1(self, dt, mode, xp):
  2192. rng = np.random.default_rng(9)
  2193. a = np.random.randn(10).astype(dt)
  2194. a += 1j * rng.standard_normal(10).astype(dt)
  2195. b = np.random.randn(8).astype(dt)
  2196. b += 1j * rng.standard_normal(8).astype(dt)
  2197. y_r = (correlate(a.real, b.real, mode=mode) +
  2198. correlate(a.imag, b.imag, mode=mode)).astype(dt)
  2199. y_r += 1j * (-correlate(a.real, b.imag, mode=mode) +
  2200. correlate(a.imag, b.real, mode=mode))
  2201. a, b, y_r = xp.asarray(a), xp.asarray(b), xp.asarray(y_r)
  2202. return a, b, y_r
  2203. def test_rank1_valid(self, dt_name, xp):
  2204. a, b, y_r = self._setup_rank1(dt_name, 'valid', xp)
  2205. dt = getattr(xp, dt_name)
  2206. y = correlate(a, b, 'valid')
  2207. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt, xp))
  2208. assert y.dtype == dt
  2209. # See gh-5897
  2210. y = correlate(b, a, 'valid')
  2211. assert_array_almost_equal(y, xp.conj(xp.flip(y_r)),
  2212. decimal=self.decimal(dt, xp))
  2213. assert y.dtype == dt
  2214. def test_rank1_same(self, dt_name, xp):
  2215. a, b, y_r = self._setup_rank1(dt_name, 'same', xp)
  2216. dt = getattr(xp, dt_name)
  2217. y = correlate(a, b, 'same')
  2218. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt, xp))
  2219. assert y.dtype == dt
  2220. def test_rank1_full(self, dt_name, xp):
  2221. a, b, y_r = self._setup_rank1(dt_name, 'full', xp)
  2222. dt = getattr(xp, dt_name)
  2223. y = correlate(a, b, 'full')
  2224. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt, xp))
  2225. assert y.dtype == dt
  2226. def test_swap_full(self, dt_name, xp):
  2227. dt = getattr(xp, dt_name)
  2228. d = xp.asarray([0.+0.j, 1.+1.j, 2.+2.j], dtype=dt)
  2229. k = xp.asarray([1.+3.j, 2.+4.j, 3.+5.j, 4.+6.j], dtype=dt)
  2230. y = correlate(d, k)
  2231. xp_assert_close(
  2232. y, xp.asarray([0.+0.j, 10.-2.j, 28.-6.j, 22.-6.j, 16.-6.j, 8.-4.j]),
  2233. atol=1e-6, check_dtype=False
  2234. )
  2235. def test_swap_same(self, dt_name, xp):
  2236. d = xp.asarray([0.+0.j, 1.+1.j, 2.+2.j])
  2237. k = xp.asarray([1.+3.j, 2.+4.j, 3.+5.j, 4.+6.j])
  2238. y = correlate(d, k, mode="same")
  2239. xp_assert_close(y, xp.asarray([10.-2.j, 28.-6.j, 22.-6.j]))
  2240. @skip_xp_backends("cupy", reason="notimplementederror")
  2241. def test_rank3(self, dt_name, xp):
  2242. if is_jax(xp) and SCIPY_DEVICE != "cpu":
  2243. pytest.xfail(reason="error tolerances exceeded with JAX on gpu")
  2244. a = np.random.randn(10, 8, 6).astype(dt_name)
  2245. a += 1j * np.random.randn(10, 8, 6).astype(dt_name)
  2246. b = np.random.randn(8, 6, 4).astype(dt_name)
  2247. b += 1j * np.random.randn(8, 6, 4).astype(dt_name)
  2248. y_r = (correlate(a.real, b.real)
  2249. + correlate(a.imag, b.imag)).astype(dt_name)
  2250. y_r += 1j * (-correlate(a.real, b.imag) + correlate(a.imag, b.real))
  2251. a, b, y_r = xp.asarray(a), xp.asarray(b), xp.asarray(y_r)
  2252. dt = getattr(xp, dt_name)
  2253. y = correlate(a, b, 'full')
  2254. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt, xp) - 1)
  2255. assert y.dtype == dt
  2256. @skip_xp_backends(np_only=True) # XXX: check 0D/scalars on backends.
  2257. def test_rank0(self, dt_name, xp):
  2258. a = np.array(np.random.randn()).astype(dt_name)
  2259. a += 1j * np.array(np.random.randn()).astype(dt_name)
  2260. b = np.array(np.random.randn()).astype(dt_name)
  2261. b += 1j * np.array(np.random.randn()).astype(dt_name)
  2262. dt = getattr(xp, dt_name)
  2263. y_r = (correlate(a.real, b.real)
  2264. + correlate(a.imag, b.imag)).astype(dt)
  2265. y_r += 1j * np.array(-correlate(a.real, b.imag) +
  2266. correlate(a.imag, b.real))
  2267. a, b = xp.asarray(a), xp.asarray(b)
  2268. y = correlate(a, b, 'full')
  2269. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt, xp) - 1)
  2270. assert y.dtype == dt
  2271. xp_assert_equal(correlate([1], [2j]), np.asarray(correlate(1, 2j)),
  2272. check_shape=False)
  2273. xp_assert_equal(correlate([2j], [3j]), np.asarray(correlate(2j, 3j)),
  2274. check_shape=False)
  2275. xp_assert_equal(correlate([3j], [4]), np.asarray(correlate(3j, 4)),
  2276. check_shape=False)
  2277. class TestCorrelate2d:
  2278. @make_xp_test_case(signal.correlate)
  2279. def test_consistency_correlate_funcs(self, xp):
  2280. # Compare np.correlate, signal.correlate, signal.correlate2d
  2281. a = np.arange(5)
  2282. b = np.array([3.2, 1.4, 3])
  2283. for mode in ['full', 'valid', 'same']:
  2284. a_xp, b_xp = xp.asarray(a), xp.asarray(b)
  2285. np_corr_result = np.correlate(a, b, mode=mode)
  2286. assert_almost_equal(signal.correlate(a_xp, b_xp, mode=mode),
  2287. xp.asarray(np_corr_result))
  2288. # See gh-5897
  2289. if mode == 'valid':
  2290. np_corr_result = np.correlate(b, a, mode=mode)
  2291. assert_almost_equal(signal.correlate(b_xp, a_xp, mode=mode),
  2292. xp.asarray(np_corr_result))
  2293. @skip_xp_backends(np_only=True)
  2294. @make_xp_test_case(signal.correlate2d)
  2295. def test_consistency_correlate_funcs_2(self, xp):
  2296. # Compare np.correlate, signal.correlate, signal.correlate2d
  2297. a = np.arange(5)
  2298. b = np.array([3.2, 1.4, 3])
  2299. for mode in ['full', 'valid', 'same']:
  2300. assert_almost_equal(np.squeeze(signal.correlate2d([a], [b],
  2301. mode=mode)),
  2302. signal.correlate(a, b, mode=mode))
  2303. # See gh-5897
  2304. if mode == 'valid':
  2305. assert_almost_equal(np.squeeze(signal.correlate2d([b], [a],
  2306. mode=mode)),
  2307. signal.correlate(b, a, mode=mode))
  2308. @skip_xp_backends(np_only=True)
  2309. @make_xp_test_case(signal.correlate2d)
  2310. def test_invalid_shapes(self, xp):
  2311. # By "invalid," we mean that no one
  2312. # array has dimensions that are all at
  2313. # least as large as the corresponding
  2314. # dimensions of the other array. This
  2315. # setup should throw a ValueError.
  2316. a = np.arange(1, 7).reshape((2, 3))
  2317. b = np.arange(-6, 0).reshape((3, 2))
  2318. assert_raises(ValueError, signal.correlate2d, *(a, b), **{'mode': 'valid'})
  2319. assert_raises(ValueError, signal.correlate2d, *(b, a), **{'mode': 'valid'})
  2320. @make_xp_test_case(signal.correlate2d)
  2321. def test_complex_input(self, xp):
  2322. xp_assert_equal(signal.correlate2d(xp.asarray([[1]]), xp.asarray([[2j]])),
  2323. xp.asarray([-2j]), check_shape=False, check_dtype=False)
  2324. xp_assert_equal(signal.correlate2d(xp.asarray([[2j]]), xp.asarray([[3j]])),
  2325. xp.asarray([6+0j]), check_shape=False, check_dtype=False)
  2326. xp_assert_equal(signal.correlate2d(xp.asarray([[3j]]), xp.asarray([[4]])),
  2327. xp.asarray([12j]), check_shape=False, check_dtype=False)
  2328. @make_xp_test_case(lfilter_zi)
  2329. class TestLFilterZI:
  2330. @skip_xp_backends(np_only=True, reason='list inputs are numpy specific')
  2331. def test_array_like(self, xp):
  2332. zi_expected = xp.asarray([5.0, -1.0])
  2333. zi = lfilter_zi([1.0, 0.0, 2.0], [1.0, -1.0, 0.5])
  2334. assert_array_almost_equal(zi, zi_expected)
  2335. def test_basic(self, xp):
  2336. a = xp.asarray([1.0, -1.0, 0.5])
  2337. b = xp.asarray([1.0, 0.0, 2.0])
  2338. zi_expected = xp.asarray([5.0, -1.0])
  2339. zi = lfilter_zi(b, a)
  2340. assert_array_almost_equal(zi, zi_expected)
  2341. def test_scale_invariance(self, xp):
  2342. # Regression test. There was a bug in which b was not correctly
  2343. # rescaled when a[0] was nonzero.
  2344. b = xp.asarray([2.0, 8, 5])
  2345. a = xp.asarray([1.0, 1, 8])
  2346. zi1 = lfilter_zi(b, a)
  2347. zi2 = lfilter_zi(2*b, 2*a)
  2348. xp_assert_close(zi2, zi1, rtol=1e-12)
  2349. @pytest.mark.parametrize('dtype', ['float32', 'float64'])
  2350. def test_types(self, dtype, xp):
  2351. dtype = getattr(xp, dtype)
  2352. b = xp.zeros((8), dtype=dtype)
  2353. a = xp.asarray([1], dtype=dtype)
  2354. assert signal.lfilter_zi(b, a).dtype == dtype
  2355. @make_xp_test_case(filtfilt, sosfiltfilt)
  2356. class TestFiltFilt:
  2357. filtfilt_kind = 'tf'
  2358. def filtfilt(self, zpk, x, axis=-1, padtype='odd', padlen=None,
  2359. method='pad', irlen=None, xp=None):
  2360. if self.filtfilt_kind == 'tf':
  2361. b, a = zpk2tf(*zpk)
  2362. b, a = xp.asarray(b), xp.asarray(a)
  2363. return filtfilt(b, a, x, axis, padtype, padlen, method, irlen)
  2364. elif self.filtfilt_kind == 'sos':
  2365. sos = zpk2sos(*zpk)
  2366. sos = xp.asarray(sos)
  2367. return sosfiltfilt(sos, x, axis, padtype, padlen)
  2368. @skip_xp_backends('torch', reason='negative strides')
  2369. def test_basic(self, xp):
  2370. if is_jax(xp) and self.filtfilt_kind == 'sos':
  2371. pytest.skip(reason='sosfilt works in-place')
  2372. zpk = tf2zpk(xp.asarray([1., 2, 3]), xp.asarray([1., 2, 3]))
  2373. out = self.filtfilt(zpk, xp.arange(12), xp=xp)
  2374. atol= 4e-9 if is_cupy(xp) else 5.28e-11
  2375. xp_assert_close(out, xp.arange(12, dtype=xp.float64), atol=atol)
  2376. @skip_xp_backends('torch', reason='negative strides')
  2377. def test_sine(self, xp):
  2378. if is_jax(xp) and self.filtfilt_kind == 'sos':
  2379. pytest.skip(reason='sosfilt works in-place')
  2380. rate = 2000
  2381. t = xp.linspace(0, 1.0, rate + 1)
  2382. # A signal with low frequency and a high frequency.
  2383. xlow = xp.sin(5 * 2 * np.pi * t)
  2384. xhigh = xp.sin(250 * 2 * np.pi * t)
  2385. x = xlow + xhigh
  2386. zpk = butter(8, xp.asarray(0.125), output='zpk')
  2387. # r is the magnitude of the largest pole.
  2388. r = np.abs(zpk[1]).max()
  2389. eps = 1e-5
  2390. # n estimates the number of steps for the
  2391. # transient to decay by a factor of eps.
  2392. n = int(np.ceil(np.log(eps) / np.log(r)))
  2393. # High order lowpass filter...
  2394. y = self.filtfilt(zpk, x, padlen=n, xp=xp)
  2395. # Result should be just xlow.
  2396. err = np.abs(y - xlow).max()
  2397. assert err < 1e-4
  2398. # A 2D case.
  2399. x2d = xp.asarray(np.vstack([xlow, xlow + xhigh]))
  2400. y2d = self.filtfilt(zpk, x2d, padlen=n, axis=1, xp=xp)
  2401. assert y2d.shape == x2d.shape
  2402. err = np.abs(y2d - xlow).max()
  2403. assert err < 1e-4
  2404. # Use the previous result to check the use of the axis keyword.
  2405. # (Regression test for ticket #1620)
  2406. y2dt = self.filtfilt(zpk, x2d.T, padlen=n, axis=0, xp=xp)
  2407. xp_assert_equal(y2d, y2dt.T)
  2408. @skip_xp_backends('torch', reason='negative strides')
  2409. def test_axis(self, xp):
  2410. if is_jax(xp) and self.filtfilt_kind == 'sos':
  2411. pytest.skip(reason='sosfilt works in-place')
  2412. # Test the 'axis' keyword on a 3D array.
  2413. x = np.arange(10.0 * 11.0 * 12.0).reshape(10, 11, 12)
  2414. x = xp.asarray(x)
  2415. zpk = butter(3, xp.asarray(0.125), output='zpk')
  2416. y0 = self.filtfilt(zpk, x, padlen=0, axis=0, xp=xp)
  2417. y1 = self.filtfilt(
  2418. zpk, xp.asarray(np.swapaxes(x, 0, 1)), padlen=0, axis=1, xp=xp
  2419. )
  2420. xp_assert_equal(y0, xp.asarray(np.swapaxes(y1, 0, 1)))
  2421. y2 = self.filtfilt(
  2422. zpk, xp.asarray(np.swapaxes(x, 0, 2)), padlen=0, axis=2, xp=xp
  2423. )
  2424. xp_assert_equal(y0, xp.asarray(np.swapaxes(y2, 0, 2)))
  2425. @skip_xp_backends(np_only=True,
  2426. reason='python scalars in array_namespace are np-only')
  2427. def test_acoeff(self, xp):
  2428. if self.filtfilt_kind != 'tf':
  2429. return # only necessary for TF
  2430. # test for 'a' coefficient as single number
  2431. out = signal.filtfilt(
  2432. xp.asarray([.5, .5]), 1, xp.arange(10, dtype=xp.float64)
  2433. )
  2434. xp_assert_close(out, xp.arange(10, dtype=xp.float64), rtol=1e-14, atol=1e-14)
  2435. @skip_xp_backends(np_only=True, reason='_filtfilt_gust is np-only')
  2436. def test_gust_simple(self, xp):
  2437. if self.filtfilt_kind != 'tf':
  2438. pytest.skip('gust only implemented for TF systems')
  2439. # The input array has length 2. The exact solution for this case
  2440. # was computed "by hand".
  2441. x = xp.asarray([1.0, 2.0])
  2442. b = xp.asarray([0.5])
  2443. a = xp.asarray([1.0, -0.5])
  2444. y, z1, z2 = _filtfilt_gust(b, a, x)
  2445. xp_assert_close(z1[0], 0.3*x[0] + 0.2*x[1])
  2446. xp_assert_close(z2[0], 0.2*x[0] + 0.3*x[1])
  2447. xp_assert_close(y,
  2448. xp.asarray([z1[0] + 0.25*z2[0] + 0.25*x[0] + 0.125*x[1],
  2449. 0.25*z1[0] + z2[0] + 0.125*x[0] + 0.25*x[1]])
  2450. )
  2451. @skip_xp_backends(np_only=True,
  2452. reason='python scalars in array_namespace are np-only')
  2453. def test_gust_scalars(self, xp):
  2454. if self.filtfilt_kind != 'tf':
  2455. pytest.skip('gust only implemented for TF systems')
  2456. # The filter coefficients are both scalars, so the filter simply
  2457. # multiplies its input by b/a. When it is used in filtfilt, the
  2458. # factor is (b/a)**2.
  2459. x = xp.arange(12)
  2460. b = 3.0
  2461. a = 2.0
  2462. y = filtfilt(b, a, x, method="gust")
  2463. expected = (b/a)**2 * x
  2464. xp_assert_close(y, expected)
  2465. @make_xp_test_case(sosfiltfilt, filtfilt)
  2466. class TestSOSFiltFilt(TestFiltFilt):
  2467. filtfilt_kind = 'sos'
  2468. @skip_xp_backends('jax.numpy', reason='sosfilt works in-place')
  2469. @skip_xp_backends('torch', reason='negative strides')
  2470. def test_equivalence(self, xp):
  2471. """Test equivalence between sosfiltfilt and filtfilt"""
  2472. x = np.random.RandomState(0).randn(1000)
  2473. x = xp.asarray(x)
  2474. for order in range(1, 6):
  2475. zpk = signal.butter(order, 0.35, output='zpk')
  2476. b, a = zpk2tf(*zpk)
  2477. sos = zpk2sos(*zpk)
  2478. b, a, sos = map(xp.asarray, (b, a, sos))
  2479. y = filtfilt(b, a, x)
  2480. y_sos = sosfiltfilt(sos, x)
  2481. xp_assert_close(y, y_sos, atol=1e-12, err_msg=f'order={order}')
  2482. def filtfilt_gust_opt(b, a, x):
  2483. """
  2484. An alternative implementation of filtfilt with Gustafsson edges.
  2485. This function computes the same result as
  2486. `scipy.signal._signaltools._filtfilt_gust`, but only 1-d arrays
  2487. are accepted. The problem is solved using `fmin` from `scipy.optimize`.
  2488. `_filtfilt_gust` is significantly faster than this implementation.
  2489. """
  2490. def filtfilt_gust_opt_func(ics, b, a, x):
  2491. """Objective function used in filtfilt_gust_opt."""
  2492. m = max(len(a), len(b)) - 1
  2493. z0f = ics[:m]
  2494. z0b = ics[m:]
  2495. y_f = lfilter(b, a, x, zi=z0f)[0]
  2496. y_fb = lfilter(b, a, y_f[::-1], zi=z0b)[0][::-1]
  2497. y_b = lfilter(b, a, x[::-1], zi=z0b)[0][::-1]
  2498. y_bf = lfilter(b, a, y_b, zi=z0f)[0]
  2499. value = np.sum((y_fb - y_bf)**2)
  2500. return value
  2501. m = max(len(a), len(b)) - 1
  2502. zi = lfilter_zi(b, a)
  2503. ics = np.concatenate((x[:m].mean()*zi, x[-m:].mean()*zi))
  2504. result = fmin(filtfilt_gust_opt_func, ics, args=(b, a, x),
  2505. xtol=1e-10, ftol=1e-12,
  2506. maxfun=10000, maxiter=10000,
  2507. full_output=True, disp=False)
  2508. opt, fopt, niter, funcalls, warnflag = result
  2509. if warnflag > 0:
  2510. raise RuntimeError(
  2511. f"minimization failed in filtfilt_gust_opt: warnflag={warnflag}"
  2512. )
  2513. z0f = opt[:m]
  2514. z0b = opt[m:]
  2515. # Apply the forward-backward filter using the computed initial
  2516. # conditions.
  2517. y_b = lfilter(b, a, x[::-1], zi=z0b)[0][::-1]
  2518. y = lfilter(b, a, y_b, zi=z0f)[0]
  2519. return y, z0f, z0b
  2520. def check_filtfilt_gust(b, a, shape, axis, irlen=None):
  2521. # Generate x, the data to be filtered.
  2522. rng = np.random.default_rng(123)
  2523. x = rng.standard_normal(shape)
  2524. # Apply filtfilt to x. This is the main calculation to be checked.
  2525. y = filtfilt(b, a, x, axis=axis, method="gust", irlen=irlen)
  2526. # Also call the private function so we can test the ICs.
  2527. yg, zg1, zg2 = _filtfilt_gust(b, a, x, axis=axis, irlen=irlen)
  2528. # filtfilt_gust_opt is an independent implementation that gives the
  2529. # expected result, but it only handles 1-D arrays, so use some looping
  2530. # and reshaping shenanigans to create the expected output arrays.
  2531. xx = np.swapaxes(x, axis, -1)
  2532. out_shape = xx.shape[:-1]
  2533. yo = np.empty_like(xx)
  2534. m = max(len(a), len(b)) - 1
  2535. zo1 = np.empty(out_shape + (m,))
  2536. zo2 = np.empty(out_shape + (m,))
  2537. for indx in product(*[range(d) for d in out_shape]):
  2538. yo[indx], zo1[indx], zo2[indx] = filtfilt_gust_opt(b, a, xx[indx])
  2539. yo = np.swapaxes(yo, -1, axis)
  2540. zo1 = np.swapaxes(zo1, -1, axis)
  2541. zo2 = np.swapaxes(zo2, -1, axis)
  2542. xp_assert_close(y, yo, rtol=1e-8, atol=1e-9)
  2543. xp_assert_close(yg, yo, rtol=1e-8, atol=1e-9)
  2544. xp_assert_close(zg1, zo1, rtol=1e-8, atol=1e-9)
  2545. xp_assert_close(zg2, zo2, rtol=1e-8, atol=1e-9)
  2546. @make_xp_test_case(choose_conv_method)
  2547. @pytest.mark.fail_slow(10)
  2548. def test_choose_conv_method(xp):
  2549. for mode in ['valid', 'same', 'full']:
  2550. for ndim in [1, 2]:
  2551. n, k, true_method = 8, 6, 'direct'
  2552. x = np.random.randn(*((n,) * ndim))
  2553. h = np.random.randn(*((k,) * ndim))
  2554. method = choose_conv_method(x, h, mode=mode)
  2555. assert method == true_method
  2556. method_try, times = choose_conv_method(x, h, mode=mode, measure=True)
  2557. assert method_try in {'fft', 'direct'}
  2558. assert isinstance(times, dict)
  2559. assert 'fft' in times.keys() and 'direct' in times.keys()
  2560. n = 10
  2561. for not_fft_conv_supp in ["complex256", "complex192"]:
  2562. if hasattr(np, not_fft_conv_supp):
  2563. x = np.ones(n, dtype=not_fft_conv_supp)
  2564. h = x.copy()
  2565. assert choose_conv_method(x, h, mode=mode) == 'direct'
  2566. x = np.array([2**51], dtype=np.int64)
  2567. h = x.copy()
  2568. assert choose_conv_method(x, h, mode=mode) == 'direct'
  2569. @make_xp_test_case(choose_conv_method)
  2570. def test_choose_conv_method_2(xp):
  2571. for mode in ['valid', 'same', 'full']:
  2572. n = 10
  2573. for not_fft_conv_supp in ["complex256", "complex192"]:
  2574. if hasattr(np, not_fft_conv_supp):
  2575. x = np.ones(n, dtype=not_fft_conv_supp)
  2576. h = x.copy()
  2577. assert choose_conv_method(x, h, mode=mode) == 'direct'
  2578. @skip_xp_backends(np_only=True)
  2579. @pytest.mark.fail_slow(10)
  2580. def test_filtfilt_gust(xp):
  2581. # Design a filter.
  2582. z, p, k = signal.ellip(3, 0.01, 120, 0.0875, output='zpk')
  2583. # Find the approximate impulse response length of the filter.
  2584. eps = 1e-10
  2585. r = np.max(np.abs(p))
  2586. approx_impulse_len = int(np.ceil(np.log(eps) / np.log(r)))
  2587. b, a = zpk2tf(z, p, k)
  2588. for irlen in [None, approx_impulse_len]:
  2589. signal_len = 5 * approx_impulse_len
  2590. # 1-d test case
  2591. check_filtfilt_gust(b, a, (signal_len,), 0, irlen)
  2592. # 3-d test case; test each axis.
  2593. for axis in range(3):
  2594. shape = [2, 2, 2]
  2595. shape[axis] = signal_len
  2596. check_filtfilt_gust(b, a, shape, axis, irlen)
  2597. # Test case with length less than 2*approx_impulse_len.
  2598. # In this case, `filtfilt_gust` should behave the same as if
  2599. # `irlen=None` was given.
  2600. length = 2*approx_impulse_len - 50
  2601. check_filtfilt_gust(b, a, (length,), 0, approx_impulse_len)
  2602. @make_xp_test_case(signal.decimate)
  2603. class TestDecimate:
  2604. def test_bad_args(self, xp):
  2605. x = xp.arange(12)
  2606. assert_raises(TypeError, signal.decimate, x, q=0.5, n=1)
  2607. assert_raises(TypeError, signal.decimate, x, q=2, n=0.5)
  2608. def test_basic_IIR(self, xp):
  2609. x = xp.arange(12)
  2610. y = signal.decimate(x, 2, n=1, ftype='iir', zero_phase=False).round()
  2611. xp_assert_equal(y, x[::2].astype(float))
  2612. def test_basic_FIR(self, xp):
  2613. x = xp.arange(12)
  2614. y = signal.decimate(x, 2, n=1, ftype='fir', zero_phase=False).round()
  2615. xp_assert_equal(y, x[::2].astype(float))
  2616. def test_shape(self, xp):
  2617. # Regression test for ticket #1480.
  2618. z = xp.zeros((30, 30))
  2619. d0 = signal.decimate(z, 2, axis=0, zero_phase=False)
  2620. assert d0.shape == (15, 30)
  2621. d1 = signal.decimate(z, 2, axis=1, zero_phase=False)
  2622. assert d1.shape == (30, 15)
  2623. @skip_xp_backends(np_only=True, reason="test code is NumPy specific")
  2624. def test_phaseshift_FIR(self, xp):
  2625. with warnings.catch_warnings():
  2626. warnings.filterwarnings(
  2627. "ignore", "Badly conditioned filter", BadCoefficients)
  2628. self._test_phaseshift(method='fir', zero_phase=False)
  2629. @skip_xp_backends(np_only=True, reason="test code is NumPy specific")
  2630. def test_zero_phase_FIR(self, xp):
  2631. with warnings.catch_warnings():
  2632. warnings.filterwarnings(
  2633. "ignore", "Badly conditioned filter", BadCoefficients)
  2634. self._test_phaseshift(method='fir', zero_phase=True)
  2635. @skip_xp_backends(np_only=True, reason="test code is NumPy specific")
  2636. def test_phaseshift_IIR(self, xp):
  2637. self._test_phaseshift(method='iir', zero_phase=False)
  2638. @skip_xp_backends(np_only=True, reason="test code is NumPy specific")
  2639. def test_zero_phase_IIR(self, xp):
  2640. self._test_phaseshift(method='iir', zero_phase=True)
  2641. def _test_phaseshift(self, method, zero_phase):
  2642. # TODO. Look into making tests using this work for CuPy.
  2643. rate = 120
  2644. rates_to = [15, 20, 30, 40] # q = 8, 6, 4, 3
  2645. t_tot = 100 # Need to let antialiasing filters settle
  2646. t = np.arange(rate*t_tot+1) / float(rate)
  2647. # Sinusoids at 0.8*nyquist, windowed to avoid edge artifacts
  2648. freqs = np.array(rates_to) * 0.8 / 2
  2649. d = (np.exp(1j * 2 * np.pi * freqs[:, np.newaxis] * t)
  2650. * signal.windows.tukey(t.size, 0.1))
  2651. for rate_to in rates_to:
  2652. q = rate // rate_to
  2653. t_to = np.arange(rate_to*t_tot+1) / float(rate_to)
  2654. d_tos = (np.exp(1j * 2 * np.pi * freqs[:, np.newaxis] * t_to)
  2655. * signal.windows.tukey(t_to.size, 0.1))
  2656. # Set up downsampling filters, match v0.17 defaults
  2657. if method == 'fir':
  2658. n = 30
  2659. system = signal.dlti(signal.firwin(n + 1, 1. / q,
  2660. window='hamming'), 1.)
  2661. elif method == 'iir':
  2662. n = 8
  2663. wc = 0.8*np.pi/q
  2664. system = signal.dlti(*signal.cheby1(n, 0.05, wc/np.pi))
  2665. # Calculate expected phase response, as unit complex vector
  2666. if zero_phase is False:
  2667. _, h_resps = signal.freqz(system.num, system.den,
  2668. freqs/rate*2*np.pi)
  2669. h_resps /= np.abs(h_resps)
  2670. else:
  2671. h_resps = np.ones_like(freqs)
  2672. y_resamps = signal.decimate(d.real, q, n, ftype=system,
  2673. zero_phase=zero_phase)
  2674. # Get phase from complex inner product, like CSD
  2675. h_resamps = np.sum(d_tos.conj() * y_resamps, axis=-1)
  2676. h_resamps /= np.abs(h_resamps)
  2677. subnyq = freqs < 0.5*rate_to
  2678. # Complex vectors should be aligned, only compare below nyquist
  2679. result = np.angle(h_resps.conj()*h_resamps)[subnyq]
  2680. xp_assert_close(result, np.zeros_like(result),
  2681. atol=1e-3, rtol=1e-3)
  2682. def test_auto_n(self, xp):
  2683. # Test that our value of n is a reasonable choice (depends on
  2684. # the downsampling factor)
  2685. sfreq = 100.
  2686. n = 1000
  2687. t = xp.arange(n) / sfreq
  2688. # will alias for decimations (>= 15)
  2689. x = xp.asarray(xp.sqrt(2. / n) * xp.sin(2 * xp.pi * (sfreq / 30.) * t))
  2690. # Use xp.sqrt(x.dot(x)) instead of xp.linalg.vector_norm(x) because
  2691. # linear algebra extension is not universally available.
  2692. xp_assert_close(xp.sqrt(x.dot(x)), xp.asarray(1.), rtol=1e-3, check_0d=False)
  2693. x_out = signal.decimate(x, 30, ftype='fir')
  2694. assert xp.sqrt(x_out.dot(x_out)) < 0.01
  2695. def test_long_float32(self, xp):
  2696. # regression: gh-15072. With 32-bit float and either lfilter
  2697. # or filtfilt, this is numerically unstable
  2698. x = signal.decimate(xp.ones(10_000, dtype=xp.float32), 10)
  2699. assert not any(xp.isnan(x))
  2700. def test_float16_upcast(self):
  2701. # float16 must be upcast to float64
  2702. x = signal.decimate(np.ones(100, dtype=np.float16), 10)
  2703. assert x.dtype.type == np.float64
  2704. @skip_xp_backends(np_only=True, reason="dlti")
  2705. def test_complex_iir_dlti(self, xp):
  2706. # regression: gh-17845
  2707. # centre frequency for filter [Hz]
  2708. fcentre = 50
  2709. # filter passband width [Hz]
  2710. fwidth = 5
  2711. # sample rate [Hz]
  2712. fs = 1e3
  2713. z, p, k = signal.butter(2, xp.asarray(2*xp.pi*fwidth/2),
  2714. output='zpk', fs=fs)
  2715. z = z.astype(complex) * xp.exp(xp.asarray(2j * xp.pi * fcentre/fs))
  2716. p = p.astype(complex) * xp.exp(xp.asarray(2j * xp.pi * fcentre/fs))
  2717. system = signal.dlti(z, p, k)
  2718. t = xp.arange(200) / fs
  2719. # input
  2720. u = (xp.exp(2j * xp.pi * fcentre * t)
  2721. + 0.5 * xp.exp(-2j * xp.pi * fcentre * t))
  2722. ynzp = signal.decimate(u, 2, ftype=system, zero_phase=False)
  2723. ynzpref = signal.lfilter(*signal.zpk2tf(z, p, k),
  2724. u)[::2]
  2725. xp_assert_equal(ynzp, ynzpref)
  2726. yzp = signal.decimate(u, 2, ftype=system, zero_phase=True)
  2727. yzpref = signal.filtfilt(*signal.zpk2tf(z, p, k),
  2728. u)[::2]
  2729. xp_assert_close(yzp, yzpref, rtol=1e-10, atol=1e-13)
  2730. @skip_xp_backends(np_only=True, reason="dlti")
  2731. def test_complex_fir_dlti(self, xp):
  2732. # centre frequency for filter [Hz]
  2733. fcentre = 50
  2734. # filter passband width [Hz]
  2735. fwidth = 5
  2736. # sample rate [Hz]
  2737. fs = 1e3
  2738. numtaps = 20
  2739. # FIR filter about 0Hz
  2740. bbase = signal.firwin(numtaps, fwidth/2, fs=fs)
  2741. # rotate these to desired frequency
  2742. zbase = np.roots(bbase)
  2743. zrot = zbase * np.exp(2j * np.pi * fcentre/fs)
  2744. # FIR filter about 50Hz, maintaining passband gain of 0dB
  2745. bz = bbase[0] * np.poly(zrot)
  2746. system = signal.dlti(bz, 1)
  2747. t = np.arange(200) / fs
  2748. # input
  2749. u = (np.exp(2j * np.pi * fcentre * t)
  2750. + 0.5 * np.exp(-2j * np.pi * fcentre * t))
  2751. ynzp = signal.decimate(u, 2, ftype=system, zero_phase=False)
  2752. ynzpref = signal.upfirdn(bz, u, up=1, down=2)[:100]
  2753. xp_assert_equal(ynzp, ynzpref)
  2754. yzp = signal.decimate(u, 2, ftype=system, zero_phase=True)
  2755. yzpref = signal.resample_poly(u, 1, 2, window=bz)
  2756. xp_assert_equal(yzp, yzpref)
  2757. @make_xp_test_case(hilbert)
  2758. class TestHilbert:
  2759. def test_bad_args(self, xp):
  2760. x = xp.asarray([1.0 + 0.0j])
  2761. assert_raises(ValueError, hilbert, x)
  2762. x = xp.arange(8.0)
  2763. assert_raises(ValueError, hilbert, x, N=0)
  2764. def test_hilbert_theoretical(self, xp):
  2765. # test cases by Ariel Rokem
  2766. decimal = 14
  2767. pi = xp.pi
  2768. t = xp.arange(0, 2 * pi, pi / 256, dtype=xp.float64)
  2769. a0 = xp.sin(t)
  2770. a1 = xp.cos(t)
  2771. a2 = xp.sin(2 * t)
  2772. a3 = xp.cos(2 * t)
  2773. a = xp.stack([a0, a1, a2, a3])
  2774. h = hilbert(a)
  2775. h_abs = xp.abs(h)
  2776. h_angle = xp.atan2(xp.imag(h), xp.real(h)) # np.angle(h)
  2777. h_real = xp.real(h)
  2778. # The real part should be equal to the original signals:
  2779. assert_almost_equal(h_real, a, decimal)
  2780. # The absolute value should be one everywhere, for this input:
  2781. assert_almost_equal(h_abs, xp.ones(a.shape), decimal)
  2782. # For the 'slow' sine - the phase should go from -pi/2 to pi/2 in
  2783. # the first 256 bins:
  2784. assert_almost_equal(h_angle[0, :256],
  2785. xp.arange(-pi / 2, pi / 2, pi / 256, dtype=xp.float64),
  2786. decimal)
  2787. # For the 'slow' cosine - the phase should go from 0 to pi in the
  2788. # same interval:
  2789. assert_almost_equal(
  2790. h_angle[1, :256], xp.arange(0, pi, pi / 256, dtype=xp.float64), decimal)
  2791. # The 'fast' sine should make this phase transition in half the time:
  2792. assert_almost_equal(h_angle[2, :128],
  2793. xp.arange(-pi / 2, pi / 2, pi / 128, dtype=xp.float64),
  2794. decimal)
  2795. # Ditto for the 'fast' cosine:
  2796. assert_almost_equal(
  2797. h_angle[3, :128], xp.arange(0, pi, pi / 128, dtype=xp.float64), decimal)
  2798. # The imaginary part of hilbert(cos(t)) = sin(t) Wikipedia
  2799. assert_almost_equal(xp.imag(h[1, :]), a0, decimal)
  2800. def test_hilbert_axisN(self, xp):
  2801. # tests for axis and N arguments
  2802. a = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 6))
  2803. # test axis
  2804. aa = hilbert(a, axis=-1)
  2805. xp_assert_equal(hilbert(a.T, axis=0), aa.T)
  2806. # test 1d
  2807. assert_almost_equal(hilbert(a[0, :]), aa[0, :], 14)
  2808. # test N
  2809. aan = hilbert(a, N=20, axis=-1)
  2810. assert aan.shape == (3, 20)
  2811. assert hilbert(a.T, N=20, axis=0).shape == (20, 3)
  2812. # the next test is just a regression test,
  2813. # no idea whether numbers make sense
  2814. a0hilb = np.array([0.000000000000000e+00 - 1.72015830311905j,
  2815. 1.000000000000000e+00 - 2.047794505137069j,
  2816. 1.999999999999999e+00 - 2.244055555687583j,
  2817. 3.000000000000000e+00 - 1.262750302935009j,
  2818. 4.000000000000000e+00 - 1.066489252384493j,
  2819. 5.000000000000000e+00 + 2.918022706971047j,
  2820. 8.881784197001253e-17 + 3.845658908989067j,
  2821. -9.444121133484362e-17 + 0.985044202202061j,
  2822. -1.776356839400251e-16 + 1.332257797702019j,
  2823. -3.996802888650564e-16 + 0.501905089898885j,
  2824. 1.332267629550188e-16 + 0.668696078880782j,
  2825. -1.192678053963799e-16 + 0.235487067862679j,
  2826. -1.776356839400251e-16 + 0.286439612812121j,
  2827. 3.108624468950438e-16 + 0.031676888064907j,
  2828. 1.332267629550188e-16 - 0.019275656884536j,
  2829. -2.360035624836702e-16 - 0.1652588660287j,
  2830. 0.000000000000000e+00 - 0.332049855010597j,
  2831. 3.552713678800501e-16 - 0.403810179797771j,
  2832. 8.881784197001253e-17 - 0.751023775297729j,
  2833. 9.444121133484362e-17 - 0.79252210110103j])
  2834. a0hilb = xp.asarray(a0hilb)
  2835. assert_almost_equal(aan[0, :], a0hilb, 14, err_msg='N regression')
  2836. def test_hilbert_axis_3d(self, xp):
  2837. a = xp.reshape(xp.arange(3 * 5 * 7, dtype=xp.float64), (3, 5, 7))
  2838. # test axis
  2839. aa = hilbert(a, axis=-1)
  2840. for axis in [0, 1]:
  2841. aap = hilbert(xp.moveaxis(a, -1, axis), axis=axis)
  2842. aap = xp.moveaxis(aap, axis, -1)
  2843. xp_assert_equal(aa, aap)
  2844. @pytest.mark.parametrize('dtype', ['float32', 'float64'])
  2845. def test_hilbert_types(self, dtype, xp):
  2846. dtype = getattr(xp, dtype)
  2847. in_typed = xp.zeros(8, dtype=dtype)
  2848. assert xp.real(hilbert(in_typed)).dtype == dtype
  2849. @make_xp_test_case(hilbert2)
  2850. class TestHilbert2:
  2851. """Test function `signal.hilbert2`. """
  2852. @skip_xp_backends(np_only=True, reason='list inputs are numpy-specific')
  2853. def test_array_like(self, xp):
  2854. hilbert2([[1, 2, 3], [4, 5, 6]])
  2855. def test_bad_args(self, xp):
  2856. """Raise all exceptions in `hilbert2`. """
  2857. x = xp.reshape(xp.arange(16), (4, 4))
  2858. with pytest.raises(ValueError, match="^x must be real."):
  2859. hilbert2(xp.asarray([[1.0 + 0.0j]]))
  2860. with pytest.raises(ValueError, match="^N must be positive."):
  2861. hilbert2(x, N=-1)
  2862. with pytest.raises(ValueError, match="^When given as a tuple, N must hold"):
  2863. hilbert2(x, N=(1, 1, 1))
  2864. with pytest.raises(ValueError, match="^When given as a tuple, N must hold"):
  2865. hilbert2(x, N=(0, 1))
  2866. @skip_xp_backends("cupy", reason="CuPy's hilbert2 does not have axes= argument")
  2867. def test_bad_args2(self, xp):
  2868. x = xp.reshape(xp.arange(16), (4, 4))
  2869. with pytest.raises(ValueError, match="^axes must be a tuple of length 2"):
  2870. hilbert2(x, axes=(0, 1, 2))
  2871. with pytest.raises(ValueError, match="^axes must contain 2 distinct axes"):
  2872. hilbert2(x, axes=(0, 0))
  2873. @pytest.mark.parametrize('dtype', ['float32', 'float64'])
  2874. def test_hilbert2_types(self, dtype, xp):
  2875. dtype = getattr(xp, dtype)
  2876. in_typed = xp.zeros((2, 32), dtype=dtype)
  2877. out = xp.real(signal.hilbert2(in_typed))
  2878. assert out.dtype == dtype
  2879. def test_1d_input(self, xp):
  2880. """Needed for 100% coverage """
  2881. x = xp.asarray([0., 1., 1., 0., -1., -1.])
  2882. x0a = signal.hilbert2(xp.reshape(x, (6, 1)))
  2883. x1a = signal.hilbert2(xp.reshape(x, (1, 6)))
  2884. xp_assert_close(x0a, x1a.T)
  2885. def test_parameter_N(self, xp):
  2886. """Compare passing tuple to single int. """
  2887. x = xp.zeros((5, 5))
  2888. x0_a = hilbert2(x, N=4)
  2889. x1_a = hilbert2(x, N=(4, 4))
  2890. xp_assert_equal(x1_a, x0_a)
  2891. @pytest.mark.parametrize('shape', [(4, 5), (5, 4), (4, 4), (5, 5)])
  2892. @skip_xp_backends("cupy", reason="Bug in cupy implementation, see cupy#9396")
  2893. def test_quadrant_values(self, shape, xp):
  2894. """Compare desired and calculated values in Fourier space. """
  2895. x_f = xp.ones(shape, dtype=xp.complex128) # FFT of input signal
  2896. x_f[0 , 0] += 7
  2897. x = xp.real(sp_fft.ifft2(x_f)) # x.imag is zero
  2898. x_as = hilbert2(x)
  2899. x_as_f = sp_fft.fft2(x_as)
  2900. # Create slices for bins with purely positive and purely negative frequencies
  2901. # (can be verified with `sp_fft.fftfreq()`):
  2902. f0_pos, f0_neg = slice(1, (shape[0] + 1) // 2), slice((shape[0] + 1) // 2, None)
  2903. f1_pos, f1_neg = slice(1, (shape[1] + 1) // 2), slice((shape[1] + 1) // 2, None)
  2904. # Verify all values:
  2905. atol = 1e-12 # for x of dtype complex128
  2906. xp_assert_close(x_as_f[f0_pos, f1_pos], x_f[f0_pos, f1_pos] * 4, atol=atol)
  2907. xp_assert_close(x_as_f[0, f1_pos], x_f[0, f1_pos] * 2, atol=atol)
  2908. xp_assert_close(x_as_f[f0_pos, 0], x_f[f0_pos, 0] * 2, atol=atol)
  2909. xp_assert_close(x_as_f[0, 0], x_f[0, 0], atol=atol)
  2910. zz_as_f = x_as_f[f0_neg, f1_neg] # check for zeroed orthants
  2911. xp_assert_close(zz_as_f, xp.zeros_like(zz_as_f), atol=atol)
  2912. @pytest.mark.parametrize('shape', [(4, 5), (5, 4), (4, 4), (5, 5)])
  2913. def test_zero_analytic_signal(self, shape, xp):
  2914. """Test that a real signal with Z[-p,-q] == np.conj(Z[p,q])
  2915. produces a zero analytic signal."""
  2916. c0 = shape[0] // 2
  2917. c1 = shape[1] // 2
  2918. x_f = xp.zeros(shape)
  2919. x_f[c0 - 1, c1 + 1] = 1.0
  2920. x_f[c0 + 1, c1 - 1] = 1.0
  2921. x_f = sp_fft.ifftshift(x_f)
  2922. x = xp.real(sp_fft.ifft2(x_f))
  2923. assert xp.sum(abs(x)) > 0.0
  2924. x_as = hilbert2(x)
  2925. xp_assert_close(x_as, xp.zeros_like(x_as), atol=xp.finfo(x_as.dtype).eps*16)
  2926. @pytest.mark.parametrize('sh0', [4, 5])
  2927. @pytest.mark.parametrize('sh1', [6, 7])
  2928. @pytest.mark.parametrize('sh2', [8, 9])
  2929. @pytest.mark.parametrize('not_axis', [0, 1, 2])
  2930. @skip_xp_backends("cupy", reason="cupy implementation does not have axes kwarg")
  2931. def test_3d_vs_slice(self, sh0, sh1, sh2, not_axis, xp):
  2932. """2d transform on 3d array is equal to 2d transform on 2d slices."""
  2933. x = xp.reshape(xp.arange(sh0 * sh1 * sh2, dtype=xp.float64), (sh0, sh1, sh2))
  2934. transform_axes = [0, 1, 2]
  2935. transform_axes.pop(not_axis)
  2936. x_as_3d = hilbert2(x, axes=transform_axes)
  2937. parts = xp.unstack(x, axis=not_axis)
  2938. x_as_2d = [hilbert2(p) for p in parts]
  2939. x_as_2d = xp.stack(x_as_2d, axis=not_axis)
  2940. xp_assert_close(x_as_3d, x_as_2d)
  2941. @skip_xp_backends("cupy", reason="cupy implementation does not have axes kwarg")
  2942. def test_3d_axis_order(self, xp):
  2943. """2d transform on equal arrays with moved axis are equal."""
  2944. x0 = xp.reshape(xp.arange(5 * 7 * 9, dtype=xp.float64), (5, 7, 9))
  2945. x0_as = hilbert2(x0)
  2946. x1 = xp.moveaxis(x0, 0, 1)
  2947. x1_as = hilbert2(x1, axes=(0, 2))
  2948. x1_as = xp.moveaxis(x1_as, 1, 0)
  2949. xp_assert_close(x0_as, x1_as)
  2950. x2 = xp.moveaxis(x0, 0, 2)
  2951. x2_as = hilbert2(x2, axes=(0, 1))
  2952. x2_as = xp.moveaxis(x2_as, 2, 0)
  2953. xp_assert_close(x0_as, x2_as)
  2954. @make_xp_test_case(envelope)
  2955. class TestEnvelope:
  2956. """Unit tests for function `._signaltools.envelope()`. """
  2957. @staticmethod
  2958. def assert_close(actual, desired, msg, xp):
  2959. a_r_tol = ({'atol': 1e-12, 'rtol': 1e-12}
  2960. if xp_default_dtype(xp) == xp.float64
  2961. else {'atol': 1e-5, 'rtol': 1e-5}
  2962. )
  2963. """Little helper to compare to arrays with proper tolerances"""
  2964. xp_assert_close(actual, desired, **a_r_tol, err_msg=msg)
  2965. def test_envelope_invalid_parameters(self, xp):
  2966. """For `envelope()` Raise all exceptions that are used to verify function
  2967. parameters. """
  2968. with pytest.raises(ValueError,
  2969. match=r"Invalid parameter axis=2 for z.shape=.*"):
  2970. envelope(np.ones(3), axis=2)
  2971. with pytest.raises(ValueError,
  2972. match=r"z.shape\[axis\] not > 0 for z.shape=.*"):
  2973. envelope(xp.ones((3, 0)), axis=1)
  2974. for bp_in in [(0, 1, 2), (0, 2.), (None, 2.)]:
  2975. ts = ', '.join(map(str, bp_in))
  2976. with pytest.raises(ValueError,
  2977. match=rf"bp_in=\({ts}\) isn't a 2-tuple of.*"):
  2978. # noinspection PyTypeChecker
  2979. envelope(xp.ones(4), bp_in=bp_in)
  2980. with pytest.raises(ValueError,
  2981. match="n_out=10.0 is not a positive integer or.*"):
  2982. # noinspection PyTypeChecker
  2983. envelope(xp.ones(4), n_out=10.)
  2984. for bp_in in [(-1, 3), (1, 1), (0, 10)]:
  2985. with pytest.raises(ValueError,
  2986. match=r"`-n//2 <= bp_in\[0\] < bp_in\[1\] <=.*"):
  2987. envelope(xp.ones(4), bp_in=bp_in)
  2988. with pytest.raises(ValueError, match="residual='undefined' not in .*"):
  2989. # noinspection PyTypeChecker
  2990. envelope(xp.ones(4), residual='undefined')
  2991. @skip_xp_backends("jax.numpy", reason="XXX: immutable arrays")
  2992. def test_envelope_verify_parameters(self, xp):
  2993. """Ensure that the various parametrizations produce compatible results. """
  2994. dt_r = xp_default_dtype(xp)
  2995. dt_c = xp.complex64 if dt_r == xp.float32 else xp.complex128
  2996. Z = xp.asarray([4.0, 2, 2, 3, 0], dtype=dt_r)
  2997. Zr_a = xp.asarray([4.0, 0, 0, 6, 0, 0, 0, 0], dtype=dt_r)
  2998. z = sp_fft.irfft(Z)
  2999. n = z.shape[0]
  3000. # the reference envelope:
  3001. ze2_0, zr_0 = xp.unstack(envelope(z, (1, 3), residual='all', squared=True))
  3002. self.assert_close(sp_fft.rfft(ze2_0),
  3003. xp.asarray([4, 2, 0, 0, 0], dtype=dt_c),
  3004. msg="Envelope calculation error", xp=xp)
  3005. self.assert_close(sp_fft.rfft(zr_0),
  3006. xp.asarray([4, 0, 0, 3, 0], dtype=dt_c),
  3007. msg="Residual calculation error", xp=xp)
  3008. ze_1, zr_1 = xp.unstack(envelope(z, (1, 3), residual='all', squared=False))
  3009. self.assert_close(ze_1**2, ze2_0,
  3010. msg="Unsquared versus Squared envelope calculation error",
  3011. xp=xp)
  3012. self.assert_close(zr_1, zr_0,
  3013. msg="Unsquared versus Squared residual calculation error",
  3014. xp=xp)
  3015. ze2_2, zr_2 = xp.unstack(
  3016. envelope(z, (1, 3), residual='all', squared=True, n_out=3*n)
  3017. )
  3018. self.assert_close(ze2_2[::3], ze2_0,
  3019. msg="3x up-sampled envelope calculation error", xp=xp)
  3020. self.assert_close(zr_2[::3], zr_0,
  3021. msg="3x up-sampled residual calculation error", xp=xp)
  3022. ze2_3, zr_3 = xp.unstack(envelope(z, (1, 3), residual='lowpass', squared=True))
  3023. self.assert_close(ze2_3, ze2_0,
  3024. msg="`residual='lowpass'` envelope calculation error", xp=xp)
  3025. self.assert_close(sp_fft.rfft(zr_3),
  3026. xp.asarray([4, 0, 0, 0, 0], dtype=dt_c),
  3027. msg="`residual='lowpass'` residual calculation error", xp=xp)
  3028. ze2_4 = envelope(z, (1, 3), residual=None, squared=True)
  3029. self.assert_close(ze2_4, ze2_0,
  3030. msg="`residual=None` envelope calculation error", xp=xp)
  3031. # compare complex analytic signal to real version
  3032. Z_a = xp.asarray(Z, copy=True)
  3033. Z_a[1:] *= 2
  3034. z_a = sp_fft.ifft(Z_a, n=n) # analytic signal of Z
  3035. self.assert_close(xp.real(z_a), z,
  3036. msg="Reference analytic signal error", xp=xp)
  3037. ze2_a, zr_a = xp.unstack(envelope(z_a, (1, 3), residual='all', squared=True))
  3038. self.assert_close(ze2_a, xp.astype(ze2_0, dt_c), # dtypes must match
  3039. msg="Complex envelope calculation error", xp=xp)
  3040. self.assert_close(sp_fft.fft(zr_a), xp.asarray(Zr_a, dtype=dt_c),
  3041. msg="Complex residual calculation error", xp=xp)
  3042. @skip_xp_backends("jax.numpy", reason="XXX: immutable arrays")
  3043. @pytest.mark.parametrize(
  3044. " Z, bp_in, Ze2_desired, Zr_desired",
  3045. [([1, 0, 2, 2, 0], (1, None), [4, 2, 0, 0, 0], [1, 0, 0, 0, 0]),
  3046. ([4, 0, 2, 0, 0], (0, None), [4, 0, 2, 0, 0], [0, 0, 0, 0, 0]),
  3047. ([4, 0, 0, 2, 0], (None, None), [4, 0, 0, 2, 0], [0, 0, 0, 0, 0]),
  3048. ([0, 0, 2, 2, 0], (1, 3), [2, 0, 0, 0, 0], [0, 0, 0, 2, 0]),
  3049. ([4, 0, 2, 2, 0], (-3, 3), [4, 0, 2, 0, 0], [0, 0, 0, 2, 0]),
  3050. ([4, 0, 3, 4, 0], (None, 1), [2, 0, 0, 0, 0], [0, 0, 3, 4, 0]),
  3051. ([4, 0, 3, 4, 0], (None, 0), [0, 0, 0, 0, 0], [4, 0, 3, 4, 0])])
  3052. def test_envelope_real_signals(self, Z, bp_in, Ze2_desired, Zr_desired, xp):
  3053. """Test envelope calculation with real-valued test signals.
  3054. The comparisons are performed in the Fourier space, since it makes evaluating
  3055. the bandpass filter behavior straightforward. Note that also the squared
  3056. envelope can be easily calculated by hand, if one recalls that coefficients of
  3057. a complex-valued Fourier series representing the signal can be directly
  3058. determined by an FFT and that the absolute square of a Fourier series is again
  3059. a Fourier series.
  3060. """
  3061. Z = xp.asarray(Z, dtype=xp.float64)
  3062. Ze2_desired = xp.asarray(Ze2_desired, dtype=xp.float64)
  3063. Zr_desired = xp.asarray(Zr_desired, dtype=xp.float64)
  3064. z = sp_fft.irfft(Z)
  3065. ze2, zr = xp.unstack(envelope(z, bp_in, residual='all', squared=True))
  3066. ze2_lp, zr_lp = xp.unstack(envelope(z, bp_in, residual='lowpass', squared=True))
  3067. Ze2, Zr, Ze2_lp, Zr_lp = (sp_fft.rfft(z_) for z_ in (ze2, zr, ze2_lp, zr_lp))
  3068. Ze2_desired = xp.asarray(Ze2_desired, dtype=xp.complex128)
  3069. Zr_desired = xp.asarray(Zr_desired, dtype=xp.complex128)
  3070. self.assert_close(Ze2, Ze2_desired,
  3071. msg="Envelope calculation error (residual='all')", xp=xp)
  3072. self.assert_close(Zr, Zr_desired,
  3073. msg="Residual calculation error (residual='all')", xp=xp)
  3074. if bp_in[1] is not None:
  3075. Zr_desired[bp_in[1]:] = 0
  3076. self.assert_close(Ze2_lp, Ze2_desired,
  3077. msg="Envelope calculation error (residual='lowpass')", xp=xp)
  3078. self.assert_close(Zr_lp, Zr_desired,
  3079. msg="Residual calculation error (residual='lowpass')", xp=xp)
  3080. @skip_xp_backends("jax.numpy", reason="XXX: immutable arrays")
  3081. @pytest.mark.parametrize(
  3082. " Z, bp_in, Ze2_desired, Zr_desired",
  3083. [([0, 5, 0, 5, 0], (None, None), [5, 0, 10, 0, 5], [0, 0, 0, 0, 0]),
  3084. ([1, 5, 0, 5, 2], (-1, 2), [5, 0, 10, 0, 5], [1, 0, 0, 0, 2]),
  3085. ([1, 2, 6, 0, 6, 3], (-1, 2), [0, 6, 0, 12, 0, 6], [1, 2, 0, 0, 0, 3])
  3086. ])
  3087. def test_envelope_complex_signals(self, Z, bp_in, Ze2_desired, Zr_desired, xp):
  3088. """Test envelope calculation with complex-valued test signals.
  3089. We only need to test for the complex envelope here, since the ``Nones``s in the
  3090. bandpass filter were already tested in the previous test.
  3091. """
  3092. Z = xp.asarray(Z, dtype=xp.float64)
  3093. Ze2_desired = xp.asarray(Ze2_desired, dtype=xp.complex128)
  3094. Zr_desired = xp.asarray(Zr_desired, dtype=xp.complex128)
  3095. z = sp_fft.ifft(sp_fft.ifftshift(Z))
  3096. ze2, zr = xp.unstack(envelope(z, bp_in, residual='all', squared=True))
  3097. Ze2, Zr = (sp_fft.fftshift(sp_fft.fft(z_)) for z_ in (ze2, zr))
  3098. self.assert_close(Ze2, Ze2_desired,
  3099. msg="Envelope calculation error", xp=xp)
  3100. self.assert_close(Zr, Zr_desired,
  3101. msg="Residual calculation error", xp=xp)
  3102. @skip_xp_backends("jax.numpy", reason="XXX: immutable arrays")
  3103. def test_envelope_verify_axis_parameter(self, xp):
  3104. """Test for multi-channel envelope calculations. """
  3105. dt_r = xp_default_dtype(xp)
  3106. dt_c = xp.complex64 if dt_r == xp.float32 else xp.complex128
  3107. z = sp_fft.irfft(xp.asarray([[1.0, 0, 2, 2, 0], [7, 0, 4, 4, 0]], dtype=dt_r))
  3108. Ze2_desired = xp.asarray([[4, 2, 0, 0, 0], [16, 8, 0, 0, 0]],
  3109. dtype=dt_c)
  3110. Zr_desired = xp.asarray([[1, 0, 0, 0, 0], [7, 0, 0, 0, 0]], dtype=dt_c)
  3111. ze2, zr = xp.unstack(envelope(z, squared=True, axis=1))
  3112. ye2T, yrT = xp.unstack(envelope(z.T, squared=True, axis=0))
  3113. Ze2, Ye2, Zr, Yr = (sp_fft.rfft(z_) for z_ in (ze2, ye2T.T, zr, yrT.T))
  3114. self.assert_close(Ze2, Ze2_desired, msg="2d envelope calculation error", xp=xp)
  3115. self.assert_close(Zr, Zr_desired, msg="2d residual calculation error", xp=xp)
  3116. self.assert_close(
  3117. Ye2, Ze2_desired, msg="Transposed 2d envelope calc. error", xp=xp
  3118. )
  3119. self.assert_close(
  3120. Yr, Zr_desired, msg="Transposed 2d residual calc. error", xp=xp
  3121. )
  3122. @skip_xp_backends("jax.numpy", reason="XXX: immutable arrays")
  3123. def test_envelope_verify_axis_parameter_complex(self, xp):
  3124. """Test for multi-channel envelope calculations with complex values. """
  3125. dt_r = xp_default_dtype(xp)
  3126. dt_c = xp.complex64 if dt_r == xp.float32 else xp.complex128
  3127. inp = xp.asarray([[1.0, 5, 0, 5, 2], [1, 10, 0, 10, 2]], dtype=dt_r)
  3128. z = sp_fft.ifft(sp_fft.ifftshift(inp, axes=1))
  3129. Ze2_des = xp.asarray([[5, 0, 10, 0, 5], [20, 0, 40, 0, 20],], dtype=dt_c)
  3130. Zr_des = xp.asarray([[1, 0, 0, 0, 2], [1, 0, 0, 0, 2]], dtype=dt_c)
  3131. kw = dict(bp_in=(-1, 2), residual='all', squared=True)
  3132. ze2, zr = xp.unstack(envelope(z, axis=1, **kw))
  3133. ye2T, yrT = xp.unstack(envelope(z.T, axis=0, **kw))
  3134. Ze2, Ye2, Zr, Yr = (sp_fft.fftshift(sp_fft.fft(z_), axes=1)
  3135. for z_ in (ze2, ye2T.T, zr, yrT.T))
  3136. self.assert_close(Ze2, Ze2_des, msg="2d envelope calculation error", xp=xp)
  3137. self.assert_close(Zr, Zr_des, msg="2d residual calculation error", xp=xp)
  3138. self.assert_close(
  3139. Ye2, Ze2_des, msg="Transposed 2d envelope calc. error", xp=xp
  3140. )
  3141. self.assert_close(Yr, Zr_des, msg="Transposed 2d residual calc. error", xp=xp)
  3142. @skip_xp_backends("jax.numpy", reason="XXX: immutable arrays")
  3143. @pytest.mark.parametrize('X', [[4, 0, 0, 1, 2], [4, 0, 0, 2, 1, 2]])
  3144. def test_compare_envelope_hilbert(self, X, xp):
  3145. """Compare output of `envelope()` and `hilbert()`. """
  3146. X = xp.asarray(X, dtype=xp.float64)
  3147. x = sp_fft.irfft(X)
  3148. e_hil = xp.abs(hilbert(x))
  3149. e_env = envelope(x, (None, None), residual=None)
  3150. self.assert_close(e_hil, e_env, msg="Hilbert-Envelope comparison error", xp=xp)
  3151. def test_nyquist(self):
  3152. """Test behavior when input is a cosine at the Nyquist frequency.
  3153. Resampling even length signals, requires accounting for unpaired bins at the
  3154. Nyquist frequency (consults the source code of `resample`).
  3155. Since `envelope` excludes the Nyquist frequency from the envelope calculation,
  3156. only the residues need to be investigated.
  3157. """
  3158. x4 = sp_fft.irfft([0, 0, 8]) # = [2, -2, 2, -2]
  3159. x6 = signal.resample(x4, num=6) # = [2, -1, -1, 2, -1, -1]
  3160. y6, y6_res = envelope(x4, n_out=6, residual='all') # real-valued case
  3161. z6, z6_res = envelope(x4 + 0j, n_out=6, residual='all') # complex-valued case
  3162. xp_assert_close(y6, np.zeros(6), atol=1e-12)
  3163. xp_assert_close(y6_res, x6, atol=1e-12)
  3164. xp_assert_close(z6, np.zeros(6, dtype=z6.dtype), atol=1e-12)
  3165. xp_assert_close(z6_res, x6.astype(z6.dtype), atol=1e-12)
  3166. class TestPartialFractionExpansion:
  3167. @staticmethod
  3168. def assert_rp_almost_equal(r, p, r_true, p_true, decimal=7):
  3169. xp = array_namespace(r, p)
  3170. r_true = xp.asarray(r_true)
  3171. p_true = xp.asarray(p_true)
  3172. distance = xp.hypot(abs(p[:, None] - p_true),
  3173. abs(r[:, None] - r_true))
  3174. rows, cols = linear_sum_assignment(_xp_copy_to_numpy(distance))
  3175. assert_almost_equal(p[rows], p_true[cols], decimal=decimal)
  3176. assert_almost_equal(r[rows], r_true[cols], decimal=decimal)
  3177. @skip_xp_backends(np_only=True)
  3178. def test_compute_factors(self, xp):
  3179. factors, poly = _compute_factors([1, 2, 3], [3, 2, 1])
  3180. assert len(factors) == 3
  3181. assert_almost_equal(factors[0], np.poly([2, 2, 3]))
  3182. assert_almost_equal(factors[1], np.poly([1, 1, 1, 3]))
  3183. assert_almost_equal(factors[2], np.poly([1, 1, 1, 2, 2]))
  3184. assert_almost_equal(poly, np.poly([1, 1, 1, 2, 2, 3]))
  3185. factors, poly = _compute_factors([1, 2, 3], [3, 2, 1],
  3186. include_powers=True)
  3187. assert len(factors) == 6
  3188. assert_almost_equal(factors[0], np.poly([1, 1, 2, 2, 3]))
  3189. assert_almost_equal(factors[1], np.poly([1, 2, 2, 3]))
  3190. assert_almost_equal(factors[2], np.poly([2, 2, 3]))
  3191. assert_almost_equal(factors[3], np.poly([1, 1, 1, 2, 3]))
  3192. assert_almost_equal(factors[4], np.poly([1, 1, 1, 3]))
  3193. assert_almost_equal(factors[5], np.poly([1, 1, 1, 2, 2]))
  3194. assert_almost_equal(poly, np.poly([1, 1, 1, 2, 2, 3]))
  3195. @skip_xp_backends(np_only=True)
  3196. def test_group_poles(self, xp):
  3197. unique, multiplicity = _group_poles(
  3198. [1.0, 1.001, 1.003, 2.0, 2.003, 3.0], 0.1, 'min')
  3199. xp_assert_close(unique, [1.0, 2.0, 3.0])
  3200. xp_assert_close(multiplicity, [3, 2, 1])
  3201. @make_xp_test_case(residue)
  3202. def test_residue_general(self, xp):
  3203. # Test are taken from issue #4464, note that poles in scipy are
  3204. # in increasing by absolute value order, opposite to MATLAB.
  3205. r, p, k = residue(xp.asarray([5, 3, -2, 7]), xp.asarray([-4, 0, 8, 3]))
  3206. assert_almost_equal(r, xp.asarray([1.3320, -0.6653, -1.4167]), decimal=4)
  3207. assert_almost_equal(p, xp.asarray([-0.4093, -1.1644, 1.5737]), decimal=4)
  3208. assert_almost_equal(k, xp.asarray([-1.2500]), decimal=4)
  3209. r, p, k = residue(xp.asarray([-4, 8]), xp.asarray([1, 6, 8]))
  3210. assert_almost_equal(r, xp.asarray([8, -12]))
  3211. assert_almost_equal(p, xp.asarray([-2, -4]))
  3212. assert k.size == 0
  3213. r, p, k = residue(xp.asarray([4, 1]), xp.asarray([1, -1, -2]))
  3214. assert_almost_equal(r, xp.asarray([1, 3]))
  3215. assert_almost_equal(p, xp.asarray([-1, 2]))
  3216. assert k.size == 0
  3217. r, p, k = residue(xp.asarray([4, 3]),
  3218. xp.asarray([2, -3.4, 1.98, -0.406]))
  3219. self.assert_rp_almost_equal(
  3220. r, p, [-18.125 - 13.125j, -18.125 + 13.125j, 36.25],
  3221. [0.5 - 0.2j, 0.5 + 0.2j, 0.7])
  3222. assert k.size == 0
  3223. r, p, k = residue(xp.asarray([2, 1]), xp.asarray([1, 5, 8, 4]))
  3224. self.assert_rp_almost_equal(r, p, [-1, 1, 3],
  3225. [-1, -2, -2])
  3226. assert k.size == 0
  3227. r, p, k = residue(xp.asarray([3, -1.1, 0.88, -2.396, 1.348]),
  3228. xp.asarray([1, -0.7, -0.14, 0.048]))
  3229. assert_almost_equal(r, xp.asarray([-3, 4, 1]))
  3230. assert_almost_equal(p, xp.asarray([0.2, -0.3, 0.8]))
  3231. assert_almost_equal(k, xp.asarray([3, 1]))
  3232. r, p, k = residue(xp.asarray([1]), xp.asarray([1, 2, -3]))
  3233. assert_almost_equal(r, xp.asarray([0.25, -0.25]))
  3234. assert_almost_equal(p, xp.asarray([1, -3]))
  3235. assert k.size == 0
  3236. r, p, k = residue(xp.asarray([1, 0, -5]), xp.asarray([1, 0, 0, 0, -1]))
  3237. self.assert_rp_almost_equal(r, p,
  3238. [1, 1.5j, -1.5j, -1],
  3239. [-1, -1j, 1j, 1])
  3240. assert k.size == 0
  3241. r, p, k = residue(xp.asarray([3, 8, 6]), xp.asarray([1, 3, 3, 1]))
  3242. self.assert_rp_almost_equal(r, p, [1, 2, 3],
  3243. [-1, -1, -1])
  3244. assert k.size == 0
  3245. r, p, k = residue(xp.asarray([3, -1]), xp.asarray([1, -3, 2]))
  3246. assert_almost_equal(r, xp.asarray([-2, 5]))
  3247. assert_almost_equal(p, xp.asarray([1, 2]))
  3248. assert k.size == 0
  3249. r, p, k = residue(xp.asarray([2, 3, -1]), xp.asarray([1, -3, 2]))
  3250. assert_almost_equal(r, xp.asarray([-4, 13]))
  3251. assert_almost_equal(p, xp.asarray([1, 2]))
  3252. assert_almost_equal(k, xp.asarray([2]))
  3253. r, p, k = residue(xp.asarray([7, 2, 3, -1]), xp.asarray([1, -3, 2]))
  3254. assert_almost_equal(r, xp.asarray([-11, 69]))
  3255. assert_almost_equal(p, xp.asarray([1, 2]))
  3256. assert_almost_equal(k, xp.asarray([7, 23]))
  3257. r, p, k = residue(xp.asarray([2, 3, -1]), xp.asarray([1, -3, 4, -2]))
  3258. self.assert_rp_almost_equal(r, p, [4, -1 + 3.5j, -1 - 3.5j],
  3259. [1, 1 - 1j, 1 + 1j])
  3260. assert k.size == 0
  3261. @make_xp_test_case(residue)
  3262. def test_residue_leading_zeros(self, xp):
  3263. # Leading zeros in numerator or denominator must not affect the answer.
  3264. r0, p0, k0 = residue(xp.asarray([5, 3, -2, 7]), xp.asarray([-4, 0, 8, 3]))
  3265. r1, p1, k1 = residue(xp.asarray([0, 5, 3, -2, 7]), xp.asarray([-4, 0, 8, 3]))
  3266. r2, p2, k2 = residue(xp.asarray([5, 3, -2, 7]), xp.asarray([0, -4, 0, 8, 3]))
  3267. r3, p3, k3 = residue(xp.asarray([0, 0, 5, 3, -2, 7]),
  3268. xp.asarray([0, 0, 0, -4, 0, 8, 3]))
  3269. assert_almost_equal(r0, r1)
  3270. assert_almost_equal(r0, r2)
  3271. assert_almost_equal(r0, r3)
  3272. assert_almost_equal(p0, p1)
  3273. assert_almost_equal(p0, p2)
  3274. assert_almost_equal(p0, p3)
  3275. assert_almost_equal(k0, k1)
  3276. assert_almost_equal(k0, k2)
  3277. assert_almost_equal(k0, k3)
  3278. @make_xp_test_case(residue)
  3279. def test_residue_degenerate(self, xp):
  3280. # Several tests for zero numerator and denominator.
  3281. r, p, k = residue(xp.asarray([0, 0]), xp.asarray([1, 6, 8]))
  3282. assert_almost_equal(r, xp.asarray([0, 0]))
  3283. assert_almost_equal(p, xp.asarray([-2, -4]))
  3284. assert k.size == 0
  3285. r, p, k = residue(xp.asarray(0), xp.asarray(1))
  3286. assert r.size == 0
  3287. assert p.size == 0
  3288. assert k.size == 0
  3289. with pytest.raises(ValueError, match="Denominator `a` is zero."):
  3290. residue(1, 0)
  3291. @make_xp_test_case(residuez)
  3292. def test_residuez_general(self, xp):
  3293. r, p, k = residuez(xp.asarray([1, 6, 6, 2]),
  3294. xp.asarray([1, -(2 + 1j), (1 + 2j), -1j]))
  3295. self.assert_rp_almost_equal(r, p, [-2+2.5j, 7.5+7.5j, -4.5-12j],
  3296. [1j, 1, 1])
  3297. assert_almost_equal(k, xp.asarray([2j]))
  3298. r, p, k = residuez(xp.asarray([1, 2, 1]), xp.asarray([1, -1, 0.3561]))
  3299. self.assert_rp_almost_equal(r, p,
  3300. [-0.9041 - 5.9928j, -0.9041 + 5.9928j],
  3301. [0.5 + 0.3257j, 0.5 - 0.3257j],
  3302. decimal=4)
  3303. assert_almost_equal(k, xp.asarray([2.8082]), decimal=4)
  3304. r, p, k = residuez(xp.asarray([1, -1]), xp.asarray([1, -5, 6]))
  3305. assert_almost_equal(r, xp.asarray([-1, 2]))
  3306. assert_almost_equal(p, xp.asarray([2, 3]))
  3307. assert k.size == 0
  3308. r, p, k = residuez(xp.asarray([2, 3, 4]), xp.asarray([1, 3, 3, 1]))
  3309. self.assert_rp_almost_equal(r, p, [4, -5, 3], [-1, -1, -1])
  3310. assert k.size == 0
  3311. r, p, k = residuez(xp.asarray([1, -10, -4, 4]), xp.asarray([2, -2, -4]))
  3312. assert_almost_equal(r, xp.asarray([0.5, -1.5]))
  3313. assert_almost_equal(p, xp.asarray([-1, 2]))
  3314. assert_almost_equal(k, xp.asarray([1.5, -1]))
  3315. r, p, k = residuez(xp.asarray([18]), xp.asarray([18, 3, -4, -1]))
  3316. self.assert_rp_almost_equal(r, p,
  3317. [0.36, 0.24, 0.4], [0.5, -1/3, -1/3])
  3318. assert k.size == 0
  3319. r, p, k = residuez(xp.asarray([2, 3]),
  3320. xp.asarray(np.polymul([1, -1/2], [1, 1/4])))
  3321. assert_almost_equal(r, xp.asarray([-10/3, 16/3]))
  3322. assert_almost_equal(p, xp.asarray([-0.25, 0.5]))
  3323. assert k.size == 0
  3324. r, p, k = residuez(xp.asarray([1, -2, 1]), xp.asarray([1, -1]))
  3325. assert_almost_equal(r, xp.asarray([0]))
  3326. assert_almost_equal(p, xp.asarray([1]))
  3327. assert_almost_equal(k, xp.asarray([1, -1]))
  3328. r, p, k = residuez(xp.asarray(1), xp.asarray([1, -1j]))
  3329. assert_almost_equal(r, xp.asarray([1]))
  3330. assert_almost_equal(p, xp.asarray([1j]))
  3331. assert k.size == 0
  3332. r, p, k = residuez(xp.asarray(1), xp.asarray([1, -1, 0.25]))
  3333. assert_almost_equal(r, xp.asarray([0, 1]))
  3334. assert_almost_equal(p, xp.asarray([0.5, 0.5]))
  3335. assert k.size == 0
  3336. r, p, k = residuez(xp.asarray(1), xp.asarray([1, -0.75, .125]))
  3337. assert_almost_equal(r, xp.asarray([-1, 2]))
  3338. assert_almost_equal(p, xp.asarray([0.25, 0.5]))
  3339. assert k.size == 0
  3340. r, p, k = residuez(xp.asarray([1, 6, 2]), xp.asarray([1, -2, 1]))
  3341. assert_almost_equal(r, xp.asarray([-10, 9]))
  3342. assert_almost_equal(p, xp.asarray([1, 1]))
  3343. assert_almost_equal(k, xp.asarray([2]))
  3344. r, p, k = residuez(xp.asarray([6, 2]), xp.asarray([1, -2, 1]))
  3345. assert_almost_equal(r, xp.asarray([-2, 8]))
  3346. assert_almost_equal(p, xp.asarray([1, 1]))
  3347. assert k.size == 0
  3348. r, p, k = residuez(xp.asarray([1, 6, 6, 2]), xp.asarray([1, -2, 1]))
  3349. assert_almost_equal(r, xp.asarray([-24, 15]))
  3350. assert_almost_equal(p, xp.asarray([1, 1]))
  3351. assert_almost_equal(k, xp.asarray([10, 2]))
  3352. r, p, k = residuez(xp.asarray([1, 0, 1]), xp.asarray([1, 0, 0, 0, 0, -1]))
  3353. self.assert_rp_almost_equal(r, p,
  3354. [0.2618 + 0.1902j, 0.2618 - 0.1902j,
  3355. 0.4, 0.0382 - 0.1176j, 0.0382 + 0.1176j],
  3356. [-0.8090 + 0.5878j, -0.8090 - 0.5878j,
  3357. 1.0, 0.3090 + 0.9511j, 0.3090 - 0.9511j],
  3358. decimal=4)
  3359. assert k.size == 0
  3360. @make_xp_test_case(residuez)
  3361. def test_residuez_trailing_zeros(self, xp):
  3362. # Trailing zeros in numerator or denominator must not affect the
  3363. # answer.
  3364. r0, p0, k0 = residuez(xp.asarray([5, 3, -2, 7]),
  3365. xp.asarray([-4, 0, 8, 3]))
  3366. r1, p1, k1 = residuez(xp.asarray([5, 3, -2, 7, 0]),
  3367. xp.asarray([-4, 0, 8, 3]))
  3368. r2, p2, k2 = residuez(xp.asarray([5, 3, -2, 7]),
  3369. xp.asarray([-4, 0, 8, 3, 0]))
  3370. r3, p3, k3 = residuez(xp.asarray([5, 3, -2, 7, 0, 0]),
  3371. xp.asarray([-4, 0, 8, 3, 0, 0, 0]))
  3372. assert_almost_equal(r0, r1)
  3373. assert_almost_equal(r0, r2)
  3374. assert_almost_equal(r0, r3)
  3375. assert_almost_equal(p0, p1)
  3376. assert_almost_equal(p0, p2)
  3377. assert_almost_equal(p0, p3)
  3378. assert_almost_equal(k0, k1)
  3379. assert_almost_equal(k0, k2)
  3380. assert_almost_equal(k0, k3)
  3381. @make_xp_test_case(residuez)
  3382. def test_residuez_degenerate(self, xp):
  3383. r, p, k = residuez(xp.asarray([0, 0]), xp.asarray([1, 6, 8]))
  3384. assert_almost_equal(r, xp.asarray([0, 0]))
  3385. assert_almost_equal(p, xp.asarray([-2, -4]))
  3386. assert k.size == 0
  3387. r, p, k = residuez(xp.asarray(0), xp.asarray(1))
  3388. assert r.size == 0
  3389. assert p.size == 0
  3390. assert k.size == 0
  3391. with pytest.raises(ValueError, match="Denominator `a` is zero."):
  3392. residuez(xp.asarray(1), xp.asarray(0))
  3393. with pytest.raises(ValueError,
  3394. match="First coefficient of determinant `a` must "
  3395. "be non-zero."):
  3396. residuez(xp.asarray(1), xp.asarray([0, 1, 2, 3]))
  3397. @make_xp_test_case(invres, invresz)
  3398. def test_inverse_unique_roots_different_rtypes(self, xp):
  3399. # This test was inspired by github issue 2496.
  3400. r = xp.asarray([3 / 10, -1 / 6, -2 / 15])
  3401. p = xp.asarray([0, -2, -5])
  3402. k = xp.asarray([])
  3403. b_expected = xp.asarray([0.0, 1, 3])
  3404. a_expected = xp.asarray([1, 7, 10, 0])
  3405. # With the default tolerance, the rtype does not matter
  3406. # for this example.
  3407. for rtype in ('avg', 'mean', 'min', 'minimum', 'max', 'maximum'):
  3408. b, a = invres(r, p, k, rtype=rtype)
  3409. xp_assert_close(b, b_expected, atol=5e-16)
  3410. xp_assert_close(a, a_expected, check_dtype=False, atol=5e-16)
  3411. b, a = invresz(r, p, k, rtype=rtype)
  3412. xp_assert_close(b, b_expected, atol=5e-16)
  3413. xp_assert_close(a, a_expected, check_dtype=False, atol=5e-16)
  3414. @make_xp_test_case(invres, invresz)
  3415. def test_inverse_repeated_roots_different_rtypes(self, xp):
  3416. r = xp.asarray([3 / 20, -7 / 36, -1 / 6, 2 / 45])
  3417. p = xp.asarray([0, -2, -2, -5])
  3418. k = xp.asarray([])
  3419. b_expected = xp.asarray([0.0, 0, 1, 3])
  3420. b_expected_z = xp.asarray([-1/6, -2/3, 11/6, 3])
  3421. a_expected = xp.asarray([1, 9, 24, 20, 0])
  3422. for rtype in ('avg', 'mean', 'min', 'minimum', 'max', 'maximum'):
  3423. b, a = invres(r, p, k, rtype=rtype)
  3424. xp_assert_close(b, b_expected, atol=1e-14)
  3425. xp_assert_close(a, a_expected, check_dtype=False)
  3426. b, a = invresz(r, p, k, rtype=rtype)
  3427. xp_assert_close(b, b_expected_z, atol=1e-14)
  3428. xp_assert_close(a, a_expected, check_dtype=False)
  3429. @make_xp_test_case(invres, invresz)
  3430. def test_inverse_bad_rtype(self, xp):
  3431. r = xp.asarray([3 / 20, -7 / 36, -1 / 6, 2 / 45])
  3432. p = xp.asarray([0, -2, -2, -5])
  3433. k = xp.asarray([])
  3434. with pytest.raises(ValueError, match="`rtype` must be one of"):
  3435. invres(r, p, k, rtype='median')
  3436. with pytest.raises(ValueError, match="`rtype` must be one of"):
  3437. invresz(r, p, k, rtype='median')
  3438. @make_xp_test_case(invresz)
  3439. def test_invresz_one_coefficient_bug(self, xp):
  3440. # Regression test for issue in gh-4646.
  3441. r = xp.asarray([1])
  3442. p = xp.asarray([2])
  3443. k = xp.asarray([0])
  3444. b, a = invresz(r, p, k)
  3445. xp_assert_close(b, xp.asarray([1]))
  3446. xp_assert_close(a, xp.asarray([1.0, -2.0]))
  3447. @make_xp_test_case(invres)
  3448. def test_invres(self, xp):
  3449. b, a = invres(xp.asarray([1]), xp.asarray([1]), xp.asarray([]))
  3450. assert_almost_equal(b, xp.asarray([1]))
  3451. assert_almost_equal(a, xp.asarray([1, -1]))
  3452. b, a = invres(xp.asarray([1 - 1j, 2, 0.5 - 3j]),
  3453. xp.asarray([1, 0.5j, 1 + 1j]), xp.asarray([]))
  3454. assert_almost_equal(b, xp.asarray([3.5 - 4j, -8.5 + 0.25j, 3.5 + 3.25j]))
  3455. assert_almost_equal(a, xp.asarray([1, -2 - 1.5j, 0.5 + 2j, 0.5 - 0.5j]))
  3456. b, a = invres(xp.asarray([0.5, 1]), xp.asarray([1 - 1j, 2 + 2j]),
  3457. xp.asarray([1, 2, 3]))
  3458. assert_almost_equal(b, xp.asarray([1, -1 - 1j, 1 - 2j, 0.5 - 3j, 10]))
  3459. assert_almost_equal(a, xp.asarray([1, -3 - 1j, 4]))
  3460. b, a = invres(xp.asarray([-1, 2, 1j, 3 - 1j, 4, -2]),
  3461. xp.asarray([-1, 2 - 1j, 2 - 1j, 3, 3, 3]), xp.asarray([]))
  3462. assert_almost_equal(b,
  3463. xp.asarray([4 - 1j, -28 + 16j, 40 - 62j, 100 + 24j,
  3464. -292 + 219j, 192 - 268j]))
  3465. assert_almost_equal(a,
  3466. xp.asarray([1, -12 + 2j, 53 - 20j, -96 + 68j, 27 - 72j,
  3467. 108 - 54j, -81 + 108j]))
  3468. b, a = invres(xp.asarray([-1, 1j]), xp.asarray([1, 1]), xp.asarray([1, 2]))
  3469. assert_almost_equal(b, xp.asarray([1, 0, -4, 3 + 1j]))
  3470. assert_almost_equal(a, xp.asarray([1, -2, 1]))
  3471. @make_xp_test_case(invresz)
  3472. def test_invresz(self, xp):
  3473. b, a = invresz(xp.asarray([1]), xp.asarray([1]), xp.asarray([]))
  3474. assert_almost_equal(b, xp.asarray([1]))
  3475. assert_almost_equal(a, xp.asarray([1, -1]))
  3476. b, a = invresz(xp.asarray([1 - 1j, 2, 0.5 - 3j]),
  3477. xp.asarray([1, 0.5j, 1 + 1j]), xp.asarray([]))
  3478. assert_almost_equal(b, xp.asarray([3.5 - 4j, -8.5 + 0.25j, 3.5 + 3.25j]))
  3479. assert_almost_equal(a, xp.asarray([1, -2 - 1.5j, 0.5 + 2j, 0.5 - 0.5j]))
  3480. b, a = invresz(xp.asarray([0.5, 1]),
  3481. xp.asarray([1 - 1j, 2 + 2j]),
  3482. xp.asarray([1, 2, 3]))
  3483. assert_almost_equal(b, xp.asarray([2.5, -3 - 1j, 1 - 2j, -1 - 3j, 12]))
  3484. assert_almost_equal(a, xp.asarray([1, -3 - 1j, 4]))
  3485. b, a = invresz(xp.asarray([-1, 2, 1j, 3 - 1j, 4, -2]),
  3486. xp.asarray([-1, 2 - 1j, 2 - 1j, 3, 3, 3]),
  3487. xp.asarray([]))
  3488. assert_almost_equal(b,
  3489. xp.asarray([6, -50 + 11j, 100 - 72j, 80 + 58j,
  3490. -354 + 228j, 234 - 297j]))
  3491. assert_almost_equal(a,
  3492. xp.asarray([1, -12 + 2j, 53 - 20j, -96 + 68j, 27 - 72j,
  3493. 108 - 54j, -81 + 108j]))
  3494. b, a = invresz(xp.asarray([-1, 1j]),
  3495. xp.asarray([1, 1]),
  3496. xp.asarray([1, 2]))
  3497. assert_almost_equal(b, xp.asarray([1j, 1, -3, 2]))
  3498. assert_almost_equal(a, xp.asarray([1, -2, 1]))
  3499. @skip_xp_backends(np_only=True)
  3500. @make_xp_test_case(invres, invresz)
  3501. def test_inverse_scalar_arguments(self, xp):
  3502. b, a = invres(1, 1, 1)
  3503. assert_almost_equal(b, [1, 0])
  3504. assert_almost_equal(a, [1, -1])
  3505. b, a = invresz(1, 1, 1)
  3506. assert_almost_equal(b, [2, -1])
  3507. assert_almost_equal(a, [1, -1])
  3508. @make_xp_test_case(vectorstrength)
  3509. class TestVectorstrength:
  3510. def test_single_1dperiod(self, xp):
  3511. events = xp.asarray([.5])
  3512. period = 5.
  3513. targ_strength = 1.
  3514. targ_phase = .1
  3515. strength, phase = vectorstrength(events, period)
  3516. assert strength.ndim == 0
  3517. assert phase.ndim == 0
  3518. assert math.isclose(strength, targ_strength, abs_tol=1.5e-7)
  3519. assert math.isclose(phase, 2 * math.pi * targ_phase, abs_tol=1.5e-7)
  3520. @xfail_xp_backends('torch', reason="phase modulo 2*pi")
  3521. def test_single_2dperiod(self, xp):
  3522. events = xp.asarray([.5])
  3523. period = xp.asarray([1, 2, 5.])
  3524. targ_strength = xp.asarray([1.] * 3)
  3525. targ_phase = xp.asarray([.5, .25, .1])
  3526. strength, phase = vectorstrength(events, period)
  3527. assert strength.ndim == 1
  3528. assert phase.ndim == 1
  3529. assert_array_almost_equal(strength, targ_strength)
  3530. assert_almost_equal(phase, 2 * xp.pi * targ_phase)
  3531. def test_equal_1dperiod(self, xp):
  3532. events = xp.asarray([.25, .25, .25, .25, .25, .25])
  3533. period = 2
  3534. targ_strength = 1.
  3535. targ_phase = .125
  3536. strength, phase = vectorstrength(events, period)
  3537. assert strength.ndim == 0
  3538. assert phase.ndim == 0
  3539. assert math.isclose(strength, targ_strength, abs_tol=1.5e-7)
  3540. assert math.isclose(phase, 2 * math.pi * targ_phase, abs_tol=1.5e-7)
  3541. def test_equal_2dperiod(self, xp):
  3542. events = xp.asarray([.25, .25, .25, .25, .25, .25])
  3543. period = xp.asarray([1, 2, ])
  3544. targ_strength = xp.asarray([1.] * 2)
  3545. targ_phase = xp.asarray([.25, .125])
  3546. strength, phase = vectorstrength(events, period)
  3547. assert strength.ndim == 1
  3548. assert phase.ndim == 1
  3549. assert_almost_equal(strength, targ_strength)
  3550. assert_almost_equal(phase, 2 * xp.pi * targ_phase)
  3551. def test_spaced_1dperiod(self, xp):
  3552. events = xp.asarray([.1, 1.1, 2.1, 4.1, 10.1])
  3553. period = 1
  3554. targ_strength = 1.
  3555. targ_phase = .1
  3556. strength, phase = vectorstrength(events, period)
  3557. assert strength.ndim == 0
  3558. assert phase.ndim == 0
  3559. assert math.isclose(strength, targ_strength, abs_tol=1.5e-7)
  3560. assert math.isclose(phase, 2 * math.pi * targ_phase, abs_tol=1.5e-6)
  3561. def test_spaced_2dperiod(self, xp):
  3562. events = xp.asarray([.1, 1.1, 2.1, 4.1, 10.1])
  3563. period = xp.asarray([1, .5])
  3564. targ_strength = xp.asarray([1.] * 2)
  3565. targ_phase = xp.asarray([.1, .2])
  3566. strength, phase = vectorstrength(events, period)
  3567. assert strength.ndim == 1
  3568. assert phase.ndim == 1
  3569. assert_almost_equal(strength, targ_strength)
  3570. rtol_kw = {'rtol': 2e-6} if xp_default_dtype(xp) == xp.float32 else {}
  3571. xp_assert_close(phase, 2 * xp.pi * targ_phase, **rtol_kw)
  3572. def test_partial_1dperiod(self, xp):
  3573. events = xp.asarray([.25, .5, .75])
  3574. period = 1
  3575. targ_strength = 1. / 3.
  3576. targ_phase = .5
  3577. strength, phase = vectorstrength(events, period)
  3578. assert strength.ndim == 0
  3579. assert phase.ndim == 0
  3580. assert math.isclose(strength, targ_strength)
  3581. assert math.isclose(phase, 2 * math.pi * targ_phase)
  3582. @xfail_xp_backends("torch", reason="phase modulo 2*pi")
  3583. def test_partial_2dperiod(self, xp):
  3584. events = xp.asarray([.25, .5, .75])
  3585. period = xp.asarray([1., 1., 1., 1.])
  3586. targ_strength = xp.asarray([1. / 3.] * 4)
  3587. targ_phase = xp.asarray([.5, .5, .5, .5])
  3588. strength, phase = vectorstrength(events, period)
  3589. assert strength.ndim == 1
  3590. assert phase.ndim == 1
  3591. assert_almost_equal(strength, targ_strength)
  3592. assert_almost_equal(phase, 2 * xp.pi * targ_phase)
  3593. def test_opposite_1dperiod(self, xp):
  3594. events = xp.asarray([0, .25, .5, .75])
  3595. period = 1.
  3596. targ_strength = 0
  3597. strength, phase = vectorstrength(events, period)
  3598. assert strength.ndim == 0
  3599. assert phase.ndim == 0
  3600. assert math.isclose(strength, targ_strength, abs_tol=1.5e-7)
  3601. def test_opposite_2dperiod(self, xp):
  3602. events = xp.asarray([0, .25, .5, .75])
  3603. period = xp.asarray([1.] * 10)
  3604. targ_strength = xp.asarray([0.] * 10)
  3605. strength, phase = vectorstrength(events, period)
  3606. assert strength.ndim == 1
  3607. assert phase.ndim == 1
  3608. assert_almost_equal(strength, targ_strength)
  3609. def test_2d_events_ValueError(self, xp):
  3610. events = xp.asarray([[1, 2]])
  3611. period = 1.
  3612. assert_raises(ValueError, vectorstrength, events, period)
  3613. def test_2d_period_ValueError(self, xp):
  3614. events = 1.
  3615. period = xp.asarray([[1]])
  3616. assert_raises(ValueError, vectorstrength, events, period)
  3617. def test_zero_period_ValueError(self, xp):
  3618. events = 1.
  3619. period = 0
  3620. assert_raises(ValueError, vectorstrength, events, period)
  3621. def test_negative_period_ValueError(self, xp):
  3622. events = 1.
  3623. period = -1
  3624. assert_raises(ValueError, vectorstrength, events, period)
  3625. # XXX: restore testing on CuPy, where possible. Multiple issues in this test:
  3626. # 1. _zi functions deliberately incompatible in cupy
  3627. # (https://github.com/scipy/scipy/pull/21713#issuecomment-2417494443)
  3628. # 2. a CuPy issue to be fixed in 14.0 only
  3629. # (https://github.com/cupy/cupy/pull/8677)
  3630. # 3. an issue with CuPy's __array__ not numpy-2.0 compatible
  3631. @skip_xp_backends(cpu_only=True)
  3632. @make_xp_test_case(sosfilt)
  3633. @pytest.mark.parametrize('dt', ['float32', 'float64', 'complex64', 'complex128'])
  3634. class TestSOSFilt:
  3635. # The test_rank* tests are pulled from _TestLinearFilter
  3636. @skip_xp_backends('jax.numpy', reason='buffer array is read-only')
  3637. def test_rank1(self, dt, xp):
  3638. dt = getattr(xp, dt)
  3639. x = xp.linspace(0, 5, 6, dtype=dt)
  3640. b = xp.asarray([1, -1], dtype=dt)
  3641. a = xp.asarray([0.5, -0.5], dtype=dt)
  3642. # Test simple IIR
  3643. y_r = xp.asarray([0, 2, 4, 6, 8, 10.], dtype=dt)
  3644. bb, aa = map(np.asarray, (b, a))
  3645. sos = tf2sos(bb, aa)
  3646. sos = xp.asarray(sos) # XXX while tf2sos is numpy only
  3647. assert_array_almost_equal(sosfilt(sos, x), y_r)
  3648. # Test simple FIR
  3649. b = xp.asarray([1, 1], dtype=dt)
  3650. # NOTE: This was changed (rel. to TestLinear...) to add a pole @zero:
  3651. a = xp.asarray([1, 0], dtype=dt)
  3652. y_r = xp.asarray([0, 1, 3, 5, 7, 9.], dtype=dt)
  3653. bb, aa = map(np.asarray, (b, a))
  3654. sos = tf2sos(bb, aa)
  3655. sos = xp.asarray(sos) # XXX while tf2sos is numpy only
  3656. assert_array_almost_equal(sosfilt(sos, x), y_r)
  3657. b = xp.asarray([1.0, 1, 0])
  3658. a = xp.asarray([1.0, 0, 0])
  3659. x = xp.ones(8)
  3660. sos = xp.concat((b, a))
  3661. sos = xp.reshape(sos, (1, 6))
  3662. y = sosfilt(sos, x)
  3663. xp_assert_close(y, xp.asarray([1.0, 2, 2, 2, 2, 2, 2, 2]))
  3664. @skip_xp_backends('jax.numpy', reason='buffer array is read-only')
  3665. def test_rank2(self, dt, xp):
  3666. dt = getattr(xp, dt)
  3667. shape = (4, 3)
  3668. prodshape = math.prod(shape)
  3669. x = xp.linspace(0, prodshape - 1, prodshape, dtype=dt)
  3670. x = xp.reshape(x, shape)
  3671. b = xp.asarray([1, -1], dtype=dt)
  3672. a = xp.asarray([0.5, 0.5], dtype=dt)
  3673. y_r2_a0 = xp.asarray([[0, 2, 4], [6, 4, 2], [0, 2, 4], [6, 4, 2]],
  3674. dtype=dt)
  3675. y_r2_a1 = xp.asarray([[0, 2, 0], [6, -4, 6], [12, -10, 12],
  3676. [18, -16, 18]], dtype=dt)
  3677. bb, aa = map(np.asarray, (b, a))
  3678. sos = tf2sos(bb, aa)
  3679. sos = xp.asarray(sos) # XXX
  3680. y = sosfilt(sos, x, axis=0)
  3681. assert_array_almost_equal(y_r2_a0, y)
  3682. sos = tf2sos(bb, aa)
  3683. sos = xp.asarray(sos) # XXX
  3684. y = sosfilt(sos, x, axis=1)
  3685. assert_array_almost_equal(y_r2_a1, y)
  3686. @skip_xp_backends('jax.numpy', reason='buffer array is read-only')
  3687. def test_rank3(self, dt, xp):
  3688. dt = getattr(xp, dt)
  3689. shape = (4, 3, 2)
  3690. prodshape = math.prod(shape)
  3691. x = xp.linspace(0, prodshape - 1, prodshape)
  3692. x = xp.reshape(x, shape)
  3693. b = xp.asarray([1, -1], dtype=dt)
  3694. a = xp.asarray([0.5, 0.5], dtype=dt)
  3695. # Test last axis
  3696. bb, aa = map(np.asarray, (b, a)) # XXX until tf2sos is array api compatible
  3697. sos = tf2sos(bb, aa)
  3698. sos = xp.asarray(sos) # XXX
  3699. y = sosfilt(sos, x)
  3700. for i in range(x.shape[0]):
  3701. for j in range(x.shape[1]):
  3702. assert_array_almost_equal(y[i, j, ...], lfilter(b, a, x[i, j, ...]))
  3703. def _get_ab_sos(self, xp):
  3704. b1, a1 = signal.butter(2, 0.25, 'low')
  3705. b2, a2 = signal.butter(2, 0.75, 'low')
  3706. b3, a3 = signal.butter(2, 0.75, 'low')
  3707. b = np.convolve(np.convolve(b1, b2), b3)
  3708. a = np.convolve(np.convolve(a1, a2), a3)
  3709. sos = np.array((np.r_[b1, a1], np.r_[b2, a2], np.r_[b3, a3]))
  3710. a, b, sos = map(xp.asarray, (a, b, sos))
  3711. return a, b, sos
  3712. @skip_xp_backends('jax.numpy', reason='item assignment')
  3713. def test_initial_conditions(self, dt, xp):
  3714. a, b, sos = self._get_ab_sos(xp)
  3715. x = np.random.rand(50).astype(dt)
  3716. x = xp.asarray(x)
  3717. dt = getattr(xp, dt)
  3718. # Stopping filtering and continuing
  3719. y_true, zi = lfilter(b, a, x[:20], zi=xp.zeros(6))
  3720. y_true = xp.concat((y_true, lfilter(b, a, x[20:], zi=zi)[0]))
  3721. xp_assert_close(y_true, lfilter(b, a, x))
  3722. y_sos, zi = sosfilt(sos, x[:20], zi=xp.zeros((3, 2)))
  3723. y_sos = xp.concat((y_sos, sosfilt(sos, x[20:], zi=zi)[0]))
  3724. xp_assert_close(y_true, y_sos)
  3725. # Use a step function
  3726. zi = sosfilt_zi(sos)
  3727. x = xp.ones(8, dtype=dt)
  3728. y, zf = sosfilt(sos, x, zi=zi)
  3729. xp_assert_close(y, xp.ones(8), check_dtype=False)
  3730. xp_assert_close(zf, zi, check_dtype=False)
  3731. @skip_xp_backends('jax.numpy', reason='item assignment')
  3732. @skip_xp_backends('array_api_strict', reason='fancy indexing not supported')
  3733. def test_initial_conditions_2(self, dt, xp):
  3734. dt = getattr(xp, dt)
  3735. x = xp.ones(8, dtype=dt)
  3736. _, _, sos = self._get_ab_sos(xp)
  3737. zi = sosfilt_zi(sos)
  3738. # Initial condition shape matching
  3739. x = xp.reshape(x, (1, 1) + x.shape) # 3D
  3740. with pytest.raises(ValueError):
  3741. sosfilt(sos, x, zi=zi)
  3742. zi_nd = xp_copy(zi, xp=xp)
  3743. zi_nd = xp.reshape(zi_nd, (zi.shape[0], 1, 1, zi.shape[-1]))
  3744. with pytest.raises(ValueError):
  3745. sosfilt(sos, x, zi=zi_nd[:, :, :, [0, 1, 1]])
  3746. y, zf = sosfilt(sos, x, zi=zi_nd)
  3747. xp_assert_close(y[0, 0], xp.ones(8), check_dtype=False)
  3748. xp_assert_close(zf[:, 0, 0, :], zi, check_dtype=False)
  3749. @skip_xp_backends('jax.numpy', reason='item assignment')
  3750. def test_initial_conditions_3d_axis1(self, dt, xp):
  3751. # Test the use of zi when sosfilt is applied to axis 1 of a 3-d input.
  3752. # Input array is x.
  3753. x = np.random.RandomState(159).randint(0, 5, size=(2, 15, 3))
  3754. x = x.astype(dt)
  3755. x = xp.asarray(x)
  3756. # Design a filter in ZPK format and convert to SOS
  3757. zpk = signal.butter(6, 0.35, output='zpk')
  3758. sos = zpk2sos(*zpk)
  3759. sos = xp.asarray(sos) # XXX while zpk2sos is numpy-only
  3760. nsections = sos.shape[0]
  3761. # Filter along this axis.
  3762. axis = 1
  3763. # Initial conditions, all zeros.
  3764. shp = list(x.shape)
  3765. shp[axis] = 2
  3766. shp = tuple([nsections] + shp)
  3767. z0 = xp.zeros(shp)
  3768. # Apply the filter to x.
  3769. yf, zf = sosfilt(sos, x, axis=axis, zi=z0)
  3770. # Apply the filter to x in two stages.
  3771. y1, z1 = sosfilt(sos, x[:, :5, :], axis=axis, zi=z0)
  3772. y2, z2 = sosfilt(sos, x[:, 5:, :], axis=axis, zi=z1)
  3773. # y should equal yf, and z2 should equal zf.
  3774. y = xp.concat((y1, y2), axis=axis)
  3775. xp_assert_close(y, yf, rtol=1e-10, atol=1e-13)
  3776. xp_assert_close(z2, zf, rtol=1e-10, atol=1e-13)
  3777. # let's try the "step" initial condition
  3778. zi = sosfilt_zi(sos)
  3779. zi = xp.reshape(zi, (nsections, 1, 2, 1))
  3780. zi = zi * x[:, 0:1, :]
  3781. y = sosfilt(sos, x, axis=axis, zi=zi)[0]
  3782. # check it against the TF form
  3783. b, a = zpk2tf(*zpk)
  3784. b, a = xp.asarray(b), xp.asarray(a) # XXX while zpk2tf is numpy-only
  3785. zi = lfilter_zi(b, a)
  3786. zi = xp.reshape(zi, (1, xp_size(zi), 1))
  3787. zi = zi * x[:, 0:1, :]
  3788. y_tf = lfilter(b, a, x, axis=axis, zi=zi)[0]
  3789. xp_assert_close(y, y_tf, rtol=1e-10, atol=1e-13)
  3790. @skip_xp_backends('torch', reason='issues a RuntimeWarning')
  3791. @skip_xp_backends('jax.numpy', reason='item assignment')
  3792. def test_bad_zi_shape(self, dt, xp):
  3793. dt = getattr(xp, dt)
  3794. # The shape of zi is checked before using any values in the
  3795. # arguments, so np.empty is fine for creating the arguments.
  3796. x = xp.empty((3, 15, 3), dtype=dt)
  3797. sos = xp.zeros((4, 6))
  3798. zi = xp.empty((4, 3, 3, 2)) # Correct shape is (4, 3, 2, 3)
  3799. with pytest.raises(ValueError, match='should be all ones'):
  3800. sosfilt(sos, x, zi=zi, axis=1)
  3801. sos[:, 3] = 1.
  3802. with pytest.raises(ValueError, match='Invalid zi shape'):
  3803. sosfilt(sos, x, zi=zi, axis=1)
  3804. @skip_xp_backends('jax.numpy', reason='item assignment')
  3805. def test_sosfilt_zi(self, dt, xp):
  3806. dt = getattr(xp, dt)
  3807. sos = signal.butter(6, 0.2, output='sos')
  3808. sos = xp.asarray(sos) # XXX while butter is np-only
  3809. zi = sosfilt_zi(sos)
  3810. y, zf = sosfilt(sos, xp.ones(40, dtype=dt), zi=zi)
  3811. xp_assert_close(zf, zi, rtol=1e-13, check_dtype=False)
  3812. # Expected steady state value of the step response of this filter:
  3813. ss = xp.prod(xp.sum(sos[:, :3], axis=-1) / xp.sum(sos[:, 3:], axis=-1))
  3814. xp_assert_close(y, ss * xp.ones_like(y), rtol=1e-13)
  3815. @skip_xp_backends(np_only=True)
  3816. def test_sosfilt_zi_2(self, dt, xp):
  3817. # zi as array-like
  3818. dt = getattr(xp, dt)
  3819. sos = signal.butter(6, 0.2, output='sos')
  3820. sos = xp.asarray(sos) # XXX while butter is np-only
  3821. zi = sosfilt_zi(sos)
  3822. _, zf = sosfilt(sos, xp.ones(40, dtype=dt), zi=zi.tolist())
  3823. xp_assert_close(zf, zi, rtol=1e-13, check_dtype=False)
  3824. @make_xp_test_case(signal.deconvolve)
  3825. class TestDeconvolve:
  3826. @skip_xp_backends(np_only=True, reason="list inputs are numpy-specific")
  3827. def test_array_like(self, xp):
  3828. # From docstring example: with lists
  3829. original = [0.0, 1, 0, 0, 1, 1, 0, 0]
  3830. impulse_response = [2, 1]
  3831. recorded = xp.asarray([0.0, 2, 1, 0, 2, 3, 1, 0, 0])
  3832. recovered, remainder = signal.deconvolve(recorded, impulse_response)
  3833. xp_assert_close(recovered, original)
  3834. def test_basic(self, xp):
  3835. # From docstring example
  3836. original = xp.asarray([0.0, 1, 0, 0, 1, 1, 0, 0], dtype=xp.float64)
  3837. impulse_response = xp.asarray([2, 1])
  3838. recorded = xp.asarray([0.0, 2, 1, 0, 2, 3, 1, 0, 0])
  3839. recovered, remainder = signal.deconvolve(recorded, impulse_response)
  3840. xp_assert_close(recovered, original)
  3841. @xfail_xp_backends("cupy", reason="different error message")
  3842. def test_n_dimensional_signal(self, xp):
  3843. recorded = xp.asarray([[0, 0], [0, 0]])
  3844. impulse_response = xp.asarray([0, 0])
  3845. with pytest.raises(ValueError, match="^Parameter signal must be non-empty"):
  3846. quotient, remainder = signal.deconvolve(recorded, impulse_response)
  3847. @xfail_xp_backends("cupy", reason="different error message")
  3848. def test_n_dimensional_divisor(self, xp):
  3849. recorded = xp.asarray([0, 0])
  3850. impulse_response = xp.asarray([[0, 0], [0, 0]])
  3851. with pytest.raises(ValueError, match="^Parameter divisor must be non-empty"):
  3852. quotient, remainder = signal.deconvolve(recorded, impulse_response)
  3853. def test_divisor_greater_signal(self, xp):
  3854. """Return signal as `remainder` when ``len(divisior) > len(signal)``. """
  3855. sig, div = xp.asarray([0, 1, 2]), xp.asarray([0, 1, 2, 4, 5])
  3856. quotient, remainder = signal.deconvolve(sig, div)
  3857. xp_assert_equal(remainder, sig)
  3858. assert xp_size(xp.asarray(quotient)) == 0
  3859. @make_xp_test_case(detrend)
  3860. class TestDetrend:
  3861. def test_basic(self, xp):
  3862. detrended = detrend(xp.asarray([1, 2, 3]))
  3863. detrended_exact = xp.asarray([0, 0, 0])
  3864. assert_array_almost_equal(detrended, detrended_exact)
  3865. @skip_xp_backends("jax.numpy", reason="overwrite_data not implemented")
  3866. def test_copy(self, xp):
  3867. x = xp.asarray([1, 1.2, 1.5, 1.6, 2.4])
  3868. copy_array = detrend(x, overwrite_data=False)
  3869. inplace = detrend(x, overwrite_data=True)
  3870. assert_array_almost_equal(copy_array, inplace)
  3871. @pytest.mark.parametrize('kind', ['linear', 'constant'])
  3872. @pytest.mark.parametrize('axis', [0, 1, 2])
  3873. def test_axis(self, axis, kind, xp):
  3874. data = xp.reshape(xp.arange(5*6*7), (5, 6, 7))
  3875. detrended = detrend(data, type=kind, axis=axis)
  3876. assert detrended.shape == data.shape
  3877. def test_bp(self, xp):
  3878. data = [0, 1, 2] + [5, 0, -5, -10]
  3879. data = xp.asarray(data)
  3880. detrended = detrend(data, type='linear', bp=3)
  3881. xp_assert_close(detrended, xp.zeros_like(detrended), atol=1e-14)
  3882. # repeat with ndim > 1 and axis
  3883. data = xp.asarray(data)[None, :, None]
  3884. detrended = detrend(data, type="linear", bp=3, axis=1)
  3885. xp_assert_close(detrended, xp.zeros_like(detrended), atol=1e-14)
  3886. # breakpoint index > shape[axis]: raises
  3887. with assert_raises(ValueError):
  3888. detrend(data, type="linear", bp=3)
  3889. @pytest.mark.parametrize('bp', [np.array([0, 2]), [0, 2]])
  3890. def test_detrend_array_bp(self, bp, xp):
  3891. # regression test for https://github.com/scipy/scipy/issues/18675
  3892. rng = np.random.RandomState(12345)
  3893. x = rng.rand(10)
  3894. x = xp.asarray(x, dtype=xp_default_dtype(xp))
  3895. if isinstance(bp, np.ndarray) and not is_jax(xp):
  3896. # JAX expects a static array for bp, so don't call xp.asarray
  3897. # for JAX.
  3898. bp = xp.asarray(bp)
  3899. else:
  3900. if not (is_numpy(xp) or is_jax(xp)):
  3901. pytest.skip("list bp is currently numpy and jax only")
  3902. res = detrend(x, bp=bp)
  3903. res_scipy_191 = xp.asarray([-4.44089210e-16, -2.22044605e-16,
  3904. -1.11128506e-01, -1.69470553e-01, 1.14710683e-01, 6.35468419e-02,
  3905. 3.53533144e-01, -3.67877935e-02, -2.00417675e-02, -1.94362049e-01])
  3906. atol = 3e-7 if xp_default_dtype(xp) == xp.float32 else 1e-14
  3907. xp_assert_close(res, res_scipy_191, atol=atol)
  3908. @make_xp_test_case(unique_roots)
  3909. class TestUniqueRoots:
  3910. def test_real_no_repeat(self, xp):
  3911. p = xp.asarray([-1.0, -0.5, 0.3, 1.2, 10.0])
  3912. unique, multiplicity = unique_roots(p)
  3913. assert_almost_equal(unique, p, decimal=15)
  3914. xp_assert_equal(multiplicity, xp.ones(len(p), dtype=int))
  3915. def test_real_repeat(self, xp):
  3916. p = xp.asarray([-1.0, -0.95, -0.89, -0.8, 0.5, 1.0, 1.05])
  3917. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='min')
  3918. assert_almost_equal(unique, xp.asarray([-1.0, -0.89, 0.5, 1.0]), decimal=15)
  3919. xp_assert_equal(multiplicity, xp.asarray([2, 2, 1, 2]))
  3920. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='max')
  3921. assert_almost_equal(unique, xp.asarray([-0.95, -0.8, 0.5, 1.05]), decimal=15)
  3922. xp_assert_equal(multiplicity, xp.asarray([2, 2, 1, 2]))
  3923. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='avg')
  3924. assert_almost_equal(unique, xp.asarray([-0.975, -0.845, 0.5, 1.025]),
  3925. decimal=15)
  3926. xp_assert_equal(multiplicity, xp.asarray([2, 2, 1, 2]))
  3927. def test_complex_no_repeat(self, xp):
  3928. p = xp.asarray([-1.0, 1.0j, 0.5 + 0.5j, -1.0 - 1.0j, 3.0 + 2.0j])
  3929. unique, multiplicity = unique_roots(p)
  3930. assert_almost_equal(unique, p, decimal=15)
  3931. xp_assert_equal(multiplicity, xp.ones(len(p), dtype=int))
  3932. def test_complex_repeat(self, xp):
  3933. p = xp.asarray([-1.0, -1.0 + 0.05j, -0.95 + 0.15j, -0.90 + 0.15j, 0.0,
  3934. 0.5 + 0.5j, 0.45 + 0.55j])
  3935. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='min')
  3936. assert_almost_equal(unique,
  3937. xp.asarray([-1.0, -0.95 + 0.15j, 0.0, 0.45 + 0.55j]),
  3938. decimal=15)
  3939. xp_assert_equal(multiplicity, xp.asarray([2, 2, 1, 2]))
  3940. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='max')
  3941. assert_almost_equal(
  3942. unique,
  3943. xp.asarray(
  3944. [-1.0 + 0.05j, -0.90 + 0.15j, 0.0, 0.5 + 0.5j]
  3945. ),
  3946. decimal=15,
  3947. )
  3948. xp_assert_equal(multiplicity, xp.asarray([2, 2, 1, 2]))
  3949. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='avg')
  3950. assert_almost_equal(
  3951. unique,
  3952. xp.asarray([-1.0 + 0.025j, -0.925 + 0.15j, 0.0, 0.475 + 0.525j]),
  3953. decimal=15,
  3954. )
  3955. xp_assert_equal(multiplicity, xp.asarray([2, 2, 1, 2]))
  3956. def test_gh_4915(self, xp):
  3957. p = xp.asarray(np.roots(np.convolve(np.ones(5), np.ones(5))))
  3958. true_roots = xp.asarray(
  3959. [-(-1)**(1/5), (-1)**(4/5), -(-1)**(3/5), (-1)**(2/5)]
  3960. )
  3961. unique, multiplicity = unique_roots(p)
  3962. unique = xp.sort(unique)
  3963. assert_almost_equal(xp.sort(unique), true_roots, decimal=7)
  3964. xp_assert_equal(multiplicity, xp.asarray([2, 2, 2, 2]))
  3965. def test_complex_roots_extra(self, xp):
  3966. unique, multiplicity = unique_roots(xp.asarray([1.0, 1.0j, 1.0]))
  3967. assert_almost_equal(unique, xp.asarray([1.0, 1.0j]), decimal=15)
  3968. xp_assert_equal(multiplicity, xp.asarray([2, 1]))
  3969. unique, multiplicity = unique_roots(
  3970. xp.asarray([1, 1 + 2e-9, 1e-9 + 1j]), tol=0.1
  3971. )
  3972. assert_almost_equal(unique, xp.asarray([1.0, 1e-9 + 1.0j]), decimal=15)
  3973. xp_assert_equal(multiplicity, xp.asarray([2, 1]))
  3974. def test_single_unique_root(self, xp):
  3975. p = xp.asarray(np.random.rand(100) + 1j * np.random.rand(100))
  3976. unique, multiplicity = unique_roots(p, 2)
  3977. assert_almost_equal(unique, xp.asarray([np.min(p)]), decimal=15)
  3978. xp_assert_equal(multiplicity, xp.asarray([100]))
  3979. def test_gh_22684():
  3980. actual = signal.resample_poly(np.arange(2000, dtype=np.complex64), 6, 4)
  3981. assert actual.dtype == np.complex64