dataset.py 288 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997
  1. import collections
  2. import copy
  3. import html
  4. import itertools
  5. import logging
  6. import time
  7. import warnings
  8. from typing import (
  9. TYPE_CHECKING,
  10. Any,
  11. Callable,
  12. Dict,
  13. Generic,
  14. Iterable,
  15. Iterator,
  16. List,
  17. Literal,
  18. Mapping,
  19. Optional,
  20. Tuple,
  21. TypeVar,
  22. Union,
  23. )
  24. import numpy as np
  25. import ray
  26. import ray.cloudpickle as pickle
  27. from ray._common.usage import usage_lib
  28. from ray._private.thirdparty.tabulate.tabulate import tabulate
  29. from ray.data._internal.compute import ComputeStrategy, TaskPoolStrategy
  30. from ray.data._internal.dataset_repr import _build_dataset_ascii_repr
  31. from ray.data._internal.datasource.bigquery_datasink import BigQueryDatasink
  32. from ray.data._internal.datasource.clickhouse_datasink import (
  33. ClickHouseDatasink,
  34. ClickHouseTableSettings,
  35. SinkMode,
  36. )
  37. from ray.data._internal.datasource.csv_datasink import CSVDatasink
  38. from ray.data._internal.datasource.iceberg_datasink import IcebergDatasink
  39. from ray.data._internal.datasource.image_datasink import ImageDatasink
  40. from ray.data._internal.datasource.json_datasink import JSONDatasink
  41. from ray.data._internal.datasource.lance_datasink import LanceDatasink
  42. from ray.data._internal.datasource.mongo_datasink import MongoDatasink
  43. from ray.data._internal.datasource.numpy_datasink import NumpyDatasink
  44. from ray.data._internal.datasource.parquet_datasink import ParquetDatasink
  45. from ray.data._internal.datasource.sql_datasink import SQLDatasink
  46. from ray.data._internal.datasource.tfrecords_datasink import TFRecordDatasink
  47. from ray.data._internal.datasource.webdataset_datasink import WebDatasetDatasink
  48. from ray.data._internal.equalize import _equalize
  49. from ray.data._internal.execution.interfaces import RefBundle
  50. from ray.data._internal.execution.interfaces.ref_bundle import (
  51. _ref_bundles_iterator_to_block_refs_list,
  52. )
  53. from ray.data._internal.execution.util import memory_string
  54. from ray.data._internal.iterator.iterator_impl import DataIteratorImpl
  55. from ray.data._internal.iterator.stream_split_iterator import StreamSplitDataIterator
  56. from ray.data._internal.logical.interfaces import LogicalPlan
  57. from ray.data._internal.logical.operators import (
  58. Count,
  59. Filter,
  60. FlatMap,
  61. InputData,
  62. Join,
  63. Limit,
  64. MapBatches,
  65. MapRows,
  66. Project,
  67. RandomizeBlocks,
  68. RandomShuffle,
  69. Repartition,
  70. Sort,
  71. StreamingRepartition,
  72. StreamingSplit,
  73. Union as UnionLogicalOperator,
  74. Write,
  75. Zip,
  76. )
  77. from ray.data._internal.pandas_block import PandasBlockBuilder, PandasBlockSchema
  78. from ray.data._internal.plan import ExecutionPlan
  79. from ray.data._internal.planner.exchange.sort_task_spec import SortKey
  80. from ray.data._internal.remote_fn import cached_remote_fn
  81. from ray.data._internal.split import _get_num_rows, _split_at_indices
  82. from ray.data._internal.stats import DatasetStats, DatasetStatsSummary, _StatsManager
  83. from ray.data._internal.tensor_extensions.arrow import (
  84. ArrowTensorTypeV2,
  85. get_arrow_extension_fixed_shape_tensor_types,
  86. )
  87. from ray.data._internal.util import (
  88. AllToAllAPI,
  89. ConsumptionAPI,
  90. _validate_rows_per_file_args,
  91. get_compute_strategy,
  92. merge_resources_to_ray_remote_args,
  93. )
  94. from ray.data.aggregate import (
  95. AggregateFn,
  96. AggregateFnV2,
  97. Max,
  98. Mean,
  99. Min,
  100. Std,
  101. Sum,
  102. Unique,
  103. )
  104. from ray.data.block import (
  105. Block,
  106. BlockAccessor,
  107. DataBatch,
  108. DataBatchColumn,
  109. T,
  110. U,
  111. UserDefinedFunction,
  112. _apply_batch_format,
  113. )
  114. from ray.data.context import DataContext
  115. from ray.data.datasource import Connection, Datasink, FilenameProvider, SaveMode
  116. from ray.data.datasource.datasink import WriteResult, _gen_datasink_write_result
  117. from ray.data.datasource.file_datasink import _FileDatasink
  118. from ray.data.datatype import DataType
  119. from ray.data.iterator import DataIterator
  120. from ray.data.random_access_dataset import RandomAccessDataset
  121. from ray.types import ObjectRef
  122. from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
  123. from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
  124. from ray.widgets import Template
  125. from ray.widgets.util import repr_with_fallback
  126. if TYPE_CHECKING:
  127. import daft
  128. import dask
  129. import mars
  130. import modin
  131. import pandas
  132. import pyarrow
  133. import pyspark
  134. import tensorflow as tf
  135. import torch
  136. import torch.utils.data
  137. from tensorflow_metadata.proto.v0 import schema_pb2
  138. from ray.data._internal.execution.interfaces import Executor, NodeIdStr
  139. from ray.data._internal.execution.streaming_executor import StreamingExecutor
  140. from ray.data.grouped_data import GroupedData
  141. from ray.data.stats import DatasetSummary
  142. from ray.data.expressions import Expr, StarExpr, col
  143. logger = logging.getLogger(__name__)
  144. # Special column name for train/test split to avoid collision with user columns
  145. _TRAIN_TEST_SPLIT_COLUMN = "__ray_train_test_split_is_train__"
  146. TensorflowFeatureTypeSpec = Union[
  147. "tf.TypeSpec", List["tf.TypeSpec"], Dict[str, "tf.TypeSpec"]
  148. ]
  149. TensorFlowTensorBatchType = Union["tf.Tensor", Dict[str, "tf.Tensor"]]
  150. CollatedData = TypeVar("CollatedData")
  151. TorchBatchType = Union[Dict[str, "torch.Tensor"], CollatedData]
  152. TorchDeviceType = Union[str, "torch.device", int]
  153. """
  154. A device identifier, which can be a string (e.g. 'cpu', 'cuda:0'),
  155. a torch.device object, or an integer (e.g. 0 for 'cuda:0').
  156. """
  157. BT_API_GROUP = "Basic Transformations"
  158. SSR_API_GROUP = "Sorting, Shuffling and Repartitioning"
  159. SMJ_API_GROUP = "Splitting, Merging, Joining Datasets"
  160. GGA_API_GROUP = "Grouped and Global Aggregations"
  161. CD_API_GROUP = "Consuming Data"
  162. IOC_API_GROUP = "I/O and Conversion"
  163. IM_API_GROUP = "Inspecting Metadata"
  164. E_API_GROUP = "Execution"
  165. EXPRESSION_API_GROUP = "Expressions"
  166. @PublicAPI
  167. class Dataset:
  168. """A Dataset is a distributed data collection for data loading and processing.
  169. Datasets are distributed pipelines that produce ``ObjectRef[Block]`` outputs,
  170. where each block holds data in Arrow format, representing a shard of the overall
  171. data collection. The block also determines the unit of parallelism. For more
  172. details, see :ref:`Ray Data Key Concepts <data_key_concepts>`.
  173. Datasets can be created in multiple ways:
  174. * from external storage systems such as local disk, S3, HDFS etc. via the ``read_*()`` APIs.
  175. * from existing memory data via ``from_*()`` APIs
  176. * from synthetic data via ``range_*()`` APIs
  177. The (potentially processed) Dataset can be saved back to external storage systems
  178. via the ``write_*()`` APIs.
  179. Examples:
  180. .. testcode::
  181. :skipif: True
  182. import ray
  183. # Create dataset from synthetic data.
  184. ds = ray.data.range(1000)
  185. # Create dataset from in-memory data.
  186. ds = ray.data.from_items(
  187. [{"col1": i, "col2": i * 2} for i in range(1000)]
  188. )
  189. # Create dataset from external storage system.
  190. ds = ray.data.read_parquet("s3://bucket/path")
  191. # Save dataset back to external storage system.
  192. ds.write_csv("s3://bucket/output")
  193. Dataset has two kinds of operations: transformation, which takes in Dataset
  194. and outputs a new Dataset (e.g. :py:meth:`.map_batches()`); and consumption,
  195. which produces values (not a data stream) as output
  196. (e.g. :meth:`.iter_batches()`).
  197. Dataset transformations are lazy, with execution of the transformations being
  198. triggered by downstream consumption.
  199. Dataset supports parallel processing at scale:
  200. * transformations such as :py:meth:`.map_batches()`
  201. * aggregations such as :py:meth:`.min()`/:py:meth:`.max()`/:py:meth:`.mean()`,
  202. * grouping via :py:meth:`.groupby()`,
  203. * shuffling operations such as :py:meth:`.sort()`, :py:meth:`.random_shuffle()`, and :py:meth:`.repartition()`
  204. * joining via :py:meth:`.join()`
  205. Examples:
  206. >>> import ray
  207. >>> ds = ray.data.range(1000)
  208. >>> # Transform batches (Dict[str, np.ndarray]) with map_batches().
  209. >>> ds.map_batches(lambda batch: {"id": batch["id"] * 2}) # doctest: +ELLIPSIS
  210. MapBatches(<lambda>)
  211. +- Dataset(num_rows=1000, schema={id: int64})
  212. >>> # Compute the maximum.
  213. >>> ds.max("id")
  214. 999
  215. >>> # Shuffle this dataset randomly.
  216. >>> ds.random_shuffle() # doctest: +ELLIPSIS
  217. shape: (1000, 1)
  218. ╭───────╮
  219. │ id │
  220. │ --- │
  221. │ int64 │
  222. ╰───────╯
  223. (Dataset isn't materialized)
  224. >>> # Sort it back in order.
  225. >>> ds.sort("id") # doctest: +ELLIPSIS
  226. shape: (1000, 1)
  227. ╭───────╮
  228. │ id │
  229. │ --- │
  230. │ int64 │
  231. ╰───────╯
  232. (Dataset isn't materialized)
  233. Both unexecuted and materialized Datasets can be passed between Ray tasks and
  234. actors without incurring a copy. Dataset supports conversion to/from several
  235. more featureful dataframe libraries (e.g., Spark, Dask, Modin, MARS), and are also
  236. compatible with distributed TensorFlow / PyTorch.
  237. """
  238. def __init__(
  239. self,
  240. plan: ExecutionPlan,
  241. logical_plan: LogicalPlan,
  242. ):
  243. """Construct a Dataset (internal API).
  244. The constructor is not part of the Dataset API. Use the ``ray.data.*``
  245. read methods to construct a dataset.
  246. """
  247. assert isinstance(plan, ExecutionPlan), type(plan)
  248. usage_lib.record_library_usage("dataset") # Legacy telemetry name.
  249. self._plan = plan
  250. self._logical_plan = logical_plan
  251. self._plan.link_logical_plan(logical_plan)
  252. # Handle to currently running executor for this dataset.
  253. self._current_executor: Optional["Executor"] = None
  254. self._write_ds = None
  255. self._set_uuid(_StatsManager.gen_dataset_id_from_stats_actor())
  256. @staticmethod
  257. def copy(
  258. ds: "Dataset", _deep_copy: bool = False, _as: Optional[type] = None
  259. ) -> "Dataset":
  260. if not _as:
  261. _as = type(ds)
  262. if _deep_copy:
  263. return _as(ds._plan.deep_copy(), ds._logical_plan)
  264. else:
  265. return _as(ds._plan.copy(), ds._logical_plan)
  266. @PublicAPI(api_group=BT_API_GROUP)
  267. def map(
  268. self,
  269. fn: Callable[[Dict[str, Any]], Dict[str, Any]],
  270. *,
  271. compute: Optional[ComputeStrategy] = None,
  272. fn_args: Optional[Iterable[Any]] = None,
  273. fn_kwargs: Optional[Dict[str, Any]] = None,
  274. fn_constructor_args: Optional[Iterable[Any]] = None,
  275. fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
  276. num_cpus: Optional[float] = None,
  277. num_gpus: Optional[float] = None,
  278. memory: Optional[float] = None,
  279. concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
  280. ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
  281. **ray_remote_args,
  282. ) -> "Dataset":
  283. """Apply the given function to each row of this dataset.
  284. Use this method to transform your data. To learn more, see
  285. :ref:`Transforming rows <transforming_rows>`.
  286. You can use either a function or a callable class to perform the transformation.
  287. For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses
  288. stateful Ray actors. For more information, see
  289. :ref:`Stateful Transforms <stateful_transforms>`.
  290. .. tip::
  291. If your transformation is vectorized like most NumPy or pandas operations,
  292. :meth:`~Dataset.map_batches` might be faster.
  293. .. warning::
  294. Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental,
  295. and may result in scheduling or stability issues. Please
  296. `report any issues <https://github.com/ray-project/ray/issues/new/choose>`_
  297. to the Ray team.
  298. Examples:
  299. .. testcode::
  300. import os
  301. from typing import Any, Dict
  302. import ray
  303. def parse_filename(row: Dict[str, Any]) -> Dict[str, Any]:
  304. row["filename"] = os.path.basename(row["path"])
  305. return row
  306. ds = (
  307. ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple", include_paths=True)
  308. .map(parse_filename)
  309. )
  310. print(ds.schema())
  311. .. testoutput::
  312. Column Type
  313. ------ ----
  314. image ArrowTensorTypeV2(shape=(32, 32, 3), dtype=uint8)
  315. path string
  316. filename string
  317. Time complexity: O(dataset size / parallelism)
  318. Args:
  319. fn: The function to apply to each row, or a class type
  320. that can be instantiated to create such a callable.
  321. compute: The compute strategy to use for the map operation.
  322. * If ``compute`` is not specified for a function, will use ``ray.data.TaskPoolStrategy()`` to launch concurrent tasks based on the available resources and number of input blocks.
  323. * Use ``ray.data.TaskPoolStrategy(size=n)`` to launch at most ``n`` concurrent Ray tasks.
  324. * If ``compute`` is not specified for a callable class, will use ``ray.data.ActorPoolStrategy(min_size=1, max_size=None)`` to launch an autoscaling actor pool from 1 to unlimited workers.
  325. * Use ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed size actor pool of ``n`` workers.
  326. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` to use an autoscaling actor pool from ``m`` to ``n`` workers.
  327. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n, initial_size=initial)`` to use an autoscaling actor pool from ``m`` to ``n`` workers, with an initial size of ``initial``.
  328. fn_args: Positional arguments to pass to ``fn`` after the first argument.
  329. These arguments are top-level arguments to the underlying Ray task.
  330. fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are
  331. top-level arguments to the underlying Ray task.
  332. fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
  333. You can only provide this if ``fn`` is a callable class. These arguments
  334. are top-level arguments in the underlying Ray actor construction task.
  335. fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor.
  336. This can only be provided if ``fn`` is a callable class. These arguments
  337. are top-level arguments in the underlying Ray actor construction task.
  338. num_cpus: The number of CPUs to reserve for each parallel map worker.
  339. num_gpus: The number of GPUs to reserve for each parallel map worker. For
  340. example, specify `num_gpus=1` to request 1 GPU for each parallel map
  341. worker.
  342. memory: The heap memory in bytes to reserve for each parallel map worker.
  343. concurrency: This argument is deprecated. Use ``compute`` argument.
  344. ray_remote_args_fn: A function that returns a dictionary of remote args
  345. passed to each map worker. The purpose of this argument is to generate
  346. dynamic arguments for each actor/task, and will be called each time prior
  347. to initializing the worker. Args returned from this dict will always
  348. override the args in ``ray_remote_args``. Note: this is an advanced,
  349. experimental feature.
  350. ray_remote_args: Additional resource requirements to request from
  351. Ray for each map worker. See :func:`ray.remote` for details.
  352. .. seealso::
  353. :meth:`~Dataset.flat_map`
  354. Call this method to create new rows from existing ones. Unlike
  355. :meth:`~Dataset.map`, a function passed to
  356. :meth:`~Dataset.flat_map` can return multiple rows.
  357. :meth:`~Dataset.map_batches`
  358. Call this method to transform batches of data.
  359. """ # noqa: E501
  360. compute = get_compute_strategy(
  361. fn,
  362. fn_constructor_args=fn_constructor_args,
  363. compute=compute,
  364. concurrency=concurrency,
  365. )
  366. ray_remote_args = merge_resources_to_ray_remote_args(
  367. num_cpus,
  368. num_gpus,
  369. memory,
  370. ray_remote_args,
  371. )
  372. plan = self._plan.copy()
  373. map_op = MapRows(
  374. self._logical_plan.dag,
  375. fn,
  376. fn_args=fn_args,
  377. fn_kwargs=fn_kwargs,
  378. fn_constructor_args=fn_constructor_args,
  379. fn_constructor_kwargs=fn_constructor_kwargs,
  380. compute=compute,
  381. ray_remote_args_fn=ray_remote_args_fn,
  382. ray_remote_args=ray_remote_args,
  383. )
  384. logical_plan = LogicalPlan(map_op, self.context)
  385. return Dataset(plan, logical_plan)
  386. @Deprecated(message="Use set_name() instead", warning=True)
  387. def _set_name(self, name: Optional[str]):
  388. self.set_name(name)
  389. def set_name(self, name: Optional[str]):
  390. """Set the name of the dataset.
  391. Used as a prefix for metrics tags.
  392. """
  393. self._plan._dataset_name = name
  394. @property
  395. @Deprecated(message="Use name() instead", warning=True)
  396. def _name(self) -> Optional[str]:
  397. return self.name
  398. @property
  399. def name(self) -> Optional[str]:
  400. """Returns the user-defined dataset name"""
  401. return self._plan._dataset_name
  402. def get_dataset_id(self) -> str:
  403. """Unique ID of the dataset, including the dataset name,
  404. UUID, and current execution index.
  405. """
  406. return self._plan.get_dataset_id()
  407. @PublicAPI(api_group=BT_API_GROUP)
  408. def map_batches(
  409. self,
  410. fn: UserDefinedFunction[DataBatch, DataBatch],
  411. *,
  412. batch_size: Union[int, None, Literal["default"]] = None,
  413. compute: Optional[ComputeStrategy] = None,
  414. batch_format: Optional[str] = "default",
  415. zero_copy_batch: bool = True,
  416. fn_args: Optional[Iterable[Any]] = None,
  417. fn_kwargs: Optional[Dict[str, Any]] = None,
  418. fn_constructor_args: Optional[Iterable[Any]] = None,
  419. fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
  420. num_cpus: Optional[float] = None,
  421. num_gpus: Optional[float] = None,
  422. memory: Optional[float] = None,
  423. concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
  424. udf_modifying_row_count: bool = True,
  425. ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
  426. **ray_remote_args,
  427. ) -> "Dataset":
  428. """Apply the given function to batches of data.
  429. This method is useful for preprocessing data and performing inference. To learn
  430. more, see :ref:`Transforming batches <transforming_batches>`.
  431. You can use either a function or a callable class to perform the transformation.
  432. For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses
  433. stateful Ray actors. For more information, see
  434. :ref:`Stateful Transforms <stateful_transforms>`.
  435. .. tip::
  436. To understand the format of the input to ``fn``, call :meth:`~Dataset.take_batch`
  437. on the dataset to get a batch in the same format as will be passed to ``fn``.
  438. .. note::
  439. ``fn`` should generally avoid modifying data buffers behind its input
  440. since these could be zero-copy views into the underlying object residing
  441. inside Ray's Object Store.
  442. To perform any modifications it's recommended to copy the data you
  443. want to modify.
  444. In rare cases when you can't copy inside your UDF, you can instead
  445. specify ``zero_copy_batch=False`` and then Ray Data will copy the
  446. *whole* batch for you, providing ``fn`` with a copy rather than
  447. a zero-copy view.
  448. .. warning::
  449. Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental,
  450. and may result in scheduling or stability issues. Please
  451. `report any issues <https://github.com/ray-project/ray/issues/new/choose>`_
  452. to the Ray team.
  453. Examples:
  454. Call :meth:`~Dataset.map_batches` to transform your data.
  455. .. testcode::
  456. from typing import Dict
  457. import numpy as np
  458. import ray
  459. def add_dog_years(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
  460. batch["age_in_dog_years"] = 7 * batch["age"]
  461. return batch
  462. ds = (
  463. ray.data.from_items([
  464. {"name": "Luna", "age": 4},
  465. {"name": "Rory", "age": 14},
  466. {"name": "Scout", "age": 9},
  467. ])
  468. .map_batches(add_dog_years)
  469. )
  470. ds.show()
  471. .. testoutput::
  472. {'name': 'Luna', 'age': 4, 'age_in_dog_years': 28}
  473. {'name': 'Rory', 'age': 14, 'age_in_dog_years': 98}
  474. {'name': 'Scout', 'age': 9, 'age_in_dog_years': 63}
  475. If your function returns large objects, yield outputs in chunks.
  476. .. testcode::
  477. from typing import Dict
  478. import ray
  479. import numpy as np
  480. def map_fn_with_large_output(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
  481. for i in range(3):
  482. yield {"large_output": np.ones((100, 1000))}
  483. ds = (
  484. ray.data.from_items([1])
  485. .map_batches(map_fn_with_large_output)
  486. )
  487. If you require stateful transformation,
  488. use Python callable class. Here is an example showing how to use stateful transforms to create model inference workers, without having to reload the model on each call.
  489. .. testcode::
  490. from typing import Dict
  491. import numpy as np
  492. import torch
  493. import ray
  494. class TorchPredictor:
  495. def __init__(self):
  496. self.model = torch.nn.Identity().cuda()
  497. self.model.eval()
  498. def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
  499. inputs = torch.as_tensor(batch["data"], dtype=torch.float32).cuda()
  500. with torch.inference_mode():
  501. batch["output"] = self.model(inputs).detach().cpu().numpy()
  502. return batch
  503. ds = (
  504. ray.data.from_numpy(np.ones((32, 100)))
  505. .map_batches(
  506. TorchPredictor,
  507. # Two workers with one GPU each
  508. compute=ray.data.ActorPoolStrategy(size=2),
  509. # Batch size is required if you're using GPUs.
  510. batch_size=4,
  511. num_gpus=1
  512. )
  513. )
  514. To learn more, see
  515. :ref:`End-to-end: Offline Batch Inference <batch_inference_home>`.
  516. Args:
  517. fn: The function or generator to apply to a record batch, or a class type
  518. that can be instantiated to create such a callable. Note ``fn`` must be
  519. pickle-able.
  520. batch_size: The desired number of rows in each batch, or ``None`` to use
  521. entire blocks as batches (blocks may contain different numbers of rows).
  522. The actual size of the batch provided to ``fn`` may be smaller than
  523. ``batch_size`` if ``batch_size`` doesn't evenly divide the block(s) sent
  524. to a given map task. Default ``batch_size`` is ``None``.
  525. compute: The compute strategy to use for the map operation.
  526. * If ``compute`` is not specified for a function, will use ``ray.data.TaskPoolStrategy()`` to launch concurrent tasks based on the available resources and number of input blocks.
  527. * Use ``ray.data.TaskPoolStrategy(size=n)`` to launch at most ``n`` concurrent Ray tasks.
  528. * If ``compute`` is not specified for a callable class, will use ``ray.data.ActorPoolStrategy(min_size=1, max_size=None)`` to launch an autoscaling actor pool from 1 to unlimited workers.
  529. * Use ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed size actor pool of ``n`` workers.
  530. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` to use an autoscaling actor pool from ``m`` to ``n`` workers.
  531. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n, initial_size=initial)`` to use an autoscaling actor pool from ``m`` to ``n`` workers, with an initial size of ``initial``.
  532. batch_format: If ``"default"`` or ``"numpy"``, batches are
  533. ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are
  534. ``pandas.DataFrame``. If ``"pyarrow"``, batches are
  535. ``pyarrow.Table``. If ``batch_format`` is set to ``None`` input
  536. block format will be used.
  537. zero_copy_batch: Whether ``fn`` should be provided zero-copy, read-only
  538. batches. If this is ``True`` and no copy is required for the
  539. ``batch_format`` conversion, the batch is a zero-copy, read-only
  540. view on data in Ray's object store, which can decrease memory
  541. utilization and improve performance. Setting this to ``False``,
  542. will make a copy of the *whole* batch, therefore allowing UDF to
  543. modify underlying data buffers (like tensors, binary arrays, etc)
  544. in place. It's recommended to copy only the data you need to
  545. modify instead of resorting to copying the whole batch.
  546. fn_args: Positional arguments to pass to ``fn`` after the first argument.
  547. These arguments are top-level arguments to the underlying Ray task.
  548. fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are
  549. top-level arguments to the underlying Ray task.
  550. fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
  551. You can only provide this if ``fn`` is a callable class. These arguments
  552. are top-level arguments in the underlying Ray actor construction task.
  553. fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor.
  554. This can only be provided if ``fn`` is a callable class. These arguments
  555. are top-level arguments in the underlying Ray actor construction task.
  556. num_cpus: The number of CPUs to reserve for each parallel map worker.
  557. num_gpus: The number of GPUs to reserve for each parallel map worker. For
  558. example, specify `num_gpus=1` to request 1 GPU for each parallel map
  559. worker.
  560. memory: The heap memory in bytes to reserve for each parallel map worker.
  561. concurrency: This argument is deprecated. Use ``compute`` argument.
  562. udf_modifying_row_count: If your UDF produces the same number of output rows
  563. as it receives, set this parameter to False. It allows Ray Data to
  564. perform more optimizations like limit pushdown.
  565. ray_remote_args_fn: A function that returns a dictionary of remote args
  566. passed to each map worker. The purpose of this argument is to generate
  567. dynamic arguments for each actor/task, and will be called each time prior
  568. to initializing the worker. Args returned from this dict will always
  569. override the args in ``ray_remote_args``. Note: this is an advanced,
  570. experimental feature.
  571. ray_remote_args: Additional resource requirements to request from
  572. Ray for each map worker. See :func:`ray.remote` for details.
  573. .. note::
  574. The size of the batches provided to ``fn`` might be smaller than the
  575. specified ``batch_size`` if ``batch_size`` doesn't evenly divide the
  576. block(s) sent to a given map task.
  577. If ``batch_size`` is set and each input block is smaller than the
  578. ``batch_size``, Ray Data will bundle up many blocks as the input for one
  579. task, until their total size is equal to or greater than the given
  580. ``batch_size``.
  581. If ``batch_size`` is not set, the bundling will not be performed. Each task
  582. will receive entire input block as a batch.
  583. .. seealso::
  584. :meth:`~Dataset.iter_batches`
  585. Call this function to iterate over batches of data.
  586. :meth:`~Dataset.take_batch`
  587. Call this function to get a batch of data from the dataset
  588. in the same format as will be passed to the `fn` function of
  589. :meth:`~Dataset.map_batches`.
  590. :meth:`~Dataset.flat_map`
  591. Call this method to create new records from existing ones. Unlike
  592. :meth:`~Dataset.map`, a function passed to :meth:`~Dataset.flat_map`
  593. can return multiple records.
  594. :meth:`~Dataset.map`
  595. Call this method to transform one record at time.
  596. """ # noqa: E501
  597. use_gpus = num_gpus is not None and num_gpus > 0
  598. if use_gpus and (batch_size is None or batch_size == "default"):
  599. raise ValueError(
  600. "You must provide `batch_size` to `map_batches` when requesting GPUs. "
  601. "The optimal batch size depends on the model, data, and GPU used. "
  602. "We recommend using the largest batch size that doesn't result "
  603. "in your GPU device running out of memory. You can view the GPU memory "
  604. "usage via the Ray dashboard."
  605. )
  606. if isinstance(batch_size, int) and batch_size < 1:
  607. raise ValueError("Batch size can't be negative or 0")
  608. return self._map_batches_without_batch_size_validation(
  609. fn,
  610. batch_size=batch_size,
  611. compute=compute,
  612. batch_format=batch_format,
  613. zero_copy_batch=zero_copy_batch,
  614. fn_args=fn_args,
  615. fn_kwargs=fn_kwargs,
  616. fn_constructor_args=fn_constructor_args,
  617. fn_constructor_kwargs=fn_constructor_kwargs,
  618. num_cpus=num_cpus,
  619. num_gpus=num_gpus,
  620. memory=memory,
  621. concurrency=concurrency,
  622. udf_modifying_row_count=udf_modifying_row_count,
  623. ray_remote_args_fn=ray_remote_args_fn,
  624. **ray_remote_args,
  625. )
  626. def _map_batches_without_batch_size_validation(
  627. self,
  628. fn: UserDefinedFunction[DataBatch, DataBatch],
  629. *,
  630. batch_size: Union[int, None, Literal["default"]],
  631. compute: Optional[ComputeStrategy],
  632. batch_format: Optional[str],
  633. zero_copy_batch: bool,
  634. fn_args: Optional[Iterable[Any]],
  635. fn_kwargs: Optional[Dict[str, Any]],
  636. fn_constructor_args: Optional[Iterable[Any]],
  637. fn_constructor_kwargs: Optional[Dict[str, Any]],
  638. num_cpus: Optional[float],
  639. num_gpus: Optional[float],
  640. memory: Optional[float],
  641. concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]],
  642. udf_modifying_row_count: bool,
  643. ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]],
  644. **ray_remote_args,
  645. ):
  646. # NOTE: The `map_groups` implementation calls `map_batches` with
  647. # `batch_size=None`. The issue is that if you request GPUs with
  648. # `batch_size=None`, then `map_batches` raises a value error. So, to allow users
  649. # to call `map_groups` with GPUs, we need a separate method that doesn't
  650. # perform batch size validation.
  651. if batch_size == "default":
  652. warnings.warn(
  653. "Passing 'default' to `map_batches` is deprecated and won't be "
  654. "supported after September 2025. Use `batch_size=None` instead.",
  655. DeprecationWarning,
  656. )
  657. batch_size = None
  658. compute = get_compute_strategy(
  659. fn,
  660. fn_constructor_args=fn_constructor_args,
  661. compute=compute,
  662. concurrency=concurrency,
  663. )
  664. if num_cpus is not None:
  665. ray_remote_args["num_cpus"] = num_cpus
  666. if num_gpus is not None:
  667. ray_remote_args["num_gpus"] = num_gpus
  668. if memory is not None:
  669. ray_remote_args["memory"] = memory
  670. batch_format = _apply_batch_format(batch_format)
  671. plan = self._plan.copy()
  672. map_batches_op = MapBatches(
  673. self._logical_plan.dag,
  674. fn,
  675. batch_size=batch_size,
  676. can_modify_num_rows=udf_modifying_row_count,
  677. batch_format=batch_format,
  678. zero_copy_batch=zero_copy_batch,
  679. min_rows_per_bundled_input=batch_size,
  680. fn_args=fn_args,
  681. fn_kwargs=fn_kwargs,
  682. fn_constructor_args=fn_constructor_args,
  683. fn_constructor_kwargs=fn_constructor_kwargs,
  684. compute=compute,
  685. ray_remote_args_fn=ray_remote_args_fn,
  686. ray_remote_args=ray_remote_args,
  687. )
  688. logical_plan = LogicalPlan(map_batches_op, self.context)
  689. return Dataset(plan, logical_plan)
  690. @PublicAPI(api_group=EXPRESSION_API_GROUP, stability="alpha")
  691. def with_column(
  692. self,
  693. column_name: str,
  694. expr: Expr,
  695. *,
  696. compute: Optional[ComputeStrategy] = None,
  697. **ray_remote_args,
  698. ) -> "Dataset":
  699. """
  700. Add a new column to the dataset via an expression.
  701. This method allows you to add a new column to a dataset by applying an
  702. expression. The expression can be composed of existing columns, literals,
  703. and user-defined functions (UDFs).
  704. For callable class UDFs, Ray Data automatically uses actor semantics to maintain
  705. state across batches. You can customize the compute strategy to control parallelism
  706. and resource allocation.
  707. Examples:
  708. >>> import ray
  709. >>> from ray.data.expressions import col
  710. >>> ds = ray.data.range(100)
  711. >>> # Add a new column 'id_2' by multiplying 'id' by 2.
  712. >>> ds.with_column("id_2", col("id") * 2).show(2)
  713. {'id': 0, 'id_2': 0}
  714. {'id': 1, 'id_2': 2}
  715. >>> # Using a UDF with with_column
  716. >>> from ray.data.datatype import DataType
  717. >>> from ray.data.expressions import udf
  718. >>> import pyarrow.compute as pc
  719. >>>
  720. >>> @udf(return_dtype=DataType.int32())
  721. ... def add_one(column):
  722. ... return pc.add(column, 1)
  723. >>>
  724. >>> ds.with_column("id_plus_one", add_one(col("id"))).show(2)
  725. {'id': 0, 'id_plus_one': 1}
  726. {'id': 1, 'id_plus_one': 2}
  727. >>> # Using a callable class UDF (automatically uses actors)
  728. >>> @udf(return_dtype=DataType.int32())
  729. ... class AddOffset:
  730. ... def __init__(self, offset):
  731. ... self.offset = offset
  732. ... def __call__(self, x):
  733. ... return pc.add(x, self.offset)
  734. >>>
  735. >>> add_five = AddOffset(5)
  736. >>> ds.with_column("id_plus_five", add_five(col("id"))).show(2)
  737. {'id': 0, 'id_plus_five': 5}
  738. {'id': 1, 'id_plus_five': 6}
  739. Args:
  740. column_name: The name of the new column.
  741. expr: An expression that defines the new column values.
  742. compute: The compute strategy to use for the projection operation.
  743. If not specified and the expression contains callable class UDFs,
  744. Ray Data automatically uses ``ActorPoolStrategy`` for actor semantics.
  745. Otherwise, uses ``TaskPoolStrategy``.
  746. * Use ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed size
  747. actor pool of ``n`` workers.
  748. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` to use
  749. an autoscaling actor pool from ``m`` to ``n`` workers.
  750. **ray_remote_args: Additional resource requirements to request from
  751. Ray for the map tasks (e.g., `num_gpus=1`).
  752. Returns:
  753. A new dataset with the added column evaluated via the expression.
  754. """
  755. # TODO: update schema based on the expression AST.
  756. from ray.data._internal.logical.operators import Download, Project
  757. # TODO: Once the expression API supports UDFs, we can clean up the code here.
  758. from ray.data.expressions import DownloadExpr
  759. plan = self._plan.copy()
  760. if isinstance(expr, DownloadExpr):
  761. download_op = Download(
  762. self._logical_plan.dag,
  763. uri_column_names=[expr.uri_column_name],
  764. output_bytes_column_names=[column_name],
  765. ray_remote_args=ray_remote_args,
  766. filesystem=expr.filesystem,
  767. )
  768. logical_plan = LogicalPlan(download_op, self.context)
  769. else:
  770. project_op = Project(
  771. self._logical_plan.dag,
  772. exprs=[StarExpr(), expr.alias(column_name)],
  773. compute=compute,
  774. ray_remote_args=ray_remote_args,
  775. )
  776. logical_plan = LogicalPlan(project_op, self.context)
  777. return Dataset(plan, logical_plan)
  778. @Deprecated(message="Use `with_column` API instead")
  779. @PublicAPI(api_group=BT_API_GROUP)
  780. def add_column(
  781. self,
  782. col: str,
  783. fn: Callable[
  784. [DataBatch],
  785. DataBatchColumn,
  786. ],
  787. *,
  788. batch_format: Optional[str] = "pandas",
  789. compute: Optional[str] = None,
  790. concurrency: Optional[int] = None,
  791. **ray_remote_args,
  792. ) -> "Dataset":
  793. """Add the given column to the dataset.
  794. A function generating the new column values given the batch in pyarrow or pandas
  795. format must be specified. This function must operate on batches of
  796. `batch_format`.
  797. Examples:
  798. >>> import ray
  799. >>> ds = ray.data.range(100)
  800. >>> ds.schema()
  801. Column Type
  802. ------ ----
  803. id int64
  804. Add a new column equal to ``id * 2``.
  805. >>> ds.add_column("new_id", lambda df: df["id"] * 2).schema()
  806. Column Type
  807. ------ ----
  808. id int64
  809. new_id int64
  810. Time complexity: O(dataset size / parallelism)
  811. Args:
  812. col: Name of the column to add. If the name already exists, the
  813. column is overwritten.
  814. fn: Map function generating the column values given a batch of
  815. records in pandas format.
  816. batch_format: If ``"default"`` or ``"numpy"``, batches are
  817. ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are
  818. ``pandas.DataFrame``. If ``"pyarrow"``, batches are
  819. ``pyarrow.Table``. If ``"numpy"``, batches are
  820. ``Dict[str, numpy.ndarray]``.
  821. compute: This argument is deprecated. Use ``concurrency`` argument.
  822. concurrency: The maximum number of Ray workers to use concurrently.
  823. ray_remote_args: Additional resource requirements to request from
  824. Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
  825. :func:`ray.remote` for details.
  826. """
  827. # Check that batch_format
  828. accepted_batch_formats = ["pandas", "pyarrow", "numpy"]
  829. if batch_format not in accepted_batch_formats:
  830. raise ValueError(
  831. f"batch_format argument must be on of {accepted_batch_formats}, "
  832. f"got: {batch_format}"
  833. )
  834. def add_column(batch: DataBatch) -> DataBatch:
  835. column = fn(batch)
  836. if batch_format == "pandas":
  837. import pandas as pd
  838. # The index of the column must be set
  839. # to align with the index of the batch.
  840. if isinstance(column, (pd.DataFrame, pd.Index, pd.Series)):
  841. column.index = batch.index
  842. batch.loc[:, col] = column
  843. return batch
  844. elif batch_format == "pyarrow":
  845. import pyarrow as pa
  846. assert isinstance(column, (pa.Array, pa.ChunkedArray)), (
  847. f"For pyarrow batch format, the function must return a pyarrow "
  848. f"Array, got: {type(column)}"
  849. )
  850. # Historically, this method was written for pandas batch format.
  851. # To resolve https://github.com/ray-project/ray/issues/48090,
  852. # we also allow pyarrow batch format which is preferred but would be
  853. # a breaking change to enforce.
  854. # For pyarrow, the index of the column will be -1 if it is missing in
  855. # which case we'll want to append it
  856. column_idx = batch.schema.get_field_index(col)
  857. if column_idx == -1:
  858. return batch.append_column(col, column)
  859. return batch.set_column(column_idx, col, column)
  860. else:
  861. # batch format is assumed to be numpy since we checked at the
  862. # beginning of the add_column function
  863. assert isinstance(column, np.ndarray), (
  864. f"For numpy batch format, the function must return a "
  865. f"numpy.ndarray, got: {type(column)}"
  866. )
  867. batch[col] = column
  868. return batch
  869. if not callable(fn):
  870. raise ValueError("`fn` must be callable, got {}".format(fn))
  871. return self.map_batches(
  872. add_column,
  873. batch_format=batch_format,
  874. compute=compute,
  875. concurrency=concurrency,
  876. zero_copy_batch=True,
  877. **ray_remote_args,
  878. )
  879. @PublicAPI(api_group=BT_API_GROUP)
  880. def drop_columns(
  881. self,
  882. cols: List[str],
  883. *,
  884. compute: Optional[str] = None,
  885. concurrency: Optional[int] = None,
  886. **ray_remote_args,
  887. ) -> "Dataset":
  888. """Drop one or more columns from the dataset.
  889. Examples:
  890. >>> import ray
  891. >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
  892. >>> ds.schema()
  893. Column Type
  894. ------ ----
  895. sepal.length double
  896. sepal.width double
  897. petal.length double
  898. petal.width double
  899. variety string
  900. >>> ds.drop_columns(["variety"]).schema()
  901. Column Type
  902. ------ ----
  903. sepal.length double
  904. sepal.width double
  905. petal.length double
  906. petal.width double
  907. Time complexity: O(dataset size / parallelism)
  908. Args:
  909. cols: Names of the columns to drop. If any name does not exist,
  910. an exception is raised. Column names must be unique.
  911. compute: This argument is deprecated. Use ``concurrency`` argument.
  912. concurrency: The maximum number of Ray workers to use concurrently.
  913. ray_remote_args: Additional resource requirements to request from
  914. Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
  915. :func:`ray.remote` for details.
  916. """ # noqa: E501
  917. if len(cols) != len(set(cols)):
  918. raise ValueError(f"drop_columns expects unique column names, got: {cols}")
  919. def drop_columns(batch):
  920. return batch.drop(cols)
  921. return self.map_batches(
  922. drop_columns,
  923. batch_format="pyarrow",
  924. zero_copy_batch=True,
  925. compute=compute,
  926. concurrency=concurrency,
  927. **ray_remote_args,
  928. )
  929. @PublicAPI(api_group=BT_API_GROUP)
  930. def select_columns(
  931. self,
  932. cols: Union[str, List[str]],
  933. *,
  934. compute: Union[str, ComputeStrategy] = None,
  935. concurrency: Optional[int] = None,
  936. **ray_remote_args,
  937. ) -> "Dataset":
  938. """Select one or more columns from the dataset.
  939. Specified columns must be in the dataset schema.
  940. .. tip::
  941. If you're reading parquet files with :meth:`ray.data.read_parquet`,
  942. you might be able to speed it up by using projection pushdown; see
  943. :ref:`Parquet column pruning <parquet_column_pruning>` for details.
  944. Examples:
  945. >>> import ray
  946. >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
  947. >>> ds.schema()
  948. Column Type
  949. ------ ----
  950. sepal.length double
  951. sepal.width double
  952. petal.length double
  953. petal.width double
  954. variety string
  955. >>> ds.select_columns(["sepal.length", "sepal.width"]).schema()
  956. Column Type
  957. ------ ----
  958. sepal.length double
  959. sepal.width double
  960. Time complexity: O(dataset size / parallelism)
  961. Args:
  962. cols: Names of the columns to select. If a name isn't in the
  963. dataset schema, an exception is raised. Columns also should be unique.
  964. compute: This argument is deprecated. Use ``concurrency`` argument.
  965. concurrency: The maximum number of Ray workers to use concurrently.
  966. ray_remote_args: Additional resource requirements to request from
  967. Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
  968. :func:`ray.remote` for details.
  969. """ # noqa: E501
  970. from ray.data.expressions import col
  971. if isinstance(cols, str):
  972. exprs = [col(cols)]
  973. elif isinstance(cols, list):
  974. if not all(isinstance(col, str) for col in cols):
  975. raise ValueError(
  976. "select_columns requires all elements of 'cols' to be strings."
  977. )
  978. if len(cols) != len(set(cols)):
  979. raise ValueError(
  980. "select_columns expected unique column names, "
  981. f"got duplicate column names: {cols}"
  982. )
  983. exprs = [col(c) for c in cols]
  984. else:
  985. raise TypeError(
  986. "select_columns requires 'cols' to be a string or a list of strings."
  987. )
  988. compute = TaskPoolStrategy(size=concurrency)
  989. plan = self._plan.copy()
  990. select_op = Project(
  991. self._logical_plan.dag,
  992. exprs=exprs,
  993. compute=compute,
  994. ray_remote_args=ray_remote_args,
  995. )
  996. logical_plan = LogicalPlan(select_op, self.context)
  997. return Dataset(plan, logical_plan)
  998. @PublicAPI(api_group=BT_API_GROUP)
  999. def rename_columns(
  1000. self,
  1001. names: Union[List[str], Dict[str, str]],
  1002. *,
  1003. concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
  1004. **ray_remote_args,
  1005. ):
  1006. """Rename columns in the dataset.
  1007. Examples:
  1008. >>> import ray
  1009. >>> ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
  1010. >>> ds.schema()
  1011. Column Type
  1012. ------ ----
  1013. sepal.length double
  1014. sepal.width double
  1015. petal.length double
  1016. petal.width double
  1017. variety string
  1018. You can pass a dictionary mapping old column names to new column names.
  1019. >>> ds.rename_columns({"variety": "category"}).schema()
  1020. Column Type
  1021. ------ ----
  1022. sepal.length double
  1023. sepal.width double
  1024. petal.length double
  1025. petal.width double
  1026. category string
  1027. Or you can pass a list of new column names.
  1028. >>> ds.rename_columns(
  1029. ... ["sepal_length", "sepal_width", "petal_length", "petal_width", "variety"]
  1030. ... ).schema()
  1031. Column Type
  1032. ------ ----
  1033. sepal_length double
  1034. sepal_width double
  1035. petal_length double
  1036. petal_width double
  1037. variety string
  1038. Args:
  1039. names: A dictionary that maps old column names to new column names, or a
  1040. list of new column names.
  1041. concurrency: The maximum number of Ray workers to use concurrently.
  1042. ray_remote_args: Additional resource requirements to request from
  1043. Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
  1044. :func:`ray.remote` for details.
  1045. """ # noqa: E501
  1046. if isinstance(names, dict):
  1047. if not names:
  1048. raise ValueError("rename_columns received 'names' with no entries.")
  1049. if len(names.values()) != len(set(names.values())):
  1050. raise ValueError(
  1051. f"rename_columns received duplicate values in the 'names': {names}"
  1052. )
  1053. if not all(
  1054. isinstance(k, str) and isinstance(v, str) for k, v in names.items()
  1055. ):
  1056. raise ValueError(
  1057. "rename_columns requires both keys and values in the 'names' "
  1058. "to be strings."
  1059. )
  1060. exprs = [col(prev)._rename(new) for prev, new in names.items()]
  1061. elif isinstance(names, list):
  1062. if not names:
  1063. raise ValueError(
  1064. "rename_columns requires 'names' with at least one column name."
  1065. )
  1066. if len(names) != len(set(names)):
  1067. raise ValueError(
  1068. f"rename_columns received duplicate values in the 'names': {names}"
  1069. )
  1070. if not all(isinstance(col, str) for col in names):
  1071. raise ValueError(
  1072. "rename_columns requires all elements in the 'names' to be strings."
  1073. )
  1074. current_names = self.schema().names
  1075. if len(current_names) != len(names):
  1076. raise ValueError(
  1077. f"rename_columns requires 'names': {names} length match current "
  1078. f"schema names: {current_names}."
  1079. )
  1080. exprs = [col(prev)._rename(new) for prev, new in zip(current_names, names)]
  1081. else:
  1082. raise TypeError(
  1083. f"rename_columns expected names to be either List[str] or "
  1084. f"Dict[str, str], got {type(names)}."
  1085. )
  1086. if concurrency is not None and not isinstance(concurrency, int):
  1087. raise ValueError(
  1088. f"Expected `concurrency` to be an integer or `None`, but "
  1089. f"got {concurrency}."
  1090. )
  1091. # Construct the plan and project operation
  1092. compute = TaskPoolStrategy(size=concurrency)
  1093. plan = self._plan.copy()
  1094. select_op = Project(
  1095. self._logical_plan.dag,
  1096. exprs=[StarExpr(), *exprs],
  1097. compute=compute,
  1098. ray_remote_args=ray_remote_args,
  1099. )
  1100. logical_plan = LogicalPlan(select_op, self.context)
  1101. return Dataset(plan, logical_plan)
  1102. @PublicAPI(api_group=BT_API_GROUP)
  1103. def flat_map(
  1104. self,
  1105. fn: UserDefinedFunction[
  1106. Dict[str, Any], Union[List[Dict[str, Any]], Dict[str, Any]]
  1107. ],
  1108. *,
  1109. compute: Optional[ComputeStrategy] = None,
  1110. fn_args: Optional[Iterable[Any]] = None,
  1111. fn_kwargs: Optional[Dict[str, Any]] = None,
  1112. fn_constructor_args: Optional[Iterable[Any]] = None,
  1113. fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
  1114. num_cpus: Optional[float] = None,
  1115. num_gpus: Optional[float] = None,
  1116. memory: Optional[float] = None,
  1117. concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
  1118. ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
  1119. **ray_remote_args,
  1120. ) -> "Dataset":
  1121. """Apply the given function to each row and then flatten results.
  1122. Use this method if your transformation returns multiple rows for each input
  1123. row.
  1124. You can use either a function or a callable class to perform the transformation.
  1125. For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses
  1126. stateful Ray actors. For more information, see
  1127. :ref:`Stateful Transforms <stateful_transforms>`.
  1128. .. tip::
  1129. :meth:`~Dataset.map_batches` can also modify the number of rows. If your
  1130. transformation is vectorized like most NumPy and pandas operations,
  1131. it might be faster.
  1132. .. warning::
  1133. Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental,
  1134. and may result in scheduling or stability issues. Please
  1135. `report any issues <https://github.com/ray-project/ray/issues/new/choose>`_
  1136. to the Ray team.
  1137. Examples:
  1138. .. testcode::
  1139. from typing import Any, Dict, List
  1140. import ray
  1141. def duplicate_row(row: Dict[str, Any]) -> List[Dict[str, Any]]:
  1142. return [row] * 2
  1143. print(
  1144. ray.data.range(3)
  1145. .flat_map(duplicate_row)
  1146. .take_all()
  1147. )
  1148. .. testoutput::
  1149. [{'id': 0}, {'id': 0}, {'id': 1}, {'id': 1}, {'id': 2}, {'id': 2}]
  1150. Time complexity: O(dataset size / parallelism)
  1151. Args:
  1152. fn: The function or generator to apply to each record, or a class type
  1153. that can be instantiated to create such a callable.
  1154. compute: The compute strategy to use for the map operation.
  1155. * If ``compute`` is not specified for a function, will use ``ray.data.TaskPoolStrategy()`` to launch concurrent tasks based on the available resources and number of input blocks.
  1156. * Use ``ray.data.TaskPoolStrategy(size=n)`` to launch at most ``n`` concurrent Ray tasks.
  1157. * If ``compute`` is not specified for a callable class, will use ``ray.data.ActorPoolStrategy(min_size=1, max_size=None)`` to launch an autoscaling actor pool from 1 to unlimited workers.
  1158. * Use ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed size actor pool of ``n`` workers.
  1159. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` to use an autoscaling actor pool from ``m`` to ``n`` workers.
  1160. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n, initial_size=initial)`` to use an autoscaling actor pool from ``m`` to ``n`` workers, with an initial size of ``initial``.
  1161. fn_args: Positional arguments to pass to ``fn`` after the first argument.
  1162. These arguments are top-level arguments to the underlying Ray task.
  1163. fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are
  1164. top-level arguments to the underlying Ray task.
  1165. fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
  1166. You can only provide this if ``fn`` is a callable class. These arguments
  1167. are top-level arguments in the underlying Ray actor construction task.
  1168. fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor.
  1169. This can only be provided if ``fn`` is a callable class. These arguments
  1170. are top-level arguments in the underlying Ray actor construction task.
  1171. num_cpus: The number of CPUs to reserve for each parallel map worker.
  1172. num_gpus: The number of GPUs to reserve for each parallel map worker. For
  1173. example, specify `num_gpus=1` to request 1 GPU for each parallel map
  1174. worker.
  1175. memory: The heap memory in bytes to reserve for each parallel map worker.
  1176. concurrency: This argument is deprecated. Use ``compute`` argument.
  1177. ray_remote_args_fn: A function that returns a dictionary of remote args
  1178. passed to each map worker. The purpose of this argument is to generate
  1179. dynamic arguments for each actor/task, and will be called each time
  1180. prior to initializing the worker. Args returned from this dict will
  1181. always override the args in ``ray_remote_args``. Note: this is an
  1182. advanced, experimental feature.
  1183. ray_remote_args: Additional resource requirements to request from
  1184. Ray for each map worker. See :func:`ray.remote` for details.
  1185. .. seealso::
  1186. :meth:`~Dataset.map_batches`
  1187. Call this method to transform batches of data.
  1188. :meth:`~Dataset.map`
  1189. Call this method to transform one row at time.
  1190. """
  1191. compute = get_compute_strategy(
  1192. fn,
  1193. fn_constructor_args=fn_constructor_args,
  1194. compute=compute,
  1195. concurrency=concurrency,
  1196. )
  1197. ray_remote_args = merge_resources_to_ray_remote_args(
  1198. num_cpus,
  1199. num_gpus,
  1200. memory,
  1201. ray_remote_args,
  1202. )
  1203. plan = self._plan.copy()
  1204. op = FlatMap(
  1205. input_op=self._logical_plan.dag,
  1206. fn=fn,
  1207. fn_args=fn_args,
  1208. fn_kwargs=fn_kwargs,
  1209. fn_constructor_args=fn_constructor_args,
  1210. fn_constructor_kwargs=fn_constructor_kwargs,
  1211. compute=compute,
  1212. ray_remote_args_fn=ray_remote_args_fn,
  1213. ray_remote_args=ray_remote_args,
  1214. )
  1215. logical_plan = LogicalPlan(op, self.context)
  1216. return Dataset(plan, logical_plan)
  1217. @PublicAPI(api_group=BT_API_GROUP)
  1218. def filter(
  1219. self,
  1220. fn: Optional[UserDefinedFunction[Dict[str, Any], bool]] = None,
  1221. expr: Optional[Union[str, Expr]] = None,
  1222. *,
  1223. compute: Union[str, ComputeStrategy] = None,
  1224. fn_args: Optional[Iterable[Any]] = None,
  1225. fn_kwargs: Optional[Dict[str, Any]] = None,
  1226. fn_constructor_args: Optional[Iterable[Any]] = None,
  1227. fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
  1228. num_cpus: Optional[float] = None,
  1229. num_gpus: Optional[float] = None,
  1230. memory: Optional[float] = None,
  1231. concurrency: Optional[Union[int, Tuple[int, int], Tuple[int, int, int]]] = None,
  1232. ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
  1233. **ray_remote_args,
  1234. ) -> "Dataset":
  1235. """Filter out rows that don't satisfy the given predicate.
  1236. You can use either a function or a callable class or an expression to
  1237. perform the transformation.
  1238. For functions, Ray Data uses stateless Ray tasks. For classes, Ray Data uses
  1239. stateful Ray actors. For more information, see
  1240. :ref:`Stateful Transforms <stateful_transforms>`.
  1241. .. tip::
  1242. If you use the `expr` parameter with a predicate expression, Ray Data
  1243. optimizes your filter with native Arrow interfaces.
  1244. .. deprecated::
  1245. String expressions are deprecated and will be removed in a future version.
  1246. Use predicate expressions from `ray.data.expressions` instead.
  1247. Examples:
  1248. >>> import ray
  1249. >>> from ray.data.expressions import col
  1250. >>> ds = ray.data.range(100)
  1251. >>> # String expressions (deprecated - will warn)
  1252. >>> ds.filter(expr="id <= 4").take_all()
  1253. [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}]
  1254. >>> # Using predicate expressions (preferred)
  1255. >>> ds.filter(expr=(col("id") > 10) & (col("id") < 20)).take_all()
  1256. [{'id': 11}, {'id': 12}, {'id': 13}, {'id': 14}, {'id': 15}, {'id': 16}, {'id': 17}, {'id': 18}, {'id': 19}]
  1257. Time complexity: O(dataset size / parallelism)
  1258. Args:
  1259. fn: The predicate to apply to each row, or a class type
  1260. that can be instantiated to create such a callable.
  1261. expr: An expression that represents a predicate (boolean condition) for filtering.
  1262. Can be either a string expression (deprecated) or a predicate expression
  1263. from `ray.data.expressions`.
  1264. fn_args: Positional arguments to pass to ``fn`` after the first argument.
  1265. These arguments are top-level arguments to the underlying Ray task.
  1266. fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are
  1267. top-level arguments to the underlying Ray task.
  1268. fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
  1269. You can only provide this if ``fn`` is a callable class. These arguments
  1270. are top-level arguments in the underlying Ray actor construction task.
  1271. fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor.
  1272. This can only be provided if ``fn`` is a callable class. These arguments
  1273. are top-level arguments in the underlying Ray actor construction task.
  1274. compute: The compute strategy to use for the map operation.
  1275. * If ``compute`` is not specified for a function, will use ``ray.data.TaskPoolStrategy()`` to launch concurrent tasks based on the available resources and number of input blocks.
  1276. * Use ``ray.data.TaskPoolStrategy(size=n)`` to launch at most ``n`` concurrent Ray tasks.
  1277. * If ``compute`` is not specified for a callable class, will use ``ray.data.ActorPoolStrategy(min_size=1, max_size=None)`` to launch an autoscaling actor pool from 1 to unlimited workers.
  1278. * Use ``ray.data.ActorPoolStrategy(size=n)`` to use a fixed size actor pool of ``n`` workers.
  1279. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n)`` to use an autoscaling actor pool from ``m`` to ``n`` workers.
  1280. * Use ``ray.data.ActorPoolStrategy(min_size=m, max_size=n, initial_size=initial)`` to use an autoscaling actor pool from ``m`` to ``n`` workers, with an initial size of ``initial``.
  1281. num_cpus: The number of CPUs to reserve for each parallel map worker.
  1282. num_gpus: The number of GPUs to reserve for each parallel map worker. For
  1283. example, specify `num_gpus=1` to request 1 GPU for each parallel map
  1284. worker.
  1285. memory: The heap memory in bytes to reserve for each parallel map worker.
  1286. concurrency: This argument is deprecated. Use ``compute`` argument.
  1287. ray_remote_args_fn: A function that returns a dictionary of remote args
  1288. passed to each map worker. The purpose of this argument is to generate
  1289. dynamic arguments for each actor/task, and will be called each time
  1290. prior to initializing the worker. Args returned from this dict will
  1291. always override the args in ``ray_remote_args``. Note: this is an
  1292. advanced, experimental feature.
  1293. ray_remote_args: Additional resource requirements to request from
  1294. Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See
  1295. :func:`ray.remote` for details.
  1296. """
  1297. # Ensure exactly one of fn or expr is provided
  1298. provided_params = sum([fn is not None, expr is not None])
  1299. if provided_params != 1:
  1300. raise ValueError("Exactly one of 'fn' or 'expr' must be provided.")
  1301. # Helper function to check for incompatible function parameters
  1302. def _check_fn_params_incompatible(param_type):
  1303. if (
  1304. fn_args is not None
  1305. or fn_kwargs is not None
  1306. or fn_constructor_args is not None
  1307. or fn_constructor_kwargs is not None
  1308. ):
  1309. raise ValueError(
  1310. f"when '{param_type}' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' cannot be used."
  1311. )
  1312. # Merge ray remote args early
  1313. ray_remote_args = merge_resources_to_ray_remote_args(
  1314. num_cpus,
  1315. num_gpus,
  1316. memory,
  1317. ray_remote_args,
  1318. )
  1319. # Initialize Filter operator arguments with proper types
  1320. input_op = self._logical_plan.dag
  1321. predicate_expr: Optional[Expr] = None
  1322. filter_fn: Optional[UserDefinedFunction] = None
  1323. filter_fn_args: Optional[Iterable[Any]] = None
  1324. filter_fn_kwargs: Optional[Dict[str, Any]] = None
  1325. filter_fn_constructor_args: Optional[Iterable[Any]] = None
  1326. filter_fn_constructor_kwargs: Optional[Dict[str, Any]] = None
  1327. filter_compute: Optional[ComputeStrategy] = None
  1328. if expr is not None:
  1329. _check_fn_params_incompatible("expr")
  1330. # Check if expr is a string (deprecated) or Expr object
  1331. if isinstance(expr, str):
  1332. warnings.warn(
  1333. "String expressions are deprecated and will be removed in a future version. "
  1334. "Use predicate expressions from ray.data.expressions instead. "
  1335. "For example: from ray.data.expressions import col; "
  1336. "ds.filter(expr=col('column_name') > 5)",
  1337. DeprecationWarning,
  1338. stacklevel=2,
  1339. )
  1340. from ray.data._internal.planner.plan_expression.expression_evaluator import ( # noqa: E501
  1341. ExpressionEvaluator,
  1342. )
  1343. # TODO: (srinathk) bind the expression to the actual schema.
  1344. # If expr is a string, convert it to a pyarrow.dataset.Expression
  1345. # Initialize ExpressionEvaluator with valid columns, if available
  1346. # str -> Ray Data's Expression
  1347. predicate_expr = ExpressionEvaluator.parse_native_expression(expr)
  1348. else:
  1349. # expr is an Expr object (predicate expression)
  1350. predicate_expr = expr
  1351. filter_compute = TaskPoolStrategy(size=concurrency)
  1352. else:
  1353. warnings.warn(
  1354. "Use 'expr' instead of 'fn' when possible for performant filters."
  1355. )
  1356. if not callable(fn):
  1357. raise ValueError(
  1358. f"fn must be a UserDefinedFunction, but got "
  1359. f"{type(fn).__name__} instead."
  1360. )
  1361. filter_fn = fn
  1362. filter_fn_args = fn_args
  1363. filter_fn_kwargs = fn_kwargs
  1364. filter_fn_constructor_args = fn_constructor_args
  1365. filter_fn_constructor_kwargs = fn_constructor_kwargs
  1366. filter_compute = get_compute_strategy(
  1367. fn=fn,
  1368. fn_constructor_args=fn_constructor_args,
  1369. compute=compute,
  1370. concurrency=concurrency,
  1371. )
  1372. # Create Filter operator with explicitly typed arguments
  1373. filter_op = Filter(
  1374. input_op=input_op,
  1375. predicate_expr=predicate_expr,
  1376. fn=filter_fn,
  1377. fn_args=filter_fn_args,
  1378. fn_kwargs=filter_fn_kwargs,
  1379. fn_constructor_args=filter_fn_constructor_args,
  1380. fn_constructor_kwargs=filter_fn_constructor_kwargs,
  1381. compute=filter_compute,
  1382. ray_remote_args_fn=ray_remote_args_fn,
  1383. ray_remote_args=ray_remote_args,
  1384. )
  1385. plan = self._plan.copy()
  1386. logical_plan = LogicalPlan(filter_op, self.context)
  1387. return Dataset(plan, logical_plan)
  1388. @PublicAPI(api_group=SSR_API_GROUP)
  1389. def repartition(
  1390. self,
  1391. num_blocks: Optional[int] = None,
  1392. target_num_rows_per_block: Optional[int] = None,
  1393. *,
  1394. shuffle: bool = False,
  1395. keys: Optional[List[str]] = None,
  1396. sort: bool = False,
  1397. ) -> "Dataset":
  1398. """Repartition the :class:`Dataset` into exactly this number of
  1399. :ref:`blocks <dataset_concept>`.
  1400. This method can be useful to tune the performance of your pipeline. To learn
  1401. more, see :ref:`Advanced: Performance Tips and Tuning <data_performance_tips>`.
  1402. If you're writing data to files, you can also use this method to change the
  1403. number of output files. To learn more, see
  1404. :ref:`Changing the number of output files <changing-number-output-files>`.
  1405. .. note::
  1406. Repartition has three modes:
  1407. * When ``num_blocks`` and ``shuffle=True`` are specified Ray Data performs a full distributed shuffle producing exactly ``num_blocks`` blocks.
  1408. * When ``num_blocks`` and ``shuffle=False`` are specified, Ray Data does NOT perform full shuffle, instead opting in for splitting and combining of the blocks attempting to minimize the necessary data movement (relative to full-blown shuffle). Exactly ``num_blocks`` will be produced.
  1409. * If ``target_num_rows_per_block`` is set (exclusive with ``num_blocks`` and ``shuffle``), streaming repartitioning will be executed, where blocks will be made to carry no more than ``target_num_rows_per_block`` rows. Smaller blocks will be combined into bigger ones up to ``target_num_rows_per_block`` as well.
  1410. .. image:: /data/images/dataset-shuffle.svg
  1411. :align: center
  1412. ..
  1413. https://docs.google.com/drawings/d/132jhE3KXZsf29ho1yUdPrCHB9uheHBWHJhDQMXqIVPA/edit
  1414. Examples:
  1415. >>> import ray
  1416. >>> ds = ray.data.range(100).repartition(10).materialize()
  1417. >>> ds.num_blocks()
  1418. 10
  1419. Time complexity: O(dataset size / parallelism)
  1420. Args:
  1421. num_blocks: Number of blocks after repartitioning.
  1422. target_num_rows_per_block: [Experimental] The target number of rows per block to
  1423. repartition. Performs streaming repartitioning of the dataset (no shuffling).
  1424. Note that either `num_blocks` or
  1425. `target_num_rows_per_block` must be set, but not both. When
  1426. `target_num_rows_per_block` is set, it only repartitions
  1427. :class:`Dataset` :ref:`blocks <dataset_concept>` that are larger than
  1428. `target_num_rows_per_block`. Note that the system will internally
  1429. figure out the number of rows per :ref:`blocks <dataset_concept>` for
  1430. optimal execution, based on the `target_num_rows_per_block`. This is
  1431. the current behavior because of the implementation and may change in
  1432. the future.
  1433. shuffle: Whether to perform a distributed shuffle during the
  1434. repartition. When shuffle is enabled, each output block
  1435. contains a subset of data rows from each input block, which
  1436. requires all-to-all data movement. When shuffle is disabled,
  1437. output blocks are created from adjacent input blocks,
  1438. minimizing data movement.
  1439. keys: List of key columns repartitioning will use to determine which
  1440. partition will row belong to after repartitioning (by applying
  1441. hash-partitioning algorithm to the whole dataset). Note that, this
  1442. config is only relevant when `DataContext.use_hash_based_shuffle`
  1443. is set to True.
  1444. sort: Whether the blocks should be sorted after repartitioning. Note,
  1445. that by default blocks will be sorted in the ascending order.
  1446. Note that you must set either `num_blocks` or `target_num_rows_per_block`
  1447. but not both.
  1448. Additionally note that this operation materializes the entire dataset in memory
  1449. when you set shuffle to True.
  1450. Returns:
  1451. The repartitioned :class:`Dataset`.
  1452. """ # noqa: E501
  1453. if target_num_rows_per_block is not None:
  1454. if keys is not None:
  1455. warnings.warn(
  1456. "`keys` is ignored when `target_num_rows_per_block` is set."
  1457. )
  1458. if sort is not False:
  1459. warnings.warn(
  1460. "`sort` is ignored when `target_num_rows_per_block` is set."
  1461. )
  1462. if shuffle:
  1463. warnings.warn(
  1464. "`shuffle` is ignored when `target_num_rows_per_block` is set."
  1465. )
  1466. if (num_blocks is None) and (target_num_rows_per_block is None):
  1467. raise ValueError(
  1468. "Either `num_blocks` or `target_num_rows_per_block` must be set"
  1469. )
  1470. if (num_blocks is not None) and (target_num_rows_per_block is not None):
  1471. raise ValueError(
  1472. "Only one of `num_blocks` or `target_num_rows_per_block` must be set, "
  1473. "but not both."
  1474. )
  1475. if target_num_rows_per_block is not None and shuffle:
  1476. raise ValueError(
  1477. "`shuffle` must be False when `target_num_rows_per_block` is set."
  1478. )
  1479. plan = self._plan.copy()
  1480. if target_num_rows_per_block is not None:
  1481. op = StreamingRepartition(
  1482. self._logical_plan.dag,
  1483. target_num_rows_per_block=target_num_rows_per_block,
  1484. )
  1485. else:
  1486. op = Repartition(
  1487. self._logical_plan.dag,
  1488. num_outputs=num_blocks,
  1489. shuffle=shuffle,
  1490. keys=keys,
  1491. sort=sort,
  1492. )
  1493. logical_plan = LogicalPlan(op, self.context)
  1494. return Dataset(plan, logical_plan)
  1495. @AllToAllAPI
  1496. @PublicAPI(api_group=SSR_API_GROUP)
  1497. def random_shuffle(
  1498. self,
  1499. *,
  1500. seed: Optional[int] = None,
  1501. num_blocks: Optional[int] = None,
  1502. **ray_remote_args,
  1503. ) -> "Dataset":
  1504. """Randomly shuffle the rows of this :class:`Dataset`.
  1505. .. tip::
  1506. This method can be slow. For better performance, try
  1507. :ref:`Iterating over batches with shuffling <iterating-over-batches-with-shuffling>`.
  1508. Also, see :ref:`Optimizing shuffles <optimizing_shuffles>`.
  1509. Examples:
  1510. >>> import ray
  1511. >>> ds = ray.data.range(100)
  1512. >>> ds.random_shuffle().take(3) # doctest: +SKIP
  1513. {'id': 41}, {'id': 21}, {'id': 92}]
  1514. >>> ds.random_shuffle(seed=42).take(3) # doctest: +SKIP
  1515. {'id': 77}, {'id': 21}, {'id': 63}]
  1516. Time complexity: O(dataset size / parallelism)
  1517. Args:
  1518. seed: Fix the random seed to use, otherwise one is chosen
  1519. based on system randomness.
  1520. Returns:
  1521. The shuffled :class:`Dataset`.
  1522. """ # noqa: E501
  1523. if num_blocks is not None:
  1524. raise DeprecationWarning(
  1525. "`num_blocks` parameter is deprecated in Ray 2.9. random_shuffle() "
  1526. "does not support to change the number of output blocks. Use "
  1527. "repartition() instead.", # noqa: E501
  1528. )
  1529. plan = self._plan.copy()
  1530. op = RandomShuffle(
  1531. self._logical_plan.dag,
  1532. seed=seed,
  1533. ray_remote_args=ray_remote_args,
  1534. )
  1535. logical_plan = LogicalPlan(op, self.context)
  1536. return Dataset(plan, logical_plan)
  1537. @AllToAllAPI
  1538. @PublicAPI(api_group=SSR_API_GROUP)
  1539. def randomize_block_order(
  1540. self,
  1541. *,
  1542. seed: Optional[int] = None,
  1543. ) -> "Dataset":
  1544. """Randomly shuffle the :ref:`blocks <dataset_concept>` of this :class:`Dataset`.
  1545. This method is useful if you :meth:`~Dataset.split` your dataset into shards and
  1546. want to randomize the data in each shard without performing a full
  1547. :meth:`~Dataset.random_shuffle`.
  1548. Examples:
  1549. >>> import ray
  1550. >>> ds = ray.data.range(100)
  1551. >>> ds.take(5)
  1552. [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}]
  1553. >>> ds.randomize_block_order().take(5) # doctest: +SKIP
  1554. {'id': 15}, {'id': 16}, {'id': 17}, {'id': 18}, {'id': 19}]
  1555. Args:
  1556. seed: Fix the random seed to use, otherwise one is chosen
  1557. based on system randomness.
  1558. Returns:
  1559. The block-shuffled :class:`Dataset`.
  1560. """ # noqa: E501
  1561. plan = self._plan.copy()
  1562. op = RandomizeBlocks(
  1563. self._logical_plan.dag,
  1564. seed=seed,
  1565. )
  1566. logical_plan = LogicalPlan(op, self.context)
  1567. return Dataset(plan, logical_plan)
  1568. @PublicAPI(api_group=BT_API_GROUP)
  1569. def random_sample(
  1570. self, fraction: float, *, seed: Optional[int] = None
  1571. ) -> "Dataset":
  1572. """Returns a new :class:`Dataset` containing a random fraction of the rows.
  1573. .. note::
  1574. This method returns roughly ``fraction * total_rows`` rows. An exact number
  1575. of rows isn't guaranteed.
  1576. Examples:
  1577. >>> import ray
  1578. >>> ds1 = ray.data.range(100)
  1579. >>> ds1.random_sample(0.1).count() # doctest: +SKIP
  1580. 10
  1581. >>> ds2 = ray.data.range(1000)
  1582. >>> ds2.random_sample(0.123, seed=42).take(2) # doctest: +SKIP
  1583. [{'id': 2}, {'id': 9}]
  1584. >>> ds2.random_sample(0.123, seed=42).take(2) # doctest: +SKIP
  1585. [{'id': 2}, {'id': 9}]
  1586. Args:
  1587. fraction: The fraction of elements to sample.
  1588. seed: Seeds the python random pRNG generator.
  1589. Returns:
  1590. Returns a :class:`Dataset` containing the sampled rows.
  1591. """
  1592. import pandas as pd
  1593. import pyarrow as pa
  1594. if self._plan.initial_num_blocks() == 0:
  1595. raise ValueError("Cannot sample from an empty Dataset.")
  1596. if fraction < 0 or fraction > 1:
  1597. raise ValueError("Fraction must be between 0 and 1.")
  1598. from ray.data._internal.execution.interfaces.task_context import TaskContext
  1599. def random_sample(batch: DataBatch, seed: Optional[int]):
  1600. ctx = TaskContext.get_current()
  1601. if "rng" in ctx.kwargs:
  1602. rng = ctx.kwargs["rng"]
  1603. elif seed is None:
  1604. rng = np.random.default_rng()
  1605. ctx.kwargs["rng"] = rng
  1606. else:
  1607. rng = np.random.default_rng([ctx.task_idx, seed])
  1608. ctx.kwargs["rng"] = rng
  1609. mask_idx = np.where(rng.random(len(batch)) < fraction)[0]
  1610. if isinstance(batch, pa.Table):
  1611. return batch.take(mask_idx)
  1612. elif isinstance(batch, pd.DataFrame):
  1613. return batch.iloc[mask_idx, :]
  1614. raise ValueError(f"Unsupported batch type: {type(batch)}")
  1615. return self.map_batches(
  1616. random_sample,
  1617. fn_args=[seed],
  1618. batch_format=None,
  1619. batch_size=None,
  1620. )
  1621. @ConsumptionAPI
  1622. @PublicAPI(api_group=SMJ_API_GROUP)
  1623. def streaming_split(
  1624. self,
  1625. n: int,
  1626. *,
  1627. equal: bool = False,
  1628. locality_hints: Optional[List["NodeIdStr"]] = None,
  1629. ) -> List[DataIterator]:
  1630. """Returns ``n`` :class:`DataIterators <ray.data.DataIterator>` that can
  1631. be used to read disjoint subsets of the dataset in parallel.
  1632. This method is the recommended way to consume :class:`Datasets <Dataset>` for
  1633. distributed training.
  1634. Streaming split works by delegating the execution of this :class:`Dataset` to a
  1635. coordinator actor. The coordinator pulls block references from the executed
  1636. stream, and divides those blocks among ``n`` output iterators. Iterators pull
  1637. blocks from the coordinator actor to return to their caller on ``next``.
  1638. The returned iterators are also repeatable; each iteration will trigger a
  1639. new execution of the Dataset. There is an implicit barrier at the start of
  1640. each iteration, which means that `next` must be called on all iterators before
  1641. the iteration starts.
  1642. .. warning::
  1643. Because iterators are pulling blocks from the same :class:`Dataset`
  1644. execution, if one iterator falls behind, other iterators may be stalled.
  1645. Examples:
  1646. .. testcode::
  1647. import ray
  1648. ds = ray.data.range(100)
  1649. it1, it2 = ds.streaming_split(2, equal=True)
  1650. Consume data from iterators in parallel.
  1651. .. testcode::
  1652. @ray.remote
  1653. def consume(it):
  1654. for batch in it.iter_batches():
  1655. pass
  1656. ray.get([consume.remote(it1), consume.remote(it2)])
  1657. You can loop over the iterators multiple times (multiple epochs).
  1658. .. testcode::
  1659. @ray.remote
  1660. def train(it):
  1661. NUM_EPOCHS = 2
  1662. for _ in range(NUM_EPOCHS):
  1663. for batch in it.iter_batches():
  1664. pass
  1665. ray.get([train.remote(it1), train.remote(it2)])
  1666. The following remote function call blocks waiting for a read on ``it2`` to
  1667. start.
  1668. .. testcode::
  1669. :skipif: True
  1670. ray.get(train.remote(it1))
  1671. Args:
  1672. n: Number of output iterators to return.
  1673. equal: If ``True``, each output iterator sees an exactly equal number
  1674. of rows, dropping data if necessary. If ``False``, some iterators may
  1675. see slightly more or less rows than others, but no data is dropped.
  1676. locality_hints: Specify the node ids corresponding to each iterator
  1677. location. Dataset will try to minimize data movement based on the
  1678. iterator output locations. This list must have length ``n``. You can
  1679. get the current node id of a task or actor by calling
  1680. ``ray.get_runtime_context().get_node_id()``.
  1681. Returns:
  1682. The output iterator splits. These iterators are Ray-serializable and can
  1683. be freely passed to any Ray task or actor.
  1684. .. seealso::
  1685. :meth:`Dataset.split`
  1686. Unlike :meth:`~Dataset.streaming_split`, :meth:`~Dataset.split`
  1687. materializes the dataset in memory.
  1688. """
  1689. plan = self._plan.copy()
  1690. op = StreamingSplit(
  1691. self._logical_plan.dag,
  1692. num_splits=n,
  1693. equal=equal,
  1694. locality_hints=locality_hints,
  1695. )
  1696. logical_plan = LogicalPlan(op, self.context)
  1697. split_dataset = Dataset(plan, logical_plan)
  1698. split_dataset._set_uuid(self._uuid)
  1699. return StreamSplitDataIterator.create(split_dataset, n, locality_hints)
  1700. @ConsumptionAPI
  1701. @PublicAPI(api_group=SMJ_API_GROUP)
  1702. def split(
  1703. self, n: int, *, equal: bool = False, locality_hints: Optional[List[Any]] = None
  1704. ) -> List["MaterializedDataset"]:
  1705. """Materialize and split the dataset into ``n`` disjoint pieces.
  1706. This method returns a list of ``MaterializedDataset`` that can be passed to Ray
  1707. Tasks and Actors and used to read the dataset rows in parallel.
  1708. Examples:
  1709. .. testcode::
  1710. @ray.remote
  1711. class Worker:
  1712. def train(self, data_iterator):
  1713. for batch in data_iterator.iter_batches(batch_size=8):
  1714. pass
  1715. workers = [Worker.remote() for _ in range(4)]
  1716. shards = ray.data.range(100).split(n=4, equal=True)
  1717. ray.get([w.train.remote(s) for w, s in zip(workers, shards)])
  1718. Time complexity: O(1)
  1719. Args:
  1720. n: Number of child datasets to return.
  1721. equal: Whether to guarantee each split has an equal
  1722. number of records. This might drop records if the rows can't be
  1723. divided equally among the splits.
  1724. locality_hints: [Experimental] A list of Ray actor handles of size ``n``.
  1725. The system tries to co-locate the blocks of the i-th dataset
  1726. with the i-th actor to maximize data locality.
  1727. Returns:
  1728. A list of ``n`` disjoint dataset splits.
  1729. .. seealso::
  1730. :meth:`Dataset.split_at_indices`
  1731. Unlike :meth:`~Dataset.split`, which splits a dataset into approximately
  1732. equal splits, :meth:`Dataset.split_proportionately` lets you split a
  1733. dataset into different sizes.
  1734. :meth:`Dataset.split_proportionately`
  1735. This method is equivalent to :meth:`Dataset.split_at_indices` if
  1736. you compute indices manually.
  1737. :meth:`Dataset.streaming_split`.
  1738. Unlike :meth:`~Dataset.split`, :meth:`~Dataset.streaming_split`
  1739. doesn't materialize the dataset in memory.
  1740. """
  1741. if n <= 0:
  1742. raise ValueError(f"The number of splits {n} is not positive.")
  1743. # fallback to split_at_indices for equal split without locality hints.
  1744. # simple benchmarks shows spilit_at_indices yields more stable performance.
  1745. # https://github.com/ray-project/ray/pull/26641 for more context.
  1746. if equal and locality_hints is None:
  1747. count = self.count()
  1748. split_index = count // n
  1749. # we are creating n split_indices which will generate
  1750. # n + 1 splits; the last split will at most contains (n - 1)
  1751. # rows, which could be safely dropped.
  1752. split_indices = [split_index * i for i in range(1, n + 1)]
  1753. shards = self.split_at_indices(split_indices)
  1754. return shards[:n]
  1755. if locality_hints and len(locality_hints) != n:
  1756. raise ValueError(
  1757. f"The length of locality_hints {len(locality_hints)} "
  1758. f"doesn't equal the number of splits {n}."
  1759. )
  1760. bundle: RefBundle = self._plan.execute()
  1761. # We should not free blocks since we will materialize the Datasets.
  1762. owned_by_consumer = False
  1763. stats = self._plan.stats()
  1764. block_refs, metadata = zip(*bundle.blocks)
  1765. if locality_hints is None:
  1766. block_refs_splits = np.array_split(block_refs, n)
  1767. metadata_splits = np.array_split(metadata, n)
  1768. split_datasets = []
  1769. for block_refs_split, metadata_split in zip(
  1770. block_refs_splits, metadata_splits
  1771. ):
  1772. ref_bundles = [
  1773. RefBundle(
  1774. [(b, m)], owns_blocks=owned_by_consumer, schema=bundle.schema
  1775. )
  1776. for b, m in zip(block_refs_split, metadata_split)
  1777. ]
  1778. logical_plan = LogicalPlan(
  1779. InputData(input_data=ref_bundles),
  1780. self.context,
  1781. )
  1782. split_datasets.append(
  1783. MaterializedDataset(
  1784. ExecutionPlan(stats, self.context.copy()),
  1785. logical_plan,
  1786. )
  1787. )
  1788. return split_datasets
  1789. metadata_mapping = dict(zip(block_refs, metadata))
  1790. # If the locality_hints is set, we use a two-round greedy algorithm
  1791. # to co-locate the blocks with the actors based on block
  1792. # and actor's location (node_id).
  1793. #
  1794. # The split algorithm tries to allocate equally-sized blocks regardless
  1795. # of locality. Thus we first calculate the expected number of blocks
  1796. # for each split.
  1797. #
  1798. # In the first round, for each actor, we look for all blocks that
  1799. # match the actor's node_id, then allocate those matched blocks to
  1800. # this actor until we reach the limit(expected number).
  1801. #
  1802. # In the second round: fill each actor's allocation with
  1803. # remaining unallocated blocks until we reach the limit.
  1804. def build_allocation_size_map(
  1805. num_blocks: int, actors: List[Any]
  1806. ) -> Dict[Any, int]:
  1807. """Given the total number of blocks and a list of actors, calcuate
  1808. the expected number of blocks to allocate for each actor.
  1809. """
  1810. num_actors = len(actors)
  1811. num_blocks_per_actor = num_blocks // num_actors
  1812. num_blocks_left = num_blocks - num_blocks_per_actor * n
  1813. num_blocks_by_actor = {}
  1814. for i, actor in enumerate(actors):
  1815. num_blocks_by_actor[actor] = num_blocks_per_actor
  1816. if i < num_blocks_left:
  1817. num_blocks_by_actor[actor] += 1
  1818. return num_blocks_by_actor
  1819. def build_block_refs_by_node_id(
  1820. blocks: List[ObjectRef[Block]],
  1821. ) -> Dict[str, List[ObjectRef[Block]]]:
  1822. """Build the reverse index from node_id to block_refs. For
  1823. simplicity, if the block is stored on multiple nodes we
  1824. only pick the first one.
  1825. """
  1826. block_ref_locations = ray.experimental.get_object_locations(blocks)
  1827. block_refs_by_node_id = collections.defaultdict(list)
  1828. for block_ref in blocks:
  1829. node_ids = block_ref_locations.get(block_ref, {}).get("node_ids", [])
  1830. node_id = node_ids[0] if node_ids else None
  1831. block_refs_by_node_id[node_id].append(block_ref)
  1832. return block_refs_by_node_id
  1833. def build_node_id_by_actor(actors: List[Any]) -> Dict[Any, str]:
  1834. """Build a map from a actor to its node_id."""
  1835. actors_state = ray._private.state.actors()
  1836. return {
  1837. actor: actors_state.get(actor._actor_id.hex(), {})
  1838. .get("Address", {})
  1839. .get("NodeID")
  1840. for actor in actors
  1841. }
  1842. # expected number of blocks to be allocated for each actor
  1843. expected_block_count_by_actor = build_allocation_size_map(
  1844. len(block_refs), locality_hints
  1845. )
  1846. # the reverse index from node_id to block_refs
  1847. block_refs_by_node_id = build_block_refs_by_node_id(block_refs)
  1848. # the map from actor to its node_id
  1849. node_id_by_actor = build_node_id_by_actor(locality_hints)
  1850. allocation_per_actor = collections.defaultdict(list)
  1851. # In the first round, for each actor, we look for all blocks that
  1852. # match the actor's node_id, then allocate those matched blocks to
  1853. # this actor until we reach the limit(expected number)
  1854. for actor in locality_hints:
  1855. node_id = node_id_by_actor[actor]
  1856. matching_blocks = block_refs_by_node_id[node_id]
  1857. expected_block_count = expected_block_count_by_actor[actor]
  1858. allocation = []
  1859. while matching_blocks and len(allocation) < expected_block_count:
  1860. allocation.append(matching_blocks.pop())
  1861. allocation_per_actor[actor] = allocation
  1862. # In the second round: fill each actor's allocation with
  1863. # remaining unallocated blocks until we reach the limit
  1864. remaining_block_refs = list(
  1865. itertools.chain.from_iterable(block_refs_by_node_id.values())
  1866. )
  1867. for actor in locality_hints:
  1868. while (
  1869. len(allocation_per_actor[actor]) < expected_block_count_by_actor[actor]
  1870. ):
  1871. allocation_per_actor[actor].append(remaining_block_refs.pop())
  1872. assert len(remaining_block_refs) == 0, len(remaining_block_refs)
  1873. per_split_bundles = []
  1874. for actor in locality_hints:
  1875. blocks = allocation_per_actor[actor]
  1876. metadata = [metadata_mapping[b] for b in blocks]
  1877. bundle = RefBundle(
  1878. tuple(zip(blocks, metadata)),
  1879. owns_blocks=owned_by_consumer,
  1880. schema=bundle.schema,
  1881. )
  1882. per_split_bundles.append(bundle)
  1883. if equal:
  1884. # equalize the splits
  1885. per_split_bundles = _equalize(per_split_bundles, owned_by_consumer)
  1886. split_datasets = []
  1887. for bundle in per_split_bundles:
  1888. logical_plan = LogicalPlan(InputData(input_data=[bundle]), self.context)
  1889. split_datasets.append(
  1890. MaterializedDataset(
  1891. ExecutionPlan(stats, self.context.copy()),
  1892. logical_plan,
  1893. )
  1894. )
  1895. return split_datasets
  1896. @ConsumptionAPI
  1897. @PublicAPI(api_group=SMJ_API_GROUP)
  1898. def split_at_indices(self, indices: List[int]) -> List["MaterializedDataset"]:
  1899. """Materialize and split the dataset at the given indices (like ``np.split``).
  1900. Examples:
  1901. >>> import ray
  1902. >>> ds = ray.data.range(10)
  1903. >>> d1, d2, d3 = ds.split_at_indices([2, 5])
  1904. >>> d1.take_batch()
  1905. {'id': array([0, 1])}
  1906. >>> d2.take_batch()
  1907. {'id': array([2, 3, 4])}
  1908. >>> d3.take_batch()
  1909. {'id': array([5, 6, 7, 8, 9])}
  1910. Time complexity: O(num splits)
  1911. Args:
  1912. indices: List of sorted integers which indicate where the dataset
  1913. are split. If an index exceeds the length of the dataset,
  1914. an empty dataset is returned.
  1915. Returns:
  1916. The dataset splits.
  1917. .. seealso::
  1918. :meth:`Dataset.split`
  1919. Unlike :meth:`~Dataset.split_at_indices`, which lets you split a
  1920. dataset into different sizes, :meth:`Dataset.split` splits a dataset
  1921. into approximately equal splits.
  1922. :meth:`Dataset.split_proportionately`
  1923. This method is equivalent to :meth:`Dataset.split_at_indices` if
  1924. you compute indices manually.
  1925. :meth:`Dataset.streaming_split`.
  1926. Unlike :meth:`~Dataset.split`, :meth:`~Dataset.streaming_split`
  1927. doesn't materialize the dataset in memory.
  1928. """
  1929. if len(indices) < 1:
  1930. raise ValueError("indices must be at least of length 1")
  1931. if sorted(indices) != indices:
  1932. raise ValueError("indices must be sorted")
  1933. if indices[0] < 0:
  1934. raise ValueError("indices must be positive")
  1935. start_time = time.perf_counter()
  1936. bundle: RefBundle = self._plan.execute()
  1937. blocks, metadata = _split_at_indices(
  1938. bundle.blocks,
  1939. indices,
  1940. False,
  1941. )
  1942. split_duration = time.perf_counter() - start_time
  1943. parent_stats = self._plan.stats()
  1944. splits = []
  1945. for bs, ms in zip(blocks, metadata):
  1946. stats = DatasetStats(metadata={"Split": ms}, parent=parent_stats)
  1947. stats.time_total_s = split_duration
  1948. ref_bundles = [
  1949. RefBundle([(b, m)], owns_blocks=False, schema=bundle.schema)
  1950. for b, m in zip(bs, ms)
  1951. ]
  1952. logical_plan = LogicalPlan(
  1953. InputData(input_data=ref_bundles),
  1954. self.context,
  1955. )
  1956. splits.append(
  1957. MaterializedDataset(
  1958. ExecutionPlan(stats, self.context.copy()),
  1959. logical_plan,
  1960. )
  1961. )
  1962. return splits
  1963. @ConsumptionAPI
  1964. @PublicAPI(api_group=SMJ_API_GROUP)
  1965. def split_proportionately(
  1966. self, proportions: List[float]
  1967. ) -> List["MaterializedDataset"]:
  1968. """Materialize and split the dataset using proportions.
  1969. A common use case for this is splitting the dataset into train
  1970. and test sets (equivalent to eg. scikit-learn's ``train_test_split``).
  1971. For a higher level abstraction, see :meth:`Dataset.train_test_split`.
  1972. This method splits datasets so that all splits
  1973. always contains at least one row. If that isn't possible,
  1974. an exception is raised.
  1975. This is equivalent to caulculating the indices manually and calling
  1976. :meth:`Dataset.split_at_indices`.
  1977. Examples:
  1978. >>> import ray
  1979. >>> ds = ray.data.range(10)
  1980. >>> d1, d2, d3 = ds.split_proportionately([0.2, 0.5])
  1981. >>> d1.take_batch()
  1982. {'id': array([0, 1])}
  1983. >>> d2.take_batch()
  1984. {'id': array([2, 3, 4, 5, 6])}
  1985. >>> d3.take_batch()
  1986. {'id': array([7, 8, 9])}
  1987. Time complexity: O(num splits)
  1988. Args:
  1989. proportions: List of proportions to split the dataset according to.
  1990. Must sum up to less than 1, and each proportion must be bigger
  1991. than 0.
  1992. Returns:
  1993. The dataset splits.
  1994. .. seealso::
  1995. :meth:`Dataset.split`
  1996. Unlike :meth:`~Dataset.split_proportionately`, which lets you split a
  1997. dataset into different sizes, :meth:`Dataset.split` splits a dataset
  1998. into approximately equal splits.
  1999. :meth:`Dataset.split_at_indices`
  2000. :meth:`Dataset.split_proportionately` uses this method under the hood.
  2001. :meth:`Dataset.streaming_split`.
  2002. Unlike :meth:`~Dataset.split`, :meth:`~Dataset.streaming_split`
  2003. doesn't materialize the dataset in memory.
  2004. """
  2005. if len(proportions) < 1:
  2006. raise ValueError("proportions must be at least of length 1")
  2007. if sum(proportions) >= 1:
  2008. raise ValueError("proportions must sum to less than 1")
  2009. if any(p <= 0 for p in proportions):
  2010. raise ValueError("proportions must be bigger than 0")
  2011. dataset_length = self.count()
  2012. cumulative_proportions = np.cumsum(proportions)
  2013. split_indices = [
  2014. int(dataset_length * proportion) for proportion in cumulative_proportions
  2015. ]
  2016. # Ensure each split has at least one element
  2017. subtract = 0
  2018. for i in range(len(split_indices) - 2, -1, -1):
  2019. split_indices[i] -= subtract
  2020. if split_indices[i] == split_indices[i + 1]:
  2021. subtract += 1
  2022. split_indices[i] -= 1
  2023. if any(i <= 0 for i in split_indices):
  2024. raise ValueError(
  2025. "Couldn't create non-empty splits with the given proportions."
  2026. )
  2027. return self.split_at_indices(split_indices)
  2028. @ConsumptionAPI
  2029. @PublicAPI(api_group=SMJ_API_GROUP)
  2030. def train_test_split(
  2031. self,
  2032. test_size: Union[int, float],
  2033. *,
  2034. shuffle: bool = False,
  2035. seed: Optional[int] = None,
  2036. stratify: Optional[str] = None,
  2037. ) -> Tuple["MaterializedDataset", "MaterializedDataset"]:
  2038. """Materialize and split the dataset into train and test subsets.
  2039. Examples:
  2040. >>> import ray
  2041. >>> ds = ray.data.range(8)
  2042. >>> train, test = ds.train_test_split(test_size=0.25)
  2043. >>> train.take_batch()
  2044. {'id': array([0, 1, 2, 3, 4, 5])}
  2045. >>> test.take_batch()
  2046. {'id': array([6, 7])}
  2047. Args:
  2048. test_size: If float, should be between 0.0 and 1.0 and represent the
  2049. proportion of the dataset to include in the test split. If int,
  2050. represents the absolute number of test samples. The train split
  2051. always complements the test split.
  2052. shuffle: Whether or not to globally shuffle the dataset before splitting.
  2053. Defaults to ``False``. This may be a very expensive operation with a
  2054. large dataset.
  2055. seed: Fix the random seed to use for shuffle, otherwise one is chosen
  2056. based on system randomness. Ignored if ``shuffle=False``.
  2057. stratify: Optional column name to use for stratified sampling. If provided,
  2058. the splits will maintain the same proportions of each class in the
  2059. stratify column across both train and test sets.
  2060. Returns:
  2061. Train and test subsets as two ``MaterializedDatasets``.
  2062. .. seealso::
  2063. :meth:`Dataset.split_proportionately`
  2064. """
  2065. ds = self
  2066. if shuffle:
  2067. ds = ds.random_shuffle(seed=seed)
  2068. if not isinstance(test_size, (int, float)):
  2069. raise TypeError(f"`test_size` must be int or float got {type(test_size)}.")
  2070. # Validate that shuffle=True and stratify are not both specified
  2071. if shuffle and stratify is not None:
  2072. raise ValueError(
  2073. "Cannot specify both 'shuffle=True' and 'stratify' parameters. "
  2074. "Stratified splitting maintains class proportions and is incompatible with shuffling."
  2075. )
  2076. # Handle stratified splitting
  2077. if stratify is not None:
  2078. return self._stratified_train_test_split(ds, test_size, stratify)
  2079. # Handle non-stratified splitting (existing logic)
  2080. if isinstance(test_size, float):
  2081. self._validate_test_size_float(test_size)
  2082. return ds.split_proportionately([1 - test_size])
  2083. else:
  2084. self._validate_test_size_int(test_size, ds)
  2085. ds_length = ds.count()
  2086. return ds.split_at_indices([ds_length - test_size])
  2087. def _stratified_train_test_split(
  2088. self, ds: "Dataset", test_size: Union[int, float], stratify: str
  2089. ) -> Tuple["MaterializedDataset", "MaterializedDataset"]:
  2090. """Perform stratified train-test split on the dataset.
  2091. Args:
  2092. ds: The dataset to split.
  2093. test_size: Test size as int or float.
  2094. stratify: Column name to use for stratified sampling.
  2095. Returns:
  2096. Train and test subsets as two MaterializedDatasets.
  2097. """
  2098. # Normalize test_size to float (only materialize if needed)
  2099. if isinstance(test_size, int):
  2100. ds_length = self._validate_test_size_int(test_size, ds)
  2101. test_size = test_size / ds_length
  2102. else:
  2103. self._validate_test_size_float(test_size)
  2104. def add_train_flag(group_batch):
  2105. n = len(group_batch)
  2106. test_count = int(n * test_size)
  2107. group_batch[_TRAIN_TEST_SPLIT_COLUMN] = np.array(
  2108. [True] * (n - test_count) + [False] * test_count
  2109. )
  2110. return group_batch
  2111. split_ds = ds.groupby(stratify).map_groups(add_train_flag).materialize()
  2112. train_ds = split_ds.filter(
  2113. lambda row: row[_TRAIN_TEST_SPLIT_COLUMN]
  2114. ).drop_columns([_TRAIN_TEST_SPLIT_COLUMN])
  2115. test_ds = split_ds.filter(
  2116. lambda row: not row[_TRAIN_TEST_SPLIT_COLUMN]
  2117. ).drop_columns([_TRAIN_TEST_SPLIT_COLUMN])
  2118. return train_ds, test_ds
  2119. def _validate_test_size_float(self, test_size: float) -> None:
  2120. """Validate test_size when it's a float.
  2121. Args:
  2122. test_size: Test size as float between 0 and 1.
  2123. Raises:
  2124. ValueError: If test_size is not in valid range.
  2125. """
  2126. if test_size <= 0 or test_size >= 1:
  2127. raise ValueError(
  2128. "If `test_size` is a float, it must be bigger than 0 and smaller "
  2129. f"than 1. Got {test_size}."
  2130. )
  2131. def _validate_test_size_int(self, test_size: int, ds: "Dataset") -> int:
  2132. """Validate test_size when it's an int and return dataset length.
  2133. Args:
  2134. test_size: Test size as int.
  2135. ds: Dataset to validate against.
  2136. Returns:
  2137. Dataset length for reuse.
  2138. Raises:
  2139. ValueError: If test_size is not in valid range.
  2140. """
  2141. ds_length = ds.count()
  2142. if test_size <= 0 or test_size >= ds_length:
  2143. raise ValueError(
  2144. "If `test_size` is an int, it must be bigger than 0 and smaller "
  2145. f"than the size of the dataset ({ds_length}). "
  2146. f"Got {test_size}."
  2147. )
  2148. return ds_length
  2149. @PublicAPI(stability="alpha", api_group=SMJ_API_GROUP)
  2150. def streaming_train_test_split(
  2151. self,
  2152. test_size: float,
  2153. *,
  2154. split_type: Literal["hash", "random"] = "random",
  2155. hash_column: Optional[str] = None,
  2156. seed: Optional[int] = None,
  2157. **ray_remote_kwargs,
  2158. ) -> Tuple["Dataset", "Dataset"]:
  2159. """split the dataset into train and test subsets in a streaming manner.
  2160. This method is recommended for large datasets.
  2161. The split type can be either "hash" or "random".
  2162. - "random": The dataset is split into random train and test subsets.
  2163. - "hash": The dataset is split into train and test subsets based on the hash of the key column.
  2164. .. tip::
  2165. Make sure to set the `preserve_order` flag in the `ExecutionOptions` to True
  2166. to ensure that the split is deterministic across pipeline executions. This is important
  2167. to avoid test rows to end up in the train set and vice versa on multiple executions.
  2168. This can be set with ``ray.data.DataContext.get_current().execution_options.preserve_order = True``.
  2169. Examples:
  2170. Examples with Random split:
  2171. >>> import ray
  2172. >>> ctx = ray.data.DataContext.get_current()
  2173. >>> ctx.execution_options.preserve_order = True
  2174. >>> ds = ray.data.range(8)
  2175. >>> train, test = ds.streaming_train_test_split(test_size=0.25, seed=0)
  2176. >>> train.count()
  2177. 6
  2178. >>> test.count()
  2179. 2
  2180. >>> ctx.execution_options.preserve_order = False
  2181. Examples with Hash split:
  2182. >>> import ray
  2183. >>> ds = ray.data.range(8)
  2184. >>> train, test = ds.streaming_train_test_split(test_size=0.25, split_type="hash", hash_column="id")
  2185. >>> train.take_batch()
  2186. {'id': array([0, 2, 3, 4, 5, 6])}
  2187. >>> test.take_batch()
  2188. {'id': array([1, 7])}
  2189. Args:
  2190. test_size: The proportion of the dataset to include in the test split.
  2191. Must be between 0.0 and 1.0.
  2192. split_type: The type of split to perform. Can be "hash" or "random".
  2193. hash_column: The column to use for the hash split. Required for hash split and
  2194. ignored for random split.
  2195. seed: The seed to use for the random split. Ignored for hash split.
  2196. **ray_remote_kwargs: Additional kwargs to pass to the Ray remote function.
  2197. Returns:
  2198. Train and test subsets as two ``Dataset``.
  2199. .. seealso::
  2200. :meth:`Dataset.train_test_split`
  2201. """
  2202. import hashlib
  2203. import pyarrow as pa
  2204. from ray.data._internal.execution.interfaces.task_context import TaskContext
  2205. if test_size <= 0 or test_size >= 1:
  2206. raise ValueError("test_size must be between 0 and 1.")
  2207. if seed is not None and split_type == "hash":
  2208. raise ValueError("seed is not supported for hash split")
  2209. if hash_column is not None and split_type == "random":
  2210. raise ValueError("hash_column is not supported for random split")
  2211. def random_split(batch: pa.Table):
  2212. """
  2213. Perform a random split on a batch: each row goes to train with probability (1 - test_proportion),
  2214. or to test otherwise.
  2215. This version ensures that the random choices are **stable per Ray task execution** by seeding
  2216. the RNG with a combination of a user-specified seed and the Ray task ID.
  2217. """
  2218. ctx = TaskContext.get_current()
  2219. if "train_test_split_rng" in ctx.kwargs:
  2220. rng = ctx.kwargs["train_test_split_rng"]
  2221. elif seed is None:
  2222. rng = np.random.default_rng([ctx.task_idx])
  2223. ctx.kwargs["train_test_split_rng"] = rng
  2224. else:
  2225. rng = np.random.default_rng([ctx.task_idx, seed])
  2226. ctx.kwargs["train_test_split_rng"] = rng
  2227. # Draw Bernoulli samples: 1 = train, 0 = test
  2228. is_train = rng.random(batch.num_rows) < (1 - test_size)
  2229. return batch.append_column(
  2230. _TRAIN_TEST_SPLIT_COLUMN, pa.array(is_train, type=pa.bool_())
  2231. )
  2232. def hash_split(batch: pa.Table) -> tuple[pa.Table, pa.Table]:
  2233. def key_to_bucket(key: Any) -> int:
  2234. # 64-bit integer in [0, 2^64)
  2235. h = int.from_bytes(
  2236. hashlib.blake2b(str(key).encode(), digest_size=8).digest(), "big"
  2237. )
  2238. return True if h < (1 - test_size) * (1 << 64) else False
  2239. if hash_column in batch.column_names:
  2240. # Use provided key for hashing
  2241. keys = batch[hash_column].to_numpy()
  2242. else:
  2243. raise ValueError(f"Key column {hash_column} not found in batch")
  2244. bucket_arr = pa.array([key_to_bucket(key) for key in keys], type=pa.bool_())
  2245. return batch.append_column(_TRAIN_TEST_SPLIT_COLUMN, bucket_arr)
  2246. if split_type == "random":
  2247. bucketted = self.map_batches(
  2248. random_split,
  2249. batch_format="pyarrow",
  2250. **ray_remote_kwargs,
  2251. )
  2252. elif split_type == "hash":
  2253. if hash_column is None:
  2254. raise ValueError("hash_column is required for hash split")
  2255. bucketted = self.map_batches(
  2256. hash_split,
  2257. batch_format="pyarrow",
  2258. **ray_remote_kwargs,
  2259. )
  2260. else:
  2261. raise ValueError(f"Invalid split type: {split_type}")
  2262. ds_train = bucketted.filter(
  2263. expr=f"{_TRAIN_TEST_SPLIT_COLUMN} == True"
  2264. ).drop_columns([_TRAIN_TEST_SPLIT_COLUMN])
  2265. ds_test = bucketted.filter(
  2266. expr=f"{_TRAIN_TEST_SPLIT_COLUMN} == False"
  2267. ).drop_columns([_TRAIN_TEST_SPLIT_COLUMN])
  2268. return ds_train, ds_test
  2269. @PublicAPI(api_group=SMJ_API_GROUP)
  2270. def union(self, *other: "Dataset") -> "Dataset":
  2271. """Concatenate :class:`Datasets <ray.data.Dataset>` across rows.
  2272. The order of the blocks in the datasets is preserved, as is the
  2273. relative ordering between the datasets passed in the argument list.
  2274. .. caution::
  2275. Unioned datasets aren't lineage-serializable. As a result, they can't be
  2276. used as a tunable hyperparameter in Ray Tune.
  2277. Examples:
  2278. >>> import ray
  2279. >>> ds1 = ray.data.range(2)
  2280. >>> ds2 = ray.data.range(3)
  2281. >>> ds1.union(ds2).take_all() # doctest: +SKIP
  2282. [{'id': 0}, {'id': 1}, {'id': 0}, {'id': 1}, {'id': 2}]
  2283. Args:
  2284. *other: The datasets to combine with this one. The datasets
  2285. must have the same schema as this dataset, otherwise the
  2286. behavior is undefined.
  2287. Returns:
  2288. A new dataset holding the rows of the input datasets.
  2289. """
  2290. start_time = time.perf_counter()
  2291. datasets = [self] + list(other)
  2292. logical_plans = [union_ds._plan._logical_plan for union_ds in datasets]
  2293. op = UnionLogicalOperator(
  2294. *[plan.dag for plan in logical_plans],
  2295. )
  2296. logical_plan = LogicalPlan(op, self.context)
  2297. stats = DatasetStats(
  2298. metadata={"Union": []},
  2299. parent=[d._plan.stats() for d in datasets],
  2300. )
  2301. stats.time_total_s = time.perf_counter() - start_time
  2302. return Dataset(
  2303. ExecutionPlan(stats, self.context.copy()),
  2304. logical_plan,
  2305. )
  2306. @AllToAllAPI
  2307. @PublicAPI(api_group=SMJ_API_GROUP)
  2308. def join(
  2309. self,
  2310. ds: "Dataset",
  2311. join_type: str,
  2312. num_partitions: int,
  2313. on: Tuple[str] = ("id",),
  2314. right_on: Optional[Tuple[str]] = None,
  2315. left_suffix: Optional[str] = None,
  2316. right_suffix: Optional[str] = None,
  2317. *,
  2318. partition_size_hint: Optional[int] = None,
  2319. aggregator_ray_remote_args: Optional[Dict[str, Any]] = None,
  2320. validate_schemas: bool = False,
  2321. ) -> "Dataset":
  2322. """Join :class:`Datasets <ray.data.Dataset>` on join keys.
  2323. Args:
  2324. ds: Other dataset to join against
  2325. join_type: The kind of join that should be performed, one of ("inner",
  2326. "left_outer", "right_outer", "full_outer", "left_semi", "right_semi",
  2327. "left_anti", "right_anti").
  2328. num_partitions: Total number of "partitions" input sequences will be split
  2329. into with each partition being joined independently. Increasing number
  2330. of partitions allows to reduce individual partition size, hence reducing
  2331. memory requirements when individual partitions are being joined. Note
  2332. that, consequently, this will also be a total number of blocks that will
  2333. be produced as a result of executing join.
  2334. on: The columns from the left operand that will be used as
  2335. keys for the join operation.
  2336. right_on: The columns from the right operand that will be
  2337. used as keys for the join operation. When none, `on` will
  2338. be assumed to be a list of columns to be used for the right dataset
  2339. as well.
  2340. left_suffix: (Optional) Suffix to be appended for columns of the left
  2341. operand.
  2342. right_suffix: (Optional) Suffix to be appended for columns of the right
  2343. operand.
  2344. partition_size_hint: (Optional) Hint to joining operator about the estimated
  2345. avg expected size of the individual partition (in bytes).
  2346. This is used in estimating the total dataset size and allow to tune
  2347. memory requirement of the individual joining workers to prevent OOMs
  2348. when joining very large datasets.
  2349. aggregator_ray_remote_args: (Optional) Parameter overriding `ray.remote`
  2350. args passed when constructing joining (aggregator) workers.
  2351. validate_schemas: (Optional) Controls whether validation of provided
  2352. configuration against input schemas will be performed (defaults to
  2353. false, since obtaining schemas could be prohibitively expensive).
  2354. Returns:
  2355. A :class:`Dataset` that holds rows of input left Dataset joined with the
  2356. right Dataset based on join type and keys.
  2357. Examples:
  2358. .. testcode::
  2359. :skipif: True
  2360. doubles_ds = ray.data.range(4).map(
  2361. lambda row: {"id": row["id"], "double": int(row["id"]) * 2}
  2362. )
  2363. squares_ds = ray.data.range(4).map(
  2364. lambda row: {"id": row["id"], "square": int(row["id"]) ** 2}
  2365. )
  2366. # Inner join example
  2367. joined_ds = doubles_ds.join(
  2368. squares_ds,
  2369. join_type="inner",
  2370. num_partitions=2,
  2371. on=("id",),
  2372. )
  2373. print(sorted(joined_ds.take_all(), key=lambda item: item["id"]))
  2374. .. testoutput::
  2375. :options: +ELLIPSIS, +NORMALIZE_WHITESPACE
  2376. [
  2377. {'id': 0, 'double': 0, 'square': 0},
  2378. {'id': 1, 'double': 2, 'square': 1},
  2379. {'id': 2, 'double': 4, 'square': 4},
  2380. {'id': 3, 'double': 6, 'square': 9}
  2381. ]
  2382. .. testcode::
  2383. :skipif: True
  2384. # Left anti-join example: find rows in doubles_ds that don't match squares_ds
  2385. partial_squares_ds = ray.data.range(2).map(
  2386. lambda row: {"id": row["id"] + 2, "square": int(row["id"]) ** 2}
  2387. )
  2388. anti_joined_ds = doubles_ds.join(
  2389. partial_squares_ds,
  2390. join_type="left_anti",
  2391. num_partitions=2,
  2392. on=("id",),
  2393. )
  2394. print(sorted(anti_joined_ds.take_all(), key=lambda item: item["id"]))
  2395. .. testoutput::
  2396. :options: +ELLIPSIS, +NORMALIZE_WHITESPACE
  2397. [
  2398. {'id': 0, 'double': 0},
  2399. {'id': 1, 'double': 2}
  2400. ]
  2401. .. testcode::
  2402. :skipif: True
  2403. # Left semi-join example: find rows in doubles_ds that have matches in squares_ds
  2404. # (only returns columns from left dataset)
  2405. semi_joined_ds = doubles_ds.join(
  2406. squares_ds,
  2407. join_type="left_semi",
  2408. num_partitions=2,
  2409. on=("id",),
  2410. )
  2411. print(sorted(semi_joined_ds.take_all(), key=lambda item: item["id"]))
  2412. .. testoutput::
  2413. :options: +ELLIPSIS, +NORMALIZE_WHITESPACE
  2414. [
  2415. {'id': 0, 'double': 0},
  2416. {'id': 1, 'double': 2},
  2417. {'id': 2, 'double': 4},
  2418. {'id': 3, 'double': 6}
  2419. ]
  2420. """
  2421. if not isinstance(on, (tuple, list)):
  2422. raise ValueError(
  2423. f"Expected tuple or list as `on` (got {type(on).__name__})"
  2424. )
  2425. if right_on and not isinstance(right_on, (tuple, list)):
  2426. raise ValueError(
  2427. f"Expected tuple or list as `right_on` (got {type(right_on).__name__})"
  2428. )
  2429. # NOTE: If no separate keys provided for the right side, assume just the left
  2430. # side ones
  2431. right_on = right_on or on
  2432. # NOTE: By default validating schemas are disabled as it could be arbitrarily
  2433. # expensive (potentially executing whole pipeline to completion) to fetch
  2434. # one currently
  2435. if validate_schemas:
  2436. left_op_schema: Optional["Schema"] = self.schema()
  2437. right_op_schema: Optional["Schema"] = ds.schema()
  2438. Join._validate_schemas(left_op_schema, right_op_schema, on, right_on)
  2439. plan = self._plan.copy()
  2440. op = Join(
  2441. left_input_op=self._logical_plan.dag,
  2442. right_input_op=ds._logical_plan.dag,
  2443. left_key_columns=on,
  2444. right_key_columns=right_on,
  2445. join_type=join_type,
  2446. num_partitions=num_partitions,
  2447. left_columns_suffix=left_suffix,
  2448. right_columns_suffix=right_suffix,
  2449. partition_size_hint=partition_size_hint,
  2450. aggregator_ray_remote_args=aggregator_ray_remote_args,
  2451. )
  2452. return Dataset(plan, LogicalPlan(op, self.context))
  2453. @AllToAllAPI
  2454. @PublicAPI(api_group=GGA_API_GROUP)
  2455. def groupby(
  2456. self,
  2457. key: Union[str, List[str], None],
  2458. num_partitions: Optional[int] = None,
  2459. ) -> "GroupedData":
  2460. """Group rows of a :class:`Dataset` according to a column.
  2461. Use this method to transform data based on a
  2462. categorical variable.
  2463. Examples:
  2464. .. testcode::
  2465. import pandas as pd
  2466. import ray
  2467. def normalize_variety(group: pd.DataFrame) -> pd.DataFrame:
  2468. for feature in group.drop(columns=["variety"]).columns:
  2469. group[feature] = group[feature] / group[feature].abs().max()
  2470. return group
  2471. ds = (
  2472. ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
  2473. .groupby("variety")
  2474. .map_groups(normalize_variety, batch_format="pandas")
  2475. )
  2476. Time complexity: O(dataset size * log(dataset size / parallelism))
  2477. Args:
  2478. key: A column name or list of column names.
  2479. If this is ``None``, place all rows in a single group.
  2480. num_partitions: Number of partitions data will be partitioned into (only
  2481. relevant if hash-shuffling strategy is used). When not set defaults
  2482. to `DataContext.min_parallelism`.
  2483. Returns:
  2484. A lazy :class:`~ray.data.grouped_data.GroupedData`.
  2485. .. seealso::
  2486. :meth:`~ray.data.grouped_data.GroupedData.map_groups`
  2487. Call this method to transform groups of data.
  2488. """
  2489. from ray.data.grouped_data import GroupedData
  2490. # Always allow None since groupby interprets that as grouping all
  2491. # records into a single global group.
  2492. if key is not None:
  2493. # Fetching the schema can trigger execution, so don't fetch it for
  2494. # input validation.
  2495. SortKey(key).validate_schema(self.schema(fetch_if_missing=False))
  2496. if num_partitions is not None and num_partitions <= 0:
  2497. raise ValueError("`num_partitions` must be a positive integer")
  2498. return GroupedData(self, key, num_partitions=num_partitions)
  2499. @AllToAllAPI
  2500. @ConsumptionAPI
  2501. @PublicAPI(api_group=GGA_API_GROUP)
  2502. def unique(self, column: str, ignore_nulls: bool = False) -> List[Any]:
  2503. """List the unique elements in a given column.
  2504. Examples:
  2505. >>> import ray
  2506. >>> ds = ray.data.from_items([1, 2, 3, 2, 3])
  2507. >>> sorted(ds.unique("item"))
  2508. [1, 2, 3]
  2509. This function is very useful for computing labels
  2510. in a machine learning dataset:
  2511. >>> import ray
  2512. >>> ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
  2513. >>> sorted(ds.unique("target"))
  2514. [0, 1, 2]
  2515. One common use case is to convert the class labels
  2516. into integers for training and inference:
  2517. >>> classes = {0: 'Setosa', 1: 'Versicolor', 2: 'Virginica'}
  2518. >>> def preprocessor(df, classes):
  2519. ... df["variety"] = df["target"].map(classes)
  2520. ... return df
  2521. >>> train_ds = ds.map_batches(
  2522. ... preprocessor, fn_kwargs={"classes": classes}, batch_format="pandas")
  2523. >>> train_ds.sort("sepal length (cm)").take(1) # Sort to make it deterministic
  2524. [{'sepal length (cm)': 4.3, ..., 'variety': 'Setosa'}]
  2525. Time complexity: O(dataset size / parallelism)
  2526. Args:
  2527. column: The column to collect unique elements over.
  2528. ignore_nulls: If ``True``, ignore null values in the column.
  2529. Returns:
  2530. A list with unique elements in the given column.
  2531. """ # noqa: E501
  2532. ret = self._aggregate_on(Unique, column, ignore_nulls=ignore_nulls)
  2533. return self._aggregate_result(ret)
  2534. @AllToAllAPI
  2535. @ConsumptionAPI
  2536. @PublicAPI(api_group=GGA_API_GROUP)
  2537. def aggregate(self, *aggs: AggregateFn) -> Union[Any, Dict[str, Any]]:
  2538. """Aggregate values using one or more functions.
  2539. Use this method to compute metrics like the product of a column.
  2540. Examples:
  2541. .. testcode::
  2542. import ray
  2543. from ray.data.aggregate import AggregateFn
  2544. ds = ray.data.from_items([{"number": i} for i in range(1, 10)])
  2545. aggregation = AggregateFn(
  2546. init=lambda column: 1,
  2547. # Apply this to each row to produce a partial aggregate result
  2548. accumulate_row=lambda a, row: a * row["number"],
  2549. # Apply this to merge partial aggregate results into a final result
  2550. merge=lambda a1, a2: a1 * a2,
  2551. name="prod"
  2552. )
  2553. print(ds.aggregate(aggregation))
  2554. .. testoutput::
  2555. {'prod': 362880}
  2556. Time complexity: O(dataset size / parallelism)
  2557. Args:
  2558. *aggs: :class:`Aggregations <ray.data.aggregate.AggregateFn>` to perform.
  2559. Returns:
  2560. A ``dict`` where each each value is an aggregation for a given column.
  2561. """
  2562. ret = self.groupby(None).aggregate(*aggs).take(1)
  2563. return ret[0] if len(ret) > 0 else None
  2564. @AllToAllAPI
  2565. @ConsumptionAPI
  2566. @PublicAPI(api_group=GGA_API_GROUP)
  2567. def sum(
  2568. self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
  2569. ) -> Union[Any, Dict[str, Any]]:
  2570. """Compute the sum of one or more columns.
  2571. Examples:
  2572. >>> import ray
  2573. >>> ray.data.range(100).sum("id")
  2574. 4950
  2575. >>> ray.data.from_items([
  2576. ... {"A": i, "B": i**2}
  2577. ... for i in range(100)
  2578. ... ]).sum(["A", "B"])
  2579. {'sum(A)': 4950, 'sum(B)': 328350}
  2580. Args:
  2581. on: a column name or a list of column names to aggregate.
  2582. ignore_nulls: Whether to ignore null values. If ``True``, null
  2583. values are ignored when computing the sum. If ``False``,
  2584. when a null value is encountered, the output is ``None``.
  2585. Ray Data considers ``np.nan``, ``None``, and ``pd.NaT`` to be null
  2586. values. Default is ``True``.
  2587. Returns:
  2588. The sum result.
  2589. For different values of ``on``, the return varies:
  2590. - ``on=None``: a dict containing the column-wise sum of all
  2591. columns,
  2592. - ``on="col"``: a scalar representing the sum of all items in
  2593. column ``"col"``,
  2594. - ``on=["col_1", ..., "col_n"]``: an n-column ``dict``
  2595. containing the column-wise sum of the provided columns.
  2596. If the dataset is empty, all values are null. If ``ignore_nulls`` is
  2597. ``False`` and any value is null, then the output is ``None``.
  2598. """
  2599. ret = self._aggregate_on(Sum, on, ignore_nulls=ignore_nulls)
  2600. return self._aggregate_result(ret)
  2601. @AllToAllAPI
  2602. @ConsumptionAPI
  2603. @PublicAPI(api_group=GGA_API_GROUP)
  2604. def min(
  2605. self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
  2606. ) -> Union[Any, Dict[str, Any]]:
  2607. """Return the minimum of one or more columns.
  2608. Examples:
  2609. >>> import ray
  2610. >>> ray.data.range(100).min("id")
  2611. 0
  2612. >>> ray.data.from_items([
  2613. ... {"A": i, "B": i**2}
  2614. ... for i in range(100)
  2615. ... ]).min(["A", "B"])
  2616. {'min(A)': 0, 'min(B)': 0}
  2617. Args:
  2618. on: a column name or a list of column names to aggregate.
  2619. ignore_nulls: Whether to ignore null values. If ``True``, null
  2620. values are ignored when computing the min; if ``False``,
  2621. when a null value is encountered, the output is ``None``.
  2622. This method considers ``np.nan``, ``None``, and ``pd.NaT`` to be null
  2623. values. Default is ``True``.
  2624. Returns:
  2625. The min result.
  2626. For different values of ``on``, the return varies:
  2627. - ``on=None``: an dict containing the column-wise min of
  2628. all columns,
  2629. - ``on="col"``: a scalar representing the min of all items in
  2630. column ``"col"``,
  2631. - ``on=["col_1", ..., "col_n"]``: an n-column dict
  2632. containing the column-wise min of the provided columns.
  2633. If the dataset is empty, all values are null. If ``ignore_nulls`` is
  2634. ``False`` and any value is null, then the output is ``None``.
  2635. """
  2636. ret = self._aggregate_on(Min, on, ignore_nulls=ignore_nulls)
  2637. return self._aggregate_result(ret)
  2638. @AllToAllAPI
  2639. @ConsumptionAPI
  2640. @PublicAPI(api_group=GGA_API_GROUP)
  2641. def max(
  2642. self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
  2643. ) -> Union[Any, Dict[str, Any]]:
  2644. """Return the maximum of one or more columns.
  2645. Examples:
  2646. >>> import ray
  2647. >>> ray.data.range(100).max("id")
  2648. 99
  2649. >>> ray.data.from_items([
  2650. ... {"A": i, "B": i**2}
  2651. ... for i in range(100)
  2652. ... ]).max(["A", "B"])
  2653. {'max(A)': 99, 'max(B)': 9801}
  2654. Args:
  2655. on: a column name or a list of column names to aggregate.
  2656. ignore_nulls: Whether to ignore null values. If ``True``, null
  2657. values are ignored when computing the max; if ``False``,
  2658. when a null value is encountered, the output is ``None``.
  2659. This method considers ``np.nan``, ``None``, and ``pd.NaT`` to be null
  2660. values. Default is ``True``.
  2661. Returns:
  2662. The max result.
  2663. For different values of ``on``, the return varies:
  2664. - ``on=None``: an dict containing the column-wise max of
  2665. all columns,
  2666. - ``on="col"``: a scalar representing the max of all items in
  2667. column ``"col"``,
  2668. - ``on=["col_1", ..., "col_n"]``: an n-column dict
  2669. containing the column-wise max of the provided columns.
  2670. If the dataset is empty, all values are null. If ``ignore_nulls`` is
  2671. ``False`` and any value is null, then the output is ``None``.
  2672. """
  2673. ret = self._aggregate_on(Max, on, ignore_nulls=ignore_nulls)
  2674. return self._aggregate_result(ret)
  2675. @AllToAllAPI
  2676. @ConsumptionAPI
  2677. @PublicAPI(api_group=GGA_API_GROUP)
  2678. def mean(
  2679. self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
  2680. ) -> Union[Any, Dict[str, Any]]:
  2681. """Compute the mean of one or more columns.
  2682. Examples:
  2683. >>> import ray
  2684. >>> ray.data.range(100).mean("id")
  2685. 49.5
  2686. >>> ray.data.from_items([
  2687. ... {"A": i, "B": i**2}
  2688. ... for i in range(100)
  2689. ... ]).mean(["A", "B"])
  2690. {'mean(A)': 49.5, 'mean(B)': 3283.5}
  2691. Args:
  2692. on: a column name or a list of column names to aggregate.
  2693. ignore_nulls: Whether to ignore null values. If ``True``, null
  2694. values are ignored when computing the mean; if ``False``,
  2695. when a null value is encountered, the output is ``None``.
  2696. This method considers ``np.nan``, ``None``, and ``pd.NaT`` to be null
  2697. values. Default is ``True``.
  2698. Returns:
  2699. The mean result.
  2700. For different values of ``on``, the return varies:
  2701. - ``on=None``: an dict containing the column-wise mean of
  2702. all columns,
  2703. - ``on="col"``: a scalar representing the mean of all items in
  2704. column ``"col"``,
  2705. - ``on=["col_1", ..., "col_n"]``: an n-column dict
  2706. containing the column-wise mean of the provided columns.
  2707. If the dataset is empty, all values are null. If ``ignore_nulls`` is
  2708. ``False`` and any value is null, then the output is ``None``.
  2709. """
  2710. ret = self._aggregate_on(Mean, on, ignore_nulls=ignore_nulls)
  2711. return self._aggregate_result(ret)
  2712. @AllToAllAPI
  2713. @ConsumptionAPI
  2714. @PublicAPI(api_group=GGA_API_GROUP)
  2715. def std(
  2716. self,
  2717. on: Optional[Union[str, List[str]]] = None,
  2718. ddof: int = 1,
  2719. ignore_nulls: bool = True,
  2720. ) -> Union[Any, Dict[str, Any]]:
  2721. """Compute the standard deviation of one or more columns.
  2722. .. note::
  2723. This method uses Welford's online method for an accumulator-style
  2724. computation of the standard deviation. This method has
  2725. numerical stability, and is computable in a single pass. This may give
  2726. different (but more accurate) results than NumPy, Pandas, and sklearn, which
  2727. use a less numerically stable two-pass algorithm.
  2728. To learn more, see
  2729. `the Wikapedia article <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm>`_.
  2730. Examples:
  2731. >>> import ray
  2732. >>> round(ray.data.range(100).std("id", ddof=0), 5)
  2733. 28.86607
  2734. >>> result = ray.data.from_items([
  2735. ... {"A": i, "B": i**2}
  2736. ... for i in range(100)
  2737. ... ]).std(["A", "B"])
  2738. >>> [(key, round(value, 10)) for key, value in result.items()]
  2739. [('std(A)', 29.0114919759), ('std(B)', 2968.1748039269)]
  2740. Args:
  2741. on: a column name or a list of column names to aggregate.
  2742. ddof: Delta Degrees of Freedom. The divisor used in calculations
  2743. is ``N - ddof``, where ``N`` represents the number of elements.
  2744. ignore_nulls: Whether to ignore null values. If ``True``, null
  2745. values are ignored when computing the std; if ``False``,
  2746. when a null value is encountered, the output is ``None``.
  2747. This method considers ``np.nan``, ``None``, and ``pd.NaT`` to be null
  2748. values. Default is ``True``.
  2749. Returns:
  2750. The standard deviation result.
  2751. For different values of ``on``, the return varies:
  2752. - ``on=None``: an dict containing the column-wise std of
  2753. all columns,
  2754. - ``on="col"``: a scalar representing the std of all items in
  2755. column ``"col"``,
  2756. - ``on=["col_1", ..., "col_n"]``: an n-column dict
  2757. containing the column-wise std of the provided columns.
  2758. If the dataset is empty, all values are null. If ``ignore_nulls`` is
  2759. ``False`` and any value is null, then the output is ``None``.
  2760. """ # noqa: E501
  2761. ret = self._aggregate_on(Std, on, ignore_nulls=ignore_nulls, ddof=ddof)
  2762. return self._aggregate_result(ret)
  2763. @AllToAllAPI
  2764. @ConsumptionAPI
  2765. @PublicAPI(api_group=GGA_API_GROUP, stability="alpha")
  2766. def summary(
  2767. self,
  2768. columns: Optional[List[str]] = None,
  2769. override_dtype_agg_mapping: Optional[
  2770. Dict[DataType, Callable[[str], List[AggregateFnV2]]]
  2771. ] = None,
  2772. ) -> "DatasetSummary":
  2773. """Generate a statistical summary of the dataset, organized by data type.
  2774. This method computes various statistics for different column dtypes:
  2775. - For numerical dtypes (int*, float*, decimal, bool): count, mean, min, max, std, approx_quantile (median), missing%, zero%
  2776. - For string and binary dtypes: count, missing%, approx_top_k (top 10 values)
  2777. - For temporal dtypes (timestamp, date, time, duration): count, min, max, missing%
  2778. - For other dtypes: count, missing%, approx_top_k
  2779. You can customize the aggregations performed for specific data types using the
  2780. `override_dtype_agg_mapping` parameter.
  2781. The summary separates statistics into two tables:
  2782. - Schema-matching stats: Statistics that preserve the original column type (e.g., min/max for integers)
  2783. - Schema-changing stats: Statistics that change the type (e.g., mean converts int to float)
  2784. Examples:
  2785. >>> import ray
  2786. >>> ds = ray.data.from_items([
  2787. ... {"age": 25, "salary": 50000, "name": "Alice", "city": "NYC"},
  2788. ... {"age": 30, "salary": 60000, "name": None, "city": "LA"},
  2789. ... {"age": 0, "salary": None, "name": "Bob", "city": None},
  2790. ... ])
  2791. >>> summary = ds.summary()
  2792. >>> # Get combined pandas DataFrame with all statistics
  2793. >>> summary.to_pandas() # doctest: +SKIP
  2794. statistic age city name salary
  2795. 0 approx_quantile[0] 25.000000 None None 60000.000000
  2796. 1 approx_topk[0] NaN {'city': 'LA', 'count': 1} {'count': 1, 'name': 'Bob'} NaN
  2797. 2 approx_topk[1] NaN {'city': 'NYC', 'count': 1} {'count': 1, 'name': 'Alice'} NaN
  2798. 3 count 3.000000 3 3 3.000000
  2799. 4 max 30.000000 NaN NaN 60000.000000
  2800. 5 mean 18.333333 None None 55000.000000
  2801. 6 min 0.000000 NaN NaN 50000.000000
  2802. 7 missing_pct 0.000000 33.333333 33.333333 33.333333
  2803. 8 std 13.123346 None None 5000.000000
  2804. 9 zero_pct 33.333333 None None 0.000000
  2805. >>> # Access individual column statistics
  2806. >>> summary.get_column_stats("age") # doctest: +SKIP
  2807. statistic value
  2808. 0 approx_quantile[0] 25.000000
  2809. 1 approx_topk[0] NaN
  2810. 2 approx_topk[1] NaN
  2811. 3 count 3.000000
  2812. 4 max 30.000000
  2813. 5 mean 18.333333
  2814. 6 min 0.000000
  2815. 7 missing_pct 0.000000
  2816. 8 std 13.123346
  2817. 9 zero_pct 33.333333
  2818. Custom aggregations for specific types:
  2819. >>> from ray.data.datatype import DataType
  2820. >>> from ray.data.aggregate import Sum, Count
  2821. >>> # Override aggregations for int64 columns
  2822. >>> custom_mapping = {
  2823. ... DataType.int64(): lambda col: [Count(on=col), Sum(on=col)]
  2824. ... }
  2825. >>> summary = ds.summary(override_dtype_agg_mapping=custom_mapping)
  2826. Args:
  2827. columns: Optional list of column names to include in the summary.
  2828. If None, all columns will be included.
  2829. override_dtype_agg_mapping: Optional mapping from DataType to factory
  2830. functions. Each factory function takes a column name and returns a
  2831. list of aggregators for that column. This will be merged with the
  2832. default mapping, with user-provided mappings taking precedence.
  2833. Returns:
  2834. A DatasetSummary object with methods to access statistics and the
  2835. original dataset schema. Use `to_pandas()` to get all statistics
  2836. as a DataFrame, or `get_column_stats(col)` for a specific column
  2837. """
  2838. from ray.data.stats import (
  2839. DatasetSummary,
  2840. _build_summary_table,
  2841. _dtype_aggregators_for_dataset,
  2842. _parse_summary_stats,
  2843. )
  2844. # Compute aggregations
  2845. dtype_aggs = _dtype_aggregators_for_dataset(
  2846. self.schema(),
  2847. columns=columns,
  2848. dtype_agg_mapping=override_dtype_agg_mapping,
  2849. )
  2850. if not dtype_aggs.aggregators:
  2851. raise ValueError(
  2852. "summary() requires at least one column with a supported type. "
  2853. f"Columns provided: {columns if columns is not None else 'all'}. "
  2854. "Check that the specified columns exist and have supported types "
  2855. "(numeric, string, binary, or temporal). Columns with None or "
  2856. "object types are skipped."
  2857. )
  2858. aggs_dataset = self.groupby(None).aggregate(*dtype_aggs.aggregators)
  2859. agg_result = aggs_dataset.take(1)[0]
  2860. # Separate statistics by whether they preserve original column types
  2861. original_schema = self.schema().base_schema
  2862. agg_schema = aggs_dataset.schema().base_schema
  2863. (
  2864. schema_matching_stats,
  2865. schema_changing_stats,
  2866. all_columns,
  2867. ) = _parse_summary_stats(
  2868. agg_result, original_schema, agg_schema, dtype_aggs.aggregators
  2869. )
  2870. # Build PyArrow tables
  2871. schema_matching_table = _build_summary_table(
  2872. schema_matching_stats, all_columns, original_schema, preserve_types=True
  2873. )
  2874. schema_changing_table = _build_summary_table(
  2875. schema_changing_stats, all_columns, original_schema, preserve_types=False
  2876. )
  2877. return DatasetSummary(
  2878. _stats_matching_column_dtype=schema_matching_table,
  2879. _stats_mismatching_column_dtype=schema_changing_table,
  2880. dataset_schema=original_schema,
  2881. columns=list(all_columns),
  2882. )
  2883. @AllToAllAPI
  2884. @PublicAPI(api_group=SSR_API_GROUP)
  2885. def sort(
  2886. self,
  2887. key: Union[str, List[str]],
  2888. descending: Union[bool, List[bool]] = False,
  2889. boundaries: List[Union[int, float]] = None,
  2890. ) -> "Dataset":
  2891. """Sort the dataset by the specified key column or key function.
  2892. The `key` parameter must be specified (i.e., it cannot be `None`).
  2893. .. note::
  2894. If provided, the `boundaries` parameter can only be used to partition
  2895. the first sort key.
  2896. Examples:
  2897. >>> import ray
  2898. >>> ds = ray.data.range(15)
  2899. >>> ds = ds.sort("id", descending=False, boundaries=[5, 10])
  2900. >>> for df in ray.get(ds.to_pandas_refs()):
  2901. ... print(df)
  2902. id
  2903. 0 0
  2904. 1 1
  2905. 2 2
  2906. 3 3
  2907. 4 4
  2908. id
  2909. 0 5
  2910. 1 6
  2911. 2 7
  2912. 3 8
  2913. 4 9
  2914. id
  2915. 0 10
  2916. 1 11
  2917. 2 12
  2918. 3 13
  2919. 4 14
  2920. Time complexity: O(dataset size * log(dataset size / parallelism))
  2921. Args:
  2922. key: The column or a list of columns to sort by.
  2923. descending: Whether to sort in descending order. Must be a boolean or a list
  2924. of booleans matching the number of the columns.
  2925. boundaries: The list of values based on which to repartition the dataset.
  2926. For example, if the input boundary is [10,20], rows with values less
  2927. than 10 will be divided into the first block, rows with values greater
  2928. than or equal to 10 and less than 20 will be divided into the
  2929. second block, and rows with values greater than or equal to 20
  2930. will be divided into the third block. If not provided, the
  2931. boundaries will be sampled from the input blocks. This feature
  2932. only supports numeric columns right now.
  2933. Returns:
  2934. A new, sorted :class:`Dataset`.
  2935. Raises:
  2936. ``ValueError``: if the sort key is None.
  2937. """
  2938. if key is None:
  2939. raise ValueError("The 'key' parameter cannot be None for sorting.")
  2940. sort_key = SortKey(key, descending, boundaries)
  2941. plan = self._plan.copy()
  2942. op = Sort(
  2943. self._logical_plan.dag,
  2944. sort_key=sort_key,
  2945. )
  2946. logical_plan = LogicalPlan(op, self.context)
  2947. return Dataset(plan, logical_plan)
  2948. @PublicAPI(api_group=SMJ_API_GROUP)
  2949. def zip(self, *other: "Dataset") -> "Dataset":
  2950. """Zip the columns of this dataset with the columns of another.
  2951. The datasets must have the same number of rows. Their column sets are
  2952. merged, and any duplicate column names are disambiguated with suffixes like
  2953. ``"_1"``.
  2954. .. note::
  2955. The smaller of the two datasets is repartitioned to align the number
  2956. of rows per block with the larger dataset.
  2957. .. note::
  2958. Zipped datasets aren't lineage-serializable. As a result, they can't be used
  2959. as a tunable hyperparameter in Ray Tune.
  2960. Examples:
  2961. >>> import ray
  2962. >>> ds1 = ray.data.range(5)
  2963. >>> ds2 = ray.data.range(5)
  2964. >>> ds3 = ray.data.range(5)
  2965. >>> ds1.zip(ds2, ds3).take_batch()
  2966. {'id': array([0, 1, 2, 3, 4]), 'id_1': array([0, 1, 2, 3, 4]), 'id_2': array([0, 1, 2, 3, 4])}
  2967. Args:
  2968. *other: The datasets to combine with this one. The datasets
  2969. must have the same row count as this dataset, otherwise the
  2970. ValueError is raised.
  2971. Returns:
  2972. A :class:`Dataset` containing the columns of the second dataset
  2973. concatenated horizontally with the columns of the first dataset,
  2974. with duplicate column names disambiguated with suffixes like ``"_1"``.
  2975. Raises:
  2976. ValueError: If the datasets have different row counts.
  2977. """
  2978. plan = self._plan.copy()
  2979. op = Zip(self._logical_plan.dag, *[other._logical_plan.dag for other in other])
  2980. logical_plan = LogicalPlan(op, self.context)
  2981. return Dataset(plan, logical_plan)
  2982. @PublicAPI(api_group=BT_API_GROUP)
  2983. def limit(self, limit: int) -> "Dataset":
  2984. """Truncate the dataset to the first ``limit`` rows.
  2985. Unlike :meth:`~Dataset.take`, this method doesn't move data to the caller's
  2986. machine. Instead, it returns a new :class:`Dataset` pointing to the truncated
  2987. distributed data.
  2988. Examples:
  2989. >>> import ray
  2990. >>> ds = ray.data.range(1000)
  2991. >>> ds.limit(5).count()
  2992. 5
  2993. Time complexity: O(limit specified)
  2994. Args:
  2995. limit: The size of the dataset to truncate to.
  2996. Returns:
  2997. The truncated dataset.
  2998. """
  2999. plan = self._plan.copy()
  3000. op = Limit(self._logical_plan.dag, limit=limit)
  3001. logical_plan = LogicalPlan(op, self.context)
  3002. return Dataset(plan, logical_plan)
  3003. @ConsumptionAPI
  3004. @PublicAPI(api_group=CD_API_GROUP)
  3005. def take_batch(
  3006. self, batch_size: int = 20, *, batch_format: Optional[str] = "default"
  3007. ) -> DataBatch:
  3008. """Return up to ``batch_size`` rows from the :class:`Dataset` in a batch.
  3009. Ray Data represents batches as NumPy arrays or pandas DataFrames. You can
  3010. configure the batch type by specifying ``batch_format``.
  3011. This method is useful for inspecting inputs to :meth:`~Dataset.map_batches`.
  3012. .. warning::
  3013. :meth:`~Dataset.take_batch` moves up to ``batch_size`` rows to the caller's
  3014. machine. If ``batch_size`` is large, this method can cause an `
  3015. ``OutOfMemory`` error on the caller.
  3016. Examples:
  3017. >>> import ray
  3018. >>> ds = ray.data.range(100)
  3019. >>> ds.take_batch(5)
  3020. {'id': array([0, 1, 2, 3, 4])}
  3021. Time complexity: O(batch_size specified)
  3022. Args:
  3023. batch_size: The maximum number of rows to return.
  3024. batch_format: If ``"default"`` or ``"numpy"``, batches are
  3025. ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are
  3026. ``pandas.DataFrame``.
  3027. Returns:
  3028. A batch of up to ``batch_size`` rows from the dataset.
  3029. Raises:
  3030. ``ValueError``: if the dataset is empty.
  3031. """
  3032. batch_format = _apply_batch_format(batch_format)
  3033. limited_ds = self.limit(batch_size)
  3034. try:
  3035. res = next(
  3036. iter(
  3037. limited_ds.iter_batches(
  3038. batch_size=batch_size,
  3039. prefetch_batches=0,
  3040. batch_format=batch_format,
  3041. )
  3042. )
  3043. )
  3044. except StopIteration:
  3045. raise ValueError("The dataset is empty.")
  3046. self._synchronize_progress_bar()
  3047. # Save the computed stats to the original dataset.
  3048. self._plan._snapshot_stats = limited_ds._plan.stats()
  3049. return res
  3050. @ConsumptionAPI
  3051. @PublicAPI(api_group=CD_API_GROUP)
  3052. def take(self, limit: int = 20) -> List[Dict[str, Any]]:
  3053. """Return up to ``limit`` rows from the :class:`Dataset`.
  3054. This method is useful for inspecting data.
  3055. .. warning::
  3056. :meth:`~Dataset.take` moves up to ``limit`` rows to the caller's machine. If
  3057. ``limit`` is large, this method can cause an ``OutOfMemory`` error on the
  3058. caller.
  3059. Examples:
  3060. >>> import ray
  3061. >>> ds = ray.data.range(100)
  3062. >>> ds.take(3)
  3063. [{'id': 0}, {'id': 1}, {'id': 2}]
  3064. Time complexity: O(limit specified)
  3065. Args:
  3066. limit: The maximum number of rows to return.
  3067. Returns:
  3068. A list of up to ``limit`` rows from the dataset.
  3069. .. seealso::
  3070. :meth:`~Dataset.take_all`
  3071. Call this method to return all rows.
  3072. """
  3073. if ray.util.log_once("dataset_take"):
  3074. logger.info(
  3075. "Tip: Use `take_batch()` instead of `take() / show()` to return "
  3076. "records in pandas or numpy batch format."
  3077. )
  3078. output = []
  3079. limited_ds = self.limit(limit)
  3080. for row in limited_ds.iter_rows():
  3081. output.append(row)
  3082. if len(output) >= limit:
  3083. break
  3084. self._synchronize_progress_bar()
  3085. # Save the computed stats to the original dataset.
  3086. self._plan._snapshot_stats = limited_ds._plan.stats()
  3087. return output
  3088. @ConsumptionAPI
  3089. @PublicAPI(api_group=CD_API_GROUP)
  3090. def take_all(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
  3091. """Return all of the rows in this :class:`Dataset`.
  3092. This method is useful for inspecting small datasets.
  3093. .. warning::
  3094. :meth:`~Dataset.take_all` moves the entire dataset to the caller's
  3095. machine. If the dataset is large, this method can cause an
  3096. ``OutOfMemory`` error on the caller.
  3097. Examples:
  3098. >>> import ray
  3099. >>> ds = ray.data.range(5)
  3100. >>> ds.take_all()
  3101. [{'id': 0}, {'id': 1}, {'id': 2}, {'id': 3}, {'id': 4}]
  3102. Time complexity: O(dataset size)
  3103. Args:
  3104. limit: Raise an error if the size exceeds the specified limit.
  3105. Returns:
  3106. A list of all the rows in the dataset.
  3107. .. seealso::
  3108. :meth:`~Dataset.take`
  3109. Call this method to return a specific number of rows.
  3110. """
  3111. output = []
  3112. for row in self.iter_rows():
  3113. output.append(row)
  3114. if limit is not None and len(output) > limit:
  3115. raise ValueError(
  3116. f"The dataset has more than the given limit of {limit} records."
  3117. )
  3118. self._synchronize_progress_bar()
  3119. return output
  3120. @ConsumptionAPI
  3121. @PublicAPI(api_group=CD_API_GROUP)
  3122. def show(self, limit: int = 20) -> None:
  3123. """Print up to the given number of rows from the :class:`Dataset`.
  3124. This method is useful for inspecting data.
  3125. Examples:
  3126. >>> import ray
  3127. >>> ds = ray.data.range(100)
  3128. >>> ds.show(3)
  3129. {'id': 0}
  3130. {'id': 1}
  3131. {'id': 2}
  3132. Time complexity: O(limit specified)
  3133. Args:
  3134. limit: The maximum number of row to print.
  3135. .. seealso::
  3136. :meth:`~Dataset.take`
  3137. Call this method to get (not print) a given number of rows.
  3138. """
  3139. for row in self.take(limit):
  3140. print(row)
  3141. @ConsumptionAPI(
  3142. if_more_than_read=True,
  3143. datasource_metadata="row count",
  3144. pattern="Examples:",
  3145. )
  3146. @PublicAPI(api_group=IM_API_GROUP)
  3147. def count(self) -> int:
  3148. """Count the number of rows in the dataset.
  3149. For Datasets which only read Parquet files (created with
  3150. :meth:`~ray.data.read_parquet`), this method reads the file metadata to
  3151. efficiently count the number of rows without reading in the entire data.
  3152. Examples:
  3153. >>> import ray
  3154. >>> ds = ray.data.range(10)
  3155. >>> ds.count()
  3156. 10
  3157. Returns:
  3158. The number of records in the dataset.
  3159. """
  3160. # Handle empty dataset.
  3161. if self._plan.initial_num_blocks() == 0:
  3162. return 0
  3163. # For parquet, we can return the count directly from metadata.
  3164. meta_count = self._meta_count()
  3165. if meta_count is not None:
  3166. return meta_count
  3167. plan = self._plan.copy()
  3168. # NOTE: Project the dataset to avoid the need to carry actual
  3169. # data when we're only interested in the total count
  3170. count_op = Count(Project(self._logical_plan.dag, exprs=[]))
  3171. logical_plan = LogicalPlan(count_op, self.context)
  3172. count_ds = Dataset(plan, logical_plan)
  3173. count = 0
  3174. for batch in count_ds.iter_batches(batch_size=None):
  3175. assert Count.COLUMN_NAME in batch, (
  3176. "Outputs from the 'Count' logical operator should contain a column "
  3177. f"named '{Count.COLUMN_NAME}'"
  3178. )
  3179. count += batch[Count.COLUMN_NAME].sum()
  3180. # Explicitly cast to int to avoid returning `np.int64`, which is the result
  3181. # from calculating `sum()` from numpy batches.
  3182. return int(count)
  3183. @ConsumptionAPI(
  3184. if_more_than_read=True,
  3185. datasource_metadata="schema",
  3186. extra_condition="or if ``fetch_if_missing=True`` (the default)",
  3187. pattern="Time complexity:",
  3188. )
  3189. @PublicAPI(api_group=IM_API_GROUP)
  3190. def schema(self, fetch_if_missing: bool = True) -> Optional["Schema"]:
  3191. """Return the schema of the dataset.
  3192. Examples:
  3193. >>> import ray
  3194. >>> ds = ray.data.range(10)
  3195. >>> ds.schema()
  3196. Column Type
  3197. ------ ----
  3198. id int64
  3199. Time complexity: O(1)
  3200. Args:
  3201. fetch_if_missing: If True, synchronously fetch the schema if it's
  3202. not known. If False, None is returned if the schema is not known.
  3203. Default is True.
  3204. Returns:
  3205. The :class:`ray.data.Schema` class of the records, or None if the
  3206. schema is not known and fetch_if_missing is False.
  3207. """
  3208. context = self._plan._context
  3209. # First check if the schema is already known from materialized blocks.
  3210. base_schema = self._plan.schema(fetch_if_missing=False)
  3211. if base_schema is not None:
  3212. return Schema(base_schema, data_context=context)
  3213. # Lazily execute only the first block to minimize computation. We achieve this
  3214. # by appending a Limit[1] operation to a copy of this Dataset, which we then
  3215. # execute to get its schema.
  3216. base_schema = self.limit(1)._plan.schema(fetch_if_missing=fetch_if_missing)
  3217. if base_schema is not None:
  3218. self._plan.cache_schema(base_schema)
  3219. return Schema(base_schema, data_context=context)
  3220. else:
  3221. return None
  3222. @ConsumptionAPI(
  3223. if_more_than_read=True,
  3224. datasource_metadata="schema",
  3225. extra_condition="or if ``fetch_if_missing=True`` (the default)",
  3226. pattern="Time complexity:",
  3227. )
  3228. @PublicAPI(api_group=IM_API_GROUP)
  3229. def columns(self, fetch_if_missing: bool = True) -> Optional[List[str]]:
  3230. """Returns the columns of this Dataset.
  3231. Time complexity: O(1)
  3232. Example:
  3233. >>> import ray
  3234. >>> # Create dataset from synthetic data.
  3235. >>> ds = ray.data.range(1000)
  3236. >>> ds.columns()
  3237. ['id']
  3238. Args:
  3239. fetch_if_missing: If True, synchronously fetch the column names from the
  3240. schema if it's not known. If False, None is returned if the schema is
  3241. not known. Default is True.
  3242. Returns:
  3243. A list of the column names for this Dataset or None if schema is not known
  3244. and `fetch_if_missing` is False.
  3245. """
  3246. schema = self.schema(fetch_if_missing=fetch_if_missing)
  3247. if schema is not None:
  3248. return schema.names
  3249. return None
  3250. @PublicAPI(api_group=IM_API_GROUP)
  3251. def num_blocks(self) -> int:
  3252. """Return the number of blocks of this :class:`Dataset`.
  3253. This method is only implemented for :class:`~ray.data.MaterializedDataset`,
  3254. since the number of blocks may dynamically change during execution.
  3255. For instance, during read and transform operations, Ray Data may dynamically
  3256. adjust the number of blocks to respect memory limits, increasing the
  3257. number of blocks at runtime.
  3258. Returns:
  3259. The number of blocks of this :class:`Dataset`.
  3260. """
  3261. raise NotImplementedError(
  3262. "Number of blocks is only available for `MaterializedDataset`,"
  3263. "because the number of blocks may dynamically change during execution."
  3264. "Call `ds.materialize()` to get a `MaterializedDataset`."
  3265. )
  3266. @ConsumptionAPI
  3267. @PublicAPI(api_group=IM_API_GROUP)
  3268. def size_bytes(self) -> int:
  3269. """Return the in-memory size of the dataset.
  3270. Examples:
  3271. >>> import ray
  3272. >>> ds = ray.data.range(10)
  3273. >>> ds.size_bytes()
  3274. 80
  3275. Returns:
  3276. The in-memory size of the dataset in bytes, or None if the
  3277. in-memory size is not known.
  3278. """
  3279. # If the size is known from metadata, return it.
  3280. if self._logical_plan.dag.infer_metadata().size_bytes is not None:
  3281. return self._logical_plan.dag.infer_metadata().size_bytes
  3282. metadata = self._plan.execute().metadata
  3283. if not metadata or metadata[0].size_bytes is None:
  3284. return None
  3285. return sum(m.size_bytes for m in metadata)
  3286. @ConsumptionAPI
  3287. @PublicAPI(api_group=IM_API_GROUP)
  3288. def input_files(self) -> List[str]:
  3289. """Return the list of input files for the dataset.
  3290. Examples:
  3291. >>> import ray
  3292. >>> ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
  3293. >>> ds.input_files()
  3294. ['ray-example-data/iris.csv']
  3295. Returns:
  3296. The list of input files used to create the dataset, or an empty
  3297. list if the input files is not known.
  3298. """
  3299. return list(set(self._plan.input_files()))
  3300. @ConsumptionAPI
  3301. @PublicAPI(api_group=IOC_API_GROUP)
  3302. def write_parquet(
  3303. self,
  3304. path: str,
  3305. *,
  3306. partition_cols: Optional[List[str]] = None,
  3307. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  3308. try_create_dir: bool = True,
  3309. arrow_open_stream_args: Optional[Dict[str, Any]] = None,
  3310. filename_provider: Optional[FilenameProvider] = None,
  3311. arrow_parquet_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
  3312. min_rows_per_file: Optional[int] = None,
  3313. max_rows_per_file: Optional[int] = None,
  3314. ray_remote_args: Dict[str, Any] = None,
  3315. concurrency: Optional[int] = None,
  3316. num_rows_per_file: Optional[int] = None,
  3317. mode: SaveMode = SaveMode.APPEND,
  3318. **arrow_parquet_args,
  3319. ) -> None:
  3320. """Writes the :class:`~ray.data.Dataset` to parquet files under the provided ``path``.
  3321. The number of files is determined by the number of blocks in the dataset.
  3322. To control the number of number of blocks, call
  3323. :meth:`~ray.data.Dataset.repartition`.
  3324. If pyarrow can't represent your data, this method errors.
  3325. By default, the format of the output files is ``{uuid}_{block_idx}.parquet``,
  3326. where ``uuid`` is a unique id for the dataset. To modify this behavior,
  3327. implement a custom :class:`~ray.data.datasource.FilenameProvider` and pass it in
  3328. as the ``filename_provider`` argument.
  3329. Examples:
  3330. >>> import ray
  3331. >>> ds = ray.data.range(100)
  3332. >>> ds.write_parquet("local:///tmp/data/")
  3333. Time complexity: O(dataset size / parallelism)
  3334. Args:
  3335. path: The path to the destination root directory, where
  3336. parquet files are written to.
  3337. partition_cols: Column names by which to partition the dataset.
  3338. Files are writted in Hive partition style.
  3339. filesystem: The pyarrow filesystem implementation to write to.
  3340. These filesystems are specified in the
  3341. `pyarrow docs <https://arrow.apache.org/docs\
  3342. /python/api/filesystems.html#filesystem-implementations>`_.
  3343. Specify this if you need to provide specific configurations to the
  3344. filesystem. By default, the filesystem is automatically selected based
  3345. on the scheme of the paths. For example, if the path begins with
  3346. ``s3://``, the ``S3FileSystem`` is used.
  3347. try_create_dir: If ``True``, attempts to create all directories in the
  3348. destination path. Does nothing if all directories already
  3349. exist. Defaults to ``True``.
  3350. arrow_open_stream_args: kwargs passed to
  3351. `pyarrow.fs.FileSystem.open_output_stream <https://arrow.apache.org\
  3352. /docs/python/generated/pyarrow.fs.FileSystem.html\
  3353. #pyarrow.fs.FileSystem.open_output_stream>`_, which is used when
  3354. opening the file to write to.
  3355. filename_provider: A :class:`~ray.data.datasource.FilenameProvider`
  3356. implementation. Use this parameter to customize what your filenames
  3357. look like. The filename is expected to be templatized with `{i}`
  3358. to ensure unique filenames when writing multiple files. If it's not
  3359. templatized, Ray Data will add `{i}` to the filename to ensure
  3360. compatibility with the pyarrow `write_dataset <https://arrow.apache.org/docs/python/generated/pyarrow.parquet.write_dataset.html>`_.
  3361. arrow_parquet_args_fn: Callable that returns a dictionary of write
  3362. arguments that are provided to `pyarrow.parquet.ParquetWriter() <https:/\
  3363. /arrow.apache.org/docs/python/generated/\
  3364. pyarrow.parquet.ParquetWriter.html>`_
  3365. when writing each block to a file. Overrides
  3366. any duplicate keys from ``arrow_parquet_args``. Use this argument
  3367. instead of ``arrow_parquet_args`` if any of your write arguments
  3368. can't pickled, or if you'd like to lazily resolve the write
  3369. arguments for each dataset block. See the note below for more details.
  3370. min_rows_per_file: [Experimental] The target minimum number of rows to write
  3371. to each file. If ``None``, Ray Data writes a system-chosen number of
  3372. rows to each file. If the number of rows per block is larger than the
  3373. specified value, Ray Data writes the number of rows per block to each file.
  3374. The specified value is a hint, not a strict limit. Ray Data
  3375. might write more or fewer rows to each file.
  3376. max_rows_per_file: [Experimental] The target maximum number of rows to write
  3377. to each file. If ``None``, Ray Data writes a system-chosen number of
  3378. rows to each file. If the number of rows per block is smaller than the
  3379. specified value, Ray Data writes the number of rows per block to each file.
  3380. The specified value is a hint, not a strict limit. Ray Data
  3381. might write more or fewer rows to each file. If both ``min_rows_per_file``
  3382. and ``max_rows_per_file`` are specified, ``max_rows_per_file`` takes
  3383. precedence when they cannot both be satisfied.
  3384. ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks.
  3385. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  3386. to control number of tasks to run concurrently. This doesn't change the
  3387. total number of tasks run. By default, concurrency is dynamically
  3388. decided based on the available resources.
  3389. num_rows_per_file: [Deprecated] Use min_rows_per_file instead.
  3390. arrow_parquet_args: Options to pass to
  3391. `pyarrow.parquet.ParquetWriter() <https:/\
  3392. /arrow.apache.org/docs/python/generated/\
  3393. pyarrow.parquet.ParquetWriter.html>`_, which is used to write
  3394. out each block to a file. See `arrow_parquet_args_fn` for more detail.
  3395. mode: Determines how to handle existing files. Valid modes are "overwrite", "error",
  3396. "ignore", "append". Defaults to "append".
  3397. NOTE: This method isn't atomic. "Overwrite" first deletes all the data
  3398. before writing to `path`.
  3399. .. note::
  3400. When using `arrow_parquet_args` or `arrow_parquet_args_fn` to pass extra
  3401. options to pyarrow, there are some special cases:
  3402. - `partitioning_flavor`: if it's not provided, default is "hive" in Ray Data.
  3403. Otherwise, it follows pyarrow's behavior: `None` for pyarrow's DirectoryPartitioning,
  3404. "hive" for HivePartitioning, and "filename" for FilenamePartitioning.
  3405. See `pyarrow.dataset.partitioning` <https://arrow.apache.org/docs/python/generated/pyarrow.dataset.partitioning.html>_.
  3406. - `row_group_size`: if provided, it's passed to
  3407. `pyarrow.parquet.ParquetWriter.write_table() <https:/\
  3408. /arrow.apache.org/docs/python/generated/pyarrow\
  3409. .parquet.ParquetWriter.html\
  3410. #pyarrow.parquet.ParquetWriter.write_table>`_.
  3411. """ # noqa: E501
  3412. if arrow_parquet_args_fn is None:
  3413. arrow_parquet_args_fn = lambda: {} # noqa: E731
  3414. effective_min_rows, effective_max_rows = _validate_rows_per_file_args(
  3415. num_rows_per_file=num_rows_per_file,
  3416. min_rows_per_file=min_rows_per_file,
  3417. max_rows_per_file=max_rows_per_file,
  3418. )
  3419. datasink = ParquetDatasink(
  3420. path,
  3421. partition_cols=partition_cols,
  3422. arrow_parquet_args_fn=arrow_parquet_args_fn,
  3423. arrow_parquet_args=arrow_parquet_args,
  3424. min_rows_per_file=effective_min_rows,
  3425. max_rows_per_file=effective_max_rows,
  3426. filesystem=filesystem,
  3427. try_create_dir=try_create_dir,
  3428. open_stream_args=arrow_open_stream_args,
  3429. filename_provider=filename_provider,
  3430. dataset_uuid=self._uuid,
  3431. mode=mode,
  3432. )
  3433. self.write_datasink(
  3434. datasink,
  3435. ray_remote_args=ray_remote_args,
  3436. concurrency=concurrency,
  3437. )
  3438. @ConsumptionAPI
  3439. @PublicAPI(api_group=IOC_API_GROUP)
  3440. def write_json(
  3441. self,
  3442. path: str,
  3443. *,
  3444. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  3445. try_create_dir: bool = True,
  3446. arrow_open_stream_args: Optional[Dict[str, Any]] = None,
  3447. filename_provider: Optional[FilenameProvider] = None,
  3448. pandas_json_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
  3449. min_rows_per_file: Optional[int] = None,
  3450. ray_remote_args: Dict[str, Any] = None,
  3451. concurrency: Optional[int] = None,
  3452. num_rows_per_file: Optional[int] = None,
  3453. mode: SaveMode = SaveMode.APPEND,
  3454. **pandas_json_args,
  3455. ) -> None:
  3456. """Writes the :class:`~ray.data.Dataset` to JSON and JSONL files.
  3457. The number of files is determined by the number of blocks in the dataset.
  3458. To control the number of number of blocks, call
  3459. :meth:`~ray.data.Dataset.repartition`.
  3460. This method is only supported for datasets with records that are convertible to
  3461. pandas dataframes.
  3462. By default, the format of the output files is ``{uuid}_{block_idx}.json``,
  3463. where ``uuid`` is a unique id for the dataset. To modify this behavior,
  3464. implement a custom :class:`~ray.data.datasource.FilenameProvider` and pass it in
  3465. as the ``filename_provider`` argument.
  3466. Examples:
  3467. Write the dataset as JSON file to a local directory.
  3468. >>> import ray
  3469. >>> import pandas as pd
  3470. >>> ds = ray.data.from_pandas([pd.DataFrame({"one": [1], "two": ["a"]})])
  3471. >>> ds.write_json("local:///tmp/data")
  3472. Write the dataset as JSONL files to a local directory.
  3473. >>> ds = ray.data.read_json("s3://anonymous@ray-example-data/train.jsonl")
  3474. >>> ds.write_json("local:///tmp/data")
  3475. Time complexity: O(dataset size / parallelism)
  3476. Args:
  3477. path: The path to the destination root directory, where
  3478. the JSON files are written to.
  3479. filesystem: The pyarrow filesystem implementation to write to.
  3480. These filesystems are specified in the
  3481. `pyarrow docs <https://arrow.apache.org/docs\
  3482. /python/api/filesystems.html#filesystem-implementations>`_.
  3483. Specify this if you need to provide specific configurations to the
  3484. filesystem. By default, the filesystem is automatically selected based
  3485. on the scheme of the paths. For example, if the path begins with
  3486. ``s3://``, the ``S3FileSystem`` is used.
  3487. try_create_dir: If ``True``, attempts to create all directories in the
  3488. destination path. Does nothing if all directories already
  3489. exist. Defaults to ``True``.
  3490. arrow_open_stream_args: kwargs passed to
  3491. `pyarrow.fs.FileSystem.open_output_stream <https://arrow.apache.org\
  3492. /docs/python/generated/pyarrow.fs.FileSystem.html\
  3493. #pyarrow.fs.FileSystem.open_output_stream>`_, which is used when
  3494. opening the file to write to.
  3495. filename_provider: A :class:`~ray.data.datasource.FilenameProvider`
  3496. implementation. Use this parameter to customize what your filenames
  3497. look like.
  3498. pandas_json_args_fn: Callable that returns a dictionary of write
  3499. arguments that are provided to
  3500. `pandas.DataFrame.to_json() <https://pandas.pydata.org/docs/reference/\
  3501. api/pandas.DataFrame.to_json.html>`_
  3502. when writing each block to a file. Overrides
  3503. any duplicate keys from ``pandas_json_args``. Use this parameter
  3504. instead of ``pandas_json_args`` if any of your write arguments
  3505. can't be pickled, or if you'd like to lazily resolve the write
  3506. arguments for each dataset block.
  3507. min_rows_per_file: [Experimental] The target minimum number of rows to write
  3508. to each file. If ``None``, Ray Data writes a system-chosen number of
  3509. rows to each file. If the number of rows per block is larger than the
  3510. specified value, Ray Data writes the number of rows per block to each file.
  3511. The specified value is a hint, not a strict limit. Ray Data
  3512. might write more or fewer rows to each file.
  3513. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks.
  3514. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  3515. to control number of tasks to run concurrently. This doesn't change the
  3516. total number of tasks run. By default, concurrency is dynamically
  3517. decided based on the available resources.
  3518. num_rows_per_file: Deprecated. Use ``min_rows_per_file`` instead.
  3519. pandas_json_args: These args are passed to
  3520. `pandas.DataFrame.to_json() <https://pandas.pydata.org/docs/reference/\
  3521. api/pandas.DataFrame.to_json.html>`_,
  3522. which is used under the hood to write out each
  3523. :class:`~ray.data.Dataset` block. These
  3524. are dict(orient="records", lines=True) by default.
  3525. mode: Determines how to handle existing files. Valid modes are "overwrite", "error",
  3526. "ignore", "append". Defaults to "append".
  3527. NOTE: This method isn't atomic. "Overwrite" first deletes all the data
  3528. before writing to `path`.
  3529. """
  3530. if pandas_json_args_fn is None:
  3531. pandas_json_args_fn = lambda: {} # noqa: E731
  3532. effective_min_rows, _ = _validate_rows_per_file_args(
  3533. num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file
  3534. )
  3535. datasink = JSONDatasink(
  3536. path,
  3537. pandas_json_args_fn=pandas_json_args_fn,
  3538. pandas_json_args=pandas_json_args,
  3539. min_rows_per_file=effective_min_rows,
  3540. filesystem=filesystem,
  3541. try_create_dir=try_create_dir,
  3542. open_stream_args=arrow_open_stream_args,
  3543. filename_provider=filename_provider,
  3544. dataset_uuid=self._uuid,
  3545. mode=mode,
  3546. )
  3547. self.write_datasink(
  3548. datasink,
  3549. ray_remote_args=ray_remote_args,
  3550. concurrency=concurrency,
  3551. )
  3552. @ConsumptionAPI
  3553. @PublicAPI(stability="alpha", api_group=IOC_API_GROUP)
  3554. def write_iceberg(
  3555. self,
  3556. table_identifier: str,
  3557. catalog_kwargs: Optional[Dict[str, Any]] = None,
  3558. snapshot_properties: Optional[Dict[str, str]] = None,
  3559. mode: "SaveMode" = SaveMode.APPEND,
  3560. overwrite_filter: Optional["Expr"] = None,
  3561. upsert_kwargs: Optional[Dict[str, Any]] = None,
  3562. overwrite_kwargs: Optional[Dict[str, Any]] = None,
  3563. ray_remote_args: Dict[str, Any] = None,
  3564. concurrency: Optional[int] = None,
  3565. ) -> None:
  3566. """Writes the :class:`~ray.data.Dataset` to an Iceberg table.
  3567. .. tip::
  3568. For more details on PyIceberg, see
  3569. - URI: https://py.iceberg.apache.org/
  3570. Examples:
  3571. .. testcode::
  3572. :skipif: True
  3573. import ray
  3574. import pandas as pd
  3575. from ray.data import SaveMode
  3576. from ray.data.expressions import col
  3577. # Basic append (default behavior)
  3578. docs = [{"id": i, "title": f"Doc {i}"} for i in range(4)]
  3579. ds = ray.data.from_pandas(pd.DataFrame(docs))
  3580. ds.write_iceberg(
  3581. table_identifier="db_name.table_name",
  3582. catalog_kwargs={"name": "default", "type": "sql"}
  3583. )
  3584. # Schema evolution is automatic - new columns are added automatically
  3585. enriched_docs = [{"id": i, "title": f"Doc {i}", "category": "new"} for i in range(3)]
  3586. ds_enriched = ray.data.from_pandas(pd.DataFrame(enriched_docs))
  3587. ds_enriched.write_iceberg(
  3588. table_identifier="db_name.table_name",
  3589. catalog_kwargs={"name": "default", "type": "sql"}
  3590. )
  3591. # Upsert mode - update existing rows or insert new ones
  3592. updated_docs = [{"id": 2, "title": "Updated Doc 2"}, {"id": 5, "title": "New Doc 5"}]
  3593. ds_updates = ray.data.from_pandas(pd.DataFrame(updated_docs))
  3594. ds_updates.write_iceberg(
  3595. table_identifier="db_name.table_name",
  3596. catalog_kwargs={"name": "default", "type": "sql"},
  3597. mode=SaveMode.UPSERT,
  3598. upsert_kwargs={"join_cols": ["id"]},
  3599. )
  3600. # Partial overwrite with Ray Data expressions
  3601. ds.write_iceberg(
  3602. table_identifier="events.user_activity",
  3603. catalog_kwargs={"name": "default", "type": "rest"},
  3604. mode=SaveMode.OVERWRITE,
  3605. overwrite_filter=col("date") >= "2024-10-28"
  3606. )
  3607. Args:
  3608. table_identifier: Fully qualified table identifier (``db_name.table_name``)
  3609. catalog_kwargs: Optional arguments to pass to PyIceberg's catalog.load_catalog()
  3610. function (such as name, type, etc.). For the function definition, see
  3611. `pyiceberg catalog
  3612. <https://py.iceberg.apache.org/reference/pyiceberg/catalog/\
  3613. #pyiceberg.catalog.load_catalog>`_.
  3614. snapshot_properties: Custom properties to write to snapshot when committing
  3615. to an iceberg table.
  3616. mode: Write mode using SaveMode enum. Options:
  3617. * SaveMode.APPEND (default): Add new data to the table without checking for duplicates.
  3618. * SaveMode.UPSERT: Update existing rows that match on the join condition (``join_cols`` in ``upsert_kwargs``),
  3619. or insert new rows if they don't exist in the table.
  3620. * SaveMode.OVERWRITE: Replace all existing data in the table with new data, or replace
  3621. data matching overwrite_filter if specified.
  3622. overwrite_filter: Optional filter for OVERWRITE mode to perform partial overwrites.
  3623. Must be a Ray Data expression from `ray.data.expressions`. Only rows matching
  3624. this filter are replaced. If None with OVERWRITE mode, replaces all table data.
  3625. Example: `col("date") >= "2024-01-01"` or `(col("region") == "US") & (col("status") == "active")`
  3626. upsert_kwargs: Optional arguments for upsert operations.
  3627. Supported parameters: join_cols (List[str]), case_sensitive (bool), branch (str).
  3628. Note: Ray Data uses a copy-on-write strategy that always updates all columns
  3629. for matched keys and inserts all new keys for optimal parallelism.
  3630. overwrite_kwargs: Optional arguments to pass through to PyIceberg's table.overwrite() method.
  3631. Supported parameters: case_sensitive (bool), branch (str). See PyIceberg documentation
  3632. for details.
  3633. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks.
  3634. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  3635. to control number of tasks to run concurrently. This doesn't change the
  3636. total number of tasks run. By default, concurrency is dynamically
  3637. decided based on the available resources.
  3638. Note:
  3639. Schema evolution is automatically enabled. New columns in the incoming data
  3640. are automatically added to the table schema. The schema is extracted
  3641. automatically from the data being written.
  3642. """
  3643. datasink = IcebergDatasink(
  3644. table_identifier=table_identifier,
  3645. catalog_kwargs=catalog_kwargs,
  3646. snapshot_properties=snapshot_properties,
  3647. mode=mode,
  3648. overwrite_filter=overwrite_filter,
  3649. upsert_kwargs=upsert_kwargs,
  3650. overwrite_kwargs=overwrite_kwargs,
  3651. )
  3652. self.write_datasink(
  3653. datasink,
  3654. ray_remote_args=ray_remote_args,
  3655. concurrency=concurrency,
  3656. )
  3657. @PublicAPI(stability="alpha", api_group=IOC_API_GROUP)
  3658. @ConsumptionAPI
  3659. def write_images(
  3660. self,
  3661. path: str,
  3662. column: str,
  3663. file_format: str = "png",
  3664. *,
  3665. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  3666. try_create_dir: bool = True,
  3667. arrow_open_stream_args: Optional[Dict[str, Any]] = None,
  3668. filename_provider: Optional[FilenameProvider] = None,
  3669. ray_remote_args: Dict[str, Any] = None,
  3670. concurrency: Optional[int] = None,
  3671. mode: SaveMode = SaveMode.APPEND,
  3672. ) -> None:
  3673. """Writes the :class:`~ray.data.Dataset` to images.
  3674. Examples:
  3675. >>> import ray
  3676. >>> ds = ray.data.read_images("s3://anonymous@ray-example-data/image-datasets/simple")
  3677. >>> ds.write_images("local:///tmp/images", column="image")
  3678. Time complexity: O(dataset size / parallelism)
  3679. Args:
  3680. path: The path to the destination root directory, where
  3681. the images are written to.
  3682. column: The column containing the data you want to write to images.
  3683. file_format: The image file format to write with. For available options,
  3684. see `Image file formats <https://pillow.readthedocs.io/en/latest\
  3685. /handbook/image-file-formats.html>`_.
  3686. filesystem: The pyarrow filesystem implementation to write to.
  3687. These filesystems are specified in the
  3688. `pyarrow docs <https://arrow.apache.org/docs\
  3689. /python/api/filesystems.html#filesystem-implementations>`_.
  3690. Specify this if you need to provide specific configurations to the
  3691. filesystem. By default, the filesystem is automatically selected based
  3692. on the scheme of the paths. For example, if the path begins with
  3693. ``s3://``, the ``S3FileSystem`` is used.
  3694. try_create_dir: If ``True``, attempts to create all directories in the
  3695. destination path. Does nothing if all directories already
  3696. exist. Defaults to ``True``.
  3697. arrow_open_stream_args: kwargs passed to
  3698. `pyarrow.fs.FileSystem.open_output_stream <https://arrow.apache.org\
  3699. /docs/python/generated/pyarrow.fs.FileSystem.html\
  3700. #pyarrow.fs.FileSystem.open_output_stream>`_, which is used when
  3701. opening the file to write to.
  3702. filename_provider: A :class:`~ray.data.datasource.FilenameProvider`
  3703. implementation. Use this parameter to customize what your filenames
  3704. look like.
  3705. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks.
  3706. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  3707. to control number of tasks to run concurrently. This doesn't change the
  3708. total number of tasks run. By default, concurrency is dynamically
  3709. decided based on the available resources.
  3710. mode: Determines how to handle existing files. Valid modes are "overwrite", "error",
  3711. "ignore", "append". Defaults to "append".
  3712. NOTE: This method isn't atomic. "Overwrite" first deletes all the data
  3713. before writing to `path`.
  3714. """ # noqa: E501
  3715. datasink = ImageDatasink(
  3716. path,
  3717. column,
  3718. file_format,
  3719. filesystem=filesystem,
  3720. try_create_dir=try_create_dir,
  3721. open_stream_args=arrow_open_stream_args,
  3722. filename_provider=filename_provider,
  3723. dataset_uuid=self._uuid,
  3724. mode=mode,
  3725. )
  3726. self.write_datasink(
  3727. datasink,
  3728. ray_remote_args=ray_remote_args,
  3729. concurrency=concurrency,
  3730. )
  3731. @ConsumptionAPI
  3732. @PublicAPI(api_group=IOC_API_GROUP)
  3733. def write_csv(
  3734. self,
  3735. path: str,
  3736. *,
  3737. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  3738. try_create_dir: bool = True,
  3739. arrow_open_stream_args: Optional[Dict[str, Any]] = None,
  3740. filename_provider: Optional[FilenameProvider] = None,
  3741. arrow_csv_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
  3742. min_rows_per_file: Optional[int] = None,
  3743. ray_remote_args: Dict[str, Any] = None,
  3744. concurrency: Optional[int] = None,
  3745. num_rows_per_file: Optional[int] = None,
  3746. mode: SaveMode = SaveMode.APPEND,
  3747. **arrow_csv_args,
  3748. ) -> None:
  3749. """Writes the :class:`~ray.data.Dataset` to CSV files.
  3750. The number of files is determined by the number of blocks in the dataset.
  3751. To control the number of number of blocks, call
  3752. :meth:`~ray.data.Dataset.repartition`.
  3753. This method is only supported for datasets with records that are convertible to
  3754. pyarrow tables.
  3755. By default, the format of the output files is ``{uuid}_{block_idx}.csv``,
  3756. where ``uuid`` is a unique id for the dataset. To modify this behavior,
  3757. implement a custom :class:`~ray.data.datasource.FilenameProvider`
  3758. and pass it in as the ``filename_provider`` argument.
  3759. Examples:
  3760. Write the dataset as CSV files to a local directory.
  3761. >>> import ray
  3762. >>> ds = ray.data.range(100)
  3763. >>> ds.write_csv("local:///tmp/data")
  3764. Write the dataset as CSV files to S3.
  3765. >>> import ray
  3766. >>> ds = ray.data.range(100)
  3767. >>> ds.write_csv("s3://bucket/folder/) # doctest: +SKIP
  3768. Time complexity: O(dataset size / parallelism)
  3769. Args:
  3770. path: The path to the destination root directory, where
  3771. the CSV files are written to.
  3772. filesystem: The pyarrow filesystem implementation to write to.
  3773. These filesystems are specified in the
  3774. `pyarrow docs <https://arrow.apache.org/docs\
  3775. /python/api/filesystems.html#filesystem-implementations>`_.
  3776. Specify this if you need to provide specific configurations to the
  3777. filesystem. By default, the filesystem is automatically selected based
  3778. on the scheme of the paths. For example, if the path begins with
  3779. ``s3://``, the ``S3FileSystem`` is used.
  3780. try_create_dir: If ``True``, attempts to create all directories in the
  3781. destination path if ``True``. Does nothing if all directories already
  3782. exist. Defaults to ``True``.
  3783. arrow_open_stream_args: kwargs passed to
  3784. `pyarrow.fs.FileSystem.open_output_stream <https://arrow.apache.org\
  3785. /docs/python/generated/pyarrow.fs.FileSystem.html\
  3786. #pyarrow.fs.FileSystem.open_output_stream>`_, which is used when
  3787. opening the file to write to.
  3788. filename_provider: A :class:`~ray.data.datasource.FilenameProvider`
  3789. implementation. Use this parameter to customize what your filenames
  3790. look like.
  3791. arrow_csv_args_fn: Callable that returns a dictionary of write
  3792. arguments that are provided to `pyarrow.write.write_csv <https://\
  3793. arrow.apache.org/docs/python/generated/\
  3794. pyarrow.csv.write_csv.html#pyarrow.csv.write_csv>`_ when writing each
  3795. block to a file. Overrides any duplicate keys from ``arrow_csv_args``.
  3796. Use this argument instead of ``arrow_csv_args`` if any of your write
  3797. arguments cannot be pickled, or if you'd like to lazily resolve the
  3798. write arguments for each dataset block.
  3799. min_rows_per_file: [Experimental] The target minimum number of rows to write
  3800. to each file. If ``None``, Ray Data writes a system-chosen number of
  3801. rows to each file. If the number of rows per block is larger than the
  3802. specified value, Ray Data writes the number of rows per block to each file.
  3803. The specified value is a hint, not a strict limit. Ray Data
  3804. might write more or fewer rows to each file.
  3805. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks.
  3806. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  3807. to control number of tasks to run concurrently. This doesn't change the
  3808. total number of tasks run. By default, concurrency is dynamically
  3809. decided based on the available resources.
  3810. num_rows_per_file: [Deprecated] Use min_rows_per_file instead.
  3811. arrow_csv_args: Options to pass to `pyarrow.write.write_csv <https://\
  3812. arrow.apache.org/docs/python/generated/pyarrow.csv.write_csv.html\
  3813. #pyarrow.csv.write_csv>`_
  3814. when writing each block to a file.
  3815. mode: Determines how to handle existing files. Valid modes are "overwrite", "error",
  3816. "ignore", "append". Defaults to "append".
  3817. NOTE: This method isn't atomic. "Overwrite" first deletes all the data
  3818. before writing to `path`.
  3819. """
  3820. if arrow_csv_args_fn is None:
  3821. arrow_csv_args_fn = lambda: {} # noqa: E731
  3822. effective_min_rows, _ = _validate_rows_per_file_args(
  3823. num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file
  3824. )
  3825. datasink = CSVDatasink(
  3826. path,
  3827. arrow_csv_args_fn=arrow_csv_args_fn,
  3828. arrow_csv_args=arrow_csv_args,
  3829. min_rows_per_file=effective_min_rows,
  3830. filesystem=filesystem,
  3831. try_create_dir=try_create_dir,
  3832. open_stream_args=arrow_open_stream_args,
  3833. filename_provider=filename_provider,
  3834. dataset_uuid=self._uuid,
  3835. mode=mode,
  3836. )
  3837. self.write_datasink(
  3838. datasink,
  3839. ray_remote_args=ray_remote_args,
  3840. concurrency=concurrency,
  3841. )
  3842. @ConsumptionAPI
  3843. @PublicAPI(api_group=IOC_API_GROUP)
  3844. def write_tfrecords(
  3845. self,
  3846. path: str,
  3847. *,
  3848. tf_schema: Optional["schema_pb2.Schema"] = None,
  3849. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  3850. try_create_dir: bool = True,
  3851. arrow_open_stream_args: Optional[Dict[str, Any]] = None,
  3852. filename_provider: Optional[FilenameProvider] = None,
  3853. min_rows_per_file: Optional[int] = None,
  3854. ray_remote_args: Dict[str, Any] = None,
  3855. concurrency: Optional[int] = None,
  3856. num_rows_per_file: Optional[int] = None,
  3857. mode: SaveMode = SaveMode.APPEND,
  3858. ) -> None:
  3859. """Write the :class:`~ray.data.Dataset` to TFRecord files.
  3860. The `TFRecord <https://www.tensorflow.org/tutorials/load_data/tfrecord>`_
  3861. files contain
  3862. `tf.train.Example <https://www.tensorflow.org/api_docs/python/tf/train/\
  3863. Example>`_
  3864. records, with one Example record for each row in the dataset.
  3865. .. warning::
  3866. tf.train.Feature only natively stores ints, floats, and bytes,
  3867. so this function only supports datasets with these data types,
  3868. and will error if the dataset contains unsupported types.
  3869. The number of files is determined by the number of blocks in the dataset.
  3870. To control the number of number of blocks, call
  3871. :meth:`~ray.data.Dataset.repartition`.
  3872. This method is only supported for datasets with records that are convertible to
  3873. pyarrow tables.
  3874. By default, the format of the output files is ``{uuid}_{block_idx}.tfrecords``,
  3875. where ``uuid`` is a unique id for the dataset. To modify this behavior,
  3876. implement a custom :class:`~ray.data.datasource.FilenameProvider`
  3877. and pass it in as the ``filename_provider`` argument.
  3878. Examples:
  3879. >>> import ray
  3880. >>> ds = ray.data.range(100)
  3881. >>> ds.write_tfrecords("local:///tmp/data/")
  3882. Time complexity: O(dataset size / parallelism)
  3883. Args:
  3884. path: The path to the destination root directory, where tfrecords
  3885. files are written to.
  3886. filesystem: The pyarrow filesystem implementation to write to.
  3887. These filesystems are specified in the
  3888. `pyarrow docs <https://arrow.apache.org/docs\
  3889. /python/api/filesystems.html#filesystem-implementations>`_.
  3890. Specify this if you need to provide specific configurations to the
  3891. filesystem. By default, the filesystem is automatically selected based
  3892. on the scheme of the paths. For example, if the path begins with
  3893. ``s3://``, the ``S3FileSystem`` is used.
  3894. try_create_dir: If ``True``, attempts to create all directories in the
  3895. destination path. Does nothing if all directories already
  3896. exist. Defaults to ``True``.
  3897. arrow_open_stream_args: kwargs passed to
  3898. `pyarrow.fs.FileSystem.open_output_stream <https://arrow.apache.org\
  3899. /docs/python/generated/pyarrow.fs.FileSystem.html\
  3900. #pyarrow.fs.FileSystem.open_output_stream>`_, which is used when
  3901. opening the file to write to.
  3902. filename_provider: A :class:`~ray.data.datasource.FilenameProvider`
  3903. implementation. Use this parameter to customize what your filenames
  3904. look like.
  3905. min_rows_per_file: [Experimental] The target minimum number of rows to write
  3906. to each file. If ``None``, Ray Data writes a system-chosen number of
  3907. rows to each file. If the number of rows per block is larger than the
  3908. specified value, Ray Data writes the number of rows per block to each file.
  3909. The specified value is a hint, not a strict limit. Ray Data
  3910. might write more or fewer rows to each file.
  3911. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks.
  3912. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  3913. to control number of tasks to run concurrently. This doesn't change the
  3914. total number of tasks run. By default, concurrency is dynamically
  3915. decided based on the available resources.
  3916. num_rows_per_file: [Deprecated] Use min_rows_per_file instead.
  3917. mode: Determines how to handle existing files. Valid modes are "overwrite", "error",
  3918. "ignore", "append". Defaults to "append".
  3919. NOTE: This method isn't atomic. "Overwrite" first deletes all the data
  3920. before writing to `path`.
  3921. """
  3922. effective_min_rows, _ = _validate_rows_per_file_args(
  3923. num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file
  3924. )
  3925. datasink = TFRecordDatasink(
  3926. path=path,
  3927. tf_schema=tf_schema,
  3928. min_rows_per_file=effective_min_rows,
  3929. filesystem=filesystem,
  3930. try_create_dir=try_create_dir,
  3931. open_stream_args=arrow_open_stream_args,
  3932. filename_provider=filename_provider,
  3933. dataset_uuid=self._uuid,
  3934. mode=mode,
  3935. )
  3936. self.write_datasink(
  3937. datasink,
  3938. ray_remote_args=ray_remote_args,
  3939. concurrency=concurrency,
  3940. )
  3941. @ConsumptionAPI
  3942. @PublicAPI(stability="alpha", api_group=IOC_API_GROUP)
  3943. def write_webdataset(
  3944. self,
  3945. path: str,
  3946. *,
  3947. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  3948. try_create_dir: bool = True,
  3949. arrow_open_stream_args: Optional[Dict[str, Any]] = None,
  3950. filename_provider: Optional[FilenameProvider] = None,
  3951. min_rows_per_file: Optional[int] = None,
  3952. ray_remote_args: Dict[str, Any] = None,
  3953. encoder: Optional[Union[bool, str, callable, list]] = True,
  3954. concurrency: Optional[int] = None,
  3955. num_rows_per_file: Optional[int] = None,
  3956. mode: SaveMode = SaveMode.APPEND,
  3957. ) -> None:
  3958. """Writes the dataset to `WebDataset <https://github.com/webdataset/webdataset>`_ files.
  3959. The `TFRecord <https://www.tensorflow.org/tutorials/load_data/tfrecord>`_
  3960. files will contain
  3961. `tf.train.Example <https://www.tensorflow.org/api_docs/python/tf/train/Example>`_ # noqa: E501
  3962. records, with one Example record for each row in the dataset.
  3963. .. warning::
  3964. tf.train.Feature only natively stores ints, floats, and bytes,
  3965. so this function only supports datasets with these data types,
  3966. and will error if the dataset contains unsupported types.
  3967. This is only supported for datasets convertible to Arrow records.
  3968. To control the number of files, use :meth:`Dataset.repartition`.
  3969. Unless a custom filename provider is given, the format of the output
  3970. files is ``{uuid}_{block_idx}.tfrecords``, where ``uuid`` is a unique id
  3971. for the dataset.
  3972. Examples:
  3973. .. testcode::
  3974. :skipif: True
  3975. import ray
  3976. ds = ray.data.range(100)
  3977. ds.write_webdataset("s3://bucket/folder/")
  3978. Time complexity: O(dataset size / parallelism)
  3979. Args:
  3980. path: The path to the destination root directory, where tfrecords
  3981. files are written to.
  3982. filesystem: The filesystem implementation to write to.
  3983. try_create_dir: If ``True``, attempts to create all
  3984. directories in the destination path. Does nothing if all directories
  3985. already exist. Defaults to ``True``.
  3986. arrow_open_stream_args: kwargs passed to
  3987. ``pyarrow.fs.FileSystem.open_output_stream``
  3988. filename_provider: A :class:`~ray.data.datasource.FilenameProvider`
  3989. implementation. Use this parameter to customize what your filenames
  3990. look like.
  3991. min_rows_per_file: [Experimental] The target minimum number of rows to write
  3992. to each file. If ``None``, Ray Data writes a system-chosen number of
  3993. rows to each file. If the number of rows per block is larger than the
  3994. specified value, Ray Data writes the number of rows per block to each file.
  3995. The specified value is a hint, not a strict limit. Ray Data
  3996. might write more or fewer rows to each file.
  3997. ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks.
  3998. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  3999. to control number of tasks to run concurrently. This doesn't change the
  4000. total number of tasks run. By default, concurrency is dynamically
  4001. decided based on the available resources.
  4002. num_rows_per_file: [Deprecated] Use min_rows_per_file instead.
  4003. mode: Determines how to handle existing files. Valid modes are "overwrite", "error",
  4004. "ignore", "append". Defaults to "append".
  4005. NOTE: This method isn't atomic. "Overwrite" first deletes all the data
  4006. before writing to `path`.
  4007. """
  4008. effective_min_rows, _ = _validate_rows_per_file_args(
  4009. num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file
  4010. )
  4011. datasink = WebDatasetDatasink(
  4012. path,
  4013. encoder=encoder,
  4014. min_rows_per_file=effective_min_rows,
  4015. filesystem=filesystem,
  4016. try_create_dir=try_create_dir,
  4017. open_stream_args=arrow_open_stream_args,
  4018. filename_provider=filename_provider,
  4019. dataset_uuid=self._uuid,
  4020. mode=mode,
  4021. )
  4022. self.write_datasink(
  4023. datasink,
  4024. ray_remote_args=ray_remote_args,
  4025. concurrency=concurrency,
  4026. )
  4027. @ConsumptionAPI
  4028. @PublicAPI(api_group=IOC_API_GROUP)
  4029. def write_numpy(
  4030. self,
  4031. path: str,
  4032. *,
  4033. column: str,
  4034. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  4035. try_create_dir: bool = True,
  4036. arrow_open_stream_args: Optional[Dict[str, Any]] = None,
  4037. filename_provider: Optional[FilenameProvider] = None,
  4038. min_rows_per_file: Optional[int] = None,
  4039. ray_remote_args: Dict[str, Any] = None,
  4040. concurrency: Optional[int] = None,
  4041. num_rows_per_file: Optional[int] = None,
  4042. mode: SaveMode = SaveMode.APPEND,
  4043. ) -> None:
  4044. """Writes a column of the :class:`~ray.data.Dataset` to .npy files.
  4045. This is only supported for columns in the datasets that can be converted to
  4046. NumPy arrays.
  4047. The number of files is determined by the number of blocks in the dataset.
  4048. To control the number of number of blocks, call
  4049. :meth:`~ray.data.Dataset.repartition`.
  4050. By default, the format of the output files is ``{uuid}_{block_idx}.npy``,
  4051. where ``uuid`` is a unique id for the dataset. To modify this behavior,
  4052. implement a custom :class:`~ray.data.datasource.FilenameProvider`
  4053. and pass it in as the ``filename_provider`` argument.
  4054. Examples:
  4055. >>> import ray
  4056. >>> ds = ray.data.range(100)
  4057. >>> ds.write_numpy("local:///tmp/data/", column="id")
  4058. Time complexity: O(dataset size / parallelism)
  4059. Args:
  4060. path: The path to the destination root directory, where
  4061. the npy files are written to.
  4062. column: The name of the column that contains the data to
  4063. be written.
  4064. filesystem: The pyarrow filesystem implementation to write to.
  4065. These filesystems are specified in the
  4066. `pyarrow docs <https://arrow.apache.org/docs\
  4067. /python/api/filesystems.html#filesystem-implementations>`_.
  4068. Specify this if you need to provide specific configurations to the
  4069. filesystem. By default, the filesystem is automatically selected based
  4070. on the scheme of the paths. For example, if the path begins with
  4071. ``s3://``, the ``S3FileSystem`` is used.
  4072. try_create_dir: If ``True``, attempts to create all directories in
  4073. destination path. Does nothing if all directories already
  4074. exist. Defaults to ``True``.
  4075. arrow_open_stream_args: kwargs passed to
  4076. `pyarrow.fs.FileSystem.open_output_stream <https://arrow.apache.org\
  4077. /docs/python/generated/pyarrow.fs.FileSystem.html\
  4078. #pyarrow.fs.FileSystem.open_output_stream>`_, which is used when
  4079. opening the file to write to.
  4080. filename_provider: A :class:`~ray.data.datasource.FilenameProvider`
  4081. implementation. Use this parameter to customize what your filenames
  4082. look like.
  4083. min_rows_per_file: [Experimental] The target minimum number of rows to write
  4084. to each file. If ``None``, Ray Data writes a system-chosen number of
  4085. rows to each file. If the number of rows per block is larger than the
  4086. specified value, Ray Data writes the number of rows per block to each file.
  4087. The specified value is a hint, not a strict limit. Ray Data
  4088. might write more or fewer rows to each file.
  4089. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks.
  4090. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  4091. to control number of tasks to run concurrently. This doesn't change the
  4092. total number of tasks run. By default, concurrency is dynamically
  4093. decided based on the available resources.
  4094. num_rows_per_file: [Deprecated] Use min_rows_per_file instead.
  4095. mode: Determines how to handle existing files. Valid modes are "overwrite", "error",
  4096. "ignore", "append". Defaults to "append".
  4097. NOTE: This method isn't atomic. "Overwrite" first deletes all the data
  4098. before writing to `path`.
  4099. """
  4100. effective_min_rows, _ = _validate_rows_per_file_args(
  4101. num_rows_per_file=num_rows_per_file, min_rows_per_file=min_rows_per_file
  4102. )
  4103. datasink = NumpyDatasink(
  4104. path,
  4105. column,
  4106. min_rows_per_file=effective_min_rows,
  4107. filesystem=filesystem,
  4108. try_create_dir=try_create_dir,
  4109. open_stream_args=arrow_open_stream_args,
  4110. filename_provider=filename_provider,
  4111. dataset_uuid=self._uuid,
  4112. mode=mode,
  4113. )
  4114. self.write_datasink(
  4115. datasink,
  4116. ray_remote_args=ray_remote_args,
  4117. concurrency=concurrency,
  4118. )
  4119. @ConsumptionAPI
  4120. def write_sql(
  4121. self,
  4122. sql: str,
  4123. connection_factory: Callable[[], Connection],
  4124. ray_remote_args: Optional[Dict[str, Any]] = None,
  4125. concurrency: Optional[int] = None,
  4126. ) -> None:
  4127. """Write to a database that provides a
  4128. `Python DB API2-compliant <https://peps.python.org/pep-0249/>`_ connector.
  4129. .. note::
  4130. This method writes data in parallel using the DB API2 ``executemany``
  4131. method. To learn more about this method, see
  4132. `PEP 249 <https://peps.python.org/pep-0249/#executemany>`_.
  4133. Examples:
  4134. .. testcode::
  4135. import sqlite3
  4136. import ray
  4137. connection = sqlite3.connect("example.db")
  4138. connection.cursor().execute("CREATE TABLE movie(title, year, score)")
  4139. dataset = ray.data.from_items([
  4140. {"title": "Monty Python and the Holy Grail", "year": 1975, "score": 8.2},
  4141. {"title": "And Now for Something Completely Different", "year": 1971, "score": 7.5}
  4142. ])
  4143. dataset.write_sql(
  4144. "INSERT INTO movie VALUES(?, ?, ?)", lambda: sqlite3.connect("example.db")
  4145. )
  4146. result = connection.cursor().execute("SELECT * FROM movie ORDER BY year")
  4147. print(result.fetchall())
  4148. .. testoutput::
  4149. [('And Now for Something Completely Different', 1971, 7.5), ('Monty Python and the Holy Grail', 1975, 8.2)]
  4150. .. testcode::
  4151. :hide:
  4152. import os
  4153. os.remove("example.db")
  4154. Arguments:
  4155. sql: An ``INSERT INTO`` statement that specifies the table to write to. The
  4156. number of parameters must match the number of columns in the table.
  4157. connection_factory: A function that takes no arguments and returns a
  4158. Python DB API2
  4159. `Connection object <https://peps.python.org/pep-0249/#connection-objects>`_.
  4160. ray_remote_args: Keyword arguments passed to :func:`ray.remote` in the
  4161. write tasks.
  4162. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  4163. to control number of tasks to run concurrently. This doesn't change the
  4164. total number of tasks run. By default, concurrency is dynamically
  4165. decided based on the available resources.
  4166. """ # noqa: E501
  4167. datasink = SQLDatasink(sql=sql, connection_factory=connection_factory)
  4168. self.write_datasink(
  4169. datasink,
  4170. ray_remote_args=ray_remote_args,
  4171. concurrency=concurrency,
  4172. )
  4173. @ConsumptionAPI
  4174. def write_snowflake(
  4175. self,
  4176. table: str,
  4177. connection_parameters: str,
  4178. *,
  4179. ray_remote_args: Dict[str, Any] = None,
  4180. concurrency: Optional[int] = None,
  4181. ):
  4182. """Write this ``Dataset`` to a Snowflake table.
  4183. Examples:
  4184. .. testcode::
  4185. :skipif: True
  4186. import ray
  4187. connection_parameters = dict(
  4188. user=...,
  4189. account="ABCDEFG-ABC12345",
  4190. password=...,
  4191. database="SNOWFLAKE_SAMPLE_DATA",
  4192. schema="TPCDS_SF100TCL"
  4193. )
  4194. ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet")
  4195. ds.write_snowflake("MY_DATABASE.MY_SCHEMA.IRIS", connection_parameters)
  4196. Args:
  4197. table: The name of the table to write to.
  4198. connection_parameters: Keyword arguments to pass to
  4199. ``snowflake.connector.connect``. To view supported parameters, read
  4200. https://docs.snowflake.com/developer-guide/python-connector/python-connector-api#functions.
  4201. ray_remote_args: Keyword arguments passed to :func:`ray.remote` in the
  4202. write tasks.
  4203. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  4204. to control number of tasks to run concurrently. This doesn't change the
  4205. total number of tasks run. By default, concurrency is dynamically
  4206. decided based on the available resources.
  4207. """ # noqa: E501
  4208. import snowflake.connector
  4209. def snowflake_connection_factory():
  4210. return snowflake.connector.connect(**connection_parameters)
  4211. # Get column names from the dataset schema
  4212. column_names = self.schema().names
  4213. # Generate the SQL insert statement
  4214. columns_str = ", ".join(f'"{col}"' for col in column_names)
  4215. placeholders = ", ".join(["%s"] * len(column_names))
  4216. sql = f"INSERT INTO {table} ({columns_str}) VALUES ({placeholders})"
  4217. self.write_sql(
  4218. sql,
  4219. connection_factory=snowflake_connection_factory,
  4220. ray_remote_args=ray_remote_args,
  4221. concurrency=concurrency,
  4222. )
  4223. @PublicAPI(stability="alpha", api_group=IOC_API_GROUP)
  4224. @ConsumptionAPI
  4225. def write_mongo(
  4226. self,
  4227. uri: str,
  4228. database: str,
  4229. collection: str,
  4230. ray_remote_args: Dict[str, Any] = None,
  4231. concurrency: Optional[int] = None,
  4232. ) -> None:
  4233. """Writes the :class:`~ray.data.Dataset` to a MongoDB database.
  4234. This method is only supported for datasets convertible to pyarrow tables.
  4235. The number of parallel writes is determined by the number of blocks in the
  4236. dataset. To control the number of number of blocks, call
  4237. :meth:`~ray.data.Dataset.repartition`.
  4238. .. warning::
  4239. This method supports only a subset of the PyArrow's types, due to the
  4240. limitation of pymongoarrow which is used underneath. Writing unsupported
  4241. types fails on type checking. See all the supported types at:
  4242. https://mongo-arrow.readthedocs.io/en/stable/api/types.html.
  4243. .. note::
  4244. The records are inserted into MongoDB as new documents. If a record has
  4245. the _id field, this _id must be non-existent in MongoDB, otherwise the write
  4246. is rejected and fail (hence preexisting documents are protected from
  4247. being mutated). It's fine to not have _id field in record and MongoDB will
  4248. auto generate one at insertion.
  4249. Examples:
  4250. .. testcode::
  4251. :skipif: True
  4252. import ray
  4253. ds = ray.data.range(100)
  4254. ds.write_mongo(
  4255. uri="mongodb://username:password@mongodb0.example.com:27017/?authSource=admin",
  4256. database="my_db",
  4257. collection="my_collection"
  4258. )
  4259. Args:
  4260. uri: The URI to the destination MongoDB where the dataset is
  4261. written to. For the URI format, see details in the
  4262. `MongoDB docs <https://www.mongodb.com/docs/manual/reference\
  4263. /connection-string/>`_.
  4264. database: The name of the database. This database must exist otherwise
  4265. a ValueError is raised.
  4266. collection: The name of the collection in the database. This collection
  4267. must exist otherwise a ValueError is raised.
  4268. ray_remote_args: kwargs passed to :func:`ray.remote` in the write tasks.
  4269. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  4270. to control number of tasks to run concurrently. This doesn't change the
  4271. total number of tasks run. By default, concurrency is dynamically
  4272. decided based on the available resources.
  4273. Raises:
  4274. ValueError: if ``database`` doesn't exist.
  4275. ValueError: if ``collection`` doesn't exist.
  4276. """
  4277. datasink = MongoDatasink(
  4278. uri=uri,
  4279. database=database,
  4280. collection=collection,
  4281. )
  4282. self.write_datasink(
  4283. datasink,
  4284. ray_remote_args=ray_remote_args,
  4285. concurrency=concurrency,
  4286. )
  4287. @ConsumptionAPI
  4288. def write_bigquery(
  4289. self,
  4290. project_id: str,
  4291. dataset: str,
  4292. max_retry_cnt: int = 10,
  4293. overwrite_table: Optional[bool] = True,
  4294. ray_remote_args: Dict[str, Any] = None,
  4295. concurrency: Optional[int] = None,
  4296. ) -> None:
  4297. """Write the dataset to a BigQuery dataset table.
  4298. To control the number of parallel write tasks, use ``.repartition()``
  4299. before calling this method.
  4300. Examples:
  4301. .. testcode::
  4302. :skipif: True
  4303. import ray
  4304. import pandas as pd
  4305. docs = [{"title": "BigQuery Datasource test"} for key in range(4)]
  4306. ds = ray.data.from_pandas(pd.DataFrame(docs))
  4307. ds.write_bigquery(
  4308. project_id="my_project_id",
  4309. dataset="my_dataset_table",
  4310. overwrite_table=True
  4311. )
  4312. Args:
  4313. project_id: The name of the associated Google Cloud Project that hosts
  4314. the dataset to read. For more information, see details in
  4315. `Creating and managing projects <https://cloud.google.com/resource-manager/docs/creating-managing-projects>`_.
  4316. dataset: The name of the dataset in the format of ``dataset_id.table_id``.
  4317. The dataset is created if it doesn't already exist.
  4318. max_retry_cnt: The maximum number of retries that an individual block write
  4319. is retried due to BigQuery rate limiting errors. This isn't
  4320. related to Ray fault tolerance retries. The default number of retries
  4321. is 10.
  4322. overwrite_table: Whether the write will overwrite the table if it already
  4323. exists. The default behavior is to overwrite the table.
  4324. ``overwrite_table=False`` will append to the table if it exists.
  4325. ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks.
  4326. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  4327. to control number of tasks to run concurrently. This doesn't change the
  4328. total number of tasks run. By default, concurrency is dynamically
  4329. decided based on the available resources.
  4330. """ # noqa: E501
  4331. if ray_remote_args is None:
  4332. ray_remote_args = {}
  4333. # Each write task will launch individual remote tasks to write each block
  4334. # To avoid duplicate block writes, the write task should not be retried
  4335. if ray_remote_args.get("max_retries", 0) != 0:
  4336. warnings.warn(
  4337. "The max_retries of a BigQuery Write Task should be set to 0"
  4338. " to avoid duplicate writes."
  4339. )
  4340. else:
  4341. ray_remote_args["max_retries"] = 0
  4342. datasink = BigQueryDatasink(
  4343. project_id=project_id,
  4344. dataset=dataset,
  4345. max_retry_cnt=max_retry_cnt,
  4346. overwrite_table=overwrite_table,
  4347. )
  4348. self.write_datasink(
  4349. datasink,
  4350. ray_remote_args=ray_remote_args,
  4351. concurrency=concurrency,
  4352. )
  4353. @ConsumptionAPI
  4354. def write_clickhouse(
  4355. self,
  4356. table: str,
  4357. dsn: str,
  4358. *,
  4359. mode: SinkMode = SinkMode.CREATE,
  4360. schema: Optional["pyarrow.Schema"] = None,
  4361. client_settings: Optional[Dict[str, Any]] = None,
  4362. client_kwargs: Optional[Dict[str, Any]] = None,
  4363. table_settings: Optional[ClickHouseTableSettings] = None,
  4364. max_insert_block_rows: Optional[int] = None,
  4365. ray_remote_args: Dict[str, Any] = None,
  4366. concurrency: Optional[int] = None,
  4367. ) -> None:
  4368. """Write the dataset to a ClickHouse dataset table.
  4369. To control the number of parallel write tasks, use ``.repartition()``
  4370. before calling this method.
  4371. Examples:
  4372. .. testcode::
  4373. :skipif: True
  4374. import ray
  4375. import pyarrow as pa
  4376. import pandas as pd
  4377. docs = [{"title": "ClickHouse Datasink test"} for key in range(4)]
  4378. ds = ray.data.from_pandas(pd.DataFrame(docs))
  4379. user_schema = pa.schema(
  4380. [
  4381. ("id", pa.int64()),
  4382. ("title", pa.string()),
  4383. ]
  4384. )
  4385. ds.write_clickhouse(
  4386. table="default.my_table",
  4387. dsn="clickhouse+http://user:pass@localhost:8123/default",
  4388. mode=ray.data.SinkMode.OVERWRITE,
  4389. schema=user_schema,
  4390. table_settings=ray.data.ClickHouseTableSettings(
  4391. engine="ReplacingMergeTree()",
  4392. order_by="id",
  4393. ),
  4394. )
  4395. Args:
  4396. table: Fully qualified table identifier (e.g., "default.my_table").
  4397. The table is created if it doesn't already exist.
  4398. dsn: A string in DSN (Data Source Name) HTTP format
  4399. (e.g., "clickhouse+http://username:password@host:8123/default").
  4400. For more information, see `ClickHouse Connection String doc
  4401. <https://clickhouse.com/docs/en/integrations/sql-clients/cli#connection_string>`_.
  4402. mode: One of SinkMode.CREATE, SinkMode.APPEND, or
  4403. SinkMode.OVERWRITE:
  4404. * SinkMode.CREATE: Create a new table; fail if it already exists. If the table
  4405. does not exist, you must provide a schema (either via the `schema`
  4406. argument or as part of the dataset's first block).
  4407. * SinkMode.APPEND: If the table exists, append data to it; if not, create
  4408. the table using the provided or inferred schema. If the table does
  4409. not exist, you must supply a schema.
  4410. * SinkMode.OVERWRITE: Drop any existing table of this name, then create
  4411. a new table and write data to it. You **must** provide a schema in
  4412. this case, as the table is being re-created.
  4413. schema: Optional :class:`pyarrow.Schema` specifying column definitions.
  4414. This is mandatory if you are creating a new table (i.e., table doesn't
  4415. exist in CREATE or APPEND mode) or overwriting an existing table (OVERWRITE).
  4416. When appending to an existing table, a schema is optional, though you can
  4417. provide one to enforce column types or cast data as needed. If omitted
  4418. (and the table already exists), the existing table definition will be used.
  4419. If omitted and the table must be created, the schema is inferred from
  4420. the first block in the dataset.
  4421. client_settings: Optional ClickHouse server settings to be used with the
  4422. session/every request. For more information, see
  4423. `ClickHouse Client Settings doc
  4424. <https://clickhouse.com/docs/en/integrations/python#settings-argument>`_.
  4425. client_kwargs: Optional keyword arguments to pass to the
  4426. ClickHouse client. For more information, see
  4427. `ClickHouse Core Settings doc
  4428. <https://clickhouse.com/docs/en/integrations/python#additional-options>`_.
  4429. table_settings: An optional :class:`ClickHouseTableSettings` dataclass
  4430. that specifies additional table creation instructions, including:
  4431. * engine (default: `"MergeTree()"`):
  4432. Specifies the engine for the `CREATE TABLE` statement.
  4433. * order_by:
  4434. Sets the `ORDER BY` clause in the `CREATE TABLE` statement, iff not provided.
  4435. When overwriting an existing table, its previous `ORDER BY` (if any) is reused.
  4436. Otherwise, a "best" column is selected automatically (favoring a timestamp column,
  4437. then a non-string column, and lastly the first column).
  4438. * partition_by:
  4439. If present, adds a `PARTITION BY <value>` clause to the `CREATE TABLE` statement.
  4440. * primary_key:
  4441. If present, adds a `PRIMARY KEY (<value>)` clause.
  4442. * settings:
  4443. Appends a `SETTINGS <value>` clause to the `CREATE TABLE` statement, allowing
  4444. custom ClickHouse settings.
  4445. max_insert_block_rows: If you have extremely large blocks, specifying
  4446. a limit here will chunk the insert into multiple smaller insert calls.
  4447. Defaults to None (no chunking).
  4448. ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks.
  4449. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  4450. to control number of tasks to run concurrently. This doesn't change the
  4451. total number of tasks run. By default, concurrency is dynamically
  4452. decided based on the available resources.
  4453. """ # noqa: E501
  4454. datasink = ClickHouseDatasink(
  4455. table=table,
  4456. dsn=dsn,
  4457. mode=mode,
  4458. schema=schema,
  4459. client_settings=client_settings,
  4460. client_kwargs=client_kwargs,
  4461. table_settings=table_settings,
  4462. max_insert_block_rows=max_insert_block_rows,
  4463. )
  4464. self.write_datasink(
  4465. datasink,
  4466. ray_remote_args=ray_remote_args,
  4467. concurrency=concurrency,
  4468. )
  4469. @ConsumptionAPI
  4470. def write_lance(
  4471. self,
  4472. path: str,
  4473. *,
  4474. schema: Optional["pyarrow.Schema"] = None,
  4475. mode: Literal["create", "append", "overwrite"] = "create",
  4476. min_rows_per_file: int = 1024 * 1024,
  4477. max_rows_per_file: int = 64 * 1024 * 1024,
  4478. data_storage_version: Optional[str] = None,
  4479. storage_options: Optional[Dict[str, Any]] = None,
  4480. ray_remote_args: Dict[str, Any] = None,
  4481. concurrency: Optional[int] = None,
  4482. ) -> None:
  4483. """Write the dataset to a Lance dataset.
  4484. Examples:
  4485. .. testcode::
  4486. import ray
  4487. import pandas as pd
  4488. docs = [{"title": "Lance data sink test"} for key in range(4)]
  4489. ds = ray.data.from_pandas(pd.DataFrame(docs))
  4490. ds.write_lance("/tmp/data/")
  4491. Args:
  4492. path: The path to the destination Lance dataset.
  4493. schema: The schema of the dataset. If not provided, it is inferred from the data.
  4494. mode: The write mode. Can be "create", "append", or "overwrite".
  4495. min_rows_per_file: The minimum number of rows per file.
  4496. max_rows_per_file: The maximum number of rows per file.
  4497. data_storage_version: The version of the data storage format to use. Newer versions are more
  4498. efficient but require newer versions of lance to read. The default is
  4499. "legacy" which will use the legacy v1 version. See the user guide
  4500. for more details.
  4501. storage_options: The storage options for the writer. Default is None.
  4502. """
  4503. datasink = LanceDatasink(
  4504. path,
  4505. schema=schema,
  4506. mode=mode,
  4507. min_rows_per_file=min_rows_per_file,
  4508. max_rows_per_file=max_rows_per_file,
  4509. data_storage_version=data_storage_version,
  4510. storage_options=storage_options,
  4511. )
  4512. self.write_datasink(
  4513. datasink,
  4514. ray_remote_args=ray_remote_args,
  4515. concurrency=concurrency,
  4516. )
  4517. @ConsumptionAPI(pattern="Time complexity:")
  4518. def write_datasink(
  4519. self,
  4520. datasink: Datasink,
  4521. *,
  4522. ray_remote_args: Dict[str, Any] = None,
  4523. concurrency: Optional[int] = None,
  4524. ) -> None:
  4525. """Writes the dataset to a custom :class:`~ray.data.Datasink`.
  4526. Time complexity: O(dataset size / parallelism)
  4527. Args:
  4528. datasink: The :class:`~ray.data.Datasink` to write to.
  4529. ray_remote_args: Kwargs passed to :func:`ray.remote` in the write tasks.
  4530. concurrency: The maximum number of Ray tasks to run concurrently. Set this
  4531. to control number of tasks to run concurrently. This doesn't change the
  4532. total number of tasks run. By default, concurrency is dynamically
  4533. decided based on the available resources.
  4534. """ # noqa: E501
  4535. if ray_remote_args is None:
  4536. ray_remote_args = {}
  4537. if not datasink.supports_distributed_writes:
  4538. if ray.util.client.ray.is_connected():
  4539. raise ValueError(
  4540. "If you're using Ray Client, Ray Data won't schedule write tasks "
  4541. "on the driver's node."
  4542. )
  4543. ray_remote_args["scheduling_strategy"] = NodeAffinitySchedulingStrategy(
  4544. ray.get_runtime_context().get_node_id(),
  4545. soft=False,
  4546. )
  4547. plan = self._plan.copy()
  4548. write_op = Write(
  4549. self._logical_plan.dag,
  4550. datasink,
  4551. ray_remote_args=ray_remote_args,
  4552. compute=TaskPoolStrategy(concurrency),
  4553. )
  4554. logical_plan = LogicalPlan(write_op, self.context)
  4555. try:
  4556. # Call on_write_start for _FileDatasink before execution to handle
  4557. # SaveMode checks (ERROR raises, OVERWRITE deletes contents, IGNORE skips)
  4558. # and directory creation. For other datasinks, on_write_start is called
  4559. # automatically by the Write operator when the first input bundle arrives.
  4560. if isinstance(datasink, _FileDatasink):
  4561. datasink.on_write_start()
  4562. # TODO (https://github.com/ray-project/ray/issues/59326): There should be no special handling for skipping writes.
  4563. if datasink._skip_write:
  4564. logger.info(
  4565. f"Ignoring write because {datasink.path} already exists"
  4566. )
  4567. return
  4568. self._write_ds = Dataset(plan, logical_plan).materialize()
  4569. iter_, stats, _ = self._write_ds._execute_to_iterator()
  4570. write_results = []
  4571. for bundle in iter_:
  4572. res = ray.get(bundle.block_refs)
  4573. # Generate write result report
  4574. write_results.append(_gen_datasink_write_result(res))
  4575. combined_write_result = WriteResult.combine(*write_results)
  4576. logger.info(
  4577. "Data sink %s finished. %d rows and %s data written.",
  4578. datasink.get_name(),
  4579. combined_write_result.num_rows,
  4580. memory_string(combined_write_result.size_bytes),
  4581. )
  4582. datasink.on_write_complete(combined_write_result)
  4583. except Exception as e:
  4584. datasink.on_write_failed(e)
  4585. raise
  4586. @ConsumptionAPI(
  4587. delegate=(
  4588. "Calling any of the consumption methods on the returned ``DataIterator``"
  4589. ),
  4590. pattern="Returns:",
  4591. )
  4592. @PublicAPI(api_group=CD_API_GROUP)
  4593. def iterator(self) -> DataIterator:
  4594. """Return a :class:`~ray.data.DataIterator` over this dataset.
  4595. Don't call this method directly. Use it internally.
  4596. Returns:
  4597. A :class:`~ray.data.DataIterator` over this dataset.
  4598. """
  4599. return DataIteratorImpl(self)
  4600. @ConsumptionAPI
  4601. @PublicAPI(api_group=CD_API_GROUP)
  4602. def iter_rows(self) -> Iterable[Dict[str, Any]]:
  4603. """Return an iterable over the rows in this dataset.
  4604. Examples:
  4605. >>> import ray
  4606. >>> for row in ray.data.range(3).iter_rows():
  4607. ... print(row)
  4608. {'id': 0}
  4609. {'id': 1}
  4610. {'id': 2}
  4611. Time complexity: O(1)
  4612. Returns:
  4613. An iterable over the rows in this dataset.
  4614. """
  4615. return self.iterator().iter_rows()
  4616. @ConsumptionAPI
  4617. @PublicAPI(api_group=CD_API_GROUP)
  4618. def iter_batches(
  4619. self,
  4620. *,
  4621. prefetch_batches: int = 1,
  4622. batch_size: Optional[int] = 256,
  4623. batch_format: Optional[str] = "default",
  4624. drop_last: bool = False,
  4625. local_shuffle_buffer_size: Optional[int] = None,
  4626. local_shuffle_seed: Optional[int] = None,
  4627. _collate_fn: Optional[Callable[[DataBatch], CollatedData]] = None,
  4628. ) -> Iterable[DataBatch]:
  4629. """Return an iterable over batches of data.
  4630. This method is useful for model training.
  4631. Examples:
  4632. .. testcode::
  4633. import ray
  4634. ds = ray.data.read_images("example://image-datasets/simple")
  4635. for batch in ds.iter_batches(batch_size=2, batch_format="numpy"):
  4636. print(batch)
  4637. .. testoutput::
  4638. :options: +MOCK
  4639. {'image': array([[[[...]]]], dtype=uint8)}
  4640. ...
  4641. {'image': array([[[[...]]]], dtype=uint8)}
  4642. Time complexity: O(1)
  4643. Args:
  4644. prefetch_batches: The number of batches to fetch ahead of the current batch
  4645. to fetch. If set to greater than 0, a separate threadpool is used
  4646. to fetch the objects to the local node and format the batches. Defaults
  4647. to 1.
  4648. batch_size: The number of rows in each batch, or ``None`` to use entire
  4649. blocks as batches (blocks may contain different numbers of rows).
  4650. The final batch may include fewer than ``batch_size`` rows if
  4651. ``drop_last`` is ``False``. Defaults to 256.
  4652. batch_format: If ``"default"`` or ``"numpy"``, batches are
  4653. ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are
  4654. ``pandas.DataFrame``.
  4655. drop_last: Whether to drop the last batch if it's incomplete.
  4656. local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled
  4657. using a local in-memory shuffle buffer, and this value serves as the
  4658. minimum number of rows that must be in the local in-memory shuffle
  4659. buffer in order to yield a batch. When there are no more rows to add to
  4660. the buffer, the remaining rows in the buffer are drained.
  4661. local_shuffle_seed: The seed to use for the local random shuffle.
  4662. Returns:
  4663. An iterable over batches of data.
  4664. """
  4665. batch_format = _apply_batch_format(batch_format)
  4666. return self.iterator()._iter_batches(
  4667. prefetch_batches=prefetch_batches,
  4668. batch_size=batch_size,
  4669. batch_format=batch_format,
  4670. drop_last=drop_last,
  4671. local_shuffle_buffer_size=local_shuffle_buffer_size,
  4672. local_shuffle_seed=local_shuffle_seed,
  4673. _collate_fn=_collate_fn,
  4674. )
  4675. @ConsumptionAPI
  4676. @PublicAPI(api_group=CD_API_GROUP)
  4677. def iter_torch_batches(
  4678. self,
  4679. *,
  4680. prefetch_batches: int = 1,
  4681. batch_size: Optional[int] = 256,
  4682. dtypes: Optional[Union["torch.dtype", Dict[str, "torch.dtype"]]] = None,
  4683. device: Union[TorchDeviceType, Literal["auto"]] = "auto",
  4684. collate_fn: Optional[Callable[[Dict[str, np.ndarray]], CollatedData]] = None,
  4685. drop_last: bool = False,
  4686. local_shuffle_buffer_size: Optional[int] = None,
  4687. local_shuffle_seed: Optional[int] = None,
  4688. pin_memory: bool = False,
  4689. ) -> Iterable[TorchBatchType]:
  4690. """Return an iterable over batches of data represented as Torch tensors.
  4691. This iterable yields batches of type ``Dict[str, torch.Tensor]``.
  4692. For more flexibility, call :meth:`~Dataset.iter_batches` and manually convert
  4693. your data to Torch tensors.
  4694. Examples:
  4695. >>> import ray
  4696. >>> for batch in ray.data.range(
  4697. ... 12,
  4698. ... ).iter_torch_batches(batch_size=4):
  4699. ... print(batch)
  4700. {'id': tensor([0, 1, 2, 3])}
  4701. {'id': tensor([4, 5, 6, 7])}
  4702. {'id': tensor([ 8, 9, 10, 11])}
  4703. Use the ``collate_fn`` to customize how the tensor batch is created.
  4704. >>> from typing import Any, Dict
  4705. >>> import torch
  4706. >>> import numpy as np
  4707. >>> import ray
  4708. >>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
  4709. ... return torch.stack(
  4710. ... [torch.as_tensor(array) for array in batch.values()],
  4711. ... axis=1
  4712. ... )
  4713. >>> dataset = ray.data.from_items([
  4714. ... {"col_1": 1, "col_2": 2},
  4715. ... {"col_1": 3, "col_2": 4}])
  4716. >>> for batch in dataset.iter_torch_batches(collate_fn=collate_fn):
  4717. ... print(batch)
  4718. tensor([[1, 2],
  4719. [3, 4]])
  4720. Time complexity: O(1)
  4721. Args:
  4722. prefetch_batches: The number of batches to fetch ahead of the current batch
  4723. to fetch. If set to greater than 0, a separate threadpool is used
  4724. to fetch the objects to the local node, format the batches, and apply
  4725. the ``collate_fn``. Defaults to 1.
  4726. batch_size: The number of rows in each batch, or ``None`` to use entire
  4727. blocks as batches (blocks may contain different number of rows).
  4728. The final batch may include fewer than ``batch_size`` rows if
  4729. ``drop_last`` is ``False``. Defaults to 256.
  4730. dtypes: The Torch dtype(s) for the created tensor(s); if ``None``, the dtype
  4731. is inferred from the tensor data. You can't use this parameter with
  4732. ``collate_fn``.
  4733. device: The device on which the tensor should be placed. Defaults to
  4734. "auto" which moves the tensors to the appropriate device when the
  4735. Dataset is passed to Ray Train and ``collate_fn`` is not provided.
  4736. Otherwise, defaults to CPU. You can't use this parameter with
  4737. ``collate_fn``.
  4738. collate_fn: A function to convert a Numpy batch to a PyTorch tensor batch.
  4739. When this parameter is specified, the user should manually handle the
  4740. host to device data transfer outside of collate_fn.
  4741. This is useful for further processing the data after it has been
  4742. batched. Potential use cases include collating along a dimension other
  4743. than the first, padding sequences of various lengths, or generally
  4744. handling batches of different length tensors. If not provided, the
  4745. default collate function is used which simply converts the batch of
  4746. numpy arrays to a batch of PyTorch tensors. This API is still
  4747. experimental and is subject to change. You can't use this parameter in
  4748. conjunction with ``dtypes`` or ``device``.
  4749. drop_last: Whether to drop the last batch if it's incomplete.
  4750. local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled
  4751. using a local in-memory shuffle buffer, and this value serves as the
  4752. minimum number of rows that must be in the local in-memory shuffle
  4753. buffer in order to yield a batch. When there are no more rows to add to
  4754. the buffer, the remaining rows in the buffer are drained.
  4755. ``batch_size`` must also be specified when using local shuffling.
  4756. local_shuffle_seed: The seed to use for the local random shuffle.
  4757. pin_memory: [Alpha] If True, copies the tensor to pinned memory. Note that
  4758. `pin_memory` is only supported when using `DefaultCollateFn`.
  4759. Returns:
  4760. An iterable over Torch Tensor batches.
  4761. .. seealso::
  4762. :meth:`Dataset.iter_batches`
  4763. Call this method to manually convert your data to Torch tensors.
  4764. """ # noqa: E501
  4765. return self.iterator().iter_torch_batches(
  4766. prefetch_batches=prefetch_batches,
  4767. batch_size=batch_size,
  4768. dtypes=dtypes,
  4769. device=device,
  4770. collate_fn=collate_fn,
  4771. drop_last=drop_last,
  4772. local_shuffle_buffer_size=local_shuffle_buffer_size,
  4773. local_shuffle_seed=local_shuffle_seed,
  4774. pin_memory=pin_memory,
  4775. )
  4776. @ConsumptionAPI
  4777. @Deprecated
  4778. def iter_tf_batches(
  4779. self,
  4780. *,
  4781. prefetch_batches: int = 1,
  4782. batch_size: Optional[int] = 256,
  4783. dtypes: Optional[Union["tf.dtypes.DType", Dict[str, "tf.dtypes.DType"]]] = None,
  4784. drop_last: bool = False,
  4785. local_shuffle_buffer_size: Optional[int] = None,
  4786. local_shuffle_seed: Optional[int] = None,
  4787. ) -> Iterable[TensorFlowTensorBatchType]:
  4788. """Return an iterable over batches of data represented as TensorFlow tensors.
  4789. This iterable yields batches of type ``Dict[str, tf.Tensor]``.
  4790. For more flexibility, call :meth:`~Dataset.iter_batches` and manually convert
  4791. your data to TensorFlow tensors.
  4792. .. tip::
  4793. If you don't need the additional flexibility provided by this method,
  4794. consider using :meth:`~ray.data.Dataset.to_tf` instead. It's easier
  4795. to use.
  4796. Examples:
  4797. .. testcode::
  4798. import ray
  4799. ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
  4800. tf_dataset = ds.to_tf(
  4801. feature_columns="sepal length (cm)",
  4802. label_columns="target",
  4803. batch_size=2
  4804. )
  4805. for features, labels in tf_dataset:
  4806. print(features, labels)
  4807. .. testoutput::
  4808. tf.Tensor([5.1 4.9], shape=(2,), dtype=float64) tf.Tensor([0 0], shape=(2,), dtype=int64)
  4809. ...
  4810. tf.Tensor([6.2 5.9], shape=(2,), dtype=float64) tf.Tensor([2 2], shape=(2,), dtype=int64)
  4811. Time complexity: O(1)
  4812. Args:
  4813. prefetch_batches: The number of batches to fetch ahead of the current batch
  4814. to fetch. If set to greater than 0, a separate threadpool is used
  4815. to fetch the objects to the local node, format the batches, and apply
  4816. the ``collate_fn``. Defaults to 1.
  4817. batch_size: The number of rows in each batch, or ``None`` to use entire
  4818. blocks as batches (blocks may contain different numbers of rows).
  4819. The final batch may include fewer than ``batch_size`` rows if
  4820. ``drop_last`` is ``False``. Defaults to 256.
  4821. dtypes: The TensorFlow dtype(s) for the created tensor(s); if ``None``, the
  4822. dtype is inferred from the tensor data.
  4823. drop_last: Whether to drop the last batch if it's incomplete.
  4824. local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled
  4825. using a local in-memory shuffle buffer, and this value serves as the
  4826. minimum number of rows that must be in the local in-memory shuffle
  4827. buffer in order to yield a batch. When there are no more rows to add to
  4828. the buffer, the remaining rows in the buffer are drained.
  4829. ``batch_size`` must also be specified when using local shuffling.
  4830. local_shuffle_seed: The seed to use for the local random shuffle.
  4831. Returns:
  4832. An iterable over TensorFlow Tensor batches.
  4833. .. seealso::
  4834. :meth:`Dataset.iter_batches`
  4835. Call this method to manually convert your data to TensorFlow tensors.
  4836. """ # noqa: E501
  4837. warnings.warn(
  4838. "`iter_tf_batches` is deprecated and will be removed after May 2025. Use "
  4839. "`to_tf` instead.",
  4840. DeprecationWarning,
  4841. )
  4842. return self.iterator().iter_tf_batches(
  4843. prefetch_batches=prefetch_batches,
  4844. batch_size=batch_size,
  4845. dtypes=dtypes,
  4846. drop_last=drop_last,
  4847. local_shuffle_buffer_size=local_shuffle_buffer_size,
  4848. local_shuffle_seed=local_shuffle_seed,
  4849. )
  4850. @ConsumptionAPI
  4851. @PublicAPI(api_group=IOC_API_GROUP)
  4852. def to_tf(
  4853. self,
  4854. feature_columns: Union[str, List[str]],
  4855. label_columns: Union[str, List[str]],
  4856. *,
  4857. additional_columns: Union[str, List[str]] = None,
  4858. prefetch_batches: int = 1,
  4859. batch_size: int = 1,
  4860. drop_last: bool = False,
  4861. local_shuffle_buffer_size: Optional[int] = None,
  4862. local_shuffle_seed: Optional[int] = None,
  4863. feature_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
  4864. label_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
  4865. additional_type_spec: Union["tf.TypeSpec", Dict[str, "tf.TypeSpec"]] = None,
  4866. ) -> "tf.data.Dataset":
  4867. """Return a `TensorFlow Dataset <https://www.tensorflow.org/api_docs/python/tf/data/Dataset/>`_
  4868. over this :class:`~ray.data.Dataset`.
  4869. .. warning::
  4870. If your :class:`~ray.data.Dataset` contains ragged tensors, this method errors.
  4871. To prevent errors, :ref:`resize your tensors <transforming_tensors>`.
  4872. Examples:
  4873. >>> import ray
  4874. >>> ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")
  4875. >>> ds
  4876. Dataset(num_rows=?, schema=...)
  4877. If your model accepts a single tensor as input, specify a single feature column.
  4878. >>> ds.to_tf(feature_columns="sepal length (cm)", label_columns="target")
  4879. <_OptionsDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
  4880. If your model accepts a dictionary as input, specify a list of feature columns.
  4881. >>> ds.to_tf(["sepal length (cm)", "sepal width (cm)"], "target")
  4882. <_OptionsDataset element_spec=({'sepal length (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal length (cm)'), 'sepal width (cm)': TensorSpec(shape=(None,), dtype=tf.float64, name='sepal width (cm)')}, TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
  4883. If your dataset contains multiple features but your model accepts a single
  4884. tensor as input, combine features with
  4885. :class:`~ray.data.preprocessors.Concatenator`.
  4886. >>> from ray.data.preprocessors import Concatenator
  4887. >>> columns_to_concat = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
  4888. >>> preprocessor = Concatenator(columns=columns_to_concat, output_column_name="features")
  4889. >>> ds = preprocessor.transform(ds)
  4890. >>> ds
  4891. Concatenator
  4892. +- Dataset(num_rows=?, schema=...)
  4893. >>> ds.to_tf("features", "target")
  4894. <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'))>
  4895. If your model accepts different types, shapes, or names of tensors as input, specify the type spec.
  4896. If type specs are not specified, they are automatically inferred from the schema of the dataset.
  4897. >>> import tensorflow as tf
  4898. >>> ds.to_tf(
  4899. ... feature_columns="features",
  4900. ... label_columns="target",
  4901. ... feature_type_spec=tf.TensorSpec(shape=(None, 4), dtype=tf.float32, name="features"),
  4902. ... label_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="label")
  4903. ... )
  4904. <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float32, name='features'), TensorSpec(shape=(None,), dtype=tf.float32, name='label'))>
  4905. If your model accepts additional metadata aside from features and label, specify a single additional column or a list of additional columns.
  4906. A common use case is to include sample weights in the data samples and train a ``tf.keras.Model`` with ``tf.keras.Model.fit``.
  4907. >>> import pandas as pd
  4908. >>> ds = ds.add_column("sample weights", lambda df: pd.Series([1] * len(df)))
  4909. >>> ds.to_tf(feature_columns="features", label_columns="target", additional_columns="sample weights")
  4910. <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.int64, name='sample weights'))>
  4911. If your model accepts different types, shapes, or names for the additional metadata, specify the type spec of the additional column.
  4912. >>> ds.to_tf(
  4913. ... feature_columns="features",
  4914. ... label_columns="target",
  4915. ... additional_columns="sample weights",
  4916. ... additional_type_spec=tf.TensorSpec(shape=(None,), dtype=tf.float32, name="weight")
  4917. ... )
  4918. <_OptionsDataset element_spec=(TensorSpec(shape=(None, 4), dtype=tf.float64, name='features'), TensorSpec(shape=(None,), dtype=tf.int64, name='target'), TensorSpec(shape=(None,), dtype=tf.float32, name='weight'))>
  4919. Args:
  4920. feature_columns: Columns that correspond to model inputs. If this is a
  4921. string, the input data is a tensor. If this is a list, the input data
  4922. is a ``dict`` that maps column names to their tensor representation.
  4923. label_columns: Columns that correspond to model targets. If this is a
  4924. string, the target data is a tensor. If this is a list, the target data
  4925. is a ``dict`` that maps column names to their tensor representation.
  4926. additional_columns: Columns that correspond to sample weights or other metadata.
  4927. If this is a string, the weight data is a tensor. If this is a list, the
  4928. weight data is a ``dict`` that maps column names to their tensor representation.
  4929. prefetch_batches: The number of batches to fetch ahead of the current batch
  4930. to fetch. If set to greater than 0, a separate threadpool is used
  4931. to fetch the objects to the local node, format the batches, and apply
  4932. the collate_fn. Defaults to 1.
  4933. batch_size: Record batch size. Defaults to 1.
  4934. drop_last: Set to True to drop the last incomplete batch,
  4935. if the dataset size is not divisible by the batch size. If
  4936. False and the size of the stream is not divisible by the batch
  4937. size, then the last batch is smaller. Defaults to False.
  4938. local_shuffle_buffer_size: If non-None, the data is randomly shuffled
  4939. using a local in-memory shuffle buffer, and this value will serve as the
  4940. minimum number of rows that must be in the local in-memory shuffle
  4941. buffer in order to yield a batch. When there are no more rows to add to
  4942. the buffer, the remaining rows in the buffer is drained. This
  4943. buffer size must be greater than or equal to ``batch_size``, and
  4944. therefore ``batch_size`` must also be specified when using local
  4945. shuffling.
  4946. local_shuffle_seed: The seed to use for the local random shuffle.
  4947. feature_type_spec: The `tf.TypeSpec` of `feature_columns`. If there is
  4948. only one column, specify a `tf.TypeSpec`. If there are multiple columns,
  4949. specify a ``dict`` that maps column names to their `tf.TypeSpec`.
  4950. Default is `None` to automatically infer the type of each column.
  4951. label_type_spec: The `tf.TypeSpec` of `label_columns`. If there is
  4952. only one column, specify a `tf.TypeSpec`. If there are multiple columns,
  4953. specify a ``dict`` that maps column names to their `tf.TypeSpec`.
  4954. Default is `None` to automatically infer the type of each column.
  4955. additional_type_spec: The `tf.TypeSpec` of `additional_columns`. If there
  4956. is only one column, specify a `tf.TypeSpec`. If there are multiple
  4957. columns, specify a ``dict`` that maps column names to their `tf.TypeSpec`.
  4958. Default is `None` to automatically infer the type of each column.
  4959. Returns:
  4960. A `TensorFlow Dataset`_ that yields inputs and targets.
  4961. .. seealso::
  4962. :meth:`~ray.data.Dataset.iter_tf_batches`
  4963. Call this method if you need more flexibility.
  4964. """ # noqa: E501
  4965. return self.iterator().to_tf(
  4966. feature_columns=feature_columns,
  4967. label_columns=label_columns,
  4968. additional_columns=additional_columns,
  4969. prefetch_batches=prefetch_batches,
  4970. drop_last=drop_last,
  4971. batch_size=batch_size,
  4972. local_shuffle_buffer_size=local_shuffle_buffer_size,
  4973. local_shuffle_seed=local_shuffle_seed,
  4974. feature_type_spec=feature_type_spec,
  4975. label_type_spec=label_type_spec,
  4976. additional_type_spec=additional_type_spec,
  4977. )
  4978. @ConsumptionAPI(pattern="Time complexity:")
  4979. @PublicAPI(api_group=IOC_API_GROUP)
  4980. def to_daft(self) -> "daft.DataFrame":
  4981. """Convert this :class:`~ray.data.Dataset` into a
  4982. `Daft DataFrame <https://docs.getdaft.io/en/stable/api/dataframe/>`_.
  4983. This will convert all the data inside the Ray Dataset into a Daft DataFrame in a zero-copy way
  4984. (using Arrow as the intermediate data format).
  4985. Time complexity: O(dataset size / parallelism)
  4986. Returns:
  4987. A `Daft DataFrame`_ created from this dataset.
  4988. """
  4989. import daft
  4990. return daft.from_ray_dataset(self)
  4991. @ConsumptionAPI(pattern="Time complexity:")
  4992. @PublicAPI(api_group=IOC_API_GROUP)
  4993. def to_dask(
  4994. self,
  4995. meta: Union[
  4996. "pandas.DataFrame",
  4997. "pandas.Series",
  4998. Dict[str, Any],
  4999. Iterable[Any],
  5000. Tuple[Any],
  5001. None,
  5002. ] = None,
  5003. verify_meta: bool = True,
  5004. ) -> "dask.dataframe.DataFrame":
  5005. """Convert this :class:`~ray.data.Dataset` into a
  5006. `Dask DataFrame <https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.html#dask.dataframe.DataFrame>`_.
  5007. This is only supported for datasets convertible to Arrow records.
  5008. Note that this function will set the Dask scheduler to Dask-on-Ray
  5009. globally, via the config.
  5010. Time complexity: O(dataset size / parallelism)
  5011. Args:
  5012. meta: An empty `pandas DataFrame`_ or `Series`_ that matches the dtypes and column
  5013. names of the stream. This metadata is necessary for many algorithms in
  5014. dask dataframe to work. For ease of use, some alternative inputs are
  5015. also available. Instead of a DataFrame, a dict of ``{name: dtype}`` or
  5016. iterable of ``(name, dtype)`` can be provided (note that the order of
  5017. the names should match the order of the columns). Instead of a series, a
  5018. tuple of ``(name, dtype)`` can be used.
  5019. By default, this is inferred from the underlying Dataset schema,
  5020. with this argument supplying an optional override.
  5021. verify_meta: If True, Dask will check that the partitions have consistent
  5022. metadata. Defaults to True.
  5023. Returns:
  5024. A `Dask DataFrame`_ created from this dataset.
  5025. .. _pandas DataFrame: https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html
  5026. .. _Series: https://pandas.pydata.org/docs/reference/api/pandas.Series.html
  5027. """ # noqa: E501
  5028. import dask
  5029. import dask.dataframe as dd
  5030. import pandas as pd
  5031. try:
  5032. import pyarrow as pa
  5033. except Exception:
  5034. pa = None
  5035. from ray.data._internal.pandas_block import PandasBlockSchema
  5036. from ray.util.client.common import ClientObjectRef
  5037. from ray.util.dask import ray_dask_get
  5038. dask.config.set(scheduler=ray_dask_get)
  5039. @dask.delayed
  5040. def block_to_df(block_ref: ObjectRef[Block]) -> pd.DataFrame:
  5041. if isinstance(block_ref, (ray.ObjectRef, ClientObjectRef)):
  5042. raise ValueError(
  5043. "Dataset.to_dask() must be used with Dask-on-Ray, please "
  5044. "set the Dask scheduler to ray_dask_get (located in "
  5045. "ray.util.dask)."
  5046. )
  5047. return _block_to_df(block_ref)
  5048. if meta is None:
  5049. from ray.data.extensions import TensorDtype
  5050. # Infer Dask metadata from Dataset schema.
  5051. schema = self.schema(fetch_if_missing=True)
  5052. if isinstance(schema, PandasBlockSchema):
  5053. meta = pd.DataFrame(
  5054. {
  5055. col: pd.Series(
  5056. dtype=(
  5057. dtype
  5058. if not isinstance(dtype, TensorDtype)
  5059. else np.object_
  5060. )
  5061. )
  5062. for col, dtype in zip(schema.names, schema.types)
  5063. }
  5064. )
  5065. elif pa is not None and isinstance(schema, pa.Schema):
  5066. arrow_tensor_ext_types = get_arrow_extension_fixed_shape_tensor_types()
  5067. if any(
  5068. isinstance(type_, arrow_tensor_ext_types) for type_ in schema.types
  5069. ):
  5070. meta = pd.DataFrame(
  5071. {
  5072. col: pd.Series(
  5073. dtype=(
  5074. dtype.to_pandas_dtype()
  5075. if not isinstance(dtype, arrow_tensor_ext_types)
  5076. else np.object_
  5077. )
  5078. )
  5079. for col, dtype in zip(schema.names, schema.types)
  5080. }
  5081. )
  5082. else:
  5083. meta = schema.empty_table().to_pandas()
  5084. dfs = []
  5085. for ref_bundle in self.iter_internal_ref_bundles():
  5086. for block_ref in ref_bundle.block_refs:
  5087. dfs.append(block_to_df(block_ref))
  5088. ddf = dd.from_delayed(
  5089. dfs,
  5090. meta=meta,
  5091. verify_meta=verify_meta,
  5092. )
  5093. return ddf
  5094. @ConsumptionAPI(pattern="Time complexity:")
  5095. @PublicAPI(api_group=IOC_API_GROUP)
  5096. def to_mars(self) -> "mars.dataframe.DataFrame":
  5097. """Convert this :class:`~ray.data.Dataset` into a
  5098. `Mars DataFrame <https://mars-project.readthedocs.io/en/latest/reference/dataframe/index.html>`_.
  5099. Time complexity: O(dataset size / parallelism)
  5100. Returns:
  5101. A `Mars DataFrame`_ created from this dataset.
  5102. """ # noqa: E501
  5103. import pandas as pd
  5104. import pyarrow as pa
  5105. from mars.dataframe.datasource.read_raydataset import DataFrameReadRayDataset
  5106. from mars.dataframe.utils import parse_index
  5107. from ray.data._internal.pandas_block import PandasBlockSchema
  5108. refs = self.to_pandas_refs()
  5109. # remove this when https://github.com/mars-project/mars/issues/2945 got fixed
  5110. schema = self.schema()
  5111. if isinstance(schema, Schema):
  5112. schema = schema.base_schema
  5113. if isinstance(schema, PandasBlockSchema):
  5114. dtypes = pd.Series(schema.types, index=schema.names)
  5115. elif isinstance(schema, pa.Schema):
  5116. dtypes = schema.empty_table().to_pandas().dtypes
  5117. else:
  5118. raise NotImplementedError(f"Unsupported format of schema {schema}")
  5119. index_value = parse_index(pd.RangeIndex(-1))
  5120. columns_value = parse_index(dtypes.index, store_data=True)
  5121. op = DataFrameReadRayDataset(refs=refs)
  5122. return op(index_value=index_value, columns_value=columns_value, dtypes=dtypes)
  5123. @ConsumptionAPI(pattern="Time complexity:")
  5124. @PublicAPI(api_group=IOC_API_GROUP)
  5125. def to_modin(self) -> "modin.pandas.dataframe.DataFrame":
  5126. """Convert this :class:`~ray.data.Dataset` into a
  5127. `Modin DataFrame <https://modin.readthedocs.io/en/stable/flow/modin/pandas/dataframe.html>`_.
  5128. This works by first converting this dataset into a distributed set of
  5129. Pandas DataFrames (using :meth:`Dataset.to_pandas_refs`).
  5130. See caveats there. Then the individual DataFrames are used to
  5131. create the Modin DataFrame using
  5132. ``modin.distributed.dataframe.pandas.partitions.from_partitions()``.
  5133. This is only supported for datasets convertible to Arrow records.
  5134. This function induces a copy of the data. For zero-copy access to the
  5135. underlying data, consider using :meth:`.to_arrow_refs` or
  5136. :meth:`.iter_internal_ref_bundles`.
  5137. Time complexity: O(dataset size / parallelism)
  5138. Returns:
  5139. A `Modin DataFrame`_ created from this dataset.
  5140. """ # noqa: E501
  5141. from modin.distributed.dataframe.pandas.partitions import from_partitions
  5142. pd_objs = self.to_pandas_refs()
  5143. return from_partitions(pd_objs, axis=0)
  5144. @ConsumptionAPI(pattern="Time complexity:")
  5145. @PublicAPI(api_group=IOC_API_GROUP)
  5146. def to_spark(self, spark: "pyspark.sql.SparkSession") -> "pyspark.sql.DataFrame":
  5147. """Convert this :class:`~ray.data.Dataset` into a
  5148. `Spark DataFrame <https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.html>`_.
  5149. Time complexity: O(dataset size / parallelism)
  5150. Args:
  5151. spark: A `SparkSession`_, which must be created by RayDP (Spark-on-Ray).
  5152. Returns:
  5153. A `Spark DataFrame`_ created from this dataset.
  5154. .. _SparkSession: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.SparkSession.html
  5155. """ # noqa: E501
  5156. import raydp
  5157. schema = self.schema()
  5158. if isinstance(schema, Schema):
  5159. schema = schema.base_schema
  5160. ref_bundles = self.iter_internal_ref_bundles()
  5161. block_refs = _ref_bundles_iterator_to_block_refs_list(ref_bundles)
  5162. return raydp.spark.ray_dataset_to_spark_dataframe(spark, schema, block_refs)
  5163. @ConsumptionAPI(pattern="Time complexity:")
  5164. @PublicAPI(api_group=IOC_API_GROUP)
  5165. def to_pandas(self, limit: int = None) -> "pandas.DataFrame":
  5166. """Convert this :class:`~ray.data.Dataset` to a single pandas DataFrame.
  5167. This method errors if the number of rows exceeds the provided ``limit``.
  5168. To truncate the dataset beforehand, call :meth:`.limit`.
  5169. Examples:
  5170. >>> import ray
  5171. >>> ds = ray.data.from_items([{"a": i} for i in range(3)])
  5172. >>> ds.to_pandas()
  5173. a
  5174. 0 0
  5175. 1 1
  5176. 2 2
  5177. Time complexity: O(dataset size)
  5178. Args:
  5179. limit: The maximum number of rows to return. An error is
  5180. raised if the dataset has more rows than this limit. Defaults to
  5181. ``None``, which means no limit.
  5182. Returns:
  5183. A pandas DataFrame created from this dataset, containing a limited
  5184. number of rows.
  5185. Raises:
  5186. ValueError: if the number of rows in the :class:`~ray.data.Dataset` exceeds
  5187. ``limit``.
  5188. """
  5189. if limit is not None:
  5190. count = self.count()
  5191. if count > limit:
  5192. raise ValueError(
  5193. f"the dataset has more than the given limit of {limit} "
  5194. f"rows: {count}. If you are sure that a DataFrame with "
  5195. f"{count} rows will fit in local memory, set "
  5196. "ds.to_pandas(limit=None) to disable limits."
  5197. )
  5198. builder = PandasBlockBuilder()
  5199. for batch in self.iter_batches(batch_format="pandas", batch_size=None):
  5200. builder.add_block(batch)
  5201. block = builder.build()
  5202. # `PandasBlockBuilder` creates a dataframe with internal extension types like
  5203. # 'TensorDtype'. We use the `to_pandas` method to convert these extension
  5204. # types to regular types.
  5205. return BlockAccessor.for_block(block).to_pandas()
  5206. @ConsumptionAPI(pattern="Time complexity:")
  5207. @DeveloperAPI
  5208. def to_pandas_refs(self) -> List[ObjectRef["pandas.DataFrame"]]:
  5209. """Converts this :class:`~ray.data.Dataset` into a distributed set of Pandas
  5210. dataframes.
  5211. One DataFrame is created for each block in this Dataset.
  5212. This function induces a copy of the data. For zero-copy access to the
  5213. underlying data, consider using :meth:`Dataset.to_arrow_refs` or
  5214. :meth:`Dataset.iter_internal_ref_bundles`.
  5215. Examples:
  5216. >>> import ray
  5217. >>> ds = ray.data.range(10, override_num_blocks=2)
  5218. >>> refs = ds.to_pandas_refs()
  5219. >>> len(refs)
  5220. 2
  5221. Time complexity: O(dataset size / parallelism)
  5222. Returns:
  5223. A list of remote pandas DataFrames created from this dataset.
  5224. """
  5225. block_to_df = cached_remote_fn(_block_to_df)
  5226. pandas_refs = []
  5227. for bundle in self.iter_internal_ref_bundles():
  5228. for block_ref in bundle.block_refs:
  5229. pandas_refs.append(block_to_df.remote(block_ref))
  5230. return pandas_refs
  5231. @DeveloperAPI
  5232. def to_numpy_refs(
  5233. self, *, column: Optional[str] = None
  5234. ) -> List[ObjectRef[np.ndarray]]:
  5235. """Converts this :class:`~ray.data.Dataset` into a distributed set of NumPy
  5236. ndarrays or dictionary of NumPy ndarrays.
  5237. This is only supported for datasets convertible to NumPy ndarrays.
  5238. This function induces a copy of the data. For zero-copy access to the
  5239. underlying data, consider using :meth:`Dataset.to_arrow_refs` or
  5240. :meth:`Dataset.iter_internal_ref_bundles`.
  5241. Examples:
  5242. >>> import ray
  5243. >>> ds = ray.data.range(10, override_num_blocks=2)
  5244. >>> refs = ds.to_numpy_refs()
  5245. >>> len(refs)
  5246. 2
  5247. Time complexity: O(dataset size / parallelism)
  5248. Args:
  5249. column: The name of the column to convert to numpy. If ``None``, all columns
  5250. are used. If multiple columns are specified, each returned
  5251. future represents a dict of ndarrays. Defaults to None.
  5252. Returns:
  5253. A list of remote NumPy ndarrays created from this dataset.
  5254. """
  5255. block_to_ndarray = cached_remote_fn(_block_to_ndarray)
  5256. numpy_refs = []
  5257. for bundle in self.iter_internal_ref_bundles():
  5258. for block_ref in bundle.block_refs:
  5259. numpy_refs.append(block_to_ndarray.remote(block_ref, column=column))
  5260. return numpy_refs
  5261. @ConsumptionAPI(pattern="Time complexity:")
  5262. @DeveloperAPI
  5263. def to_arrow_refs(self) -> List[ObjectRef["pyarrow.Table"]]:
  5264. """Convert this :class:`~ray.data.Dataset` into a distributed set of PyArrow
  5265. tables.
  5266. One PyArrow table is created for each block in this Dataset.
  5267. This method is only supported for datasets convertible to PyArrow tables.
  5268. This function is zero-copy if the existing data is already in PyArrow
  5269. format. Otherwise, the data is converted to PyArrow format.
  5270. Examples:
  5271. >>> import ray
  5272. >>> ds = ray.data.range(10, override_num_blocks=2)
  5273. >>> refs = ds.to_arrow_refs()
  5274. >>> len(refs)
  5275. 2
  5276. Time complexity: O(1) unless conversion is required.
  5277. Returns:
  5278. A list of remote PyArrow tables created from this dataset.
  5279. """
  5280. import pyarrow as pa
  5281. ref_bundle: RefBundle = self._plan.execute()
  5282. block_refs: List[
  5283. ObjectRef["pyarrow.Table"]
  5284. ] = _ref_bundles_iterator_to_block_refs_list([ref_bundle])
  5285. # Schema is safe to call since we have already triggered execution with
  5286. # self._plan.execute(), which will cache the schema
  5287. schema = self.schema(fetch_if_missing=True)
  5288. if isinstance(schema, Schema):
  5289. schema = schema.base_schema
  5290. if isinstance(schema, pa.Schema):
  5291. # Zero-copy path.
  5292. return block_refs
  5293. block_to_arrow = cached_remote_fn(_block_to_arrow)
  5294. return [block_to_arrow.remote(block) for block in block_refs]
  5295. @ConsumptionAPI(pattern="Args:")
  5296. def to_random_access_dataset(
  5297. self,
  5298. key: str,
  5299. num_workers: Optional[int] = None,
  5300. ) -> RandomAccessDataset:
  5301. """Convert this dataset into a distributed RandomAccessDataset (EXPERIMENTAL).
  5302. RandomAccessDataset partitions the dataset across the cluster by the given
  5303. sort key, providing efficient random access to records via binary search. A
  5304. number of worker actors are created, each of which has zero-copy access to the
  5305. underlying sorted data blocks of the dataset.
  5306. Note that the key must be unique in the dataset. If there are duplicate keys,
  5307. an arbitrary value is returned.
  5308. This is only supported for Arrow-format datasets.
  5309. Args:
  5310. key: The key column over which records can be queried.
  5311. num_workers: The number of actors to use to serve random access queries.
  5312. By default, this is determined by multiplying the number of Ray nodes
  5313. in the cluster by four. As a rule of thumb, you can expect each worker
  5314. to provide ~3000 records / second via ``get_async()``, and
  5315. ~10000 records / second via ``multiget()``.
  5316. """
  5317. if num_workers is None:
  5318. num_workers = 4 * len(ray.nodes())
  5319. return RandomAccessDataset(self, key, num_workers=num_workers)
  5320. @ConsumptionAPI(pattern="store memory.", insert_after=True)
  5321. @PublicAPI(api_group=E_API_GROUP)
  5322. def materialize(self) -> "MaterializedDataset":
  5323. """Execute and materialize this dataset into object store memory.
  5324. This can be used to read all blocks into memory. By default, Dataset
  5325. doesn't read blocks from the datasource until the first transform.
  5326. Note that this does not mutate the original Dataset. Only the blocks of the
  5327. returned MaterializedDataset class are pinned in memory.
  5328. Examples:
  5329. >>> import ray
  5330. >>> ds = ray.data.range(10)
  5331. >>> materialized_ds = ds.materialize()
  5332. >>> materialized_ds
  5333. shape: (10, 1)
  5334. ╭───────╮
  5335. │ id │
  5336. │ --- │
  5337. │ int64 │
  5338. ╞═══════╡
  5339. │ 0 │
  5340. │ 1 │
  5341. │ 2 │
  5342. │ 3 │
  5343. │ 4 │
  5344. │ 5 │
  5345. │ 6 │
  5346. │ 7 │
  5347. │ 8 │
  5348. │ 9 │
  5349. ╰───────╯
  5350. (Showing 10 of 10 rows)
  5351. Returns:
  5352. A MaterializedDataset holding the materialized data blocks.
  5353. """
  5354. copy = Dataset.copy(self, _deep_copy=True, _as=MaterializedDataset)
  5355. bundle: RefBundle = copy._plan.execute()
  5356. blocks_with_metadata = bundle.blocks
  5357. # TODO(hchen): Here we generate the same number of blocks as
  5358. # the original Dataset. Because the old code path does this, and
  5359. # some unit tests implicily depend on this behavior.
  5360. # After we remove the old code path, we should consider merging
  5361. # some blocks for better perf.
  5362. ref_bundles = [
  5363. RefBundle(
  5364. blocks=[block_with_metadata],
  5365. owns_blocks=False,
  5366. schema=bundle.schema,
  5367. )
  5368. for block_with_metadata in blocks_with_metadata
  5369. ]
  5370. logical_plan = LogicalPlan(InputData(input_data=ref_bundles), self.context)
  5371. output = MaterializedDataset(
  5372. ExecutionPlan(copy._plan.stats(), data_context=copy.context),
  5373. logical_plan,
  5374. )
  5375. # Metrics are tagged with `copy`s uuid, update the output uuid with
  5376. # this so the user can access the metrics label.
  5377. output.set_name(copy.name)
  5378. output._set_uuid(copy._get_uuid())
  5379. output._plan.execute() # No-op that marks the plan as fully executed.
  5380. return output
  5381. @PublicAPI(api_group=IM_API_GROUP)
  5382. def stats(self) -> str:
  5383. """Returns a string containing execution timing information.
  5384. Note that this does not trigger execution, so if the dataset has not yet
  5385. executed, an empty string is returned.
  5386. Examples:
  5387. .. testcode::
  5388. import ray
  5389. ds = ray.data.range(10)
  5390. assert ds.stats() == ""
  5391. ds = ds.materialize()
  5392. print(ds.stats())
  5393. .. testoutput::
  5394. :options: +MOCK
  5395. Operator 0 Read: 1 tasks executed, 5 blocks produced in 0s
  5396. * Remote wall time: 16.29us min, 7.29ms max, 1.21ms mean, 24.17ms total
  5397. * Remote cpu time: 16.0us min, 2.54ms max, 810.45us mean, 16.21ms total
  5398. * Peak heap memory usage (MiB): 137968.75 min, 142734.38 max, 139846 mean
  5399. * Output num rows: 0 min, 1 max, 0 mean, 10 total
  5400. * Output size bytes: 0 min, 8 max, 4 mean, 80 total
  5401. * Tasks per node: 20 min, 20 max, 20 mean; 1 nodes used
  5402. """
  5403. if self._current_executor:
  5404. return self._current_executor.get_stats().to_summary().to_string()
  5405. elif self._write_ds is not None and self._write_ds._plan.has_computed_output():
  5406. return self._write_ds.stats()
  5407. return self._get_stats_summary().to_string()
  5408. @PublicAPI(api_group=IM_API_GROUP, stability="alpha")
  5409. def explain(self):
  5410. """Show the logical plan and physical plan of the dataset.
  5411. Examples:
  5412. .. testcode::
  5413. import ray
  5414. from ray.data import Dataset
  5415. ds: Dataset = ray.data.range(10, override_num_blocks=10)
  5416. ds = ds.map(lambda x: x + 1)
  5417. ds.explain()
  5418. .. testoutput::
  5419. <BLANKLINE>
  5420. -------- Logical Plan --------
  5421. MapRows[Map(<lambda>)]
  5422. +- Read[ReadRange]
  5423. <BLANKLINE>
  5424. -------- Logical Plan (Optimized) --------
  5425. MapRows[Map(<lambda>)]
  5426. +- Read[ReadRange]
  5427. <BLANKLINE>
  5428. -------- Physical Plan --------
  5429. TaskPoolMapOperator[Map(<lambda>)]
  5430. +- TaskPoolMapOperator[ReadRange]
  5431. +- InputDataBuffer[Input]
  5432. <BLANKLINE>
  5433. -------- Physical Plan (Optimized) --------
  5434. TaskPoolMapOperator[ReadRange->Map(<lambda>)]
  5435. +- InputDataBuffer[Input]
  5436. <BLANKLINE>
  5437. """
  5438. print(self._plan.explain())
  5439. def _get_stats_summary(self) -> DatasetStatsSummary:
  5440. return self._plan.stats().to_summary()
  5441. @ConsumptionAPI(pattern="Examples:")
  5442. @DeveloperAPI
  5443. def iter_internal_ref_bundles(self) -> Iterator[RefBundle]:
  5444. """Get an iterator over ``RefBundles``
  5445. belonging to this Dataset. Calling this function doesn't keep
  5446. the data materialized in-memory.
  5447. Examples:
  5448. >>> import ray
  5449. >>> ds = ray.data.range(1)
  5450. >>> for ref_bundle in ds.iter_internal_ref_bundles():
  5451. ... for block_ref, block_md in ref_bundle.blocks:
  5452. ... block = ray.get(block_ref)
  5453. Returns:
  5454. An iterator over this Dataset's ``RefBundles``.
  5455. """
  5456. iter_ref_bundles, _, _ = self._plan.execute_to_iterator()
  5457. self._synchronize_progress_bar()
  5458. return iter_ref_bundles
  5459. @Deprecated
  5460. @ConsumptionAPI(pattern="Examples:")
  5461. def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
  5462. """Get a list of references to the underlying blocks of this dataset.
  5463. This function can be used for zero-copy access to the data. It blocks
  5464. until the underlying blocks are computed.
  5465. Examples:
  5466. >>> import ray
  5467. >>> ds = ray.data.range(1)
  5468. >>> ds.get_internal_block_refs()
  5469. [ObjectRef(...)]
  5470. Returns:
  5471. A list of references to this dataset's blocks.
  5472. """
  5473. logger.warning(
  5474. "`Dataset.get_internal_block_refs()` is deprecated. Use "
  5475. "`Dataset.iter_internal_ref_bundles()` instead.",
  5476. )
  5477. block_refs = self._plan.execute().block_refs
  5478. self._synchronize_progress_bar()
  5479. return block_refs
  5480. @DeveloperAPI
  5481. def has_serializable_lineage(self) -> bool:
  5482. """Whether this dataset's lineage is able to be serialized for storage and
  5483. later deserialized, possibly on a different cluster.
  5484. Only datasets that are created from data that we know will still exist at
  5485. deserialization time, e.g. data external to this Ray cluster such as persistent
  5486. cloud object stores, support lineage-based serialization. All of the
  5487. ray.data.read_*() APIs support lineage-based serialization.
  5488. Examples:
  5489. >>> import ray
  5490. >>> ray.data.from_items(list(range(10))).has_serializable_lineage()
  5491. False
  5492. >>> ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv").has_serializable_lineage()
  5493. True
  5494. """ # noqa: E501
  5495. return all(
  5496. op.is_lineage_serializable()
  5497. for op in self._logical_plan.dag.post_order_iter()
  5498. )
  5499. @DeveloperAPI
  5500. def serialize_lineage(self) -> bytes:
  5501. """
  5502. Serialize this dataset's lineage, not the actual data or the existing data
  5503. futures, to bytes that can be stored and later deserialized, possibly on a
  5504. different cluster.
  5505. Note that this uses pickle and will drop all computed data, and that everything
  5506. is recomputed from scratch after deserialization.
  5507. Use :py:meth:`Dataset.deserialize_lineage` to deserialize the serialized
  5508. bytes returned from this method into a Dataset.
  5509. .. note::
  5510. Unioned and zipped datasets, produced by :py:meth`Dataset.union` and
  5511. :py:meth:`Dataset.zip`, are not lineage-serializable.
  5512. Examples:
  5513. .. testcode::
  5514. import ray
  5515. ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
  5516. serialized_ds = ds.serialize_lineage()
  5517. ds = ray.data.Dataset.deserialize_lineage(serialized_ds)
  5518. print(ds)
  5519. .. testoutput::
  5520. Dataset(num_rows=?, schema=...)
  5521. Returns:
  5522. Serialized bytes containing the lineage of this dataset.
  5523. """
  5524. if not self.has_serializable_lineage():
  5525. raise ValueError(
  5526. "Lineage-based serialization is not supported for this stream, which "
  5527. "means that it cannot be used as a tunable hyperparameter. "
  5528. "Lineage-based serialization is explicitly NOT supported for unioned "
  5529. "or zipped datasets (see docstrings for those methods), and is only "
  5530. "supported for datasets created from data that we know will still "
  5531. "exist at deserialization time, e.g. external data in persistent cloud "
  5532. "object stores or in-memory data from long-lived clusters. Concretely, "
  5533. "all ray.data.read_*() APIs should support lineage-based "
  5534. "serialization, while all of the ray.data.from_*() APIs do not. To "
  5535. "allow this stream to be serialized to storage, write the data to an "
  5536. "external store (such as AWS S3, GCS, or Azure Blob Storage) using the "
  5537. "Dataset.write_*() APIs, and serialize a new dataset reading "
  5538. "from the external store using the ray.data.read_*() APIs."
  5539. )
  5540. # Copy Dataset and clear the blocks from the execution plan so only the
  5541. # Dataset's lineage is serialized.
  5542. plan_copy = self._plan.deep_copy()
  5543. logical_plan_copy = copy.copy(self._plan._logical_plan)
  5544. ds = Dataset(plan_copy, logical_plan_copy)
  5545. ds._plan.clear_snapshot()
  5546. ds._set_uuid(self._get_uuid())
  5547. def _reduce_remote_fn(rf: ray.remote_function.RemoteFunction):
  5548. # Custom reducer for Ray remote function handles that allows for
  5549. # cross-cluster serialization.
  5550. # This manually unsets the last export session and job to force re-exporting
  5551. # of the function when the handle is deserialized on a new cluster.
  5552. # TODO(Clark): Fix this in core Ray, see issue:
  5553. # https://github.com/ray-project/ray/issues/24152.
  5554. reconstructor, args, state = rf.__reduce__()
  5555. state["_last_export_session_and_job"] = None
  5556. return reconstructor, args, state
  5557. context = ray._private.worker.global_worker.get_serialization_context()
  5558. try:
  5559. context._register_cloudpickle_reducer(
  5560. ray.remote_function.RemoteFunction, _reduce_remote_fn
  5561. )
  5562. serialized = pickle.dumps(ds)
  5563. finally:
  5564. context._unregister_cloudpickle_reducer(ray.remote_function.RemoteFunction)
  5565. return serialized
  5566. @staticmethod
  5567. @DeveloperAPI
  5568. def deserialize_lineage(serialized_ds: bytes) -> "Dataset":
  5569. """
  5570. Deserialize the provided lineage-serialized Dataset.
  5571. This uses pickle, and assumes that the provided serialized bytes were
  5572. serialized using :py:meth:`Dataset.serialize_lineage`.
  5573. Examples:
  5574. .. testcode::
  5575. import ray
  5576. ds = ray.data.read_csv("s3://anonymous@ray-example-data/iris.csv")
  5577. serialized_ds = ds.serialize_lineage()
  5578. ds = ray.data.Dataset.deserialize_lineage(serialized_ds)
  5579. print(ds)
  5580. .. testoutput::
  5581. Dataset(num_rows=?, schema=...)
  5582. Args:
  5583. serialized_ds: The serialized Dataset that we wish to deserialize.
  5584. Returns:
  5585. A deserialized ``Dataset`` instance.
  5586. """
  5587. return pickle.loads(serialized_ds)
  5588. @property
  5589. @DeveloperAPI
  5590. def context(self) -> DataContext:
  5591. """Return the DataContext used to create this Dataset."""
  5592. return self._plan._context
  5593. def _aggregate_on(
  5594. self, agg_cls: type, on: Optional[Union[str, List[str]]], *args, **kwargs
  5595. ):
  5596. """Helper for aggregating on a particular subset of the dataset.
  5597. This validates the `on` argument, and converts a list of column names
  5598. or lambdas to a multi-aggregation. A null `on` results in a
  5599. multi-aggregation on all columns for an Arrow Dataset, and a single
  5600. aggregation on the entire row for a simple Dataset.
  5601. """
  5602. aggs = self._build_multicolumn_aggs(agg_cls, on, *args, **kwargs)
  5603. return self.aggregate(*aggs)
  5604. def _build_multicolumn_aggs(
  5605. self,
  5606. agg_cls: type,
  5607. on: Optional[Union[str, List[str]]],
  5608. *args,
  5609. skip_cols: Optional[List[str]] = None,
  5610. **kwargs,
  5611. ):
  5612. """Build set of aggregations for applying a single aggregation to
  5613. multiple columns.
  5614. """
  5615. # Expand None into an aggregation for each column.
  5616. if on is None:
  5617. schema = self.schema(fetch_if_missing=True)
  5618. if schema is not None and not isinstance(schema, type):
  5619. if not skip_cols:
  5620. skip_cols = []
  5621. if len(schema.names) > 0:
  5622. on = [col for col in schema.names if col not in skip_cols]
  5623. if not isinstance(on, list):
  5624. on = [on]
  5625. if len(on) == 0:
  5626. raise ValueError("At least 1 column to aggregate on has to be provided")
  5627. return [agg_cls(on_, *args, **kwargs) for on_ in on]
  5628. def _aggregate_result(self, result: Union[Tuple, Mapping]) -> U:
  5629. if result is not None and len(result) == 1:
  5630. if isinstance(result, tuple):
  5631. return result[0]
  5632. else:
  5633. # NOTE (kfstorm): We cannot call `result[0]` directly on
  5634. # `PandasRow` because indexing a column with position is not
  5635. # supported by pandas.
  5636. return list(result.values())[0]
  5637. else:
  5638. return result
  5639. @repr_with_fallback(["ipywidgets", "8"])
  5640. def _repr_mimebundle_(self, **kwargs):
  5641. """Return a mimebundle with an ipywidget repr and a simple text repr.
  5642. Depending on the frontend where the data is being displayed,
  5643. different mimetypes are used from this bundle.
  5644. See https://ipython.readthedocs.io/en/stable/config/integrating.html
  5645. for information about this method, and
  5646. https://ipywidgets.readthedocs.io/en/latest/embedding.html
  5647. for more information about the jupyter widget mimetype.
  5648. Args:
  5649. **kwargs: Additional arguments passed to the widget's _repr_mimebundle_ method.
  5650. Returns:
  5651. A mimebundle containing an ipywidget repr and a simple text repr.
  5652. """
  5653. import ipywidgets
  5654. title = ipywidgets.HTML(f"<h2>{self.__class__.__name__}</h2>")
  5655. tab = self._tab_repr_()
  5656. widget = ipywidgets.VBox([title, tab], layout=ipywidgets.Layout(width="100%"))
  5657. # Get the widget mime bundle, but replace the plaintext
  5658. # with the Datastream repr
  5659. bundle = widget._repr_mimebundle_(**kwargs)
  5660. bundle.update(
  5661. {
  5662. "text/plain": repr(self),
  5663. }
  5664. )
  5665. return bundle
  5666. def _tab_repr_(self):
  5667. from ipywidgets import HTML, Tab
  5668. metadata = {
  5669. "num_blocks": self._plan.initial_num_blocks(),
  5670. "num_rows": self._meta_count(),
  5671. }
  5672. # Show metadata if available, but don't trigger execution.
  5673. schema = self.schema(fetch_if_missing=False)
  5674. if schema is None:
  5675. schema_repr = Template("rendered_html_common.html.j2").render(
  5676. content="<h5>Unknown schema</h5>"
  5677. )
  5678. elif isinstance(schema, type):
  5679. schema_repr = Template("rendered_html_common.html.j2").render(
  5680. content=f"<h5>Data type: <code>{html.escape(str(schema))}</code></h5>"
  5681. )
  5682. else:
  5683. schema_data = {}
  5684. for sname, stype in zip(schema.names, schema.types):
  5685. schema_data[sname] = getattr(stype, "__name__", str(stype))
  5686. schema_repr = Template("scrollableTable.html.j2").render(
  5687. table=tabulate(
  5688. tabular_data=schema_data.items(),
  5689. tablefmt="html",
  5690. showindex=False,
  5691. headers=["Name", "Type"],
  5692. ),
  5693. max_height="300px",
  5694. )
  5695. children = []
  5696. children.append(
  5697. HTML(
  5698. Template("scrollableTable.html.j2").render(
  5699. table=tabulate(
  5700. tabular_data=metadata.items(),
  5701. tablefmt="html",
  5702. showindex=False,
  5703. headers=["Field", "Value"],
  5704. ),
  5705. max_height="300px",
  5706. )
  5707. )
  5708. )
  5709. children.append(HTML(schema_repr))
  5710. return Tab(children, titles=["Metadata", "Schema"])
  5711. def __repr__(self) -> str:
  5712. return self._tabular_repr()
  5713. def _tabular_repr(self) -> str:
  5714. schema = self.schema(fetch_if_missing=False)
  5715. if schema is None or not isinstance(schema, Schema):
  5716. return self._plan.get_plan_as_string(self.__class__)
  5717. is_materialized = isinstance(self, MaterializedDataset)
  5718. return _build_dataset_ascii_repr(self, schema, is_materialized)
  5719. def __str__(self) -> str:
  5720. return repr(self)
  5721. def __bool__(self) -> bool:
  5722. # Prevents `__len__` from being called to check if it is None
  5723. # see: issue #25152
  5724. return True
  5725. def __len__(self) -> int:
  5726. raise AttributeError(
  5727. "Use `ds.count()` to compute the length of a distributed Dataset. "
  5728. "This may be an expensive operation."
  5729. )
  5730. def __iter__(self):
  5731. raise TypeError(
  5732. "`Dataset` objects aren't iterable. To iterate records, call "
  5733. "`ds.iter_rows()` or `ds.iter_batches()`. For more information, read "
  5734. "https://docs.ray.io/en/latest/data/iterating-over-data.html."
  5735. )
  5736. def _block_num_rows(self) -> List[int]:
  5737. get_num_rows = cached_remote_fn(_get_num_rows)
  5738. num_rows = []
  5739. for ref_bundle in self.iter_internal_ref_bundles():
  5740. for block_ref in ref_bundle.block_refs:
  5741. num_rows.append(get_num_rows.remote(block_ref))
  5742. return ray.get(num_rows)
  5743. def _meta_count(self) -> Optional[int]:
  5744. return self._plan.meta_count()
  5745. def _get_uuid(self) -> str:
  5746. return self._uuid
  5747. def _set_uuid(self, uuid: str) -> None:
  5748. self._uuid = uuid
  5749. self._plan._dataset_uuid = uuid
  5750. self._plan._in_stats.dataset_uuid = uuid
  5751. def _synchronize_progress_bar(self):
  5752. """Flush progress bar output by shutting down the current executor.
  5753. This should be called at the end of all blocking APIs (e.g., `take`), but not
  5754. async APIs (e.g., `iter_batches`).
  5755. The streaming executor runs in a separate generator / thread, so it is
  5756. possible the shutdown logic runs even after a call to retrieve rows from the
  5757. stream has finished. Explicit shutdown avoids this, which can clobber console
  5758. output (https://github.com/ray-project/ray/issues/32414).
  5759. """
  5760. if self._current_executor:
  5761. # NOTE: This method expected to have executor fully shutdown upon returning
  5762. # from this method
  5763. self._current_executor.shutdown(force=True)
  5764. self._current_executor = None
  5765. def _execute_to_iterator(
  5766. self,
  5767. ) -> Tuple[Iterator[RefBundle], DatasetStats, Optional["StreamingExecutor"]]:
  5768. bundle_iter, stats, executor = self._plan.execute_to_iterator()
  5769. # Capture current executor to be able to clean it up properly, once
  5770. # dataset is garbage-collected
  5771. self._current_executor = executor
  5772. return bundle_iter, stats, executor
  5773. def __getstate__(self):
  5774. # Note: excludes _current_executor which is not serializable.
  5775. return {
  5776. "plan": self._plan,
  5777. "uuid": self._uuid,
  5778. "logical_plan": self._logical_plan,
  5779. }
  5780. def __setstate__(self, state):
  5781. self._plan = state["plan"]
  5782. self._uuid = state["uuid"]
  5783. self._logical_plan = state["logical_plan"]
  5784. self._current_executor = None
  5785. def __del__(self):
  5786. if not self._current_executor:
  5787. return
  5788. # When Python shuts down, `ray` might evaluate to `<module None from None>`.
  5789. # This value is truthy and not `None`, so we use a try-catch in addition to
  5790. # `if ray is not None`. For more information, see #42382.
  5791. try:
  5792. if ray is not None and ray.is_initialized():
  5793. # NOTE: Upon garbage-collection we're allowing running tasks
  5794. # to be terminated asynchronously (ie avoid unnecessary
  5795. # synchronization on their completion)
  5796. self._current_executor.shutdown(force=False)
  5797. except TypeError:
  5798. pass
  5799. @PublicAPI
  5800. class MaterializedDataset(Dataset, Generic[T]):
  5801. """A Dataset materialized in Ray memory, e.g., via `.materialize()`.
  5802. The blocks of a MaterializedDataset object are materialized into Ray object store
  5803. memory, which means that this class can be shared or iterated over by multiple Ray
  5804. tasks without re-executing the underlying computations for producing the stream.
  5805. """
  5806. def num_blocks(self) -> int:
  5807. """Return the number of blocks of this :class:`MaterializedDataset`.
  5808. Examples:
  5809. >>> import ray
  5810. >>> ds = ray.data.range(100).repartition(10).materialize()
  5811. >>> ds.num_blocks()
  5812. 10
  5813. Time complexity: O(1)
  5814. Returns:
  5815. The number of blocks of this :class:`Dataset`.
  5816. """
  5817. return self._plan.initial_num_blocks()
  5818. @PublicAPI(stability="beta")
  5819. class Schema:
  5820. """Dataset schema.
  5821. Attributes:
  5822. base_schema: The underlying Arrow or Pandas schema.
  5823. """
  5824. def __init__(
  5825. self,
  5826. base_schema: Union["pyarrow.lib.Schema", "PandasBlockSchema"],
  5827. *,
  5828. data_context: Optional[DataContext] = None,
  5829. ):
  5830. """
  5831. Initialize a :class:`Schema` wrapper around an Arrow or Pandas schema.
  5832. Args:
  5833. base_schema: The underlying Arrow or Pandas schema.
  5834. data_context: The data context to use for this schema.
  5835. """
  5836. self.base_schema = base_schema
  5837. # Snapshot the current context, so that the config of Datasets is always
  5838. # determined by the config at the time it was created.
  5839. self._context = data_context or copy.deepcopy(DataContext.get_current())
  5840. @property
  5841. def names(self) -> List[str]:
  5842. """Lists the columns of this Dataset."""
  5843. return self.base_schema.names
  5844. @property
  5845. def types(self) -> List[Union[type[object], "pyarrow.lib.DataType"]]:
  5846. """Lists the types of this Dataset in Arrow format
  5847. For non-Arrow compatible types, we return "object".
  5848. """
  5849. import pandas as pd
  5850. import pyarrow as pa
  5851. from pandas.core.dtypes.dtypes import BaseMaskedDtype
  5852. from ray.data.extensions import ArrowTensorType, TensorDtype
  5853. def _convert_to_pa_type(
  5854. dtype: Union[np.dtype, pd.ArrowDtype, BaseMaskedDtype],
  5855. ) -> pa.DataType:
  5856. if isinstance(dtype, pd.ArrowDtype):
  5857. return dtype.pyarrow_dtype
  5858. elif isinstance(dtype, pd.StringDtype):
  5859. # StringDtype is not a BaseMaskedDtype, handle separately
  5860. return pa.string()
  5861. elif isinstance(dtype, BaseMaskedDtype):
  5862. dtype = dtype.numpy_dtype
  5863. return pa.from_numpy_dtype(dtype)
  5864. if isinstance(self.base_schema, pa.lib.Schema):
  5865. return list(self.base_schema.types)
  5866. arrow_types = []
  5867. for dtype in self.base_schema.types:
  5868. if isinstance(dtype, TensorDtype):
  5869. if self._context.use_arrow_tensor_v2:
  5870. pa_tensor_type_class = ArrowTensorTypeV2
  5871. else:
  5872. pa_tensor_type_class = ArrowTensorType
  5873. # Manually convert our Pandas tensor extension type to Arrow.
  5874. arrow_types.append(
  5875. pa_tensor_type_class(
  5876. shape=dtype._shape,
  5877. dtype=_convert_to_pa_type(dtype._dtype),
  5878. )
  5879. )
  5880. else:
  5881. try:
  5882. arrow_types.append(_convert_to_pa_type(dtype))
  5883. except pa.ArrowNotImplementedError:
  5884. arrow_types.append(object)
  5885. except Exception:
  5886. logger.exception(f"Error converting dtype {dtype} to Arrow.")
  5887. arrow_types.append(None)
  5888. return arrow_types
  5889. def __eq__(self, other):
  5890. return (
  5891. isinstance(other, Schema)
  5892. and other.types == self.types
  5893. and other.names == self.names
  5894. )
  5895. def __repr__(self):
  5896. column_width = max([len(name) for name in self.names] + [len("Column")])
  5897. padding = 2
  5898. output = "Column"
  5899. output += " " * ((column_width + padding) - len("Column"))
  5900. output += "Type\n"
  5901. output += "-" * len("Column")
  5902. output += " " * ((column_width + padding) - len("Column"))
  5903. output += "-" * len("Type") + "\n"
  5904. for name, type in zip(self.names, self.types):
  5905. output += name
  5906. output += " " * ((column_width + padding) - len(name))
  5907. output += f"{type}\n"
  5908. output = output.rstrip()
  5909. return output
  5910. def _block_to_df(block: Block) -> "pandas.DataFrame":
  5911. block = BlockAccessor.for_block(block)
  5912. return block.to_pandas()
  5913. def _block_to_ndarray(block: Block, column: Optional[str]):
  5914. block = BlockAccessor.for_block(block)
  5915. return block.to_numpy(column)
  5916. def _block_to_arrow(block: Block):
  5917. block = BlockAccessor.for_block(block)
  5918. return block.to_arrow()