algorithm.py 211 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897
  1. import concurrent
  2. import copy
  3. import functools
  4. import importlib
  5. import importlib.metadata
  6. import json
  7. import logging
  8. import os
  9. import pathlib
  10. import re
  11. import tempfile
  12. import time
  13. from collections import defaultdict
  14. from datetime import datetime
  15. from typing import (
  16. TYPE_CHECKING,
  17. Any,
  18. Callable,
  19. Collection,
  20. DefaultDict,
  21. Dict,
  22. List,
  23. Optional,
  24. Set,
  25. Tuple,
  26. Type,
  27. Union,
  28. )
  29. import gymnasium as gym
  30. import numpy as np
  31. import pyarrow.fs
  32. import tree # pip install dm_tree
  33. from packaging import version
  34. import ray
  35. import ray.cloudpickle as pickle
  36. from ray._common.deprecation import (
  37. DEPRECATED_VALUE,
  38. Deprecated,
  39. deprecation_warning,
  40. )
  41. from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag
  42. from ray.actor import ActorHandle
  43. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  44. from ray.rllib.algorithms.registry import ALGORITHMS_CLASS_TO_NAME as ALL_ALGORITHMS
  45. from ray.rllib.algorithms.utils import (
  46. AggregatorActor,
  47. _get_env_runner_bundles,
  48. _get_learner_bundles,
  49. _get_main_process_bundle,
  50. _get_offline_eval_runner_bundles,
  51. )
  52. from ray.rllib.callbacks.utils import make_callback
  53. from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
  54. from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2
  55. from ray.rllib.core import (
  56. COMPONENT_ENV_RUNNER,
  57. COMPONENT_ENV_TO_MODULE_CONNECTOR,
  58. COMPONENT_EVAL_ENV_RUNNER,
  59. COMPONENT_LEARNER,
  60. COMPONENT_LEARNER_GROUP,
  61. COMPONENT_METRICS_LOGGER,
  62. COMPONENT_MODULE_TO_ENV_CONNECTOR,
  63. COMPONENT_RL_MODULE,
  64. DEFAULT_MODULE_ID,
  65. )
  66. from ray.rllib.core.columns import Columns
  67. from ray.rllib.core.rl_module import validate_module_id
  68. from ray.rllib.core.rl_module.multi_rl_module import (
  69. MultiRLModule,
  70. MultiRLModuleSpec,
  71. )
  72. from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
  73. from ray.rllib.env import INPUT_ENV_SPACES
  74. from ray.rllib.env.env_context import EnvContext
  75. from ray.rllib.env.env_runner import EnvRunner
  76. from ray.rllib.env.env_runner_group import EnvRunnerGroup
  77. from ray.rllib.env.utils import _gym_env_creator
  78. from ray.rllib.evaluation.metrics import (
  79. collect_episodes,
  80. summarize_episodes,
  81. )
  82. from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
  83. from ray.rllib.offline import get_dataset_and_shards
  84. from ray.rllib.offline.estimators import (
  85. DirectMethod,
  86. DoublyRobust,
  87. ImportanceSampling,
  88. OffPolicyEstimator,
  89. WeightedImportanceSampling,
  90. )
  91. from ray.rllib.offline.offline_evaluator import OfflineEvaluator
  92. from ray.rllib.policy.policy import Policy, PolicySpec
  93. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
  94. from ray.rllib.utils import FilterManager, deep_update, force_list
  95. from ray.rllib.utils.actor_manager import FaultTolerantActorManager
  96. from ray.rllib.utils.annotations import (
  97. DeveloperAPI,
  98. ExperimentalAPI,
  99. OldAPIStack,
  100. OverrideToImplementCustomLogic,
  101. OverrideToImplementCustomLogic_CallToSuperRecommended,
  102. PublicAPI,
  103. override,
  104. )
  105. from ray.rllib.utils.checkpoints import (
  106. CHECKPOINT_VERSION,
  107. CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER,
  108. Checkpointable,
  109. get_checkpoint_info,
  110. try_import_msgpack,
  111. )
  112. from ray.rllib.utils.debug import update_global_seed_if_necessary
  113. from ray.rllib.utils.error import ERR_MSG_INVALID_ENV_DESCRIPTOR, EnvError
  114. from ray.rllib.utils.framework import try_import_tf
  115. from ray.rllib.utils.from_config import from_config
  116. from ray.rllib.utils.metrics import (
  117. AGGREGATOR_ACTOR_RESULTS,
  118. ALL_MODULES,
  119. DATASET_NUM_ITERS_EVALUATED,
  120. ENV_RUNNER_RESULTS,
  121. ENV_RUNNER_SAMPLING_TIMER,
  122. EPISODE_LEN_MEAN,
  123. EPISODE_RETURN_MEAN,
  124. EVALUATION_ITERATION_TIMER,
  125. EVALUATION_RESULTS,
  126. FAULT_TOLERANCE_STATS,
  127. LEARNER_RESULTS,
  128. LEARNER_UPDATE_TIMER,
  129. NUM_AGENT_STEPS_SAMPLED,
  130. NUM_AGENT_STEPS_SAMPLED_LIFETIME,
  131. NUM_AGENT_STEPS_SAMPLED_THIS_ITER,
  132. NUM_AGENT_STEPS_TRAINED,
  133. NUM_AGENT_STEPS_TRAINED_LIFETIME,
  134. NUM_ENV_STEPS_SAMPLED,
  135. NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER,
  136. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  137. NUM_ENV_STEPS_SAMPLED_THIS_ITER,
  138. NUM_ENV_STEPS_TRAINED,
  139. NUM_ENV_STEPS_TRAINED_LIFETIME,
  140. NUM_EPISODES,
  141. NUM_EPISODES_LIFETIME,
  142. NUM_TRAINING_STEP_CALLS_PER_ITERATION,
  143. OFFLINE_EVAL_RUNNER_RESULTS,
  144. OFFLINE_EVALUATION_ITERATION_TIMER,
  145. RESTORE_ENV_RUNNERS_TIMER,
  146. RESTORE_EVAL_ENV_RUNNERS_TIMER,
  147. RESTORE_OFFLINE_EVAL_RUNNERS_TIMER,
  148. STEPS_TRAINED_THIS_ITER_COUNTER,
  149. SYNCH_ENV_CONNECTOR_STATES_TIMER,
  150. SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER,
  151. SYNCH_WORKER_WEIGHTS_TIMER,
  152. TIMERS,
  153. TRAINING_ITERATION_TIMER,
  154. TRAINING_STEP_TIMER,
  155. )
  156. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
  157. from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
  158. from ray.rllib.utils.metrics.ray_metrics import (
  159. DEFAULT_HISTOGRAM_BOUNDARIES_LONG_EVENTS,
  160. DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  161. TimerAndPrometheusLogger,
  162. )
  163. from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer, ReplayBuffer
  164. from ray.rllib.utils.runners.runner_group import RunnerGroup
  165. from ray.rllib.utils.serialization import NOT_SERIALIZABLE, deserialize_type
  166. from ray.rllib.utils.spaces import space_utils
  167. from ray.rllib.utils.typing import (
  168. AgentConnectorDataType,
  169. AgentID,
  170. AgentToModuleMappingFn,
  171. AlgorithmConfigDict,
  172. EnvCreator,
  173. EnvInfoDict,
  174. EnvType,
  175. EpisodeID,
  176. ModuleID,
  177. PartialAlgorithmConfigDict,
  178. PolicyID,
  179. PolicyState,
  180. ResultDict,
  181. SampleBatchType,
  182. ShouldModuleBeUpdatedFn,
  183. StateDict,
  184. TensorStructType,
  185. TensorType,
  186. )
  187. from ray.train.constants import DEFAULT_STORAGE_PATH
  188. from ray.tune import Checkpoint
  189. from ray.tune.execution.placement_groups import PlacementGroupFactory
  190. from ray.tune.experiment.trial import ExportFormat
  191. from ray.tune.logger import Logger, UnifiedLogger
  192. from ray.tune.registry import ENV_CREATOR, _global_registry, get_trainable_cls
  193. from ray.tune.resources import Resources
  194. from ray.tune.result import TRAINING_ITERATION
  195. from ray.tune.trainable import Trainable
  196. from ray.util import log_once
  197. from ray.util.metrics import Counter, Histogram
  198. from ray.util.timer import _Timer
  199. if TYPE_CHECKING:
  200. from ray.rllib.core.learner.learner_group import LearnerGroup
  201. from ray.rllib.offline.offline_data import OfflineData
  202. tf1, tf, tfv = try_import_tf()
  203. logger = logging.getLogger(__name__)
  204. @PublicAPI
  205. class Algorithm(Checkpointable, Trainable):
  206. """An RLlib algorithm responsible for training one or more neural network models.
  207. You can write your own Algorithm classes by sub-classing from `Algorithm`
  208. or any of its built-in subclasses.
  209. Override the `training_step` method to implement your own algorithm logic.
  210. Find the various built-in `training_step()` methods for different algorithms in
  211. their respective [algo name].py files, for example:
  212. `ray.rllib.algorithms.dqn.dqn.py` or `ray.rllib.algorithms.impala.impala.py`.
  213. The most important API methods an Algorithm exposes are `train()` for running a
  214. single training iteration, `evaluate()` for running a single round of evaluation,
  215. `save_to_path()` for creating a checkpoint, and `restore_from_path()` for loading a
  216. state from an existing checkpoint.
  217. """
  218. #: The AlgorithmConfig instance of the Algorithm.
  219. config: Optional[AlgorithmConfig] = None
  220. #: The MetricsLogger instance of the Algorithm. RLlib uses this to log
  221. #: metrics from within the `training_step()` method. Users can use it to log
  222. #: metrics from within their custom Algorithm-based callbacks.
  223. metrics: Optional[MetricsLogger] = None
  224. #: The `EnvRunnerGroup` of the Algorithm. An `EnvRunnerGroup` is
  225. #: composed of a single local `EnvRunner` (see: `self.env_runner`), serving as
  226. #: the reference copy of the models to be trained and optionally one or more
  227. #: remote `EnvRunners` used to generate training samples from the RL
  228. #: environment, in parallel. EnvRunnerGroup is fault-tolerant and elastic. It
  229. #: tracks health states for all the managed remote EnvRunner actors. As a
  230. #: result, Algorithm should never access the underlying actor handles directly.
  231. #: Instead, always access them via all the foreach APIs with assigned IDs of
  232. #: the underlying EnvRunners.
  233. env_runner_group: Optional[EnvRunnerGroup] = None
  234. #: A special EnvRunnerGroup only used for evaluation, not to
  235. #: collect training samples.
  236. eval_env_runner_group: Optional[EnvRunnerGroup] = None
  237. #: The `LearnerGroup` instance of the Algorithm, managing either
  238. #: one local `Learner` or one or more remote `Learner` actors. Responsible for
  239. #: updating the models from RL environment (episode) data.
  240. learner_group: Optional["LearnerGroup"] = None
  241. #: An optional OfflineData instance, used for offline RL.
  242. offline_data: Optional["OfflineData"] = None
  243. # Whether to allow unknown top-level config keys.
  244. _allow_unknown_configs = False
  245. # List of top-level keys with value=dict, for which new sub-keys are
  246. # allowed to be added to the value dict.
  247. _allow_unknown_subkeys = [
  248. "tf_session_args",
  249. "local_tf_session_args",
  250. "env_config",
  251. "model",
  252. "optimizer",
  253. "custom_resources_per_env_runner",
  254. "custom_resources_per_worker",
  255. "evaluation_config",
  256. "exploration_config",
  257. "replay_buffer_config",
  258. "extra_python_environs_for_worker",
  259. "input_config",
  260. "output_config",
  261. ]
  262. # List of top level keys with value=dict, for which we always override the
  263. # entire value (dict), iff the "type" key in that value dict changes.
  264. _override_all_subkeys_if_type_changes = [
  265. "exploration_config",
  266. "replay_buffer_config",
  267. ]
  268. # List of keys that are always fully overridden if present in any dict or sub-dict
  269. _override_all_key_list = ["off_policy_estimation_methods", "policies"]
  270. _progress_metrics = (
  271. f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}",
  272. f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}",
  273. f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}",
  274. f"{NUM_ENV_STEPS_TRAINED_LIFETIME}",
  275. f"{NUM_EPISODES_LIFETIME}",
  276. f"{ENV_RUNNER_RESULTS}/{EPISODE_LEN_MEAN}",
  277. )
  278. # Backward compatibility with old checkpoint system (now through the
  279. # `Checkpointable` API).
  280. METADATA_FILE_NAME = "rllib_checkpoint.json"
  281. STATE_FILE_NAME = "algorithm_state"
  282. @classmethod
  283. @override(Checkpointable)
  284. def from_checkpoint(
  285. cls,
  286. path: Union[str, Checkpoint],
  287. filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  288. *,
  289. # @OldAPIStack
  290. policy_ids: Optional[Collection[PolicyID]] = None,
  291. policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
  292. policies_to_train: Optional[
  293. Union[
  294. Collection[PolicyID],
  295. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  296. ]
  297. ] = None,
  298. # deprecated args
  299. checkpoint=DEPRECATED_VALUE,
  300. **kwargs,
  301. ) -> "Algorithm":
  302. """Creates a new algorithm instance from a given checkpoint.
  303. Args:
  304. path: The path (str) to the checkpoint directory to use or a Ray Train
  305. Checkpoint instance to restore from.
  306. filesystem: PyArrow FileSystem to use to access data at the `path`. If not
  307. specified, this is inferred from the URI scheme of `path`.
  308. policy_ids: Optional list of PolicyIDs to recover. This allows users to
  309. restore an Algorithm with only a subset of the originally present
  310. Policies.
  311. policy_mapping_fn: An optional (updated) policy mapping function to use from
  312. here on.
  313. policies_to_train: An optional list of policy IDs to be trained or a
  314. callable taking PolicyID and SampleBatchType and returning a bool
  315. (trainable or not?). If None, will keep the existing setup in place.
  316. Policies, whose IDs are not in the list (or for which the callable
  317. returns False) will not be updated.
  318. Returns:
  319. The instantiated Algorithm.
  320. """
  321. if checkpoint != DEPRECATED_VALUE:
  322. deprecation_warning(
  323. old="Algorithm.from_checkpoint(checkpoint=...)",
  324. new="Algorithm.from_checkpoint(path=...)",
  325. error=True,
  326. )
  327. checkpoint_info = get_checkpoint_info(path, filesystem)
  328. # New API stack -> Use Checkpointable's default implementation.
  329. if checkpoint_info["checkpoint_version"] >= version.Version("2.0"):
  330. # `path` is a Checkpoint instance: Translate to directory and continue.
  331. if isinstance(path, Checkpoint):
  332. path = path.to_directory()
  333. return super().from_checkpoint(path, filesystem=filesystem, **kwargs)
  334. # Not possible for (v0.1) (algo class and config information missing
  335. # or very hard to retrieve).
  336. elif checkpoint_info["checkpoint_version"] == version.Version("0.1"):
  337. raise ValueError(
  338. "Cannot restore a v0 checkpoint using `Algorithm.from_checkpoint()`!"
  339. "In this case, do the following:\n"
  340. "1) Create a new Algorithm object using your original config.\n"
  341. "2) Call the `restore()` method of this algo object passing it"
  342. " your checkpoint dir or AIR Checkpoint object."
  343. )
  344. elif checkpoint_info["checkpoint_version"] < version.Version("1.0"):
  345. raise ValueError(
  346. "`checkpoint_info['checkpoint_version']` in `Algorithm.from_checkpoint"
  347. "()` must be 1.0 or later! You are using a checkpoint with "
  348. f"version v{checkpoint_info['checkpoint_version']}."
  349. )
  350. # This is a msgpack checkpoint.
  351. if checkpoint_info["format"] == "msgpack":
  352. # User did not provide unserializable function with this call
  353. # (`policy_mapping_fn`). Note that if `policies_to_train` is None, it
  354. # defaults to training all policies (so it's ok to not provide this here).
  355. if policy_mapping_fn is None:
  356. # Only DEFAULT_POLICY_ID present in this algorithm, provide default
  357. # implementations of these two functions.
  358. if checkpoint_info["policy_ids"] == {DEFAULT_POLICY_ID}:
  359. policy_mapping_fn = AlgorithmConfig.DEFAULT_POLICY_MAPPING_FN
  360. # Provide meaningful error message.
  361. else:
  362. raise ValueError(
  363. "You are trying to restore a multi-agent algorithm from a "
  364. "`msgpack` formatted checkpoint, which do NOT store the "
  365. "`policy_mapping_fn` or `policies_to_train` "
  366. "functions! Make sure that when using the "
  367. "`Algorithm.from_checkpoint()` utility, you also pass the "
  368. "args: `policy_mapping_fn` and `policies_to_train` with your "
  369. "call. You might leave `policies_to_train=None` in case "
  370. "you would like to train all policies anyways."
  371. )
  372. state = Algorithm._checkpoint_info_to_algorithm_state(
  373. checkpoint_info=checkpoint_info,
  374. policy_ids=policy_ids,
  375. policy_mapping_fn=policy_mapping_fn,
  376. policies_to_train=policies_to_train,
  377. )
  378. return Algorithm.from_state(state)
  379. @PublicAPI
  380. def __init__(
  381. self,
  382. config: Optional[AlgorithmConfig] = None,
  383. env=None, # deprecated arg
  384. logger_creator: Optional[Callable[[], Logger]] = None,
  385. **kwargs,
  386. ):
  387. """Initializes an Algorithm instance.
  388. Args:
  389. config: Algorithm-specific configuration object.
  390. logger_creator: Callable that creates a ray.tune.Logger
  391. object. If unspecified, a default logger is created.
  392. **kwargs: Arguments passed to the Trainable base class.
  393. """
  394. # Translate possible dict into an AlgorithmConfig object, as well as,
  395. # resolving generic config objects into specific ones (e.g. passing
  396. # an `AlgorithmConfig` super-class instance into a PPO constructor,
  397. # which normally would expect a PPOConfig object).
  398. if isinstance(config, dict):
  399. default_config = self.get_default_config()
  400. # `self.get_default_config()` also returned a dict ->
  401. # Last resort: Create core AlgorithmConfig from merged dicts.
  402. if isinstance(default_config, dict):
  403. if "class" in config:
  404. AlgorithmConfig.from_state(config)
  405. else:
  406. config = AlgorithmConfig.from_dict(
  407. config_dict=self.merge_algorithm_configs(
  408. default_config, config, True
  409. )
  410. )
  411. # Default config is an AlgorithmConfig -> update its properties
  412. # from the given config dict.
  413. else:
  414. if isinstance(config, dict) and "class" in config:
  415. config = default_config.from_state(config)
  416. else:
  417. config = default_config.update_from_dict(config)
  418. else:
  419. default_config = self.get_default_config()
  420. # Given AlgorithmConfig is not of the same type as the default config:
  421. # This could be the case e.g. if the user is building an algo from a
  422. # generic AlgorithmConfig() object.
  423. if not isinstance(config, type(default_config)):
  424. config = default_config.update_from_dict(config.to_dict())
  425. else:
  426. config = default_config.from_state(config.get_state())
  427. # In case this algo is using a generic config (with no algo_class set), set it
  428. # here.
  429. if config.algo_class is None:
  430. config.algo_class = type(self)
  431. if env is not None:
  432. deprecation_warning(
  433. old=f"algo = Algorithm(env='{env}', ...)",
  434. new=f"algo = AlgorithmConfig().environment('{env}').build()",
  435. error=False,
  436. )
  437. config.environment(env)
  438. # Validate and freeze our AlgorithmConfig object (no more changes possible).
  439. config.validate()
  440. config.freeze()
  441. # Convert `env` provided in config into a concrete env creator callable, which
  442. # takes an EnvContext (config dict) as arg and returning an RLlib supported Env
  443. # type (e.g. a gym.Env).
  444. self._env_id, self.env_creator = self._get_env_id_and_creator(
  445. config.env, config
  446. )
  447. env_descr = (
  448. self._env_id.__name__ if isinstance(self._env_id, type) else self._env_id
  449. )
  450. # Placeholder for a local replay buffer instance.
  451. self.local_replay_buffer = None
  452. # Placeholder for our LearnerGroup responsible for updating the RLModule(s).
  453. self.learner_group: Optional["LearnerGroup"] = None
  454. # The Algorithm's `MetricsLogger` object to collect stats from all its
  455. # components (including timers, counters and other stats in its own
  456. # `training_step()` and other methods) as well as custom callbacks.
  457. self.metrics: MetricsLogger = MetricsLogger(
  458. root=True, stats_cls_lookup=config.stats_cls_lookup
  459. )
  460. # Create a default logger creator if no logger_creator is specified
  461. if logger_creator is None:
  462. # Default logdir prefix containing the agent's name and the
  463. # env id.
  464. timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
  465. env_descr_for_dir = re.sub("[/\\\\]", "-", str(env_descr))
  466. logdir_prefix = f"{type(self).__name__}_{env_descr_for_dir}_{timestr}"
  467. if not os.path.exists(DEFAULT_STORAGE_PATH):
  468. # Possible race condition if dir is created several times on
  469. # rollout workers
  470. os.makedirs(DEFAULT_STORAGE_PATH, exist_ok=True)
  471. logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=DEFAULT_STORAGE_PATH)
  472. # Allow users to more precisely configure the created logger
  473. # via "logger_config.type".
  474. if config.logger_config and "type" in config.logger_config:
  475. def default_logger_creator(config):
  476. """Creates a custom logger with the default prefix."""
  477. cfg = config["logger_config"].copy()
  478. cls = cfg.pop("type")
  479. # Provide default for logdir, in case the user does
  480. # not specify this in the "logger_config" dict.
  481. logdir_ = cfg.pop("logdir", logdir)
  482. return from_config(cls=cls, _args=[cfg], logdir=logdir_)
  483. # If no `type` given, use tune's UnifiedLogger as last resort.
  484. else:
  485. def default_logger_creator(config):
  486. """Creates a Unified logger with the default prefix."""
  487. return UnifiedLogger(config, logdir, loggers=None)
  488. logger_creator = default_logger_creator
  489. # Metrics-related properties.
  490. self._timers = defaultdict(_Timer)
  491. self._counters = defaultdict(int)
  492. self._episode_history = []
  493. self._episodes_to_be_collected = []
  494. # The fully qualified AlgorithmConfig used for evaluation
  495. # (or None if evaluation not setup).
  496. self.evaluation_config: Optional[AlgorithmConfig] = None
  497. # Evaluation EnvRunnerGroup and metrics last returned by `self.evaluate()`.
  498. self.eval_env_runner_group: Optional[EnvRunnerGroup] = None
  499. # Ray metrics - Algorithm
  500. self._metrics_step_time: Optional[Histogram] = None
  501. self._metrics_run_one_training_iteration_time: Optional[Histogram] = None
  502. self._metrics_run_one_evaluation_time: Optional[Histogram] = None
  503. self._metrics_compile_iteration_results_time: Optional[Histogram] = None
  504. self._metrics_training_step_time: Optional[Histogram] = None
  505. self._metrics_evaluate_time: Optional[Histogram] = None
  506. self._metrics_evaluate_sync_env_runner_weights_time: Optional[Histogram] = None
  507. self._metrics_evaluate_sync_connector_states_time: Optional[Histogram] = None
  508. self._metrics_step_sync_env_runner_states_time: Optional[Histogram] = None
  509. self._metrics_load_checkpoint_time: Optional[Histogram] = None
  510. self._metrics_save_checkpoint_time: Optional[Histogram] = None
  511. # Ray metrics - Algorithm callbacks
  512. self._metrics_callback_on_train_result_time: Optional[Histogram] = None
  513. self._metrics_callback_on_evaluate_start_time: Optional[Histogram] = None
  514. self._metrics_callback_on_evaluate_end_time: Optional[Histogram] = None
  515. self._metrics_callback_on_evaluate_offline_start_time: Optional[
  516. Histogram
  517. ] = None
  518. self._metrics_callback_on_evaluate_offline_end_time: Optional[Histogram] = None
  519. # Ray metrics - IMPALA
  520. self._metrics_impala_training_step_time: Optional[Histogram] = None
  521. self._metrics_impala_training_step_aggregator_preprocessing_time: Optional[
  522. Histogram
  523. ] = None
  524. self._metrics_impala_training_step_learner_group_loop_time: Optional[
  525. Histogram
  526. ] = None
  527. self._metrics_impala_training_step_sync_env_runner_state_time: Optional[
  528. Histogram
  529. ] = None
  530. self._metrics_impala_sample_and_get_connector_states_time: Optional[
  531. Histogram
  532. ] = None
  533. self._metrics_impala_training_step_input_batches: Optional[Counter] = None
  534. self._metrics_impala_training_step_zero_input_batches: Optional[Counter] = None
  535. self._metrics_impala_training_step_env_steps_dropped: Optional[Counter] = None
  536. super().__init__(
  537. config=config,
  538. logger_creator=logger_creator,
  539. **kwargs,
  540. )
  541. def _set_up_metrics(self):
  542. self._metrics_step_time = Histogram(
  543. name="rllib_algorithm_step_time",
  544. description="Time spent in Algorithm.step()",
  545. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_LONG_EVENTS,
  546. tag_keys=("rllib",),
  547. )
  548. self._metrics_step_time.set_default_tags({"rllib": self.__class__.__name__})
  549. self._metrics_run_one_training_iteration_time = Histogram(
  550. name="rllib_algorithm_run_one_training_iteration_time",
  551. description="Time spent in Algorithm._run_one_training_iteration()",
  552. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_LONG_EVENTS,
  553. tag_keys=("rllib",),
  554. )
  555. self._metrics_run_one_training_iteration_time.set_default_tags(
  556. {"rllib": self.__class__.__name__}
  557. )
  558. self._metrics_run_one_evaluation_time = Histogram(
  559. name="rllib_algorithm_run_one_evaluation_time",
  560. description="Time spent in Algorithm._run_one_evaluation()",
  561. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_LONG_EVENTS,
  562. tag_keys=("rllib",),
  563. )
  564. self._metrics_run_one_evaluation_time.set_default_tags(
  565. {"rllib": self.__class__.__name__}
  566. )
  567. self._metrics_compile_iteration_results_time = Histogram(
  568. name="rllib_algorithm_compile_iteration_results_time",
  569. description="Time spent in Algorithm._compile_iteration_results()",
  570. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  571. tag_keys=("rllib",),
  572. )
  573. self._metrics_compile_iteration_results_time.set_default_tags(
  574. {"rllib": self.__class__.__name__}
  575. )
  576. self._metrics_training_step_time = Histogram(
  577. name="rllib_algorithm_training_step_time",
  578. description="Time spent in Algorithm.training_step()",
  579. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_LONG_EVENTS,
  580. tag_keys=("rllib",),
  581. )
  582. self._metrics_training_step_time.set_default_tags(
  583. {"rllib": self.__class__.__name__}
  584. )
  585. self._metrics_evaluate_time = Histogram(
  586. name="rllib_algorithm_evaluate_time",
  587. description="Time spent in Algorithm.evaluate()",
  588. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_LONG_EVENTS,
  589. tag_keys=("rllib",),
  590. )
  591. self._metrics_evaluate_time.set_default_tags({"rllib": self.__class__.__name__})
  592. self._metrics_evaluate_sync_env_runner_weights_time = Histogram(
  593. name="rllib_algorithm_evaluate_sync_env_runner_weights_time",
  594. description="Time spent on syncing weights to the eval EnvRunners in the Algorithm.evaluate()",
  595. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  596. tag_keys=("rllib",),
  597. )
  598. self._metrics_evaluate_sync_env_runner_weights_time.set_default_tags(
  599. {"rllib": self.__class__.__name__}
  600. )
  601. self._metrics_evaluate_sync_connector_states_time = Histogram(
  602. name="rllib_algorithm_evaluate_sync_connector_states_time",
  603. description="Time spent on syncing connector states in the Algorithm.evaluate()",
  604. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  605. tag_keys=("rllib",),
  606. )
  607. self._metrics_evaluate_sync_connector_states_time.set_default_tags(
  608. {"rllib": self.__class__.__name__}
  609. )
  610. self._metrics_step_sync_env_runner_states_time = Histogram(
  611. name="rllib_algorithm_step_sync_env_runner_states_time",
  612. description="Time spent in sync_env_runner_states code block of the Algorithm.step()",
  613. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  614. tag_keys=("rllib",),
  615. )
  616. self._metrics_step_sync_env_runner_states_time.set_default_tags(
  617. {"rllib": self.__class__.__name__}
  618. )
  619. self._metrics_load_checkpoint_time = Histogram(
  620. name="rllib_algorithm_load_checkpoint_time",
  621. description="Time spent in Algorithm.load_checkpoint()",
  622. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  623. tag_keys=("rllib",),
  624. )
  625. self._metrics_load_checkpoint_time.set_default_tags(
  626. {"rllib": self.__class__.__name__}
  627. )
  628. self._metrics_save_checkpoint_time = Histogram(
  629. name="rllib_algorithm_save_checkpoint_time",
  630. description="Time spent in Algorithm.save_checkpoint()",
  631. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  632. tag_keys=("rllib",),
  633. )
  634. self._metrics_save_checkpoint_time.set_default_tags(
  635. {"rllib": self.__class__.__name__}
  636. )
  637. # Ray metrics - Algorithm callbacks
  638. self._metrics_callback_on_train_result_time = Histogram(
  639. name="rllib_algorithm_callback_on_train_result_time",
  640. description="Time spent in callback 'on_train_result()'",
  641. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  642. tag_keys=("rllib",),
  643. )
  644. self._metrics_callback_on_train_result_time.set_default_tags(
  645. {"rllib": self.__class__.__name__}
  646. )
  647. self._metrics_callback_on_evaluate_start_time = Histogram(
  648. name="rllib_algorithm_callback_on_evaluate_start_time",
  649. description="Time spent in callback 'on_evaluate_start()'",
  650. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  651. tag_keys=("rllib",),
  652. )
  653. self._metrics_callback_on_evaluate_start_time.set_default_tags(
  654. {"rllib": self.__class__.__name__}
  655. )
  656. self._metrics_callback_on_evaluate_end_time = Histogram(
  657. name="rllib_algorithm_callback_on_evaluate_end_time",
  658. description="Time spent in callback 'on_evaluate_end()'",
  659. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  660. tag_keys=("rllib",),
  661. )
  662. self._metrics_callback_on_evaluate_end_time.set_default_tags(
  663. {"rllib": self.__class__.__name__}
  664. )
  665. self._metrics_callback_on_evaluate_offline_start_time = Histogram(
  666. name="rllib_algorithm_callback_on_evaluate_offline_start_time",
  667. description="Time spent in callback 'on_evaluate_offline_start()'",
  668. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  669. tag_keys=("rllib",),
  670. )
  671. self._metrics_callback_on_evaluate_offline_start_time.set_default_tags(
  672. {"rllib": self.__class__.__name__}
  673. )
  674. self._metrics_callback_on_evaluate_offline_end_time = Histogram(
  675. name="rllib_algorithm_callback_on_evaluate_offline_end_time",
  676. description="Time spent in callback 'on_evaluate_offline_end()'",
  677. boundaries=DEFAULT_HISTOGRAM_BOUNDARIES_SHORT_EVENTS,
  678. tag_keys=("rllib",),
  679. )
  680. self._metrics_callback_on_evaluate_offline_end_time.set_default_tags(
  681. {"rllib": self.__class__.__name__}
  682. )
  683. @OverrideToImplementCustomLogic
  684. @classmethod
  685. def get_default_config(cls) -> AlgorithmConfig:
  686. return AlgorithmConfig()
  687. @OverrideToImplementCustomLogic
  688. def _remote_worker_ids_for_metrics(self) -> List[int]:
  689. """Returns a list of remote worker IDs to fetch metrics from.
  690. Specific Algorithm implementations can override this method to
  691. use a subset of the workers for metrics collection.
  692. Returns:
  693. List of remote worker IDs to fetch metrics from.
  694. """
  695. return self.env_runner_group.healthy_worker_ids()
  696. @OverrideToImplementCustomLogic_CallToSuperRecommended
  697. @override(Trainable)
  698. def setup(self, config: AlgorithmConfig) -> None:
  699. # Setup our config: Merge the user-supplied config dict (which could
  700. # be a partial config dict) with the class' default.
  701. if not isinstance(config, AlgorithmConfig):
  702. assert isinstance(config, PartialAlgorithmConfigDict)
  703. config_obj = self.get_default_config()
  704. if not isinstance(config_obj, AlgorithmConfig):
  705. assert isinstance(config, PartialAlgorithmConfigDict)
  706. config_obj = AlgorithmConfig().from_dict(config_obj)
  707. config_obj.update_from_dict(config)
  708. config_obj.env = self._env_id
  709. self.config = config_obj
  710. # Set Algorithm's seed after we have - if necessary - enabled
  711. # tf eager-execution.
  712. update_global_seed_if_necessary(self.config.framework_str, self.config.seed)
  713. self._record_usage(self.config)
  714. # Create the callbacks object.
  715. if self.config.enable_env_runner_and_connector_v2:
  716. self.callbacks = [cls() for cls in force_list(self.config.callbacks_class)]
  717. else:
  718. self.callbacks = self.config.callbacks_class()
  719. if self.config.log_level in ["WARN", "ERROR"]:
  720. logger.info(
  721. f"Current log_level is {self.config.log_level}. For more information, "
  722. "set 'log_level': 'INFO' / 'DEBUG' or use the -v and "
  723. "-vv flags."
  724. )
  725. if self.config.log_level:
  726. logging.getLogger("ray.rllib").setLevel(self.config.log_level)
  727. # Create local replay buffer if necessary.
  728. self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
  729. self.config
  730. )
  731. # Create a dict, mapping ActorHandles to sets of open remote
  732. # requests (object refs). This way, we keep track, of which actors
  733. # inside this Algorithm (e.g. a remote EnvRunner) have
  734. # already been sent how many (e.g. `sample()`) requests.
  735. self.remote_requests_in_flight: DefaultDict[
  736. ActorHandle, Set[ray.ObjectRef]
  737. ] = defaultdict(set)
  738. self.env_runner_group: Optional[EnvRunnerGroup] = None
  739. # In case there is no local EnvRunner anymore, we need to handle connector
  740. # pipelines directly here.
  741. self.spaces: Optional[Dict] = None
  742. self.env_to_module_connector: Optional[ConnectorPipelineV2] = None
  743. self.module_to_env_connector: Optional[ConnectorPipelineV2] = None
  744. # Offline RL settings.
  745. input_evaluation = self.config.get("input_evaluation")
  746. if input_evaluation is not None and input_evaluation is not DEPRECATED_VALUE:
  747. ope_dict = {str(ope): {"type": ope} for ope in input_evaluation}
  748. deprecation_warning(
  749. old="config.input_evaluation={}".format(input_evaluation),
  750. new="config.evaluation(evaluation_config=config.overrides("
  751. f"off_policy_estimation_methods={ope_dict}"
  752. "))",
  753. error=True,
  754. help="Running OPE during training is not recommended.",
  755. )
  756. self.config.off_policy_estimation_methods = ope_dict
  757. # If an input path is available and we are on the new API stack generate
  758. # an `OfflineData` instance.
  759. if self.config.is_offline:
  760. from ray.rllib.offline.offline_data import OfflineData
  761. # Use either user-provided `OfflineData` class or RLlib's default.
  762. offline_data_class = self.config.offline_data_class or OfflineData
  763. # Build the `OfflineData` class.
  764. self.offline_data = offline_data_class(self.config)
  765. # Otherwise set the attribute to `None`.
  766. else:
  767. self.offline_data = None
  768. if self.config.is_online or not self.config.enable_env_runner_and_connector_v2:
  769. # Create a set of env runner actors via a EnvRunnerGroup.
  770. self.env_runner_group = EnvRunnerGroup(
  771. env_creator=self.env_creator,
  772. validate_env=self.validate_env,
  773. default_policy_class=self.get_default_policy_class(self.config),
  774. config=self.config,
  775. # New API stack: User decides whether to create local env runner.
  776. # Old API stack: Always create local EnvRunner.
  777. local_env_runner=(
  778. True
  779. if not self.config.enable_env_runner_and_connector_v2
  780. else self.config.create_local_env_runner
  781. ),
  782. logdir=self.logdir,
  783. tune_trial_id=self.trial_id,
  784. )
  785. # Compile, validate, and freeze an evaluation config.
  786. self.evaluation_config = self.config.get_evaluation_config_object()
  787. self.evaluation_config.validate()
  788. self.evaluation_config.freeze()
  789. # Evaluation EnvRunnerGroup setup.
  790. # User would like to setup a separate evaluation worker set.
  791. # Note: We skip EnvRunnerGroup creation if we need to do offline evaluation.
  792. if self._should_create_evaluation_env_runners(self.evaluation_config):
  793. _, env_creator = self._get_env_id_and_creator(
  794. self.evaluation_config.env, self.evaluation_config
  795. )
  796. # Create a separate evaluation worker set for evaluation.
  797. # If evaluation_num_env_runners=0, use the evaluation set's local
  798. # worker for evaluation, otherwise, use its remote workers
  799. # (parallelized evaluation).
  800. self.eval_env_runner_group: EnvRunnerGroup = EnvRunnerGroup(
  801. env_creator=env_creator,
  802. validate_env=None,
  803. default_policy_class=self.get_default_policy_class(self.config),
  804. config=self.evaluation_config,
  805. logdir=self.logdir,
  806. tune_trial_id=self.trial_id,
  807. # New API stack: User decides whether to create local env runner.
  808. # Old API stack: Always create local EnvRunner.
  809. local_env_runner=(
  810. True
  811. if not self.evaluation_config.enable_env_runner_and_connector_v2
  812. else self.evaluation_config.create_local_env_runner
  813. ),
  814. pg_offset=self.config.num_env_runners,
  815. )
  816. if self.env_runner_group:
  817. self.spaces = self.env_runner_group.get_spaces()
  818. elif self.eval_env_runner_group:
  819. self.spaces = self.eval_env_runner_group.get_spaces()
  820. if self.env_runner is None and self.spaces is not None:
  821. self.env_to_module_connector = self.config.build_env_to_module_connector(
  822. spaces=self.spaces
  823. )
  824. self.module_to_env_connector = self.config.build_module_to_env_connector(
  825. spaces=self.spaces
  826. )
  827. self.evaluation_dataset = None
  828. if (
  829. self.evaluation_config.off_policy_estimation_methods
  830. and not self.evaluation_config.ope_split_batch_by_episode
  831. ):
  832. # the num worker is set to 0 to avoid creating shards. The dataset will not
  833. # be repartioned to num_workers blocks.
  834. logger.info("Creating evaluation dataset ...")
  835. self.evaluation_dataset, _ = get_dataset_and_shards(
  836. self.evaluation_config, num_workers=0
  837. )
  838. logger.info("Evaluation dataset created")
  839. self.reward_estimators: Dict[str, OffPolicyEstimator] = {}
  840. ope_types = {
  841. "is": ImportanceSampling,
  842. "wis": WeightedImportanceSampling,
  843. "dm": DirectMethod,
  844. "dr": DoublyRobust,
  845. }
  846. for name, method_config in self.config.off_policy_estimation_methods.items():
  847. method_type = method_config.pop("type")
  848. if method_type in ope_types:
  849. deprecation_warning(
  850. old=method_type,
  851. new=str(ope_types[method_type]),
  852. error=True,
  853. )
  854. method_type = ope_types[method_type]
  855. elif isinstance(method_type, str):
  856. logger.log(0, "Trying to import from string: " + method_type)
  857. mod, obj = method_type.rsplit(".", 1)
  858. mod = importlib.import_module(mod)
  859. method_type = getattr(mod, obj)
  860. if isinstance(method_type, type) and issubclass(
  861. method_type, OfflineEvaluator
  862. ):
  863. # TODO(kourosh) : Add an integration test for all these
  864. # offline evaluators.
  865. policy = self.get_policy()
  866. if issubclass(method_type, OffPolicyEstimator):
  867. method_config["gamma"] = self.config.gamma
  868. self.reward_estimators[name] = method_type(policy, **method_config)
  869. else:
  870. raise ValueError(
  871. f"Unknown off_policy_estimation type: {method_type}! Must be "
  872. "either a class path or a sub-class of ray.rllib."
  873. "offline.offline_evaluator::OfflineEvaluator"
  874. )
  875. # TODO (Rohan138): Refactor this and remove deprecated methods
  876. # Need to add back method_type in case Algorithm is restored from checkpoint
  877. method_config["type"] = method_type
  878. if self.config.enable_rl_module_and_learner:
  879. spaces = {
  880. INPUT_ENV_SPACES: (
  881. self.config.observation_space,
  882. self.config.action_space,
  883. )
  884. }
  885. if self.env_runner_group:
  886. spaces.update(self.spaces)
  887. elif self.eval_env_runner_group:
  888. spaces.update(self.eval_env_runner_group.get_spaces())
  889. else:
  890. # If the algorithm is online we use the spaces from as they are
  891. # provided.
  892. if self.config.is_online:
  893. spaces.update(
  894. {
  895. DEFAULT_MODULE_ID: (
  896. self.config.observation_space,
  897. self.config.action_space,
  898. ),
  899. }
  900. )
  901. # Otherwise, when we are offline we need to check, if the learner connector
  902. # is transforming the spaces.
  903. elif self.config.is_offline:
  904. # Build the learner connector with the input spaces from the environment.
  905. learner_connector = self.config.build_learner_connector(
  906. input_observation_space=spaces[INPUT_ENV_SPACES][0],
  907. input_action_space=spaces[INPUT_ENV_SPACES][1],
  908. )
  909. # Update the `spaces` dictionary by using the output spaces of the learner
  910. # connector pipeline.
  911. spaces.update(
  912. {
  913. DEFAULT_MODULE_ID: (
  914. learner_connector.observation_space,
  915. learner_connector.action_space,
  916. ),
  917. }
  918. )
  919. module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec(
  920. spaces=spaces,
  921. inference_only=False,
  922. )
  923. self.learner_group = self.config.build_learner_group(
  924. rl_module_spec=module_spec
  925. )
  926. # Check if there are modules to load from the `module_spec`.
  927. rl_module_ckpt_dirs = {}
  928. multi_rl_module_ckpt_dir = module_spec.load_state_path
  929. modules_to_load = module_spec.modules_to_load
  930. for module_id, sub_module_spec in module_spec.rl_module_specs.items():
  931. if sub_module_spec.load_state_path:
  932. rl_module_ckpt_dirs[module_id] = sub_module_spec.load_state_path
  933. if multi_rl_module_ckpt_dir or rl_module_ckpt_dirs:
  934. self.learner_group.load_module_state(
  935. multi_rl_module_ckpt_dir=multi_rl_module_ckpt_dir,
  936. modules_to_load=modules_to_load,
  937. rl_module_ckpt_dirs=rl_module_ckpt_dirs,
  938. )
  939. # Sync the weights from the learner group to the EnvRunners.
  940. rl_module_state = self.learner_group.get_state(
  941. components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
  942. inference_only=True,
  943. )[COMPONENT_LEARNER]
  944. if self.env_runner_group:
  945. self.env_runner_group.sync_env_runner_states(
  946. config=self.config,
  947. env_steps_sampled=self.metrics.peek(
  948. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  949. ),
  950. rl_module_state=rl_module_state,
  951. env_to_module=self.env_to_module_connector,
  952. module_to_env=self.module_to_env_connector,
  953. )
  954. elif self.eval_env_runner_group:
  955. self.eval_env_runner_group.sync_env_runner_states(
  956. config=self.evaluation_config,
  957. env_steps_sampled=self.metrics.peek(
  958. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  959. ),
  960. rl_module_state=rl_module_state,
  961. env_to_module=self.env_to_module_connector,
  962. module_to_env=self.module_to_env_connector,
  963. )
  964. # TODO (simon): Update modules in DataWorkers.
  965. if self.offline_data:
  966. # If the learners are remote we need to provide specific
  967. # information and the learner's actor handles.
  968. if self.learner_group.is_remote:
  969. # If learners run on different nodes, locality hints help
  970. # to use the nearest learner in the workers that do the
  971. # data preprocessing.
  972. learner_node_ids = self.learner_group.foreach_learner(
  973. lambda _: ray.get_runtime_context().get_node_id()
  974. )
  975. self.offline_data.locality_hints = [
  976. node_id.get() for node_id in learner_node_ids
  977. ]
  978. # Provide the actor handles for the learners for module
  979. # updating during preprocessing.
  980. self.offline_data.learner_handles = self.learner_group._workers
  981. # Otherwise we can simply pass in the local learner.
  982. else:
  983. self.offline_data.learner_handles = [self.learner_group._learner]
  984. # TODO (simon, sven): Replace these set-some-object's-attributes-
  985. # directly? We should find some solution for this in the future, an API,
  986. # or setting these in the OfflineData constructor?
  987. # Provide the module_spec. Note, in the remote case this is needed
  988. # because the learner module cannot be copied, but must be built.
  989. self.offline_data.module_spec = module_spec
  990. # Provide the `OfflineData` instance with space information. It might
  991. # need it for reading recorded experiences.
  992. self.offline_data.spaces = spaces
  993. if self._should_create_offline_evaluation_runners(self.evaluation_config):
  994. from ray.rllib.offline.offline_evaluation_runner_group import (
  995. OfflineEvaluationRunnerGroup,
  996. )
  997. # If no inference-only `RLModule` should be used in offline evaluation,
  998. # get the complete learner module.
  999. if not self.evaluation_config.offline_eval_rl_module_inference_only:
  1000. rl_module_state = self.learner_group.get_state(
  1001. components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
  1002. inference_only=False,
  1003. )[COMPONENT_LEARNER]
  1004. # Create the offline evaluation runner group.
  1005. self.offline_eval_runner_group: OfflineEvaluationRunnerGroup = OfflineEvaluationRunnerGroup(
  1006. config=self.evaluation_config,
  1007. # Do not create a local runner such that the dataset can be split.
  1008. local_runner=self.config.num_offline_eval_runners == 0,
  1009. # Provide the `RLModule`'s state for the `OfflinePreLearner`s.
  1010. module_state=rl_module_state[COMPONENT_RL_MODULE],
  1011. module_spec=module_spec,
  1012. # Note, even if no environment is run, the `MultiRLModule` needs
  1013. # spaces to construct the policy network.
  1014. spaces=spaces,
  1015. )
  1016. # Create an Aggregator actor set, if necessary.
  1017. self._aggregator_actor_manager = None
  1018. if self.config.enable_rl_module_and_learner and (
  1019. self.config.num_aggregator_actors_per_learner > 0
  1020. ):
  1021. rl_module_spec = self.config.get_multi_rl_module_spec(
  1022. spaces=self.spaces,
  1023. inference_only=False,
  1024. )
  1025. agg_cls = ray.remote(
  1026. num_cpus=1,
  1027. max_restarts=-1,
  1028. )(AggregatorActor)
  1029. self._aggregator_actor_manager = FaultTolerantActorManager(
  1030. [
  1031. agg_cls.remote(self.config, rl_module_spec)
  1032. for _ in range(
  1033. (self.config.num_learners or 1)
  1034. * self.config.num_aggregator_actors_per_learner
  1035. )
  1036. ],
  1037. max_remote_requests_in_flight_per_actor=(
  1038. self.config.max_requests_in_flight_per_aggregator_actor
  1039. ),
  1040. )
  1041. # Get the devices of each learner.
  1042. learner_locations = list(
  1043. enumerate(
  1044. self.learner_group.foreach_learner(
  1045. func=lambda _learner: (_learner.node, _learner.device),
  1046. )
  1047. )
  1048. )
  1049. # Get the devices of each AggregatorActor.
  1050. aggregator_locations = list(
  1051. enumerate(
  1052. self._aggregator_actor_manager.foreach_actor(
  1053. func=lambda actor: (actor._node, actor._device)
  1054. )
  1055. )
  1056. )
  1057. self._aggregator_actor_to_learner = {}
  1058. for agg_idx, aggregator_location in aggregator_locations:
  1059. aggregator_location = aggregator_location.get()
  1060. for learner_idx, learner_location in learner_locations:
  1061. # TODO (sven): Activate full comparison (including device) when Ray
  1062. # has figured out GPU pre-loading.
  1063. if learner_location.get()[0] == aggregator_location[0]:
  1064. # Round-robin, in case all Learners are on same device/node.
  1065. learner_locations = learner_locations[1:] + [
  1066. learner_locations[0]
  1067. ]
  1068. self._aggregator_actor_to_learner[agg_idx] = learner_idx
  1069. break
  1070. if agg_idx not in self._aggregator_actor_to_learner:
  1071. raise RuntimeError(
  1072. "No Learner worker found that matches aggregation worker "
  1073. f"#{agg_idx}'s node ({aggregator_location[0]}) and device "
  1074. f"({aggregator_location[1]})! The Learner workers' locations "
  1075. f"are {learner_locations}."
  1076. )
  1077. # Make sure, each Learner index is mapped to from at least one
  1078. # AggregatorActor.
  1079. if not all(
  1080. learner_idx in self._aggregator_actor_to_learner.values()
  1081. for learner_idx in range(self.config.num_learners or 1)
  1082. ):
  1083. raise RuntimeError(
  1084. "Some Learner indices are not mapped to from any AggregatorActors! "
  1085. "Final AggregatorActor idx -> Learner idx mapping is: "
  1086. f"{self._aggregator_actor_to_learner}"
  1087. )
  1088. # Ray metrics
  1089. self._set_up_metrics()
  1090. # Run `on_algorithm_init` callback after initialization is done.
  1091. make_callback(
  1092. "on_algorithm_init",
  1093. self.callbacks,
  1094. self.config.callbacks_on_algorithm_init,
  1095. kwargs=dict(
  1096. algorithm=self,
  1097. metrics_logger=self.metrics,
  1098. ),
  1099. )
  1100. @OverrideToImplementCustomLogic
  1101. @classmethod
  1102. def get_default_policy_class(
  1103. cls,
  1104. config: AlgorithmConfig,
  1105. ) -> Optional[Type[Policy]]:
  1106. """Returns a default Policy class to use, given a config.
  1107. This class will be used by an Algorithm in case
  1108. the policy class is not provided by the user in any single- or
  1109. multi-agent PolicySpec.
  1110. Note: This method is ignored when the RLModule API is enabled.
  1111. """
  1112. return None
  1113. @override(Trainable)
  1114. def step(self) -> ResultDict:
  1115. """Implements the main `Algorithm.train()` logic.
  1116. Takes n attempts to perform a single training step. Thereby
  1117. catches RayErrors resulting from worker failures. After n attempts,
  1118. fails gracefully.
  1119. Override this method in your Algorithm sub-classes if you would like to
  1120. handle worker failures yourself.
  1121. Otherwise, override only `training_step()` to implement the core
  1122. algorithm logic.
  1123. Returns:
  1124. The results dict with stats/infos on sampling, training,
  1125. and - if required - evaluation.
  1126. """
  1127. # Ray metrics
  1128. with TimerAndPrometheusLogger(self._metrics_step_time):
  1129. # Do we have to run `self.evaluate()` this iteration?
  1130. # `self.iteration` gets incremented after this function returns,
  1131. # meaning that e.g. the first time this function is called,
  1132. # self.iteration will be 0.
  1133. evaluate_this_iter = bool(
  1134. self.config.evaluation_interval
  1135. and (self.iteration + 1) % self.config.evaluation_interval == 0
  1136. )
  1137. evaluate_offline_this_iter = bool(
  1138. self.config.offline_evaluation_interval
  1139. and (self.iteration + 1) % self.config.offline_evaluation_interval == 0
  1140. )
  1141. # Results dict for training (and if appolicable: evaluation).
  1142. eval_results: ResultDict = {}
  1143. # Parallel eval + training: Kick off evaluation-loop and parallel train() call.
  1144. if evaluate_this_iter and (
  1145. self.config.evaluation_parallel_to_training
  1146. or self.config.offline_evaluation_parallel_to_training
  1147. ):
  1148. (
  1149. train_results,
  1150. eval_results,
  1151. train_iter_ctx,
  1152. ) = self._run_one_training_iteration_and_evaluation_in_parallel()
  1153. # - No evaluation necessary, just run the next training iteration.
  1154. # - We have to evaluate in this training iteration, but no parallelism ->
  1155. # evaluate after the training iteration is entirely done.
  1156. else:
  1157. if self.config.enable_env_runner_and_connector_v2:
  1158. train_results, train_iter_ctx = self._run_one_training_iteration()
  1159. else:
  1160. (
  1161. train_results,
  1162. train_iter_ctx,
  1163. ) = self._run_one_training_iteration_old_api_stack()
  1164. # Sequential: Train (already done above), then evaluate.
  1165. if evaluate_this_iter and not self.config.evaluation_parallel_to_training:
  1166. eval_results = self._run_one_evaluation(parallel_train_future=None)
  1167. if evaluate_offline_this_iter:
  1168. offline_eval_results = self._run_one_offline_evaluation()
  1169. # If we already have online evaluation results merge the offline
  1170. # evaluation results.
  1171. if eval_results:
  1172. eval_results[EVALUATION_RESULTS].update(
  1173. offline_eval_results[EVALUATION_RESULTS]
  1174. )
  1175. # Otherwise, just assign.
  1176. else:
  1177. eval_results = offline_eval_results
  1178. # Sync EnvRunner workers.
  1179. # TODO (sven): For the new API stack, the common execution pattern for any algo
  1180. # should be: [sample + get_metrics + get_state] -> send all these in one remote
  1181. # call down to `training_step` (where episodes are sent as ray object
  1182. # references). Then distribute the episode refs to the learners, store metrics
  1183. # in special key in result dict and perform the connector merge/broadcast
  1184. # inside the `training_step` as well. See the new IMPALA for an example.
  1185. if self.config.enable_env_runner_and_connector_v2:
  1186. if (
  1187. not self.config._dont_auto_sync_env_runner_states
  1188. and self.env_runner_group
  1189. ):
  1190. # Synchronize EnvToModule and ModuleToEnv connector states and broadcast
  1191. # new states back to all EnvRunners.
  1192. with self.metrics.log_time(
  1193. (TIMERS, SYNCH_ENV_CONNECTOR_STATES_TIMER)
  1194. ):
  1195. with TimerAndPrometheusLogger(
  1196. self._metrics_step_sync_env_runner_states_time
  1197. ):
  1198. self.env_runner_group.sync_env_runner_states(
  1199. config=self.config,
  1200. env_steps_sampled=self.metrics.peek(
  1201. (
  1202. ENV_RUNNER_RESULTS,
  1203. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  1204. ),
  1205. default=0,
  1206. ),
  1207. env_to_module=self.env_to_module_connector,
  1208. module_to_env=self.module_to_env_connector,
  1209. )
  1210. # Compile final ResultDict from `train_results` and `eval_results`. Note
  1211. # that, as opposed to the old API stack, EnvRunner stats should already be
  1212. # in `train_results` and `eval_results`.
  1213. results = self._compile_iteration_results(
  1214. train_results=train_results,
  1215. eval_results=eval_results,
  1216. )
  1217. else:
  1218. self._sync_filters_if_needed(
  1219. central_worker=self.env_runner_group.local_env_runner,
  1220. workers=self.env_runner_group,
  1221. config=self.config,
  1222. )
  1223. # Get EnvRunner metrics and compile them into results.
  1224. episodes_this_iter = collect_episodes(
  1225. self.env_runner_group,
  1226. self._remote_worker_ids_for_metrics(),
  1227. timeout_seconds=self.config.metrics_episode_collection_timeout_s,
  1228. )
  1229. results = self._compile_iteration_results_old_api_stack(
  1230. episodes_this_iter=episodes_this_iter,
  1231. step_ctx=train_iter_ctx,
  1232. iteration_results={**train_results, **eval_results},
  1233. )
  1234. return results
  1235. @PublicAPI
  1236. def evaluate_offline(self) -> ResultDict:
  1237. """Evaluates current policy offline under `evaluation_config` settings.
  1238. Returns:
  1239. A ResultDict only containing the offline evaluation results from the current
  1240. iteration.
  1241. """
  1242. # First synchronize weights.
  1243. self.offline_eval_runner_group.sync_weights(
  1244. from_worker_or_learner_group=self.learner_group,
  1245. inference_only=self.config.offline_eval_rl_module_inference_only,
  1246. )
  1247. # TODO (simon): Check, how we can sync without a local runner. Also,
  1248. # connectors are in the data pipeline not directly in the runner applied.
  1249. # NOTE (simon): Connector synching must actually happen in the OfflinePreLearner/OfflinePreEvaluation
  1250. # if self.config.broadcast_offline_eval_runner_states:
  1251. # # TODO (simon): Create offline equivalent.
  1252. # with self.metrics.log_time(TIMERS, SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER):
  1253. # self.offline_eval_runner_group.sync_runner_states(
  1254. # from_runner=
  1255. # )
  1256. with TimerAndPrometheusLogger(
  1257. self._metrics_callback_on_evaluate_offline_start_time
  1258. ):
  1259. make_callback(
  1260. "on_evaluate_offline_start",
  1261. callbacks_objects=self.callbacks,
  1262. callbacks_functions=self.config.callbacks_on_evaluate_offline_start,
  1263. kwargs=dict(algorithm=self, metrics_logger=self.metrics),
  1264. )
  1265. # Evaluate with fixed duration.
  1266. if self.offline_eval_runner_group.num_healthy_remote_runners > 0:
  1267. self._evaluate_offline_with_fixed_duration()
  1268. else:
  1269. self._evaluate_offline_on_local_runner()
  1270. # Check, whether we have any results.
  1271. if log_once("no_offline_eval_results") and not self.metrics.peek(
  1272. (EVALUATION_RESULTS, OFFLINE_EVAL_RUNNER_RESULTS)
  1273. ):
  1274. logger.warning(
  1275. "No offline evaluation results found for this iteration. "
  1276. "This can happen if the offline evaluation runner(s) is/are not healthy."
  1277. )
  1278. # Peek the offline evaluation results from the metrics store.
  1279. eval_results = self.metrics.peek(
  1280. (EVALUATION_RESULTS, OFFLINE_EVAL_RUNNER_RESULTS),
  1281. default={},
  1282. latest_merged_only=True,
  1283. )
  1284. # Trigger `on_evaluate_offline_end` callback.
  1285. with TimerAndPrometheusLogger(
  1286. self._metrics_callback_on_evaluate_offline_end_time
  1287. ):
  1288. make_callback(
  1289. "on_evaluate_offline_end",
  1290. callbacks_objects=self.callbacks,
  1291. callbacks_functions=self.config.callbacks_on_evaluate_offline_end,
  1292. kwargs=dict(
  1293. algorithm=self,
  1294. metrics_logger=self.metrics,
  1295. evaluation_metrics=eval_results,
  1296. ),
  1297. )
  1298. # Also return the results here for convenience.
  1299. return {OFFLINE_EVAL_RUNNER_RESULTS: eval_results}
  1300. @PublicAPI
  1301. def evaluate(
  1302. self,
  1303. parallel_train_future: Optional[concurrent.futures.ThreadPoolExecutor] = None,
  1304. ) -> ResultDict:
  1305. """Evaluates current policy under `evaluation_config` settings.
  1306. Args:
  1307. parallel_train_future: In case, we are training and avaluating in parallel,
  1308. this arg carries the currently running ThreadPoolExecutor object that
  1309. runs the training iteration. Use `parallel_train_future.done()` to
  1310. check, whether the parallel training job has completed and
  1311. `parallel_train_future.result()` to get its return values.
  1312. Returns:
  1313. A ResultDict only containing the evaluation results from the current
  1314. iteration.
  1315. """
  1316. with TimerAndPrometheusLogger(self._metrics_evaluate_time):
  1317. # Call the `_before_evaluate` hook.
  1318. self._before_evaluate()
  1319. if self.evaluation_dataset is not None:
  1320. return self._run_offline_evaluation_old_api_stack()
  1321. if self.config.enable_env_runner_and_connector_v2:
  1322. if (
  1323. self.env_runner_group is not None
  1324. and self.env_runner_group.healthy_env_runner_ids()
  1325. ):
  1326. # TODO (sven): Replace this with a new ActorManager API:
  1327. # try_remote_request_till_success("get_state") -> tuple(int,
  1328. # remoteresult)
  1329. weights_src = self.env_runner_group._worker_manager._actors[
  1330. self.env_runner_group.healthy_env_runner_ids()[0]
  1331. ]
  1332. else:
  1333. weights_src = self.learner_group
  1334. else:
  1335. weights_src = self.env_runner
  1336. # Sync weights to the evaluation EnvRunners.
  1337. if self.eval_env_runner_group is not None:
  1338. with TimerAndPrometheusLogger(
  1339. self._metrics_evaluate_sync_env_runner_weights_time
  1340. ):
  1341. self.eval_env_runner_group.sync_weights(
  1342. from_worker_or_learner_group=weights_src,
  1343. inference_only=True,
  1344. )
  1345. # Merge (eval) EnvRunner states and broadcast the merged state back
  1346. # to the remote (eval) EnvRunner actors.
  1347. if self.config.enable_env_runner_and_connector_v2:
  1348. if self.evaluation_config.broadcast_env_runner_states:
  1349. with self.metrics.log_time(
  1350. (TIMERS, SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER)
  1351. ):
  1352. with TimerAndPrometheusLogger(
  1353. self._metrics_evaluate_sync_connector_states_time
  1354. ):
  1355. self.eval_env_runner_group.sync_env_runner_states(
  1356. config=self.evaluation_config,
  1357. from_worker=self.env_runner,
  1358. env_steps_sampled=self.metrics.peek(
  1359. (
  1360. ENV_RUNNER_RESULTS,
  1361. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  1362. ),
  1363. default=0,
  1364. ),
  1365. env_to_module=self.env_to_module_connector,
  1366. module_to_env=self.module_to_env_connector,
  1367. )
  1368. else:
  1369. self._sync_filters_if_needed(
  1370. central_worker=self.env_runner_group.local_env_runner,
  1371. workers=self.eval_env_runner_group,
  1372. config=self.evaluation_config,
  1373. )
  1374. # Sync weights to the local EnvRunner (if no eval EnvRunnerGroup).
  1375. elif self.config.enable_env_runner_and_connector_v2:
  1376. self.env_runner_group.sync_weights(
  1377. from_worker_or_learner_group=weights_src,
  1378. inference_only=True,
  1379. )
  1380. with TimerAndPrometheusLogger(
  1381. self._metrics_callback_on_evaluate_start_time
  1382. ):
  1383. make_callback(
  1384. "on_evaluate_start",
  1385. callbacks_objects=self.callbacks,
  1386. callbacks_functions=self.config.callbacks_on_evaluate_start,
  1387. kwargs=dict(algorithm=self, metrics_logger=self.metrics),
  1388. )
  1389. env_steps = agent_steps = 0
  1390. batches = []
  1391. # We will use a user provided evaluation function.
  1392. if self.config.custom_evaluation_function:
  1393. if self.config.enable_env_runner_and_connector_v2:
  1394. (
  1395. eval_results,
  1396. env_steps,
  1397. agent_steps,
  1398. ) = self._evaluate_with_custom_eval_function()
  1399. else:
  1400. eval_results = self.config.custom_evaluation_function()
  1401. # There is no eval EnvRunnerGroup -> Run on local EnvRunner.
  1402. elif self.eval_env_runner_group is None and self.env_runner:
  1403. (
  1404. eval_results,
  1405. env_steps,
  1406. agent_steps,
  1407. batches,
  1408. ) = self._evaluate_on_local_env_runner(self.env_runner)
  1409. # There is only a local eval EnvRunner -> Run on that.
  1410. elif self.eval_env_runner_group.num_healthy_remote_workers() == 0:
  1411. (
  1412. eval_results,
  1413. env_steps,
  1414. agent_steps,
  1415. batches,
  1416. ) = self._evaluate_on_local_env_runner(self.eval_env_runner)
  1417. # There are healthy remote evaluation workers -> Run on these.
  1418. elif self.eval_env_runner_group.num_healthy_remote_workers() > 0:
  1419. # Running in automatic duration mode (parallel with training step).
  1420. if self.config.evaluation_duration == "auto":
  1421. assert parallel_train_future is not None
  1422. (
  1423. eval_results,
  1424. env_steps,
  1425. agent_steps,
  1426. batches,
  1427. ) = self._evaluate_with_auto_duration(parallel_train_future)
  1428. # Running with a fixed amount of data to sample.
  1429. else:
  1430. (
  1431. eval_results,
  1432. env_steps,
  1433. agent_steps,
  1434. batches,
  1435. ) = self._evaluate_with_fixed_duration()
  1436. # Can't find a good way to run this evaluation -> Wait for next iteration.
  1437. else:
  1438. eval_results = {}
  1439. if self.config.enable_env_runner_and_connector_v2:
  1440. if log_once("no_eval_results") and not self.metrics.peek(
  1441. (EVALUATION_RESULTS, ENV_RUNNER_RESULTS)
  1442. ):
  1443. logger.warning(
  1444. "No evaluation results found for this iteration. This can happen if the evaluation worker(s) is/are not healthy."
  1445. )
  1446. # Peek the results here from the metrics store if requested.
  1447. eval_results = self.metrics.peek(
  1448. key=EVALUATION_RESULTS,
  1449. default={},
  1450. latest_merged_only=True,
  1451. )
  1452. else:
  1453. eval_results = {ENV_RUNNER_RESULTS: eval_results}
  1454. eval_results[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps
  1455. eval_results[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps
  1456. eval_results["timesteps_this_iter"] = env_steps
  1457. self._counters[
  1458. NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER
  1459. ] = env_steps
  1460. # Compute off-policy estimates
  1461. if not self.config.custom_evaluation_function:
  1462. estimates = defaultdict(list)
  1463. # for each batch run the estimator's fwd pass
  1464. for name, estimator in self.reward_estimators.items():
  1465. for batch in batches:
  1466. estimate_result = estimator.estimate(
  1467. batch,
  1468. split_batch_by_episode=self.config.ope_split_batch_by_episode,
  1469. )
  1470. estimates[name].append(estimate_result)
  1471. # collate estimates from all batches
  1472. if estimates:
  1473. eval_results["off_policy_estimator"] = {}
  1474. for name, estimate_list in estimates.items():
  1475. avg_estimate = tree.map_structure(
  1476. lambda *x: np.mean(x, axis=0), *estimate_list
  1477. )
  1478. eval_results["off_policy_estimator"][name] = avg_estimate
  1479. # Trigger `on_evaluate_end` callback.
  1480. with TimerAndPrometheusLogger(self._metrics_callback_on_evaluate_end_time):
  1481. make_callback(
  1482. "on_evaluate_end",
  1483. callbacks_objects=self.callbacks,
  1484. callbacks_functions=self.config.callbacks_on_evaluate_end,
  1485. kwargs=dict(
  1486. algorithm=self,
  1487. metrics_logger=self.metrics,
  1488. evaluation_metrics=eval_results,
  1489. ),
  1490. )
  1491. # Also return the results here for convenience.
  1492. return eval_results
  1493. def _evaluate_with_custom_eval_function(self) -> Tuple[ResultDict, int, int]:
  1494. logger.info(
  1495. f"Evaluating current state of {self} using the custom eval function "
  1496. f"{self.config.custom_evaluation_function}"
  1497. )
  1498. if self.config.enable_env_runner_and_connector_v2:
  1499. (
  1500. eval_results,
  1501. env_steps,
  1502. agent_steps,
  1503. ) = self.config.custom_evaluation_function(self, self.eval_env_runner_group)
  1504. if not env_steps or not agent_steps:
  1505. raise ValueError(
  1506. "Custom eval function must return "
  1507. "`Tuple[ResultDict, int, int]` with `int, int` being "
  1508. f"`env_steps` and `agent_steps`! Got {env_steps}, {agent_steps}."
  1509. )
  1510. else:
  1511. eval_results = self.config.custom_evaluation_function()
  1512. if not eval_results or not isinstance(eval_results, dict):
  1513. raise ValueError(
  1514. "Custom eval function must return "
  1515. f"dict of metrics! Got {eval_results}."
  1516. )
  1517. return eval_results, env_steps, agent_steps
  1518. def _evaluate_offline_on_local_runner(self):
  1519. # How many episodes/timesteps do we need to run?
  1520. unit = "batches"
  1521. duration = (
  1522. self.config.offline_evaluation_duration
  1523. * self.config.dataset_num_iters_per_eval_runner
  1524. )
  1525. logger.info(f"Evaluating current state of {self} for {duration} {unit}.")
  1526. results = self.offline_eval_runner_group.local_runner.run()
  1527. self.metrics.aggregate(
  1528. [results],
  1529. key=(EVALUATION_RESULTS, OFFLINE_EVAL_RUNNER_RESULTS),
  1530. )
  1531. def _evaluate_on_local_env_runner(self, env_runner):
  1532. if hasattr(env_runner, "input_reader") and env_runner.input_reader is None:
  1533. raise ValueError(
  1534. "Can't evaluate on a local worker if this local worker does not have "
  1535. "an environment!\nTry one of the following:"
  1536. "\n1) Set `evaluation_interval` > 0 to force creating a separate "
  1537. "evaluation EnvRunnerGroup.\n2) Set `create_local_env_runner=True` to "
  1538. "force the local (non-eval) EnvRunner to have an environment to "
  1539. "evaluate on."
  1540. )
  1541. elif self.config.evaluation_parallel_to_training:
  1542. raise ValueError(
  1543. "Cannot run on local evaluation worker parallel to training! Try "
  1544. "setting `evaluation_parallel_to_training=False`."
  1545. )
  1546. # How many episodes/timesteps do we need to run?
  1547. unit = self.config.evaluation_duration_unit
  1548. duration = self.config.evaluation_duration
  1549. eval_cfg = self.evaluation_config
  1550. env_steps = agent_steps = 0
  1551. logger.info(f"Evaluating current state of {self} for {duration} {unit}.")
  1552. all_batches = []
  1553. if self.config.enable_env_runner_and_connector_v2:
  1554. episodes = env_runner.sample(
  1555. num_timesteps=duration if unit == "timesteps" else None,
  1556. num_episodes=duration if unit == "episodes" else None,
  1557. )
  1558. agent_steps += sum(e.agent_steps() for e in episodes)
  1559. env_steps += sum(e.env_steps() for e in episodes)
  1560. elif unit == "episodes":
  1561. for _ in range(duration):
  1562. batch = env_runner.sample()
  1563. agent_steps += batch.agent_steps()
  1564. env_steps += batch.env_steps()
  1565. if self.reward_estimators:
  1566. all_batches.append(batch)
  1567. else:
  1568. batch = env_runner.sample()
  1569. agent_steps += batch.agent_steps()
  1570. env_steps += batch.env_steps()
  1571. if self.reward_estimators:
  1572. all_batches.append(batch)
  1573. env_runner_results = env_runner.get_metrics()
  1574. if not self.config.enable_env_runner_and_connector_v2:
  1575. env_runner_results = summarize_episodes(
  1576. env_runner_results,
  1577. env_runner_results,
  1578. keep_custom_metrics=eval_cfg.keep_per_episode_custom_metrics,
  1579. )
  1580. else:
  1581. self.metrics.aggregate(
  1582. [env_runner_results],
  1583. key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS),
  1584. )
  1585. env_runner_results = None
  1586. return env_runner_results, env_steps, agent_steps, all_batches
  1587. def _evaluate_with_auto_duration(self, parallel_train_future):
  1588. logger.info(
  1589. f"Evaluating current state of {self} for as long as the parallelly "
  1590. "running training step takes."
  1591. )
  1592. all_metrics = []
  1593. all_batches = []
  1594. # How many episodes have we run (across all eval workers)?
  1595. num_healthy_workers = self.eval_env_runner_group.num_healthy_remote_workers()
  1596. # Do we have to force-reset the EnvRunners before the first round of `sample()`
  1597. # calls.?
  1598. force_reset = self.config.evaluation_force_reset_envs_before_iteration
  1599. # Remote function used on healthy EnvRunners to sample, get metrics, and
  1600. # step counts.
  1601. def _env_runner_remote(worker, num, round, iter):
  1602. # Sample AND get_metrics, but only return metrics (and steps actually taken)
  1603. # to save time.
  1604. episodes = worker.sample(
  1605. num_timesteps=num, force_reset=force_reset and round == 0
  1606. )
  1607. metrics = worker.get_metrics()
  1608. env_steps = sum(e.env_steps() for e in episodes)
  1609. agent_steps = sum(e.agent_steps() for e in episodes)
  1610. return env_steps, agent_steps, metrics, iter
  1611. env_steps = agent_steps = 0
  1612. if self.config.enable_env_runner_and_connector_v2:
  1613. train_mean_time = self.metrics.peek(
  1614. (TIMERS, TRAINING_ITERATION_TIMER), default=0.0
  1615. )
  1616. else:
  1617. train_mean_time = self._timers[TRAINING_ITERATION_TIMER].mean
  1618. t0 = time.time()
  1619. algo_iteration = self.iteration
  1620. _round = -1
  1621. while (
  1622. # In case all the remote evaluation workers die during a round of
  1623. # evaluation, we need to stop.
  1624. num_healthy_workers > 0
  1625. # Run at least for one round AND at least for as long as the parallel
  1626. # training step takes.
  1627. and (_round == -1 or not parallel_train_future.done())
  1628. ):
  1629. _round += 1
  1630. # New API stack -> EnvRunners return Episodes.
  1631. if self.config.enable_env_runner_and_connector_v2:
  1632. # Compute rough number of timesteps it takes for a single EnvRunner
  1633. # to occupy the estimated (parallelly running) train step.
  1634. throughput_estimate = self.metrics.peek(
  1635. (
  1636. EVALUATION_RESULTS,
  1637. ENV_RUNNER_RESULTS,
  1638. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  1639. ),
  1640. throughput=True,
  1641. # Note (artur): Peeking throughputs of lifetime metrics results in a dictionary with both throughputs (since last restore and total).
  1642. # We only need the throughput since last restore here.
  1643. default={"throughput_since_last_restore": 0.0},
  1644. )["throughput_since_last_restore"]
  1645. _num = min(
  1646. # Clamp number of steps to take between a max and a min.
  1647. self.config.evaluation_auto_duration_max_env_steps_per_sample,
  1648. max(
  1649. self.config.evaluation_auto_duration_min_env_steps_per_sample,
  1650. (
  1651. # How much time do we have left?
  1652. (train_mean_time - (time.time() - t0))
  1653. # Multiply by our own (eval) throughput to get the timesteps
  1654. # to do (per worker).
  1655. * throughput_estimate
  1656. / num_healthy_workers
  1657. ),
  1658. ),
  1659. )
  1660. results = (
  1661. self.eval_env_runner_group.foreach_env_runner_async_fetch_ready(
  1662. func=_env_runner_remote,
  1663. kwargs={"num": _num, "round": _round, "iter": algo_iteration},
  1664. tag="_env_runner_remote",
  1665. )
  1666. )
  1667. for env_s, ag_s, metrics, iter in results:
  1668. # Ignore eval results kicked off in an earlier iteration.
  1669. # (those results would be outdated and thus misleading).
  1670. if iter != self.iteration:
  1671. continue
  1672. env_steps += env_s
  1673. agent_steps += ag_s
  1674. all_metrics.append(metrics)
  1675. time.sleep(0.01)
  1676. # Old API stack -> RolloutWorkers return batches.
  1677. else:
  1678. results = (
  1679. self.eval_env_runner_group.foreach_env_runner_async_fetch_ready(
  1680. func=lambda w: (w.sample(), w.get_metrics(), algo_iteration),
  1681. tag="env_runner_sample_and_get_metrics",
  1682. )
  1683. )
  1684. for batch, metrics, iter in results:
  1685. if iter != self.iteration:
  1686. continue
  1687. env_steps += batch.env_steps()
  1688. agent_steps += batch.agent_steps()
  1689. all_metrics.extend(metrics)
  1690. if self.reward_estimators:
  1691. # TODO: (kourosh) This approach will cause an OOM issue when
  1692. # the dataset gets huge (should be ok for now).
  1693. all_batches.append(batch)
  1694. # Update correct number of healthy remote workers.
  1695. num_healthy_workers = (
  1696. self.eval_env_runner_group.num_healthy_remote_workers()
  1697. )
  1698. if num_healthy_workers == 0:
  1699. logger.warning(
  1700. "Calling `sample()` on your remote evaluation worker(s) "
  1701. "resulted in all workers crashing! Make sure a) your environment is not"
  1702. " too unstable, b) you have enough evaluation workers "
  1703. "(`config.evaluation(evaluation_num_env_runners=...)`) to cover for "
  1704. "occasional losses, and c) you use the `config.fault_tolerance("
  1705. "restart_failed_env_runners=True)` setting."
  1706. )
  1707. if not self.config.enable_env_runner_and_connector_v2:
  1708. env_runner_results = summarize_episodes(
  1709. all_metrics,
  1710. all_metrics,
  1711. keep_custom_metrics=(
  1712. self.evaluation_config.keep_per_episode_custom_metrics
  1713. ),
  1714. )
  1715. num_episodes = env_runner_results[NUM_EPISODES]
  1716. else:
  1717. self.metrics.aggregate(
  1718. all_metrics,
  1719. key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS),
  1720. )
  1721. num_episodes = self.metrics.peek(
  1722. (EVALUATION_RESULTS, ENV_RUNNER_RESULTS, NUM_EPISODES),
  1723. default=0,
  1724. )
  1725. env_runner_results = None
  1726. # Warn if results are empty, it could be that this is because the auto-time is
  1727. # not enough to run through one full episode.
  1728. if (
  1729. self.config.evaluation_force_reset_envs_before_iteration
  1730. and num_episodes == 0
  1731. ):
  1732. logger.warning(
  1733. "This evaluation iteration resulted in an empty set of episode summary "
  1734. "results! It's possible that the auto-duration time (roughly the mean "
  1735. "time it takes for the training step to finish) is not enough to finish"
  1736. " even a single episode. Your current mean training iteration time is "
  1737. f"{train_mean_time}sec. Try setting the min iteration time to a higher "
  1738. "value via the `config.reporting(min_time_s_per_iteration=...)` OR you "
  1739. "can also set `config.evaluation_force_reset_envs_before_iteration` to "
  1740. "False. However, keep in mind that then the evaluation results may "
  1741. "contain some episode stats generated with earlier weights versions."
  1742. )
  1743. return env_runner_results, env_steps, agent_steps, all_batches
  1744. def _evaluate_offline_with_fixed_duration(self) -> None:
  1745. # How many batches do we need to run?
  1746. num_workers = self.config.num_offline_eval_runners
  1747. time_out = self.config.offline_evaluation_timeout_s
  1748. def _offline_eval_runner_remote(runner, iter):
  1749. metrics = runner.run()
  1750. return metrics, iter
  1751. all_metrics = []
  1752. num_units_done = []
  1753. # How many episodes have we run (across all eval workers)?
  1754. num_units_done = 0
  1755. num_healthy_workers = self.offline_eval_runner_group.num_healthy_remote_runners
  1756. # TODO (simon): Note, agent steps might not be available, but only
  1757. # module steps.
  1758. t_last_result = time.time()
  1759. _round = -1
  1760. algo_iteration = self.iteration
  1761. # In case all the remote evaluation workers die during a round of
  1762. # evaluation, we need to stop.
  1763. while num_healthy_workers > 0:
  1764. units_left_to_do = (
  1765. self.config.offline_evaluation_duration * num_workers - num_units_done
  1766. )
  1767. if units_left_to_do <= 0:
  1768. break
  1769. _round += 1
  1770. self.offline_eval_runner_group.foreach_runner_async(
  1771. func=functools.partial(
  1772. _offline_eval_runner_remote,
  1773. iter=algo_iteration,
  1774. ),
  1775. )
  1776. results = self.offline_eval_runner_group.fetch_ready_async_reqs(
  1777. return_obj_refs=False, timeout_seconds=0.01
  1778. )
  1779. # Make sure we properly time out if we have not received any results
  1780. # for more than `time_out` seconds.
  1781. time_now = time.time()
  1782. if not results and time_now - t_last_result > time_out:
  1783. break
  1784. elif results:
  1785. t_last_result = time_now
  1786. for wid, (met, iter) in results:
  1787. if iter != self.iteration:
  1788. continue
  1789. all_metrics.append(met)
  1790. # Note, the `dataset_num_iters_per_eval_runner` must be smaller than
  1791. # `offline_evaluation_duration` // `num_offline_eval_runners`.
  1792. num_units_done += (
  1793. met[ALL_MODULES][DATASET_NUM_ITERS_EVALUATED].peek()
  1794. if DATASET_NUM_ITERS_EVALUATED in met[ALL_MODULES]
  1795. else 0
  1796. )
  1797. # Update correct number of healthy remote workers.
  1798. num_healthy_workers = (
  1799. self.offline_eval_runner_group.num_healthy_remote_runners
  1800. )
  1801. if num_healthy_workers == 0:
  1802. logger.warning(
  1803. "Calling `run()` on your remote offline evaluation runner(s) "
  1804. "resulted in all runners crashing! Make sure a) your dataset is not"
  1805. " corrupted, b) you have enough offline evaluation runners "
  1806. "(`config.evaluation(num_offline_eval_runners=...)`) to cover for "
  1807. "occasional losses, and c) you use the `config.fault_tolerance("
  1808. "restart_failed_offline_eval_runners=True)` setting."
  1809. )
  1810. self.metrics.aggregate(
  1811. all_metrics,
  1812. key=(EVALUATION_RESULTS, OFFLINE_EVAL_RUNNER_RESULTS),
  1813. )
  1814. def _evaluate_with_fixed_duration(self):
  1815. # How many episodes/timesteps do we need to run?
  1816. unit = self.config.evaluation_duration_unit
  1817. eval_cfg = self.evaluation_config
  1818. num_workers = self.config.evaluation_num_env_runners
  1819. force_reset = self.config.evaluation_force_reset_envs_before_iteration
  1820. time_out = self.config.evaluation_sample_timeout_s
  1821. # Remote function used on healthy EnvRunners to sample, get metrics, and
  1822. # step counts.
  1823. def _env_runner_remote(worker, num, round, iter, _force_reset):
  1824. # Sample AND get_metrics, but only return metrics (and steps actually taken)
  1825. # to save time. Also return the iteration to check, whether we should
  1826. # discard and outdated result (from a slow worker).
  1827. episodes = worker.sample(
  1828. num_timesteps=(
  1829. num[worker.worker_index] if unit == "timesteps" else None
  1830. ),
  1831. num_episodes=(num[worker.worker_index] if unit == "episodes" else None),
  1832. force_reset=_force_reset and round == 0,
  1833. )
  1834. metrics = worker.get_metrics()
  1835. env_steps = sum(e.env_steps() for e in episodes)
  1836. agent_steps = sum(e.agent_steps() for e in episodes)
  1837. return env_steps, agent_steps, metrics, iter
  1838. all_metrics = []
  1839. all_batches = []
  1840. # How many episodes have we run (across all eval workers)?
  1841. num_units_done = 0
  1842. num_healthy_workers = self.eval_env_runner_group.num_healthy_remote_workers()
  1843. env_steps = agent_steps = 0
  1844. t_last_result = time.time()
  1845. _round = -1
  1846. algo_iteration = self.iteration
  1847. # In case all the remote evaluation workers die during a round of
  1848. # evaluation, we need to stop.
  1849. while num_healthy_workers > 0:
  1850. units_left_to_do = self.config.evaluation_duration - num_units_done
  1851. if units_left_to_do <= 0:
  1852. break
  1853. _round += 1
  1854. # New API stack -> EnvRunners return Episodes.
  1855. if self.config.enable_env_runner_and_connector_v2:
  1856. _num = [None] + [ # [None]: skip idx=0 (local worker)
  1857. (units_left_to_do // num_healthy_workers)
  1858. + bool(i <= (units_left_to_do % num_healthy_workers))
  1859. for i in range(1, num_workers + 1)
  1860. ]
  1861. results = (
  1862. self.eval_env_runner_group.foreach_env_runner_async_fetch_ready(
  1863. func=_env_runner_remote,
  1864. kwargs={
  1865. "num": _num,
  1866. "round": _round,
  1867. "iter": algo_iteration,
  1868. "_force_reset": force_reset,
  1869. },
  1870. tag="_env_runner_remote",
  1871. )
  1872. )
  1873. # Make sure we properly time out if we have not received any results
  1874. # for more than `time_out` seconds.
  1875. time_now = time.time()
  1876. if not results and time_now - t_last_result > time_out:
  1877. break
  1878. elif results:
  1879. t_last_result = time_now
  1880. for env_s, ag_s, met, iter in results:
  1881. if iter != self.iteration:
  1882. continue
  1883. env_steps += env_s
  1884. agent_steps += ag_s
  1885. all_metrics.append(met)
  1886. num_units_done += (
  1887. (met[NUM_EPISODES].peek() if NUM_EPISODES in met else 0)
  1888. if unit == "episodes"
  1889. else (
  1890. env_s if self.config.count_steps_by == "env_steps" else ag_s
  1891. )
  1892. )
  1893. # Old API stack -> RolloutWorkers return batches.
  1894. else:
  1895. units_per_healthy_remote_worker = (
  1896. 1
  1897. if unit == "episodes"
  1898. else eval_cfg.rollout_fragment_length
  1899. * eval_cfg.num_envs_per_env_runner
  1900. )
  1901. # Select proper number of evaluation workers for this round.
  1902. selected_eval_worker_ids = [
  1903. worker_id
  1904. for i, worker_id in enumerate(
  1905. self.eval_env_runner_group.healthy_worker_ids()
  1906. )
  1907. if i * units_per_healthy_remote_worker < units_left_to_do
  1908. ]
  1909. results = (
  1910. self.eval_env_runner_group.foreach_env_runner_async_fetch_ready(
  1911. func=lambda w: (w.sample(), w.get_metrics(), algo_iteration),
  1912. remote_worker_ids=selected_eval_worker_ids,
  1913. tag="env_runner_sample_and_get_metrics",
  1914. )
  1915. )
  1916. # Make sure we properly time out if we have not received any results
  1917. # for more than `time_out` seconds.
  1918. time_now = time.time()
  1919. if not results and time_now - t_last_result > time_out:
  1920. break
  1921. elif results:
  1922. t_last_result = time_now
  1923. for batch, metrics, iter in results:
  1924. if iter != self.iteration:
  1925. continue
  1926. env_steps += batch.env_steps()
  1927. agent_steps += batch.agent_steps()
  1928. all_metrics.extend(metrics)
  1929. if self.reward_estimators:
  1930. # TODO: (kourosh) This approach will cause an OOM issue when
  1931. # the dataset gets huge (should be ok for now).
  1932. all_batches.append(batch)
  1933. # 1 episode per returned batch.
  1934. if unit == "episodes":
  1935. num_units_done += len(results)
  1936. # n timesteps per returned batch.
  1937. else:
  1938. num_units_done = (
  1939. env_steps
  1940. if self.config.count_steps_by == "env_steps"
  1941. else agent_steps
  1942. )
  1943. # Update correct number of healthy remote workers.
  1944. num_healthy_workers = (
  1945. self.eval_env_runner_group.num_healthy_remote_workers()
  1946. )
  1947. if num_healthy_workers == 0:
  1948. logger.warning(
  1949. "Calling `sample()` on your remote evaluation worker(s) "
  1950. "resulted in all workers crashing! Make sure a) your environment is not"
  1951. " too unstable, b) you have enough evaluation workers "
  1952. "(`config.evaluation(evaluation_num_env_runners=...)`) to cover for "
  1953. "occasional losses, and c) you use the `config.fault_tolerance("
  1954. "restart_failed_env_runners=True)` setting."
  1955. )
  1956. if not self.config.enable_env_runner_and_connector_v2:
  1957. env_runner_results = summarize_episodes(
  1958. all_metrics,
  1959. all_metrics,
  1960. keep_custom_metrics=(
  1961. self.evaluation_config.keep_per_episode_custom_metrics
  1962. ),
  1963. )
  1964. num_episodes = env_runner_results[NUM_EPISODES]
  1965. else:
  1966. self.metrics.aggregate(
  1967. all_metrics,
  1968. key=(EVALUATION_RESULTS, ENV_RUNNER_RESULTS),
  1969. )
  1970. num_episodes = self.metrics.peek(
  1971. (EVALUATION_RESULTS, ENV_RUNNER_RESULTS, NUM_EPISODES),
  1972. default=0,
  1973. latest_merged_only=True,
  1974. )
  1975. env_runner_results = None
  1976. # Warn if results are empty, it could be that this is because the eval timesteps
  1977. # are not enough to run through one full episode.
  1978. if num_episodes == 0:
  1979. logger.warning(
  1980. "This evaluation iteration resulted in an empty set of episode summary "
  1981. "results! It's possible that your configured duration timesteps are not"
  1982. " enough to finish even a single episode. You have configured "
  1983. f"{self.config.evaluation_duration} "
  1984. f"{self.config.evaluation_duration_unit}. For 'timesteps', try "
  1985. "increasing this value via the `config.evaluation(evaluation_duration="
  1986. "...)` OR change the unit to 'episodes' via `config.evaluation("
  1987. "evaluation_duration_unit='episodes')` OR try increasing the timeout "
  1988. "threshold via `config.evaluation(evaluation_sample_timeout_s=...)` OR "
  1989. "you can also set `config.evaluation_force_reset_envs_before_iteration`"
  1990. " to False. However, keep in mind that in the latter case, the "
  1991. "evaluation results may contain some episode stats generated with "
  1992. "earlier weights versions."
  1993. )
  1994. return env_runner_results, env_steps, agent_steps, all_batches
  1995. @OverrideToImplementCustomLogic
  1996. def restore_env_runners(self, env_runner_group: EnvRunnerGroup) -> List[int]:
  1997. """Try bringing back unhealthy EnvRunners and - if successful - sync with local.
  1998. Algorithms that use custom EnvRunners may override this method to
  1999. disable the default, and create custom restoration logics. Note that "restoring"
  2000. does not include the actual restarting process, but merely what should happen
  2001. after such a restart of a (previously failed) worker.
  2002. Args:
  2003. env_runner_group: The EnvRunnerGroup to restore. This may be the training or
  2004. the evaluation EnvRunnerGroup.
  2005. Returns:
  2006. A list of EnvRunner indices that have been restored during the call of
  2007. this method.
  2008. """
  2009. # This is really cheap, since probe_unhealthy_env_runners() is a no-op
  2010. # if there are no unhealthy workers.
  2011. restored = None
  2012. if self.config.is_online:
  2013. restored = env_runner_group.probe_unhealthy_env_runners()
  2014. if not restored:
  2015. return []
  2016. # Count the restored workers.
  2017. self._counters["total_num_restored_workers"] += len(restored)
  2018. from_env_runner = env_runner_group.local_env_runner or self.env_runner
  2019. # Sync from local EnvRunner, if it exists.
  2020. if from_env_runner is not None:
  2021. # Get the state of the EnvRunner.
  2022. state = from_env_runner.get_state()
  2023. state_ref = ray.put(state)
  2024. # Take out (old) connector states from local worker's state.
  2025. if not self.config.enable_env_runner_and_connector_v2:
  2026. for pol_states in state["policy_states"].values():
  2027. pol_states.pop("connector_configs", None)
  2028. elif self.config.is_multi_agent:
  2029. multi_rl_module_spec = MultiRLModuleSpec.from_module(
  2030. from_env_runner.module
  2031. )
  2032. # Otherwise, sync from another EnvRunner that's still healthy.
  2033. else:
  2034. multi_rl_module_spec = (
  2035. self.learner_group.foreach_learner(
  2036. lambda learner: MultiRLModuleSpec.from_module(learner.module)
  2037. )
  2038. .result_or_errors[0]
  2039. .get()
  2040. )
  2041. # Sync the weights from the learner group to the EnvRunners.
  2042. state = self.learner_group.get_state(
  2043. components=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
  2044. inference_only=True,
  2045. )[COMPONENT_LEARNER]
  2046. state[
  2047. COMPONENT_ENV_TO_MODULE_CONNECTOR
  2048. ] = self.env_to_module_connector.get_state()
  2049. state[
  2050. COMPONENT_MODULE_TO_ENV_CONNECTOR
  2051. ] = self.module_to_env_connector.get_state()
  2052. state[NUM_ENV_STEPS_SAMPLED_LIFETIME] = self.metrics.peek(
  2053. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
  2054. default=0,
  2055. )
  2056. state_ref = ray.put(state)
  2057. def _sync_env_runner(er): # noqa
  2058. # Remove modules (new API stack only), if necessary.
  2059. if (
  2060. er.config.enable_env_runner_and_connector_v2
  2061. and er.config.is_multi_agent
  2062. ):
  2063. for module_id, module in er.module._rl_modules.copy().items():
  2064. if module_id not in multi_rl_module_spec.rl_module_specs:
  2065. er.module.remove_module(module_id, raise_err_if_not_found=True)
  2066. # Add modules, if necessary.
  2067. for mid, mod_spec in multi_rl_module_spec.rl_module_specs.items():
  2068. if mid not in er.module:
  2069. er.module.add_module(mid, mod_spec.build(), override=False)
  2070. # Now that the MultiRLModule is fixed, update the state.
  2071. er.set_state(ray.get(state_ref))
  2072. # By default, entire local EnvRunner state is synced after restoration
  2073. # to bring the previously failed EnvRunner up to date.
  2074. env_runner_group.foreach_env_runner(
  2075. func=_sync_env_runner,
  2076. remote_worker_ids=restored,
  2077. # Don't update the local EnvRunner, b/c it's the one we are synching
  2078. # from.
  2079. local_env_runner=False,
  2080. timeout_seconds=self.config.env_runner_restore_timeout_s,
  2081. )
  2082. return restored
  2083. @OverrideToImplementCustomLogic
  2084. def restore_offline_eval_runners(self, runner_group: RunnerGroup) -> List[int]:
  2085. if not runner_group or not runner_group.local_runner:
  2086. return []
  2087. restored = runner_group.probe_unhealthy_runners()
  2088. if restored:
  2089. # Count the restored workers.
  2090. self._counters["total_num_restored_workers"] += len(restored)
  2091. # Get the state of the correct (reference) worker.
  2092. from_runner = runner_group.healthy_runner_ids()[0]
  2093. state = runner_group.foreach_runner(
  2094. "get_state",
  2095. local_runner=False,
  2096. remote_worker_ids=from_runner,
  2097. )[0]
  2098. state_ref = ray.put(state)
  2099. def _sync_runner(r):
  2100. r.set_state(ray.get(state_ref))
  2101. # By default, entire `Runner`` state is synced after restoration
  2102. # to bring the previously failed `Runner` up to date.
  2103. runner_group.foreach_runner(
  2104. func=_sync_runner,
  2105. remote_worker_ids=restored,
  2106. # Don't update the local `Runner`.
  2107. local_runner=False,
  2108. timeout_seconds=self.evaluation_config.offline_eval_runner_restore_timeout_s,
  2109. )
  2110. # Restore the correct data iterator split stream.
  2111. # TODO (simon): Define a `restore` method in the `RunnerGroup`
  2112. # such that we do not have to check here for the group.
  2113. # Also get a different streaming split if a runner fails and is not
  2114. # recreated.
  2115. runner_group.foreach_runner(
  2116. func="set_dataset_iterator",
  2117. remote_worker_ids=restored,
  2118. local_runner=False,
  2119. timeout_seconds=self.evaluation_config.offline_eval_runner_restore_timeout_s,
  2120. kwargs={"iterator": runner_group._offline_data_iterators[restored]},
  2121. )
  2122. return restored
  2123. @OverrideToImplementCustomLogic
  2124. def training_step(self) -> None:
  2125. """Default single iteration logic of an algorithm.
  2126. - Collect on-policy samples (SampleBatches) in parallel using the
  2127. Algorithm's EnvRunners (@ray.remote).
  2128. - Concatenate collected SampleBatches into one train batch.
  2129. - Note that we may have more than one policy in the multi-agent case:
  2130. Call the different policies' `learn_on_batch` (simple optimizer) OR
  2131. `load_batch_into_buffer` + `learn_on_loaded_batch` (multi-GPU
  2132. optimizer) methods to calculate loss and update the model(s).
  2133. - Return all collected metrics for the iteration.
  2134. Returns:
  2135. For the new API stack, returns None. Results are compiled and extracted
  2136. automatically through a single `self.metrics.reduce()` call at the very end
  2137. of an iteration (which might contain more than one call to
  2138. `training_step()`). This way, we make sure that we account for all
  2139. results generated by each individual `training_step()` call.
  2140. For the old API stack, returns the results dict from executing the training
  2141. step.
  2142. """
  2143. if not self.config.enable_env_runner_and_connector_v2:
  2144. raise NotImplementedError(
  2145. "The `Algorithm.training_step()` default implementation no longer "
  2146. "supports the old API stack! If you would like to continue "
  2147. "using these "
  2148. "old APIs with this default `training_step`, simply subclass "
  2149. "`Algorithm` and override its `training_step` method (copy/paste the "
  2150. "code and delete this error message)."
  2151. )
  2152. # Collect a list of Episodes from EnvRunners until we reach the train batch
  2153. # size.
  2154. with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)):
  2155. if self.config.count_steps_by == "agent_steps":
  2156. episodes, env_runner_results = synchronous_parallel_sample(
  2157. worker_set=self.env_runner_group,
  2158. max_agent_steps=self.config.total_train_batch_size,
  2159. sample_timeout_s=self.config.sample_timeout_s,
  2160. _uses_new_env_runners=True,
  2161. _return_metrics=True,
  2162. )
  2163. else:
  2164. episodes, env_runner_results = synchronous_parallel_sample(
  2165. worker_set=self.env_runner_group,
  2166. max_env_steps=self.config.total_train_batch_size,
  2167. sample_timeout_s=self.config.sample_timeout_s,
  2168. _uses_new_env_runners=True,
  2169. _return_metrics=True,
  2170. )
  2171. # Reduce EnvRunner metrics over the n EnvRunners.
  2172. self.metrics.aggregate(env_runner_results, key=ENV_RUNNER_RESULTS)
  2173. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
  2174. learner_results = self.learner_group.update(
  2175. episodes=episodes,
  2176. timesteps={
  2177. NUM_ENV_STEPS_SAMPLED_LIFETIME: (
  2178. self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME)
  2179. ),
  2180. },
  2181. )
  2182. self.metrics.aggregate(learner_results, key=LEARNER_RESULTS)
  2183. # Update weights - after learning on the local worker - on all
  2184. # remote workers (only those RLModules that were actually trained).
  2185. with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
  2186. self.env_runner_group.sync_weights(
  2187. from_worker_or_learner_group=self.learner_group,
  2188. policies=list(set(learner_results.keys()) - {ALL_MODULES}),
  2189. inference_only=True,
  2190. )
  2191. @PublicAPI
  2192. def get_module(self, module_id: ModuleID = DEFAULT_MODULE_ID) -> Optional[RLModule]:
  2193. """Returns the (single-agent) RLModule with `model_id` (None if ID not found).
  2194. Args:
  2195. module_id: ID of the (single-agent) RLModule to return from the MARLModule
  2196. used by the local EnvRunner.
  2197. Returns:
  2198. The RLModule found under the ModuleID key inside the local EnvRunner's
  2199. MultiRLModule. None if `module_id` doesn't exist.
  2200. """
  2201. if self.env_runner is not None:
  2202. module = self.env_runner.module
  2203. else:
  2204. module = self.env_runner_group.foreach_env_runner(
  2205. lambda er: er.module,
  2206. remote_worker_ids=[1],
  2207. local_env_runner=False,
  2208. )[0]
  2209. if isinstance(module, MultiRLModule):
  2210. return module.get(module_id)
  2211. else:
  2212. return module
  2213. @PublicAPI
  2214. def add_module(
  2215. self,
  2216. module_id: ModuleID,
  2217. module_spec: RLModuleSpec,
  2218. *,
  2219. config_overrides: Optional[Dict] = None,
  2220. new_agent_to_module_mapping_fn: Optional[AgentToModuleMappingFn] = None,
  2221. new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
  2222. add_to_learners: bool = True,
  2223. add_to_env_runners: bool = True,
  2224. add_to_eval_env_runners: bool = True,
  2225. ) -> MultiRLModuleSpec:
  2226. """Adds a new (single-agent) RLModule to this Algorithm's MARLModule.
  2227. Note that an Algorithm has up to 3 different components to which to add
  2228. the new module to: The LearnerGroup (with n Learners), the EnvRunnerGroup
  2229. (with m EnvRunners plus a local one) and - if applicable - the eval
  2230. EnvRunnerGroup (with o EnvRunners plus a local one).
  2231. Args:
  2232. module_id: ID of the RLModule to add to the MARLModule.
  2233. IMPORTANT: Must not contain characters that
  2234. are also not allowed in Unix/Win filesystems, such as: `<>:"/|?*`,
  2235. or a dot, space or backslash at the end of the ID.
  2236. module_spec: The SingleAgentRLModuleSpec to use for constructing the new
  2237. RLModule.
  2238. config_overrides: The `AlgorithmConfig` overrides that should apply to
  2239. the new Module, if any.
  2240. new_agent_to_module_mapping_fn: An optional (updated) AgentID to ModuleID
  2241. mapping function to use from here on. Note that already ongoing
  2242. episodes will not change their mapping but will use the old mapping till
  2243. the end of the episode.
  2244. new_should_module_be_updated: An optional sequence of ModuleIDs or a
  2245. callable taking ModuleID and SampleBatchType and returning whether the
  2246. ModuleID should be updated (trained).
  2247. If None, will keep the existing setup in place. RLModules,
  2248. whose IDs are not in the list (or for which the callable
  2249. returns False) will not be updated.
  2250. add_to_learners: Whether to add the new RLModule to the LearnerGroup
  2251. (with its n Learners).
  2252. add_to_env_runners: Whether to add the new RLModule to the EnvRunnerGroup
  2253. (with its m EnvRunners plus the local one).
  2254. add_to_eval_env_runners: Whether to add the new RLModule to the eval
  2255. EnvRunnerGroup (with its o EnvRunners plus the local one).
  2256. Returns:
  2257. The new MultiRLModuleSpec (after the RLModule has been added).
  2258. """
  2259. validate_module_id(module_id, error=True)
  2260. # The to-be-returned new MultiRLModuleSpec.
  2261. multi_rl_module_spec = None
  2262. if not self.config.is_multi_agent:
  2263. raise RuntimeError(
  2264. "Can't add a new RLModule to a single-agent setup! Make sure that your "
  2265. "setup is already initially multi-agent by either defining >1 "
  2266. f"RLModules in your `rl_module_spec` or assigning a ModuleID other "
  2267. f"than {DEFAULT_MODULE_ID} to your (only) RLModule."
  2268. )
  2269. if not any([add_to_learners, add_to_env_runners, add_to_eval_env_runners]):
  2270. raise ValueError(
  2271. "At least one of `add_to_learners`, `add_to_env_runners`, or "
  2272. "`add_to_eval_env_runners` must be set to True!"
  2273. )
  2274. # Add to Learners and sync weights.
  2275. if add_to_learners:
  2276. multi_rl_module_spec = self.learner_group.add_module(
  2277. module_id=module_id,
  2278. module_spec=module_spec,
  2279. config_overrides=config_overrides,
  2280. new_should_module_be_updated=new_should_module_be_updated,
  2281. )
  2282. # Change our config (AlgorithmConfig) to contain the new Module.
  2283. # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
  2284. # but we'll deprecate config.policies soon anyway.
  2285. self.config._is_frozen = False
  2286. self.config.policies[module_id] = PolicySpec()
  2287. if config_overrides is not None:
  2288. self.config.multi_agent(
  2289. algorithm_config_overrides_per_module={module_id: config_overrides}
  2290. )
  2291. if new_agent_to_module_mapping_fn is not None:
  2292. self.config.multi_agent(policy_mapping_fn=new_agent_to_module_mapping_fn)
  2293. self.config.rl_module(rl_module_spec=multi_rl_module_spec)
  2294. if new_should_module_be_updated is not None:
  2295. self.config.multi_agent(policies_to_train=new_should_module_be_updated)
  2296. self.config.freeze()
  2297. def _add(_env_runner, _module_spec=module_spec):
  2298. # Add the RLModule to the existing one on the EnvRunner.
  2299. _env_runner.module.add_module(
  2300. module_id=module_id, module=_module_spec.build()
  2301. )
  2302. # Update the `agent_to_module_mapping_fn` on the EnvRunner.
  2303. if new_agent_to_module_mapping_fn is not None:
  2304. _env_runner.config.multi_agent(
  2305. policy_mapping_fn=new_agent_to_module_mapping_fn,
  2306. )
  2307. # Update the `should_module_be_updated` on the EnvRunner. Note that
  2308. # even though this information is typically not needed by the EnvRunner,
  2309. # it's good practice to keep this setting updated everywhere either way.
  2310. if new_should_module_be_updated is not None:
  2311. _env_runner.config.multi_agent(
  2312. policies_to_train=new_should_module_be_updated,
  2313. )
  2314. return MultiRLModuleSpec.from_module(_env_runner.module)
  2315. # Add to (training) EnvRunners and sync weights.
  2316. if add_to_env_runners:
  2317. if multi_rl_module_spec is None:
  2318. multi_rl_module_spec = self.env_runner_group.foreach_env_runner(_add)[0]
  2319. else:
  2320. self.env_runner_group.foreach_env_runner(_add)
  2321. self.env_runner_group.sync_weights(
  2322. from_worker_or_learner_group=self.learner_group,
  2323. inference_only=True,
  2324. )
  2325. # Add to eval EnvRunners and sync weights.
  2326. if add_to_eval_env_runners is True and self.eval_env_runner_group is not None:
  2327. if multi_rl_module_spec is None:
  2328. multi_rl_module_spec = self.eval_env_runner_group.foreach_env_runner(
  2329. _add
  2330. )[0]
  2331. else:
  2332. self.eval_env_runner_group.foreach_env_runner(_add)
  2333. self.eval_env_runner_group.sync_weights(
  2334. from_worker_or_learner_group=self.learner_group,
  2335. inference_only=True,
  2336. )
  2337. return multi_rl_module_spec
  2338. @PublicAPI
  2339. def remove_module(
  2340. self,
  2341. module_id: ModuleID,
  2342. *,
  2343. new_agent_to_module_mapping_fn: Optional[AgentToModuleMappingFn] = None,
  2344. new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
  2345. remove_from_learners: bool = True,
  2346. remove_from_env_runners: bool = True,
  2347. remove_from_eval_env_runners: bool = True,
  2348. ) -> Optional[Policy]:
  2349. """Removes a new (single-agent) RLModule from this Algorithm's MARLModule.
  2350. Args:
  2351. module_id: ID of the RLModule to remove from the MARLModule.
  2352. IMPORTANT: Must not contain characters that
  2353. are also not allowed in Unix/Win filesystems, such as: `<>:"/|?*`,
  2354. or a dot, space or backslash at the end of the ID.
  2355. new_agent_to_module_mapping_fn: An optional (updated) AgentID to ModuleID
  2356. mapping function to use from here on. Note that already ongoing
  2357. episodes will not change their mapping but will use the old mapping till
  2358. the end of the episode.
  2359. new_should_module_be_updated: An optional sequence of ModuleIDs or a
  2360. callable taking ModuleID and SampleBatchType and returning whether the
  2361. ModuleID should be updated (trained).
  2362. If None, will keep the existing setup in place. RLModules,
  2363. whose IDs are not in the list (or for which the callable
  2364. returns False) will not be updated.
  2365. remove_from_learners: Whether to remove the RLModule from the LearnerGroup
  2366. (with its n Learners).
  2367. remove_from_env_runners: Whether to remove the RLModule from the
  2368. EnvRunnerGroup (with its m EnvRunners plus the local one).
  2369. remove_from_eval_env_runners: Whether to remove the RLModule from the eval
  2370. EnvRunnerGroup (with its o EnvRunners plus the local one).
  2371. Returns:
  2372. The new MultiRLModuleSpec (after the RLModule has been removed).
  2373. """
  2374. # The to-be-returned new MultiRLModuleSpec.
  2375. multi_rl_module_spec = None
  2376. # Remove RLModule from the LearnerGroup.
  2377. if remove_from_learners:
  2378. multi_rl_module_spec = self.learner_group.remove_module(
  2379. module_id=module_id,
  2380. new_should_module_be_updated=new_should_module_be_updated,
  2381. )
  2382. # Change our config (AlgorithmConfig) with the Module removed.
  2383. # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
  2384. # but we'll deprecate config.policies soon anyway.
  2385. self.config._is_frozen = False
  2386. del self.config.policies[module_id]
  2387. self.config.algorithm_config_overrides_per_module.pop(module_id, None)
  2388. if new_agent_to_module_mapping_fn is not None:
  2389. self.config.multi_agent(policy_mapping_fn=new_agent_to_module_mapping_fn)
  2390. self.config.rl_module(rl_module_spec=multi_rl_module_spec)
  2391. if new_should_module_be_updated is not None:
  2392. self.config.multi_agent(policies_to_train=new_should_module_be_updated)
  2393. self.config.freeze()
  2394. def _remove(_env_runner):
  2395. # Remove the RLModule from the existing one on the EnvRunner.
  2396. _env_runner.module.remove_module(module_id=module_id)
  2397. # Update the `agent_to_module_mapping_fn` on the EnvRunner.
  2398. if new_agent_to_module_mapping_fn is not None:
  2399. _env_runner.config.multi_agent(
  2400. policy_mapping_fn=new_agent_to_module_mapping_fn
  2401. )
  2402. # Force reset all ongoing episodes on the EnvRunner to avoid having
  2403. # different ModuleIDs compute actions for the same AgentID in the same
  2404. # episode.
  2405. # TODO (sven): Create an API for this.
  2406. _env_runner._needs_initial_reset = True
  2407. return MultiRLModuleSpec.from_module(_env_runner.module)
  2408. # Remove from (training) EnvRunners and sync weights.
  2409. if remove_from_env_runners:
  2410. if multi_rl_module_spec is None:
  2411. multi_rl_module_spec = self.env_runner_group.foreach_env_runner(
  2412. _remove
  2413. )[0]
  2414. else:
  2415. self.env_runner_group.foreach_env_runner(_remove)
  2416. self.env_runner_group.sync_weights(
  2417. from_worker_or_learner_group=self.learner_group,
  2418. inference_only=True,
  2419. )
  2420. # Remove from (eval) EnvRunners and sync weights.
  2421. if (
  2422. remove_from_eval_env_runners is True
  2423. and self.eval_env_runner_group is not None
  2424. ):
  2425. if multi_rl_module_spec is None:
  2426. multi_rl_module_spec = self.eval_env_runner_group.foreach_env_runner(
  2427. _remove
  2428. )[0]
  2429. else:
  2430. self.eval_env_runner_group.foreach_env_runner(_remove)
  2431. self.eval_env_runner_group.sync_weights(
  2432. from_worker_or_learner_group=self.learner_group,
  2433. inference_only=True,
  2434. )
  2435. return multi_rl_module_spec
  2436. @OldAPIStack
  2437. def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy:
  2438. """Return policy for the specified id, or None.
  2439. Args:
  2440. policy_id: ID of the policy to return.
  2441. """
  2442. return self.env_runner.get_policy(policy_id)
  2443. @PublicAPI
  2444. def get_weights(self, policies: Optional[List[PolicyID]] = None) -> dict:
  2445. """Return a dict mapping Module/Policy IDs to weights.
  2446. Args:
  2447. policies: Optional list of policies to return weights for,
  2448. or None for all policies.
  2449. """
  2450. # New API stack (get weights from LearnerGroup).
  2451. if self.learner_group is not None:
  2452. return self.learner_group.get_weights(module_ids=policies)
  2453. return self.env_runner.get_weights(policies)
  2454. @PublicAPI
  2455. def set_weights(self, weights: Dict[PolicyID, dict]):
  2456. """Set RLModule/Policy weights by Module/Policy ID.
  2457. Args:
  2458. weights: Dict mapping ModuleID/PolicyID to weights.
  2459. """
  2460. # New API stack -> Use `set_state` API and specify the LearnerGroup state in the
  2461. # call, which will automatically take care of weight synching to all EnvRunners.
  2462. if self.learner_group is not None:
  2463. self.set_state(
  2464. {
  2465. COMPONENT_LEARNER_GROUP: {
  2466. COMPONENT_LEARNER: {
  2467. COMPONENT_RL_MODULE: weights,
  2468. },
  2469. },
  2470. },
  2471. )
  2472. self.env_runner_group.local_env_runner.set_weights(weights)
  2473. @OldAPIStack
  2474. def add_policy(
  2475. self,
  2476. policy_id: PolicyID,
  2477. policy_cls: Optional[Type[Policy]] = None,
  2478. policy: Optional[Policy] = None,
  2479. *,
  2480. observation_space: Optional[gym.spaces.Space] = None,
  2481. action_space: Optional[gym.spaces.Space] = None,
  2482. config: Optional[Union[AlgorithmConfig, PartialAlgorithmConfigDict]] = None,
  2483. policy_state: Optional[PolicyState] = None,
  2484. policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
  2485. policies_to_train: Optional[
  2486. Union[
  2487. Collection[PolicyID],
  2488. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  2489. ]
  2490. ] = None,
  2491. add_to_env_runners: bool = True,
  2492. add_to_eval_env_runners: bool = True,
  2493. module_spec: Optional[RLModuleSpec] = None,
  2494. # Deprecated arg.
  2495. evaluation_workers=DEPRECATED_VALUE,
  2496. add_to_learners=DEPRECATED_VALUE,
  2497. ) -> Optional[Policy]:
  2498. """Adds a new policy to this Algorithm.
  2499. Args:
  2500. policy_id: ID of the policy to add.
  2501. IMPORTANT: Must not contain characters that
  2502. are also not allowed in Unix/Win filesystems, such as: `<>:"/|?*`,
  2503. or a dot, space or backslash at the end of the ID.
  2504. policy_cls: The Policy class to use for constructing the new Policy.
  2505. Note: Only one of `policy_cls` or `policy` must be provided.
  2506. policy: The Policy instance to add to this algorithm. If not None, the
  2507. given Policy object will be directly inserted into the Algorithm's
  2508. local worker and clones of that Policy will be created on all remote
  2509. workers as well as all evaluation workers.
  2510. Note: Only one of `policy_cls` or `policy` must be provided.
  2511. observation_space: The observation space of the policy to add.
  2512. If None, try to infer this space from the environment.
  2513. action_space: The action space of the policy to add.
  2514. If None, try to infer this space from the environment.
  2515. config: The config object or overrides for the policy to add.
  2516. policy_state: Optional state dict to apply to the new
  2517. policy instance, right after its construction.
  2518. policy_mapping_fn: An optional (updated) policy mapping function
  2519. to use from here on. Note that already ongoing episodes will
  2520. not change their mapping but will use the old mapping till
  2521. the end of the episode.
  2522. policies_to_train: An optional list of policy IDs to be trained
  2523. or a callable taking PolicyID and SampleBatchType and
  2524. returning a bool (trainable or not?).
  2525. If None, will keep the existing setup in place. Policies,
  2526. whose IDs are not in the list (or for which the callable
  2527. returns False) will not be updated.
  2528. add_to_env_runners: Whether to add the new RLModule to the EnvRunnerGroup
  2529. (with its m EnvRunners plus the local one).
  2530. add_to_eval_env_runners: Whether to add the new RLModule to the eval
  2531. EnvRunnerGroup (with its o EnvRunners plus the local one).
  2532. module_spec: In the new RLModule API we need to pass in the module_spec for
  2533. the new module that is supposed to be added. Knowing the policy spec is
  2534. not sufficient.
  2535. Returns:
  2536. The newly added policy (the copy that got added to the local
  2537. worker). If `workers` was provided, None is returned.
  2538. """
  2539. if self.config.enable_env_runner_and_connector_v2:
  2540. raise ValueError(
  2541. "`Algorithm.add_policy()` is not supported on the new API stack w/ "
  2542. "EnvRunners! Use `Algorithm.add_module()` instead. Also see "
  2543. "`rllib/examples/self_play_league_based_with_open_spiel.py` for an "
  2544. "example."
  2545. )
  2546. if evaluation_workers != DEPRECATED_VALUE:
  2547. deprecation_warning(
  2548. old="Algorithm.add_policy(evaluation_workers=...)",
  2549. new="Algorithm.add_policy(add_to_eval_env_runners=...)",
  2550. error=True,
  2551. )
  2552. if add_to_learners != DEPRECATED_VALUE:
  2553. deprecation_warning(
  2554. old="Algorithm.add_policy(add_to_learners=..)",
  2555. help="Hybrid API stack no longer supported by RLlib!",
  2556. error=True,
  2557. )
  2558. validate_module_id(policy_id, error=True)
  2559. if add_to_env_runners is True:
  2560. self.env_runner_group.add_policy(
  2561. policy_id,
  2562. policy_cls,
  2563. policy,
  2564. observation_space=observation_space,
  2565. action_space=action_space,
  2566. config=config,
  2567. policy_state=policy_state,
  2568. policy_mapping_fn=policy_mapping_fn,
  2569. policies_to_train=policies_to_train,
  2570. module_spec=module_spec,
  2571. )
  2572. # Add to evaluation workers, if necessary.
  2573. if add_to_eval_env_runners is True and self.eval_env_runner_group is not None:
  2574. self.eval_env_runner_group.add_policy(
  2575. policy_id,
  2576. policy_cls,
  2577. policy,
  2578. observation_space=observation_space,
  2579. action_space=action_space,
  2580. config=config,
  2581. policy_state=policy_state,
  2582. policy_mapping_fn=policy_mapping_fn,
  2583. policies_to_train=policies_to_train,
  2584. module_spec=module_spec,
  2585. )
  2586. # Return newly added policy (from the local EnvRunner).
  2587. if add_to_env_runners:
  2588. return self.get_policy(policy_id)
  2589. elif add_to_eval_env_runners and self.eval_env_runner_group:
  2590. return self.eval_env_runner.policy_map[policy_id]
  2591. @OldAPIStack
  2592. def remove_policy(
  2593. self,
  2594. policy_id: PolicyID = DEFAULT_POLICY_ID,
  2595. *,
  2596. policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
  2597. policies_to_train: Optional[
  2598. Union[
  2599. Collection[PolicyID],
  2600. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  2601. ]
  2602. ] = None,
  2603. remove_from_env_runners: bool = True,
  2604. remove_from_eval_env_runners: bool = True,
  2605. # Deprecated args.
  2606. evaluation_workers=DEPRECATED_VALUE,
  2607. remove_from_learners=DEPRECATED_VALUE,
  2608. ) -> None:
  2609. """Removes a policy from this Algorithm.
  2610. Args:
  2611. policy_id: ID of the policy to be removed.
  2612. policy_mapping_fn: An optional (updated) policy mapping function
  2613. to use from here on. Note that already ongoing episodes will
  2614. not change their mapping but will use the old mapping till
  2615. the end of the episode.
  2616. policies_to_train: An optional list of policy IDs to be trained
  2617. or a callable taking PolicyID and SampleBatchType and
  2618. returning a bool (trainable or not?).
  2619. If None, will keep the existing setup in place. Policies,
  2620. whose IDs are not in the list (or for which the callable
  2621. returns False) will not be updated.
  2622. remove_from_env_runners: Whether to remove the Policy from the
  2623. EnvRunnerGroup (with its m EnvRunners plus the local one).
  2624. remove_from_eval_env_runners: Whether to remove the RLModule from the eval
  2625. EnvRunnerGroup (with its o EnvRunners plus the local one).
  2626. """
  2627. if evaluation_workers != DEPRECATED_VALUE:
  2628. deprecation_warning(
  2629. old="Algorithm.remove_policy(evaluation_workers=...)",
  2630. new="Algorithm.remove_policy(remove_from_eval_env_runners=...)",
  2631. error=False,
  2632. )
  2633. remove_from_eval_env_runners = evaluation_workers
  2634. if remove_from_learners != DEPRECATED_VALUE:
  2635. deprecation_warning(
  2636. old="Algorithm.remove_policy(remove_from_learners=..)",
  2637. help="Hybrid API stack no longer supported by RLlib!",
  2638. error=True,
  2639. )
  2640. def fn(worker):
  2641. worker.remove_policy(
  2642. policy_id=policy_id,
  2643. policy_mapping_fn=policy_mapping_fn,
  2644. policies_to_train=policies_to_train,
  2645. )
  2646. # Update all EnvRunner workers.
  2647. if remove_from_env_runners:
  2648. self.env_runner_group.foreach_env_runner(fn, local_env_runner=True)
  2649. # Update the evaluation worker set's workers, if required.
  2650. if remove_from_eval_env_runners and self.eval_env_runner_group is not None:
  2651. self.eval_env_runner_group.foreach_env_runner(fn, local_env_runner=True)
  2652. @OldAPIStack
  2653. @staticmethod
  2654. def from_state(state: Dict) -> "Algorithm":
  2655. """Recovers an Algorithm from a state object.
  2656. The `state` of an instantiated Algorithm can be retrieved by calling its
  2657. `get_state` method. It contains all information necessary
  2658. to create the Algorithm from scratch. No access to the original code (e.g.
  2659. configs, knowledge of the Algorithm's class, etc..) is needed.
  2660. Args:
  2661. state: The state to recover a new Algorithm instance from.
  2662. Returns:
  2663. A new Algorithm instance.
  2664. """
  2665. algorithm_class: Type[Algorithm] = state.get("algorithm_class")
  2666. if algorithm_class is None:
  2667. raise ValueError(
  2668. "No `algorithm_class` key was found in given `state`! "
  2669. "Cannot create new Algorithm."
  2670. )
  2671. # algo_class = get_trainable_cls(algo_class_name)
  2672. # Create the new algo.
  2673. config = state.get("config")
  2674. if not config:
  2675. raise ValueError("No `config` found in given Algorithm state!")
  2676. new_algo = algorithm_class(config=config)
  2677. # Set the new algo's state.
  2678. new_algo.__setstate__(state)
  2679. # Return the new algo.
  2680. return new_algo
  2681. @OldAPIStack
  2682. def export_policy_model(
  2683. self,
  2684. export_dir: str,
  2685. policy_id: PolicyID = DEFAULT_POLICY_ID,
  2686. onnx: Optional[int] = None,
  2687. ) -> None:
  2688. """Exports policy model with given policy_id to a local directory.
  2689. Args:
  2690. export_dir: Writable local directory.
  2691. policy_id: Optional policy id to export.
  2692. onnx: If given, will export model in ONNX format. The
  2693. value of this parameter set the ONNX OpSet version to use.
  2694. If None, the output format will be DL framework specific.
  2695. """
  2696. self.get_policy(policy_id).export_model(export_dir, onnx)
  2697. @OldAPIStack
  2698. def export_policy_checkpoint(
  2699. self,
  2700. export_dir: str,
  2701. policy_id: PolicyID = DEFAULT_POLICY_ID,
  2702. ) -> None:
  2703. """Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
  2704. Args:
  2705. export_dir: Writable local directory to store the AIR Checkpoint
  2706. information into.
  2707. policy_id: Optional policy ID to export. If not provided, will export
  2708. "default_policy". If `policy_id` does not exist in this Algorithm,
  2709. will raise a KeyError.
  2710. Raises:
  2711. KeyError: if `policy_id` cannot be found in this Algorithm.
  2712. """
  2713. policy = self.get_policy(policy_id)
  2714. if policy is None:
  2715. raise KeyError(f"Policy with ID {policy_id} not found in Algorithm!")
  2716. policy.export_checkpoint(export_dir)
  2717. @override(Trainable)
  2718. def save_checkpoint(self, checkpoint_dir: str) -> None:
  2719. """Exports checkpoint to a local directory.
  2720. The structure of an Algorithm checkpoint dir will be as follows::
  2721. policies/
  2722. pol_1/
  2723. policy_state.pkl
  2724. pol_2/
  2725. policy_state.pkl
  2726. learner/
  2727. learner_state.json
  2728. module_state/
  2729. module_1/
  2730. ...
  2731. optimizer_state/
  2732. optimizers_module_1/
  2733. ...
  2734. rllib_checkpoint.json
  2735. algorithm_state.pkl
  2736. Note: `rllib_checkpoint.json` contains a "version" key (e.g. with value 0.1)
  2737. helping RLlib to remain backward compatible wrt. restoring from checkpoints from
  2738. Ray 2.0 onwards.
  2739. Args:
  2740. checkpoint_dir: The directory where the checkpoint files will be stored.
  2741. """
  2742. with TimerAndPrometheusLogger(self._metrics_save_checkpoint_time):
  2743. # New API stack: Delegate to the `Checkpointable` implementation of
  2744. # `save_to_path()` and return.
  2745. if self.config.enable_rl_module_and_learner:
  2746. self.save_to_path(
  2747. checkpoint_dir,
  2748. use_msgpack=self.config._use_msgpack_checkpoints,
  2749. )
  2750. return
  2751. checkpoint_dir = pathlib.Path(checkpoint_dir)
  2752. state = self.__getstate__()
  2753. # Extract policy states from worker state (Policies get their own
  2754. # checkpoint sub-dirs).
  2755. policy_states = {}
  2756. if "worker" in state and "policy_states" in state["worker"]:
  2757. policy_states = state["worker"].pop("policy_states", {})
  2758. # Add RLlib checkpoint version.
  2759. if self.config.enable_rl_module_and_learner:
  2760. state["checkpoint_version"] = CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER
  2761. else:
  2762. state["checkpoint_version"] = CHECKPOINT_VERSION
  2763. # Write state (w/o policies) to disk.
  2764. state_file = checkpoint_dir / "algorithm_state.pkl"
  2765. with open(state_file, "wb") as f:
  2766. pickle.dump(state, f)
  2767. # Write rllib_checkpoint.json.
  2768. with open(checkpoint_dir / "rllib_checkpoint.json", "w") as f:
  2769. json.dump(
  2770. {
  2771. "type": "Algorithm",
  2772. "checkpoint_version": str(state["checkpoint_version"]),
  2773. "format": "cloudpickle",
  2774. "state_file": str(state_file),
  2775. "policy_ids": list(policy_states.keys()),
  2776. "ray_version": ray.__version__,
  2777. "ray_commit": ray.__commit__,
  2778. },
  2779. f,
  2780. )
  2781. # Old API stack: Write individual policies to disk, each in their own
  2782. # sub-directory.
  2783. for pid, policy_state in policy_states.items():
  2784. # From here on, disallow policyIDs that would not work as directory names.
  2785. validate_module_id(pid, error=True)
  2786. policy_dir = checkpoint_dir / "policies" / pid
  2787. os.makedirs(policy_dir, exist_ok=True)
  2788. policy = self.get_policy(pid)
  2789. policy.export_checkpoint(policy_dir, policy_state=policy_state)
  2790. # If we are using the learner API (hybrid API stack) -> Save the learner group's
  2791. # state inside a "learner" subdir. Note that this is not in line with the
  2792. # new Checkpointable API, but makes this case backward compatible.
  2793. # The new Checkpointable API is only strictly applied anyways to the
  2794. # new API stack.
  2795. if self.config.enable_rl_module_and_learner:
  2796. learner_state_dir = os.path.join(checkpoint_dir, "learner")
  2797. self.learner_group.save_to_path(learner_state_dir)
  2798. @override(Trainable)
  2799. def load_checkpoint(self, checkpoint_dir: str) -> None:
  2800. with TimerAndPrometheusLogger(self._metrics_load_checkpoint_time):
  2801. # New API stack: Delegate to the `Checkpointable` implementation of
  2802. # `restore_from_path()`.
  2803. if self.config.enable_rl_module_and_learner:
  2804. self.restore_from_path(checkpoint_dir)
  2805. else:
  2806. # Checkpoint is provided as a local directory.
  2807. # Restore from the checkpoint file or dir.
  2808. checkpoint_info = get_checkpoint_info(checkpoint_dir)
  2809. checkpoint_data = Algorithm._checkpoint_info_to_algorithm_state(
  2810. checkpoint_info
  2811. )
  2812. self.__setstate__(checkpoint_data)
  2813. # Call the `on_checkpoint_loaded` callback.
  2814. make_callback(
  2815. "on_checkpoint_loaded",
  2816. callbacks_objects=self.callbacks,
  2817. callbacks_functions=self.config.callbacks_on_checkpoint_loaded,
  2818. kwargs=dict(algorithm=self),
  2819. )
  2820. @override(Checkpointable)
  2821. def get_state(
  2822. self,
  2823. components: Optional[Union[str, Collection[str]]] = None,
  2824. *,
  2825. not_components: Optional[Union[str, Collection[str]]] = None,
  2826. **kwargs,
  2827. ) -> StateDict:
  2828. if not self.config.enable_env_runner_and_connector_v2:
  2829. raise RuntimeError(
  2830. "Algorithm.get_state() not supported on the old API stack! "
  2831. "Use Algorithm.__getstate__() instead."
  2832. )
  2833. state = {}
  2834. # Get (local) EnvRunner state (w/o RLModule).
  2835. if self.config.is_online:
  2836. if self.env_runner:
  2837. if self._check_component(
  2838. COMPONENT_ENV_RUNNER, components, not_components
  2839. ):
  2840. state[COMPONENT_ENV_RUNNER] = self.env_runner.get_state(
  2841. components=self._get_subcomponents(
  2842. COMPONENT_RL_MODULE, components
  2843. ),
  2844. not_components=force_list(
  2845. self._get_subcomponents(COMPONENT_RL_MODULE, not_components)
  2846. )
  2847. # We don't want the RLModule state from the EnvRunners (it's
  2848. # `inference_only` anyway and already provided in full by the
  2849. # Learners).
  2850. + [COMPONENT_RL_MODULE],
  2851. **kwargs,
  2852. )
  2853. else:
  2854. if self._check_component(
  2855. COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components
  2856. ):
  2857. state[
  2858. COMPONENT_ENV_TO_MODULE_CONNECTOR
  2859. ] = self.env_to_module_connector.get_state()
  2860. if self._check_component(
  2861. COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components
  2862. ):
  2863. state[
  2864. COMPONENT_MODULE_TO_ENV_CONNECTOR
  2865. ] = self.module_to_env_connector.get_state()
  2866. # Get (local) evaluation EnvRunner state (w/o RLModule).
  2867. if self.eval_env_runner and self._check_component(
  2868. COMPONENT_EVAL_ENV_RUNNER, components, not_components
  2869. ):
  2870. state[COMPONENT_EVAL_ENV_RUNNER] = self.eval_env_runner.get_state(
  2871. components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
  2872. not_components=force_list(
  2873. self._get_subcomponents(COMPONENT_RL_MODULE, not_components)
  2874. )
  2875. # We don't want the RLModule state from the EnvRunners (it's
  2876. # `inference_only` anyway and already provided in full by the Learners).
  2877. + [COMPONENT_RL_MODULE],
  2878. **kwargs,
  2879. )
  2880. # Get LearnerGroup state (w/ RLModule).
  2881. if self._check_component(COMPONENT_LEARNER_GROUP, components, not_components):
  2882. state[COMPONENT_LEARNER_GROUP] = self.learner_group.get_state(
  2883. components=self._get_subcomponents(COMPONENT_LEARNER_GROUP, components),
  2884. not_components=self._get_subcomponents(
  2885. COMPONENT_LEARNER_GROUP, not_components
  2886. ),
  2887. **kwargs,
  2888. )
  2889. # Get entire MetricsLogger state.
  2890. # TODO (sven): Make `MetricsLogger` a Checkpointable.
  2891. state[COMPONENT_METRICS_LOGGER] = self.metrics.get_state()
  2892. # Save current `training_iteration`.
  2893. state[TRAINING_ITERATION] = self.training_iteration
  2894. return state
  2895. @override(Checkpointable)
  2896. def set_state(self, state: StateDict) -> None:
  2897. # Set the (training) EnvRunners' states.
  2898. if COMPONENT_ENV_RUNNER in state:
  2899. if self.env_runner:
  2900. self.env_runner.set_state(state[COMPONENT_ENV_RUNNER])
  2901. else:
  2902. self.env_to_module_connector.set_state(
  2903. state[COMPONENT_ENV_RUNNER][COMPONENT_ENV_TO_MODULE_CONNECTOR]
  2904. )
  2905. self.module_to_env_connector.set_state(
  2906. state[COMPONENT_ENV_RUNNER][COMPONENT_MODULE_TO_ENV_CONNECTOR]
  2907. )
  2908. self.env_runner_group.sync_env_runner_states(
  2909. config=self.config,
  2910. from_worker=self.env_runner,
  2911. env_steps_sampled=self.metrics.peek(
  2912. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  2913. ),
  2914. env_to_module=self.env_to_module_connector,
  2915. module_to_env=self.module_to_env_connector,
  2916. )
  2917. # Set the (eval) EnvRunners' states.
  2918. if self.eval_env_runner_group and COMPONENT_EVAL_ENV_RUNNER in state:
  2919. if self.eval_env_runner:
  2920. self.eval_env_runner.set_state(state[COMPONENT_ENV_RUNNER])
  2921. self.eval_env_runner_group.sync_env_runner_states(
  2922. config=self.evaluation_config,
  2923. from_worker=self.env_runner,
  2924. env_steps_sampled=self.metrics.peek(
  2925. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  2926. ),
  2927. env_to_module=self.env_to_module_connector,
  2928. module_to_env=self.module_to_env_connector,
  2929. )
  2930. # Set the LearnerGroup's state.
  2931. if COMPONENT_LEARNER_GROUP in state:
  2932. self.learner_group.set_state(state[COMPONENT_LEARNER_GROUP])
  2933. # Sync new weights to all EnvRunners.
  2934. self.env_runner_group.sync_weights(
  2935. from_worker_or_learner_group=self.learner_group,
  2936. inference_only=True,
  2937. )
  2938. if self.eval_env_runner_group:
  2939. self.eval_env_runner_group.sync_weights(
  2940. from_worker_or_learner_group=self.learner_group,
  2941. inference_only=True,
  2942. )
  2943. # TODO (sven): Make `MetricsLogger` a Checkpointable.
  2944. if COMPONENT_METRICS_LOGGER in state:
  2945. self.metrics.set_state(state[COMPONENT_METRICS_LOGGER])
  2946. if TRAINING_ITERATION in state:
  2947. self._iteration = state[TRAINING_ITERATION]
  2948. @override(Checkpointable)
  2949. def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
  2950. components = [
  2951. (COMPONENT_LEARNER_GROUP, self.learner_group),
  2952. ]
  2953. if self.config.is_online and self.env_runner:
  2954. components.append(
  2955. (COMPONENT_ENV_RUNNER, self.env_runner),
  2956. )
  2957. elif self.config.is_online and not self.env_runner:
  2958. if self.env_to_module_connector:
  2959. components.append(
  2960. (COMPONENT_ENV_TO_MODULE_CONNECTOR, self.env_to_module_connector),
  2961. )
  2962. if self.module_to_env_connector:
  2963. components.append(
  2964. (COMPONENT_MODULE_TO_ENV_CONNECTOR, self.module_to_env_connector),
  2965. )
  2966. if self.eval_env_runner:
  2967. components.append(
  2968. (
  2969. COMPONENT_EVAL_ENV_RUNNER,
  2970. self.eval_env_runner,
  2971. )
  2972. )
  2973. return components
  2974. @override(Checkpointable)
  2975. def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
  2976. return (
  2977. (self.config.get_state(),), # *args,
  2978. {}, # **kwargs
  2979. )
  2980. @override(Checkpointable)
  2981. def restore_from_path(self, path, *args, **kwargs):
  2982. # Override from parent method, b/c we might have to sync the EnvRunner weights
  2983. # after having restored/loaded the LearnerGroup state.
  2984. super().restore_from_path(path, *args, **kwargs)
  2985. # Sync EnvRunners, if LearnerGroup's checkpoint can be found in path
  2986. # or user loaded a subcomponent within the LearnerGroup (for example a module).
  2987. path = pathlib.Path(path)
  2988. if (path / COMPONENT_LEARNER_GROUP).is_dir() or (
  2989. "component" in kwargs and COMPONENT_LEARNER_GROUP in kwargs["component"]
  2990. ):
  2991. # Make also sure, all (training) EnvRunners get the just loaded weights, but
  2992. # only the inference-only ones.
  2993. self.env_runner_group.sync_weights(
  2994. from_worker_or_learner_group=self.learner_group,
  2995. inference_only=True,
  2996. )
  2997. # If we have remote `EnvRunner`s but no local `EnvRunner` we have to restore states
  2998. # from path.
  2999. if self.env_runner_group.num_remote_env_runners() > 0 and not self.env_runner:
  3000. if (path / COMPONENT_ENV_TO_MODULE_CONNECTOR).is_dir():
  3001. self.env_to_module_connector.restore_from_path(
  3002. path / COMPONENT_ENV_TO_MODULE_CONNECTOR, *args, **kwargs
  3003. )
  3004. if (path / COMPONENT_MODULE_TO_ENV_CONNECTOR).is_dir():
  3005. self.module_to_env_connector.restore_from_path(
  3006. path / COMPONENT_MODULE_TO_ENV_CONNECTOR, *args, **kwargs
  3007. )
  3008. self.env_runner_group.sync_env_runner_states(
  3009. config=self.config,
  3010. from_worker=None,
  3011. env_steps_sampled=self.metrics.peek(
  3012. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  3013. ),
  3014. # connector_states=connector_states,
  3015. env_to_module=self.env_to_module_connector,
  3016. module_to_env=self.module_to_env_connector,
  3017. )
  3018. # Otherwise get the connector states from the local `EnvRunner`.
  3019. elif self.env_runner_group.num_remote_env_runners() > 0 and self.env_runner:
  3020. self.env_runner_group.sync_env_runner_states(
  3021. config=self.config,
  3022. env_steps_sampled=self.metrics.peek(
  3023. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  3024. ),
  3025. from_worker=self.env_runner,
  3026. )
  3027. @override(Trainable)
  3028. def log_result(self, result: ResultDict) -> None:
  3029. # Log after the callback is invoked, so that the user has a chance
  3030. # to mutate the result.
  3031. # TODO (sven): It might not make sense to pass in the MetricsLogger at this late
  3032. # point in time. In here, the result dict has already been "compiled" (reduced)
  3033. # by the MetricsLogger and there is probably no point in adding more Stats
  3034. # here.
  3035. with TimerAndPrometheusLogger(self._metrics_callback_on_train_result_time):
  3036. make_callback(
  3037. "on_train_result",
  3038. callbacks_objects=self.callbacks,
  3039. callbacks_functions=self.config.callbacks_on_train_result,
  3040. kwargs=dict(
  3041. algorithm=self,
  3042. metrics_logger=self.metrics,
  3043. result=result,
  3044. ),
  3045. )
  3046. # Then log according to Trainable's logging logic.
  3047. Trainable.log_result(self, result)
  3048. @override(Trainable)
  3049. def cleanup(self) -> None:
  3050. # Stop all Learners.
  3051. if hasattr(self, "learner_group") and self.learner_group is not None:
  3052. self.learner_group.shutdown()
  3053. # Stop all aggregation actors.
  3054. if hasattr(self, "_aggregator_actor_manager") and (
  3055. self._aggregator_actor_manager is not None
  3056. ):
  3057. self._aggregator_actor_manager.clear()
  3058. # Stop all EnvRunners.
  3059. if hasattr(self, "env_runner_group") and self.env_runner_group is not None:
  3060. self.env_runner_group.stop()
  3061. if (
  3062. hasattr(self, "eval_env_runner_group")
  3063. and self.eval_env_runner_group is not None
  3064. ):
  3065. self.eval_env_runner_group.stop()
  3066. if (
  3067. hasattr(self, "offline_eval_runner_group")
  3068. and self.offline_eval_runner_group is not None
  3069. ):
  3070. self.offline_eval_runner_group.stop()
  3071. @OverrideToImplementCustomLogic
  3072. @classmethod
  3073. @override(Trainable)
  3074. def default_resource_request(
  3075. cls, config: Union[AlgorithmConfig, PartialAlgorithmConfigDict]
  3076. ) -> Union[Resources, PlacementGroupFactory]:
  3077. config = cls.get_default_config().update_from_dict(config)
  3078. config.validate()
  3079. config.freeze()
  3080. eval_config = config.get_evaluation_config_object()
  3081. eval_config.validate()
  3082. eval_config.freeze()
  3083. if config.enable_rl_module_and_learner:
  3084. main_process = _get_main_process_bundle(config)
  3085. else:
  3086. main_process = {
  3087. "CPU": config.num_cpus_for_main_process,
  3088. "GPU": (
  3089. 0
  3090. if config._fake_gpus
  3091. else config.num_gpus
  3092. if not config.enable_rl_module_and_learner
  3093. else 0
  3094. ),
  3095. }
  3096. env_runner_bundles = _get_env_runner_bundles(config)
  3097. if cls._should_create_evaluation_env_runners(eval_config):
  3098. eval_env_runner_bundles = _get_env_runner_bundles(eval_config)
  3099. else:
  3100. eval_env_runner_bundles = []
  3101. if cls._should_create_offline_evaluation_runners(eval_config):
  3102. offline_eval_runner_bundles = _get_offline_eval_runner_bundles(eval_config)
  3103. else:
  3104. offline_eval_runner_bundles = []
  3105. learner_bundles = []
  3106. if config.enable_rl_module_and_learner:
  3107. learner_bundles = _get_learner_bundles(config)
  3108. bundles = (
  3109. [main_process]
  3110. + env_runner_bundles
  3111. + eval_env_runner_bundles
  3112. + offline_eval_runner_bundles
  3113. + learner_bundles
  3114. )
  3115. return PlacementGroupFactory(
  3116. bundles=bundles,
  3117. strategy=config.placement_strategy,
  3118. )
  3119. @DeveloperAPI
  3120. def _before_evaluate(self):
  3121. """Pre-evaluation callback."""
  3122. pass
  3123. @staticmethod
  3124. def _get_env_id_and_creator(
  3125. env_specifier: Union[str, EnvType, None], config: AlgorithmConfig
  3126. ) -> Tuple[Optional[str], EnvCreator]:
  3127. """Returns env_id and creator callable given original env id from config.
  3128. Args:
  3129. env_specifier: An env class, an already tune registered env ID, a known
  3130. gym env name, or None (if no env is used).
  3131. config: The AlgorithmConfig object.
  3132. Returns:
  3133. Tuple consisting of a) env ID string and b) env creator callable.
  3134. """
  3135. # Environment is specified via a string.
  3136. if isinstance(env_specifier, str):
  3137. # An already registered env.
  3138. if _global_registry.contains(ENV_CREATOR, env_specifier):
  3139. return env_specifier, _global_registry.get(ENV_CREATOR, env_specifier)
  3140. # A class path specifier.
  3141. elif "." in env_specifier:
  3142. def env_creator_from_classpath(env_context):
  3143. try:
  3144. env_obj = from_config(env_specifier, env_context)
  3145. except ValueError:
  3146. raise EnvError(
  3147. ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_specifier)
  3148. )
  3149. return env_obj
  3150. return env_specifier, env_creator_from_classpath
  3151. # Try gym/PyBullet.
  3152. else:
  3153. return env_specifier, functools.partial(
  3154. _gym_env_creator, env_descriptor=env_specifier
  3155. )
  3156. elif isinstance(env_specifier, type):
  3157. env_id = env_specifier # .__name__
  3158. if config["remote_worker_envs"]:
  3159. # Check gym version (0.22 or higher?).
  3160. # If > 0.21, can't perform auto-wrapping of the given class as this
  3161. # would lead to a pickle error.
  3162. gym_version = importlib.metadata.version("gym")
  3163. if version.parse(gym_version) >= version.parse("0.22"):
  3164. raise ValueError(
  3165. "Cannot specify a gym.Env class via `config.env` while setting "
  3166. "`config.remote_worker_env=True` AND your gym version is >= "
  3167. "0.22! Try installing an older version of gym or set `config."
  3168. "remote_worker_env=False`."
  3169. )
  3170. @ray.remote(num_cpus=1)
  3171. class _wrapper(env_specifier):
  3172. # Add convenience `_get_spaces` and `_is_multi_agent`
  3173. # methods:
  3174. def _get_spaces(self):
  3175. return self.observation_space, self.action_space
  3176. def _is_multi_agent(self):
  3177. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  3178. return isinstance(self, MultiAgentEnv)
  3179. return env_id, lambda cfg: _wrapper.remote(cfg)
  3180. # gym.Env-subclass: Also go through our RLlib gym-creator.
  3181. elif issubclass(env_specifier, gym.Env):
  3182. return env_id, functools.partial(
  3183. _gym_env_creator,
  3184. env_descriptor=env_specifier,
  3185. )
  3186. # All other env classes: Call c'tor directly.
  3187. else:
  3188. return env_id, lambda cfg: env_specifier(cfg)
  3189. # No env -> Env creator always returns None.
  3190. elif env_specifier is None:
  3191. return None, lambda env_config: None
  3192. else:
  3193. raise ValueError(
  3194. "{} is an invalid env specifier. ".format(env_specifier)
  3195. + "You can specify a custom env as either a class "
  3196. '(e.g., YourEnvCls) or a registered env id (e.g., "your_env").'
  3197. )
  3198. def _sync_filters_if_needed(
  3199. self,
  3200. *,
  3201. central_worker: EnvRunner,
  3202. workers: EnvRunnerGroup,
  3203. config: AlgorithmConfig,
  3204. ) -> None:
  3205. """Synchronizes the filter stats from `workers` to `central_worker`.
  3206. .. and broadcasts the central_worker's filter stats back to all `workers`
  3207. (if configured).
  3208. Args:
  3209. central_worker: The worker to sync/aggregate all `workers`' filter stats to
  3210. and from which to (possibly) broadcast the updated filter stats back to
  3211. `workers`.
  3212. workers: The EnvRunnerGroup, whose EnvRunners' filter stats should be used
  3213. for aggregation on `central_worker` and which (possibly) get updated
  3214. from `central_worker` after the sync.
  3215. config: The algorithm config instance. This is used to determine, whether
  3216. syncing from `workers` should happen at all and whether broadcasting
  3217. back to `workers` (after possible syncing) should happen.
  3218. """
  3219. if central_worker and config.observation_filter != "NoFilter":
  3220. FilterManager.synchronize(
  3221. central_worker.filters,
  3222. workers,
  3223. update_remote=config.update_worker_filter_stats,
  3224. timeout_seconds=config.sync_filters_on_rollout_workers_timeout_s,
  3225. use_remote_data_for_update=config.use_worker_filter_stats,
  3226. )
  3227. @classmethod
  3228. @override(Trainable)
  3229. def resource_help(cls, config: Union[AlgorithmConfig, AlgorithmConfigDict]) -> str:
  3230. return (
  3231. "\n\nYou can adjust the resource requests of RLlib Algorithms by calling "
  3232. "`AlgorithmConfig.env_runners("
  3233. "num_env_runners=.., num_cpus_per_env_runner=.., "
  3234. "num_gpus_per_env_runner=.., ..)` and "
  3235. "`AgorithmConfig.learners(num_learners=.., num_gpus_per_learner=..)`. See "
  3236. "the `ray.rllib.algorithms.algorithm_config.AlgorithmConfig` classes "
  3237. "(each Algorithm has its own subclass of this class) for more info.\n\n"
  3238. f"The config of this Algorithm is: {config}"
  3239. )
  3240. @override(Trainable)
  3241. def get_auto_filled_metrics(
  3242. self,
  3243. now: Optional[datetime] = None,
  3244. time_this_iter: Optional[float] = None,
  3245. timestamp: Optional[int] = None,
  3246. debug_metrics_only: bool = False,
  3247. ) -> dict:
  3248. # Override this method to make sure, the `config` key of the returned results
  3249. # contains the proper Tune config dict (instead of an AlgorithmConfig object).
  3250. auto_filled = super().get_auto_filled_metrics(
  3251. now, time_this_iter, timestamp, debug_metrics_only
  3252. )
  3253. if "config" not in auto_filled:
  3254. raise KeyError("`config` key not found in auto-filled results dict!")
  3255. # If `config` key is no dict (but AlgorithmConfig object) ->
  3256. # make sure, it's a dict to not break Tune APIs.
  3257. if not isinstance(auto_filled["config"], dict):
  3258. assert isinstance(auto_filled["config"], AlgorithmConfig)
  3259. auto_filled["config"] = auto_filled["config"].to_dict()
  3260. return auto_filled
  3261. @classmethod
  3262. def merge_algorithm_configs(
  3263. cls,
  3264. config1: AlgorithmConfigDict,
  3265. config2: PartialAlgorithmConfigDict,
  3266. _allow_unknown_configs: Optional[bool] = None,
  3267. ) -> AlgorithmConfigDict:
  3268. """Merges a complete Algorithm config dict with a partial override dict.
  3269. Respects nested structures within the config dicts. The values in the
  3270. partial override dict take priority.
  3271. Args:
  3272. config1: The complete Algorithm's dict to be merged (overridden)
  3273. with `config2`.
  3274. config2: The partial override config dict to merge on top of
  3275. `config1`.
  3276. _allow_unknown_configs: If True, keys in `config2` that don't exist
  3277. in `config1` are allowed and will be added to the final config.
  3278. Returns:
  3279. The merged full algorithm config dict.
  3280. """
  3281. config1 = copy.deepcopy(config1)
  3282. if "callbacks" in config2 and type(config2["callbacks"]) is dict:
  3283. deprecation_warning(
  3284. "callbacks dict interface",
  3285. "a class extending rllib.callbacks.callbacks.RLlibCallback; "
  3286. "see `rllib/examples/metrics/custom_metrics_and_callbacks.py` for an "
  3287. "example.",
  3288. error=True,
  3289. )
  3290. if _allow_unknown_configs is None:
  3291. _allow_unknown_configs = cls._allow_unknown_configs
  3292. return deep_update(
  3293. config1,
  3294. config2,
  3295. _allow_unknown_configs,
  3296. cls._allow_unknown_subkeys,
  3297. cls._override_all_subkeys_if_type_changes,
  3298. cls._override_all_key_list,
  3299. )
  3300. @staticmethod
  3301. @ExperimentalAPI
  3302. def validate_env(env: EnvType, env_context: EnvContext) -> None:
  3303. """Env validator function for this Algorithm class.
  3304. Override this in child classes to define custom validation
  3305. behavior.
  3306. Args:
  3307. env: The (sub-)environment to validate. This is normally a
  3308. single sub-environment (e.g. a gym.Env) within a vectorized
  3309. setup.
  3310. env_context: The EnvContext to configure the environment.
  3311. Raises:
  3312. Exception: in case something is wrong with the given environment.
  3313. """
  3314. pass
  3315. def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
  3316. """Runs one training iteration (`self.iteration` will be +1 after this).
  3317. Calls `self.training_step()` repeatedly until the configured minimum time (sec),
  3318. minimum sample- or minimum training steps have been reached.
  3319. Returns:
  3320. The ResultDict from the last call to `training_step()`. Note that even
  3321. though we only return the last ResultDict, the user still has full control
  3322. over the history and reduce behavior of individual metrics at the time these
  3323. metrics are logged with `self.metrics.log_...()`.
  3324. """
  3325. with TimerAndPrometheusLogger(self._metrics_run_one_training_iteration_time):
  3326. with self.metrics.log_time((TIMERS, TRAINING_ITERATION_TIMER)):
  3327. # In case we are training (in a thread) parallel to evaluation,
  3328. # we may have to re-enable eager mode here (gets disabled in the
  3329. # thread).
  3330. if self.config.get("framework") == "tf2" and not tf.executing_eagerly():
  3331. tf1.enable_eager_execution()
  3332. has_run_once = False
  3333. # Create a step context ...
  3334. with TrainIterCtx(algo=self) as train_iter_ctx:
  3335. # .. so we can query it whether we should stop the iteration loop (e.g.
  3336. # when we have reached `min_time_s_per_iteration`).
  3337. while not train_iter_ctx.should_stop(has_run_once):
  3338. # Before training step, try to bring failed workers back.
  3339. with self.metrics.log_time((TIMERS, RESTORE_ENV_RUNNERS_TIMER)):
  3340. restored = self.restore_env_runners(self.env_runner_group)
  3341. # Fire the callback for re-created EnvRunners.
  3342. if restored:
  3343. self._make_on_env_runners_recreated_callbacks(
  3344. config=self.config,
  3345. env_runner_group=self.env_runner_group,
  3346. restored_env_runner_indices=restored,
  3347. )
  3348. # Try to train one step.
  3349. with self.metrics.log_time((TIMERS, TRAINING_STEP_TIMER)):
  3350. with TimerAndPrometheusLogger(
  3351. self._metrics_training_step_time
  3352. ):
  3353. training_step_return_value = self.training_step()
  3354. has_run_once = True
  3355. # On the new API stack, results should NOT be returned anymore as
  3356. # a dict, but purely logged through the `MetricsLogger` API. This
  3357. # way, we make sure to never miss a single stats/counter/timer
  3358. # when calling `self.training_step()` more than once within the same
  3359. # iteration.
  3360. if training_step_return_value is not None:
  3361. raise ValueError(
  3362. "`Algorithm.training_step()` should NOT return a result "
  3363. "dict anymore on the new API stack! Instead, log all "
  3364. "results, timers, counters through the `self.metrics` "
  3365. "(MetricsLogger) instance of the Algorithm and return "
  3366. "None. The logged results are compiled automatically into "
  3367. "one single result dict per training iteration."
  3368. )
  3369. # TODO (sven): Resolve this metric through log_time's future
  3370. # ability to compute throughput.
  3371. self.metrics.log_value(
  3372. NUM_TRAINING_STEP_CALLS_PER_ITERATION,
  3373. 1,
  3374. reduce="sum",
  3375. )
  3376. if self.config.num_aggregator_actors_per_learner:
  3377. remote_aggregator_metrics = self._aggregator_actor_manager.foreach_actor_async_fetch_ready(
  3378. func=lambda actor: actor.get_metrics(),
  3379. tag="metrics",
  3380. timeout_seconds=0.0,
  3381. return_obj_refs=False,
  3382. # (Artur) TODO: In the future, we want to make aggregator actors fault tolerant and should make this configurable
  3383. ignore_ray_errors=False,
  3384. )
  3385. self.metrics.aggregate(
  3386. remote_aggregator_metrics,
  3387. key=AGGREGATOR_ACTOR_RESULTS,
  3388. )
  3389. # Only here (at the end of the iteration), compile the results into a single result dict.
  3390. # Calling compile here reduces the metrics into single values and adds throughputs to the results where applicable.
  3391. compiled_metrics = self.metrics.compile()
  3392. return compiled_metrics, train_iter_ctx
  3393. def _run_one_offline_evaluation(self):
  3394. """Runs offline evaluation step via `self.offline_evaluate()` and handling runner
  3395. failures.
  3396. Returns:
  3397. The results dict from the offline evaluation call.
  3398. """
  3399. # Restore crashed offline evaluation runners.
  3400. if self.offline_eval_runner_group is not None:
  3401. with self.metrics.log_time((TIMERS, RESTORE_OFFLINE_EVAL_RUNNERS_TIMER)):
  3402. restored = self.restore_offline_eval_runners(
  3403. self.offline_eval_runner_group
  3404. )
  3405. if restored:
  3406. # Fire the callback for re-created offline evaluation runners.
  3407. make_callback(
  3408. "on_offline_eval_runners_recreated",
  3409. callbacks_objects=self.callbacks,
  3410. callbacks_functions=(
  3411. self.config.callbacks_on_offline_eval_runners_recreated
  3412. ),
  3413. kwargs=dict(
  3414. algorithm=self,
  3415. env_runner_group=self.offline_eval_runner_group,
  3416. env_runner_indices=restored,
  3417. ),
  3418. )
  3419. # Run one offline evaluation and time it.
  3420. with self.metrics.log_time((TIMERS, OFFLINE_EVALUATION_ITERATION_TIMER)):
  3421. eval_results = self.evaluate_offline()
  3422. # After evaluation, do a round of health check on remote eval runners to see if
  3423. # any of the failed runners are back.
  3424. if self.offline_eval_runner_group is not None:
  3425. # Add number of healthy evaluation runners after this iteration.
  3426. eval_results[
  3427. "num_healthy_offline_eval_runners"
  3428. ] = self.offline_eval_runner_group.num_healthy_remote_runners
  3429. eval_results[
  3430. "offline_runners_actor_manager_num_outstanding_async_reqs"
  3431. ] = self.offline_eval_runner_group.num_in_flight_async_reqs
  3432. eval_results[
  3433. "num_remote_offline_eval_runners_restarts"
  3434. ] = self.offline_eval_runner_group.num_remote_runner_restarts
  3435. return {EVALUATION_RESULTS: eval_results}
  3436. def _run_one_evaluation(
  3437. self,
  3438. parallel_train_future: Optional[concurrent.futures.ThreadPoolExecutor] = None,
  3439. ) -> ResultDict:
  3440. """Runs evaluation step via `self.evaluate()` and handling worker failures.
  3441. Args:
  3442. parallel_train_future: In case, we are training and avaluating in parallel,
  3443. this arg carries the currently running ThreadPoolExecutor object that
  3444. runs the training iteration. Use `parallel_train_future.done()` to
  3445. check, whether the parallel training job has completed and
  3446. `parallel_train_future.result()` to get its return values.
  3447. Returns:
  3448. The results dict from the evaluation call.
  3449. """
  3450. with TimerAndPrometheusLogger(self._metrics_run_one_evaluation_time):
  3451. if self.eval_env_runner_group is not None:
  3452. if self.config.enable_env_runner_and_connector_v2:
  3453. with self.metrics.log_time(
  3454. (TIMERS, RESTORE_EVAL_ENV_RUNNERS_TIMER)
  3455. ):
  3456. restored = self.restore_env_runners(self.eval_env_runner_group)
  3457. else:
  3458. with self._timers["restore_eval_workers"]:
  3459. restored = self.restore_env_runners(self.eval_env_runner_group)
  3460. # Fire the callback for re-created EnvRunners.
  3461. if restored:
  3462. self._make_on_env_runners_recreated_callbacks(
  3463. config=self.evaluation_config,
  3464. env_runner_group=self.eval_env_runner_group,
  3465. restored_env_runner_indices=restored,
  3466. )
  3467. # Run `self.evaluate()` only once per training iteration.
  3468. if self.config.enable_env_runner_and_connector_v2:
  3469. with self.metrics.log_time((TIMERS, EVALUATION_ITERATION_TIMER)):
  3470. eval_results = self.evaluate(
  3471. parallel_train_future=parallel_train_future,
  3472. )
  3473. else:
  3474. with self._timers[EVALUATION_ITERATION_TIMER]:
  3475. eval_results = self.evaluate(
  3476. parallel_train_future=parallel_train_future,
  3477. )
  3478. self._timers[EVALUATION_ITERATION_TIMER].push_units_processed(
  3479. self._counters[NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER]
  3480. )
  3481. # After evaluation, do a round of health check on remote eval workers to see if
  3482. # any of the failed workers are back.
  3483. if self.eval_env_runner_group is not None:
  3484. # Add number of healthy evaluation workers after this iteration.
  3485. eval_results[
  3486. "num_healthy_workers"
  3487. ] = self.eval_env_runner_group.num_healthy_remote_workers()
  3488. eval_results[
  3489. "actor_manager_num_outstanding_async_reqs"
  3490. ] = self.eval_env_runner_group.num_in_flight_async_reqs()
  3491. eval_results[
  3492. "num_remote_worker_restarts"
  3493. ] = self.eval_env_runner_group.num_remote_worker_restarts()
  3494. return {EVALUATION_RESULTS: eval_results}
  3495. def _run_one_training_iteration_and_evaluation_in_parallel(
  3496. self,
  3497. ) -> Tuple[ResultDict, ResultDict, "TrainIterCtx"]:
  3498. """Runs one training iteration and one evaluation step in parallel.
  3499. First starts the training iteration (via `self._run_one_training_iteration()`)
  3500. within a ThreadPoolExecutor, then runs the evaluation step in parallel.
  3501. In auto-duration mode (config.evaluation_duration=auto), makes sure the
  3502. evaluation step takes roughly the same time as the training iteration.
  3503. Returns:
  3504. A tuple containing the training results, the evaluation results, and
  3505. the `TrainIterCtx` object returned by the training call.
  3506. """
  3507. with concurrent.futures.ThreadPoolExecutor() as executor:
  3508. if self.config.enable_env_runner_and_connector_v2:
  3509. parallel_train_future = executor.submit(
  3510. lambda: self._run_one_training_iteration()
  3511. )
  3512. else:
  3513. parallel_train_future = executor.submit(
  3514. lambda: self._run_one_training_iteration_old_api_stack()
  3515. )
  3516. # Pass the train_future into `self._run_one_evaluation()` to allow it
  3517. # to run exactly as long as the training iteration takes in case
  3518. # evaluation_duration=auto.
  3519. evaluation_results = self._run_one_evaluation(
  3520. parallel_train_future=parallel_train_future
  3521. )
  3522. # Collect the training results from the future.
  3523. train_results, train_iter_ctx = parallel_train_future.result()
  3524. return train_results, evaluation_results, train_iter_ctx
  3525. def _run_offline_evaluation_old_api_stack(self):
  3526. """Runs offline evaluation via `OfflineEvaluator.estimate_on_dataset()` API.
  3527. This method will be used when `evaluation_dataset` is provided.
  3528. Note: This will only work if the policy is a single agent policy.
  3529. Returns:
  3530. The results dict from the offline evaluation call.
  3531. """
  3532. assert len(self.env_runner_group.local_env_runner.policy_map) == 1
  3533. parallelism = self.evaluation_config.evaluation_num_env_runners or 1
  3534. offline_eval_results = {"off_policy_estimator": {}}
  3535. for evaluator_name, offline_evaluator in self.reward_estimators.items():
  3536. offline_eval_results["off_policy_estimator"][
  3537. evaluator_name
  3538. ] = offline_evaluator.estimate_on_dataset(
  3539. self.evaluation_dataset,
  3540. n_parallelism=parallelism,
  3541. )
  3542. return offline_eval_results
  3543. @classmethod
  3544. def _should_create_evaluation_env_runners(cls, eval_config: "AlgorithmConfig"):
  3545. """Determines whether we need to create evaluation workers.
  3546. Returns False if we need to run offline evaluation
  3547. (with ope.estimate_on_dastaset API) or when local worker is to be used for
  3548. evaluation. Note: We only use estimate_on_dataset API with bandits for now.
  3549. That is when ope_split_batch_by_episode is False.
  3550. TODO: In future we will do the same for episodic RL OPE.
  3551. """
  3552. run_offline_evaluation = (
  3553. eval_config.off_policy_estimation_methods
  3554. and not eval_config.ope_split_batch_by_episode
  3555. )
  3556. return not run_offline_evaluation and (
  3557. eval_config.evaluation_num_env_runners > 0
  3558. or eval_config.evaluation_interval
  3559. )
  3560. # TODO (simon, sven): Flexibilize the different env/offline components and move
  3561. # away from the currently hard-coded: (1) eval `EnvRunnerGroup`, (2) OfflineData
  3562. # and (3) `OfflineEvaluationRunnerGroup`.
  3563. @classmethod
  3564. def _should_create_offline_evaluation_runners(cls, eval_config: "AlgorithmConfig"):
  3565. """Determines whether we need to create offline evaluation workers."""
  3566. return (
  3567. eval_config.offline_evaluation_interval is not None
  3568. or eval_config.num_offline_eval_runners > 0
  3569. )
  3570. def _compile_iteration_results(self, *, train_results, eval_results):
  3571. with TimerAndPrometheusLogger(self._metrics_compile_iteration_results_time):
  3572. # Error if users still use `self._timers`.
  3573. if self._timers:
  3574. raise ValueError(
  3575. "`Algorithm._timers` is no longer supported on the new API stack! "
  3576. "Instead, use `Algorithm.metrics.log_time("
  3577. "[some key (str) or nested key sequence (tuple)])`, e.g. inside your "
  3578. "custom `training_step()` method, do: "
  3579. "`with self.metrics.log_time(('timers', 'my_block_to_be_timed')): ...`"
  3580. )
  3581. # Return dict (shallow copy of `train_results`).
  3582. results: ResultDict = train_results.copy()
  3583. if NUM_ENV_STEPS_SAMPLED_LIFETIME not in results:
  3584. results[NUM_ENV_STEPS_SAMPLED_LIFETIME] = results.get(
  3585. ENV_RUNNER_RESULTS, {}
  3586. ).get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)
  3587. # Evaluation results.
  3588. if eval_results:
  3589. assert (
  3590. isinstance(eval_results, dict)
  3591. and len(eval_results) == 1
  3592. and EVALUATION_RESULTS in eval_results
  3593. )
  3594. results.update(eval_results)
  3595. # EnvRunner actors fault tolerance stats.
  3596. if self.env_runner_group:
  3597. results[FAULT_TOLERANCE_STATS] = {
  3598. "num_healthy_workers": (
  3599. self.env_runner_group.num_healthy_remote_workers()
  3600. ),
  3601. "num_remote_worker_restarts": (
  3602. self.env_runner_group.num_remote_worker_restarts()
  3603. ),
  3604. }
  3605. results["env_runner_group"] = {
  3606. "actor_manager_num_outstanding_async_reqs": (
  3607. self.env_runner_group.num_in_flight_async_reqs()
  3608. ),
  3609. }
  3610. return results
  3611. def _make_on_env_runners_recreated_callbacks(
  3612. self,
  3613. *,
  3614. config,
  3615. env_runner_group,
  3616. restored_env_runner_indices,
  3617. ):
  3618. make_callback(
  3619. "on_env_runners_recreated",
  3620. callbacks_objects=self.callbacks,
  3621. callbacks_functions=(config.callbacks_on_env_runners_recreated),
  3622. kwargs=dict(
  3623. algorithm=self,
  3624. env_runner_group=env_runner_group,
  3625. env_runner_indices=restored_env_runner_indices,
  3626. is_evaluation=config.in_evaluation,
  3627. ),
  3628. )
  3629. # TODO (sven): Deprecate this call.
  3630. make_callback(
  3631. "on_workers_recreated",
  3632. callbacks_objects=self.callbacks,
  3633. kwargs=dict(
  3634. algorithm=self,
  3635. worker_set=env_runner_group,
  3636. worker_ids=restored_env_runner_indices,
  3637. is_evaluation=config.in_evaluation,
  3638. ),
  3639. )
  3640. def __repr__(self):
  3641. if self.config.enable_rl_module_and_learner:
  3642. return (
  3643. f"{type(self).__name__}("
  3644. f"env={self.config.env}; env-runners={self.config.num_env_runners}; "
  3645. f"learners={self.config.num_learners}; "
  3646. f"multi-agent={self.config.is_multi_agent}"
  3647. f")"
  3648. )
  3649. else:
  3650. return type(self).__name__
  3651. @property
  3652. def env_runner(self):
  3653. """The local EnvRunner instance within the algo's EnvRunnerGroup."""
  3654. if self.env_runner_group:
  3655. return self.env_runner_group.local_env_runner
  3656. return None
  3657. @property
  3658. def eval_env_runner(self):
  3659. """The local EnvRunner instance within the algo's evaluation EnvRunnerGroup."""
  3660. if self.eval_env_runner_group:
  3661. return self.eval_env_runner_group.local_env_runner
  3662. return None
  3663. def _record_usage(self, config):
  3664. """Record the framework and algorithm used.
  3665. Args:
  3666. config: Algorithm config dict.
  3667. """
  3668. record_extra_usage_tag(TagKey.RLLIB_FRAMEWORK, config["framework"])
  3669. record_extra_usage_tag(TagKey.RLLIB_NUM_WORKERS, str(config["num_env_runners"]))
  3670. alg = self.__class__.__name__
  3671. # We do not want to collect user defined algorithm names.
  3672. if alg not in ALL_ALGORITHMS:
  3673. alg = "USER_DEFINED"
  3674. record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg)
  3675. @OldAPIStack
  3676. def _export_model(
  3677. self, export_formats: List[str], export_dir: str
  3678. ) -> Dict[str, str]:
  3679. ExportFormat.validate(export_formats)
  3680. exported = {}
  3681. if ExportFormat.CHECKPOINT in export_formats:
  3682. path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
  3683. self.export_policy_checkpoint(path)
  3684. exported[ExportFormat.CHECKPOINT] = path
  3685. if ExportFormat.MODEL in export_formats:
  3686. path = os.path.join(export_dir, ExportFormat.MODEL)
  3687. self.export_policy_model(path)
  3688. exported[ExportFormat.MODEL] = path
  3689. if ExportFormat.ONNX in export_formats:
  3690. path = os.path.join(export_dir, ExportFormat.ONNX)
  3691. self.export_policy_model(path, onnx=int(os.getenv("ONNX_OPSET", "11")))
  3692. exported[ExportFormat.ONNX] = path
  3693. return exported
  3694. @OldAPIStack
  3695. def __getstate__(self) -> Dict:
  3696. """Returns current state of Algorithm, sufficient to restore it from scratch.
  3697. Returns:
  3698. The current state dict of this Algorithm, which can be used to sufficiently
  3699. restore the algorithm from scratch without any other information.
  3700. """
  3701. if self.config.enable_env_runner_and_connector_v2:
  3702. raise RuntimeError(
  3703. "Algorithm.__getstate__() not supported anymore on the new API stack! "
  3704. "Use Algorithm.get_state() instead."
  3705. )
  3706. # Add config to state so complete Algorithm can be reproduced w/o it.
  3707. state = {
  3708. "algorithm_class": type(self),
  3709. "config": self.config.get_state(),
  3710. }
  3711. if hasattr(self, "env_runner_group"):
  3712. state["worker"] = self.env_runner_group.local_env_runner.get_state()
  3713. # Also store eval `policy_mapping_fn` (in case it's different from main
  3714. # one). Note, the new `EnvRunner API` has no policy mapping function.
  3715. if (
  3716. hasattr(self, "eval_env_runner_group")
  3717. and self.eval_env_runner_group is not None
  3718. ):
  3719. state["eval_policy_mapping_fn"] = self.eval_env_runner.policy_mapping_fn
  3720. # Save counters.
  3721. state["counters"] = self._counters
  3722. # TODO: Experimental functionality: Store contents of replay buffer
  3723. # to checkpoint, only if user has configured this.
  3724. if self.local_replay_buffer is not None and self.config.get(
  3725. "store_buffer_in_checkpoints"
  3726. ):
  3727. state["local_replay_buffer"] = self.local_replay_buffer.get_state()
  3728. # Save current `training_iteration`.
  3729. state[TRAINING_ITERATION] = self.training_iteration
  3730. return state
  3731. @OldAPIStack
  3732. def __setstate__(self, state) -> None:
  3733. """Sets the algorithm to the provided state.
  3734. Args:
  3735. state: The state dict to restore this Algorithm instance to. `state` may
  3736. have been returned by a call to an Algorithm's `__getstate__()` method.
  3737. """
  3738. if self.config.enable_env_runner_and_connector_v2:
  3739. raise RuntimeError(
  3740. "Algorithm.__setstate__() not supported anymore on the new API stack! "
  3741. "Use Algorithm.set_state() instead."
  3742. )
  3743. # Old API stack: The local worker stores its state (together with all the
  3744. # Module information) in state['worker'].
  3745. if hasattr(self, "env_runner_group") and "worker" in state and state["worker"]:
  3746. self.env_runner.set_state(state["worker"])
  3747. remote_state_ref = ray.put(state["worker"])
  3748. self.env_runner_group.foreach_env_runner(
  3749. lambda w: w.set_state(ray.get(remote_state_ref)),
  3750. local_env_runner=False,
  3751. )
  3752. if self.eval_env_runner_group:
  3753. # Avoid `state` being pickled into the remote function below.
  3754. _eval_policy_mapping_fn = state.get("eval_policy_mapping_fn")
  3755. def _setup_eval_worker(w):
  3756. w.set_state(ray.get(remote_state_ref))
  3757. # Override `policy_mapping_fn` as it might be different for eval
  3758. # workers.
  3759. w.set_policy_mapping_fn(_eval_policy_mapping_fn)
  3760. # If evaluation workers are used, also restore the policies
  3761. # there in case they are used for evaluation purpose.
  3762. self.eval_env_runner_group.foreach_env_runner(_setup_eval_worker)
  3763. # Restore replay buffer data.
  3764. if self.local_replay_buffer is not None:
  3765. # TODO: Experimental functionality: Restore contents of replay
  3766. # buffer from checkpoint, only if user has configured this.
  3767. if self.config.store_buffer_in_checkpoints:
  3768. if "local_replay_buffer" in state:
  3769. self.local_replay_buffer.set_state(state["local_replay_buffer"])
  3770. else:
  3771. logger.warning(
  3772. "`store_buffer_in_checkpoints` is True, but no replay "
  3773. "data found in state!"
  3774. )
  3775. elif "local_replay_buffer" in state and log_once(
  3776. "no_store_buffer_in_checkpoints_but_data_found"
  3777. ):
  3778. logger.warning(
  3779. "`store_buffer_in_checkpoints` is False, but some replay "
  3780. "data found in state!"
  3781. )
  3782. if "counters" in state:
  3783. self._counters = state["counters"]
  3784. if TRAINING_ITERATION in state:
  3785. self._iteration = state[TRAINING_ITERATION]
  3786. @OldAPIStack
  3787. @staticmethod
  3788. def _checkpoint_info_to_algorithm_state(
  3789. checkpoint_info: dict,
  3790. *,
  3791. policy_ids: Optional[Collection[PolicyID]] = None,
  3792. policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None,
  3793. policies_to_train: Optional[
  3794. Union[
  3795. Collection[PolicyID],
  3796. Callable[[PolicyID, Optional[SampleBatchType]], bool],
  3797. ]
  3798. ] = None,
  3799. ) -> Dict:
  3800. """Converts a checkpoint info or object to a proper Algorithm state dict.
  3801. The returned state dict can be used inside self.__setstate__().
  3802. Args:
  3803. checkpoint_info: A checkpoint info dict as returned by
  3804. `ray.rllib.utils.checkpoints.get_checkpoint_info(
  3805. [checkpoint dir or AIR Checkpoint])`.
  3806. policy_ids: Optional list/set of PolicyIDs. If not None, only those policies
  3807. listed here will be included in the returned state. Note that
  3808. state items such as filters, the `is_policy_to_train` function, as
  3809. well as the multi-agent `policy_ids` dict will be adjusted as well,
  3810. based on this arg.
  3811. policy_mapping_fn: An optional (updated) policy mapping function
  3812. to include in the returned state.
  3813. policies_to_train: An optional list of policy IDs to be trained
  3814. or a callable taking PolicyID and SampleBatchType and
  3815. returning a bool (trainable or not?) to include in the returned state.
  3816. Returns:
  3817. The state dict usable within the `self.__setstate__()` method.
  3818. """
  3819. if checkpoint_info["type"] != "Algorithm":
  3820. raise ValueError(
  3821. "`checkpoint` arg passed to "
  3822. "`Algorithm._checkpoint_info_to_algorithm_state()` must be an "
  3823. f"Algorithm checkpoint (but is {checkpoint_info['type']})!"
  3824. )
  3825. msgpack = None
  3826. if checkpoint_info.get("format") == "msgpack":
  3827. msgpack = try_import_msgpack(error=True)
  3828. with open(checkpoint_info["state_file"], "rb") as f:
  3829. if msgpack is not None:
  3830. data = f.read()
  3831. state = msgpack.unpackb(data, raw=False)
  3832. else:
  3833. state = pickle.load(f)
  3834. # Old API stack: Policies are in separate sub-dirs.
  3835. if (
  3836. checkpoint_info["checkpoint_version"] > version.Version("0.1")
  3837. and state.get("worker") is not None
  3838. and state.get("worker")
  3839. ):
  3840. worker_state = state["worker"]
  3841. # Retrieve the set of all required policy IDs.
  3842. policy_ids = set(
  3843. policy_ids if policy_ids is not None else worker_state["policy_ids"]
  3844. )
  3845. # Remove those policies entirely from filters that are not in
  3846. # `policy_ids`.
  3847. worker_state["filters"] = {
  3848. pid: filter
  3849. for pid, filter in worker_state["filters"].items()
  3850. if pid in policy_ids
  3851. }
  3852. # Get Algorithm class.
  3853. if isinstance(state["algorithm_class"], str):
  3854. # Try deserializing from a full classpath.
  3855. # Or as a last resort: Tune registered algorithm name.
  3856. state["algorithm_class"] = deserialize_type(
  3857. state["algorithm_class"]
  3858. ) or get_trainable_cls(state["algorithm_class"])
  3859. # Compile actual config object.
  3860. default_config = state["algorithm_class"].get_default_config()
  3861. if isinstance(default_config, AlgorithmConfig):
  3862. new_config = default_config.update_from_dict(state["config"])
  3863. else:
  3864. new_config = Algorithm.merge_algorithm_configs(
  3865. default_config, state["config"]
  3866. )
  3867. # Remove policies from multiagent dict that are not in `policy_ids`.
  3868. new_policies = new_config.policies
  3869. if isinstance(new_policies, (set, list, tuple)):
  3870. new_policies = {pid for pid in new_policies if pid in policy_ids}
  3871. else:
  3872. new_policies = {
  3873. pid: spec for pid, spec in new_policies.items() if pid in policy_ids
  3874. }
  3875. new_config.multi_agent(
  3876. policies=new_policies,
  3877. policies_to_train=policies_to_train,
  3878. **(
  3879. {"policy_mapping_fn": policy_mapping_fn}
  3880. if policy_mapping_fn is not None
  3881. else {}
  3882. ),
  3883. )
  3884. state["config"] = new_config
  3885. # Prepare local `worker` state to add policies' states into it,
  3886. # read from separate policy checkpoint files.
  3887. worker_state["policy_states"] = {}
  3888. for pid in policy_ids:
  3889. policy_state_file = os.path.join(
  3890. checkpoint_info["checkpoint_dir"],
  3891. "policies",
  3892. pid,
  3893. "policy_state."
  3894. + ("msgpck" if checkpoint_info["format"] == "msgpack" else "pkl"),
  3895. )
  3896. if not os.path.isfile(policy_state_file):
  3897. raise ValueError(
  3898. "Given checkpoint does not seem to be valid! No policy "
  3899. f"state file found for PID={pid}. "
  3900. f"The file not found is: {policy_state_file}."
  3901. )
  3902. with open(policy_state_file, "rb") as f:
  3903. if msgpack is not None:
  3904. worker_state["policy_states"][pid] = msgpack.load(f)
  3905. else:
  3906. worker_state["policy_states"][pid] = pickle.load(f)
  3907. # These two functions are never serialized in a msgpack checkpoint (which
  3908. # does not store code, unlike a cloudpickle checkpoint). Hence the user has
  3909. # to provide them with the `Algorithm.from_checkpoint()` call.
  3910. if policy_mapping_fn is not None:
  3911. worker_state["policy_mapping_fn"] = policy_mapping_fn
  3912. if (
  3913. policies_to_train is not None
  3914. # `policies_to_train` might be left None in case all policies should be
  3915. # trained.
  3916. or worker_state["is_policy_to_train"] == NOT_SERIALIZABLE
  3917. ):
  3918. worker_state["is_policy_to_train"] = policies_to_train
  3919. if state["config"].enable_rl_module_and_learner:
  3920. state["learner_state_dir"] = os.path.join(
  3921. checkpoint_info["checkpoint_dir"], "learner"
  3922. )
  3923. return state
  3924. @OldAPIStack
  3925. def _create_local_replay_buffer_if_necessary(
  3926. self, config: PartialAlgorithmConfigDict
  3927. ) -> Optional[MultiAgentReplayBuffer]:
  3928. """Create a MultiAgentReplayBuffer instance if necessary.
  3929. Args:
  3930. config: Algorithm-specific configuration data.
  3931. Returns:
  3932. MultiAgentReplayBuffer instance based on algorithm config.
  3933. None, if local replay buffer is not needed.
  3934. """
  3935. if not config.get("replay_buffer_config") or config["replay_buffer_config"].get(
  3936. "no_local_replay_buffer"
  3937. ):
  3938. return
  3939. # Add parameters, if necessary.
  3940. if "EpisodeReplayBuffer" in config["replay_buffer_config"]["type"]:
  3941. # TODO (simon): Subclassing needs a proper class and therefore
  3942. # we need at this moment the string checking. Because we add
  3943. # this keyword argument the old stack ReplayBuffer constructors
  3944. # will exit with an error b/c tje keyword argument is unknown to them.
  3945. config["replay_buffer_config"][
  3946. "metrics_num_episodes_for_smoothing"
  3947. ] = self.config.metrics_num_episodes_for_smoothing
  3948. return from_config(ReplayBuffer, config["replay_buffer_config"])
  3949. @OldAPIStack
  3950. def _run_one_training_iteration_old_api_stack(self):
  3951. with self._timers[TRAINING_ITERATION_TIMER]:
  3952. if self.config.get("framework") == "tf2" and not tf.executing_eagerly():
  3953. tf1.enable_eager_execution()
  3954. results = {}
  3955. training_step_results = None
  3956. with TrainIterCtx(algo=self) as train_iter_ctx:
  3957. while not train_iter_ctx.should_stop(training_step_results):
  3958. with self._timers["restore_workers"]:
  3959. restored = self.restore_env_runners(self.env_runner_group)
  3960. # Fire the callback for re-created EnvRunners.
  3961. if restored:
  3962. self._make_on_env_runners_recreated_callbacks(
  3963. config=self.config,
  3964. env_runner_group=self.env_runner_group,
  3965. restored_env_runner_indices=restored,
  3966. )
  3967. with self._timers[TRAINING_STEP_TIMER]:
  3968. training_step_results = self.training_step()
  3969. if training_step_results:
  3970. results = training_step_results
  3971. return results, train_iter_ctx
  3972. @OldAPIStack
  3973. def _compile_iteration_results_old_api_stack(
  3974. self, *, episodes_this_iter, step_ctx, iteration_results
  3975. ):
  3976. # Results to be returned.
  3977. results: ResultDict = {}
  3978. # Evaluation results.
  3979. if "evaluation" in iteration_results:
  3980. eval_results = iteration_results.pop("evaluation")
  3981. iteration_results.pop(EVALUATION_RESULTS, None)
  3982. results["evaluation"] = results[EVALUATION_RESULTS] = eval_results
  3983. # Custom metrics and episode media.
  3984. results["custom_metrics"] = iteration_results.pop("custom_metrics", {})
  3985. results["episode_media"] = iteration_results.pop("episode_media", {})
  3986. # Learner info.
  3987. results["info"] = {LEARNER_INFO: iteration_results}
  3988. # Calculate how many (if any) of older, historical episodes we have to add to
  3989. # `episodes_this_iter` in order to reach the required smoothing window.
  3990. episodes_for_metrics = episodes_this_iter[:]
  3991. missing = self.config.metrics_num_episodes_for_smoothing - len(
  3992. episodes_this_iter
  3993. )
  3994. # We have to add some older episodes to reach the smoothing window size.
  3995. if missing > 0:
  3996. episodes_for_metrics = self._episode_history[-missing:] + episodes_this_iter
  3997. assert (
  3998. len(episodes_for_metrics)
  3999. <= self.config.metrics_num_episodes_for_smoothing
  4000. )
  4001. # Note that when there are more than `metrics_num_episodes_for_smoothing`
  4002. # episodes in `episodes_for_metrics`, leave them as-is. In this case, we'll
  4003. # compute the stats over that larger number.
  4004. # Add new episodes to our history and make sure it doesn't grow larger than
  4005. # needed.
  4006. self._episode_history.extend(episodes_this_iter)
  4007. self._episode_history = self._episode_history[
  4008. -self.config.metrics_num_episodes_for_smoothing :
  4009. ]
  4010. results[ENV_RUNNER_RESULTS] = summarize_episodes(
  4011. episodes_for_metrics,
  4012. episodes_this_iter,
  4013. self.config.keep_per_episode_custom_metrics,
  4014. )
  4015. results[
  4016. "num_healthy_workers"
  4017. ] = self.env_runner_group.num_healthy_remote_workers()
  4018. results[
  4019. "actor_manager_num_outstanding_async_reqs"
  4020. ] = self.env_runner_group.num_in_flight_async_reqs()
  4021. results[
  4022. "num_remote_worker_restarts"
  4023. ] = self.env_runner_group.num_remote_worker_restarts()
  4024. # Train-steps- and env/agent-steps this iteration.
  4025. for c in [
  4026. NUM_AGENT_STEPS_SAMPLED,
  4027. NUM_AGENT_STEPS_TRAINED,
  4028. NUM_ENV_STEPS_SAMPLED,
  4029. NUM_ENV_STEPS_TRAINED,
  4030. ]:
  4031. results[c] = self._counters[c]
  4032. time_taken_sec = step_ctx.get_time_taken_sec()
  4033. if self.config.count_steps_by == "agent_steps":
  4034. results[NUM_AGENT_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
  4035. results[NUM_AGENT_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
  4036. results[NUM_AGENT_STEPS_SAMPLED + "_throughput_per_sec"] = (
  4037. step_ctx.sampled / time_taken_sec
  4038. )
  4039. results[NUM_AGENT_STEPS_TRAINED + "_throughput_per_sec"] = (
  4040. step_ctx.trained / time_taken_sec
  4041. )
  4042. # TODO: For CQL and other algos, count by trained steps.
  4043. results["timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED]
  4044. else:
  4045. results[NUM_ENV_STEPS_SAMPLED + "_this_iter"] = step_ctx.sampled
  4046. results[NUM_ENV_STEPS_TRAINED + "_this_iter"] = step_ctx.trained
  4047. results[NUM_ENV_STEPS_SAMPLED + "_throughput_per_sec"] = (
  4048. step_ctx.sampled / time_taken_sec
  4049. )
  4050. results[NUM_ENV_STEPS_TRAINED + "_throughput_per_sec"] = (
  4051. step_ctx.trained / time_taken_sec
  4052. )
  4053. # TODO: For CQL and other algos, count by trained steps.
  4054. results["timesteps_total"] = self._counters[NUM_ENV_STEPS_SAMPLED]
  4055. # Forward compatibility with new API stack.
  4056. results[NUM_ENV_STEPS_SAMPLED_LIFETIME] = results["timesteps_total"]
  4057. results[NUM_AGENT_STEPS_SAMPLED_LIFETIME] = self._counters[
  4058. NUM_AGENT_STEPS_SAMPLED
  4059. ]
  4060. # TODO: Backward compatibility.
  4061. results[STEPS_TRAINED_THIS_ITER_COUNTER] = step_ctx.trained
  4062. results["agent_timesteps_total"] = self._counters[NUM_AGENT_STEPS_SAMPLED]
  4063. # Process timer results.
  4064. timers = {}
  4065. for k, timer in self._timers.items():
  4066. timers["{}_time_ms".format(k)] = round(timer.mean * 1000, 3)
  4067. if timer.has_units_processed():
  4068. timers["{}_throughput".format(k)] = round(timer.mean_throughput, 3)
  4069. results["timers"] = timers
  4070. # Process counter results.
  4071. counters = {}
  4072. for k, counter in self._counters.items():
  4073. counters[k] = counter
  4074. results["counters"] = counters
  4075. # TODO: Backward compatibility.
  4076. results["info"].update(counters)
  4077. return results
  4078. @OldAPIStack
  4079. @Deprecated(
  4080. help="`Algorithm.compute_single_action` should no longer be used. Get the "
  4081. "RLModule instance through `Algorithm.get_module([module ID])`, then compute "
  4082. "actions through `RLModule.forward_inference({'obs': [obs batch]})`.",
  4083. error=False,
  4084. )
  4085. def compute_single_action(
  4086. self,
  4087. observation: Optional[TensorStructType] = None,
  4088. state: Optional[List[TensorStructType]] = None,
  4089. *,
  4090. prev_action: Optional[TensorStructType] = None,
  4091. prev_reward: Optional[float] = None,
  4092. info: Optional[EnvInfoDict] = None,
  4093. input_dict: Optional[SampleBatch] = None,
  4094. policy_id: PolicyID = DEFAULT_POLICY_ID,
  4095. full_fetch: bool = False,
  4096. explore: Optional[bool] = None,
  4097. timestep: Optional[int] = None,
  4098. episode=None,
  4099. unsquash_action: Optional[bool] = None,
  4100. clip_action: Optional[bool] = None,
  4101. ) -> Union[
  4102. TensorStructType,
  4103. Tuple[TensorStructType, List[TensorType], Dict[str, TensorType]],
  4104. ]:
  4105. if unsquash_action is None:
  4106. unsquash_action = self.config.normalize_actions
  4107. elif clip_action is None:
  4108. clip_action = self.config.clip_actions
  4109. err_msg = (
  4110. "Provide either `input_dict` OR [`observation`, ...] as "
  4111. "args to `Algorithm.compute_single_action()`!"
  4112. )
  4113. if input_dict is not None:
  4114. assert (
  4115. observation is None
  4116. and prev_action is None
  4117. and prev_reward is None
  4118. and state is None
  4119. ), err_msg
  4120. observation = input_dict[Columns.OBS]
  4121. else:
  4122. assert observation is not None, err_msg
  4123. policy = self.get_policy(policy_id)
  4124. if policy is None:
  4125. raise KeyError(
  4126. f"PolicyID '{policy_id}' not found in PolicyMap of the "
  4127. f"Algorithm's local worker!"
  4128. )
  4129. pp = policy.agent_connectors[ObsPreprocessorConnector]
  4130. if not isinstance(observation, (np.ndarray, dict, tuple)):
  4131. try:
  4132. observation = np.asarray(observation)
  4133. except Exception:
  4134. raise ValueError(
  4135. f"Observation type {type(observation)} cannot be converted to "
  4136. f"np.ndarray."
  4137. )
  4138. if pp:
  4139. assert len(pp) == 1, "Only one preprocessor should be in the pipeline"
  4140. pp = pp[0]
  4141. if not pp.is_identity():
  4142. pp.in_eval()
  4143. if observation is not None:
  4144. _input_dict = {Columns.OBS: observation}
  4145. elif input_dict is not None:
  4146. _input_dict = {Columns.OBS: input_dict[Columns.OBS]}
  4147. else:
  4148. raise ValueError(
  4149. "Either observation or input_dict must be provided."
  4150. )
  4151. acd = AgentConnectorDataType("0", "0", _input_dict)
  4152. pp.reset(env_id="0")
  4153. ac_o = pp([acd])[0]
  4154. observation = ac_o.data[Columns.OBS]
  4155. if input_dict is not None:
  4156. input_dict[Columns.OBS] = observation
  4157. action, state, extra = policy.compute_single_action(
  4158. input_dict=input_dict,
  4159. explore=explore,
  4160. timestep=timestep,
  4161. episode=episode,
  4162. )
  4163. else:
  4164. action, state, extra = policy.compute_single_action(
  4165. obs=observation,
  4166. state=state,
  4167. prev_action=prev_action,
  4168. prev_reward=prev_reward,
  4169. info=info,
  4170. explore=explore,
  4171. timestep=timestep,
  4172. episode=episode,
  4173. )
  4174. if unsquash_action:
  4175. action = space_utils.unsquash_action(action, policy.action_space_struct)
  4176. elif clip_action:
  4177. action = space_utils.clip_action(action, policy.action_space_struct)
  4178. if state or full_fetch:
  4179. return action, state, extra
  4180. else:
  4181. return action
  4182. @OldAPIStack
  4183. @Deprecated(
  4184. help="`Algorithm.compute_actions` should no longer be used. Get the RLModule "
  4185. "instance through `Algorithm.get_module([module ID])`, then compute actions "
  4186. "through `RLModule.forward_inference({'obs': [obs batch]})`.",
  4187. error=False,
  4188. )
  4189. def compute_actions(
  4190. self,
  4191. observations: TensorStructType,
  4192. state: Optional[List[TensorStructType]] = None,
  4193. *,
  4194. prev_action: Optional[TensorStructType] = None,
  4195. prev_reward: Optional[TensorStructType] = None,
  4196. info: Optional[EnvInfoDict] = None,
  4197. policy_id: PolicyID = DEFAULT_POLICY_ID,
  4198. full_fetch: bool = False,
  4199. explore: Optional[bool] = None,
  4200. timestep: Optional[int] = None,
  4201. episodes=None,
  4202. unsquash_actions: Optional[bool] = None,
  4203. clip_actions: Optional[bool] = None,
  4204. ):
  4205. if unsquash_actions is None:
  4206. unsquash_actions = self.config.normalize_actions
  4207. elif clip_actions is None:
  4208. clip_actions = self.config.clip_actions
  4209. state_defined = state is not None
  4210. policy = self.get_policy(policy_id)
  4211. filtered_obs, filtered_state = [], []
  4212. for agent_id, ob in observations.items():
  4213. worker = self.env_runner_group.local_env_runner
  4214. if worker.preprocessors.get(policy_id) is not None:
  4215. preprocessed = worker.preprocessors[policy_id].transform(ob)
  4216. else:
  4217. preprocessed = ob
  4218. filtered = worker.filters[policy_id](preprocessed, update=False)
  4219. filtered_obs.append(filtered)
  4220. if state is None:
  4221. continue
  4222. elif agent_id in state:
  4223. filtered_state.append(state[agent_id])
  4224. else:
  4225. filtered_state.append(policy.get_initial_state())
  4226. obs_batch = np.stack(filtered_obs)
  4227. if state is None:
  4228. state = []
  4229. else:
  4230. state = list(zip(*filtered_state))
  4231. state = [np.stack(s) for s in state]
  4232. input_dict = {Columns.OBS: obs_batch}
  4233. if prev_action is not None:
  4234. input_dict[SampleBatch.PREV_ACTIONS] = prev_action
  4235. if prev_reward is not None:
  4236. input_dict[SampleBatch.PREV_REWARDS] = prev_reward
  4237. if info:
  4238. input_dict[Columns.INFOS] = info
  4239. for i, s in enumerate(state):
  4240. input_dict[f"state_in_{i}"] = s
  4241. actions, states, infos = policy.compute_actions_from_input_dict(
  4242. input_dict=input_dict,
  4243. explore=explore,
  4244. timestep=timestep,
  4245. episodes=episodes,
  4246. )
  4247. single_actions = space_utils.unbatch(actions)
  4248. actions = {}
  4249. for key, a in zip(observations, single_actions):
  4250. if unsquash_actions:
  4251. a = space_utils.unsquash_action(a, policy.action_space_struct)
  4252. elif clip_actions:
  4253. a = space_utils.clip_action(a, policy.action_space_struct)
  4254. actions[key] = a
  4255. unbatched_states = {}
  4256. for idx, agent_id in enumerate(observations):
  4257. unbatched_states[agent_id] = [s[idx] for s in states]
  4258. if state_defined or full_fetch:
  4259. return actions, unbatched_states, infos
  4260. else:
  4261. return actions
  4262. @Deprecated(new="Algorithm.restore_env_runners", error=True)
  4263. def restore_workers(self, *args, **kwargs):
  4264. pass
  4265. @Deprecated(
  4266. new="Algorithm.env_runner_group",
  4267. error=True,
  4268. )
  4269. @property
  4270. def workers(self):
  4271. return self.env_runner_group
  4272. @Deprecated(
  4273. new="Algorithm.eval_env_runner_group",
  4274. error=True,
  4275. )
  4276. @property
  4277. def evaluation_workers(self):
  4278. return self.eval_env_runner_group
  4279. class TrainIterCtx:
  4280. def __init__(self, algo: Algorithm):
  4281. self.algo = algo
  4282. self.time_start = None
  4283. self.time_stop = None
  4284. def __enter__(self):
  4285. # Before first call to `step()`, `results` is expected to be None ->
  4286. # Start with self.failures=-1 -> set to 0 before the very first call
  4287. # to `self.step()`.
  4288. self.failures = -1
  4289. self.time_start = time.time()
  4290. self.sampled = 0
  4291. self.trained = 0
  4292. if self.algo.config.enable_env_runner_and_connector_v2:
  4293. self.init_env_steps_sampled = self.algo.metrics.peek(
  4294. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  4295. )
  4296. self.init_env_steps_trained = self.algo.metrics.peek(
  4297. (LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME),
  4298. default=0,
  4299. )
  4300. self.init_agent_steps_sampled = sum(
  4301. self.algo.metrics.peek(
  4302. (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME), default={}
  4303. ).values()
  4304. )
  4305. self.init_agent_steps_trained = sum(
  4306. self.algo.metrics.peek(
  4307. (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED_LIFETIME), default={}
  4308. ).values()
  4309. )
  4310. else:
  4311. self.init_env_steps_sampled = self.algo._counters[NUM_ENV_STEPS_SAMPLED]
  4312. self.init_env_steps_trained = self.algo._counters[NUM_ENV_STEPS_TRAINED]
  4313. self.init_agent_steps_sampled = self.algo._counters[NUM_AGENT_STEPS_SAMPLED]
  4314. self.init_agent_steps_trained = self.algo._counters[NUM_AGENT_STEPS_TRAINED]
  4315. self.failure_tolerance = (
  4316. self.algo.config.num_consecutive_env_runner_failures_tolerance
  4317. )
  4318. return self
  4319. def __exit__(self, *args):
  4320. self.time_stop = time.time()
  4321. def get_time_taken_sec(self) -> float:
  4322. """Returns the time we spent in the context in seconds."""
  4323. return self.time_stop - self.time_start
  4324. def should_stop(self, results):
  4325. # Before first call to `step()`.
  4326. if results in [None, False]:
  4327. # Fail after n retries.
  4328. self.failures += 1
  4329. if self.failures > self.failure_tolerance:
  4330. raise RuntimeError(
  4331. "More than `num_consecutive_env_runner_failures_tolerance="
  4332. f"{self.failure_tolerance}` consecutive worker failures! "
  4333. "Exiting."
  4334. )
  4335. # Continue to very first `step()` call or retry `step()` after
  4336. # a (tolerable) failure.
  4337. return False
  4338. # Stopping criteria.
  4339. if self.algo.config.enable_env_runner_and_connector_v2:
  4340. if self.algo.config.count_steps_by == "agent_steps":
  4341. self.sampled = (
  4342. sum(
  4343. self.algo.metrics.peek(
  4344. (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME),
  4345. default={},
  4346. ).values()
  4347. )
  4348. - self.init_agent_steps_sampled
  4349. )
  4350. self.trained = (
  4351. sum(
  4352. self.algo.metrics.peek(
  4353. (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED_LIFETIME),
  4354. default={},
  4355. ).values()
  4356. )
  4357. - self.init_agent_steps_trained
  4358. )
  4359. else:
  4360. self.sampled = (
  4361. self.algo.metrics.peek(
  4362. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  4363. )
  4364. - self.init_env_steps_sampled
  4365. )
  4366. self.trained = (
  4367. self.algo.metrics.peek(
  4368. (LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME),
  4369. default=0,
  4370. )
  4371. - self.init_env_steps_trained
  4372. )
  4373. else:
  4374. if self.algo.config.count_steps_by == "agent_steps":
  4375. self.sampled = (
  4376. self.algo._counters[NUM_AGENT_STEPS_SAMPLED]
  4377. - self.init_agent_steps_sampled
  4378. )
  4379. self.trained = (
  4380. self.algo._counters[NUM_AGENT_STEPS_TRAINED]
  4381. - self.init_agent_steps_trained
  4382. )
  4383. else:
  4384. self.sampled = (
  4385. self.algo._counters[NUM_ENV_STEPS_SAMPLED]
  4386. - self.init_env_steps_sampled
  4387. )
  4388. self.trained = (
  4389. self.algo._counters[NUM_ENV_STEPS_TRAINED]
  4390. - self.init_env_steps_trained
  4391. )
  4392. min_t = self.algo.config.min_time_s_per_iteration
  4393. min_sample_ts = self.algo.config.min_sample_timesteps_per_iteration
  4394. min_train_ts = self.algo.config.min_train_timesteps_per_iteration
  4395. # Repeat if not enough time has passed or if not enough
  4396. # env|train timesteps have been processed (or these min
  4397. # values are not provided by the user).
  4398. if (
  4399. (not min_t or time.time() - self.time_start >= min_t)
  4400. and (not min_sample_ts or self.sampled >= min_sample_ts)
  4401. and (not min_train_ts or self.trained >= min_train_ts)
  4402. ):
  4403. return True
  4404. else:
  4405. return False