_distribution_infrastructure.py 232 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818
  1. import functools
  2. from abc import ABC, abstractmethod
  3. from functools import cached_property
  4. from types import GenericAlias
  5. import inspect
  6. import math
  7. import numpy as np
  8. from numpy import inf
  9. from scipy._lib._array_api import xp_capabilities, xp_promote
  10. from scipy._lib._util import _rng_spawn, _RichResult
  11. from scipy._lib._docscrape import ClassDoc, NumpyDocString
  12. import scipy._lib.array_api_extra as xpx
  13. from scipy import special, stats
  14. from scipy.special._ufuncs import _log1mexp
  15. from scipy.integrate import tanhsinh as _tanhsinh, nsum
  16. from scipy.optimize._bracket import _bracket_root, _bracket_minimum
  17. from scipy.optimize._chandrupatla import _chandrupatla, _chandrupatla_minimize
  18. from scipy.stats._probability_distribution import _ProbabilityDistribution
  19. from scipy.stats import qmc
  20. # in case we need to distinguish between None and not specified
  21. # Typically this is used to determine whether the tolerance has been set by the
  22. # user and make a decision about which method to use to evaluate a distribution
  23. # function. Sometimes, the logic does not consider the value of the tolerance,
  24. # only whether this has been defined or not. This is not intended to be the
  25. # best possible logic; the intent is to establish the structure, which can
  26. # be refined in follow-up work.
  27. # See https://github.com/scipy/scipy/pull/21050#discussion_r1714195433.
  28. _null = object()
  29. def _isnull(x):
  30. return type(x) is object or x is None
  31. __all__ = ['make_distribution', 'Mixture', 'order_statistic',
  32. 'truncate', 'abs', 'exp', 'log']
  33. # Could add other policies for broadcasting and edge/out-of-bounds case handling
  34. # For instance, when edge case handling is known not to be needed, it's much
  35. # faster to turn it off, but it might still be nice to have array conversion
  36. # and shaping done so the user doesn't need to be so careful.
  37. _SKIP_ALL = "skip_all"
  38. # Other cache policies would be useful, too.
  39. _NO_CACHE = "no_cache"
  40. # TODO:
  41. # Test sample dtypes
  42. # Add dtype kwarg (especially for distributions with no parameters)
  43. # When drawing endpoint/out-of-bounds values of a parameter, draw them from
  44. # the endpoints/out-of-bounds region of the full `domain`, not `typical`.
  45. # Distributions without shape parameters probably need to accept a `dtype` parameter;
  46. # right now they default to float64. If we have them default to float16, they will
  47. # need to determine result_type when input is not float16 (overhead).
  48. # Test _solve_bounded bracket logic, and decide what to do about warnings
  49. # Get test coverage to 100%
  50. # Raise when distribution method returns wrong shape/dtype?
  51. # Consider ensuring everything is at least 1D for calculations? Would avoid needing
  52. # to sprinkle `np.asarray` throughout due to indescriminate conversion of 0D arrays
  53. # to scalars
  54. # Break up `test_basic`: test each method separately
  55. # Fix `sample` for QMCEngine (implementation does not match documentation)
  56. # When a parameter is invalid, set only the offending parameter to NaN (if possible)?
  57. # `_tanhsinh` special case when there are no abscissae between the limits
  58. # example: cdf of uniform betweeen 1.0 and np.nextafter(1.0, np.inf)
  59. # check behavior of moment methods when moments are undefined/infinite -
  60. # basically OK but needs tests
  61. # investigate use of median
  62. # implement symmetric distribution
  63. # implement composite distribution
  64. # implement wrapped distribution
  65. # profile/optimize
  66. # general cleanup (choose keyword-only parameters)
  67. # compare old/new distribution timing
  68. # make video
  69. # add array API support
  70. # why does dist.ilogcdf(-100) not converge to bound? Check solver response to inf
  71. # _chandrupatla_minimize should not report xm = fm = NaN when it fails
  72. # integrate `logmoment` into `moment`? (Not hard, but enough time and code
  73. # complexity to wait for reviewer feedback before adding.)
  74. # Eliminate bracket_root error "`min <= a < b <= max` must be True"
  75. # Test repr?
  76. # use `median` information to improve integration? In some cases this will
  77. # speed things up. If it's not needed, it may be about twice as slow. I think
  78. # it should depend on the accuracy setting.
  79. # in tests, check reference value against that produced using np.vectorize?
  80. # add `axis` to `ks_1samp`
  81. # User tips for faster execution:
  82. # - pass NumPy arrays
  83. # - pass inputs of floating point type (not integers)
  84. # - prefer NumPy scalars or 0d arrays over other size 1 arrays
  85. # - pass no invalid parameters and disable invalid parameter checks with iv_profile
  86. # - provide a Generator if you're going to do sampling
  87. # add options for drawing parameters: log-spacing
  88. # accuracy benchmark suite
  89. # Should caches be attributes so we can more easily ensure that they are not
  90. # modified when caching is turned off?
  91. # Make ShiftedScaledDistribution more efficient - only process underlying
  92. # distribution parameters as necessary.
  93. # Reconsider `all_inclusive`
  94. # Should process_parameters update kwargs rather than returning? Should we
  95. # update parameters rather than setting to what process_parameters returns?
  96. # Questions:
  97. # 1. I override `__getattr__` so that distribution parameters can be read as
  98. # attributes. We don't want uses to try to change them.
  99. # - To prevent replacements (dist.a = b), I could override `__setattr__`.
  100. # - To prevent in-place modifications, `__getattr__` could return a copy,
  101. # or it could set the WRITEABLE flag of the array to false.
  102. # Which should I do?
  103. # 2. `cache_policy` is supported in several methods where I imagine it being
  104. # useful, but it needs to be tested. Before doing that:
  105. # - What should the default value be?
  106. # - What should the other values be?
  107. # Or should we just eliminate this policy?
  108. # 3. `validation_policy` is supported in a few places, but it should be checked for
  109. # consistency. I have the same questions as for `cache_policy`.
  110. # 4. `tol` is currently notional. I think there needs to be way to set
  111. # separate `atol` and `rtol`. Some ways I imagine it being used:
  112. # - Values can be passed to iterative functions (quadrature, root-finder).
  113. # - To control which "method" of a distribution function is used. For
  114. # example, if `atol` is set to `1e-12`, it may be acceptable to compute
  115. # the complementary CDF as 1 - CDF even when CDF is nearly 1; otherwise,
  116. # a (potentially more time-consuming) method would need to be used.
  117. # I'm looking for unified suggestions for the interface, not ad hoc ideas
  118. # for using tolerances. Suppose the user wants to have more control over
  119. # the tolerances used for each method - how do they specify it? It would
  120. # probably be easiest for the user if they could pass tolerances into each
  121. # method, but it's easiest for us if they can only set it as a property of
  122. # the class. Perhaps a dictionary of tolerance settings?
  123. # 5. I also envision that accuracy estimates should be reported to the user
  124. # somehow. I think my preference would be to return a subclass of an array
  125. # with an `error` attribute - yes, really. But this is unlikely to be
  126. # popular, so what are other ideas? Again, we need a unified vision here,
  127. # not just pointing out difficulties (not all errors are known or easy
  128. # to estimate, what to do when errors could compound, etc.).
  129. # 6. The term "method" is used to refer to public instance functions,
  130. # private instance functions, the "method" string argument, and the means
  131. # of calculating the desired quantity (represented by the string argument).
  132. # For the sake of disambiguation, shall I rename the "method" string to
  133. # "strategy" and refer to the means of calculating the quantity as the
  134. # "strategy"?
  135. # Originally, I planned to filter out invalid distribution parameters;
  136. # distribution implementation functions would always work with "compressed",
  137. # 1D arrays containing only valid distribution parameters. There are two
  138. # problems with this:
  139. # - This essentially requires copying all arrays, even if there is only a
  140. # single invalid parameter combination. This is expensive. Then, to output
  141. # the original size data to the user, we need to "decompress" the arrays
  142. # and fill in the NaNs, so more copying. Unless we branch the code when
  143. # there are no invalid data, these copies happen even in the normal case,
  144. # where there are no invalid parameter combinations. We should not incur
  145. # all this overhead in the normal case.
  146. # - For methods that accept arguments other than distribution parameters, the
  147. # user will pass in arrays that are broadcastable with the original arrays,
  148. # not the compressed arrays. This means that this same sort of invalid
  149. # value detection needs to be repeated every time one of these methods is
  150. # called.
  151. # The much simpler solution is to keep the data uncompressed but to replace
  152. # the invalid parameters and arguments with NaNs (and only if some are
  153. # invalid). With this approach, the copying happens only if/when it is
  154. # needed. Most functions involved in stats distribution calculations don't
  155. # mind NaNs; they just return NaN. The behavior "If x_i is NaN, the result
  156. # is NaN" is explicit in the array API. So this should be fine.
  157. #
  158. # Currently, I am still leaving the parameters and function arguments
  159. # in their broadcasted shapes rather than, say, raveling. The intent
  160. # is to avoid back and forth reshaping. If authors of distributions have
  161. # trouble dealing with N-D arrays, we can reconsider this.
  162. #
  163. # Another important decision is that the *private* methods must accept
  164. # the distribution parameters as inputs rather than relying on these
  165. # cached properties directly (although the public methods typically pass
  166. # the cached values to the private methods). This is because the elementwise
  167. # algorithms for quadrature, differentiation, root-finding, and minimization
  168. # prefer that the input functions are strictly elementwise in the sense
  169. # that the value output for a given input element does not depend on the
  170. # shape of the input or that element's location within the input array.
  171. # When the computation has converged for an element, it is removed from
  172. # the computation entirely. As a result, the shape of the arrays passed to
  173. # the function will almost never be broadcastable with the shape of the
  174. # cached parameter arrays.
  175. #
  176. # I've sprinkled in some optimizations for scalars and same-shape/type arrays
  177. # throughout. The biggest time sinks before were:
  178. # - broadcast_arrays
  179. # - result_dtype
  180. # - is_subdtype
  181. # It is much faster to check whether these are necessary than to do them.
  182. class _Domain(ABC):
  183. r""" Representation of the applicable domain of a parameter or variable.
  184. A `_Domain` object is responsible for storing information about the
  185. domain of a parameter or variable, determining whether a value is within
  186. the domain (`contains`), and providing a text/mathematical representation
  187. of itself (`__str__`). Because the domain of a parameter/variable can have
  188. a complicated relationship with other parameters and variables of a
  189. distribution, `_Domain` itself does not try to represent all possibilities;
  190. in fact, it has no implementation and is meant for subclassing.
  191. Attributes
  192. ----------
  193. symbols : dict
  194. A map from special numerical values to symbols for use in `__str__`
  195. Methods
  196. -------
  197. contains(x)
  198. Determine whether the argument is contained within the domain (True)
  199. or not (False). Used for input validation.
  200. get_numerical_endpoints()
  201. Gets the numerical values of the domain endpoints, which may have been
  202. defined symbolically or through a callable.
  203. __str__()
  204. Returns a text representation of the domain (e.g. ``[0, b)``).
  205. Used for generating documentation.
  206. """
  207. symbols = {np.inf: r"\infty", -np.inf: r"-\infty", np.pi: r"\pi", -np.pi: r"-\pi"}
  208. # generic type compatibility with scipy-stubs
  209. __class_getitem__ = classmethod(GenericAlias)
  210. @abstractmethod
  211. def contains(self, x):
  212. raise NotImplementedError()
  213. @abstractmethod
  214. def draw(self, n):
  215. raise NotImplementedError()
  216. @abstractmethod
  217. def get_numerical_endpoints(self, x):
  218. raise NotImplementedError()
  219. @abstractmethod
  220. def __str__(self):
  221. raise NotImplementedError()
  222. class _Interval(_Domain):
  223. r""" Representation of an interval defined by two endpoints.
  224. Each endpoint may be a finite scalar, positive or negative infinity, or
  225. be given by a single parameter. The domain may include the endpoints or
  226. not.
  227. This class still does not provide an implementation of the __str__ method,
  228. so it is meant for subclassing (e.g. a subclass for domains on the real
  229. line).
  230. Attributes
  231. ----------
  232. symbols : dict
  233. Inherited. A map from special values to symbols for use in `__str__`.
  234. endpoints : 2-tuple of float(s) and/or str(s) and/or callable(s).
  235. A tuple with two values. Each may be either a float (the numerical
  236. value of the endpoints of the domain), a string (the name of the
  237. parameters that will define the endpoint), or a callable taking the
  238. parameters used to define the endpoints of the domain as keyword only
  239. arguments and returning a numerical value for the endpoint.
  240. inclusive : 2-tuple of bools
  241. A tuple with two boolean values; each indicates whether the
  242. corresponding endpoint is included within the domain or not.
  243. Methods
  244. -------
  245. define_parameters(*parameters)
  246. Records any parameters used to define the endpoints of the domain
  247. get_numerical_endpoints(parameter_values)
  248. Gets the numerical values of the domain endpoints, which may have been
  249. defined symbolically or through a callable.
  250. contains(item, parameter_values)
  251. Determines whether the argument is contained within the domain
  252. draw(size, rng, proportions, parameter_values)
  253. Draws random values based on the domain.
  254. """
  255. def __init__(self, endpoints=(-inf, inf), inclusive=(False, False)):
  256. self.symbols = super().symbols.copy()
  257. a, b = endpoints
  258. self.endpoints = np.asarray(a)[()], np.asarray(b)[()]
  259. self.inclusive = inclusive
  260. def define_parameters(self, *parameters):
  261. r""" Records any parameters used to define the endpoints of the domain.
  262. Adds the keyword name of each parameter and its text representation
  263. to the `symbols` attribute as key:value pairs.
  264. For instance, a parameter may be passed into to a distribution's
  265. initializer using the keyword `log_a`, and the corresponding
  266. string representation may be '\log(a)'. To form the text
  267. representation of the domain for use in documentation, the
  268. _Domain object needs to map from the keyword name used in the code
  269. to the string representation.
  270. Returns None, but updates the `symbols` attribute.
  271. Parameters
  272. ----------
  273. *parameters : _Parameter objects
  274. Parameters that may define the endpoints of the domain.
  275. """
  276. new_symbols = {param.name: param.symbol for param in parameters}
  277. self.symbols.update(new_symbols)
  278. def get_numerical_endpoints(self, parameter_values):
  279. r""" Get the numerical values of the domain endpoints.
  280. Domain endpoints may be defined symbolically or through a callable.
  281. This returns numerical values of the endpoints given numerical values for
  282. any variables.
  283. Parameters
  284. ----------
  285. parameter_values : dict
  286. A dictionary that maps between string variable names and numerical
  287. values of parameters, which may define the endpoints.
  288. Returns
  289. -------
  290. a, b : ndarray
  291. Numerical values of the endpoints
  292. """
  293. a, b = self.endpoints
  294. # If `a` (`b`) is a string - the name of the parameter that defines
  295. # the endpoint of the domain - then corresponding numerical values
  296. # will be found in the `parameter_values` dictionary.
  297. # If a callable, it will be executed with `parameter_values` passed as
  298. # keyword arguments, and it will return the numerical values.
  299. # Otherwise, it is itself the array of numerical values of the endpoint.
  300. try:
  301. if callable(a):
  302. a = a(**parameter_values)
  303. else:
  304. a = np.asarray(parameter_values.get(a, a))
  305. if callable(b):
  306. b = b(**parameter_values)
  307. else:
  308. b = np.asarray(parameter_values.get(b, b))
  309. except TypeError as e:
  310. message = ("The endpoints of the distribution are defined by "
  311. "parameters, but their values were not provided. When "
  312. f"using a private method of {self.__class__}, pass "
  313. "all required distribution parameters as keyword "
  314. "arguments.")
  315. raise TypeError(message) from e
  316. # Floating point types are used for even integer parameters.
  317. # Convert to float here to ensure consistency throughout framework.
  318. a, b = xp_promote(a, b, force_floating=True, xp=np)
  319. return a, b
  320. def contains(self, item, parameter_values=None):
  321. r"""Determine whether the argument is contained within the domain.
  322. Parameters
  323. ----------
  324. item : ndarray
  325. The argument
  326. parameter_values : dict
  327. A dictionary that maps between string variable names and numerical
  328. values of parameters, which may define the endpoints.
  329. Returns
  330. -------
  331. out : bool
  332. True if `item` is within the domain; False otherwise.
  333. """
  334. parameter_values = parameter_values or {}
  335. # if self.all_inclusive:
  336. # # Returning a 0d value here makes things much faster.
  337. # # I'm not sure if it's safe, though. If it causes a bug someday,
  338. # # I guess it wasn't.
  339. # # Even if there is no bug because of the shape, it is incorrect for
  340. # # `contains` to return True when there are invalid (e.g. NaN)
  341. # # parameters.
  342. # return np.asarray(True)
  343. a, b = self.get_numerical_endpoints(parameter_values)
  344. left_inclusive, right_inclusive = self.inclusive
  345. in_left = item >= a if left_inclusive else item > a
  346. in_right = item <= b if right_inclusive else item < b
  347. return in_left & in_right
  348. def draw(self, n, type_, min, max, squeezed_base_shape, rng=None):
  349. r""" Draw random values from the domain.
  350. Parameters
  351. ----------
  352. n : int
  353. The number of values to be drawn from the domain.
  354. type_ : str
  355. A string indicating whether the values are
  356. - strictly within the domain ('in'),
  357. - at one of the two endpoints ('on'),
  358. - strictly outside the domain ('out'), or
  359. - NaN ('nan').
  360. min, max : ndarray
  361. The endpoints of the domain.
  362. squeezed_based_shape : tuple of ints
  363. See _RealParameter.draw.
  364. rng : np.Generator
  365. The Generator used for drawing random values.
  366. """
  367. rng = np.random.default_rng(rng)
  368. def ints(*args, **kwargs): return rng.integers(*args, **kwargs, endpoint=True)
  369. uniform = rng.uniform if isinstance(self, _RealInterval) else ints
  370. # get copies of min and max with no nans so that uniform doesn't fail
  371. min_nn, max_nn = min.copy(), max.copy()
  372. i = np.isnan(min_nn) | np.isnan(max_nn)
  373. min_nn[i] = 0
  374. max_nn[i] = 1
  375. shape = (n,) + squeezed_base_shape
  376. if type_ == 'in':
  377. z = uniform(min_nn, max_nn, size=shape)
  378. elif type_ == 'on':
  379. z_on_shape = shape
  380. z = np.ones(z_on_shape)
  381. i = rng.random(size=n) < 0.5
  382. z[i] = min
  383. z[~i] = max
  384. elif type_ == 'out':
  385. z = min_nn - uniform(1, 5, size=shape) # 1, 5 is arbitary; we just want
  386. zr = max_nn + uniform(1, 5, size=shape) # some numbers outside domain
  387. i = rng.random(size=n) < 0.5
  388. z[i] = zr[i]
  389. elif type_ == 'nan':
  390. z = np.full(shape, np.nan)
  391. return z
  392. class _RealInterval(_Interval):
  393. r""" Represents a simply-connected subset of the real line; i.e., an interval
  394. Completes the implementation of the `_Interval` class for intervals
  395. on the real line.
  396. Methods
  397. -------
  398. define_parameters(*parameters)
  399. (Inherited) Records any parameters used to define the endpoints of the
  400. domain.
  401. get_numerical_endpoints(parameter_values)
  402. (Inherited) Gets the numerical values of the domain endpoints, which
  403. may have been defined symbolically.
  404. contains(item, parameter_values)
  405. (Inherited) Determines whether the argument is contained within the
  406. domain
  407. __str__()
  408. Returns a string representation of the domain, e.g. "[a, b)".
  409. """
  410. def __str__(self):
  411. a, b = self.endpoints
  412. a, b = self._get_endpoint_str(a, "f1"), self._get_endpoint_str(b, "f2")
  413. left_inclusive, right_inclusive = self.inclusive
  414. left = "[" if left_inclusive else "("
  415. right = "]" if right_inclusive else ")"
  416. return f"{left}{a}, {b}{right}"
  417. def _get_endpoint_str(self, endpoint, funcname):
  418. if callable(endpoint):
  419. if endpoint.__doc__ is not None:
  420. return endpoint.__doc__
  421. params = inspect.signature(endpoint).parameters.values()
  422. params = [
  423. p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
  424. ]
  425. return f"{funcname}({','.join(params)})"
  426. return self.symbols.get(endpoint, f"{endpoint}")
  427. class _IntegerInterval(_Interval):
  428. r""" Represents an interval of integers
  429. Completes the implementation of the `_Interval` class for simple
  430. domains on the integers.
  431. Methods
  432. -------
  433. define_parameters(*parameters)
  434. (Inherited) Records any parameters used to define the endpoints of the
  435. domain.
  436. get_numerical_endpoints(parameter_values)
  437. (Inherited) Gets the numerical values of the domain endpoints, which
  438. may have been defined symbolically.
  439. contains(item, parameter_values)
  440. (Overridden) Determines whether the argument is contained within the
  441. domain
  442. draw(n, type_, min, max, squeezed_base_shape, rng=None)
  443. (Inherited) Draws random values based on the domain.
  444. __str__()
  445. Returns a string representation of the domain, e.g. "{a, a+1, ..., b-1, b}".
  446. """
  447. def contains(self, item, parameter_values=None):
  448. super_contains = super().contains(item, parameter_values)
  449. integral = (item == np.round(item))
  450. return super_contains & integral
  451. def __str__(self):
  452. a, b = self.endpoints
  453. a = self.symbols.get(a, a)
  454. b = self.symbols.get(b, b)
  455. a_str, b_str = isinstance(a, str), isinstance(b, str)
  456. a_inf = a == r"-\infty" if a_str else np.isinf(a)
  457. b_inf = b == r"\infty" if b_str else np.isinf(b)
  458. # This doesn't work well for cases where ``a`` is floating point
  459. # number large enough that ``nextafter(a, inf) > a + 1``, and
  460. # similarly for ``b`` and nextafter(b, -inf). There may not be any
  461. # distributions fit for SciPy where we would actually need to handle these
  462. # cases though.
  463. ap1 = f"{a} + 1" if a_str else f"{a + 1}"
  464. bm1 = f"{b} - 1" if b_str else f"{b - 1}"
  465. if not a_str and not b_str:
  466. gap = b - a
  467. if gap == 3:
  468. return f"\\{{{a}, {ap1}, {bm1}, {b}\\}}"
  469. if gap == 2:
  470. return f"\\{{{a}, {ap1}, {b}\\}}"
  471. if gap == 1:
  472. return f"\\{{{a}, {b}\\}}"
  473. if gap == 0:
  474. return f"\\{{{a}\\}}"
  475. if not a_inf and b_inf:
  476. ap2 = f"{a} + 2" if a_str else f"{a + 2}"
  477. return f"\\{{{a}, {ap1}, {ap2}, ...\\}}"
  478. if a_inf and not b_inf:
  479. bm2 = f"{b} - 2" if b_str else f"{b - 2}"
  480. return f"\\{{{b}, {bm1}, {bm2}, ...\\}}"
  481. if a_inf and b_inf:
  482. return "\\{..., -2, -1, 0, 1, 2, ...\\}"
  483. return f"\\{{{a}, {ap1}, ..., {bm1}, {b}\\}}"
  484. class _Parameter(ABC):
  485. r""" Representation of a distribution parameter or variable.
  486. A `_Parameter` object is responsible for storing information about a
  487. parameter or variable, providing input validation/standardization of
  488. values passed for that parameter, providing a text/mathematical
  489. representation of the parameter for the documentation (`__str__`), and
  490. drawing random values of itself for testing and benchmarking. It does
  491. not provide a complete implementation of this functionality and is meant
  492. for subclassing.
  493. Attributes
  494. ----------
  495. name : str
  496. The keyword used to pass numerical values of the parameter into the
  497. initializer of the distribution
  498. symbol : str
  499. The text representation of the variable in the documentation. May
  500. include LaTeX.
  501. domain : _Domain
  502. The domain of the parameter for which the distribution is valid.
  503. typical : 2-tuple of floats or strings (consider making a _Domain)
  504. Defines the endpoints of a typical range of values of the parameter.
  505. Used for sampling.
  506. Methods
  507. -------
  508. __str__():
  509. Returns a string description of the variable for use in documentation,
  510. including the keyword used to represent it in code, the symbol used to
  511. represent it mathemtatically, and a description of the valid domain.
  512. draw(size, *, rng, domain, proportions)
  513. Draws random values of the parameter. Proportions of values within
  514. the valid domain, on the endpoints of the domain, outside the domain,
  515. and having value NaN are specified by `proportions`.
  516. validate(x):
  517. Validates and standardizes the argument for use as numerical values
  518. of the parameter.
  519. """
  520. # generic type compatibility with scipy-stubs
  521. __class_getitem__ = classmethod(GenericAlias)
  522. def __init__(self, name, *, domain, symbol=None, typical=None):
  523. self.name = name
  524. self.symbol = symbol or name
  525. self.domain = domain
  526. if typical is not None and not isinstance(typical, _Domain):
  527. typical = domain.__class__(typical)
  528. self.typical = typical or domain
  529. def __str__(self):
  530. r""" String representation of the parameter for use in documentation."""
  531. return f"`{self.name}` for :math:`{self.symbol} \\in {str(self.domain)}`"
  532. def draw(self, size=None, *, rng=None, region='domain', proportions=None,
  533. parameter_values=None):
  534. r""" Draw random values of the parameter for use in testing.
  535. Parameters
  536. ----------
  537. size : tuple of ints
  538. The shape of the array of valid values to be drawn.
  539. rng : np.Generator
  540. The Generator used for drawing random values.
  541. region : str
  542. The region of the `_Parameter` from which to draw. Default is
  543. "domain" (the *full* domain); alternative is "typical". An
  544. enhancement would give a way to interpolate between the two.
  545. proportions : tuple of numbers
  546. A tuple of four non-negative numbers that indicate the expected
  547. relative proportion of elements that:
  548. - are strictly within the domain,
  549. - are at one of the two endpoints,
  550. - are strictly outside the domain, and
  551. - are NaN,
  552. respectively. Default is (1, 0, 0, 0). The number of elements in
  553. each category is drawn from the multinomial distribution with
  554. `np.prod(size)` as the number of trials and `proportions` as the
  555. event probabilities. The values in `proportions` are automatically
  556. normalized to sum to 1.
  557. parameter_values : dict
  558. Map between the names of parameters (that define the endpoints of
  559. `typical`) and numerical values (arrays).
  560. """
  561. parameter_values = parameter_values or {}
  562. domain = self.domain
  563. proportions = (1, 0, 0, 0) if proportions is None else proportions
  564. pvals = proportions / np.sum(proportions)
  565. a, b = domain.get_numerical_endpoints(parameter_values)
  566. a, b = np.broadcast_arrays(a, b)
  567. base_shape = a.shape
  568. extended_shape = np.broadcast_shapes(size, base_shape)
  569. n_extended = np.prod(extended_shape)
  570. n_base = np.prod(base_shape)
  571. n = int(n_extended / n_base) if n_extended else 0
  572. rng = np.random.default_rng(rng)
  573. n_in, n_on, n_out, n_nan = rng.multinomial(n, pvals)
  574. # `min` and `max` can have singleton dimensions that correspond with
  575. # non-singleton dimensions in `size`. We need to be careful to avoid
  576. # shuffling results (e.g. a value that was generated for the domain
  577. # [min[i], max[i]] ends up at index j). To avoid this:
  578. # - Squeeze the singleton dimensions out of `min`/`max`. Squeezing is
  579. # often not the right thing to do, but here is equivalent to moving
  580. # all the dimensions that are singleton in `min`/`max` (which may be
  581. # non-singleton in the result) to the left. This is what we want.
  582. # - Now all the non-singleton dimensions of the result are on the left.
  583. # Ravel them to a single dimension of length `n`, which is now along
  584. # the 0th axis.
  585. # - Reshape the 0th axis back to the required dimensions, and move
  586. # these axes back to their original places.
  587. base_shape_padded = ((1,)*(len(extended_shape) - len(base_shape))
  588. + base_shape)
  589. base_singletons = np.where(np.asarray(base_shape_padded)==1)[0]
  590. new_base_singletons = tuple(range(len(base_singletons)))
  591. # Base singleton dimensions are going to get expanded to these lengths
  592. shape_expansion = np.asarray(extended_shape)[base_singletons]
  593. # assert(np.prod(shape_expansion) == n) # check understanding
  594. # min = np.reshape(min, base_shape_padded)
  595. # max = np.reshape(max, base_shape_padded)
  596. # min = np.moveaxis(min, base_singletons, new_base_singletons)
  597. # max = np.moveaxis(max, base_singletons, new_base_singletons)
  598. # squeezed_base_shape = max.shape[len(base_singletons):]
  599. # assert np.all(min.reshape(squeezed_base_shape) == min.squeeze())
  600. # assert np.all(max.reshape(squeezed_base_shape) == max.squeeze())
  601. # min = np.maximum(a, _fiinfo(a).min/10) if np.any(np.isinf(a)) else a
  602. # max = np.minimum(b, _fiinfo(b).max/10) if np.any(np.isinf(b)) else b
  603. min = np.asarray(a.squeeze())
  604. max = np.asarray(b.squeeze())
  605. squeezed_base_shape = max.shape
  606. if region == 'typical':
  607. typical = self.typical
  608. a, b = typical.get_numerical_endpoints(parameter_values)
  609. a, b = np.broadcast_arrays(a, b)
  610. min_here = np.asarray(a.squeeze())
  611. max_here = np.asarray(b.squeeze())
  612. z_in = typical.draw(n_in, 'in', min_here, max_here, squeezed_base_shape,
  613. rng=rng)
  614. else:
  615. z_in = domain.draw(n_in, 'in', min, max, squeezed_base_shape, rng=rng)
  616. z_on = domain.draw(n_on, 'on', min, max, squeezed_base_shape, rng=rng)
  617. z_out = domain.draw(n_out, 'out', min, max, squeezed_base_shape, rng=rng)
  618. z_nan= domain.draw(n_nan, 'nan', min, max, squeezed_base_shape, rng=rng)
  619. z = np.concatenate((z_in, z_on, z_out, z_nan), axis=0)
  620. z = rng.permuted(z, axis=0)
  621. z = np.reshape(z, tuple(shape_expansion) + squeezed_base_shape)
  622. z = np.moveaxis(z, new_base_singletons, base_singletons)
  623. return z
  624. @abstractmethod
  625. def validate(self, arr):
  626. raise NotImplementedError()
  627. class _RealParameter(_Parameter):
  628. r""" Represents a real-valued parameter.
  629. Implements the remaining methods of _Parameter for real parameters.
  630. All attributes are inherited.
  631. """
  632. def validate(self, arr, parameter_values):
  633. r""" Input validation/standardization of numerical values of a parameter.
  634. Checks whether elements of the argument `arr` are reals, ensuring that
  635. the dtype reflects this. Also produces a logical array that indicates
  636. which elements meet the requirements.
  637. Parameters
  638. ----------
  639. arr : ndarray
  640. The argument array to be validated and standardized.
  641. parameter_values : dict
  642. Map of parameter names to parameter value arrays.
  643. Returns
  644. -------
  645. arr : ndarray
  646. The argument array that has been validated and standardized
  647. (converted to an appropriate dtype, if necessary).
  648. dtype : NumPy dtype
  649. The appropriate floating point dtype of the parameter.
  650. valid : boolean ndarray
  651. Logical array indicating which elements are valid (True) and
  652. which are not (False). The arrays of all distribution parameters
  653. will be broadcasted, and elements for which any parameter value
  654. does not meet the requirements will be replaced with NaN.
  655. """
  656. arr = np.asarray(arr)
  657. valid_dtype = None
  658. # minor optimization - fast track the most common types to avoid
  659. # overhead of np.issubdtype. Checking for `in {...}` doesn't work : /
  660. if arr.dtype == np.float64 or arr.dtype == np.float32:
  661. pass
  662. elif arr.dtype == np.int32 or arr.dtype == np.int64:
  663. arr = np.asarray(arr, dtype=np.float64)
  664. elif np.issubdtype(arr.dtype, np.floating):
  665. pass
  666. elif np.issubdtype(arr.dtype, np.integer):
  667. arr = np.asarray(arr, dtype=np.float64)
  668. else:
  669. message = f"Parameter `{self.name}` must be of real dtype."
  670. raise TypeError(message)
  671. valid = self.domain.contains(arr, parameter_values)
  672. valid = valid & valid_dtype if valid_dtype is not None else valid
  673. return arr[()], arr.dtype, valid
  674. class _Parameterization:
  675. r""" Represents a parameterization of a distribution.
  676. Distributions can have multiple parameterizations. A `_Parameterization`
  677. object is responsible for recording the parameters used by the
  678. parameterization, checking whether keyword arguments passed to the
  679. distribution match the parameterization, and performing input validation
  680. of the numerical values of these parameters.
  681. Attributes
  682. ----------
  683. parameters : dict
  684. String names (of keyword arguments) and the corresponding _Parameters.
  685. Methods
  686. -------
  687. __len__()
  688. Returns the number of parameters in the parameterization.
  689. __str__()
  690. Returns a string representation of the parameterization.
  691. copy
  692. Returns a copy of the parameterization. This is needed for transformed
  693. distributions that add parameters to the parameterization.
  694. matches(parameters)
  695. Checks whether the keyword arguments match the parameterization.
  696. validation(parameter_values)
  697. Input validation / standardization of parameterization. Validates the
  698. numerical values of all parameters.
  699. draw(sizes, rng, proportions)
  700. Draw random values of all parameters of the parameterization for use
  701. in testing.
  702. """
  703. def __init__(self, *parameters):
  704. self.parameters = {param.name: param for param in parameters}
  705. def __len__(self):
  706. return len(self.parameters)
  707. def copy(self):
  708. return _Parameterization(*self.parameters.values())
  709. def matches(self, parameters):
  710. r""" Checks whether the keyword arguments match the parameterization.
  711. Parameters
  712. ----------
  713. parameters : set
  714. Set of names of parameters passed into the distribution as keyword
  715. arguments.
  716. Returns
  717. -------
  718. out : bool
  719. True if the keyword arguments names match the names of the
  720. parameters of this parameterization.
  721. """
  722. return parameters == set(self.parameters.keys())
  723. def validation(self, parameter_values):
  724. r""" Input validation / standardization of parameterization.
  725. Parameters
  726. ----------
  727. parameter_values : dict
  728. The keyword arguments passed as parameter values to the
  729. distribution.
  730. Returns
  731. -------
  732. all_valid : ndarray
  733. Logical array indicating the elements of the broadcasted arrays
  734. for which all parameter values are valid.
  735. dtype : dtype
  736. The common dtype of the parameter arrays. This will determine
  737. the dtype of the output of distribution methods.
  738. """
  739. all_valid = True
  740. dtypes = set() # avoid np.result_type if there's only one type
  741. for name, arr in parameter_values.items():
  742. parameter = self.parameters[name]
  743. arr, dtype, valid = parameter.validate(arr, parameter_values)
  744. dtypes.add(dtype)
  745. all_valid = all_valid & valid
  746. parameter_values[name] = arr
  747. dtype = arr.dtype if len(dtypes)==1 else np.result_type(*list(dtypes))
  748. return all_valid, dtype
  749. def __str__(self):
  750. r"""Returns a string representation of the parameterization."""
  751. messages = [str(param) for name, param in self.parameters.items()]
  752. return ", ".join(messages)
  753. def draw(self, sizes=None, rng=None, proportions=None, region='domain'):
  754. r"""Draw random values of all parameters for use in testing.
  755. Parameters
  756. ----------
  757. sizes : iterable of shape tuples
  758. The size of the array to be generated for each parameter in the
  759. parameterization. Note that the order of sizes is arbitary; the
  760. size of the array generated for a specific parameter is not
  761. controlled individually as written.
  762. rng : NumPy Generator
  763. The generator used to draw random values.
  764. proportions : tuple
  765. A tuple of four non-negative numbers that indicate the expected
  766. relative proportion of elements that are within the parameter's
  767. domain, are on the boundary of the parameter's domain, are outside
  768. the parameter's domain, and have value NaN. For more information,
  769. see the `draw` method of the _Parameter subclasses.
  770. domain : str
  771. The domain of the `_Parameter` from which to draw. Default is
  772. "domain" (the *full* domain); alternative is "typical".
  773. Returns
  774. -------
  775. parameter_values : dict (string: array)
  776. A dictionary of parameter name/value pairs.
  777. """
  778. # ENH: be smart about the order. The domains of some parameters
  779. # depend on others. If the relationshp is simple (e.g. a < b < c),
  780. # we can draw values in order a, b, c.
  781. parameter_values = {}
  782. if sizes is None or not len(sizes) or not np.iterable(sizes[0]):
  783. sizes = [sizes]*len(self.parameters)
  784. for size, param in zip(sizes, self.parameters.values()):
  785. parameter_values[param.name] = param.draw(
  786. size, rng=rng, proportions=proportions,
  787. parameter_values=parameter_values,
  788. region=region
  789. )
  790. return parameter_values
  791. def _set_invalid_nan(f):
  792. # Wrapper for input / output validation and standardization of distribution
  793. # functions that accept either the quantile or percentile as an argument:
  794. # logpdf, pdf
  795. # logpmf, pmf
  796. # logcdf, cdf
  797. # logccdf, ccdf
  798. # ilogcdf, icdf
  799. # ilogccdf, iccdf
  800. # Arguments that are outside the required range are replaced by NaN before
  801. # passing them into the underlying function. The corresponding outputs
  802. # are replaced by the appropriate value before being returned to the user.
  803. # For example, when the argument of `cdf` exceeds the right end of the
  804. # distribution's support, the wrapper replaces the argument with NaN,
  805. # ignores the output of the underlying function, and returns 1.0. It also
  806. # ensures that output is of the appropriate shape and dtype.
  807. endpoints = {'icdf': (0, 1), 'iccdf': (0, 1),
  808. 'ilogcdf': (-np.inf, 0), 'ilogccdf': (-np.inf, 0)}
  809. replacements = {'logpdf': (-inf, -inf), 'pdf': (0, 0),
  810. 'logpmf': (-inf, -inf), 'pmf': (0, 0),
  811. '_logcdf1': (-inf, 0), '_logccdf1': (0, -inf),
  812. '_cdf1': (0, 1), '_ccdf1': (1, 0)}
  813. replace_strict = {'pdf', 'logpdf', 'pmf', 'logpmf'}
  814. replace_exact = {'icdf', 'iccdf', 'ilogcdf', 'ilogccdf'}
  815. clip = {'_cdf1', '_ccdf1'}
  816. clip_log = {'_logcdf1', '_logccdf1'}
  817. # relevant to discrete distributions only
  818. replace_non_integral = {'pmf', 'logpmf', 'pdf', 'logpdf'}
  819. @functools.wraps(f)
  820. def filtered(self, x, *args, **kwargs):
  821. if self.validation_policy == _SKIP_ALL:
  822. return f(self, x, *args, **kwargs)
  823. method_name = f.__name__
  824. x = np.asarray(x)
  825. dtype = self._dtype
  826. shape = self._shape
  827. discrete = isinstance(self, DiscreteDistribution)
  828. keep_low_endpoint = discrete and method_name in {'_cdf1', '_logcdf1',
  829. '_ccdf1', '_logccdf1'}
  830. # Ensure that argument is at least as precise as distribution
  831. # parameters, which are already at least floats. This will avoid issues
  832. # with raising integers to negative integer powers and failure to replace
  833. # invalid integers with NaNs.
  834. if x.dtype != dtype:
  835. dtype = np.result_type(x.dtype, dtype)
  836. x = np.asarray(x, dtype=dtype)
  837. # Broadcasting is slow. Do it only if necessary.
  838. if not x.shape == shape:
  839. try:
  840. shape = np.broadcast_shapes(x.shape, shape)
  841. x = np.broadcast_to(x, shape)
  842. # Should we broadcast the distribution parameters to this shape, too?
  843. except ValueError as e:
  844. message = (
  845. f"The argument provided to `{self.__class__.__name__}"
  846. f".{method_name}` cannot be be broadcast to the same "
  847. "shape as the distribution parameters.")
  848. raise ValueError(message) from e
  849. low, high = endpoints.get(method_name, self.support())
  850. # Check for arguments outside of domain. They'll be replaced with NaNs,
  851. # and the result will be set to the appropriate value.
  852. left_inc, right_inc = self._variable.domain.inclusive
  853. mask_low = (x < low if (method_name in replace_strict and left_inc)
  854. or keep_low_endpoint else x <= low)
  855. mask_high = (x > high if (method_name in replace_strict and right_inc)
  856. else x >= high)
  857. mask_invalid = (mask_low | mask_high)
  858. any_invalid = (mask_invalid if mask_invalid.shape == ()
  859. else np.any(mask_invalid))
  860. # Check for arguments at domain endpoints, whether they
  861. # are part of the domain or not.
  862. any_endpoint = False
  863. if method_name in replace_exact:
  864. mask_low_endpoint = (x == low)
  865. mask_high_endpoint = (x == high)
  866. mask_endpoint = (mask_low_endpoint | mask_high_endpoint)
  867. any_endpoint = (mask_endpoint if mask_endpoint.shape == ()
  868. else np.any(mask_endpoint))
  869. # Check for non-integral arguments to PMF method
  870. # or PDF of a discrete distribution.
  871. any_non_integral = False
  872. if discrete and method_name in replace_non_integral:
  873. mask_non_integral = (x != np.floor(x))
  874. any_non_integral = (mask_non_integral if mask_non_integral.shape == ()
  875. else np.any(mask_non_integral))
  876. # Set out-of-domain arguments to NaN. The result will be set to the
  877. # appropriate value later.
  878. if any_invalid:
  879. x = np.array(x, dtype=dtype, copy=True)
  880. x[mask_invalid] = np.nan
  881. res = np.asarray(f(self, x, *args, **kwargs))
  882. # Ensure that the result is the correct dtype and shape,
  883. # copying (only once) if necessary.
  884. res_needs_copy = False
  885. if res.dtype != dtype:
  886. dtype = np.result_type(dtype, self._dtype)
  887. res_needs_copy = True
  888. if res.shape != shape: # faster to check first
  889. res = np.broadcast_to(res, self._shape)
  890. res_needs_copy = (res_needs_copy or any_invalid
  891. or any_endpoint or any_non_integral)
  892. if res_needs_copy:
  893. res = np.array(res, dtype=dtype, copy=True)
  894. # For non-integral arguments to PMF (and PDF of discrete distribution)
  895. # replace with zero.
  896. if any_non_integral:
  897. zero = -np.inf if method_name in {'logpmf', 'logpdf'} else 0
  898. res[mask_non_integral & ~np.isnan(res)] = zero
  899. # For arguments outside the function domain, replace results
  900. if any_invalid:
  901. replace_low, replace_high = (
  902. replacements.get(method_name, (np.nan, np.nan)))
  903. res[mask_low] = replace_low
  904. res[mask_high] = replace_high
  905. # For arguments at the endpoints of the domain, replace results
  906. if any_endpoint:
  907. a, b = self.support()
  908. if a.shape != shape:
  909. a = np.array(np.broadcast_to(a, shape), copy=True)
  910. b = np.array(np.broadcast_to(b, shape), copy=True)
  911. replace_low_endpoint = (
  912. b[mask_low_endpoint] if method_name.endswith('ccdf')
  913. else a[mask_low_endpoint])
  914. replace_high_endpoint = (
  915. a[mask_high_endpoint] if method_name.endswith('ccdf')
  916. else b[mask_high_endpoint])
  917. if not keep_low_endpoint:
  918. res[mask_low_endpoint] = replace_low_endpoint
  919. res[mask_high_endpoint] = replace_high_endpoint
  920. # Clip probabilities to [0, 1]
  921. if method_name in clip:
  922. res = np.clip(res, 0., 1.)
  923. elif method_name in clip_log:
  924. res = res.real # exp(res) > 0
  925. res = np.clip(res, None, 0.) # exp(res) < 1
  926. return res[()]
  927. return filtered
  928. def _set_invalid_nan_property(f):
  929. # Wrapper for input / output validation and standardization of distribution
  930. # functions that represent properties of the distribution itself:
  931. # logentropy, entropy
  932. # median, mode
  933. # moment
  934. # It ensures that the output is of the correct shape and dtype and that
  935. # there are NaNs wherever the distribution parameters were invalid.
  936. @functools.wraps(f)
  937. def filtered(self, *args, **kwargs):
  938. if self.validation_policy == _SKIP_ALL:
  939. return f(self, *args, **kwargs)
  940. res = f(self, *args, **kwargs)
  941. if res is None:
  942. # message could be more appropriate
  943. raise NotImplementedError(self._not_implemented)
  944. res = np.asarray(res)
  945. needs_copy = False
  946. dtype = res.dtype
  947. if dtype != self._dtype: # this won't work for logmoments (complex)
  948. dtype = np.result_type(dtype, self._dtype)
  949. needs_copy = True
  950. if res.shape != self._shape: # faster to check first
  951. res = np.broadcast_to(res, self._shape)
  952. needs_copy = needs_copy or self._any_invalid
  953. if needs_copy:
  954. res = res.astype(dtype=dtype, copy=True)
  955. if self._any_invalid:
  956. # may be redundant when quadrature is used, but not necessarily
  957. # when formulas are used.
  958. res[self._invalid] = np.nan
  959. return res[()]
  960. return filtered
  961. def _dispatch(f):
  962. # For each public method (instance function) of a distribution (e.g. ccdf),
  963. # there may be several ways ("method"s) that it can be computed (e.g. a
  964. # formula, as the complement of the CDF, or via numerical integration).
  965. # Each "method" is implemented by a different private method (instance
  966. # function).
  967. # This wrapper calls the appropriate private method based on the public
  968. # method and any specified `method` keyword option.
  969. # - If `method` is specified as a string (by the user), the appropriate
  970. # private method is called.
  971. # - If `method` is None:
  972. # - The appropriate private method for the public method is looked up
  973. # in a cache.
  974. # - If the cache does not have an entry for the public method, the
  975. # appropriate "dispatch " function is called to determine which method
  976. # is most appropriate given the available private methods and
  977. # settings (e.g. tolerance).
  978. @functools.wraps(f)
  979. def wrapped(self, *args, method=None, **kwargs):
  980. func_name = f.__name__
  981. method = method or self._method_cache.get(func_name, None)
  982. if callable(method):
  983. pass
  984. elif method is not None:
  985. method = 'logexp' if method == 'log/exp' else method
  986. method_name = func_name.replace('dispatch', method)
  987. method = getattr(self, method_name)
  988. else:
  989. method = f(self, *args, method=method, **kwargs)
  990. if func_name != '_sample_dispatch' and self.cache_policy != _NO_CACHE:
  991. self._method_cache[func_name] = method
  992. try:
  993. return method(*args, **kwargs)
  994. except KeyError as e:
  995. raise NotImplementedError(self._not_implemented) from e
  996. return wrapped
  997. def _cdf2_input_validation(f):
  998. # Wrapper that does the job of `_set_invalid_nan` when `cdf` or `logcdf`
  999. # is called with two quantile arguments.
  1000. # Let's keep it simple; no special cases for speed right now.
  1001. # The strategy is a bit different than for 1-arg `cdf` (and other methods
  1002. # covered by `_set_invalid_nan`). For 1-arg `cdf`, elements of `x` that
  1003. # are outside (or at the edge of) the support get replaced by `nan`,
  1004. # and then the results get replaced by the appropriate value (0 or 1).
  1005. # We *could* do something similar, dispatching to `_cdf1` in these
  1006. # cases. That would be a bit more robust, but it would also be quite
  1007. # a bit more complex, since we'd have to do different things when
  1008. # `x` and `y` are both out of bounds, when just `x` is out of bounds,
  1009. # when just `y` is out of bounds, and when both are out of bounds.
  1010. # I'm not going to do that right now. Instead, simply replace values
  1011. # outside the support by those at the edge of the support. Here, we also
  1012. # omit some of the optimizations that make `_set_invalid_nan` faster for
  1013. # simple arguments (e.g. float64 scalars).
  1014. @functools.wraps(f)
  1015. def wrapped(self, x, y, *args, **kwargs):
  1016. func_name = f.__name__
  1017. low, high = self.support()
  1018. x, y, low, high = np.broadcast_arrays(x, y, low, high)
  1019. dtype = np.result_type(x.dtype, y.dtype, self._dtype)
  1020. # yes, copy to avoid modifying input arrays
  1021. x, y = x.astype(dtype, copy=True), y.astype(dtype, copy=True)
  1022. # Swap arguments to ensure that x < y, and replace
  1023. # out-of domain arguments with domain endpoints. We'll
  1024. # transform the result later.
  1025. i_swap = y < x
  1026. x[i_swap], y[i_swap] = y[i_swap], x[i_swap]
  1027. i = x < low
  1028. x[i] = low[i]
  1029. i = y < low
  1030. y[i] = low[i]
  1031. i = x > high
  1032. x[i] = high[i]
  1033. i = y > high
  1034. y[i] = high[i]
  1035. res = f(self, x, y, *args, **kwargs)
  1036. # Clipping probability to [0, 1]
  1037. if func_name in {'_cdf2', '_ccdf2'}:
  1038. res = np.clip(res, 0., 1.)
  1039. else:
  1040. res = np.clip(res, None, 0.) # exp(res) < 1
  1041. # Transform the result to account for swapped argument order
  1042. res = np.asarray(res)
  1043. if func_name == '_cdf2':
  1044. res[i_swap] *= -1.
  1045. elif func_name == '_ccdf2':
  1046. res[i_swap] *= -1
  1047. res[i_swap] += 2.
  1048. elif func_name == '_logcdf2':
  1049. res = np.asarray(res + 0j) if np.any(i_swap) else res
  1050. res[i_swap] = res[i_swap] + np.pi*1j
  1051. else:
  1052. # res[i_swap] is always positive and less than 1, so it's
  1053. # safe to ensure that the result is real
  1054. res[i_swap] = _logexpxmexpy(np.log(2), res[i_swap]).real
  1055. return res[()]
  1056. return wrapped
  1057. def _fiinfo(x):
  1058. if np.issubdtype(x.dtype, np.inexact):
  1059. return np.finfo(x.dtype)
  1060. else:
  1061. return np.iinfo(x)
  1062. def _kwargs2args(f, args=None, kwargs=None):
  1063. # Wraps a function that accepts a primary argument `x`, secondary
  1064. # arguments `args`, and secondary keyward arguments `kwargs` such that the
  1065. # wrapper accepts only `x` and `args`. The keyword arguments are extracted
  1066. # from `args` passed into the wrapper, and these are passed to the
  1067. # underlying function as `kwargs`.
  1068. # This is a temporary workaround until the scalar algorithms `_tanhsinh`,
  1069. # `_chandrupatla`, etc., support `kwargs` or can operate with compressing
  1070. # arguments to the callable.
  1071. args = args or []
  1072. kwargs = kwargs or {}
  1073. names = list(kwargs.keys())
  1074. n_args = len(args)
  1075. def wrapped(x, *args):
  1076. return f(x, *args[:n_args], **dict(zip(names, args[n_args:])))
  1077. args = tuple(args) + tuple(kwargs.values())
  1078. return wrapped, args
  1079. def _logexpxmexpy(x, y):
  1080. """ Compute the log of the difference of the exponentials of two arguments.
  1081. Avoids over/underflow, but does not prevent loss of precision otherwise.
  1082. """
  1083. # TODO: properly avoid NaN when y is negative infinity
  1084. # TODO: silence warning with taking log of complex nan
  1085. # TODO: deal with x == y better
  1086. i = np.isneginf(np.real(y))
  1087. if np.any(i):
  1088. y = np.asarray(y.copy())
  1089. y[i] = np.finfo(y.dtype).min
  1090. x, y = np.broadcast_arrays(x, y)
  1091. res = np.asarray(special.logsumexp([x, y+np.pi*1j], axis=0))
  1092. i = (x == y)
  1093. res[i] = -np.inf
  1094. return res
  1095. def _guess_bracket(xmin, xmax):
  1096. a = np.full_like(xmin, -1.0)
  1097. b = np.ones_like(xmax)
  1098. i = np.isfinite(xmin) & np.isfinite(xmax)
  1099. a[i] = xmin[i]
  1100. b[i] = xmax[i]
  1101. i = np.isfinite(xmin) & ~np.isfinite(xmax)
  1102. a[i] = xmin[i]
  1103. b[i] = xmin[i] + 1
  1104. i = np.isfinite(xmax) & ~np.isfinite(xmin)
  1105. a[i] = xmax[i] - 1
  1106. b[i] = xmax[i]
  1107. return a, b
  1108. def _log_real_standardize(x):
  1109. """Standardizes the (complex) logarithm of a real number.
  1110. The logarithm of a real number may be represented by a complex number with
  1111. imaginary part that is a multiple of pi*1j. Even multiples correspond with
  1112. a positive real and odd multiples correspond with a negative real.
  1113. Given a logarithm of a real number `x`, this function returns an equivalent
  1114. representation in a standard form: the log of a positive real has imaginary
  1115. part `0` and the log of a negative real has imaginary part `pi`.
  1116. """
  1117. shape = x.shape
  1118. x = np.atleast_1d(x)
  1119. real = np.real(x).astype(x.dtype)
  1120. complex = np.imag(x)
  1121. y = real
  1122. negative = np.exp(complex*1j) < 0.5
  1123. y[negative] = y[negative] + np.pi * 1j
  1124. return y.reshape(shape)[()]
  1125. def _combine_docs(dist_family, *, include_examples=True):
  1126. fields = set(NumpyDocString.sections)
  1127. fields.remove('index')
  1128. if not include_examples:
  1129. fields.remove('Examples')
  1130. doc = ClassDoc(dist_family)
  1131. superdoc = ClassDoc(UnivariateDistribution)
  1132. for field in fields:
  1133. if field in {"Methods", "Attributes"}:
  1134. doc[field] = superdoc[field]
  1135. elif field in {"Summary"}:
  1136. pass
  1137. elif field == "Extended Summary":
  1138. doc[field].append(_generate_domain_support(dist_family))
  1139. elif field == 'Examples':
  1140. doc[field] = [_generate_example(dist_family)]
  1141. else:
  1142. doc[field] += superdoc[field]
  1143. return str(doc)
  1144. def _generate_domain_support(dist_family):
  1145. n_parameterizations = len(dist_family._parameterizations)
  1146. domain = f"\nfor :math:`x \\in {dist_family._variable.domain}`.\n"
  1147. if n_parameterizations == 0:
  1148. support = """
  1149. This class accepts no distribution parameters.
  1150. """
  1151. elif n_parameterizations == 1:
  1152. support = f"""
  1153. This class accepts one parameterization:
  1154. {str(dist_family._parameterizations[0])}.
  1155. """
  1156. else:
  1157. number = {2: 'two', 3: 'three', 4: 'four', 5: 'five'}[
  1158. n_parameterizations]
  1159. parameterizations = [f"- {str(p)}" for p in
  1160. dist_family._parameterizations]
  1161. parameterizations = "\n".join(parameterizations)
  1162. support = f"""
  1163. This class accepts {number} parameterizations:
  1164. {parameterizations}
  1165. """
  1166. support = "\n".join([line.lstrip() for line in support.split("\n")][1:])
  1167. return domain + support
  1168. def _generate_example(dist_family):
  1169. n_parameters = dist_family._num_parameters(0)
  1170. shapes = [()] * n_parameters
  1171. rng = np.random.default_rng(615681484984984)
  1172. i = 0
  1173. dist = dist_family._draw(shapes, rng=rng, i_parameterization=i)
  1174. rng = np.random.default_rng(2354873452)
  1175. name = dist_family.__name__
  1176. if n_parameters:
  1177. parameter_names = list(dist._parameterizations[i].parameters)
  1178. parameter_values = [round(getattr(dist, name), 2) for name in
  1179. parameter_names]
  1180. name_values = [f"{name}={value}" for name, value in
  1181. zip(parameter_names, parameter_values)]
  1182. instantiation = f"{name}({', '.join(name_values)})"
  1183. attributes = ", ".join([f"X.{param}" for param in dist._parameters])
  1184. X = dist_family(**dict(zip(parameter_names, parameter_values)))
  1185. else:
  1186. instantiation = f"{name}()"
  1187. X = dist
  1188. p = 0.32
  1189. x = round(X.icdf(p), 2)
  1190. y = round(X.icdf(2 * p), 2) # noqa: F841
  1191. example = f"""
  1192. To use the distribution class, it must be instantiated using keyword
  1193. parameters corresponding with one of the accepted parameterizations.
  1194. >>> import numpy as np
  1195. >>> import matplotlib.pyplot as plt
  1196. >>> from scipy import stats
  1197. >>> from scipy.stats import {name}
  1198. >>> X = {instantiation}
  1199. For convenience, the ``plot`` method can be used to visualize the density
  1200. and other functions of the distribution.
  1201. >>> X.plot()
  1202. >>> plt.show()
  1203. The support of the underlying distribution is available using the ``support``
  1204. method.
  1205. >>> X.support()
  1206. {X.support()}
  1207. """
  1208. if n_parameters:
  1209. example += f"""
  1210. The numerical values of parameters associated with all parameterizations
  1211. are available as attributes.
  1212. >>> {attributes}
  1213. {tuple(X._parameters.values())}
  1214. """
  1215. example += f"""
  1216. To evaluate the probability density/mass function of the underlying distribution
  1217. at argument ``x={x}``:
  1218. >>> x = {x}
  1219. >>> X.pdf(x), X.pmf(x)
  1220. {X.pdf(x), X.pmf(x)}
  1221. The cumulative distribution function, its complement, and the logarithm
  1222. of these functions are evaluated similarly.
  1223. >>> np.allclose(np.exp(X.logccdf(x)), 1 - X.cdf(x))
  1224. True
  1225. """
  1226. # When two-arg CDF is implemented for DiscreteDistribution, consider removing
  1227. # the special-casing here.
  1228. if issubclass(dist_family, ContinuousDistribution):
  1229. example_continuous = f"""
  1230. The inverse of these functions with respect to the argument ``x`` is also
  1231. available.
  1232. >>> logp = np.log(1 - X.ccdf(x))
  1233. >>> np.allclose(X.ilogcdf(logp), x)
  1234. True
  1235. Note that distribution functions and their logarithms also have two-argument
  1236. versions for working with the probability mass between two arguments. The
  1237. result tends to be more accurate than the naive implementation because it avoids
  1238. subtractive cancellation.
  1239. >>> y = {y}
  1240. >>> np.allclose(X.ccdf(x, y), 1 - (X.cdf(y) - X.cdf(x)))
  1241. True
  1242. """
  1243. example += example_continuous
  1244. example += f"""
  1245. There are methods for computing measures of central tendency,
  1246. dispersion, higher moments, and entropy.
  1247. >>> X.mean(), X.median(), X.mode()
  1248. {X.mean(), X.median(), X.mode()}
  1249. >>> X.variance(), X.standard_deviation()
  1250. {X.variance(), X.standard_deviation()}
  1251. >>> X.skewness(), X.kurtosis()
  1252. {X.skewness(), X.kurtosis()}
  1253. >>> np.allclose(X.moment(order=6, kind='standardized'),
  1254. ... X.moment(order=6, kind='central') / X.variance()**3)
  1255. True
  1256. """
  1257. # When logentropy is implemented for DiscreteDistribution, remove special-casing
  1258. if issubclass(dist_family, ContinuousDistribution):
  1259. example += """
  1260. >>> np.allclose(np.exp(X.logentropy()), X.entropy())
  1261. True
  1262. """
  1263. else:
  1264. example += f"""
  1265. >>> X.entropy()
  1266. {X.entropy()}
  1267. """
  1268. example += f"""
  1269. Pseudo-random samples can be drawn from
  1270. the underlying distribution using ``sample``.
  1271. >>> X.sample(shape=(4,))
  1272. {repr(X.sample(shape=(4,)))} # may vary
  1273. """
  1274. # remove the indentation due to use of block quote within function;
  1275. # eliminate blank first line
  1276. example = "\n".join([line.lstrip() for line in example.split("\n")][1:])
  1277. return example
  1278. class UnivariateDistribution(_ProbabilityDistribution):
  1279. r""" Class that represents a continuous statistical distribution.
  1280. Parameters
  1281. ----------
  1282. tol : positive float, optional
  1283. The desired relative tolerance of calculations. Left unspecified,
  1284. calculations may be faster; when provided, calculations may be
  1285. more likely to meet the desired accuracy.
  1286. validation_policy : {None, "skip_all"}
  1287. Specifies the level of input validation to perform. Left unspecified,
  1288. input validation is performed to ensure appropriate behavior in edge
  1289. case (e.g. parameters out of domain, argument outside of distribution
  1290. support, etc.) and improve consistency of output dtype, shape, etc.
  1291. Pass ``'skip_all'`` to avoid the computational overhead of these
  1292. checks when rough edges are acceptable.
  1293. cache_policy : {None, "no_cache"}
  1294. Specifies the extent to which intermediate results are cached. Left
  1295. unspecified, intermediate results of some calculations (e.g. distribution
  1296. support, moments, etc.) are cached to improve performance of future
  1297. calculations. Pass ``'no_cache'`` to reduce memory reserved by the class
  1298. instance.
  1299. Attributes
  1300. ----------
  1301. All parameters are available as attributes.
  1302. Methods
  1303. -------
  1304. support
  1305. plot
  1306. sample
  1307. moment
  1308. mean
  1309. median
  1310. mode
  1311. variance
  1312. standard_deviation
  1313. skewness
  1314. kurtosis
  1315. pdf
  1316. logpdf
  1317. cdf
  1318. icdf
  1319. ccdf
  1320. iccdf
  1321. logcdf
  1322. ilogcdf
  1323. logccdf
  1324. ilogccdf
  1325. entropy
  1326. logentropy
  1327. See Also
  1328. --------
  1329. :ref:`rv_infrastructure` : Tutorial
  1330. Notes
  1331. -----
  1332. The following abbreviations are used throughout the documentation.
  1333. - PDF: probability density function
  1334. - CDF: cumulative distribution function
  1335. - CCDF: complementary CDF
  1336. - entropy: differential entropy
  1337. - log-*F*: logarithm of *F* (e.g. log-CDF)
  1338. - inverse *F*: inverse function of *F* (e.g. inverse CDF)
  1339. The API documentation is written to describe the API, not to serve as
  1340. a statistical reference. Effort is made to be correct at the level
  1341. required to use the functionality, not to be mathematically rigorous.
  1342. For example, continuity and differentiability may be implicitly assumed.
  1343. For precise mathematical definitions, consult your preferred mathematical
  1344. text.
  1345. """
  1346. __array_priority__ = 1
  1347. _parameterizations = [] # type: ignore[var-annotated]
  1348. ### Initialization
  1349. def __init__(self, *, tol=_null, validation_policy=None, cache_policy=None,
  1350. **parameters):
  1351. self.tol = tol
  1352. self.validation_policy = validation_policy
  1353. self.cache_policy = cache_policy
  1354. self._not_implemented = (
  1355. f"`{self.__class__.__name__}` does not provide an accurate "
  1356. "implementation of the required method. Consider leaving "
  1357. "`method` and `tol` unspecified to use another implementation."
  1358. )
  1359. self._original_parameters = {}
  1360. # We may want to override the `__init__` method with parameters so
  1361. # IDEs can suggest parameter names. If there are multiple parameterizations,
  1362. # we'll need the default values of parameters to be None; this will
  1363. # filter out the parameters that were not actually specified by the user.
  1364. parameters = {key: val for key, val in
  1365. sorted(parameters.items()) if val is not None}
  1366. self._update_parameters(**parameters)
  1367. def _update_parameters(self, *, validation_policy=None, **params):
  1368. r""" Update the numerical values of distribution parameters.
  1369. Parameters
  1370. ----------
  1371. **params : array_like
  1372. Desired numerical values of the distribution parameters. Any or all
  1373. of the parameters initially used to instantiate the distribution
  1374. may be modified. Parameters used in alternative parameterizations
  1375. are not accepted.
  1376. validation_policy : str
  1377. To be documented. See Question 3 at the top.
  1378. """
  1379. parameters = original_parameters = self._original_parameters.copy()
  1380. parameters.update(**params)
  1381. parameterization = None
  1382. self._invalid = np.asarray(False)
  1383. self._any_invalid = False
  1384. self._shape = tuple()
  1385. self._ndim = 0
  1386. self._size = 1
  1387. self._dtype = np.float64
  1388. if (validation_policy or self.validation_policy) == _SKIP_ALL:
  1389. parameters = self._process_parameters(**parameters)
  1390. elif not len(self._parameterizations):
  1391. if parameters:
  1392. message = (f"The `{self.__class__.__name__}` distribution "
  1393. "family does not accept parameters, but parameters "
  1394. f"`{set(parameters)}` were provided.")
  1395. raise ValueError(message)
  1396. else:
  1397. # This is default behavior, which re-runs all parameter validations
  1398. # even when only a single parameter is modified. For many
  1399. # distributions, the domain of a parameter doesn't depend on other
  1400. # parameters, so parameters could safely be modified without
  1401. # re-validating all other parameters. To handle these cases more
  1402. # efficiently, we could allow the developer to override this
  1403. # behavior.
  1404. # Currently the user can only update the original parameterization.
  1405. # Even though that parameterization is already known,
  1406. # `_identify_parameterization` is called to produce a nice error
  1407. # message if the user passes other values. To be a little more
  1408. # efficient, we could detect whether the values passed are
  1409. # consistent with the original parameterization rather than finding
  1410. # it from scratch. However, we might want other parameterizations
  1411. # to be accepted, which would require other changes, so I didn't
  1412. # optimize this.
  1413. parameterization = self._identify_parameterization(parameters)
  1414. parameters, shape, size, ndim = self._broadcast(parameters)
  1415. parameters, invalid, any_invalid, dtype = (
  1416. self._validate(parameterization, parameters))
  1417. parameters = self._process_parameters(**parameters)
  1418. self._invalid = invalid
  1419. self._any_invalid = any_invalid
  1420. self._shape = shape
  1421. self._size = size
  1422. self._ndim = ndim
  1423. self._dtype = dtype
  1424. self.reset_cache()
  1425. self._parameters = parameters
  1426. self._parameterization = parameterization
  1427. self._original_parameters = original_parameters
  1428. for name in self._parameters.keys():
  1429. # Make parameters properties of the class; return values from the instance
  1430. if hasattr(self.__class__, name):
  1431. continue
  1432. setattr(self.__class__, name, property(lambda self_, name_=name:
  1433. self_._parameters[name_].copy()[()]))
  1434. def reset_cache(self):
  1435. r""" Clear all cached values.
  1436. To improve the speed of some calculations, the distribution's support
  1437. and moments are cached.
  1438. This function is called automatically whenever the distribution
  1439. parameters are updated.
  1440. """
  1441. # We could offer finer control over what is cleared.
  1442. # For simplicity, these will still exist even if cache_policy is
  1443. # NO_CACHE; they just won't be populated. This allows caching to be
  1444. # turned on and off easily.
  1445. self._moment_raw_cache = {}
  1446. self._moment_central_cache = {}
  1447. self._moment_standardized_cache = {}
  1448. self._support_cache = None
  1449. self._method_cache = {}
  1450. self._constant_cache = None
  1451. def _identify_parameterization(self, parameters):
  1452. # Determine whether a `parameters` dictionary matches is consistent
  1453. # with one of the parameterizations of the distribution. If so,
  1454. # return that parameterization object; if not, raise an error.
  1455. #
  1456. # I've come back to this a few times wanting to avoid this explicit
  1457. # loop. I've considered several possibilities, but they've all been a
  1458. # little unusual. For example, we could override `_eq_` so we can
  1459. # use _parameterizations.index() to retrieve the parameterization,
  1460. # or the user could put the parameterizations in a dictionary so we
  1461. # could look them up with a key (e.g. frozenset of parameter names).
  1462. # I haven't been sure enough of these approaches to implement them.
  1463. parameter_names_set = set(parameters)
  1464. for parameterization in self._parameterizations:
  1465. if parameterization.matches(parameter_names_set):
  1466. break
  1467. else:
  1468. if not parameter_names_set:
  1469. message = (f"The `{self.__class__.__name__}` distribution "
  1470. "family requires parameters, but none were "
  1471. "provided.")
  1472. else:
  1473. parameter_names = self._get_parameter_str(parameters)
  1474. message = (f"The provided parameters `{parameter_names}` "
  1475. "do not match a supported parameterization of the "
  1476. f"`{self.__class__.__name__}` distribution family.")
  1477. raise ValueError(message)
  1478. return parameterization
  1479. def _broadcast(self, parameters):
  1480. # Broadcast the distribution parameters to the same shape. If the
  1481. # arrays are not broadcastable, raise a meaningful error.
  1482. #
  1483. # We always make sure that the parameters *are* the same shape
  1484. # and not just broadcastable. Users can access parameters as
  1485. # attributes, and I think they should see the arrays as the same shape.
  1486. # More importantly, arrays should be the same shape before logical
  1487. # indexing operations, which are needed in infrastructure code when
  1488. # there are invalid parameters, and may be needed in
  1489. # distribution-specific code. We don't want developers to need to
  1490. # broadcast in implementation functions.
  1491. # It's much faster to check whether broadcasting is necessary than to
  1492. # broadcast when it's not necessary.
  1493. parameter_vals = [np.asarray(parameter)
  1494. for parameter in parameters.values()]
  1495. parameter_shapes = set(parameter.shape for parameter in parameter_vals)
  1496. if len(parameter_shapes) == 1:
  1497. return (parameters, parameter_vals[0].shape,
  1498. parameter_vals[0].size, parameter_vals[0].ndim)
  1499. try:
  1500. parameter_vals = np.broadcast_arrays(*parameter_vals)
  1501. except ValueError as e:
  1502. parameter_names = self._get_parameter_str(parameters)
  1503. message = (f"The parameters `{parameter_names}` provided to the "
  1504. f"`{self.__class__.__name__}` distribution family "
  1505. "cannot be broadcast to the same shape.")
  1506. raise ValueError(message) from e
  1507. return (dict(zip(parameters.keys(), parameter_vals)),
  1508. parameter_vals[0].shape,
  1509. parameter_vals[0].size,
  1510. parameter_vals[0].ndim)
  1511. def _validate(self, parameterization, parameters):
  1512. # Broadcasts distribution parameter arrays and converts them to a
  1513. # consistent dtype. Replaces invalid parameters with `np.nan`.
  1514. # Returns the validated parameters, a boolean mask indicated *which*
  1515. # elements are invalid, a boolean scalar indicating whether *any*
  1516. # are invalid (to skip special treatments if none are invalid), and
  1517. # the common dtype.
  1518. valid, dtype = parameterization.validation(parameters)
  1519. invalid = ~valid
  1520. any_invalid = invalid if invalid.shape == () else np.any(invalid)
  1521. # If necessary, make the arrays contiguous and replace invalid with NaN
  1522. if any_invalid:
  1523. for parameter_name in parameters:
  1524. parameters[parameter_name] = np.copy(
  1525. parameters[parameter_name])
  1526. parameters[parameter_name][invalid] = np.nan
  1527. return parameters, invalid, any_invalid, dtype
  1528. def _process_parameters(self, **params):
  1529. r""" Process and cache distribution parameters for reuse.
  1530. This is intended to be overridden by subclasses. It allows distribution
  1531. authors to pre-process parameters for re-use. For instance, when a user
  1532. parameterizes a LogUniform distribution with `a` and `b`, it makes
  1533. sense to calculate `log(a)` and `log(b)` because these values will be
  1534. used in almost all distribution methods. The dictionary returned by
  1535. this method is passed to all private methods that calculate functions
  1536. of the distribution.
  1537. """
  1538. return params
  1539. def _get_parameter_str(self, parameters):
  1540. # Get a string representation of the parameters like "{a, b, c}".
  1541. return f"{{{', '.join(parameters.keys())}}}"
  1542. def _copy_parameterization(self):
  1543. self._parameterizations = self._parameterizations.copy()
  1544. for i in range(len(self._parameterizations)):
  1545. self._parameterizations[i] = self._parameterizations[i].copy()
  1546. ### Attributes
  1547. # `tol` attribute is just notional right now. See Question 4 above.
  1548. @property
  1549. def tol(self):
  1550. r"""positive float:
  1551. The desired relative tolerance of calculations. Left unspecified,
  1552. calculations may be faster; when provided, calculations may be
  1553. more likely to meet the desired accuracy.
  1554. """
  1555. return self._tol
  1556. @tol.setter
  1557. def tol(self, tol):
  1558. if _isnull(tol):
  1559. self._tol = tol
  1560. return
  1561. tol = np.asarray(tol)
  1562. if (tol.shape != () or not tol > 0 or # catches NaNs
  1563. not np.issubdtype(tol.dtype, np.floating)):
  1564. message = (f"Attribute `tol` of `{self.__class__.__name__}` must "
  1565. "be a positive float, if specified.")
  1566. raise ValueError(message)
  1567. self._tol = tol[()]
  1568. @property
  1569. def cache_policy(self):
  1570. r"""{None, "no_cache"}:
  1571. Specifies the extent to which intermediate results are cached. Left
  1572. unspecified, intermediate results of some calculations (e.g. distribution
  1573. support, moments, etc.) are cached to improve performance of future
  1574. calculations. Pass ``'no_cache'`` to reduce memory reserved by the class
  1575. instance.
  1576. """
  1577. return self._cache_policy
  1578. @cache_policy.setter
  1579. def cache_policy(self, cache_policy):
  1580. cache_policy = str(cache_policy).lower() if cache_policy is not None else None
  1581. cache_policies = {None, 'no_cache'}
  1582. if cache_policy not in cache_policies:
  1583. message = (f"Attribute `cache_policy` of `{self.__class__.__name__}` "
  1584. f"must be one of {cache_policies}, if specified.")
  1585. raise ValueError(message)
  1586. self._cache_policy = cache_policy
  1587. @property
  1588. def validation_policy(self):
  1589. r"""{None, "skip_all"}:
  1590. Specifies the level of input validation to perform. Left unspecified,
  1591. input validation is performed to ensure appropriate behavior in edge
  1592. case (e.g. parameters out of domain, argument outside of distribution
  1593. support, etc.) and improve consistency of output dtype, shape, etc.
  1594. Use ``'skip_all'`` to avoid the computational overhead of these
  1595. checks when rough edges are acceptable.
  1596. """
  1597. return self._validation_policy
  1598. @validation_policy.setter
  1599. def validation_policy(self, validation_policy):
  1600. validation_policy = (str(validation_policy).lower()
  1601. if validation_policy is not None else None)
  1602. iv_policies = {None, 'skip_all'}
  1603. if validation_policy not in iv_policies:
  1604. message = (f"Attribute `validation_policy` of `{self.__class__.__name__}` "
  1605. f"must be one of {iv_policies}, if specified.")
  1606. raise ValueError(message)
  1607. self._validation_policy = validation_policy
  1608. ### Other magic methods
  1609. def __repr__(self):
  1610. r""" Returns a string representation of the distribution.
  1611. Includes the name of the distribution family, the names of the
  1612. parameters and the `repr` of each of their values.
  1613. """
  1614. class_name = self.__class__.__name__
  1615. parameters = list(self._original_parameters.items())
  1616. info = []
  1617. with np.printoptions(threshold=10):
  1618. str_parameters = [f"{symbol}={repr(value)}" for symbol, value in parameters]
  1619. str_parameters = f"{', '.join(str_parameters)}"
  1620. info.append(str_parameters)
  1621. return f"{class_name}({', '.join(info)})"
  1622. def __str__(self):
  1623. class_name = self.__class__.__name__
  1624. parameters = list(self._original_parameters.items())
  1625. info = []
  1626. with np.printoptions(threshold=10):
  1627. str_parameters = [f"{symbol}={str(value)}" for symbol, value in parameters]
  1628. str_parameters = f"{', '.join(str_parameters)}"
  1629. info.append(str_parameters)
  1630. return f"{class_name}({', '.join(info)})"
  1631. def __add__(self, loc):
  1632. return ShiftedScaledDistribution(self, loc=loc)
  1633. def __sub__(self, loc):
  1634. return ShiftedScaledDistribution(self, loc=-loc)
  1635. def __mul__(self, scale):
  1636. return ShiftedScaledDistribution(self, scale=scale)
  1637. def __truediv__(self, scale):
  1638. return ShiftedScaledDistribution(self, scale=1/scale)
  1639. def __pow__(self, other):
  1640. if not np.isscalar(other) or other <= 0 or other != int(other):
  1641. message = ("Raising a random variable to the power of an argument is only "
  1642. "implemented when the argument is a positive integer.")
  1643. raise NotImplementedError(message)
  1644. # Fill in repr_pattern with the repr of self before taking abs.
  1645. # Avoids having unnecessary abs in the repr.
  1646. with np.printoptions(threshold=10):
  1647. repr_pattern = f"({repr(self)})**{repr(other)}"
  1648. str_pattern = f"({str(self)})**{str(other)}"
  1649. X = abs(self) if other % 2 == 0 else self
  1650. funcs = dict(g=lambda u: u**other, repr_pattern=repr_pattern,
  1651. str_pattern=str_pattern,
  1652. h=lambda u: np.sign(u) * np.abs(u)**(1 / other),
  1653. dh=lambda u: 1/other * np.abs(u)**(1/other - 1))
  1654. return MonotonicTransformedDistribution(X, **funcs, increasing=True)
  1655. def __radd__(self, other):
  1656. return self.__add__(other)
  1657. def __rsub__(self, other):
  1658. return self.__neg__().__add__(other)
  1659. def __rmul__(self, other):
  1660. return self.__mul__(other)
  1661. def __rtruediv__(self, other):
  1662. a, b = self.support()
  1663. with np.printoptions(threshold=10):
  1664. funcs = dict(g=lambda u: 1 / u,
  1665. repr_pattern=f"{repr(other)}/({repr(self)})",
  1666. str_pattern=f"{str(other)}/({str(self)})",
  1667. h=lambda u: 1 / u, dh=lambda u: 1 / u ** 2)
  1668. if np.all(a >= 0) or np.all(b <= 0):
  1669. out = MonotonicTransformedDistribution(self, **funcs, increasing=False)
  1670. else:
  1671. message = ("Division by a random variable is only implemented "
  1672. "when the support is either non-negative or non-positive.")
  1673. raise NotImplementedError(message)
  1674. if np.all(other == 1):
  1675. return out
  1676. else:
  1677. return out * other
  1678. def __rpow__(self, other):
  1679. with np.printoptions(threshold=10):
  1680. funcs = dict(g=lambda u: other**u,
  1681. h=lambda u: np.log(u) / np.log(other),
  1682. dh=lambda u: 1 / np.abs(u * np.log(other)),
  1683. repr_pattern=f"{repr(other)}**({repr(self)})",
  1684. str_pattern=f"{str(other)}**({str(self)})",)
  1685. if not np.isscalar(other) or other <= 0 or other == 1:
  1686. message = ("Raising an argument to the power of a random variable is only "
  1687. "implemented when the argument is a positive scalar other than "
  1688. "1.")
  1689. raise NotImplementedError(message)
  1690. if other > 1:
  1691. return MonotonicTransformedDistribution(self, **funcs, increasing=True)
  1692. else:
  1693. return MonotonicTransformedDistribution(self, **funcs, increasing=False)
  1694. def __neg__(self):
  1695. return self * -1
  1696. def __abs__(self):
  1697. return FoldedDistribution(self)
  1698. ### Utilities
  1699. ## Input validation
  1700. def _validate_order_kind(self, order, kind, kinds):
  1701. # Yet another integer validating function. Unlike others in SciPy, it
  1702. # Is quite flexible about what is allowed as an integer, and it
  1703. # raises a distribution-specific error message to facilitate
  1704. # identification of the source of the error.
  1705. if self.validation_policy == _SKIP_ALL:
  1706. return order
  1707. order = np.asarray(order, dtype=self._dtype)[()]
  1708. message = (f"Argument `order` of `{self.__class__.__name__}.moment` "
  1709. "must be a finite, positive integer.")
  1710. try:
  1711. order_int = round(order.item())
  1712. # If this fails for any reason (e.g. it's an array, it's infinite)
  1713. # it's not a valid `order`.
  1714. except Exception as e:
  1715. raise ValueError(message) from e
  1716. if order_int <0 or order_int != order:
  1717. raise ValueError(message)
  1718. message = (f"Argument `kind` of `{self.__class__.__name__}.moment` "
  1719. f"must be one of {set(kinds)}.")
  1720. if kind.lower() not in kinds:
  1721. raise ValueError(message)
  1722. return order
  1723. def _preserve_type(self, x):
  1724. x = np.asarray(x)
  1725. if x.dtype != self._dtype:
  1726. x = x.astype(self._dtype)
  1727. return x[()]
  1728. ## Testing
  1729. @classmethod
  1730. def _draw(cls, sizes=None, rng=None, i_parameterization=None,
  1731. proportions=None):
  1732. r""" Draw a specific (fully-defined) distribution from the family.
  1733. See _Parameterization.draw for documentation details.
  1734. """
  1735. rng = np.random.default_rng(rng)
  1736. if len(cls._parameterizations) == 0:
  1737. return cls()
  1738. if i_parameterization is None:
  1739. n = cls._num_parameterizations()
  1740. i_parameterization = rng.integers(0, max(0, n - 1), endpoint=True)
  1741. parameterization = cls._parameterizations[i_parameterization]
  1742. parameters = parameterization.draw(sizes, rng, proportions=proportions,
  1743. region='typical')
  1744. return cls(**parameters)
  1745. @classmethod
  1746. def _num_parameterizations(cls):
  1747. # Returns the number of parameterizations accepted by the family.
  1748. return len(cls._parameterizations)
  1749. @classmethod
  1750. def _num_parameters(cls, i_parameterization=0):
  1751. # Returns the number of parameters used in the specified
  1752. # parameterization.
  1753. return (0 if not cls._num_parameterizations()
  1754. else len(cls._parameterizations[i_parameterization]))
  1755. ## Algorithms
  1756. def _quadrature(self, integrand, limits=None, args=None,
  1757. params=None, log=False):
  1758. # Performs numerical integration of an integrand between limits.
  1759. # Much of this should be added to `_tanhsinh`.
  1760. a, b = self._support(**params) if limits is None else limits
  1761. a, b = np.broadcast_arrays(a, b)
  1762. if not a.size:
  1763. # maybe need to figure out result type from a, b
  1764. return np.empty(a.shape, dtype=self._dtype)
  1765. args = [] if args is None else args
  1766. params = {} if params is None else params
  1767. f, args = _kwargs2args(integrand, args=args, kwargs=params)
  1768. args = np.broadcast_arrays(*args)
  1769. # If we know the median or mean, consider breaking up the interval
  1770. rtol = None if _isnull(self.tol) else self.tol
  1771. # For now, we ignore the status, but I want to return the error
  1772. # estimate - see question 5 at the top.
  1773. if isinstance(self, ContinuousDistribution):
  1774. res = _tanhsinh(f, a, b, args=args, log=log, rtol=rtol)
  1775. return res.integral
  1776. else:
  1777. res = nsum(f, a, b, args=args, log=log, tolerances=dict(rtol=rtol)).sum
  1778. res = np.asarray(res)
  1779. # The result should be nan when parameters are nan, so need to special
  1780. # case this.
  1781. cond = np.isnan(params.popitem()[1]) if params else np.True_
  1782. cond = np.broadcast_to(cond, a.shape)
  1783. res[(a > b)] = -np.inf if log else 0 # fix in nsum?
  1784. res[cond] = np.nan
  1785. return res[()]
  1786. def _solve_bounded(self, f, p, *, bounds=None, params=None, xatol=None):
  1787. # Finds the argument of a function that produces the desired output.
  1788. # Much of this should be added to _bracket_root / _chandrupatla.
  1789. xmin, xmax = self._support(**params) if bounds is None else bounds
  1790. params = {} if params is None else params
  1791. p, xmin, xmax = np.broadcast_arrays(p, xmin, xmax)
  1792. if not p.size:
  1793. # might need to figure out result type based on p
  1794. res = _RichResult()
  1795. empty = np.empty(p.shape, dtype=self._dtype)
  1796. res.xl, res.x, res.xr = empty, empty, empty
  1797. res.fl, res.fr = empty, empty
  1798. def f2(x, _p, **kwargs): # named `_p` to avoid conflict with shape `p`
  1799. return f(x, **kwargs) - _p
  1800. f3, args = _kwargs2args(f2, args=[p], kwargs=params)
  1801. # If we know the median or mean, should use it
  1802. # Any operations between 0d array and a scalar produces a scalar, so...
  1803. shape = xmin.shape
  1804. xmin, xmax = np.atleast_1d(xmin, xmax)
  1805. xl0, xr0 = _guess_bracket(xmin, xmax)
  1806. xmin = xmin.reshape(shape)
  1807. xmax = xmax.reshape(shape)
  1808. xl0 = xl0.reshape(shape)
  1809. xr0 = xr0.reshape(shape)
  1810. res = _bracket_root(f3, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, args=args)
  1811. # For now, we ignore the status, but I want to use the bracket width
  1812. # as an error estimate - see question 5 at the top.
  1813. xrtol = None if _isnull(self.tol) else self.tol
  1814. xatol = None if xatol is None else xatol
  1815. tolerances = dict(xrtol=xrtol, xatol=xatol, fatol=0, frtol=0)
  1816. return _chandrupatla(f3, a=res.xl, b=res.xr, args=args, **tolerances)
  1817. ## Other
  1818. def _overrides(self, method_name):
  1819. # Determines whether a class overrides a specified method.
  1820. # Returns True if the method implementation exists and is the same as
  1821. # that of the `ContinuousDistribution` class; otherwise returns False.
  1822. # Sometimes we use `_overrides` to check whether a certain method is overridden
  1823. # and if so, call it. This begs the questions of why we don't do the more
  1824. # obvious thing: restructure so that if the private method is overridden,
  1825. # Python will call it instead of the inherited version automatically. The short
  1826. # answer is that there are multiple ways a use might wish to evaluate a method,
  1827. # and simply overriding the method with a formula is not always the best option.
  1828. # For more complete discussion of the considerations, see:
  1829. # https://github.com/scipy/scipy/pull/21050#discussion_r1707798901
  1830. method = getattr(self.__class__, method_name, None)
  1831. super_method = getattr(UnivariateDistribution, method_name, None)
  1832. return method is not super_method
  1833. ### Distribution properties
  1834. # The following "distribution properties" are exposed via a public method
  1835. # that accepts only options (not distribution parameters or quantile/
  1836. # percentile argument).
  1837. # support
  1838. # logentropy, entropy,
  1839. # median, mode, mean,
  1840. # variance, standard_deviation
  1841. # skewness, kurtosis
  1842. # Common options are:
  1843. # method - a string that indicates which method should be used to compute
  1844. # the quantity (e.g. a formula or numerical integration).
  1845. # Input/output validation is provided by the `_set_invalid_nan_property`
  1846. # decorator. These are the methods meant to be called by users.
  1847. #
  1848. # Each public method calls a private "dispatch" method that
  1849. # determines which "method" (strategy for calculating the desired quantity)
  1850. # to use by default and, via the `@_dispatch` decorator, calls the
  1851. # method and computes the result.
  1852. # Dispatch methods always accept:
  1853. # method - as passed from the public method
  1854. # params - a dictionary of distribution shape parameters passed by
  1855. # the public method.
  1856. # Dispatch methods accept `params` rather than relying on the state of the
  1857. # object because iterative algorithms like `_tanhsinh` and `_chandrupatla`
  1858. # need their callable to follow a strict elementwise protocol: each element
  1859. # of the output is determined solely by the values of the inputs at the
  1860. # corresponding location. The public methods do not satisfy this protocol
  1861. # because they do not accept the parameters as arguments, producing an
  1862. # output that generally has a different shape than that of the input. Also,
  1863. # by calling "dispatch" methods rather than the public methods, the
  1864. # iterative algorithms avoid the overhead of input validation.
  1865. #
  1866. # Each dispatch method can designate the responsibility of computing
  1867. # the required value to any of several "implementation" methods. These
  1868. # methods accept only `**params`, the parameter dictionary passed from
  1869. # the public method via the dispatch method. We separate the implementation
  1870. # methods from the dispatch methods for the sake of simplicity (via
  1871. # compartmentalization) and to allow subclasses to override certain
  1872. # implementation methods (typically only the "formula" methods). The names
  1873. # of implementation methods are combinations of the public method name and
  1874. # the name of the "method" (strategy for calculating the desired quantity)
  1875. # string. (In fact, the name of the implementation method is calculated
  1876. # from these two strings in the `_dispatch` decorator.) Common method
  1877. # strings are:
  1878. # formula - distribution-specific analytical expressions to be implemented
  1879. # by subclasses.
  1880. # log/exp - Compute the log of a number and then exponentiate it or vice
  1881. # versa.
  1882. # quadrature - Compute the value via numerical integration.
  1883. #
  1884. # The default method (strategy) is determined based on what implementation
  1885. # methods are available and the error tolerance of the user. Typically,
  1886. # a formula is always used if available. We fall back to "log/exp" if a
  1887. # formula for the logarithm or exponential of the quantity is available,
  1888. # and we use quadrature otherwise.
  1889. def support(self):
  1890. # If this were a `cached_property`, we couldn't update the value
  1891. # when the distribution parameters change.
  1892. # Caching is important, though, because calls to _support take a few
  1893. # microseconds even when `a` and `b` are already the same shape.
  1894. if self._support_cache is not None:
  1895. return self._support_cache
  1896. a, b = self._support(**self._parameters)
  1897. if a.shape != self._shape:
  1898. a = np.broadcast_to(a, self._shape)
  1899. if b.shape != self._shape:
  1900. b = np.broadcast_to(b, self._shape)
  1901. if self._any_invalid:
  1902. a, b = np.asarray(a).copy(), np.asarray(b).copy()
  1903. a[self._invalid], b[self._invalid] = np.nan, np.nan
  1904. a, b = a[()], b[()]
  1905. support = (a, b)
  1906. if self.cache_policy != _NO_CACHE:
  1907. self._support_cache = support
  1908. return support
  1909. def _support(self, **params):
  1910. # Computes the support given distribution parameters
  1911. a, b = self._variable.domain.get_numerical_endpoints(params)
  1912. if len(params):
  1913. # the parameters should all be of the same dtype and shape at this point
  1914. vals = list(params.values())
  1915. shape = vals[0].shape
  1916. a = np.broadcast_to(a, shape) if a.shape != shape else a
  1917. b = np.broadcast_to(b, shape) if b.shape != shape else b
  1918. return self._preserve_type(a), self._preserve_type(b)
  1919. @_set_invalid_nan_property
  1920. def logentropy(self, *, method=None):
  1921. return self._logentropy_dispatch(method=method, **self._parameters) + 0j
  1922. @_dispatch
  1923. def _logentropy_dispatch(self, method=None, **params):
  1924. if self._overrides('_logentropy_formula'):
  1925. method = self._logentropy_formula
  1926. elif self._overrides('_entropy_formula'):
  1927. method = self._logentropy_logexp_safe
  1928. else:
  1929. method = self._logentropy_quadrature
  1930. return method
  1931. def _logentropy_formula(self, **params):
  1932. raise NotImplementedError(self._not_implemented)
  1933. def _logentropy_logexp(self, **params):
  1934. res = np.log(self._entropy_dispatch(**params)+0j)
  1935. return _log_real_standardize(res)
  1936. def _logentropy_logexp_safe(self, **params):
  1937. out = self._logentropy_logexp(**params)
  1938. mask = np.isinf(out.real)
  1939. if np.any(mask):
  1940. params_mask = {key:val[mask] for key, val in params.items()}
  1941. out = np.asarray(out)
  1942. out[mask] = self._logentropy_quadrature(**params_mask)
  1943. return out[()]
  1944. def _logentropy_quadrature(self, **params):
  1945. def logintegrand(x, **params):
  1946. logpxf = self._logpxf_dispatch(x, **params)
  1947. return logpxf + np.log(0j+logpxf)
  1948. res = self._quadrature(logintegrand, params=params, log=True)
  1949. return _log_real_standardize(res + np.pi*1j)
  1950. @_set_invalid_nan_property
  1951. def entropy(self, *, method=None):
  1952. return self._entropy_dispatch(method=method, **self._parameters)
  1953. @_dispatch
  1954. def _entropy_dispatch(self, method=None, **params):
  1955. if self._overrides('_entropy_formula'):
  1956. method = self._entropy_formula
  1957. elif self._overrides('_logentropy_formula'):
  1958. method = self._entropy_logexp
  1959. else:
  1960. method = self._entropy_quadrature
  1961. return method
  1962. def _entropy_formula(self, **params):
  1963. raise NotImplementedError(self._not_implemented)
  1964. def _entropy_logexp(self, **params):
  1965. return np.real(np.exp(self._logentropy_dispatch(**params)))
  1966. def _entropy_quadrature(self, **params):
  1967. def integrand(x, **params):
  1968. pxf = self._pxf_dispatch(x, **params)
  1969. logpxf = self._logpxf_dispatch(x, **params)
  1970. temp = np.asarray(pxf)
  1971. i = (pxf != 0) # 0 * inf -> nan; should be 0
  1972. temp[i] = -pxf[i]*logpxf[i]
  1973. return temp
  1974. return self._quadrature(integrand, params=params)
  1975. @_set_invalid_nan_property
  1976. def median(self, *, method=None):
  1977. return self._median_dispatch(method=method, **self._parameters)
  1978. @_dispatch
  1979. def _median_dispatch(self, method=None, **params):
  1980. if self._overrides('_median_formula'):
  1981. method = self._median_formula
  1982. else:
  1983. method = self._median_icdf
  1984. return method
  1985. def _median_formula(self, **params):
  1986. raise NotImplementedError(self._not_implemented)
  1987. def _median_icdf(self, **params):
  1988. return self._icdf_dispatch(np.asarray(0.5, dtype=self._dtype), **params)
  1989. @_set_invalid_nan_property
  1990. def mode(self, *, method=None):
  1991. return self._mode_dispatch(method=method, **self._parameters)
  1992. @_dispatch
  1993. def _mode_dispatch(self, method=None, **params):
  1994. # We could add a method that looks for a critical point with
  1995. # differentiation and the root finder
  1996. if self._overrides('_mode_formula'):
  1997. method = self._mode_formula
  1998. else:
  1999. method = self._mode_optimization
  2000. return method
  2001. def _mode_formula(self, **params):
  2002. raise NotImplementedError(self._not_implemented)
  2003. def _mode_optimization(self, xatol=None, **params):
  2004. if not self._size:
  2005. return np.empty(self._shape, dtype=self._dtype)
  2006. a, b = self._support(**params)
  2007. m = self._median_dispatch(**params)
  2008. f, args = _kwargs2args(lambda x, **params: -self._pxf_dispatch(x, **params),
  2009. args=(), kwargs=params)
  2010. res_b = _bracket_minimum(f, m, xmin=a, xmax=b, args=args)
  2011. res = _chandrupatla_minimize(f, res_b.xl, res_b.xm, res_b.xr,
  2012. args=args, xatol=xatol)
  2013. mode = np.asarray(res.x)
  2014. mode_at_boundary = res_b.status == -1
  2015. mode_at_left = mode_at_boundary & (res_b.fl <= res_b.fm)
  2016. mode_at_right = mode_at_boundary & (res_b.fr < res_b.fm)
  2017. mode[mode_at_left] = a[mode_at_left]
  2018. mode[mode_at_right] = b[mode_at_right]
  2019. return mode[()]
  2020. def mean(self, *, method=None):
  2021. return self.moment(1, kind='raw', method=method)
  2022. def variance(self, *, method=None):
  2023. return self.moment(2, kind='central', method=method)
  2024. def standard_deviation(self, *, method=None):
  2025. return np.sqrt(self.variance(method=method))
  2026. def skewness(self, *, method=None):
  2027. return self.moment(3, kind='standardized', method=method)
  2028. def kurtosis(self, *, method=None, convention='non-excess'):
  2029. conventions = {'non-excess', 'excess'}
  2030. message = (f'Parameter `convention` of `{self.__class__.__name__}.kurtosis` '
  2031. f"must be one of {conventions}.")
  2032. convention = convention.lower()
  2033. if convention not in conventions:
  2034. raise ValueError(message)
  2035. k = self.moment(4, kind='standardized', method=method)
  2036. return k - 3 if convention == 'excess' else k
  2037. ### Distribution functions
  2038. # The following functions related to the distribution PDF and CDF are
  2039. # exposed via a public method that accepts one positional argument - the
  2040. # quantile - and keyword options (but not distribution parameters).
  2041. # logpdf, pdf
  2042. # logcdf, cdf
  2043. # logccdf, ccdf
  2044. # The `logcdf` and `cdf` functions can also be called with two positional
  2045. # arguments - lower and upper quantiles - and they return the probability
  2046. # mass (integral of the PDF) between them. The 2-arg versions of `logccdf`
  2047. # and `ccdf` return the complement of this quantity.
  2048. # All the (1-arg) cumulative distribution functions have inverse
  2049. # functions, which accept one positional argument - the percentile.
  2050. # ilogcdf, icdf
  2051. # ilogccdf, iccdf
  2052. # Common keyword options include:
  2053. # method - a string that indicates which method should be used to compute
  2054. # the quantity (e.g. a formula or numerical integration).
  2055. # Tolerance options should be added.
  2056. # Input/output validation is provided by the `_set_invalid_nan`
  2057. # decorator. These are the methods meant to be called by users.
  2058. #
  2059. # Each public method calls a private "dispatch" method that
  2060. # determines which "method" (strategy for calculating the desired quantity)
  2061. # to use by default and, via the `@_dispatch` decorator, calls the
  2062. # method and computes the result.
  2063. # Each dispatch method can designate the responsibility of computing
  2064. # the required value to any of several "implementation" methods. These
  2065. # methods accept only `**params`, the parameter dictionary passed from
  2066. # the public method via the dispatch method.
  2067. # See the note corresponding with the "Distribution Parameters" for more
  2068. # information.
  2069. ## Probability Density/Mass Functions
  2070. @_set_invalid_nan
  2071. def logpdf(self, x, /, *, method=None):
  2072. return self._logpdf_dispatch(x, method=method, **self._parameters)
  2073. @_dispatch
  2074. def _logpdf_dispatch(self, x, *, method=None, **params):
  2075. if self._overrides('_logpdf_formula'):
  2076. method = self._logpdf_formula
  2077. elif _isnull(self.tol): # ensure that developers override _logpdf
  2078. method = self._logpdf_logexp
  2079. return method
  2080. def _logpdf_formula(self, x, **params):
  2081. raise NotImplementedError(self._not_implemented)
  2082. def _logpdf_logexp(self, x, **params):
  2083. return np.log(self._pdf_dispatch(x, **params))
  2084. @_set_invalid_nan
  2085. def pdf(self, x, /, *, method=None):
  2086. return self._pdf_dispatch(x, method=method, **self._parameters)
  2087. @_dispatch
  2088. def _pdf_dispatch(self, x, *, method=None, **params):
  2089. if self._overrides('_pdf_formula'):
  2090. method = self._pdf_formula
  2091. else:
  2092. method = self._pdf_logexp
  2093. return method
  2094. def _pdf_formula(self, x, **params):
  2095. raise NotImplementedError(self._not_implemented)
  2096. def _pdf_logexp(self, x, **params):
  2097. return np.exp(self._logpdf_dispatch(x, **params))
  2098. @_set_invalid_nan
  2099. def logpmf(self, x, /, *, method=None):
  2100. return self._logpmf_dispatch(x, method=method, **self._parameters)
  2101. @_dispatch
  2102. def _logpmf_dispatch(self, x, *, method=None, **params):
  2103. if self._overrides('_logpmf_formula'):
  2104. method = self._logpmf_formula
  2105. elif _isnull(self.tol): # ensure that developers override _logpmf
  2106. method = self._logpmf_logexp
  2107. return method
  2108. def _logpmf_formula(self, x, **params):
  2109. raise NotImplementedError(self._not_implemented)
  2110. def _logpmf_logexp(self, x, **params):
  2111. with np.errstate(divide='ignore'):
  2112. return np.log(self._pmf_dispatch(x, **params))
  2113. @_set_invalid_nan
  2114. def pmf(self, x, /, *, method=None):
  2115. return self._pmf_dispatch(x, method=method, **self._parameters)
  2116. @_dispatch
  2117. def _pmf_dispatch(self, x, *, method=None, **params):
  2118. if self._overrides('_pmf_formula'):
  2119. method = self._pmf_formula
  2120. else:
  2121. method = self._pmf_logexp
  2122. return method
  2123. def _pmf_formula(self, x, **params):
  2124. raise NotImplementedError(self._not_implemented)
  2125. def _pmf_logexp(self, x, **params):
  2126. return np.exp(self._logpmf_dispatch(x, **params))
  2127. ## Cumulative Distribution Functions
  2128. def logcdf(self, x, y=None, /, *, method=None):
  2129. if y is None:
  2130. return self._logcdf1(x, method=method)
  2131. else:
  2132. return self._logcdf2(x, y, method=method)
  2133. @_cdf2_input_validation
  2134. def _logcdf2(self, x, y, *, method):
  2135. out = self._logcdf2_dispatch(x, y, method=method, **self._parameters)
  2136. return (out + 0j) if not np.issubdtype(out.dtype, np.complexfloating) else out
  2137. @_dispatch
  2138. def _logcdf2_dispatch(self, x, y, *, method=None, **params):
  2139. # dtype is complex if any x > y, else real
  2140. # Should revisit this logic.
  2141. if self._overrides('_logcdf2_formula'):
  2142. method = self._logcdf2_formula
  2143. elif (self._overrides('_logcdf_formula')
  2144. or self._overrides('_logccdf_formula')):
  2145. method = self._logcdf2_subtraction
  2146. elif (self._overrides('_cdf_formula')
  2147. or self._overrides('_ccdf_formula')):
  2148. method = self._logcdf2_logexp_safe
  2149. else:
  2150. method = self._logcdf2_quadrature
  2151. return method
  2152. def _logcdf2_formula(self, x, y, **params):
  2153. raise NotImplementedError(self._not_implemented)
  2154. def _logcdf2_subtraction(self, x, y, **params):
  2155. flip_sign = x > y # some results will be negative
  2156. x, y = np.minimum(x, y), np.maximum(x, y)
  2157. logcdf_x = self._logcdf_dispatch(x, **params)
  2158. logcdf_y = self._logcdf_dispatch(y, **params)
  2159. logccdf_x = self._logccdf_dispatch(x, **params)
  2160. logccdf_y = self._logccdf_dispatch(y, **params)
  2161. case_left = (logcdf_x < -1) & (logcdf_y < -1)
  2162. case_right = (logccdf_x < -1) & (logccdf_y < -1)
  2163. case_central = ~(case_left | case_right)
  2164. log_mass = _logexpxmexpy(logcdf_y, logcdf_x)
  2165. log_mass[case_right] = _logexpxmexpy(logccdf_x, logccdf_y)[case_right]
  2166. log_tail = np.logaddexp(logcdf_x, logccdf_y)[case_central]
  2167. log_mass[case_central] = _log1mexp(log_tail)
  2168. log_mass[flip_sign] += np.pi * 1j
  2169. return log_mass[()] if np.any(flip_sign) else log_mass.real[()]
  2170. def _logcdf2_logexp(self, x, y, **params):
  2171. expres = self._cdf2_dispatch(x, y, **params)
  2172. expres = expres + 0j if np.any(x > y) else expres
  2173. return np.log(expres)
  2174. def _logcdf2_logexp_safe(self, x, y, **params):
  2175. out = self._logcdf2_logexp(x, y, **params)
  2176. mask = np.isinf(out.real)
  2177. if np.any(mask):
  2178. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2179. for key, val in params.items()}
  2180. out = np.asarray(out)
  2181. out[mask] = self._logcdf2_quadrature(x[mask], y[mask], **params_mask)
  2182. return out[()]
  2183. def _logcdf2_quadrature(self, x, y, **params):
  2184. logres = self._quadrature(self._logpxf_dispatch, limits=(x, y),
  2185. log=True, params=params)
  2186. return logres
  2187. @_set_invalid_nan
  2188. def _logcdf1(self, x, *, method=None):
  2189. return self._logcdf_dispatch(x, method=method, **self._parameters)
  2190. @_dispatch
  2191. def _logcdf_dispatch(self, x, *, method=None, **params):
  2192. if self._overrides('_logcdf_formula'):
  2193. method = self._logcdf_formula
  2194. elif self._overrides('_logccdf_formula'):
  2195. method = self._logcdf_complement
  2196. elif self._overrides('_cdf_formula'):
  2197. method = self._logcdf_logexp_safe
  2198. else:
  2199. method = self._logcdf_quadrature
  2200. return method
  2201. def _logcdf_formula(self, x, **params):
  2202. raise NotImplementedError(self._not_implemented)
  2203. def _logcdf_complement(self, x, **params):
  2204. return _log1mexp(self._logccdf_dispatch(x, **params))
  2205. def _logcdf_logexp(self, x, **params):
  2206. return np.log(self._cdf_dispatch(x, **params))
  2207. def _logcdf_logexp_safe(self, x, **params):
  2208. out = self._logcdf_logexp(x, **params)
  2209. mask = np.isinf(out)
  2210. if np.any(mask):
  2211. params_mask = {key:np.broadcast_to(val, mask.shape)[mask]
  2212. for key, val in params.items()}
  2213. out = np.asarray(out)
  2214. out[mask] = self._logcdf_quadrature(x[mask], **params_mask)
  2215. return out[()]
  2216. def _logcdf_quadrature(self, x, **params):
  2217. a, _ = self._support(**params)
  2218. return self._quadrature(self._logpxf_dispatch, limits=(a, x),
  2219. params=params, log=True)
  2220. def cdf(self, x, y=None, /, *, method=None):
  2221. if y is None:
  2222. return self._cdf1(x, method=method)
  2223. else:
  2224. return self._cdf2(x, y, method=method)
  2225. @_cdf2_input_validation
  2226. def _cdf2(self, x, y, *, method):
  2227. return self._cdf2_dispatch(x, y, method=method, **self._parameters)
  2228. @_dispatch
  2229. def _cdf2_dispatch(self, x, y, *, method=None, **params):
  2230. # Should revisit this logic.
  2231. if self._overrides('_cdf2_formula'):
  2232. method = self._cdf2_formula
  2233. elif (self._overrides('_logcdf_formula')
  2234. or self._overrides('_logccdf_formula')):
  2235. method = self._cdf2_logexp
  2236. elif self._overrides('_cdf_formula') or self._overrides('_ccdf_formula'):
  2237. method = self._cdf2_subtraction_safe
  2238. else:
  2239. method = self._cdf2_quadrature
  2240. return method
  2241. def _cdf2_formula(self, x, y, **params):
  2242. raise NotImplementedError(self._not_implemented)
  2243. def _cdf2_logexp(self, x, y, **params):
  2244. return np.real(np.exp(self._logcdf2_dispatch(x, y, **params)))
  2245. def _cdf2_subtraction(self, x, y, **params):
  2246. # Improvements:
  2247. # Lazy evaluation of cdf/ccdf only where needed
  2248. # Stack x and y to reduce function calls?
  2249. cdf_x = self._cdf_dispatch(x, **params)
  2250. cdf_y = self._cdf_dispatch(y, **params)
  2251. ccdf_x = self._ccdf_dispatch(x, **params)
  2252. ccdf_y = self._ccdf_dispatch(y, **params)
  2253. i = (ccdf_x < 0.5) & (ccdf_y < 0.5)
  2254. return np.where(i, ccdf_x-ccdf_y, cdf_y-cdf_x)
  2255. def _cdf2_subtraction_safe(self, x, y, **params):
  2256. cdf_x = self._cdf_dispatch(x, **params)
  2257. cdf_y = self._cdf_dispatch(y, **params)
  2258. ccdf_x = self._ccdf_dispatch(x, **params)
  2259. ccdf_y = self._ccdf_dispatch(y, **params)
  2260. i = (ccdf_x < 0.5) & (ccdf_y < 0.5)
  2261. out = np.where(i, ccdf_x-ccdf_y, cdf_y-cdf_x)
  2262. eps = np.finfo(self._dtype).eps
  2263. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2264. cdf_max = np.maximum(cdf_x, cdf_y)
  2265. ccdf_max = np.maximum(ccdf_x, ccdf_y)
  2266. spacing = np.spacing(np.where(i, ccdf_max, cdf_max))
  2267. mask = np.abs(tol * out) < spacing
  2268. if np.any(mask):
  2269. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2270. for key, val in params.items()}
  2271. out = np.asarray(out)
  2272. out[mask] = self._cdf2_quadrature(x[mask], y[mask], **params_mask)
  2273. return out[()]
  2274. def _cdf2_quadrature(self, x, y, **params):
  2275. return self._quadrature(self._pxf_dispatch, limits=(x, y), params=params)
  2276. @_set_invalid_nan
  2277. def _cdf1(self, x, *, method):
  2278. return self._cdf_dispatch(x, method=method, **self._parameters)
  2279. @_dispatch
  2280. def _cdf_dispatch(self, x, *, method=None, **params):
  2281. if self._overrides('_cdf_formula'):
  2282. method = self._cdf_formula
  2283. elif self._overrides('_logcdf_formula'):
  2284. method = self._cdf_logexp
  2285. elif self._overrides('_ccdf_formula'):
  2286. method = self._cdf_complement_safe
  2287. else:
  2288. method = self._cdf_quadrature
  2289. return method
  2290. def _cdf_formula(self, x, **params):
  2291. raise NotImplementedError(self._not_implemented)
  2292. def _cdf_logexp(self, x, **params):
  2293. return np.exp(self._logcdf_dispatch(x, **params))
  2294. def _cdf_complement(self, x, **params):
  2295. return 1 - self._ccdf_dispatch(x, **params)
  2296. def _cdf_complement_safe(self, x, **params):
  2297. ccdf = self._ccdf_dispatch(x, **params)
  2298. out = 1 - ccdf
  2299. eps = np.finfo(self._dtype).eps
  2300. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2301. mask = tol * out < np.spacing(ccdf)
  2302. if np.any(mask):
  2303. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2304. for key, val in params.items()}
  2305. out = np.asarray(out)
  2306. out[mask] = self._cdf_quadrature(x[mask], *params_mask)
  2307. return out[()]
  2308. def _cdf_quadrature(self, x, **params):
  2309. a, _ = self._support(**params)
  2310. return self._quadrature(self._pxf_dispatch, limits=(a, x),
  2311. params=params)
  2312. def logccdf(self, x, y=None, /, *, method=None):
  2313. if y is None:
  2314. return self._logccdf1(x, method=method)
  2315. else:
  2316. return self._logccdf2(x, y, method=method)
  2317. @_cdf2_input_validation
  2318. def _logccdf2(self, x, y, *, method):
  2319. return self._logccdf2_dispatch(x, y, method=method, **self._parameters)
  2320. @_dispatch
  2321. def _logccdf2_dispatch(self, x, y, *, method=None, **params):
  2322. # if _logccdf2_formula exists, we could use the complement
  2323. # if _ccdf2_formula exists, we could use log/exp
  2324. if self._overrides('_logccdf2_formula'):
  2325. method = self._logccdf2_formula
  2326. else:
  2327. method = self._logccdf2_addition
  2328. return method
  2329. def _logccdf2_formula(self, x, y, **params):
  2330. raise NotImplementedError(self._not_implemented)
  2331. def _logccdf2_addition(self, x, y, **params):
  2332. logcdf_x = self._logcdf_dispatch(x, **params)
  2333. logccdf_y = self._logccdf_dispatch(y, **params)
  2334. return special.logsumexp([logcdf_x, logccdf_y], axis=0)
  2335. @_set_invalid_nan
  2336. def _logccdf1(self, x, *, method=None):
  2337. return self._logccdf_dispatch(x, method=method, **self._parameters)
  2338. @_dispatch
  2339. def _logccdf_dispatch(self, x, method=None, **params):
  2340. if self._overrides('_logccdf_formula'):
  2341. method = self._logccdf_formula
  2342. elif self._overrides('_logcdf_formula'):
  2343. method = self._logccdf_complement
  2344. elif self._overrides('_ccdf_formula'):
  2345. method = self._logccdf_logexp_safe
  2346. else:
  2347. method = self._logccdf_quadrature
  2348. return method
  2349. def _logccdf_formula(self, x, **params):
  2350. raise NotImplementedError(self._not_implemented)
  2351. def _logccdf_complement(self, x, **params):
  2352. return _log1mexp(self._logcdf_dispatch(x, **params))
  2353. def _logccdf_logexp(self, x, **params):
  2354. return np.log(self._ccdf_dispatch(x, **params))
  2355. def _logccdf_logexp_safe(self, x, **params):
  2356. out = self._logccdf_logexp(x, **params)
  2357. mask = np.isinf(out)
  2358. if np.any(mask):
  2359. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2360. for key, val in params.items()}
  2361. out = np.asarray(out)
  2362. out[mask] = self._logccdf_quadrature(x[mask], **params_mask)
  2363. return out[()]
  2364. def _logccdf_quadrature(self, x, **params):
  2365. _, b = self._support(**params)
  2366. return self._quadrature(self._logpxf_dispatch, limits=(x, b),
  2367. params=params, log=True)
  2368. def ccdf(self, x, y=None, /, *, method=None):
  2369. if y is None:
  2370. return self._ccdf1(x, method=method)
  2371. else:
  2372. return self._ccdf2(x, y, method=method)
  2373. @_cdf2_input_validation
  2374. def _ccdf2(self, x, y, *, method):
  2375. return self._ccdf2_dispatch(x, y, method=method, **self._parameters)
  2376. @_dispatch
  2377. def _ccdf2_dispatch(self, x, y, *, method=None, **params):
  2378. if self._overrides('_ccdf2_formula'):
  2379. method = self._ccdf2_formula
  2380. else:
  2381. method = self._ccdf2_addition
  2382. return method
  2383. def _ccdf2_formula(self, x, y, **params):
  2384. raise NotImplementedError(self._not_implemented)
  2385. def _ccdf2_addition(self, x, y, **params):
  2386. cdf_x = self._cdf_dispatch(x, **params)
  2387. ccdf_y = self._ccdf_dispatch(y, **params)
  2388. # even if x > y, cdf(x, y) + ccdf(x,y) sums to 1
  2389. return cdf_x + ccdf_y
  2390. @_set_invalid_nan
  2391. def _ccdf1(self, x, *, method):
  2392. return self._ccdf_dispatch(x, method=method, **self._parameters)
  2393. @_dispatch
  2394. def _ccdf_dispatch(self, x, method=None, **params):
  2395. if self._overrides('_ccdf_formula'):
  2396. method = self._ccdf_formula
  2397. elif self._overrides('_logccdf_formula'):
  2398. method = self._ccdf_logexp
  2399. elif self._overrides('_cdf_formula'):
  2400. method = self._ccdf_complement_safe
  2401. else:
  2402. method = self._ccdf_quadrature
  2403. return method
  2404. def _ccdf_formula(self, x, **params):
  2405. raise NotImplementedError(self._not_implemented)
  2406. def _ccdf_logexp(self, x, **params):
  2407. return np.exp(self._logccdf_dispatch(x, **params))
  2408. def _ccdf_complement(self, x, **params):
  2409. return 1 - self._cdf_dispatch(x, **params)
  2410. def _ccdf_complement_safe(self, x, **params):
  2411. cdf = self._cdf_dispatch(x, **params)
  2412. out = 1 - cdf
  2413. eps = np.finfo(self._dtype).eps
  2414. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2415. mask = tol * out < np.spacing(cdf)
  2416. if np.any(mask):
  2417. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2418. for key, val in params.items()}
  2419. out = np.asarray(out)
  2420. out[mask] = self._ccdf_quadrature(x[mask], **params_mask)
  2421. return out[()]
  2422. def _ccdf_quadrature(self, x, **params):
  2423. _, b = self._support(**params)
  2424. return self._quadrature(self._pxf_dispatch, limits=(x, b),
  2425. params=params)
  2426. ## Inverse cumulative distribution functions
  2427. @_set_invalid_nan
  2428. def ilogcdf(self, logp, /, *, method=None):
  2429. return self._ilogcdf_dispatch(logp, method=method, **self._parameters)
  2430. @_dispatch
  2431. def _ilogcdf_dispatch(self, x, method=None, **params):
  2432. if self._overrides('_ilogcdf_formula'):
  2433. method = self._ilogcdf_formula
  2434. elif self._overrides('_ilogccdf_formula'):
  2435. method = self._ilogcdf_complement
  2436. else:
  2437. method = self._ilogcdf_inversion
  2438. return method
  2439. def _ilogcdf_formula(self, x, **params):
  2440. raise NotImplementedError(self._not_implemented)
  2441. def _ilogcdf_complement(self, x, **params):
  2442. return self._ilogccdf_dispatch(_log1mexp(x), **params)
  2443. def _ilogcdf_inversion(self, x, **params):
  2444. return self._solve_bounded_continuous(self._logcdf_dispatch, x, params=params)
  2445. @_set_invalid_nan
  2446. def icdf(self, p, /, *, method=None):
  2447. return self._icdf_dispatch(p, method=method, **self._parameters)
  2448. @_dispatch
  2449. def _icdf_dispatch(self, x, method=None, **params):
  2450. if self._overrides('_icdf_formula'):
  2451. method = self._icdf_formula
  2452. elif self._overrides('_iccdf_formula'):
  2453. method = self._icdf_complement_safe
  2454. else:
  2455. method = self._icdf_inversion
  2456. return method
  2457. def _icdf_formula(self, x, **params):
  2458. raise NotImplementedError(self._not_implemented)
  2459. def _icdf_complement(self, x, **params):
  2460. return self._iccdf_dispatch(1 - x, **params)
  2461. def _icdf_complement_safe(self, x, **params):
  2462. out = self._icdf_complement(x, **params)
  2463. eps = np.finfo(self._dtype).eps
  2464. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2465. mask = tol * x < np.spacing(1 - x)
  2466. if np.any(mask):
  2467. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2468. for key, val in params.items()}
  2469. out = np.asarray(out)
  2470. out[mask] = self._icdf_inversion(x[mask], *params_mask)
  2471. return out[()]
  2472. def _icdf_inversion(self, x, **params):
  2473. return self._solve_bounded_continuous(self._cdf_dispatch, x, params=params)
  2474. @_set_invalid_nan
  2475. def ilogccdf(self, logp, /, *, method=None):
  2476. return self._ilogccdf_dispatch(logp, method=method, **self._parameters)
  2477. @_dispatch
  2478. def _ilogccdf_dispatch(self, x, method=None, **params):
  2479. if self._overrides('_ilogccdf_formula'):
  2480. method = self._ilogccdf_formula
  2481. elif self._overrides('_ilogcdf_formula'):
  2482. method = self._ilogccdf_complement
  2483. else:
  2484. method = self._ilogccdf_inversion
  2485. return method
  2486. def _ilogccdf_formula(self, x, **params):
  2487. raise NotImplementedError(self._not_implemented)
  2488. def _ilogccdf_complement(self, x, **params):
  2489. return self._ilogcdf_dispatch(_log1mexp(x), **params)
  2490. def _ilogccdf_inversion(self, x, **params):
  2491. return self._solve_bounded_continuous(self._logccdf_dispatch, x, params=params)
  2492. @_set_invalid_nan
  2493. def iccdf(self, p, /, *, method=None):
  2494. return self._iccdf_dispatch(p, method=method, **self._parameters)
  2495. @_dispatch
  2496. def _iccdf_dispatch(self, x, method=None, **params):
  2497. if self._overrides('_iccdf_formula'):
  2498. method = self._iccdf_formula
  2499. elif self._overrides('_icdf_formula'):
  2500. method = self._iccdf_complement_safe
  2501. else:
  2502. method = self._iccdf_inversion
  2503. return method
  2504. def _iccdf_formula(self, x, **params):
  2505. raise NotImplementedError(self._not_implemented)
  2506. def _iccdf_complement(self, x, **params):
  2507. return self._icdf_dispatch(1 - x, **params)
  2508. def _iccdf_complement_safe(self, x, **params):
  2509. out = self._iccdf_complement(x, **params)
  2510. eps = np.finfo(self._dtype).eps
  2511. tol = self.tol if not _isnull(self.tol) else np.sqrt(eps)
  2512. mask = tol * x < np.spacing(1 - x)
  2513. if np.any(mask):
  2514. params_mask = {key: np.broadcast_to(val, mask.shape)[mask]
  2515. for key, val in params.items()}
  2516. out = np.asarray(out)
  2517. out[mask] = self._iccdf_inversion(x[mask], *params_mask)
  2518. return out[()]
  2519. def _iccdf_inversion(self, x, **params):
  2520. return self._solve_bounded_continuous(self._ccdf_dispatch, x, params=params)
  2521. ### Sampling Functions
  2522. # The following functions for drawing samples from the distribution are
  2523. # exposed via a public method that accepts one positional argument - the
  2524. # shape of the sample - and keyword options (but not distribution
  2525. # parameters).
  2526. # sample
  2527. # ~~qmc_sample~~ built into sample now
  2528. #
  2529. # Common keyword options include:
  2530. # method - a string that indicates which method should be used to compute
  2531. # the quantity (e.g. a formula or numerical integration).
  2532. # rng - the NumPy Generator/SciPy QMCEngine object to used for drawing numbers.
  2533. #
  2534. # Input/output validation is included in each function, since there is
  2535. # little code to be shared.
  2536. # These are the methods meant to be called by users.
  2537. #
  2538. # Each public method calls a private "dispatch" method that
  2539. # determines which "method" (strategy for calculating the desired quantity)
  2540. # to use by default and, via the `@_dispatch` decorator, calls the
  2541. # method and computes the result.
  2542. # Each dispatch method can designate the responsibility of sampling to any
  2543. # of several "implementation" methods. These methods accept only
  2544. # `**params`, the parameter dictionary passed from the public method via
  2545. # the "dispatch" method.
  2546. # See the note corresponding with the "Distribution Parameters" for more
  2547. # information.
  2548. # TODO:
  2549. # - should we accept a QRNG with `d != 1`?
  2550. def sample(self, shape=(), *, method=None, rng=None):
  2551. # needs output validation to ensure that developer returns correct
  2552. # dtype and shape
  2553. sample_shape = (shape,) if not np.iterable(shape) else tuple(shape)
  2554. full_shape = sample_shape + self._shape
  2555. rng = np.random.default_rng(rng) if not isinstance(rng, qmc.QMCEngine) else rng
  2556. res = self._sample_dispatch(full_shape, method=method, rng=rng,
  2557. **self._parameters)
  2558. return res.astype(self._dtype, copy=False)
  2559. @_dispatch
  2560. def _sample_dispatch(self, full_shape, *, method, rng, **params):
  2561. # make sure that tests catch if sample is 0d array
  2562. if self._overrides('_sample_formula') and not isinstance(rng, qmc.QMCEngine):
  2563. method = self._sample_formula
  2564. else:
  2565. method = self._sample_inverse_transform
  2566. return method
  2567. def _sample_formula(self, full_shape, *, rng, **params):
  2568. raise NotImplementedError(self._not_implemented)
  2569. def _sample_inverse_transform(self, full_shape, *, rng, **params):
  2570. if isinstance(rng, qmc.QMCEngine):
  2571. uniform = self._qmc_uniform(full_shape, qrng=rng, **params)
  2572. else:
  2573. uniform = rng.random(size=full_shape, dtype=self._dtype)
  2574. return self._icdf_dispatch(uniform, **params)
  2575. def _qmc_uniform(self, full_shape, *, qrng, **params):
  2576. # Generate QMC uniform sample(s) on unit interval with specified shape;
  2577. # if `sample_shape != ()`, then each slice along axis 0 is independent.
  2578. sample_shape = full_shape[:len(full_shape)-len(self._shape)]
  2579. # Determine the number of independent sequences and the length of each.
  2580. n_low_discrepancy = sample_shape[0] if sample_shape else 1
  2581. n_independent = math.prod(full_shape[1:] if sample_shape else full_shape)
  2582. # For each independent sequence, we'll need a new QRNG of the appropriate class
  2583. # with its own RNG. (If scramble=False, we don't really need all the separate
  2584. # rngs, but I'm not going to add a special code path right now.)
  2585. rngs = _rng_spawn(qrng.rng, n_independent)
  2586. qrng_class = qrng.__class__
  2587. kwargs = dict(d=1, scramble=qrng.scramble, optimization=qrng._optimization)
  2588. if isinstance(qrng, qmc.Sobol):
  2589. kwargs['bits'] = qrng.bits
  2590. # Draw uniform low-discrepancy sequences scrambled with each RNG
  2591. uniforms = []
  2592. for rng in rngs:
  2593. qrng = qrng_class(seed=rng, **kwargs)
  2594. uniform = qrng.random(n_low_discrepancy)
  2595. uniform = uniform.reshape(n_low_discrepancy if sample_shape else ())[()]
  2596. uniforms.append(uniform)
  2597. # Reorder the axes and ensure that the shape is correct
  2598. uniform = np.moveaxis(np.stack(uniforms), -1, 0) if uniforms else np.asarray([])
  2599. return uniform.reshape(full_shape)
  2600. ### Moments
  2601. # The `moment` method accepts two positional arguments - the order and kind
  2602. # (raw, central, or standard) of the moment - and a keyword option:
  2603. # method - a string that indicates which method should be used to compute
  2604. # the quantity (e.g. a formula or numerical integration).
  2605. # Like the distribution properties, input/output validation is provided by
  2606. # the `_set_invalid_nan_property` decorator.
  2607. #
  2608. # Unlike most public methods above, `moment` dispatches to one of three
  2609. # private methods - one for each 'kind'. Like most *public* methods above,
  2610. # each of these private methods calls a private "dispatch" method that
  2611. # determines which "method" (strategy for calculating the desired quantity)
  2612. # to use. Also, each dispatch method can designate the responsibility
  2613. # computing the moment to one of several "implementation" methods.
  2614. # Unlike the dispatch methods above, however, the `@_dispatch` decorator
  2615. # is not used, and both logic and method calls are included in the function
  2616. # itself.
  2617. # Instead of determining which method will be used based solely on the
  2618. # implementation methods available and calling only the corresponding
  2619. # implementation method, *all* the implementation methods are called
  2620. # in sequence until one returns the desired information. When an
  2621. # implementation methods cannot provide the requested information, it
  2622. # returns the object None (which is distinct from arrays with NaNs or infs,
  2623. # which are valid values of moments).
  2624. # The reason for this approach is that although formulae for the first
  2625. # few moments of a distribution may be found, general formulae that work
  2626. # for all orders are not always easy to find. This approach allows the
  2627. # developer to write "formula" implementation functions that return the
  2628. # desired moment when it is available and None otherwise.
  2629. #
  2630. # Note that the first implementation method called is a cache. This is
  2631. # important because lower-order moments are often needed to compute
  2632. # higher moments from formulae, so we eliminate redundant calculations
  2633. # when moments of several orders are needed.
  2634. @cached_property
  2635. def _moment_methods(self):
  2636. return {'cache', 'formula', 'transform',
  2637. 'normalize', 'general', 'quadrature'}
  2638. @property
  2639. def _zero(self):
  2640. return self._constants()[0]
  2641. @property
  2642. def _one(self):
  2643. return self._constants()[1]
  2644. def _constants(self):
  2645. if self._constant_cache is not None:
  2646. return self._constant_cache
  2647. constants = self._preserve_type([0, 1])
  2648. if self.cache_policy != _NO_CACHE:
  2649. self._constant_cache = constants
  2650. return constants
  2651. @_set_invalid_nan_property
  2652. def moment(self, order=1, kind='raw', *, method=None):
  2653. kinds = {'raw': self._moment_raw,
  2654. 'central': self._moment_central,
  2655. 'standardized': self._moment_standardized}
  2656. order = self._validate_order_kind(order, kind, kinds)
  2657. moment_kind = kinds[kind]
  2658. return moment_kind(order, method=method)
  2659. def _moment_raw(self, order=1, *, method=None):
  2660. """Raw distribution moment about the origin."""
  2661. # Consider exposing the point about which moments are taken as an
  2662. # option. This is easy to support, since `_moment_transform_center`
  2663. # does all the work.
  2664. methods = self._moment_methods if method is None else {method}
  2665. return self._moment_raw_dispatch(order, methods=methods, **self._parameters)
  2666. def _moment_raw_dispatch(self, order, *, methods, **params):
  2667. moment = None
  2668. if 'cache' in methods:
  2669. moment = self._moment_raw_cache.get(order, None)
  2670. if moment is None and 'formula' in methods:
  2671. moment = self._moment_raw_formula(order, **params)
  2672. if moment is None and 'transform' in methods and order > 1:
  2673. moment = self._moment_raw_transform(order, **params)
  2674. if moment is None and 'general' in methods:
  2675. moment = self._moment_raw_general(order, **params)
  2676. if moment is None and 'quadrature' in methods:
  2677. moment = self._moment_from_pxf(order, center=self._zero, **params)
  2678. if moment is None and 'quadrature_icdf' in methods:
  2679. moment = self._moment_integrate_icdf(order, center=self._zero, **params)
  2680. if moment is not None and self.cache_policy != _NO_CACHE:
  2681. self._moment_raw_cache[order] = moment
  2682. return moment
  2683. def _moment_raw_formula(self, order, **params):
  2684. return None
  2685. def _moment_raw_transform(self, order, **params):
  2686. central_moments = []
  2687. for i in range(int(order) + 1):
  2688. methods = {'cache', 'formula', 'normalize', 'general'}
  2689. moment_i = self._moment_central_dispatch(order=i, methods=methods, **params)
  2690. if moment_i is None:
  2691. return None
  2692. central_moments.append(moment_i)
  2693. # Doesn't make sense to get the mean by "transform", since that's
  2694. # how we got here. Questionable whether 'quadrature' should be here.
  2695. mean_methods = {'cache', 'formula', 'quadrature'}
  2696. mean = self._moment_raw_dispatch(self._one, methods=mean_methods, **params)
  2697. if mean is None:
  2698. return None
  2699. moment = self._moment_transform_center(order, central_moments, mean, self._zero)
  2700. return moment
  2701. def _moment_raw_general(self, order, **params):
  2702. # This is the only general formula for a raw moment of a probability
  2703. # distribution
  2704. return self._one if order == 0 else None
  2705. def _moment_central(self, order=1, *, method=None):
  2706. """Distribution moment about the mean."""
  2707. methods = self._moment_methods if method is None else {method}
  2708. return self._moment_central_dispatch(order, methods=methods, **self._parameters)
  2709. def _moment_central_dispatch(self, order, *, methods, **params):
  2710. moment = None
  2711. if 'cache' in methods:
  2712. moment = self._moment_central_cache.get(order, None)
  2713. if moment is None and 'formula' in methods:
  2714. moment = self._moment_central_formula(order, **params)
  2715. if moment is None and 'transform' in methods:
  2716. moment = self._moment_central_transform(order, **params)
  2717. if moment is None and 'normalize' in methods and order > 2:
  2718. moment = self._moment_central_normalize(order, **params)
  2719. if moment is None and 'general' in methods:
  2720. moment = self._moment_central_general(order, **params)
  2721. if moment is None and 'quadrature' in methods:
  2722. mean = self._moment_raw_dispatch(self._one, **params,
  2723. methods=self._moment_methods)
  2724. moment = self._moment_from_pxf(order, center=mean, **params)
  2725. if moment is None and 'quadrature_icdf' in methods:
  2726. mean = self._moment_raw_dispatch(self._one, **params,
  2727. methods=self._moment_methods)
  2728. moment = self._moment_integrate_icdf(order, center=mean, **params)
  2729. if moment is not None and self.cache_policy != _NO_CACHE:
  2730. self._moment_central_cache[order] = moment
  2731. return moment
  2732. def _moment_central_formula(self, order, **params):
  2733. return None
  2734. def _moment_central_transform(self, order, **params):
  2735. raw_moments = []
  2736. for i in range(int(order) + 1):
  2737. methods = {'cache', 'formula', 'general'}
  2738. moment_i = self._moment_raw_dispatch(order=i, methods=methods, **params)
  2739. if moment_i is None:
  2740. return None
  2741. raw_moments.append(moment_i)
  2742. mean_methods = self._moment_methods
  2743. mean = self._moment_raw_dispatch(self._one, methods=mean_methods, **params)
  2744. moment = self._moment_transform_center(order, raw_moments, self._zero, mean)
  2745. return moment
  2746. def _moment_central_normalize(self, order, **params):
  2747. methods = {'cache', 'formula', 'general'}
  2748. standard_moment = self._moment_standardized_dispatch(order, **params,
  2749. methods=methods)
  2750. if standard_moment is None:
  2751. return None
  2752. var = self._moment_central_dispatch(2, methods=self._moment_methods, **params)
  2753. return standard_moment*var**(order/2)
  2754. def _moment_central_general(self, order, **params):
  2755. general_central_moments = {0: self._one, 1: self._zero}
  2756. return general_central_moments.get(order, None)
  2757. def _moment_standardized(self, order=1, *, method=None):
  2758. """Standardized distribution moment."""
  2759. methods = self._moment_methods if method is None else {method}
  2760. return self._moment_standardized_dispatch(order, methods=methods,
  2761. **self._parameters)
  2762. def _moment_standardized_dispatch(self, order, *, methods, **params):
  2763. moment = None
  2764. if 'cache' in methods:
  2765. moment = self._moment_standardized_cache.get(order, None)
  2766. if moment is None and 'formula' in methods:
  2767. moment = self._moment_standardized_formula(order, **params)
  2768. if moment is None and 'normalize' in methods:
  2769. moment = self._moment_standardized_normalize(order, False, **params)
  2770. if moment is None and 'general' in methods:
  2771. moment = self._moment_standardized_general(order, **params)
  2772. if moment is None and 'normalize' in methods:
  2773. moment = self._moment_standardized_normalize(order, True, **params)
  2774. if moment is not None and self.cache_policy != _NO_CACHE:
  2775. self._moment_standardized_cache[order] = moment
  2776. return moment
  2777. def _moment_standardized_formula(self, order, **params):
  2778. return None
  2779. def _moment_standardized_normalize(self, order, use_quadrature, **params):
  2780. methods = ({'quadrature'} if use_quadrature
  2781. else {'cache', 'formula', 'transform'})
  2782. central_moment = self._moment_central_dispatch(order, **params,
  2783. methods=methods)
  2784. if central_moment is None:
  2785. return None
  2786. var = self._moment_central_dispatch(2, methods=self._moment_methods,
  2787. **params)
  2788. return central_moment/var**(order/2)
  2789. def _moment_standardized_general(self, order, **params):
  2790. general_standard_moments = {0: self._one, 1: self._zero, 2: self._one}
  2791. return general_standard_moments.get(order, None)
  2792. def _moment_from_pxf(self, order, center, **params):
  2793. def integrand(x, order, center, **params):
  2794. pxf = self._pxf_dispatch(x, **params)
  2795. return pxf*(x-center)**order
  2796. return self._quadrature(integrand, args=(order, center), params=params)
  2797. def _moment_integrate_icdf(self, order, center, **params):
  2798. def integrand(x, order, center, **params):
  2799. x = self._icdf_dispatch(x, **params)
  2800. return (x-center)**order
  2801. return self._quadrature(integrand, limits=(0., 1.),
  2802. args=(order, center), params=params)
  2803. def _moment_transform_center(self, order, moment_as, a, b):
  2804. a, b, *moment_as = np.broadcast_arrays(a, b, *moment_as)
  2805. n = order
  2806. i = np.arange(n+1).reshape([-1]+[1]*a.ndim) # orthogonal to other axes
  2807. i = self._preserve_type(i)
  2808. n_choose_i = special.binom(n, i)
  2809. with np.errstate(invalid='ignore'): # can happen with infinite moment
  2810. moment_b = np.sum(n_choose_i*moment_as*(a-b)**(n-i), axis=0)
  2811. return moment_b
  2812. def _logmoment(self, order=1, *, logcenter=None, standardized=False):
  2813. # make this private until it is worked into moment
  2814. if logcenter is None or standardized is True:
  2815. logmean = self._logmoment_quad(self._one, -np.inf, **self._parameters)
  2816. else:
  2817. logmean = None
  2818. logcenter = logmean if logcenter is None else logcenter
  2819. res = self._logmoment_quad(order, logcenter, **self._parameters)
  2820. if standardized:
  2821. logvar = self._logmoment_quad(2, logmean, **self._parameters)
  2822. res = res - logvar * (order/2)
  2823. return res
  2824. def _logmoment_quad(self, order, logcenter, **params):
  2825. def logintegrand(x, order, logcenter, **params):
  2826. logpdf = self._logpxf_dispatch(x, **params)
  2827. return logpdf + order * _logexpxmexpy(np.log(x + 0j), logcenter)
  2828. ## if logx == logcenter, `_logexpxmexpy` returns (-inf + 0j)
  2829. ## multiplying by order produces (-inf + nan j) - bad
  2830. ## We're skipping logmoment tests, so we might don't need to fix
  2831. ## now, but if we ever do use run them, this might help:
  2832. # logx = np.log(x+0j)
  2833. # out = np.asarray(logpdf + order*_logexpxmexpy(logx, logcenter))
  2834. # i = (logx == logcenter)
  2835. # out[i] = logpdf[i]
  2836. # return out
  2837. return self._quadrature(logintegrand, args=(order, logcenter),
  2838. params=params, log=True)
  2839. ### Convenience
  2840. def plot(self, x='x', y=None, *, t=None, ax=None):
  2841. r"""Plot a function of the distribution.
  2842. Convenience function for quick visualization of the distribution
  2843. underlying the random variable.
  2844. Parameters
  2845. ----------
  2846. x, y : str, optional
  2847. String indicating the quantities to be used as the abscissa and
  2848. ordinate (horizontal and vertical coordinates), respectively.
  2849. Defaults are ``'x'`` (the domain of the random variable) and either
  2850. ``'pdf'`` (the probability density function) (continuous) or
  2851. ``'pdf'`` (the probability density function) (discrete).
  2852. Valid values are:
  2853. 'x', 'pdf', 'pmf', 'cdf', 'ccdf', 'icdf', 'iccdf', 'logpdf', 'logpmf',
  2854. 'logcdf', 'logccdf', 'ilogcdf', 'ilogccdf'.
  2855. t : 3-tuple of (str, float, float), optional
  2856. Tuple indicating the limits within which the quantities are plotted.
  2857. The default is ``('cdf', 0.0005, 0.9995)`` if the domain is infinite,
  2858. indicating that the central 99.9% of the distribution is to be shown;
  2859. otherwise, endpoints of the support are used where they are finite.
  2860. Valid values are:
  2861. 'x', 'cdf', 'ccdf', 'icdf', 'iccdf', 'logcdf', 'logccdf',
  2862. 'ilogcdf', 'ilogccdf'.
  2863. ax : `matplotlib.axes`, optional
  2864. Axes on which to generate the plot. If not provided, use the
  2865. current axes.
  2866. Returns
  2867. -------
  2868. ax : `matplotlib.axes`
  2869. Axes on which the plot was generated.
  2870. The plot can be customized by manipulating this object.
  2871. Examples
  2872. --------
  2873. Instantiate a distribution with the desired parameters:
  2874. >>> import numpy as np
  2875. >>> import matplotlib.pyplot as plt
  2876. >>> from scipy import stats
  2877. >>> X = stats.Normal(mu=1., sigma=2.)
  2878. Plot the PDF over the central 99.9% of the distribution.
  2879. Compare against a histogram of a random sample.
  2880. >>> ax = X.plot()
  2881. >>> sample = X.sample(10000)
  2882. >>> ax.hist(sample, density=True, bins=50, alpha=0.5)
  2883. >>> plt.show()
  2884. Plot ``logpdf(x)`` as a function of ``x`` in the left tail,
  2885. where the log of the CDF is between -10 and ``np.log(0.5)``.
  2886. >>> X.plot('x', 'logpdf', t=('logcdf', -10, np.log(0.5)))
  2887. >>> plt.show()
  2888. Plot the PDF of the normal distribution as a function of the
  2889. CDF for various values of the scale parameter.
  2890. >>> X = stats.Normal(mu=0., sigma=[0.5, 1., 2])
  2891. >>> X.plot('cdf', 'pdf')
  2892. >>> plt.show()
  2893. """
  2894. # Strategy: given t limits, get quantile limits. Form grid of
  2895. # quantiles, compute requested x and y at quantiles, and plot.
  2896. # Currently, the grid of quantiles is always linearly spaced.
  2897. # Instead of always computing linearly-spaced quantiles, it
  2898. # would be better to choose:
  2899. # a) quantiles or probabilities
  2900. # b) linearly or logarithmically spaced
  2901. # based on the specified `t`.
  2902. # TODO:
  2903. # - smart spacing of points
  2904. # - when the parameters of the distribution are an array,
  2905. # use the full range of abscissae for all curves
  2906. discrete = isinstance(self, DiscreteDistribution)
  2907. t_is_quantile = {'x', 'icdf', 'iccdf', 'ilogcdf', 'ilogccdf'}
  2908. t_is_probability = {'cdf', 'ccdf', 'logcdf', 'logccdf'}
  2909. valid_t = t_is_quantile.union(t_is_probability)
  2910. valid_xy = valid_t.union({'pdf', 'logpdf', 'pmf', 'logpmf'})
  2911. y_default = 'pmf' if discrete else 'pdf'
  2912. y = y_default if y is None else y
  2913. ndim = self._ndim
  2914. x_name, y_name = x, y
  2915. t_name = 'cdf' if t is None else t[0]
  2916. a, b = self.support()
  2917. tliml_default = 0 if np.all(np.isfinite(a)) else 0.0005
  2918. tliml = tliml_default if t is None else t[1]
  2919. tlimr_default = 1 if np.all(np.isfinite(b)) else 0.9995
  2920. tlimr = tlimr_default if t is None else t[2]
  2921. tlim = np.asarray([tliml, tlimr])
  2922. tlim = tlim[:, np.newaxis] if ndim else tlim
  2923. # pdf/logpdf are not valid for `t` because we can't easily invert them
  2924. message = (f'Argument `t` of `{self.__class__.__name__}.plot` "'
  2925. f'must be one of {valid_t}')
  2926. if t_name not in valid_t:
  2927. raise ValueError(message)
  2928. message = (f'Argument `x` of `{self.__class__.__name__}.plot` "'
  2929. f'must be one of {valid_xy}')
  2930. if x_name not in valid_xy:
  2931. raise ValueError(message)
  2932. message = (f'Argument `y` of `{self.__class__.__name__}.plot` "'
  2933. f'must be one of {valid_xy}')
  2934. if y_name not in valid_xy:
  2935. raise ValueError(message)
  2936. # This could just be a warning
  2937. message = (f'`{self.__class__.__name__}.plot` was called on a random '
  2938. 'variable with at least one invalid shape parameters. When '
  2939. 'a parameter is invalid, no plot can be shown.')
  2940. if self._any_invalid:
  2941. raise ValueError(message)
  2942. # We could automatically ravel, but do we want to? For now, raise.
  2943. message = ("To use `plot`, distribution parameters must be "
  2944. "scalars or arrays with one or fewer dimensions.")
  2945. if ndim > 1:
  2946. raise ValueError(message)
  2947. try:
  2948. import matplotlib.pyplot as plt # noqa: F401, E402
  2949. except ModuleNotFoundError as exc:
  2950. message = ("`matplotlib` must be installed to use "
  2951. f"`{self.__class__.__name__}.plot`.")
  2952. raise ModuleNotFoundError(message) from exc
  2953. ax = plt.gca() if ax is None else ax
  2954. # get quantile limits given t limits
  2955. qlim = tlim if t_name in t_is_quantile else getattr(self, 'i'+t_name)(tlim)
  2956. message = (f"`{self.__class__.__name__}.plot` received invalid input for `t`: "
  2957. f"calling {'i'+t_name}({tlim}) produced {qlim}.")
  2958. if not np.all(np.isfinite(qlim)):
  2959. raise ValueError(message)
  2960. # form quantile grid
  2961. if discrete and x_name in t_is_quantile:
  2962. # should probably aggregate for large ranges
  2963. q = np.arange(np.min(qlim[0]), np.max(qlim[1]) + 1)
  2964. q = q[:, np.newaxis] if ndim else q
  2965. else:
  2966. grid = np.linspace(0, 1, 300)
  2967. grid = grid[:, np.newaxis] if ndim else grid
  2968. q = qlim[0] + (qlim[1] - qlim[0]) * grid
  2969. q = np.round(q) if discrete else q
  2970. # compute requested x and y at quantile grid
  2971. x = q if x_name in t_is_quantile else getattr(self, x_name)(q)
  2972. y = q if y_name in t_is_quantile else getattr(self, y_name)(q)
  2973. # make plot
  2974. x, y = np.broadcast_arrays(x.T, np.atleast_2d(y.T))
  2975. for xi, yi in zip(x, y): # plot is vectorized, but bar/step don't seem to be
  2976. if discrete and x_name in t_is_quantile and y_name == 'pmf':
  2977. # should this just be a step plot, too?
  2978. ax.bar(xi, yi, alpha=np.sqrt(1/y.shape[0])) # alpha heuristic
  2979. elif discrete and x_name in t_is_quantile:
  2980. values = yi
  2981. edges = np.concatenate((xi, [xi[-1]+1]))
  2982. ax.stairs(values, edges, baseline=None)
  2983. else:
  2984. ax.plot(xi, yi)
  2985. ax.set_xlabel(f"${x_name}$")
  2986. ax.set_ylabel(f"${y_name}$")
  2987. ax.set_title(str(self))
  2988. # only need a legend if distribution has parameters
  2989. if len(self._parameters):
  2990. label = []
  2991. parameters = self._parameterization.parameters
  2992. param_names = list(parameters)
  2993. param_arrays = [np.atleast_1d(self._parameters[pname])
  2994. for pname in param_names]
  2995. for param_vals in zip(*param_arrays):
  2996. assignments = [f"${parameters[name].symbol}$ = {val:.4g}"
  2997. for name, val in zip(param_names, param_vals)]
  2998. label.append(", ".join(assignments))
  2999. ax.legend(label)
  3000. return ax
  3001. ### Fitting
  3002. # All methods above treat the distribution parameters as fixed, and the
  3003. # variable argument may be a quantile or probability. The fitting functions
  3004. # are fundamentally different because the quantiles (often observations)
  3005. # are considered to be fixed, and the distribution parameters are the
  3006. # variables. In a sense, they are like an inverse of the sampling
  3007. # functions.
  3008. #
  3009. # At first glance, it would seem ideal for `fit` to be a classmethod,
  3010. # called like `LogUniform.fit(sample=sample)`.
  3011. # I tried this. I insisted on it for a while. But if `fit` is a
  3012. # classmethod, it cannot call instance methods. If we want to support MLE,
  3013. # MPS, MoM, MoLM, then we end up with most of the distribution functions
  3014. # above needing to be classmethods, too. All state information, such as
  3015. # tolerances and the underlying distribution of `ShiftedScaledDistribution`
  3016. # and `OrderStatisticDistribution`, would need to be passed into all
  3017. # methods. And I'm not really sure how we would call `fit` as a
  3018. # classmethod of a transformed distribution - maybe
  3019. # ShiftedScaledDistribution.fit would accept the class of the
  3020. # shifted/scaled distribution as an argument?
  3021. #
  3022. # In any case, it was a conscious decision for the infrastructure to
  3023. # treat the parameters as "fixed" and the quantile/percentile arguments
  3024. # as "variable". There are a lot of advantages to this structure, and I
  3025. # don't think the fact that a few methods reverse the fixed and variable
  3026. # quantities should make us question that choice. It can still accomodate
  3027. # these methods reasonably efficiently.
  3028. class ContinuousDistribution(UnivariateDistribution):
  3029. def _overrides(self, method_name):
  3030. if method_name in {'_logpmf_formula', '_pmf_formula'}:
  3031. return True
  3032. return super()._overrides(method_name)
  3033. def _pmf_formula(self, x, **params):
  3034. return np.zeros_like(x)
  3035. def _logpmf_formula(self, x, **params):
  3036. return np.full_like(x, -np.inf)
  3037. def _pxf_dispatch(self, x, *, method=None, **params):
  3038. return self._pdf_dispatch(x, method=method, **params)
  3039. def _logpxf_dispatch(self, x, *, method=None, **params):
  3040. return self._logpdf_dispatch(x, method=method, **params)
  3041. def _solve_bounded_continuous(self, func, p, params, xatol=None):
  3042. return self._solve_bounded(func, p, params=params, xatol=xatol).x
  3043. class DiscreteDistribution(UnivariateDistribution):
  3044. def _overrides(self, method_name):
  3045. if method_name in {'_logpdf_formula', '_pdf_formula'}:
  3046. return True
  3047. return super()._overrides(method_name)
  3048. def _logpdf_formula(self, x, **params):
  3049. if params:
  3050. p = next(iter(params.values()))
  3051. nan_result = np.isnan(x) | np.isnan(p)
  3052. else:
  3053. nan_result = np.isnan(x)
  3054. return np.where(nan_result, np.nan, np.inf)
  3055. def _pdf_formula(self, x, **params):
  3056. if params:
  3057. p = next(iter(params.values()))
  3058. nan_result = np.isnan(x) | np.isnan(p)
  3059. else:
  3060. nan_result = np.isnan(x)
  3061. return np.where(nan_result, np.nan, np.inf)
  3062. def _pxf_dispatch(self, x, *, method=None, **params):
  3063. return self._pmf_dispatch(x, method=method, **params)
  3064. def _logpxf_dispatch(self, x, *, method=None, **params):
  3065. return self._logpmf_dispatch(x, method=method, **params)
  3066. def _cdf_quadrature(self, x, **params):
  3067. return super()._cdf_quadrature(np.floor(x), **params)
  3068. def _logcdf_quadrature(self, x, **params):
  3069. return super()._logcdf_quadrature(np.floor(x), **params)
  3070. def _ccdf_quadrature(self, x, **params):
  3071. return super()._ccdf_quadrature(np.floor(x + 1), **params)
  3072. def _logccdf_quadrature(self, x, **params):
  3073. return super()._logccdf_quadrature(np.floor(x + 1), **params)
  3074. def _cdf2(self, x, y, *, method):
  3075. raise NotImplementedError(
  3076. "Two argument cdf functions are currently only supported for "
  3077. "continuous distributions.")
  3078. def _ccdf2(self, x, y, *, method):
  3079. raise NotImplementedError(
  3080. "Two argument cdf functions are currently only supported for "
  3081. "continuous distributions.")
  3082. def _logcdf2(self, x, y, *, method):
  3083. raise NotImplementedError(
  3084. "Two argument cdf functions are currently only supported for "
  3085. "continuous distributions.")
  3086. def _logccdf2(self, x, y, *, method):
  3087. raise NotImplementedError(
  3088. "Two argument cdf functions are currently only supported for "
  3089. "continuous distributions.")
  3090. def _solve_bounded_discrete(self, func, p, params, comp):
  3091. # We're trying to solve one of these two problems:
  3092. # a) find the smallest integer x* within the support s.t. F(x*) >= p
  3093. # b) find the smallest integer x* within the support s.t. G(x*) = 1 - F(x*) <= p
  3094. # Our approach is to solve a continuous version of the problem that narrows the
  3095. # solution down to an integer x s.t. either x* = x or x* = x + 1. At the end,
  3096. # we'll choose between them.
  3097. # First, solve func(x) == p where func is a continuous, monotone interpolant
  3098. # of either the monotone increasing F or monotone decreasing G.
  3099. res = self._solve_bounded(func, p, params=params, xatol=0.9)
  3100. # Here, `_solve_bounded` can terminate for one of three reasons:
  3101. # 1. `func(res.x) == p` (`fatol = 0` is satisfied),
  3102. # 2. `res.xl` and `res.xr` bracket the root and `|res.xr - res.xl| <= xatol`, or
  3103. # 3. There is no solution within the support.
  3104. # There are several possible strategies for using `res.xl`, `res.x`, and/or
  3105. # `res.xr` to find a solution to the original, discrete problem. Here is ours.
  3106. # Consider case 2a. Because F is an increasing function, we know
  3107. # that F(xr) >= p (and F(xl) <= p), so F(floor(xr) + 1) >= p.
  3108. # F(floor(xr)) *may* be >= p, but we can't know until we evaluate it.
  3109. # F(floor(xr) - 1) < p (strictly) because floor(xr) - 1 < xl and F decreases
  3110. # monotonically as the argument decreases. So we choose x = floor(xr), and
  3111. # later we'll choose between x* = x and x* = x + 1.
  3112. x = np.asarray(np.floor(res.xr))
  3113. # This is also suitable for case 2b. Because G is a *decreasing* function, we
  3114. # know that G(xr) <= p (and G(xl) >= p), so G(floor(xr) + 1) <= p.
  3115. # G(floor(xr)) *may* be <= p, but we can't know until we evaluate it.
  3116. # G(floor(xr) - 1) > p (strictly) because floor(xr) - 1 < xl and G increases
  3117. # as the argument decreases. So we would still want to choose x = floor(xr), and
  3118. # later we'll choose between x* = x and x* = x + 1.
  3119. # Now we consider case 1a/b. In this case, `res.x` solved the equation
  3120. # *exactly*, so the algorithm may have terminated before the bracket is tight
  3121. # enough to rely on `res.xr`. If `res.x` happens to be integral, `res.x` is
  3122. # the solution to the discrete problem, and floor(res.x) == res.x, so
  3123. # floor(res.x) is the solution to the discrete problem. If not:
  3124. # a) F(floor(res.x)) < p (strictly) and F(floor(res.x) + 1) > p (strictly). So
  3125. # floor(res.x) + 1 is the solution to the discrete problem.
  3126. # b) G(floor(res.x)) > p (strictly) and G(floor(res.x) + 1) < p (strictly). So
  3127. # floor(res.x) + 1 is again the solution to the discrete problem.
  3128. # Either way, we can choose x = res.x, and at the end we'll choose between
  3129. # x* = x and x* = x + 1.
  3130. mask = res.fun == 0
  3131. x[mask] = np.floor(res.x[mask])
  3132. # For case 3, let xmin be the left endpoint of the support, and note that in
  3133. # general, F(xmin) > 0 and G(xmin) < 1. Therefore it is possible that:
  3134. # a) F(x) > p for all x in the support (e.g. because p ~ 0)
  3135. # a) G(x) < p for all x in the support (e.g. because p ~ 1)
  3136. # In these cases, `_solve_bounded` would fail to find a root of the continuous
  3137. # equation above, but the solution to the original, discrete problem is the left
  3138. # endpoint of the support.
  3139. # This case is handled before we get to this function; otherwise,
  3140. # `_solve_bounded` may spin its wheels for a long time in vain.
  3141. # Now, we choose between x* = x and x* = x + 1: if func(x) satisfies the
  3142. # comparison `comp` (>= for cdf, <= for ccdf), the solution is x* = x;
  3143. # otherwise the solution must be x* = x + 1.
  3144. f = func(x, **params)
  3145. x = np.where(comp(f, p), x, x + 1.0)
  3146. x[np.isnan(f)] = np.nan # needed? why would func(x) be NaN within support?
  3147. return x
  3148. def _base_discrete_inversion(self, p, func, comp, /, **params):
  3149. # For discrete distributions, icdf(p) is defined as the minimum integer x*
  3150. # within the support such that F(x*) >= p; iccdf(p) is the minimum integer x*
  3151. # within the support such that G(x*) <= p.
  3152. # Identify where the solution is xmin.
  3153. # (See rationale in `_solve_bounded_discrete`.)
  3154. xmin, xmax = self._support(**params)
  3155. p, xmin, _ = np.broadcast_arrays(p, xmin, xmax)
  3156. mask = comp(func(xmin, **params), p)
  3157. # Use `apply_where` to perform the inversion only when necessary.
  3158. def f1(p, *args):
  3159. return self._solve_bounded_discrete(
  3160. func, p, params=dict(zip(params.keys(), args)), comp=comp)
  3161. x = xpx.apply_where(~mask, (p, *params.values()), f1, fill_value=xmin)
  3162. # x above may be a finite value even when p is NaN, so the returned value
  3163. # should be NaN. We need to handle this as a special case.
  3164. x[np.isnan(p)] = np.nan
  3165. return x[()]
  3166. def _icdf_inversion(self, x, **params):
  3167. return self._base_discrete_inversion(x, self._cdf_dispatch,
  3168. np.greater_equal, **params)
  3169. def _ilogcdf_inversion(self, x, **params):
  3170. return self._base_discrete_inversion(x, self._logcdf_dispatch,
  3171. np.greater_equal, **params)
  3172. def _iccdf_inversion(self, x, **params):
  3173. return self._base_discrete_inversion(x, self._ccdf_dispatch,
  3174. np.less_equal, **params)
  3175. def _ilogccdf_inversion(self, x, **params):
  3176. return self._base_discrete_inversion(x, self._logccdf_dispatch,
  3177. np.less_equal, **params)
  3178. def _mode_optimization(self, **params):
  3179. # If `x` is the true mode of a unimodal continuous function, we can find
  3180. # the mode among integers by rounding in each direction and checking
  3181. # which is better. If the difference between `x` and the nearest integer
  3182. # is less than `xatol`, the computed value of `x` may end up on the wrong
  3183. # side of the nearest integer. Setting `xatol=0.5` guarantees that at most
  3184. # three integers need to be checked, the two nearest integers, ``floor(x)``
  3185. # and ``round(x)`` and the nearest integer other than these.
  3186. x = super()._mode_optimization(xatol=0.5, **params)
  3187. low, high = self.support()
  3188. xl, xr = np.floor(x), np.ceil(x)
  3189. nearest = np.round(x)
  3190. # Clip to stay within support. There will be redundant calculation
  3191. # when clipping since `xo` will be one of `xl` or `xr`, but let's
  3192. # keep the implementation simple for now.
  3193. xo = np.clip(nearest + np.copysign(1, nearest - x), low, high)
  3194. x = np.stack([xl, xo, xr])
  3195. idx = np.argmax(self._pmf_dispatch(x, **params), axis=0)
  3196. return np.choose(idx, [xl, xo, xr])
  3197. def _logentropy_quadrature(self, **params):
  3198. def logintegrand(x, **params):
  3199. logpmf = self._logpmf_dispatch(x, **params)
  3200. # Entropy summand is -pmf*log(pmf), so log-entropy summand is
  3201. # logpmf + log(logpmf) + pi*j. But pmf is always between 0 and 1,
  3202. # so logpmf is always negative, and so log(logpmf) = log(-logpmf) + pi*j.
  3203. # The two imaginary components "cancel" each other out (which we would
  3204. # expect because each term of the entropy summand is positive).
  3205. return np.where(np.isfinite(logpmf), logpmf + np.log(-logpmf), -np.inf)
  3206. return self._quadrature(logintegrand, params=params, log=True)
  3207. # Special case the names of some new-style distributions in `make_distribution`
  3208. _distribution_names = {
  3209. # Continuous
  3210. 'argus': 'ARGUS',
  3211. 'betaprime': 'BetaPrime',
  3212. 'chi2': 'ChiSquared',
  3213. 'crystalball': 'CrystalBall',
  3214. 'dgamma': 'DoubleGamma',
  3215. 'dweibull': 'DoubleWeibull',
  3216. 'expon': 'Exponential',
  3217. 'exponnorm': 'ExponentiallyModifiedNormal',
  3218. 'exponweib': 'ExponentialWeibull',
  3219. 'exponpow': 'ExponentialPower',
  3220. 'fatiguelife': 'FatigueLife',
  3221. 'foldcauchy': 'FoldedCauchy',
  3222. 'foldnorm': 'FoldedNormal',
  3223. 'genlogistic': 'GeneralizedLogistic',
  3224. 'gennorm': 'GeneralizedNormal',
  3225. 'genpareto': 'GeneralizedPareto',
  3226. 'genexpon': 'GeneralizedExponential',
  3227. 'genextreme': 'GeneralizedExtremeValue',
  3228. 'gausshyper': 'GaussHypergeometric',
  3229. 'gengamma': 'GeneralizedGamma',
  3230. 'genhalflogistic': 'GeneralizedHalfLogistic',
  3231. 'geninvgauss': 'GeneralizedInverseGaussian',
  3232. 'gumbel_r': 'Gumbel',
  3233. 'gumbel_l': 'ReflectedGumbel',
  3234. 'halfcauchy': 'HalfCauchy',
  3235. 'halflogistic': 'HalfLogistic',
  3236. 'halfnorm': 'HalfNormal',
  3237. 'halfgennorm': 'HalfGeneralizedNormal',
  3238. 'hypsecant': 'HyperbolicSecant',
  3239. 'invgamma': 'InverseGammma',
  3240. 'invgauss': 'InverseGaussian',
  3241. 'invweibull': 'InverseWeibull',
  3242. 'irwinhall': 'IrwinHall',
  3243. 'jf_skew_t': 'JonesFaddySkewT',
  3244. 'johnsonsb': 'JohnsonSB',
  3245. 'johnsonsu': 'JohnsonSU',
  3246. 'ksone': 'KSOneSided',
  3247. 'kstwo': 'KSTwoSided',
  3248. 'kstwobign': 'KSTwoSidedAsymptotic',
  3249. 'laplace_asymmetric': 'LaplaceAsymmetric',
  3250. 'levy_l': 'LevyLeft',
  3251. 'levy_stable': 'LevyStable',
  3252. 'loggamma': 'ExpGamma', # really the Exponential Gamma Distribution
  3253. 'loglaplace': 'LogLaplace',
  3254. 'lognorm': 'LogNormal',
  3255. 'loguniform': 'LogUniform',
  3256. 'ncx2': 'NoncentralChiSquared',
  3257. 'nct': 'NoncentralT',
  3258. 'norm': 'Normal',
  3259. 'norminvgauss': 'NormalInverseGaussian',
  3260. 'powerlaw': 'PowerLaw',
  3261. 'powernorm': 'PowerNormal',
  3262. 'rdist': 'R',
  3263. 'rel_breitwigner': 'RelativisticBreitWigner',
  3264. 'recipinvgauss': 'ReciprocalInverseGaussian',
  3265. 'reciprocal': 'LogUniform',
  3266. 'semicircular': 'SemiCircular',
  3267. 'skewcauchy': 'SkewCauchy',
  3268. 'skewnorm': 'SkewNormal',
  3269. 'studentized_range': 'StudentizedRange',
  3270. 't': 'StudentT',
  3271. 'trapezoid': 'Trapezoidal',
  3272. 'triang': 'Triangular',
  3273. 'truncexpon': 'TruncatedExponential',
  3274. 'truncnorm': 'TruncatedNormal',
  3275. 'truncpareto': 'TruncatedPareto',
  3276. 'truncweibull_min': 'TruncatedWeibull',
  3277. 'tukeylambda': 'TukeyLambda',
  3278. 'vonmises_line': 'VonMisesLine',
  3279. 'weibull_min': 'Weibull',
  3280. 'weibull_max': 'ReflectedWeibull',
  3281. 'wrapcauchy': 'WrappedCauchyLine',
  3282. # Discrete
  3283. 'betabinom': 'BetaBinomial',
  3284. 'betanbinom': 'BetaNegativeBinomial',
  3285. 'dlaplace': 'LaplaceDiscrete',
  3286. 'geom': 'Geometric',
  3287. 'hypergeom': 'Hypergeometric',
  3288. 'logser': 'LogarithmicSeries',
  3289. 'nbinom': 'NegativeBinomial',
  3290. 'nchypergeom_fisher': 'NoncentralHypergeometricFisher',
  3291. 'nchypergeom_wallenius': 'NoncentralHypergeometricWallenius',
  3292. 'nhypergeom': 'NegativeHypergeometric',
  3293. 'poisson_binom': 'PoissonBinomial',
  3294. 'randint': 'UniformDiscrete',
  3295. 'yulesimon': 'YuleSimon',
  3296. 'zipf': 'Zeta',
  3297. }
  3298. # beta, genextreme, gengamma, t, tukeylambda need work for 1D arrays
  3299. @xp_capabilities(np_only=True)
  3300. def make_distribution(dist):
  3301. """Generate a `UnivariateDistribution` class from a compatible object
  3302. The argument may be an instance of `rv_continuous` or an instance of
  3303. another class that satisfies the interface described below.
  3304. The returned value is a `ContinuousDistribution` subclass if the input is an
  3305. instance of `rv_continuous` or a `DiscreteDistribution` subclass if the input
  3306. is an instance of `rv_discrete`. Like any subclass of `UnivariateDistribution`,
  3307. it must be instantiated (i.e. by passing all shape parameters as keyword
  3308. arguments) before use. Once instantiated, the resulting object will have the
  3309. same interface as any other instance of `UnivariateDistribution`; e.g.,
  3310. `scipy.stats.Normal`, `scipy.stats.Binomial`.
  3311. .. note::
  3312. `make_distribution` does not work perfectly with all instances of
  3313. `rv_continuous`. Known failures include `levy_stable`, `vonmises`,
  3314. `hypergeom`, 'nchypergeom_fisher', 'nchypergeom_wallenius', and
  3315. `poisson_binom`. Some methods of some distributions will not support
  3316. array shape parameters.
  3317. Parameters
  3318. ----------
  3319. dist : `rv_continuous`
  3320. Instance of `rv_continuous`, `rv_discrete`, or an instance of any class with
  3321. the following attributes:
  3322. __make_distribution_version__ : str
  3323. A string containing the version number of SciPy in which this interface
  3324. is defined. The preferred interface may change in future SciPy versions,
  3325. in which case support for an old interface version may be deprecated
  3326. and eventually removed.
  3327. parameters : dict or tuple
  3328. If a dictionary, each key is the name of a parameter,
  3329. and the corresponding value is either a dictionary or tuple.
  3330. If the value is a dictionary, it may have the following items, with default
  3331. values used for entries which aren't present.
  3332. endpoints : tuple, default: (-inf, inf)
  3333. A tuple defining the lower and upper endpoints of the domain of the
  3334. parameter; allowable values are floats, the name (string) of another
  3335. parameter, or a callable taking parameters as keyword only
  3336. arguments and returning the numerical value of an endpoint for
  3337. given parameter values.
  3338. inclusive : tuple of bool, default: (False, False)
  3339. A tuple specifying whether the endpoints are included within the domain
  3340. of the parameter.
  3341. typical : tuple, default: ``endpoints``
  3342. Defining endpoints of a typical range of values of a parameter. Can be
  3343. used for sampling parameter values for testing. Behaves like the
  3344. ``endpoints`` tuple above, and should define a subinterval of the
  3345. domain given by ``endpoints``.
  3346. A tuple value ``(a, b)`` associated to a key in the ``parameters``
  3347. dictionary is equivalent to ``{endpoints: (a, b)}``.
  3348. Custom distributions with multiple parameterizations can be defined by
  3349. having the ``parameters`` attribute be a tuple of dictionaries with
  3350. the structure described above. In this case, ``dist``\'s class must also
  3351. define a method ``process_parameters`` to map between the different
  3352. parameterizations. It must take all parameters from all parameterizations
  3353. as optional keyword arguments and return a dictionary mapping parameters to
  3354. values, filling in values from other parameterizations using values from
  3355. the supplied parameterization. See example.
  3356. support : dict or tuple
  3357. A dictionary describing the support of the distribution or a tuple
  3358. describing the endpoints of the support. This behaves identically to
  3359. the values of the parameters dict described above, except that the key
  3360. ``typical`` is ignored.
  3361. The class **must** also define a ``pdf`` method and **may** define methods
  3362. ``logentropy``, ``entropy``, ``median``, ``mode``, ``logpdf``,
  3363. ``logcdf``, ``cdf``, ``logccdf``, ``ccdf``,
  3364. ``ilogcdf``, ``icdf``, ``ilogccdf``, ``iccdf``,
  3365. ``moment``, and ``sample``.
  3366. If defined, these methods must accept the parameters of the distribution as
  3367. keyword arguments and also accept any positional-only arguments accepted by
  3368. the corresponding method of `ContinuousDistribution`.
  3369. When multiple parameterizations are defined, these methods must accept
  3370. all parameters from all parameterizations. The ``moment`` method
  3371. must accept the ``order`` and ``kind`` arguments by position or keyword, but
  3372. may return ``None`` if a formula is not available for the arguments; in this
  3373. case, the infrastructure will fall back to a default implementation. The
  3374. ``sample`` method must accept ``shape`` by position or keyword, but contrary
  3375. to the public method of the same name, the argument it receives will be the
  3376. *full* shape of the output array - that is, the shape passed to the public
  3377. method prepended to the broadcasted shape of random variable parameters.
  3378. Returns
  3379. -------
  3380. CustomDistribution : `UnivariateDistribution`
  3381. A subclass of `UnivariateDistribution` corresponding with `dist`. The
  3382. initializer requires all shape parameters to be passed as keyword arguments
  3383. (using the same names as the instance of `rv_continuous`/`rv_discrete`).
  3384. Notes
  3385. -----
  3386. The documentation of `UnivariateDistribution` is not rendered. See below for
  3387. an example of how to instantiate the class (i.e. pass all shape parameters of
  3388. `dist` to the initializer as keyword arguments). Documentation of all methods
  3389. is identical to that of `scipy.stats.Normal`. Use ``help`` on the returned
  3390. class or its methods for more information.
  3391. Examples
  3392. --------
  3393. >>> import numpy as np
  3394. >>> import matplotlib.pyplot as plt
  3395. >>> from scipy import stats
  3396. >>> from scipy import special
  3397. Create a `ContinuousDistribution` from `scipy.stats.loguniform`.
  3398. >>> LogUniform = stats.make_distribution(stats.loguniform)
  3399. >>> X = LogUniform(a=1.0, b=3.0)
  3400. >>> np.isclose((X + 0.25).median(), stats.loguniform.ppf(0.5, 1, 3, loc=0.25))
  3401. np.True_
  3402. >>> X.plot()
  3403. >>> sample = X.sample(10000, rng=np.random.default_rng())
  3404. >>> plt.hist(sample, density=True, bins=30)
  3405. >>> plt.legend(('pdf', 'histogram'))
  3406. >>> plt.show()
  3407. Create a custom distribution.
  3408. >>> class MyLogUniform:
  3409. ... @property
  3410. ... def __make_distribution_version__(self):
  3411. ... return "1.16.0"
  3412. ...
  3413. ... @property
  3414. ... def parameters(self):
  3415. ... return {'a': {'endpoints': (0, np.inf),
  3416. ... 'inclusive': (False, False)},
  3417. ... 'b': {'endpoints': ('a', np.inf),
  3418. ... 'inclusive': (False, False)}}
  3419. ...
  3420. ... @property
  3421. ... def support(self):
  3422. ... return {'endpoints': ('a', 'b'), 'inclusive': (True, True)}
  3423. ...
  3424. ... def pdf(self, x, a, b):
  3425. ... return 1 / (x * (np.log(b)- np.log(a)))
  3426. >>>
  3427. >>> MyLogUniform = stats.make_distribution(MyLogUniform())
  3428. >>> Y = MyLogUniform(a=1.0, b=3.0)
  3429. >>> np.isclose(Y.cdf(2.), X.cdf(2.))
  3430. np.True_
  3431. Create a custom distribution with variable support.
  3432. >>> class MyUniformCube:
  3433. ... @property
  3434. ... def __make_distribution_version__(self):
  3435. ... return "1.16.0"
  3436. ...
  3437. ... @property
  3438. ... def parameters(self):
  3439. ... return {"a": (-np.inf, np.inf),
  3440. ... "b": {'endpoints':('a', np.inf), 'inclusive':(True, False)}}
  3441. ...
  3442. ... @property
  3443. ... def support(self):
  3444. ... def left(*, a, b):
  3445. ... return a**3
  3446. ...
  3447. ... def right(*, a, b):
  3448. ... return b**3
  3449. ... return (left, right)
  3450. ...
  3451. ... def pdf(self, x, *, a, b):
  3452. ... return 1 / (3*(b - a)*np.cbrt(x)**2)
  3453. ...
  3454. ... def cdf(self, x, *, a, b):
  3455. ... return (np.cbrt(x) - a) / (b - a)
  3456. >>>
  3457. >>> MyUniformCube = stats.make_distribution(MyUniformCube())
  3458. >>> X = MyUniformCube(a=-2, b=2)
  3459. >>> Y = stats.Uniform(a=-2, b=2)**3
  3460. >>> X.support()
  3461. (-8.0, 8.0)
  3462. >>> np.isclose(X.cdf(2.1), Y.cdf(2.1))
  3463. np.True_
  3464. Create a custom distribution with multiple parameterizations. Here we create a
  3465. custom version of the beta distribution that has an alternative parameterization
  3466. in terms of the mean ``mu`` and a dispersion parameter ``nu``.
  3467. >>> class MyBeta:
  3468. ... @property
  3469. ... def __make_distribution_version__(self):
  3470. ... return "1.16.0"
  3471. ...
  3472. ... @property
  3473. ... def parameters(self):
  3474. ... return ({"a": (0, np.inf), "b": (0, np.inf)},
  3475. ... {"mu": (0, 1), "nu": (0, np.inf)})
  3476. ...
  3477. ... def process_parameters(self, a=None, b=None, mu=None, nu=None):
  3478. ... if a is not None and b is not None:
  3479. ... nu = a + b
  3480. ... mu = a / nu
  3481. ... else:
  3482. ... a = mu * nu
  3483. ... b = nu - a
  3484. ... return dict(a=a, b=b, mu=mu, nu=nu)
  3485. ...
  3486. ... @property
  3487. ... def support(self):
  3488. ... return {'endpoints': (0, 1)}
  3489. ...
  3490. ... def pdf(self, x, a, b, mu, nu):
  3491. ... return special._ufuncs._beta_pdf(x, a, b)
  3492. ...
  3493. ... def cdf(self, x, a, b, mu, nu):
  3494. ... return special.betainc(a, b, x)
  3495. >>>
  3496. >>> MyBeta = stats.make_distribution(MyBeta())
  3497. >>> X = MyBeta(a=2.0, b=2.0)
  3498. >>> Y = MyBeta(mu=0.5, nu=4.0)
  3499. >>> np.isclose(X.pdf(0.3), Y.pdf(0.3))
  3500. np.True_
  3501. """
  3502. if dist in {stats.levy_stable, stats.vonmises, stats.hypergeom,
  3503. stats.nchypergeom_fisher, stats.nchypergeom_wallenius,
  3504. stats.poisson_binom}:
  3505. raise NotImplementedError(f"`{dist.name}` is not supported.")
  3506. if isinstance(dist, stats.rv_continuous | stats.rv_discrete):
  3507. return _make_distribution_rv_generic(dist)
  3508. elif getattr(dist, "__make_distribution_version__", "0.0.0") >= "1.16.0":
  3509. return _make_distribution_custom(dist)
  3510. else:
  3511. message = ("The argument must be an instance of `rv_continuous`, "
  3512. "`rv_discrete`, or an instance of a class with attribute "
  3513. "`__make_distribution_version__ >= 1.16`.")
  3514. raise ValueError(message)
  3515. def _make_distribution_rv_generic(dist):
  3516. parameters = []
  3517. names = []
  3518. support = getattr(dist, '_support', (dist.a, dist.b))
  3519. for shape_info in dist._shape_info():
  3520. domain = _RealInterval(endpoints=shape_info.endpoints,
  3521. inclusive=shape_info.inclusive)
  3522. param = _RealParameter(shape_info.name, domain=domain)
  3523. parameters.append(param)
  3524. names.append(shape_info.name)
  3525. repr_str = _distribution_names.get(dist.name, dist.name.capitalize())
  3526. if isinstance(dist, stats.rv_continuous):
  3527. old_class, new_class = stats.rv_continuous, ContinuousDistribution
  3528. else:
  3529. old_class, new_class = stats.rv_discrete, DiscreteDistribution
  3530. def _overrides(method_name):
  3531. return (getattr(dist.__class__, method_name, None)
  3532. is not getattr(old_class, method_name, None))
  3533. if _overrides("_get_support"):
  3534. def left(**parameter_values):
  3535. a, _ = dist._get_support(**parameter_values)
  3536. return np.asarray(a)[()]
  3537. def right(**parameter_values):
  3538. _, b = dist._get_support(**parameter_values)
  3539. return np.asarray(b)[()]
  3540. endpoints = (left, right)
  3541. else:
  3542. endpoints = support
  3543. _x_support = _RealInterval(endpoints=endpoints, inclusive=(True, True))
  3544. _x_param = _RealParameter('x', domain=_x_support, typical=(-1, 1))
  3545. class CustomDistribution(new_class):
  3546. _parameterizations = ([_Parameterization(*parameters)] if parameters
  3547. else [])
  3548. _variable = _x_param
  3549. __class_getitem__ = None
  3550. def __repr__(self):
  3551. s = super().__repr__()
  3552. return s.replace('CustomDistribution', repr_str)
  3553. def __str__(self):
  3554. s = super().__str__()
  3555. return s.replace('CustomDistribution', repr_str)
  3556. def _sample_formula(self, full_shape=(), *, rng=None, **kwargs):
  3557. return dist._rvs(size=full_shape, random_state=rng, **kwargs)
  3558. def _moment_raw_formula(self, order, **kwargs):
  3559. return dist._munp(int(order), **kwargs)
  3560. def _moment_raw_formula_1(self, order, **kwargs):
  3561. if order != 1:
  3562. return None
  3563. return dist._stats(**kwargs)[0]
  3564. def _moment_central_formula(self, order, **kwargs):
  3565. if order != 2:
  3566. return None
  3567. return dist._stats(**kwargs)[1]
  3568. def _moment_standard_formula(self, order, **kwargs):
  3569. if order == 3:
  3570. if dist._stats_has_moments:
  3571. kwargs['moments'] = 's'
  3572. return dist._stats(**kwargs)[int(order - 1)]
  3573. elif order == 4:
  3574. if dist._stats_has_moments:
  3575. kwargs['moments'] = 'k'
  3576. k = dist._stats(**kwargs)[int(order - 1)]
  3577. return k if k is None else k + 3
  3578. else:
  3579. return None
  3580. methods = {'_logpdf': '_logpdf_formula',
  3581. '_pdf': '_pdf_formula',
  3582. '_logpmf': '_logpmf_formula',
  3583. '_pmf': '_pmf_formula',
  3584. '_logcdf': '_logcdf_formula',
  3585. '_cdf': '_cdf_formula',
  3586. '_logsf': '_logccdf_formula',
  3587. '_sf': '_ccdf_formula',
  3588. '_ppf': '_icdf_formula',
  3589. '_isf': '_iccdf_formula',
  3590. '_entropy': '_entropy_formula',
  3591. '_median': '_median_formula'}
  3592. # These are not desirable overrides for the new infrastructure
  3593. skip_override = {'norminvgauss': {'_sf', '_isf'}}
  3594. for old_method, new_method in methods.items():
  3595. if dist.name in skip_override and old_method in skip_override[dist.name]:
  3596. continue
  3597. # If method of old distribution overrides generic implementation...
  3598. method = getattr(dist.__class__, old_method, None)
  3599. super_method = getattr(old_class, old_method, None)
  3600. if method is not super_method:
  3601. # Make it an attribute of the new object with the new name
  3602. setattr(CustomDistribution, new_method, getattr(dist, old_method))
  3603. if _overrides('_munp'):
  3604. CustomDistribution._moment_raw_formula = _moment_raw_formula
  3605. if _overrides('_rvs'):
  3606. CustomDistribution._sample_formula = _sample_formula
  3607. if _overrides('_stats'):
  3608. CustomDistribution._moment_standardized_formula = _moment_standard_formula
  3609. if not _overrides('_munp'):
  3610. CustomDistribution._moment_raw_formula = _moment_raw_formula_1
  3611. CustomDistribution._moment_central_formula = _moment_central_formula
  3612. support_etc = _combine_docs(CustomDistribution, include_examples=False).lstrip()
  3613. docs = [
  3614. f"This class represents `scipy.stats.{dist.name}` as a subclass of "
  3615. f"`{new_class}`.",
  3616. f"The `repr`/`str` of class instances is `{repr_str}`.",
  3617. f"The PDF of the distribution is defined {support_etc}"
  3618. ]
  3619. CustomDistribution.__doc__ = ("\n".join(docs))
  3620. return CustomDistribution
  3621. def _get_domain_info(info):
  3622. domain_info = {"endpoints": info} if isinstance(info, tuple) else info
  3623. typical = domain_info.pop("typical", None)
  3624. return domain_info, typical
  3625. def _make_distribution_custom(dist):
  3626. dist_parameters = (
  3627. dist.parameters if isinstance(dist.parameters, tuple) else (dist.parameters, )
  3628. )
  3629. parameterizations = []
  3630. for parameterization in dist_parameters:
  3631. # The attribute name ``parameters`` appears reasonable from a user facing
  3632. # perspective, but there is a little tension here with the internal. It's
  3633. # important to keep in mind that the ``parameters`` attribute in a
  3634. # user-created custom distribution specifies ``_parameterizations`` within
  3635. # the infrastructure.
  3636. parameters = []
  3637. for name, info in parameterization.items():
  3638. domain_info, typical = _get_domain_info(info)
  3639. domain = _RealInterval(**domain_info)
  3640. param = _RealParameter(name, domain=domain, typical=typical)
  3641. parameters.append(param)
  3642. parameterizations.append(_Parameterization(*parameters) if parameters else [])
  3643. domain_info, _ = _get_domain_info(dist.support)
  3644. _x_support = _RealInterval(**domain_info)
  3645. _x_param = _RealParameter('x', domain=_x_support)
  3646. repr_str = dist.__class__.__name__
  3647. class CustomDistribution(ContinuousDistribution):
  3648. _parameterizations = parameterizations
  3649. _variable = _x_param
  3650. def __repr__(self):
  3651. s = super().__repr__()
  3652. return s.replace('CustomDistribution', repr_str)
  3653. def __str__(self):
  3654. s = super().__str__()
  3655. return s.replace('CustomDistribution', repr_str)
  3656. methods = {'sample', 'logentropy', 'entropy',
  3657. 'median', 'mode', 'logpdf', 'pdf',
  3658. 'logcdf2', 'logcdf', 'cdf2', 'cdf',
  3659. 'logccdf2', 'logccdf', 'ccdf2', 'ccdf',
  3660. 'ilogcdf', 'icdf', 'ilogccdf', 'iccdf'}
  3661. for method in methods:
  3662. if hasattr(dist, method):
  3663. # Make it an attribute of the new object with the new name
  3664. new_method = f"_{method}_formula"
  3665. setattr(CustomDistribution, new_method, getattr(dist, method))
  3666. if hasattr(dist, 'moment'):
  3667. def _moment_raw_formula(self, order, **kwargs):
  3668. return dist.moment(order, kind='raw', **kwargs)
  3669. def _moment_central_formula(self, order, **kwargs):
  3670. return dist.moment(order, kind='central', **kwargs)
  3671. def _moment_standardized_formula(self, order, **kwargs):
  3672. return dist.moment(order, kind='standardized', **kwargs)
  3673. CustomDistribution._moment_raw_formula = _moment_raw_formula
  3674. CustomDistribution._moment_central_formula = _moment_central_formula
  3675. CustomDistribution._moment_standardized_formula = _moment_standardized_formula
  3676. if hasattr(dist, 'process_parameters'):
  3677. setattr(
  3678. CustomDistribution,
  3679. "_process_parameters",
  3680. getattr(dist, "process_parameters")
  3681. )
  3682. support_etc = _combine_docs(CustomDistribution, include_examples=False).lstrip()
  3683. docs = [
  3684. f"This class represents `{repr_str}` as a subclass of "
  3685. "`ContinuousDistribution`.",
  3686. f"The PDF of the distribution is defined {support_etc}"
  3687. ]
  3688. CustomDistribution.__doc__ = ("\n".join(docs))
  3689. return CustomDistribution
  3690. # Rough sketch of how we might shift/scale distributions. The purpose of
  3691. # making it a separate class is for
  3692. # a) simplicity of the ContinuousDistribution class and
  3693. # b) avoiding the requirement that every distribution accept loc/scale.
  3694. # The simplicity of ContinuousDistribution is important, because there are
  3695. # several other distribution transformations to be supported; e.g., truncation,
  3696. # wrapping, folding, and doubling. We wouldn't want to cram all of this
  3697. # into the `ContinuousDistribution` class. Also, the order of the composition
  3698. # matters (e.g. truncate then shift/scale or vice versa). It's easier to
  3699. # accommodate different orders if the transformation is built up from
  3700. # components rather than all built into `ContinuousDistribution`.
  3701. def _shift_scale_distribution_function_2arg(func):
  3702. def wrapped(self, x, y, *args, loc, scale, sign, **kwargs):
  3703. item = func.__name__
  3704. f = getattr(self._dist, item)
  3705. # Obviously it's possible to get away with half of the work here.
  3706. # Let's focus on correct results first and optimize later.
  3707. xt = self._transform(x, loc, scale)
  3708. yt = self._transform(y, loc, scale)
  3709. fxy = f(xt, yt, *args, **kwargs)
  3710. fyx = f(yt, xt, *args, **kwargs)
  3711. return np.real_if_close(np.where(sign, fxy, fyx))[()]
  3712. return wrapped
  3713. def _shift_scale_distribution_function(func):
  3714. # c is for complementary
  3715. citem = {'_logcdf_dispatch': '_logccdf_dispatch',
  3716. '_cdf_dispatch': '_ccdf_dispatch',
  3717. '_logccdf_dispatch': '_logcdf_dispatch',
  3718. '_ccdf_dispatch': '_cdf_dispatch'}
  3719. def wrapped(self, x, *args, loc, scale, sign, **kwargs):
  3720. item = func.__name__
  3721. f = getattr(self._dist, item)
  3722. cf = getattr(self._dist, citem[item])
  3723. # Obviously it's possible to get away with half of the work here.
  3724. # Let's focus on correct results first and optimize later.
  3725. xt = self._transform(x, loc, scale)
  3726. fx = f(xt, *args, **kwargs)
  3727. cfx = cf(xt, *args, **kwargs)
  3728. return np.where(sign, fx, cfx)[()]
  3729. return wrapped
  3730. def _shift_scale_inverse_function(func):
  3731. citem = {'_ilogcdf_dispatch': '_ilogccdf_dispatch',
  3732. '_icdf_dispatch': '_iccdf_dispatch',
  3733. '_ilogccdf_dispatch': '_ilogcdf_dispatch',
  3734. '_iccdf_dispatch': '_icdf_dispatch'}
  3735. def wrapped(self, p, *args, loc, scale, sign, **kwargs):
  3736. item = func.__name__
  3737. f = getattr(self._dist, item)
  3738. cf = getattr(self._dist, citem[item])
  3739. # Obviously it's possible to get away with half of the work here.
  3740. # Let's focus on correct results first and optimize later.
  3741. fx = self._itransform(f(p, *args, **kwargs), loc, scale)
  3742. cfx = self._itransform(cf(p, *args, **kwargs), loc, scale)
  3743. return np.where(sign, fx, cfx)[()]
  3744. return wrapped
  3745. class TransformedDistribution(ContinuousDistribution):
  3746. def __init__(self, X, /, *args, **kwargs):
  3747. if not isinstance(X, ContinuousDistribution):
  3748. message = "Transformations are currently only supported for continuous RVs."
  3749. raise NotImplementedError(message)
  3750. self._copy_parameterization()
  3751. self._variable = X._variable
  3752. self._dist = X
  3753. if X._parameterization:
  3754. # Add standard distribution parameters to our parameterization
  3755. dist_parameters = X._parameterization.parameters
  3756. set_params = set(dist_parameters)
  3757. if not self._parameterizations:
  3758. self._parameterizations.append(_Parameterization())
  3759. for parameterization in self._parameterizations:
  3760. if set_params.intersection(parameterization.parameters):
  3761. message = (f"One or more of the parameters of {X} has "
  3762. "the same name as a parameter of "
  3763. f"{self.__class__.__name__}. Name collisions "
  3764. "create ambiguities and are not supported.")
  3765. raise ValueError(message)
  3766. parameterization.parameters.update(dist_parameters)
  3767. super().__init__(*args, **kwargs)
  3768. def _overrides(self, method_name):
  3769. return (self._dist._overrides(method_name)
  3770. or super()._overrides(method_name))
  3771. def reset_cache(self):
  3772. self._dist.reset_cache()
  3773. super().reset_cache()
  3774. def _update_parameters(self, *, validation_policy=None, **params):
  3775. # maybe broadcast everything before processing?
  3776. parameters = {}
  3777. # There may be some issues with _original_parameters
  3778. # We only want to update with _dist._original_parameters during
  3779. # initialization. Afterward that, we want to start with
  3780. # self._original_parameters.
  3781. parameters.update(self._dist._original_parameters)
  3782. parameters.update(params)
  3783. super()._update_parameters(validation_policy=validation_policy, **parameters)
  3784. def _process_parameters(self, **params):
  3785. return self._dist._process_parameters(**params)
  3786. def __repr__(self):
  3787. raise NotImplementedError()
  3788. def __str__(self):
  3789. raise NotImplementedError()
  3790. class TruncatedDistribution(TransformedDistribution):
  3791. """Truncated distribution."""
  3792. # TODO:
  3793. # - consider avoiding catastropic cancellation by using appropriate tail
  3794. # - if the mode of `_dist` is within the support, it's still the mode
  3795. # - rejection sampling might be more efficient than inverse transform
  3796. _lb_domain = _RealInterval(endpoints=(-inf, 'ub'), inclusive=(True, False))
  3797. _lb_param = _RealParameter('lb', symbol=r'b_l',
  3798. domain=_lb_domain, typical=(0.1, 0.2))
  3799. _ub_domain = _RealInterval(endpoints=('lb', inf), inclusive=(False, True))
  3800. _ub_param = _RealParameter('ub', symbol=r'b_u',
  3801. domain=_ub_domain, typical=(0.8, 0.9))
  3802. _parameterizations = [_Parameterization(_lb_param, _ub_param),
  3803. _Parameterization(_lb_param),
  3804. _Parameterization(_ub_param)]
  3805. def __init__(self, X, /, *args, lb=-np.inf, ub=np.inf, **kwargs):
  3806. return super().__init__(X, *args, lb=lb, ub=ub, **kwargs)
  3807. def _process_parameters(self, lb=None, ub=None, **params):
  3808. lb = lb if lb is not None else np.full_like(lb, -np.inf)[()]
  3809. ub = ub if ub is not None else np.full_like(ub, np.inf)[()]
  3810. parameters = self._dist._process_parameters(**params)
  3811. a, b = self._support(lb=lb, ub=ub, **parameters)
  3812. logmass = self._dist._logcdf2_dispatch(a, b, **parameters)
  3813. parameters.update(dict(lb=lb, ub=ub, _a=a, _b=b, logmass=logmass))
  3814. return parameters
  3815. def _support(self, lb, ub, **params):
  3816. a, b = self._dist._support(**params)
  3817. return np.maximum(a, lb), np.minimum(b, ub)
  3818. def _overrides(self, method_name):
  3819. return False
  3820. def _logpdf_dispatch(self, x, *args, lb, ub, _a, _b, logmass, **params):
  3821. logpdf = self._dist._logpdf_dispatch(x, *args, **params)
  3822. return logpdf - logmass
  3823. def _logcdf_dispatch(self, x, *args, lb, ub, _a, _b, logmass, **params):
  3824. logcdf = self._dist._logcdf2_dispatch(_a, x, *args, **params)
  3825. # of course, if this result is small we could compute with the other tail
  3826. return logcdf - logmass
  3827. def _logccdf_dispatch(self, x, *args, lb, ub, _a, _b, logmass, **params):
  3828. logccdf = self._dist._logcdf2_dispatch(x, _b, *args, **params)
  3829. return logccdf - logmass
  3830. def _logcdf2_dispatch(self, x, y, *args, lb, ub, _a, _b, logmass, **params):
  3831. logcdf2 = self._dist._logcdf2_dispatch(x, y, *args, **params)
  3832. return logcdf2 - logmass
  3833. def _ilogcdf_dispatch(self, logp, *args, lb, ub, _a, _b, logmass, **params):
  3834. log_Fa = self._dist._logcdf_dispatch(_a, *args, **params)
  3835. logp_adjusted = np.logaddexp(log_Fa, logp + logmass)
  3836. return self._dist._ilogcdf_dispatch(logp_adjusted, *args, **params)
  3837. def _ilogccdf_dispatch(self, logp, *args, lb, ub, _a, _b, logmass, **params):
  3838. log_cFb = self._dist._logccdf_dispatch(_b, *args, **params)
  3839. logp_adjusted = np.logaddexp(log_cFb, logp + logmass)
  3840. return self._dist._ilogccdf_dispatch(logp_adjusted, *args, **params)
  3841. def _icdf_dispatch(self, p, *args, lb, ub, _a, _b, logmass, **params):
  3842. Fa = self._dist._cdf_dispatch(_a, *args, **params)
  3843. p_adjusted = Fa + p*np.exp(logmass)
  3844. return self._dist._icdf_dispatch(p_adjusted, *args, **params)
  3845. def _iccdf_dispatch(self, p, *args, lb, ub, _a, _b, logmass, **params):
  3846. cFb = self._dist._ccdf_dispatch(_b, *args, **params)
  3847. p_adjusted = cFb + p*np.exp(logmass)
  3848. return self._dist._iccdf_dispatch(p_adjusted, *args, **params)
  3849. def __repr__(self):
  3850. with np.printoptions(threshold=10):
  3851. return (f"truncate({repr(self._dist)}, "
  3852. f"lb={repr(self.lb)}, ub={repr(self.ub)})")
  3853. def __str__(self):
  3854. with np.printoptions(threshold=10):
  3855. return (f"truncate({str(self._dist)}, "
  3856. f"lb={str(self.lb)}, ub={str(self.ub)})")
  3857. @xp_capabilities(np_only=True)
  3858. def truncate(X, lb=-np.inf, ub=np.inf):
  3859. """Truncate the support of a random variable.
  3860. Given a random variable `X`, `truncate` returns a random variable with
  3861. support truncated to the interval between `lb` and `ub`. The underlying
  3862. probability density function is normalized accordingly.
  3863. Parameters
  3864. ----------
  3865. X : `ContinuousDistribution`
  3866. The random variable to be truncated.
  3867. lb, ub : float array-like
  3868. The lower and upper truncation points, respectively. Must be
  3869. broadcastable with one another and the shape of `X`.
  3870. Returns
  3871. -------
  3872. X : `ContinuousDistribution`
  3873. The truncated random variable.
  3874. References
  3875. ----------
  3876. .. [1] "Truncated Distribution". *Wikipedia*.
  3877. https://en.wikipedia.org/wiki/Truncated_distribution
  3878. Examples
  3879. --------
  3880. Compare against `scipy.stats.truncnorm`, which truncates a standard normal,
  3881. *then* shifts and scales it.
  3882. >>> import numpy as np
  3883. >>> import matplotlib.pyplot as plt
  3884. >>> from scipy import stats
  3885. >>> loc, scale, lb, ub = 1, 2, -2, 2
  3886. >>> X = stats.truncnorm(lb, ub, loc, scale)
  3887. >>> Y = scale * stats.truncate(stats.Normal(), lb, ub) + loc
  3888. >>> x = np.linspace(-3, 5, 300)
  3889. >>> plt.plot(x, X.pdf(x), '-', label='X')
  3890. >>> plt.plot(x, Y.pdf(x), '--', label='Y')
  3891. >>> plt.xlabel('x')
  3892. >>> plt.ylabel('PDF')
  3893. >>> plt.title('Truncated, then Shifted/Scaled Normal')
  3894. >>> plt.legend()
  3895. >>> plt.show()
  3896. However, suppose we wish to shift and scale a normal random variable,
  3897. then truncate its support to given values. This is straightforward with
  3898. `truncate`.
  3899. >>> Z = stats.truncate(scale * stats.Normal() + loc, lb, ub)
  3900. >>> Z.plot()
  3901. >>> plt.show()
  3902. Furthermore, `truncate` can be applied to any random variable:
  3903. >>> Rayleigh = stats.make_distribution(stats.rayleigh)
  3904. >>> W = stats.truncate(Rayleigh(), lb=0.5, ub=3)
  3905. >>> W.plot()
  3906. >>> plt.show()
  3907. """
  3908. return TruncatedDistribution(X, lb=lb, ub=ub)
  3909. class ShiftedScaledDistribution(TransformedDistribution):
  3910. """Distribution with a standard shift/scale transformation."""
  3911. # Unclear whether infinite loc/scale will work reasonably in all cases
  3912. _loc_domain = _RealInterval(endpoints=(-inf, inf), inclusive=(True, True))
  3913. _loc_param = _RealParameter('loc', symbol=r'\mu',
  3914. domain=_loc_domain, typical=(1, 2))
  3915. _scale_domain = _RealInterval(endpoints=(-inf, inf), inclusive=(True, True))
  3916. _scale_param = _RealParameter('scale', symbol=r'\sigma',
  3917. domain=_scale_domain, typical=(0.1, 10))
  3918. _parameterizations = [_Parameterization(_loc_param, _scale_param),
  3919. _Parameterization(_loc_param),
  3920. _Parameterization(_scale_param)]
  3921. def _process_parameters(self, loc=None, scale=None, **params):
  3922. loc = loc if loc is not None else np.zeros_like(scale)[()]
  3923. scale = scale if scale is not None else np.ones_like(loc)[()]
  3924. sign = scale > 0
  3925. parameters = self._dist._process_parameters(**params)
  3926. parameters.update(dict(loc=loc, scale=scale, sign=sign))
  3927. return parameters
  3928. def _transform(self, x, loc, scale, **kwargs):
  3929. return (x - loc)/scale
  3930. def _itransform(self, x, loc, scale, **kwargs):
  3931. return x * scale + loc
  3932. def _support(self, loc, scale, sign, **params):
  3933. # Add shortcut for infinite support?
  3934. a, b = self._dist._support(**params)
  3935. a, b = self._itransform(a, loc, scale), self._itransform(b, loc, scale)
  3936. return np.where(sign, a, b)[()], np.where(sign, b, a)[()]
  3937. def __repr__(self):
  3938. with np.printoptions(threshold=10):
  3939. result = f"{repr(self.scale)}*{repr(self._dist)}"
  3940. if not self.loc.ndim and self.loc < 0:
  3941. result += f" - {repr(-self.loc)}"
  3942. elif (np.any(self.loc != 0)
  3943. or not np.can_cast(self.loc.dtype, self.scale.dtype)):
  3944. # We don't want to hide a zero array loc if it can cause
  3945. # a type promotion.
  3946. result += f" + {repr(self.loc)}"
  3947. return result
  3948. def __str__(self):
  3949. with np.printoptions(threshold=10):
  3950. result = f"{str(self.scale)}*{str(self._dist)}"
  3951. if not self.loc.ndim and self.loc < 0:
  3952. result += f" - {str(-self.loc)}"
  3953. elif (np.any(self.loc != 0)
  3954. or not np.can_cast(self.loc.dtype, self.scale.dtype)):
  3955. # We don't want to hide a zero array loc if it can cause
  3956. # a type promotion.
  3957. result += f" + {str(self.loc)}"
  3958. return result
  3959. # Here, we override all the `_dispatch` methods rather than the public
  3960. # methods or _function methods. Why not the public methods?
  3961. # If we were to override the public methods, then other
  3962. # TransformedDistribution classes (which could transform a
  3963. # ShiftedScaledDistribution) would need to call the public methods of
  3964. # ShiftedScaledDistribution, which would run the input validation again.
  3965. # Why not the _function methods? For distributions that rely on the
  3966. # default implementation of methods (e.g. `quadrature`, `inversion`),
  3967. # the implementation would "see" the location and scale like other
  3968. # distribution parameters, so they could affect the accuracy of the
  3969. # calculations. I think it is cleaner if `loc` and `scale` do not affect
  3970. # the underlying calculations at all.
  3971. def _entropy_dispatch(self, *args, loc, scale, sign, **params):
  3972. return (self._dist._entropy_dispatch(*args, **params)
  3973. + np.log(np.abs(scale)))
  3974. def _logentropy_dispatch(self, *args, loc, scale, sign, **params):
  3975. lH0 = self._dist._logentropy_dispatch(*args, **params)
  3976. lls = np.log(np.log(np.abs(scale))+0j)
  3977. return special.logsumexp(np.broadcast_arrays(lH0, lls), axis=0)
  3978. def _median_dispatch(self, *, method, loc, scale, sign, **params):
  3979. raw = self._dist._median_dispatch(method=method, **params)
  3980. return self._itransform(raw, loc, scale)
  3981. def _mode_dispatch(self, *, method, loc, scale, sign, **params):
  3982. raw = self._dist._mode_dispatch(method=method, **params)
  3983. return self._itransform(raw, loc, scale)
  3984. def _logpdf_dispatch(self, x, *args, loc, scale, sign, **params):
  3985. x = self._transform(x, loc, scale)
  3986. logpdf = self._dist._logpdf_dispatch(x, *args, **params)
  3987. return logpdf - np.log(np.abs(scale))
  3988. def _pdf_dispatch(self, x, *args, loc, scale, sign, **params):
  3989. x = self._transform(x, loc, scale)
  3990. pdf = self._dist._pdf_dispatch(x, *args, **params)
  3991. return pdf / np.abs(scale)
  3992. def _logpmf_dispatch(self, x, *args, loc, scale, sign, **params):
  3993. x = self._transform(x, loc, scale)
  3994. logpmf = self._dist._logpmf_dispatch(x, *args, **params)
  3995. return logpmf - np.log(np.abs(scale))
  3996. def _pmf_dispatch(self, x, *args, loc, scale, sign, **params):
  3997. x = self._transform(x, loc, scale)
  3998. pmf = self._dist._pmf_dispatch(x, *args, **params)
  3999. return pmf / np.abs(scale)
  4000. def _logpxf_dispatch(self, x, *args, loc, scale, sign, **params):
  4001. x = self._transform(x, loc, scale)
  4002. logpxf = self._dist._logpxf_dispatch(x, *args, **params)
  4003. return logpxf - np.log(np.abs(scale))
  4004. def _pxf_dispatch(self, x, *args, loc, scale, sign, **params):
  4005. x = self._transform(x, loc, scale)
  4006. pxf = self._dist._pxf_dispatch(x, *args, **params)
  4007. return pxf / np.abs(scale)
  4008. # Sorry about the magic. This is just a draft to show the behavior.
  4009. @_shift_scale_distribution_function
  4010. def _logcdf_dispatch(self, x, *, method=None, **params):
  4011. pass
  4012. @_shift_scale_distribution_function
  4013. def _cdf_dispatch(self, x, *, method=None, **params):
  4014. pass
  4015. @_shift_scale_distribution_function
  4016. def _logccdf_dispatch(self, x, *, method=None, **params):
  4017. pass
  4018. @_shift_scale_distribution_function
  4019. def _ccdf_dispatch(self, x, *, method=None, **params):
  4020. pass
  4021. @_shift_scale_distribution_function_2arg
  4022. def _logcdf2_dispatch(self, x, y, *, method=None, **params):
  4023. pass
  4024. @_shift_scale_distribution_function_2arg
  4025. def _cdf2_dispatch(self, x, y, *, method=None, **params):
  4026. pass
  4027. @_shift_scale_distribution_function_2arg
  4028. def _logccdf2_dispatch(self, x, y, *, method=None, **params):
  4029. pass
  4030. @_shift_scale_distribution_function_2arg
  4031. def _ccdf2_dispatch(self, x, y, *, method=None, **params):
  4032. pass
  4033. @_shift_scale_inverse_function
  4034. def _ilogcdf_dispatch(self, x, *, method=None, **params):
  4035. pass
  4036. @_shift_scale_inverse_function
  4037. def _icdf_dispatch(self, x, *, method=None, **params):
  4038. pass
  4039. @_shift_scale_inverse_function
  4040. def _ilogccdf_dispatch(self, x, *, method=None, **params):
  4041. pass
  4042. @_shift_scale_inverse_function
  4043. def _iccdf_dispatch(self, x, *, method=None, **params):
  4044. pass
  4045. def _moment_standardized_dispatch(self, order, *, loc, scale, sign, methods,
  4046. **params):
  4047. res = (self._dist._moment_standardized_dispatch(
  4048. order, methods=methods, **params))
  4049. return None if res is None else res * np.sign(scale)**order
  4050. def _moment_central_dispatch(self, order, *, loc, scale, sign, methods,
  4051. **params):
  4052. res = (self._dist._moment_central_dispatch(
  4053. order, methods=methods, **params))
  4054. return None if res is None else res * scale**order
  4055. def _moment_raw_dispatch(self, order, *, loc, scale, sign, methods,
  4056. **params):
  4057. raw_moments = []
  4058. methods_highest_order = methods
  4059. for i in range(int(order) + 1):
  4060. methods = (self._moment_methods if i < order
  4061. else methods_highest_order)
  4062. raw = self._dist._moment_raw_dispatch(i, methods=methods, **params)
  4063. if raw is None:
  4064. return None
  4065. moment_i = raw * scale**i
  4066. raw_moments.append(moment_i)
  4067. return self._moment_transform_center(
  4068. order, raw_moments, loc, self._zero)
  4069. def _sample_dispatch(self, full_shape, *,
  4070. rng, loc, scale, sign, method, **params):
  4071. rvs = self._dist._sample_dispatch(full_shape, method=method, rng=rng, **params)
  4072. return self._itransform(rvs, loc=loc, scale=scale, sign=sign, **params)
  4073. def __add__(self, loc):
  4074. return ShiftedScaledDistribution(self._dist, loc=self.loc + loc,
  4075. scale=self.scale)
  4076. def __sub__(self, loc):
  4077. return ShiftedScaledDistribution(self._dist, loc=self.loc - loc,
  4078. scale=self.scale)
  4079. def __mul__(self, scale):
  4080. return ShiftedScaledDistribution(self._dist,
  4081. loc=self.loc * scale,
  4082. scale=self.scale * scale)
  4083. def __truediv__(self, scale):
  4084. return ShiftedScaledDistribution(self._dist,
  4085. loc=self.loc / scale,
  4086. scale=self.scale / scale)
  4087. class OrderStatisticDistribution(TransformedDistribution):
  4088. r"""Probability distribution of an order statistic
  4089. An instance of this class represents a random variable that follows the
  4090. distribution underlying the :math:`r^{\text{th}}` order statistic of a
  4091. sample of :math:`n` observations of a random variable :math:`X`.
  4092. Parameters
  4093. ----------
  4094. dist : `ContinuousDistribution`
  4095. The random variable :math:`X`
  4096. n : array_like
  4097. The (integer) sample size :math:`n`
  4098. r : array_like
  4099. The (integer) rank of the order statistic :math:`r`
  4100. Notes
  4101. -----
  4102. If we make :math:`n` observations of a continuous random variable
  4103. :math:`X` and sort them in increasing order
  4104. :math:`X_{(1)}, \dots, X_{(r)}, \dots, X_{(n)}`,
  4105. :math:`X_{(r)}` is known as the :math:`r^{\text{th}}` order statistic.
  4106. If the PDF, CDF, and CCDF underlying math:`X` are denoted :math:`f`,
  4107. :math:`F`, and :math:`F'`, respectively, then the PDF underlying
  4108. math:`X_{(r)}` is given by:
  4109. .. math::
  4110. f_r(x) = \frac{n!}{(r-1)! (n-r)!} f(x) F(x)^{r-1} F'(x)^{n - r}
  4111. The CDF and other methods of the distribution underlying :math:`X_{(r)}`
  4112. are calculated using the fact that :math:`X = F^{-1}(U)`, where :math:`U` is
  4113. a standard uniform random variable, and that the order statistics of
  4114. observations of `U` follow a beta distribution, :math:`B(r, n - r + 1)`.
  4115. References
  4116. ----------
  4117. .. [1] Order statistic. *Wikipedia*. https://en.wikipedia.org/wiki/Order_statistic
  4118. Examples
  4119. --------
  4120. Suppose we are interested in order statistics of samples of size five drawn
  4121. from the standard normal distribution. Plot the PDF underlying the fourth
  4122. order statistic and compare with a normalized histogram from simulation.
  4123. >>> import numpy as np
  4124. >>> import matplotlib.pyplot as plt
  4125. >>> from scipy import stats
  4126. >>> from scipy.stats._distribution_infrastructure import OrderStatisticDistribution
  4127. >>>
  4128. >>> X = stats.Normal()
  4129. >>> data = X.sample(shape=(10000, 5))
  4130. >>> ranks = np.sort(data, axis=1)
  4131. >>> Y = OrderStatisticDistribution(X, r=4, n=5)
  4132. >>>
  4133. >>> ax = plt.gca()
  4134. >>> Y.plot(ax=ax)
  4135. >>> ax.hist(ranks[:, 3], density=True, bins=30)
  4136. >>> plt.show()
  4137. """
  4138. # These can be restricted to _IntegerInterval/_IntegerParameter in a separate
  4139. # PR if desired.
  4140. _r_domain = _RealInterval(endpoints=(1, 'n'), inclusive=(True, True))
  4141. _r_param = _RealParameter('r', domain=_r_domain, typical=(1, 2))
  4142. _n_domain = _RealInterval(endpoints=(1, np.inf), inclusive=(True, True))
  4143. _n_param = _RealParameter('n', domain=_n_domain, typical=(1, 4))
  4144. _r_domain.define_parameters(_n_param)
  4145. _parameterizations = [_Parameterization(_r_param, _n_param)]
  4146. def __init__(self, dist, /, *args, r, n, **kwargs):
  4147. super().__init__(dist, *args, r=r, n=n, **kwargs)
  4148. def _support(self, *args, r, n, **kwargs):
  4149. return self._dist._support(*args, **kwargs)
  4150. def _process_parameters(self, r=None, n=None, **params):
  4151. parameters = self._dist._process_parameters(**params)
  4152. parameters.update(dict(r=r, n=n))
  4153. return parameters
  4154. def _overrides(self, method_name):
  4155. return method_name in {'_logpdf_formula', '_pdf_formula',
  4156. '_cdf_formula', '_ccdf_formula',
  4157. '_icdf_formula', '_iccdf_formula'}
  4158. def _logpdf_formula(self, x, r, n, **kwargs):
  4159. log_factor = special.betaln(r, n - r + 1)
  4160. log_fX = self._dist._logpdf_dispatch(x, **kwargs)
  4161. # log-methods sometimes use complex dtype with 0 imaginary component,
  4162. # but `_tanhsinh` doesn't accept complex limits of integration; take `real`.
  4163. log_FX = self._dist._logcdf_dispatch(x.real, **kwargs)
  4164. log_cFX = self._dist._logccdf_dispatch(x.real, **kwargs)
  4165. # This can be problematic when (r - 1)|(n-r) = 0 and `log_FX`|log_cFX = -inf
  4166. # The PDF in these cases is 0^0, so these should be replaced with log(1)=0
  4167. # return log_fX + (r-1)*log_FX + (n-r)*log_cFX - log_factor
  4168. rm1_log_FX = np.where((r - 1 == 0) & np.isneginf(log_FX), 0, (r-1)*log_FX)
  4169. nmr_log_cFX = np.where((n - r == 0) & np.isneginf(log_cFX), 0, (n-r)*log_cFX)
  4170. return log_fX + rm1_log_FX + nmr_log_cFX - log_factor
  4171. def _pdf_formula(self, x, r, n, **kwargs):
  4172. # 1 / factor = factorial(n) / (factorial(r-1) * factorial(n-r))
  4173. factor = special.beta(r, n - r + 1)
  4174. fX = self._dist._pdf_dispatch(x, **kwargs)
  4175. FX = self._dist._cdf_dispatch(x, **kwargs)
  4176. cFX = self._dist._ccdf_dispatch(x, **kwargs)
  4177. return fX * FX**(r-1) * cFX**(n-r) / factor
  4178. def _cdf_formula(self, x, r, n, **kwargs):
  4179. x_ = self._dist._cdf_dispatch(x, **kwargs)
  4180. return special.betainc(r, n-r+1, x_)
  4181. def _ccdf_formula(self, x, r, n, **kwargs):
  4182. x_ = self._dist._cdf_dispatch(x, **kwargs)
  4183. return special.betaincc(r, n-r+1, x_)
  4184. def _icdf_formula(self, p, r, n, **kwargs):
  4185. p_ = special.betaincinv(r, n-r+1, p)
  4186. return self._dist._icdf_dispatch(p_, **kwargs)
  4187. def _iccdf_formula(self, p, r, n, **kwargs):
  4188. p_ = special.betainccinv(r, n-r+1, p)
  4189. return self._dist._icdf_dispatch(p_, **kwargs)
  4190. def __repr__(self):
  4191. with np.printoptions(threshold=10):
  4192. return (f"order_statistic({repr(self._dist)}, r={repr(self.r)}, "
  4193. f"n={repr(self.n)})")
  4194. def __str__(self):
  4195. with np.printoptions(threshold=10):
  4196. return (f"order_statistic({str(self._dist)}, r={str(self.r)}, "
  4197. f"n={str(self.n)})")
  4198. @xp_capabilities(np_only=True)
  4199. def order_statistic(X, /, *, r, n):
  4200. r"""Probability distribution of an order statistic
  4201. Returns a random variable that follows the distribution underlying the
  4202. :math:`r^{\text{th}}` order statistic of a sample of :math:`n`
  4203. observations of a random variable :math:`X`.
  4204. Parameters
  4205. ----------
  4206. X : `ContinuousDistribution`
  4207. The random variable :math:`X`
  4208. r : array_like
  4209. The (positive integer) rank of the order statistic :math:`r`
  4210. n : array_like
  4211. The (positive integer) sample size :math:`n`
  4212. Returns
  4213. -------
  4214. Y : `ContinuousDistribution`
  4215. A random variable that follows the distribution of the prescribed
  4216. order statistic.
  4217. Notes
  4218. -----
  4219. If we make :math:`n` observations of a continuous random variable
  4220. :math:`X` and sort them in increasing order
  4221. :math:`X_{(1)}, \dots, X_{(r)}, \dots, X_{(n)}`,
  4222. :math:`X_{(r)}` is known as the :math:`r^{\text{th}}` order statistic.
  4223. If the PDF, CDF, and CCDF underlying math:`X` are denoted :math:`f`,
  4224. :math:`F`, and :math:`F'`, respectively, then the PDF underlying
  4225. math:`X_{(r)}` is given by:
  4226. .. math::
  4227. f_r(x) = \frac{n!}{(r-1)! (n-r)!} f(x) F(x)^{r-1} F'(x)^{n - r}
  4228. The CDF and other methods of the distribution underlying :math:`X_{(r)}`
  4229. are calculated using the fact that :math:`X = F^{-1}(U)`, where :math:`U` is
  4230. a standard uniform random variable, and that the order statistics of
  4231. observations of `U` follow a beta distribution, :math:`B(r, n - r + 1)`.
  4232. References
  4233. ----------
  4234. .. [1] Order statistic. *Wikipedia*. https://en.wikipedia.org/wiki/Order_statistic
  4235. Examples
  4236. --------
  4237. Suppose we are interested in order statistics of samples of size five drawn
  4238. from the standard normal distribution. Plot the PDF underlying each
  4239. order statistic and compare with a normalized histogram from simulation.
  4240. >>> import numpy as np
  4241. >>> import matplotlib.pyplot as plt
  4242. >>> from scipy import stats
  4243. >>>
  4244. >>> X = stats.Normal()
  4245. >>> data = X.sample(shape=(10000, 5))
  4246. >>> sorted = np.sort(data, axis=1)
  4247. >>> Y = stats.order_statistic(X, r=[1, 2, 3, 4, 5], n=5)
  4248. >>>
  4249. >>> ax = plt.gca()
  4250. >>> colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
  4251. >>> for i in range(5):
  4252. ... y = sorted[:, i]
  4253. ... ax.hist(y, density=True, bins=30, alpha=0.1, color=colors[i])
  4254. >>> Y.plot(ax=ax)
  4255. >>> plt.show()
  4256. """
  4257. r, n = np.asarray(r), np.asarray(n)
  4258. if np.any((r != np.floor(r)) | (r < 0)) or np.any((n != np.floor(n)) | (n < 0)):
  4259. message = "`r` and `n` must contain only positive integers."
  4260. raise ValueError(message)
  4261. return OrderStatisticDistribution(X, r=r, n=n)
  4262. class Mixture(_ProbabilityDistribution):
  4263. r"""Representation of a mixture distribution.
  4264. A mixture distribution is the distribution of a random variable
  4265. defined in the following way: first, a random variable is selected
  4266. from `components` according to the probabilities given by `weights`, then
  4267. the selected random variable is realized.
  4268. Parameters
  4269. ----------
  4270. components : sequence of `ContinuousDistribution`
  4271. The underlying instances of `ContinuousDistribution`.
  4272. All must have scalar shape parameters (if any); e.g., the `pdf` evaluated
  4273. at a scalar argument must return a scalar.
  4274. weights : sequence of floats, optional
  4275. The corresponding probabilities of selecting each random variable.
  4276. Must be non-negative and sum to one. The default behavior is to weight
  4277. all components equally.
  4278. Attributes
  4279. ----------
  4280. components : sequence of `ContinuousDistribution`
  4281. The underlying instances of `ContinuousDistribution`.
  4282. weights : ndarray
  4283. The corresponding probabilities of selecting each random variable.
  4284. Methods
  4285. -------
  4286. support
  4287. sample
  4288. moment
  4289. mean
  4290. median
  4291. mode
  4292. variance
  4293. standard_deviation
  4294. skewness
  4295. kurtosis
  4296. pdf
  4297. logpdf
  4298. cdf
  4299. icdf
  4300. ccdf
  4301. iccdf
  4302. logcdf
  4303. ilogcdf
  4304. logccdf
  4305. ilogccdf
  4306. entropy
  4307. Notes
  4308. -----
  4309. The following abbreviations are used throughout the documentation.
  4310. - PDF: probability density function
  4311. - CDF: cumulative distribution function
  4312. - CCDF: complementary CDF
  4313. - entropy: differential entropy
  4314. - log-*F*: logarithm of *F* (e.g. log-CDF)
  4315. - inverse *F*: inverse function of *F* (e.g. inverse CDF)
  4316. References
  4317. ----------
  4318. .. [1] Mixture distribution, *Wikipedia*,
  4319. https://en.wikipedia.org/wiki/Mixture_distribution
  4320. Examples
  4321. --------
  4322. A mixture of normal distributions:
  4323. >>> import numpy as np
  4324. >>> from scipy import stats
  4325. >>> import matplotlib.pyplot as plt
  4326. >>> X1 = stats.Normal(mu=-2, sigma=1)
  4327. >>> X2 = stats.Normal(mu=2, sigma=1)
  4328. >>> mixture = stats.Mixture([X1, X2], weights=[0.4, 0.6])
  4329. >>> print(f'mean: {mixture.mean():.2f}, '
  4330. ... f'median: {mixture.median():.2f}, '
  4331. ... f'mode: {mixture.mode():.2f}')
  4332. mean: 0.40, median: 1.04, mode: 2.00
  4333. >>> x = np.linspace(-10, 10, 300)
  4334. >>> plt.plot(x, mixture.pdf(x))
  4335. >>> plt.title('PDF of normal distribution mixture')
  4336. >>> plt.show()
  4337. """
  4338. # Todo:
  4339. # Add support for array shapes, weights
  4340. def _input_validation(self, components, weights):
  4341. if len(components) == 0:
  4342. message = ("`components` must contain at least one random variable.")
  4343. raise ValueError(message)
  4344. for var in components:
  4345. # will generalize to other kinds of distributions when there
  4346. # *are* other kinds of distributions
  4347. if not isinstance(var, ContinuousDistribution):
  4348. message = ("Each element of `components` must be an instance of "
  4349. "`ContinuousDistribution`.")
  4350. raise ValueError(message)
  4351. if not var._shape == ():
  4352. message = "All elements of `components` must have scalar shapes."
  4353. raise ValueError(message)
  4354. if weights is None:
  4355. return components, weights
  4356. weights = np.asarray(weights)
  4357. if weights.shape != (len(components),):
  4358. message = "`components` and `weights` must have the same length."
  4359. raise ValueError(message)
  4360. if not np.issubdtype(weights.dtype, np.inexact):
  4361. message = "`weights` must have floating point dtype."
  4362. raise ValueError(message)
  4363. if not np.isclose(np.sum(weights), 1.0):
  4364. message = "`weights` must sum to 1.0."
  4365. raise ValueError(message)
  4366. if not np.all(weights >= 0):
  4367. message = "All `weights` must be non-negative."
  4368. raise ValueError(message)
  4369. return components, weights
  4370. def __init__(self, components, *, weights=None):
  4371. components, weights = self._input_validation(components, weights)
  4372. n = len(components)
  4373. dtype = np.result_type(*(var._dtype for var in components))
  4374. self._shape = np.broadcast_shapes(*(var._shape for var in components))
  4375. self._dtype, self._components = dtype, components
  4376. self._weights = np.full(n, 1/n, dtype=dtype) if weights is None else weights
  4377. self.validation_policy = None
  4378. @property
  4379. def components(self):
  4380. return list(self._components)
  4381. @property
  4382. def weights(self):
  4383. return self._weights.copy()
  4384. def _full(self, val, *args):
  4385. args = [np.asarray(arg) for arg in args]
  4386. dtype = np.result_type(self._dtype, *(arg.dtype for arg in args))
  4387. shape = np.broadcast_shapes(self._shape, *(arg.shape for arg in args))
  4388. return np.full(shape, val, dtype=dtype)
  4389. def _sum(self, fun, *args):
  4390. out = self._full(0, *args)
  4391. for var, weight in zip(self._components, self._weights):
  4392. out += getattr(var, fun)(*args) * weight
  4393. return out[()]
  4394. def _logsum(self, fun, *args):
  4395. out = self._full(-np.inf, *args)
  4396. for var, log_weight in zip(self._components, np.log(self._weights)):
  4397. np.logaddexp(out, getattr(var, fun)(*args) + log_weight, out=out)
  4398. return out[()]
  4399. def support(self):
  4400. a = self._full(np.inf)
  4401. b = self._full(-np.inf)
  4402. for var in self._components:
  4403. a = np.minimum(a, var.support()[0])
  4404. b = np.maximum(b, var.support()[1])
  4405. return a, b
  4406. def _raise_if_method(self, method):
  4407. if method is not None:
  4408. raise NotImplementedError("`method` not implemented for this distribution.")
  4409. def logentropy(self, *, method=None):
  4410. self._raise_if_method(method)
  4411. def log_integrand(x):
  4412. # `x` passed by `_tanhsinh` will be of complex dtype because
  4413. # `log_integrand` returns complex values, but the imaginary
  4414. # component is always zero. Extract the real part because
  4415. # `logpdf` uses `logaddexp`, which fails for complex input.
  4416. return self.logpdf(x.real) + np.log(self.logpdf(x.real) + 0j)
  4417. res = _tanhsinh(log_integrand, *self.support(), log=True).integral
  4418. return _log_real_standardize(res + np.pi*1j)
  4419. def entropy(self, *, method=None):
  4420. self._raise_if_method(method)
  4421. return _tanhsinh(lambda x: -self.pdf(x) * self.logpdf(x),
  4422. *self.support()).integral
  4423. def mode(self, *, method=None):
  4424. self._raise_if_method(method)
  4425. a, b = self.support()
  4426. def f(x): return -self.pdf(x)
  4427. res = _bracket_minimum(f, 1., xmin=a, xmax=b)
  4428. res = _chandrupatla_minimize(f, res.xl, res.xm, res.xr)
  4429. return res.x
  4430. def median(self, *, method=None):
  4431. self._raise_if_method(method)
  4432. return self.icdf(0.5)
  4433. def mean(self, *, method=None):
  4434. self._raise_if_method(method)
  4435. return self._sum('mean')
  4436. def variance(self, *, method=None):
  4437. self._raise_if_method(method)
  4438. return self._moment_central(2)
  4439. def standard_deviation(self, *, method=None):
  4440. self._raise_if_method(method)
  4441. return self.variance()**0.5
  4442. def skewness(self, *, method=None):
  4443. self._raise_if_method(method)
  4444. return self._moment_standardized(3)
  4445. def kurtosis(self, *, method=None):
  4446. self._raise_if_method(method)
  4447. return self._moment_standardized(4)
  4448. def moment(self, order=1, kind='raw', *, method=None):
  4449. self._raise_if_method(method)
  4450. kinds = {'raw': self._moment_raw,
  4451. 'central': self._moment_central,
  4452. 'standardized': self._moment_standardized}
  4453. order = ContinuousDistribution._validate_order_kind(self, order, kind, kinds)
  4454. moment_kind = kinds[kind]
  4455. return moment_kind(order)
  4456. def _moment_raw(self, order):
  4457. out = self._full(0)
  4458. for var, weight in zip(self._components, self._weights):
  4459. out += var.moment(order, kind='raw') * weight
  4460. return out[()]
  4461. def _moment_central(self, order):
  4462. order = int(order)
  4463. out = self._full(0)
  4464. for var, weight in zip(self._components, self._weights):
  4465. moment_as = [var.moment(order, kind='central')
  4466. for order in range(order + 1)]
  4467. a, b = var.mean(), self.mean()
  4468. moment = var._moment_transform_center(order, moment_as, a, b)
  4469. out += moment * weight
  4470. return out[()]
  4471. def _moment_standardized(self, order):
  4472. return self._moment_central(order) / self.standard_deviation()**order
  4473. def pdf(self, x, /, *, method=None):
  4474. self._raise_if_method(method)
  4475. return self._sum('pdf', x)
  4476. def logpdf(self, x, /, *, method=None):
  4477. self._raise_if_method(method)
  4478. return self._logsum('logpdf', x)
  4479. def pmf(self, x, /, *, method=None):
  4480. self._raise_if_method(method)
  4481. return self._sum('pmf', x)
  4482. def logpmf(self, x, /, *, method=None):
  4483. self._raise_if_method(method)
  4484. return self._logsum('logpmf', x)
  4485. def cdf(self, x, y=None, /, *, method=None):
  4486. self._raise_if_method(method)
  4487. args = (x,) if y is None else (x, y)
  4488. return self._sum('cdf', *args)
  4489. def logcdf(self, x, y=None, /, *, method=None):
  4490. self._raise_if_method(method)
  4491. args = (x,) if y is None else (x, y)
  4492. return self._logsum('logcdf', *args)
  4493. def ccdf(self, x, y=None, /, *, method=None):
  4494. self._raise_if_method(method)
  4495. args = (x,) if y is None else (x, y)
  4496. return self._sum('ccdf', *args)
  4497. def logccdf(self, x, y=None, /, *, method=None):
  4498. self._raise_if_method(method)
  4499. args = (x,) if y is None else (x, y)
  4500. return self._logsum('logccdf', *args)
  4501. def _invert(self, fun, p):
  4502. xmin, xmax = self.support()
  4503. fun = getattr(self, fun)
  4504. f = lambda x, p: fun(x) - p # noqa: E731 is silly
  4505. xl0, xr0 = _guess_bracket(xmin, xmax)
  4506. res = _bracket_root(f, xl0=xl0, xr0=xr0, xmin=xmin, xmax=xmax, args=(p,))
  4507. return _chandrupatla(f, a=res.xl, b=res.xr, args=(p,)).x
  4508. def icdf(self, p, /, *, method=None):
  4509. self._raise_if_method(method)
  4510. return self._invert('cdf', p)
  4511. def iccdf(self, p, /, *, method=None):
  4512. self._raise_if_method(method)
  4513. return self._invert('ccdf', p)
  4514. def ilogcdf(self, p, /, *, method=None):
  4515. self._raise_if_method(method)
  4516. return self._invert('logcdf', p)
  4517. def ilogccdf(self, p, /, *, method=None):
  4518. self._raise_if_method(method)
  4519. return self._invert('logccdf', p)
  4520. def sample(self, shape=(), *, rng=None, method=None):
  4521. self._raise_if_method(method)
  4522. rng = np.random.default_rng(rng)
  4523. size = np.prod(np.atleast_1d(shape))
  4524. ns = rng.multinomial(size, self._weights)
  4525. x = [var.sample(shape=n, rng=rng) for n, var in zip(ns, self._components)]
  4526. x = np.reshape(rng.permuted(np.concatenate(x)), shape)
  4527. return x[()]
  4528. def __repr__(self):
  4529. result = "Mixture(\n"
  4530. result += " [\n"
  4531. with np.printoptions(threshold=10):
  4532. for component in self.components:
  4533. result += f" {repr(component)},\n"
  4534. result += " ],\n"
  4535. result += f" weights={repr(self.weights)},\n"
  4536. result += ")"
  4537. return result
  4538. def __str__(self):
  4539. result = "Mixture(\n"
  4540. result += " [\n"
  4541. with np.printoptions(threshold=10):
  4542. for component in self.components:
  4543. result += f" {str(component)},\n"
  4544. result += " ],\n"
  4545. result += f" weights={str(self.weights)},\n"
  4546. result += ")"
  4547. return result
  4548. class MonotonicTransformedDistribution(TransformedDistribution):
  4549. r"""Distribution underlying a strictly monotonic function of a random variable
  4550. Given a random variable :math:`X`; a strictly monotonic function
  4551. :math:`g(u)`, its inverse :math:`h(u) = g^{-1}(u)`, and the derivative magnitude
  4552. :math: `|h'(u)| = \left| \frac{dh(u)}{du} \right|`, define the distribution
  4553. underlying the random variable :math:`Y = g(X)`.
  4554. Parameters
  4555. ----------
  4556. X : `ContinuousDistribution`
  4557. The random variable :math:`X`.
  4558. g, h, dh : callable
  4559. Elementwise functions representing the mathematical functions
  4560. :math:`g(u)`, :math:`h(u)`, and :math:`|h'(u)|`
  4561. logdh : callable, optional
  4562. Elementwise function representing :math:`\log(h'(u))`.
  4563. The default is ``lambda u: np.log(dh(u))``, but providing
  4564. a custom implementation may avoid over/underflow.
  4565. increasing : bool, optional
  4566. Whether the function is strictly increasing (True, default)
  4567. or strictly decreasing (False).
  4568. repr_pattern : str, optional
  4569. A string pattern for determining the __repr__. The __repr__
  4570. for X will be substituted into the position where `***` appears.
  4571. For example:
  4572. ``"exp(***)"`` for the repr of an exponentially transformed
  4573. distribution
  4574. The default is ``f"{g.__name__}(***)"``.
  4575. str_pattern : str, optional
  4576. A string pattern for determining `__str__`. The `__str__`
  4577. for X will be substituted into the position where `***` appears.
  4578. For example:
  4579. ``"exp(***)"`` for the repr of an exponentially transformed
  4580. distribution
  4581. The default is the value `repr_pattern` takes.
  4582. """
  4583. def __init__(self, X, /, *args, g, h, dh, logdh=None,
  4584. increasing=True, repr_pattern=None,
  4585. str_pattern=None, **kwargs):
  4586. super().__init__(X, *args, **kwargs)
  4587. self._g = g
  4588. self._h = h
  4589. self._dh = dh
  4590. self._logdh = (logdh if logdh is not None
  4591. else lambda u: np.log(dh(u)))
  4592. if increasing:
  4593. self._xdf = self._dist._cdf_dispatch
  4594. self._cxdf = self._dist._ccdf_dispatch
  4595. self._ixdf = self._dist._icdf_dispatch
  4596. self._icxdf = self._dist._iccdf_dispatch
  4597. self._logxdf = self._dist._logcdf_dispatch
  4598. self._logcxdf = self._dist._logccdf_dispatch
  4599. self._ilogxdf = self._dist._ilogcdf_dispatch
  4600. self._ilogcxdf = self._dist._ilogccdf_dispatch
  4601. else:
  4602. self._xdf = self._dist._ccdf_dispatch
  4603. self._cxdf = self._dist._cdf_dispatch
  4604. self._ixdf = self._dist._iccdf_dispatch
  4605. self._icxdf = self._dist._icdf_dispatch
  4606. self._logxdf = self._dist._logccdf_dispatch
  4607. self._logcxdf = self._dist._logcdf_dispatch
  4608. self._ilogxdf = self._dist._ilogccdf_dispatch
  4609. self._ilogcxdf = self._dist._ilogcdf_dispatch
  4610. self._increasing = increasing
  4611. self._repr_pattern = repr_pattern or f"{g.__name__}(***)"
  4612. self._str_pattern = str_pattern or self._repr_pattern
  4613. def __repr__(self):
  4614. with np.printoptions(threshold=10):
  4615. return self._repr_pattern.replace("***", repr(self._dist))
  4616. def __str__(self):
  4617. with np.printoptions(threshold=10):
  4618. return self._str_pattern.replace("***", str(self._dist))
  4619. def _overrides(self, method_name):
  4620. # Do not use the generic overrides of TransformedDistribution
  4621. return False
  4622. def _support(self, **params):
  4623. a, b = self._dist._support(**params)
  4624. # For reciprocal transformation, we want this zero to become -inf
  4625. b = np.where(b==0, np.asarray("-0", dtype=b.dtype), b)
  4626. with np.errstate(divide='ignore'):
  4627. if self._increasing:
  4628. return self._g(a), self._g(b)
  4629. else:
  4630. return self._g(b), self._g(a)
  4631. def _logpdf_dispatch(self, x, *args, **params):
  4632. return self._dist._logpdf_dispatch(self._h(x), *args, **params) + self._logdh(x)
  4633. def _pdf_dispatch(self, x, *args, **params):
  4634. return self._dist._pdf_dispatch(self._h(x), *args, **params) * self._dh(x)
  4635. def _logcdf_dispatch(self, x, *args, **params):
  4636. return self._logxdf(self._h(x), *args, **params)
  4637. def _cdf_dispatch(self, x, *args, **params):
  4638. return self._xdf(self._h(x), *args, **params)
  4639. def _logccdf_dispatch(self, x, *args, **params):
  4640. return self._logcxdf(self._h(x), *args, **params)
  4641. def _ccdf_dispatch(self, x, *args, **params):
  4642. return self._cxdf(self._h(x), *args, **params)
  4643. def _ilogcdf_dispatch(self, p, *args, **params):
  4644. return self._g(self._ilogxdf(p, *args, **params))
  4645. def _icdf_dispatch(self, p, *args, **params):
  4646. return self._g(self._ixdf(p, *args, **params))
  4647. def _ilogccdf_dispatch(self, p, *args, **params):
  4648. return self._g(self._ilogcxdf(p, *args, **params))
  4649. def _iccdf_dispatch(self, p, *args, **params):
  4650. return self._g(self._icxdf(p, *args, **params))
  4651. def _sample_dispatch(self, full_shape, *, method, rng, **params):
  4652. rvs = self._dist._sample_dispatch(full_shape, method=method, rng=rng, **params)
  4653. return self._g(rvs)
  4654. class FoldedDistribution(TransformedDistribution):
  4655. r"""Distribution underlying the absolute value of a random variable
  4656. Given a random variable :math:`X`; define the distribution
  4657. underlying the random variable :math:`Y = |X|`.
  4658. Parameters
  4659. ----------
  4660. X : `ContinuousDistribution`
  4661. The random variable :math:`X`.
  4662. Returns
  4663. -------
  4664. Y : `ContinuousDistribution`
  4665. The random variable :math:`Y = |X|`
  4666. """
  4667. # Many enhancements are possible if distribution is symmetric. Start
  4668. # with the general case; enhance later.
  4669. def __init__(self, X, /, *args, **kwargs):
  4670. super().__init__(X, *args, **kwargs)
  4671. # I think we need to allow `_support` to define whether the endpoints
  4672. # are inclusive or not. In the meantime, it's best to ensure that the lower
  4673. # endpoint (typically 0 for folded distribution) is inclusive so PDF evaluates
  4674. # correctly at that point.
  4675. self._variable.domain.inclusive = (True, self._variable.domain.inclusive[1])
  4676. def _overrides(self, method_name):
  4677. # Do not use the generic overrides of TransformedDistribution
  4678. return False
  4679. def _support(self, **params):
  4680. a, b = self._dist._support(**params)
  4681. a_, b_ = np.abs(a), np.abs(b)
  4682. a_, b_ = np.minimum(a_, b_), np.maximum(a_, b_)
  4683. i = (a < 0) & (b > 0)
  4684. a_ = np.asarray(a_)
  4685. a_[i] = 0
  4686. return a_[()], b_[()]
  4687. def _logpdf_dispatch(self, x, *args, method=None, **params):
  4688. x = np.abs(x)
  4689. right = self._dist._logpdf_dispatch(x, *args, method=method, **params)
  4690. left = self._dist._logpdf_dispatch(-x, *args, method=method, **params)
  4691. left = np.asarray(left)
  4692. right = np.asarray(right)
  4693. a, b = self._dist._support(**params)
  4694. left[-x < a] = -np.inf
  4695. right[x > b] = -np.inf
  4696. logpdfs = np.stack([left, right])
  4697. return special.logsumexp(logpdfs, axis=0)
  4698. def _pdf_dispatch(self, x, *args, method=None, **params):
  4699. x = np.abs(x)
  4700. right = self._dist._pdf_dispatch(x, *args, method=method, **params)
  4701. left = self._dist._pdf_dispatch(-x, *args, method=method, **params)
  4702. left = np.asarray(left)
  4703. right = np.asarray(right)
  4704. a, b = self._dist._support(**params)
  4705. left[-x < a] = 0
  4706. right[x > b] = 0
  4707. return left + right
  4708. def _logcdf_dispatch(self, x, *args, method=None, **params):
  4709. x = np.abs(x)
  4710. a, b = self._dist._support(**params)
  4711. xl = np.maximum(-x, a)
  4712. xr = np.minimum(x, b)
  4713. return self._dist._logcdf2_dispatch(xl, xr, *args, method=method, **params).real
  4714. def _cdf_dispatch(self, x, *args, method=None, **params):
  4715. x = np.abs(x)
  4716. a, b = self._dist._support(**params)
  4717. xl = np.maximum(-x, a)
  4718. xr = np.minimum(x, b)
  4719. return self._dist._cdf2_dispatch(xl, xr, *args, **params)
  4720. def _logccdf_dispatch(self, x, *args, method=None, **params):
  4721. x = np.abs(x)
  4722. a, b = self._dist._support(**params)
  4723. xl = np.maximum(-x, a)
  4724. xr = np.minimum(x, b)
  4725. return self._dist._logccdf2_dispatch(xl, xr, *args, method=method,
  4726. **params).real
  4727. def _ccdf_dispatch(self, x, *args, method=None, **params):
  4728. x = np.abs(x)
  4729. a, b = self._dist._support(**params)
  4730. xl = np.maximum(-x, a)
  4731. xr = np.minimum(x, b)
  4732. return self._dist._ccdf2_dispatch(xl, xr, *args, method=method, **params)
  4733. def _sample_dispatch(self, full_shape, *, method, rng, **params):
  4734. rvs = self._dist._sample_dispatch(full_shape, method=method, rng=rng, **params)
  4735. return np.abs(rvs)
  4736. def __repr__(self):
  4737. with np.printoptions(threshold=10):
  4738. return f"abs({repr(self._dist)})"
  4739. def __str__(self):
  4740. with np.printoptions(threshold=10):
  4741. return f"abs({str(self._dist)})"
  4742. @xp_capabilities(np_only=True)
  4743. def abs(X, /):
  4744. r"""Absolute value of a random variable
  4745. Parameters
  4746. ----------
  4747. X : `ContinuousDistribution`
  4748. The random variable :math:`X`.
  4749. Returns
  4750. -------
  4751. Y : `ContinuousDistribution`
  4752. A random variable :math:`Y = |X|`.
  4753. Examples
  4754. --------
  4755. Suppose we have a normally distributed random variable :math:`X`:
  4756. >>> import numpy as np
  4757. >>> from scipy import stats
  4758. >>> X = stats.Normal()
  4759. We wish to have a random variable :math:`Y` distributed according to
  4760. the folded normal distribution; that is, a random variable :math:`|X|`.
  4761. >>> Y = stats.abs(X)
  4762. The PDF of the distribution in the left half plane is "folded" over to
  4763. the right half plane. Because the normal PDF is symmetric, the resulting
  4764. PDF is zero for negative arguments and doubled for positive arguments.
  4765. >>> import matplotlib.pyplot as plt
  4766. >>> x = np.linspace(0, 5, 300)
  4767. >>> ax = plt.gca()
  4768. >>> Y.plot(x='x', y='pdf', t=('x', -1, 5), ax=ax)
  4769. >>> plt.plot(x, 2 * X.pdf(x), '--')
  4770. >>> plt.legend(('PDF of `Y`', 'Doubled PDF of `X`'))
  4771. >>> plt.show()
  4772. """
  4773. return FoldedDistribution(X)
  4774. @xp_capabilities(np_only=True)
  4775. def exp(X, /):
  4776. r"""Natural exponential of a random variable
  4777. Parameters
  4778. ----------
  4779. X : `ContinuousDistribution`
  4780. The random variable :math:`X`.
  4781. Returns
  4782. -------
  4783. Y : `ContinuousDistribution`
  4784. A random variable :math:`Y = \exp(X)`.
  4785. Examples
  4786. --------
  4787. Suppose we have a normally distributed random variable :math:`X`:
  4788. >>> import numpy as np
  4789. >>> from scipy import stats
  4790. >>> X = stats.Normal()
  4791. We wish to have a lognormally distributed random variable :math:`Y`,
  4792. a random variable whose natural logarithm is :math:`X`.
  4793. If :math:`X` is to be the natural logarithm of :math:`Y`, then we
  4794. must take :math:`Y` to be the natural exponential of :math:`X`.
  4795. >>> Y = stats.exp(X)
  4796. To demonstrate that ``X`` represents the logarithm of ``Y``,
  4797. we plot a normalized histogram of the logarithm of observations of
  4798. ``Y`` against the PDF underlying ``X``.
  4799. >>> import matplotlib.pyplot as plt
  4800. >>> rng = np.random.default_rng(435383595582522)
  4801. >>> y = Y.sample(shape=10000, rng=rng)
  4802. >>> ax = plt.gca()
  4803. >>> ax.hist(np.log(y), bins=50, density=True)
  4804. >>> X.plot(ax=ax)
  4805. >>> plt.legend(('PDF of `X`', 'histogram of `log(y)`'))
  4806. >>> plt.show()
  4807. """
  4808. return MonotonicTransformedDistribution(X, g=np.exp, h=np.log, dh=lambda u: 1 / u,
  4809. logdh=lambda u: -np.log(u))
  4810. @xp_capabilities(np_only=True)
  4811. def log(X, /):
  4812. r"""Natural logarithm of a non-negative random variable
  4813. Parameters
  4814. ----------
  4815. X : `ContinuousDistribution`
  4816. The random variable :math:`X` with positive support.
  4817. Returns
  4818. -------
  4819. Y : `ContinuousDistribution`
  4820. A random variable :math:`Y = \log(X)`.
  4821. Examples
  4822. --------
  4823. Suppose we have a gamma distributed random variable :math:`X`:
  4824. >>> import numpy as np
  4825. >>> from scipy import stats
  4826. >>> Gamma = stats.make_distribution(stats.gamma)
  4827. >>> X = Gamma(a=1.0)
  4828. We wish to have an exp-gamma distributed random variable :math:`Y`,
  4829. a random variable whose natural exponential is :math:`X`.
  4830. If :math:`X` is to be the natural exponential of :math:`Y`, then we
  4831. must take :math:`Y` to be the natural logarithm of :math:`X`.
  4832. >>> Y = stats.log(X)
  4833. To demonstrate that ``X`` represents the exponential of ``Y``,
  4834. we plot a normalized histogram of the exponential of observations of
  4835. ``Y`` against the PDF underlying ``X``.
  4836. >>> import matplotlib.pyplot as plt
  4837. >>> rng = np.random.default_rng(435383595582522)
  4838. >>> y = Y.sample(shape=10000, rng=rng)
  4839. >>> ax = plt.gca()
  4840. >>> ax.hist(np.exp(y), bins=50, density=True)
  4841. >>> X.plot(ax=ax)
  4842. >>> plt.legend(('PDF of `X`', 'histogram of `exp(y)`'))
  4843. >>> plt.show()
  4844. """
  4845. if np.any(X.support()[0] < 0):
  4846. message = ("The logarithm of a random variable is only implemented when the "
  4847. "support is non-negative.")
  4848. raise NotImplementedError(message)
  4849. return MonotonicTransformedDistribution(X, g=np.log, h=np.exp, dh=np.exp,
  4850. logdh=lambda u: u)