| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354 |
- # mypy: allow-untyped-defs
- """Distributed Collective Communication (c10d)."""
- import collections.abc
- import contextlib
- import copy
- import ctypes
- import hashlib
- import io
- import itertools
- import logging
- import os
- import pickle
- import sys
- import time
- import warnings
- from collections import namedtuple
- from collections.abc import Callable
- from datetime import timedelta
- from typing import Any, NewType, TYPE_CHECKING
- from typing_extensions import deprecated
- import torch
- from torch._C import _DistStoreError as DistStoreError
- from torch._C._distributed_c10d import (
- _DistributedBackendOptions,
- _register_process_group,
- _resolve_process_group,
- _unregister_all_process_groups,
- _unregister_process_group,
- AllgatherOptions,
- AllreduceCoalescedOptions,
- AllreduceOptions,
- AllToAllOptions,
- BarrierOptions,
- BroadcastOptions,
- DebugLevel,
- GatherOptions,
- get_debug_level,
- PrefixStore,
- ProcessGroup,
- ReduceOp,
- ReduceOptions,
- ReduceScatterOptions,
- ScatterOptions,
- Store,
- Work,
- )
- from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs
- from torch.monitor import _WaitCounter
- from torch.overrides import handle_torch_function, has_torch_function
- from torch.utils._typing_utils import not_none
- from . import config as dist_config
- from .c10d_logger import _exception_logger, _time_logger
- from .constants import default_pg_nccl_timeout, default_pg_timeout
- from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401
- __all__ = [
- "Backend",
- "BackendConfig",
- "GroupMember",
- "P2POp",
- "all_gather",
- "all_gather_coalesced",
- "all_gather_object",
- "all_reduce",
- "all_reduce_coalesced",
- "all_to_all",
- "all_to_all_single",
- "barrier",
- "batch_isend_irecv",
- "broadcast",
- "send_object_list",
- "recv_object_list",
- "broadcast_object_list",
- "destroy_process_group",
- "gather",
- "gather_object",
- "get_backend_config",
- "get_backend",
- "get_default_backend_for_device",
- "get_rank",
- "get_world_size",
- "get_pg_count",
- "group",
- "init_process_group",
- "irecv",
- "is_gloo_available",
- "is_initialized",
- "is_mpi_available",
- "is_backend_available",
- "is_nccl_available",
- "is_torchelastic_launched",
- "is_ucc_available",
- "is_xccl_available",
- "isend",
- "monitored_barrier",
- "new_group",
- "new_subgroups",
- "new_subgroups_by_enumeration",
- "recv",
- "reduce",
- "reduce_scatter",
- "scatter",
- "scatter_object_list",
- "send",
- "supports_complex",
- "AllreduceCoalescedOptions",
- "AllreduceOptions",
- "AllToAllOptions",
- "BarrierOptions",
- "BroadcastOptions",
- "GatherOptions",
- "GroupName",
- "PrefixStore",
- "ProcessGroup",
- "ReduceOp",
- "ReduceOptions",
- "ReduceScatterOptions",
- "ScatterOptions",
- "Store",
- "DebugLevel",
- "get_debug_level",
- "Work",
- "default_pg_timeout",
- "get_group_rank",
- "get_global_rank",
- "get_process_group_ranks",
- "reduce_op",
- "all_gather_into_tensor",
- "reduce_scatter_tensor",
- "get_node_local_rank",
- "split_group",
- "shrink_group",
- ]
- _MPI_AVAILABLE = True
- _NCCL_AVAILABLE = True
- _GLOO_AVAILABLE = True
- _UCC_AVAILABLE = True
- _XCCL_AVAILABLE = True
- try:
- # pyrefly: ignore [import-error, missing-import]
- from torchcomms._comms import _BackendWrapper, new_comm
- _TORCHCOMM_AVAILABLE = True
- except ImportError:
- _TORCHCOMM_AVAILABLE = False
- def _use_torchcomms_enabled() -> bool:
- """Check if torchcomms is enabled via config."""
- return _TORCHCOMM_AVAILABLE and dist_config.use_torchcomms
- _pickler = pickle.Pickler
- _unpickler = pickle.Unpickler
- GroupName = NewType("GroupName", str)
- # Change __module__ of all imported types from torch._C._distributed_c10d that are public
- def _export_c_types() -> None:
- _public_types_to_change_module = [
- AllreduceCoalescedOptions,
- AllreduceOptions,
- AllToAllOptions,
- BarrierOptions,
- BroadcastOptions,
- GatherOptions,
- PrefixStore,
- ProcessGroup,
- ReduceOp,
- ReduceOptions,
- ReduceScatterOptions,
- ScatterOptions,
- Store,
- DebugLevel,
- get_debug_level,
- Work,
- ]
- for type in _public_types_to_change_module:
- type.__module__ = "torch.distributed.distributed_c10d"
- _export_c_types()
- try:
- from torch._C._distributed_c10d import ProcessGroupMPI
- ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d"
- __all__ += ["ProcessGroupMPI"]
- except ImportError:
- _MPI_AVAILABLE = False
- try:
- from torch._C._distributed_c10d import ProcessGroupNCCL
- ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
- __all__ += ["ProcessGroupNCCL"]
- except ImportError:
- _NCCL_AVAILABLE = False
- try:
- from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo
- ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d"
- __all__ += ["ProcessGroupGloo"]
- except ImportError:
- _GLOO_AVAILABLE = False
- try:
- from torch._C._distributed_c10d import ProcessGroupUCC
- ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d"
- __all__ += ["ProcessGroupUCC"]
- except ImportError:
- _UCC_AVAILABLE = False
- try:
- from torch._C._distributed_c10d import ProcessGroupXCCL
- ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d"
- __all__ += ["ProcessGroupXCCL"]
- except ImportError:
- _XCCL_AVAILABLE = False
- logger = logging.getLogger(__name__)
- PG_WRAPPER_STORE_PREFIX = "pg_wrapper"
- # Some reduce ops are not supported by complex numbers and will result in an error.
- # We currently provide complex support to the distributed API by viewing
- # complex tensors as real (torch.view_as_real), meaning that calling
- # these unsupported ops will return garbage values rather than error out.
- # (e.g. max(2+3i, 3+2i) = 3+3i)
- # We'd like calls to unsupported ops to error out accordingly,
- # rather than returning garbage values.
- def supports_complex(reduceOp: ReduceOp) -> bool:
- """Return true if reduce ops is supported. False otherwise."""
- denyList = [
- ReduceOp.MAX,
- ReduceOp.MIN,
- ReduceOp.PRODUCT,
- ReduceOp.BAND,
- ReduceOp.BOR,
- ReduceOp.BXOR,
- ]
- return reduceOp not in denyList
- # TODO refactor into enum/strenum
- class Backend(str): # noqa: SLOT000
- """
- An enum-like class for backends.
- Available backends: GLOO, NCCL, UCC, MPI, XCCL, FAKE, and other registered backends.
- The values of this class are lowercase strings, e.g., ``"gloo"``. They can
- be accessed as attributes, e.g., ``Backend.NCCL``.
- This class can be directly called to parse the string, e.g.,
- ``Backend(backend_str)`` will check if ``backend_str`` is valid, and
- return the parsed lowercase string if so. It also accepts uppercase strings,
- e.g., ``Backend("GLOO")`` returns ``"gloo"``.
- .. note:: The entry ``Backend.UNDEFINED`` is present but only used as
- initial value of some fields. Users should neither use it directly
- nor assume its existence.
- """
- UNDEFINED = "undefined"
- GLOO = "gloo"
- NCCL = "nccl"
- UCC = "ucc"
- MPI = "mpi"
- XCCL = "xccl"
- FAKE = "fake"
- _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])
- _plugins: dict[str, _BackendPlugin] = {}
- backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI, FAKE]
- # 3rd-party devices can register the default backend support here
- default_device_backend_map: dict[str, str] = {
- "cpu": GLOO,
- "cuda": NCCL,
- "xpu": XCCL,
- "mps": GLOO,
- }
- backend_capability: dict[str, list[str]] = {
- GLOO: ["cpu", "cuda"],
- NCCL: ["cuda"],
- XCCL: ["xpu"],
- UCC: ["cpu", "cuda"],
- MPI: ["cpu", "cuda"],
- FAKE: ["cpu", "cuda", "hpu", "xpu"],
- }
- backend_type_map: dict[str, ProcessGroup.BackendType] = {
- UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
- GLOO: ProcessGroup.BackendType.GLOO,
- NCCL: ProcessGroup.BackendType.NCCL,
- XCCL: ProcessGroup.BackendType.XCCL,
- UCC: ProcessGroup.BackendType.UCC,
- MPI: ProcessGroup.BackendType.MPI,
- FAKE: ProcessGroup.BackendType.CUSTOM,
- }
- def __new__(cls, name: str):
- """Create and return a new instance of the class."""
- if not isinstance(name, str):
- raise ValueError("Backend constructor parameter must be string-ish")
- value = getattr(Backend, name.upper(), Backend.UNDEFINED)
- if value == Backend.UNDEFINED:
- value = name.lower()
- return value
- @classmethod
- def register_backend(
- cls,
- name,
- func,
- extended_api: bool = False,
- devices: str | list[str] | None = None,
- ) -> None:
- """
- Register a new backend with the given name and instantiating function.
- This class method is used by 3rd party ``ProcessGroup`` extension to
- register new backends.
- Args:
- name (str): Backend name of the ``ProcessGroup`` extension. It
- should match the one in ``init_process_group()``.
- func (function): Function handler that instantiates the backend.
- The function should be implemented in the backend
- extension and takes four arguments, including
- ``store``, ``rank``, ``world_size``, and ``timeout``.
- extended_api (bool, optional): Whether the backend supports extended argument structure.
- Default: ``False``. If set to ``True``, the backend
- will get an instance of ``c10d::DistributedBackendOptions``, and
- a process group options object as defined by the backend implementation.
- device (str or list of str, optional): device type this backend
- supports, e.g. "cpu", "cuda", etc. If `None`,
- assuming both "cpu" and "cuda"
- .. note:: This support of 3rd party backend is experimental and subject to change.
- """
- # This takes care of CUSTOM Out-of-tree backend types, update in backend_list indicates availability
- if not hasattr(Backend, name.upper()):
- setattr(Backend, name.upper(), name.lower())
- if name.lower() not in Backend.backend_list:
- Backend.backend_list.append(name.lower())
- if devices is not None:
- for device in devices:
- if device not in Backend.default_device_backend_map:
- Backend.default_device_backend_map[device] = name.lower()
- Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM
- # Update device capability matrix in Backend class
- if devices is None:
- # This is more of a backward support for groups like `threaded`:
- # assume default devices "cpu" and "cuda", but warn
- warnings.warn(
- f"Device capability of {name} unspecified, assuming `cpu` and "
- "`cuda` or `xpu`. Please specify it via the `devices` argument of "
- "`register_backend`.",
- stacklevel=2,
- )
- Backend.backend_capability[name.lower()] = (
- ["cpu", "cuda", "xpu"] if torch.xpu.is_available() else ["cpu", "cuda"]
- )
- elif isinstance(devices, str):
- # Single device string specified. Simply convert to list.
- Backend.backend_capability[name.lower()] = [devices]
- else:
- Backend.backend_capability[name.lower()] = devices
- Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api)
- class BackendConfig:
- """Backend configuration class."""
- def __init__(self, backend: Backend):
- """Init."""
- self.device_backend_map: dict[str, Backend] = {}
- # pyrefly: ignore [bad-assignment]
- backend = str(backend)
- if backend == Backend.UNDEFINED:
- # Detect the accelerator on the machine. If no accelerator is
- # available, it returns CPU.
- device_type = torch._C._get_accelerator().type
- try:
- backend_str = Backend.default_device_backend_map[device_type]
- self.device_backend_map[device_type] = Backend(backend_str)
- except KeyError:
- raise ValueError(
- f"We detected accelerator {device_type} on your machine. "
- f"But we don't know which communication backend to use for this accelerator. "
- f"Please specify the `backend` argument in the `init_process_group` call."
- ) from None
- elif backend.lower() in Backend.backend_list:
- # Cases for when backend is a single string (without device types)
- # e.g. "nccl", "gloo", "ucc", "mpi"
- supported_devices = Backend.backend_capability[backend.lower()]
- backend_val = Backend(backend)
- self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
- elif ":" in backend.lower():
- # Backend specified in "device:backend" format
- # make sure the backend string is in the correct format
- # "{device_type1}:{backend1},{device_type2}:{backend2}"
- # e.g. "cpu:gloo,cuda:nccl"
- backend_str_error_message = f"""The custom backend string argument is invalid: {backend}.
- Custom backend string is an experimental feature where the backend string must be in the format:
- "<device_type1>:<backend1>,<device_type2>:<backend2>...". e.g. 'cpu:gloo,cuda:nccl'"""
- # parse the backend string and populate the device_backend_map
- for device_backend_pair_str in backend.lower().split(","):
- device_backend_pair = device_backend_pair_str.split(":")
- if len(device_backend_pair) != 2:
- raise ValueError(
- f"Invalid device:backend pairing: \
- {device_backend_pair_str}. {backend_str_error_message}"
- )
- # pyrefly: ignore [bad-assignment]
- device, backend = device_backend_pair
- if device in self.device_backend_map:
- raise ValueError(
- f"Duplicate device type {device} \
- in backend string: {backend}. {backend_str_error_message}"
- )
- self.device_backend_map[device] = Backend(backend)
- else:
- # User specified a single backend name whose device capability is
- # unknown, assuming it can support the default devices of PyTorch
- # (cpu and cuda)
- warnings.warn(
- f"Device capability of {backend} unknown, assuming `cpu` and "
- "`cuda`. You can specify it in `device:backend` format in "
- "`init_process_group` call.",
- stacklevel=2,
- )
- backend_val = Backend(backend)
- self.device_backend_map = {
- "cpu": backend_val,
- "cuda": backend_val,
- "xpu": backend_val,
- }
- logger.info("Using backend config: %s", self.device_backend_map)
- def __repr__(self):
- """Return all the device:backend pairs separated by commas."""
- return ",".join(
- f"{device}:{backend}" for device, backend in self.device_backend_map.items()
- )
- def get_device_backend_map(self) -> dict[str, Backend]:
- """Return backend map of the device."""
- return self.device_backend_map
- class _reduce_op:
- r"""
- Deprecated enum-like class.
- For reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``.
- :class:`~torch.distributed.ReduceOp` is recommended to use instead.
- """
- def __init__(self) -> None:
- # __members__ is a dict storing key-value pairs for enum classes
- for k, v in ReduceOp.RedOpType.__members__.items():
- setattr(self, k, v)
- self.__members__ = ReduceOp.RedOpType.__members__
- @deprecated(
- "`torch.distributed.reduce_op` is deprecated, "
- "please use `torch.distributed.ReduceOp` instead",
- category=FutureWarning,
- )
- def __getattribute__(self, key):
- return object.__getattribute__(self, key)
- reduce_op = _reduce_op()
- class P2POp:
- """
- A class to build point-to-point operations for ``batch_isend_irecv``.
- This class builds the type of P2P operation, communication buffer, peer rank,
- Process Group, and tag. Instances of this class will be passed to
- ``batch_isend_irecv`` for point-to-point communications.
- Args:
- op (Callable): A function to send data to or receive data from a peer process.
- The type of ``op`` is either ``torch.distributed.isend`` or
- ``torch.distributed.irecv``.
- tensor (Tensor): Tensor to send or receive.
- peer (int, optional): Destination or source rank.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- tag (int, optional): Tag to match send with recv.
- group_peer (int, optional): Destination or source rank.
- """
- def __init__(
- self,
- op: Callable,
- tensor: torch.Tensor,
- peer: int | None = None,
- group: ProcessGroup | None = None,
- tag: int = 0,
- group_peer: int | None = None,
- ):
- """Init."""
- self.op = op
- self.tensor = tensor
- self.group = _group_or_default_group(group)
- self.peer = _canonicalize_group_rank(
- self.group, peer, group_peer, return_global=True
- )
- self.tag = tag
- self.group_peer = _canonicalize_group_rank(self.group, peer, group_peer)
- def __new__(
- cls,
- op: Callable,
- tensor: torch.Tensor,
- peer: int | None = None,
- group: ProcessGroup | None = None,
- tag: int = 0,
- group_peer: int | None = None,
- ):
- """Create and return a new instance of the class."""
- _check_op(op)
- _check_single_tensor(tensor, "tensor")
- return object.__new__(cls)
- def __repr__(self):
- my_group_rank = get_rank(self.group)
- op_name = self.op.__name__
- group_name = self.group.group_name if self.group else "default_pg"
- if "send" in op_name:
- s = my_group_rank
- d = self.group_peer
- elif "recv" in op_name:
- s = self.group_peer
- d = my_group_rank
- else:
- return super().__repr__()
- return f"P2POp({op_name} pg={group_name}, group_src={s}, group_dst={d}, {self.tensor.shape}, {self.tensor.dtype})"
- class _CollOp:
- """
- A class to capture collective operations.
- Args:
- op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``.
- tensor (Tensor): Tensor to operate on.
- dst_tensor (Tensor, optional): Provided when source and destination tensors are not the same.
- redop (ReduceOp, optional): reduce operation.
- root (int, optional): root of broadcast or reduce.
- """
- def __init__(
- self,
- op: Callable,
- tensor: torch.Tensor,
- dst_tensor: torch.Tensor | None = None,
- redop: ReduceOp | None = None,
- root: int | None = None,
- ):
- self.op = op
- self.tensor = tensor
- self.dst_tensor = dst_tensor
- self.redop = redop
- self.root = root
- # DO NOT USE THESE FIELDS DIRECTLY.
- # Use them through the _world object to make sure the _world override mechanism
- _pg_map: dict[ProcessGroup, tuple[str, Store]] = {}
- _pg_names: dict[ProcessGroup, GroupName] = {}
- _pg_group_ranks: dict[ProcessGroup, dict[int, int]] = {}
- # For a pg, it is a map from ProcessGroup to BackendConfig
- _pg_backend_config: dict[ProcessGroup, str] = {}
- _group_count = 0
- _tags_to_pg: dict[str, list[ProcessGroup]] = {}
- _pg_to_tag: dict[ProcessGroup, str] = {}
- _backend: str | None = None
- class _World:
- """
- Container class for c10d process group state.
- This is used during registration and lookup of PG state.
- .. warning:: This is an experimental API intended to expose the inner workings
- of c10d and is subject to change..
- """
- def __init__(self) -> None:
- self._default_pg = None
- self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {}
- self._comms: list[Any] = []
- @property
- def default_pg(self) -> ProcessGroup | None:
- """
- Process group that includes all ranks of the cluster.
- This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed
- but None is provided.
- """
- return self._default_pg
- @default_pg.setter
- def default_pg(self, value) -> None:
- self._default_pg = value
- @property
- def pg_map(self) -> dict[ProcessGroup, tuple[str, Store]]:
- """
- Provide Mapping from ProcessGroup to backend name and store.
- For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
- For MPI pg, it is a map from ProcessGroup to (Backend, None)
- TODO don't expose the map, expose fine grained ops
- """
- global _pg_map
- return _pg_map
- @property
- def pg_names(self) -> dict[ProcessGroup, GroupName]:
- """
- Process group's names, map from ProcessGroup to str.
- TODO don't expose the map, expose fine grained ops
- """
- global _pg_names
- return _pg_names
- @property
- def pg_group_ranks(self) -> dict[ProcessGroup, dict[int, int]]:
- """
- Process group's global rank to local rank mapping.
- TODO don't expose the map, expose fine grained ops
- """
- global _pg_group_ranks
- return _pg_group_ranks
- @property
- def pg_backend_config(self) -> dict[ProcessGroup, str]:
- """
- Process group's backend config.
- TODO don't expose the map, expose fine grained ops
- """
- global _pg_backend_config
- return _pg_backend_config
- @property
- def group_count(self) -> int:
- """
- Process group count for default naming.
- TODO don't expose group_count, use something else instead
- """
- global _group_count
- return _group_count
- @group_count.setter
- def group_count(self, value: int) -> None:
- """Use to compute the name of ProcessGroups when using global synchronization."""
- global _group_count
- _group_count = value
- @property
- def tags_to_pg(self) -> dict[str, list[ProcessGroup]]:
- global _tags_to_pg
- return _tags_to_pg
- @property
- def pg_to_tag(self) -> dict[ProcessGroup, str]:
- global _pg_to_tag
- return _pg_to_tag
- @property
- def pg_coalesce_state(self) -> dict[ProcessGroup, list[_CollOp]]:
- return self._pg_coalesce_state
- @property
- def comms(self) -> list[Any]:
- return self._comms
- @property
- def pg_config_info(self) -> list[dict[str, Any]]:
- """
- Return a list of dict with process groups and backends.
- Along with their unique IDs and configurations (types and ranks).
- """
- config_info: list[dict[str, Any]] = []
- default_pg_size = _get_group_size(None)
- for pg in self.pg_map:
- ranks = self.pg_group_ranks[pg]
- config_info.append(
- {
- "pg_name": self.pg_names[pg],
- "pg_desc": pg.group_desc,
- "backend_config": self.pg_backend_config[pg],
- "ranks": (
- list(ranks.keys()) if len(ranks) != default_pg_size else []
- ), # 'ranks' is an empty list when all ranks are involved in a pg
- "group_size": len(ranks),
- "group_count": self.group_count,
- }
- )
- return config_info
- _world = _World()
- """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
- class _WorldMeta(type):
- """
- Meta class of ``group`` and ``GroupMember``.
- Allows them to have the class property ``WORLD``.
- """
- # Points to the default PG once initialized.
- @property
- def WORLD(cls) -> ProcessGroup | None:
- return _world.default_pg
- @WORLD.setter
- def WORLD(cls, pg: ProcessGroup | None):
- _world.default_pg = pg
- class group(metaclass=_WorldMeta):
- """Group class. Placeholder."""
- class GroupMember(metaclass=_WorldMeta):
- """Group member class."""
- NON_GROUP_MEMBER = -100
- def _get_default_timeout(backend: Backend) -> timedelta:
- # see note on nccl vs other backend timeout (constants.py)
- if backend == Backend.NCCL:
- if not isinstance(default_pg_nccl_timeout, timedelta):
- # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was
- # changed to be a warning. We should fix the moco model.
- warnings.warn(
- "Attempted to get default timeout for nccl backend, but NCCL support is not compiled",
- stacklevel=2,
- )
- return default_pg_timeout
- return default_pg_nccl_timeout
- else:
- return default_pg_timeout
- def _check_valid_timeout(timeout: Any) -> None:
- if not isinstance(timeout, timedelta):
- raise TypeError(
- f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
- )
- # Default process group state
- _default_pg_init_method: str | None = None
- STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
- def _get_object_coll_device(group: ProcessGroup | None = None) -> str:
- """
- .. note:: This is an internal helper and does not have backward
- compatibility, please use with caution.
- Return the device type to use with ``group`` for object collectives or
- barrier.
- There are selection rules:
- 1. If user specifies exactly one backend in ``init_process_group`` call:
- use that backend
- 2. Else if user specifies multiple "device:backend" pairs in init_process_group:
- If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory);
- Otherwise, use the first backend (sort of a random pick).
- Args:
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- Returns:
- str: The device type to use for object collective with ``group``.
- """
- group = group or _get_default_group()
- if not isinstance(group, ProcessGroup):
- warnings.warn(
- f"You are using a Backend {type(group)} as a ProcessGroup. "
- "This usage is deprecated since PyTorch 2.0. Please use a public API "
- "of PyTorch Distributed instead.",
- stacklevel=2,
- )
- # Provide backward compatibility to cases where `group` passed in is
- # actually a Backend (like `ProcessGroupGloo`) rather than a
- # `ProcessGroup` in PT 2.0 sense
- if isinstance(group, ProcessGroupGloo):
- # RPC uses Gloo for object collectives
- return "cpu"
- else:
- raise ValueError(f"Expecting a ProcessGroup, but got a {type(group)}.")
- """
- ``group._device_types`` is a property pybind that returns the devices
- ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the
- ``group`` supports multiple devices.
- """
- devices = group._device_types
- if len(devices) == 1:
- # User fixed exactly one backend in `init_process_group`
- return devices[0].type
- elif len(devices) == 0:
- # No backend has been registered with this PG (maybe because no
- # collective has been run?) We pick cpu as the default and hopefully
- # this would lazily init Gloo or other available cpu backend.
- return "cpu"
- elif torch.device("cpu") in devices:
- # There are multiple backends in this PG and cpu is among them.
- # cpu is preferred as the object is in cpu memory. No need for device
- # copy.
- return "cpu"
- else:
- # No cpu in the backend list. Randomly pick the first backend
- return devices[0].type
- def _get_pg_default_device(group: ProcessGroup | None = None) -> torch.device:
- """
- .. note:: This method will be deprecated, it only stays for
- backward-compatibility reason. Alternatives:
- - If you need to find a device for object collectives, please use
- `_get_object_coll_device(group)`.
- - If you need to query the device types supported by group, please use
- `_device_capability(group)`.
- Return the device type registered with ``group``.
- For example, if `init_process_group("nccl", ...)` was called, the returned
- value would be `torch.device("cuda")`.
- Errors out if no device has been registered.
- Args:
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- Returns:
- torch.device: The device type registered with ``group``.
- """
- warnings.warn(
- "`_get_pg_default_device` will be deprecated, it only stays for "
- "backward-compatibility reason. If you need to find a device for object "
- "collectives, please use `_get_object_coll_device`. If you need to query "
- "the device types supported by group, please use "
- "`_device_capability(group)`. ",
- stacklevel=2,
- )
- group = group or _get_default_group()
- if not isinstance(group, ProcessGroup):
- # Provide backward compatibility to cases where `group` passed in is
- # actually a Backend (like `ProcessGroupGloo`) rather than a
- # `ProcessGroup` in PT 2.0 sense
- warnings.warn(
- f"You are using a Backend {type(group)} as a ProcessGroup. "
- "This usage is deprecated since PyTorch 2.0. Please use a public API "
- "of PyTorch Distributed instead.",
- FutureWarning,
- stacklevel=3,
- )
- # Most users create Gloo with private API for object collectives
- return torch.device("cpu")
- """
- ``group._device_types`` is a property pybind that returns the devices
- ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the
- ``group`` supports multiple devices.
- """
- devices = group._device_types
- if len(devices) == 1:
- # User fixed exactly one backend in `init_process_group`
- return devices[0]
- elif len(devices) == 0:
- raise RuntimeError(
- "Default device not found, because no backend has been registered "
- "with this ProcessGroup."
- )
- else:
- # There are multiple backends in this PG.
- if torch.device("cpu") in devices:
- rv = torch.device("cpu")
- else:
- rv = devices[0]
- warnings.warn(
- "Multiple backends are registered with this ProcessGroup. We cannot "
- f"determine which one is the default. Returning {rv}. "
- "Please consider using other APIs.",
- stacklevel=2,
- )
- return rv
- def _device_capability(group: ProcessGroup | None = None) -> list[str]:
- """
- Return the device type(s) supported by ``group``.
- Args:
- group (ProcessGroup, optional): The process group to query. If None,
- the default process group will be used.
- Returns:
- List[str]: A list of device types supported by ``group``.
- """
- group = group or _get_default_group()
- return [device.type for device in group._device_types]
- @_time_logger
- def _store_based_barrier(
- rank,
- store,
- group_name: GroupName,
- rendezvous_count,
- timeout,
- logging_interval=timedelta(seconds=10),
- ) -> None:
- """
- Store based barrier for synchronizing processes.
- Barrier based on store which is used for synchronizing processes after
- ``init_process_group`` or ``new_group``. Intended to be used only with
- those two methods and is not a generic alternative to ``barrier()``.
- """
- store_key = f"{STORE_BASED_BARRIER_PREFIX}:{group_name}"
- store.add(store_key, 1)
- logger.debug("Added key: %s to store for rank: %s", store_key, rank)
- # Now wait for all workers to check in with the store.
- world_size = rendezvous_count
- worker_count = store.add(store_key, 0)
- last_worker_key = f"{store_key}:last_worker"
- if worker_count == world_size:
- store.set(last_worker_key, "1")
- # adjust the timeout to be at least 10secs + 1sec per thousand ranks to reduce the odds of timeout
- # this value was empirically found while scale testing.
- logging_interval = max(logging_interval, timedelta(seconds=10 + world_size / 1000))
- start = time.time()
- while True:
- try:
- # This will throw an exception after the logging_interval in which we print out
- # the status of the group or time out officially, throwing runtime error
- store.wait([last_worker_key], logging_interval)
- break
- except RuntimeError as e:
- worker_count = store.add(store_key, 0)
- # Print status periodically to keep track.
- logger.debug( # noqa: G200
- "Waiting in store based barrier to initialize process group for %s seconds"
- "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)",
- time.time() - start,
- rank,
- store_key,
- world_size,
- worker_count,
- timeout,
- e,
- )
- if timedelta(seconds=(time.time() - start)) > timeout:
- raise DistStoreError( # noqa: B904
- "Timed out initializing process group in store based barrier on "
- f"rank {rank}, for key: {store_key} (world_size={world_size}, "
- f"num_workers_joined={worker_count}, timeout={timeout} error={e})"
- )
- logger.info(
- "Rank %s: Completed store-based barrier for key:%s with %s nodes.",
- rank,
- store_key,
- world_size,
- )
- def _rank_not_in_group(group: ProcessGroup | None) -> bool:
- """Check if the current process's rank is not in a given group."""
- if group is None:
- return False
- return group == GroupMember.NON_GROUP_MEMBER
- def _warn_not_in_group(op_name) -> None:
- global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank()
- warnings.warn(
- f"Running {op_name} on global rank {global_rank} which does not "
- "belong to the given group.",
- stacklevel=2,
- )
- def get_group_rank(group: ProcessGroup, global_rank: int) -> int:
- """
- Translate a global rank into a group rank.
- ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError.
- Args:
- group (ProcessGroup): ProcessGroup to find the relative rank.
- global_rank (int): Global rank to query.
- Returns:
- Group rank of ``global_rank`` relative to ``group``
- N.B. calling this function on the default process group returns identity
- """
- if group is GroupMember.WORLD:
- return global_rank
- if group not in _world.pg_group_ranks:
- raise ValueError(
- f"Group {group} is not registered, please create group with torch.distributed.new_group API"
- )
- group_ranks = _world.pg_group_ranks[group]
- if global_rank not in group_ranks:
- raise ValueError(f"Global rank {global_rank} is not part of group {group}")
- return group_ranks[global_rank]
- def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
- """
- Translate a group rank into a global rank.
- ``group_rank`` must be part of `group` otherwise this raises RuntimeError.
- Args:
- group (ProcessGroup): ProcessGroup to find the global rank from.
- group_rank (int): Group rank to query.
- Returns:
- Global rank of ``group_rank`` relative to ``group``
- N.B. calling this function on the default process group returns identity
- """
- if group is GroupMember.WORLD:
- return group_rank
- if group not in _world.pg_group_ranks:
- raise ValueError(
- f"Group {group} is not registered, please create group with torch.distributed.new_group API"
- )
- for rank, grp_rank in _world.pg_group_ranks[group].items():
- if grp_rank == group_rank:
- return rank
- raise ValueError(f"Group rank {group_rank} is not part of group {group}")
- # TODO: remove this once the ecosystem moves away from it.
- @deprecated(
- "`torch.distributed.distributed_c10d._get_global_rank` is deprecated, "
- "please use `torch.distributed.distributed_c10d.get_global_rank` instead",
- category=FutureWarning,
- )
- def _get_global_rank(group, rank) -> int:
- """Use get_global_rank as this method is deprecated."""
- return get_global_rank(group, rank)
- def get_process_group_ranks(group: ProcessGroup | None) -> list[int]:
- """
- Get all ranks associated with ``group``.
- Args:
- group (Optional[ProcessGroup]): ProcessGroup to get all ranks from.
- If None, the default process group will be used.
- Returns:
- List of global ranks ordered by group rank.
- """
- return list(_world.pg_group_ranks[group or _get_default_group()].keys())
- def _get_group_size(group: ProcessGroup | None) -> int:
- """Get a given group's world size."""
- if group is GroupMember.WORLD or group is None:
- default_pg = _get_default_group()
- return default_pg.size()
- return group.size()
- def _get_group_size_by_name(group_name: GroupName) -> int:
- group = _resolve_process_group(group_name)
- return group.size()
- def _resolve_group_name_by_ranks_and_tag(ranks: list[int], tag: str) -> GroupName:
- # TODO(yifu): remove this function once ranks + tag is not a supported
- # identifier for process group for functional collectives.
- group = _find_pg_by_ranks_and_tag(tag, ranks)
- if group is None:
- raise ValueError("")
- return group.group_name
- def _check_single_tensor(param, param_name: str) -> None:
- """Check that the parameter ``param_name`` is a single tensor."""
- if not isinstance(param, torch.Tensor):
- raise TypeError(
- f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor
- but got {type(param)} instead."""
- )
- def _check_tensor_list(param, param_name: str) -> None:
- """Check that the parameter ``param_name`` is a list of tensors."""
- if not isinstance(param, list):
- raise TypeError(
- f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor]
- but got {type(param)} instead."""
- )
- elif not all(isinstance(p, torch.Tensor) for p in param):
- raise TypeError(
- f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor]
- but got {type(param)} with elements of type {[type(p) for p in param]}."""
- )
- def _group_or_default_group(group: ProcessGroup | None = None) -> ProcessGroup:
- if group is None or group is GroupMember.WORLD:
- group = _get_default_group()
- return group
- def _canonicalize_group_rank(
- group: ProcessGroup,
- global_rank: int | None = None,
- group_rank: int | None = None,
- return_global: bool = False,
- ) -> int:
- """
- Helper method to take _either_ a global rank or a group rank and produce a group rank.
- If 'return_global' is true, produce a global rank instead of a group rank.
- """
- if group_rank is not None:
- if global_rank is not None:
- raise ValueError("Can't specify both group_rank and global_rank")
- if return_global:
- return get_global_rank(group, group_rank)
- else:
- if global_rank is None:
- raise ValueError("Must specify global_rank or group_rank")
- if return_global:
- return global_rank
- group_rank = get_group_rank(group, global_rank)
- return group_rank
- def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str):
- if group.rank() == rank:
- raise ValueError(
- f"Invalid {rank_type} rank: {rank_type} rank should not be the same as "
- "the rank of the current process."
- )
- def _as_iterable(obj) -> collections.abc.Iterable:
- return obj if isinstance(obj, list) else (obj,)
- def _ensure_all_tensors_same_dtype(*tensors) -> None:
- last_dtype = None
- for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
- tensor_dtype = tensor.dtype
- # Mixing complex and its element type is allowed
- if tensor_dtype.is_complex:
- tensor_dtype = (
- torch.float32 if tensor_dtype == torch.complex64 else torch.complex128
- )
- if last_dtype is None:
- last_dtype = tensor_dtype
- else:
- if last_dtype != tensor_dtype:
- raise ValueError(
- "Invalid usage of tensors with different dtypes"
- f"Found {last_dtype} and {tensor.dtype}"
- )
- def _check_op(op) -> None:
- """Check that the ``op`` is either isend or irecv."""
- if op not in [isend, irecv]:
- raise ValueError(
- "Invalid ``op``. Expected ``op`` "
- "to be of type ``torch.distributed.isend`` or "
- "``torch.distributed.irecv``."
- )
- def _check_p2p_op_list(p2p_op_list) -> None:
- """
- Check that the ``p2p_op_list`` is a list of P2POp instances.
- Also, check that all ops use the same group.
- """
- if not isinstance(p2p_op_list, list) or not all(
- isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list
- ):
- raise ValueError(
- "Invalid ``p2p_op_list``. Each op is expected to "
- "to be of type ``torch.distributed.P2POp``."
- )
- group = p2p_op_list[0].group
- if not all(group == p2p_op.group for p2p_op in p2p_op_list):
- raise ValueError("All ops need to use the same group.")
- def is_mpi_available() -> bool:
- """Check if the MPI backend is available."""
- return _MPI_AVAILABLE
- def is_nccl_available() -> bool:
- """Check if the NCCL backend is available."""
- return _NCCL_AVAILABLE
- def is_gloo_available() -> bool:
- """Check if the Gloo backend is available."""
- return _GLOO_AVAILABLE
- def is_ucc_available() -> bool:
- """Check if the UCC backend is available."""
- return _UCC_AVAILABLE
- def is_xccl_available() -> bool:
- """Check if the XCCL backend is available."""
- return _XCCL_AVAILABLE
- def _check_single_backend_availability(backend_name: str) -> bool:
- """
- Helper function to check if a single backend is available.
- """
- available_func = getattr(
- torch.distributed, f"is_{str(backend_name).lower()}_available", None
- )
- if available_func:
- return available_func()
- return str(backend_name).lower() in Backend.backend_list
- def is_backend_available(backend: str) -> bool:
- """
- Check backend availability.
- Checks if the given backend is available and supports the built-in backends or
- third-party backends through function ``Backend.register_backend``.
- Args:
- backend (str): Backend name.
- Returns:
- bool: Returns true if the backend is available otherwise false.
- """
- # If the backend has an ``is_backend_available`` function, return the result of that function directly
- if ":" in backend.lower(): # composite backend like "cpu:gloo"
- backend_config = BackendConfig(Backend(backend))
- device_backend_map = backend_config.get_device_backend_map()
- return all(
- _check_single_backend_availability(str(backend_name))
- for backend_name in device_backend_map.values()
- )
- else:
- # Handle simple backend strings like "nccl", "gloo"
- return _check_single_backend_availability(backend)
- def is_initialized() -> bool:
- """Check if the default process group has been initialized."""
- return GroupMember.WORLD is not None
- def is_torchelastic_launched() -> bool:
- """
- Check whether this process was launched with ``torch.distributed.elastic`` (aka torchelastic).
- The existence of ``TORCHELASTIC_RUN_ID`` environment
- variable is used as a proxy to determine whether the current process
- was launched with torchelastic. This is a reasonable proxy since
- ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a
- non-null value indicating the job id for peer discovery purposes..
- """
- return os.getenv("TORCHELASTIC_RUN_ID") is not None
- def _is_barrier_after_init() -> int:
- # Environment variable to control whether process group should perform a
- # barrier after its init. Default value is 0, i.e. no barrier. If you
- # experience issue with this setting, you may set
- # `TORCH_DIST_INIT_BARRIER=1` to add the barrier.
- return int(os.getenv("TORCH_DIST_INIT_BARRIER", "0"))
- def _get_default_group() -> ProcessGroup:
- """Get the default process group created by init_process_group."""
- if not is_initialized():
- raise ValueError(
- "Default process group has not been initialized, "
- "please make sure to call init_process_group."
- )
- if TYPE_CHECKING:
- return not_none(GroupMember.WORLD)
- else:
- return GroupMember.WORLD
- def _get_default_store() -> Store:
- """Get the default store created by init_process_group."""
- if not is_initialized():
- raise ValueError(
- "Default process group has not been initialized, "
- "please make sure to call init_process_group."
- )
- default_pg = _get_default_group()
- _, default_store = _world.pg_map[default_pg]
- return default_store
- def _update_default_pg(pg: ProcessGroup | None) -> None:
- _world.default_pg = pg
- rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1
- torch._C._distributed_c10d._set_global_rank(rank)
- def get_backend_config(group: ProcessGroup | None = None) -> str:
- """
- Return the backend configuration of the given process group.
- Args:
- group (ProcessGroup, optional): The process group to work on. The
- default is the general main process group. If another specific group
- is specified, the calling process must be part of :attr:`group`.
- Returns:
- The backend configuration of the given process group as a lower case string.
- """
- pg = group or _get_default_group()
- if _rank_not_in_group(pg):
- raise ValueError("Invalid process group specified")
- backend_config = _world.pg_backend_config.get(pg)
- return str(not_none(backend_config))
- def get_backend(group: ProcessGroup | None = None) -> Backend:
- """
- Return the backend of the given process group.
- Args:
- group (ProcessGroup, optional): The process group to work on. The
- default is the general main process group. If another specific group
- is specified, the calling process must be part of :attr:`group`.
- Returns:
- The backend of the given process group as a lower case string.
- """
- pg = group or _get_default_group()
- if _rank_not_in_group(pg):
- raise ValueError("Invalid process group specified")
- pg_store = _world.pg_map.get(pg, None)
- if pg_store is None:
- raise ValueError(
- f"Process group {pg} is not initialized in the world group map. Please initialize the group first."
- )
- return Backend(not_none(pg_store)[0])
- def get_default_backend_for_device(device: str | torch.device) -> str:
- """
- Return the default backend for the given device.
- Args:
- device (Union[str, torch.device]): The device to get the default backend for.
- Returns:
- The default backend for the given device as a lower case string.
- """
- if isinstance(device, torch.device):
- device_str = device.type
- else:
- device_str = torch.device(device).type
- backend = Backend.default_device_backend_map.get(device_str)
- if backend is None:
- raise ValueError(f"Default backend not registered for device : {device}")
- return backend
- def _get_process_group_uid(pg: ProcessGroup) -> int:
- backend = None
- try:
- backend = pg._get_backend(torch.device("cuda"))
- except RuntimeError:
- pass
- if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
- return backend.uid
- return -1
- def _get_pg_config(group: ProcessGroup | None = None) -> dict[str, Any]:
- """
- Return the pg configuration of the given process group.
- """
- pg = group or _get_default_group()
- return {
- "pg_name": _get_process_group_name(pg),
- "pg_desc": pg.group_desc,
- "backend_config": get_backend_config(pg),
- "pg_size": _get_group_size(pg),
- "ranks": get_process_group_ranks(pg),
- }
- def _get_all_pg_configs() -> list[dict[str, Any]]:
- """
- Return the pg configuration of all the process groups.
- """
- config_info: list[dict[str, Any]] = [_get_pg_config(pg) for pg in _world.pg_map]
- return config_info
- def get_pg_count() -> int:
- """
- Return the number of process groups.
- """
- return _world.group_count
- def get_node_local_rank(fallback_rank: int | None = None) -> int:
- """
- Return the local rank of the current process relative to the node.
- Semantically, this is a useful concept for mapping processes to devices.
- For example, on a node with 8 accelerator you could use the node local rank to decide
- which accelerator device to bind the process to.
- In practice, the actual assignment of node local ranks is handled by the process launcher outside of pytorch,
- and communicated via the `LOCAL_RANK` environment variable.
- Torchrun will automatically populate `LOCAL_RANK`, but other launchers may not. If `LOCAL_RANK` is unspecified,
- this API will fall back to the provided kwarg 'fallback_rank' if specified, otherwise it will raise an error. The
- intent is to allow writing an application that runs either in single or multi device contexts without error.
- """
- if "LOCAL_RANK" in os.environ:
- return int(os.environ["LOCAL_RANK"])
- elif fallback_rank is not None:
- return int(fallback_rank)
- raise RuntimeError(
- "LOCAL_RANK is not in the environment. Consider passing fallback_rank to allow `get_node_local_rank` to work, "
- "assuming you are not running in a multi-device context and want the code to run locally instead."
- )
- def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None:
- """
- This API adds an ephemeral timeout extension for all PGs locally
- on one rank. The timeout gets reset when the first collective issued
- after API called finished.
- NOTE: We only support to set timeout for cuda backends for now.
- NOTE: While this feature
- provides flexibility in specific scenarios, it introduces statefulness
- to timeout setting. Therefore, it is advisable to use this API sparingly
- and consider alternative approaches, such as directly setting the timeout
- or utilizing a barrier collective (one can set any timeout to the barrier),
- whenever feasible.
- Args:
- timeout (timedelta): The delta of timeout to extend.
- Returns:
- None.
- """
- for pg in _world.pg_map:
- devices = pg._device_types
- if torch.device("cuda") in devices:
- backend = pg._get_backend(torch.device("cuda"))
- if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
- backend._add_ephemeral_timeout(timeout)
- def _set_pg_timeout(timeout: timedelta, group: ProcessGroup | None = None) -> None:
- """
- Set the timeout for the given process group when users want to use a different timeout instead of
- default values.
- Args:
- timeout (timedelta): Timeout for operations executed against the process group which
- users want to set. Default value is 10 minutes for NCCL and 30 minutes for other backends.
- This is the duration after which collectives will be aborted asynchronously and the process will crash.
- This is done since CUDA execution is async and it is no longer safe to continue executing user code since
- failed async NCCL operations might result in subsequent CUDA operations running on corrupted data.
- When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout.
- group (ProcessGroup, optional): The process group to work on. The
- default is the general main process group. If another specific group
- is specified, the calling process must be part of :attr:`group`.
- Returns:
- None
- """
- if group is None:
- group = _get_default_group()
- if _rank_not_in_group(group):
- raise ValueError("Invalid process group specified")
- if not isinstance(group, ProcessGroup):
- raise AssertionError(f"Expected ProcessGroup, got {type(group)}")
- devices = group._device_types
- backends = set()
- if torch.device("cpu") in devices and is_gloo_available():
- backend = group._get_backend(torch.device("cpu"))
- if isinstance(backend, ProcessGroupGloo):
- backends.add(backend)
- if torch.device("cuda") in devices:
- backend = group._get_backend(torch.device("cuda"))
- if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
- backends.add(backend) # type: ignore[arg-type]
- elif is_gloo_available() and isinstance(backend, ProcessGroupGloo):
- backends.add(backend) # type: ignore[arg-type]
- elif _use_torchcomms_enabled() and isinstance(backend, _BackendWrapper):
- backends.add(backend) # type: ignore[arg-type]
- if len(backends) == 0:
- warnings.warn(
- "Set timeout is now only supported for either nccl or gloo.", stacklevel=2
- )
- for backend in backends:
- backend._set_default_timeout(timeout)
- @_exception_logger
- @_time_logger
- def init_process_group(
- backend: str | None = None,
- init_method: str | None = None,
- timeout: timedelta | None = None,
- world_size: int = -1,
- rank: int = -1,
- store: Store | None = None,
- group_name: str = "",
- pg_options: Any | None = None,
- device_id: torch.device | int | None = None,
- _ranks: list[int] | None = None,
- ) -> None:
- """
- Initialize the default distributed process group.
- This will also initialize the distributed package.
- There are 2 main ways to initialize a process group:
- 1. Specify ``store``, ``rank``, and ``world_size`` explicitly.
- 2. Specify ``init_method`` (a URL string) which indicates where/how
- to discover peers. Optionally specify ``rank`` and ``world_size``,
- or encode all required parameters in the URL and omit them.
- If neither is specified, ``init_method`` is assumed to be "env://".
- Args:
- backend (str or Backend, optional): The backend to use. Depending on
- build-time configurations, valid values include ``mpi``, ``gloo``,
- ``nccl``, ``ucc``, ``xccl`` or one that is registered by a third-party
- plugin.
- Since 2.6, if ``backend`` is not provided, c10d will use a backend
- registered for the device type indicated by the `device_id` kwarg
- (if provided). The known default registrations today are: ``nccl``
- for ``cuda``, ``gloo`` for ``cpu``, ``xccl`` for ``xpu``.
- If neither ``backend`` nor ``device_id`` is provided, c10d will
- detect the accelerator on the run-time machine and use a backend
- registered for that detected accelerator (or ``cpu``).
- This field can be given as a lowercase string (e.g., ``"gloo"``),
- which can also be accessed via :class:`Backend` attributes (e.g.,
- ``Backend.GLOO``).
- If using multiple processes per machine with ``nccl`` backend, each
- process must have exclusive access to every GPU it uses, as sharing
- GPUs between processes can result in deadlock or NCCL invalid usage.
- ``ucc`` backend is experimental.
- Default backend for the device can be queried with
- :func:`get_default_backend_for_device`.
- init_method (str, optional): URL specifying how to initialize the
- process group. Default is "env://" if no
- ``init_method`` or ``store`` is specified.
- Mutually exclusive with ``store``.
- world_size (int, optional): Number of processes participating in
- the job. Required if ``store`` is specified.
- rank (int, optional): Rank of the current process (it should be a
- number between 0 and ``world_size``-1).
- Required if ``store`` is specified.
- store(Store, optional): Key/value store accessible to all workers, used
- to exchange connection/address information.
- Mutually exclusive with ``init_method``.
- timeout (timedelta, optional): Timeout for operations executed against
- the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends.
- This is the duration after which collectives will be aborted asynchronously and the process will crash.
- This is done since CUDA execution is async and it is no longer safe to continue executing user code since
- failed async NCCL operations might result in subsequent CUDA operations running on corrupted data.
- When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout.
- group_name (str, optional, deprecated): Group name. This argument is ignored
- pg_options (ProcessGroupOptions, optional): process group options
- specifying what additional options need to be passed in during
- the construction of specific process groups. As of now, the only
- options we support is ``ProcessGroupNCCL.Options`` for the ``nccl``
- backend, ``is_high_priority_stream`` can be specified so that
- the nccl backend can pick up high priority cuda streams when
- there're compute kernels waiting. For other available options to config nccl,
- See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
- device_id (torch.device | int, optional): a single, specific device
- this process will work on, allowing for backend-specific
- optimizations. Currently this has two effects, only under
- NCCL: the communicator is immediately formed (calling
- ``ncclCommInit*`` immediately rather than the normal lazy
- call) and sub-groups will use ``ncclCommSplit`` when
- possible to avoid unnecessary overhead of group creation. If you
- want to know NCCL initialization error early, you can also use this
- field. If an `int` is provided, the API assumes that the accelerator
- type at compile time will be used.
- _ranks: The ranks in the process group. If provided, the process
- group name will be the hash of all the ranks in the group.
- .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source
- on a system that supports MPI.
- .. note:: Support for multiple backends is experimental. Currently when no backend is
- specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend
- will be used for collectives with CPU tensors and the ``nccl`` backend will be used
- for collectives with CUDA tensors. A custom backend can be specified by passing in
- a string with format "<device_type>:<backend_name>,<device_type>:<backend_name>", e.g.
- "cpu:gloo,cuda:custom_backend".
- """
- global _world
- global _backend
- global _default_pg_init_method
- if GroupMember.WORLD is not None:
- raise ValueError("trying to initialize the default process group twice!")
- set_pytorch_distributed_envs_from_justknobs()
- # Depending on the import order, some trace_rules functions may be evaluated
- # during the import phase. In such a case, these functions may not correctly
- # add the distributed related rules due to import circular dependency.
- # We need to clear the lru_cache during the runtime to ensure the correctness
- # of these trace_rules.
- #
- # Since this API must be called before all distributed code being compiled,
- # clearing the cache here should be safe.
- if "torch._dynamo" in sys.modules:
- torch._dynamo.trace_rules.clear_lru_cache()
- if not ((store is None) or (init_method is None)):
- raise AssertionError("Cannot specify both init_method and store.")
- if store is not None:
- if not world_size > 0:
- raise AssertionError("world_size must be positive if using store")
- if not rank >= 0:
- raise AssertionError("rank must be non-negative if using store")
- elif init_method is None:
- init_method = "env://"
- # Get the compile-time accelerator type.
- # None indicates no accelerator support.
- acc = torch.accelerator.current_accelerator()
- # Auto complete device id
- if isinstance(device_id, int):
- if acc is None:
- raise ValueError(
- "device_id is an int, but no accelerator support is found from the current compilation. "
- "Please use a different compiled version that supports your accelerator."
- )
- device_id = torch.device(acc.type, device_id)
- # Sanity check device_id
- if device_id is not None and device_id.type != "cpu":
- # Type
- if acc is None or device_id.type != acc.type:
- raise ValueError(
- f"device_id {device_id} does not match the current compilation's accelerator support: {acc}. "
- "Please use a different compiled version that supports your accelerator."
- )
- # Index
- if device_id.index is None:
- raise ValueError("Please use a device_id with index.")
- # Range
- if device_id.index >= torch.accelerator.device_count():
- raise ValueError(
- f"device_id {device_id} is out of range. Please use a device index less than "
- f"the number of accelerators available: {torch.accelerator.device_count()}."
- )
- logger.info("Using device: %s", device_id)
- # If user did not provide a backend string but provided a device id, e.g.
- # >>> init_process_group(device_id=device)
- # we try to figure out the backend name based on the device type.
- if backend is None and device_id is not None:
- # Note: 3rd-party devices can register default backend through the
- # default map below.
- backend = Backend.default_device_backend_map.get(device_id.type)
- # If we still cannot figure it out, e.g.
- # >>> init_process_group()
- # we set it to `undefined` and rely on lazy init.
- if backend is None:
- backend = "undefined"
- # Convert string into `Backend` type
- backend = Backend(backend)
- if timeout is None:
- timeout = _get_default_timeout(backend)
- _check_valid_timeout(timeout)
- """
- Group name is not visible to users unless they access
- internals of c10d. This means we can ignore the value
- they provide as it not exposed in a public way.
- """
- if _ranks is None or len(_ranks) == 0:
- group_name = _process_group_name([], use_hashed_name=False)
- else:
- group_name = _process_group_name(_ranks, use_hashed_name=True)
- if backend == Backend.MPI:
- if world_size != -1 or rank != -1:
- warnings.warn(
- f"For MPI backend, world_size ({world_size}) and rank ({rank}) "
- "are ignored since they are assigned by the "
- "MPI runtime.",
- stacklevel=2,
- )
- default_pg, _ = _new_process_group_helper(
- -1,
- -1,
- [],
- backend,
- Store(), # Placeholder value since store cannot be None
- group_name,
- timeout=timeout,
- group_desc="default_pg",
- )
- else:
- # backward compatible API
- if store is None:
- if backend == Backend.FAKE:
- from torch.testing._internal.distributed.fake_pg import FakeStore
- store = FakeStore()
- else:
- rendezvous_iterator = rendezvous(
- not_none(init_method), rank, world_size, timeout=timeout
- )
- store, rank, world_size = next(rendezvous_iterator)
- store.set_timeout(timeout)
- # Use a PrefixStore to avoid accidental overrides of keys used by
- # different systems (e.g. RPC) in case the store is multi-tenant.
- store = PrefixStore("default_pg", store)
- default_pg, _ = _new_process_group_helper(
- world_size,
- rank,
- [],
- backend,
- store,
- group_name,
- backend_options=pg_options,
- timeout=timeout,
- device_id=device_id,
- group_desc="default_pg",
- )
- _update_default_pg(default_pg)
- _world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index]
- i: i
- for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined]
- }
- _backend = _world.pg_map[not_none(GroupMember.WORLD)][0]
- _default_pg_init_method = init_method
- old_hook = sys.excepthook
- excepthook_prefix = f"[rank{get_rank()}]"
- def _distributed_excepthook(*args):
- old_stderr = sys.stderr
- sys.stderr = buf = io.StringIO()
- try:
- old_hook(*args)
- finally:
- sys.stderr = old_stderr
- msg = buf.getvalue()
- msg = "\n".join(
- f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n")
- )
- sys.stderr.write(msg)
- sys.stderr.flush()
- sys.excepthook = _distributed_excepthook
- if _is_barrier_after_init() == 1:
- # barrier at the end to ensure that once we return from this method, all
- # process groups including global variables (if any) are updated
- # correctly on all ranks.
- # Update 04/2023: for large-scale runs, this barrier (esp. store-based
- # barrier) may be costly and/or unscalable. Also, in a lot of cases,
- # these barriers may be unnecessary, as proven by a green CI after
- # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
- # added which enables this barrier only when set to 1.
- logger.debug(
- "Performing barrier after ProcessGroup initialization since "
- "TORCH_DIST_INIT_BARRIER = 1"
- )
- if backend == Backend.MPI:
- # MPI backend doesn't use store.
- barrier()
- else:
- # Use store based barrier here since barrier() used a bunch of
- # default devices and messes up NCCL internal state.
- _store_based_barrier(rank, store, group_name, world_size, timeout)
- def _get_split_source(pg: ProcessGroup):
- split_from = None
- if pg.bound_device_id:
- split_from = pg._get_backend(pg.bound_device_id)
- elif pg is _world.default_pg:
- try:
- split_from = pg._get_backend(torch.device("cuda"))
- except RuntimeError:
- # no cuda device associated with this backend
- pass
- if not split_from or not split_from.supports_splitting:
- return None
- # If necessary, find a backend to split from by peeling process
- # group wrappers from our potentially wrapped process group.
- while _GLOO_AVAILABLE and isinstance(split_from, _ProcessGroupWrapper):
- split_from = split_from.wrapped_pg
- return split_from
- def _new_process_group_helper(
- group_size,
- group_rank,
- global_ranks_in_group,
- backend,
- store,
- group_name: GroupName,
- backend_options=None,
- timeout=None,
- pg_tag=None,
- device_id=None,
- group_desc=None,
- ):
- """
- Create a new distributed process group.
- This function must be called by ALL processes in the global group, even if
- the calling process is not part of the newly created group. In that case,
- this function returns GroupMember.NON_GROUP_MEMBER.
- This function is called with ``global_ranks_in_group == []`` for the default group.
- """
- global _world
- if group_name in _world.pg_names.values():
- raise ValueError(
- "The specified group name has already been "
- "created, please use a different group name"
- )
- if device_id is not None and (device_id.index is None or device_id.type == "cpu"):
- raise ValueError(
- "init_process_group device_id parameter must be an accelerator with an index"
- )
- # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
- _check_valid_timeout(timeout)
- if pg_tag not in [None, ""]:
- # creating with the same tag and rank set results in the same underlying PG
- existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group)
- if existing_group:
- _, prefix_store = _world.pg_map[existing_group]
- return existing_group, prefix_store
- group_desc = "undefined" if group_desc is None else group_desc
- # The list of group ranks is empty if we're creating the default group.
- is_default_group = len(global_ranks_in_group) == 0
- # nccl and potentially other backends allow creation of
- # communicators based on pre-existing ones, which can save
- # initialization time. Due to lazy initialization of
- # communicators in some backends, we have to be careful and only
- # split when we *know* the default PG has already started communicator initialization.
- # We know this if we have bound a device id to the default pg (eager initialized).
- if is_initialized() and _get_default_group().bound_device_id:
- split_from = _get_split_source(_get_default_group())
- else:
- split_from = None
- # If this is a subgroup (which means group_ranks is specified),
- # we check if the current process is a member of the new group.
- if not is_default_group:
- global_rank = _get_default_group().rank()
- if global_rank not in global_ranks_in_group:
- # If we are using `ncclCommSplit` (or similar split from
- # other APIs) to create the communicator, we will need to
- # call `ncclCommSplit` on *all* ranks in this new group's
- # parent group, even those not in the new group. This is
- # a requirement of the NCCL API as otherwise we would get
- # out of sync.
- if split_from:
- split_from.perform_nocolor_split(_get_default_group().bound_device_id)
- return GroupMember.NON_GROUP_MEMBER, None
- prefix_store = PrefixStore(f"{group_name}/", store)
- # The backend for PG will be set later based on what's inside BackendConfig
- # and timeout are set in each backend's option.
- pg: ProcessGroup = ProcessGroup(
- prefix_store,
- group_rank,
- group_size,
- )
- backend_config = BackendConfig(backend)
- # Set the default backend when single backend is passed in.
- if "," not in str(backend) and ":" not in str(backend):
- if backend not in Backend.backend_type_map:
- raise AssertionError(f"Unknown backend type {backend}")
- if backend == Backend.UNDEFINED:
- # Currently when backend is UNDEFINED, only one backend will be initialized
- # we use nccl (if cuda is available) or gloo as default backend
- # so we can correctly call getDefaultBackend which in ProcessGroup.
- if Backend.NCCL in backend_config.get_device_backend_map().values():
- pg._set_default_backend(ProcessGroup.BackendType.NCCL)
- else:
- pg._set_default_backend(ProcessGroup.BackendType.GLOO)
- else:
- pg._set_default_backend(Backend.backend_type_map[backend])
- # In order to correctly call pg._has_hooks(), we should set the default backend
- # when multi backend is passed in
- else:
- if Backend.NCCL in backend_config.device_backend_map.values():
- pg._set_default_backend(ProcessGroup.BackendType.NCCL)
- elif Backend._plugins.keys():
- custom_backend = next(iter(Backend._plugins.keys()))
- if custom_backend in backend_config.device_backend_map.values():
- pg._set_default_backend(ProcessGroup.BackendType.CUSTOM)
- else:
- pg._set_default_backend(ProcessGroup.BackendType.GLOO)
- if device_id:
- pg.bound_device_id = device_id
- backend_class: torch._C._distributed_c10d.Backend
- for device, backend_str in backend_config.get_device_backend_map().items():
- # Use the group name as prefix in the default store, such that
- # a single store can be reused by multiple groups.
- backend_prefix_store = PrefixStore(f"{device}/", prefix_store)
- if _use_torchcomms_enabled() and backend_str not in [Backend.FAKE]:
- torch_device = torch.device(device)
- logger.warning(
- "Using TorchComms backend (enabled via %s) for device %s with backend %s",
- "TORCH_DISTRIBUTED_USE_TORCHCOMMS env var"
- if os.environ.get("TORCH_DISTRIBUTED_USE_TORCHCOMMS")
- else "dist_config.use_torchcomms",
- torch_device,
- backend_str,
- )
- # TODO: figure out pg option conversion for torchComms.
- comm = new_comm(
- backend_str, torch_device, name=group_name, store=backend_prefix_store
- )
- # Keep a reference so the comm outlives this function scope.
- _world.comms.append(comm)
- group_name = GroupName(group_name)
- # pyrefly: ignore [bad-assignment]
- backend_class = _BackendWrapper(comm)
- backend_type = ProcessGroup.BackendType.CUSTOM
- elif backend_str == Backend.MPI:
- if not is_mpi_available():
- raise RuntimeError(
- "Distributed package doesn't have MPI built in."
- " MPI is only included if you build PyTorch from"
- " source on a host that has MPI installed."
- )
- backend_class = ProcessGroupMPI.create(global_ranks_in_group)
- backend_type = ProcessGroup.BackendType.MPI
- if not backend_class:
- return GroupMember.NON_GROUP_MEMBER, None
- # create new process group with accurate rank and size
- if pg.rank() == -1 and pg.size() == -1:
- pg = ProcessGroup(
- backend_prefix_store,
- backend_class.rank(),
- backend_class.size(),
- )
- pg._set_default_backend(backend_type)
- elif backend_str == Backend.GLOO:
- # TODO: remove this check after lazy initialization is supported
- # if pg_options is not None:
- # raise RuntimeError("GLOO options not supported")
- if not is_gloo_available():
- raise RuntimeError("Distributed package doesn't have Gloo built in")
- backend_class = ProcessGroupGloo(
- backend_prefix_store,
- group_rank,
- group_size,
- # pyrefly: ignore [bad-argument-type]
- timeout=timeout,
- )
- backend_class.options.global_ranks_in_group = global_ranks_in_group
- backend_class.options.group_name = group_name
- backend_type = ProcessGroup.BackendType.GLOO
- elif backend_str == Backend.NCCL:
- if not is_nccl_available():
- raise RuntimeError("Distributed package doesn't have NCCL built in")
- if backend_options is not None:
- if not isinstance(backend_options, ProcessGroupNCCL.Options):
- raise AssertionError(
- "Expected backend_options argument to be of type ProcessGroupNCCL.Options"
- )
- if backend_options._timeout != timeout:
- warnings.warn(
- "backend_options._timeout was specified, "
- "but timeout kwarg has a default value that will always override it. ",
- stacklevel=2,
- )
- else:
- # default backend_options for NCCL
- backend_options = ProcessGroupNCCL.Options()
- backend_options.is_high_priority_stream = False
- # pyrefly: ignore [bad-argument-type]
- backend_options._timeout = timeout
- if split_from:
- backend_options.split_from = split_from
- backend_options.split_color = _process_group_color(
- global_ranks_in_group
- )
- backend_options.global_ranks_in_group = global_ranks_in_group
- backend_options.group_name = group_name
- backend_class = ProcessGroupNCCL(
- backend_prefix_store, group_rank, group_size, backend_options
- )
- backend_type = ProcessGroup.BackendType.NCCL
- elif backend_str == Backend.UCC and is_ucc_available():
- # TODO: once UCC plugin is fully deprecated, remove
- # is_ucc_available() from above elif-condition and raise
- # RuntimeError if is_ucc_available() returns false.
- backend_class = ProcessGroupUCC(
- backend_prefix_store,
- group_rank,
- group_size,
- # pyrefly: ignore [bad-argument-type]
- timeout=timeout,
- )
- backend_type = ProcessGroup.BackendType.UCC
- elif backend_str == Backend.XCCL:
- if not is_xccl_available():
- raise RuntimeError("Distributed package doesn't have XCCL built in")
- backend_options = ProcessGroupXCCL.Options()
- backend_options.global_ranks_in_group = global_ranks_in_group
- backend_options.group_name = group_name
- # pyrefly: ignore [bad-argument-type]
- backend_options._timeout = timeout
- backend_class = ProcessGroupXCCL(
- backend_prefix_store, group_rank, group_size, backend_options
- )
- backend_type = ProcessGroup.BackendType.XCCL
- else:
- if backend_str.upper() not in Backend._plugins:
- raise AssertionError(f"Unknown c10d backend type {backend_str.upper()}")
- backend_plugin = Backend._plugins[backend_str.upper()]
- creator_fn = backend_plugin.creator_fn
- extended_api = backend_plugin.extended_api
- backend_type = ProcessGroup.BackendType.CUSTOM
- if not extended_api:
- backend_class = creator_fn(
- backend_prefix_store, group_rank, group_size, timeout
- )
- else:
- dist_backend_opts = _DistributedBackendOptions()
- dist_backend_opts.store = backend_prefix_store
- dist_backend_opts.group_rank = group_rank
- dist_backend_opts.group_size = group_size
- # pyrefly: ignore [bad-argument-type]
- dist_backend_opts.timeout = timeout
- dist_backend_opts.group_id = group_name
- dist_backend_opts.global_ranks_in_group = global_ranks_in_group
- backend_class = creator_fn(dist_backend_opts, backend_options)
- # Set sequence numbers for gloo and nccl backends.
- if backend_str == Backend.GLOO and not _use_torchcomms_enabled():
- if not isinstance(backend_class, ProcessGroupGloo):
- raise AssertionError(
- f"Expected ProcessGroupGloo, got {type(backend_class)}"
- )
- backend_class._set_sequence_number_for_group()
- elif backend_str == Backend.NCCL and not _use_torchcomms_enabled():
- if not isinstance(backend_class, ProcessGroupNCCL):
- raise AssertionError(
- f"Expected ProcessGroupNCCL, got {type(backend_class)}"
- )
- backend_class._set_sequence_number_for_group()
- # If the type is a subclass of ProcessGroup then return this process group immediately
- # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the
- # ProcessGroup instance
- if issubclass(type(backend_class), ProcessGroup):
- pg = backend_class # type: ignore[assignment]
- break
- # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set
- if (
- backend_str in [Backend.GLOO, Backend.NCCL, Backend.XCCL, Backend.UCC]
- or backend_str.upper() in Backend._plugins
- ):
- # In debug mode and if GLOO is available, wrap in a wrapper PG that
- # enables enhanced collective checking for debuggability.
- if get_debug_level() == DebugLevel.DETAIL:
- if not _GLOO_AVAILABLE:
- logger.info(
- """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but
- GLOO is not available. Build with Gloo to
- create a wrapper process group in debug mode
- to aid collective desynchronization debugging."""
- )
- else:
- backend_class = _create_process_group_wrapper(
- wrapped_pg=backend_class,
- store_prefix=group_name,
- store=backend_prefix_store,
- rank=group_rank,
- world_size=group_size,
- # pyrefly: ignore [bad-argument-type]
- timeout=timeout,
- )
- elif (
- get_debug_level() == DebugLevel.DETAIL
- or os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL"
- ) and _use_torchcomms_enabled():
- logger.warning(
- "(Backend %s) Debug Mode not supported for torchcomms.", backend_str
- )
- # register only a single backend when all get_device_backend_map values are the same
- if len(set(backend_config.get_device_backend_map().values())) == 1:
- for device in backend_config.get_device_backend_map():
- pg._register_backend(torch.device(device), backend_type, backend_class)
- # break out of outer loop to not create any more backends
- break
- pg._register_backend(torch.device(device), backend_type, backend_class)
- # set group_name and group_dsec to backend
- if group_name is None:
- raise AssertionError("group_name must not be None")
- if group_desc is None:
- raise AssertionError("group_desc must not be None")
- pg._set_group_name(group_name)
- pg._set_group_desc(group_desc)
- if device_id and pg._get_backend(device_id).supports_splitting:
- eager_backend = pg._get_backend(device_id)
- eager_backend.eager_connect_single_device(device_id)
- # update global state
- _world.pg_map[pg] = (backend, prefix_store)
- _world.pg_names[pg] = group_name
- _register_process_group(group_name, pg)
- _world.pg_backend_config[pg] = str(backend_config)
- # "" is the default tag for user PGs
- if pg_tag in [None, ""]:
- pg_tag = f"ptd:{group_name}"
- _world.tags_to_pg.setdefault("", []).append(pg)
- else:
- pg_tag = f"user:{pg_tag}"
- _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
- _world.pg_to_tag[pg] = pg_tag
- return pg, prefix_store
- def destroy_process_group(group: ProcessGroup | None = None):
- """
- Destroy a given process group, and deinitialize the distributed package.
- Args:
- group (ProcessGroup, optional): The process group to be destroyed, if
- group.WORLD is given, all process
- groups including the default one will
- be destroyed.
- """
- global _world
- if group == GroupMember.NON_GROUP_MEMBER:
- return
- if group is None:
- pg = GroupMember.WORLD
- else:
- pg = group
- if pg is None:
- raise AssertionError("Process group cannot be None")
- if _world.pg_map.get(pg, None) is None:
- raise ValueError("Invalid process group specified")
- # When users register Python onCompletion hooks, those hooks will run on a
- # different thread than the main thread. Today, the ProcessGroup dtor does
- # wait for that thread. However, the dtor might finish after the Python
- # Interpreter exits. After that grabbing the GIL for the Python hook will crash.
- # We can either revive the interpreter when running hooks or keep the main one
- # alive until all works and hooks are done. The current implementation does the
- # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait
- # for the pending hooks to finish.
- if type(pg) is ProcessGroup and pg._has_hooks():
- pg._wait_for_pending_works()
- if group is None or group == GroupMember.WORLD:
- for comm in _world.comms:
- comm.finalize()
- _world.comms.clear()
- # shutdown all backends in the order of pg names. shutting down in order because
- # ncclCommAbort() was a 'collective' call in some versions of NCCL.
- for pg_to_shutdown in sorted(
- _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
- ):
- pg_to_shutdown.shutdown()
- _update_default_pg(None)
- _world.pg_map.clear()
- _world.pg_names.clear()
- _world.pg_group_ranks.clear()
- _world.pg_backend_config.clear()
- _world.pg_to_tag.clear()
- _world.tags_to_pg.clear()
- _world.pg_coalesce_state.clear()
- _unregister_all_process_groups()
- # when process group doesn't have an explicit name (only WORLD (default)
- # process group can have an explicit name), we use global _world.group_count
- # to generate the name. We need to reset the counter on destruction to
- # allow consistent value to be generated when we re-create process
- # groups after some trainers recover from failure
- #
- # We only reset this when WORLD is being destroyed because if this
- # process group is in good state, we aren't dealing with failures.
- _world.group_count = 0
- else:
- if _TORCHCOMM_AVAILABLE:
- for device_type in pg._device_types:
- backend = pg._get_backend(device_type)
- if isinstance(backend, _BackendWrapper):
- backend.get_comm().finalize()
- _world.comms.clear()
- pg.shutdown()
- del _world.pg_map[pg]
- del _world.pg_names[pg]
- del _world.pg_group_ranks[pg]
- del _world.pg_backend_config[pg]
- if pg in _world.pg_coalesce_state:
- warnings.warn(
- "Some coalesced collectives haven't been launched when "
- "ProcessGroup is destroyed. They will be cleaned.",
- stacklevel=2,
- )
- del _world.pg_coalesce_state[pg]
- tag = _world.pg_to_tag.get(pg)
- del _world.pg_to_tag[pg]
- if tag is not None:
- try:
- _world.tags_to_pg[tag].remove(pg)
- if tag.startswith("ptd:"):
- _world.tags_to_pg[""].remove(pg)
- except Exception:
- pass
- _unregister_process_group(pg.group_name)
- def _abort_process_group(group: ProcessGroup | None = None):
- """
- Abort a given process group. If group.WORLD (i.e. `None`) is given, all
- process groups including the default one will be aborted.
- Args:
- group (ProcessGroup, optional): The process group to be aborted.
- .. note:: this API is experimental and currently only works with the NCCL
- backend.
- .. note:: this API should be used with `TORCH_NCCL_ASYNC_ERROR_HANDLING`
- turned off (i.e. set to 0). Otherwise, ProcessGroupNCCL's watchdog may
- automatically handle errors or timeouts for you including aborting the
- ProcessGroup.
- """
- global _world
- if group == GroupMember.NON_GROUP_MEMBER:
- return
- pg = group or GroupMember.WORLD
- if pg is None:
- raise AssertionError("Process group cannot be None")
- if _world.pg_map.get(pg, None) is None:
- raise ValueError("Invalid process group specified or has been destroyed.")
- try:
- backend = pg._get_backend(torch.device("cuda"))
- except RuntimeError:
- backend = None
- if group is None or group == GroupMember.WORLD:
- # Abort all backends within a ncclGroupStart|End semantic.
- # This ensures that different NCCL communicators' abort calls won't
- # deadlock each other.
- # For details, please see: https://github.com/pytorch/pytorch/issues/119797
- if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
- backend._group_start()
- for pg_to_abort in sorted(
- _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True
- ):
- pg_to_abort.abort()
- if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
- backend._group_end()
- _update_default_pg(None)
- _world.pg_map.clear()
- _world.pg_names.clear()
- _world.pg_group_ranks.clear()
- _world.pg_backend_config.clear()
- _world.pg_to_tag.clear()
- _world.tags_to_pg.clear()
- _world.pg_coalesce_state.clear()
- _unregister_all_process_groups()
- # when process group doesn't have an explicit name (only WORLD (default)
- # process group can have an explicit name), we use global _world.group_count
- # to generate the name. We need to reset the counter on destruction to
- # allow consistent value to be generated when we re-create process
- # groups after some trainers recover from failure
- #
- # We only reset this when WORLD is being destroyed because if this
- # process group is in good state, we aren't dealing with failures.
- _world.group_count = 0
- else:
- pg.abort()
- del _world.pg_map[pg]
- del _world.pg_names[pg]
- del _world.pg_group_ranks[pg]
- del _world.pg_backend_config[pg]
- if pg in _world.pg_coalesce_state:
- warnings.warn(
- "Some coalesced collectives haven't been launched when "
- "ProcessGroup is aborted. They will be cleaned.",
- stacklevel=2,
- )
- del _world.pg_coalesce_state[pg]
- tag = _world.pg_to_tag.get(pg)
- del _world.pg_to_tag[pg]
- if tag is not None:
- try:
- _world.tags_to_pg[tag].remove(pg)
- if tag.startswith("ptd:"):
- _world.tags_to_pg[""].remove(pg)
- except Exception:
- pass
- _unregister_process_group(pg.group_name)
- def get_rank(group: ProcessGroup | None = None) -> int:
- """
- Return the rank of the current process in the provided ``group``, default otherwise.
- Rank is a unique identifier assigned to each process within a distributed
- process group. They are always consecutive integers ranging from 0 to
- ``world_size``.
- Args:
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- Returns:
- The rank of the process group
- -1, if not part of the group
- """
- if _rank_not_in_group(group):
- return -1
- default_pg = _get_default_group()
- if group is None or group is GroupMember.WORLD:
- return default_pg.rank()
- return get_group_rank(group, default_pg.rank())
- def get_world_size(group: ProcessGroup | None = None) -> int:
- """
- Return the number of processes in the current process group.
- Args:
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- Returns:
- The world size of the process group
- -1, if not part of the group
- """
- if _rank_not_in_group(group):
- return -1
- return _get_group_size(group)
- def isend(
- tensor: torch.Tensor,
- dst: int | None = None,
- group: ProcessGroup | None = None,
- tag: int = 0,
- group_dst: int | None = None,
- ) -> Work | None:
- """
- Send a tensor asynchronously.
- .. warning::
- Modifying ``tensor`` before the request completes causes undefined
- behavior.
- .. warning::
- ``tag`` is not supported with the NCCL backend.
- Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self.
- Args:
- tensor (Tensor): Tensor to send.
- dst (int): Destination rank on global process group (regardless of ``group`` argument)
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- tag (int, optional): Tag to match send with remote recv
- group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``
- Returns:
- A distributed request object.
- None, if not part of the group
- """
- group = _group_or_default_group(group)
- group_dst = _canonicalize_group_rank(group, dst, group_dst)
- _check_single_tensor(tensor, "tensor")
- if _rank_not_in_group(group):
- _warn_not_in_group("isend")
- return None
- if tensor.is_complex():
- tensor = torch.view_as_real(tensor)
- return group.send([tensor], group_dst, tag)
- def irecv(
- tensor: torch.Tensor,
- src: int | None = None,
- group: ProcessGroup | None = None,
- tag: int = 0,
- group_src: int | None = None,
- ) -> Work | None:
- """
- Receives a tensor asynchronously.
- .. warning::
- ``tag`` is not supported with the NCCL backend.
- Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self.
- Args:
- tensor (Tensor): Tensor to fill with received data.
- src (int, optional): Source rank on global process group (regardless of ``group`` argument).
- Will receive from any process if unspecified.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- tag (int, optional): Tag to match recv with remote send
- group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
- Returns:
- A distributed request object.
- None, if not part of the group
- """
- _check_single_tensor(tensor, "tensor")
- if _rank_not_in_group(group):
- _warn_not_in_group("irecv")
- return None
- if tensor.is_complex():
- tensor = torch.view_as_real(tensor)
- group = _group_or_default_group(group)
- if src is None and group_src is None:
- return group.recv_anysource([tensor], tag)
- else:
- group_src = _canonicalize_group_rank(group, src, group_src)
- return group.recv([tensor], group_src, tag)
- @_exception_logger
- def send(
- tensor: torch.Tensor,
- dst: int | None = None,
- group: ProcessGroup | None = None,
- tag: int = 0,
- group_dst: int | None = None,
- ) -> None:
- """
- Send a tensor synchronously.
- .. warning::
- ``tag`` is not supported with the NCCL backend.
- Args:
- tensor (Tensor): Tensor to send.
- dst (int): Destination rank on global process group (regardless of ``group`` argument).
- Destination rank should not be the same as the rank of the current process.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- tag (int, optional): Tag to match send with remote recv
- group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``.
- """
- group = _group_or_default_group(group)
- group_dst = _canonicalize_group_rank(group, dst, group_dst)
- _check_not_self_rank(group, group_dst, "destination")
- work = isend(tensor, group=group, tag=tag, group_dst=group_dst)
- if work is not None:
- work.wait()
- @_exception_logger
- def recv(
- tensor: torch.Tensor,
- src: int | None = None,
- group: ProcessGroup | None = None,
- tag: int = 0,
- group_src: int | None = None,
- ) -> int:
- """
- Receives a tensor synchronously.
- .. warning::
- ``tag`` is not supported with the NCCL backend.
- Args:
- tensor (Tensor): Tensor to fill with received data.
- src (int, optional): Source rank on global process group (regardless of ``group`` argument).
- Will receive from any process if unspecified.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- tag (int, optional): Tag to match recv with remote send
- group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
- Returns:
- Sender rank
- -1, if not part of the group
- """
- work = irecv(tensor, src=src, group=group, tag=tag, group_src=group_src)
- if work is None:
- return -1
- work.wait()
- if src is None:
- if group_src is None:
- group_src = work._source_rank()
- group = _group_or_default_group(group)
- _check_not_self_rank(group, group_src, "source")
- src = get_global_rank(group, group_src)
- return src
- class _IllegalWork(Work):
- def __getattribute__(self, name):
- if name in [
- "is_success",
- "exception",
- "wait",
- "source_rank",
- "_source_rank",
- "result",
- "synchronize",
- ]:
- raise ValueError(f"Illegal to call {name} on IllegalWork object")
- class _CoalescingManager:
- def __init__(self) -> None:
- self.works: list[Work] = []
- def append(self, work: Work | None = None):
- if work:
- self.works.append(work)
- def wait(self):
- for work in self.works:
- work.wait()
- @contextlib.contextmanager
- def _coalescing_manager(
- group: ProcessGroup | None = None,
- device: torch.device | None = None,
- async_ops: bool = False,
- ):
- """
- Context manager used to coalesce collectives or P2P operations when possible.
- Args:
- group (`ProcessGroup`, optional): The process group to work on. If None,
- the default process group will be used.
- device (`torch.device`, optional): Default is None, set to a device if
- there isn't a `**_coalesced` implementation by the backend.
- async_ops (`bool`, optional): whether the coalesced ops are async ops.
- Examples:
- >>> # xdoctest: +SKIP("no rank")
- >>> # Synchronous ops
- >>> with _coalescing_manager():
- >>> for i in range(num_colls):
- >>> dist.all_reduce(tensors[i])
- >>> # Asynchronous ops
- >>> with _coalescing_manager(async_ops=True) as cm:
- >>> for i in range(num_colls):
- >>> dist.all_reduce(tensors[i])
- >>> cm.wait()
- .. warning::
- :func:`_coalescing_manager` currently do not support coalescing
- all-reduces with different reduce operators, e.g. `ReduceOp.SUM` mixed
- with `ReduceOp.PRODUCT`.
- """
- group = group or _get_default_group()
- op_list = _world.pg_coalesce_state.setdefault(group, [])
- if op_list:
- raise ValueError(
- "ProcessGroup has non-empty op list at the start of coalescing"
- )
- if device:
- group._start_coalescing(device)
- cm = _CoalescingManager()
- yield cm
- work = None
- op_list = _world.pg_coalesce_state.pop(group)
- if op_list:
- # Collectives supporting "Fast Path" coalescing are captured.
- # See implementation in corresponding collective APIs.
- # Currently supported:
- # - coalesced `all_reduce`
- # - coalesced `all_gather_into_tensor`
- # - coalesced `reduce_scatter_tensor`
- op0 = op_list[0].op
- if op0 is all_reduce:
- tensors = [op.tensor for op in op_list]
- all_reduce_opts = AllreduceCoalescedOptions()
- all_reduce_opts.reduceOp = not_none(op_list[0].redop)
- all_reduce_opts.asyncOp = async_ops
- work = group.allreduce_coalesced(tensors, all_reduce_opts)
- elif op0 is all_gather_into_tensor:
- inputs = []
- outputs = []
- for op in op_list:
- inputs.append(op.tensor)
- outputs.append(not_none(op.dst_tensor))
- all_gather_opts = AllgatherOptions()
- all_gather_opts.asyncOp = async_ops
- work = group.allgather_into_tensor_coalesced(outputs, inputs)
- elif op0 is reduce_scatter_tensor:
- inputs = []
- outputs = []
- for op in op_list:
- inputs.append(op.tensor)
- outputs.append(not_none(op.dst_tensor))
- reduce_opts = ReduceScatterOptions()
- reduce_opts.reduceOp = not_none(op_list[0].redop)
- reduce_opts.asyncOp = async_ops
- work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts)
- else:
- raise AssertionError(
- f"Coalescing manager does not support fast-path coalescing of {op0}, "
- f"yet {op0} is still recorded in op list. This is an internal error of c10d."
- )
- if device:
- # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding
- work = group._end_coalescing(device)
- if async_ops:
- cm.append(work)
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- class _TimeEstimator:
- def __init__(self) -> None:
- self.estimated_time: float | None = None
- @contextlib.contextmanager
- def _time_estimator(
- group: ProcessGroup | None = None,
- device: torch.device | None = None,
- ):
- """
- Context manager used to estimate time of collectives.
- Within the context manager, nothing is actually run and the backend just simulates
- the collective time only.
- Args:
- group (`ProcessGroup`, optional): The process group to work on. If None,
- the default process group will be used.
- device (`torch.device`, optional): Default is None, set to a device if
- there isn't a `**_coalesced` implementation by the backend.
- Examples:
- >>> # xdoctest: +SKIP("no rank")
- >>> # Synchronous ops
- >>> with _time_estimator() as cm:
- >>> for i in range(num_colls):
- >>> dist.all_reduce(tensors[i])
- >>> # estimate time is stored in cm.estimated_time
- .. warning::
- :func:`_time_estimator` currently only support NCCL backend but it can
- easily be extended to other backends.
- Also a NCCL communicator needs to be created because only with a real communicator can we do accurate estimation.
- The communicator internally has knowledge about the links it runs on
- (e.g. intra-node or inter-node, whether the links are NVLink or PCI-e or IB).
- """
- # TODO: We need to also support torch inductor for the time estimator.
- group = group or _get_default_group()
- device = device or _get_pg_default_device(group)
- backend = group._get_backend(device)
- if not backend.supports_time_estimate:
- raise NotImplementedError(
- f"collective time estimator is not supported in the current version of backend {backend}"
- )
- backend._start_time_estimate() # type: ignore[attr-defined]
- cm = _TimeEstimator()
- yield cm
- cm.estimated_time = backend._end_time_estimate() # type: ignore[attr-defined]
- def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]:
- """
- Send or Receive a batch of tensors asynchronously and return a list of requests.
- Process each of the operations in ``p2p_op_list`` and return the corresponding
- requests. NCCL, Gloo, and UCC backend are currently supported.
- Args:
- p2p_op_list: A list of point-to-point operations(type of each operator is
- ``torch.distributed.P2POp``). The order of the isend/irecv in the list
- matters and it needs to match with corresponding isend/irecv on the
- remote end.
- Returns:
- A list of distributed request objects returned by calling the corresponding
- op in the op_list.
- Examples:
- >>> # xdoctest: +SKIP("no rank")
- >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank
- >>> recv_tensor = torch.randn(2, dtype=torch.float32)
- >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
- >>> recv_op = dist.P2POp(
- ... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
- ... )
- >>> reqs = batch_isend_irecv([send_op, recv_op])
- >>> for req in reqs:
- >>> req.wait()
- >>> recv_tensor
- tensor([2, 3]) # Rank 0
- tensor([0, 1]) # Rank 1
- .. note:: Note that when this API is used with the NCCL PG backend, users must set
- the current GPU device with `torch.cuda.set_device`, otherwise it will
- lead to unexpected hang issues.
- In addition, if this API is the first collective call in the ``group``
- passed to ``dist.P2POp``, all ranks of the ``group`` must participate in
- this API call; otherwise, the behavior is undefined. If this API call is
- not the first collective call in the ``group``, batched P2P operations
- involving only a subset of ranks of the ``group`` are allowed.
- """
- _check_p2p_op_list(p2p_op_list)
- group = p2p_op_list[0].group
- if group is None:
- group = _get_default_group()
- device = p2p_op_list[0].tensor.device
- def peer_kwarg(op: P2POp) -> dict[str, int]:
- key = "group_dst" if op.op is isend else "group_src"
- return {key: op.group_peer}
- if type(group) is ProcessGroup and group._get_backend(device).supports_coalescing:
- # NCCL style coalescing
- with _coalescing_manager(group, device, async_ops=True) as cm:
- for p2p_op in p2p_op_list:
- p2p_op.op(
- p2p_op.tensor,
- group=p2p_op.group,
- tag=p2p_op.tag,
- **peer_kwarg(p2p_op),
- )
- return cm.works
- else:
- # backend not support coalescing
- reqs = []
- for p2p_op in p2p_op_list:
- work = p2p_op.op(
- p2p_op.tensor,
- group=p2p_op.group,
- tag=p2p_op.tag,
- **peer_kwarg(p2p_op),
- )
- if work:
- reqs.append(work)
- return reqs
- @_exception_logger
- def broadcast(
- tensor: torch.Tensor,
- src: int | None = None,
- group: ProcessGroup | None = None,
- async_op: bool = False,
- group_src: int | None = None,
- ):
- """
- Broadcasts the tensor to the whole group.
- ``tensor`` must have the same number of elements in all processes
- participating in the collective.
- Args:
- tensor (Tensor): Data to be sent if ``src`` is the rank of current
- process, and tensor to be used to save received data otherwise.
- src (int): Source rank on global process group (regardless of ``group`` argument).
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- group_src (int): Source rank on ``group``. Must specify one of ``group_src``
- and ``src`` but not both.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- """
- group = _group_or_default_group(group)
- group_src = _canonicalize_group_rank(group, src, group_src, return_global=False)
- _check_single_tensor(tensor, "tensor")
- if _rank_not_in_group(group):
- _warn_not_in_group("broadcast")
- return
- opts = BroadcastOptions()
- opts.rootRank = group_src
- opts.rootTensor = 0
- opts.asyncOp = async_op
- if tensor.is_complex():
- tensor = torch.view_as_real(tensor)
- work = group.broadcast([tensor], opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op: bool = False):
- """
- Reduces the tensor data across all machines in a way that all get the final result.
- After the call ``tensor`` is going to be bitwise identical in all processes.
- Complex tensors are supported.
- Args:
- tensor (Tensor): Input and output of the collective. The function
- operates in-place.
- op (optional): One of the values from
- ``torch.distributed.ReduceOp``
- enum. Specifies an operation used for element-wise reductions.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- Examples:
- >>> # xdoctest: +SKIP("no rank")
- >>> # All tensors below are of torch.int64 type.
- >>> # We have 2 process groups, 2 ranks.
- >>> device = torch.device(f"cuda:{rank}")
- >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
- >>> tensor
- tensor([1, 2], device='cuda:0') # Rank 0
- tensor([3, 4], device='cuda:1') # Rank 1
- >>> dist.all_reduce(tensor, op=ReduceOp.SUM)
- >>> tensor
- tensor([4, 6], device='cuda:0') # Rank 0
- tensor([4, 6], device='cuda:1') # Rank 1
- >>> # All tensors below are of torch.cfloat type.
- >>> # We have 2 process groups, 2 ranks.
- >>> tensor = torch.tensor(
- ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
- ... ) + 2 * rank * (1 + 1j)
- >>> tensor
- tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
- tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
- >>> dist.all_reduce(tensor, op=ReduceOp.SUM)
- >>> tensor
- tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0
- tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1
- """
- # Dynamo has built-in logic to map legacy distributed ops to functional collectives.
- # Let's redirect to a torch function mode that can mimic this logic outside Dynamo
- # (e.g., non-strict export implements such a torch function mode).
- relevant_args = (tensor,)
- if has_torch_function(relevant_args):
- return handle_torch_function(
- all_reduce,
- relevant_args,
- tensor,
- op=op,
- group=group,
- async_op=async_op,
- )
- _check_single_tensor(tensor, "tensor")
- if _rank_not_in_group(group):
- _warn_not_in_group("all_reduce")
- return
- if tensor.is_complex():
- if not supports_complex(op):
- raise ValueError(f"all_reduce does not support {op} on complex tensors")
- tensor = torch.view_as_real(tensor)
- opts = AllreduceOptions()
- opts.reduceOp = op
- opts.asyncOp = async_op
- if group is None:
- group = _get_default_group()
- if group in _world.pg_coalesce_state:
- # We are in coalescing context, do not issue single operation, just append a collective representation
- coll = _CollOp(all_reduce, tensor, None, op, None)
- _world.pg_coalesce_state[group].append(coll)
- if async_op:
- return _IllegalWork()
- else:
- return None
- work = group.allreduce([tensor], opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- @deprecated(
- "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must "
- "use it, please revisit our documentation later at "
- "https://pytorch.org/docs/main/distributed.html#collective-functions",
- category=FutureWarning,
- )
- def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op: bool = False):
- """
- WARNING: at this time individual shape checking is not implemented across nodes.
- For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
- rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce
- operation will proceed without complaint and return erroneous outputs. This lack
- of shape checking results in significant performance improvements but users of this
- function should take extra care to ensure that each node passes in tensors whose
- shapes match across nodes.
- Reduces each tensor in tensors (residing on the same device) across all machines
- in such a way that all get the final result.
- After the call each tensor in tensors is going to bitwise identical
- in all processes.
- Complex tensors are supported.
- Args:
- tensors (Union[List[Tensor], Tensor]): Input and output of the collective.
- The function operates in-place.
- op (Optional[ReduceOp]): One of the values from
- ``torch.distributed.ReduceOp`` enum. Specifies an operation used for
- element-wise reductions.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (Optional[bool]): Whether this op should be an async op.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group.
- """
- if isinstance(tensors, torch.Tensor):
- tensors = [tensors]
- _check_tensor_list(tensors, "tensor")
- _ensure_all_tensors_same_dtype(tensors)
- if _rank_not_in_group(group):
- _warn_not_in_group("all_reduce_coalesced")
- return
- if any(t.is_complex() for t in tensors) and not supports_complex(op):
- raise ValueError(f"all_reduce does not support {op} on complex tensors")
- tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]
- opts = AllreduceCoalescedOptions()
- opts.reduceOp = op
- opts.asyncOp = async_op
- group = group or _get_default_group()
- work = group.allreduce_coalesced(tensors, opts)
- if async_op:
- return work.get_future()
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- def reduce(
- tensor: torch.Tensor,
- dst: int | None = None,
- op=ReduceOp.SUM,
- group: ProcessGroup | None = None,
- async_op: bool = False,
- group_dst: int | None = None,
- ):
- """
- Reduces the tensor data across all machines.
- Only the process with rank ``dst`` is going to receive the final result.
- Args:
- tensor (Tensor): Input and output of the collective. The function
- operates in-place.
- dst (int): Destination rank on global process group (regardless of ``group`` argument)
- op (optional): One of the values from
- ``torch.distributed.ReduceOp``
- enum. Specifies an operation used for element-wise reductions.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst``
- and ``dst`` but not both.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- """
- group = _group_or_default_group(group)
- group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False)
- _check_single_tensor(tensor, "tensor")
- if _rank_not_in_group(group):
- _warn_not_in_group("reduce")
- return
- opts = ReduceOptions()
- opts.reduceOp = op
- opts.rootRank = group_dst
- opts.asyncOp = async_op
- work = group.reduce([tensor], opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- def _object_to_tensor(obj, device, group):
- with _WaitCounter("pytorch.wait_counter.c10d._object_to_tensor").guard():
- f = io.BytesIO()
- _pickler(f).dump(obj)
- byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
- # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
- # Otherwise, it will cause 100X slowdown.
- # See: https://github.com/pytorch/pytorch/issues/65696
- byte_tensor = torch.ByteTensor(byte_storage).to(device)
- if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
- backend = get_backend(group)
- if backend == Backend.NCCL:
- hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
- logger.warning(
- "_object_to_tensor size: %s hash value: %s",
- byte_tensor.numel(),
- hash,
- )
- local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
- return byte_tensor, local_size
- def _tensor_to_object(tensor, tensor_size, group):
- with _WaitCounter("pytorch.wait_counter.c10d._tensor_to_object").guard():
- if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
- backend = get_backend(group)
- if backend == Backend.NCCL:
- hash = torch._C._distributed_c10d._hash_tensors([tensor])
- logger.warning(
- "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash
- )
- tensor = tensor.cpu()
- buf = tensor.numpy().tobytes()[:tensor_size]
- return _unpickler(io.BytesIO(buf)).load()
- @_exception_logger
- def all_gather_object(object_list, obj, group=None):
- """
- Gathers picklable objects from the whole group into a list.
- Similar to :func:`all_gather`, but Python objects can be passed in.
- Note that the object must be picklable in order to be gathered.
- Args:
- object_list (list[Any]): Output list. It should be correctly sized as the
- size of the group for this collective and will contain the output.
- obj (Any): Pickable Python object to be broadcast from current process.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used. Default is ``None``.
- Returns:
- None. If the calling rank is part of this group, the output of the
- collective will be populated into the input ``object_list``. If the
- calling rank is not part of the group, the passed in ``object_list`` will
- be unmodified.
- .. note:: Note that this API differs slightly from the :func:`all_gather`
- collective since it does not provide an ``async_op`` handle and thus
- will be a blocking call.
- .. note:: For NCCL-based processed groups, internal tensor representations
- of objects must be moved to the GPU device before communication takes
- place. In this case, the device used is given by
- ``torch.cuda.current_device()`` and it is the user's responsibility to
- ensure that this is set so that each rank has an individual GPU, via
- ``torch.cuda.set_device()``.
- .. warning::
- Object collectives have a number of serious performance and scalability
- limitations. See :ref:`object_collectives` for details.
- .. warning::
- :func:`all_gather_object` uses ``pickle`` module implicitly, which is
- known to be insecure. It is possible to construct malicious pickle data
- which will execute arbitrary code during unpickling. Only call this
- function with data you trust.
- .. warning::
- Calling :func:`all_gather_object` with GPU tensors is not well supported
- and inefficient as it incurs GPU -> CPU transfer since tensors would be
- pickled. Please consider using :func:`all_gather` instead.
- Example::
- >>> # xdoctest: +SKIP("need process group init")
- >>> # Note: Process group initialization omitted on each rank.
- >>> import torch.distributed as dist
- >>> # Assumes world_size of 3.
- >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
- >>> output = [None for _ in gather_objects]
- >>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
- >>> output
- ['foo', 12, {1: 2}]
- """
- if _rank_not_in_group(group):
- _warn_not_in_group("all_gather_object")
- return
- current_device = _get_object_coll_device(group)
- input_tensor, local_size = _object_to_tensor(obj, current_device, group)
- # Gather all local sizes. This is so that we can find the max size, and index
- # until the correct size when deserializing the tensors.
- group_size = get_world_size(group=group)
- object_sizes_tensor = torch.zeros(
- group_size, dtype=torch.long, device=current_device
- )
- object_size_list = [
- object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
- ]
- # Allgather tensor sizes
- all_gather(object_size_list, local_size, group=group)
- max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
- # Resize tensor to max size across all ranks.
- input_tensor.resize_(max_object_size)
- coalesced_output_tensor = torch.empty(
- max_object_size * group_size, dtype=torch.uint8, device=current_device
- )
- # Output tensors are nonoverlapping views of coalesced_output_tensor
- output_tensors = [
- coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
- for i in range(group_size)
- ]
- all_gather(output_tensors, input_tensor, group=group)
- # Deserialize outputs back to object.
- for i, tensor in enumerate(output_tensors):
- tensor = tensor.type(torch.uint8)
- tensor_size = object_size_list[i]
- object_list[i] = _tensor_to_object(tensor, tensor_size, group)
- @_exception_logger
- def gather_object(
- obj: Any,
- object_gather_list: list[Any] | None = None,
- dst: int | None = None,
- group: ProcessGroup | None = None,
- group_dst: int | None = None,
- ):
- """
- Gathers picklable objects from the whole group in a single process.
- Similar to :func:`gather`, but Python objects can be passed in. Note that the
- object must be picklable in order to be gathered.
- Args:
- obj (Any): Input object. Must be picklable.
- object_gather_list (list[Any]): Output list. On the ``dst`` rank, it
- should be correctly sized as the size of the group for this
- collective and will contain the output. Must be ``None`` on non-dst
- ranks. (default is ``None``)
- dst (int, optional): Destination rank on global process group (regardless of ``group`` argument).
- (If both ``dst`` and ``group_dst`` are None, default is global rank 0)
- group: (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used. Default is ``None``.
- group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``
- Returns:
- None. On the ``dst`` rank, ``object_gather_list`` will contain the
- output of the collective.
- .. note:: Note that this API differs slightly from the gather collective
- since it does not provide an async_op handle and thus will be a blocking
- call.
- .. note:: For NCCL-based processed groups, internal tensor representations
- of objects must be moved to the GPU device before communication takes
- place. In this case, the device used is given by
- ``torch.cuda.current_device()`` and it is the user's responsibility to
- ensure that this is set so that each rank has an individual GPU, via
- ``torch.cuda.set_device()``.
- .. warning::
- Object collectives have a number of serious performance and scalability
- limitations. See :ref:`object_collectives` for details.
- .. warning::
- :func:`gather_object` uses ``pickle`` module implicitly, which is
- known to be insecure. It is possible to construct malicious pickle data
- which will execute arbitrary code during unpickling. Only call this
- function with data you trust.
- .. warning::
- Calling :func:`gather_object` with GPU tensors is not well supported
- and inefficient as it incurs GPU -> CPU transfer since tensors would be
- pickled. Please consider using :func:`gather` instead.
- Example::
- >>> # xdoctest: +SKIP("need process group init")
- >>> # Note: Process group initialization omitted on each rank.
- >>> import torch.distributed as dist
- >>> # Assumes world_size of 3.
- >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
- >>> output = [None for _ in gather_objects]
- >>> dist.gather_object(
- ... gather_objects[dist.get_rank()],
- ... output if dist.get_rank() == 0 else None,
- ... dst=0
- ... )
- >>> # On rank 0
- >>> output
- ['foo', 12, {1: 2}]
- """
- group = _group_or_default_group(group)
- if dst is None and group_dst is None:
- dst = 0
- group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False)
- if _rank_not_in_group(group):
- _warn_not_in_group("gather_object")
- return
- # Ensure object_gather_list is specified appropriately.
- my_group_rank = group.rank()
- _validate_output_list_for_rank(my_group_rank, group_dst, object_gather_list)
- current_device = _get_object_coll_device(group)
- input_tensor, local_size = _object_to_tensor(obj, current_device, group)
- # Gather all local sizes. This is so that we can find the max size, and index
- # until the correct size when deserializing the tensors.
- group_size = get_world_size(group=group)
- object_sizes_tensor = torch.zeros(
- group_size, dtype=torch.long, device=current_device
- )
- object_size_list = [
- object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
- ]
- # Allgather tensor sizes. An all-gather is needed here despite this being a
- # gather, since each rank needs to broadcast a tensor of the same (maximal)
- # size.
- all_gather(object_size_list, local_size, group=group)
- max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
- # Resize tensor to max size across all ranks.
- input_tensor.resize_(max_object_size)
- # Avoid populating output tensors if the result won't be gathered on this rank.
- if my_group_rank == group_dst:
- coalesced_output_tensor = torch.empty(
- max_object_size * group_size, dtype=torch.uint8, device=current_device
- )
- # Output tensors are nonoverlapping views of coalesced_output_tensor
- output_tensors = [
- coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
- for i in range(group_size)
- ]
- # All ranks call gather with equal-sized tensors.
- gather(
- input_tensor,
- gather_list=output_tensors if my_group_rank == group_dst else None, # type: ignore[possibly-undefined]
- group_dst=group_dst,
- group=group,
- )
- if my_group_rank != group_dst:
- return
- if object_gather_list is None:
- raise AssertionError("Must provide object_gather_list on dst rank")
- # pyrefly: ignore [unbound-name]
- for i, tensor in enumerate(output_tensors):
- tensor = tensor.type(torch.uint8)
- tensor_size = object_size_list[i]
- object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group)
- @_exception_logger
- def send_object_list(
- object_list: list[Any],
- dst: int | None = None,
- group: ProcessGroup | None = None,
- device: torch.device | None = None,
- group_dst: int | None = None,
- use_batch: bool = False,
- ):
- """
- Sends picklable objects in ``object_list`` synchronously.
- Similar to :func:`send`, but Python objects can be passed in.
- Note that all objects in ``object_list`` must be picklable in order to be
- sent.
- Args:
- object_list (List[Any]): List of input objects to sent.
- Each object must be picklable. Receiver must provide lists of equal sizes.
- dst (int): Destination rank to send ``object_list`` to.
- Destination rank is based on global process group (regardless of ``group`` argument)
- group: (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used. Default is ``None``.
- device (``torch.device``, optional): If not None, the objects are
- serialized and converted to tensors which are moved to the
- ``device`` before sending. Default is ``None``.
- group_dst (int, optional): Destination rank on ``group``.
- Must specify one of ``dst`` and ``group_dst`` but not both
- use_batch (bool, optional): If True, use batch p2p operations instead of
- regular send operations. This avoids initializing 2-rank communicators and
- uses existing entire group communicators. See batch_isend_irecv for usage and
- assumptions. Default is ``False``.
- Returns:
- ``None``.
- .. note:: For NCCL-based process groups, internal tensor representations
- of objects must be moved to the GPU device before communication takes
- place. In this case, the device used is given by
- ``torch.cuda.current_device()`` and it is the user's responsibility to
- ensure that this is set so that each rank has an individual GPU, via
- ``torch.cuda.set_device()``.
- .. warning::
- Object collectives have a number of serious performance and scalability
- limitations. See :ref:`object_collectives` for details.
- .. warning::
- :func:`send_object_list` uses ``pickle`` module implicitly, which
- is known to be insecure. It is possible to construct malicious pickle
- data which will execute arbitrary code during unpickling. Only call this
- function with data you trust.
- .. warning::
- Calling :func:`send_object_list` with GPU tensors is not well supported
- and inefficient as it incurs GPU -> CPU transfer since tensors would be
- pickled. Please consider using :func:`send` instead.
- Example::
- >>> # xdoctest: +SKIP("need process group init")
- >>> # Note: Process group initialization omitted on each rank.
- >>> import torch.distributed as dist
- >>> # Assumes backend is not NCCL
- >>> device = torch.device("cpu")
- >>> if dist.get_rank() == 0:
- >>> # Assumes world_size of 2.
- >>> objects = ["foo", 12, {1: 2}] # any picklable object
- >>> dist.send_object_list(objects, dst=1, device=device)
- >>> else:
- >>> objects = [None, None, None]
- >>> dist.recv_object_list(objects, src=0, device=device)
- >>> objects
- ['foo', 12, {1: 2}]
- """
- group = _group_or_default_group(group)
- group_dst = _canonicalize_group_rank(group, dst, group_dst)
- _check_not_self_rank(group, group_dst, "destination")
- if _rank_not_in_group(group):
- _warn_not_in_group("send_object_list")
- return
- # Current device selection.
- # To preserve backwards compatibility, ``device`` is default to ``None``
- # in which case we run current logic of device selection, i.e.
- # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
- # case it is not ``None`` we move the size and object tensors to be
- # sent to this device.
- current_device = device or _get_object_coll_device(group)
- # Serialize object_list elements to tensors on src rank.
- tensor_list, size_list = zip(
- *[_object_to_tensor(obj, current_device, group) for obj in object_list]
- )
- object_sizes_tensor = torch.cat(size_list)
- # Send object sizes
- if use_batch:
- batch_isend_irecv(
- [P2POp(isend, object_sizes_tensor, group_peer=group_dst, group=group)]
- ).pop().wait()
- else:
- send(object_sizes_tensor, group_dst=group_dst, group=group)
- # Concatenate and send serialized object tensors
- # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
- # has only one element, we can skip the copy.
- if len(tensor_list) == 1: # type: ignore[possibly-undefined]
- object_tensor = tensor_list[0]
- else:
- object_tensor = torch.cat(tensor_list)
- if use_batch:
- batch_isend_irecv(
- [P2POp(isend, object_tensor, group_peer=group_dst, group=group)]
- ).pop().wait()
- else:
- send(object_tensor, group_dst=group_dst, group=group)
- @_exception_logger
- def recv_object_list(
- object_list: list[Any],
- src: int | None = None,
- group: ProcessGroup | None = None,
- device: torch.device | None = None,
- group_src: int | None = None,
- use_batch: bool = False,
- ):
- """
- Receives picklable objects in ``object_list`` synchronously.
- Similar to :func:`recv`, but can receive Python objects.
- Args:
- object_list (List[Any]): List of objects to receive into.
- Must provide a list of sizes equal to the size of the list being sent.
- src (int, optional): Source rank from which to recv ``object_list``.
- Source rank is based on global process group (regardless of ``group`` argument)
- Will receive from any rank if set to None. Default is ``None``.
- group: (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used. Default is ``None``.
- device (``torch.device``, optional): If not None, receives on this device.
- Default is ``None``.
- group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
- use_batch (bool, optional): If True, use batch p2p operations instead of
- regular send operations. This avoids initializing 2-rank communicators and
- uses existing entire group communicators. See batch_isend_irecv for usage and
- assumptions. Default is ``False``.
- Returns:
- Sender rank. -1 if rank is not part of the group. If rank is part of the group,
- ``object_list`` will contain the sent objects from ``src`` rank.
- .. note:: For NCCL-based process groups, internal tensor representations
- of objects must be moved to the GPU device before communication takes
- place. In this case, the device used is given by
- ``torch.cuda.current_device()`` and it is the user's responsibility to
- ensure that this is set so that each rank has an individual GPU, via
- ``torch.cuda.set_device()``.
- .. warning::
- Object collectives have a number of serious performance and scalability
- limitations. See :ref:`object_collectives` for details.
- .. warning::
- :func:`recv_object_list` uses ``pickle`` module implicitly, which
- is known to be insecure. It is possible to construct malicious pickle
- data which will execute arbitrary code during unpickling. Only call this
- function with data you trust.
- .. warning::
- Calling :func:`recv_object_list` with GPU tensors is not well supported
- and inefficient as it incurs GPU -> CPU transfer since tensors would be
- pickled. Please consider using :func:`recv` instead.
- Example::
- >>> # xdoctest: +SKIP("need process group init")
- >>> # Note: Process group initialization omitted on each rank.
- >>> import torch.distributed as dist
- >>> # Assumes backend is not NCCL
- >>> device = torch.device("cpu")
- >>> if dist.get_rank() == 0:
- >>> # Assumes world_size of 2.
- >>> objects = ["foo", 12, {1: 2}] # any picklable object
- >>> dist.send_object_list(objects, dst=1, device=device)
- >>> else:
- >>> objects = [None, None, None]
- >>> dist.recv_object_list(objects, src=0, device=device)
- >>> objects
- ['foo', 12, {1: 2}]
- """
- group = _group_or_default_group(group)
- group_src = _canonicalize_group_rank(group, src, group_src)
- _check_not_self_rank(group, group_src, "source")
- if _rank_not_in_group(group):
- _warn_not_in_group("recv_object_list")
- return -1
- # Current device selection.
- # To preserve backwards compatibility, ``device`` is default to ``None``
- # in which case we run current logic of device selection, i.e.
- # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
- # case it is not ``None`` we move the size and object tensors to be
- # received to this device.
- current_device = device or _get_object_coll_device(group)
- object_sizes_tensor = torch.empty(
- len(object_list), dtype=torch.long, device=current_device
- )
- # Receive object sizes
- if use_batch:
- work = batch_isend_irecv(
- [
- P2POp(
- irecv,
- object_sizes_tensor,
- group_peer=group_src,
- group=group,
- )
- ]
- ).pop()
- work.wait()
- rank_sizes = get_global_rank(group, group_src)
- else:
- rank_sizes = recv(object_sizes_tensor, group=group, group_src=group_src)
- # Tensor to receive serialized objects into.
- object_tensor = torch.empty( # type: ignore[call-overload]
- torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
- dtype=torch.uint8,
- device=current_device,
- )
- if use_batch:
- work = batch_isend_irecv(
- [
- P2POp(
- irecv,
- object_tensor,
- group_peer=group_src,
- group=group,
- )
- ]
- ).pop()
- work.wait()
- rank_objects = get_global_rank(group, group_src)
- else:
- rank_objects = recv(object_tensor, group=group, group_src=group_src)
- if rank_sizes != rank_objects:
- raise AssertionError("Mismatch in return ranks for object sizes and objects.")
- # Deserialize objects using their stored sizes.
- offset = 0
- for i, obj_size in enumerate(object_sizes_tensor):
- obj_view = object_tensor[offset : offset + obj_size]
- obj_view = obj_view.type(torch.uint8)
- offset += obj_size
- object_list[i] = _tensor_to_object(obj_view, obj_size, group)
- return rank_objects
- @_exception_logger
- def broadcast_object_list(
- object_list: list[Any],
- src: int | None = None,
- group: ProcessGroup | None = None,
- device: torch.device | None = None,
- group_src: int | None = None,
- ):
- """
- Broadcasts picklable objects in ``object_list`` to the whole group.
- Similar to :func:`broadcast`, but Python objects can be passed in.
- Note that all objects in ``object_list`` must be picklable in order to be
- broadcasted.
- Args:
- object_list (List[Any]): List of input objects to broadcast.
- Each object must be picklable. Only objects on the ``src`` rank will
- be broadcast, but each rank must provide lists of equal sizes.
- src (int): Source rank from which to broadcast ``object_list``.
- Source rank is based on global process group (regardless of ``group`` argument)
- group: (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used. Default is ``None``.
- device (``torch.device``, optional): If not None, the objects are
- serialized and converted to tensors which are moved to the
- ``device`` before broadcasting. Default is ``None``.
- group_src (int): Source rank on ``group``. Must not specify one of ``group_src``
- and ``src`` but not both.
- Returns:
- ``None``. If rank is part of the group, ``object_list`` will contain the
- broadcasted objects from ``src`` rank.
- .. note:: For NCCL-based process groups, internal tensor representations
- of objects must be moved to the GPU device before communication takes
- place. In this case, the device used is given by
- ``torch.cuda.current_device()`` and it is the user's responsibility to
- ensure that this is set so that each rank has an individual GPU, via
- ``torch.cuda.set_device()``.
- .. note:: Note that this API differs slightly from the :func:`broadcast`
- collective since it does not provide an ``async_op`` handle and thus
- will be a blocking call.
- .. warning::
- Object collectives have a number of serious performance and scalability
- limitations. See :ref:`object_collectives` for details.
- .. warning::
- :func:`broadcast_object_list` uses ``pickle`` module implicitly, which
- is known to be insecure. It is possible to construct malicious pickle
- data which will execute arbitrary code during unpickling. Only call this
- function with data you trust.
- .. warning::
- Calling :func:`broadcast_object_list` with GPU tensors is not well supported
- and inefficient as it incurs GPU -> CPU transfer since tensors would be
- pickled. Please consider using :func:`broadcast` instead.
- Example::
- >>> # xdoctest: +SKIP("need process group init")
- >>> # Note: Process group initialization omitted on each rank.
- >>> import torch.distributed as dist
- >>> if dist.get_rank() == 0:
- >>> # Assumes world_size of 3.
- >>> objects = ["foo", 12, {1: 2}] # any picklable object
- >>> else:
- >>> objects = [None, None, None]
- >>> # Assumes backend is not NCCL
- >>> device = torch.device("cpu")
- >>> dist.broadcast_object_list(objects, src=0, device=device)
- >>> objects
- ['foo', 12, {1: 2}]
- """
- group = _group_or_default_group(group)
- if src is None and group_src is None:
- src = 0
- group_src = _canonicalize_group_rank(group, src, group_src, return_global=False)
- if _rank_not_in_group(group):
- _warn_not_in_group("broadcast_object_list")
- return
- # Current device selection.
- # To preserve backwards compatibility, ``device`` is default to ``None``
- # in which case we run current logic of device selection, i.e.
- # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
- # case it is not ``None`` we move the size and object tensors to be
- # broadcasted to this device.
- current_device = device or _get_object_coll_device(group)
- my_group_rank = group.rank()
- # Serialize object_list elements to tensors on src rank.
- if my_group_rank == group_src:
- tensor_list, size_list = zip(
- *[_object_to_tensor(obj, current_device, group) for obj in object_list]
- )
- object_sizes_tensor = torch.cat(size_list)
- else:
- object_sizes_tensor = torch.empty(
- len(object_list), dtype=torch.long, device=current_device
- )
- # Broadcast object sizes
- broadcast(object_sizes_tensor, group_src=group_src, group=group)
- # Concatenate and broadcast serialized object tensors
- # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
- # has only one element, we can skip the copy.
- if my_group_rank == group_src:
- if len(tensor_list) == 1: # type: ignore[possibly-undefined]
- # pyrefly: ignore [unbound-name]
- object_tensor = tensor_list[0]
- else:
- # pyrefly: ignore [unbound-name]
- object_tensor = torch.cat(tensor_list)
- else:
- object_tensor = torch.empty( # type: ignore[call-overload]
- torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
- dtype=torch.uint8,
- device=current_device,
- )
- broadcast(object_tensor, group_src=group_src, group=group)
- # Deserialize objects using their stored sizes.
- offset = 0
- if my_group_rank != group_src:
- for i, obj_size in enumerate(object_sizes_tensor):
- obj_view = object_tensor[offset : offset + obj_size]
- obj_view = obj_view.type(torch.uint8)
- offset += obj_size
- object_list[i] = _tensor_to_object(obj_view, obj_size, group)
- @_exception_logger
- def scatter_object_list(
- scatter_object_output_list: list[Any],
- scatter_object_input_list: list[Any] | None = None,
- src: int | None = None,
- group: ProcessGroup | None = None,
- group_src: int | None = None,
- ):
- """
- Scatters picklable objects in ``scatter_object_input_list`` to the whole group.
- Similar to :func:`scatter`, but Python objects can be passed in. On
- each rank, the scattered object will be stored as the first element of
- ``scatter_object_output_list``. Note that all objects in
- ``scatter_object_input_list`` must be picklable in order to be scattered.
- Args:
- scatter_object_output_list (List[Any]): Non-empty list whose first
- element will store the object scattered to this rank.
- scatter_object_input_list (List[Any], optional): List of input objects to scatter.
- Each object must be picklable. Only objects on the ``src`` rank will
- be scattered, and the argument can be ``None`` for non-src ranks.
- src (int): Source rank from which to scatter ``scatter_object_input_list``.
- Source rank is based on global process group (regardless of ``group`` argument).
- (If both ``src`` and ``group_src`` are None, default is global rank 0)
- group: (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used. Default is ``None``.
- group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src``
- Returns:
- ``None``. If rank is part of the group, ``scatter_object_output_list``
- will have its first element set to the scattered object for this rank.
- .. note:: Note that this API differs slightly from the scatter collective
- since it does not provide an ``async_op`` handle and thus will be a
- blocking call.
- .. warning::
- Object collectives have a number of serious performance and scalability
- limitations. See :ref:`object_collectives` for details.
- .. warning::
- :func:`scatter_object_list` uses ``pickle`` module implicitly, which
- is known to be insecure. It is possible to construct malicious pickle
- data which will execute arbitrary code during unpickling. Only call this
- function with data you trust.
- .. warning::
- Calling :func:`scatter_object_list` with GPU tensors is not well supported
- and inefficient as it incurs GPU -> CPU transfer since tensors would be
- pickled. Please consider using :func:`scatter` instead.
- Example::
- >>> # xdoctest: +SKIP("need process group init")
- >>> # Note: Process group initialization omitted on each rank.
- >>> import torch.distributed as dist
- >>> if dist.get_rank() == 0:
- >>> # Assumes world_size of 3.
- >>> objects = ["foo", 12, {1: 2}] # any picklable object
- >>> else:
- >>> # Can be any list on non-src ranks, elements are not used.
- >>> objects = [None, None, None]
- >>> output_list = [None]
- >>> dist.scatter_object_list(output_list, objects, src=0)
- >>> # Rank i gets objects[i]. For example, on rank 2:
- >>> output_list
- [{1: 2}]
- """
- group = _group_or_default_group(group)
- if src is None and group_src is None:
- src = 0
- group_src = _canonicalize_group_rank(group, src, group_src, return_global=False)
- if _rank_not_in_group(group):
- _warn_not_in_group("scatter_object_list")
- return
- if (
- not isinstance(scatter_object_output_list, list)
- or len(scatter_object_output_list) < 1
- ):
- raise ValueError(
- "Expected argument scatter_object_output_list to be a list of size at least 1."
- )
- my_group_rank = group.rank()
- pg_device = _get_object_coll_device(group)
- if my_group_rank == group_src:
- if scatter_object_input_list is None:
- raise ValueError(
- "source rank must provide non-None scatter_object_input_list"
- )
- tensor_list, tensor_sizes = zip(
- *[
- _object_to_tensor(obj, pg_device, group)
- for obj in scatter_object_input_list
- ]
- )
- tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes)
- # Src rank broadcasts the maximum tensor size. This is because all ranks are
- # expected to call into scatter() with equal-sized tensors.
- max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined]
- for tensor in tensor_list: # type: ignore[possibly-undefined]
- tensor.resize_(max_tensor_size)
- else:
- max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
- broadcast(max_tensor_size, group_src=group_src, group=group)
- # Scatter actual serialized objects
- # pyrefly: ignore [no-matching-overload]
- output_tensor = torch.empty(
- max_tensor_size.item(), dtype=torch.uint8, device=pg_device
- )
- scatter(
- output_tensor,
- scatter_list=None if my_group_rank != group_src else tensor_list, # type: ignore[possibly-undefined]
- group_src=group_src,
- group=group,
- )
- # Scatter per-object sizes to trim tensors when deserializing back to object
- obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device)
- scatter(
- obj_tensor_size,
- scatter_list=None if my_group_rank != group_src else tensor_sizes, # type: ignore[possibly-undefined]
- group_src=group_src,
- group=group,
- )
- # Deserialize back to object
- scatter_object_output_list[0] = _tensor_to_object(
- output_tensor, obj_tensor_size, group
- )
- @_exception_logger
- def all_gather(tensor_list, tensor, group=None, async_op=False):
- """
- Gathers tensors from the whole group in a list.
- Complex and uneven sized tensors are supported.
- Args:
- tensor_list (list[Tensor]): Output list. It should contain
- correctly-sized tensors to be used for output of the collective.
- Uneven sized tensors are supported.
- tensor (Tensor): Tensor to be broadcast from current process.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- Examples:
- >>> # xdoctest: +SKIP("need process group init")
- >>> # All tensors below are of torch.int64 dtype.
- >>> # We have 2 process groups, 2 ranks.
- >>> device = torch.device(f"cuda:{rank}")
- >>> tensor_list = [
- ... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2)
- ... ]
- >>> tensor_list
- [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0
- [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1
- >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
- >>> tensor
- tensor([1, 2], device='cuda:0') # Rank 0
- tensor([3, 4], device='cuda:1') # Rank 1
- >>> dist.all_gather(tensor_list, tensor)
- >>> tensor_list
- [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0
- [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1
- >>> # All tensors below are of torch.cfloat dtype.
- >>> # We have 2 process groups, 2 ranks.
- >>> tensor_list = [
- ... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2)
- ... ]
- >>> tensor_list
- [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0
- [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1
- >>> tensor = torch.tensor(
- ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device
- ... ) + 2 * rank * (1 + 1j)
- >>> tensor
- tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0
- tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1
- >>> dist.all_gather(tensor_list, tensor)
- >>> tensor_list
- [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0
- [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1
- """
- # Dynamo has built-in logic to map legacy distributed ops to functional collectives.
- # Let's redirect to a torch function mode that can mimic this logic outside Dynamo
- # (e.g., non-strict export implements such a torch function mode).
- relevant_args = (tensor,)
- if has_torch_function(relevant_args):
- return handle_torch_function(
- all_gather,
- relevant_args,
- tensor_list,
- tensor,
- group=group,
- async_op=async_op,
- )
- _check_tensor_list(tensor_list, "tensor_list")
- _check_single_tensor(tensor, "tensor")
- _ensure_all_tensors_same_dtype(tensor_list, tensor)
- if _rank_not_in_group(group):
- _warn_not_in_group("all_gather")
- return
- tensor_list = [
- t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list
- ]
- tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
- group = group or _get_default_group()
- opts = AllgatherOptions()
- opts.asyncOp = async_op
- work = group.allgather([tensor_list], [tensor], opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False):
- """
- Gather tensors from all ranks and put them in a single output tensor.
- This function requires all tensors to be the same size on each process.
- Args:
- output_tensor (Tensor): Output tensor to accommodate tensor elements
- from all ranks. It must be correctly sized to have one of the
- following forms:
- (i) a concatenation of all the input tensors along the primary
- dimension; for definition of "concatenation", see ``torch.cat()``;
- (ii) a stack of all the input tensors along the primary dimension;
- for definition of "stack", see ``torch.stack()``.
- Examples below may better explain the supported output forms.
- input_tensor (Tensor): Tensor to be gathered from current rank.
- Different from the ``all_gather`` API, the input tensors in this
- API must have the same size across all ranks.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- Examples:
- >>> # xdoctest: +SKIP("need process group init")
- >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
- >>> # We have two ranks.
- >>> device = torch.device(f"cuda:{rank}")
- >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank
- >>> tensor_in
- tensor([1, 2], device='cuda:0') # Rank 0
- tensor([3, 4], device='cuda:1') # Rank 1
- >>> # Output in concatenation form
- >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device)
- >>> dist.all_gather_into_tensor(tensor_out, tensor_in)
- >>> tensor_out
- tensor([1, 2, 3, 4], device='cuda:0') # Rank 0
- tensor([1, 2, 3, 4], device='cuda:1') # Rank 1
- >>> # Output in stack form
- >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device)
- >>> dist.all_gather_into_tensor(tensor_out2, tensor_in)
- >>> tensor_out2
- tensor([[1, 2],
- [3, 4]], device='cuda:0') # Rank 0
- tensor([[1, 2],
- [3, 4]], device='cuda:1') # Rank 1
- """
- # Dynamo has built-in logic to map legacy distributed ops to functional collectives.
- # Let's redirect to a torch function mode that can mimic this logic outside Dynamo
- # (e.g., non-strict export implements such a torch function mode).
- relevant_args = (input_tensor,)
- if has_torch_function(relevant_args):
- return handle_torch_function(
- all_gather_into_tensor,
- relevant_args,
- output_tensor,
- input_tensor,
- group=group,
- async_op=async_op,
- )
- _check_single_tensor(input_tensor, "input_tensor")
- _check_single_tensor(output_tensor, "output_tensor")
- if _rank_not_in_group(group):
- _warn_not_in_group("all_gather_into_tensor")
- return
- output_tensor = (
- output_tensor
- if not output_tensor.is_complex()
- else torch.view_as_real(output_tensor)
- )
- input_tensor = (
- input_tensor
- if not input_tensor.is_complex()
- else torch.view_as_real(input_tensor)
- )
- opts = AllgatherOptions()
- opts.asyncOp = async_op
- group = group or _get_default_group()
- if group in _world.pg_coalesce_state:
- # We are in coalescing context, do not issue single operation, just append a collective representation
- coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor)
- _world.pg_coalesce_state[group].append(coll)
- if async_op:
- return _IllegalWork()
- else:
- return None
- work = group._allgather_base(output_tensor, input_tensor, opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- @deprecated(
- "`torch.distributed._all_gather_base` is a private function and will be deprecated. "
- "Please use `torch.distributed.all_gather_into_tensor` instead.",
- category=FutureWarning,
- )
- def _all_gather_base(output_tensor, input_tensor, group=None, async_op: bool = False):
- """
- Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
- Args:
- output_tensor (Tensor): Output tensor. It should contain
- correctly-sized tensors to be used for output of the collective.
- input_tensor (Tensor): Tensor to be broadcast from current process.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- .. warning::
- `_all_gather_base` is a private function. Users should use
- `all_gather_into_tensor` instead.
- """
- return all_gather_into_tensor(output_tensor, input_tensor, group, async_op)
- @_exception_logger
- @deprecated(
- "`torch.distributed.all_gather_coalesced` will be deprecated. If you must use it, "
- "please revisit our documentation later at "
- "https://pytorch.org/docs/main/distributed.html#collective-functions",
- category=FutureWarning,
- )
- def all_gather_coalesced(
- output_tensor_lists, input_tensor_list, group=None, async_op: bool = False
- ):
- """
- Gathers input tensors from the whole group in a list in a coalesced manner.
- Complex tensors are supported.
- Args:
- output_tensor_lists (list[list[Tensor]]): Output list. It should contain
- correctly-sized tensors to be used for output of the collective.
- input_tensor_list (list[Tensor]): Tensors to be broadcast from
- current process. At least one tensor has to be non empty.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- Example:
- we have 2 process groups, 2 ranks.
- rank 0 passes:
- input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]]
- output_tensor_lists =
- [[[[-1, -1], [-1, -1]], [-1], [-1, -1]],
- [[[-1, -1], [-1, -1]], [-1], [-1, -1]]]
- rank 1 passes:
- input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]]
- output_tensor_lists =
- [[[[-1, -1], [-1, -1]], [-1], [-1, -1]],
- [[[-1, -1], [-1, -1]], [-1], [-1, -1]]]
- both rank 0 and 1 get:
- output_tensor_lists =
- [[[1, 1], [1, 1]], [2], [3, 3]],
- [[3, 3], [3, 3]], [5], [1, 1]]].
- WARNING: at this time individual shape checking is not implemented across nodes.
- For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the
- rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the
- all_gather_coalesced operation will proceed without complaint and return
- erroneous outputs. This lack of shape checking results in significant
- performance improvements but users of this function should take extra care
- to ensure that each node passes in tensors whose shapes match across nodes.
- """
- # We only check basic compatibility with C++ params here, C++ code will
- # do shape and type checking.
- if _rank_not_in_group(group):
- _warn_not_in_group("all_gather_coalesced")
- return
- _check_tensor_list(input_tensor_list, "input_tensor_list")
- _ensure_all_tensors_same_dtype(input_tensor_list)
- if not isinstance(output_tensor_lists, list):
- raise TypeError(
- "Invalid function argument: output_tensor_lists should be a list"
- )
- for output_tensor_list in output_tensor_lists:
- _check_tensor_list(output_tensor_list, "output_tensor_lists")
- _ensure_all_tensors_same_dtype(output_tensor_list)
- output_tensor_lists = [
- [t if not t.is_complex() else torch.view_as_real(t) for t in l]
- for l in output_tensor_lists
- ]
- input_tensor_list = [
- t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
- ]
- group = group or _get_default_group()
- opts = AllgatherOptions()
- opts.asyncOp = async_op
- work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts)
- if async_op:
- return work.get_future()
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- def _validate_output_list_for_rank(my_rank: int, dst: int, gather_list):
- if dst == my_rank:
- if not gather_list:
- raise ValueError(
- "Argument ``gather_list`` must be specified on destination rank."
- )
- elif gather_list:
- raise ValueError(
- "Argument ``gather_list`` must NOT be specified on non-destination ranks."
- )
- @_exception_logger
- def gather(
- tensor: torch.Tensor,
- gather_list: list[torch.Tensor] | None = None,
- dst: int | None = None,
- group: ProcessGroup | None = None,
- async_op: bool = False,
- group_dst: int | None = None,
- ):
- """
- Gathers a list of tensors in a single process.
- This function requires all tensors to be the same size on each process.
- Args:
- tensor (Tensor): Input tensor.
- gather_list (list[Tensor], optional): List of appropriately,
- same-sized tensors to use for gathered data
- (default is None, must be specified on the destination rank)
- dst (int, optional): Destination rank on global process group (regardless of ``group`` argument).
- (If both ``dst`` and ``group_dst`` are None, default is global rank 0)
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- .. note:: Note that all Tensors in gather_list must have the same size.
- Example::
- >>> # xdoctest: +SKIP("no rank")
- >>> # We have 2 process groups, 2 ranks.
- >>> tensor_size = 2
- >>> device = torch.device(f'cuda:{rank}')
- >>> tensor = torch.ones(tensor_size, device=device) + rank
- >>> if dist.get_rank() == 0:
- >>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)]
- >>> else:
- >>> gather_list = None
- >>> dist.gather(tensor, gather_list, dst=0)
- >>> # Rank 0 gets gathered data.
- >>> gather_list
- [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0
- None # Rank 1
- """
- _check_single_tensor(tensor, "tensor")
- # Parameter ``gather_list`` may be left unspecified on non-dst ranks.
- if gather_list:
- _check_tensor_list(gather_list, "gather_list")
- else:
- gather_list = []
- _ensure_all_tensors_same_dtype(tensor, gather_list)
- group = _group_or_default_group(group)
- if _rank_not_in_group(group):
- _warn_not_in_group("gather")
- return
- if dst is None and group_dst is None:
- dst = 0
- group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False)
- my_group_rank = group.rank()
- _validate_output_list_for_rank(my_group_rank, group_dst, gather_list)
- output_tensors = [gather_list] if group_dst == my_group_rank else []
- input_tensors = [tensor]
- opts = GatherOptions()
- opts.rootRank = group_dst
- opts.asyncOp = async_op
- work = group.gather(output_tensors, input_tensors, opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- def scatter(
- tensor: torch.Tensor,
- scatter_list: list[torch.Tensor] | None = None,
- src: int | None = None,
- group: ProcessGroup | None = None,
- async_op: bool = False,
- group_src: int | None = None,
- ):
- """
- Scatters a list of tensors to all processes in a group.
- Each process will receive exactly one tensor and store its data in the
- ``tensor`` argument.
- Complex tensors are supported.
- Args:
- tensor (Tensor): Output tensor.
- scatter_list (list[Tensor]): List of tensors to scatter (default is
- None, must be specified on the source rank)
- src (int): Source rank on global process group (regardless of ``group`` argument).
- (If both ``src`` and ``group_src`` are None, default is global rank 0)
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src``
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- .. note:: Note that all Tensors in scatter_list must have the same size.
- Example::
- >>> # xdoctest: +SKIP("need process group init")
- >>> # Note: Process group initialization omitted on each rank.
- >>> import torch.distributed as dist
- >>> tensor_size = 2
- >>> device = torch.device(f'cuda:{rank}')
- >>> output_tensor = torch.zeros(tensor_size, device=device)
- >>> if dist.get_rank() == 0:
- >>> # Assumes world_size of 2.
- >>> # Only tensors, all of which must be the same size.
- >>> t_ones = torch.ones(tensor_size, device=device)
- >>> t_fives = torch.ones(tensor_size, device=device) * 5
- >>> scatter_list = [t_ones, t_fives]
- >>> else:
- >>> scatter_list = None
- >>> dist.scatter(output_tensor, scatter_list, src=0)
- >>> # Rank i gets scatter_list[i].
- >>> output_tensor
- tensor([1., 1.], device='cuda:0') # Rank 0
- tensor([5., 5.], device='cuda:1') # Rank 1
- """
- _check_single_tensor(tensor, "tensor")
- # Parameter ``scatter_list`` may be left unspecified on non-src ranks.
- if scatter_list:
- _check_tensor_list(scatter_list, "scatter_list")
- else:
- scatter_list = []
- _ensure_all_tensors_same_dtype(tensor, scatter_list)
- group = _group_or_default_group(group)
- if src is None and group_src is None:
- src = 0
- group_src = _canonicalize_group_rank(group, src, group_src, return_global=False)
- if _rank_not_in_group(group):
- _warn_not_in_group("scatter")
- return
- scatter_list = [
- t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list
- ]
- tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
- my_group_rank = group.rank()
- if group_src == my_group_rank:
- if not scatter_list:
- raise ValueError(
- "Argument ``scatter_list`` must be specified on source rank."
- )
- input_tensors = [scatter_list]
- output_tensors = [tensor]
- else:
- if scatter_list:
- raise ValueError(
- "Argument ``scatter_list`` must NOT be specified on non-source ranks."
- )
- input_tensors = []
- output_tensors = [tensor]
- opts = ScatterOptions()
- opts.rootRank = group_src
- opts.asyncOp = async_op
- work = group.scatter(output_tensors, input_tensors, opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- def reduce_scatter(
- output, input_list, op=ReduceOp.SUM, group=None, async_op: bool = False
- ):
- """
- Reduces, then scatters a list of tensors to all processes in a group.
- Args:
- output (Tensor): Output tensor.
- input_list (list[Tensor]): List of tensors to reduce and scatter.
- op (optional): One of the values from
- ``torch.distributed.ReduceOp``
- enum. Specifies an operation used for element-wise reductions.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group.
- """
- _check_single_tensor(output, "output")
- _check_tensor_list(input_list, "input_list")
- _ensure_all_tensors_same_dtype(output, input_list)
- if _rank_not_in_group(group):
- _warn_not_in_group("reduce_scatter")
- return
- opts = ReduceScatterOptions()
- opts.reduceOp = op
- opts.asyncOp = async_op
- group = group or _get_default_group()
- work = group.reduce_scatter([output], [input_list], opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False):
- """
- Reduces, then scatters a tensor to all ranks in a group.
- Args:
- output (Tensor): Output tensor. It should have the same size across all
- ranks.
- input (Tensor): Input tensor to be reduced and scattered. Its size
- should be output tensor size times the world size. The input tensor
- can have one of the following shapes:
- (i) a concatenation of the output tensors along the primary
- dimension, or
- (ii) a stack of the output tensors along the primary dimension.
- For definition of "concatenation", see ``torch.cat()``.
- For definition of "stack", see ``torch.stack()``.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group.
- Examples:
- >>> # xdoctest: +SKIP("need process group init")
- >>> # All tensors below are of torch.int64 dtype and on CUDA devices.
- >>> # We have two ranks.
- >>> device = torch.device(f"cuda:{rank}")
- >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device)
- >>> # Input in concatenation form
- >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device)
- >>> tensor_in
- tensor([0, 1, 2, 3], device='cuda:0') # Rank 0
- tensor([0, 1, 2, 3], device='cuda:1') # Rank 1
- >>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
- >>> tensor_out
- tensor([0, 2], device='cuda:0') # Rank 0
- tensor([4, 6], device='cuda:1') # Rank 1
- >>> # Input in stack form
- >>> tensor_in = torch.reshape(tensor_in, (world_size, 2))
- >>> tensor_in
- tensor([[0, 1],
- [2, 3]], device='cuda:0') # Rank 0
- tensor([[0, 1],
- [2, 3]], device='cuda:1') # Rank 1
- >>> dist.reduce_scatter_tensor(tensor_out, tensor_in)
- >>> tensor_out
- tensor([0, 2], device='cuda:0') # Rank 0
- tensor([4, 6], device='cuda:1') # Rank 1
- """
- # Dynamo has built-in logic to map legacy distributed ops to functional collectives.
- # Let's redirect to a torch function mode that can mimic this logic outside Dynamo
- # (e.g., non-strict export implements such a torch function mode).
- relevant_args = (input,)
- if has_torch_function(relevant_args):
- return handle_torch_function(
- reduce_scatter_tensor,
- relevant_args,
- output,
- input,
- op=op,
- group=group,
- async_op=async_op,
- )
- _check_single_tensor(output, "output")
- _check_single_tensor(input, "input")
- if _rank_not_in_group(group):
- _warn_not_in_group("reduce_scatter_tensor")
- return
- opts = ReduceScatterOptions()
- opts.reduceOp = op
- opts.asyncOp = async_op
- group = group or _get_default_group()
- # Check if we are in coalescing context
- # If we are, do not issue single operation, just append a collective representation
- if group in _world.pg_coalesce_state:
- coll = _CollOp(reduce_scatter_tensor, input, output, op, None)
- _world.pg_coalesce_state[group].append(coll)
- if async_op:
- return _IllegalWork()
- else:
- return None
- work = group._reduce_scatter_base(output, input, opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @deprecated(
- "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. "
- "Please use `torch.distributed.reduce_scatter_tensor` instead.",
- category=FutureWarning,
- )
- def _reduce_scatter_base(
- output, input, op=ReduceOp.SUM, group=None, async_op: bool = False
- ):
- """
- Reduces, then scatters a flattened tensor to all processes in a group.
- Args:
- output (Tensor): Output tensor.
- input (Tensor): Input tensor that is of size output tensor size times world size
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group.
- .. warning::
- `_reduce_scatter_base` is a private function. Users should use
- `reduce_scatter_tensor` instead.
- """
- return reduce_scatter_tensor(output, input, op, group, async_op)
- @_exception_logger
- def all_to_all_single(
- output,
- input,
- output_split_sizes=None,
- input_split_sizes=None,
- group=None,
- async_op: bool = False,
- ):
- """
- Split input tensor and then scatter the split list to all processes in a group.
- Later the received tensors are concatenated from all the processes in the group
- and returned as a single output tensor.
- Complex tensors are supported.
- Args:
- output (Tensor): Gathered concatenated output tensor.
- input (Tensor): Input tensor to scatter.
- output_split_sizes: (list[Int], optional): Output split sizes for dim 0
- if specified None or empty, dim 0 of ``output`` tensor must divide
- equally by ``world_size``.
- input_split_sizes: (list[Int], optional): Input split sizes for dim 0
- if specified None or empty, dim 0 of ``input`` tensor must divide
- equally by ``world_size``.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group.
- .. warning::
- `all_to_all_single` is experimental and subject to change.
- Examples:
- >>> # xdoctest: +SKIP("Undefined rank")
- >>> input = torch.arange(4) + rank * 4
- >>> input
- tensor([0, 1, 2, 3]) # Rank 0
- tensor([4, 5, 6, 7]) # Rank 1
- tensor([8, 9, 10, 11]) # Rank 2
- tensor([12, 13, 14, 15]) # Rank 3
- >>> output = torch.empty([4], dtype=torch.int64)
- >>> dist.all_to_all_single(output, input)
- >>> output
- tensor([0, 4, 8, 12]) # Rank 0
- tensor([1, 5, 9, 13]) # Rank 1
- tensor([2, 6, 10, 14]) # Rank 2
- tensor([3, 7, 11, 15]) # Rank 3
- >>> # Essentially, it is similar to following operation:
- >>> scatter_list = list(input.chunk(world_size))
- >>> gather_list = list(output.chunk(world_size))
- >>> for i in range(world_size):
- >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i)
- >>> # Another example with uneven split
- >>> input
- tensor([0, 1, 2, 3, 4, 5]) # Rank 0
- tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1
- tensor([20, 21, 22, 23, 24]) # Rank 2
- tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3
- >>> input_splits
- [2, 2, 1, 1] # Rank 0
- [3, 2, 2, 2] # Rank 1
- [2, 1, 1, 1] # Rank 2
- [2, 2, 2, 1] # Rank 3
- >>> output_splits
- [2, 3, 2, 2] # Rank 0
- [2, 2, 1, 2] # Rank 1
- [1, 2, 1, 2] # Rank 2
- [1, 2, 1, 1] # Rank 3
- >>> output = ...
- >>> dist.all_to_all_single(output, input, output_splits, input_splits)
- >>> output
- tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0
- tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1
- tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2
- tensor([ 5, 17, 18, 24, 36]) # Rank 3
- >>> # Another example with tensors of torch.cfloat type.
- >>> input = torch.tensor(
- ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
- ... ) + 4 * rank * (1 + 1j)
- >>> input
- tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0
- tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1
- tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2
- tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3
- >>> output = torch.empty([4], dtype=torch.int64)
- >>> dist.all_to_all_single(output, input)
- >>> output
- tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0
- tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1
- tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2
- tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3
- """
- # Dynamo has built-in logic to map legacy distributed ops to functional collectives.
- # Let's redirect to a torch function mode that can mimic this logic outside Dynamo
- # (e.g., non-strict export implements such a torch function mode).
- relevant_args = (input,)
- if has_torch_function(relevant_args):
- return handle_torch_function(
- all_to_all_single,
- relevant_args,
- output,
- input,
- output_split_sizes=output_split_sizes,
- input_split_sizes=input_split_sizes,
- group=group,
- async_op=async_op,
- )
- if _rank_not_in_group(group):
- _warn_not_in_group("all_to_all_single")
- return
- opts = AllToAllOptions()
- opts.asyncOp = async_op
- _check_single_tensor(output, "output")
- _check_single_tensor(input, "input")
- _ensure_all_tensors_same_dtype(output, input)
- if input.is_complex():
- input = torch.view_as_real(input)
- if output.is_complex():
- output = torch.view_as_real(output)
- output_split_sizes = [] if output_split_sizes is None else output_split_sizes
- input_split_sizes = [] if input_split_sizes is None else input_split_sizes
- group = group or _get_default_group()
- work = group.alltoall_base(
- output, input, output_split_sizes, input_split_sizes, opts
- )
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- def all_to_all(
- output_tensor_list, input_tensor_list, group=None, async_op: bool = False
- ):
- """
- Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
- Complex tensors are supported.
- Args:
- output_tensor_list (list[Tensor]): List of tensors to be gathered one
- per rank.
- input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group.
- .. warning::
- `all_to_all` is experimental and subject to change.
- Examples:
- >>> # xdoctest: +SKIP("Undefined rank")
- >>> input = torch.arange(4) + rank * 4
- >>> input = list(input.chunk(4))
- >>> input
- [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0
- [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1
- [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2
- [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3
- >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
- >>> dist.all_to_all(output, input)
- >>> output
- [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0
- [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1
- [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2
- [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3
- >>> # Essentially, it is similar to following operation:
- >>> scatter_list = input
- >>> gather_list = output
- >>> for i in range(world_size):
- >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i)
- >>> input
- tensor([0, 1, 2, 3, 4, 5]) # Rank 0
- tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1
- tensor([20, 21, 22, 23, 24]) # Rank 2
- tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3
- >>> input_splits
- [2, 2, 1, 1] # Rank 0
- [3, 2, 2, 2] # Rank 1
- [2, 1, 1, 1] # Rank 2
- [2, 2, 2, 1] # Rank 3
- >>> output_splits
- [2, 3, 2, 2] # Rank 0
- [2, 2, 1, 2] # Rank 1
- [1, 2, 1, 2] # Rank 2
- [1, 2, 1, 1] # Rank 3
- >>> input = list(input.split(input_splits))
- >>> input
- [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0
- [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1
- [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2
- [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3
- >>> output = ...
- >>> dist.all_to_all(output, input)
- >>> output
- [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0
- [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1
- [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2
- [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3
- >>> # Another example with tensors of torch.cfloat type.
- >>> input = torch.tensor(
- ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat
- ... ) + 4 * rank * (1 + 1j)
- >>> input = list(input.chunk(4))
- >>> input
- [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0
- [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1
- [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2
- [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3
- >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4))
- >>> dist.all_to_all(output, input)
- >>> output
- [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0
- [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1
- [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2
- [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3
- """
- if _rank_not_in_group(group):
- _warn_not_in_group("all_to_all")
- return
- opts = AllToAllOptions()
- opts.asyncOp = async_op
- _check_tensor_list(output_tensor_list, "output_tensor_list")
- _check_tensor_list(input_tensor_list, "input_tensor_list")
- _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)
- input_tensor_list = [
- t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
- ]
- output_tensor_list = [
- t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list
- ]
- group = group or _get_default_group()
- work = group.alltoall(output_tensor_list, input_tensor_list, opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- @_exception_logger
- def barrier(
- group: ProcessGroup | None = GroupMember.WORLD,
- async_op: bool = False,
- device_ids=None,
- ):
- """
- Synchronize all processes.
- This collective blocks processes until the whole group enters this function,
- if async_op is False, or if async work handle is called on wait().
- Args:
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- async_op (bool, optional): Whether this op should be an async op
- device_ids ([int], optional): List of device/GPU ids. Only one id is expected.
- Returns:
- Async work handle, if async_op is set to True.
- None, if not async_op or if not part of the group
- .. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective.
- .. note:: `ProcessGroupNCCL` implements barrier as an all_reduce of a 1-element tensor. A device must be chosen
- for allocating this tensor. The device choice is made by checking in this order (1) the first device passed to
- `device_ids` arg of barrier if not None, (2) the device passed to init_process_group if not None, (3) the device
- that was first used with this process group, if another collective with tensor inputs has been performed, (4)
- the device index indicated by the global rank mod local device count.
- """
- group = group or _get_default_group()
- if _rank_not_in_group(group):
- _warn_not_in_group("barrier")
- return
- opts = BarrierOptions()
- opts.asyncOp = async_op
- # Detect the accelerator on the machine. If no accelerator is available, it
- # returns CPU.
- device = torch._C._get_accelerator()
- if isinstance(device_ids, list):
- opts.device_ids = device_ids
- # use only the first device id
- # pyrefly: ignore [read-only]
- opts.device = torch.device(device.type, device_ids[0])
- elif getattr(group, "bound_device_id", None) is not None:
- # Use device id from `init_process_group(device_id=...)`
- opts.device = group.bound_device_id # type: ignore[assignment]
- elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
- # pyrefly: ignore [read-only]
- opts.device = torch.device("cpu")
- else:
- # Use the current device set by the user. If user did not set any, this
- # may use default device 0, causing issues like hang or all processes
- # creating context on device 0.
- # pyrefly: ignore [read-only]
- opts.device = device
- if group.rank() == 0:
- warnings.warn( # warn only once
- "barrier(): using the device under current context. "
- "You can specify `device_id` in `init_process_group` to mute this warning.",
- stacklevel=2,
- )
- work = group.barrier(opts=opts)
- if async_op:
- return work
- elif (
- work is not None
- ): # Backward compatible with backends that don't sync at CPP level
- work.wait()
- # Otherwise, the backend has sync'ed at CPP level
- def monitored_barrier(
- group: ProcessGroup | None = GroupMember.WORLD,
- timeout=None,
- wait_all_ranks: bool = False,
- ):
- """
- Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout.
- It is able to report ranks that did not pass this barrier within the provided timeout.
- Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0.
- Rank 0 will block until all send /recv from other ranks are processed, and will report
- failures for ranks that failed to respond in time. Note that if one rank does not reach the
- monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier.
- This collective will block all processes/ranks in the group, until the
- whole group exits the function successfully, making it useful for debugging
- and synchronizing. However, it can have a performance impact and should only
- be used for debugging or scenarios that require full synchronization points
- on the host-side. For debugging purposes, this barrier can be inserted
- before the application's collective calls to check if any ranks are
- desynchronized.
- .. note:: Note that this collective is only supported with the GLOO backend.
- Args:
- group (ProcessGroup, optional): The process group to work on. If
- ``None``, the default process group will be used.
- timeout (datetime.timedelta, optional): Timeout for monitored_barrier.
- If ``None``, the default process group timeout will be used.
- wait_all_ranks (bool, optional): Whether to collect all failed ranks or
- not. By default, this is ``False`` and ``monitored_barrier`` on rank 0
- will throw on the first failed rank it encounters in order to fail
- fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will
- collect all failed ranks and throw an error containing information
- about all failed ranks.
- Returns:
- ``None``.
- Example::
- >>> # xdoctest: +SKIP("need process group init")
- >>> # Note: Process group initialization omitted on each rank.
- >>> import torch.distributed as dist
- >>> if dist.get_rank() != 1:
- >>> dist.monitored_barrier() # Raises exception indicating that
- >>> # rank 1 did not call into monitored_barrier.
- >>> # Example with wait_all_ranks=True
- >>> if dist.get_rank() == 0:
- >>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception
- >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into
- >>> # monitored_barrier.
- """
- # Need to call rank not in group before using the group, otherwise
- # "Invalid process group" error is raised.
- if _rank_not_in_group(group):
- _warn_not_in_group("monitored_barrier")
- return
- if get_backend(group) != Backend.GLOO:
- raise ValueError("monitored_barrier is only implemented for GLOO backend.")
- if timeout is None:
- timeout = _get_default_timeout(get_backend(group))
- elif isinstance(timeout, float):
- # TODO(whc) apparently some existing test case for monitored_barrier passes in a timeout in float format?
- warnings.warn(
- "Please specify timeout arg as a timedelta. "
- f"Converting current value of {timeout} assuming it represents seconds",
- stacklevel=2,
- )
- timeout = timedelta(seconds=timeout)
- _check_valid_timeout(timeout)
- group_to_use = _get_default_group() if group is None else group
- return group_to_use.monitored_barrier( # type:ignore[attr-defined]
- timeout, wait_all_ranks=wait_all_ranks
- )
- def _create_process_group_wrapper(
- wrapped_pg: torch._C._distributed_c10d.Backend,
- store_prefix: str,
- store: Store,
- rank: int,
- world_size: int,
- timeout: timedelta = default_pg_timeout,
- ):
- if not _GLOO_AVAILABLE:
- raise AssertionError("ProcessGroupWrapper unsupported without GLOO backend.")
- # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...
- # Create a separate prefix store for the helper process group.
- prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}"
- store = PrefixStore(prefix, store)
- helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout)
- # Wrap the underlying pg with ProcessGroupWrapper.
- wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg)
- return wrapped_pg
- # helper function for deterministically hashing a list of ranks to a unique
- # string
- def _hash_ranks_to_str(ranks: list[int]) -> str:
- rank_join: str = "_".join(map(str, ranks))
- # In case there is already a PG with the same rank composition
- unique_str = "_".join([rank_join, str(len(_world.pg_names))])
- return hashlib.sha1(bytes(unique_str, "utf-8"), usedforsecurity=False).hexdigest()
- # Takes a list of ranks and computes an integer color
- def _process_group_color(ranks: list[int]) -> int:
- # Convert list to tuple to make it hashable
- # pyrefly: ignore [bad-assignment]
- ranks = tuple(ranks)
- hash_value = hash(ranks)
- # Split color must be:
- # - a non-negative integer;
- # - a type compatible with C's int because we are pybinding to the latter.
- # Thus, we limit the hash value within c_int's max value.
- max_c_int = 2 ** (ctypes.sizeof(ctypes.c_int) * 8 - 1)
- color = abs(hash_value) % max_c_int
- return color
- def _process_group_name(ranks, use_hashed_name) -> GroupName:
- # Create name for a process group.
- global _world
- if use_hashed_name:
- pg_name = GroupName(_hash_ranks_to_str(ranks))
- else:
- pg_name = GroupName(str(_world.group_count))
- _world.group_count += 1
- # TODO: why is group count incremented only in the else path?
- return pg_name
- def _get_backend_from_str(backend: str | None = None) -> Backend:
- # Default to the same backend as the global process group
- # if backend is not specified.
- if not backend:
- backend = get_backend(_get_default_group())
- return Backend(backend)
- def _is_safe_to_split() -> bool:
- """
- Checks if it is safe to split the any process group in the world.
- This is only safe if the default pg has a bound device id, otherwise
- users must be aware that a pg is only splittable after the first collective is
- issued.
- """
- return _get_default_group().bound_device_id is not None
- @_time_logger
- def split_group(
- parent_pg: ProcessGroup | None = None,
- split_ranks: list | None = None,
- timeout: timedelta | None = None,
- pg_options: Any | None = None,
- group_desc: str | None = None,
- ) -> ProcessGroup | None:
- """
- Create a new process group split from the given parent process group.
- warning:: This is an experimental API. Only the ``NCCL`` and custom plugin backends
- are supported. Other backends will raise an error.
- Users of this API must guarantee that all ranks in the parent group enter this API call,
- and the split of the sub groups is the same across all ranks in the parent group.
- Args:
- parent_pg (ProcessGroup, optional): The parent process group. If None,
- the default process group will be used. Users need to guarantee that
- the parent group is fully initialized (e.g, communicators are initialized)
- split_ranks (list[list[int]]): the split ranks, which is a list of list of ranks.
- Users need to make sure the validity of the split ranks such that one
- split (represented by one inner list of ints) does not overlap with any other split.
- Note that the ranks in each split is the group rank (instead of global rank)
- in the parent pg. For example, if the parent group has 4 ranks, and split_ranks can be
- [[0, 1], [2, 3]]. Note [[0,1]] is also a valid split, in which case ranks 2, 3 would
- return a non-group member.
- timeout (timedelta, optional): see `init_process_group` for details and default value.
- pg_options (ProcessGroupOptions, optional): Additional options need to be passed in during
- the construction of specific process groups. i.e.``is_high_priority_stream``
- can be specified so that process group can pick up high priority cuda streams.
- group_desc (str, optional): a string to describe the process group.
- Returns:
- ProcessGroup if the current rank is within one split/subgroup given by split_ranks,
- or None if the current rank is not part of any split_ranks`.
- """
- # check inputs
- if split_ranks is None or len(split_ranks) == 0:
- raise ValueError("split_ranks cannot be None or empty")
- global _world
- default_pg = _get_default_group()
- device_id = default_pg.bound_device_id
- if not device_id and not _use_torchcomms_enabled():
- raise RuntimeError(
- "No device associated with the default pg, not safe to split any process groups"
- )
- global_rank = default_pg.rank()
- global_world_size = default_pg.size()
- if not parent_pg:
- parent_pg = default_pg
- if parent_pg not in _world.pg_group_ranks:
- raise ValueError(f"Group {parent_pg} is not registered")
- parent_global_to_group_ranks = _world.pg_group_ranks[parent_pg]
- parent_group_to_global_ranks = {
- group_rank: global_rank
- for global_rank, group_rank in parent_global_to_group_ranks.items()
- }
- if global_rank not in parent_global_to_group_ranks:
- raise ValueError(
- f"Global rank {global_rank} is not part of the parent group {parent_pg}"
- )
- parent_group_rank = parent_global_to_group_ranks[global_rank]
- parent_backend = parent_pg._get_backend(torch.device("cuda"))
- # if the parent backend does not support splitting, raise error
- # currently this API only support NCCL backend
- if (
- not parent_backend or not parent_backend.supports_splitting
- ) and not _use_torchcomms_enabled():
- raise RuntimeError(
- "No backend for the parent process group or its backend does not support splitting"
- )
- # set the group_desc before the color or no_cloor split
- if hasattr(parent_backend, "comm_split_count") and group_desc is None:
- group_desc = f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined]
- parent_backend_str, _ = _world.pg_map[parent_pg]
- # same type of backend as the parent process group
- backend = Backend(parent_backend_str)
- backend_config = BackendConfig(backend)
- # TODO: figure out pg option for torchComms
- if pg_options is None and not _use_torchcomms_enabled():
- # default pg_options same as the parent process group
- # A deep copy is needed because if the option will be modified inside split
- # and if we split parent pg multiple times, we will run into device out of bound error.
- pg_options = copy.deepcopy(parent_backend.options)
- # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants,
- # which may just pass their timeout value (or None)
- if timeout is None:
- timeout = _get_default_timeout(backend)
- _check_valid_timeout(timeout)
- # find my group of ranks and my group local rank in split_ranks
- # for ranks which are not in any split PGs, we just pass in this the first split group
- # and None will be returned.
- my_group = split_ranks[0]
- for split_group in split_ranks:
- if len(split_group) == 0:
- raise ValueError("the split group cannot be empty")
- if len(split_group) > global_world_size:
- raise ValueError(
- "the split group's size should be less or equal to the world_size set by init_process_group"
- )
- if len(split_group) != len(set(split_group)):
- raise ValueError("the split group cannot have duplicate ranks")
- split_group = sorted(split_group)
- if parent_group_rank in split_group:
- my_group = split_group
- break
- # use_hashed_name is True to ensure that subgroups have unique names.
- # This is needed as some backends (e.g. Gloo) use the group name as a
- # PrefixStore prefix for initialization of splits. Thus, names have to be
- # unique to avoid key collisions.
- group_name = _process_group_name(my_group, use_hashed_name=True)
- split_pg = parent_pg.split_group(
- my_group,
- timeout=timeout,
- opts=pg_options,
- group_name=group_name,
- group_desc=group_desc,
- )
- if split_pg is None:
- return None
- global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group]
- split_pg.bound_device_id = device_id # type: ignore[union-attr]
- split_backend_class = split_pg._get_backend(torch.device("cuda"))
- if not _use_torchcomms_enabled():
- split_backend_class._set_sequence_number_for_group()
- if split_pg.group_name != group_name:
- raise AssertionError(
- f"group name should be set to {group_name} but got {split_pg.group_name}"
- )
- # update global state
- _world.pg_map[split_pg] = (backend, split_pg.get_group_store())
- _world.pg_names[split_pg] = group_name
- _register_process_group(group_name, split_pg)
- _world.pg_backend_config[split_pg] = str(backend_config)
- pg_tag = f"ptd:{group_name}"
- _world.tags_to_pg.setdefault(pg_tag, []).append(split_pg)
- _world.pg_to_tag[split_pg] = pg_tag
- # Create the global rank to group rank mapping
- _world.pg_group_ranks[split_pg] = {
- global_rank: group_rank
- for group_rank, global_rank in enumerate(global_ranks_in_my_group)
- }
- if _use_torchcomms_enabled():
- # pyrefly: ignore [missing-attribute]
- _world.comms.append(split_backend_class.get_comm())
- return split_pg
- @_time_logger
- def new_group(
- ranks=None,
- timeout=None,
- backend=None,
- pg_options=None,
- use_local_synchronization: bool = False,
- group_desc=None,
- device_id: torch.device | None = None,
- ):
- """
- Create a new distributed group.
- This function requires that all processes in the main group (i.e. all
- processes that are part of the distributed job) enter this function, even
- if they are not going to be members of the group. Additionally, groups
- should be created in the same order in all processes.
- .. warning::
- Safe concurrent usage:
- When using multiple process groups with the ``NCCL`` backend, the user
- must ensure a globally consistent execution order of collectives across
- ranks.
- If multiple threads within a process issue collectives, explicit
- synchronization is necessary to ensure consistent ordering.
- When using async variants of torch.distributed communication APIs,
- a work object is returned and the communication kernel is
- enqueued on a separate CUDA stream, allowing overlap of communication
- and computation. Once one or more async ops have been issued on one process
- group, they must be synchronized with other cuda streams by calling `work.wait()`
- before using another process group.
- See `Using multiple NCCL communicators concurrently
- <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#using-multiple-nccl-communicators-concurrently>`
- for more details.
- Args:
- ranks (list[int]): List of ranks of group members. If ``None``, will be
- set to all ranks. Default is ``None``.
- timeout (timedelta, optional): see `init_process_group` for details and default value.
- backend (str or Backend, optional): The backend to use. Depending on
- build-time configurations, valid values are ``gloo`` and ``nccl``.
- By default uses the same backend as the global group. This field
- should be given as a lowercase string (e.g., ``"gloo"``), which can
- also be accessed via :class:`Backend` attributes (e.g.,
- ``Backend.GLOO``). If ``None`` is passed in, the backend
- corresponding to the default process group will be used. Default is
- ``None``.
- pg_options (ProcessGroupOptions, optional): process group options
- specifying what additional options need to be passed in during
- the construction of specific process groups. i.e. for the ``nccl``
- backend, ``is_high_priority_stream`` can be specified so that
- process group can pick up high priority cuda streams. For other available options to config nccl,
- See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization
- (bool, optional): perform a group-local barrier at the end of the process group creation.
- This is different in that non-member ranks don't need to call into API and don't
- join the barrier.
- group_desc (str, optional): a string to describe the process group.
- device_id (torch.device, optional): a single, specific device
- to "bind" this process to, The `new_group` call will try to initialize
- a communication backend immediately for the device if this field is given.
- Returns:
- A handle of distributed group that can be given to collective calls or
- GroupMember.NON_GROUP_MEMBER if the rank is not part of ``ranks``.
- N.B. use_local_synchronization doesn't work with MPI.
- N.B. While use_local_synchronization=True can be significantly faster with larger
- clusters and small process groups, care must be taken since it changes cluster behavior
- as non-member ranks don't join the group barrier().
- N.B. use_local_synchronization=True can lead to deadlocks when each rank creates
- multiple overlapping process groups. To avoid that, make sure all ranks follow the
- same global creation order.
- """
- return _new_group_with_tag(
- ranks,
- timeout,
- backend,
- pg_options,
- None,
- use_local_synchronization=use_local_synchronization,
- group_desc=group_desc,
- device_id=device_id,
- )
- def _new_group_with_tag(
- ranks=None,
- timeout=None,
- backend=None,
- backend_options=None,
- pg_tag=None,
- use_local_synchronization=False,
- group_desc=None,
- device_id: torch.device | None = None,
- ):
- """
- Variant of ``new_group`` that exposes tag creation.
- :: N.B. The mechanism is experimental and tied to the functional collectives effort, see
- ``torch.distributed._functional_collectives`` for reference on how to use it.
- """
- global _world
- default_pg = _get_default_group()
- if device_id is None:
- device_id = default_pg.bound_device_id
- elif default_pg.bound_device_id is not None:
- if device_id != default_pg.bound_device_id:
- raise AssertionError(
- "Mismatched bound device between new pg and the default pg."
- )
- default_backend, default_store = _world.pg_map[default_pg]
- global_rank = default_pg.rank()
- global_world_size = default_pg.size()
- # Default to the same backend as the global process group
- # if the backend is not specified.
- if not backend:
- backend = default_backend
- backend = Backend(backend)
- # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants,
- # which may just pass their timeout value (or None)
- if timeout is None:
- timeout = _get_default_timeout(backend)
- _check_valid_timeout(timeout)
- if use_local_synchronization:
- # MPI backend doesn't have have a way for us to perform a partial sync
- if backend == Backend.MPI:
- raise ValueError(
- "MPI backend doesn't support use_local_synchronization=True"
- )
- if ranks is not None and get_rank() not in ranks:
- return None
- # checks the input ranks
- if ranks is not None:
- ranks = sorted(ranks)
- group_world_size = len(ranks)
- if group_world_size > global_world_size:
- raise ValueError(
- "the new group's world size should be less or "
- "equal to the world size set by "
- "init_process_group"
- )
- # check ranks' sanity
- for rank in ranks:
- if rank < 0 or rank >= global_world_size:
- raise ValueError(
- "The new group's rank should be within "
- "the world_size set by init_process_group"
- )
- if global_rank in ranks:
- group_rank = ranks.index(global_rank)
- else:
- group_rank = None
- else:
- ranks = list(range(global_world_size))
- group_world_size = global_world_size
- group_rank = global_rank
- group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization)
- pg, pg_store = _new_process_group_helper(
- group_world_size,
- group_rank,
- ranks,
- backend,
- default_store,
- group_name,
- backend_options=backend_options,
- timeout=timeout,
- pg_tag=pg_tag,
- device_id=device_id,
- group_desc=group_desc,
- )
- # Create the global rank to group rank mapping
- _world.pg_group_ranks[pg] = {
- global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
- }
- if _is_barrier_after_init() == 1:
- # barrier at the end to ensure that once we return from this method, all
- # process groups including global variables (if any) are updated
- # correctly on all ranks.
- # Update 04/2023: for large-scale runs, this barrier (esp. store-based
- # barrier) may be costly and/or unscalable. Also, in a lot of cases,
- # these barriers may be unnecessary, as proven by a green CI after
- # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been
- # added which enables this barrier only when set to 1.
- logger.info(
- "Performing barrier after ProcessGroup initialization since "
- "TORCH_DIST_INIT_BARRIER = 1"
- )
- if backend == Backend.MPI:
- # MPI doesn't have store.
- barrier()
- else:
- barrier_store = pg_store if use_local_synchronization else default_store
- world_size = len(ranks) if use_local_synchronization else get_world_size()
- # Use store based barrier here since barrier() used a bunch of
- # default devices and messes up NCCL internal state.
- _store_based_barrier(
- global_rank, barrier_store, group_name, world_size, timeout
- )
- return pg
- def new_subgroups(
- group_size=None,
- group=None,
- timeout=None,
- backend=None,
- pg_options=None,
- group_desc=None,
- ):
- """
- Create subgroups of equal size.
- By default, it creates intra-machine subgroups,
- where each of which contains all the ranks of a machine, based on the assumption
- that each machine has the same number of devices.
- This is a convenience API that calls ``new_group`` to generate multiple subgroups.
- It requires that all processes in the main group (i.e. all
- processes that are part of the distributed job) enter this function, even
- if they are not going to be members of the group.
- .. warning::
- If ``group_size`` is passed in, the world size must be divisible by ``group_size``.
- If no ``group_size`` is passed in, it believe that you are creating a group based
- on CUDA and determining the group size by number of CUDA devices, and if not all
- the machines have the same number of devices, the subgroup division will be
- different across nodes and can cause unexpected behaviors. Therefore, if you are
- creating a subgroup that does not depend on CUDA (such as Gloo on CPU), please
- pass in ``group_size`` correctly.
- .. warning::
- See warning `Safe concurrent usage` for `new_group` API for important details about
- using multiple process groups concurrently in a safe manner.
- Args:
- group_size (int, optional): The size of each subgroup. If ``None``,
- the default subgroup size is equal to the number of devices on each machine,
- based on the assumption that each machine has exactly the same
- number of devices. Default is ``None``.
- group (ProcessGroup, optional): The process group to work on. If
- ``None``, the default process group will be used. Default is ``None``.
- timeout (timedelta, optional): see `init_process_group` for details and default value.
- backend (str or Backend, optional): The backend to use. Depending on
- build-time configurations, valid values are ``gloo`` and ``nccl``.
- By default uses the same backend as the global group. This field
- should be given as a lowercase string (e.g., ``"gloo"``), which can
- also be accessed via :class:`Backend` attributes (e.g.,
- ``Backend.GLOO``). If ``None`` is passed in, the backend
- corresponding to the default process group will be used. Default is
- ``None``.
- pg_options (ProcessGroupOptions, optional): process group options
- specifying what additional options need to be passed in during
- the construction of specific process groups. i.e. for the ``nccl``
- backend, ``is_high_priority_stream`` can be specified so that
- process group can pick up high priority cuda streams.
- group_desc (str, optional): A string describing the group. Each subgroup will
- inherit its group_desc
- Returns:
- The subgroup containing the current rank, and all the subgroups used for cleanup.
- Examples:
- >>> # Create intra-machine subgroups.
- >>> # xdoctest: +SKIP("need process group init")
- >>> cur_subgroup, subgroups = dist.new_subgroups()
- >>> # Allreduce within the machine.
- >>> rank = dist.get_rank()
- >>> tensor = torch.ones(1, device=rank) * rank
- >>> dist.all_reduce(tensor, group=cur_subgroup)
- >>> tensor
- tensor([28]) # Assume 8 CUDA devices per machine. 28 is sum(range(8)).
- >>> # Cleanup.
- >>> for subgroup in subgroups:
- >>> dist.destroy_process_group(subgroup)
- """
- if group_size is None:
- if not torch.cuda.is_available():
- raise ValueError(
- "Default group size only takes effect when CUDA is available."
- "If your subgroup using a backend that does not depend on CUDA,"
- "please pass in 'group_size' correctly."
- )
- group_size = torch.cuda.device_count()
- if group_size <= 0:
- raise ValueError(f"The arg 'group_size' ({group_size}) must be positive")
- world_size = get_world_size(group=group)
- if world_size < group_size:
- raise ValueError(
- f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})"
- )
- if world_size % group_size != 0:
- raise ValueError(
- f"The world size ({world_size}) must be divisible by '{group_size=}'"
- )
- # TODO: Use itertools.batched(get_process_group_ranks(group=group), group_size) instead when Python 3.12 is supported.
- ranks = get_process_group_ranks(group=group)
- ranks_per_subgroup_list = [
- ranks[i : i + group_size] for i in range(0, len(ranks), group_size)
- ]
- return new_subgroups_by_enumeration(
- ranks_per_subgroup_list,
- timeout=timeout,
- backend=backend,
- pg_options=pg_options,
- group_desc=group_desc,
- )
- def new_subgroups_by_enumeration(
- ranks_per_subgroup_list,
- timeout=None,
- backend=None,
- pg_options=None,
- group_desc=None,
- ):
- """
- Create subgroups by dividing the global world.
- The division is specified by a nested list of ranks. The subgroups cannot have
- overlap, and some ranks may not have to be in any subgroup.
- This is a convenience API that calls ``new_group`` to generate multiple subgroups.
- It requires that all processes in the main group (i.e. all
- processes that are part of the distributed job) enter this function, even
- if they are not going to be members of the group.
- .. warning::
- See warning `Safe concurrent usage` for `new_group` API for important details about
- using multiple process groups concurrently in a safe manner.
- Args:
- ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of
- group members.
- timeout (timedelta, optional): see `init_process_group` for details and default value.
- backend (str or Backend, optional): The backend to use. Depending on
- build-time configurations, valid values are ``gloo`` and ``nccl``.
- By default uses the same backend as the global group. This field
- should be given as a lowercase string (e.g., ``"gloo"``), which can
- also be accessed via :class:`Backend` attributes (e.g.,
- ``Backend.GLOO``). If ``None`` is passed in, the backend
- corresponding to the default process group will be used. Default is
- ``None``.
- pg_options (ProcessGroupOptions, optional): process group options
- specifying what additional options need to be passed in during
- the construction of specific process groups. i.e. for the ``nccl``
- backend, ``is_high_priority_stream`` can be specified so that
- process group can pick up high priority cuda streams.
- group_desc (str, optional): A string describing the group. Each subgroup will
- inherit its group_desc.
- Returns:
- The subgroup containing the current rank, and all the subgroups used for cleanup.
- Examples:
- >>> # Create two subgroups, where each has 2 processes.
- >>> # xdoctest: +SKIP("need process group init")
- >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]])
- >>> rank = dist.get_rank()
- >>> tensor = torch.ones(1, device=rank) * rank
- >>> dist.all_reduce(tensor, group=cur_subgroup)
- >>> tensor
- tensor([2]) # Subgroup 0: ranks 0 and 2
- tensor([4]) # Subgroup 1: ranks 1 and 3
- """
- if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0:
- raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty")
- subgroups = []
- cur_subgroup = None
- # Create a mapping from rank to subgroup to check if there is any subgroup overlap.
- rank_to_ranks_dict = {} # type: ignore[var-annotated]
- for ranks in ranks_per_subgroup_list:
- subgroup = new_group(
- ranks=ranks,
- timeout=timeout,
- backend=backend,
- pg_options=pg_options,
- group_desc=group_desc,
- )
- subgroups.append(subgroup)
- my_rank = get_rank()
- for rank in ranks:
- if rank in rank_to_ranks_dict:
- raise ValueError(
- f"Rank {rank} has appeared in both subgroup {rank_to_ranks_dict[rank]} and {ranks}"
- )
- rank_to_ranks_dict[rank] = ranks
- if my_rank == rank:
- cur_subgroup = subgroup
- logger.info("Rank %s is assigned to subgroup %s", rank, ranks)
- return cur_subgroup, subgroups
- def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> ProcessGroup | None:
- if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
- tag = f"user:{tag}"
- for group in _world.tags_to_pg.get(tag, []):
- if group.size() != len(ranks):
- continue
- group_ranks = get_process_group_ranks(group)
- good = all(r in group_ranks for r in ranks)
- if good:
- return group
- return None
- def _find_or_create_pg_by_ranks_and_tag(
- tag: str, ranks: list[int], stride: int
- ) -> ProcessGroup:
- if len(ranks) % stride != 0:
- raise ValueError(
- f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
- )
- my_rank = get_rank()
- my_ranks = None
- if stride == len(ranks):
- my_ranks = ranks.copy()
- if my_rank not in my_ranks:
- raise AssertionError("rankset doesn't include the current node")
- else:
- for i in range(0, len(ranks), stride):
- rank_set = ranks[i : i + stride]
- if my_rank in rank_set:
- my_ranks = rank_set
- if my_ranks is None:
- raise AssertionError("rankset doesn't include the current node")
- my_ranks = sorted(my_ranks)
- pg = _find_pg_by_ranks_and_tag(tag, my_ranks)
- if pg is not None:
- return pg
- if tag == "":
- raise ValueError("Cannot automatically create PG with empty tag")
- # TODO copy settings and timeout from default PG
- return _new_group_with_tag(my_ranks, pg_tag=tag)
- def _get_group_tag(pg: ProcessGroup) -> str:
- """Return the tag associated with ``pg``."""
- tag = _world.pg_to_tag[pg]
- tag = tag.removeprefix("user:")
- return tag
- def _get_process_group_name(pg: ProcessGroup) -> str:
- return _world.pg_names.get(pg, "None")
- def _get_process_group_store(pg: ProcessGroup) -> Store:
- return _world.pg_map[pg][1]
- # Shrink flags for process group backends
- SHRINK_DEFAULT = 0x00
- SHRINK_ABORT = 0x01
- @_time_logger
- def shrink_group(
- ranks_to_exclude: list[int],
- group: ProcessGroup | None = None,
- shrink_flags: int = SHRINK_DEFAULT,
- pg_options: Any | None = None,
- ) -> ProcessGroup:
- """
- Shrinks a process group by excluding specified ranks.
- Creates and returns a new, smaller process group comprising only the ranks
- from the original group that were not in the ``ranks_to_exclude`` list.
- Args:
- ranks_to_exclude (List[int]): A list of ranks from the original
- ``group`` to exclude from the new group.
- group (ProcessGroup, optional): The process group to shrink. If ``None``,
- the default process group is used. Defaults to ``None``.
- shrink_flags (int, optional): Flags to control the shrinking behavior.
- Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``.
- ``SHRINK_ABORT`` will attempt to terminate ongoing operations
- in the parent communicator before shrinking.
- Defaults to ``SHRINK_DEFAULT``.
- pg_options (ProcessGroupOptions, optional): Backend-specific options to apply
- to the shrunken process group. If provided, the backend will use
- these options when creating the new group. If omitted, the new group
- inherits defaults from the parent.
- Returns:
- ProcessGroup: a new group comprised of the remaining ranks. If the
- default group was shrunk, the returned group becomes the new default group.
- Raises:
- TypeError: if the group’s backend does not support shrinking.
- ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds,
- duplicates, or excludes all ranks).
- RuntimeError: if an excluded rank calls this function or the backend
- fails the operation.
- Notes:
- - Only non-excluded ranks should call this function; excluded ranks
- must not participate in the shrink operation.
- - Shrinking the default group destroys all other process groups since
- rank reassignment makes them inconsistent.
- """
- # Step 1: Validate input parameters with comprehensive error checking
- _validate_shrink_inputs(ranks_to_exclude, shrink_flags)
- # Step 2: Get target group and essential properties
- target_group_info = _prepare_shrink_target_group(group)
- # Step 3: Validate backend requirements and availability
- backend_impl = _validate_shrink_backend_requirements(target_group_info)
- # Step 4: Validate ranks against group and check for duplicates
- excluded_ranks_set = _validate_and_process_excluded_ranks(
- ranks_to_exclude, target_group_info
- )
- # Step 5: Execute the actual shrink operation (backend-specific)
- new_backend = backend_impl.shrink(
- sorted(excluded_ranks_set),
- shrink_flags,
- pg_options if pg_options is not None else None,
- )
- # Step 6: Handle cleanup and creation of new process group
- target_group_info["pg_options_override"] = pg_options
- return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend)
- def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None:
- """Validate input parameters for shrink_group."""
- if not isinstance(ranks_to_exclude, list):
- raise TypeError(
- f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. "
- f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5."
- )
- if not ranks_to_exclude:
- raise ValueError(
- "ranks_to_exclude cannot be empty. To shrink a group, you must specify at least "
- "one rank to exclude. Example: [failed_rank_id]"
- )
- # Validate shrink_flags with clear explanation of valid values
- valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT]
- if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags:
- raise ValueError(
- f"Invalid shrink_flags value: {shrink_flags}. Must be one of: "
- f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). "
- f"Use SHRINK_ABORT to abort ongoing operations before shrinking."
- )
- def _prepare_shrink_target_group(group: ProcessGroup | None) -> dict:
- """Prepare and validate the target group for shrinking."""
- target_pg = group if group is not None else _get_default_group()
- # Cache frequently accessed properties to avoid repeated calls
- group_size = int(target_pg.size())
- group_info = {
- "process_group": target_pg,
- "is_default_group": (target_pg == _get_default_group()),
- "group_size": group_size,
- "current_rank": target_pg.rank(),
- "group_name": _get_process_group_name(target_pg),
- }
- # Validate that we have a valid process group
- if group_size <= 1:
- raise ValueError(
- f"Cannot shrink a process group with size {group_size}. "
- f"Group must have at least 2 ranks to support shrinking."
- )
- return group_info
- def _validate_shrink_backend_requirements(group_info: dict) -> Any:
- """Return the backend implementation for the target group or raise if unsupported."""
- target_pg = group_info["process_group"]
- group_name = group_info["group_name"]
- # Get the group's backend directly via ProcessGroup API. Prefer a bound device if present,
- # otherwise try CUDA then fall back to CPU.
- try:
- preferred_device = getattr(target_pg, "bound_device_id", None)
- if preferred_device is not None:
- backend_impl = target_pg._get_backend(preferred_device)
- else:
- # Try CUDA first if available, else CPU
- try:
- backend_impl = target_pg._get_backend(torch.device("cuda"))
- except Exception:
- backend_impl = target_pg._get_backend(torch.device("cpu"))
- except RuntimeError as e:
- raise RuntimeError(
- f"Cannot access device backend for process group '{group_name}'. "
- f"Ensure the process group was initialized with a compatible device backend and devices are available."
- ) from e
- try:
- supports = bool(backend_impl.supports_shrinking)
- except Exception:
- supports = False
- if not supports:
- raise TypeError(
- f"Process group backend for '{group_name}' does not support shrinking operations."
- )
- return backend_impl
- def _validate_and_process_excluded_ranks(
- ranks_to_exclude: list[int], group_info: dict
- ) -> set:
- """Validate excluded ranks and convert to set for efficient operations."""
- group_size = group_info["group_size"]
- current_rank = group_info["current_rank"]
- # Use set for O(1) duplicate detection and membership testing
- excluded_ranks_set = set()
- # Validate each rank with detailed error messages
- for i, rank in enumerate(ranks_to_exclude):
- if not isinstance(rank, int):
- raise TypeError(
- f"All elements in ranks_to_exclude must be integers. "
- f"Element at index {i} is {type(rank).__name__}: {rank}"
- )
- if not (0 <= rank < group_size):
- raise ValueError(
- f"Rank {rank} at index {i} is out of bounds for group size {group_size}. "
- f"Valid ranks are in range [0, {group_size - 1}]."
- )
- if rank in excluded_ranks_set:
- raise ValueError(
- f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. "
- f"Each rank can only be excluded once."
- )
- excluded_ranks_set.add(rank)
- # Ensure we don't exclude all ranks
- if len(excluded_ranks_set) >= group_size:
- raise ValueError(
- f"Cannot exclude all {group_size} ranks from process group. "
- f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks."
- )
- # Critical check: current rank should not be in excluded list
- if current_rank in excluded_ranks_set:
- raise RuntimeError(
- f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). "
- f"Only non-excluded ranks should participate in the shrinking operation. "
- f"Excluded ranks should terminate their processes instead."
- )
- return excluded_ranks_set
- def _finalize_shrunk_group(
- group_info: dict, excluded_ranks_set: set, new_backend
- ) -> ProcessGroup:
- """Clean up old group and create new shrunk process group."""
- target_pg = group_info["process_group"]
- is_default_group = group_info["is_default_group"]
- # Handle default group dependencies - destroy other groups first
- if is_default_group:
- _destroy_all_other_groups(exclude_group=target_pg)
- # Gather original group metadata before cleanup
- original_group_metadata = _extract_group_metadata(target_pg)
- # Calculate remaining ranks efficiently
- original_ranks = get_process_group_ranks(target_pg)
- remaining_ranks = [
- rank for rank in original_ranks if rank not in excluded_ranks_set
- ]
- # Clean up the original group
- _cleanup_original_group(target_pg, is_default_group)
- # Create and configure the new process group
- new_pg = _create_shrunk_process_group(
- new_backend, remaining_ranks, original_group_metadata, is_default_group
- )
- # Register the new group in global state
- if is_default_group:
- _update_default_pg(new_pg)
- # Update global state with new group information
- rank_mapping = {
- global_rank: group_rank
- for group_rank, global_rank in enumerate(remaining_ranks)
- }
- _update_process_group_global_state(
- pg=new_pg,
- backend_name=original_group_metadata["backend_name"],
- store=original_group_metadata["store"],
- group_name=original_group_metadata["new_group_name"],
- backend_config=original_group_metadata["backend_config"],
- rank_mapping=rank_mapping,
- )
- return new_pg
- def _extract_group_metadata(target_pg: ProcessGroup) -> dict:
- """Extract metadata from the original group before cleanup."""
- original_backend_name, original_store = _world.pg_map[target_pg]
- original_backend_config = _world.pg_backend_config.get(target_pg, "")
- original_group_name = _get_process_group_name(target_pg)
- # Extract device binding information before cleanup to avoid accessing destroyed group
- bound_device_id = None
- if hasattr(target_pg, "bound_device_id"):
- bound_device_id = target_pg.bound_device_id
- # Generate new group name for the shrunk group; hash for uniqueness across backends
- remaining_ranks = list(get_process_group_ranks(target_pg))
- new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True)
- return {
- "backend_name": original_backend_name,
- "store": original_store,
- "backend_config": original_backend_config,
- "original_group_name": original_group_name,
- "new_group_name": new_group_name,
- "bound_device_id": bound_device_id, # Safe to access after cleanup
- }
- def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None:
- """Clean up the original process group safely."""
- try:
- destroy_process_group(target_pg)
- except Exception:
- group_type = "default" if is_default_group else "non-default"
- logger.warning(
- "Failed to destroy %s group during shrinking", group_type, exc_info=True
- )
- # Ensure global state cleanup even if destroy_process_group fails
- _cleanup_process_group_global_state(target_pg)
- def _create_shrunk_process_group(
- new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool
- ) -> ProcessGroup:
- """Create and configure the new shrunk process group."""
- # Create new group properties
- new_group_rank = new_backend.rank()
- new_group_size = new_backend.size()
- group_name = metadata["new_group_name"]
- # Generate descriptive group description
- if is_default_group:
- group_desc = "default:shrunken"
- else:
- group_desc = f"{metadata['original_group_name']}:shrunk"
- # Create process group with new communicator (clone the parent store like split does)
- prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone())
- new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size)
- # Configure backend using the device type of the new backend's bound device if available,
- # otherwise derive from the original group's bound device or fall back to CPU.
- backend_device = metadata.get("bound_device_id")
- if backend_device is None:
- # Default to CPU if no bound device is present
- backend_device = torch.device("cpu")
- # Choose backend enum based on device type
- if backend_device.type == "cuda":
- backend_type = ProcessGroup.BackendType.NCCL
- else:
- backend_type = ProcessGroup.BackendType.GLOO
- new_pg._register_backend(backend_device, backend_type, new_backend)
- new_pg._set_default_backend(backend_type)
- # Inherit device binding from original group if it was bound
- bound_device_id = metadata.get("bound_device_id")
- if bound_device_id is not None:
- new_pg.bound_device_id = bound_device_id
- # Set group metadata
- new_pg._set_group_name(group_name)
- new_pg._set_group_desc(group_desc)
- # Persist backend configuration overrides (if provided via shrink_group)
- backend_config_override = metadata.get("backend_config")
- if backend_config_override is not None:
- # Store for introspection/debugging and potential backend hooks
- _world.pg_backend_config[new_pg] = backend_config_override
- return new_pg
- def _destroy_all_other_groups(exclude_group: ProcessGroup | None = None) -> None:
- """
- Destroy all process groups except the excluded group and clean up all global state.
- This is necessary when shrinking the default group because global ranks
- are reassigned by NCCL, making all existing process groups inconsistent.
- Note: Uses abort for non-collective cleanup since excluded ranks may not
- participate in collective operations. Backend cleanup is handled independently per group.
- Args:
- exclude_group (ProcessGroup, optional): Process group to exclude from destruction.
- If None, destroys all process groups.
- """
- # Get list of groups to destroy (avoid modifying dict while iterating)
- groups_to_destroy = []
- for pg in list(_world.pg_group_ranks.keys()):
- if exclude_group is not None and pg == exclude_group:
- continue
- groups_to_destroy.append(pg)
- # Warn user about automatic destruction
- if groups_to_destroy:
- group_names = [_get_process_group_name(pg) for pg in groups_to_destroy]
- logger.warning(
- "Shrinking default group will destroy %d other process groups: %s. "
- "This is necessary because shrinking the default group reassigns global ranks, "
- "making existing groups inconsistent.",
- len(groups_to_destroy),
- ", ".join(group_names),
- )
- # Destroy each group and clean up global state
- for pg in groups_to_destroy:
- try:
- # First call abort_process_group which handles the C++ cleanup non-collectively
- _abort_process_group(pg)
- except Exception:
- # Log but don't fail - some groups might already be destroyed
- logger.warning(
- "Failed to abort process group %s",
- _get_process_group_name(pg),
- exc_info=True,
- )
- # Ensure all global state is cleaned up even if _abort_process_group fails
- # or doesn't clean up everything
- _cleanup_process_group_global_state(pg)
- def _cleanup_process_group_global_state(pg: ProcessGroup) -> None:
- """
- Clean up all global state associated with a process group.
- This function ensures complete cleanup of process group state from all
- global dictionaries and registries, even if destroy_process_group fails
- or doesn't clean up everything. This is critical when destroying multiple
- groups to prevent inconsistent state.
- The cleanup removes the process group from:
- - _world.pg_map (backend and store mapping)
- - _world.pg_names (group name mapping)
- - _world.pg_group_ranks (rank mappings)
- - _world.pg_backend_config (backend configuration)
- - _world.tags_to_pg and _world.pg_to_tag (tag mappings)
- - _world.pg_coalesce_state (coalescing state)
- - C++ internal registries via _unregister_process_group
- Args:
- pg (ProcessGroup): The process group to clean up.
- """
- try:
- # Clean up main process group mappings
- _world.pg_map.pop(pg, None)
- _world.pg_group_ranks.pop(pg, None)
- _world.pg_backend_config.pop(pg, None)
- # Clean up process group name mapping
- group_name = _world.pg_names.pop(pg, None)
- # Clean up tag mappings
- pg_tag = _world.pg_to_tag.pop(pg, None)
- if pg_tag is not None and pg_tag in _world.tags_to_pg:
- try:
- _world.tags_to_pg[pg_tag].remove(pg)
- # Remove the tag entry if list is empty
- if not _world.tags_to_pg[pg_tag]:
- _world.tags_to_pg.pop(pg_tag, None)
- except (ValueError, KeyError):
- # Process group was already removed from the list
- pass
- # Clean up any registered process group names using C++ unregister function
- if group_name is not None:
- try:
- _unregister_process_group(group_name)
- except Exception:
- # Process group name might not be registered or already unregistered
- pass
- # Clean up coalesce state if present
- _world.pg_coalesce_state.pop(pg, None)
- except Exception:
- # Log cleanup failures but don't propagate - we want to continue with other cleanups
- logger.warning(
- "Failed to fully clean up global state for process group", exc_info=True
- )
- def _update_process_group_global_state(
- pg: ProcessGroup,
- backend_name: str,
- store: Store,
- group_name: GroupName,
- backend_config: str,
- rank_mapping: dict[int, int] | None = None,
- pg_tag: str | None = None,
- user_tag: str | None = None,
- ) -> None:
- """
- Update all global state dictionaries for a process group.
- This helper function consolidates the common pattern of updating multiple
- global state dictionaries when creating or modifying process groups.
- Args:
- pg (ProcessGroup): The process group to update state for.
- backend_name (str): Backend name for pg_map.
- store (Store): Store instance for pg_map.
- group_name (str): Group name for pg_names and registration.
- backend_config (str): Backend configuration string.
- rank_mapping (Dict[int, int], optional): Global rank to group rank mapping.
- If None, skips updating pg_group_ranks.
- pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}".
- user_tag (str, optional): User-provided tag for special tag handling.
- If provided, creates "user:{user_tag}" tag and also adds to default "".
- """
- # Update main process group mappings
- _world.pg_map[pg] = (backend_name, store)
- _world.pg_names[pg] = group_name
- _world.pg_backend_config[pg] = backend_config
- # Register the process group name
- _register_process_group(group_name, pg)
- # Update rank mapping if provided
- if rank_mapping is not None:
- _world.pg_group_ranks[pg] = rank_mapping
- # Handle tag management
- if pg_tag is None:
- pg_tag = f"ptd:{group_name}"
- if user_tag is not None:
- # Special handling for user-provided tags
- # Add to default "" tag first
- _world.tags_to_pg.setdefault("", []).append(pg)
- # Then create user-specific tag
- user_pg_tag = f"user:{user_tag}"
- _world.tags_to_pg.setdefault(user_pg_tag, []).append(pg)
- _world.pg_to_tag[pg] = user_pg_tag
- else:
- # Standard process group tag
- _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
- _world.pg_to_tag[pg] = pg_tag
|