_signaltools.py 192 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356
  1. # Author: Travis Oliphant
  2. # 1999 -- 2002
  3. from __future__ import annotations # Provides typing union operator `|` in Python 3.9
  4. import operator
  5. import math
  6. from math import prod as _prod
  7. import timeit
  8. import warnings
  9. from typing import Literal
  10. from numpy._typing import ArrayLike
  11. from scipy.spatial import cKDTree
  12. from . import _sigtools
  13. from ._ltisys import dlti
  14. from ._upfirdn import upfirdn, _output_len, _upfirdn_modes
  15. from scipy import linalg, fft as sp_fft
  16. from scipy import ndimage
  17. from scipy.fft._helper import _init_nd_shape_and_axes
  18. import numpy as np
  19. from scipy.special import lambertw
  20. from .windows import get_window
  21. from ._arraytools import axis_slice, axis_reverse, odd_ext, even_ext, const_ext
  22. from ._filter_design import cheby1, _validate_sos, zpk2sos
  23. from ._fir_filter_design import firwin
  24. from ._sosfilt import _sosfilt
  25. from scipy._lib._array_api import (
  26. array_namespace, is_torch, is_numpy, xp_copy, xp_size, xp_default_dtype,
  27. xp_swapaxes
  28. )
  29. from scipy._lib.array_api_compat import is_array_api_obj
  30. import scipy._lib.array_api_extra as xpx
  31. __all__ = ['correlate', 'correlation_lags', 'correlate2d',
  32. 'convolve', 'convolve2d', 'fftconvolve', 'oaconvolve',
  33. 'order_filter', 'medfilt', 'medfilt2d', 'wiener', 'lfilter',
  34. 'lfiltic', 'sosfilt', 'deconvolve', 'hilbert', 'hilbert2', 'envelope',
  35. 'unique_roots', 'invres', 'invresz', 'residue',
  36. 'residuez', 'resample', 'resample_poly', 'detrend',
  37. 'lfilter_zi', 'sosfilt_zi', 'sosfiltfilt', 'choose_conv_method',
  38. 'filtfilt', 'decimate', 'vectorstrength']
  39. _modedict = {'valid': 0, 'same': 1, 'full': 2}
  40. _boundarydict = {'fill': 0, 'pad': 0, 'wrap': 2, 'circular': 2, 'symm': 1,
  41. 'symmetric': 1, 'reflect': 4}
  42. def _valfrommode(mode):
  43. try:
  44. return _modedict[mode]
  45. except KeyError as e:
  46. raise ValueError("Acceptable mode flags are 'valid',"
  47. " 'same', or 'full'.") from e
  48. def _bvalfromboundary(boundary):
  49. try:
  50. return _boundarydict[boundary] << 2
  51. except KeyError as e:
  52. raise ValueError("Acceptable boundary flags are 'fill', 'circular' "
  53. "(or 'wrap'), and 'symmetric' (or 'symm').") from e
  54. def _inputs_swap_needed(mode, shape1, shape2, axes=None):
  55. """Determine if inputs arrays need to be swapped in `"valid"` mode.
  56. If in `"valid"` mode, returns whether or not the input arrays need to be
  57. swapped depending on whether `shape1` is at least as large as `shape2` in
  58. every calculated dimension.
  59. This is important for some of the correlation and convolution
  60. implementations in this module, where the larger array input needs to come
  61. before the smaller array input when operating in this mode.
  62. Note that if the mode provided is not 'valid', False is immediately
  63. returned.
  64. """
  65. if mode != 'valid':
  66. return False
  67. if not shape1:
  68. return False
  69. if axes is None:
  70. axes = range(len(shape1))
  71. ok1 = all(shape1[i] >= shape2[i] for i in axes)
  72. ok2 = all(shape2[i] >= shape1[i] for i in axes)
  73. if not (ok1 or ok2):
  74. raise ValueError("For 'valid' mode, one must be at least "
  75. "as large as the other in every dimension")
  76. return not ok1
  77. def correlate(in1, in2, mode='full', method='auto'):
  78. r"""
  79. Cross-correlate two N-dimensional arrays.
  80. Cross-correlate `in1` and `in2`, with the output size determined by the
  81. `mode` argument.
  82. Parameters
  83. ----------
  84. in1 : array_like
  85. First input.
  86. in2 : array_like
  87. Second input. Should have the same number of dimensions as `in1`.
  88. mode : str {'full', 'valid', 'same'}, optional
  89. A string indicating the size of the output:
  90. ``full``
  91. The output is the full discrete linear cross-correlation
  92. of the inputs. (Default)
  93. ``valid``
  94. The output consists only of those elements that do not
  95. rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
  96. must be at least as large as the other in every dimension.
  97. ``same``
  98. The output is the same size as `in1`, centered
  99. with respect to the 'full' output.
  100. method : str {'auto', 'direct', 'fft'}, optional
  101. A string indicating which method to use to calculate the correlation.
  102. ``direct``
  103. The correlation is determined directly from sums, the definition of
  104. correlation.
  105. ``fft``
  106. The Fast Fourier Transform is used to perform the correlation more
  107. quickly (only available for numerical arrays.)
  108. ``auto``
  109. Automatically chooses direct or Fourier method based on an estimate
  110. of which is faster (default). See `convolve` Notes for more detail.
  111. .. versionadded:: 0.19.0
  112. Returns
  113. -------
  114. correlate : array
  115. An N-dimensional array containing a subset of the discrete linear
  116. cross-correlation of `in1` with `in2`.
  117. See Also
  118. --------
  119. choose_conv_method : contains more documentation on `method`.
  120. correlation_lags : calculates the lag / displacement indices array for 1D
  121. cross-correlation.
  122. Notes
  123. -----
  124. The correlation z of two d-dimensional arrays x and y is defined as::
  125. z[...,k,...] = sum[..., i_l, ...] x[..., i_l,...] * conj(y[..., i_l - k,...])
  126. This way, if ``x`` and ``y`` are 1-D arrays and ``z = correlate(x, y, 'full')``
  127. then
  128. .. math::
  129. z[k] = \sum_{l=0}^{N-1} x_l \, y_{l-k}^{*}
  130. for :math:`k = -(M-1), \dots, (N-1)`, where :math:`N` is the length of ``x``,
  131. :math:`M` is the length of ``y``, and :math:`y_m = 0` when :math:`m` is outside the
  132. valid range :math:`[0, M-1]`. The size of :math:`z` is :math:`N + M - 1` and
  133. :math:`y^*` denotes the complex conjugate of :math:`y`.
  134. ``method='fft'`` only works for numerical arrays as it relies on
  135. `fftconvolve`. In certain cases (i.e., arrays of objects or when
  136. rounding integers can lose precision), ``method='direct'`` is always used.
  137. When using ``mode='same'`` with even-length inputs, the outputs of `correlate`
  138. and `correlate2d` differ: There is a 1-index offset between them.
  139. Examples
  140. --------
  141. Implement a matched filter using cross-correlation, to recover a signal
  142. that has passed through a noisy channel.
  143. >>> import numpy as np
  144. >>> from scipy import signal
  145. >>> import matplotlib.pyplot as plt
  146. >>> rng = np.random.default_rng()
  147. >>> sig = np.repeat([0., 1., 1., 0., 1., 0., 0., 1.], 128)
  148. >>> sig_noise = sig + rng.standard_normal(len(sig))
  149. >>> corr = signal.correlate(sig_noise, np.ones(128), mode='same') / 128
  150. >>> clock = np.arange(64, len(sig), 128)
  151. >>> fig, (ax_orig, ax_noise, ax_corr) = plt.subplots(3, 1, sharex=True)
  152. >>> ax_orig.plot(sig)
  153. >>> ax_orig.plot(clock, sig[clock], 'ro')
  154. >>> ax_orig.set_title('Original signal')
  155. >>> ax_noise.plot(sig_noise)
  156. >>> ax_noise.set_title('Signal with noise')
  157. >>> ax_corr.plot(corr)
  158. >>> ax_corr.plot(clock, corr[clock], 'ro')
  159. >>> ax_corr.axhline(0.5, ls=':')
  160. >>> ax_corr.set_title('Cross-correlated with rectangular pulse')
  161. >>> ax_orig.margins(0, 0.1)
  162. >>> fig.tight_layout()
  163. >>> plt.show()
  164. Compute the cross-correlation of a noisy signal with the original signal.
  165. >>> x = np.arange(128) / 128
  166. >>> sig = np.sin(2 * np.pi * x)
  167. >>> sig_noise = sig + rng.standard_normal(len(sig))
  168. >>> corr = signal.correlate(sig_noise, sig)
  169. >>> lags = signal.correlation_lags(len(sig), len(sig_noise))
  170. >>> corr /= np.max(corr)
  171. >>> fig, (ax_orig, ax_noise, ax_corr) = plt.subplots(3, 1, figsize=(4.8, 4.8))
  172. >>> ax_orig.plot(sig)
  173. >>> ax_orig.set_title('Original signal')
  174. >>> ax_orig.set_xlabel('Sample Number')
  175. >>> ax_noise.plot(sig_noise)
  176. >>> ax_noise.set_title('Signal with noise')
  177. >>> ax_noise.set_xlabel('Sample Number')
  178. >>> ax_corr.plot(lags, corr)
  179. >>> ax_corr.set_title('Cross-correlated signal')
  180. >>> ax_corr.set_xlabel('Lag')
  181. >>> ax_orig.margins(0, 0.1)
  182. >>> ax_noise.margins(0, 0.1)
  183. >>> ax_corr.margins(0, 0.1)
  184. >>> fig.tight_layout()
  185. >>> plt.show()
  186. """
  187. xp = array_namespace(in1, in2)
  188. in1 = xp.asarray(in1)
  189. in2 = xp.asarray(in2)
  190. if in1.ndim == in2.ndim == 0:
  191. in2_conj = (xp.conj(in2)
  192. if xp.isdtype(in2.dtype, 'complex floating')
  193. else in2)
  194. return in1 * in2_conj
  195. elif in1.ndim != in2.ndim:
  196. raise ValueError("in1 and in2 should have the same dimensionality")
  197. # Don't use _valfrommode, since correlate should not accept numeric modes
  198. try:
  199. val = _modedict[mode]
  200. except KeyError as e:
  201. raise ValueError("Acceptable mode flags are 'valid',"
  202. " 'same', or 'full'.") from e
  203. # this either calls fftconvolve or this function with method=='direct'
  204. if method in ('fft', 'auto'):
  205. return convolve(in1, _reverse_and_conj(in2, xp), mode, method)
  206. elif method == 'direct':
  207. # fastpath to faster numpy.correlate for 1d inputs when possible
  208. if _np_conv_ok(in1, in2, mode, xp):
  209. a_in1 = np.asarray(in1)
  210. a_in2 = np.asarray(in2)
  211. out = np.correlate(a_in1, a_in2, mode)
  212. return xp.asarray(out)
  213. # _correlateND is far slower when in2.size > in1.size, so swap them
  214. # and then undo the effect afterward if mode == 'full'. Also, it fails
  215. # with 'valid' mode if in2 is larger than in1, so swap those, too.
  216. # Don't swap inputs for 'same' mode, since shape of in1 matters.
  217. swapped_inputs = ((mode == 'full') and (xp_size(in2) > xp_size(in1)) or
  218. _inputs_swap_needed(mode, in1.shape, in2.shape))
  219. if swapped_inputs:
  220. in1, in2 = in2, in1
  221. # convert to numpy & back for _sigtools._correlateND
  222. a_in1 = np.asarray(in1)
  223. a_in2 = np.asarray(in2)
  224. if mode == 'valid':
  225. ps = [i - j + 1 for i, j in zip(in1.shape, in2.shape)]
  226. out = np.empty(ps, a_in1.dtype)
  227. z = _sigtools._correlateND(a_in1, a_in2, out, val)
  228. else:
  229. ps = [i + j - 1 for i, j in zip(in1.shape, in2.shape)]
  230. # zero pad input
  231. in1zpadded = np.zeros(ps, a_in1.dtype)
  232. sc = tuple(slice(0, i) for i in in1.shape)
  233. in1zpadded[sc] = a_in1.copy()
  234. if mode == 'full':
  235. out = np.empty(ps, a_in1.dtype)
  236. elif mode == 'same':
  237. out = np.empty(in1.shape, a_in1.dtype)
  238. z = _sigtools._correlateND(in1zpadded, a_in2, out, val)
  239. z = xp.asarray(z)
  240. if swapped_inputs:
  241. # Reverse and conjugate to undo the effect of swapping inputs
  242. z = _reverse_and_conj(z, xp)
  243. return z
  244. else:
  245. raise ValueError("Acceptable method flags are 'auto',"
  246. " 'direct', or 'fft'.")
  247. def correlation_lags(in1_len, in2_len, mode='full'):
  248. r"""
  249. Calculates the lag / displacement indices array for 1D cross-correlation.
  250. Parameters
  251. ----------
  252. in1_len : int
  253. First input size.
  254. in2_len : int
  255. Second input size.
  256. mode : str {'full', 'valid', 'same'}, optional
  257. A string indicating the size of the output.
  258. See the documentation `correlate` for more information.
  259. Returns
  260. -------
  261. lags : array
  262. Returns an array containing cross-correlation lag/displacement indices.
  263. Indices can be indexed with the np.argmax of the correlation to return
  264. the lag/displacement.
  265. See Also
  266. --------
  267. correlate : Compute the N-dimensional cross-correlation.
  268. Notes
  269. -----
  270. Cross-correlation for continuous functions :math:`f` and :math:`g` is
  271. defined as:
  272. .. math::
  273. \left ( f\star g \right )\left ( \tau \right )
  274. \triangleq \int_{t_0}^{t_0 +T}
  275. \overline{f\left ( t \right )}g\left ( t+\tau \right )dt
  276. Where :math:`\tau` is defined as the displacement, also known as the lag.
  277. Cross correlation for discrete functions :math:`f` and :math:`g` is
  278. defined as:
  279. .. math::
  280. \left ( f\star g \right )\left [ n \right ]
  281. \triangleq \sum_{-\infty}^{\infty}
  282. \overline{f\left [ m \right ]}g\left [ m+n \right ]
  283. Where :math:`n` is the lag.
  284. Examples
  285. --------
  286. Cross-correlation of a signal with its time-delayed self.
  287. >>> import numpy as np
  288. >>> from scipy import signal
  289. >>> rng = np.random.default_rng()
  290. >>> x = rng.standard_normal(1000)
  291. >>> y = np.concatenate([rng.standard_normal(100), x])
  292. >>> correlation = signal.correlate(x, y, mode="full")
  293. >>> lags = signal.correlation_lags(x.size, y.size, mode="full")
  294. >>> lag = lags[np.argmax(correlation)]
  295. """
  296. # calculate lag ranges in different modes of operation
  297. if mode == "full":
  298. # the output is the full discrete linear convolution
  299. # of the inputs. (Default)
  300. lags = np.arange(-in2_len + 1, in1_len)
  301. elif mode == "same":
  302. # the output is the same size as `in1`, centered
  303. # with respect to the 'full' output.
  304. # calculate the full output
  305. lags = np.arange(-in2_len + 1, in1_len)
  306. # determine the midpoint in the full output
  307. mid = lags.size // 2
  308. # determine lag_bound to be used with respect
  309. # to the midpoint
  310. lag_bound = in1_len // 2
  311. # calculate lag ranges for even and odd scenarios
  312. if in1_len % 2 == 0:
  313. lags = lags[(mid-lag_bound):(mid+lag_bound)]
  314. else:
  315. lags = lags[(mid-lag_bound):(mid+lag_bound)+1]
  316. elif mode == "valid":
  317. # the output consists only of those elements that do not
  318. # rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
  319. # must be at least as large as the other in every dimension.
  320. # the lag_bound will be either negative or positive
  321. # this let's us infer how to present the lag range
  322. lag_bound = in1_len - in2_len
  323. if lag_bound >= 0:
  324. lags = np.arange(lag_bound + 1)
  325. else:
  326. lags = np.arange(lag_bound, 1)
  327. else:
  328. raise ValueError(f"Mode {mode} is invalid")
  329. return lags
  330. def _centered(arr, newshape):
  331. # Return the center newshape portion of the array.
  332. newshape = np.asarray(newshape)
  333. currshape = np.array(arr.shape)
  334. startind = (currshape - newshape) // 2
  335. endind = startind + newshape
  336. myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
  337. return arr[tuple(myslice)]
  338. def _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False):
  339. """Handle the axes argument for frequency-domain convolution.
  340. Returns the inputs and axes in a standard form, eliminating redundant axes,
  341. swapping the inputs if necessary, and checking for various potential
  342. errors.
  343. Parameters
  344. ----------
  345. in1 : array
  346. First input.
  347. in2 : array
  348. Second input.
  349. mode : str {'full', 'valid', 'same'}, optional
  350. A string indicating the size of the output.
  351. See the documentation `fftconvolve` for more information.
  352. axes : list of ints
  353. Axes over which to compute the FFTs.
  354. sorted_axes : bool, optional
  355. If `True`, sort the axes.
  356. Default is `False`, do not sort.
  357. Returns
  358. -------
  359. in1 : array
  360. The first input, possible swapped with the second input.
  361. in2 : array
  362. The second input, possible swapped with the first input.
  363. axes : list of ints
  364. Axes over which to compute the FFTs.
  365. """
  366. s1 = in1.shape
  367. s2 = in2.shape
  368. noaxes = axes is None
  369. _, axes = _init_nd_shape_and_axes(in1, shape=None, axes=axes)
  370. if not noaxes and not len(axes):
  371. raise ValueError("when provided, axes cannot be empty")
  372. # Axes of length 1 can rely on broadcasting rules for multiply,
  373. # no fft needed.
  374. axes = [a for a in axes if s1[a] != 1 and s2[a] != 1]
  375. if sorted_axes:
  376. axes.sort()
  377. if not all(s1[a] == s2[a] or s1[a] == 1 or s2[a] == 1
  378. for a in range(in1.ndim) if a not in axes):
  379. raise ValueError("incompatible shapes for in1 and in2:"
  380. f" {s1} and {s2}")
  381. # Check that input sizes are compatible with 'valid' mode.
  382. if _inputs_swap_needed(mode, s1, s2, axes=axes):
  383. # Convolution is commutative; order doesn't have any effect on output.
  384. in1, in2 = in2, in1
  385. return in1, in2, axes
  386. def _freq_domain_conv(xp, in1, in2, axes, shape, calc_fast_len=False):
  387. """Convolve two arrays in the frequency domain.
  388. This function implements only base the FFT-related operations.
  389. Specifically, it converts the signals to the frequency domain, multiplies
  390. them, then converts them back to the time domain. Calculations of axes,
  391. shapes, convolution mode, etc. are implemented in higher level-functions,
  392. such as `fftconvolve` and `oaconvolve`. Those functions should be used
  393. instead of this one.
  394. Parameters
  395. ----------
  396. in1 : array_like
  397. First input.
  398. in2 : array_like
  399. Second input. Should have the same number of dimensions as `in1`.
  400. axes : array_like of ints
  401. Axes over which to compute the FFTs.
  402. shape : array_like of ints
  403. The sizes of the FFTs.
  404. calc_fast_len : bool, optional
  405. If `True`, set each value of `shape` to the next fast FFT length.
  406. Default is `False`, use `axes` as-is.
  407. Returns
  408. -------
  409. out : array
  410. An N-dimensional array containing the discrete linear convolution of
  411. `in1` with `in2`.
  412. """
  413. if not len(axes):
  414. return in1 * in2
  415. complex_result = (xp.isdtype(in1.dtype, 'complex floating') or
  416. xp.isdtype(in2.dtype, 'complex floating'))
  417. if calc_fast_len:
  418. # Speed up FFT by padding to optimal size.
  419. fshape = [
  420. sp_fft.next_fast_len(shape[a], not complex_result) for a in axes]
  421. else:
  422. fshape = shape
  423. if not complex_result:
  424. fft, ifft = sp_fft.rfftn, sp_fft.irfftn
  425. else:
  426. fft, ifft = sp_fft.fftn, sp_fft.ifftn
  427. if xp.isdtype(in1.dtype, 'integral'):
  428. in1 = xp.astype(in1, xp.float64)
  429. if xp.isdtype(in2.dtype, 'integral'):
  430. in2 = xp.astype(in2, xp.float64)
  431. sp1 = fft(in1, fshape, axes=axes)
  432. sp2 = fft(in2, fshape, axes=axes)
  433. ret = ifft(sp1 * sp2, fshape, axes=axes)
  434. if calc_fast_len:
  435. fslice = tuple([slice(sz) for sz in shape])
  436. ret = ret[fslice]
  437. return ret
  438. def _apply_conv_mode(ret, s1, s2, mode, axes, xp):
  439. """Calculate the convolution result shape based on the `mode` argument.
  440. Returns the result sliced to the correct size for the given mode.
  441. Parameters
  442. ----------
  443. ret : array
  444. The result array, with the appropriate shape for the 'full' mode.
  445. s1 : list of int
  446. The shape of the first input.
  447. s2 : list of int
  448. The shape of the second input.
  449. mode : str {'full', 'valid', 'same'}
  450. A string indicating the size of the output.
  451. See the documentation `fftconvolve` for more information.
  452. axes : list of ints
  453. Axes over which to compute the convolution.
  454. Returns
  455. -------
  456. ret : array
  457. A copy of `res`, sliced to the correct size for the given `mode`.
  458. """
  459. if mode == "full":
  460. return xp_copy(ret, xp=xp)
  461. elif mode == "same":
  462. return xp_copy(_centered(ret, s1), xp=xp)
  463. elif mode == "valid":
  464. shape_valid = [ret.shape[a] if a not in axes else s1[a] - s2[a] + 1
  465. for a in range(ret.ndim)]
  466. return xp_copy(_centered(ret, shape_valid), xp=xp)
  467. else:
  468. raise ValueError("acceptable mode flags are 'valid',"
  469. " 'same', or 'full'")
  470. def fftconvolve(in1, in2, mode="full", axes=None):
  471. """Convolve two N-dimensional arrays using FFT.
  472. Convolve `in1` and `in2` using the fast Fourier transform method, with
  473. the output size determined by the `mode` argument.
  474. This is generally much faster than `convolve` for large arrays (n > ~500),
  475. but can be slower when only a few output values are needed, and can only
  476. output float arrays (int or object array inputs will be cast to float).
  477. As of v0.19, `convolve` automatically chooses this method or the direct
  478. method based on an estimation of which is faster.
  479. Parameters
  480. ----------
  481. in1 : array_like
  482. First input.
  483. in2 : array_like
  484. Second input. Should have the same number of dimensions as `in1`.
  485. mode : str {'full', 'valid', 'same'}, optional
  486. A string indicating the size of the output:
  487. ``full``
  488. The output is the full discrete linear convolution
  489. of the inputs. (Default)
  490. ``valid``
  491. The output consists only of those elements that do not
  492. rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
  493. must be at least as large as the other in every dimension.
  494. ``same``
  495. The output is the same size as `in1`, centered
  496. with respect to the 'full' output.
  497. axes : int or array_like of ints or None, optional
  498. Axes over which to compute the convolution.
  499. The default is over all axes.
  500. Returns
  501. -------
  502. out : array
  503. An N-dimensional array containing a subset of the discrete linear
  504. convolution of `in1` with `in2`.
  505. See Also
  506. --------
  507. convolve : Uses the direct convolution or FFT convolution algorithm
  508. depending on which is faster.
  509. oaconvolve : Uses the overlap-add method to do convolution, which is
  510. generally faster when the input arrays are large and
  511. significantly different in size.
  512. Examples
  513. --------
  514. Autocorrelation of white noise is an impulse.
  515. >>> import numpy as np
  516. >>> from scipy import signal
  517. >>> rng = np.random.default_rng()
  518. >>> sig = rng.standard_normal(1000)
  519. >>> autocorr = signal.fftconvolve(sig, sig[::-1], mode='full')
  520. >>> import matplotlib.pyplot as plt
  521. >>> fig, (ax_orig, ax_mag) = plt.subplots(2, 1)
  522. >>> ax_orig.plot(sig)
  523. >>> ax_orig.set_title('White noise')
  524. >>> ax_mag.plot(np.arange(-len(sig)+1,len(sig)), autocorr)
  525. >>> ax_mag.set_title('Autocorrelation')
  526. >>> fig.tight_layout()
  527. >>> fig.show()
  528. Gaussian blur implemented using FFT convolution. Notice the dark borders
  529. around the image, due to the zero-padding beyond its boundaries.
  530. The `convolve2d` function allows for other types of image boundaries,
  531. but is far slower.
  532. >>> from scipy import datasets
  533. >>> face = datasets.face(gray=True)
  534. >>> kernel = np.outer(signal.windows.gaussian(70, 8),
  535. ... signal.windows.gaussian(70, 8))
  536. >>> blurred = signal.fftconvolve(face, kernel, mode='same')
  537. >>> fig, (ax_orig, ax_kernel, ax_blurred) = plt.subplots(3, 1,
  538. ... figsize=(6, 15))
  539. >>> ax_orig.imshow(face, cmap='gray')
  540. >>> ax_orig.set_title('Original')
  541. >>> ax_orig.set_axis_off()
  542. >>> ax_kernel.imshow(kernel, cmap='gray')
  543. >>> ax_kernel.set_title('Gaussian kernel')
  544. >>> ax_kernel.set_axis_off()
  545. >>> ax_blurred.imshow(blurred, cmap='gray')
  546. >>> ax_blurred.set_title('Blurred')
  547. >>> ax_blurred.set_axis_off()
  548. >>> fig.show()
  549. """
  550. xp = array_namespace(in1, in2)
  551. in1 = xp.asarray(in1)
  552. in2 = xp.asarray(in2)
  553. if in1.ndim == in2.ndim == 0: # scalar inputs
  554. return in1 * in2
  555. elif in1.ndim != in2.ndim:
  556. raise ValueError("in1 and in2 should have the same dimensionality")
  557. elif xp_size(in1) == 0 or xp_size(in2) == 0: # empty arrays
  558. return xp.asarray([])
  559. in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes,
  560. sorted_axes=False)
  561. s1 = in1.shape
  562. s2 = in2.shape
  563. shape = [max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1
  564. for i in range(in1.ndim)]
  565. ret = _freq_domain_conv(xp, in1, in2, axes, shape, calc_fast_len=True)
  566. return _apply_conv_mode(ret, s1, s2, mode, axes, xp=xp)
  567. def _calc_oa_lens(s1, s2):
  568. """Calculate the optimal FFT lengths for overlap-add convolution.
  569. The calculation is done for a single dimension.
  570. Parameters
  571. ----------
  572. s1 : int
  573. Size of the dimension for the first array.
  574. s2 : int
  575. Size of the dimension for the second array.
  576. Returns
  577. -------
  578. block_size : int
  579. The size of the FFT blocks.
  580. overlap : int
  581. The amount of overlap between two blocks.
  582. in1_step : int
  583. The size of each step for the first array.
  584. in2_step : int
  585. The size of each step for the first array.
  586. """
  587. # Set up the arguments for the conventional FFT approach.
  588. fallback = (s1+s2-1, None, s1, s2)
  589. # Use conventional FFT convolve if sizes are same.
  590. if s1 == s2 or s1 == 1 or s2 == 1:
  591. return fallback
  592. if s2 > s1:
  593. s1, s2 = s2, s1
  594. swapped = True
  595. else:
  596. swapped = False
  597. # There cannot be a useful block size if s2 is more than half of s1.
  598. if s2 >= s1/2:
  599. return fallback
  600. # Derivation of optimal block length
  601. # For original formula see:
  602. # https://en.wikipedia.org/wiki/Overlap-add_method
  603. #
  604. # Formula:
  605. # K = overlap = s2-1
  606. # N = block_size
  607. # C = complexity
  608. # e = exponential, exp(1)
  609. #
  610. # C = (N*(log2(N)+1))/(N-K)
  611. # C = (N*log2(2N))/(N-K)
  612. # C = N/(N-K) * log2(2N)
  613. # C1 = N/(N-K)
  614. # C2 = log2(2N) = ln(2N)/ln(2)
  615. #
  616. # dC1/dN = (1*(N-K)-N)/(N-K)^2 = -K/(N-K)^2
  617. # dC2/dN = 2/(2*N*ln(2)) = 1/(N*ln(2))
  618. #
  619. # dC/dN = dC1/dN*C2 + dC2/dN*C1
  620. # dC/dN = -K*ln(2N)/(ln(2)*(N-K)^2) + N/(N*ln(2)*(N-K))
  621. # dC/dN = -K*ln(2N)/(ln(2)*(N-K)^2) + 1/(ln(2)*(N-K))
  622. # dC/dN = -K*ln(2N)/(ln(2)*(N-K)^2) + (N-K)/(ln(2)*(N-K)^2)
  623. # dC/dN = (-K*ln(2N) + (N-K)/(ln(2)*(N-K)^2)
  624. # dC/dN = (N - K*ln(2N) - K)/(ln(2)*(N-K)^2)
  625. #
  626. # Solve for minimum, where dC/dN = 0
  627. # 0 = (N - K*ln(2N) - K)/(ln(2)*(N-K)^2)
  628. # 0 * ln(2)*(N-K)^2 = N - K*ln(2N) - K
  629. # 0 = N - K*ln(2N) - K
  630. # 0 = N - K*(ln(2N) + 1)
  631. # 0 = N - K*ln(2Ne)
  632. # N = K*ln(2Ne)
  633. # N/K = ln(2Ne)
  634. #
  635. # e^(N/K) = e^ln(2Ne)
  636. # e^(N/K) = 2Ne
  637. # 1/e^(N/K) = 1/(2*N*e)
  638. # e^(N/-K) = 1/(2*N*e)
  639. # e^(N/-K) = K/N*1/(2*K*e)
  640. # N/K*e^(N/-K) = 1/(2*e*K)
  641. # N/-K*e^(N/-K) = -1/(2*e*K)
  642. #
  643. # Using Lambert W function
  644. # https://en.wikipedia.org/wiki/Lambert_W_function
  645. # x = W(y) It is the solution to y = x*e^x
  646. # x = N/-K
  647. # y = -1/(2*e*K)
  648. #
  649. # N/-K = W(-1/(2*e*K))
  650. #
  651. # N = -K*W(-1/(2*e*K))
  652. overlap = s2-1
  653. opt_size = -overlap*lambertw(-1/(2*math.e*overlap), k=-1).real
  654. block_size = sp_fft.next_fast_len(math.ceil(opt_size))
  655. # Use conventional FFT convolve if there is only going to be one block.
  656. if block_size >= s1:
  657. return fallback
  658. if not swapped:
  659. in1_step = block_size-s2+1
  660. in2_step = s2
  661. else:
  662. in1_step = s2
  663. in2_step = block_size-s2+1
  664. return block_size, overlap, in1_step, in2_step
  665. # may want to look at moving xp_swapaxes and this to array-api-extra,
  666. # cross-ref https://github.com/data-apis/array-api-extra/issues/97
  667. def _split(x, indices_or_sections, axis, xp):
  668. """A simplified version of np.split, with `indices` being an list.
  669. """
  670. # https://github.com/numpy/numpy/blob/v2.2.0/numpy/lib/_shape_base_impl.py#L743
  671. Ntotal = x.shape[axis]
  672. # handle array case.
  673. Nsections = len(indices_or_sections) + 1
  674. div_points = [0] + list(indices_or_sections) + [Ntotal]
  675. sub_arys = []
  676. sary = xp_swapaxes(x, axis, 0, xp=xp)
  677. for i in range(Nsections):
  678. st = div_points[i]
  679. end = div_points[i + 1]
  680. sub_arys.append(xp_swapaxes(sary[st:end, ...], axis, 0, xp=xp))
  681. return sub_arys
  682. def oaconvolve(in1, in2, mode="full", axes=None):
  683. """Convolve two N-dimensional arrays using the overlap-add method.
  684. Convolve `in1` and `in2` using the overlap-add method, with
  685. the output size determined by the `mode` argument.
  686. This is generally much faster than `convolve` for large arrays (n > ~500),
  687. and generally much faster than `fftconvolve` when one array is much
  688. larger than the other, but can be slower when only a few output values are
  689. needed or when the arrays are very similar in shape, and can only
  690. output float arrays (int or object array inputs will be cast to float).
  691. Parameters
  692. ----------
  693. in1 : array_like
  694. First input.
  695. in2 : array_like
  696. Second input. Should have the same number of dimensions as `in1`.
  697. mode : str {'full', 'valid', 'same'}, optional
  698. A string indicating the size of the output:
  699. ``full``
  700. The output is the full discrete linear convolution
  701. of the inputs. (Default)
  702. ``valid``
  703. The output consists only of those elements that do not
  704. rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
  705. must be at least as large as the other in every dimension.
  706. ``same``
  707. The output is the same size as `in1`, centered
  708. with respect to the 'full' output.
  709. axes : int or array_like of ints or None, optional
  710. Axes over which to compute the convolution.
  711. The default is over all axes.
  712. Returns
  713. -------
  714. out : array
  715. An N-dimensional array containing a subset of the discrete linear
  716. convolution of `in1` with `in2`.
  717. See Also
  718. --------
  719. convolve : Uses the direct convolution or FFT convolution algorithm
  720. depending on which is faster.
  721. fftconvolve : An implementation of convolution using FFT.
  722. Notes
  723. -----
  724. .. versionadded:: 1.4.0
  725. References
  726. ----------
  727. .. [1] Wikipedia, "Overlap-add_method".
  728. https://en.wikipedia.org/wiki/Overlap-add_method
  729. .. [2] Richard G. Lyons. Understanding Digital Signal Processing,
  730. Third Edition, 2011. Chapter 13.10.
  731. ISBN 13: 978-0137-02741-5
  732. Examples
  733. --------
  734. Convolve a 100,000 sample signal with a 512-sample filter.
  735. >>> import numpy as np
  736. >>> from scipy import signal
  737. >>> rng = np.random.default_rng()
  738. >>> sig = rng.standard_normal(100000)
  739. >>> filt = signal.firwin(512, 0.01)
  740. >>> fsig = signal.oaconvolve(sig, filt)
  741. >>> import matplotlib.pyplot as plt
  742. >>> fig, (ax_orig, ax_mag) = plt.subplots(2, 1)
  743. >>> ax_orig.plot(sig)
  744. >>> ax_orig.set_title('White noise')
  745. >>> ax_mag.plot(fsig)
  746. >>> ax_mag.set_title('Filtered noise')
  747. >>> fig.tight_layout()
  748. >>> fig.show()
  749. """
  750. xp = array_namespace(in1, in2)
  751. in1 = xp.asarray(in1)
  752. in2 = xp.asarray(in2)
  753. if in1.ndim == in2.ndim == 0: # scalar inputs
  754. return in1 * in2
  755. elif in1.ndim != in2.ndim:
  756. raise ValueError("in1 and in2 should have the same dimensionality")
  757. elif in1.size == 0 or in2.size == 0: # empty arrays
  758. return xp.asarray([])
  759. elif in1.shape == in2.shape: # Equivalent to fftconvolve
  760. return fftconvolve(in1, in2, mode=mode, axes=axes)
  761. in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes,
  762. sorted_axes=True)
  763. s1 = in1.shape
  764. s2 = in2.shape
  765. if not axes:
  766. ret = in1 * in2
  767. return _apply_conv_mode(ret, s1, s2, mode, axes, xp)
  768. # Calculate this now since in1 is changed later
  769. shape_final = [None if i not in axes else
  770. s1[i] + s2[i] - 1 for i in range(in1.ndim)]
  771. # Calculate the block sizes for the output, steps, first and second inputs.
  772. # It is simpler to calculate them all together than doing them in separate
  773. # loops due to all the special cases that need to be handled.
  774. optimal_sizes = ((-1, -1, s1[i], s2[i]) if i not in axes else
  775. _calc_oa_lens(s1[i], s2[i]) for i in range(in1.ndim))
  776. block_size, overlaps, \
  777. in1_step, in2_step = zip(*optimal_sizes)
  778. # Fall back to fftconvolve if there is only one block in every dimension.
  779. if in1_step == s1 and in2_step == s2:
  780. return fftconvolve(in1, in2, mode=mode, axes=axes)
  781. # Figure out the number of steps and padding.
  782. # This would get too complicated in a list comprehension.
  783. nsteps1 = []
  784. nsteps2 = []
  785. pad_size1 = []
  786. pad_size2 = []
  787. for i in range(in1.ndim):
  788. if i not in axes:
  789. pad_size1 += [(0, 0)]
  790. pad_size2 += [(0, 0)]
  791. continue
  792. if s1[i] > in1_step[i]:
  793. curnstep1 = math.ceil((s1[i]+1)/in1_step[i])
  794. if (block_size[i] - overlaps[i])*curnstep1 < shape_final[i]:
  795. curnstep1 += 1
  796. curpad1 = curnstep1*in1_step[i] - s1[i]
  797. else:
  798. curnstep1 = 1
  799. curpad1 = 0
  800. if s2[i] > in2_step[i]:
  801. curnstep2 = math.ceil((s2[i]+1)/in2_step[i])
  802. if (block_size[i] - overlaps[i])*curnstep2 < shape_final[i]:
  803. curnstep2 += 1
  804. curpad2 = curnstep2*in2_step[i] - s2[i]
  805. else:
  806. curnstep2 = 1
  807. curpad2 = 0
  808. nsteps1 += [curnstep1]
  809. nsteps2 += [curnstep2]
  810. pad_size1 += [(0, curpad1)]
  811. pad_size2 += [(0, curpad2)]
  812. # Pad the array to a size that can be reshaped to the desired shape
  813. # if necessary.
  814. if not all(curpad == (0, 0) for curpad in pad_size1):
  815. in1 = xpx.pad(in1, pad_size1, mode='constant', constant_values=0, xp=xp)
  816. if not all(curpad == (0, 0) for curpad in pad_size2):
  817. in2 = xpx.pad(in2, pad_size2, mode='constant', constant_values=0, xp=xp)
  818. # Reshape the overlap-add parts to input block sizes.
  819. split_axes = [iax+i for i, iax in enumerate(axes)]
  820. fft_axes = [iax+1 for iax in split_axes]
  821. # We need to put each new dimension before the corresponding dimension
  822. # being reshaped in order to get the data in the right layout at the end.
  823. reshape_size1 = list(in1_step)
  824. reshape_size2 = list(in2_step)
  825. for i, iax in enumerate(split_axes):
  826. reshape_size1.insert(iax, nsteps1[i])
  827. reshape_size2.insert(iax, nsteps2[i])
  828. in1 = xp.reshape(in1, tuple(reshape_size1))
  829. in2 = xp.reshape(in2, tuple(reshape_size2))
  830. # Do the convolution.
  831. fft_shape = [block_size[i] for i in axes]
  832. ret = _freq_domain_conv(xp, in1, in2, fft_axes, fft_shape, calc_fast_len=False)
  833. # Do the overlap-add.
  834. for ax, ax_fft, ax_split in zip(axes, fft_axes, split_axes):
  835. overlap = overlaps[ax]
  836. if overlap is None:
  837. continue
  838. ret, overpart = _split(ret, [-overlap], ax_fft, xp=xp)
  839. overpart = _split(overpart, [-1], ax_split, xp=xp)[0]
  840. ret_overpart = _split(ret, [overlap], ax_fft, xp=xp)[0]
  841. ret_overpart = _split(ret_overpart, [1], ax_split, xp)[1]
  842. ret_overpart += overpart
  843. # Reshape back to the correct dimensionality.
  844. shape_ret = [ret.shape[i] if i not in fft_axes else
  845. ret.shape[i]*ret.shape[i-1]
  846. for i in range(ret.ndim) if i not in split_axes]
  847. ret = xp.reshape(ret, tuple(shape_ret))
  848. # Slice to the correct size.
  849. slice_final = tuple([slice(islice) for islice in shape_final])
  850. ret = ret[slice_final]
  851. return _apply_conv_mode(ret, s1, s2, mode, axes, xp)
  852. def _numeric_arrays(arrays, kinds='buifc', xp=None):
  853. """
  854. See if a list of arrays are all numeric.
  855. Parameters
  856. ----------
  857. arrays : array or list of arrays
  858. arrays to check if numeric.
  859. kinds : string-like
  860. The dtypes of the arrays to be checked. If the dtype.kind of
  861. the ndarrays are not in this string the function returns False and
  862. otherwise returns True.
  863. """
  864. if xp is None:
  865. xp = array_namespace(*arrays)
  866. if not is_numpy(xp):
  867. return True
  868. if type(arrays) is np.ndarray:
  869. return arrays.dtype.kind in kinds
  870. for array_ in arrays:
  871. if array_.dtype.kind not in kinds:
  872. return False
  873. return True
  874. def _conv_ops(x_shape, h_shape, mode):
  875. """
  876. Find the number of operations required for direct/fft methods of
  877. convolution. The direct operations were recorded by making a dummy class to
  878. record the number of operations by overriding ``__mul__`` and ``__add__``.
  879. The FFT operations rely on the (well-known) computational complexity of the
  880. FFT (and the implementation of ``_freq_domain_conv``).
  881. """
  882. if mode == "full":
  883. out_shape = [n + k - 1 for n, k in zip(x_shape, h_shape)]
  884. elif mode == "valid":
  885. out_shape = [abs(n - k) + 1 for n, k in zip(x_shape, h_shape)]
  886. elif mode == "same":
  887. out_shape = x_shape
  888. else:
  889. raise ValueError("Acceptable mode flags are 'valid',"
  890. f" 'same', or 'full', not mode={mode}")
  891. s1, s2 = x_shape, h_shape
  892. if len(x_shape) == 1:
  893. s1, s2 = s1[0], s2[0]
  894. if mode == "full":
  895. direct_ops = s1 * s2
  896. elif mode == "valid":
  897. direct_ops = (s2 - s1 + 1) * s1 if s2 >= s1 else (s1 - s2 + 1) * s2
  898. elif mode == "same":
  899. direct_ops = (s1 * s2 if s1 < s2 else
  900. s1 * s2 - (s2 // 2) * ((s2 + 1) // 2))
  901. else:
  902. if mode == "full":
  903. direct_ops = min(_prod(s1), _prod(s2)) * _prod(out_shape)
  904. elif mode == "valid":
  905. direct_ops = min(_prod(s1), _prod(s2)) * _prod(out_shape)
  906. elif mode == "same":
  907. direct_ops = _prod(s1) * _prod(s2)
  908. full_out_shape = [n + k - 1 for n, k in zip(x_shape, h_shape)]
  909. N = _prod(full_out_shape)
  910. fft_ops = 3 * N * np.log(N) # 3 separate FFTs of size full_out_shape
  911. return fft_ops, direct_ops
  912. def _fftconv_faster(x, h, mode):
  913. """
  914. See if using fftconvolve or convolve is faster.
  915. Parameters
  916. ----------
  917. x : np.ndarray
  918. Signal
  919. h : np.ndarray
  920. Kernel
  921. mode : str
  922. Mode passed to convolve
  923. Returns
  924. -------
  925. fft_faster : bool
  926. Notes
  927. -----
  928. See docstring of `choose_conv_method` for details on tuning hardware.
  929. See pull request 11031 for more detail:
  930. https://github.com/scipy/scipy/pull/11031.
  931. """
  932. fft_ops, direct_ops = _conv_ops(x.shape, h.shape, mode)
  933. offset = -1e-3 if x.ndim == 1 else -1e-4
  934. constants = {
  935. "valid": (1.89095737e-9, 2.1364985e-10, offset),
  936. "full": (1.7649070e-9, 2.1414831e-10, offset),
  937. "same": (3.2646654e-9, 2.8478277e-10, offset)
  938. if h.size <= x.size
  939. else (3.21635404e-9, 1.1773253e-8, -1e-5),
  940. } if x.ndim == 1 else {
  941. "valid": (1.85927e-9, 2.11242e-8, offset),
  942. "full": (1.99817e-9, 1.66174e-8, offset),
  943. "same": (2.04735e-9, 1.55367e-8, offset),
  944. }
  945. O_fft, O_direct, O_offset = constants[mode]
  946. return O_fft * fft_ops < O_direct * direct_ops + O_offset
  947. def _reverse_and_conj(x, xp):
  948. """
  949. Reverse array `x` in all dimensions and perform the complex conjugate
  950. """
  951. if not is_torch(xp):
  952. reverse = (slice(None, None, -1),) * x.ndim
  953. x_rev = x[reverse]
  954. else:
  955. # NB: is a copy, not a view as torch does not allow negative indices
  956. # in slices, x-ref https://github.com/pytorch/pytorch/issues/59786
  957. x_rev = xp.flip(x)
  958. # cf https://github.com/data-apis/array-api/issues/824
  959. if xp.isdtype(x.dtype, 'complex floating'):
  960. return xp.conj(x_rev)
  961. else:
  962. return x_rev
  963. def _np_conv_ok(volume, kernel, mode, xp):
  964. """
  965. See if numpy supports convolution of `volume` and `kernel` (i.e. both are
  966. 1D ndarrays and of the appropriate shape). NumPy's 'same' mode uses the
  967. size of the larger input, while SciPy's uses the size of the first input.
  968. Invalid mode strings will return False and be caught by the calling func.
  969. """
  970. if volume.ndim == kernel.ndim == 1:
  971. if mode in ('full', 'valid'):
  972. return True
  973. elif mode == 'same':
  974. return xp_size(volume) >= xp_size(kernel)
  975. else:
  976. return False
  977. def _timeit_fast(stmt="pass", setup="pass", repeat=3):
  978. """
  979. Returns the time the statement/function took, in seconds.
  980. Faster, less precise version of IPython's timeit. `stmt` can be a statement
  981. written as a string or a callable.
  982. Will do only 1 loop (like IPython's timeit) with no repetitions
  983. (unlike IPython) for very slow functions. For fast functions, only does
  984. enough loops to take 5 ms, which seems to produce similar results (on
  985. Windows at least), and avoids doing an extraneous cycle that isn't
  986. measured.
  987. """
  988. timer = timeit.Timer(stmt, setup)
  989. # determine number of calls per rep so total time for 1 rep >= 5 ms
  990. x = 0
  991. for p in range(0, 10):
  992. number = 10**p
  993. x = timer.timeit(number) # seconds
  994. if x >= 5e-3 / 10: # 5 ms for final test, 1/10th that for this one
  995. break
  996. if x > 1: # second
  997. # If it's macroscopic, don't bother with repetitions
  998. best = x
  999. else:
  1000. number *= 10
  1001. r = timer.repeat(repeat, number)
  1002. best = min(r)
  1003. sec = best / number
  1004. return sec
  1005. def choose_conv_method(in1, in2, mode='full', measure=False):
  1006. """
  1007. Find the fastest convolution/correlation method.
  1008. This primarily exists to be called during the ``method='auto'`` option in
  1009. `convolve` and `correlate`. It can also be used to determine the value of
  1010. ``method`` for many different convolutions of the same dtype/shape.
  1011. In addition, it supports timing the convolution to adapt the value of
  1012. ``method`` to a particular set of inputs and/or hardware.
  1013. Parameters
  1014. ----------
  1015. in1 : array_like
  1016. The first argument passed into the convolution function.
  1017. in2 : array_like
  1018. The second argument passed into the convolution function.
  1019. mode : str {'full', 'valid', 'same'}, optional
  1020. A string indicating the size of the output:
  1021. ``full``
  1022. The output is the full discrete linear convolution
  1023. of the inputs. (Default)
  1024. ``valid``
  1025. The output consists only of those elements that do not
  1026. rely on the zero-padding.
  1027. ``same``
  1028. The output is the same size as `in1`, centered
  1029. with respect to the 'full' output.
  1030. measure : bool, optional
  1031. If True, run and time the convolution of `in1` and `in2` with both
  1032. methods and return the fastest. If False (default), predict the fastest
  1033. method using precomputed values.
  1034. Returns
  1035. -------
  1036. method : str
  1037. A string indicating which convolution method is fastest, either
  1038. 'direct' or 'fft'
  1039. times : dict, optional
  1040. A dictionary containing the times (in seconds) needed for each method.
  1041. This value is only returned if ``measure=True``.
  1042. See Also
  1043. --------
  1044. convolve
  1045. correlate
  1046. Notes
  1047. -----
  1048. Generally, this method is 99% accurate for 2D signals and 85% accurate
  1049. for 1D signals for randomly chosen input sizes. For precision, use
  1050. ``measure=True`` to find the fastest method by timing the convolution.
  1051. This can be used to avoid the minimal overhead of finding the fastest
  1052. ``method`` later, or to adapt the value of ``method`` to a particular set
  1053. of inputs.
  1054. Experiments were run on an Amazon EC2 r5a.2xlarge machine to test this
  1055. function. These experiments measured the ratio between the time required
  1056. when using ``method='auto'`` and the time required for the fastest method
  1057. (i.e., ``ratio = time_auto / min(time_fft, time_direct)``). In these
  1058. experiments, we found:
  1059. * There is a 95% chance of this ratio being less than 1.5 for 1D signals
  1060. and a 99% chance of being less than 2.5 for 2D signals.
  1061. * The ratio was always less than 2.5/5 for 1D/2D signals respectively.
  1062. * This function is most inaccurate for 1D convolutions that take between 1
  1063. and 10 milliseconds with ``method='direct'``. A good proxy for this
  1064. (at least in our experiments) is ``1e6 <= in1.size * in2.size <= 1e7``.
  1065. The 2D results almost certainly generalize to 3D/4D/etc because the
  1066. implementation is the same (the 1D implementation is different).
  1067. All the numbers above are specific to the EC2 machine. However, we did find
  1068. that this function generalizes fairly decently across hardware. The speed
  1069. tests were of similar quality (and even slightly better) than the same
  1070. tests performed on the machine to tune this function's numbers (a mid-2014
  1071. 15-inch MacBook Pro with 16GB RAM and a 2.5GHz Intel i7 processor).
  1072. There are cases when `fftconvolve` supports the inputs but this function
  1073. returns `direct` (e.g., to protect against floating point integer
  1074. precision).
  1075. .. versionadded:: 0.19
  1076. Examples
  1077. --------
  1078. Estimate the fastest method for a given input:
  1079. >>> import numpy as np
  1080. >>> from scipy import signal
  1081. >>> rng = np.random.default_rng()
  1082. >>> img = rng.random((32, 32))
  1083. >>> filter = rng.random((8, 8))
  1084. >>> method = signal.choose_conv_method(img, filter, mode='same')
  1085. >>> method
  1086. 'fft'
  1087. This can then be applied to other arrays of the same dtype and shape:
  1088. >>> img2 = rng.random((32, 32))
  1089. >>> filter2 = rng.random((8, 8))
  1090. >>> corr2 = signal.correlate(img2, filter2, mode='same', method=method)
  1091. >>> conv2 = signal.convolve(img2, filter2, mode='same', method=method)
  1092. The output of this function (``method``) works with `correlate` and
  1093. `convolve`.
  1094. """
  1095. xp = array_namespace(in1, in2)
  1096. volume = xp.asarray(in1)
  1097. kernel = xp.asarray(in2)
  1098. if measure:
  1099. times = {}
  1100. for method in ['fft', 'direct']:
  1101. times[method] = _timeit_fast(lambda: convolve(volume, kernel,
  1102. mode=mode, method=method))
  1103. chosen_method = 'fft' if times['fft'] < times['direct'] else 'direct'
  1104. return chosen_method, times
  1105. # for integer input,
  1106. # catch when more precision required than float provides (representing an
  1107. # integer as float can lose precision in fftconvolve if larger than 2**52)
  1108. if any([_numeric_arrays([x], kinds='ui', xp=xp) for x in [volume, kernel]]):
  1109. max_value = int(xp.max(xp.abs(volume))) * int(xp.max(xp.abs(kernel)))
  1110. max_value *= int(min(xp_size(volume), xp_size(kernel)))
  1111. if max_value > 2**np.finfo('float').nmant - 1:
  1112. return 'direct'
  1113. if _numeric_arrays([volume, kernel], kinds='b', xp=xp):
  1114. return 'direct'
  1115. if _numeric_arrays([volume, kernel], xp=xp):
  1116. if _fftconv_faster(volume, kernel, mode):
  1117. return 'fft'
  1118. return 'direct'
  1119. def convolve(in1, in2, mode='full', method='auto'):
  1120. """
  1121. Convolve two N-dimensional arrays.
  1122. Convolve `in1` and `in2`, with the output size determined by the
  1123. `mode` argument.
  1124. Parameters
  1125. ----------
  1126. in1 : array_like
  1127. First input.
  1128. in2 : array_like
  1129. Second input. Should have the same number of dimensions as `in1`.
  1130. mode : str {'full', 'valid', 'same'}, optional
  1131. A string indicating the size of the output:
  1132. ``full``
  1133. The output is the full discrete linear convolution
  1134. of the inputs. (Default)
  1135. ``valid``
  1136. The output consists only of those elements that do not
  1137. rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
  1138. must be at least as large as the other in every dimension.
  1139. ``same``
  1140. The output is the same size as `in1`, centered
  1141. with respect to the 'full' output.
  1142. method : str {'auto', 'direct', 'fft'}, optional
  1143. A string indicating which method to use to calculate the convolution.
  1144. ``direct``
  1145. The convolution is determined directly from sums, the definition of
  1146. convolution.
  1147. ``fft``
  1148. The Fourier Transform is used to perform the convolution by calling
  1149. `fftconvolve`.
  1150. ``auto``
  1151. Automatically chooses direct or Fourier method based on an estimate
  1152. of which is faster (default). See Notes for more detail.
  1153. .. versionadded:: 0.19.0
  1154. Returns
  1155. -------
  1156. convolve : array
  1157. An N-dimensional array containing a subset of the discrete linear
  1158. convolution of `in1` with `in2`.
  1159. Warns
  1160. -----
  1161. RuntimeWarning
  1162. Use of the FFT convolution on input containing NAN or INF will lead
  1163. to the entire output being NAN or INF. Use method='direct' when your
  1164. input contains NAN or INF values.
  1165. See Also
  1166. --------
  1167. numpy.polymul : performs polynomial multiplication (same operation, but
  1168. also accepts poly1d objects)
  1169. choose_conv_method : chooses the fastest appropriate convolution method
  1170. fftconvolve : Always uses the FFT method.
  1171. oaconvolve : Uses the overlap-add method to do convolution, which is
  1172. generally faster when the input arrays are large and
  1173. significantly different in size.
  1174. Notes
  1175. -----
  1176. By default, `convolve` and `correlate` use ``method='auto'``, which calls
  1177. `choose_conv_method` to choose the fastest method using pre-computed
  1178. values (`choose_conv_method` can also measure real-world timing with a
  1179. keyword argument). Because `fftconvolve` relies on floating point numbers,
  1180. there are certain constraints that may force ``method='direct'`` (more detail
  1181. in `choose_conv_method` docstring).
  1182. Examples
  1183. --------
  1184. Smooth a square pulse using a Hann window:
  1185. >>> import numpy as np
  1186. >>> from scipy import signal
  1187. >>> sig = np.repeat([0., 1., 0.], 100)
  1188. >>> win = signal.windows.hann(50)
  1189. >>> filtered = signal.convolve(sig, win, mode='same') / sum(win)
  1190. >>> import matplotlib.pyplot as plt
  1191. >>> fig, (ax_orig, ax_win, ax_filt) = plt.subplots(3, 1, sharex=True)
  1192. >>> ax_orig.plot(sig)
  1193. >>> ax_orig.set_title('Original pulse')
  1194. >>> ax_orig.margins(0, 0.1)
  1195. >>> ax_win.plot(win)
  1196. >>> ax_win.set_title('Filter impulse response')
  1197. >>> ax_win.margins(0, 0.1)
  1198. >>> ax_filt.plot(filtered)
  1199. >>> ax_filt.set_title('Filtered signal')
  1200. >>> ax_filt.margins(0, 0.1)
  1201. >>> fig.tight_layout()
  1202. >>> fig.show()
  1203. """
  1204. xp = array_namespace(in1, in2)
  1205. volume = xp.asarray(in1)
  1206. kernel = xp.asarray(in2)
  1207. if volume.ndim == kernel.ndim == 0:
  1208. return volume * kernel
  1209. elif volume.ndim != kernel.ndim:
  1210. raise ValueError("volume and kernel should have the same "
  1211. "dimensionality")
  1212. if _inputs_swap_needed(mode, volume.shape, kernel.shape):
  1213. # Convolution is commutative; order doesn't have any effect on output
  1214. volume, kernel = kernel, volume
  1215. if method == 'auto':
  1216. method = choose_conv_method(volume, kernel, mode=mode)
  1217. if method == 'fft':
  1218. out = fftconvolve(volume, kernel, mode=mode)
  1219. result_type = xp.result_type(volume, kernel)
  1220. if xp.isdtype(result_type, 'integral'):
  1221. out = xp.round(out)
  1222. if xp.isnan(xp.reshape(out, (-1,))[0]) or xp.isinf(xp.reshape(out, (-1,))[0]):
  1223. warnings.warn("Use of fft convolution on input with NAN or inf"
  1224. " results in NAN or inf output. Consider using"
  1225. " method='direct' instead.",
  1226. category=RuntimeWarning, stacklevel=2)
  1227. return xp.astype(out, result_type)
  1228. elif method == 'direct':
  1229. # fastpath to faster numpy.convolve for 1d inputs when possible
  1230. if _np_conv_ok(volume, kernel, mode, xp):
  1231. # convert to numpy and back
  1232. a_volume = np.asarray(volume)
  1233. a_kernel = np.asarray(kernel)
  1234. out = np.convolve(a_volume, a_kernel, mode)
  1235. return xp.asarray(out)
  1236. return correlate(volume, _reverse_and_conj(kernel, xp), mode, 'direct')
  1237. else:
  1238. raise ValueError("Acceptable method flags are 'auto',"
  1239. " 'direct', or 'fft'.")
  1240. def order_filter(a, domain, rank):
  1241. """
  1242. Perform an order filter on an N-D array.
  1243. Perform an order filter on the array in. The domain argument acts as a
  1244. mask centered over each pixel. The non-zero elements of domain are
  1245. used to select elements surrounding each input pixel which are placed
  1246. in a list. The list is sorted, and the output for that pixel is the
  1247. element corresponding to rank in the sorted list.
  1248. Parameters
  1249. ----------
  1250. a : ndarray
  1251. The N-dimensional input array.
  1252. domain : array_like
  1253. A mask array with the same number of dimensions as `a`.
  1254. Each dimension should have an odd number of elements.
  1255. rank : int
  1256. A non-negative integer which selects the element from the
  1257. sorted list (0 corresponds to the smallest element, 1 is the
  1258. next smallest element, etc.).
  1259. Returns
  1260. -------
  1261. out : ndarray
  1262. The results of the order filter in an array with the same
  1263. shape as `a`.
  1264. Examples
  1265. --------
  1266. >>> import numpy as np
  1267. >>> from scipy import signal
  1268. >>> x = np.arange(25).reshape(5, 5)
  1269. >>> domain = np.identity(3)
  1270. >>> x
  1271. array([[ 0, 1, 2, 3, 4],
  1272. [ 5, 6, 7, 8, 9],
  1273. [10, 11, 12, 13, 14],
  1274. [15, 16, 17, 18, 19],
  1275. [20, 21, 22, 23, 24]])
  1276. >>> signal.order_filter(x, domain, 0)
  1277. array([[ 0, 0, 0, 0, 0],
  1278. [ 0, 0, 1, 2, 0],
  1279. [ 0, 5, 6, 7, 0],
  1280. [ 0, 10, 11, 12, 0],
  1281. [ 0, 0, 0, 0, 0]])
  1282. >>> signal.order_filter(x, domain, 2)
  1283. array([[ 6, 7, 8, 9, 4],
  1284. [ 11, 12, 13, 14, 9],
  1285. [ 16, 17, 18, 19, 14],
  1286. [ 21, 22, 23, 24, 19],
  1287. [ 20, 21, 22, 23, 24]])
  1288. """
  1289. xp = array_namespace(a, domain)
  1290. domain = xp.asarray(domain)
  1291. for dimsize in domain.shape:
  1292. if (dimsize % 2) != 1:
  1293. raise ValueError("Each dimension of domain argument "
  1294. "should have an odd number of elements.")
  1295. a = xp.asarray(a)
  1296. if not (
  1297. xp.isdtype(a.dtype, "integral") or a.dtype in (xp.float32, xp.float64)
  1298. ):
  1299. raise ValueError(f"dtype={a.dtype} is not supported by order_filter")
  1300. result = ndimage.rank_filter(a, rank, footprint=domain, mode='constant')
  1301. return result
  1302. def medfilt(volume, kernel_size=None):
  1303. """
  1304. Perform a median filter on an N-dimensional array.
  1305. Apply a median filter to the input array using a local window-size
  1306. given by `kernel_size`. The array will automatically be zero-padded.
  1307. Parameters
  1308. ----------
  1309. volume : array_like
  1310. An N-dimensional input array.
  1311. kernel_size : array_like, optional
  1312. A scalar or an N-length list giving the size of the median filter
  1313. window in each dimension. Elements of `kernel_size` should be odd.
  1314. If `kernel_size` is a scalar, then this scalar is used as the size in
  1315. each dimension. Default size is 3 for each dimension.
  1316. Returns
  1317. -------
  1318. out : ndarray
  1319. An array the same size as input containing the median filtered
  1320. result.
  1321. Warns
  1322. -----
  1323. UserWarning
  1324. If array size is smaller than kernel size along any dimension
  1325. See Also
  1326. --------
  1327. scipy.ndimage.median_filter
  1328. scipy.signal.medfilt2d
  1329. """
  1330. xp = array_namespace(volume)
  1331. volume = xp.asarray(volume)
  1332. if volume.ndim == 0:
  1333. volume = xpx.atleast_nd(volume, ndim=1, xp=xp)
  1334. if not (xp.isdtype(volume.dtype, "integral") or
  1335. volume.dtype in [xp.float32, xp.float64]):
  1336. raise ValueError(f"dtype={volume.dtype} is not supported by medfilt")
  1337. if kernel_size is None:
  1338. kernel_size = [3] * volume.ndim
  1339. kernel_size = xp.asarray(kernel_size)
  1340. if kernel_size.shape == ():
  1341. kernel_size = xp.repeat(kernel_size, volume.ndim)
  1342. for k in range(volume.ndim):
  1343. if (kernel_size[k] % 2) != 1:
  1344. raise ValueError("Each element of kernel_size should be odd.")
  1345. if any(k > s for k, s in zip(kernel_size, volume.shape)):
  1346. warnings.warn('kernel_size exceeds volume extent: the volume will be '
  1347. 'zero-padded.',
  1348. stacklevel=2)
  1349. size = math.prod(kernel_size)
  1350. result = ndimage.rank_filter(volume, size // 2, size=kernel_size,
  1351. mode='constant')
  1352. return result
  1353. def wiener(im, mysize=None, noise=None):
  1354. """
  1355. Perform a Wiener filter on an N-dimensional array.
  1356. Apply a Wiener filter to the N-dimensional array `im`.
  1357. Parameters
  1358. ----------
  1359. im : ndarray
  1360. An N-dimensional array.
  1361. mysize : int or array_like, optional
  1362. A scalar or an N-length list giving the size of the Wiener filter
  1363. window in each dimension. Elements of mysize should be odd.
  1364. If mysize is a scalar, then this scalar is used as the size
  1365. in each dimension.
  1366. noise : float, optional
  1367. The noise-power to use. If None, then noise is estimated as the
  1368. average of the local variance of the input.
  1369. Returns
  1370. -------
  1371. out : ndarray
  1372. Wiener filtered result with the same shape as `im`.
  1373. Notes
  1374. -----
  1375. This implementation is similar to wiener2 in Matlab/Octave.
  1376. For more details see [1]_
  1377. References
  1378. ----------
  1379. .. [1] Lim, Jae S., Two-Dimensional Signal and Image Processing,
  1380. Englewood Cliffs, NJ, Prentice Hall, 1990, p. 548.
  1381. Examples
  1382. --------
  1383. >>> from scipy.datasets import face
  1384. >>> from scipy.signal import wiener
  1385. >>> import matplotlib.pyplot as plt
  1386. >>> import numpy as np
  1387. >>> rng = np.random.default_rng()
  1388. >>> img = rng.random((40, 40)) #Create a random image
  1389. >>> filtered_img = wiener(img, (5, 5)) #Filter the image
  1390. >>> f, (plot1, plot2) = plt.subplots(1, 2)
  1391. >>> plot1.imshow(img)
  1392. >>> plot2.imshow(filtered_img)
  1393. >>> plt.show()
  1394. """
  1395. xp = array_namespace(im)
  1396. im = xp.asarray(im)
  1397. if mysize is None:
  1398. mysize = [3] * im.ndim
  1399. mysize_arr = xp.asarray(mysize)
  1400. if mysize_arr.shape == ():
  1401. mysize = [mysize] * im.ndim
  1402. # Estimate the local mean
  1403. size = math.prod(mysize)
  1404. lMean = correlate(im, xp.ones(mysize), 'same')
  1405. lsize = float(size)
  1406. lMean = lMean / lsize
  1407. # Estimate the local variance
  1408. lVar = (correlate(im ** 2, xp.ones(mysize), 'same') / lsize - lMean ** 2)
  1409. # Estimate the noise power if needed.
  1410. if noise is None:
  1411. noise = xp.mean(xp.reshape(lVar, (-1,)), axis=0)
  1412. res = (im - lMean)
  1413. res *= (1 - noise / lVar)
  1414. res += lMean
  1415. out = xp.where(lVar < noise, lMean, res)
  1416. return out
  1417. def convolve2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
  1418. """
  1419. Convolve two 2-dimensional arrays.
  1420. Convolve `in1` and `in2` with output size determined by `mode`, and
  1421. boundary conditions determined by `boundary` and `fillvalue`.
  1422. Parameters
  1423. ----------
  1424. in1 : array_like
  1425. First input.
  1426. in2 : array_like
  1427. Second input. Should have the same number of dimensions as `in1`.
  1428. mode : str {'full', 'valid', 'same'}, optional
  1429. A string indicating the size of the output:
  1430. ``full``
  1431. The output is the full discrete linear convolution
  1432. of the inputs. (Default)
  1433. ``valid``
  1434. The output consists only of those elements that do not
  1435. rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
  1436. must be at least as large as the other in every dimension.
  1437. ``same``
  1438. The output is the same size as `in1`, centered
  1439. with respect to the 'full' output.
  1440. boundary : str {'fill', 'wrap', 'symm'}, optional
  1441. A flag indicating how to handle boundaries:
  1442. ``fill``
  1443. pad input arrays with fillvalue. (default)
  1444. ``wrap``
  1445. circular boundary conditions.
  1446. ``symm``
  1447. symmetrical boundary conditions.
  1448. fillvalue : scalar, optional
  1449. Value to fill pad input arrays with. Default is 0.
  1450. Returns
  1451. -------
  1452. out : ndarray
  1453. A 2-dimensional array containing a subset of the discrete linear
  1454. convolution of `in1` with `in2`.
  1455. Examples
  1456. --------
  1457. Compute the gradient of an image by 2D convolution with a complex Scharr
  1458. operator. (Horizontal operator is real, vertical is imaginary.) Use
  1459. symmetric boundary condition to avoid creating edges at the image
  1460. boundaries.
  1461. >>> import numpy as np
  1462. >>> from scipy import signal
  1463. >>> from scipy import datasets
  1464. >>> ascent = datasets.ascent()
  1465. >>> scharr = np.array([[ -3-3j, 0-10j, +3 -3j],
  1466. ... [-10+0j, 0+ 0j, +10 +0j],
  1467. ... [ -3+3j, 0+10j, +3 +3j]]) # Gx + j*Gy
  1468. >>> grad = signal.convolve2d(ascent, scharr, boundary='symm', mode='same')
  1469. >>> import matplotlib.pyplot as plt
  1470. >>> fig, (ax_orig, ax_mag, ax_ang) = plt.subplots(3, 1, figsize=(6, 15))
  1471. >>> ax_orig.imshow(ascent, cmap='gray')
  1472. >>> ax_orig.set_title('Original')
  1473. >>> ax_orig.set_axis_off()
  1474. >>> ax_mag.imshow(np.absolute(grad), cmap='gray')
  1475. >>> ax_mag.set_title('Gradient magnitude')
  1476. >>> ax_mag.set_axis_off()
  1477. >>> ax_ang.imshow(np.angle(grad), cmap='hsv') # hsv is cyclic, like angles
  1478. >>> ax_ang.set_title('Gradient orientation')
  1479. >>> ax_ang.set_axis_off()
  1480. >>> fig.show()
  1481. """
  1482. xp = array_namespace(in1, in2)
  1483. # NB: do work in NumPy, only convert the output
  1484. in1 = np.asarray(in1)
  1485. in2 = np.asarray(in2)
  1486. if not in1.ndim == in2.ndim == 2:
  1487. raise ValueError('convolve2d inputs must both be 2-D arrays')
  1488. if _inputs_swap_needed(mode, in1.shape, in2.shape):
  1489. in1, in2 = in2, in1
  1490. val = _valfrommode(mode)
  1491. bval = _bvalfromboundary(boundary)
  1492. out = _sigtools._convolve2d(in1, in2, 1, val, bval, fillvalue)
  1493. return xp.asarray(out)
  1494. def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0):
  1495. """
  1496. Cross-correlate two 2-dimensional arrays.
  1497. Cross correlate `in1` and `in2` with output size determined by `mode`, and
  1498. boundary conditions determined by `boundary` and `fillvalue`.
  1499. Parameters
  1500. ----------
  1501. in1 : array_like
  1502. First input.
  1503. in2 : array_like
  1504. Second input. Should have the same number of dimensions as `in1`.
  1505. mode : str {'full', 'valid', 'same'}, optional
  1506. A string indicating the size of the output:
  1507. ``full``
  1508. The output is the full discrete linear cross-correlation
  1509. of the inputs. (Default)
  1510. ``valid``
  1511. The output consists only of those elements that do not
  1512. rely on the zero-padding. In 'valid' mode, either `in1` or `in2`
  1513. must be at least as large as the other in every dimension.
  1514. ``same``
  1515. The output is the same size as `in1`, centered
  1516. with respect to the 'full' output.
  1517. boundary : str {'fill', 'wrap', 'symm'}, optional
  1518. A flag indicating how to handle boundaries:
  1519. ``fill``
  1520. pad input arrays with fillvalue. (default)
  1521. ``wrap``
  1522. circular boundary conditions.
  1523. ``symm``
  1524. symmetrical boundary conditions.
  1525. fillvalue : scalar, optional
  1526. Value to fill pad input arrays with. Default is 0.
  1527. Returns
  1528. -------
  1529. correlate2d : ndarray
  1530. A 2-dimensional array containing a subset of the discrete linear
  1531. cross-correlation of `in1` with `in2`.
  1532. Notes
  1533. -----
  1534. When using "same" mode with even-length inputs, the outputs of `correlate`
  1535. and `correlate2d` differ: There is a 1-index offset between them.
  1536. Examples
  1537. --------
  1538. Use 2D cross-correlation to find the location of a template in a noisy
  1539. image:
  1540. >>> import numpy as np
  1541. >>> from scipy import signal, datasets, ndimage
  1542. >>> rng = np.random.default_rng()
  1543. >>> face = datasets.face(gray=True) - datasets.face(gray=True).mean()
  1544. >>> face = ndimage.zoom(face[30:500, 400:950], 0.5) # extract the face
  1545. >>> template = np.copy(face[135:165, 140:175]) # right eye
  1546. >>> template -= template.mean()
  1547. >>> face = face + rng.standard_normal(face.shape) * 50 # add noise
  1548. >>> corr = signal.correlate2d(face, template, boundary='symm', mode='same')
  1549. >>> y, x = np.unravel_index(np.argmax(corr), corr.shape) # find the match
  1550. >>> import matplotlib.pyplot as plt
  1551. >>> fig, (ax_orig, ax_template, ax_corr) = plt.subplots(3, 1,
  1552. ... figsize=(6, 15))
  1553. >>> ax_orig.imshow(face, cmap='gray')
  1554. >>> ax_orig.set_title('Original')
  1555. >>> ax_orig.set_axis_off()
  1556. >>> ax_template.imshow(template, cmap='gray')
  1557. >>> ax_template.set_title('Template')
  1558. >>> ax_template.set_axis_off()
  1559. >>> ax_corr.imshow(corr, cmap='gray')
  1560. >>> ax_corr.set_title('Cross-correlation')
  1561. >>> ax_corr.set_axis_off()
  1562. >>> ax_orig.plot(x, y, 'ro')
  1563. >>> fig.show()
  1564. """
  1565. xp = array_namespace(in1, in2)
  1566. in1 = np.asarray(in1)
  1567. in2 = np.asarray(in2)
  1568. if not in1.ndim == in2.ndim == 2:
  1569. raise ValueError('correlate2d inputs must both be 2-D arrays')
  1570. swapped_inputs = _inputs_swap_needed(mode, in1.shape, in2.shape)
  1571. if swapped_inputs:
  1572. in1, in2 = in2, in1
  1573. val = _valfrommode(mode)
  1574. bval = _bvalfromboundary(boundary)
  1575. out = _sigtools._convolve2d(in1, in2.conj(), 0, val, bval, fillvalue)
  1576. if swapped_inputs:
  1577. out = out[::-1, ::-1]
  1578. return xp.asarray(out)
  1579. def medfilt2d(input, kernel_size=3):
  1580. """
  1581. Median filter a 2-dimensional array.
  1582. Apply a median filter to the `input` array using a local window-size
  1583. given by `kernel_size` (must be odd). The array is zero-padded
  1584. automatically.
  1585. Parameters
  1586. ----------
  1587. input : array_like
  1588. A 2-dimensional input array.
  1589. kernel_size : array_like, optional
  1590. A scalar or a list of length 2, giving the size of the
  1591. median filter window in each dimension. Elements of
  1592. `kernel_size` should be odd. If `kernel_size` is a scalar,
  1593. then this scalar is used as the size in each dimension.
  1594. Default is a kernel of size (3, 3).
  1595. Returns
  1596. -------
  1597. out : ndarray
  1598. An array the same size as input containing the median filtered
  1599. result.
  1600. See Also
  1601. --------
  1602. scipy.ndimage.median_filter
  1603. Notes
  1604. -----
  1605. This is faster than `medfilt` when the input dtype is ``uint8``,
  1606. ``float32``, or ``float64``; for other types, this falls back to
  1607. `medfilt`. In some situations, `scipy.ndimage.median_filter` may be
  1608. faster than this function.
  1609. Examples
  1610. --------
  1611. >>> import numpy as np
  1612. >>> from scipy import signal
  1613. >>> x = np.arange(25).reshape(5, 5)
  1614. >>> x
  1615. array([[ 0, 1, 2, 3, 4],
  1616. [ 5, 6, 7, 8, 9],
  1617. [10, 11, 12, 13, 14],
  1618. [15, 16, 17, 18, 19],
  1619. [20, 21, 22, 23, 24]])
  1620. # Replaces i,j with the median out of 5*5 window
  1621. >>> signal.medfilt2d(x, kernel_size=5)
  1622. array([[ 0, 0, 2, 0, 0],
  1623. [ 0, 3, 7, 4, 0],
  1624. [ 2, 8, 12, 9, 4],
  1625. [ 0, 8, 12, 9, 0],
  1626. [ 0, 0, 12, 0, 0]])
  1627. # Replaces i,j with the median out of default 3*3 window
  1628. >>> signal.medfilt2d(x)
  1629. array([[ 0, 1, 2, 3, 0],
  1630. [ 1, 6, 7, 8, 4],
  1631. [ 6, 11, 12, 13, 9],
  1632. [11, 16, 17, 18, 14],
  1633. [ 0, 16, 17, 18, 0]])
  1634. # Replaces i,j with the median out of default 5*3 window
  1635. >>> signal.medfilt2d(x, kernel_size=[5,3])
  1636. array([[ 0, 1, 2, 3, 0],
  1637. [ 0, 6, 7, 8, 3],
  1638. [ 5, 11, 12, 13, 8],
  1639. [ 5, 11, 12, 13, 8],
  1640. [ 0, 11, 12, 13, 0]])
  1641. # Replaces i,j with the median out of default 3*5 window
  1642. >>> signal.medfilt2d(x, kernel_size=[3,5])
  1643. array([[ 0, 0, 2, 1, 0],
  1644. [ 1, 5, 7, 6, 3],
  1645. [ 6, 10, 12, 11, 8],
  1646. [11, 15, 17, 16, 13],
  1647. [ 0, 15, 17, 16, 0]])
  1648. # As seen in the examples,
  1649. # kernel numbers must be odd and not exceed original array dim
  1650. """
  1651. xp = array_namespace(input)
  1652. image = np.asarray(input)
  1653. # checking dtype.type, rather than just dtype, is necessary for
  1654. # excluding np.longdouble with MS Visual C.
  1655. if image.dtype.type not in (np.ubyte, np.float32, np.float64):
  1656. return xp.asarray(medfilt(image, kernel_size))
  1657. if kernel_size is None:
  1658. kernel_size = [3] * 2
  1659. kernel_size = np.asarray(kernel_size)
  1660. if kernel_size.shape == ():
  1661. kernel_size = np.repeat(kernel_size.item(), 2)
  1662. for size in kernel_size:
  1663. if (size % 2) != 1:
  1664. raise ValueError("Each element of kernel_size should be odd.")
  1665. result_np = _sigtools._medfilt2d(image, kernel_size)
  1666. return xp.asarray(result_np)
  1667. def lfilter(b, a, x, axis=-1, zi=None):
  1668. """
  1669. Filter data along one-dimension with an IIR or FIR filter.
  1670. Filter a data sequence, `x`, using a digital filter. This works for many
  1671. fundamental data types (including Object type). The filter is a direct
  1672. form II transposed implementation of the standard difference equation
  1673. (see Notes).
  1674. The function `sosfilt` (and filter design using ``output='sos'``) should be
  1675. preferred over `lfilter` for most filtering tasks, as second-order sections
  1676. have fewer numerical problems.
  1677. Parameters
  1678. ----------
  1679. b : array_like
  1680. The numerator coefficient vector in a 1-D sequence.
  1681. a : array_like
  1682. The denominator coefficient vector in a 1-D sequence. If ``a[0]``
  1683. is not 1, then both `a` and `b` are normalized by ``a[0]``.
  1684. x : array_like
  1685. An N-dimensional input array.
  1686. axis : int, optional
  1687. The axis of the input data array along which to apply the
  1688. linear filter. The filter is applied to each subarray along
  1689. this axis. Default is -1.
  1690. zi : array_like, optional
  1691. Initial conditions for the filter delays. It is a vector
  1692. (or array of vectors for an N-dimensional input) of length
  1693. ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then
  1694. initial rest is assumed. See `lfiltic` for more information.
  1695. Returns
  1696. -------
  1697. y : array
  1698. The output of the digital filter.
  1699. zf : array, optional
  1700. If `zi` is None, this is not returned, otherwise, `zf` holds the
  1701. final filter delay values.
  1702. See Also
  1703. --------
  1704. lfiltic : Construct initial conditions for `lfilter`.
  1705. lfilter_zi : Compute initial state (steady state of step response) for
  1706. `lfilter`.
  1707. filtfilt : A forward-backward filter, to obtain a filter with zero phase.
  1708. savgol_filter : A Savitzky-Golay filter.
  1709. sosfilt: Filter data using cascaded second-order sections.
  1710. sosfiltfilt: A forward-backward filter using second-order sections.
  1711. Notes
  1712. -----
  1713. The filter function is implemented as a direct II transposed structure.
  1714. This means that the filter implements::
  1715. a[0]*y[n] = b[0]*x[n] + b[1]*x[n-1] + ... + b[M]*x[n-M]
  1716. - a[1]*y[n-1] - ... - a[N]*y[n-N]
  1717. where `M` is the degree of the numerator, `N` is the degree of the
  1718. denominator, and `n` is the sample number. It is implemented using
  1719. the following difference equations (assuming M = N)::
  1720. a[0]*y[n] = b[0] * x[n] + d[0][n-1]
  1721. d[0][n] = b[1] * x[n] - a[1] * y[n] + d[1][n-1]
  1722. d[1][n] = b[2] * x[n] - a[2] * y[n] + d[2][n-1]
  1723. ...
  1724. d[N-2][n] = b[N-1]*x[n] - a[N-1]*y[n] + d[N-1][n-1]
  1725. d[N-1][n] = b[N] * x[n] - a[N] * y[n]
  1726. where `d` are the state variables.
  1727. The rational transfer function describing this filter in the
  1728. z-transform domain is::
  1729. -1 -M
  1730. b[0] + b[1]z + ... + b[M] z
  1731. Y(z) = -------------------------------- X(z)
  1732. -1 -N
  1733. a[0] + a[1]z + ... + a[N] z
  1734. Examples
  1735. --------
  1736. Generate a noisy signal to be filtered:
  1737. >>> import numpy as np
  1738. >>> from scipy import signal
  1739. >>> import matplotlib.pyplot as plt
  1740. >>> rng = np.random.default_rng()
  1741. >>> t = np.linspace(-1, 1, 201)
  1742. >>> x = (np.sin(2*np.pi*0.75*t*(1-t) + 2.1) +
  1743. ... 0.1*np.sin(2*np.pi*1.25*t + 1) +
  1744. ... 0.18*np.cos(2*np.pi*3.85*t))
  1745. >>> xn = x + rng.standard_normal(len(t)) * 0.08
  1746. Create an order 3 lowpass butterworth filter:
  1747. >>> b, a = signal.butter(3, 0.05)
  1748. Apply the filter to xn. Use lfilter_zi to choose the initial condition of
  1749. the filter:
  1750. >>> zi = signal.lfilter_zi(b, a)
  1751. >>> z, _ = signal.lfilter(b, a, xn, zi=zi*xn[0])
  1752. Apply the filter again, to have a result filtered at an order the same as
  1753. filtfilt:
  1754. >>> z2, _ = signal.lfilter(b, a, z, zi=zi*z[0])
  1755. Use filtfilt to apply the filter:
  1756. >>> y = signal.filtfilt(b, a, xn)
  1757. Plot the original signal and the various filtered versions:
  1758. >>> plt.figure
  1759. >>> plt.plot(t, xn, 'b', alpha=0.75)
  1760. >>> plt.plot(t, z, 'r--', t, z2, 'r', t, y, 'k')
  1761. >>> plt.legend(('noisy signal', 'lfilter, once', 'lfilter, twice',
  1762. ... 'filtfilt'), loc='best')
  1763. >>> plt.grid(True)
  1764. >>> plt.show()
  1765. """
  1766. xp = array_namespace(b, a, x, zi)
  1767. b = np.atleast_1d(b)
  1768. a = np.atleast_1d(a)
  1769. x = np.asarray(x)
  1770. if zi is not None:
  1771. zi = np.asarray(zi)
  1772. if not (b.ndim == 1 and xp_size(b) > 0):
  1773. raise ValueError(f"Parameter b is not a non-empty 1d array, since {b.shape=}!")
  1774. if not (a.ndim == 1 and xp_size(a) > 0):
  1775. raise ValueError(f"Parameter a is not a non-empty 1d array, since {a.shape=}!")
  1776. if len(a) == 1:
  1777. # This path only supports types fdgFDGO to mirror _linear_filter below.
  1778. # Any of b, a, x, or zi can set the dtype, but there is no default
  1779. # casting of other types; instead a NotImplementedError is raised.
  1780. b = np.asarray(b)
  1781. a = np.asarray(a)
  1782. x = _validate_x(x)
  1783. inputs = [b, a, x]
  1784. if zi is not None:
  1785. # _linear_filter does not broadcast zi, but does do expansion of
  1786. # singleton dims.
  1787. zi = np.asarray(zi)
  1788. if zi.ndim != x.ndim:
  1789. raise ValueError("Dimensions of parameters x and zi must match, but " +
  1790. f"{x.ndim=}, {zi.ndim=}!")
  1791. expected_shape = list(x.shape)
  1792. expected_shape[axis] = b.shape[0] - 1
  1793. expected_shape = tuple(expected_shape)
  1794. # check the trivial case where zi is the right shape first
  1795. if zi.shape != expected_shape:
  1796. strides = zi.ndim * [None]
  1797. if axis < 0:
  1798. axis += zi.ndim
  1799. for k in range(zi.ndim):
  1800. if k == axis and zi.shape[k] == expected_shape[k]:
  1801. strides[k] = zi.strides[k]
  1802. elif k != axis and zi.shape[k] == expected_shape[k]:
  1803. strides[k] = zi.strides[k]
  1804. elif k != axis and zi.shape[k] == 1:
  1805. strides[k] = 0
  1806. else:
  1807. raise ValueError('Unexpected shape for parameter zi: expected '
  1808. f'{expected_shape}, found {zi.shape}.')
  1809. zi = np.lib.stride_tricks.as_strided(zi, expected_shape,
  1810. strides)
  1811. inputs.append(zi)
  1812. dtype = np.result_type(*inputs)
  1813. if dtype.char not in 'fdgFDGO':
  1814. raise NotImplementedError("Parameter's dtypes produced result type " +
  1815. f"'{dtype}', which is not supported!")
  1816. b = np.array(b, dtype=dtype)
  1817. a = np.asarray(a, dtype=dtype)
  1818. b /= a[0]
  1819. x = np.asarray(x, dtype=dtype)
  1820. out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x)
  1821. ind = out_full.ndim * [slice(None)]
  1822. if zi is not None:
  1823. ind[axis] = slice(zi.shape[axis])
  1824. out_full[tuple(ind)] += zi
  1825. ind[axis] = slice(out_full.shape[axis] - len(b) + 1)
  1826. out = out_full[tuple(ind)]
  1827. if zi is None:
  1828. return xp.asarray(out)
  1829. else:
  1830. ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None)
  1831. zf = out_full[tuple(ind)]
  1832. return xp.asarray(out), xp.asarray(zf)
  1833. else:
  1834. if zi is None:
  1835. result =_sigtools._linear_filter(b, a, x, axis)
  1836. return xp.asarray(result)
  1837. else:
  1838. out, zf = _sigtools._linear_filter(b, a, x, axis, zi)
  1839. return xp.asarray(out), xp.asarray(zf)
  1840. def lfiltic(b, a, y, x=None):
  1841. """
  1842. Construct initial conditions for lfilter given input and output vectors.
  1843. Given a linear filter (b, a) and initial conditions on the output `y`
  1844. and the input `x`, return the initial conditions on the state vector zi
  1845. which is used by `lfilter` to generate the output given the input.
  1846. Parameters
  1847. ----------
  1848. b : array_like
  1849. Linear filter term.
  1850. a : array_like
  1851. Linear filter term.
  1852. y : array_like
  1853. Initial conditions.
  1854. If ``N = len(a) - 1``, then ``y = {y[-1], y[-2], ..., y[-N]}``.
  1855. If `y` is too short, it is padded with zeros.
  1856. x : array_like, optional
  1857. Initial conditions.
  1858. If ``M = len(b) - 1``, then ``x = {x[-1], x[-2], ..., x[-M]}``.
  1859. If `x` is not given, its initial conditions are assumed zero.
  1860. If `x` is too short, it is padded with zeros.
  1861. Returns
  1862. -------
  1863. zi : ndarray
  1864. The state vector ``zi = {z_0[-1], z_1[-1], ..., z_K-1[-1]}``,
  1865. where ``K = max(M, N)``.
  1866. See Also
  1867. --------
  1868. lfilter, lfilter_zi
  1869. """
  1870. xp = array_namespace(a, b, y, x)
  1871. a = xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp)
  1872. b = xpx.atleast_nd(xp.asarray(b), ndim=1, xp=xp)
  1873. if a.ndim > 1:
  1874. raise ValueError('Filter coefficients `a` must be 1-D.')
  1875. if b.ndim > 1:
  1876. raise ValueError('Filter coefficients `b` must be 1-D.')
  1877. N = a.shape[0] - 1
  1878. M = b.shape[0] - 1
  1879. K = max(M, N)
  1880. y = xp.asarray(y)
  1881. if N < 0:
  1882. raise ValueError("There must be at least one `a` coefficient.")
  1883. if x is None:
  1884. result_type = xp.result_type(b, a, y)
  1885. if xp.isdtype(result_type, ('bool', 'integral')): #'bui':
  1886. result_type = xp.float64
  1887. x = xp.zeros(M, dtype=result_type)
  1888. else:
  1889. x = xp.asarray(x)
  1890. result_type = xp.result_type(b, a, y, x)
  1891. if xp.isdtype(result_type, ('bool', 'integral')): #'bui':
  1892. result_type = xp.float64
  1893. x = xp.astype(x, result_type)
  1894. L = xp_size(x)
  1895. if L < M:
  1896. x = xp.concat((x, xp.zeros(M - L)))
  1897. y = xp.astype(y, result_type)
  1898. zi = xp.zeros(K, dtype=result_type)
  1899. L = xp_size(y)
  1900. if L < N:
  1901. y = xp.concat((y, xp.zeros(N - L)))
  1902. for m in range(M):
  1903. zi[m] = xp.sum(b[m + 1:] * x[:M - m], axis=0)
  1904. for m in range(N):
  1905. zi[m] -= xp.sum(a[m + 1:] * y[:N - m], axis=0)
  1906. if a[0] != 1.:
  1907. if a[0] == 0.:
  1908. raise ValueError("First `a` filter coefficient must be non-zero.")
  1909. zi /= a[0]
  1910. return zi
  1911. def deconvolve(signal, divisor):
  1912. """Deconvolves ``divisor`` out of ``signal`` using inverse filtering.
  1913. Returns the quotient and remainder such that
  1914. ``signal = convolve(divisor, quotient) + remainder``
  1915. Parameters
  1916. ----------
  1917. signal : (N,) array_like
  1918. Signal data, typically a recorded signal
  1919. divisor : (N,) array_like
  1920. Divisor data, typically an impulse response or filter that was
  1921. applied to the original signal
  1922. Returns
  1923. -------
  1924. quotient : ndarray
  1925. Quotient, typically the recovered original signal
  1926. remainder : ndarray
  1927. Remainder
  1928. See Also
  1929. --------
  1930. numpy.polydiv : performs polynomial division (same operation, but
  1931. also accepts poly1d objects)
  1932. Examples
  1933. --------
  1934. Deconvolve a signal that's been filtered:
  1935. >>> from scipy import signal
  1936. >>> original = [0, 1, 0, 0, 1, 1, 0, 0]
  1937. >>> impulse_response = [2, 1]
  1938. >>> recorded = signal.convolve(impulse_response, original)
  1939. >>> recorded
  1940. array([0, 2, 1, 0, 2, 3, 1, 0, 0])
  1941. >>> recovered, remainder = signal.deconvolve(recorded, impulse_response)
  1942. >>> recovered
  1943. array([ 0., 1., 0., 0., 1., 1., 0., 0.])
  1944. >>> remainder
  1945. array([0., 0., 0., 0., 0., 0., 0., 0., 0.])
  1946. """
  1947. xp = array_namespace(signal, divisor)
  1948. num = xpx.atleast_nd(xp.asarray(signal), ndim=1, xp=xp)
  1949. den = xpx.atleast_nd(xp.asarray(divisor), ndim=1, xp=xp)
  1950. if not (num.ndim == 1 and xp_size(num) > 0):
  1951. raise ValueError("Parameter signal must be non-empty 1d array, " +
  1952. f"but its shape is {num.shape}!")
  1953. if not (den.ndim == 1 and xp_size(den) > 0):
  1954. raise ValueError("Parameter divisor must be non-empty 1d array, " +
  1955. f"but its shape is {den.shape}!")
  1956. N = num.shape[0]
  1957. D = den.shape[0]
  1958. if D > N:
  1959. quot = []
  1960. rem = num
  1961. else:
  1962. input = xp.zeros(N - D + 1, dtype=xp.float64)
  1963. input[0] = 1
  1964. quot = lfilter(num, den, input)
  1965. rem = num - convolve(den, quot, mode='full')
  1966. return quot, rem
  1967. def hilbert(x, N=None, axis=-1):
  1968. r"""FFT-based computation of the analytic signal.
  1969. The analytic signal is calculated by zeroing out the negative frequencies and
  1970. doubling the amplitudes of the positive frequencies in the FFT domain.
  1971. The imaginary part of the result is the hilbert transform of the real-valued input
  1972. signal.
  1973. The transformation is done along the last axis by default.
  1974. For numpy arrays, `scipy.fft.set_workers` can be used to change the number of
  1975. workers used for the FFTs.
  1976. Parameters
  1977. ----------
  1978. x : array_like
  1979. Signal data. Must be real.
  1980. N : int, optional
  1981. Number of output samples. `x` is initially cropped or zero-padded to length
  1982. `N` along `axis`. Default: ``x.shape[axis]``
  1983. axis : int, optional
  1984. Axis along which to do the transformation. Default: -1.
  1985. Returns
  1986. -------
  1987. xa : ndarray
  1988. Analytic signal of `x`, of each 1-D array along `axis`
  1989. Notes
  1990. -----
  1991. The analytic signal ``x_a(t)`` of a real-valued signal ``x(t)``
  1992. can be expressed as [1]_
  1993. .. math:: x_a = F^{-1}(F(x) 2U) = x + i y\ ,
  1994. where `F` is the Fourier transform, `U` the unit step function,
  1995. and `y` the Hilbert transform of `x`. [2]_
  1996. In other words, the negative half of the frequency spectrum is zeroed
  1997. out, turning the real-valued signal into a complex-valued signal. The Hilbert
  1998. transformed signal can be obtained from ``np.imag(hilbert(x))``, and the
  1999. original signal from ``np.real(hilbert(x))``.
  2000. References
  2001. ----------
  2002. .. [1] Wikipedia, "Analytic signal".
  2003. https://en.wikipedia.org/wiki/Analytic_signal
  2004. .. [2] Wikipedia, "Hilbert Transform".
  2005. https://en.wikipedia.org/wiki/Hilbert_transform
  2006. .. [3] Leon Cohen, "Time-Frequency Analysis", 1995. Chapter 2.
  2007. .. [4] Alan V. Oppenheim, Ronald W. Schafer. Discrete-Time Signal
  2008. Processing, Third Edition, 2009. Chapter 12.
  2009. ISBN 13: 978-1292-02572-8
  2010. See Also
  2011. --------
  2012. envelope: Compute envelope of a real- or complex-valued signal.
  2013. Examples
  2014. --------
  2015. In this example we use the Hilbert transform to determine the amplitude
  2016. envelope and instantaneous frequency of an amplitude-modulated signal.
  2017. Let's create a chirp of which the frequency increases from 20 Hz to 100 Hz and
  2018. apply an amplitude modulation:
  2019. >>> import numpy as np
  2020. >>> import matplotlib.pyplot as plt
  2021. >>> from scipy.signal import hilbert, chirp
  2022. ...
  2023. >>> duration, fs = 1, 400 # 1 s signal with sampling frequency of 400 Hz
  2024. >>> t = np.arange(int(fs*duration)) / fs # timestamps of samples
  2025. >>> signal = chirp(t, 20.0, t[-1], 100.0)
  2026. >>> signal *= (1.0 + 0.5 * np.sin(2.0*np.pi*3.0*t) )
  2027. The amplitude envelope is given by the magnitude of the analytic signal. The
  2028. instantaneous frequency can be obtained by differentiating the
  2029. instantaneous phase in respect to time. The instantaneous phase corresponds
  2030. to the phase angle of the analytic signal.
  2031. >>> analytic_signal = hilbert(signal)
  2032. >>> amplitude_envelope = np.abs(analytic_signal)
  2033. >>> instantaneous_phase = np.unwrap(np.angle(analytic_signal))
  2034. >>> instantaneous_frequency = np.diff(instantaneous_phase) / (2.0*np.pi) * fs
  2035. ...
  2036. >>> fig, (ax0, ax1) = plt.subplots(nrows=2, sharex='all', tight_layout=True)
  2037. >>> ax0.set_title("Amplitude-modulated Chirp Signal")
  2038. >>> ax0.set_ylabel("Amplitude")
  2039. >>> ax0.plot(t, signal, label='Signal')
  2040. >>> ax0.plot(t, amplitude_envelope, label='Envelope')
  2041. >>> ax0.legend()
  2042. >>> ax1.set(xlabel="Time in seconds", ylabel="Frequency in Hz", ylim=(0, 120))
  2043. >>> ax1.plot(t[1:], instantaneous_frequency, 'C2-',
  2044. ... label='Instantaneous Frequency')
  2045. >>> ax1.legend()
  2046. >>> plt.show()
  2047. """
  2048. xp = array_namespace(x)
  2049. x = xp.asarray(x)
  2050. if xp.isdtype(x.dtype, 'complex floating'):
  2051. raise ValueError("x must be real.")
  2052. if N is None:
  2053. N = x.shape[axis]
  2054. if N <= 0:
  2055. raise ValueError("N must be positive.")
  2056. Xf = sp_fft.fft(x, N, axis=axis)
  2057. Xf = xp.moveaxis(Xf, axis, -1)
  2058. if N % 2 == 0:
  2059. Xf[..., 1: N // 2] *= 2.0
  2060. Xf[..., N // 2 + 1:N] = 0.0
  2061. else:
  2062. Xf[..., 1:(N + 1) // 2] *= 2.0
  2063. Xf[..., (N + 1) // 2:N] = 0.0
  2064. Xf = xp.moveaxis(Xf, -1, axis)
  2065. x = sp_fft.ifft(Xf, axis=axis)
  2066. return x
  2067. def hilbert2(x, N=None, axes=(-2, -1)):
  2068. r"""Compute the '2-D' analytic signal of `x`.
  2069. The 2-D analytic signal is calculated as a so-called "single-orthant" transform.
  2070. This is achieved by applying one-dimensional Hilbert functions (as in
  2071. `~scipy.signal.hilbert`) to the first and to the second array axis in Fourier space.
  2072. For NumPy arrays, `scipy.fft.set_workers` can be used to change the number of
  2073. workers used for the FFTs.
  2074. Parameters
  2075. ----------
  2076. x : array_like
  2077. Input signal. Must be at least two-dimensional.
  2078. N : int or tuple of two ints, optional
  2079. Number of output samples. `x` is initially cropped or zero-padded to length
  2080. `N` along `axes`. Default: ``x.shape[i] for i in axes``
  2081. axes : tuple of two ints, optional
  2082. Axes along which to do the transformation. Default: (-2, -1).
  2083. .. versionchanged:: 1.17
  2084. Added `axes` parameter
  2085. Returns
  2086. -------
  2087. xa : ndarray
  2088. Analytic signal of `x` taken along given axes.
  2089. Notes
  2090. -----
  2091. The "single-orthant" transform, as defined in [2]_, is calculated by performing the
  2092. following steps:
  2093. 1. Calculate the two-dimensional FFT of the input, i.e.,
  2094. .. math::
  2095. X[p,q] = \sum_{k,l=0}^{N_0,N_1} x[k,l]\,
  2096. e^{-2j\pi k p/N_0}\, e^{-2j\pi l q/N_1}
  2097. 2. Zero negative frequency bins and double their positive counterparts, i.e.,
  2098. .. math::
  2099. X_a[p,q] = \big(1 + s_{N_0}(p)\big) \big(1 + s_{N_1}(q)\big) X[p,q]
  2100. with :math:`s_N(.)` being a modified sign function defined as
  2101. .. math::
  2102. s_N(p) := \begin{cases}
  2103. -1 & \text{ for } p < 0\ ,\\
  2104. \phantom{-}0 & \text{ for } p = 0\ ,\\
  2105. +1 & \text{ for } 1 \leq p < (N+1) // 2\ ,\\
  2106. \phantom{-}0 & \text{ elsewhere.}
  2107. \end{cases}
  2108. The limitation of the ":math:`+1`" case to the range of ``[1:(N+1)//2]``
  2109. accounts for the unpaired Nyquist frequency bin at :math:`N/2` for even
  2110. :math:`N`. Note that :math:`X_a[p] = \big(1 + s_N(p)\big) X[p]` is the
  2111. one-dimensional Hilbert function (as in `~scipy.signal.hilbert`) in Fourier
  2112. space.
  2113. 3. Produce the analytic signal by performing the inverse FFT, i.e.,
  2114. .. math::
  2115. x_a[k, l] = \frac{1}{N_0 N_1}
  2116. \sum_{p,q=0}^{N_0,N_1} X_a[p,q]\, e^{2j\pi k p/N_0}\, e^{2j\pi l q/N_1}
  2117. The "single-orthant" transform is not the only possible definition of an analytic
  2118. signal in multiple dimensions (as noted in [1]_). Consult [3]_ for a description of
  2119. properties that this 2-D transform does and does not share with the 1-D transform.
  2120. The second example below shows one of the downsides of this approach.
  2121. References
  2122. ----------
  2123. .. [1] Wikipedia, "Analytic signal",
  2124. https://en.wikipedia.org/wiki/Analytic_signal
  2125. .. [2] Hahn, Stefan L. "Multidimensional complex signals with
  2126. single-orthant spectra." Proceedings of the IEEE 80.8
  2127. (1992): 1287-1300.
  2128. `PDF <https://ieeexplore.ieee.org/iel1/5/4083/00158601.pdf>`__
  2129. .. [3] Bülow, Thomas, and Gerald Sommer. "A novel approach to the 2D analytic
  2130. signal." In International Conference on Computer Analysis of Images and
  2131. Patterns, pp. 25-32. Berlin, Heidelberg: Springer Berlin Heidelberg, 1999.
  2132. `PDF <https://www.informatik.uni-kiel.de/inf/Sommer/doc/Publications/tbl/caip99.pdf>`__
  2133. Examples
  2134. --------
  2135. The following example calculates the two-dimensional analytic signal from a single
  2136. impulse with an added constant offset. The impulse produces an FFT where each bin
  2137. has a value of one and the constant offset component produces only a non-zero
  2138. component at the ``(0,0)`` bin.
  2139. >>> import numpy as np
  2140. >>> from scipy.fft import fft2, fftshift, ifftshift
  2141. >>> from scipy.signal import hilbert2
  2142. ...
  2143. >>> # Input signal is unit impulse with a constant offset:
  2144. >>> x = np.ones((5, 5)) / 5
  2145. >>> x[0, 0] += 1
  2146. ...
  2147. >>> X = fftshift(fft2(x)) # Zero frequency bin is at center
  2148. >>> print(X)
  2149. [[1.-0.j 1.-0.j 1.-0.j 1.+0.j 1.+0.j]
  2150. [1.-0.j 1.-0.j 1.-0.j 1.+0.j 1.+0.j]
  2151. [1.-0.j 1.-0.j 6.-0.j 1.+0.j 1.+0.j]
  2152. [1.-0.j 1.-0.j 1.+0.j 1.+0.j 1.+0.j]
  2153. [1.-0.j 1.-0.j 1.+0.j 1.+0.j 1.+0.j]]
  2154. >>> x_a = hilbert2(x)
  2155. >>> X_a = fftshift(fft2(x_a))
  2156. >>> print(np.round(X_a, 3))
  2157. [[ 0.+0.j 0.+0.j -0.+0.j 0.+0.j 0.+0.j]
  2158. [ 0.+0.j 0.+0.j -0.+0.j 0.+0.j 0.+0.j]
  2159. [ 0.+0.j 0.+0.j 6.+0.j 2.+0.j 2.+0.j]
  2160. [ 0.+0.j 0.+0.j 2.+0.j 4.+0.j 4.+0.j]
  2161. [ 0.+0.j 0.+0.j 2.+0.j 4.+0.j 4.+0.j]]
  2162. The FFT of the result illustrates that the values of the lower right quadrant (or
  2163. orthant) with purely positive frequency bins have been quadrupled. The values at its
  2164. borders, where only one frequency component is zero, are doubled. The zero frequency
  2165. bin ``(0, 0)`` has not been altered. All other quadrants have been set to zero.
  2166. This second example illustrates a problem with the "single-orthant" convention. A
  2167. purely real signal can produce an analytic signal which is completely zero:
  2168. >>> from scipy.fft import fft2, fftshift, ifft2, ifftshift
  2169. >>> from scipy.signal import hilbert2
  2170. ...
  2171. >>> # Create a real signal by ensuring `Z[-p,-q] == np.conj(Z[p,q])` holds:
  2172. >>> Z = np.array([[0, 0, 0, 0, 0],
  2173. ... [0, 0, 0, 1, 0],
  2174. ... [0, 0, 0, 0, 0],
  2175. ... [0, 1, 0, 0, 0],
  2176. ... [0, 0, 0, 0, 0]]) * 25
  2177. >>> z = ifft2(ifftshift(Z))
  2178. >>> np.allclose(z.imag, 0) # z is a real signal
  2179. True
  2180. >>> np.sum(z.real**2) # z.real is non-zero
  2181. np.float64(50.0)
  2182. >>> z_a = hilbert2(z.real)
  2183. >>> np.allclose(z_a, 0) # analytic signal is zero
  2184. True
  2185. """
  2186. xp = array_namespace(x)
  2187. x = xpx.atleast_nd(xp.asarray(x), ndim=2, xp=xp)
  2188. if xp.isdtype(x.dtype, 'complex floating'):
  2189. raise ValueError("x must be real.")
  2190. if len(axes) != 2:
  2191. raise ValueError("axes must be a tuple of length 2")
  2192. if axes[0] == axes[1]:
  2193. raise ValueError("axes must contain 2 distinct axes")
  2194. if N is None:
  2195. N = (x.shape[axes[0]], x.shape[axes[1]])
  2196. elif isinstance(N, int):
  2197. if N <= 0:
  2198. raise ValueError("N must be positive.")
  2199. N = (N, N)
  2200. elif len(N) != 2 or np.any(np.asarray(N) <= 0):
  2201. raise ValueError("When given as a tuple, N must hold exactly "
  2202. "two positive integers")
  2203. Xf = sp_fft.fft2(x, N, axes=axes)
  2204. Xf = xp.moveaxis(Xf, axes, (-2, -1))
  2205. k0, k1 = (N[0] + 1) // 2, (N[1] + 1) // 2
  2206. if k0 > 1: # condition k0 > 1 needed for Dask backend
  2207. Xf[..., 1:k0, :] *= 2.0
  2208. if k1 > 1: # condition k1 > 1 needed for Dask backend
  2209. Xf[..., :, 1:k1] *= 2.0
  2210. Xf[..., k0:, :] = 0.0
  2211. Xf[..., :, k1:] = 0.0
  2212. Xf = xp.moveaxis(Xf, (-2, -1), axes)
  2213. x = sp_fft.ifft2(Xf, axes=axes)
  2214. return x
  2215. def envelope(z, bp_in: tuple[int | None, int | None] = (1, None), *,
  2216. n_out: int | None = None, squared: bool = False,
  2217. residual: Literal['lowpass', 'all', None] = 'lowpass',
  2218. axis: int = -1):
  2219. r"""Compute the envelope of a real- or complex-valued signal.
  2220. Parameters
  2221. ----------
  2222. z : ndarray
  2223. Real- or complex-valued input signal, which is assumed to be made up of ``n``
  2224. samples and having sampling interval ``T``. `z` may also be a multidimensional
  2225. array with the time axis being defined by `axis`.
  2226. bp_in : tuple[int | None, int | None], optional
  2227. 2-tuple defining the frequency band ``bp_in[0]:bp_in[1]`` of the input filter.
  2228. The corner frequencies are specified as integer multiples of ``1/(n*T)`` with
  2229. ``-n//2 <= bp_in[0] < bp_in[1] <= (n+1)//2`` being the allowed frequency range.
  2230. ``None`` entries are replaced with ``-n//2`` or ``(n+1)//2`` respectively. The
  2231. default of ``(1, None)`` removes the mean value as well as the negative
  2232. frequency components.
  2233. n_out : int | None, optional
  2234. If not ``None`` the output will be resampled to `n_out` samples. The default
  2235. of ``None`` sets the output to the same length as the input `z`.
  2236. squared : bool, optional
  2237. If set, the square of the envelope is returned. The bandwidth of the squared
  2238. envelope is often smaller than the non-squared envelope bandwidth due to the
  2239. nonlinear nature of the utilized absolute value function. I.e., the embedded
  2240. square root function typically produces addiational harmonics.
  2241. The default is ``False``.
  2242. residual : Literal['lowpass', 'all', None], optional
  2243. This option determines what kind of residual, i.e., the signal part which the
  2244. input bandpass filter removes, is returned. ``'all'`` returns everything except
  2245. the contents of the frequency band ``bp_in[0]:bp_in[1]``, ``'lowpass'``
  2246. returns the contents of the frequency band ``< bp_in[0]``. If ``None`` then
  2247. only the envelope is returned. Default: ``'lowpass'``.
  2248. axis : int, optional
  2249. Axis of `z` over which to compute the envelope. Default is last the axis.
  2250. Returns
  2251. -------
  2252. ndarray
  2253. If parameter `residual` is ``None`` then an array ``z_env`` with the same shape
  2254. as the input `z` is returned, containing its envelope. Otherwise, an array with
  2255. shape ``(2, *z.shape)``, containing the arrays ``z_env`` and ``z_res``, stacked
  2256. along the first axis, is returned.
  2257. It allows unpacking, i.e., ``z_env, z_res = envelope(z, residual='all')``.
  2258. The residual ``z_res`` contains the signal part which the input bandpass filter
  2259. removed, depending on the parameter `residual`. Note that for real-valued
  2260. signals, a real-valued residual is returned. Hence, the negative frequency
  2261. components of `bp_in` are ignored.
  2262. Notes
  2263. -----
  2264. Any complex-valued signal :math:`z(t)` can be described by a real-valued
  2265. instantaneous amplitude :math:`a(t)` and a real-valued instantaneous phase
  2266. :math:`\phi(t)`, i.e., :math:`z(t) = a(t) \exp\!\big(j \phi(t)\big)`. The
  2267. envelope is defined as the absolute value of the amplitude :math:`|a(t)| = |z(t)|`,
  2268. which is at the same time the absolute value of the signal. Hence, :math:`|a(t)|`
  2269. "envelopes" the class of all signals with amplitude :math:`a(t)` and arbitrary
  2270. phase :math:`\phi(t)`.
  2271. For real-valued signals, :math:`x(t) = a(t) \cos\!\big(\phi(t)\big)` is the
  2272. analogous formulation. Hence, :math:`|a(t)|` can be determined by converting
  2273. :math:`x(t)` into an analytic signal :math:`z_a(t)` by means of a Hilbert
  2274. transform, i.e.,
  2275. :math:`z_a(t) = a(t) \cos\!\big(\phi(t)\big) + j a(t) \sin\!\big(\phi(t) \big)`,
  2276. which produces a complex-valued signal with the same envelope :math:`|a(t)|`.
  2277. The implementation is based on computing the FFT of the input signal and then
  2278. performing the necessary operations in Fourier space. Hence, the typical FFT
  2279. caveats need to be taken into account:
  2280. * The signal is assumed to be periodic. Discontinuities between signal start and
  2281. end can lead to unwanted results due to Gibbs phenomenon.
  2282. * The FFT is slow if the signal length is prime or very long. Also, the memory
  2283. demands are typically higher than a comparable FIR/IIR filter based
  2284. implementation.
  2285. * The frequency spacing ``1 / (n*T)`` for corner frequencies of the bandpass filter
  2286. corresponds to the frequencies produced by ``scipy.fft.fftfreq(len(z), T)``.
  2287. If the envelope of a complex-valued signal `z` with no bandpass filtering is
  2288. desired, i.e., ``bp_in=(None, None)``, then the envelope corresponds to the
  2289. absolute value. Hence, it is more efficient to use ``np.abs(z)`` instead of this
  2290. function.
  2291. Although computing the envelope based on the analytic signal [1]_ is the natural
  2292. method for real-valued signals, other methods are also frequently used. The most
  2293. popular alternative is probably the so-called "square-law" envelope detector and
  2294. its relatives [2]_. They do not always compute the correct result for all kinds of
  2295. signals, but are usually correct and typically computationally more efficient for
  2296. most kinds of narrowband signals. The definition for an envelope presented here is
  2297. common where instantaneous amplitude and phase are of interest (e.g., as described
  2298. in [3]_). There exist also other concepts, which rely on the general mathematical
  2299. idea of an envelope [4]_: A pragmatic approach is to determine all upper and lower
  2300. signal peaks and use a spline interpolation to determine the curves [5]_.
  2301. References
  2302. ----------
  2303. .. [1] "Analytic Signal", Wikipedia,
  2304. https://en.wikipedia.org/wiki/Analytic_signal
  2305. .. [2] Lyons, Richard, "Digital envelope detection: The good, the bad, and the
  2306. ugly", IEEE Signal Processing Magazine 34.4 (2017): 183-187.
  2307. `PDF <https://community.infineon.com/gfawx74859/attachments/gfawx74859/psoc135/46469/1/R.%20Lyons_envelope_detection_v3.pdf>`__
  2308. .. [3] T.G. Kincaid, "The complex representation of signals.",
  2309. TIS R67# MH5, General Electric Co. (1966).
  2310. `PDF <https://apps.dtic.mil/sti/tr/pdf/ADA953296.pdf>`__
  2311. .. [4] "Envelope (mathematics)", Wikipedia,
  2312. https://en.wikipedia.org/wiki/Envelope_(mathematics)
  2313. .. [5] Yang, Yanli. "A signal theoretic approach for envelope analysis of
  2314. real-valued signals." IEEE Access 5 (2017): 5623-5630.
  2315. `PDF <https://ieeexplore.ieee.org/iel7/6287639/6514899/07891054.pdf>`__
  2316. See Also
  2317. --------
  2318. hilbert: Compute analytic signal by means of Hilbert transform.
  2319. Examples
  2320. --------
  2321. The following plot illustrates the envelope of a signal with variable frequency and
  2322. a low-frequency drift. To separate the drift from the envelope, a 4 Hz highpass
  2323. filter is used. The low-pass residuum of the input bandpass filter is utilized to
  2324. determine an asymmetric upper and lower bound to enclose the signal. Due to the
  2325. smoothness of the resulting envelope, it is down-sampled from 500 to 40 samples.
  2326. Note that the instantaneous amplitude ``a_x`` and the computed envelope ``x_env``
  2327. are not perfectly identical. This is due to the signal not being perfectly periodic
  2328. as well as the existence of some spectral overlapping of ``x_carrier`` and
  2329. ``x_drift``. Hence, they cannot be completely separated by a bandpass filter.
  2330. >>> import matplotlib.pyplot as plt
  2331. >>> import numpy as np
  2332. >>> from scipy.signal.windows import gaussian
  2333. >>> from scipy.signal import envelope
  2334. ...
  2335. >>> n, n_out = 500, 40 # number of signal samples and envelope samples
  2336. >>> T = 2 / n # sampling interval for 2 s duration
  2337. >>> t = np.arange(n) * T # time stamps
  2338. >>> a_x = gaussian(len(t), 0.4/T) # instantaneous amplitude
  2339. >>> phi_x = 30*np.pi*t + 35*np.cos(2*np.pi*0.25*t) # instantaneous phase
  2340. >>> x_carrier = a_x * np.cos(phi_x)
  2341. >>> x_drift = 0.3 * gaussian(len(t), 0.4/T) # drift
  2342. >>> x = x_carrier + x_drift
  2343. ...
  2344. >>> bp_in = (int(4 * (n*T)), None) # 4 Hz highpass input filter
  2345. >>> x_env, x_res = envelope(x, bp_in, n_out=n_out)
  2346. >>> t_out = np.arange(n_out) * (n / n_out) * T
  2347. ...
  2348. >>> fg0, ax0 = plt.subplots(1, 1, tight_layout=True)
  2349. >>> ax0.set_title(r"$4\,$Hz Highpass Envelope of Drifting Signal")
  2350. >>> ax0.set(xlabel="Time in seconds", xlim=(0, n*T), ylabel="Amplitude")
  2351. >>> ax0.plot(t, x, 'C0-', alpha=0.5, label="Signal")
  2352. >>> ax0.plot(t, x_drift, 'C2--', alpha=0.25, label="Drift")
  2353. >>> ax0.plot(t_out, x_res+x_env, 'C1.-', alpha=0.5, label="Envelope")
  2354. >>> ax0.plot(t_out, x_res-x_env, 'C1.-', alpha=0.5, label=None)
  2355. >>> ax0.grid(True)
  2356. >>> ax0.legend()
  2357. >>> plt.show()
  2358. The second example provides a geometric envelope interpretation of complex-valued
  2359. signals: The following two plots show the complex-valued signal as a blue
  2360. 3d-trajectory and the envelope as an orange round tube with varying diameter, i.e.,
  2361. as :math:`|a(t)| \exp(j\rho(t))`, with :math:`\rho(t)\in[-\pi,\pi]`. Also, the
  2362. projection into the 2d real and imaginary coordinate planes of trajectory and tube
  2363. is depicted. Every point of the complex-valued signal touches the tube's surface.
  2364. The left plot shows an analytic signal, i.e, the phase difference between
  2365. imaginary and real part is always 90 degrees, resulting in a spiraling trajectory.
  2366. It can be seen that in this case the real part has also the expected envelope,
  2367. i.e., representing the absolute value of the instantaneous amplitude.
  2368. The right plot shows the real part of that analytic signal being interpreted
  2369. as a complex-vauled signal, i.e., having zero imaginary part. There the resulting
  2370. envelope is not as smooth as in the analytic case and the instantaneous amplitude
  2371. in the real plane is not recovered. If ``z_re`` had been passed as a real-valued
  2372. signal, i.e., as ``z_re = z.real`` instead of ``z_re = z.real + 0j``, the result
  2373. would have been identical to the left plot. The reason for this is that real-valued
  2374. signals are interpreted as being the real part of a complex-valued analytic signal.
  2375. >>> import matplotlib.pyplot as plt
  2376. >>> import numpy as np
  2377. >>> from scipy.signal.windows import gaussian
  2378. >>> from scipy.signal import envelope
  2379. ...
  2380. >>> n, T = 1000, 1/1000 # number of samples and sampling interval
  2381. >>> t = np.arange(n) * T # time stamps for 1 s duration
  2382. >>> f_c = 3 # Carrier frequency for signal
  2383. >>> z = gaussian(len(t), 0.3/T) * np.exp(2j*np.pi*f_c*t) # analytic signal
  2384. >>> z_re = z.real + 0j # complex signal with zero imaginary part
  2385. ...
  2386. >>> e_a, e_r = (envelope(z_, (None, None), residual=None) for z_ in (z, z_re))
  2387. ...
  2388. >>> # Generate grids to visualize envelopes as 2d and 3d surfaces:
  2389. >>> E2d_t, E2_amp = np.meshgrid(t, [-1, 1])
  2390. >>> E2d_1 = np.ones_like(E2_amp)
  2391. >>> E3d_t, E3d_phi = np.meshgrid(t, np.linspace(-np.pi, np.pi, 300))
  2392. >>> ma = 1.8 # maximum axis values in real and imaginary direction
  2393. ...
  2394. >>> fg0 = plt.figure(figsize=(6.2, 4.))
  2395. >>> ax00 = fg0.add_subplot(1, 2, 1, projection='3d')
  2396. >>> ax01 = fg0.add_subplot(1, 2, 2, projection='3d', sharex=ax00,
  2397. ... sharey=ax00, sharez=ax00)
  2398. >>> ax00.set_title("Analytic Signal")
  2399. >>> ax00.set(xlim=(0, 1), ylim=(-ma, ma), zlim=(-ma, ma))
  2400. >>> ax01.set_title("Real-valued Signal")
  2401. >>> for z_, e_, ax_ in zip((z, z.real), (e_a, e_r), (ax00, ax01)):
  2402. ... ax_.set(xlabel="Time $t$", ylabel="Real Amp. $x(t)$",
  2403. ... zlabel="Imag. Amp. $y(t)$")
  2404. ... ax_.plot(t, z_.real, 'C0-', zs=-ma, zdir='z', alpha=0.5, label="Real")
  2405. ... ax_.plot_surface(E2d_t, e_*E2_amp, -ma*E2d_1, color='C1', alpha=0.25)
  2406. ... ax_.plot(t, z_.imag, 'C0-', zs=+ma, zdir='y', alpha=0.5, label="Imag.")
  2407. ... ax_.plot_surface(E2d_t, ma*E2d_1, e_*E2_amp, color='C1', alpha=0.25)
  2408. ... ax_.plot(t, z_.real, z_.imag, 'C0-', label="Signal")
  2409. ... ax_.plot_surface(E3d_t, e_*np.cos(E3d_phi), e_*np.sin(E3d_phi),
  2410. ... color='C1', alpha=0.5, shade=True, label="Envelope")
  2411. ... ax_.view_init(elev=22.7, azim=-114.3)
  2412. >>> fg0.subplots_adjust(left=0.08, right=0.97, wspace=0.15)
  2413. >>> plt.show()
  2414. """
  2415. xp = array_namespace(z)
  2416. if not (-z.ndim <= axis < z.ndim):
  2417. raise ValueError(f"Invalid parameter {axis=} for {z.shape=}!")
  2418. if not (z.shape[axis] > 0):
  2419. raise ValueError(f"z.shape[axis] not > 0 for {z.shape=}, {axis=}!")
  2420. if len(bp_in) != 2 or not all((isinstance(b_, int) or b_ is None) for b_ in bp_in):
  2421. raise ValueError(f"{bp_in=} isn't a 2-tuple of type (int | None, int | None)!")
  2422. if not ((isinstance(n_out, int) and 0 < n_out) or n_out is None):
  2423. raise ValueError(f"{n_out=} is not a positive integer or None!")
  2424. if residual not in ('lowpass', 'all', None):
  2425. raise ValueError(f"{residual=} not in ['lowpass', 'all', None]!")
  2426. n = z.shape[axis] # number of time samples of input
  2427. n_out = n if n_out is None else n_out
  2428. fak = n_out / n # scaling factor for resampling
  2429. bp = slice(bp_in[0] if bp_in[0] is not None else -(n//2),
  2430. bp_in[1] if bp_in[1] is not None else (n+1)//2)
  2431. if not (-n//2 <= bp.start < bp.stop <= (n+1)//2):
  2432. raise ValueError("`-n//2 <= bp_in[0] < bp_in[1] <= (n+1)//2` does not hold " +
  2433. f"for n={z.shape[axis]=} and {bp_in=}!")
  2434. # moving active axis to end allows to use `...` for indexing:
  2435. z = xp.moveaxis(z, axis, -1)
  2436. if xp.isdtype(z.dtype, 'complex floating'):
  2437. Z = sp_fft.fft(z)
  2438. else: # avoid calculating negative frequency bins for real signals:
  2439. dt = sp_fft.rfft(z[..., :1]).dtype
  2440. Z = xp.zeros_like(z, dtype=dt)
  2441. Z[..., :n//2 + 1] = sp_fft.rfft(z)
  2442. if bp.start > 0: # make signal analytic within bp_in band:
  2443. Z[..., bp] *= 2
  2444. elif bp.stop > 0:
  2445. Z[..., 1:bp.stop] *= 2
  2446. if not (bp.start <= 0 < bp.stop): # envelope is invariant to freq. shifts.
  2447. z_bb = sp_fft.ifft(Z[..., bp], n=n_out) * fak # baseband signal
  2448. else:
  2449. bp_shift = slice(bp.start + n//2, bp.stop + n//2)
  2450. z_bb = sp_fft.ifft(sp_fft.fftshift(Z, axes=-1)[..., bp_shift], n=n_out) * fak
  2451. z_env = xp.abs(z_bb) if not squared else xp.real(z_bb) ** 2 + xp.imag(z_bb) ** 2
  2452. z_env = xp.moveaxis(z_env, -1, axis)
  2453. # Calculate the residual from the input bandpass filter:
  2454. if residual is None:
  2455. return z_env
  2456. if not (bp.start <= 0 < bp.stop):
  2457. Z[..., bp] = 0
  2458. else:
  2459. Z[..., :bp.stop], Z[..., bp.start:] = 0, 0
  2460. if residual == 'lowpass':
  2461. if bp.stop > 0:
  2462. Z[..., bp.stop:(n+1) // 2] = 0
  2463. else:
  2464. Z[..., bp.start:], Z[..., 0:(n + 1) // 2] = 0, 0
  2465. if xp.isdtype(z.dtype, 'complex floating'): # resample accounts for unpaired bins:
  2466. z_res = resample(Z, n_out, axis=-1, domain='freq') # ifft() with corrections
  2467. else: # account for unpaired bin at m//2 before doing irfft():
  2468. if n_out != n and (m := min(n, n_out)) % 2 == 0:
  2469. Z[..., m//2] *= 2 if n_out < n else 0.5
  2470. z_res = fak * sp_fft.irfft(Z, n=n_out)
  2471. return xp.stack((z_env, xp.moveaxis(z_res, -1, axis)), axis=0)
  2472. def _cmplx_sort(p):
  2473. """Sort roots based on magnitude.
  2474. Parameters
  2475. ----------
  2476. p : array_like
  2477. The roots to sort, as a 1-D array.
  2478. Returns
  2479. -------
  2480. p_sorted : ndarray
  2481. Sorted roots.
  2482. indx : ndarray
  2483. Array of indices needed to sort the input `p`.
  2484. Examples
  2485. --------
  2486. >>> from scipy import signal
  2487. >>> vals = [1, 4, 1+1.j, 3]
  2488. >>> p_sorted, indx = signal.cmplx_sort(vals)
  2489. >>> p_sorted
  2490. array([1.+0.j, 1.+1.j, 3.+0.j, 4.+0.j])
  2491. >>> indx
  2492. array([0, 2, 3, 1])
  2493. """
  2494. p = np.asarray(p)
  2495. indx = np.argsort(abs(p))
  2496. return np.take(p, indx, 0), indx
  2497. def unique_roots(p, tol=1e-3, rtype='min'):
  2498. """Determine unique roots and their multiplicities from a list of roots.
  2499. Parameters
  2500. ----------
  2501. p : array_like
  2502. The list of roots.
  2503. tol : float, optional
  2504. The tolerance for two roots to be considered equal in terms of
  2505. the distance between them. Default is 1e-3. Refer to Notes about
  2506. the details on roots grouping.
  2507. rtype : {'max', 'maximum', 'min', 'minimum', 'avg', 'mean'}, optional
  2508. How to determine the returned root if multiple roots are within
  2509. `tol` of each other.
  2510. - 'max', 'maximum': pick the maximum of those roots
  2511. - 'min', 'minimum': pick the minimum of those roots
  2512. - 'avg', 'mean': take the average of those roots
  2513. When finding minimum or maximum among complex roots they are compared
  2514. first by the real part and then by the imaginary part.
  2515. Returns
  2516. -------
  2517. unique : ndarray
  2518. The list of unique roots.
  2519. multiplicity : ndarray
  2520. The multiplicity of each root.
  2521. Notes
  2522. -----
  2523. If we have 3 roots ``a``, ``b`` and ``c``, such that ``a`` is close to
  2524. ``b`` and ``b`` is close to ``c`` (distance is less than `tol`), then it
  2525. doesn't necessarily mean that ``a`` is close to ``c``. It means that roots
  2526. grouping is not unique. In this function we use "greedy" grouping going
  2527. through the roots in the order they are given in the input `p`.
  2528. This utility function is not specific to roots but can be used for any
  2529. sequence of values for which uniqueness and multiplicity has to be
  2530. determined. For a more general routine, see `numpy.unique`.
  2531. Examples
  2532. --------
  2533. >>> from scipy import signal
  2534. >>> vals = [0, 1.3, 1.31, 2.8, 1.25, 2.2, 10.3]
  2535. >>> uniq, mult = signal.unique_roots(vals, tol=2e-2, rtype='avg')
  2536. Check which roots have multiplicity larger than 1:
  2537. >>> uniq[mult > 1]
  2538. array([ 1.305])
  2539. """
  2540. if rtype in ['max', 'maximum']:
  2541. reduce = np.max
  2542. elif rtype in ['min', 'minimum']:
  2543. reduce = np.min
  2544. elif rtype in ['avg', 'mean']:
  2545. reduce = np.mean
  2546. else:
  2547. raise ValueError("`rtype` must be one of "
  2548. "{'max', 'maximum', 'min', 'minimum', 'avg', 'mean'}")
  2549. p = np.asarray(p)
  2550. points = np.empty((len(p), 2))
  2551. points[:, 0] = np.real(p)
  2552. points[:, 1] = np.imag(p)
  2553. tree = cKDTree(points)
  2554. p_unique = []
  2555. p_multiplicity = []
  2556. used = np.zeros(len(p), dtype=bool)
  2557. for i in range(len(p)):
  2558. if used[i]:
  2559. continue
  2560. group = tree.query_ball_point(points[i], tol)
  2561. group = [x for x in group if not used[x]]
  2562. p_unique.append(reduce(p[group]))
  2563. p_multiplicity.append(len(group))
  2564. used[group] = True
  2565. return np.asarray(p_unique), np.asarray(p_multiplicity)
  2566. def invres(r, p, k, tol=1e-3, rtype='avg'):
  2567. """Compute b(s) and a(s) from partial fraction expansion.
  2568. If `M` is the degree of numerator `b` and `N` the degree of denominator
  2569. `a`::
  2570. b(s) b[0] s**(M) + b[1] s**(M-1) + ... + b[M]
  2571. H(s) = ------ = ------------------------------------------
  2572. a(s) a[0] s**(N) + a[1] s**(N-1) + ... + a[N]
  2573. then the partial-fraction expansion H(s) is defined as::
  2574. r[0] r[1] r[-1]
  2575. = -------- + -------- + ... + --------- + k(s)
  2576. (s-p[0]) (s-p[1]) (s-p[-1])
  2577. If there are any repeated roots (closer together than `tol`), then H(s)
  2578. has terms like::
  2579. r[i] r[i+1] r[i+n-1]
  2580. -------- + ----------- + ... + -----------
  2581. (s-p[i]) (s-p[i])**2 (s-p[i])**n
  2582. This function is used for polynomials in positive powers of s or z,
  2583. such as analog filters or digital filters in controls engineering. For
  2584. negative powers of z (typical for digital filters in DSP), use `invresz`.
  2585. Parameters
  2586. ----------
  2587. r : array_like
  2588. Residues corresponding to the poles. For repeated poles, the residues
  2589. must be ordered to correspond to ascending by power fractions.
  2590. p : array_like
  2591. Poles. Equal poles must be adjacent.
  2592. k : array_like
  2593. Coefficients of the direct polynomial term.
  2594. tol : float, optional
  2595. The tolerance for two roots to be considered equal in terms of
  2596. the distance between them. Default is 1e-3. See `unique_roots`
  2597. for further details.
  2598. rtype : {'avg', 'min', 'max'}, optional
  2599. Method for computing a root to represent a group of identical roots.
  2600. Default is 'avg'. See `unique_roots` for further details.
  2601. Returns
  2602. -------
  2603. b : ndarray
  2604. Numerator polynomial coefficients.
  2605. a : ndarray
  2606. Denominator polynomial coefficients.
  2607. See Also
  2608. --------
  2609. residue, invresz, unique_roots
  2610. """
  2611. r = np.atleast_1d(r)
  2612. p = np.atleast_1d(p)
  2613. k = np.trim_zeros(np.atleast_1d(k), 'f')
  2614. unique_poles, multiplicity = _group_poles(p, tol, rtype)
  2615. factors, denominator = _compute_factors(unique_poles, multiplicity,
  2616. include_powers=True)
  2617. if len(k) == 0:
  2618. numerator = 0
  2619. else:
  2620. numerator = np.polymul(k, denominator)
  2621. for residue, factor in zip(r, factors):
  2622. numerator = np.polyadd(numerator, residue * factor)
  2623. return numerator, denominator
  2624. def _compute_factors(roots, multiplicity, include_powers=False):
  2625. """Compute the total polynomial divided by factors for each root."""
  2626. current = np.array([1])
  2627. suffixes = [current]
  2628. for pole, mult in zip(roots[-1:0:-1], multiplicity[-1:0:-1]):
  2629. monomial = np.array([1, -pole])
  2630. for _ in range(mult):
  2631. current = np.polymul(current, monomial)
  2632. suffixes.append(current)
  2633. suffixes = suffixes[::-1]
  2634. factors = []
  2635. current = np.array([1])
  2636. for pole, mult, suffix in zip(roots, multiplicity, suffixes):
  2637. monomial = np.array([1, -pole])
  2638. block = []
  2639. for i in range(mult):
  2640. if i == 0 or include_powers:
  2641. block.append(np.polymul(current, suffix))
  2642. current = np.polymul(current, monomial)
  2643. factors.extend(reversed(block))
  2644. return factors, current
  2645. def _compute_residues(poles, multiplicity, numerator):
  2646. denominator_factors, _ = _compute_factors(poles, multiplicity)
  2647. numerator = numerator.astype(poles.dtype)
  2648. residues = []
  2649. for pole, mult, factor in zip(poles, multiplicity,
  2650. denominator_factors):
  2651. if mult == 1:
  2652. residues.append(np.polyval(numerator, pole) /
  2653. np.polyval(factor, pole))
  2654. else:
  2655. numer = numerator.copy()
  2656. monomial = np.array([1, -pole])
  2657. factor, d = np.polydiv(factor, monomial)
  2658. block = []
  2659. for _ in range(mult):
  2660. numer, n = np.polydiv(numer, monomial)
  2661. r = n[0] / d[0]
  2662. numer = np.polysub(numer, r * factor)
  2663. block.append(r)
  2664. residues.extend(reversed(block))
  2665. return np.asarray(residues)
  2666. def residue(b, a, tol=1e-3, rtype='avg'):
  2667. """Compute partial-fraction expansion of b(s) / a(s).
  2668. If `M` is the degree of numerator `b` and `N` the degree of denominator
  2669. `a`::
  2670. b(s) b[0] s**(M) + b[1] s**(M-1) + ... + b[M]
  2671. H(s) = ------ = ------------------------------------------
  2672. a(s) a[0] s**(N) + a[1] s**(N-1) + ... + a[N]
  2673. then the partial-fraction expansion H(s) is defined as::
  2674. r[0] r[1] r[-1]
  2675. = -------- + -------- + ... + --------- + k(s)
  2676. (s-p[0]) (s-p[1]) (s-p[-1])
  2677. If there are any repeated roots (closer together than `tol`), then H(s)
  2678. has terms like::
  2679. r[i] r[i+1] r[i+n-1]
  2680. -------- + ----------- + ... + -----------
  2681. (s-p[i]) (s-p[i])**2 (s-p[i])**n
  2682. This function is used for polynomials in positive powers of s or z,
  2683. such as analog filters or digital filters in controls engineering. For
  2684. negative powers of z (typical for digital filters in DSP), use `residuez`.
  2685. See Notes for details about the algorithm.
  2686. Parameters
  2687. ----------
  2688. b : array_like
  2689. Numerator polynomial coefficients.
  2690. a : array_like
  2691. Denominator polynomial coefficients.
  2692. tol : float, optional
  2693. The tolerance for two roots to be considered equal in terms of
  2694. the distance between them. Default is 1e-3. See `unique_roots`
  2695. for further details.
  2696. rtype : {'avg', 'min', 'max'}, optional
  2697. Method for computing a root to represent a group of identical roots.
  2698. Default is 'avg'. See `unique_roots` for further details.
  2699. Returns
  2700. -------
  2701. r : ndarray
  2702. Residues corresponding to the poles. For repeated poles, the residues
  2703. are ordered to correspond to ascending by power fractions.
  2704. p : ndarray
  2705. Poles ordered by magnitude in ascending order.
  2706. k : ndarray
  2707. Coefficients of the direct polynomial term.
  2708. See Also
  2709. --------
  2710. invres, residuez, numpy.poly, unique_roots
  2711. Notes
  2712. -----
  2713. The "deflation through subtraction" algorithm is used for
  2714. computations --- method 6 in [1]_.
  2715. The form of partial fraction expansion depends on poles multiplicity in
  2716. the exact mathematical sense. However there is no way to exactly
  2717. determine multiplicity of roots of a polynomial in numerical computing.
  2718. Thus you should think of the result of `residue` with given `tol` as
  2719. partial fraction expansion computed for the denominator composed of the
  2720. computed poles with empirically determined multiplicity. The choice of
  2721. `tol` can drastically change the result if there are close poles.
  2722. References
  2723. ----------
  2724. .. [1] J. F. Mahoney, B. D. Sivazlian, "Partial fractions expansion: a
  2725. review of computational methodology and efficiency", Journal of
  2726. Computational and Applied Mathematics, Vol. 9, 1983.
  2727. """
  2728. b = np.asarray(b)
  2729. a = np.asarray(a)
  2730. if (np.issubdtype(b.dtype, np.complexfloating)
  2731. or np.issubdtype(a.dtype, np.complexfloating)):
  2732. b = b.astype(complex)
  2733. a = a.astype(complex)
  2734. else:
  2735. b = b.astype(float)
  2736. a = a.astype(float)
  2737. b = np.trim_zeros(np.atleast_1d(b), 'f')
  2738. a = np.trim_zeros(np.atleast_1d(a), 'f')
  2739. if a.size == 0:
  2740. raise ValueError("Denominator `a` is zero.")
  2741. poles = np.roots(a)
  2742. if b.size == 0:
  2743. return np.zeros(poles.shape), _cmplx_sort(poles)[0], np.array([])
  2744. if len(b) < len(a):
  2745. k = np.empty(0)
  2746. else:
  2747. k, b = np.polydiv(b, a)
  2748. unique_poles, multiplicity = unique_roots(poles, tol=tol, rtype=rtype)
  2749. unique_poles, order = _cmplx_sort(unique_poles)
  2750. multiplicity = multiplicity[order]
  2751. residues = _compute_residues(unique_poles, multiplicity, b)
  2752. index = 0
  2753. for pole, mult in zip(unique_poles, multiplicity):
  2754. poles[index:index + mult] = pole
  2755. index += mult
  2756. return residues / a[0], poles, k
  2757. def residuez(b, a, tol=1e-3, rtype='avg'):
  2758. """Compute partial-fraction expansion of b(z) / a(z).
  2759. If `M` is the degree of numerator `b` and `N` the degree of denominator
  2760. `a`::
  2761. b(z) b[0] + b[1] z**(-1) + ... + b[M] z**(-M)
  2762. H(z) = ------ = ------------------------------------------
  2763. a(z) a[0] + a[1] z**(-1) + ... + a[N] z**(-N)
  2764. then the partial-fraction expansion H(z) is defined as::
  2765. r[0] r[-1]
  2766. = --------------- + ... + ---------------- + k[0] + k[1]z**(-1) ...
  2767. (1-p[0]z**(-1)) (1-p[-1]z**(-1))
  2768. If there are any repeated roots (closer than `tol`), then the partial
  2769. fraction expansion has terms like::
  2770. r[i] r[i+1] r[i+n-1]
  2771. -------------- + ------------------ + ... + ------------------
  2772. (1-p[i]z**(-1)) (1-p[i]z**(-1))**2 (1-p[i]z**(-1))**n
  2773. This function is used for polynomials in negative powers of z,
  2774. such as digital filters in DSP. For positive powers, use `residue`.
  2775. See Notes of `residue` for details about the algorithm.
  2776. Parameters
  2777. ----------
  2778. b : array_like
  2779. Numerator polynomial coefficients.
  2780. a : array_like
  2781. Denominator polynomial coefficients.
  2782. tol : float, optional
  2783. The tolerance for two roots to be considered equal in terms of
  2784. the distance between them. Default is 1e-3. See `unique_roots`
  2785. for further details.
  2786. rtype : {'avg', 'min', 'max'}, optional
  2787. Method for computing a root to represent a group of identical roots.
  2788. Default is 'avg'. See `unique_roots` for further details.
  2789. Returns
  2790. -------
  2791. r : ndarray
  2792. Residues corresponding to the poles. For repeated poles, the residues
  2793. are ordered to correspond to ascending by power fractions.
  2794. p : ndarray
  2795. Poles ordered by magnitude in ascending order.
  2796. k : ndarray
  2797. Coefficients of the direct polynomial term.
  2798. See Also
  2799. --------
  2800. invresz, residue, unique_roots
  2801. """
  2802. b = np.asarray(b)
  2803. a = np.asarray(a)
  2804. if (np.issubdtype(b.dtype, np.complexfloating)
  2805. or np.issubdtype(a.dtype, np.complexfloating)):
  2806. b = b.astype(complex)
  2807. a = a.astype(complex)
  2808. else:
  2809. b = b.astype(float)
  2810. a = a.astype(float)
  2811. b = np.trim_zeros(np.atleast_1d(b), 'b')
  2812. a = np.trim_zeros(np.atleast_1d(a), 'b')
  2813. if a.size == 0:
  2814. raise ValueError("Denominator `a` is zero.")
  2815. elif a[0] == 0:
  2816. raise ValueError("First coefficient of determinant `a` must be "
  2817. "non-zero.")
  2818. poles = np.roots(a)
  2819. if b.size == 0:
  2820. return np.zeros(poles.shape), _cmplx_sort(poles)[0], np.array([])
  2821. b_rev = b[::-1]
  2822. a_rev = a[::-1]
  2823. if len(b_rev) < len(a_rev):
  2824. k_rev = np.empty(0)
  2825. else:
  2826. k_rev, b_rev = np.polydiv(b_rev, a_rev)
  2827. unique_poles, multiplicity = unique_roots(poles, tol=tol, rtype=rtype)
  2828. unique_poles, order = _cmplx_sort(unique_poles)
  2829. multiplicity = multiplicity[order]
  2830. residues = _compute_residues(1 / unique_poles, multiplicity, b_rev)
  2831. index = 0
  2832. powers = np.empty(len(residues), dtype=int)
  2833. for pole, mult in zip(unique_poles, multiplicity):
  2834. poles[index:index + mult] = pole
  2835. powers[index:index + mult] = 1 + np.arange(mult)
  2836. index += mult
  2837. residues *= (-poles) ** powers / a_rev[0]
  2838. return residues, poles, k_rev[::-1]
  2839. def _group_poles(poles, tol, rtype):
  2840. if rtype in ['max', 'maximum']:
  2841. reduce = np.max
  2842. elif rtype in ['min', 'minimum']:
  2843. reduce = np.min
  2844. elif rtype in ['avg', 'mean']:
  2845. reduce = np.mean
  2846. else:
  2847. raise ValueError("`rtype` must be one of "
  2848. "{'max', 'maximum', 'min', 'minimum', 'avg', 'mean'}")
  2849. unique = []
  2850. multiplicity = []
  2851. pole = poles[0]
  2852. block = [pole]
  2853. for i in range(1, len(poles)):
  2854. if abs(poles[i] - pole) <= tol:
  2855. block.append(pole)
  2856. else:
  2857. unique.append(reduce(block))
  2858. multiplicity.append(len(block))
  2859. pole = poles[i]
  2860. block = [pole]
  2861. unique.append(reduce(block))
  2862. multiplicity.append(len(block))
  2863. return np.asarray(unique), np.asarray(multiplicity)
  2864. def invresz(r, p, k, tol=1e-3, rtype='avg'):
  2865. """Compute b(z) and a(z) from partial fraction expansion.
  2866. If `M` is the degree of numerator `b` and `N` the degree of denominator
  2867. `a`::
  2868. b(z) b[0] + b[1] z**(-1) + ... + b[M] z**(-M)
  2869. H(z) = ------ = ------------------------------------------
  2870. a(z) a[0] + a[1] z**(-1) + ... + a[N] z**(-N)
  2871. then the partial-fraction expansion H(z) is defined as::
  2872. r[0] r[-1]
  2873. = --------------- + ... + ---------------- + k[0] + k[1]z**(-1) ...
  2874. (1-p[0]z**(-1)) (1-p[-1]z**(-1))
  2875. If there are any repeated roots (closer than `tol`), then the partial
  2876. fraction expansion has terms like::
  2877. r[i] r[i+1] r[i+n-1]
  2878. -------------- + ------------------ + ... + ------------------
  2879. (1-p[i]z**(-1)) (1-p[i]z**(-1))**2 (1-p[i]z**(-1))**n
  2880. This function is used for polynomials in negative powers of z,
  2881. such as digital filters in DSP. For positive powers, use `invres`.
  2882. Parameters
  2883. ----------
  2884. r : array_like
  2885. Residues corresponding to the poles. For repeated poles, the residues
  2886. must be ordered to correspond to ascending by power fractions.
  2887. p : array_like
  2888. Poles. Equal poles must be adjacent.
  2889. k : array_like
  2890. Coefficients of the direct polynomial term.
  2891. tol : float, optional
  2892. The tolerance for two roots to be considered equal in terms of
  2893. the distance between them. Default is 1e-3. See `unique_roots`
  2894. for further details.
  2895. rtype : {'avg', 'min', 'max'}, optional
  2896. Method for computing a root to represent a group of identical roots.
  2897. Default is 'avg'. See `unique_roots` for further details.
  2898. Returns
  2899. -------
  2900. b : ndarray
  2901. Numerator polynomial coefficients.
  2902. a : ndarray
  2903. Denominator polynomial coefficients.
  2904. See Also
  2905. --------
  2906. residuez, unique_roots, invres
  2907. """
  2908. r = np.atleast_1d(r)
  2909. p = np.atleast_1d(p)
  2910. k = np.trim_zeros(np.atleast_1d(k), 'b')
  2911. unique_poles, multiplicity = _group_poles(p, tol, rtype)
  2912. factors, denominator = _compute_factors(unique_poles, multiplicity,
  2913. include_powers=True)
  2914. if len(k) == 0:
  2915. numerator = 0
  2916. else:
  2917. numerator = np.polymul(k[::-1], denominator[::-1])
  2918. for residue, factor in zip(r, factors):
  2919. numerator = np.polyadd(numerator, residue * factor[::-1])
  2920. return numerator[::-1], denominator
  2921. def resample(x, num, t=None, axis=0, window=None, domain='time'):
  2922. r"""Resample `x` to `num` samples using the Fourier method along the given `axis`.
  2923. The resampling is performed by shortening or zero-padding the FFT of `x`. This has
  2924. the advantages of providing an ideal antialiasing filter and allowing arbitrary
  2925. up- or down-sampling ratios. The main drawback is the requirement of assuming `x`
  2926. to be a periodic signal.
  2927. Parameters
  2928. ----------
  2929. x : array_like
  2930. The input signal made up of equidistant samples. If `x` is a multidimensional
  2931. array, the parameter `axis` specifies the time/frequency axis. It is assumed
  2932. here that ``n_x = x.shape[axis]`` specifies the number of samples and ``T`` the
  2933. sampling interval.
  2934. num : int
  2935. The number of samples of the resampled output signal. It may be larger or
  2936. smaller than ``n_x``.
  2937. t : array_like, optional
  2938. If `t` is not ``None``, then the timestamps of the resampled signal are also
  2939. returned. `t` must contain at least the first two timestamps of the input
  2940. signal `x` (all others are ignored). The timestamps of the output signal are
  2941. determined by ``t[0] + T * n_x / num * np.arange(num)`` with
  2942. ``T = t[1] - t[0]``. Default is ``None``.
  2943. axis : int, optional
  2944. The time/frequency axis of `x` along which the resampling take place.
  2945. The Default is 0.
  2946. window : array_like, callable, string, float, or tuple, optional
  2947. If not ``None``, it specifies a filter in the Fourier domain, which is applied
  2948. before resampling. I.e., the FFT ``X`` of `x` is calculated by
  2949. ``X = W * fft(x, axis=axis)``. ``W`` may be interpreted as a spectral windowing
  2950. function ``W(f_X)`` which consumes the frequencies ``f_X = fftfreq(n_x, T)``.
  2951. If `window` is a 1d array of length ``n_x`` then ``W=window``.
  2952. If `window` is a callable then ``W = window(f_X)``.
  2953. Otherwise, `window` is passed to `~scipy.signal.get_window`, i.e.,
  2954. ``W = fftshift(signal.get_window(window, n_x))``. Default is ``None``.
  2955. domain : 'time' | 'freq', optional
  2956. If set to ``'time'`` (default) then an FFT is applied to `x`, otherwise
  2957. (``'freq'``) it is asssmued that an FFT was already applied, i.e.,
  2958. ``x = fft(x_t, axis=axis)`` with ``x_t`` being the input signal in the time
  2959. domain.
  2960. Returns
  2961. -------
  2962. x_r : ndarray
  2963. The resampled signal made up of `num` samples and sampling interval
  2964. ``T * n_x / num``.
  2965. t_r : ndarray, optional
  2966. The `num` equidistant timestamps of `x_r`.
  2967. This is only returned if paramater `t` is not ``None``.
  2968. See Also
  2969. --------
  2970. decimate : Downsample a (periodic/non-periodic) signal after applying an FIR
  2971. or IIR filter.
  2972. resample_poly : Resample a (periodic/non-periodic) signal using polyphase filtering
  2973. and an FIR filter.
  2974. Notes
  2975. -----
  2976. This function uses the more efficient one-sided FFT, i.e. `~scipy.fft.rfft` /
  2977. `~scipy.fft.irfft`, if `x` is real-valued and in the time domain.
  2978. Else, the two-sided FFT, i.e., `~scipy.fft.fft` / `~scipy.fft.ifft`, is used
  2979. (all FFT functions are taken from the `scipy.fft` module).
  2980. If a `window` is applied to a real-valued `x`, the one-sided spectral windowing
  2981. function is determined by taking the average of the negative and the positive
  2982. frequency component. This ensures that real-valued signals and complex signals with
  2983. zero imaginary part are treated identically. I.e., passing `x` or passing
  2984. ``x.astype(np.complex128)`` produce the same numeric result.
  2985. If the number of input or output samples are prime or have few prime factors, this
  2986. function may be slow due to utilizing FFTs. Consult `~scipy.fft.prev_fast_len` and
  2987. `~scipy.fft.next_fast_len` for determining efficient signals lengths.
  2988. Alternatively, utilizing `resample_poly` to calculate an intermediate signal (as
  2989. illustrated in the example below) can result in significant speed increases.
  2990. `resample` is intended to be used for periodic signals with equidistant sampling
  2991. intervals. For non-periodic signals, `resample_poly` may be a better choice.
  2992. Consult the `scipy.interpolate` module for methods of resampling signals with
  2993. non-constant sampling intervals.
  2994. Examples
  2995. --------
  2996. The following example depicts a signal being up-sampled from 20 samples to 100
  2997. samples. The ringing at the beginning of the up-sampled signal is due to
  2998. interpreting the signal being periodic. The red square in the plot illustrates that
  2999. periodictiy by showing the first sample of the next cycle of the signal.
  3000. >>> import numpy as np
  3001. >>> import matplotlib.pyplot as plt
  3002. >>> from scipy.signal import resample
  3003. ...
  3004. >>> n0, n1 = 20, 100 # number of samples
  3005. >>> t0 = np.linspace(0, 10, n0, endpoint=False) # input time stamps
  3006. >>> x0 = np.cos(-t0**2/6) # input signal
  3007. ...
  3008. >>> x1 = resample(x0, n1) # resampled signal
  3009. >>> t1 = np.linspace(0, 10, n1, endpoint=False) # timestamps of x1
  3010. ...
  3011. >>> fig0, ax0 = plt.subplots(1, 1, tight_layout=True)
  3012. >>> ax0.set_title(f"Resampling $x(t)$ from {n0} samples to {n1} samples")
  3013. >>> ax0.set(xlabel="Time $t$", ylabel="Amplitude $x(t)$")
  3014. >>> ax0.plot(t1, x1, '.-', alpha=.5, label=f"Resampled")
  3015. >>> ax0.plot(t0, x0, 'o-', alpha=.5, label="Original")
  3016. >>> ax0.plot(10, x0[0], 'rs', alpha=.5, label="Next Cycle")
  3017. >>> ax0.legend(loc='best')
  3018. >>> ax0.grid(True)
  3019. >>> plt.show()
  3020. The following example compares this function with a naive `~scipy.fft.rfft` /
  3021. `~scipy.fft.irfft` combination: An input signal with a sampling interval of one
  3022. second is upsampled by a factor of eight. The first figure depicts an odd number of
  3023. input samples whereas the second figure an even number. The upper subplots show the
  3024. signals over time: The input samples are marked by large green dots, the upsampled
  3025. signals by a continuous and a dashed line. The lower subplots show the magnitude
  3026. spectrum: The FFT values of the input are depicted by large green dots, which lie
  3027. in the frequency interval [-0.5, 0.5] Hz, whereas the frequency interval of the
  3028. upsampled signal is [-4, 4] Hz. The continuous green line depicts the upsampled
  3029. spectrum without antialiasing filter, which is a periodic continuation of the input
  3030. spectrum. The blue x's and orange dots depict the FFT values of the signal created
  3031. by the naive approach as well as this function's result.
  3032. >>> import matplotlib.pyplot as plt
  3033. >>> import numpy as np
  3034. >>> from scipy.fft import fftshift, fftfreq, fft, rfft, irfft
  3035. >>> from scipy.signal import resample, resample_poly
  3036. ...
  3037. >>> fac, T0, T1 = 8, 1, 1/8 # upsampling factor and sampling intervals
  3038. >>> for n0 in (15, 16): # number of samples of input signal
  3039. ... n1 = fac * n0 # number of samples of upsampled signal
  3040. ... t0, t1 = T0 * np.arange(n0), T1 * np.arange(n1) # time stamps
  3041. ... x0 = np.zeros(n0) # input signal has two non-zero sample values
  3042. ... x0[n0//2], x0[n0//2+1] = n0 // 2, -(n0 // 2)
  3043. ...
  3044. ... x1n = irfft(rfft(x0), n=n1) * n1 / n0 # naive resampling
  3045. ... x1r = resample(x0, n1) # resample signal
  3046. ...
  3047. ... # Determine magnitude spectrum:
  3048. ... x0_up = np.zeros_like(x1r) # upsampling without antialiasing filter
  3049. ... x0_up[::n1 // n0] = x0
  3050. ... X0, X0_up = (fftshift(fft(x_)) / n0 for x_ in (x0, x0_up))
  3051. ... XX1 = (fftshift(fft(x_)) / n1 for x_ in (x1n, x1r))
  3052. ... f0, f1 = fftshift(fftfreq(n0, T0)), fftshift(fftfreq(n1, T1)) # frequencies
  3053. ... df = f0[1] - f0[0] # frequency resolution
  3054. ...
  3055. ... fig, (ax0, ax1) = plt.subplots(2, 1, layout='constrained', figsize=(5, 4))
  3056. ... ax0.set_title(rf"Upsampling ${fac}\times$ from {n0} to {n1} samples")
  3057. ... ax0.set(xlabel="Time $t$ in seconds", ylabel="Amplitude $x(t)$",
  3058. ... xlim=(0, n1*T1))
  3059. ... ax0.step(t0, x0, 'C2o-', where='post', alpha=.3, linewidth=2,
  3060. ... label="$x_0(t)$ / $X_0(f)$")
  3061. ... for x_, l_ in zip((x1n, x1r), ('C0--', 'C1-')):
  3062. ... ax0.plot(t1, x_, l_, alpha=.5, label=None)
  3063. ... ax0.grid()
  3064. ... ax1.set(xlabel=rf"Frequency $f$ in hertz ($\Delta f = {df*1e3:.1f}\,$mHz)",
  3065. ... ylabel="Magnitude $|X(f)|$", xlim=(-0.7, 0.7))
  3066. ... ax1.axvspan(0.5/T0, f1[-1], color='gray', alpha=.2)
  3067. ... ax1.axvspan(f1[0], -0.5/T0, color='gray', alpha=.2)
  3068. ... ax1.plot(f1, abs(X0_up), 'C2-', f0, abs(X0), 'C2o', alpha=.3, linewidth=2)
  3069. ... for X_, n_, l_ in zip(XX1, ("naive", "resample"), ('C0x--', 'C1.-')):
  3070. ... ax1.plot(f1, abs(X_), l_, alpha=.5, label=n_)
  3071. ... ax1.grid()
  3072. ... fig.legend(loc='outside lower center', ncols=4)
  3073. >>> plt.show()
  3074. The first figure shows that upsampling an odd number of samples produces identical
  3075. results. The second figure illustrates that the signal produced with the naive
  3076. approach (dashed blue line) from an even number of samples does not touch all
  3077. original samples. This deviation is due to `resample` correctly treating unpaired
  3078. frequency bins. I.e., the input `x1` has a bin pair ±0.5 Hz, whereas the output has
  3079. only one unpaired bin at -0.5 Hz, which demands rescaling of that bin pair.
  3080. Generally, special treatment is required if ``n_x != num`` and ``min(n_x, num)`` is
  3081. even. If the bin values at `±m` are zero, obviously, no special treatment is
  3082. needed. Consult the source code of `resample` for details.
  3083. The final example shows how to utilize `resample_poly` to speed up the
  3084. down-sampling: The input signal a non-zero value at :math:`t=0` and is downsampled
  3085. from 19937 to 128 samples. Since 19937 is prime, the FFT is expected to be slow. To
  3086. speed matters up, `resample_poly` is used to downsample first by a factor of ``n0
  3087. // n1 = 155`` and then pass the result to `resample`. Two parameterization of
  3088. `resample_poly` are used: Passing ``padtype='wrap'`` treats the input as being
  3089. periodic wheras the default parametrization performs zero-padding. The upper
  3090. subplot shows the resulting signals over time whereas the lower subplot depicts the
  3091. resulting one-sided magnitude spectra.
  3092. >>> import matplotlib.pyplot as plt
  3093. >>> import numpy as np
  3094. >>> from scipy.fft import rfftfreq, rfft
  3095. >>> from scipy.signal import resample, resample_poly
  3096. ...
  3097. >>> n0 = 19937 # number of input samples - prime
  3098. >>> n1 = 128 # number of output samples - fast FFT length
  3099. >>> T0, T1 = 1/n0, 1/n1 # sampling intervals
  3100. >>> t0, t1 = np.arange(n0)*T0, np.arange(n1)*T1 # time stamps
  3101. ...
  3102. >>> x0 = np.zeros(n0) # Input has one non-zero sample
  3103. >>> x0[0] = n0
  3104. >>>
  3105. >>> x1r = resample(x0, n1) # slow due to n0 being prime
  3106. >>> # This is faster:
  3107. >>> x1p = resample(resample_poly(x0, 1, n0 // n1, padtype='wrap'), n1) # periodic
  3108. >>> x2p = resample(resample_poly(x0, 1, n0 // n1), n1) # with zero-padding
  3109. ...
  3110. >>> X0 = rfft(x0) / n0
  3111. >>> X1r, X1p, X2p = rfft(x1r) / n1, rfft(x1p) / n1, rfft(x2p) / n1
  3112. >>> f0, f1 = rfftfreq(n0, T0), rfftfreq(n1, T1)
  3113. ...
  3114. >>> fig, (ax0, ax1) = plt.subplots(2, 1, layout='constrained', figsize=(5, 4))
  3115. >>> ax0.set_title(f"Dowsampled Impulse response (from {n0} to {n1} samples)")
  3116. >>> ax0.set(xlabel="Time $t$ in seconds", ylabel="Amplitude $x(t)$", xlim=(-T1, 1))
  3117. >>> for x_ in (x1r, x1p, x2p):
  3118. ... ax0.plot(t1, x_, alpha=.5)
  3119. >>> ax0.grid()
  3120. >>> ax1.set(xlabel=rf"Frequency $f$ in hertz ($\Delta f = {f1[1]}\,$Hz)",
  3121. ... ylabel="Magnitude $|X(f)|$", xlim=(0, 0.55/T1))
  3122. >>> ax1.axvspan(0.5/T1, f0[-1], color='gray', alpha=.2)
  3123. >>> ax1.plot(f1, abs(X1r), 'C0.-', alpha=.5, label="resample")
  3124. >>> ax1.plot(f1, abs(X1p), 'C1.-', alpha=.5, label="resample_poly(padtype='wrap')")
  3125. >>> ax1.plot(f1, abs(X2p), 'C2x-', alpha=.5, label="resample_poly")
  3126. >>> ax1.grid()
  3127. >>> fig.legend(loc='outside lower center', ncols=2)
  3128. >>> plt.show()
  3129. The plots show that the results of the "pure" `resample` and the usage of the
  3130. default parameters of `resample_poly` agree well. The periodic padding of
  3131. `resample_poly` (``padtype='wrap'``) on the other hand produces significant
  3132. deviations. This is caused by the disconiuity at the beginning of the signal, for
  3133. which the default filter of `resample_poly` is not suited well. This example
  3134. illustrates that for some use cases, adpating the `resample_poly` parameters may
  3135. be beneficial. `resample` has a big advantage in this regard: It uses the ideal
  3136. antialiasing filter with the maximum bandwidth by default.
  3137. Note that the doubled spectral magnitude at the Nyqist frequency of 64 Hz is due the
  3138. even number of ``n1=128`` output samples, which requires a special treatment as
  3139. discussed in the previous example.
  3140. """
  3141. if domain not in ('time', 'freq'):
  3142. raise ValueError(f"Parameter {domain=} not in ('time', 'freq')!")
  3143. xp = array_namespace(x, t)
  3144. x = xp.asarray(x)
  3145. if x.ndim > 1: # moving active axis to end allows to use `...` in indexing:
  3146. x = xp.moveaxis(x, axis, -1)
  3147. n_x = x.shape[-1] # number of samples along the time/frequency axis
  3148. s_fac = n_x / num # scaling factor represents sample interval dilatation
  3149. m = min(num, n_x) # number of relevant frequency bins
  3150. m2 = m // 2 + 1 # number of relevant frequency bins of a one-sided FFT
  3151. if window is None: # Determine spectral windowing function:
  3152. W = None
  3153. elif callable(window):
  3154. W = window(sp_fft.fftfreq(n_x))
  3155. elif hasattr(window, 'shape'): # must be an array object
  3156. if window.shape != (n_x,):
  3157. raise ValueError(f"{window.shape=} != ({n_x},), i.e., window length " +
  3158. "is not equal to number of frequency bins!")
  3159. W = xp.asarray(window, copy=True) # prevent modifying the function parameters
  3160. else:
  3161. W = sp_fft.fftshift(get_window(window, n_x, xp=xp))
  3162. W = xp.astype(W, xp_default_dtype(xp)) # get_window always returns float64
  3163. if domain == 'time' and not xp.isdtype(x.dtype, 'complex floating'): # use rfft():
  3164. X = sp_fft.rfft(x)
  3165. if W is not None: # fold window, i.e., W1[l] = (W[l] + W[-l]) / 2 for l > 0
  3166. n_X = X.shape[-1]
  3167. W[1:n_X] += xp.flip(W[-n_X+1:]) #W[:-n_X:-1]
  3168. W[1:n_X] /= 2
  3169. X *= W[:n_X] # apply window
  3170. X = X[..., :m2] # extract relevant data
  3171. if m % 2 == 0 and num != n_x: # Account for unpaired bin at m//2:
  3172. X[..., m//2] *= 2 if num < n_x else 0.5
  3173. x_r = sp_fft.irfft(X / s_fac, n=num, overwrite_x=True)
  3174. else: # use standard two-sided FFT:
  3175. X = sp_fft.fft(x) if domain == 'time' else x
  3176. if W is not None:
  3177. X = X * W # writing X *= W could modify parameter x
  3178. Y = xp.zeros(X.shape[:-1] + (num,), dtype=X.dtype)
  3179. Y[..., :m2] = X[..., :m2] # copy part up to Nyquist frequency
  3180. if m2 < m: # == m > 2
  3181. Y[..., m2-m:] = X[..., m2-m:] # copy negative frequency part
  3182. if m % 2 == 0: # Account for unpaired bin at m//2:
  3183. if num < n_x: # down-sampling: unite bin pair into one unpaired bin
  3184. Y[..., -m//2] += X[..., -m//2]
  3185. elif n_x < num: # up-sampling: split unpaired bin into bin pair
  3186. Y[..., m//2] /= 2
  3187. Y[..., num-m//2] = Y[..., m//2]
  3188. x_r = sp_fft.ifft(Y / s_fac, n=num, overwrite_x=True)
  3189. if x_r.ndim > 1: # moving active axis back to original position:
  3190. x_r = xp.moveaxis(x_r, -1, axis)
  3191. if t is not None:
  3192. return x_r, t[0] + (t[1] - t[0]) * s_fac * xp.arange(num)
  3193. return x_r
  3194. def resample_poly(x, up, down, axis=0, window=('kaiser', 5.0),
  3195. padtype='constant', cval=None):
  3196. """
  3197. Resample `x` along the given axis using polyphase filtering.
  3198. The signal `x` is upsampled by the factor `up`, a zero-phase low-pass
  3199. FIR filter is applied, and then it is downsampled by the factor `down`.
  3200. The resulting sample rate is ``up / down`` times the original sample
  3201. rate. By default, values beyond the boundary of the signal are assumed
  3202. to be zero during the filtering step.
  3203. Parameters
  3204. ----------
  3205. x : array_like
  3206. The data to be resampled.
  3207. up : int
  3208. The upsampling factor.
  3209. down : int
  3210. The downsampling factor.
  3211. axis : int, optional
  3212. The axis of `x` that is resampled. Default is 0.
  3213. window : string, tuple, or array_like, optional
  3214. Desired window to use to design the low-pass filter, or the FIR filter
  3215. coefficients to employ. See below for details.
  3216. padtype : string, optional
  3217. `constant`, `line`, `mean`, `median`, `maximum`, `minimum` or any of
  3218. the other signal extension modes supported by `scipy.signal.upfirdn`.
  3219. Changes assumptions on values beyond the boundary. If `constant`,
  3220. assumed to be `cval` (default zero). If `line` assumed to continue a
  3221. linear trend defined by the first and last points. `mean`, `median`,
  3222. `maximum` and `minimum` work as in `np.pad` and assume that the values
  3223. beyond the boundary are the mean, median, maximum or minimum
  3224. respectively of the array along the axis.
  3225. .. versionadded:: 1.4.0
  3226. cval : float, optional
  3227. Value to use if `padtype='constant'`. Default is zero.
  3228. .. versionadded:: 1.4.0
  3229. Returns
  3230. -------
  3231. resampled_x : array
  3232. The resampled array.
  3233. See Also
  3234. --------
  3235. decimate : Downsample the signal after applying an FIR or IIR filter.
  3236. resample : Resample up or down using the FFT method.
  3237. Notes
  3238. -----
  3239. This polyphase method will likely be faster than the Fourier method
  3240. in `scipy.signal.resample` when the number of samples is large and
  3241. prime, or when the number of samples is large and `up` and `down`
  3242. share a large greatest common denominator. The length of the FIR
  3243. filter used will depend on ``max(up, down) // gcd(up, down)``, and
  3244. the number of operations during polyphase filtering will depend on
  3245. the filter length and `down` (see `scipy.signal.upfirdn` for details).
  3246. The argument `window` specifies the FIR low-pass filter design.
  3247. If `window` is an array_like it is assumed to be the FIR filter
  3248. coefficients. Note that the FIR filter is applied after the upsampling
  3249. step, so it should be designed to operate on a signal at a sampling
  3250. frequency higher than the original by a factor of `up//gcd(up, down)`.
  3251. This function's output will be centered with respect to this array, so it
  3252. is best to pass a symmetric filter with an odd number of samples if, as
  3253. is usually the case, a zero-phase filter is desired.
  3254. For any other type of `window`, the functions `scipy.signal.get_window`
  3255. and `scipy.signal.firwin` are called to generate the appropriate filter
  3256. coefficients.
  3257. The first sample of the returned vector is the same as the first
  3258. sample of the input vector. The spacing between samples is changed
  3259. from ``dx`` to ``dx * down / float(up)``.
  3260. Examples
  3261. --------
  3262. By default, the end of the resampled data rises to meet the first
  3263. sample of the next cycle for the FFT method, and gets closer to zero
  3264. for the polyphase method:
  3265. >>> import numpy as np
  3266. >>> from scipy import signal
  3267. >>> import matplotlib.pyplot as plt
  3268. >>> x = np.linspace(0, 10, 20, endpoint=False)
  3269. >>> y = np.cos(-x**2/6.0)
  3270. >>> f_fft = signal.resample(y, 100)
  3271. >>> f_poly = signal.resample_poly(y, 100, 20)
  3272. >>> xnew = np.linspace(0, 10, 100, endpoint=False)
  3273. >>> plt.plot(xnew, f_fft, 'b.-', xnew, f_poly, 'r.-')
  3274. >>> plt.plot(x, y, 'ko-')
  3275. >>> plt.plot(10, y[0], 'bo', 10, 0., 'ro') # boundaries
  3276. >>> plt.legend(['resample', 'resamp_poly', 'data'], loc='best')
  3277. >>> plt.show()
  3278. This default behaviour can be changed by using the padtype option:
  3279. >>> N = 5
  3280. >>> x = np.linspace(0, 1, N, endpoint=False)
  3281. >>> y = 2 + x**2 - 1.7*np.sin(x) + .2*np.cos(11*x)
  3282. >>> y2 = 1 + x**3 + 0.1*np.sin(x) + .1*np.cos(11*x)
  3283. >>> Y = np.stack([y, y2], axis=-1)
  3284. >>> up = 4
  3285. >>> xr = np.linspace(0, 1, N*up, endpoint=False)
  3286. >>> y2 = signal.resample_poly(Y, up, 1, padtype='constant')
  3287. >>> y3 = signal.resample_poly(Y, up, 1, padtype='mean')
  3288. >>> y4 = signal.resample_poly(Y, up, 1, padtype='line')
  3289. >>> for i in [0,1]:
  3290. ... plt.figure()
  3291. ... plt.plot(xr, y4[:,i], 'g.', label='line')
  3292. ... plt.plot(xr, y3[:,i], 'y.', label='mean')
  3293. ... plt.plot(xr, y2[:,i], 'r.', label='constant')
  3294. ... plt.plot(x, Y[:,i], 'k-')
  3295. ... plt.legend()
  3296. >>> plt.show()
  3297. """
  3298. xp = array_namespace(x)
  3299. x = xp.asarray(x)
  3300. if up != int(up):
  3301. raise ValueError("up must be an integer")
  3302. if down != int(down):
  3303. raise ValueError("down must be an integer")
  3304. up = int(up)
  3305. down = int(down)
  3306. if up < 1 or down < 1:
  3307. raise ValueError('up and down must be >= 1')
  3308. if cval is not None and padtype != 'constant':
  3309. raise ValueError('cval has no effect when padtype is ', padtype)
  3310. # Determine our up and down factors
  3311. # Use a rational approximation to save computation time on really long
  3312. # signals
  3313. g_ = math.gcd(up, down)
  3314. up //= g_
  3315. down //= g_
  3316. if up == down == 1:
  3317. return xp.asarray(x, copy=True)
  3318. n_in = x.shape[axis]
  3319. n_out = n_in * up
  3320. n_out = n_out // down + bool(n_out % down)
  3321. if isinstance(window, list) or is_array_api_obj(window):
  3322. window = xp.asarray(window, copy=True) # force a copy (we modify `window`)
  3323. if window.ndim > 1:
  3324. raise ValueError('window must be 1-D')
  3325. half_len = (xp_size(window) - 1) // 2
  3326. h = window
  3327. else:
  3328. # Design a linear-phase low-pass FIR filter
  3329. max_rate = max(up, down)
  3330. f_c = 1. / max_rate # cutoff of FIR filter (rel. to Nyquist)
  3331. half_len = 10 * max_rate # reasonable cutoff for sinc-like function
  3332. if xp.isdtype(x.dtype, ("real floating", "complex floating")):
  3333. h = firwin(2 * half_len + 1, f_c, window=window)
  3334. h = xp.asarray(h, dtype=x.dtype) # match dtype of x
  3335. else:
  3336. h = firwin(2 * half_len + 1, f_c, window=window)
  3337. h = xp.asarray(h)
  3338. h *= up
  3339. # Zero-pad our filter to put the output samples at the center
  3340. n_pre_pad = (down - half_len % down)
  3341. n_post_pad = 0
  3342. n_pre_remove = (half_len + n_pre_pad) // down
  3343. # We should rarely need to do this given our filter lengths...
  3344. while _output_len(h.shape[0] + n_pre_pad + n_post_pad, n_in,
  3345. up, down) < n_out + n_pre_remove:
  3346. n_post_pad += 1
  3347. h = xp.concat((xp.zeros(n_pre_pad, dtype=h.dtype), h,
  3348. xp.zeros(n_post_pad, dtype=h.dtype)))
  3349. n_pre_remove_end = n_pre_remove + n_out
  3350. # XXX consider using stats.quantile, which is natively Array API compatible
  3351. def _median(x, *args, **kwds):
  3352. return xp.asarray(np.median(np.asarray(x), *args, **kwds))
  3353. # Remove background depending on the padtype option
  3354. funcs = {'mean': xp.mean, 'median': _median,
  3355. 'minimum': xp.min, 'maximum': xp.max}
  3356. upfirdn_kwargs = {'mode': 'constant', 'cval': 0}
  3357. if padtype in funcs:
  3358. background_values = funcs[padtype](x, axis=axis, keepdims=True)
  3359. elif padtype in _upfirdn_modes:
  3360. upfirdn_kwargs = {'mode': padtype}
  3361. if padtype == 'constant':
  3362. if cval is None:
  3363. cval = 0
  3364. upfirdn_kwargs['cval'] = cval
  3365. else:
  3366. raise ValueError(
  3367. 'padtype must be one of: maximum, mean, median, minimum, ' +
  3368. ', '.join(_upfirdn_modes))
  3369. if padtype in funcs:
  3370. x = x - background_values
  3371. # filter then remove excess
  3372. y = upfirdn(h, x, up, down, axis=axis, **upfirdn_kwargs)
  3373. keep = [slice(None), ]*x.ndim
  3374. keep[axis] = slice(n_pre_remove, n_pre_remove_end)
  3375. y_keep = y[tuple(keep)]
  3376. # Add background back
  3377. if padtype in funcs:
  3378. y_keep += background_values
  3379. return y_keep
  3380. def _angle(z, xp):
  3381. """np.angle replacement
  3382. """
  3383. # XXX: https://github.com/data-apis/array-api/issues/595
  3384. zimag = xp.imag(z) if xp.isdtype(z.dtype, 'complex floating') else 0.
  3385. a = xp.atan2(zimag, xp.real(z))
  3386. return a
  3387. def vectorstrength(events, period):
  3388. '''
  3389. Determine the vector strength of the events corresponding to the given
  3390. period.
  3391. The vector strength is a measure of phase synchrony, how well the
  3392. timing of the events is synchronized to a single period of a periodic
  3393. signal.
  3394. If multiple periods are used, calculate the vector strength of each.
  3395. This is called the "resonating vector strength".
  3396. Parameters
  3397. ----------
  3398. events : 1D array_like
  3399. An array of time points containing the timing of the events.
  3400. period : float or array_like
  3401. The period of the signal that the events should synchronize to.
  3402. The period is in the same units as `events`. It can also be an array
  3403. of periods, in which case the outputs are arrays of the same length.
  3404. Returns
  3405. -------
  3406. strength : float or 1D array
  3407. The strength of the synchronization. 1.0 is perfect synchronization
  3408. and 0.0 is no synchronization. If `period` is an array, this is also
  3409. an array with each element containing the vector strength at the
  3410. corresponding period.
  3411. phase : float or array
  3412. The phase that the events are most strongly synchronized to in radians.
  3413. If `period` is an array, this is also an array with each element
  3414. containing the phase for the corresponding period.
  3415. References
  3416. ----------
  3417. van Hemmen, JL, Longtin, A, and Vollmayr, AN. Testing resonating vector
  3418. strength: Auditory system, electric fish, and noise.
  3419. Chaos 21, 047508 (2011);
  3420. :doi:`10.1063/1.3670512`.
  3421. van Hemmen, JL. Vector strength after Goldberg, Brown, and von Mises:
  3422. biological and mathematical perspectives. Biol Cybern.
  3423. 2013 Aug;107(4):385-96. :doi:`10.1007/s00422-013-0561-7`.
  3424. van Hemmen, JL and Vollmayr, AN. Resonating vector strength: what happens
  3425. when we vary the "probing" frequency while keeping the spike times
  3426. fixed. Biol Cybern. 2013 Aug;107(4):491-94.
  3427. :doi:`10.1007/s00422-013-0560-8`.
  3428. '''
  3429. xp = array_namespace(events, period)
  3430. events = xp.asarray(events)
  3431. period = xp.asarray(period)
  3432. if xp.isdtype(period.dtype, 'integral'):
  3433. period = xp.astype(period, xp.float64)
  3434. if events.ndim > 1:
  3435. raise ValueError('events cannot have dimensions more than 1')
  3436. if period.ndim > 1:
  3437. raise ValueError('period cannot have dimensions more than 1')
  3438. # we need to know later if period was originally a scalar
  3439. scalarperiod = not period.ndim
  3440. events = xpx.atleast_nd(events, ndim=2, xp=xp)
  3441. period = xpx.atleast_nd(period, ndim=2, xp=xp)
  3442. if xp.any(period <= 0):
  3443. raise ValueError('periods must be positive')
  3444. # this converts the times to vectors
  3445. events_ = xp.astype(events, period.dtype)
  3446. vectors = xp.exp(2j * (xp.pi / period.T @ events_))
  3447. # the vector strength is just the magnitude of the mean of the vectors
  3448. # the vector phase is the angle of the mean of the vectors
  3449. vectormean = xp.mean(vectors, axis=1)
  3450. strength = xp.abs(vectormean)
  3451. phase = _angle(vectormean, xp)
  3452. # if the original period was a scalar, return scalars
  3453. if scalarperiod:
  3454. strength = strength[0]
  3455. phase = phase[0]
  3456. return strength, phase
  3457. def detrend(data: np.ndarray, axis: int = -1,
  3458. type: Literal['linear', 'constant'] = 'linear',
  3459. bp: ArrayLike | int = 0, overwrite_data: bool = False) -> np.ndarray:
  3460. r"""Remove linear or constant trend along axis from data.
  3461. Parameters
  3462. ----------
  3463. data : array_like
  3464. The input data.
  3465. axis : int, optional
  3466. The axis along which to detrend the data. By default this is the
  3467. last axis (-1).
  3468. type : {'linear', 'constant'}, optional
  3469. The type of detrending. If ``type == 'linear'`` (default),
  3470. the result of a linear least-squares fit to `data` is subtracted
  3471. from `data`.
  3472. If ``type == 'constant'``, only the mean of `data` is subtracted.
  3473. bp : array_like of ints, optional
  3474. A sequence of break points. If given, an individual linear fit is
  3475. performed for each part of `data` between two break points.
  3476. Break points are specified as indices into `data`. This parameter
  3477. only has an effect when ``type == 'linear'``.
  3478. overwrite_data: bool, optional
  3479. If True, allow in place detrending and avoid a copy. Default is
  3480. False. In place modification applies only if ``type == 'linear'``
  3481. and `data` is of the floating point dtype ``float32``, ``float64``,
  3482. ``complex64`` or ``complex128``.
  3483. Returns
  3484. -------
  3485. ret : ndarray
  3486. The detrended input data.
  3487. Notes
  3488. -----
  3489. Detrending can be interpreted as subtracting a least squares fit polynomial:
  3490. Setting the parameter `type` to 'constant' corresponds to fitting a zeroth degree
  3491. polynomial, 'linear' to a first degree polynomial. Consult the example below.
  3492. See Also
  3493. --------
  3494. :meth:`numpy.polynomial.polynomial.Polynomial.fit` : Create least squares fit polynomial.
  3495. Examples
  3496. --------
  3497. The following example detrends the function :math:`x(t) = \sin(\pi t) + 1/4`:
  3498. >>> import matplotlib.pyplot as plt
  3499. >>> import numpy as np
  3500. >>> from scipy.signal import detrend
  3501. ...
  3502. >>> t = np.linspace(-0.5, 0.5, 21)
  3503. >>> x = np.sin(np.pi*t) + 1/4
  3504. ...
  3505. >>> x_d_const = detrend(x, type='constant')
  3506. >>> x_d_linear = detrend(x, type='linear')
  3507. ...
  3508. >>> fig1, ax1 = plt.subplots()
  3509. >>> ax1.set_title(r"Detrending $x(t)=\sin(\pi t) + 1/4$")
  3510. >>> ax1.set(xlabel="t", ylabel="$x(t)$", xlim=(t[0], t[-1]))
  3511. >>> ax1.axhline(y=0, color='black', linewidth=.5)
  3512. >>> ax1.axvline(x=0, color='black', linewidth=.5)
  3513. >>> ax1.plot(t, x, 'C0.-', label="No detrending")
  3514. >>> ax1.plot(t, x_d_const, 'C1x-', label="type='constant'")
  3515. >>> ax1.plot(t, x_d_linear, 'C2+-', label="type='linear'")
  3516. >>> ax1.legend()
  3517. >>> plt.show()
  3518. Alternatively, NumPy's `~numpy.polynomial.polynomial.Polynomial` can be used for
  3519. detrending as well:
  3520. >>> pp0 = np.polynomial.Polynomial.fit(t, x, deg=0) # fit degree 0 polynomial
  3521. >>> np.allclose(x_d_const, x - pp0(t)) # compare with constant detrend
  3522. True
  3523. >>> pp1 = np.polynomial.Polynomial.fit(t, x, deg=1) # fit degree 1 polynomial
  3524. >>> np.allclose(x_d_linear, x - pp1(t)) # compare with linear detrend
  3525. True
  3526. Note that `~numpy.polynomial.polynomial.Polynomial` also allows fitting higher
  3527. degree polynomials. Consult its documentation on how to extract the polynomial
  3528. coefficients.
  3529. """ # noqa: E501
  3530. if type not in ['linear', 'l', 'constant', 'c']:
  3531. raise ValueError("Trend type must be 'linear' or 'constant'.")
  3532. xp = array_namespace(data, bp)
  3533. data = np.asarray(data)
  3534. dtype = data.dtype.char
  3535. if dtype not in 'dfDF':
  3536. dtype = 'd'
  3537. if type in ['constant', 'c']:
  3538. ret = data - np.mean(data, axis, keepdims=True)
  3539. return xp.asarray(ret)
  3540. else:
  3541. dshape = data.shape
  3542. N = dshape[axis]
  3543. bp = np.asarray(bp)
  3544. bp = np.sort(np.unique(np.concatenate(np.atleast_1d(0, bp, N))))
  3545. if np.any(bp > N):
  3546. raise ValueError("Breakpoints must be less than length "
  3547. "of data along given axis.")
  3548. # Restructure data so that axis is along first dimension and
  3549. # all other dimensions are collapsed into second dimension
  3550. rnk = len(dshape)
  3551. if axis < 0:
  3552. axis = axis + rnk
  3553. newdata = np.moveaxis(data, axis, 0)
  3554. newdata_shape = newdata.shape
  3555. newdata = newdata.reshape(N, -1)
  3556. if not overwrite_data:
  3557. newdata = newdata.copy() # make sure we have a copy
  3558. if newdata.dtype.char not in 'dfDF':
  3559. newdata = newdata.astype(dtype)
  3560. # Nreg = len(bp) - 1
  3561. # Find leastsq fit and remove it for each piece
  3562. for m in range(len(bp) - 1):
  3563. Npts = bp[m + 1] - bp[m]
  3564. A = np.ones((Npts, 2), dtype)
  3565. A[:, 0] = np.arange(1, Npts + 1, dtype=dtype) / Npts
  3566. sl = slice(bp[m], bp[m + 1])
  3567. coef, resids, rank, s = linalg.lstsq(A, newdata[sl])
  3568. newdata[sl] = newdata[sl] - A @ coef
  3569. # Put data back in original shape.
  3570. newdata = newdata.reshape(newdata_shape)
  3571. ret = np.moveaxis(newdata, 0, axis)
  3572. return xp.asarray(ret)
  3573. def lfilter_zi(b, a):
  3574. """
  3575. Construct initial conditions for lfilter for step response steady-state.
  3576. Compute an initial state `zi` for the `lfilter` function that corresponds
  3577. to the steady state of the step response.
  3578. A typical use of this function is to set the initial state so that the
  3579. output of the filter starts at the same value as the first element of
  3580. the signal to be filtered.
  3581. Parameters
  3582. ----------
  3583. b, a : array_like (1-D)
  3584. The IIR filter coefficients. See `lfilter` for more
  3585. information.
  3586. Returns
  3587. -------
  3588. zi : 1-D ndarray
  3589. The initial state for the filter.
  3590. See Also
  3591. --------
  3592. lfilter, lfiltic, filtfilt
  3593. Notes
  3594. -----
  3595. A linear filter with order m has a state space representation (A, B, C, D),
  3596. for which the output y of the filter can be expressed as::
  3597. z(n+1) = A*z(n) + B*x(n)
  3598. y(n) = C*z(n) + D*x(n)
  3599. where z(n) is a vector of length m, A has shape (m, m), B has shape
  3600. (m, 1), C has shape (1, m) and D has shape (1, 1) (assuming x(n) is
  3601. a scalar). lfilter_zi solves::
  3602. zi = A*zi + B
  3603. In other words, it finds the initial condition for which the response
  3604. to an input of all ones is a constant.
  3605. Given the filter coefficients `a` and `b`, the state space matrices
  3606. for the transposed direct form II implementation of the linear filter,
  3607. which is the implementation used by scipy.signal.lfilter, are::
  3608. A = scipy.linalg.companion(a).T
  3609. B = b[1:] - a[1:]*b[0]
  3610. assuming ``a[0]`` is 1.0; if ``a[0]`` is not 1, `a` and `b` are first
  3611. divided by a[0].
  3612. Examples
  3613. --------
  3614. The following code creates a lowpass Butterworth filter. Then it
  3615. applies that filter to an array whose values are all 1.0; the
  3616. output is also all 1.0, as expected for a lowpass filter. If the
  3617. `zi` argument of `lfilter` had not been given, the output would have
  3618. shown the transient signal.
  3619. >>> from numpy import array, ones
  3620. >>> from scipy.signal import lfilter, lfilter_zi, butter
  3621. >>> b, a = butter(5, 0.25)
  3622. >>> zi = lfilter_zi(b, a)
  3623. >>> y, zo = lfilter(b, a, ones(10), zi=zi)
  3624. >>> y
  3625. array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
  3626. Another example:
  3627. >>> x = array([0.5, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0])
  3628. >>> y, zf = lfilter(b, a, x, zi=zi*x[0])
  3629. >>> y
  3630. array([ 0.5 , 0.5 , 0.5 , 0.49836039, 0.48610528,
  3631. 0.44399389, 0.35505241])
  3632. Note that the `zi` argument to `lfilter` was computed using
  3633. `lfilter_zi` and scaled by ``x[0]``. Then the output `y` has no
  3634. transient until the input drops from 0.5 to 0.0.
  3635. """
  3636. xp = array_namespace(b, a)
  3637. # FIXME: Can this function be replaced with an appropriate
  3638. # use of lfiltic? For example, when b,a = butter(N,Wn),
  3639. # lfiltic(b, a, y=numpy.ones_like(a), x=numpy.ones_like(b)).
  3640. #
  3641. # We could use scipy.signal.normalize, but it uses warnings in
  3642. # cases where a ValueError is more appropriate, and it allows
  3643. # b to be 2D.
  3644. b = xpx.atleast_nd(xp.asarray(b), ndim=1, xp=xp)
  3645. if b.ndim != 1:
  3646. raise ValueError("Numerator b must be 1-D.")
  3647. a = xpx.atleast_nd(xp.asarray(a), ndim=1, xp=xp)
  3648. if a.ndim != 1:
  3649. raise ValueError("Denominator a must be 1-D.")
  3650. while a.shape[0] > 1 and a[0] == 0.0:
  3651. a = a[1:]
  3652. if xp_size(a) < 1:
  3653. raise ValueError("There must be at least one nonzero `a` coefficient.")
  3654. if a[0] != 1.0:
  3655. # Normalize the coefficients so a[0] == 1.
  3656. b = b / a[0]
  3657. a = a / a[0]
  3658. n = max(a.shape[0], b.shape[0])
  3659. # Pad a or b with zeros so they are the same length.
  3660. if a.shape[0] < n:
  3661. a = xp.concat((a, xp.zeros(n - a.shape[0], dtype=a.dtype)))
  3662. elif b.shape[0] < n:
  3663. b = xp.concat((b, xp.zeros(n - b.shape[0], dtype=b.dtype)))
  3664. dt = xp.result_type(a, b)
  3665. IminusA = np.eye(n - 1) - linalg.companion(a).T
  3666. IminusA = xp.asarray(IminusA, dtype=dt)
  3667. B = b[1:] - a[1:] * b[0]
  3668. # Solve zi = A*zi + B
  3669. zi = xp.linalg.solve(IminusA, B)
  3670. # For future reference: we could also use the following
  3671. # explicit formulas to solve the linear system:
  3672. #
  3673. # zi = np.zeros(n - 1)
  3674. # zi[0] = B.sum() / IminusA[:,0].sum()
  3675. # asum = 1.0
  3676. # csum = 0.0
  3677. # for k in range(1,n-1):
  3678. # asum += a[k]
  3679. # csum += b[k] - a[k]*b[0]
  3680. # zi[k] = asum*zi[0] - csum
  3681. return zi
  3682. def sosfilt_zi(sos):
  3683. """
  3684. Construct initial conditions for sosfilt for step response steady-state.
  3685. Compute an initial state `zi` for the `sosfilt` function that corresponds
  3686. to the steady state of the step response.
  3687. A typical use of this function is to set the initial state so that the
  3688. output of the filter starts at the same value as the first element of
  3689. the signal to be filtered.
  3690. Parameters
  3691. ----------
  3692. sos : array_like
  3693. Array of second-order filter coefficients, must have shape
  3694. ``(n_sections, 6)``. See `sosfilt` for the SOS filter format
  3695. specification.
  3696. Returns
  3697. -------
  3698. zi : ndarray
  3699. Initial conditions suitable for use with ``sosfilt``, shape
  3700. ``(n_sections, 2)``.
  3701. See Also
  3702. --------
  3703. sosfilt, zpk2sos
  3704. Notes
  3705. -----
  3706. .. versionadded:: 0.16.0
  3707. Examples
  3708. --------
  3709. Filter a rectangular pulse that begins at time 0, with and without
  3710. the use of the `zi` argument of `scipy.signal.sosfilt`.
  3711. >>> import numpy as np
  3712. >>> from scipy import signal
  3713. >>> import matplotlib.pyplot as plt
  3714. >>> sos = signal.butter(9, 0.125, output='sos')
  3715. >>> zi = signal.sosfilt_zi(sos)
  3716. >>> x = (np.arange(250) < 100).astype(int)
  3717. >>> f1 = signal.sosfilt(sos, x)
  3718. >>> f2, zo = signal.sosfilt(sos, x, zi=zi)
  3719. >>> plt.plot(x, 'k--', label='x')
  3720. >>> plt.plot(f1, 'b', alpha=0.5, linewidth=2, label='filtered')
  3721. >>> plt.plot(f2, 'g', alpha=0.25, linewidth=4, label='filtered with zi')
  3722. >>> plt.legend(loc='best')
  3723. >>> plt.show()
  3724. """
  3725. xp = array_namespace(sos)
  3726. sos = xp.asarray(sos)
  3727. if sos.ndim != 2 or sos.shape[1] != 6:
  3728. raise ValueError('sos must be shape (n_sections, 6)')
  3729. if xp.isdtype(sos.dtype, ("integral", "bool")):
  3730. sos = xp.astype(sos, xp.float64)
  3731. n_sections = sos.shape[0]
  3732. zi = xp.empty((n_sections, 2), dtype=sos.dtype)
  3733. scale = 1.0
  3734. for section in range(n_sections):
  3735. b = sos[section, :3]
  3736. a = sos[section, 3:]
  3737. zi[section, ...] = scale * lfilter_zi(b, a)
  3738. # If H(z) = B(z)/A(z) is this section's transfer function, then
  3739. # b.sum()/a.sum() is H(1), the gain at omega=0. That's the steady
  3740. # state value of this section's step response.
  3741. scale *= xp.sum(b) / xp.sum(a)
  3742. return zi
  3743. def _filtfilt_gust(b, a, x, axis=-1, irlen=None):
  3744. """Forward-backward IIR filter that uses Gustafsson's method.
  3745. Apply the IIR filter defined by ``(b,a)`` to `x` twice, first forward
  3746. then backward, using Gustafsson's initial conditions [1]_.
  3747. Let ``y_fb`` be the result of filtering first forward and then backward,
  3748. and let ``y_bf`` be the result of filtering first backward then forward.
  3749. Gustafsson's method is to compute initial conditions for the forward
  3750. pass and the backward pass such that ``y_fb == y_bf``.
  3751. Parameters
  3752. ----------
  3753. b : scalar or 1-D ndarray
  3754. Numerator coefficients of the filter.
  3755. a : scalar or 1-D ndarray
  3756. Denominator coefficients of the filter.
  3757. x : ndarray
  3758. Data to be filtered.
  3759. axis : int, optional
  3760. Axis of `x` to be filtered. Default is -1.
  3761. irlen : int or None, optional
  3762. The length of the nonnegligible part of the impulse response.
  3763. If `irlen` is None, or if the length of the signal is less than
  3764. ``2 * irlen``, then no part of the impulse response is ignored.
  3765. Returns
  3766. -------
  3767. y : ndarray
  3768. The filtered data.
  3769. x0 : ndarray
  3770. Initial condition for the forward filter.
  3771. x1 : ndarray
  3772. Initial condition for the backward filter.
  3773. Notes
  3774. -----
  3775. Typically the return values `x0` and `x1` are not needed by the
  3776. caller. The intended use of these return values is in unit tests.
  3777. References
  3778. ----------
  3779. .. [1] F. Gustaffson. Determining the initial states in forward-backward
  3780. filtering. Transactions on Signal Processing, 46(4):988-992, 1996.
  3781. """
  3782. # In the comments, "Gustafsson's paper" and [1] refer to the
  3783. # paper referenced in the docstring.
  3784. b = np.atleast_1d(b)
  3785. a = np.atleast_1d(a)
  3786. order = max(len(b), len(a)) - 1
  3787. if order == 0:
  3788. # The filter is just scalar multiplication, with no state.
  3789. scale = (b[0] / a[0])**2
  3790. y = scale * x
  3791. return y, np.array([]), np.array([])
  3792. if axis != -1 or axis != x.ndim - 1:
  3793. # Move the axis containing the data to the end.
  3794. x = np.swapaxes(x, axis, x.ndim - 1)
  3795. # n is the number of samples in the data to be filtered.
  3796. n = x.shape[-1]
  3797. if irlen is None or n <= 2*irlen:
  3798. m = n
  3799. else:
  3800. m = irlen
  3801. # Create Obs, the observability matrix (called O in the paper).
  3802. # This matrix can be interpreted as the operator that propagates
  3803. # an arbitrary initial state to the output, assuming the input is
  3804. # zero.
  3805. # In Gustafsson's paper, the forward and backward filters are not
  3806. # necessarily the same, so he has both O_f and O_b. We use the same
  3807. # filter in both directions, so we only need O. The same comment
  3808. # applies to S below.
  3809. Obs = np.zeros((m, order))
  3810. zi = np.zeros(order)
  3811. zi[0] = 1
  3812. Obs[:, 0] = lfilter(b, a, np.zeros(m), zi=zi)[0]
  3813. for k in range(1, order):
  3814. Obs[k:, k] = Obs[:-k, 0]
  3815. # Obsr is O^R (Gustafsson's notation for row-reversed O)
  3816. Obsr = Obs[::-1]
  3817. # Create S. S is the matrix that applies the filter to the reversed
  3818. # propagated initial conditions. That is,
  3819. # out = S.dot(zi)
  3820. # is the same as
  3821. # tmp, _ = lfilter(b, a, zeros(), zi=zi) # Propagate ICs.
  3822. # out = lfilter(b, a, tmp[::-1]) # Reverse and filter.
  3823. # Equations (5) & (6) of [1]
  3824. S = lfilter(b, a, Obs[::-1], axis=0)
  3825. # Sr is S^R (row-reversed S)
  3826. Sr = S[::-1]
  3827. # M is [(S^R - O), (O^R - S)]
  3828. if m == n:
  3829. M = np.hstack((Sr - Obs, Obsr - S))
  3830. else:
  3831. # Matrix described in section IV of [1].
  3832. M = np.zeros((2*m, 2*order))
  3833. M[:m, :order] = Sr - Obs
  3834. M[m:, order:] = Obsr - S
  3835. # Naive forward-backward and backward-forward filters.
  3836. # These have large transients because the filters use zero initial
  3837. # conditions.
  3838. y_f = lfilter(b, a, x)
  3839. y_fb = lfilter(b, a, y_f[..., ::-1])[..., ::-1]
  3840. y_b = lfilter(b, a, x[..., ::-1])[..., ::-1]
  3841. y_bf = lfilter(b, a, y_b)
  3842. delta_y_bf_fb = y_bf - y_fb
  3843. if m == n:
  3844. delta = delta_y_bf_fb
  3845. else:
  3846. start_m = delta_y_bf_fb[..., :m]
  3847. end_m = delta_y_bf_fb[..., -m:]
  3848. delta = np.concatenate((start_m, end_m), axis=-1)
  3849. # ic_opt holds the "optimal" initial conditions.
  3850. # The following code computes the result shown in the formula
  3851. # of the paper between equations (6) and (7).
  3852. if delta.ndim == 1:
  3853. ic_opt = linalg.lstsq(M, delta)[0]
  3854. else:
  3855. # Reshape delta so it can be used as an array of multiple
  3856. # right-hand-sides in linalg.lstsq.
  3857. delta2d = delta.reshape(-1, delta.shape[-1]).T
  3858. ic_opt0 = linalg.lstsq(M, delta2d)[0].T
  3859. ic_opt = ic_opt0.reshape(delta.shape[:-1] + (M.shape[-1],))
  3860. # Now compute the filtered signal using equation (7) of [1].
  3861. # First, form [S^R, O^R] and call it W.
  3862. if m == n:
  3863. W = np.hstack((Sr, Obsr))
  3864. else:
  3865. W = np.zeros((2*m, 2*order))
  3866. W[:m, :order] = Sr
  3867. W[m:, order:] = Obsr
  3868. # Equation (7) of [1] says
  3869. # Y_fb^opt = Y_fb^0 + W * [x_0^opt; x_{N-1}^opt]
  3870. # `wic` is (almost) the product on the right.
  3871. # W has shape (m, 2*order), and ic_opt has shape (..., 2*order),
  3872. # so we can't use W.dot(ic_opt). Instead, we dot ic_opt with W.T,
  3873. # so wic has shape (..., m).
  3874. wic = ic_opt.dot(W.T)
  3875. # `wic` is "almost" the product of W and the optimal ICs in equation
  3876. # (7)--if we're using a truncated impulse response (m < n), `wic`
  3877. # contains only the adjustments required for the ends of the signal.
  3878. # Here we form y_opt, taking this into account if necessary.
  3879. y_opt = y_fb
  3880. if m == n:
  3881. y_opt += wic
  3882. else:
  3883. y_opt[..., :m] += wic[..., :m]
  3884. y_opt[..., -m:] += wic[..., -m:]
  3885. x0 = ic_opt[..., :order]
  3886. x1 = ic_opt[..., -order:]
  3887. if axis != -1 or axis != x.ndim - 1:
  3888. # Restore the data axis to its original position.
  3889. x0 = np.swapaxes(x0, axis, x.ndim - 1)
  3890. x1 = np.swapaxes(x1, axis, x.ndim - 1)
  3891. y_opt = np.swapaxes(y_opt, axis, x.ndim - 1)
  3892. return y_opt, x0, x1
  3893. def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
  3894. irlen=None):
  3895. """
  3896. Apply a digital filter forward and backward to a signal.
  3897. This function applies a linear digital filter twice, once forward and
  3898. once backwards. The combined filter has zero phase and a filter order
  3899. twice that of the original.
  3900. The function provides options for handling the edges of the signal.
  3901. The function `sosfiltfilt` (and filter design using ``output='sos'``)
  3902. should be preferred over `filtfilt` for most filtering tasks, as
  3903. second-order sections have fewer numerical problems.
  3904. Parameters
  3905. ----------
  3906. b : (N,) array_like
  3907. The numerator coefficient vector of the filter.
  3908. a : (N,) array_like
  3909. The denominator coefficient vector of the filter. If ``a[0]``
  3910. is not 1, then both `a` and `b` are normalized by ``a[0]``.
  3911. x : array_like
  3912. The array of data to be filtered.
  3913. axis : int, optional
  3914. The axis of `x` to which the filter is applied.
  3915. Default is -1.
  3916. padtype : str or None, optional
  3917. Must be 'odd', 'even', 'constant', or None. This determines the
  3918. type of extension to use for the padded signal to which the filter
  3919. is applied. If `padtype` is None, no padding is used. The default
  3920. is 'odd'.
  3921. padlen : int or None, optional
  3922. The number of elements by which to extend `x` at both ends of
  3923. `axis` before applying the filter. This value must be less than
  3924. ``x.shape[axis] - 1``. ``padlen=0`` implies no padding.
  3925. The default value is ``3 * max(len(a), len(b))``.
  3926. method : str, optional
  3927. Determines the method for handling the edges of the signal, either
  3928. "pad" or "gust". When `method` is "pad", the signal is padded; the
  3929. type of padding is determined by `padtype` and `padlen`, and `irlen`
  3930. is ignored. When `method` is "gust", Gustafsson's method is used,
  3931. and `padtype` and `padlen` are ignored.
  3932. irlen : int or None, optional
  3933. When `method` is "gust", `irlen` specifies the length of the
  3934. impulse response of the filter. If `irlen` is None, no part
  3935. of the impulse response is ignored. For a long signal, specifying
  3936. `irlen` can significantly improve the performance of the filter.
  3937. Returns
  3938. -------
  3939. y : ndarray
  3940. The filtered output with the same shape as `x`.
  3941. See Also
  3942. --------
  3943. sosfiltfilt, lfilter_zi, lfilter, lfiltic, savgol_filter, sosfilt
  3944. Notes
  3945. -----
  3946. When `method` is "pad", the function pads the data along the given axis
  3947. in one of three ways: odd, even or constant. The odd and even extensions
  3948. have the corresponding symmetry about the end point of the data. The
  3949. constant extension extends the data with the values at the end points. On
  3950. both the forward and backward passes, the initial condition of the
  3951. filter is found by using `lfilter_zi` and scaling it by the end point of
  3952. the extended data.
  3953. When `method` is "gust", Gustafsson's method [1]_ is used. Initial
  3954. conditions are chosen for the forward and backward passes so that the
  3955. forward-backward filter gives the same result as the backward-forward
  3956. filter.
  3957. The option to use Gustaffson's method was added in scipy version 0.16.0.
  3958. References
  3959. ----------
  3960. .. [1] F. Gustaffson, "Determining the initial states in forward-backward
  3961. filtering", Transactions on Signal Processing, Vol. 46, pp. 988-992,
  3962. 1996.
  3963. Examples
  3964. --------
  3965. The examples will use several functions from `scipy.signal`.
  3966. >>> import numpy as np
  3967. >>> from scipy import signal
  3968. >>> import matplotlib.pyplot as plt
  3969. First we create a one second signal that is the sum of two pure sine
  3970. waves, with frequencies 5 Hz and 250 Hz, sampled at 2000 Hz.
  3971. >>> t = np.linspace(0, 1.0, 2001)
  3972. >>> xlow = np.sin(2 * np.pi * 5 * t)
  3973. >>> xhigh = np.sin(2 * np.pi * 250 * t)
  3974. >>> x = xlow + xhigh
  3975. Now create a lowpass Butterworth filter with a cutoff of 0.125 times
  3976. the Nyquist frequency, or 125 Hz, and apply it to ``x`` with `filtfilt`.
  3977. The result should be approximately ``xlow``, with no phase shift.
  3978. >>> b, a = signal.butter(8, 0.125)
  3979. >>> y = signal.filtfilt(b, a, x, padlen=150)
  3980. >>> np.abs(y - xlow).max()
  3981. 9.1086182074789912e-06
  3982. We get a fairly clean result for this artificial example because
  3983. the odd extension is exact, and with the moderately long padding,
  3984. the filter's transients have dissipated by the time the actual data
  3985. is reached. In general, transient effects at the edges are
  3986. unavoidable.
  3987. The following example demonstrates the option ``method="gust"``.
  3988. First, create a filter.
  3989. >>> b, a = signal.ellip(4, 0.01, 120, 0.125) # Filter to be applied.
  3990. `sig` is a random input signal to be filtered.
  3991. >>> rng = np.random.default_rng()
  3992. >>> n = 60
  3993. >>> sig = rng.standard_normal(n)**3 + 3*rng.standard_normal(n).cumsum()
  3994. Apply `filtfilt` to `sig`, once using the Gustafsson method, and
  3995. once using padding, and plot the results for comparison.
  3996. >>> fgust = signal.filtfilt(b, a, sig, method="gust")
  3997. >>> fpad = signal.filtfilt(b, a, sig, padlen=50)
  3998. >>> plt.plot(sig, 'k-', label='input')
  3999. >>> plt.plot(fgust, 'b-', linewidth=4, label='gust')
  4000. >>> plt.plot(fpad, 'c-', linewidth=1.5, label='pad')
  4001. >>> plt.legend(loc='best')
  4002. >>> plt.show()
  4003. The `irlen` argument can be used to improve the performance
  4004. of Gustafsson's method.
  4005. Estimate the impulse response length of the filter.
  4006. >>> z, p, k = signal.tf2zpk(b, a)
  4007. >>> eps = 1e-9
  4008. >>> r = np.max(np.abs(p))
  4009. >>> approx_impulse_len = int(np.ceil(np.log(eps) / np.log(r)))
  4010. >>> approx_impulse_len
  4011. 137
  4012. Apply the filter to a longer signal, with and without the `irlen`
  4013. argument. The difference between `y1` and `y2` is small. For long
  4014. signals, using `irlen` gives a significant performance improvement.
  4015. >>> x = rng.standard_normal(4000)
  4016. >>> y1 = signal.filtfilt(b, a, x, method='gust')
  4017. >>> y2 = signal.filtfilt(b, a, x, method='gust', irlen=approx_impulse_len)
  4018. >>> print(np.max(np.abs(y1 - y2)))
  4019. 2.875334415008979e-10
  4020. """
  4021. xp = array_namespace(b, a, x)
  4022. b = np.atleast_1d(np.asarray(b))
  4023. a = np.atleast_1d(np.asarray(a))
  4024. x = np.asarray(x)
  4025. if method not in ["pad", "gust"]:
  4026. raise ValueError("method must be 'pad' or 'gust'.")
  4027. if method == "gust":
  4028. y, z1, z2 = _filtfilt_gust(b, a, x, axis=axis, irlen=irlen)
  4029. return xp.asarray(y)
  4030. # method == "pad"
  4031. edge, ext = _validate_pad(padtype, padlen, x, axis,
  4032. ntaps=max(len(a), len(b)))
  4033. # Get the steady state of the filter's step response.
  4034. zi = lfilter_zi(b, a)
  4035. # Reshape zi and create x0 so that zi*x0 broadcasts
  4036. # to the correct value for the 'zi' keyword argument
  4037. # to lfilter.
  4038. zi_shape = [1] * x.ndim
  4039. zi_shape[axis] = zi.size
  4040. zi = np.reshape(zi, zi_shape)
  4041. x0 = axis_slice(ext, stop=1, axis=axis)
  4042. # Forward filter.
  4043. (y, zf) = lfilter(b, a, ext, axis=axis, zi=zi * x0)
  4044. # Backward filter.
  4045. # Create y0 so zi*y0 broadcasts appropriately.
  4046. y0 = axis_slice(y, start=-1, axis=axis)
  4047. (y, zf) = lfilter(b, a, axis_reverse(y, axis=axis), axis=axis, zi=zi * y0)
  4048. # Reverse y.
  4049. y = axis_reverse(y, axis=axis)
  4050. if edge > 0:
  4051. # Slice the actual signal from the extended signal.
  4052. y = axis_slice(y, start=edge, stop=-edge, axis=axis)
  4053. if is_torch(xp):
  4054. y = y.copy() # pytorch/pytorch#59786 : no negative strides in pytorch
  4055. return xp.asarray(y)
  4056. def _validate_pad(padtype, padlen, x, axis, ntaps):
  4057. """Helper to validate padding for filtfilt"""
  4058. if padtype not in ['even', 'odd', 'constant', None]:
  4059. raise ValueError(f"Unknown value '{padtype}' given to padtype. "
  4060. "padtype must be 'even', 'odd', 'constant', or None.")
  4061. if padtype is None:
  4062. padlen = 0
  4063. if padlen is None:
  4064. # Original padding; preserved for backwards compatibility.
  4065. edge = ntaps * 3
  4066. else:
  4067. edge = padlen
  4068. # x's 'axis' dimension must be bigger than edge.
  4069. if x.shape[axis] <= edge:
  4070. raise ValueError(
  4071. f"The length of the input vector x must be greater than padlen, "
  4072. f"which is {edge}."
  4073. )
  4074. if padtype is not None and edge > 0:
  4075. # Make an extension of length `edge` at each
  4076. # end of the input array.
  4077. if padtype == 'even':
  4078. ext = even_ext(x, edge, axis=axis)
  4079. elif padtype == 'odd':
  4080. ext = odd_ext(x, edge, axis=axis)
  4081. else:
  4082. ext = const_ext(x, edge, axis=axis)
  4083. else:
  4084. ext = x
  4085. return edge, ext
  4086. def _validate_x(x):
  4087. x = np.asarray(x)
  4088. if x.ndim == 0:
  4089. raise ValueError('x must be at least 1-D')
  4090. return x
  4091. def sosfilt(sos, x, axis=-1, zi=None):
  4092. """
  4093. Filter data along one dimension using cascaded second-order sections.
  4094. Filter a data sequence, `x`, using a digital IIR filter defined by
  4095. `sos`.
  4096. Parameters
  4097. ----------
  4098. sos : array_like
  4099. Array of second-order filter coefficients, must have shape
  4100. ``(n_sections, 6)``. Each row corresponds to a second-order
  4101. section, with the first three columns providing the numerator
  4102. coefficients and the last three providing the denominator
  4103. coefficients.
  4104. x : array_like
  4105. An N-dimensional input array.
  4106. axis : int, optional
  4107. The axis of the input data array along which to apply the
  4108. linear filter. The filter is applied to each subarray along
  4109. this axis. Default is -1.
  4110. zi : array_like, optional
  4111. Initial conditions for the cascaded filter delays. It is a (at
  4112. least 2D) vector of shape ``(n_sections, ..., 2, ...)``, where
  4113. ``..., 2, ...`` denotes the shape of `x`, but with ``x.shape[axis]``
  4114. replaced by 2. If `zi` is None or is not given then initial rest
  4115. (i.e. all zeros) is assumed.
  4116. Note that these initial conditions are *not* the same as the initial
  4117. conditions given by `lfiltic` or `lfilter_zi`.
  4118. Returns
  4119. -------
  4120. y : ndarray
  4121. The output of the digital filter.
  4122. zf : ndarray, optional
  4123. If `zi` is None, this is not returned, otherwise, `zf` holds the
  4124. final filter delay values.
  4125. See Also
  4126. --------
  4127. zpk2sos, sos2zpk, sosfilt_zi, sosfiltfilt, freqz_sos
  4128. Notes
  4129. -----
  4130. The filter function is implemented as a series of second-order filters
  4131. with direct-form II transposed structure. It is designed to minimize
  4132. numerical precision errors for high-order filters.
  4133. .. versionadded:: 0.16.0
  4134. Examples
  4135. --------
  4136. Plot a 13th-order filter's impulse response using both `lfilter` and
  4137. `sosfilt`, showing the instability that results from trying to do a
  4138. 13th-order filter in a single stage (the numerical error pushes some poles
  4139. outside of the unit circle):
  4140. >>> import matplotlib.pyplot as plt
  4141. >>> from scipy import signal
  4142. >>> b, a = signal.ellip(13, 0.009, 80, 0.05, output='ba')
  4143. >>> sos = signal.ellip(13, 0.009, 80, 0.05, output='sos')
  4144. >>> x = signal.unit_impulse(700)
  4145. >>> y_tf = signal.lfilter(b, a, x)
  4146. >>> y_sos = signal.sosfilt(sos, x)
  4147. >>> plt.plot(y_tf, 'r', label='TF')
  4148. >>> plt.plot(y_sos, 'k', label='SOS')
  4149. >>> plt.legend(loc='best')
  4150. >>> plt.show()
  4151. """
  4152. xp = array_namespace(sos, x, zi)
  4153. x = _validate_x(x)
  4154. sos, n_sections = _validate_sos(sos)
  4155. x_zi_shape = list(x.shape)
  4156. x_zi_shape[axis] = 2
  4157. x_zi_shape = tuple([n_sections] + x_zi_shape)
  4158. inputs = [sos, x]
  4159. if zi is not None:
  4160. inputs.append(np.asarray(zi))
  4161. dtype = np.result_type(*inputs)
  4162. if dtype.char not in 'fdgFDGO':
  4163. raise NotImplementedError(f"input type '{dtype}' not supported")
  4164. if zi is not None:
  4165. zi = np.asarray(zi, dtype=dtype)
  4166. # make a copy so that we can operate in place
  4167. # NB: 1. use xp_copy to paper over numpy 1/2 copy= keyword
  4168. # 2. make sure the copied zi remains a numpy array
  4169. zi = xp_copy(zi, xp=array_namespace(zi))
  4170. if zi.shape != x_zi_shape:
  4171. raise ValueError(
  4172. f"Invalid zi shape. With axis={axis!r}, "
  4173. f"an input with shape {x.shape!r}, "
  4174. f"and an sos array with {n_sections} sections, zi must have "
  4175. f"shape {x_zi_shape!r}, got {zi.shape!r}."
  4176. )
  4177. return_zi = True
  4178. else:
  4179. zi = np.zeros(x_zi_shape, dtype=dtype)
  4180. return_zi = False
  4181. axis = axis % x.ndim # make positive
  4182. x = np.moveaxis(x, axis, -1)
  4183. zi = np.moveaxis(zi, (0, axis + 1), (-2, -1))
  4184. x_shape, zi_shape = x.shape, zi.shape
  4185. x = np.reshape(x, (-1, x.shape[-1]))
  4186. x = np.array(x, dtype, order='C') # make a copy, can modify in place
  4187. zi = np.ascontiguousarray(np.reshape(zi, (-1, n_sections, 2)))
  4188. sos = sos.astype(dtype, copy=False)
  4189. _sosfilt(sos, x, zi)
  4190. x = x.reshape(x_shape)
  4191. x = np.moveaxis(x, -1, axis)
  4192. if return_zi:
  4193. zi = zi.reshape(zi_shape)
  4194. zi = np.moveaxis(zi, (-2, -1), (0, axis + 1))
  4195. out = (xp.asarray(x), xp.asarray(zi))
  4196. else:
  4197. out = xp.asarray(x)
  4198. return out
  4199. def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
  4200. """
  4201. A forward-backward digital filter using cascaded second-order sections.
  4202. See `filtfilt` for more complete information about this method.
  4203. Parameters
  4204. ----------
  4205. sos : array_like
  4206. Array of second-order filter coefficients, must have shape
  4207. ``(n_sections, 6)``. Each row corresponds to a second-order
  4208. section, with the first three columns providing the numerator
  4209. coefficients and the last three providing the denominator
  4210. coefficients.
  4211. x : array_like
  4212. The array of data to be filtered.
  4213. axis : int, optional
  4214. The axis of `x` to which the filter is applied.
  4215. Default is -1.
  4216. padtype : str or None, optional
  4217. Must be 'odd', 'even', 'constant', or None. This determines the
  4218. type of extension to use for the padded signal to which the filter
  4219. is applied. If `padtype` is None, no padding is used. The default
  4220. is 'odd'.
  4221. padlen : int or None, optional
  4222. The number of elements by which to extend `x` at both ends of
  4223. `axis` before applying the filter. This value must be less than
  4224. ``x.shape[axis] - 1``. ``padlen=0`` implies no padding.
  4225. The default value is::
  4226. 3 * (2 * len(sos) + 1 - min((sos[:, 2] == 0).sum(),
  4227. (sos[:, 5] == 0).sum()))
  4228. The extra subtraction at the end attempts to compensate for poles
  4229. and zeros at the origin (e.g. for odd-order filters) to yield
  4230. equivalent estimates of `padlen` to those of `filtfilt` for
  4231. second-order section filters built with `scipy.signal` functions.
  4232. Returns
  4233. -------
  4234. y : ndarray
  4235. The filtered output with the same shape as `x`.
  4236. See Also
  4237. --------
  4238. filtfilt, sosfilt, sosfilt_zi, freqz_sos
  4239. Notes
  4240. -----
  4241. .. versionadded:: 0.18.0
  4242. Examples
  4243. --------
  4244. >>> import numpy as np
  4245. >>> from scipy.signal import sosfiltfilt, butter
  4246. >>> import matplotlib.pyplot as plt
  4247. >>> rng = np.random.default_rng()
  4248. Create an interesting signal to filter.
  4249. >>> n = 201
  4250. >>> t = np.linspace(0, 1, n)
  4251. >>> x = 1 + (t < 0.5) - 0.25*t**2 + 0.05*rng.standard_normal(n)
  4252. Create a lowpass Butterworth filter, and use it to filter `x`.
  4253. >>> sos = butter(4, 0.125, output='sos')
  4254. >>> y = sosfiltfilt(sos, x)
  4255. For comparison, apply an 8th order filter using `sosfilt`. The filter
  4256. is initialized using the mean of the first four values of `x`.
  4257. >>> from scipy.signal import sosfilt, sosfilt_zi
  4258. >>> sos8 = butter(8, 0.125, output='sos')
  4259. >>> zi = x[:4].mean() * sosfilt_zi(sos8)
  4260. >>> y2, zo = sosfilt(sos8, x, zi=zi)
  4261. Plot the results. Note that the phase of `y` matches the input, while
  4262. `y2` has a significant phase delay.
  4263. >>> plt.plot(t, x, alpha=0.5, label='x(t)')
  4264. >>> plt.plot(t, y, label='y(t)')
  4265. >>> plt.plot(t, y2, label='y2(t)')
  4266. >>> plt.legend(framealpha=1, shadow=True)
  4267. >>> plt.grid(alpha=0.25)
  4268. >>> plt.xlabel('t')
  4269. >>> plt.show()
  4270. """
  4271. xp = array_namespace(sos, x)
  4272. sos, n_sections = _validate_sos(sos)
  4273. x = _validate_x(x)
  4274. # `method` is "pad"...
  4275. ntaps = 2 * n_sections + 1
  4276. ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
  4277. edge, ext = _validate_pad(padtype, padlen, x, axis,
  4278. ntaps=ntaps)
  4279. # These steps follow the same form as filtfilt with modifications
  4280. zi = sosfilt_zi(sos) # shape (n_sections, 2) --> (n_sections, ..., 2, ...)
  4281. zi_shape = [1] * x.ndim
  4282. zi_shape[axis] = 2
  4283. zi = zi.reshape([n_sections] + zi_shape)
  4284. x_0 = axis_slice(ext, stop=1, axis=axis)
  4285. (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x_0)
  4286. y_0 = axis_slice(y, start=-1, axis=axis)
  4287. (y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y_0)
  4288. y = axis_reverse(y, axis=axis)
  4289. if edge > 0:
  4290. y = axis_slice(y, start=edge, stop=-edge, axis=axis)
  4291. return xp.asarray(y)
  4292. def decimate(x, q, n=None, ftype='iir', axis=-1, zero_phase=True):
  4293. """
  4294. Downsample the signal after applying an anti-aliasing filter.
  4295. By default, an order 8 Chebyshev type I filter is used. A 30 point FIR
  4296. filter with Hamming window is used if `ftype` is 'fir'.
  4297. Parameters
  4298. ----------
  4299. x : array_like
  4300. The input signal made up of equidistant samples. If `x` is a multidimensional
  4301. array, the parameter `axis` specifies the time axis.
  4302. q : int
  4303. The downsampling factor, which is a postive integer. When using IIR
  4304. downsampling, it is recommended to call `decimate` multiple times for
  4305. downsampling factors higher than 13.
  4306. n : int, optional
  4307. The order of the filter (1 less than the length for 'fir'). Defaults to
  4308. 8 for 'iir' and 20 times the downsampling factor for 'fir'.
  4309. ftype : str {'iir', 'fir'} or ``dlti`` instance, optional
  4310. If 'iir' or 'fir', specifies the type of lowpass filter. If an instance
  4311. of an `dlti` object, uses that object to filter before downsampling.
  4312. axis : int, optional
  4313. The axis along which to decimate.
  4314. zero_phase : bool, optional
  4315. Prevent phase shift by filtering with `filtfilt` instead of `lfilter`
  4316. when using an IIR filter, and shifting the outputs back by the filter's
  4317. group delay when using an FIR filter. The default value of ``True`` is
  4318. recommended, since a phase shift is generally not desired.
  4319. .. versionadded:: 0.18.0
  4320. Returns
  4321. -------
  4322. y : ndarray
  4323. The down-sampled signal.
  4324. See Also
  4325. --------
  4326. resample : Resample up or down using the FFT method.
  4327. resample_poly : Resample using polyphase filtering and an FIR filter.
  4328. Notes
  4329. -----
  4330. For non-integer downsampling factors, `~scipy.signal.resample` can be used. Consult
  4331. the `scipy.interpolate` module for methods of resampling signals with non-constant
  4332. sampling intervals.
  4333. The ``zero_phase`` keyword was added in 0.18.0.
  4334. The possibility to use instances of ``dlti`` as ``ftype`` was added in
  4335. 0.18.0.
  4336. Examples
  4337. --------
  4338. >>> import numpy as np
  4339. >>> from scipy import signal
  4340. >>> import matplotlib.pyplot as plt
  4341. Define wave parameters.
  4342. >>> wave_duration = 3
  4343. >>> sample_rate = 100
  4344. >>> freq = 2
  4345. >>> q = 5
  4346. Calculate number of samples.
  4347. >>> samples = wave_duration*sample_rate
  4348. >>> samples_decimated = int(samples/q)
  4349. Create cosine wave.
  4350. >>> x = np.linspace(0, wave_duration, samples, endpoint=False)
  4351. >>> y = np.cos(x*np.pi*freq*2)
  4352. Decimate cosine wave.
  4353. >>> ydem = signal.decimate(y, q)
  4354. >>> xnew = np.linspace(0, wave_duration, samples_decimated, endpoint=False)
  4355. Plot original and decimated waves.
  4356. >>> plt.plot(x, y, '.-', xnew, ydem, 'o-')
  4357. >>> plt.xlabel('Time, Seconds')
  4358. >>> plt.legend(['data', 'decimated'], loc='best')
  4359. >>> plt.show()
  4360. """
  4361. x = np.asarray(x)
  4362. q = operator.index(q)
  4363. if n is not None:
  4364. n = operator.index(n)
  4365. result_type = x.dtype
  4366. if not np.issubdtype(result_type, np.inexact) \
  4367. or result_type.type == np.float16:
  4368. # upcast integers and float16 to float64
  4369. result_type = np.float64
  4370. if ftype == 'fir':
  4371. if n is None:
  4372. half_len = 10 * q # reasonable cutoff for our sinc-like function
  4373. n = 2 * half_len
  4374. b, a = firwin(n+1, 1. / q, window='hamming'), 1.
  4375. b = np.asarray(b, dtype=result_type)
  4376. a = np.asarray(a, dtype=result_type)
  4377. elif ftype == 'iir':
  4378. iir_use_sos = True
  4379. if n is None:
  4380. n = 8
  4381. sos = cheby1(n, 0.05, 0.8 / q, output='sos')
  4382. sos = np.asarray(sos, dtype=result_type)
  4383. elif isinstance(ftype, dlti):
  4384. system = ftype._as_zpk()
  4385. if system.poles.shape[0] == 0:
  4386. # FIR
  4387. system = ftype._as_tf()
  4388. b, a = system.num, system.den
  4389. ftype = 'fir'
  4390. elif (any(np.iscomplex(system.poles))
  4391. or any(np.iscomplex(system.poles))
  4392. or np.iscomplex(system.gain)):
  4393. # sosfilt & sosfiltfilt don't handle complex coeffs
  4394. iir_use_sos = False
  4395. system = ftype._as_tf()
  4396. b, a = system.num, system.den
  4397. else:
  4398. iir_use_sos = True
  4399. sos = zpk2sos(system.zeros, system.poles, system.gain)
  4400. sos = np.asarray(sos, dtype=result_type)
  4401. else:
  4402. raise ValueError('invalid ftype')
  4403. sl = [slice(None)] * x.ndim
  4404. if ftype == 'fir':
  4405. b = b / a
  4406. if zero_phase:
  4407. y = resample_poly(x, 1, q, axis=axis, window=b)
  4408. else:
  4409. # upfirdn is generally faster than lfilter by a factor equal to the
  4410. # downsampling factor, since it only calculates the needed outputs
  4411. n_out = x.shape[axis] // q + bool(x.shape[axis] % q)
  4412. y = upfirdn(b, x, up=1, down=q, axis=axis)
  4413. sl[axis] = slice(None, n_out, None)
  4414. else: # IIR case
  4415. if zero_phase:
  4416. if iir_use_sos:
  4417. y = sosfiltfilt(sos, x, axis=axis)
  4418. else:
  4419. y = filtfilt(b, a, x, axis=axis)
  4420. else:
  4421. if iir_use_sos:
  4422. y = sosfilt(sos, x, axis=axis)
  4423. else:
  4424. y = lfilter(b, a, x, axis=axis)
  4425. sl[axis] = slice(None, None, q)
  4426. return y[tuple(sl)]