test_multivariate.py 193 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028
  1. """
  2. Test functions for multivariate normal, t, and related distributions.
  3. """
  4. import pickle
  5. from dataclasses import dataclass
  6. from numpy.testing import (assert_allclose, assert_almost_equal,
  7. assert_array_almost_equal, assert_equal,
  8. assert_array_less, assert_)
  9. import pytest
  10. from pytest import raises as assert_raises
  11. from .test_continuous_basic import check_distribution_rvs
  12. import numpy as np
  13. import scipy.linalg
  14. from scipy.stats._multivariate import (_PSD,
  15. _lnB,
  16. multivariate_normal_frozen)
  17. from scipy.stats import (multivariate_normal, multivariate_hypergeom,
  18. matrix_normal, special_ortho_group, ortho_group,
  19. random_correlation, unitary_group, dirichlet,
  20. beta, wishart, multinomial, invwishart, chi2,
  21. invgamma, norm, uniform, ks_2samp, kstest, binom,
  22. hypergeom, multivariate_t, cauchy, normaltest,
  23. random_table, uniform_direction, vonmises_fisher,
  24. dirichlet_multinomial, vonmises, matrix_t)
  25. from scipy.stats import _covariance, Covariance
  26. from scipy.stats._continuous_distns import _norm_pdf as norm_pdf
  27. from scipy import stats
  28. from scipy.integrate import tanhsinh, cubature, quad
  29. from scipy.integrate import romb, qmc_quad, dblquad, tplquad
  30. from scipy.special import multigammaln
  31. import scipy.special as special
  32. from .common_tests import check_random_state_property
  33. from .data._mvt import _qsimvtv
  34. from unittest.mock import patch
  35. def assert_close(res, ref, *args, **kwargs):
  36. res, ref = np.asarray(res), np.asarray(ref)
  37. assert_allclose(res, ref, *args, **kwargs)
  38. assert_equal(res.shape, ref.shape)
  39. class TestCovariance:
  40. def test_input_validation(self):
  41. message = "The input `precision` must be a square, two-dimensional..."
  42. with pytest.raises(ValueError, match=message):
  43. _covariance.CovViaPrecision(np.ones(2))
  44. message = "`precision.shape` must equal `covariance.shape`."
  45. with pytest.raises(ValueError, match=message):
  46. _covariance.CovViaPrecision(np.eye(3), covariance=np.eye(2))
  47. message = "The input `diagonal` must be a one-dimensional array..."
  48. with pytest.raises(ValueError, match=message):
  49. _covariance.CovViaDiagonal("alpaca")
  50. message = "The input `cholesky` must be a square, two-dimensional..."
  51. with pytest.raises(ValueError, match=message):
  52. _covariance.CovViaCholesky(np.ones(2))
  53. message = "The input `eigenvalues` must be a one-dimensional..."
  54. with pytest.raises(ValueError, match=message):
  55. _covariance.CovViaEigendecomposition(("alpaca", np.eye(2)))
  56. message = "The input `eigenvectors` must be a square..."
  57. with pytest.raises(ValueError, match=message):
  58. _covariance.CovViaEigendecomposition((np.ones(2), "alpaca"))
  59. message = "The shapes of `eigenvalues` and `eigenvectors` must be..."
  60. with pytest.raises(ValueError, match=message):
  61. _covariance.CovViaEigendecomposition(([1, 2, 3], np.eye(2)))
  62. _covariance_preprocessing = {"Diagonal": np.diag,
  63. "Precision": np.linalg.inv,
  64. "Cholesky": np.linalg.cholesky,
  65. "Eigendecomposition": np.linalg.eigh,
  66. "PSD": lambda x:
  67. _PSD(x, allow_singular=True)}
  68. _all_covariance_types = np.array(list(_covariance_preprocessing))
  69. _matrices = {"diagonal full rank": np.diag([1, 2, 3]),
  70. "general full rank": [[5, 1, 3], [1, 6, 4], [3, 4, 7]],
  71. "diagonal singular": np.diag([1, 0, 3]),
  72. "general singular": [[5, -1, 0], [-1, 5, 0], [0, 0, 0]]}
  73. _cov_types = {"diagonal full rank": _all_covariance_types,
  74. "general full rank": _all_covariance_types[1:],
  75. "diagonal singular": _all_covariance_types[[0, -2, -1]],
  76. "general singular": _all_covariance_types[-2:]}
  77. @pytest.mark.parametrize("cov_type_name", _all_covariance_types[:-1])
  78. def test_factories(self, cov_type_name):
  79. A = np.diag([1, 2, 3])
  80. x = [-4, 2, 5]
  81. cov_type = getattr(_covariance, f"CovVia{cov_type_name}")
  82. preprocessing = self._covariance_preprocessing[cov_type_name]
  83. factory = getattr(Covariance, f"from_{cov_type_name.lower()}")
  84. res = factory(preprocessing(A))
  85. ref = cov_type(preprocessing(A))
  86. assert type(res) is type(ref)
  87. assert_allclose(res.whiten(x), ref.whiten(x))
  88. @pytest.mark.parametrize("matrix_type", list(_matrices))
  89. @pytest.mark.parametrize("cov_type_name", _all_covariance_types)
  90. def test_covariance(self, matrix_type, cov_type_name):
  91. message = (f"CovVia{cov_type_name} does not support {matrix_type} "
  92. "matrices")
  93. if cov_type_name not in self._cov_types[matrix_type]:
  94. pytest.skip(message)
  95. A = self._matrices[matrix_type]
  96. cov_type = getattr(_covariance, f"CovVia{cov_type_name}")
  97. preprocessing = self._covariance_preprocessing[cov_type_name]
  98. psd = _PSD(A, allow_singular=True)
  99. # test properties
  100. cov_object = cov_type(preprocessing(A))
  101. assert_close(cov_object.log_pdet, psd.log_pdet)
  102. assert_equal(cov_object.rank, psd.rank)
  103. assert_equal(cov_object.shape, np.asarray(A).shape)
  104. assert_close(cov_object.covariance, np.asarray(A))
  105. # test whitening/coloring 1D x
  106. rng = np.random.default_rng(5292808890472453840)
  107. x = rng.random(size=3)
  108. res = cov_object.whiten(x)
  109. ref = x @ psd.U
  110. # res != ref in general; but res @ res == ref @ ref
  111. assert_close(res @ res, ref @ ref)
  112. if hasattr(cov_object, "_colorize") and "singular" not in matrix_type:
  113. # CovViaPSD does not have _colorize
  114. assert_close(cov_object.colorize(res), x)
  115. # test whitening/coloring 3D x
  116. x = rng.random(size=(2, 4, 3))
  117. res = cov_object.whiten(x)
  118. ref = x @ psd.U
  119. assert_close((res**2).sum(axis=-1), (ref**2).sum(axis=-1))
  120. if hasattr(cov_object, "_colorize") and "singular" not in matrix_type:
  121. assert_close(cov_object.colorize(res), x)
  122. # gh-19197 reported that multivariate normal `rvs` produced incorrect
  123. # results when a singular Covariance object was produce using
  124. # `from_eigenvalues`. This was due to an issue in `colorize` with
  125. # singular covariance matrices. Check this edge case, which is skipped
  126. # in the previous tests.
  127. if hasattr(cov_object, "_colorize"):
  128. res = cov_object.colorize(np.eye(len(A)))
  129. assert_close(res.T @ res, A)
  130. @pytest.mark.parametrize("size", [None, tuple(), 1, (2, 4, 3)])
  131. @pytest.mark.parametrize("matrix_type", list(_matrices))
  132. @pytest.mark.parametrize("cov_type_name", _all_covariance_types)
  133. def test_mvn_with_covariance(self, size, matrix_type, cov_type_name):
  134. message = (f"CovVia{cov_type_name} does not support {matrix_type} "
  135. "matrices")
  136. if cov_type_name not in self._cov_types[matrix_type]:
  137. pytest.skip(message)
  138. A = self._matrices[matrix_type]
  139. cov_type = getattr(_covariance, f"CovVia{cov_type_name}")
  140. preprocessing = self._covariance_preprocessing[cov_type_name]
  141. mean = [0.1, 0.2, 0.3]
  142. cov_object = cov_type(preprocessing(A))
  143. mvn = multivariate_normal
  144. dist0 = multivariate_normal(mean, A, allow_singular=True)
  145. dist1 = multivariate_normal(mean, cov_object, allow_singular=True)
  146. rng = np.random.default_rng(5292808890472453840)
  147. x = rng.multivariate_normal(mean, A, size=size)
  148. rng = np.random.default_rng(5292808890472453840)
  149. x1 = mvn.rvs(mean, cov_object, size=size, random_state=rng)
  150. rng = np.random.default_rng(5292808890472453840)
  151. x2 = mvn(mean, cov_object, seed=rng).rvs(size=size)
  152. if isinstance(cov_object, _covariance.CovViaPSD):
  153. assert_close(x1, np.squeeze(x)) # for backward compatibility
  154. assert_close(x2, np.squeeze(x))
  155. else:
  156. assert_equal(x1.shape, x.shape)
  157. assert_equal(x2.shape, x.shape)
  158. assert_close(x2, x1)
  159. assert_close(mvn.pdf(x, mean, cov_object), dist0.pdf(x))
  160. assert_close(dist1.pdf(x), dist0.pdf(x))
  161. assert_close(mvn.logpdf(x, mean, cov_object), dist0.logpdf(x))
  162. assert_close(dist1.logpdf(x), dist0.logpdf(x))
  163. assert_close(mvn.entropy(mean, cov_object), dist0.entropy())
  164. assert_close(dist1.entropy(), dist0.entropy())
  165. @pytest.mark.parametrize("size", [tuple(), (2, 4, 3)])
  166. @pytest.mark.parametrize("cov_type_name", _all_covariance_types)
  167. def test_mvn_with_covariance_cdf(self, size, cov_type_name):
  168. # This is split from the test above because it's slow to be running
  169. # with all matrix types, and there's no need because _mvn.mvnun
  170. # does the calculation. All Covariance needs to do is pass is
  171. # provide the `covariance` attribute.
  172. matrix_type = "diagonal full rank"
  173. A = self._matrices[matrix_type]
  174. cov_type = getattr(_covariance, f"CovVia{cov_type_name}")
  175. preprocessing = self._covariance_preprocessing[cov_type_name]
  176. mean = [0.1, 0.2, 0.3]
  177. cov_object = cov_type(preprocessing(A))
  178. mvn = multivariate_normal
  179. dist0 = multivariate_normal(mean, A, allow_singular=True)
  180. dist1 = multivariate_normal(mean, cov_object, allow_singular=True)
  181. rng = np.random.default_rng(5292808890472453840)
  182. x = rng.multivariate_normal(mean, A, size=size)
  183. assert_close(mvn.cdf(x, mean, cov_object), dist0.cdf(x))
  184. assert_close(dist1.cdf(x), dist0.cdf(x))
  185. assert_close(mvn.logcdf(x, mean, cov_object), dist0.logcdf(x))
  186. assert_close(dist1.logcdf(x), dist0.logcdf(x))
  187. def test_covariance_instantiation(self):
  188. message = "The `Covariance` class cannot be instantiated directly."
  189. with pytest.raises(NotImplementedError, match=message):
  190. Covariance()
  191. @pytest.mark.filterwarnings("ignore::RuntimeWarning") # matrix not PSD
  192. def test_gh9942(self):
  193. # Originally there was a mistake in the `multivariate_normal_frozen`
  194. # `rvs` method that caused all covariance objects to be processed as
  195. # a `_CovViaPSD`. Ensure that this is resolved.
  196. A = np.diag([1, 2, -1e-8])
  197. n = A.shape[0]
  198. mean = np.zeros(n)
  199. # Error if the matrix is processed as a `_CovViaPSD`
  200. with pytest.raises(ValueError, match="The input matrix must be..."):
  201. multivariate_normal(mean, A).rvs()
  202. # No error if it is provided as a `CovViaEigendecomposition`
  203. seed = 3562050283508273023
  204. rng1 = np.random.default_rng(seed)
  205. rng2 = np.random.default_rng(seed)
  206. cov = Covariance.from_eigendecomposition(np.linalg.eigh(A))
  207. rv = multivariate_normal(mean, cov)
  208. res = rv.rvs(random_state=rng1)
  209. ref = multivariate_normal.rvs(mean, cov, random_state=rng2)
  210. assert_equal(res, ref)
  211. def test_gh19197(self):
  212. # gh-19197 reported that multivariate normal `rvs` produced incorrect
  213. # results when a singular Covariance object was produce using
  214. # `from_eigenvalues`. Check that this specific issue is resolved;
  215. # a more general test is included in `test_covariance`.
  216. mean = np.ones(2)
  217. cov = Covariance.from_eigendecomposition((np.zeros(2), np.eye(2)))
  218. dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
  219. rvs = dist.rvs(size=None)
  220. assert_equal(rvs, mean)
  221. cov = scipy.stats.Covariance.from_eigendecomposition(
  222. (np.array([1., 0.]), np.array([[1., 0.], [0., 400.]])))
  223. dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
  224. rvs = dist.rvs(size=None)
  225. assert rvs[0] != mean[0]
  226. assert rvs[1] == mean[1]
  227. def _random_covariance(dim, evals, rng, singular=False):
  228. # Generates random covariance matrix with dimensionality `dim` and
  229. # eigenvalues `evals` using provided Generator `rng`. Randomly sets
  230. # some evals to zero if `singular` is True.
  231. A = rng.random((dim, dim))
  232. A = A @ A.T
  233. _, v = np.linalg.eigh(A)
  234. if singular:
  235. zero_eigs = rng.normal(size=dim) > 0
  236. evals[zero_eigs] = 0
  237. cov = v @ np.diag(evals) @ v.T
  238. return cov
  239. def _sample_orthonormal_matrix(n):
  240. rng = np.random.default_rng(9086764251)
  241. M = rng.standard_normal((n, n))
  242. u, s, v = scipy.linalg.svd(M)
  243. return u
  244. def marginal_pdf(X, X_ndim, dimensions, x):
  245. """Integrate marginalized dimensions of multivariate
  246. probability distribution to calculate the marginalized
  247. distribution.
  248. """
  249. # Sort input data based on order of dimensions
  250. dimensions = np.asarray(dimensions)
  251. dimensions[dimensions < 0] += X_ndim
  252. dim_sort_idx = dimensions.argsort()
  253. x = x[:, dim_sort_idx]
  254. i_marginalize = np.ones(X_ndim, dtype=bool)
  255. i_marginalize[dimensions] = False
  256. def g(z):
  257. y = np.empty((z.shape[0], x.shape[0], X_ndim))
  258. y[..., i_marginalize] = z[:, np.newaxis, :]
  259. y[..., ~i_marginalize] = x
  260. return X.pdf(y)
  261. inf = np.full(X_ndim - len(dimensions), np.inf)
  262. return cubature(g, -inf, inf).estimate
  263. @dataclass
  264. class MVNProblem:
  265. """Instantiate a multivariate normal integration problem with special structure.
  266. When covariance matrix is a correlation matrix where the off-diagonal entries
  267. ``covar[i, j] == lambdas[i]*lambdas[j]`` for ``i != j``, then the multidimensional
  268. integral reduces to a simpler univariate integral that can be numerically integrated
  269. easily.
  270. The ``generate_*()`` classmethods provide a few options for creating variations
  271. of this problem.
  272. References
  273. ----------
  274. .. [1] Tong, Y.L. "The Multivariate Normal Distribution".
  275. Springer-Verlag. p192. 1990.
  276. """
  277. ndim : int
  278. low : np.ndarray
  279. high : np.ndarray
  280. lambdas : np.ndarray
  281. covar : np.ndarray
  282. target_val : float
  283. target_err : float
  284. #: The `generator_halves()` case has an analytically-known true value that we'll
  285. #: record here. It remain None for most cases, though.
  286. true_val : float | None = None
  287. def __init__(self, ndim, low, high, lambdas):
  288. super().__init__()
  289. self.ndim = ndim
  290. self.low = low
  291. self.high = high
  292. self.lambdas = lambdas
  293. self.covar = np.outer(self.lambdas, self.lambdas)
  294. np.fill_diagonal(self.covar, 1.0)
  295. self.find_target()
  296. @classmethod
  297. def generate_semigeneral(cls, ndim, rng=None):
  298. """Random lambdas, random upper bounds, infinite lower bounds.
  299. """
  300. rng = np.random.default_rng(rng)
  301. low = np.full(ndim, -np.inf)
  302. high = rng.uniform(0.0, np.sqrt(ndim), size=ndim)
  303. lambdas = rng.uniform(-1.0, 1.0, size=ndim)
  304. self = cls(
  305. ndim=ndim,
  306. low=low,
  307. high=high,
  308. lambdas=lambdas,
  309. )
  310. return self
  311. @classmethod
  312. def generate_constant(cls, ndim, rng=None):
  313. """Constant off-diagonal covariance, random upper bounds, infinite lower bounds.
  314. """
  315. rng = np.random.default_rng(rng)
  316. low = np.full(ndim, -np.inf)
  317. high = rng.uniform(0.0, np.sqrt(ndim), size=ndim)
  318. sigma = np.sqrt(rng.uniform(0.0, 1.0))
  319. lambdas = np.full(ndim, sigma)
  320. self = cls(
  321. ndim=ndim,
  322. low=low,
  323. high=high,
  324. lambdas=lambdas,
  325. )
  326. return self
  327. @classmethod
  328. def generate_halves(cls, ndim, rng=None):
  329. """Off-diagonal covariance of 0.5, negative orthant bounds.
  330. True analytically-derived answer is 1/(ndim+1).
  331. """
  332. low = np.full(ndim, -np.inf)
  333. high = np.zeros(ndim)
  334. lambdas = np.sqrt(0.5)
  335. self = cls(
  336. ndim=ndim,
  337. low=low,
  338. high=high,
  339. lambdas=lambdas,
  340. )
  341. self.true_val = 1 / (ndim+1)
  342. return self
  343. def find_target(self, **kwds):
  344. """Perform the simplified integral and store the results.
  345. """
  346. d = dict(
  347. a=-9.0,
  348. b=+9.0,
  349. )
  350. d.update(kwds)
  351. self.target_val, self.target_err = quad(self.univariate_func, **d)
  352. def _univariate_term(self, t):
  353. """The parameter-specific term of the univariate integrand,
  354. for separate plotting.
  355. """
  356. denom = np.sqrt(1 - self.lambdas**2)
  357. return np.prod(
  358. special.ndtr((self.high + self.lambdas*t[:, np.newaxis]) / denom) -
  359. special.ndtr((self.low + self.lambdas*t[:, np.newaxis]) / denom),
  360. axis=1,
  361. )
  362. def univariate_func(self, t):
  363. """Univariate integrand.
  364. """
  365. t = np.atleast_1d(t)
  366. return np.squeeze(norm_pdf(t) * self._univariate_term(t))
  367. def plot_integrand(self):
  368. """Plot the univariate integrand and its component terms for understanding.
  369. """
  370. from matplotlib import pyplot as plt
  371. t = np.linspace(-9.0, 9.0, 1001)
  372. plt.plot(t, norm_pdf(t), label=r'$\phi(t)$')
  373. plt.plot(t, self._univariate_term(t), label=r'$f(t)$')
  374. plt.plot(t, self.univariate_func(t), label=r'$f(t)*phi(t)$')
  375. plt.legend()
  376. @dataclass
  377. class SingularMVNProblem:
  378. """Instantiate a multivariate normal integration problem with a special singular
  379. covariance structure.
  380. When covariance matrix is a correlation matrix where the off-diagonal entries
  381. ``covar[i, j] == -lambdas[i]*lambdas[j]`` for ``i != j``, and
  382. ``sum(lambdas**2 / (1+lambdas**2)) == 1``, then the matrix is singular, and
  383. the multidimensional integral reduces to a simpler univariate integral that
  384. can be numerically integrated fairly easily.
  385. The lower bound must be infinite, though the upper bounds can be general.
  386. References
  387. ----------
  388. .. [1] Kwong, K.-S. (1995). "Evaluation of the one-sided percentage points of the
  389. singular multivariate normal distribution." Journal of Statistical
  390. Computation and Simulation, 51(2-4), 121-135. doi:10.1080/00949659508811627
  391. """
  392. ndim : int
  393. low : np.ndarray
  394. high : np.ndarray
  395. lambdas : np.ndarray
  396. covar : np.ndarray
  397. target_val : float
  398. target_err : float
  399. def __init__(self, ndim, high, lambdas):
  400. self.ndim = ndim
  401. self.high = high
  402. self.lambdas = lambdas
  403. self.low = np.full(ndim, -np.inf)
  404. self.covar = -np.outer(self.lambdas, self.lambdas)
  405. np.fill_diagonal(self.covar, 1.0)
  406. self.find_target()
  407. @classmethod
  408. def generate_semiinfinite(cls, ndim, rng=None):
  409. """Singular lambdas, random upper bounds.
  410. """
  411. rng = np.random.default_rng(rng)
  412. high = rng.uniform(0.0, np.sqrt(ndim), size=ndim)
  413. p = rng.dirichlet(np.full(ndim, 1.0))
  414. lambdas = np.sqrt(p / (1-p)) * rng.choice([-1.0, 1.0], size=ndim)
  415. self = cls(
  416. ndim=ndim,
  417. high=high,
  418. lambdas=lambdas,
  419. )
  420. return self
  421. def find_target(self, **kwds):
  422. d = dict(
  423. a=-9.0,
  424. b=+9.0,
  425. )
  426. d.update(kwds)
  427. self.target_val, self.target_err = quad(self.univariate_func, **d)
  428. def _univariate_term(self, t):
  429. denom = np.sqrt(1 + self.lambdas**2)
  430. i1 = np.prod(
  431. special.ndtr((self.high - 1j*self.lambdas*t[:, np.newaxis]) / denom),
  432. axis=1,
  433. )
  434. i2 = np.prod(
  435. special.ndtr((-self.high + 1j*self.lambdas*t[:, np.newaxis]) / denom),
  436. axis=1,
  437. )
  438. # The imaginary part is an odd function, so it can be ignored; it will integrate
  439. # out to 0.
  440. return (i1 - (-1)**self.ndim * i2).real
  441. def univariate_func(self, t):
  442. t = np.atleast_1d(t)
  443. return (norm_pdf(t) * self._univariate_term(t)).squeeze()
  444. def plot_integrand(self):
  445. """Plot the univariate integrand and its component terms for understanding.
  446. """
  447. from matplotlib import pyplot as plt
  448. t = np.linspace(-9.0, 9.0, 1001)
  449. plt.plot(t, norm_pdf(t), label=r'$\phi(t)$')
  450. plt.plot(t, self._univariate_term(t), label=r'$f(t)$')
  451. plt.plot(t, self.univariate_func(t), label=r'$f(t)*phi(t)$')
  452. plt.ylim(-0.1, 1.1)
  453. plt.legend()
  454. class TestMultivariateNormal:
  455. def test_input_shape(self):
  456. mu = np.arange(3)
  457. cov = np.identity(2)
  458. assert_raises(ValueError, multivariate_normal.pdf, (0, 1), mu, cov)
  459. assert_raises(ValueError, multivariate_normal.pdf, (0, 1, 2), mu, cov)
  460. assert_raises(ValueError, multivariate_normal.cdf, (0, 1), mu, cov)
  461. assert_raises(ValueError, multivariate_normal.cdf, (0, 1, 2), mu, cov)
  462. def test_scalar_values(self):
  463. rng = np.random.default_rng(1234)
  464. # When evaluated on scalar data, the pdf should return a scalar
  465. x, mean, cov = 1.5, 1.7, 2.5
  466. pdf = multivariate_normal.pdf(x, mean, cov)
  467. assert_equal(pdf.ndim, 0)
  468. # When evaluated on a single vector, the pdf should return a scalar
  469. x = rng.standard_normal(5)
  470. mean = rng.standard_normal(5)
  471. cov = np.abs(rng.standard_normal(5)) # Diagonal values for cov. matrix
  472. pdf = multivariate_normal.pdf(x, mean, cov)
  473. assert_equal(pdf.ndim, 0)
  474. # When evaluated on scalar data, the cdf should return a scalar
  475. x, mean, cov = 1.5, 1.7, 2.5
  476. cdf = multivariate_normal.cdf(x, mean, cov)
  477. assert_equal(cdf.ndim, 0)
  478. # When evaluated on a single vector, the cdf should return a scalar
  479. x = rng.standard_normal(5)
  480. mean = rng.standard_normal(5)
  481. cov = np.abs(rng.standard_normal(5)) # Diagonal values for cov. matrix
  482. cdf = multivariate_normal.cdf(x, mean, cov)
  483. assert_equal(cdf.ndim, 0)
  484. def test_logpdf(self):
  485. # Check that the log of the pdf is in fact the logpdf
  486. rng = np.random.default_rng(1234)
  487. x = rng.standard_normal(5)
  488. mean = rng.standard_normal(5)
  489. cov = np.abs(rng.standard_normal(5))
  490. d1 = multivariate_normal.logpdf(x, mean, cov)
  491. d2 = multivariate_normal.pdf(x, mean, cov)
  492. assert_allclose(d1, np.log(d2))
  493. def test_logpdf_default_values(self):
  494. # Check that the log of the pdf is in fact the logpdf
  495. # with default parameters Mean=None and cov = 1
  496. rng = np.random.default_rng(1234)
  497. x = rng.standard_normal(5)
  498. d1 = multivariate_normal.logpdf(x)
  499. d2 = multivariate_normal.pdf(x)
  500. # check whether default values are being used
  501. d3 = multivariate_normal.logpdf(x, None, 1)
  502. d4 = multivariate_normal.pdf(x, None, 1)
  503. assert_allclose(d1, np.log(d2))
  504. assert_allclose(d3, np.log(d4))
  505. def test_logcdf(self):
  506. # Check that the log of the cdf is in fact the logcdf
  507. rng = np.random.default_rng(1234)
  508. x = rng.standard_normal(5)
  509. mean = rng.standard_normal(5)
  510. cov = np.abs(rng.standard_normal(5))
  511. d1 = multivariate_normal.logcdf(x, mean, cov)
  512. d2 = multivariate_normal.cdf(x, mean, cov)
  513. assert_allclose(d1, np.log(d2))
  514. def test_logcdf_default_values(self):
  515. # Check that the log of the cdf is in fact the logcdf
  516. # with default parameters Mean=None and cov = 1
  517. rng = np.random.default_rng(1234)
  518. x = rng.standard_normal(5)
  519. d1 = multivariate_normal.logcdf(x)
  520. d2 = multivariate_normal.cdf(x)
  521. # check whether default values are being used
  522. d3 = multivariate_normal.logcdf(x, None, 1)
  523. d4 = multivariate_normal.cdf(x, None, 1)
  524. assert_allclose(d1, np.log(d2))
  525. assert_allclose(d3, np.log(d4))
  526. def test_rank(self):
  527. # Check that the rank is detected correctly.
  528. rng = np.random.default_rng(1234)
  529. n = 4
  530. mean = rng.standard_normal(n)
  531. for expected_rank in range(1, n + 1):
  532. s = rng.standard_normal((n, expected_rank))
  533. cov = np.dot(s, s.T)
  534. distn = multivariate_normal(mean, cov, allow_singular=True)
  535. assert_equal(distn.cov_object.rank, expected_rank)
  536. def test_degenerate_distributions(self):
  537. rng = np.random.default_rng(1234)
  538. for n in range(1, 5):
  539. z = rng.standard_normal(n)
  540. for k in range(1, n):
  541. # Sample a small covariance matrix.
  542. s = rng.standard_normal((k, k))
  543. cov_kk = np.dot(s, s.T)
  544. # Embed the small covariance matrix into a larger singular one.
  545. cov_nn = np.zeros((n, n))
  546. cov_nn[:k, :k] = cov_kk
  547. # Embed part of the vector in the same way
  548. x = np.zeros(n)
  549. x[:k] = z[:k]
  550. # Define a rotation of the larger low rank matrix.
  551. u = _sample_orthonormal_matrix(n)
  552. cov_rr = np.dot(u, np.dot(cov_nn, u.T))
  553. y = np.dot(u, x)
  554. # Check some identities.
  555. distn_kk = multivariate_normal(np.zeros(k), cov_kk,
  556. allow_singular=True)
  557. distn_nn = multivariate_normal(np.zeros(n), cov_nn,
  558. allow_singular=True)
  559. distn_rr = multivariate_normal(np.zeros(n), cov_rr,
  560. allow_singular=True)
  561. assert_equal(distn_kk.cov_object.rank, k)
  562. assert_equal(distn_nn.cov_object.rank, k)
  563. assert_equal(distn_rr.cov_object.rank, k)
  564. pdf_kk = distn_kk.pdf(x[:k])
  565. pdf_nn = distn_nn.pdf(x)
  566. pdf_rr = distn_rr.pdf(y)
  567. assert_allclose(pdf_kk, pdf_nn)
  568. assert_allclose(pdf_kk, pdf_rr)
  569. logpdf_kk = distn_kk.logpdf(x[:k])
  570. logpdf_nn = distn_nn.logpdf(x)
  571. logpdf_rr = distn_rr.logpdf(y)
  572. assert_allclose(logpdf_kk, logpdf_nn)
  573. assert_allclose(logpdf_kk, logpdf_rr)
  574. # Add an orthogonal component and find the density
  575. y_orth = y + u[:, -1]
  576. pdf_rr_orth = distn_rr.pdf(y_orth)
  577. logpdf_rr_orth = distn_rr.logpdf(y_orth)
  578. # Ensure that this has zero probability
  579. assert_equal(pdf_rr_orth, 0.0)
  580. assert_equal(logpdf_rr_orth, -np.inf)
  581. def test_degenerate_array(self):
  582. # Test that we can generate arrays of random variate from a degenerate
  583. # multivariate normal, and that the pdf for these samples is non-zero
  584. # (i.e. samples from the distribution lie on the subspace)
  585. k = 10
  586. for n in range(2, 6):
  587. for r in range(1, n):
  588. mn = np.zeros(n)
  589. u = _sample_orthonormal_matrix(n)[:, :r]
  590. vr = np.dot(u, u.T)
  591. X = multivariate_normal.rvs(mean=mn, cov=vr, size=k)
  592. pdf = multivariate_normal.pdf(X, mean=mn, cov=vr,
  593. allow_singular=True)
  594. assert_equal(pdf.size, k)
  595. assert np.all(pdf > 0.0)
  596. logpdf = multivariate_normal.logpdf(X, mean=mn, cov=vr,
  597. allow_singular=True)
  598. assert_equal(logpdf.size, k)
  599. assert np.all(logpdf > -np.inf)
  600. def test_large_pseudo_determinant(self):
  601. # Check that large pseudo-determinants are handled appropriately.
  602. # Construct a singular diagonal covariance matrix
  603. # whose pseudo determinant overflows double precision.
  604. large_total_log = 1000.0
  605. npos = 100
  606. nzero = 2
  607. large_entry = np.exp(large_total_log / npos)
  608. n = npos + nzero
  609. cov = np.zeros((n, n), dtype=float)
  610. np.fill_diagonal(cov, large_entry)
  611. cov[-nzero:, -nzero:] = 0
  612. # Check some determinants.
  613. assert_equal(scipy.linalg.det(cov), 0)
  614. assert_equal(scipy.linalg.det(cov[:npos, :npos]), np.inf)
  615. assert_allclose(np.linalg.slogdet(cov[:npos, :npos]),
  616. (1, large_total_log))
  617. # Check the pseudo-determinant.
  618. psd = _PSD(cov)
  619. assert_allclose(psd.log_pdet, large_total_log)
  620. def test_broadcasting(self):
  621. rng = np.random.RandomState(1234)
  622. n = 4
  623. # Construct a random covariance matrix.
  624. data = rng.randn(n, n)
  625. cov = np.dot(data, data.T)
  626. mean = rng.randn(n)
  627. # Construct an ndarray which can be interpreted as
  628. # a 2x3 array whose elements are random data vectors.
  629. X = rng.randn(2, 3, n)
  630. # Check that multiple data points can be evaluated at once.
  631. desired_pdf = multivariate_normal.pdf(X, mean, cov)
  632. desired_cdf = multivariate_normal.cdf(X, mean, cov)
  633. for i in range(2):
  634. for j in range(3):
  635. actual = multivariate_normal.pdf(X[i, j], mean, cov)
  636. assert_allclose(actual, desired_pdf[i,j])
  637. # Repeat for cdf
  638. actual = multivariate_normal.cdf(X[i, j], mean, cov)
  639. assert_allclose(actual, desired_cdf[i,j], rtol=1e-3)
  640. def test_normal_1D(self):
  641. # The probability density function for a 1D normal variable should
  642. # agree with the standard normal distribution in scipy.stats.distributions
  643. x = np.linspace(0, 2, 10)
  644. mean, cov = 1.2, 0.9
  645. scale = cov**0.5
  646. d1 = norm.pdf(x, mean, scale)
  647. d2 = multivariate_normal.pdf(x, mean, cov)
  648. assert_allclose(d1, d2)
  649. # The same should hold for the cumulative distribution function
  650. d1 = norm.cdf(x, mean, scale)
  651. d2 = multivariate_normal.cdf(x, mean, cov)
  652. assert_allclose(d1, d2)
  653. def test_marginalization(self):
  654. # Integrating out one of the variables of a 2D Gaussian should
  655. # yield a 1D Gaussian
  656. mean = np.array([2.5, 3.5])
  657. cov = np.array([[.5, 0.2], [0.2, .6]])
  658. n = 2 ** 8 + 1 # Number of samples
  659. delta = 6 / (n - 1) # Grid spacing
  660. v = np.linspace(0, 6, n)
  661. xv, yv = np.meshgrid(v, v)
  662. pos = np.empty((n, n, 2))
  663. pos[:, :, 0] = xv
  664. pos[:, :, 1] = yv
  665. pdf = multivariate_normal.pdf(pos, mean, cov)
  666. # Marginalize over x and y axis
  667. margin_x = romb(pdf, delta, axis=0)
  668. margin_y = romb(pdf, delta, axis=1)
  669. # Compare with standard normal distribution
  670. gauss_x = norm.pdf(v, loc=mean[0], scale=cov[0, 0] ** 0.5)
  671. gauss_y = norm.pdf(v, loc=mean[1], scale=cov[1, 1] ** 0.5)
  672. assert_allclose(margin_x, gauss_x, rtol=1e-2, atol=1e-2)
  673. assert_allclose(margin_y, gauss_y, rtol=1e-2, atol=1e-2)
  674. def test_frozen(self):
  675. # The frozen distribution should agree with the regular one
  676. rng = np.random.default_rng(1234)
  677. x = rng.standard_normal(5)
  678. mean = rng.standard_normal(5)
  679. cov = np.abs(rng.standard_normal(5))
  680. norm_frozen = multivariate_normal(mean, cov)
  681. assert_allclose(norm_frozen.pdf(x), multivariate_normal.pdf(x, mean, cov))
  682. assert_allclose(norm_frozen.logpdf(x),
  683. multivariate_normal.logpdf(x, mean, cov))
  684. assert_allclose(norm_frozen.cdf(x), multivariate_normal.cdf(x, mean, cov))
  685. assert_allclose(norm_frozen.logcdf(x),
  686. multivariate_normal.logcdf(x, mean, cov))
  687. @pytest.mark.parametrize(
  688. 'covariance',
  689. [
  690. np.eye(2),
  691. Covariance.from_diagonal([1, 1]),
  692. ]
  693. )
  694. def test_frozen_multivariate_normal_exposes_attributes(self, covariance):
  695. mean = np.ones((2,))
  696. cov_should_be = np.eye(2)
  697. norm_frozen = multivariate_normal(mean, covariance)
  698. assert np.allclose(norm_frozen.mean, mean)
  699. assert np.allclose(norm_frozen.cov, cov_should_be)
  700. def test_pseudodet_pinv(self):
  701. # Make sure that pseudo-inverse and pseudo-det agree on cutoff
  702. # Assemble random covariance matrix with large and small eigenvalues
  703. rng = np.random.default_rng(1234)
  704. n = 7
  705. x = rng.standard_normal((n, n))
  706. cov = np.dot(x, x.T)
  707. s, u = scipy.linalg.eigh(cov)
  708. s = np.full(n, 0.5)
  709. s[0] = 1.0
  710. s[-1] = 1e-7
  711. cov = np.dot(u, np.dot(np.diag(s), u.T))
  712. # Set cond so that the lowest eigenvalue is below the cutoff
  713. cond = 1e-5
  714. psd = _PSD(cov, cond=cond)
  715. psd_pinv = _PSD(psd.pinv, cond=cond)
  716. # Check that the log pseudo-determinant agrees with the sum
  717. # of the logs of all but the smallest eigenvalue
  718. assert_allclose(psd.log_pdet, np.sum(np.log(s[:-1])))
  719. # Check that the pseudo-determinant of the pseudo-inverse
  720. # agrees with 1 / pseudo-determinant
  721. assert_allclose(-psd.log_pdet, psd_pinv.log_pdet)
  722. def test_exception_nonsquare_cov(self):
  723. cov = [[1, 2, 3], [4, 5, 6]]
  724. assert_raises(ValueError, _PSD, cov)
  725. def test_exception_nonfinite_cov(self):
  726. cov_nan = [[1, 0], [0, np.nan]]
  727. assert_raises(ValueError, _PSD, cov_nan)
  728. cov_inf = [[1, 0], [0, np.inf]]
  729. assert_raises(ValueError, _PSD, cov_inf)
  730. def test_exception_non_psd_cov(self):
  731. cov = [[1, 0], [0, -1]]
  732. assert_raises(ValueError, _PSD, cov)
  733. def test_exception_singular_cov(self):
  734. rng = np.random.default_rng(1234)
  735. x = rng.standard_normal(5)
  736. mean = rng.standard_normal(5)
  737. cov = np.ones((5, 5))
  738. e = np.linalg.LinAlgError
  739. assert_raises(e, multivariate_normal, mean, cov)
  740. assert_raises(e, multivariate_normal.pdf, x, mean, cov)
  741. assert_raises(e, multivariate_normal.logpdf, x, mean, cov)
  742. assert_raises(e, multivariate_normal.cdf, x, mean, cov)
  743. assert_raises(e, multivariate_normal.logcdf, x, mean, cov)
  744. # Message used to be "singular matrix", but this is more accurate.
  745. # See gh-15508
  746. cov = [[1., 0.], [1., 1.]]
  747. msg = "When `allow_singular is False`, the input matrix"
  748. with pytest.raises(np.linalg.LinAlgError, match=msg):
  749. multivariate_normal(cov=cov)
  750. def test_R_values(self):
  751. # Compare the multivariate pdf with some values precomputed
  752. # in R version 3.0.1 (2013-05-16) on Mac OS X 10.6.
  753. # The values below were generated by the following R-script:
  754. # > library(mnormt)
  755. # > x <- seq(0, 2, length=5)
  756. # > y <- 3*x - 2
  757. # > z <- x + cos(y)
  758. # > mu <- c(1, 3, 2)
  759. # > Sigma <- matrix(c(1,2,0,2,5,0.5,0,0.5,3), 3, 3)
  760. # > r_pdf <- dmnorm(cbind(x,y,z), mu, Sigma)
  761. r_pdf = np.array([0.0002214706, 0.0013819953, 0.0049138692,
  762. 0.0103803050, 0.0140250800])
  763. x = np.linspace(0, 2, 5)
  764. y = 3 * x - 2
  765. z = x + np.cos(y)
  766. r = np.array([x, y, z]).T
  767. mean = np.array([1, 3, 2], 'd')
  768. cov = np.array([[1, 2, 0], [2, 5, .5], [0, .5, 3]], 'd')
  769. pdf = multivariate_normal.pdf(r, mean, cov)
  770. assert_allclose(pdf, r_pdf, atol=1e-10)
  771. # Compare the multivariate cdf with some values precomputed
  772. # in R version 3.3.2 (2016-10-31) on Debian GNU/Linux.
  773. # The values below were generated by the following R-script:
  774. # > library(mnormt)
  775. # > x <- seq(0, 2, length=5)
  776. # > y <- 3*x - 2
  777. # > z <- x + cos(y)
  778. # > mu <- c(1, 3, 2)
  779. # > Sigma <- matrix(c(1,2,0,2,5,0.5,0,0.5,3), 3, 3)
  780. # > r_cdf <- pmnorm(cbind(x,y,z), mu, Sigma)
  781. r_cdf = np.array([0.0017866215, 0.0267142892, 0.0857098761,
  782. 0.1063242573, 0.2501068509])
  783. cdf = multivariate_normal.cdf(r, mean, cov)
  784. assert_allclose(cdf, r_cdf, atol=2e-5)
  785. # Also test bivariate cdf with some values precomputed
  786. # in R version 3.3.2 (2016-10-31) on Debian GNU/Linux.
  787. # The values below were generated by the following R-script:
  788. # > library(mnormt)
  789. # > x <- seq(0, 2, length=5)
  790. # > y <- 3*x - 2
  791. # > mu <- c(1, 3)
  792. # > Sigma <- matrix(c(1,2,2,5), 2, 2)
  793. # > r_cdf2 <- pmnorm(cbind(x,y), mu, Sigma)
  794. r_cdf2 = np.array([0.01262147, 0.05838989, 0.18389571,
  795. 0.40696599, 0.66470577])
  796. r2 = np.array([x, y]).T
  797. mean2 = np.array([1, 3], 'd')
  798. cov2 = np.array([[1, 2], [2, 5]], 'd')
  799. cdf2 = multivariate_normal.cdf(r2, mean2, cov2)
  800. assert_allclose(cdf2, r_cdf2, atol=1e-5)
  801. def test_multivariate_normal_rvs_zero_covariance(self):
  802. mean = np.zeros(2)
  803. covariance = np.zeros((2, 2))
  804. model = multivariate_normal(mean, covariance, allow_singular=True)
  805. sample = model.rvs()
  806. assert_equal(sample, [0, 0])
  807. def test_rvs_shape(self):
  808. # Check that rvs parses the mean and covariance correctly, and returns
  809. # an array of the right shape
  810. N = 300
  811. d = 4
  812. sample = multivariate_normal.rvs(mean=np.zeros(d), cov=1, size=N)
  813. assert_equal(sample.shape, (N, d))
  814. sample = multivariate_normal.rvs(mean=None,
  815. cov=np.array([[2, .1], [.1, 1]]),
  816. size=N)
  817. assert_equal(sample.shape, (N, 2))
  818. u = multivariate_normal(mean=0, cov=1)
  819. sample = u.rvs(N)
  820. assert_equal(sample.shape, (N, ))
  821. def test_large_sample(self):
  822. # Generate large sample and compare sample mean and sample covariance
  823. # with mean and covariance matrix.
  824. rng = np.random.RandomState(2846)
  825. n = 3
  826. mean = rng.randn(n)
  827. M = rng.randn(n, n)
  828. cov = np.dot(M, M.T)
  829. size = 5000
  830. sample = multivariate_normal.rvs(mean, cov, size, random_state=rng)
  831. assert_allclose(np.cov(sample.T), cov, rtol=1e-1)
  832. assert_allclose(sample.mean(0), mean, rtol=1e-1)
  833. def test_entropy(self):
  834. rng = np.random.RandomState(2846)
  835. n = 3
  836. mean = rng.randn(n)
  837. M = rng.randn(n, n)
  838. cov = np.dot(M, M.T)
  839. rv = multivariate_normal(mean, cov)
  840. # Check that frozen distribution agrees with entropy function
  841. assert_almost_equal(rv.entropy(), multivariate_normal.entropy(mean, cov))
  842. # Compare entropy with manually computed expression involving
  843. # the sum of the logs of the eigenvalues of the covariance matrix
  844. eigs = np.linalg.eig(cov)[0]
  845. desired = 1 / 2 * (n * (np.log(2 * np.pi) + 1) + np.sum(np.log(eigs)))
  846. assert_almost_equal(desired, rv.entropy())
  847. def test_lnB(self):
  848. alpha = np.array([1, 1, 1])
  849. desired = .5 # e^lnB = 1/2 for [1, 1, 1]
  850. assert_almost_equal(np.exp(_lnB(alpha)), desired)
  851. def test_cdf_with_lower_limit_arrays(self):
  852. # test CDF with lower limit in several dimensions
  853. rng = np.random.default_rng(2408071309372769818)
  854. mean = [0, 0]
  855. cov = np.eye(2)
  856. a = rng.random((4, 3, 2))*6 - 3
  857. b = rng.random((4, 3, 2))*6 - 3
  858. cdf1 = multivariate_normal.cdf(b, mean, cov, lower_limit=a)
  859. cdf2a = multivariate_normal.cdf(b, mean, cov)
  860. cdf2b = multivariate_normal.cdf(a, mean, cov)
  861. ab1 = np.concatenate((a[..., 0:1], b[..., 1:2]), axis=-1)
  862. ab2 = np.concatenate((a[..., 1:2], b[..., 0:1]), axis=-1)
  863. cdf2ab1 = multivariate_normal.cdf(ab1, mean, cov)
  864. cdf2ab2 = multivariate_normal.cdf(ab2, mean, cov)
  865. cdf2 = cdf2a + cdf2b - cdf2ab1 - cdf2ab2
  866. assert_allclose(cdf1, cdf2)
  867. def test_cdf_with_lower_limit_consistency(self):
  868. # check that multivariate normal CDF functions are consistent
  869. rng = np.random.default_rng(2408071309372769818)
  870. mean = rng.random(3)
  871. cov = rng.random((3, 3))
  872. cov = cov @ cov.T
  873. a = rng.random((2, 3))*6 - 3
  874. b = rng.random((2, 3))*6 - 3
  875. cdf1 = multivariate_normal.cdf(b, mean, cov, lower_limit=a)
  876. cdf2 = multivariate_normal(mean, cov).cdf(b, lower_limit=a)
  877. cdf3 = np.exp(multivariate_normal.logcdf(b, mean, cov, lower_limit=a))
  878. cdf4 = np.exp(multivariate_normal(mean, cov).logcdf(b, lower_limit=a))
  879. assert_allclose(cdf2, cdf1, rtol=1e-4)
  880. assert_allclose(cdf3, cdf1, rtol=1e-4)
  881. assert_allclose(cdf4, cdf1, rtol=1e-4)
  882. def test_cdf_signs(self):
  883. # check that sign of output is correct when np.any(lower > x)
  884. mean = np.zeros(3)
  885. cov = np.eye(3)
  886. b = [[1, 1, 1], [0, 0, 0], [1, 0, 1], [0, 1, 0]]
  887. a = [[0, 0, 0], [1, 1, 1], [0, 1, 0], [1, 0, 1]]
  888. # when odd number of elements of b < a, output is negative
  889. expected_signs = np.array([1, -1, -1, 1])
  890. cdf = multivariate_normal.cdf(b, mean, cov, lower_limit=a)
  891. assert_allclose(cdf, cdf[0]*expected_signs)
  892. @pytest.mark.slow
  893. @pytest.mark.parametrize("ndim", [2, 3])
  894. def test_cdf_vs_cubature(self, ndim):
  895. rng = np.random.default_rng(123)
  896. a = rng.uniform(size=(ndim, ndim))
  897. cov = a.T @ a
  898. m = rng.uniform(size=ndim)
  899. dist = multivariate_normal(mean=m, cov=cov)
  900. x = rng.uniform(low=-3, high=3, size=(ndim,))
  901. cdf = dist.cdf(x)
  902. dist_i = multivariate_normal(mean=[0]*ndim, cov=cov)
  903. cdf_i = cubature(dist_i.pdf, [-np.inf]*ndim, x - m).estimate
  904. assert_allclose(cdf, cdf_i, atol=5e-6)
  905. def test_cdf_known(self):
  906. # https://github.com/scipy/scipy/pull/17410#issuecomment-1312628547
  907. for ndim in range(2, 12):
  908. cov = np.full((ndim, ndim), 0.5)
  909. np.fill_diagonal(cov, 1.)
  910. dist = multivariate_normal([0]*ndim, cov=cov)
  911. assert_allclose(
  912. dist.cdf([0]*ndim),
  913. 1. / (1. + ndim),
  914. atol=5e-5
  915. )
  916. @pytest.mark.parametrize("ndim", range(2, 10))
  917. @pytest.mark.parametrize("seed", [0xdeadbeef, 0xdd24528764c9773579731c6b022b48e2])
  918. def test_cdf_vs_univariate(self, seed, ndim):
  919. rng = np.random.default_rng(seed)
  920. case = MVNProblem.generate_semigeneral(ndim=ndim, rng=rng)
  921. assert (case.low == -np.inf).all()
  922. dist = multivariate_normal(mean=[0]*ndim, cov=case.covar)
  923. cdf_val = dist.cdf(case.high, rng=rng)
  924. assert_allclose(cdf_val, case.target_val, atol=5e-5)
  925. @pytest.mark.parametrize("ndim", range(2, 11))
  926. @pytest.mark.parametrize("seed", [0xdeadbeef, 0xdd24528764c9773579731c6b022b48e2])
  927. def test_cdf_vs_univariate_2(self, seed, ndim):
  928. rng = np.random.default_rng(seed)
  929. case = MVNProblem.generate_constant(ndim=ndim, rng=rng)
  930. assert (case.low == -np.inf).all()
  931. dist = multivariate_normal(mean=[0]*ndim, cov=case.covar)
  932. cdf_val = dist.cdf(case.high, rng=rng)
  933. assert_allclose(cdf_val, case.target_val, atol=5e-5)
  934. @pytest.mark.parametrize("ndim", range(4, 11))
  935. @pytest.mark.parametrize("seed", [0xdeadbeef, 0xdd24528764c9773579731c6b022b48e4])
  936. def test_cdf_vs_univariate_singular(self, seed, ndim):
  937. # NB: ndim = 2, 3 has much poorer accuracy than ndim > 3 for many seeds.
  938. # No idea why.
  939. rng = np.random.default_rng(seed)
  940. case = SingularMVNProblem.generate_semiinfinite(ndim=ndim, rng=rng)
  941. assert (case.low == -np.inf).all()
  942. dist = multivariate_normal(mean=[0]*ndim, cov=case.covar, allow_singular=True,
  943. # default maxpts is too slow, limit it here
  944. maxpts=10_000*case.covar.shape[0]
  945. )
  946. cdf_val = dist.cdf(case.high, rng=rng)
  947. assert_allclose(cdf_val, case.target_val, atol=1e-3)
  948. def test_mean_cov(self):
  949. # test the interaction between a Covariance object and mean
  950. P = np.diag(1 / np.array([1, 2, 3]))
  951. cov_object = _covariance.CovViaPrecision(P)
  952. message = "`cov` represents a covariance matrix in 3 dimensions..."
  953. with pytest.raises(ValueError, match=message):
  954. multivariate_normal.entropy([0, 0], cov_object)
  955. with pytest.raises(ValueError, match=message):
  956. multivariate_normal([0, 0], cov_object)
  957. x = [0.5, 0.5, 0.5]
  958. ref = multivariate_normal.pdf(x, [0, 0, 0], cov_object)
  959. assert_equal(multivariate_normal.pdf(x, cov=cov_object), ref)
  960. ref = multivariate_normal.pdf(x, [1, 1, 1], cov_object)
  961. assert_equal(multivariate_normal.pdf(x, 1, cov=cov_object), ref)
  962. def test_fit_wrong_fit_data_shape(self):
  963. data = [1, 3]
  964. error_msg = "`x` must be two-dimensional."
  965. with pytest.raises(ValueError, match=error_msg):
  966. multivariate_normal.fit(data)
  967. @pytest.mark.parametrize('dim', (3, 5))
  968. def test_fit_correctness(self, dim):
  969. rng = np.random.default_rng(4385269356937404)
  970. x = rng.random((100, dim))
  971. mean_est, cov_est = multivariate_normal.fit(x)
  972. mean_ref, cov_ref = np.mean(x, axis=0), np.cov(x.T, ddof=0)
  973. assert_allclose(mean_est, mean_ref, atol=1e-15)
  974. assert_allclose(cov_est, cov_ref, rtol=1e-15)
  975. def test_fit_both_parameters_fixed(self):
  976. data = np.full((2, 1), 3)
  977. mean_fixed = 1.
  978. cov_fixed = np.atleast_2d(1.)
  979. mean, cov = multivariate_normal.fit(data, fix_mean=mean_fixed,
  980. fix_cov=cov_fixed)
  981. assert_equal(mean, mean_fixed)
  982. assert_equal(cov, cov_fixed)
  983. @pytest.mark.parametrize('fix_mean', [np.zeros((2, 2)),
  984. np.zeros((3, ))])
  985. def test_fit_fix_mean_input_validation(self, fix_mean):
  986. msg = ("`fix_mean` must be a one-dimensional array the same "
  987. "length as the dimensionality of the vectors `x`.")
  988. with pytest.raises(ValueError, match=msg):
  989. multivariate_normal.fit(np.eye(2), fix_mean=fix_mean)
  990. @pytest.mark.parametrize('fix_cov', [np.zeros((2, )),
  991. np.zeros((3, 2)),
  992. np.zeros((4, 4))])
  993. def test_fit_fix_cov_input_validation_dimension(self, fix_cov):
  994. msg = ("`fix_cov` must be a two-dimensional square array "
  995. "of same side length as the dimensionality of the "
  996. "vectors `x`.")
  997. with pytest.raises(ValueError, match=msg):
  998. multivariate_normal.fit(np.eye(3), fix_cov=fix_cov)
  999. def test_fit_fix_cov_not_positive_semidefinite(self):
  1000. error_msg = "`fix_cov` must be symmetric positive semidefinite."
  1001. with pytest.raises(ValueError, match=error_msg):
  1002. fix_cov = np.array([[1., 0.], [0., -1.]])
  1003. multivariate_normal.fit(np.eye(2), fix_cov=fix_cov)
  1004. def test_fit_fix_mean(self):
  1005. rng = np.random.default_rng(4385269356937404)
  1006. loc = rng.random(3)
  1007. A = rng.random((3, 3))
  1008. cov = np.dot(A, A.T)
  1009. samples = multivariate_normal.rvs(mean=loc, cov=cov, size=100,
  1010. random_state=rng)
  1011. mean_free, cov_free = multivariate_normal.fit(samples)
  1012. logp_free = multivariate_normal.logpdf(samples, mean=mean_free,
  1013. cov=cov_free).sum()
  1014. mean_fix, cov_fix = multivariate_normal.fit(samples, fix_mean=loc)
  1015. assert_equal(mean_fix, loc)
  1016. logp_fix = multivariate_normal.logpdf(samples, mean=mean_fix,
  1017. cov=cov_fix).sum()
  1018. # test that fixed parameters result in lower likelihood than free
  1019. # parameters
  1020. assert logp_fix < logp_free
  1021. # test that a small perturbation of the resulting parameters
  1022. # has lower likelihood than the estimated parameters
  1023. A = rng.random((3, 3))
  1024. m = 1e-8 * np.dot(A, A.T)
  1025. cov_perturbed = cov_fix + m
  1026. logp_perturbed = (multivariate_normal.logpdf(samples,
  1027. mean=mean_fix,
  1028. cov=cov_perturbed)
  1029. ).sum()
  1030. assert logp_perturbed < logp_fix
  1031. def test_fit_fix_cov(self):
  1032. rng = np.random.default_rng(4385269356937404)
  1033. loc = rng.random(3)
  1034. A = rng.random((3, 3))
  1035. cov = np.dot(A, A.T)
  1036. samples = multivariate_normal.rvs(mean=loc, cov=cov,
  1037. size=100, random_state=rng)
  1038. mean_free, cov_free = multivariate_normal.fit(samples)
  1039. logp_free = multivariate_normal.logpdf(samples, mean=mean_free,
  1040. cov=cov_free).sum()
  1041. mean_fix, cov_fix = multivariate_normal.fit(samples, fix_cov=cov)
  1042. assert_equal(mean_fix, np.mean(samples, axis=0))
  1043. assert_equal(cov_fix, cov)
  1044. logp_fix = multivariate_normal.logpdf(samples, mean=mean_fix,
  1045. cov=cov_fix).sum()
  1046. # test that fixed parameters result in lower likelihood than free
  1047. # parameters
  1048. assert logp_fix < logp_free
  1049. # test that a small perturbation of the resulting parameters
  1050. # has lower likelihood than the estimated parameters
  1051. mean_perturbed = mean_fix + 1e-8 * rng.random(3)
  1052. logp_perturbed = (multivariate_normal.logpdf(samples,
  1053. mean=mean_perturbed,
  1054. cov=cov_fix)
  1055. ).sum()
  1056. assert logp_perturbed < logp_fix
  1057. class TestMarginal:
  1058. @pytest.mark.parametrize('dist,kwargs', [(multivariate_normal, {}),
  1059. (multivariate_t, {'df': 4})])
  1060. @pytest.mark.parametrize('X_ndim', [3])
  1061. @pytest.mark.parametrize('dimensions', [[1], [-1, 1]])
  1062. @pytest.mark.parametrize('frozen', [True, False])
  1063. @pytest.mark.parametrize('cov_object', [True, False])
  1064. def test_marginal_distribution(self, dist, X_ndim, dimensions, frozen,
  1065. cov_object, kwargs):
  1066. rng = np.random.default_rng(413911473)
  1067. loc = rng.standard_normal(X_ndim)
  1068. A = rng.standard_normal((X_ndim, X_ndim))
  1069. scale = A @ A.T
  1070. if cov_object and dist == multivariate_t:
  1071. pytest.skip('`multivariate_t` does not accept a `Covariance` object')
  1072. elif cov_object:
  1073. scale = _covariance.CovViaPrecision(scale)
  1074. # number of points at which to evaluate marginal PDF
  1075. x = np.random.standard_normal((4, len(dimensions)))
  1076. X = dist(loc, scale, **kwargs)
  1077. if frozen:
  1078. Y = X.marginal(dimensions)
  1079. res = Y.pdf(x)
  1080. else:
  1081. Y = dist.marginal(dimensions, loc, scale, **kwargs)
  1082. res = Y.pdf(x)
  1083. ref = marginal_pdf(X, X_ndim, dimensions, x)
  1084. assert_allclose(ref, res)
  1085. @pytest.mark.parametrize('dist', [multivariate_normal, multivariate_t])
  1086. def test_marginal_input_validation(self, dist):
  1087. rng = np.random.default_rng(413911473)
  1088. mean = rng.standard_normal(3)
  1089. A = rng.standard_normal((3, 3))
  1090. cov = A @ A.T
  1091. X = dist(mean, cov)
  1092. msg = r"Dimensions \[3\] are invalid .*"
  1093. with pytest.raises(ValueError, match=msg):
  1094. X.marginal(3)
  1095. with pytest.raises(ValueError, match=msg):
  1096. X.marginal([0, 1, 2, 3])
  1097. msg = r"All elements of `dimensions` must be unique."
  1098. with pytest.raises(ValueError, match=msg):
  1099. X.marginal([2, -1])
  1100. with pytest.raises(ValueError, match=msg):
  1101. X.marginal([[0, 1]])
  1102. msg = r"Elements of `dimensions` must be integers."
  1103. with pytest.raises(ValueError, match=msg):
  1104. X.marginal([1.1, 2.0])
  1105. @pytest.mark.parametrize('dist', [multivariate_normal, multivariate_t])
  1106. def test_marginal_special_cases(self, dist):
  1107. rng = np.random.default_rng(413911473)
  1108. loc = rng.standard_normal(3)
  1109. A = rng.standard_normal((3, 3))
  1110. scale = A @ A.T
  1111. X = dist(loc, scale)
  1112. msg = r"Cannot marginalize all dimensions."
  1113. with pytest.raises(ValueError, match=msg):
  1114. X.marginal([])
  1115. class TestMatrixNormal:
  1116. def test_bad_input(self):
  1117. # Check that bad inputs raise errors
  1118. num_rows = 4
  1119. num_cols = 3
  1120. M = np.full((num_rows,num_cols), 0.3)
  1121. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1122. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1123. # Incorrect dimensions
  1124. assert_raises(ValueError, matrix_normal, np.zeros((5,4,3)))
  1125. assert_raises(ValueError, matrix_normal, M, np.zeros(10), V)
  1126. assert_raises(ValueError, matrix_normal, M, U, np.zeros(10))
  1127. assert_raises(ValueError, matrix_normal, M, U, U)
  1128. assert_raises(ValueError, matrix_normal, M, V, V)
  1129. assert_raises(ValueError, matrix_normal, M.T, U, V)
  1130. e = np.linalg.LinAlgError
  1131. # Singular covariance for the rvs method of a non-frozen instance
  1132. assert_raises(e, matrix_normal.rvs,
  1133. M, U, np.ones((num_cols, num_cols)))
  1134. assert_raises(e, matrix_normal.rvs,
  1135. M, np.ones((num_rows, num_rows)), V)
  1136. # Singular covariance for a frozen instance
  1137. assert_raises(e, matrix_normal, M, U, np.ones((num_cols, num_cols)))
  1138. assert_raises(e, matrix_normal, M, np.ones((num_rows, num_rows)), V)
  1139. def test_default_inputs(self):
  1140. # Check that default argument handling works
  1141. num_rows = 4
  1142. num_cols = 3
  1143. M = np.full((num_rows,num_cols), 0.3)
  1144. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1145. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1146. Z = np.zeros((num_rows, num_cols))
  1147. Zr = np.zeros((num_rows, 1))
  1148. Zc = np.zeros((1, num_cols))
  1149. Ir = np.identity(num_rows)
  1150. Ic = np.identity(num_cols)
  1151. I1 = np.identity(1)
  1152. assert_equal(matrix_normal.rvs(mean=M, rowcov=U, colcov=V).shape,
  1153. (num_rows, num_cols))
  1154. assert_equal(matrix_normal.rvs(mean=M).shape,
  1155. (num_rows, num_cols))
  1156. assert_equal(matrix_normal.rvs(rowcov=U).shape,
  1157. (num_rows, 1))
  1158. assert_equal(matrix_normal.rvs(colcov=V).shape,
  1159. (1, num_cols))
  1160. assert_equal(matrix_normal.rvs(mean=M, colcov=V).shape,
  1161. (num_rows, num_cols))
  1162. assert_equal(matrix_normal.rvs(mean=M, rowcov=U).shape,
  1163. (num_rows, num_cols))
  1164. assert_equal(matrix_normal.rvs(rowcov=U, colcov=V).shape,
  1165. (num_rows, num_cols))
  1166. assert_equal(matrix_normal(mean=M).rowcov, Ir)
  1167. assert_equal(matrix_normal(mean=M).colcov, Ic)
  1168. assert_equal(matrix_normal(rowcov=U).mean, Zr)
  1169. assert_equal(matrix_normal(rowcov=U).colcov, I1)
  1170. assert_equal(matrix_normal(colcov=V).mean, Zc)
  1171. assert_equal(matrix_normal(colcov=V).rowcov, I1)
  1172. assert_equal(matrix_normal(mean=M, rowcov=U).colcov, Ic)
  1173. assert_equal(matrix_normal(mean=M, colcov=V).rowcov, Ir)
  1174. assert_equal(matrix_normal(rowcov=U, colcov=V).mean, Z)
  1175. def test_covariance_expansion(self):
  1176. # Check that covariance can be specified with scalar or vector
  1177. num_rows = 4
  1178. num_cols = 3
  1179. M = np.full((num_rows, num_cols), 0.3)
  1180. Uv = np.full(num_rows, 0.2)
  1181. Us = 0.2
  1182. Vv = np.full(num_cols, 0.1)
  1183. Vs = 0.1
  1184. Ir = np.identity(num_rows)
  1185. Ic = np.identity(num_cols)
  1186. assert_equal(matrix_normal(mean=M, rowcov=Uv, colcov=Vv).rowcov,
  1187. 0.2*Ir)
  1188. assert_equal(matrix_normal(mean=M, rowcov=Uv, colcov=Vv).colcov,
  1189. 0.1*Ic)
  1190. assert_equal(matrix_normal(mean=M, rowcov=Us, colcov=Vs).rowcov,
  1191. 0.2*Ir)
  1192. assert_equal(matrix_normal(mean=M, rowcov=Us, colcov=Vs).colcov,
  1193. 0.1*Ic)
  1194. def test_frozen_matrix_normal(self):
  1195. for i in range(1,5):
  1196. for j in range(1,5):
  1197. M = np.full((i,j), 0.3)
  1198. U = 0.5 * np.identity(i) + np.full((i,i), 0.5)
  1199. V = 0.7 * np.identity(j) + np.full((j,j), 0.3)
  1200. frozen = matrix_normal(mean=M, rowcov=U, colcov=V)
  1201. rvs1 = frozen.rvs(random_state=1234)
  1202. rvs2 = matrix_normal.rvs(mean=M, rowcov=U, colcov=V,
  1203. random_state=1234)
  1204. assert_equal(rvs1, rvs2)
  1205. X = frozen.rvs(random_state=1234)
  1206. pdf1 = frozen.pdf(X)
  1207. pdf2 = matrix_normal.pdf(X, mean=M, rowcov=U, colcov=V)
  1208. assert_equal(pdf1, pdf2)
  1209. logpdf1 = frozen.logpdf(X)
  1210. logpdf2 = matrix_normal.logpdf(X, mean=M, rowcov=U, colcov=V)
  1211. assert_equal(logpdf1, logpdf2)
  1212. def test_matches_multivariate(self):
  1213. # Check that the pdfs match those obtained by vectorising and
  1214. # treating as a multivariate normal.
  1215. for i in range(1,5):
  1216. for j in range(1,5):
  1217. M = np.full((i,j), 0.3)
  1218. U = 0.5 * np.identity(i) + np.full((i,i), 0.5)
  1219. V = 0.7 * np.identity(j) + np.full((j,j), 0.3)
  1220. frozen = matrix_normal(mean=M, rowcov=U, colcov=V)
  1221. X = frozen.rvs(random_state=1234)
  1222. pdf1 = frozen.pdf(X)
  1223. logpdf1 = frozen.logpdf(X)
  1224. entropy1 = frozen.entropy()
  1225. vecX = X.T.flatten()
  1226. vecM = M.T.flatten()
  1227. cov = np.kron(V,U)
  1228. pdf2 = multivariate_normal.pdf(vecX, mean=vecM, cov=cov)
  1229. logpdf2 = multivariate_normal.logpdf(vecX, mean=vecM, cov=cov)
  1230. entropy2 = multivariate_normal.entropy(mean=vecM, cov=cov)
  1231. assert_allclose(pdf1, pdf2, rtol=1E-10)
  1232. assert_allclose(logpdf1, logpdf2, rtol=1E-10)
  1233. assert_allclose(entropy1, entropy2)
  1234. def test_array_input(self):
  1235. # Check array of inputs has the same output as the separate entries.
  1236. num_rows = 4
  1237. num_cols = 3
  1238. M = np.full((num_rows,num_cols), 0.3)
  1239. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1240. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1241. N = 10
  1242. frozen = matrix_normal(mean=M, rowcov=U, colcov=V)
  1243. X1 = frozen.rvs(size=N, random_state=1234)
  1244. X2 = frozen.rvs(size=N, random_state=4321)
  1245. X = np.concatenate((X1[np.newaxis,:,:,:],X2[np.newaxis,:,:,:]), axis=0)
  1246. assert_equal(X.shape, (2, N, num_rows, num_cols))
  1247. array_logpdf = frozen.logpdf(X)
  1248. assert_equal(array_logpdf.shape, (2, N))
  1249. for i in range(2):
  1250. for j in range(N):
  1251. separate_logpdf = matrix_normal.logpdf(X[i,j], mean=M,
  1252. rowcov=U, colcov=V)
  1253. assert_allclose(separate_logpdf, array_logpdf[i,j], 1E-10)
  1254. def test_moments(self):
  1255. # Check that the sample moments match the parameters
  1256. num_rows = 4
  1257. num_cols = 3
  1258. M = np.full((num_rows,num_cols), 0.3)
  1259. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1260. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1261. N = 1000
  1262. frozen = matrix_normal(mean=M, rowcov=U, colcov=V)
  1263. X = frozen.rvs(size=N, random_state=1234)
  1264. sample_mean = np.mean(X,axis=0)
  1265. assert_allclose(sample_mean, M, atol=0.1)
  1266. sample_colcov = np.cov(X.reshape(N*num_rows,num_cols).T)
  1267. assert_allclose(sample_colcov, V, atol=0.1)
  1268. sample_rowcov = np.cov(np.swapaxes(X,1,2).reshape(
  1269. N*num_cols,num_rows).T)
  1270. assert_allclose(sample_rowcov, U, atol=0.1)
  1271. def test_samples(self):
  1272. # Regression test to ensure that we always generate the same stream of
  1273. # random variates.
  1274. actual = matrix_normal.rvs(
  1275. mean=np.array([[1, 2], [3, 4]]),
  1276. rowcov=np.array([[4, -1], [-1, 2]]),
  1277. colcov=np.array([[5, 1], [1, 10]]),
  1278. random_state=np.random.default_rng(0),
  1279. size=2
  1280. )
  1281. expected = np.array(
  1282. [[[1.56228264238181, -1.24136424071189],
  1283. [2.46865788392114, 6.22964440489445]],
  1284. [[3.86405716144353, 10.73714311429529],
  1285. [2.59428444080606, 5.79987854490876]]]
  1286. )
  1287. assert_allclose(actual, expected)
  1288. class TestMatrixT:
  1289. def test_bad_input(self):
  1290. # Check that bad inputs raise errors
  1291. num_rows = 4
  1292. num_cols = 3
  1293. df = 5
  1294. M = np.full((num_rows, num_cols), 0.3)
  1295. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1296. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1297. # Nonpositive degrees of freedom
  1298. with pytest.raises(ValueError, match="Degrees of freedom must be positive."):
  1299. matrix_t(df=0)
  1300. # Incorrect dimensions
  1301. with pytest.raises(ValueError, match="Array `mean` must be 2D."):
  1302. matrix_t(mean=np.zeros((5, 4, 3)))
  1303. with pytest.raises(ValueError, match="Array `mean` has invalid shape."):
  1304. matrix_t(mean=np.zeros((4, 3, 0)))
  1305. with pytest.raises(ValueError, match="Array `row_spread` has invalid shape."):
  1306. matrix_t(row_spread=np.ones((1, 0)))
  1307. with pytest.raises(
  1308. ValueError, match="Array `row_spread` must be a scalar or a 2D array."
  1309. ):
  1310. matrix_t(row_spread=np.ones((1, 2, 3)))
  1311. with pytest.raises(ValueError, match="Array `row_spread` must be square."):
  1312. matrix_t(row_spread=np.ones((1, 2)))
  1313. with pytest.raises(ValueError, match="Array `col_spread` has invalid shape."):
  1314. matrix_t(col_spread=np.ones((1, 0)))
  1315. with pytest.raises(
  1316. ValueError, match="Array `col_spread` must be a scalar or a 2D array."
  1317. ):
  1318. matrix_t(col_spread=np.ones((1, 2, 3)))
  1319. with pytest.raises(ValueError, match="Array `col_spread` must be square."):
  1320. matrix_t(col_spread=np.ones((1, 2)))
  1321. with pytest.raises(
  1322. ValueError,
  1323. match="Arrays `mean` and `row_spread` must have the same number "
  1324. "of rows.",
  1325. ):
  1326. matrix_t(mean=M, row_spread=V)
  1327. with pytest.raises(
  1328. ValueError,
  1329. match="Arrays `mean` and `col_spread` must have the same number "
  1330. "of columns.",
  1331. ):
  1332. matrix_t(mean=M, col_spread=U)
  1333. # Incorrect dimension of input matrix
  1334. with pytest.raises(
  1335. ValueError,
  1336. match="The shape of array `X` is not conformal with "
  1337. "the distribution parameters.",
  1338. ):
  1339. matrix_t.pdf(X=np.zeros((num_rows, num_rows)), mean=M)
  1340. # Singular covariance for a non-frozen instance
  1341. with pytest.raises(
  1342. np.linalg.LinAlgError,
  1343. match="2-th leading minor of the array is not positive definite",
  1344. ):
  1345. matrix_t.rvs(M, U, np.ones((num_cols, num_cols)), df)
  1346. with pytest.raises(
  1347. np.linalg.LinAlgError,
  1348. match="2-th leading minor of the array is not positive definite",
  1349. ):
  1350. matrix_t.rvs(M, np.ones((num_rows, num_rows)), V, df)
  1351. # Singular covariance for a frozen instance
  1352. with pytest.raises(
  1353. np.linalg.LinAlgError,
  1354. match="When `allow_singular is False`, the input matrix must be "
  1355. "symmetric positive definite.",
  1356. ):
  1357. matrix_t(M, U, np.ones((num_cols, num_cols)), df)
  1358. with pytest.raises(
  1359. np.linalg.LinAlgError,
  1360. match="When `allow_singular is False`, the input matrix must be "
  1361. "symmetric positive definite.",
  1362. ):
  1363. matrix_t(M, np.ones((num_rows, num_rows)), V, df)
  1364. def test_default_inputs(self):
  1365. # Check that default argument handling works
  1366. num_rows = 4
  1367. num_cols = 3
  1368. df = 5
  1369. M = np.full((num_rows, num_cols), 0.3)
  1370. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1371. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1372. Z = np.zeros((num_rows, num_cols))
  1373. Zr = np.zeros((num_rows, 1))
  1374. Zc = np.zeros((1, num_cols))
  1375. Ir = np.identity(num_rows)
  1376. Ic = np.identity(num_cols)
  1377. I1 = np.identity(1)
  1378. dfdefault = 1
  1379. assert_equal(
  1380. matrix_t.rvs(mean=M, row_spread=U, col_spread=V, df=df).shape,
  1381. (num_rows, num_cols),
  1382. )
  1383. assert_equal(matrix_t.rvs(mean=M).shape, (num_rows, num_cols))
  1384. assert_equal(matrix_t.rvs(row_spread=U).shape, (num_rows, 1))
  1385. assert_equal(matrix_t.rvs(col_spread=V).shape, (1, num_cols))
  1386. assert_equal(matrix_t.rvs(mean=M, col_spread=V).shape, (num_rows, num_cols))
  1387. assert_equal(matrix_t.rvs(mean=M, row_spread=U).shape, (num_rows, num_cols))
  1388. assert_equal(
  1389. matrix_t.rvs(row_spread=U, col_spread=V).shape, (num_rows, num_cols)
  1390. )
  1391. assert_equal(matrix_t().df, dfdefault)
  1392. assert_equal(matrix_t(mean=M).row_spread, Ir)
  1393. assert_equal(matrix_t(mean=M).col_spread, Ic)
  1394. assert_equal(matrix_t(row_spread=U).mean, Zr)
  1395. assert_equal(matrix_t(row_spread=U).col_spread, I1)
  1396. assert_equal(matrix_t(col_spread=V).mean, Zc)
  1397. assert_equal(matrix_t(col_spread=V).row_spread, I1)
  1398. assert_equal(matrix_t(mean=M, row_spread=U).col_spread, Ic)
  1399. assert_equal(matrix_t(mean=M, col_spread=V).row_spread, Ir)
  1400. assert_equal(matrix_t(row_spread=U, col_spread=V, df=df).mean, Z)
  1401. def test_covariance_expansion(self):
  1402. # Check that covariance can be specified with scalar or vector
  1403. num_rows = 4
  1404. num_cols = 3
  1405. df = 1
  1406. M = np.full((num_rows, num_cols), 0.3)
  1407. Uv = np.full(num_rows, 0.2)
  1408. Us = 0.2
  1409. Vv = np.full(num_cols, 0.1)
  1410. Vs = 0.1
  1411. Ir = np.identity(num_rows)
  1412. Ic = np.identity(num_cols)
  1413. assert_equal(
  1414. matrix_t(mean=M, row_spread=Uv, col_spread=Vv, df=df).row_spread, 0.2 * Ir
  1415. )
  1416. assert_equal(
  1417. matrix_t(mean=M, row_spread=Uv, col_spread=Vv, df=df).col_spread, 0.1 * Ic
  1418. )
  1419. assert_equal(
  1420. matrix_t(mean=M, row_spread=Us, col_spread=Vs, df=df).row_spread, 0.2 * Ir
  1421. )
  1422. assert_equal(
  1423. matrix_t(mean=M, row_spread=Us, col_spread=Vs, df=df).col_spread, 0.1 * Ic
  1424. )
  1425. @pytest.mark.parametrize("i", range(1, 4))
  1426. @pytest.mark.parametrize("j", range(1, 4))
  1427. def test_frozen_matrix_t(self, i, j):
  1428. M = np.full((i, j), 0.3)
  1429. U = 0.5 * np.identity(i) + np.full((i, i), 0.5)
  1430. V = 0.7 * np.identity(j) + np.full((j, j), 0.3)
  1431. df = i + j
  1432. frozen = matrix_t(mean=M, row_spread=U, col_spread=V, df=df)
  1433. rvs1 = frozen.rvs(random_state=1234)
  1434. rvs2 = matrix_t.rvs(
  1435. mean=M, row_spread=U, col_spread=V, df=df, random_state=1234
  1436. )
  1437. assert_equal(rvs1, rvs2)
  1438. X = frozen.rvs(random_state=1234)
  1439. pdf1 = frozen.pdf(X)
  1440. pdf2 = matrix_t.pdf(X, mean=M, row_spread=U, col_spread=V, df=df)
  1441. assert_equal(pdf1, pdf2)
  1442. logpdf1 = frozen.logpdf(X)
  1443. logpdf2 = matrix_t.logpdf(X, mean=M, row_spread=U, col_spread=V, df=df)
  1444. assert_equal(logpdf1, logpdf2)
  1445. def test_array_input(self):
  1446. # Check array of inputs has the same output as the separate entries.
  1447. num_rows = 4
  1448. num_cols = 3
  1449. M = np.full((num_rows, num_cols), 0.3)
  1450. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1451. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1452. df = 1
  1453. N = 10
  1454. frozen = matrix_t(mean=M, row_spread=U, col_spread=V, df=df)
  1455. X1 = frozen.rvs(size=N, random_state=1234)
  1456. X2 = frozen.rvs(size=N, random_state=4321)
  1457. X = np.concatenate((X1[np.newaxis, :, :, :], X2[np.newaxis, :, :, :]), axis=0)
  1458. assert_equal(X.shape, (2, N, num_rows, num_cols))
  1459. array_logpdf = frozen.logpdf(X)
  1460. logpdf_shape = array_logpdf.shape
  1461. assert_equal(logpdf_shape, (2, N))
  1462. for i in range(2):
  1463. for j in range(N):
  1464. separate_logpdf = matrix_t.logpdf(
  1465. X[i, j], mean=M, row_spread=U, col_spread=V, df=df
  1466. )
  1467. assert_allclose(separate_logpdf, array_logpdf[i, j], 1e-10)
  1468. @staticmethod
  1469. def relative_error(vec1: np.ndarray, vec2: np.ndarray):
  1470. numerator = np.linalg.norm(vec1 - vec2) ** 2
  1471. denominator = np.linalg.norm(vec1) ** 2 + np.linalg.norm(vec2) ** 2
  1472. return numerator / denominator
  1473. @staticmethod
  1474. def matrix_divergence(mat_true: np.ndarray, mat_est: np.ndarray) -> float:
  1475. mat_true_psd = _PSD(mat_true, allow_singular=False)
  1476. mat_est_psd = _PSD(mat_est, allow_singular=False)
  1477. if (np.exp(mat_est_psd.log_pdet) <= 0) or (np.exp(mat_true_psd.log_pdet) <= 0):
  1478. return np.inf
  1479. trace_term = np.trace(mat_est_psd.pinv @ mat_true)
  1480. log_detratio = mat_est_psd.log_pdet - mat_true_psd.log_pdet
  1481. return (trace_term + log_detratio - len(mat_true)) / 2
  1482. @staticmethod
  1483. def vec(a_mat: np.ndarray) -> np.ndarray:
  1484. """
  1485. For an (m,n) array `a_mat` the output `vec(a_mat)` is an (m*n, 1)
  1486. array formed by stacking the columns of `a_mat` in the order in
  1487. which they occur in `a_mat`.
  1488. """
  1489. assert a_mat.ndim == 2
  1490. return a_mat.T.reshape((a_mat.size,))
  1491. def test_moments(self):
  1492. r"""
  1493. Gupta and Nagar (2000) Theorem 4.3.1 (p.135)
  1494. --------------------------------------------
  1495. The covariance of the vectorized matrix variate t-distribution equals
  1496. $ (V \otimes U) / (\text{df} - 2)$, where $\otimes$
  1497. denotes the usual Kronecker product.
  1498. """
  1499. df = 5
  1500. num_rows = 4
  1501. num_cols = 3
  1502. M = np.full((num_rows, num_cols), 0.3)
  1503. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1504. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1505. N = 10**4
  1506. atol = 1e-1
  1507. frozen = matrix_t(mean=M, row_spread=U, col_spread=V, df=df)
  1508. X = frozen.rvs(size=N, random_state=42)
  1509. relerr = self.relative_error(M, X.mean(axis=0))
  1510. assert_close(relerr, 0, atol=atol)
  1511. cov_vec_true = np.kron(V, U) / (df - 2)
  1512. cov_vec_rvs = np.cov(np.array([self.vec(x) for x in X]), rowvar=False)
  1513. kl = self.matrix_divergence(cov_vec_true, cov_vec_rvs)
  1514. assert_close(kl, 0, atol=atol)
  1515. def test_pdf_against_julia(self):
  1516. """
  1517. Test values generated from Julia.
  1518. Dockerfile
  1519. ----------
  1520. FROM julia:1.11.5
  1521. RUN julia -e 'using Pkg; Pkg.add("Distributions"); Pkg.add("PDMats")'
  1522. WORKDIR /usr/src
  1523. Commands
  1524. --------
  1525. using DelimitedFiles
  1526. using Distributions
  1527. using PDMats
  1528. using Random
  1529. Random.seed!(42)
  1530. ν = 5
  1531. M = [1 2 3; 4 5 6]
  1532. Σ = PDMats.PDMat([1 0.5; 0.5 1])
  1533. Ω = PDMats.PDMat([1 0.3 0.2; 0.3 1 0.4; 0.2 0.4 1])
  1534. dist = MatrixTDist(ν, M, Σ, Ω)
  1535. samples = rand(dist, 10)
  1536. pdfs = [pdf(dist, s) for s in samples]
  1537. """
  1538. df = 5
  1539. M = np.array([[1, 2, 3], [4, 5, 6]])
  1540. U = np.array([[1, 0.5], [0.5, 1]])
  1541. V = np.array([[1, 0.3, 0.2], [0.3, 1, 0.4], [0.2, 0.4, 1]])
  1542. rtol = 1e-10
  1543. samples_j = np.array(
  1544. [
  1545. [
  1546. [0.958884881003464, 2.328976673167312, 2.936195396714506],
  1547. [3.656388568544394, 5.677549814962506, 6.292509556719057]
  1548. ],
  1549. [
  1550. [0.830992685140180, 2.588946865508210, 3.310327469315906],
  1551. [3.850637198786261, 5.106074165416971, 6.403143979925566]
  1552. ],
  1553. [
  1554. [1.572053537500711, 1.760828063560249, 2.812123062636012],
  1555. [4.156334686390513, 5.075942019982631, 5.827004350136873]
  1556. ],
  1557. [
  1558. [1.683810860278459, 2.801203900480317, 4.054517744825265],
  1559. [4.778239956376877, 5.070613721477604, 6.640349743267192]
  1560. ],
  1561. [
  1562. [0.443183825511296, 2.072092271247398, 3.045385527559403],
  1563. [4.374387994815022, 5.083432151729137, 5.958013783940404]
  1564. ],
  1565. [
  1566. [0.311591337218329, 1.162836182564980, 2.562167762547456],
  1567. [3.079154928756626, 4.202325496476140, 5.485839479663457]
  1568. ],
  1569. [
  1570. [0.943713128785340, 1.923800464789872, 2.511941262351750],
  1571. [4.124882619205123, 4.889406461458511, 5.689675454116582]
  1572. ],
  1573. [
  1574. [1.487852512870631, 1.933859334657448, 2.681311906634522],
  1575. [4.124418827930267, 5.335204598518954, 5.988120342017037]
  1576. ],
  1577. [
  1578. [1.002470749319751, 1.386785511789551, 2.890832331097640],
  1579. [4.372884362128993, 4.729718562700068, 6.732322315921552]
  1580. ],
  1581. [
  1582. [1.421351511333299, 2.106946903600814, 2.654619331838720],
  1583. [4.188693248790616, 5.336439611284261, 5.279121290355546]
  1584. ]
  1585. ]
  1586. )
  1587. pdfs_j = np.array(
  1588. [
  1589. 0.082798951655369,
  1590. 0.119993852401118,
  1591. 0.151969434727803,
  1592. 0.003620324481841,
  1593. 0.072538716346179,
  1594. 0.027002666410192,
  1595. 0.485180162388507,
  1596. 0.135740468069511,
  1597. 0.013619162593841,
  1598. 0.034813885519299
  1599. ]
  1600. )
  1601. pdfs_py = matrix_t.pdf(samples_j, mean=M, row_spread=U, col_spread=V, df=df)
  1602. assert_allclose(pdfs_j, pdfs_py, rtol=rtol)
  1603. def test_pdf_against_mathematica(self):
  1604. """
  1605. Test values generated from Mathematica 13.0.0 for Linux x86 (64-bit)
  1606. Release ID 13.0.0.0 (7522564, 2021120311723), Patch Level 0
  1607. mu={{1,2,3},{4,5,6}};
  1608. sigma={{1,0.5},{0.5,1}};
  1609. omega={{1,0.3,0.2},{0.3,1,0.4},{0.2,0.4,1}};
  1610. df=5;
  1611. sampleSize=10;
  1612. SeedRandom[42];
  1613. dist=MatrixTDistribution[mu,sigma,omega,df];
  1614. samples=SetPrecision[RandomVariate[dist,sampleSize],15];
  1615. pdfs=SetPrecision[PDF[dist,#]&/@samples,15];
  1616. """
  1617. df = 5
  1618. M = np.array([[1, 2, 3], [4, 5, 6]])
  1619. U = np.array([[1, 0.5], [0.5, 1]])
  1620. V = np.array([[1, 0.3, 0.2], [0.3, 1, 0.4], [0.2, 0.4, 1]])
  1621. rtol = 1e-10
  1622. samples_m = np.array(
  1623. [
  1624. [
  1625. [0.639971699425374, 2.171718671534955, 2.575826093352771],
  1626. [4.031082477912233, 5.021680958526638, 6.268126154787008],
  1627. ],
  1628. [
  1629. [1.164842884206232, 2.526297099993045, 3.781375229865069],
  1630. [3.912979114956833, 4.202714884504189, 5.661830748993523],
  1631. ],
  1632. [
  1633. [1.00461853907369, 2.080028751298565, 3.406489485602410],
  1634. [3.993327716320432, 5.655909265966448, 6.578059791357837],
  1635. ],
  1636. [
  1637. [0.80625209501374, 2.529009560674907, 2.807513313302189],
  1638. [3.722896768794995, 5.26987322525995, 5.801155613199776],
  1639. ],
  1640. [
  1641. [0.445816208657817, 3.224059910964103, 2.954990980541423],
  1642. [3.451520519442941, 7.064424621385415, 5.438834195890955],
  1643. ],
  1644. [
  1645. [0.919232769636664, 2.374572300756703, 3.495118928313048],
  1646. [3.924447237903237, 5.627654256287447, 5.806104608153957],
  1647. ],
  1648. [
  1649. [2.014242004090113, 1.377018127709871, 3.114064311468686],
  1650. [3.88881648137925, 4.603482820518904, 5.714205489738063],
  1651. ],
  1652. [
  1653. [1.322000147426889, 2.602135838377777, 2.558921028724319],
  1654. [4.50534702030683, 5.861137323151889, 5.181872548334852],
  1655. ],
  1656. [
  1657. [1.448743656862261, 2.053847557652242, 3.637321543241769],
  1658. [4.097711403906707, 4.506916241403669, 5.68010653497977],
  1659. ],
  1660. [
  1661. [1.045187318995198, 1.645467189679729, 3.284396214544507],
  1662. [3.648493466445393, 5.004212508553601, 6.301624351328048],
  1663. ],
  1664. ]
  1665. )
  1666. pdfs_m = np.array(
  1667. [
  1668. 0.085671937131824,
  1669. 0.004821273644067,
  1670. 0.105978034029754,
  1671. 0.174250448808208,
  1672. 3.945711836053583e-05,
  1673. 0.027158790350349,
  1674. 0.00299095120309,
  1675. 0.005594546018078,
  1676. 0.025788366971310,
  1677. 0.120210733598845,
  1678. ]
  1679. )
  1680. pdfs_py = matrix_t.pdf(samples_m, mean=M, row_spread=U, col_spread=V, df=df)
  1681. assert_allclose(pdfs_m, pdfs_py, rtol=rtol)
  1682. def test_samples(self):
  1683. df = 5
  1684. num_rows = 4
  1685. num_cols = 3
  1686. M = np.full((num_rows, num_cols), 0.3)
  1687. U = 0.5 * np.identity(num_rows) + np.full((num_rows, num_rows), 0.5)
  1688. V = 0.7 * np.identity(num_cols) + np.full((num_cols, num_cols), 0.3)
  1689. N = 10**4
  1690. rtol = 0.05
  1691. # `rvs` performs Cholesky-inverse-Wishart sampling on the smaller
  1692. # dimension of `mean`
  1693. frozen = matrix_t(mean=M, row_spread=U, col_spread=V, df=df)
  1694. X = frozen.rvs(size=N, random_state=42) # column-wise rvs
  1695. m = X.mean(0)
  1696. frozenT = matrix_t(mean=M.T, row_spread=V, col_spread=U, df=df)
  1697. XT = frozenT.rvs(size=N, random_state=42) # row-wise rvs
  1698. mT = XT.mean(0)
  1699. # Gupta and Nagar (2000) Theorem 4.3.3 (p.137)
  1700. # --------------------------------------------
  1701. # If T follows a matrix variate t-distribution with mean M and row_spread U
  1702. # and col_spread V and df degrees of freedom, then its transpose T.T follows
  1703. # a matrix variate t-distribution with mean M.T and row_spread V and
  1704. # col_spread U and df degrees of freedom.
  1705. assert_allclose(M, m, rtol=rtol)
  1706. assert_allclose(M.T, mT, rtol=rtol)
  1707. assert_allclose(m, mT.T, rtol=rtol)
  1708. assert_allclose(m.T, mT, rtol=rtol)
  1709. @pytest.mark.parametrize("shape_case", ["row", "col"])
  1710. def test_against_multivariate_t(self, shape_case):
  1711. r"""
  1712. Gupta and Nagar (2000) p.133f
  1713. When the number of rows or the number of columns equals 1 the
  1714. matrix t reduces to the multivariate t. But, the matrix t
  1715. is parameterized by raw 2nd moments whereas the multivariate t is
  1716. parameterized by a covariance (raw 2nd central moment normalized by df).
  1717. We can see the difference by comparing the author's notation
  1718. $t_p(n, \omega, \mathbf{\mu}, \Sigma)$
  1719. for a matrix t with a single column
  1720. to the formula (4.1.2) for the PDF of the multivariate t.
  1721. """
  1722. rtol = 1e-6
  1723. df = 5
  1724. if shape_case == "row":
  1725. num_rows = 1
  1726. num_cols = 3
  1727. row_spread = 1
  1728. col_spread = np.array([[1, 0.3, 0.2], [0.3, 1, 0.4], [0.2, 0.4, 1]])
  1729. shape = col_spread / df
  1730. else: # shape_case == "col"
  1731. num_rows = 3
  1732. num_cols = 1
  1733. row_spread = np.array([[1, 0.3, 0.2], [0.3, 1, 0.4], [0.2, 0.4, 1]])
  1734. col_spread=1
  1735. shape = row_spread / df
  1736. M = np.full((num_rows, num_cols), 0.3)
  1737. t_mat = matrix_t(
  1738. mean=M, row_spread=row_spread, col_spread=col_spread, df=df
  1739. )
  1740. t_mvt = multivariate_t(loc=M.squeeze(), shape=shape, df=df)
  1741. X = t_mat.rvs(size=3, random_state=42)
  1742. t_mat_logpdf = t_mat.logpdf(X)
  1743. t_mvt_logpdf = t_mvt.logpdf(X.squeeze())
  1744. assert_allclose(t_mvt_logpdf, t_mat_logpdf, rtol=rtol)
  1745. class TestDirichlet:
  1746. def test_frozen_dirichlet(self):
  1747. rng = np.random.default_rng(2846)
  1748. n = rng.integers(1, 32)
  1749. alpha = rng.uniform(10e-10, 100, n)
  1750. d = dirichlet(alpha)
  1751. assert_equal(d.var(), dirichlet.var(alpha))
  1752. assert_equal(d.mean(), dirichlet.mean(alpha))
  1753. assert_equal(d.entropy(), dirichlet.entropy(alpha))
  1754. num_tests = 10
  1755. for i in range(num_tests):
  1756. x = rng.uniform(10e-10, 100, n)
  1757. x /= np.sum(x)
  1758. assert_equal(d.pdf(x[:-1]), dirichlet.pdf(x[:-1], alpha))
  1759. assert_equal(d.logpdf(x[:-1]), dirichlet.logpdf(x[:-1], alpha))
  1760. def test_numpy_rvs_shape_compatibility(self):
  1761. rng = np.random.default_rng(2846)
  1762. alpha = np.array([1.0, 2.0, 3.0])
  1763. x = rng.dirichlet(alpha, size=7)
  1764. assert_equal(x.shape, (7, 3))
  1765. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1766. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1767. dirichlet.pdf(x.T, alpha)
  1768. dirichlet.pdf(x.T[:-1], alpha)
  1769. dirichlet.logpdf(x.T, alpha)
  1770. dirichlet.logpdf(x.T[:-1], alpha)
  1771. def test_alpha_with_zeros(self):
  1772. rng = np.random.default_rng(2846)
  1773. alpha = [1.0, 0.0, 3.0]
  1774. # don't pass invalid alpha to np.random.dirichlet
  1775. x = rng.dirichlet(np.maximum(1e-9, alpha), size=7).T
  1776. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1777. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1778. def test_alpha_with_negative_entries(self):
  1779. rng = np.random.default_rng(2846)
  1780. alpha = [1.0, -2.0, 3.0]
  1781. # don't pass invalid alpha to np.random.dirichlet
  1782. x = rng.dirichlet(np.maximum(1e-9, alpha), size=7).T
  1783. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1784. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1785. def test_data_with_zeros(self):
  1786. alpha = np.array([1.0, 2.0, 3.0, 4.0])
  1787. x = np.array([0.1, 0.0, 0.2, 0.7])
  1788. dirichlet.pdf(x, alpha)
  1789. dirichlet.logpdf(x, alpha)
  1790. alpha = np.array([1.0, 1.0, 1.0, 1.0])
  1791. assert_almost_equal(dirichlet.pdf(x, alpha), 6)
  1792. assert_almost_equal(dirichlet.logpdf(x, alpha), np.log(6))
  1793. def test_data_with_zeros_and_small_alpha(self):
  1794. alpha = np.array([1.0, 0.5, 3.0, 4.0])
  1795. x = np.array([0.1, 0.0, 0.2, 0.7])
  1796. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1797. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1798. def test_data_with_negative_entries(self):
  1799. alpha = np.array([1.0, 2.0, 3.0, 4.0])
  1800. x = np.array([0.1, -0.1, 0.3, 0.7])
  1801. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1802. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1803. def test_data_with_too_large_entries(self):
  1804. alpha = np.array([1.0, 2.0, 3.0, 4.0])
  1805. x = np.array([0.1, 1.1, 0.3, 0.7])
  1806. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1807. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1808. def test_data_too_deep_c(self):
  1809. alpha = np.array([1.0, 2.0, 3.0])
  1810. x = np.full((2, 7, 7), 1 / 14)
  1811. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1812. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1813. def test_alpha_too_deep(self):
  1814. alpha = np.array([[1.0, 2.0], [3.0, 4.0]])
  1815. x = np.full((2, 2, 7), 1 / 4)
  1816. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1817. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1818. def test_alpha_correct_depth(self):
  1819. alpha = np.array([1.0, 2.0, 3.0])
  1820. x = np.full((3, 7), 1 / 3)
  1821. dirichlet.pdf(x, alpha)
  1822. dirichlet.logpdf(x, alpha)
  1823. def test_non_simplex_data(self):
  1824. alpha = np.array([1.0, 2.0, 3.0])
  1825. x = np.full((3, 7), 1 / 2)
  1826. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1827. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1828. def test_data_vector_too_short(self):
  1829. alpha = np.array([1.0, 2.0, 3.0, 4.0])
  1830. x = np.full((2, 7), 1 / 2)
  1831. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1832. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1833. def test_data_vector_too_long(self):
  1834. alpha = np.array([1.0, 2.0, 3.0, 4.0])
  1835. x = np.full((5, 7), 1 / 5)
  1836. assert_raises(ValueError, dirichlet.pdf, x, alpha)
  1837. assert_raises(ValueError, dirichlet.logpdf, x, alpha)
  1838. def test_mean_var_cov(self):
  1839. # Reference values calculated by hand and confirmed with Mathematica, e.g.
  1840. # `Covariance[DirichletDistribution[{ 1, 0.8, 0.2, 10^-300}]]`
  1841. alpha = np.array([1., 0.8, 0.2])
  1842. d = dirichlet(alpha)
  1843. expected_mean = [0.5, 0.4, 0.1]
  1844. expected_var = [1. / 12., 0.08, 0.03]
  1845. expected_cov = [
  1846. [ 1. / 12, -1. / 15, -1. / 60],
  1847. [-1. / 15, 2. / 25, -1. / 75],
  1848. [-1. / 60, -1. / 75, 3. / 100],
  1849. ]
  1850. assert_array_almost_equal(d.mean(), expected_mean)
  1851. assert_array_almost_equal(d.var(), expected_var)
  1852. assert_array_almost_equal(d.cov(), expected_cov)
  1853. def test_scalar_values(self):
  1854. alpha = np.array([0.2])
  1855. d = dirichlet(alpha)
  1856. # For alpha of length 1, mean and var should be scalar instead of array
  1857. assert_equal(d.mean().ndim, 0)
  1858. assert_equal(d.var().ndim, 0)
  1859. assert_equal(d.pdf([1.]).ndim, 0)
  1860. assert_equal(d.logpdf([1.]).ndim, 0)
  1861. def test_K_and_K_minus_1_calls_equal(self):
  1862. # Test that calls with K and K-1 entries yield the same results.
  1863. rng = np.random.default_rng(2846)
  1864. n = rng.integers(1, 32)
  1865. alpha = rng.uniform(10e-10, 100, n)
  1866. d = dirichlet(alpha)
  1867. num_tests = 10
  1868. for i in range(num_tests):
  1869. x = rng.uniform(10e-10, 100, n)
  1870. x /= np.sum(x)
  1871. assert_almost_equal(d.pdf(x[:-1]), d.pdf(x))
  1872. def test_multiple_entry_calls(self):
  1873. # Test that calls with multiple x vectors as matrix work
  1874. rng = np.random.default_rng(2846)
  1875. n = rng.integers(1, 32)
  1876. alpha = rng.uniform(10e-10, 100, n)
  1877. d = dirichlet(alpha)
  1878. num_tests = 10
  1879. num_multiple = 5
  1880. xm = None
  1881. for i in range(num_tests):
  1882. for m in range(num_multiple):
  1883. x = rng.uniform(10e-10, 100, n)
  1884. x /= np.sum(x)
  1885. if xm is not None:
  1886. xm = np.vstack((xm, x))
  1887. else:
  1888. xm = x
  1889. rm = d.pdf(xm.T)
  1890. rs = None
  1891. for xs in xm:
  1892. r = d.pdf(xs)
  1893. if rs is not None:
  1894. rs = np.append(rs, r)
  1895. else:
  1896. rs = r
  1897. assert_array_almost_equal(rm, rs)
  1898. def test_2D_dirichlet_is_beta(self):
  1899. rng = np.random.default_rng(2846)
  1900. alpha = rng.uniform(10e-10, 100, 2)
  1901. d = dirichlet(alpha)
  1902. b = beta(alpha[0], alpha[1])
  1903. num_tests = 10
  1904. for i in range(num_tests):
  1905. x = rng.uniform(10e-10, 100, 2)
  1906. x /= np.sum(x)
  1907. assert_almost_equal(b.pdf(x), d.pdf([x]))
  1908. assert_almost_equal(b.mean(), d.mean()[0])
  1909. assert_almost_equal(b.var(), d.var()[0])
  1910. def test_multivariate_normal_dimensions_mismatch():
  1911. # Regression test for GH #3493. Check that setting up a PDF with a mean of
  1912. # length M and a covariance matrix of size (N, N), where M != N, raises a
  1913. # ValueError with an informative error message.
  1914. mu = np.array([0.0, 0.0])
  1915. sigma = np.array([[1.0]])
  1916. assert_raises(ValueError, multivariate_normal, mu, sigma)
  1917. # A simple check that the right error message was passed along. Checking
  1918. # that the entire message is there, word for word, would be somewhat
  1919. # fragile, so we just check for the leading part.
  1920. try:
  1921. multivariate_normal(mu, sigma)
  1922. except ValueError as e:
  1923. msg = "Dimension mismatch"
  1924. assert_equal(str(e)[:len(msg)], msg)
  1925. class TestWishart:
  1926. def test_scale_dimensions(self):
  1927. # Test that we can call the Wishart with various scale dimensions
  1928. # Test case: dim=1, scale=1
  1929. true_scale = np.array(1, ndmin=2)
  1930. scales = [
  1931. 1, # scalar
  1932. [1], # iterable
  1933. np.array(1), # 0-dim
  1934. np.r_[1], # 1-dim
  1935. np.array(1, ndmin=2) # 2-dim
  1936. ]
  1937. for scale in scales:
  1938. w = wishart(1, scale)
  1939. assert_equal(w.scale, true_scale)
  1940. assert_equal(w.scale.shape, true_scale.shape)
  1941. # Test case: dim=2, scale=[[1,0]
  1942. # [0,2]
  1943. true_scale = np.array([[1,0],
  1944. [0,2]])
  1945. scales = [
  1946. [1,2], # iterable
  1947. np.r_[1,2], # 1-dim
  1948. np.array([[1,0], # 2-dim
  1949. [0,2]])
  1950. ]
  1951. for scale in scales:
  1952. w = wishart(2, scale)
  1953. assert_equal(w.scale, true_scale)
  1954. assert_equal(w.scale.shape, true_scale.shape)
  1955. # We cannot call with a df < dim - 1
  1956. assert_raises(ValueError, wishart, 1, np.eye(2))
  1957. # But we can call with dim - 1 < df < dim
  1958. wishart(1.1, np.eye(2)) # no error
  1959. # see gh-5562
  1960. # We cannot call with a 3-dimension array
  1961. scale = np.array(1, ndmin=3)
  1962. assert_raises(ValueError, wishart, 1, scale)
  1963. def test_quantile_dimensions(self):
  1964. # Test that we can call the Wishart rvs with various quantile dimensions
  1965. # If dim == 1, consider x.shape = [1,1,1]
  1966. X = [
  1967. 1, # scalar
  1968. [1], # iterable
  1969. np.array(1), # 0-dim
  1970. np.r_[1], # 1-dim
  1971. np.array(1, ndmin=2), # 2-dim
  1972. np.array([1], ndmin=3) # 3-dim
  1973. ]
  1974. w = wishart(1,1)
  1975. density = w.pdf(np.array(1, ndmin=3))
  1976. for x in X:
  1977. assert_equal(w.pdf(x), density)
  1978. # If dim == 1, consider x.shape = [1,1,*]
  1979. X = [
  1980. [1,2,3], # iterable
  1981. np.r_[1,2,3], # 1-dim
  1982. np.array([1,2,3], ndmin=3) # 3-dim
  1983. ]
  1984. w = wishart(1,1)
  1985. density = w.pdf(np.array([1,2,3], ndmin=3))
  1986. for x in X:
  1987. assert_equal(w.pdf(x), density)
  1988. # If dim == 2, consider x.shape = [2,2,1]
  1989. # where x[:,:,*] = np.eye(1)*2
  1990. X = [
  1991. 2, # scalar
  1992. [2,2], # iterable
  1993. np.array(2), # 0-dim
  1994. np.r_[2,2], # 1-dim
  1995. np.array([[2,0],
  1996. [0,2]]), # 2-dim
  1997. np.array([[2,0],
  1998. [0,2]])[:,:,np.newaxis] # 3-dim
  1999. ]
  2000. w = wishart(2,np.eye(2))
  2001. density = w.pdf(np.array([[2,0],
  2002. [0,2]])[:,:,np.newaxis])
  2003. for x in X:
  2004. assert_equal(w.pdf(x), density)
  2005. def test_frozen(self):
  2006. # Test that the frozen and non-frozen Wishart gives the same answers
  2007. # Construct an arbitrary positive definite scale matrix
  2008. dim = 4
  2009. scale = np.diag(np.arange(dim)+1)
  2010. scale[np.tril_indices(dim, k=-1)] = np.arange(dim * (dim-1) // 2)
  2011. scale = np.dot(scale.T, scale)
  2012. # Construct a collection of positive definite matrices to test the PDF
  2013. X = []
  2014. for i in range(5):
  2015. x = np.diag(np.arange(dim)+(i+1)**2)
  2016. x[np.tril_indices(dim, k=-1)] = np.arange(dim * (dim-1) // 2)
  2017. x = np.dot(x.T, x)
  2018. X.append(x)
  2019. X = np.array(X).T
  2020. # Construct a 1D and 2D set of parameters
  2021. parameters = [
  2022. (10, 1, np.linspace(0.1, 10, 5)), # 1D case
  2023. (10, scale, X)
  2024. ]
  2025. for (df, scale, x) in parameters:
  2026. w = wishart(df, scale)
  2027. assert_equal(w.var(), wishart.var(df, scale))
  2028. assert_equal(w.mean(), wishart.mean(df, scale))
  2029. assert_equal(w.mode(), wishart.mode(df, scale))
  2030. assert_equal(w.entropy(), wishart.entropy(df, scale))
  2031. assert_equal(w.pdf(x), wishart.pdf(x, df, scale))
  2032. def test_wishart_2D_rvs(self):
  2033. dim = 3
  2034. df = 10
  2035. # Construct a simple non-diagonal positive definite matrix
  2036. scale = np.eye(dim)
  2037. scale[0,1] = 0.5
  2038. scale[1,0] = 0.5
  2039. # Construct frozen Wishart random variables
  2040. w = wishart(df, scale)
  2041. # Get the generated random variables from a known seed
  2042. rng = np.random.RandomState(248042)
  2043. w_rvs = wishart.rvs(df, scale, random_state=rng)
  2044. rng = np.random.RandomState(248042)
  2045. frozen_w_rvs = w.rvs(random_state=rng)
  2046. # Manually calculate what it should be, based on the Bartlett (1933)
  2047. # decomposition of a Wishart into D A A' D', where D is the Cholesky
  2048. # factorization of the scale matrix and A is the lower triangular matrix
  2049. # with the square root of chi^2 variates on the diagonal and N(0,1)
  2050. # variates in the lower triangle.
  2051. rng = np.random.RandomState(248042)
  2052. covariances = rng.normal(size=3)
  2053. variances = np.r_[
  2054. rng.chisquare(df),
  2055. rng.chisquare(df-1),
  2056. rng.chisquare(df-2),
  2057. ]**0.5
  2058. # Construct the lower-triangular A matrix
  2059. A = np.diag(variances)
  2060. A[np.tril_indices(dim, k=-1)] = covariances
  2061. # Wishart random variate
  2062. D = np.linalg.cholesky(scale)
  2063. DA = D.dot(A)
  2064. manual_w_rvs = np.dot(DA, DA.T)
  2065. # Test for equality
  2066. assert_allclose(w_rvs, manual_w_rvs)
  2067. assert_allclose(frozen_w_rvs, manual_w_rvs)
  2068. def test_1D_is_chisquared(self):
  2069. # The 1-dimensional Wishart with an identity scale matrix is just a
  2070. # chi-squared distribution.
  2071. # Test variance, mean, entropy, pdf
  2072. # Kolgomorov-Smirnov test for rvs
  2073. rng = np.random.default_rng(482974)
  2074. sn = 500
  2075. dim = 1
  2076. scale = np.eye(dim)
  2077. df_range = np.arange(1, 10, 2, dtype=float)
  2078. X = np.linspace(0.1,10,num=10)
  2079. for df in df_range:
  2080. w = wishart(df, scale)
  2081. c = chi2(df)
  2082. # Statistics
  2083. assert_allclose(w.var(), c.var())
  2084. assert_allclose(w.mean(), c.mean())
  2085. assert_allclose(w.entropy(), c.entropy())
  2086. # PDF
  2087. assert_allclose(w.pdf(X), c.pdf(X))
  2088. # rvs
  2089. rvs = w.rvs(size=sn, random_state=rng)
  2090. args = (df,)
  2091. alpha = 0.01
  2092. check_distribution_rvs('chi2', args, alpha, rvs)
  2093. def test_is_scaled_chisquared(self):
  2094. # The 2-dimensional Wishart with an arbitrary scale matrix can be
  2095. # transformed to a scaled chi-squared distribution.
  2096. # For :math:`S \sim W_p(V,n)` and :math:`\lambda \in \mathbb{R}^p` we have
  2097. # :math:`\lambda' S \lambda \sim \lambda' V \lambda \times \chi^2(n)`
  2098. rng = np.random.default_rng(482974)
  2099. sn = 500
  2100. df = 10
  2101. dim = 4
  2102. # Construct an arbitrary positive definite matrix
  2103. scale = np.diag(np.arange(4)+1)
  2104. scale[np.tril_indices(4, k=-1)] = np.arange(6)
  2105. scale = np.dot(scale.T, scale)
  2106. # Use :math:`\lambda = [1, \dots, 1]'`
  2107. lamda = np.ones((dim,1))
  2108. sigma_lamda = lamda.T.dot(scale).dot(lamda).squeeze()
  2109. w = wishart(df, sigma_lamda)
  2110. c = chi2(df, scale=sigma_lamda)
  2111. # Statistics
  2112. assert_allclose(w.var(), c.var())
  2113. assert_allclose(w.mean(), c.mean())
  2114. assert_allclose(w.entropy(), c.entropy())
  2115. # PDF
  2116. X = np.linspace(0.1,10,num=10)
  2117. assert_allclose(w.pdf(X), c.pdf(X))
  2118. # rvs
  2119. rvs = w.rvs(size=sn, random_state=rng)
  2120. args = (df,0,sigma_lamda)
  2121. alpha = 0.01
  2122. check_distribution_rvs('chi2', args, alpha, rvs)
  2123. class TestMultinomial:
  2124. def test_logpmf(self):
  2125. vals1 = multinomial.logpmf((3,4), 7, (0.3, 0.7))
  2126. assert_allclose(vals1, -1.483270127243324, rtol=1e-8)
  2127. vals2 = multinomial.logpmf([3, 4], 0, [.3, .7])
  2128. assert vals2 == -np.inf
  2129. vals3 = multinomial.logpmf([0, 0], 0, [.3, .7])
  2130. assert vals3 == 0
  2131. vals4 = multinomial.logpmf([3, 4], 0, [-2, 3])
  2132. assert_allclose(vals4, np.nan, rtol=1e-8)
  2133. def test_reduces_binomial(self):
  2134. # test that the multinomial pmf reduces to the binomial pmf in the 2d
  2135. # case
  2136. val1 = multinomial.logpmf((3, 4), 7, (0.3, 0.7))
  2137. val2 = binom.logpmf(3, 7, 0.3)
  2138. assert_allclose(val1, val2, rtol=1e-8)
  2139. val1 = multinomial.pmf((6, 8), 14, (0.1, 0.9))
  2140. val2 = binom.pmf(6, 14, 0.1)
  2141. assert_allclose(val1, val2, rtol=1e-8)
  2142. def test_R(self):
  2143. # test against the values produced by this R code
  2144. # (https://stat.ethz.ch/R-manual/R-devel/library/stats/html/Multinom.html)
  2145. # X <- t(as.matrix(expand.grid(0:3, 0:3))); X <- X[, colSums(X) <= 3]
  2146. # X <- rbind(X, 3:3 - colSums(X)); dimnames(X) <- list(letters[1:3], NULL)
  2147. # X
  2148. # apply(X, 2, function(x) dmultinom(x, prob = c(1,2,5)))
  2149. n, p = 3, [1./8, 2./8, 5./8]
  2150. r_vals = {(0, 0, 3): 0.244140625, (1, 0, 2): 0.146484375,
  2151. (2, 0, 1): 0.029296875, (3, 0, 0): 0.001953125,
  2152. (0, 1, 2): 0.292968750, (1, 1, 1): 0.117187500,
  2153. (2, 1, 0): 0.011718750, (0, 2, 1): 0.117187500,
  2154. (1, 2, 0): 0.023437500, (0, 3, 0): 0.015625000}
  2155. for x in r_vals:
  2156. assert_allclose(multinomial.pmf(x, n, p), r_vals[x], atol=1e-14)
  2157. @pytest.mark.parametrize("n", [0, 3])
  2158. def test_rvs_np(self, n):
  2159. # test that .rvs agrees w/numpy
  2160. message = "Some rows of `p` do not sum to 1.0 within..."
  2161. with pytest.warns(FutureWarning, match=message):
  2162. rndm = np.random.RandomState(123)
  2163. sc_rvs = multinomial.rvs(n, [1/4.]*3, size=7, random_state=123)
  2164. np_rvs = rndm.multinomial(n, [1/4.]*3, size=7)
  2165. assert_equal(sc_rvs, np_rvs)
  2166. with pytest.warns(FutureWarning, match=message):
  2167. rndm = np.random.RandomState(123)
  2168. sc_rvs = multinomial.rvs(n, [1/4.]*5, size=7, random_state=123)
  2169. np_rvs = rndm.multinomial(n, [1/4.]*5, size=7)
  2170. assert_equal(sc_rvs, np_rvs)
  2171. def test_pmf(self):
  2172. vals0 = multinomial.pmf((5,), 5, (1,))
  2173. assert_allclose(vals0, 1, rtol=1e-8)
  2174. vals1 = multinomial.pmf((3,4), 7, (.3, .7))
  2175. assert_allclose(vals1, .22689449999999994, rtol=1e-8)
  2176. vals2 = multinomial.pmf([[[3,5],[0,8]], [[-1, 9], [1, 1]]], 8,
  2177. (.1, .9))
  2178. assert_allclose(vals2, [[.03306744, .43046721], [0, 0]], rtol=1e-8)
  2179. x = np.empty((0,2), dtype=np.float64)
  2180. vals3 = multinomial.pmf(x, 4, (.3, .7))
  2181. assert_equal(vals3, np.empty([], dtype=np.float64))
  2182. vals4 = multinomial.pmf([1,2], 4, (.3, .7))
  2183. assert_allclose(vals4, 0, rtol=1e-8)
  2184. vals5 = multinomial.pmf([3, 3, 0], 6, [2/3.0, 1/3.0, 0])
  2185. assert_allclose(vals5, 0.219478737997, rtol=1e-8)
  2186. vals5 = multinomial.pmf([0, 0, 0], 0, [2/3.0, 1/3.0, 0])
  2187. assert vals5 == 1
  2188. vals6 = multinomial.pmf([2, 1, 0], 0, [2/3.0, 1/3.0, 0])
  2189. assert vals6 == 0
  2190. def test_pmf_broadcasting(self):
  2191. vals0 = multinomial.pmf([1, 2], 3, [[.1, .9], [.2, .8]])
  2192. assert_allclose(vals0, [.243, .384], rtol=1e-8)
  2193. vals1 = multinomial.pmf([1, 2], [3, 4], [.1, .9])
  2194. assert_allclose(vals1, [.243, 0], rtol=1e-8)
  2195. vals2 = multinomial.pmf([[[1, 2], [1, 1]]], 3, [.1, .9])
  2196. assert_allclose(vals2, [[.243, 0]], rtol=1e-8)
  2197. vals3 = multinomial.pmf([1, 2], [[[3], [4]]], [.1, .9])
  2198. assert_allclose(vals3, [[[.243], [0]]], rtol=1e-8)
  2199. vals4 = multinomial.pmf([[1, 2], [1,1]], [[[[3]]]], [.1, .9])
  2200. assert_allclose(vals4, [[[[.243, 0]]]], rtol=1e-8)
  2201. @pytest.mark.parametrize("n", [0, 5])
  2202. def test_cov(self, n):
  2203. cov1 = multinomial.cov(n, (.2, .3, .5))
  2204. cov2 = [[n*.2*.8, -n*.2*.3, -n*.2*.5],
  2205. [-n*.3*.2, n*.3*.7, -n*.3*.5],
  2206. [-n*.5*.2, -n*.5*.3, n*.5*.5]]
  2207. assert_allclose(cov1, cov2, rtol=1e-8)
  2208. def test_cov_broadcasting(self):
  2209. cov1 = multinomial.cov(5, [[.1, .9], [.2, .8]])
  2210. cov2 = [[[.45, -.45],[-.45, .45]], [[.8, -.8], [-.8, .8]]]
  2211. assert_allclose(cov1, cov2, rtol=1e-8)
  2212. cov3 = multinomial.cov([4, 5], [.1, .9])
  2213. cov4 = [[[.36, -.36], [-.36, .36]], [[.45, -.45], [-.45, .45]]]
  2214. assert_allclose(cov3, cov4, rtol=1e-8)
  2215. cov5 = multinomial.cov([4, 5], [[.3, .7], [.4, .6]])
  2216. cov6 = [[[4*.3*.7, -4*.3*.7], [-4*.3*.7, 4*.3*.7]],
  2217. [[5*.4*.6, -5*.4*.6], [-5*.4*.6, 5*.4*.6]]]
  2218. assert_allclose(cov5, cov6, rtol=1e-8)
  2219. @pytest.mark.parametrize("n", [0, 2])
  2220. def test_entropy(self, n):
  2221. # this is equivalent to a binomial distribution with n=2, so the
  2222. # entropy .77899774929 is easily computed "by hand"
  2223. ent0 = multinomial.entropy(n, [.2, .8])
  2224. assert_allclose(ent0, binom.entropy(n, .2), rtol=1e-8)
  2225. def test_entropy_broadcasting(self):
  2226. ent0 = multinomial.entropy([2, 3], [.2, .8])
  2227. assert_allclose(ent0, [binom.entropy(2, .2), binom.entropy(3, .2)],
  2228. rtol=1e-8)
  2229. ent1 = multinomial.entropy([7, 8], [[.3, .7], [.4, .6]])
  2230. assert_allclose(ent1, [binom.entropy(7, .3), binom.entropy(8, .4)],
  2231. rtol=1e-8)
  2232. ent2 = multinomial.entropy([[7], [8]], [[.3, .7], [.4, .6]])
  2233. assert_allclose(ent2,
  2234. [[binom.entropy(7, .3), binom.entropy(7, .4)],
  2235. [binom.entropy(8, .3), binom.entropy(8, .4)]],
  2236. rtol=1e-8)
  2237. @pytest.mark.parametrize("n", [0, 5])
  2238. def test_mean(self, n):
  2239. mean1 = multinomial.mean(n, [.2, .8])
  2240. assert_allclose(mean1, [n*.2, n*.8], rtol=1e-8)
  2241. def test_mean_broadcasting(self):
  2242. mean1 = multinomial.mean([5, 6], [.2, .8])
  2243. assert_allclose(mean1, [[5*.2, 5*.8], [6*.2, 6*.8]], rtol=1e-8)
  2244. def test_frozen(self):
  2245. # The frozen distribution should agree with the regular one
  2246. n = 12
  2247. pvals = (.1, .2, .3, .4)
  2248. x = [[0,0,0,12],[0,0,1,11],[0,1,1,10],[1,1,1,9],[1,1,2,8]]
  2249. x = np.asarray(x, dtype=np.float64)
  2250. mn_frozen = multinomial(n, pvals)
  2251. assert_allclose(mn_frozen.pmf(x), multinomial.pmf(x, n, pvals))
  2252. assert_allclose(mn_frozen.logpmf(x), multinomial.logpmf(x, n, pvals))
  2253. assert_allclose(mn_frozen.entropy(), multinomial.entropy(n, pvals))
  2254. def test_gh_11860(self):
  2255. # gh-11860 reported cases in which the adjustments made by multinomial
  2256. # to the last element of `p` can cause `nan`s even when the input is
  2257. # essentially valid. Check that a pathological case returns a finite,
  2258. # nonzero result. (This would fail in main before the PR.)
  2259. n = 88
  2260. rng = np.random.default_rng(8879715917488330089)
  2261. p = rng.random(n)
  2262. p[-1] = 1e-30
  2263. p /= np.sum(p)
  2264. x = np.ones(n)
  2265. logpmf = multinomial.logpmf(x, n, p)
  2266. assert np.isfinite(logpmf)
  2267. @pytest.mark.parametrize('dtype', [np.float32, np.float64])
  2268. def test_gh_22565(self, dtype):
  2269. # Same issue as gh-11860 above, essentially, but the original
  2270. # fix didn't completely solve the problem.
  2271. n = 19
  2272. p = np.asarray([0.2, 0.2, 0.2, 0.2, 0.2], dtype=dtype)
  2273. res1 = multinomial.pmf(x=[1, 2, 5, 7, 4], n=n, p=p)
  2274. res2 = multinomial.pmf(x=[1, 2, 4, 5, 7], n=n, p=p)
  2275. np.testing.assert_allclose(res1, res2, rtol=1e-15)
  2276. class TestInvwishart:
  2277. def test_frozen(self):
  2278. # Test that the frozen and non-frozen inverse Wishart gives the same
  2279. # answers
  2280. # Construct an arbitrary positive definite scale matrix
  2281. dim = 4
  2282. scale = np.diag(np.arange(dim)+1)
  2283. scale[np.tril_indices(dim, k=-1)] = np.arange(dim*(dim-1)/2)
  2284. scale = np.dot(scale.T, scale)
  2285. # Construct a collection of positive definite matrices to test the PDF
  2286. X = []
  2287. for i in range(5):
  2288. x = np.diag(np.arange(dim)+(i+1)**2)
  2289. x[np.tril_indices(dim, k=-1)] = np.arange(dim*(dim-1)/2)
  2290. x = np.dot(x.T, x)
  2291. X.append(x)
  2292. X = np.array(X).T
  2293. # Construct a 1D and 2D set of parameters
  2294. parameters = [
  2295. (10, 1, np.linspace(0.1, 10, 5)), # 1D case
  2296. (10, scale, X)
  2297. ]
  2298. for (df, scale, x) in parameters:
  2299. iw = invwishart(df, scale)
  2300. assert_equal(iw.var(), invwishart.var(df, scale))
  2301. assert_equal(iw.mean(), invwishart.mean(df, scale))
  2302. assert_equal(iw.mode(), invwishart.mode(df, scale))
  2303. assert_allclose(iw.pdf(x), invwishart.pdf(x, df, scale))
  2304. def test_1D_is_invgamma(self):
  2305. # The 1-dimensional inverse Wishart with an identity scale matrix is
  2306. # just an inverse gamma distribution.
  2307. # Test variance, mean, pdf, entropy
  2308. # Kolgomorov-Smirnov test for rvs
  2309. rng = np.random.RandomState(482974)
  2310. sn = 500
  2311. dim = 1
  2312. scale = np.eye(dim)
  2313. df_range = np.arange(5, 20, 2, dtype=float)
  2314. X = np.linspace(0.1,10,num=10)
  2315. for df in df_range:
  2316. iw = invwishart(df, scale)
  2317. ig = invgamma(df/2, scale=1./2)
  2318. # Statistics
  2319. assert_allclose(iw.var(), ig.var())
  2320. assert_allclose(iw.mean(), ig.mean())
  2321. # PDF
  2322. assert_allclose(iw.pdf(X), ig.pdf(X))
  2323. # rvs
  2324. rvs = iw.rvs(size=sn, random_state=rng)
  2325. args = (df/2, 0, 1./2)
  2326. alpha = 0.01
  2327. check_distribution_rvs('invgamma', args, alpha, rvs)
  2328. # entropy
  2329. assert_allclose(iw.entropy(), ig.entropy())
  2330. def test_invwishart_2D_rvs(self):
  2331. dim = 3
  2332. df = 10
  2333. # Construct a simple non-diagonal positive definite matrix
  2334. scale = np.eye(dim)
  2335. scale[0,1] = 0.5
  2336. scale[1,0] = 0.5
  2337. # Construct frozen inverse-Wishart random variables
  2338. iw = invwishart(df, scale)
  2339. # Get the generated random variables from a known seed
  2340. rng = np.random.RandomState(608072)
  2341. iw_rvs = invwishart.rvs(df, scale, random_state=rng)
  2342. rng = np.random.RandomState(608072)
  2343. frozen_iw_rvs = iw.rvs(random_state=rng)
  2344. # Manually calculate what it should be, based on the decomposition in
  2345. # https://arxiv.org/abs/2310.15884 of an invers-Wishart into L L',
  2346. # where L A = D, D is the Cholesky factorization of the scale matrix,
  2347. # and A is the lower triangular matrix with the square root of chi^2
  2348. # variates on the diagonal and N(0,1) variates in the lower triangle.
  2349. # the diagonal chi^2 variates in this A are reversed compared to those
  2350. # in the Bartlett decomposition A for Wishart rvs.
  2351. rng = np.random.RandomState(608072)
  2352. covariances = rng.normal(size=3)
  2353. variances = np.r_[
  2354. rng.chisquare(df-2),
  2355. rng.chisquare(df-1),
  2356. rng.chisquare(df),
  2357. ]**0.5
  2358. # Construct the lower-triangular A matrix
  2359. A = np.diag(variances)
  2360. A[np.tril_indices(dim, k=-1)] = covariances
  2361. # inverse-Wishart random variate
  2362. D = np.linalg.cholesky(scale)
  2363. L = np.linalg.solve(A.T, D.T).T
  2364. manual_iw_rvs = np.dot(L, L.T)
  2365. # Test for equality
  2366. assert_allclose(iw_rvs, manual_iw_rvs)
  2367. assert_allclose(frozen_iw_rvs, manual_iw_rvs)
  2368. def test_sample_mean(self):
  2369. """Test that sample mean consistent with known mean."""
  2370. # Construct an arbitrary positive definite scale matrix
  2371. df = 10
  2372. sample_size = 20_000
  2373. for dim in [1, 5]:
  2374. scale = np.diag(np.arange(dim) + 1)
  2375. scale[np.tril_indices(dim, k=-1)] = np.arange(dim * (dim - 1) / 2)
  2376. scale = np.dot(scale.T, scale)
  2377. dist = invwishart(df, scale)
  2378. Xmean_exp = dist.mean()
  2379. Xvar_exp = dist.var()
  2380. Xmean_std = (Xvar_exp / sample_size)**0.5 # asymptotic SE of mean estimate
  2381. X = dist.rvs(size=sample_size, random_state=1234)
  2382. Xmean_est = X.mean(axis=0)
  2383. ntests = dim*(dim + 1)//2
  2384. fail_rate = 0.01 / ntests # correct for multiple tests
  2385. max_diff = norm.ppf(1 - fail_rate / 2)
  2386. assert np.allclose(
  2387. (Xmean_est - Xmean_exp) / Xmean_std,
  2388. 0,
  2389. atol=max_diff,
  2390. )
  2391. def test_logpdf_4x4(self):
  2392. """Regression test for gh-8844."""
  2393. X = np.array([[2, 1, 0, 0.5],
  2394. [1, 2, 0.5, 0.5],
  2395. [0, 0.5, 3, 1],
  2396. [0.5, 0.5, 1, 2]])
  2397. Psi = np.array([[9, 7, 3, 1],
  2398. [7, 9, 5, 1],
  2399. [3, 5, 8, 2],
  2400. [1, 1, 2, 9]])
  2401. nu = 6
  2402. prob = invwishart.logpdf(X, nu, Psi)
  2403. # Explicit calculation from the formula on wikipedia.
  2404. p = X.shape[0]
  2405. sig, logdetX = np.linalg.slogdet(X)
  2406. sig, logdetPsi = np.linalg.slogdet(Psi)
  2407. M = np.linalg.solve(X, Psi)
  2408. expected = ((nu/2)*logdetPsi
  2409. - (nu*p/2)*np.log(2)
  2410. - multigammaln(nu/2, p)
  2411. - (nu + p + 1)/2*logdetX
  2412. - 0.5*M.trace())
  2413. assert_allclose(prob, expected)
  2414. class TestSpecialOrthoGroup:
  2415. def test_reproducibility(self):
  2416. x = special_ortho_group.rvs(3, random_state=np.random.default_rng(514))
  2417. expected = np.array([[-0.93200988, 0.01533561, -0.36210826],
  2418. [0.35742128, 0.20446501, -0.91128705],
  2419. [0.06006333, -0.97875374, -0.19604469]])
  2420. assert_array_almost_equal(x, expected)
  2421. def test_invalid_dim(self):
  2422. assert_raises(ValueError, special_ortho_group.rvs, None)
  2423. assert_raises(ValueError, special_ortho_group.rvs, (2, 2))
  2424. assert_raises(ValueError, special_ortho_group.rvs, -1)
  2425. assert_raises(ValueError, special_ortho_group.rvs, 2.5)
  2426. def test_frozen_matrix(self):
  2427. dim = 7
  2428. frozen = special_ortho_group(dim)
  2429. rvs1 = frozen.rvs(random_state=1234)
  2430. rvs2 = special_ortho_group.rvs(dim, random_state=1234)
  2431. assert_equal(rvs1, rvs2)
  2432. def test_det_and_ortho(self):
  2433. xs = [special_ortho_group.rvs(dim)
  2434. for dim in range(2,12)
  2435. for i in range(3)]
  2436. # Test that determinants are always +1
  2437. dets = [np.linalg.det(x) for x in xs]
  2438. assert_allclose(dets, [1.]*30, rtol=1e-13)
  2439. # Test that these are orthogonal matrices
  2440. for x in xs:
  2441. assert_array_almost_equal(np.dot(x, x.T),
  2442. np.eye(x.shape[0]))
  2443. def test_haar(self):
  2444. # Test that the distribution is constant under rotation
  2445. # Every column should have the same distribution
  2446. # Additionally, the distribution should be invariant under another rotation
  2447. # Generate samples
  2448. dim = 5
  2449. samples = 1000 # Not too many, or the test takes too long
  2450. ks_prob = .05
  2451. xs = special_ortho_group.rvs(
  2452. dim, size=samples, random_state=np.random.default_rng(513)
  2453. )
  2454. # Dot a few rows (0, 1, 2) with unit vectors (0, 2, 4, 3),
  2455. # effectively picking off entries in the matrices of xs.
  2456. # These projections should all have the same distribution,
  2457. # establishing rotational invariance. We use the two-sided
  2458. # KS test to confirm this.
  2459. # We could instead test that angles between random vectors
  2460. # are uniformly distributed, but the below is sufficient.
  2461. # It is not feasible to consider all pairs, so pick a few.
  2462. els = ((0,0), (0,2), (1,4), (2,3))
  2463. #proj = {(er, ec): [x[er][ec] for x in xs] for er, ec in els}
  2464. proj = {(er, ec): sorted([x[er][ec] for x in xs]) for er, ec in els}
  2465. pairs = [(e0, e1) for e0 in els for e1 in els if e0 > e1]
  2466. ks_tests = [ks_2samp(proj[p0], proj[p1])[1] for (p0, p1) in pairs]
  2467. assert_array_less([ks_prob]*len(pairs), ks_tests)
  2468. def test_one_by_one(self):
  2469. # Test that the distribution is a delta function at the identity matrix
  2470. # when dim=1
  2471. assert_allclose(special_ortho_group.rvs(1, size=1000), 1, rtol=1e-13)
  2472. def test_zero_by_zero(self):
  2473. assert_equal(special_ortho_group.rvs(0, size=4).shape, (4, 0, 0))
  2474. class TestOrthoGroup:
  2475. def test_reproducibility(self):
  2476. seed = 514
  2477. rng = np.random.RandomState(seed)
  2478. x = ortho_group.rvs(3, random_state=rng)
  2479. x2 = ortho_group.rvs(3, random_state=seed)
  2480. # Note this matrix has det -1, distinguishing O(N) from SO(N)
  2481. assert_almost_equal(np.linalg.det(x), -1)
  2482. expected = np.array([[0.381686, -0.090374, 0.919863],
  2483. [0.905794, -0.161537, -0.391718],
  2484. [-0.183993, -0.98272, -0.020204]])
  2485. assert_array_almost_equal(x, expected)
  2486. assert_array_almost_equal(x2, expected)
  2487. def test_invalid_dim(self):
  2488. assert_raises(ValueError, ortho_group.rvs, None)
  2489. assert_raises(ValueError, ortho_group.rvs, (2, 2))
  2490. assert_raises(ValueError, ortho_group.rvs, -1)
  2491. assert_raises(ValueError, ortho_group.rvs, 2.5)
  2492. def test_frozen_matrix(self):
  2493. dim = 7
  2494. frozen = ortho_group(dim)
  2495. frozen_seed = ortho_group(dim, seed=1234)
  2496. rvs1 = frozen.rvs(random_state=1234)
  2497. rvs2 = ortho_group.rvs(dim, random_state=1234)
  2498. rvs3 = frozen_seed.rvs(size=1)
  2499. assert_equal(rvs1, rvs2)
  2500. assert_equal(rvs1, rvs3)
  2501. def test_det_and_ortho(self):
  2502. xs = [[ortho_group.rvs(dim)
  2503. for i in range(10)]
  2504. for dim in range(2,12)]
  2505. # Test that abs determinants are always +1
  2506. dets = np.array([[np.linalg.det(x) for x in xx] for xx in xs])
  2507. assert_allclose(np.fabs(dets), np.ones(dets.shape), rtol=1e-13)
  2508. # Test that these are orthogonal matrices
  2509. for xx in xs:
  2510. for x in xx:
  2511. assert_array_almost_equal(np.dot(x, x.T),
  2512. np.eye(x.shape[0]))
  2513. @pytest.mark.parametrize("dim", [2, 5, 10, 20])
  2514. def test_det_distribution_gh18272(self, dim):
  2515. # Test that positive and negative determinants are equally likely.
  2516. rng = np.random.default_rng(6796248956179332344)
  2517. dist = ortho_group(dim=dim)
  2518. rvs = dist.rvs(size=5000, random_state=rng)
  2519. dets = scipy.linalg.det(rvs)
  2520. k = np.sum(dets > 0)
  2521. n = len(dets)
  2522. res = stats.binomtest(k, n)
  2523. low, high = res.proportion_ci(confidence_level=0.95)
  2524. assert low < 0.5 < high
  2525. def test_haar(self):
  2526. # Test that the distribution is constant under rotation
  2527. # Every column should have the same distribution
  2528. # Additionally, the distribution should be invariant under another rotation
  2529. # Generate samples
  2530. dim = 5
  2531. samples = 1000 # Not too many, or the test takes too long
  2532. ks_prob = .05
  2533. rng = np.random.RandomState(518) # Note that the test is sensitive to seed too
  2534. xs = ortho_group.rvs(dim, size=samples, random_state=rng)
  2535. # Dot a few rows (0, 1, 2) with unit vectors (0, 2, 4, 3),
  2536. # effectively picking off entries in the matrices of xs.
  2537. # These projections should all have the same distribution,
  2538. # establishing rotational invariance. We use the two-sided
  2539. # KS test to confirm this.
  2540. # We could instead test that angles between random vectors
  2541. # are uniformly distributed, but the below is sufficient.
  2542. # It is not feasible to consider all pairs, so pick a few.
  2543. els = ((0,0), (0,2), (1,4), (2,3))
  2544. #proj = {(er, ec): [x[er][ec] for x in xs] for er, ec in els}
  2545. proj = {(er, ec): sorted([x[er][ec] for x in xs]) for er, ec in els}
  2546. pairs = [(e0, e1) for e0 in els for e1 in els if e0 > e1]
  2547. ks_tests = [ks_2samp(proj[p0], proj[p1])[1] for (p0, p1) in pairs]
  2548. assert_array_less([ks_prob]*len(pairs), ks_tests)
  2549. def test_one_by_one(self):
  2550. # Test that the 1x1 distribution gives ±1 with equal probability.
  2551. dim = 1
  2552. xs = ortho_group.rvs(dim, size=5000, random_state=np.random.default_rng(514))
  2553. assert_allclose(np.abs(xs), 1, rtol=1e-13)
  2554. k = np.sum(xs > 0)
  2555. n = len(xs)
  2556. res = stats.binomtest(k, n)
  2557. low, high = res.proportion_ci(confidence_level=0.95)
  2558. assert low < 0.5 < high
  2559. def test_zero_by_zero(self):
  2560. assert_equal(special_ortho_group.rvs(0, size=4).shape, (4, 0, 0))
  2561. @pytest.mark.slow
  2562. def test_pairwise_distances(self):
  2563. # Test that the distribution of pairwise distances is close to correct.
  2564. rng = np.random.RandomState(514)
  2565. def random_ortho(dim, random_state=None):
  2566. u, _s, v = np.linalg.svd(rng.normal(size=(dim, dim)))
  2567. return np.dot(u, v)
  2568. for dim in range(2, 6):
  2569. def generate_test_statistics(rvs, N=1000, eps=1e-10):
  2570. stats = np.array([
  2571. np.sum((rvs(dim=dim, random_state=rng) -
  2572. rvs(dim=dim, random_state=rng))**2)
  2573. for _ in range(N)
  2574. ])
  2575. # Add a bit of noise to account for numeric accuracy.
  2576. stats += np.random.uniform(-eps, eps, size=stats.shape)
  2577. return stats
  2578. expected = generate_test_statistics(random_ortho)
  2579. actual = generate_test_statistics(scipy.stats.ortho_group.rvs)
  2580. _D, p = scipy.stats.ks_2samp(expected, actual)
  2581. assert_array_less(.05, p)
  2582. class TestRandomCorrelation:
  2583. def test_reproducibility(self):
  2584. rng = np.random.RandomState(514)
  2585. eigs = (.5, .8, 1.2, 1.5)
  2586. x = random_correlation.rvs(eigs, random_state=rng)
  2587. x2 = random_correlation.rvs(eigs, random_state=514)
  2588. expected = np.array([[1., -0.184851, 0.109017, -0.227494],
  2589. [-0.184851, 1., 0.231236, 0.326669],
  2590. [0.109017, 0.231236, 1., -0.178912],
  2591. [-0.227494, 0.326669, -0.178912, 1.]])
  2592. assert_array_almost_equal(x, expected)
  2593. assert_array_almost_equal(x2, expected)
  2594. def test_invalid_eigs(self):
  2595. assert_raises(ValueError, random_correlation.rvs, None)
  2596. assert_raises(ValueError, random_correlation.rvs, 'test')
  2597. assert_raises(ValueError, random_correlation.rvs, 2.5)
  2598. assert_raises(ValueError, random_correlation.rvs, [2.5])
  2599. assert_raises(ValueError, random_correlation.rvs, [[1,2],[3,4]])
  2600. assert_raises(ValueError, random_correlation.rvs, [2.5, -.5])
  2601. assert_raises(ValueError, random_correlation.rvs, [1, 2, .1])
  2602. def test_frozen_matrix(self):
  2603. eigs = (.5, .8, 1.2, 1.5)
  2604. frozen = random_correlation(eigs)
  2605. frozen_seed = random_correlation(eigs, seed=514)
  2606. rvs1 = random_correlation.rvs(eigs, random_state=514)
  2607. rvs2 = frozen.rvs(random_state=514)
  2608. rvs3 = frozen_seed.rvs()
  2609. assert_equal(rvs1, rvs2)
  2610. assert_equal(rvs1, rvs3)
  2611. def test_definition(self):
  2612. # Test the definition of a correlation matrix in several dimensions:
  2613. #
  2614. # 1. Det is product of eigenvalues (and positive by construction
  2615. # in examples)
  2616. # 2. 1's on diagonal
  2617. # 3. Matrix is symmetric
  2618. def norm(i, e):
  2619. return i*e/sum(e)
  2620. rng = np.random.RandomState(123)
  2621. eigs = [norm(i, rng.uniform(size=i)) for i in range(2, 6)]
  2622. eigs.append([4,0,0,0])
  2623. ones = [[1.]*len(e) for e in eigs]
  2624. xs = [random_correlation.rvs(e, random_state=rng) for e in eigs]
  2625. # Test that determinants are products of eigenvalues
  2626. # These are positive by construction
  2627. # Could also test that the eigenvalues themselves are correct,
  2628. # but this seems sufficient.
  2629. dets = [np.fabs(np.linalg.det(x)) for x in xs]
  2630. dets_known = [np.prod(e) for e in eigs]
  2631. assert_allclose(dets, dets_known, rtol=1e-13, atol=1e-13)
  2632. # Test for 1's on the diagonal
  2633. diags = [np.diag(x) for x in xs]
  2634. for a, b in zip(diags, ones):
  2635. assert_allclose(a, b, rtol=1e-13)
  2636. # Correlation matrices are symmetric
  2637. for x in xs:
  2638. assert_allclose(x, x.T, rtol=1e-13)
  2639. def test_to_corr(self):
  2640. # Check some corner cases in to_corr
  2641. # ajj == 1
  2642. m = np.array([[0.1, 0], [0, 1]], dtype=float)
  2643. m = random_correlation._to_corr(m)
  2644. assert_allclose(m, np.array([[1, 0], [0, 0.1]]))
  2645. # Floating point overflow; fails to compute the correct
  2646. # rotation, but should still produce some valid rotation
  2647. # rather than infs/nans
  2648. with np.errstate(over='ignore'):
  2649. g = np.array([[0, 1], [-1, 0]])
  2650. m0 = np.array([[1e300, 0], [0, np.nextafter(1, 0)]], dtype=float)
  2651. m = random_correlation._to_corr(m0.copy())
  2652. assert_allclose(m, g.T.dot(m0).dot(g))
  2653. m0 = np.array([[0.9, 1e300], [1e300, 1.1]], dtype=float)
  2654. m = random_correlation._to_corr(m0.copy())
  2655. assert_allclose(m, g.T.dot(m0).dot(g))
  2656. # Zero discriminant; should set the first diag entry to 1
  2657. m0 = np.array([[2, 1], [1, 2]], dtype=float)
  2658. m = random_correlation._to_corr(m0.copy())
  2659. assert_allclose(m[0,0], 1)
  2660. # Slightly negative discriminant; should be approx correct still
  2661. m0 = np.array([[2 + 1e-7, 1], [1, 2]], dtype=float)
  2662. m = random_correlation._to_corr(m0.copy())
  2663. assert_allclose(m[0,0], 1)
  2664. class TestUniformDirection:
  2665. @pytest.mark.parametrize("dim", [1, 3])
  2666. @pytest.mark.parametrize("size", [None, 1, 5, (5, 4)])
  2667. def test_samples(self, dim, size):
  2668. # test that samples have correct shape and norm 1
  2669. rng = np.random.default_rng(2777937887058094419)
  2670. uniform_direction_dist = uniform_direction(dim, seed=rng)
  2671. samples = uniform_direction_dist.rvs(size)
  2672. mean, cov = np.zeros(dim), np.eye(dim)
  2673. expected_shape = rng.multivariate_normal(mean, cov, size=size).shape
  2674. assert samples.shape == expected_shape
  2675. norms = np.linalg.norm(samples, axis=-1)
  2676. assert_allclose(norms, 1.)
  2677. @pytest.mark.parametrize("dim", [None, 0, (2, 2), 2.5])
  2678. def test_invalid_dim(self, dim):
  2679. message = ("Dimension of vector must be specified, "
  2680. "and must be an integer greater than 0.")
  2681. with pytest.raises(ValueError, match=message):
  2682. uniform_direction.rvs(dim)
  2683. def test_frozen_distribution(self):
  2684. dim = 5
  2685. frozen = uniform_direction(dim)
  2686. frozen_seed = uniform_direction(dim, seed=514)
  2687. rvs1 = frozen.rvs(random_state=514)
  2688. rvs2 = uniform_direction.rvs(dim, random_state=514)
  2689. rvs3 = frozen_seed.rvs()
  2690. assert_equal(rvs1, rvs2)
  2691. assert_equal(rvs1, rvs3)
  2692. @pytest.mark.parametrize("dim", [2, 5, 8])
  2693. def test_uniform(self, dim):
  2694. rng = np.random.default_rng(1036978481269651776)
  2695. spherical_dist = uniform_direction(dim, seed=rng)
  2696. # generate random, orthogonal vectors
  2697. v1, v2 = spherical_dist.rvs(size=2)
  2698. v2 -= v1 @ v2 * v1
  2699. v2 /= np.linalg.norm(v2)
  2700. assert_allclose(v1 @ v2, 0, atol=1e-14) # orthogonal
  2701. # generate data and project onto orthogonal vectors
  2702. samples = spherical_dist.rvs(size=10000)
  2703. s1 = samples @ v1
  2704. s2 = samples @ v2
  2705. angles = np.arctan2(s1, s2)
  2706. # test that angles follow a uniform distribution
  2707. # normalize angles to range [0, 1]
  2708. angles += np.pi
  2709. angles /= 2*np.pi
  2710. # perform KS test
  2711. uniform_dist = uniform()
  2712. kstest_result = kstest(angles, uniform_dist.cdf)
  2713. assert kstest_result.pvalue > 0.05
  2714. class TestUnitaryGroup:
  2715. def test_reproducibility(self):
  2716. rng = np.random.RandomState(514)
  2717. x = unitary_group.rvs(3, random_state=rng)
  2718. x2 = unitary_group.rvs(3, random_state=514)
  2719. expected = np.array(
  2720. [[0.308771+0.360312j, 0.044021+0.622082j, 0.160327+0.600173j],
  2721. [0.732757+0.297107j, 0.076692-0.4614j, -0.394349+0.022613j],
  2722. [-0.148844+0.357037j, -0.284602-0.557949j, 0.607051+0.299257j]]
  2723. )
  2724. assert_array_almost_equal(x, expected)
  2725. assert_array_almost_equal(x2, expected)
  2726. def test_invalid_dim(self):
  2727. assert_raises(ValueError, unitary_group.rvs, None)
  2728. assert_raises(ValueError, unitary_group.rvs, (2, 2))
  2729. assert_raises(ValueError, unitary_group.rvs, -1)
  2730. assert_raises(ValueError, unitary_group.rvs, 2.5)
  2731. def test_frozen_matrix(self):
  2732. dim = 7
  2733. frozen = unitary_group(dim)
  2734. frozen_seed = unitary_group(dim, seed=514)
  2735. rvs1 = frozen.rvs(random_state=514)
  2736. rvs2 = unitary_group.rvs(dim, random_state=514)
  2737. rvs3 = frozen_seed.rvs(size=1)
  2738. assert_equal(rvs1, rvs2)
  2739. assert_equal(rvs1, rvs3)
  2740. def test_unitarity(self):
  2741. xs = [unitary_group.rvs(dim)
  2742. for dim in range(2,12)
  2743. for i in range(3)]
  2744. # Test that these are unitary matrices
  2745. for x in xs:
  2746. assert_allclose(np.dot(x, x.conj().T), np.eye(x.shape[0]), atol=1e-15)
  2747. def test_haar(self):
  2748. # Test that the eigenvalues, which lie on the unit circle in
  2749. # the complex plane, are uncorrelated.
  2750. # Generate samples
  2751. for dim in (1, 5):
  2752. samples = 1000 # Not too many, or the test takes too long
  2753. # Note that the test is sensitive to seed too
  2754. xs = unitary_group.rvs(
  2755. dim, size=samples, random_state=np.random.default_rng(514)
  2756. )
  2757. # The angles "x" of the eigenvalues should be uniformly distributed
  2758. # Overall this seems to be a necessary but weak test of the distribution.
  2759. eigs = np.vstack([scipy.linalg.eigvals(x) for x in xs])
  2760. x = np.arctan2(eigs.imag, eigs.real)
  2761. res = kstest(x.ravel(), uniform(-np.pi, 2*np.pi).cdf)
  2762. assert_(res.pvalue > 0.05)
  2763. def test_zero_by_zero(self):
  2764. assert_equal(unitary_group.rvs(0, size=4).shape, (4, 0, 0))
  2765. class TestMultivariateT:
  2766. # These tests were created by running vpa(mvtpdf(...)) in MATLAB. The
  2767. # function takes no `mu` parameter. The tests were run as
  2768. #
  2769. # >> ans = vpa(mvtpdf(x - mu, shape, df));
  2770. #
  2771. PDF_TESTS = [(
  2772. # x
  2773. [
  2774. [1, 2],
  2775. [4, 1],
  2776. [2, 1],
  2777. [2, 4],
  2778. [1, 4],
  2779. [4, 1],
  2780. [3, 2],
  2781. [3, 3],
  2782. [4, 4],
  2783. [5, 1],
  2784. ],
  2785. # loc
  2786. [0, 0],
  2787. # shape
  2788. [
  2789. [1, 0],
  2790. [0, 1]
  2791. ],
  2792. # df
  2793. 4,
  2794. # ans
  2795. [
  2796. 0.013972450422333741737457302178882,
  2797. 0.0010998721906793330026219646100571,
  2798. 0.013972450422333741737457302178882,
  2799. 0.00073682844024025606101402363634634,
  2800. 0.0010998721906793330026219646100571,
  2801. 0.0010998721906793330026219646100571,
  2802. 0.0020732579600816823488240725481546,
  2803. 0.00095660371505271429414668515889275,
  2804. 0.00021831953784896498569831346792114,
  2805. 0.00037725616140301147447000396084604
  2806. ]
  2807. ), (
  2808. # x
  2809. [
  2810. [0.9718, 0.1298, 0.8134],
  2811. [0.4922, 0.5522, 0.7185],
  2812. [0.3010, 0.1491, 0.5008],
  2813. [0.5971, 0.2585, 0.8940],
  2814. [0.5434, 0.5287, 0.9507],
  2815. ],
  2816. # loc
  2817. [-1, 1, 50],
  2818. # shape
  2819. [
  2820. [1.0000, 0.5000, 0.2500],
  2821. [0.5000, 1.0000, -0.1000],
  2822. [0.2500, -0.1000, 1.0000],
  2823. ],
  2824. # df
  2825. 8,
  2826. # ans
  2827. [
  2828. 0.00000000000000069609279697467772867405511133763,
  2829. 0.00000000000000073700739052207366474839369535934,
  2830. 0.00000000000000069522909962669171512174435447027,
  2831. 0.00000000000000074212293557998314091880208889767,
  2832. 0.00000000000000077039675154022118593323030449058,
  2833. ]
  2834. )]
  2835. @pytest.mark.parametrize("x, loc, shape, df, ans", PDF_TESTS)
  2836. def test_pdf_correctness(self, x, loc, shape, df, ans):
  2837. dist = multivariate_t(loc, shape, df, seed=0)
  2838. val = dist.pdf(x)
  2839. assert_array_almost_equal(val, ans)
  2840. @pytest.mark.parametrize("x, loc, shape, df, ans", PDF_TESTS)
  2841. def test_logpdf_correct(self, x, loc, shape, df, ans):
  2842. dist = multivariate_t(loc, shape, df, seed=0)
  2843. val1 = dist.pdf(x)
  2844. val2 = dist.logpdf(x)
  2845. assert_array_almost_equal(np.log(val1), val2)
  2846. # https://github.com/scipy/scipy/issues/10042#issuecomment-576795195
  2847. def test_mvt_with_df_one_is_cauchy(self):
  2848. x = [9, 7, 4, 1, -3, 9, 0, -3, -1, 3]
  2849. val = multivariate_t.pdf(x, df=1)
  2850. ans = cauchy.pdf(x)
  2851. assert_array_almost_equal(val, ans)
  2852. def test_mvt_with_high_df_is_approx_normal(self):
  2853. # `normaltest` returns the chi-squared statistic and the associated
  2854. # p-value. The null hypothesis is that `x` came from a normal
  2855. # distribution, so a low p-value represents rejecting the null, i.e.
  2856. # that it is unlikely that `x` came a normal distribution.
  2857. P_VAL_MIN = 0.1
  2858. dist = multivariate_t(0, 1, df=100000, seed=1)
  2859. samples = dist.rvs(size=100000)
  2860. _, p = normaltest(samples)
  2861. assert (p > P_VAL_MIN)
  2862. dist = multivariate_t([-2, 3], [[10, -1], [-1, 10]], df=100000,
  2863. seed=42)
  2864. samples = dist.rvs(size=100000)
  2865. _, p = normaltest(samples)
  2866. assert ((p > P_VAL_MIN).all())
  2867. @pytest.mark.thread_unsafe(reason="uses mocking")
  2868. @patch('scipy.stats.multivariate_normal._logpdf')
  2869. def test_mvt_with_inf_df_calls_normal(self, mock):
  2870. dist = multivariate_t(0, 1, df=np.inf, seed=7)
  2871. assert isinstance(dist, multivariate_normal_frozen)
  2872. multivariate_t.pdf(0, df=np.inf)
  2873. assert mock.call_count == 1
  2874. multivariate_t.logpdf(0, df=np.inf)
  2875. assert mock.call_count == 2
  2876. def test_shape_correctness(self):
  2877. # pdf and logpdf should return scalar when the
  2878. # number of samples in x is one.
  2879. dim = 4
  2880. loc = np.zeros(dim)
  2881. shape = np.eye(dim)
  2882. df = 4.5
  2883. x = np.zeros(dim)
  2884. res = multivariate_t(loc, shape, df).pdf(x)
  2885. assert np.isscalar(res)
  2886. res = multivariate_t(loc, shape, df).logpdf(x)
  2887. assert np.isscalar(res)
  2888. # pdf() and logpdf() should return probabilities of shape
  2889. # (n_samples,) when x has n_samples.
  2890. n_samples = 7
  2891. rng = np.random.default_rng(2767231913)
  2892. x = rng.random((n_samples, dim))
  2893. res = multivariate_t(loc, shape, df).pdf(x)
  2894. assert (res.shape == (n_samples,))
  2895. res = multivariate_t(loc, shape, df).logpdf(x)
  2896. assert (res.shape == (n_samples,))
  2897. # rvs() should return scalar unless a size argument is applied.
  2898. res = multivariate_t(np.zeros(1), np.eye(1), 1).rvs()
  2899. assert np.isscalar(res)
  2900. # rvs() should return vector of shape (size,) if size argument
  2901. # is applied.
  2902. size = 7
  2903. res = multivariate_t(np.zeros(1), np.eye(1), 1).rvs(size=size)
  2904. assert (res.shape == (size,))
  2905. def test_default_arguments(self):
  2906. dist = multivariate_t()
  2907. assert_equal(dist.loc, [0])
  2908. assert_equal(dist.shape, [[1]])
  2909. assert (dist.df == 1)
  2910. DEFAULT_ARGS_TESTS = [
  2911. (None, None, None, 0, 1, 1),
  2912. (None, None, 7, 0, 1, 7),
  2913. (None, [[7, 0], [0, 7]], None, [0, 0], [[7, 0], [0, 7]], 1),
  2914. (None, [[7, 0], [0, 7]], 7, [0, 0], [[7, 0], [0, 7]], 7),
  2915. ([7, 7], None, None, [7, 7], [[1, 0], [0, 1]], 1),
  2916. ([7, 7], None, 7, [7, 7], [[1, 0], [0, 1]], 7),
  2917. ([7, 7], [[7, 0], [0, 7]], None, [7, 7], [[7, 0], [0, 7]], 1),
  2918. ([7, 7], [[7, 0], [0, 7]], 7, [7, 7], [[7, 0], [0, 7]], 7)
  2919. ]
  2920. @pytest.mark.parametrize("loc, shape, df, loc_ans, shape_ans, df_ans",
  2921. DEFAULT_ARGS_TESTS)
  2922. def test_default_args(self, loc, shape, df, loc_ans, shape_ans, df_ans):
  2923. dist = multivariate_t(loc=loc, shape=shape, df=df)
  2924. assert_equal(dist.loc, loc_ans)
  2925. assert_equal(dist.shape, shape_ans)
  2926. assert (dist.df == df_ans)
  2927. ARGS_SHAPES_TESTS = [
  2928. (-1, 2, 3, [-1], [[2]], 3),
  2929. ([-1], [2], 3, [-1], [[2]], 3),
  2930. (np.array([-1]), np.array([2]), 3, [-1], [[2]], 3)
  2931. ]
  2932. @pytest.mark.parametrize("loc, shape, df, loc_ans, shape_ans, df_ans",
  2933. ARGS_SHAPES_TESTS)
  2934. def test_scalar_list_and_ndarray_arguments(self, loc, shape, df, loc_ans,
  2935. shape_ans, df_ans):
  2936. dist = multivariate_t(loc, shape, df)
  2937. assert_equal(dist.loc, loc_ans)
  2938. assert_equal(dist.shape, shape_ans)
  2939. assert_equal(dist.df, df_ans)
  2940. def test_argument_error_handling(self):
  2941. # `loc` should be a one-dimensional vector.
  2942. loc = [[1, 1]]
  2943. assert_raises(ValueError,
  2944. multivariate_t,
  2945. **dict(loc=loc))
  2946. # `shape` should be scalar or square matrix.
  2947. shape = [[1, 1], [2, 2], [3, 3]]
  2948. assert_raises(ValueError,
  2949. multivariate_t,
  2950. **dict(loc=loc, shape=shape))
  2951. # `df` should be greater than zero.
  2952. loc = np.zeros(2)
  2953. shape = np.eye(2)
  2954. df = -1
  2955. assert_raises(ValueError,
  2956. multivariate_t,
  2957. **dict(loc=loc, shape=shape, df=df))
  2958. df = 0
  2959. assert_raises(ValueError,
  2960. multivariate_t,
  2961. **dict(loc=loc, shape=shape, df=df))
  2962. def test_reproducibility(self):
  2963. rng = np.random.RandomState(4)
  2964. loc = rng.uniform(size=3)
  2965. shape = np.eye(3)
  2966. dist1 = multivariate_t(loc, shape, df=3, seed=2)
  2967. dist2 = multivariate_t(loc, shape, df=3, seed=2)
  2968. samples1 = dist1.rvs(size=10)
  2969. samples2 = dist2.rvs(size=10)
  2970. assert_equal(samples1, samples2)
  2971. def test_allow_singular(self):
  2972. # Make shape singular and verify error was raised.
  2973. args = dict(loc=[0,0], shape=[[0,0],[0,1]], df=1, allow_singular=False)
  2974. assert_raises(np.linalg.LinAlgError, multivariate_t, **args)
  2975. @pytest.mark.parametrize("size", [(10, 3), (5, 6, 4, 3)])
  2976. @pytest.mark.parametrize("dim", [2, 3, 4, 5])
  2977. @pytest.mark.parametrize("df", [1., 2., np.inf])
  2978. def test_rvs(self, size, dim, df):
  2979. dist = multivariate_t(np.zeros(dim), np.eye(dim), df)
  2980. rvs = dist.rvs(size=size)
  2981. assert rvs.shape == size + (dim, )
  2982. def test_cdf_signs(self):
  2983. # check that sign of output is correct when np.any(lower > x)
  2984. mean = np.zeros(3)
  2985. cov = np.eye(3)
  2986. df = 10
  2987. b = [[1, 1, 1], [0, 0, 0], [1, 0, 1], [0, 1, 0]]
  2988. a = [[0, 0, 0], [1, 1, 1], [0, 1, 0], [1, 0, 1]]
  2989. # when odd number of elements of b < a, output is negative
  2990. expected_signs = np.array([1, -1, -1, 1])
  2991. cdf = multivariate_normal.cdf(b, mean, cov, df, lower_limit=a)
  2992. assert_allclose(cdf, cdf[0]*expected_signs)
  2993. @pytest.mark.parametrize('dim', [1, 2, 5])
  2994. def test_cdf_against_multivariate_normal(self, dim):
  2995. # Check accuracy against MVN randomly-generated cases
  2996. self.cdf_against_mvn_test(dim)
  2997. @pytest.mark.parametrize('dim', [3, 6, 9])
  2998. def test_cdf_against_multivariate_normal_singular(self, dim):
  2999. # Check accuracy against MVN for randomly-generated singular cases
  3000. self.cdf_against_mvn_test(3, True)
  3001. def cdf_against_mvn_test(self, dim, singular=False):
  3002. # Check for accuracy in the limit that df -> oo and MVT -> MVN
  3003. rng = np.random.default_rng(413722918996573)
  3004. n = 3
  3005. w = 10**rng.uniform(-2, 1, size=dim)
  3006. cov = _random_covariance(dim, w, rng, singular)
  3007. mean = 10**rng.uniform(-1, 2, size=dim) * np.sign(rng.normal(size=dim))
  3008. a = -10**rng.uniform(-1, 2, size=(n, dim)) + mean
  3009. b = 10**rng.uniform(-1, 2, size=(n, dim)) + mean
  3010. res = stats.multivariate_t.cdf(b, mean, cov, df=10000, lower_limit=a,
  3011. allow_singular=True, random_state=rng)
  3012. ref = stats.multivariate_normal.cdf(b, mean, cov, allow_singular=True,
  3013. lower_limit=a)
  3014. assert_allclose(res, ref, atol=5e-4)
  3015. def test_cdf_against_univariate_t(self):
  3016. rng = np.random.default_rng(413722918996573)
  3017. cov = 2
  3018. mean = 0
  3019. x = rng.normal(size=10, scale=np.sqrt(cov))
  3020. df = 3
  3021. res = stats.multivariate_t.cdf(x, mean, cov, df, lower_limit=-np.inf,
  3022. random_state=rng)
  3023. ref = stats.t.cdf(x, df, mean, np.sqrt(cov))
  3024. incorrect = stats.norm.cdf(x, mean, np.sqrt(cov))
  3025. assert_allclose(res, ref, atol=5e-4) # close to t
  3026. assert np.all(np.abs(res - incorrect) > 1e-3) # not close to normal
  3027. @pytest.mark.parametrize("dim", [2, 3, 5, 10])
  3028. @pytest.mark.parametrize("seed", [3363958638, 7891119608, 3887698049,
  3029. 5013150848, 1495033423, 6170824608])
  3030. @pytest.mark.parametrize("singular", [False, True])
  3031. def test_cdf_against_qsimvtv(self, dim, seed, singular):
  3032. if singular and seed != 3363958638:
  3033. pytest.skip('Agreement with qsimvtv is not great in singular case')
  3034. rng = np.random.default_rng(seed)
  3035. w = 10**rng.uniform(-2, 2, size=dim)
  3036. cov = _random_covariance(dim, w, rng, singular)
  3037. mean = rng.random(dim)
  3038. a = -rng.random(dim)
  3039. b = rng.random(dim)
  3040. df = rng.random() * 5
  3041. # no lower limit
  3042. res = stats.multivariate_t.cdf(b, mean, cov, df, random_state=rng,
  3043. allow_singular=True)
  3044. with np.errstate(invalid='ignore'):
  3045. ref = _qsimvtv(20000, df, cov, np.inf*a, b - mean, rng)[0]
  3046. assert_allclose(res, ref, atol=2e-4, rtol=1e-3)
  3047. # with lower limit
  3048. res = stats.multivariate_t.cdf(b, mean, cov, df, lower_limit=a,
  3049. random_state=rng, allow_singular=True)
  3050. with np.errstate(invalid='ignore'):
  3051. ref = _qsimvtv(20000, df, cov, a - mean, b - mean, rng)[0]
  3052. assert_allclose(res, ref, atol=1e-4, rtol=1e-3)
  3053. @pytest.mark.slow
  3054. def test_cdf_against_generic_integrators(self):
  3055. # Compare result against generic numerical integrators
  3056. dim = 3
  3057. rng = np.random.default_rng(41372291899657)
  3058. w = 10 ** rng.uniform(-1, 1, size=dim)
  3059. cov = _random_covariance(dim, w, rng, singular=True)
  3060. mean = rng.random(dim)
  3061. a = -rng.random(dim)
  3062. b = rng.random(dim)
  3063. df = rng.random() * 5
  3064. res = stats.multivariate_t.cdf(b, mean, cov, df, random_state=rng,
  3065. lower_limit=a)
  3066. def integrand(x):
  3067. return stats.multivariate_t.pdf(x.T, mean, cov, df)
  3068. ref = qmc_quad(integrand, a, b, qrng=stats.qmc.Halton(d=dim, seed=rng))
  3069. assert_allclose(res, ref.integral, rtol=1e-3)
  3070. def integrand(*zyx):
  3071. return stats.multivariate_t.pdf(zyx[::-1], mean, cov, df)
  3072. ref = tplquad(integrand, a[0], b[0], a[1], b[1], a[2], b[2])
  3073. assert_allclose(res, ref[0], rtol=1e-3)
  3074. def test_against_matlab(self):
  3075. # Test against matlab mvtcdf:
  3076. # C = [6.21786909 0.2333667 7.95506077;
  3077. # 0.2333667 29.67390923 16.53946426;
  3078. # 7.95506077 16.53946426 19.17725252]
  3079. # df = 1.9559939787727658
  3080. # mvtcdf([0, 0, 0], C, df) % 0.2523
  3081. rng = np.random.default_rng(2967390923)
  3082. cov = np.array([[ 6.21786909, 0.2333667 , 7.95506077],
  3083. [ 0.2333667 , 29.67390923, 16.53946426],
  3084. [ 7.95506077, 16.53946426, 19.17725252]])
  3085. df = 1.9559939787727658
  3086. dist = stats.multivariate_t(shape=cov, df=df)
  3087. res = dist.cdf([0, 0, 0], random_state=rng)
  3088. ref = 0.2523
  3089. assert_allclose(res, ref, rtol=1e-3)
  3090. def test_frozen(self):
  3091. seed = 4137229573
  3092. rng = np.random.default_rng(seed)
  3093. loc = rng.uniform(size=3)
  3094. x = rng.uniform(size=3) + loc
  3095. shape = np.eye(3)
  3096. df = rng.random()
  3097. args = (loc, shape, df)
  3098. rng_frozen = np.random.default_rng(seed)
  3099. rng_unfrozen = np.random.default_rng(seed)
  3100. dist = stats.multivariate_t(*args, seed=rng_frozen)
  3101. assert_equal(dist.cdf(x),
  3102. multivariate_t.cdf(x, *args, random_state=rng_unfrozen))
  3103. def test_vectorized(self):
  3104. dim = 4
  3105. n = (2, 3)
  3106. rng = np.random.default_rng(413722918996573)
  3107. A = rng.random(size=(dim, dim))
  3108. cov = A @ A.T
  3109. mean = rng.random(dim)
  3110. x = rng.random(n + (dim,))
  3111. df = rng.random() * 5
  3112. res = stats.multivariate_t.cdf(x, mean, cov, df, random_state=rng)
  3113. def _cdf_1d(x):
  3114. return _qsimvtv(10000, df, cov, -np.inf*x, x-mean, rng)[0]
  3115. ref = np.apply_along_axis(_cdf_1d, -1, x)
  3116. assert_allclose(res, ref, atol=1e-4, rtol=1e-3)
  3117. @pytest.mark.parametrize("dim", (3, 7))
  3118. def test_against_analytical(self, dim):
  3119. rng = np.random.default_rng(413722918996573)
  3120. A = scipy.linalg.toeplitz(c=[1] + [0.5] * (dim - 1))
  3121. res = stats.multivariate_t(shape=A).cdf([0] * dim, random_state=rng)
  3122. ref = 1 / (dim + 1)
  3123. assert_allclose(res, ref, rtol=5e-5)
  3124. def test_entropy_inf_df(self):
  3125. cov = np.eye(3, 3)
  3126. df = np.inf
  3127. mvt_entropy = stats.multivariate_t.entropy(shape=cov, df=df)
  3128. mvn_entropy = stats.multivariate_normal.entropy(None, cov)
  3129. assert mvt_entropy == mvn_entropy
  3130. @pytest.mark.parametrize("df", [1, 10, 100])
  3131. def test_entropy_1d(self, df):
  3132. mvt_entropy = stats.multivariate_t.entropy(shape=1., df=df)
  3133. t_entropy = stats.t.entropy(df=df)
  3134. assert_allclose(mvt_entropy, t_entropy, rtol=1e-13)
  3135. # entropy reference values were computed via numerical integration
  3136. #
  3137. # def integrand(x, y, mvt):
  3138. # vec = np.array([x, y])
  3139. # return mvt.logpdf(vec) * mvt.pdf(vec)
  3140. # def multivariate_t_entropy_quad_2d(df, cov):
  3141. # dim = cov.shape[0]
  3142. # loc = np.zeros((dim, ))
  3143. # mvt = stats.multivariate_t(loc, cov, df)
  3144. # limit = 100
  3145. # return -integrate.dblquad(integrand, -limit, limit, -limit, limit,
  3146. # args=(mvt, ))[0]
  3147. @pytest.mark.parametrize("df, cov, ref, tol",
  3148. [(10, np.eye(2, 2), 3.0378770664093313, 1e-14),
  3149. (100, np.array([[0.5, 1], [1, 10]]),
  3150. 3.55102424550609, 1e-8)])
  3151. def test_entropy_vs_numerical_integration(self, df, cov, ref, tol):
  3152. loc = np.zeros((2, ))
  3153. mvt = stats.multivariate_t(loc, cov, df)
  3154. assert_allclose(mvt.entropy(), ref, rtol=tol)
  3155. @pytest.mark.parametrize(
  3156. "df, dim, ref, tol",
  3157. [
  3158. (10, 1, 1.5212624929756808, 1e-15),
  3159. (100, 1, 1.4289633653182439, 1e-13),
  3160. (500, 1, 1.420939531869349, 1e-14),
  3161. (1e20, 1, 1.4189385332046727, 1e-15),
  3162. (1e100, 1, 1.4189385332046727, 1e-15),
  3163. (10, 10, 15.069150450832911, 1e-15),
  3164. (1000, 10, 14.19936546446673, 1e-13),
  3165. (1e20, 10, 14.189385332046728, 1e-15),
  3166. (1e100, 10, 14.189385332046728, 1e-15),
  3167. (10, 100, 148.28902883192654, 1e-15),
  3168. (1000, 100, 141.99155538003762, 1e-14),
  3169. (1e20, 100, 141.8938533204673, 1e-15),
  3170. (1e100, 100, 141.8938533204673, 1e-15),
  3171. ]
  3172. )
  3173. def test_extreme_entropy(self, df, dim, ref, tol):
  3174. # Reference values were calculated with mpmath:
  3175. # from mpmath import mp
  3176. # mp.dps = 500
  3177. #
  3178. # def mul_t_mpmath_entropy(dim, df=1):
  3179. # dim = mp.mpf(dim)
  3180. # df = mp.mpf(df)
  3181. # halfsum = (dim + df)/2
  3182. # half_df = df/2
  3183. #
  3184. # return float(
  3185. # -mp.loggamma(halfsum) + mp.loggamma(half_df)
  3186. # + dim / 2 * mp.log(df * mp.pi)
  3187. # + halfsum * (mp.digamma(halfsum) - mp.digamma(half_df))
  3188. # + 0.0
  3189. # )
  3190. mvt = stats.multivariate_t(shape=np.eye(dim), df=df)
  3191. assert_allclose(mvt.entropy(), ref, rtol=tol)
  3192. def test_entropy_with_covariance(self):
  3193. # Generated using np.randn(5, 5) and then rounding
  3194. # to two decimal places
  3195. _A = np.array([
  3196. [1.42, 0.09, -0.49, 0.17, 0.74],
  3197. [-1.13, -0.01, 0.71, 0.4, -0.56],
  3198. [1.07, 0.44, -0.28, -0.44, 0.29],
  3199. [-1.5, -0.94, -0.67, 0.73, -1.1],
  3200. [0.17, -0.08, 1.46, -0.32, 1.36]
  3201. ])
  3202. # Set cov to be a symmetric positive semi-definite matrix
  3203. cov = _A @ _A.T
  3204. # Test the asymptotic case. For large degrees of freedom
  3205. # the entropy approaches the multivariate normal entropy.
  3206. df = 1e20
  3207. mul_t_entropy = stats.multivariate_t.entropy(shape=cov, df=df)
  3208. mul_norm_entropy = multivariate_normal(None, cov=cov).entropy()
  3209. assert_allclose(mul_t_entropy, mul_norm_entropy, rtol=1e-15)
  3210. # Test the regular case. For a dim of 5 the threshold comes out
  3211. # to be approximately 766.45. So using slightly
  3212. # different dfs on each site of the threshold, the entropies
  3213. # are being compared.
  3214. df1 = 765
  3215. df2 = 768
  3216. _entropy1 = stats.multivariate_t.entropy(shape=cov, df=df1)
  3217. _entropy2 = stats.multivariate_t.entropy(shape=cov, df=df2)
  3218. assert_allclose(_entropy1, _entropy2, rtol=1e-5)
  3219. def test_logpdf_df_inf_gh19930(self):
  3220. # `multivariate_t._logpdf` (and `logpdf`/`pdf`) was not working with infinite
  3221. # `df` after an update to `multivariate_normal._logpdf`.
  3222. # Reproducible example from the issue
  3223. res = multivariate_t.logpdf(1, 1, 1, df=np.inf)
  3224. ref = multivariate_normal.logpdf(1, 1, 1)
  3225. assert_allclose(res, ref)
  3226. # More extensive test
  3227. # Generate a valid multivariate normal distribution and corresponding MVT
  3228. rng = np.random.default_rng(324893259825)
  3229. mean = rng.random(3)
  3230. cov = rng.random((3, 3)) + np.eye(3)*3
  3231. cov = cov.T + cov
  3232. X = multivariate_normal(mean=mean, cov=cov)
  3233. Y = multivariate_t(loc=mean, shape=cov, df=np.inf)
  3234. # compare the pdf and logpdf at 10 random points
  3235. x = X.rvs(10)
  3236. assert_allclose(Y.logpdf(x), X.logpdf(x))
  3237. assert_allclose(Y.pdf(x), X.pdf(x))
  3238. class TestMultivariateHypergeom:
  3239. @pytest.mark.parametrize(
  3240. "x, m, n, expected",
  3241. [
  3242. # Ground truth value from R dmvhyper
  3243. ([3, 4], [5, 10], 7, -1.119814),
  3244. # test for `n=0`
  3245. ([3, 4], [5, 10], 0, -np.inf),
  3246. # test for `x < 0`
  3247. ([-3, 4], [5, 10], 7, -np.inf),
  3248. # test for `m < 0` (RuntimeWarning issue)
  3249. ([3, 4], [-5, 10], 7, np.nan),
  3250. # test for all `m < 0` and `x.sum() != n`
  3251. ([[1, 2], [3, 4]], [[-4, -6], [-5, -10]],
  3252. [3, 7], [np.nan, np.nan]),
  3253. # test for `x < 0` and `m < 0` (RuntimeWarning issue)
  3254. ([-3, 4], [-5, 10], 1, np.nan),
  3255. # test for `x > m`
  3256. ([1, 11], [10, 1], 12, np.nan),
  3257. # test for `m < 0` (RuntimeWarning issue)
  3258. ([1, 11], [10, -1], 12, np.nan),
  3259. # test for `n < 0`
  3260. ([3, 4], [5, 10], -7, np.nan),
  3261. # test for `x.sum() != n`
  3262. ([3, 3], [5, 10], 7, -np.inf)
  3263. ]
  3264. )
  3265. def test_logpmf(self, x, m, n, expected):
  3266. vals = multivariate_hypergeom.logpmf(x, m, n)
  3267. assert_allclose(vals, expected, rtol=1e-6)
  3268. def test_reduces_hypergeom(self):
  3269. # test that the multivariate_hypergeom pmf reduces to the
  3270. # hypergeom pmf in the 2d case.
  3271. val1 = multivariate_hypergeom.pmf(x=[3, 1], m=[10, 5], n=4)
  3272. val2 = hypergeom.pmf(k=3, M=15, n=4, N=10)
  3273. assert_allclose(val1, val2, rtol=1e-8)
  3274. val1 = multivariate_hypergeom.pmf(x=[7, 3], m=[15, 10], n=10)
  3275. val2 = hypergeom.pmf(k=7, M=25, n=10, N=15)
  3276. assert_allclose(val1, val2, rtol=1e-8)
  3277. def test_rvs(self):
  3278. # test if `rvs` is unbiased and large sample size converges
  3279. # to the true mean.
  3280. rv = multivariate_hypergeom(m=[3, 5], n=4)
  3281. rvs = rv.rvs(size=1000, random_state=123)
  3282. assert_allclose(rvs.mean(0), rv.mean(), rtol=1e-2)
  3283. def test_rvs_broadcasting(self):
  3284. rv = multivariate_hypergeom(m=[[3, 5], [5, 10]], n=[4, 9])
  3285. rvs = rv.rvs(size=(1000, 2), random_state=123)
  3286. assert_allclose(rvs.mean(0), rv.mean(), rtol=1e-2)
  3287. @pytest.mark.parametrize('m, n', (
  3288. ([0, 0, 20, 0, 0], 5), ([0, 0, 0, 0, 0], 0),
  3289. ([0, 0], 0), ([0], 0)
  3290. ))
  3291. def test_rvs_gh16171(self, m, n):
  3292. res = multivariate_hypergeom.rvs(m, n)
  3293. m = np.asarray(m)
  3294. res_ex = m.copy()
  3295. res_ex[m != 0] = n
  3296. assert_equal(res, res_ex)
  3297. @pytest.mark.parametrize(
  3298. "x, m, n, expected",
  3299. [
  3300. ([5], [5], 5, 1),
  3301. ([3, 4], [5, 10], 7, 0.3263403),
  3302. # Ground truth value from R dmvhyper
  3303. ([[[3, 5], [0, 8]], [[-1, 9], [1, 1]]],
  3304. [5, 10], [[8, 8], [8, 2]],
  3305. [[0.3916084, 0.006993007], [0, 0.4761905]]),
  3306. # test with empty arrays.
  3307. (np.array([], dtype=int), np.array([], dtype=int), 0, []),
  3308. ([1, 2], [4, 5], 5, 0),
  3309. # Ground truth value from R dmvhyper
  3310. ([3, 3, 0], [5, 6, 7], 6, 0.01077354)
  3311. ]
  3312. )
  3313. def test_pmf(self, x, m, n, expected):
  3314. vals = multivariate_hypergeom.pmf(x, m, n)
  3315. assert_allclose(vals, expected, rtol=1e-7)
  3316. @pytest.mark.parametrize(
  3317. "x, m, n, expected",
  3318. [
  3319. ([3, 4], [[5, 10], [10, 15]], 7, [0.3263403, 0.3407531]),
  3320. ([[1], [2]], [[3], [4]], [1, 3], [1., 0.]),
  3321. ([[[1], [2]]], [[3], [4]], [1, 3], [[1., 0.]]),
  3322. ([[1], [2]], [[[[3]]]], [1, 3], [[[1., 0.]]])
  3323. ]
  3324. )
  3325. def test_pmf_broadcasting(self, x, m, n, expected):
  3326. vals = multivariate_hypergeom.pmf(x, m, n)
  3327. assert_allclose(vals, expected, rtol=1e-7)
  3328. def test_cov(self):
  3329. cov1 = multivariate_hypergeom.cov(m=[3, 7, 10], n=12)
  3330. cov2 = [[0.64421053, -0.26526316, -0.37894737],
  3331. [-0.26526316, 1.14947368, -0.88421053],
  3332. [-0.37894737, -0.88421053, 1.26315789]]
  3333. assert_allclose(cov1, cov2, rtol=1e-8)
  3334. def test_cov_broadcasting(self):
  3335. cov1 = multivariate_hypergeom.cov(m=[[7, 9], [10, 15]], n=[8, 12])
  3336. cov2 = [[[1.05, -1.05], [-1.05, 1.05]],
  3337. [[1.56, -1.56], [-1.56, 1.56]]]
  3338. assert_allclose(cov1, cov2, rtol=1e-8)
  3339. cov3 = multivariate_hypergeom.cov(m=[[4], [5]], n=[4, 5])
  3340. cov4 = [[[0.]], [[0.]]]
  3341. assert_allclose(cov3, cov4, rtol=1e-8)
  3342. cov5 = multivariate_hypergeom.cov(m=[7, 9], n=[8, 12])
  3343. cov6 = [[[1.05, -1.05], [-1.05, 1.05]],
  3344. [[0.7875, -0.7875], [-0.7875, 0.7875]]]
  3345. assert_allclose(cov5, cov6, rtol=1e-8)
  3346. def test_var(self):
  3347. # test with hypergeom
  3348. var0 = multivariate_hypergeom.var(m=[10, 5], n=4)
  3349. var1 = hypergeom.var(M=15, n=4, N=10)
  3350. assert_allclose(var0, var1, rtol=1e-8)
  3351. def test_var_broadcasting(self):
  3352. var0 = multivariate_hypergeom.var(m=[10, 5], n=[4, 8])
  3353. var1 = multivariate_hypergeom.var(m=[10, 5], n=4)
  3354. var2 = multivariate_hypergeom.var(m=[10, 5], n=8)
  3355. assert_allclose(var0[0], var1, rtol=1e-8)
  3356. assert_allclose(var0[1], var2, rtol=1e-8)
  3357. var3 = multivariate_hypergeom.var(m=[[10, 5], [10, 14]], n=[4, 8])
  3358. var4 = [[0.6984127, 0.6984127], [1.352657, 1.352657]]
  3359. assert_allclose(var3, var4, rtol=1e-8)
  3360. var5 = multivariate_hypergeom.var(m=[[5], [10]], n=[5, 10])
  3361. var6 = [[0.], [0.]]
  3362. assert_allclose(var5, var6, rtol=1e-8)
  3363. def test_mean(self):
  3364. # test with hypergeom
  3365. mean0 = multivariate_hypergeom.mean(m=[10, 5], n=4)
  3366. mean1 = hypergeom.mean(M=15, n=4, N=10)
  3367. assert_allclose(mean0[0], mean1, rtol=1e-8)
  3368. mean2 = multivariate_hypergeom.mean(m=[12, 8], n=10)
  3369. mean3 = [12.*10./20., 8.*10./20.]
  3370. assert_allclose(mean2, mean3, rtol=1e-8)
  3371. def test_mean_broadcasting(self):
  3372. mean0 = multivariate_hypergeom.mean(m=[[3, 5], [10, 5]], n=[4, 8])
  3373. mean1 = [[3.*4./8., 5.*4./8.], [10.*8./15., 5.*8./15.]]
  3374. assert_allclose(mean0, mean1, rtol=1e-8)
  3375. def test_mean_edge_cases(self):
  3376. mean0 = multivariate_hypergeom.mean(m=[0, 0, 0], n=0)
  3377. assert_equal(mean0, [0., 0., 0.])
  3378. mean1 = multivariate_hypergeom.mean(m=[1, 0, 0], n=2)
  3379. assert_equal(mean1, [np.nan, np.nan, np.nan])
  3380. mean2 = multivariate_hypergeom.mean(m=[[1, 0, 0], [1, 0, 1]], n=2)
  3381. assert_allclose(mean2, [[np.nan, np.nan, np.nan], [1., 0., 1.]],
  3382. rtol=1e-17)
  3383. mean3 = multivariate_hypergeom.mean(m=np.array([], dtype=int), n=0)
  3384. assert_equal(mean3, [])
  3385. assert_(mean3.shape == (0, ))
  3386. def test_var_edge_cases(self):
  3387. var0 = multivariate_hypergeom.var(m=[0, 0, 0], n=0)
  3388. assert_allclose(var0, [0., 0., 0.], rtol=1e-16)
  3389. var1 = multivariate_hypergeom.var(m=[1, 0, 0], n=2)
  3390. assert_equal(var1, [np.nan, np.nan, np.nan])
  3391. var2 = multivariate_hypergeom.var(m=[[1, 0, 0], [1, 0, 1]], n=2)
  3392. assert_allclose(var2, [[np.nan, np.nan, np.nan], [0., 0., 0.]],
  3393. rtol=1e-17)
  3394. var3 = multivariate_hypergeom.var(m=np.array([], dtype=int), n=0)
  3395. assert_equal(var3, [])
  3396. assert_(var3.shape == (0, ))
  3397. def test_cov_edge_cases(self):
  3398. cov0 = multivariate_hypergeom.cov(m=[1, 0, 0], n=1)
  3399. cov1 = [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]
  3400. assert_allclose(cov0, cov1, rtol=1e-17)
  3401. cov3 = multivariate_hypergeom.cov(m=[0, 0, 0], n=0)
  3402. cov4 = [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]
  3403. assert_equal(cov3, cov4)
  3404. cov5 = multivariate_hypergeom.cov(m=np.array([], dtype=int), n=0)
  3405. cov6 = np.array([], dtype=np.float64).reshape(0, 0)
  3406. assert_allclose(cov5, cov6, rtol=1e-17)
  3407. assert_(cov5.shape == (0, 0))
  3408. def test_frozen(self):
  3409. # The frozen distribution should agree with the regular one
  3410. n = 12
  3411. m = [7, 9, 11, 13]
  3412. x = [[0, 0, 0, 12], [0, 0, 1, 11], [0, 1, 1, 10],
  3413. [1, 1, 1, 9], [1, 1, 2, 8]]
  3414. x = np.asarray(x, dtype=int)
  3415. mhg_frozen = multivariate_hypergeom(m, n)
  3416. assert_allclose(mhg_frozen.pmf(x),
  3417. multivariate_hypergeom.pmf(x, m, n))
  3418. assert_allclose(mhg_frozen.logpmf(x),
  3419. multivariate_hypergeom.logpmf(x, m, n))
  3420. assert_allclose(mhg_frozen.var(), multivariate_hypergeom.var(m, n))
  3421. assert_allclose(mhg_frozen.cov(), multivariate_hypergeom.cov(m, n))
  3422. def test_invalid_params(self):
  3423. assert_raises(ValueError, multivariate_hypergeom.pmf, 5, 10, 5)
  3424. assert_raises(ValueError, multivariate_hypergeom.pmf, 5, [10], 5)
  3425. assert_raises(ValueError, multivariate_hypergeom.pmf, [5, 4], [10], 5)
  3426. assert_raises(TypeError, multivariate_hypergeom.pmf, [5.5, 4.5],
  3427. [10, 15], 5)
  3428. assert_raises(TypeError, multivariate_hypergeom.pmf, [5, 4],
  3429. [10.5, 15.5], 5)
  3430. assert_raises(TypeError, multivariate_hypergeom.pmf, [5, 4],
  3431. [10, 15], 5.5)
  3432. class TestRandomTable:
  3433. def get_rng(self):
  3434. return np.random.default_rng(628174795866951638)
  3435. def test_process_parameters(self):
  3436. message = "`row` must be one-dimensional"
  3437. with pytest.raises(ValueError, match=message):
  3438. random_table([[1, 2]], [1, 2])
  3439. message = "`col` must be one-dimensional"
  3440. with pytest.raises(ValueError, match=message):
  3441. random_table([1, 2], [[1, 2]])
  3442. message = "each element of `row` must be non-negative"
  3443. with pytest.raises(ValueError, match=message):
  3444. random_table([1, -1], [1, 2])
  3445. message = "each element of `col` must be non-negative"
  3446. with pytest.raises(ValueError, match=message):
  3447. random_table([1, 2], [1, -2])
  3448. message = "sums over `row` and `col` must be equal"
  3449. with pytest.raises(ValueError, match=message):
  3450. random_table([1, 2], [1, 0])
  3451. message = "each element of `row` must be an integer"
  3452. with pytest.raises(ValueError, match=message):
  3453. random_table([2.1, 2.1], [1, 1, 2])
  3454. message = "each element of `col` must be an integer"
  3455. with pytest.raises(ValueError, match=message):
  3456. random_table([1, 2], [1.1, 1.1, 1])
  3457. row = [1, 3]
  3458. col = [2, 1, 1]
  3459. r, c, n = random_table._process_parameters([1, 3], [2, 1, 1])
  3460. assert_equal(row, r)
  3461. assert_equal(col, c)
  3462. assert n == np.sum(row)
  3463. @pytest.mark.parametrize("scale,method",
  3464. ((1, "boyett"), (100, "patefield")))
  3465. def test_process_rvs_method_on_None(self, scale, method):
  3466. row = np.array([1, 3]) * scale
  3467. col = np.array([2, 1, 1]) * scale
  3468. ct = random_table
  3469. expected = ct.rvs(row, col, method=method, random_state=1)
  3470. got = ct.rvs(row, col, method=None, random_state=1)
  3471. assert_equal(expected, got)
  3472. def test_process_rvs_method_bad_argument(self):
  3473. row = [1, 3]
  3474. col = [2, 1, 1]
  3475. # order of items in set is random, so cannot check that
  3476. message = "'foo' not recognized, must be one of"
  3477. with pytest.raises(ValueError, match=message):
  3478. random_table.rvs(row, col, method="foo")
  3479. @pytest.mark.parametrize('frozen', (True, False))
  3480. @pytest.mark.parametrize('log', (True, False))
  3481. def test_pmf_logpmf(self, frozen, log):
  3482. # The pmf is tested through random sample generation
  3483. # with Boyett's algorithm, whose implementation is simple
  3484. # enough to verify manually for correctness.
  3485. rng = self.get_rng()
  3486. row = [2, 6]
  3487. col = [1, 3, 4]
  3488. rvs = random_table.rvs(row, col, size=1000,
  3489. method="boyett", random_state=rng)
  3490. obj = random_table(row, col) if frozen else random_table
  3491. method = getattr(obj, "logpmf" if log else "pmf")
  3492. if not frozen:
  3493. original_method = method
  3494. def method(x):
  3495. return original_method(x, row, col)
  3496. pmf = (lambda x: np.exp(method(x))) if log else method
  3497. unique_rvs, counts = np.unique(rvs, axis=0, return_counts=True)
  3498. # rough accuracy check
  3499. p = pmf(unique_rvs)
  3500. assert_allclose(p * len(rvs), counts, rtol=0.1)
  3501. # accept any iterable
  3502. p2 = pmf(list(unique_rvs[0]))
  3503. assert_equal(p2, p[0])
  3504. # accept high-dimensional input and 2d input
  3505. rvs_nd = rvs.reshape((10, 100) + rvs.shape[1:])
  3506. p = pmf(rvs_nd)
  3507. assert p.shape == (10, 100)
  3508. for i in range(p.shape[0]):
  3509. for j in range(p.shape[1]):
  3510. pij = p[i, j]
  3511. rvij = rvs_nd[i, j]
  3512. qij = pmf(rvij)
  3513. assert_equal(pij, qij)
  3514. # probability is zero if column marginal does not match
  3515. x = [[0, 1, 1], [2, 1, 3]]
  3516. assert_equal(np.sum(x, axis=-1), row)
  3517. p = pmf(x)
  3518. assert p == 0
  3519. # probability is zero if row marginal does not match
  3520. x = [[0, 1, 2], [1, 2, 2]]
  3521. assert_equal(np.sum(x, axis=-2), col)
  3522. p = pmf(x)
  3523. assert p == 0
  3524. # response to invalid inputs
  3525. message = "`x` must be at least two-dimensional"
  3526. with pytest.raises(ValueError, match=message):
  3527. pmf([1])
  3528. message = "`x` must contain only integral values"
  3529. with pytest.raises(ValueError, match=message):
  3530. pmf([[1.1]])
  3531. message = "`x` must contain only integral values"
  3532. with pytest.raises(ValueError, match=message):
  3533. pmf([[np.nan]])
  3534. message = "`x` must contain only non-negative values"
  3535. with pytest.raises(ValueError, match=message):
  3536. pmf([[-1]])
  3537. message = "shape of `x` must agree with `row`"
  3538. with pytest.raises(ValueError, match=message):
  3539. pmf([[1, 2, 3]])
  3540. message = "shape of `x` must agree with `col`"
  3541. with pytest.raises(ValueError, match=message):
  3542. pmf([[1, 2],
  3543. [3, 4]])
  3544. @pytest.mark.parametrize("method", ("boyett", "patefield"))
  3545. def test_rvs_mean(self, method):
  3546. # test if `rvs` is unbiased and large sample size converges
  3547. # to the true mean.
  3548. rng = self.get_rng()
  3549. row = [2, 6]
  3550. col = [1, 3, 4]
  3551. rvs = random_table.rvs(row, col, size=1000, method=method,
  3552. random_state=rng)
  3553. mean = random_table.mean(row, col)
  3554. assert_equal(np.sum(mean), np.sum(row))
  3555. assert_allclose(rvs.mean(0), mean, atol=0.05)
  3556. assert_equal(rvs.sum(axis=-1), np.broadcast_to(row, (1000, 2)))
  3557. assert_equal(rvs.sum(axis=-2), np.broadcast_to(col, (1000, 3)))
  3558. def test_rvs_cov(self):
  3559. # test if `rvs` generated with patefield and boyett algorithms
  3560. # produce approximately the same covariance matrix
  3561. rng = self.get_rng()
  3562. row = [2, 6]
  3563. col = [1, 3, 4]
  3564. rvs1 = random_table.rvs(row, col, size=10000, method="boyett",
  3565. random_state=rng)
  3566. rvs2 = random_table.rvs(row, col, size=10000, method="patefield",
  3567. random_state=rng)
  3568. cov1 = np.var(rvs1, axis=0)
  3569. cov2 = np.var(rvs2, axis=0)
  3570. assert_allclose(cov1, cov2, atol=0.02)
  3571. @pytest.mark.parametrize("method", ("boyett", "patefield"))
  3572. def test_rvs_size(self, method):
  3573. row = [2, 6]
  3574. col = [1, 3, 4]
  3575. # test size `None`
  3576. rv = random_table.rvs(row, col, method=method,
  3577. random_state=self.get_rng())
  3578. assert rv.shape == (2, 3)
  3579. # test size 1
  3580. rv2 = random_table.rvs(row, col, size=1, method=method,
  3581. random_state=self.get_rng())
  3582. assert rv2.shape == (1, 2, 3)
  3583. assert_equal(rv, rv2[0])
  3584. # test size 0
  3585. rv3 = random_table.rvs(row, col, size=0, method=method,
  3586. random_state=self.get_rng())
  3587. assert rv3.shape == (0, 2, 3)
  3588. # test other valid size
  3589. rv4 = random_table.rvs(row, col, size=20, method=method,
  3590. random_state=self.get_rng())
  3591. assert rv4.shape == (20, 2, 3)
  3592. rv5 = random_table.rvs(row, col, size=(4, 5), method=method,
  3593. random_state=self.get_rng())
  3594. assert rv5.shape == (4, 5, 2, 3)
  3595. assert_allclose(rv5.reshape(20, 2, 3), rv4, rtol=1e-15)
  3596. # test invalid size
  3597. message = "`size` must be a non-negative integer or `None`"
  3598. with pytest.raises(ValueError, match=message):
  3599. random_table.rvs(row, col, size=-1, method=method,
  3600. random_state=self.get_rng())
  3601. with pytest.raises(ValueError, match=message):
  3602. random_table.rvs(row, col, size=np.nan, method=method,
  3603. random_state=self.get_rng())
  3604. @pytest.mark.parametrize("method", ("boyett", "patefield"))
  3605. def test_rvs_method(self, method):
  3606. # This test assumes that pmf is correct and checks that random samples
  3607. # follow this probability distribution. This seems like a circular
  3608. # argument, since pmf is checked in test_pmf_logpmf with random samples
  3609. # generated with the rvs method. This test is not redundant, because
  3610. # test_pmf_logpmf intentionally uses rvs generation with Boyett only,
  3611. # but here we test both Boyett and Patefield.
  3612. row = [2, 6]
  3613. col = [1, 3, 4]
  3614. ct = random_table
  3615. rvs = ct.rvs(row, col, size=100000, method=method,
  3616. random_state=self.get_rng())
  3617. unique_rvs, counts = np.unique(rvs, axis=0, return_counts=True)
  3618. # generated frequencies should match expected frequencies
  3619. p = ct.pmf(unique_rvs, row, col)
  3620. assert_allclose(p * len(rvs), counts, rtol=0.02)
  3621. @pytest.mark.parametrize("method", ("boyett", "patefield"))
  3622. def test_rvs_with_zeros_in_col_row(self, method):
  3623. row = [0, 1, 0]
  3624. col = [1, 0, 0, 0]
  3625. d = random_table(row, col)
  3626. rv = d.rvs(1000, method=method, random_state=self.get_rng())
  3627. expected = np.zeros((1000, len(row), len(col)))
  3628. expected[...] = [[0, 0, 0, 0],
  3629. [1, 0, 0, 0],
  3630. [0, 0, 0, 0]]
  3631. assert_equal(rv, expected)
  3632. @pytest.mark.parametrize("method", (None, "boyett", "patefield"))
  3633. @pytest.mark.parametrize("col", ([], [0]))
  3634. @pytest.mark.parametrize("row", ([], [0]))
  3635. def test_rvs_with_edge_cases(self, method, row, col):
  3636. d = random_table(row, col)
  3637. rv = d.rvs(10, method=method, random_state=self.get_rng())
  3638. expected = np.zeros((10, len(row), len(col)))
  3639. assert_equal(rv, expected)
  3640. @pytest.mark.parametrize('v', (1, 2))
  3641. def test_rvs_rcont(self, v):
  3642. # This test checks the internal low-level interface.
  3643. # It is implicitly also checked by the other test_rvs* calls.
  3644. import scipy.stats._rcont as _rcont
  3645. row = np.array([1, 3], dtype=np.int64)
  3646. col = np.array([2, 1, 1], dtype=np.int64)
  3647. rvs = getattr(_rcont, f"rvs_rcont{v}")
  3648. ntot = np.sum(row)
  3649. result = rvs(row, col, ntot, 1, self.get_rng())
  3650. assert result.shape == (1, len(row), len(col))
  3651. assert np.sum(result) == ntot
  3652. def test_frozen(self):
  3653. row = [2, 6]
  3654. col = [1, 3, 4]
  3655. d = random_table(row, col, seed=self.get_rng())
  3656. sample = d.rvs()
  3657. expected = random_table.mean(row, col)
  3658. assert_equal(expected, d.mean())
  3659. expected = random_table.pmf(sample, row, col)
  3660. assert_equal(expected, d.pmf(sample))
  3661. expected = random_table.logpmf(sample, row, col)
  3662. assert_equal(expected, d.logpmf(sample))
  3663. @pytest.mark.parametrize("method", ("boyett", "patefield"))
  3664. def test_rvs_frozen(self, method):
  3665. row = [2, 6]
  3666. col = [1, 3, 4]
  3667. d = random_table(row, col, seed=self.get_rng())
  3668. expected = random_table.rvs(row, col, size=10, method=method,
  3669. random_state=self.get_rng())
  3670. got = d.rvs(size=10, method=method)
  3671. assert_equal(expected, got)
  3672. def check_pickling(distfn, args):
  3673. # check that a distribution instance pickles and unpickles
  3674. # pay special attention to the random_state property
  3675. # save the random_state (restore later)
  3676. rndm = distfn.random_state
  3677. distfn.random_state = 1234
  3678. distfn.rvs(*args, size=8)
  3679. s = pickle.dumps(distfn)
  3680. r0 = distfn.rvs(*args, size=8)
  3681. unpickled = pickle.loads(s)
  3682. r1 = unpickled.rvs(*args, size=8)
  3683. assert_equal(r0, r1)
  3684. # restore the random_state
  3685. distfn.random_state = rndm
  3686. @pytest.mark.thread_unsafe(reason="uses numpy global random state and monkey-patching")
  3687. def test_random_state_property():
  3688. scale = np.eye(3)
  3689. scale[0, 1] = 0.5
  3690. scale[1, 0] = 0.5
  3691. dists = [
  3692. [multivariate_normal, ()],
  3693. [dirichlet, (np.array([1.]), )],
  3694. [wishart, (10, scale)],
  3695. [invwishart, (10, scale)],
  3696. [multinomial, (5, [0.5, 0.4, 0.1])],
  3697. [ortho_group, (2,)],
  3698. [special_ortho_group, (2,)]
  3699. ]
  3700. for distfn, args in dists:
  3701. check_random_state_property(distfn, args)
  3702. check_pickling(distfn, args)
  3703. class TestVonMises_Fisher:
  3704. @pytest.mark.parametrize("dim", [2, 3, 4, 6])
  3705. @pytest.mark.parametrize("size", [None, 1, 5, (5, 4)])
  3706. def test_samples(self, dim, size):
  3707. # test that samples have correct shape and norm 1
  3708. rng = np.random.default_rng(2777937887058094419)
  3709. mu = np.full((dim, ), 1/np.sqrt(dim))
  3710. vmf_dist = vonmises_fisher(mu, 1, seed=rng)
  3711. samples = vmf_dist.rvs(size)
  3712. mean, cov = np.zeros(dim), np.eye(dim)
  3713. expected_shape = rng.multivariate_normal(mean, cov, size=size).shape
  3714. assert samples.shape == expected_shape
  3715. norms = np.linalg.norm(samples, axis=-1)
  3716. assert_allclose(norms, 1.)
  3717. @pytest.mark.parametrize("dim", [5, 8])
  3718. @pytest.mark.parametrize("kappa", [1e15, 1e20, 1e30])
  3719. def test_sampling_high_concentration(self, dim, kappa):
  3720. # test that no warnings are encountered for high values
  3721. rng = np.random.default_rng(2777937887058094419)
  3722. mu = np.full((dim, ), 1/np.sqrt(dim))
  3723. vmf_dist = vonmises_fisher(mu, kappa, seed=rng)
  3724. vmf_dist.rvs(10)
  3725. def test_two_dimensional_mu(self):
  3726. mu = np.ones((2, 2))
  3727. msg = "'mu' must have one-dimensional shape."
  3728. with pytest.raises(ValueError, match=msg):
  3729. vonmises_fisher(mu, 1)
  3730. def test_wrong_norm_mu(self):
  3731. mu = np.ones((2, ))
  3732. msg = "'mu' must be a unit vector of norm 1."
  3733. with pytest.raises(ValueError, match=msg):
  3734. vonmises_fisher(mu, 1)
  3735. def test_one_entry_mu(self):
  3736. mu = np.ones((1, ))
  3737. msg = "'mu' must have at least two entries."
  3738. with pytest.raises(ValueError, match=msg):
  3739. vonmises_fisher(mu, 1)
  3740. @pytest.mark.parametrize("kappa", [-1, (5, 3)])
  3741. def test_kappa_validation(self, kappa):
  3742. msg = "'kappa' must be a positive scalar."
  3743. with pytest.raises(ValueError, match=msg):
  3744. vonmises_fisher([1, 0], kappa)
  3745. @pytest.mark.parametrize("kappa", [0, 0.])
  3746. def test_kappa_zero(self, kappa):
  3747. msg = ("For 'kappa=0' the von Mises-Fisher distribution "
  3748. "becomes the uniform distribution on the sphere "
  3749. "surface. Consider using 'scipy.stats.uniform_direction' "
  3750. "instead.")
  3751. with pytest.raises(ValueError, match=msg):
  3752. vonmises_fisher([1, 0], kappa)
  3753. @pytest.mark.parametrize("method", [vonmises_fisher.pdf,
  3754. vonmises_fisher.logpdf])
  3755. def test_invalid_shapes_pdf_logpdf(self, method):
  3756. x = np.array([1., 0., 0])
  3757. msg = ("The dimensionality of the last axis of 'x' must "
  3758. "match the dimensionality of the von Mises Fisher "
  3759. "distribution.")
  3760. with pytest.raises(ValueError, match=msg):
  3761. method(x, [1, 0], 1)
  3762. @pytest.mark.parametrize("method", [vonmises_fisher.pdf,
  3763. vonmises_fisher.logpdf])
  3764. def test_unnormalized_input(self, method):
  3765. x = np.array([0.5, 0.])
  3766. msg = "'x' must be unit vectors of norm 1 along last dimension."
  3767. with pytest.raises(ValueError, match=msg):
  3768. method(x, [1, 0], 1)
  3769. # Expected values of the vonmises-fisher logPDF were computed via mpmath
  3770. # from mpmath import mp
  3771. # import numpy as np
  3772. # mp.dps = 50
  3773. # def logpdf_mpmath(x, mu, kappa):
  3774. # dim = mu.size
  3775. # halfdim = mp.mpf(0.5 * dim)
  3776. # kappa = mp.mpf(kappa)
  3777. # const = (kappa**(halfdim - mp.one)/((2*mp.pi)**halfdim * \
  3778. # mp.besseli(halfdim -mp.one, kappa)))
  3779. # return float(const * mp.exp(kappa*mp.fdot(x, mu)))
  3780. @pytest.mark.parametrize('x, mu, kappa, reference',
  3781. [(np.array([1., 0., 0.]), np.array([1., 0., 0.]),
  3782. 1e-4, 0.0795854295583605),
  3783. (np.array([1., 0., 0]), np.array([0., 0., 1.]),
  3784. 1e-4, 0.07957747141331854),
  3785. (np.array([1., 0., 0.]), np.array([1., 0., 0.]),
  3786. 100, 15.915494309189533),
  3787. (np.array([1., 0., 0]), np.array([0., 0., 1.]),
  3788. 100, 5.920684802611232e-43),
  3789. (np.array([1., 0., 0.]),
  3790. np.array([np.sqrt(0.98), np.sqrt(0.02), 0.]),
  3791. 2000, 5.930499050746588e-07),
  3792. (np.array([1., 0., 0]), np.array([1., 0., 0.]),
  3793. 2000, 318.3098861837907),
  3794. (np.array([1., 0., 0., 0., 0.]),
  3795. np.array([1., 0., 0., 0., 0.]),
  3796. 2000, 101371.86957712633),
  3797. (np.array([1., 0., 0., 0., 0.]),
  3798. np.array([np.sqrt(0.98), np.sqrt(0.02), 0.,
  3799. 0, 0.]),
  3800. 2000, 0.00018886808182653578),
  3801. (np.array([1., 0., 0., 0., 0.]),
  3802. np.array([np.sqrt(0.8), np.sqrt(0.2), 0.,
  3803. 0, 0.]),
  3804. 2000, 2.0255393314603194e-87)])
  3805. def test_pdf_accuracy(self, x, mu, kappa, reference):
  3806. pdf = vonmises_fisher(mu, kappa).pdf(x)
  3807. assert_allclose(pdf, reference, rtol=1e-13)
  3808. # Expected values of the vonmises-fisher logPDF were computed via mpmath
  3809. # from mpmath import mp
  3810. # import numpy as np
  3811. # mp.dps = 50
  3812. # def logpdf_mpmath(x, mu, kappa):
  3813. # dim = mu.size
  3814. # halfdim = mp.mpf(0.5 * dim)
  3815. # kappa = mp.mpf(kappa)
  3816. # two = mp.mpf(2.)
  3817. # const = (kappa**(halfdim - mp.one)/((two*mp.pi)**halfdim * \
  3818. # mp.besseli(halfdim - mp.one, kappa)))
  3819. # return float(mp.log(const * mp.exp(kappa*mp.fdot(x, mu))))
  3820. @pytest.mark.parametrize('x, mu, kappa, reference',
  3821. [(np.array([1., 0., 0.]), np.array([1., 0., 0.]),
  3822. 1e-4, -2.5309242486359573),
  3823. (np.array([1., 0., 0]), np.array([0., 0., 1.]),
  3824. 1e-4, -2.5310242486359575),
  3825. (np.array([1., 0., 0.]), np.array([1., 0., 0.]),
  3826. 100, 2.767293119578746),
  3827. (np.array([1., 0., 0]), np.array([0., 0., 1.]),
  3828. 100, -97.23270688042125),
  3829. (np.array([1., 0., 0.]),
  3830. np.array([np.sqrt(0.98), np.sqrt(0.02), 0.]),
  3831. 2000, -14.337987284534103),
  3832. (np.array([1., 0., 0]), np.array([1., 0., 0.]),
  3833. 2000, 5.763025393132737),
  3834. (np.array([1., 0., 0., 0., 0.]),
  3835. np.array([1., 0., 0., 0., 0.]),
  3836. 2000, 11.526550911307156),
  3837. (np.array([1., 0., 0., 0., 0.]),
  3838. np.array([np.sqrt(0.98), np.sqrt(0.02), 0.,
  3839. 0, 0.]),
  3840. 2000, -8.574461766359684),
  3841. (np.array([1., 0., 0., 0., 0.]),
  3842. np.array([np.sqrt(0.8), np.sqrt(0.2), 0.,
  3843. 0, 0.]),
  3844. 2000, -199.61906708886113)])
  3845. def test_logpdf_accuracy(self, x, mu, kappa, reference):
  3846. logpdf = vonmises_fisher(mu, kappa).logpdf(x)
  3847. assert_allclose(logpdf, reference, rtol=1e-14)
  3848. # Expected values of the vonmises-fisher entropy were computed via mpmath
  3849. # from mpmath import mp
  3850. # import numpy as np
  3851. # mp.dps = 50
  3852. # def entropy_mpmath(dim, kappa):
  3853. # mu = np.full((dim, ), 1/np.sqrt(dim))
  3854. # kappa = mp.mpf(kappa)
  3855. # halfdim = mp.mpf(0.5 * dim)
  3856. # logconstant = (mp.log(kappa**(halfdim - mp.one)
  3857. # /((2*mp.pi)**halfdim
  3858. # * mp.besseli(halfdim -mp.one, kappa)))
  3859. # return float(-logconstant - kappa * mp.besseli(halfdim, kappa)/
  3860. # mp.besseli(halfdim -1, kappa))
  3861. @pytest.mark.parametrize('dim, kappa, reference',
  3862. [(3, 1e-4, 2.531024245302624),
  3863. (3, 100, -1.7672931195787458),
  3864. (5, 5000, -11.359032310024453),
  3865. (8, 1, 3.4189526482545527)])
  3866. def test_entropy_accuracy(self, dim, kappa, reference):
  3867. mu = np.full((dim, ), 1/np.sqrt(dim))
  3868. entropy = vonmises_fisher(mu, kappa).entropy()
  3869. assert_allclose(entropy, reference, rtol=2e-14)
  3870. @pytest.mark.parametrize("method", [vonmises_fisher.pdf,
  3871. vonmises_fisher.logpdf])
  3872. def test_broadcasting(self, method):
  3873. # test that pdf and logpdf values are correctly broadcasted
  3874. testshape = (2, 2)
  3875. rng = np.random.default_rng(2777937887058094419)
  3876. x = uniform_direction(3).rvs(testshape, random_state=rng)
  3877. mu = np.full((3, ), 1/np.sqrt(3))
  3878. kappa = 5
  3879. result_all = method(x, mu, kappa)
  3880. assert result_all.shape == testshape
  3881. for i in range(testshape[0]):
  3882. for j in range(testshape[1]):
  3883. current_val = method(x[i, j, :], mu, kappa)
  3884. assert_allclose(current_val, result_all[i, j], rtol=1e-15)
  3885. def test_vs_vonmises_2d(self):
  3886. # test that in 2D, von Mises-Fisher yields the same results
  3887. # as the von Mises distribution
  3888. rng = np.random.default_rng(2777937887058094419)
  3889. mu = np.array([0, 1])
  3890. mu_angle = np.arctan2(mu[1], mu[0])
  3891. kappa = 20
  3892. vmf = vonmises_fisher(mu, kappa)
  3893. vonmises_dist = vonmises(loc=mu_angle, kappa=kappa)
  3894. vectors = uniform_direction(2).rvs(10, random_state=rng)
  3895. angles = np.arctan2(vectors[:, 1], vectors[:, 0])
  3896. assert_allclose(vonmises_dist.entropy(), vmf.entropy())
  3897. assert_allclose(vonmises_dist.pdf(angles), vmf.pdf(vectors))
  3898. assert_allclose(vonmises_dist.logpdf(angles), vmf.logpdf(vectors))
  3899. @pytest.mark.parametrize("dim", [2, 3, 6])
  3900. @pytest.mark.parametrize("kappa, mu_tol, kappa_tol",
  3901. [(1, 5e-2, 5e-2),
  3902. (10, 1e-2, 1e-2),
  3903. (100, 5e-3, 2e-2),
  3904. (1000, 1e-3, 2e-2)])
  3905. def test_fit_accuracy(self, dim, kappa, mu_tol, kappa_tol):
  3906. mu = np.full((dim, ), 1/np.sqrt(dim))
  3907. vmf_dist = vonmises_fisher(mu, kappa)
  3908. rng = np.random.default_rng(2777937887058094419)
  3909. n_samples = 10000
  3910. samples = vmf_dist.rvs(n_samples, random_state=rng)
  3911. mu_fit, kappa_fit = vonmises_fisher.fit(samples)
  3912. angular_error = np.arccos(mu.dot(mu_fit))
  3913. assert_allclose(angular_error, 0., atol=mu_tol, rtol=0)
  3914. assert_allclose(kappa, kappa_fit, rtol=kappa_tol)
  3915. def test_fit_error_one_dimensional_data(self):
  3916. x = np.zeros((3, ))
  3917. msg = "'x' must be two dimensional."
  3918. with pytest.raises(ValueError, match=msg):
  3919. vonmises_fisher.fit(x)
  3920. def test_fit_error_unnormalized_data(self):
  3921. x = np.ones((3, 3))
  3922. msg = "'x' must be unit vectors of norm 1 along last dimension."
  3923. with pytest.raises(ValueError, match=msg):
  3924. vonmises_fisher.fit(x)
  3925. def test_frozen_distribution(self):
  3926. mu = np.array([0, 0, 1])
  3927. kappa = 5
  3928. frozen = vonmises_fisher(mu, kappa)
  3929. frozen_seed = vonmises_fisher(mu, kappa, seed=514)
  3930. rvs1 = frozen.rvs(random_state=514)
  3931. rvs2 = vonmises_fisher.rvs(mu, kappa, random_state=514)
  3932. rvs3 = frozen_seed.rvs()
  3933. assert_equal(rvs1, rvs2)
  3934. assert_equal(rvs1, rvs3)
  3935. class TestDirichletMultinomial:
  3936. @classmethod
  3937. def get_params(self, m):
  3938. rng = np.random.default_rng(28469824356873456)
  3939. alpha = rng.uniform(0, 100, size=2)
  3940. x = rng.integers(1, 20, size=(m, 2))
  3941. n = x.sum(axis=-1)
  3942. return rng, m, alpha, n, x
  3943. def test_frozen(self):
  3944. rng = np.random.default_rng(28469824356873456)
  3945. alpha = rng.uniform(0, 100, 10)
  3946. x = rng.integers(0, 10, 10)
  3947. n = np.sum(x, axis=-1)
  3948. d = dirichlet_multinomial(alpha, n)
  3949. assert_equal(d.logpmf(x), dirichlet_multinomial.logpmf(x, alpha, n))
  3950. assert_equal(d.pmf(x), dirichlet_multinomial.pmf(x, alpha, n))
  3951. assert_equal(d.mean(), dirichlet_multinomial.mean(alpha, n))
  3952. assert_equal(d.var(), dirichlet_multinomial.var(alpha, n))
  3953. assert_equal(d.cov(), dirichlet_multinomial.cov(alpha, n))
  3954. def test_pmf_logpmf_against_R(self):
  3955. # # Compare PMF against R's extraDistr ddirmnon
  3956. # # library(extraDistr)
  3957. # # options(digits=16)
  3958. # ddirmnom(c(1, 2, 3), 6, c(3, 4, 5))
  3959. x = np.array([1, 2, 3])
  3960. n = np.sum(x)
  3961. alpha = np.array([3, 4, 5])
  3962. res = dirichlet_multinomial.pmf(x, alpha, n)
  3963. logres = dirichlet_multinomial.logpmf(x, alpha, n)
  3964. ref = 0.08484162895927638
  3965. assert_allclose(res, ref)
  3966. assert_allclose(logres, np.log(ref))
  3967. assert res.shape == logres.shape == ()
  3968. # library(extraDistr)
  3969. # options(digits=16)
  3970. # ddirmnom(c(4, 3, 2, 0, 2, 3, 5, 7, 4, 7), 37,
  3971. # c(45.01025314, 21.98739582, 15.14851365, 80.21588671,
  3972. # 52.84935481, 25.20905262, 53.85373737, 4.88568118,
  3973. # 89.06440654, 20.11359466))
  3974. rng = np.random.default_rng(28469824356873456)
  3975. alpha = rng.uniform(0, 100, 10)
  3976. x = rng.integers(0, 10, 10)
  3977. n = np.sum(x, axis=-1)
  3978. res = dirichlet_multinomial(alpha, n).pmf(x)
  3979. logres = dirichlet_multinomial.logpmf(x, alpha, n)
  3980. ref = 3.65409306285992e-16
  3981. assert_allclose(res, ref)
  3982. assert_allclose(logres, np.log(ref))
  3983. def test_pmf_logpmf_support(self):
  3984. # when the sum of the category counts does not equal the number of
  3985. # trials, the PMF is zero
  3986. rng, m, alpha, n, x = self.get_params(1)
  3987. n += 1
  3988. assert_equal(dirichlet_multinomial(alpha, n).pmf(x), 0)
  3989. assert_equal(dirichlet_multinomial(alpha, n).logpmf(x), -np.inf)
  3990. rng, m, alpha, n, x = self.get_params(10)
  3991. i = rng.random(size=10) > 0.5
  3992. x[i] = np.round(x[i] * 2) # sum of these x does not equal n
  3993. assert_equal(dirichlet_multinomial(alpha, n).pmf(x)[i], 0)
  3994. assert_equal(dirichlet_multinomial(alpha, n).logpmf(x)[i], -np.inf)
  3995. assert np.all(dirichlet_multinomial(alpha, n).pmf(x)[~i] > 0)
  3996. assert np.all(dirichlet_multinomial(alpha, n).logpmf(x)[~i] > -np.inf)
  3997. def test_dimensionality_one(self):
  3998. # if the dimensionality is one, there is only one possible outcome
  3999. n = 6 # number of trials
  4000. alpha = [10] # concentration parameters
  4001. x = np.asarray([n]) # counts
  4002. dist = dirichlet_multinomial(alpha, n)
  4003. assert_equal(dist.pmf(x), 1)
  4004. assert_equal(dist.pmf(x+1), 0)
  4005. assert_equal(dist.logpmf(x), 0)
  4006. assert_equal(dist.logpmf(x+1), -np.inf)
  4007. assert_equal(dist.mean(), n)
  4008. assert_equal(dist.var(), 0)
  4009. assert_equal(dist.cov(), 0)
  4010. def test_n_is_zero(self):
  4011. # similarly, only one possible outcome if n is zero
  4012. n = 0
  4013. alpha = np.asarray([1., 1.])
  4014. x = np.asarray([0, 0])
  4015. dist = dirichlet_multinomial(alpha, n)
  4016. assert_equal(dist.pmf(x), 1)
  4017. assert_equal(dist.pmf(x+1), 0)
  4018. assert_equal(dist.logpmf(x), 0)
  4019. assert_equal(dist.logpmf(x+1), -np.inf)
  4020. assert_equal(dist.mean(), [0, 0])
  4021. assert_equal(dist.var(), [0, 0])
  4022. assert_equal(dist.cov(), [[0, 0], [0, 0]])
  4023. @pytest.mark.parametrize('method_name', ['pmf', 'logpmf'])
  4024. def test_against_betabinom_pmf(self, method_name):
  4025. rng, m, alpha, n, x = self.get_params(100)
  4026. method = getattr(dirichlet_multinomial(alpha, n), method_name)
  4027. ref_method = getattr(stats.betabinom(n, *alpha.T), method_name)
  4028. res = method(x)
  4029. ref = ref_method(x.T[0])
  4030. assert_allclose(res, ref)
  4031. @pytest.mark.parametrize('method_name', ['mean', 'var'])
  4032. def test_against_betabinom_moments(self, method_name):
  4033. rng, m, alpha, n, x = self.get_params(100)
  4034. method = getattr(dirichlet_multinomial(alpha, n), method_name)
  4035. ref_method = getattr(stats.betabinom(n, *alpha.T), method_name)
  4036. res = method()[:, 0]
  4037. ref = ref_method()
  4038. assert_allclose(res, ref)
  4039. def test_moments(self):
  4040. rng = np.random.default_rng(28469824356873456)
  4041. dim = 5
  4042. n = rng.integers(1, 100)
  4043. alpha = rng.random(size=dim) * 10
  4044. dist = dirichlet_multinomial(alpha, n)
  4045. # Generate a random sample from the distribution using NumPy
  4046. m = 100000
  4047. p = rng.dirichlet(alpha, size=m)
  4048. x = rng.multinomial(n, p, size=m)
  4049. assert_allclose(dist.mean(), np.mean(x, axis=0), rtol=5e-3)
  4050. assert_allclose(dist.var(), np.var(x, axis=0), rtol=1e-2)
  4051. assert dist.mean().shape == dist.var().shape == (dim,)
  4052. cov = dist.cov()
  4053. assert cov.shape == (dim, dim)
  4054. assert_allclose(cov, np.cov(x.T), rtol=2e-2)
  4055. assert_equal(np.diag(cov), dist.var())
  4056. assert np.all(scipy.linalg.eigh(cov)[0] > 0) # positive definite
  4057. def test_input_validation(self):
  4058. # valid inputs
  4059. x0 = np.array([1, 2, 3])
  4060. n0 = np.sum(x0)
  4061. alpha0 = np.array([3, 4, 5])
  4062. text = "`x` must contain only non-negative integers."
  4063. with assert_raises(ValueError, match=text):
  4064. dirichlet_multinomial.logpmf([1, -1, 3], alpha0, n0)
  4065. with assert_raises(ValueError, match=text):
  4066. dirichlet_multinomial.logpmf([1, 2.1, 3], alpha0, n0)
  4067. text = "`alpha` must contain only positive values."
  4068. with assert_raises(ValueError, match=text):
  4069. dirichlet_multinomial.logpmf(x0, [3, 0, 4], n0)
  4070. with assert_raises(ValueError, match=text):
  4071. dirichlet_multinomial.logpmf(x0, [3, -1, 4], n0)
  4072. text = "`n` must be a non-negative integer."
  4073. with assert_raises(ValueError, match=text):
  4074. dirichlet_multinomial.logpmf(x0, alpha0, 49.1)
  4075. with assert_raises(ValueError, match=text):
  4076. dirichlet_multinomial.logpmf(x0, alpha0, -1)
  4077. x = np.array([1, 2, 3, 4])
  4078. alpha = np.array([3, 4, 5])
  4079. text = "`x` and `alpha` must be broadcastable."
  4080. with assert_raises(ValueError, match=text):
  4081. dirichlet_multinomial.logpmf(x, alpha, x.sum())
  4082. @pytest.mark.parametrize('method', ['pmf', 'logpmf'])
  4083. def test_broadcasting_pmf(self, method):
  4084. alpha = np.array([[3, 4, 5], [4, 5, 6], [5, 5, 7], [8, 9, 10]])
  4085. n = np.array([[6], [7], [8]])
  4086. x = np.array([[1, 2, 3], [2, 2, 3]]).reshape((2, 1, 1, 3))
  4087. method = getattr(dirichlet_multinomial, method)
  4088. res = method(x, alpha, n)
  4089. assert res.shape == (2, 3, 4)
  4090. for i in range(len(x)):
  4091. for j in range(len(n)):
  4092. for k in range(len(alpha)):
  4093. res_ijk = res[i, j, k]
  4094. ref = method(x[i].squeeze(), alpha[k].squeeze(), n[j].squeeze())
  4095. assert_allclose(res_ijk, ref)
  4096. @pytest.mark.parametrize('method_name', ['mean', 'var', 'cov'])
  4097. def test_broadcasting_moments(self, method_name):
  4098. alpha = np.array([[3, 4, 5], [4, 5, 6], [5, 5, 7], [8, 9, 10]])
  4099. n = np.array([[6], [7], [8]])
  4100. method = getattr(dirichlet_multinomial, method_name)
  4101. res = method(alpha, n)
  4102. assert res.shape == (3, 4, 3) if method_name != 'cov' else (3, 4, 3, 3)
  4103. for j in range(len(n)):
  4104. for k in range(len(alpha)):
  4105. res_ijk = res[j, k]
  4106. ref = method(alpha[k].squeeze(), n[j].squeeze())
  4107. assert_allclose(res_ijk, ref)
  4108. class TestNormalInverseGamma:
  4109. def test_marginal_x(self):
  4110. # According to [1], sqrt(a * lmbda / b) * (x - u) should follow a t-distribution
  4111. # with 2*a degrees of freedom. Test that this is true of the PDF and random
  4112. # variates.
  4113. rng = np.random.default_rng(8925849245)
  4114. mu, lmbda, a, b = rng.random(4)
  4115. norm_inv_gamma = stats.normal_inverse_gamma(mu, lmbda, a, b)
  4116. t = stats.t(2*a, loc=mu, scale=1/np.sqrt(a * lmbda / b))
  4117. # Test PDF
  4118. x = np.linspace(-5, 5, 11)
  4119. res = tanhsinh(lambda s2, x: norm_inv_gamma.pdf(x, s2), 0, np.inf, args=(x,))
  4120. ref = t.pdf(x)
  4121. assert_allclose(res.integral, ref)
  4122. # Test RVS
  4123. res = norm_inv_gamma.rvs(size=10000, random_state=rng)
  4124. _, pvalue = stats.ks_1samp(res[0], t.cdf)
  4125. assert pvalue > 0.1
  4126. def test_marginal_s2(self):
  4127. # According to [1], s2 should follow an inverse gamma distribution with
  4128. # shapes a, b (where b is the scale in our parameterization). Test that
  4129. # this is true of the PDF and random variates.
  4130. rng = np.random.default_rng(8925849245)
  4131. mu, lmbda, a, b = rng.random(4)
  4132. norm_inv_gamma = stats.normal_inverse_gamma(mu, lmbda, a, b)
  4133. inv_gamma = stats.invgamma(a, scale=b)
  4134. # Test PDF
  4135. s2 = np.linspace(0.1, 10, 10)
  4136. res = tanhsinh(lambda x, s2: norm_inv_gamma.pdf(x, s2),
  4137. -np.inf, np.inf, args=(s2,))
  4138. ref = inv_gamma.pdf(s2)
  4139. assert_allclose(res.integral, ref)
  4140. # Test RVS
  4141. res = norm_inv_gamma.rvs(size=10000, random_state=rng)
  4142. _, pvalue = stats.ks_1samp(res[1], inv_gamma.cdf)
  4143. assert pvalue > 0.1
  4144. def test_pdf_logpdf(self):
  4145. # Check that PDF and log-PDF are consistent
  4146. rng = np.random.default_rng(8925849245)
  4147. mu, lmbda, a, b = rng.random((4, 20)) - 0.25 # make some invalid
  4148. x, s2 = rng.random(size=(2, 20)) - 0.25
  4149. res = stats.normal_inverse_gamma(mu, lmbda, a, b).pdf(x, s2)
  4150. ref = stats.normal_inverse_gamma.logpdf(x, s2, mu, lmbda, a, b)
  4151. assert_allclose(res, np.exp(ref))
  4152. def test_invalid_and_special_cases(self):
  4153. # Test cases that are handled by input validation rather than the formulas
  4154. rng = np.random.default_rng(8925849245)
  4155. mu, lmbda, a, b = rng.random(4)
  4156. x, s2 = rng.random(2)
  4157. res = stats.normal_inverse_gamma(np.nan, lmbda, a, b).pdf(x, s2)
  4158. assert_equal(res, np.nan)
  4159. res = stats.normal_inverse_gamma(mu, -1, a, b).pdf(x, s2)
  4160. assert_equal(res, np.nan)
  4161. res = stats.normal_inverse_gamma(mu, lmbda, 0, b).pdf(x, s2)
  4162. assert_equal(res, np.nan)
  4163. res = stats.normal_inverse_gamma(mu, lmbda, a, -1).pdf(x, s2)
  4164. assert_equal(res, np.nan)
  4165. res = stats.normal_inverse_gamma(mu, lmbda, a, b).pdf(x, -1)
  4166. assert_equal(res, 0)
  4167. # PDF with out-of-support s2 is not zero if shape parameter is invalid
  4168. res = stats.normal_inverse_gamma(mu, [-1, np.nan], a, b).pdf(x, -1)
  4169. assert_equal(res, np.nan)
  4170. res = stats.normal_inverse_gamma(mu, -1, a, b).mean()
  4171. assert_equal(res, (np.nan, np.nan))
  4172. res = stats.normal_inverse_gamma(mu, lmbda, -1, b).var()
  4173. assert_equal(res, (np.nan, np.nan))
  4174. with pytest.raises(ValueError, match="Domain error in arguments..."):
  4175. stats.normal_inverse_gamma(mu, lmbda, a, -1).rvs()
  4176. def test_broadcasting(self):
  4177. # Test methods with broadcastable array parameters. Roughly speaking, the
  4178. # shapes should be the broadcasted shapes of all arguments, and the raveled
  4179. # outputs should be the same as the outputs with raveled inputs.
  4180. rng = np.random.default_rng(8925849245)
  4181. b = rng.random(2)
  4182. a = rng.random((3, 1)) + 2 # for defined moments
  4183. lmbda = rng.random((4, 1, 1))
  4184. mu = rng.random((5, 1, 1, 1))
  4185. s2 = rng.random((6, 1, 1, 1, 1))
  4186. x = rng.random((7, 1, 1, 1, 1, 1))
  4187. dist = stats.normal_inverse_gamma(mu, lmbda, a, b)
  4188. # Test PDF and log-PDF
  4189. broadcasted = np.broadcast_arrays(x, s2, mu, lmbda, a, b)
  4190. broadcasted_raveled = [np.ravel(arr) for arr in broadcasted]
  4191. res = dist.pdf(x, s2)
  4192. assert res.shape == broadcasted[0].shape
  4193. assert_allclose(res.ravel(),
  4194. stats.normal_inverse_gamma.pdf(*broadcasted_raveled))
  4195. res = dist.logpdf(x, s2)
  4196. assert res.shape == broadcasted[0].shape
  4197. assert_allclose(res.ravel(),
  4198. stats.normal_inverse_gamma.logpdf(*broadcasted_raveled))
  4199. # Test moments
  4200. broadcasted = np.broadcast_arrays(mu, lmbda, a, b)
  4201. broadcasted_raveled = [np.ravel(arr) for arr in broadcasted]
  4202. res = dist.mean()
  4203. assert res[0].shape == broadcasted[0].shape
  4204. assert_allclose((res[0].ravel(), res[1].ravel()),
  4205. stats.normal_inverse_gamma.mean(*broadcasted_raveled))
  4206. res = dist.var()
  4207. assert res[0].shape == broadcasted[0].shape
  4208. assert_allclose((res[0].ravel(), res[1].ravel()),
  4209. stats.normal_inverse_gamma.var(*broadcasted_raveled))
  4210. # Test RVS
  4211. size = (6, 5, 4, 3, 2)
  4212. rng = np.random.default_rng(2348923985324)
  4213. res = dist.rvs(size=size, random_state=rng)
  4214. rng = np.random.default_rng(2348923985324)
  4215. shape = 6, 5*4*3*2
  4216. ref = stats.normal_inverse_gamma.rvs(*broadcasted_raveled, size=shape,
  4217. random_state=rng)
  4218. assert_allclose((res[0].reshape(shape), res[1].reshape(shape)), ref)
  4219. @pytest.mark.slow
  4220. @pytest.mark.fail_slow(10)
  4221. def test_moments(self):
  4222. # Test moments against quadrature
  4223. rng = np.random.default_rng(8925849245)
  4224. mu, lmbda, a, b = rng.random(4)
  4225. a += 2 # ensure defined
  4226. dist = stats.normal_inverse_gamma(mu, lmbda, a, b)
  4227. res = dist.mean()
  4228. ref = dblquad(lambda s2, x: dist.pdf(x, s2) * x, -np.inf, np.inf, 0, np.inf)
  4229. assert_allclose(res[0], ref[0], rtol=1e-6)
  4230. ref = dblquad(lambda s2, x: dist.pdf(x, s2) * s2, -np.inf, np.inf, 0, np.inf)
  4231. assert_allclose(res[1], ref[0], rtol=1e-6)
  4232. @pytest.mark.parametrize('dtype', [np.int32, np.float16, np.float32, np.float64])
  4233. def test_dtype(self, dtype):
  4234. if np.__version__ < "2":
  4235. pytest.skip("Scalar dtypes only respected after NEP 50.")
  4236. rng = np.random.default_rng(8925849245)
  4237. x, s2, mu, lmbda, a, b = rng.uniform(3, 10, size=6).astype(dtype)
  4238. dtype_out = np.result_type(1.0, dtype)
  4239. dist = stats.normal_inverse_gamma(mu, lmbda, a, b)
  4240. assert dist.rvs()[0].dtype == dtype_out
  4241. assert dist.rvs()[1].dtype == dtype_out
  4242. assert dist.mean()[0].dtype == dtype_out
  4243. assert dist.mean()[1].dtype == dtype_out
  4244. assert dist.var()[0].dtype == dtype_out
  4245. assert dist.var()[1].dtype == dtype_out
  4246. assert dist.logpdf(x, s2).dtype == dtype_out
  4247. assert dist.pdf(x, s2).dtype == dtype_out