compiled_dag_node.py 140 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312
  1. import asyncio
  2. import logging
  3. import threading
  4. import time
  5. import traceback
  6. import uuid
  7. import weakref
  8. from collections import defaultdict
  9. from contextlib import nullcontext
  10. from dataclasses import asdict, dataclass
  11. from typing import (
  12. Any,
  13. Dict,
  14. List,
  15. Optional,
  16. Set,
  17. Tuple,
  18. Union,
  19. )
  20. import ray
  21. import ray.exceptions
  22. from ray.dag.constants import (
  23. RAY_CGRAPH_ENABLE_NVTX_PROFILING,
  24. RAY_CGRAPH_ENABLE_TORCH_PROFILING,
  25. RAY_CGRAPH_VISUALIZE_SCHEDULE,
  26. )
  27. from ray.dag.dag_node_operation import (
  28. _build_dag_node_operation_graph,
  29. _DAGNodeOperation,
  30. _DAGNodeOperationType,
  31. _DAGOperationGraphNode,
  32. _extract_execution_schedule,
  33. _generate_actor_to_execution_schedule,
  34. _generate_overlapped_execution_schedule,
  35. _visualize_execution_schedule,
  36. )
  37. from ray.dag.dag_operation_future import DAGOperationFuture, GPUFuture, ResolvedFuture
  38. from ray.exceptions import (
  39. RayCgraphCapacityExceeded,
  40. RayChannelError,
  41. RayChannelTimeoutError,
  42. RayTaskError,
  43. )
  44. from ray.experimental.channel import (
  45. AwaitableBackgroundReader,
  46. AwaitableBackgroundWriter,
  47. ChannelContext,
  48. ChannelInterface,
  49. ChannelOutputType,
  50. CompiledDAGArgs,
  51. CompositeChannel,
  52. IntraProcessChannel,
  53. ReaderInterface,
  54. SynchronousReader,
  55. SynchronousWriter,
  56. WriterInterface,
  57. )
  58. from ray.experimental.channel.accelerator_context import AcceleratorContext
  59. from ray.experimental.channel.auto_transport_type import (
  60. AutoTransportType,
  61. TypeHintResolver,
  62. )
  63. from ray.experimental.channel.cached_channel import CachedChannel
  64. from ray.experimental.channel.communicator import Communicator
  65. from ray.experimental.channel.shared_memory_channel import (
  66. SharedMemoryType,
  67. )
  68. from ray.experimental.channel.torch_tensor_accelerator_channel import (
  69. _destroy_communicator,
  70. _init_communicator,
  71. )
  72. from ray.experimental.channel.torch_tensor_type import TorchTensorType
  73. from ray.experimental.compiled_dag_ref import (
  74. CompiledDAGFuture,
  75. CompiledDAGRef,
  76. _process_return_vals,
  77. )
  78. from ray.util.annotations import DeveloperAPI
  79. from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
  80. logger = logging.getLogger(__name__)
  81. # Keep tracking of every compiled dag created during the lifetime of
  82. # this process. It tracks them as weakref meaning when the compiled dag
  83. # is GC'ed, it is automatically removed from here. It is used to teardown
  84. # compiled dags at interpreter shutdown time.
  85. _compiled_dags = weakref.WeakValueDictionary()
  86. # Relying on __del__ doesn't work well upon shutdown because
  87. # the destructor order is not guaranteed. We call this function
  88. # upon `ray.worker.shutdown` which is registered to atexit handler
  89. # so that teardown is properly called before objects are destructed.
  90. def _shutdown_all_compiled_dags():
  91. global _compiled_dags
  92. for _, compiled_dag in _compiled_dags.items():
  93. # Kill DAG actors to avoid hanging during shutdown if the actor tasks
  94. # cannot be cancelled.
  95. compiled_dag.teardown(kill_actors=True)
  96. _compiled_dags = weakref.WeakValueDictionary()
  97. def _check_unused_dag_input_attributes(
  98. output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str]
  99. ) -> Set[str]:
  100. """
  101. Helper function to check that all input attributes are used in the DAG.
  102. For example, if the user creates an input attribute by calling
  103. InputNode()["x"], we ensure that there is a path from the
  104. InputAttributeNode corresponding to "x" to the DAG's output. If an
  105. input attribute is not used, throw an error.
  106. Args:
  107. output_node: The starting node for the traversal.
  108. input_attributes: A set of attributes accessed by the InputNode.
  109. """
  110. from ray.dag import InputAttributeNode
  111. used_attributes = set()
  112. visited_nodes = set()
  113. stack: List["ray.dag.DAGNode"] = [output_node]
  114. while stack:
  115. current_node = stack.pop()
  116. if current_node in visited_nodes:
  117. continue
  118. visited_nodes.add(current_node)
  119. if isinstance(current_node, InputAttributeNode):
  120. used_attributes.add(current_node.key)
  121. stack.extend(current_node._upstream_nodes)
  122. unused_attributes = input_attributes - used_attributes
  123. if unused_attributes:
  124. unused_attributes_str = ", ".join(str(key) for key in unused_attributes)
  125. input_attributes_str = ", ".join(str(key) for key in input_attributes)
  126. unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused"
  127. raise ValueError(
  128. "Compiled Graph expects input to be accessed "
  129. f"using all of attributes {input_attributes_str}, "
  130. f"but {unused_attributes_str} {unused_phrase}. "
  131. "Ensure all input attributes are used and contribute "
  132. "to the computation of the Compiled Graph output."
  133. )
  134. @DeveloperAPI
  135. def do_allocate_channel(
  136. self,
  137. reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]],
  138. typ: ChannelOutputType,
  139. driver_actor_id: Optional[str] = None,
  140. ) -> ChannelInterface:
  141. """Generic actor method to allocate an output channel.
  142. Args:
  143. reader_and_node_list: A list of tuples, where each tuple contains a reader
  144. actor handle and the node ID where the actor is located.
  145. typ: The output type hint for the channel.
  146. driver_actor_id: If this channel is read by a driver and that driver is an
  147. actual actor, this will be the actor ID of that driver actor.
  148. Returns:
  149. The allocated channel.
  150. """
  151. # None means it is called from a driver.
  152. writer: Optional["ray.actor.ActorHandle"] = None
  153. try:
  154. writer = ray.get_runtime_context().current_actor
  155. except RuntimeError:
  156. # This is the driver so there is no current actor handle.
  157. pass
  158. output_channel = typ.create_channel(
  159. writer,
  160. reader_and_node_list,
  161. driver_actor_id,
  162. )
  163. return output_channel
  164. @DeveloperAPI
  165. def do_exec_tasks(
  166. self,
  167. tasks: List["ExecutableTask"],
  168. schedule: List[_DAGNodeOperation],
  169. overlap_gpu_communication: bool = False,
  170. ) -> None:
  171. """A generic actor method to begin executing the operations belonging to an
  172. actor. This runs an infinite loop to execute each _DAGNodeOperation in the
  173. order specified by the schedule. It exits only if the actor dies or an
  174. exception is thrown.
  175. Args:
  176. tasks: the executable tasks corresponding to the actor methods.
  177. schedule: A list of _DAGNodeOperation that should be executed in order.
  178. overlap_gpu_communication: Whether to overlap GPU communication with
  179. computation during DAG execution to improve performance.
  180. """
  181. try:
  182. for task in tasks:
  183. task.prepare(overlap_gpu_communication=overlap_gpu_communication)
  184. if RAY_CGRAPH_ENABLE_NVTX_PROFILING:
  185. assert (
  186. not RAY_CGRAPH_ENABLE_TORCH_PROFILING
  187. ), "NVTX and torch profiling cannot be enabled at the same time."
  188. try:
  189. import nvtx
  190. except ImportError:
  191. raise ImportError(
  192. "Please install nvtx to enable nsight profiling. "
  193. "You can install it by running `pip install nvtx`."
  194. )
  195. nvtx_profile = nvtx.Profile()
  196. nvtx_profile.enable()
  197. if RAY_CGRAPH_ENABLE_TORCH_PROFILING:
  198. assert (
  199. not RAY_CGRAPH_ENABLE_NVTX_PROFILING
  200. ), "NVTX and torch profiling cannot be enabled at the same time."
  201. import torch
  202. torch_profile = torch.profiler.profile(
  203. activities=[
  204. torch.profiler.ProfilerActivity.CPU,
  205. torch.profiler.ProfilerActivity.CUDA,
  206. ],
  207. with_stack=True,
  208. on_trace_ready=torch.profiler.tensorboard_trace_handler(
  209. "compiled_graph_torch_profiles"
  210. ),
  211. )
  212. torch_profile.start()
  213. logger.info("Torch profiling started")
  214. done = False
  215. while True:
  216. if done:
  217. break
  218. for operation in schedule:
  219. done = tasks[operation.exec_task_idx].exec_operation(
  220. self, operation.type, overlap_gpu_communication
  221. )
  222. if done:
  223. break
  224. if RAY_CGRAPH_ENABLE_NVTX_PROFILING:
  225. nvtx_profile.disable()
  226. if RAY_CGRAPH_ENABLE_TORCH_PROFILING:
  227. torch_profile.stop()
  228. logger.info("Torch profiling stopped")
  229. except Exception:
  230. logging.exception("Compiled DAG task exited with exception")
  231. raise
  232. @DeveloperAPI
  233. def do_profile_tasks(
  234. self,
  235. tasks: List["ExecutableTask"],
  236. schedule: List[_DAGNodeOperation],
  237. overlap_gpu_communication: bool = False,
  238. ) -> None:
  239. """A generic actor method similar to `do_exec_tasks`, but with profiling enabled.
  240. Args:
  241. tasks: the executable tasks corresponding to the actor methods.
  242. schedule: A list of _DAGNodeOperation that should be executed in order.
  243. overlap_gpu_communication: Whether to overlap GPU communication with
  244. computation during DAG execution to improve performance.
  245. """
  246. try:
  247. for task in tasks:
  248. task.prepare(overlap_gpu_communication=overlap_gpu_communication)
  249. if not hasattr(self, "__ray_cgraph_events"):
  250. self.__ray_cgraph_events = []
  251. done = False
  252. while True:
  253. if done:
  254. break
  255. for operation in schedule:
  256. start_t = time.perf_counter()
  257. task = tasks[operation.exec_task_idx]
  258. done = task.exec_operation(
  259. self, operation.type, overlap_gpu_communication
  260. )
  261. end_t = time.perf_counter()
  262. self.__ray_cgraph_events.append(
  263. _ExecutableTaskRecord(
  264. actor_classname=self.__class__.__name__,
  265. actor_name=ray.get_runtime_context().get_actor_name(),
  266. actor_id=ray.get_runtime_context().get_actor_id(),
  267. method_name=task.method_name,
  268. bind_index=task.bind_index,
  269. operation=operation.type.value,
  270. start_t=start_t,
  271. end_t=end_t,
  272. )
  273. )
  274. if done:
  275. break
  276. except Exception:
  277. logging.exception("Compiled DAG task exited with exception")
  278. raise
  279. @DeveloperAPI
  280. def do_cancel_executable_tasks(self, tasks: List["ExecutableTask"]) -> None:
  281. # CUDA events should be destroyed before other CUDA resources.
  282. for task in tasks:
  283. task.destroy_cuda_event()
  284. for task in tasks:
  285. task.cancel()
  286. def _wrap_exception(exc):
  287. backtrace = ray._private.utils.format_error_message(
  288. "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
  289. task_exception=True,
  290. )
  291. wrapped = RayTaskError(
  292. function_name="do_exec_tasks",
  293. traceback_str=backtrace,
  294. cause=exc,
  295. )
  296. return wrapped
  297. def _get_comm_group_id(type_hint: ChannelOutputType) -> Optional[str]:
  298. """
  299. Get the communicator group ID from the type hint. If the type hint does not
  300. require communicator, return None.
  301. Args:
  302. type_hint: The type hint of the channel.
  303. Returns:
  304. The communicator group ID if the type hint requires communicator,
  305. otherwise None.
  306. """
  307. if type_hint.requires_accelerator():
  308. assert isinstance(type_hint, TorchTensorType)
  309. return type_hint.communicator_id
  310. return None
  311. def _device_context_manager():
  312. """
  313. Return a context manager for executing communication operations
  314. (i.e., READ and WRITE). For accelerator operations, the context manager
  315. uses the proper cuda device from channel context, otherwise,
  316. nullcontext will be returned.
  317. """
  318. if not ChannelContext.get_current().torch_available:
  319. return nullcontext()
  320. import torch
  321. from ray.experimental.channel.accelerator_context import AcceleratorContext
  322. device = ChannelContext.get_current().torch_device
  323. if device.type == "cuda" and not torch.cuda.is_available():
  324. # In the case of mocked NCCL, we may get a device with type "cuda"
  325. # but CUDA is not available. We return nullcontext() in that case,
  326. # otherwise torch raises a runtime error if the cuda device context
  327. # manager is used.
  328. # TODO(rui): consider better mocking NCCL to support device context.
  329. return nullcontext()
  330. return AcceleratorContext.get().get_device_context(device)
  331. @DeveloperAPI
  332. class CompiledTask:
  333. """Wraps the normal Ray DAGNode with some metadata."""
  334. def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
  335. """
  336. Args:
  337. idx: A unique index into the original DAG.
  338. dag_node: The original DAG node created by the user.
  339. """
  340. self.idx = idx
  341. self.dag_node = dag_node
  342. # Dict from task index to actor handle for immediate downstream tasks.
  343. self.downstream_task_idxs: Dict[int, "ray.actor.ActorHandle"] = {}
  344. # Case 1: The task represents a ClassMethodNode.
  345. #
  346. # Multiple return values are written to separate `output_channels`.
  347. # `output_idxs` represents the tuple index of the output value for
  348. # multiple returns in a tuple. If an output index is None, it means
  349. # the complete return value is written to the output channel.
  350. # Otherwise, the return value is a tuple and the index is used
  351. # to extract the value to be written to the output channel.
  352. #
  353. # Case 2: The task represents an InputNode.
  354. #
  355. # `output_idxs` can be an integer or a string to retrieve the
  356. # corresponding value from `args` or `kwargs` in the DAG's input.
  357. self.output_channels: List[ChannelInterface] = []
  358. self.output_idxs: List[Optional[Union[int, str]]] = []
  359. # The DAGNodes that are arguments to this task.
  360. # This is used for lazy resolution of the arguments' type hints.
  361. self.arg_nodes: List["ray.dag.DAGNode"] = []
  362. # idxs of possible ClassMethodOutputNodes if they exist, used for visualization
  363. self.output_node_idxs: List[int] = []
  364. @property
  365. def args(self) -> Tuple[Any]:
  366. return self.dag_node.get_args()
  367. @property
  368. def kwargs(self) -> Dict[str, Any]:
  369. return self.dag_node.get_kwargs()
  370. @property
  371. def num_readers(self) -> int:
  372. return len(self.downstream_task_idxs)
  373. @property
  374. def arg_type_hints(self) -> List["ChannelOutputType"]:
  375. return [arg_node.type_hint for arg_node in self.arg_nodes]
  376. def __str__(self) -> str:
  377. return f"""
  378. Node: {self.dag_node}
  379. Arguments: {self.args}
  380. Output: {self.output_channels}
  381. """
  382. class _ExecutableTaskInput:
  383. """Represents an input to an ExecutableTask.
  384. Args:
  385. input_variant: either an unresolved input (when type is ChannelInterface)
  386. , or a resolved input value (when type is Any)
  387. channel_idx: if input_variant is an unresolved input, this is the index
  388. into the input channels list.
  389. """
  390. def __init__(
  391. self,
  392. input_variant: Union[ChannelInterface, Any],
  393. channel_idx: Optional[int],
  394. ):
  395. self.input_variant = input_variant
  396. self.channel_idx = channel_idx
  397. def resolve(self, channel_results: Any) -> Any:
  398. """
  399. Resolve the input value from the channel results.
  400. Args:
  401. channel_results: The results from reading the input channels.
  402. """
  403. if isinstance(self.input_variant, ChannelInterface):
  404. value = channel_results[self.channel_idx]
  405. else:
  406. value = self.input_variant
  407. return value
  408. @DeveloperAPI
  409. class ExecutableTask:
  410. """A task that can be executed in a compiled DAG, and it
  411. corresponds to an actor method.
  412. """
  413. def __init__(
  414. self,
  415. task: "CompiledTask",
  416. resolved_args: List[Any],
  417. resolved_kwargs: Dict[str, Any],
  418. ):
  419. """
  420. Args:
  421. task: The CompiledTask that this ExecutableTask corresponds to.
  422. resolved_args: The arguments to the method. Arguments that are
  423. not Channels will get passed through to the actor method.
  424. If the argument is a channel, it will be replaced by the
  425. value read from the channel before the method executes.
  426. resolved_kwargs: The keyword arguments to the method. Currently, we
  427. do not support binding kwargs to other DAG nodes, so the values
  428. of the dictionary cannot be Channels.
  429. """
  430. from ray.dag import CollectiveOutputNode
  431. self.method_name = task.dag_node.get_method_name()
  432. self.bind_index = task.dag_node._get_bind_index()
  433. self.output_channels = task.output_channels
  434. self.output_idxs = task.output_idxs
  435. self.input_type_hints: List[ChannelOutputType] = task.arg_type_hints
  436. self.output_type_hint: ChannelOutputType = task.dag_node.type_hint
  437. # The accelerator collective operation.
  438. self.collective_op: Optional["ray.dag.CollectiveOperation"] = None
  439. if isinstance(task.dag_node, CollectiveOutputNode):
  440. self.collective_op = task.dag_node.collective_op
  441. self.input_channels: List[ChannelInterface] = []
  442. self.task_inputs: List[_ExecutableTaskInput] = []
  443. self.resolved_kwargs: Dict[str, Any] = resolved_kwargs
  444. # A unique index which can be used to index into `idx_to_task` to get
  445. # the corresponding task.
  446. self.task_idx = task.idx
  447. # Reverse map for input_channels: maps an input channel to
  448. # its index in input_channels.
  449. input_channel_to_idx: dict[ChannelInterface, int] = {}
  450. for arg in resolved_args:
  451. if isinstance(arg, ChannelInterface):
  452. channel = arg
  453. if channel in input_channel_to_idx:
  454. # The same channel was added before, so reuse the index.
  455. channel_idx = input_channel_to_idx[channel]
  456. else:
  457. # Add a new channel to the list of input channels.
  458. self.input_channels.append(channel)
  459. channel_idx = len(self.input_channels) - 1
  460. input_channel_to_idx[channel] = channel_idx
  461. task_input = _ExecutableTaskInput(arg, channel_idx)
  462. else:
  463. task_input = _ExecutableTaskInput(arg, None)
  464. self.task_inputs.append(task_input)
  465. # Currently DAGs do not support binding kwargs to other DAG nodes.
  466. for val in self.resolved_kwargs.values():
  467. assert not isinstance(val, ChannelInterface)
  468. # Input reader to read input data from upstream DAG nodes.
  469. self.input_reader: ReaderInterface = SynchronousReader(self.input_channels)
  470. # Output writer to write output data to downstream DAG nodes.
  471. self.output_writer: WriterInterface = SynchronousWriter(
  472. self.output_channels, self.output_idxs
  473. )
  474. # The intermediate future for a READ or COMPUTE operation,
  475. # and `wait()` must be called to get the actual result of the operation.
  476. # The result of a READ operation will be used by a COMPUTE operation,
  477. # and the result of a COMPUTE operation will be used by a WRITE operation.
  478. self._intermediate_future: Optional[DAGOperationFuture] = None
  479. def cancel(self):
  480. """
  481. Close all the input channels and the output channel. The exact behavior
  482. depends on the type of channel. Typically, it will release the resources
  483. used by the channels.
  484. """
  485. self.input_reader.close()
  486. self.output_writer.close()
  487. def destroy_cuda_event(self):
  488. """
  489. If this executable task has created a GPU future that is not yet waited on,
  490. that future is in the channel context cache. Remove the future from the cache
  491. and destroy its CUDA event.
  492. """
  493. GPUFuture.remove_gpu_future(self.task_idx)
  494. def prepare(self, overlap_gpu_communication: bool = False):
  495. """
  496. Prepare the task for execution. The `exec_operation` function can only
  497. be called after `prepare` has been called.
  498. Args:
  499. overlap_gpu_communication: Whether to overlap GPU communication with
  500. computation during DAG execution to improve performance
  501. """
  502. for typ_hint in self.input_type_hints:
  503. typ_hint.register_custom_serializer()
  504. self.output_type_hint.register_custom_serializer()
  505. self.input_reader.start()
  506. self.output_writer.start()
  507. # Stream context type are different between different accelerators.
  508. # Type hint is not applicable here.
  509. self._send_stream = nullcontext()
  510. self._recv_stream = nullcontext()
  511. if not overlap_gpu_communication:
  512. return
  513. # Set up send_stream and recv_stream when overlap_gpu_communication
  514. # is configured
  515. if self.output_type_hint.requires_accelerator():
  516. comm_group_id = _get_comm_group_id(self.output_type_hint)
  517. comm_group = ChannelContext.get_current().communicators.get(comm_group_id)
  518. assert comm_group is not None
  519. self._send_stream = comm_group.send_stream
  520. if self.input_type_hints:
  521. for type_hint in self.input_type_hints:
  522. if type_hint.requires_accelerator():
  523. comm_group_id = _get_comm_group_id(type_hint)
  524. comm_group = ChannelContext.get_current().communicators.get(
  525. comm_group_id
  526. )
  527. assert comm_group is not None
  528. if not isinstance(self._recv_stream, nullcontext):
  529. assert self._recv_stream == comm_group.recv_stream, (
  530. "Currently all torch tensor input channels of a "
  531. "Compiled Graph task should use the same recv cuda stream."
  532. )
  533. self._recv_stream = comm_group.recv_stream
  534. def wrap_and_set_intermediate_future(
  535. self, val: Any, wrap_in_gpu_future: bool
  536. ) -> None:
  537. """
  538. Wrap the value in a `DAGOperationFuture` and store to the intermediate future.
  539. The value corresponds to result of a READ or COMPUTE operation.
  540. If wrap_in_gpu_future is True, the value will be wrapped in a GPUFuture,
  541. Otherwise, the future will be a ResolvedFuture.
  542. Args:
  543. val: The value to wrap in a future.
  544. wrap_in_gpu_future: Whether to wrap the value in a GPUFuture.
  545. """
  546. assert self._intermediate_future is None
  547. if wrap_in_gpu_future:
  548. future = GPUFuture(val, self.task_idx)
  549. else:
  550. future = ResolvedFuture(val)
  551. self._intermediate_future = future
  552. def reset_and_wait_intermediate_future(self) -> Any:
  553. """
  554. Reset the intermediate future and wait for the result.
  555. The wait does not block the CPU because:
  556. - If the future is a ResolvedFuture, the result is immediately returned.
  557. - If the future is a GPUFuture, the result is only waited by the current
  558. CUDA stream, and the CPU is not blocked.
  559. Returns:
  560. The result of a READ or COMPUTE operation from the intermediate future.
  561. """
  562. future = self._intermediate_future
  563. self._intermediate_future = None
  564. return future.wait()
  565. def _read(self, overlap_gpu_communication: bool) -> bool:
  566. """
  567. Read input data from upstream DAG nodes and cache the intermediate result.
  568. Args:
  569. overlap_gpu_communication: Whether to overlap GPU communication with
  570. computation during DAG execution to improve performance.
  571. Returns:
  572. True if system error occurs and exit the loop; otherwise, False.
  573. """
  574. assert self._intermediate_future is None
  575. exit = False
  576. try:
  577. input_data = self.input_reader.read()
  578. # When overlap_gpu_communication is enabled, wrap the result in
  579. # a GPUFuture so that this read operation (communication) can
  580. # be overlapped with computation.
  581. self.wrap_and_set_intermediate_future(
  582. input_data,
  583. wrap_in_gpu_future=overlap_gpu_communication,
  584. )
  585. except RayChannelError:
  586. # Channel closed. Exit the loop.
  587. exit = True
  588. return exit
  589. def _compute(
  590. self,
  591. overlap_gpu_communication: bool,
  592. class_handle,
  593. ) -> bool:
  594. """
  595. Retrieve the intermediate result from the READ operation and perform the
  596. computation. Then, cache the new intermediate result. The caller must ensure
  597. that the last operation executed is READ so that the function retrieves the
  598. correct intermediate result.
  599. Args:
  600. overlap_gpu_communication: Whether to overlap GPU communication with
  601. computation during DAG execution to improve performance.
  602. class_handle: An instance of the class to which the actor belongs. For
  603. example, the type of `class_handle` is <class 'xxxx.Worker'> if the
  604. actor belongs to the `class Worker` class.
  605. Returns:
  606. True if system error occurs and exit the loop; otherwise, False.
  607. """
  608. input_data = self.reset_and_wait_intermediate_future()
  609. try:
  610. _process_return_vals(input_data, return_single_output=False)
  611. except Exception as exc:
  612. # Previous task raised an application-level exception.
  613. # Propagate it and skip the actual task. We don't need to wrap the
  614. # exception in a RayTaskError here because it has already been wrapped
  615. # by the previous task.
  616. self.wrap_and_set_intermediate_future(
  617. exc, wrap_in_gpu_future=overlap_gpu_communication
  618. )
  619. return False
  620. resolved_inputs = []
  621. for task_input in self.task_inputs:
  622. resolved_inputs.append(task_input.resolve(input_data))
  623. if self.collective_op is not None:
  624. # Run an accelerator collective operation.
  625. method = self.collective_op.execute
  626. else:
  627. # Run an actor method.
  628. method = getattr(class_handle, self.method_name)
  629. try:
  630. output_val = method(*resolved_inputs, **self.resolved_kwargs)
  631. except Exception as exc:
  632. output_val = _wrap_exception(exc)
  633. # When overlap_gpu_communication is enabled, wrap the result in a GPUFuture
  634. # so that this compute operation can be overlapped with communication.
  635. self.wrap_and_set_intermediate_future(
  636. output_val, wrap_in_gpu_future=overlap_gpu_communication
  637. )
  638. return False
  639. def _write(self) -> bool:
  640. """
  641. Retrieve the intermediate result from the COMPUTE operation and write to its
  642. downstream DAG nodes. The caller must ensure that the last operation executed
  643. is COMPUTE so that the function retrieves the correct intermediate result.
  644. Returns:
  645. True if system error occurs and exit the loop; otherwise, False.
  646. """
  647. output_val = self.reset_and_wait_intermediate_future()
  648. exit = False
  649. try:
  650. self.output_writer.write(output_val)
  651. except RayChannelError:
  652. # Channel closed. Exit the loop.
  653. exit = True
  654. return exit
  655. def exec_operation(
  656. self,
  657. class_handle,
  658. op_type: _DAGNodeOperationType,
  659. overlap_gpu_communication: bool = False,
  660. ) -> bool:
  661. """
  662. An ExecutableTask corresponds to a DAGNode. It consists of three
  663. operations: READ, COMPUTE, and WRITE, which should be executed in
  664. order to ensure that each operation can read the correct intermediate
  665. result.
  666. Args:
  667. class_handle: The handle of the class to which the actor belongs.
  668. op_type: The type of the operation. Possible types are READ,
  669. COMPUTE, and WRITE.
  670. overlap_gpu_communication: Whether to overlap GPU communication with
  671. computation during DAG execution to improve performance.
  672. Returns:
  673. True if the next operation should not be executed; otherwise, False.
  674. """
  675. if op_type == _DAGNodeOperationType.READ:
  676. with _device_context_manager():
  677. with self._recv_stream:
  678. return self._read(overlap_gpu_communication)
  679. elif op_type == _DAGNodeOperationType.COMPUTE:
  680. return self._compute(overlap_gpu_communication, class_handle)
  681. elif op_type == _DAGNodeOperationType.WRITE:
  682. with _device_context_manager():
  683. with self._send_stream:
  684. return self._write()
  685. @dataclass
  686. class _ExecutableTaskRecord:
  687. actor_classname: str
  688. actor_name: str
  689. actor_id: str
  690. method_name: str
  691. bind_index: int
  692. operation: str
  693. start_t: float
  694. end_t: float
  695. def to_dict(self):
  696. return asdict(self)
  697. @DeveloperAPI
  698. class CompiledDAG:
  699. """Experimental class for accelerated execution.
  700. This class should not be called directly. Instead, create
  701. a ray.dag and call experimental_compile().
  702. See REP https://github.com/ray-project/enhancements/pull/48 for more
  703. information.
  704. """
  705. @ray.remote(num_cpus=0)
  706. class DAGDriverProxyActor:
  707. """
  708. To support the driver as a reader, the output writer needs to be able to invoke
  709. remote functions on the driver. This is necessary so that the output writer can
  710. create a reader ref on the driver node, and later potentially create a larger
  711. reader ref on the driver node if the channel backing store needs to be resized.
  712. However, remote functions cannot be invoked on the driver.
  713. A Compiled Graph creates an actor from this class when the DAG is initialized.
  714. The actor is on the same node as the driver. This class has an empty
  715. implementation, though it serves as a way for the output writer to invoke remote
  716. functions on the driver node.
  717. """
  718. pass
  719. def __init__(
  720. self,
  721. submit_timeout: Optional[float] = None,
  722. buffer_size_bytes: Optional[int] = None,
  723. enable_asyncio: bool = False,
  724. max_inflight_executions: Optional[int] = None,
  725. max_buffered_results: Optional[int] = None,
  726. overlap_gpu_communication: Optional[bool] = None,
  727. default_communicator: Optional[Union[Communicator, str]] = "create",
  728. ):
  729. """
  730. Args:
  731. submit_timeout: The maximum time in seconds to wait for execute() calls.
  732. None means using default timeout (DAGContext.submit_timeout),
  733. 0 means immediate timeout (immediate success or timeout without
  734. blocking), -1 means infinite timeout (block indefinitely).
  735. buffer_size_bytes: The initial buffer size in bytes for messages
  736. that can be passed between tasks in the DAG. The buffers will
  737. be automatically resized if larger messages are written to the
  738. channel.
  739. enable_asyncio: Whether to enable asyncio. If enabled, caller must
  740. be running in an event loop and must use `execute_async` to
  741. invoke the DAG. Otherwise, the caller should use `execute` to
  742. invoke the DAG.
  743. max_inflight_executions: The maximum number of in-flight executions that
  744. can be submitted via `execute` or `execute_async` before consuming
  745. the output using `ray.get()`. If the caller submits more executions,
  746. `RayCgraphCapacityExceeded` is raised.
  747. max_buffered_results: The maximum number of results that can be
  748. buffered at the driver. If more results are buffered,
  749. `RayCgraphCapacityExceeded` is raised. Note that
  750. when result corresponding to an execution is retrieved
  751. (by calling `ray.get()` on a `CompiledDAGRef` or
  752. `CompiledDAGRef` or await on a `CompiledDAGFuture), results
  753. corresponding to earlier executions that have not been retrieved
  754. yet are buffered.
  755. overlap_gpu_communication: (experimental) Whether to overlap GPU
  756. communication with computation during DAG execution. If True, the
  757. communication and computation can be overlapped, which can improve
  758. the performance of the DAG execution. If None, the default value
  759. will be used.
  760. _default_communicator: The default communicator to use to transfer
  761. tensors. Three types of values are valid. (1) Communicator:
  762. For p2p operations, this is the default communicator
  763. to use for nodes annotated with `with_tensor_transport()` and when
  764. shared memory is not the desired option (e.g., when transport="accelerator",
  765. or when transport="auto" for communication between two different GPUs).
  766. For collective operations, this is the default communicator to use
  767. when a custom communicator is not specified.
  768. (2) "create": for each collective operation without a custom communicator
  769. specified, a communicator is created and initialized on its involved actors,
  770. or an already created communicator is reused if the set of actors is the same.
  771. For all p2p operations without a custom communicator specified, it reuses
  772. an already created collective communicator if the p2p actors are a subset.
  773. Otherwise, a new communicator is created.
  774. (3) None: a ValueError will be thrown if a custom communicator is not specified.
  775. Returns:
  776. Channel: A wrapper around ray.ObjectRef.
  777. """
  778. from ray.dag import DAGContext
  779. ctx = DAGContext.get_current()
  780. self._enable_asyncio: bool = enable_asyncio
  781. self._fut_queue = asyncio.Queue()
  782. self._max_inflight_executions = max_inflight_executions
  783. if self._max_inflight_executions is None:
  784. self._max_inflight_executions = ctx.max_inflight_executions
  785. self._max_buffered_results = max_buffered_results
  786. if self._max_buffered_results is None:
  787. self._max_buffered_results = ctx.max_buffered_results
  788. self._dag_id = uuid.uuid4().hex
  789. self._submit_timeout: Optional[float] = submit_timeout
  790. if self._submit_timeout is None:
  791. self._submit_timeout = ctx.submit_timeout
  792. self._get_timeout: Optional[float] = ctx.get_timeout
  793. self._buffer_size_bytes: Optional[int] = buffer_size_bytes
  794. if self._buffer_size_bytes is None:
  795. self._buffer_size_bytes = ctx.buffer_size_bytes
  796. self._overlap_gpu_communication: Optional[bool] = overlap_gpu_communication
  797. if self._overlap_gpu_communication is None:
  798. self._overlap_gpu_communication = ctx.overlap_gpu_communication
  799. self._create_default_communicator = False
  800. if isinstance(default_communicator, str):
  801. if default_communicator == "create":
  802. self._create_default_communicator = True
  803. default_communicator = None
  804. else:
  805. raise ValueError(
  806. "The only allowed string for default_communicator is 'create', "
  807. f"got {default_communicator}"
  808. )
  809. elif default_communicator is not None and not isinstance(
  810. default_communicator, Communicator
  811. ):
  812. raise ValueError(
  813. "The default_communicator must be None, a string, or a Communicator, "
  814. f"got {type(default_communicator)}"
  815. )
  816. self._default_communicator: Optional[Communicator] = default_communicator
  817. # Dict from passed-in communicator to set of type hints that refer to it.
  818. self._communicator_to_type_hints: Dict[
  819. Communicator,
  820. Set["ray.experimental.channel.torch_tensor_type.TorchTensorType"],
  821. ] = defaultdict(set)
  822. # Dict from set of actors to created communicator ID.
  823. # These communicators are created by Compiled Graph, rather than passed in.
  824. # Communicators are only created when self._create_default_communicator is True.
  825. self._actors_to_created_communicator_id: Dict[
  826. Tuple["ray.actor.ActorHandle"], str
  827. ] = {}
  828. # Set of actors involved in P2P communication using an unresolved communicator.
  829. self._p2p_actors_with_unresolved_communicators: Set[
  830. "ray.actor.ActorHandle"
  831. ] = set()
  832. # Set of DAG nodes involved in P2P communication using an unresolved communicator.
  833. self._p2p_dag_nodes_with_unresolved_communicators: Set[
  834. "ray.dag.DAGNode"
  835. ] = set()
  836. # Set of collective operations using an unresolved communicator.
  837. self._collective_ops_with_unresolved_communicators: Set[
  838. "ray.dag.collective_node._CollectiveOperation"
  839. ] = set()
  840. self._default_type_hint: ChannelOutputType = SharedMemoryType(
  841. buffer_size_bytes=self._buffer_size_bytes,
  842. # We conservatively set num_shm_buffers to _max_inflight_executions.
  843. # It means that the DAG can be underutilized, but it guarantees there's
  844. # no false positive timeouts.
  845. num_shm_buffers=self._max_inflight_executions,
  846. )
  847. if not isinstance(self._buffer_size_bytes, int) or self._buffer_size_bytes <= 0:
  848. raise ValueError(
  849. "`buffer_size_bytes` must be a positive integer, found "
  850. f"{self._buffer_size_bytes}"
  851. )
  852. # Used to ensure that the future returned to the
  853. # caller corresponds to the correct DAG output. I.e.
  854. # order of futures added to fut_queue should match the
  855. # order of inputs written to the DAG.
  856. self._dag_submission_lock = asyncio.Lock()
  857. # idx -> CompiledTask.
  858. self.idx_to_task: Dict[int, "CompiledTask"] = {}
  859. # DAGNode -> idx.
  860. self.dag_node_to_idx: Dict["ray.dag.DAGNode", int] = {}
  861. # idx counter.
  862. self.counter: int = 0
  863. # Attributes that are set during preprocessing.
  864. # Preprocessing identifies the input node and output node.
  865. self.input_task_idx: Optional[int] = None
  866. self.output_task_idx: Optional[int] = None
  867. # List of task indices that are input attribute nodes.
  868. self.input_attr_task_idxs: List[int] = []
  869. # Denotes whether execute/execute_async returns a list of refs/futures.
  870. self._returns_list: bool = False
  871. # Number of expected positional args and kwargs that may be passed to
  872. # dag.execute.
  873. self._input_num_positional_args: Optional[int] = None
  874. self._input_kwargs: Tuple[str, ...] = None
  875. # Cached attributes that are set during compilation.
  876. self.dag_input_channels: Optional[List[ChannelInterface]] = None
  877. self.dag_output_channels: Optional[List[ChannelInterface]] = None
  878. self._dag_submitter: Optional[WriterInterface] = None
  879. self._dag_output_fetcher: Optional[ReaderInterface] = None
  880. # ObjectRef for each worker's task. The task is an infinite loop that
  881. # repeatedly executes the method specified in the DAG.
  882. self.worker_task_refs: Dict["ray.actor.ActorHandle", "ray.ObjectRef"] = {}
  883. self.actor_to_tasks: Dict[
  884. "ray.actor.ActorHandle", List["CompiledTask"]
  885. ] = defaultdict(list)
  886. # Mapping from actor handle to its GPU IDs.
  887. # This is used for type hint resolution for with_tensor_transport("auto").
  888. self.actor_to_gpu_ids: Dict["ray.actor.ActorHandle", List[str]] = {}
  889. self.actor_to_executable_tasks: Dict[
  890. "ray.actor.ActorHandle", List["ExecutableTask"]
  891. ] = {}
  892. # Mapping from the actor handle to the execution schedule which is a list
  893. # of operations to be executed.
  894. self.actor_to_execution_schedule: Dict[
  895. "ray.actor.ActorHandle", List[_DAGNodeOperation]
  896. ] = defaultdict(list)
  897. # Mapping from the actor handle to the node ID that the actor is on.
  898. # A None actor handle means the actor is the driver.
  899. self.actor_to_node_id: Dict[Optional["ray.actor.ActorHandle"], str] = {}
  900. # The index of the current execution. It is incremented each time
  901. # the DAG is executed.
  902. self._execution_index: int = -1
  903. # The maximum index of finished executions.
  904. # All results with higher indexes have not been generated yet.
  905. self._max_finished_execution_index: int = -1
  906. # execution_index -> {channel_index -> result}
  907. self._result_buffer: Dict[int, Dict[int, Any]] = defaultdict(dict)
  908. # channel to possible inner channel
  909. self._channel_dict: Dict[ChannelInterface, ChannelInterface] = {}
  910. def _create_proxy_actor() -> "ray.actor.ActorHandle":
  911. # Creates the driver actor on the same node as the driver.
  912. #
  913. # To support the driver as a reader, the output writer needs to be able to
  914. # invoke remote functions on the driver (e.g., to create the reader ref, to
  915. # create a reader ref for a larger object when the channel backing store is
  916. # resized, etc.). The driver actor serves as a way for the output writer
  917. # to invoke remote functions on the driver node.
  918. return CompiledDAG.DAGDriverProxyActor.options(
  919. scheduling_strategy=NodeAffinitySchedulingStrategy(
  920. ray.get_runtime_context().get_node_id(), soft=False
  921. )
  922. ).remote()
  923. self._proxy_actor = _create_proxy_actor()
  924. # Set to True when `teardown` API is called.
  925. self._is_teardown = False
  926. # Execution index to set of channel indices for CompiledDAGRefs
  927. # or CompiledDAGFuture whose destructor has been called. A "None"
  928. # channel index means there is only one channel, and its destructor
  929. # has been called.
  930. self._destructed_ref_idxs: Dict[int, Set[Optional[int]]] = dict()
  931. # Execution index to set of channel indices for CompiledDAGRefs
  932. # or CompiledDAGFuture whose get() has been called. A "None"
  933. # channel index means there is only one channel, and its get()
  934. # has been called.
  935. self._got_ref_idxs: Dict[int, Set[Optional[int]]] = dict()
  936. @property
  937. def is_teardown(self) -> bool:
  938. return self._is_teardown
  939. def get_id(self) -> str:
  940. """
  941. Get the unique ID of the compiled DAG.
  942. """
  943. return self._dag_id
  944. def __str__(self) -> str:
  945. return f"CompiledDAG({self._dag_id})"
  946. def _add_node(self, node: "ray.dag.DAGNode") -> None:
  947. idx = self.counter
  948. self.idx_to_task[idx] = CompiledTask(idx, node)
  949. self.dag_node_to_idx[node] = idx
  950. self.counter += 1
  951. def _preprocess(self) -> None:
  952. """Before compiling, preprocess the DAG to build an index from task to
  953. upstream and downstream tasks, and to set the input and output node(s)
  954. of the DAG.
  955. This function is idempotent.
  956. """
  957. from ray.dag import (
  958. ClassMethodNode,
  959. CollectiveOutputNode,
  960. DAGNode,
  961. FunctionNode,
  962. InputAttributeNode,
  963. InputNode,
  964. MultiOutputNode,
  965. )
  966. self.input_task_idx, self.output_task_idx = None, None
  967. input_attributes: Set[str] = set()
  968. # Find the input node and input attribute nodes in the DAG.
  969. for idx, task in self.idx_to_task.items():
  970. if isinstance(task.dag_node, InputNode):
  971. assert self.input_task_idx is None, "More than one InputNode found"
  972. self.input_task_idx = idx
  973. # handle_unused_attributes:
  974. # Save input attributes in a set.
  975. input_node = task.dag_node
  976. input_attributes.update(input_node.input_attribute_nodes.keys())
  977. elif isinstance(task.dag_node, InputAttributeNode):
  978. self.input_attr_task_idxs.append(idx)
  979. # Find the (multi-)output node to the DAG.
  980. for idx, task in self.idx_to_task.items():
  981. if idx == self.input_task_idx or isinstance(
  982. task.dag_node, InputAttributeNode
  983. ):
  984. continue
  985. if (
  986. len(task.downstream_task_idxs) == 0
  987. and task.dag_node.is_cgraph_output_node
  988. ):
  989. assert self.output_task_idx is None, "More than one output node found"
  990. self.output_task_idx = idx
  991. assert self.output_task_idx is not None
  992. output_node = self.idx_to_task[self.output_task_idx].dag_node
  993. # Add an MultiOutputNode to the end of the DAG if it's not already there.
  994. if not isinstance(output_node, MultiOutputNode):
  995. output_node = MultiOutputNode([output_node])
  996. self._add_node(output_node)
  997. self.output_task_idx = self.dag_node_to_idx[output_node]
  998. else:
  999. self._returns_list = True
  1000. # TODO: Support no-input DAGs (use an empty object to signal).
  1001. if self.input_task_idx is None:
  1002. raise NotImplementedError(
  1003. "Compiled DAGs currently require exactly one InputNode"
  1004. )
  1005. # Whether the DAG binds directly to the InputNode(), versus binding to
  1006. # a positional arg or kwarg of the input. For example, a.foo.bind(inp)
  1007. # instead of a.foo.bind(inp[0]) or a.foo.bind(inp.key).
  1008. direct_input: Optional[bool] = None
  1009. # Collect the set of InputNode keys bound to DAG node args.
  1010. input_positional_args: Set[int] = set()
  1011. input_kwargs: Set[str] = set()
  1012. # Set of tasks with annotation of with_tensor_transport("auto").
  1013. # These only correspond to ClassMethodNodes, but not InputNodes
  1014. # or InputAttributeNodes.
  1015. auto_transport_tasks: Set["CompiledTask"] = set()
  1016. # For each task node, set its upstream and downstream task nodes.
  1017. # Also collect the set of tasks that produce torch.tensors.
  1018. for task_idx, task in self.idx_to_task.items():
  1019. dag_node = task.dag_node
  1020. if not (
  1021. isinstance(dag_node, InputNode)
  1022. or isinstance(dag_node, InputAttributeNode)
  1023. or isinstance(dag_node, MultiOutputNode)
  1024. or isinstance(dag_node, ClassMethodNode)
  1025. ):
  1026. if isinstance(dag_node, FunctionNode):
  1027. # TODO(swang): Support non-actor tasks.
  1028. raise NotImplementedError(
  1029. "Compiled DAGs currently only support actor method nodes"
  1030. )
  1031. else:
  1032. raise ValueError(f"Found unsupported node of type {type(dag_node)}")
  1033. if isinstance(dag_node, ClassMethodNode) and dag_node.is_class_method_call:
  1034. actor_handle = dag_node._get_actor_handle()
  1035. if actor_handle is None:
  1036. raise ValueError(
  1037. "Compiled DAGs can only bind methods to an actor "
  1038. "that is already created with Actor.remote()"
  1039. )
  1040. if actor_handle not in self.actor_to_gpu_ids:
  1041. self.actor_to_gpu_ids[actor_handle] = CompiledDAG._get_gpu_ids(
  1042. actor_handle
  1043. )
  1044. if isinstance(dag_node.type_hint, AutoTransportType):
  1045. auto_transport_tasks.add(task)
  1046. # Collect actors for accelerator P2P methods.
  1047. if dag_node.type_hint.requires_accelerator():
  1048. self._track_communicator_usage(dag_node, {actor_handle})
  1049. # Collect accelerator collective operations.
  1050. if isinstance(dag_node, CollectiveOutputNode):
  1051. self._track_communicator_usage(
  1052. dag_node,
  1053. set(dag_node._collective_op.actor_handles),
  1054. collective_op=True,
  1055. )
  1056. assert not self._overlap_gpu_communication, (
  1057. "Currently, the overlap_gpu_communication option is not "
  1058. "supported for accelerator collective operations. Please set "
  1059. "overlap_gpu_communication=False."
  1060. )
  1061. elif isinstance(dag_node, InputNode) or isinstance(
  1062. dag_node, InputAttributeNode
  1063. ):
  1064. if dag_node.type_hint.requires_accelerator():
  1065. raise ValueError(
  1066. "DAG inputs cannot be transferred via accelerator because "
  1067. "the driver cannot participate in the communicator group"
  1068. )
  1069. if isinstance(dag_node.type_hint, AutoTransportType):
  1070. # Currently driver on GPU is not supported, so we always
  1071. # use shared memory to transfer tensors.
  1072. dag_node.type_hint = TorchTensorType(
  1073. device=dag_node.type_hint.device
  1074. )
  1075. if type(dag_node.type_hint) is ChannelOutputType:
  1076. # No type hint specified by the user. Replace
  1077. # with the default type hint for this DAG.
  1078. dag_node.type_hint = self._default_type_hint
  1079. for _, val in task.kwargs.items():
  1080. if isinstance(val, DAGNode):
  1081. raise ValueError(
  1082. "Compiled DAG currently does not support binding to "
  1083. "other DAG nodes as kwargs"
  1084. )
  1085. for _, arg in enumerate(task.args):
  1086. if not isinstance(arg, DAGNode):
  1087. continue
  1088. upstream_node_idx = self.dag_node_to_idx[arg]
  1089. upstream_task = self.idx_to_task[upstream_node_idx]
  1090. downstream_actor_handle = None
  1091. if (
  1092. isinstance(dag_node, ClassMethodNode)
  1093. and dag_node.is_class_method_call
  1094. ):
  1095. downstream_actor_handle = dag_node._get_actor_handle()
  1096. # Add upstream node as the argument nodes of this task, whose
  1097. # type hints may be updated when resolved lazily.
  1098. task.arg_nodes.append(upstream_task.dag_node)
  1099. if isinstance(upstream_task.dag_node, InputAttributeNode):
  1100. # Record all of the keys used to index the InputNode.
  1101. # During execution, we will check that the user provides
  1102. # the same args and kwargs.
  1103. if isinstance(upstream_task.dag_node.key, int):
  1104. input_positional_args.add(upstream_task.dag_node.key)
  1105. elif isinstance(upstream_task.dag_node.key, str):
  1106. input_kwargs.add(upstream_task.dag_node.key)
  1107. else:
  1108. raise ValueError(
  1109. "InputNode() can only be indexed using int "
  1110. "for positional args or str for kwargs."
  1111. )
  1112. if direct_input is not None and direct_input:
  1113. raise ValueError(
  1114. "All tasks must either use InputNode() "
  1115. "directly, or they must index to specific args or "
  1116. "kwargs."
  1117. )
  1118. direct_input = False
  1119. # If the upstream node is an InputAttributeNode, treat the
  1120. # DAG's input node as the actual upstream node
  1121. upstream_task = self.idx_to_task[self.input_task_idx]
  1122. elif isinstance(upstream_task.dag_node, InputNode):
  1123. if direct_input is not None and not direct_input:
  1124. raise ValueError(
  1125. "All tasks must either use InputNode() directly, "
  1126. "or they must index to specific args or kwargs."
  1127. )
  1128. direct_input = True
  1129. upstream_task.downstream_task_idxs[task_idx] = downstream_actor_handle
  1130. if upstream_task.dag_node.type_hint.requires_accelerator():
  1131. # Here we are processing the args of the DAGNode, so track
  1132. # downstream actors only, upstream actor is already tracked
  1133. # when processing the DAGNode itself.
  1134. self._track_communicator_usage(
  1135. upstream_task.dag_node,
  1136. {downstream_actor_handle},
  1137. )
  1138. # Check that all specified input attributes, e.g., InputNode()["x"],
  1139. # are used in the DAG.
  1140. _check_unused_dag_input_attributes(output_node, input_attributes)
  1141. self._check_leaf_nodes()
  1142. self._resolve_auto_transport(auto_transport_tasks)
  1143. self._init_communicators()
  1144. if direct_input:
  1145. self._input_num_positional_args = 1
  1146. elif not input_positional_args:
  1147. self._input_num_positional_args = 0
  1148. else:
  1149. self._input_num_positional_args = max(input_positional_args) + 1
  1150. self._input_kwargs = tuple(input_kwargs)
  1151. def _init_communicators(self) -> None:
  1152. """
  1153. Initialize communicators for the DAG.
  1154. """
  1155. # First, initialize communicators that are passed in by the user.
  1156. for communicator, type_hints in self._communicator_to_type_hints.items():
  1157. communicator_id = _init_communicator(
  1158. communicator.get_actor_handles(),
  1159. communicator,
  1160. self._overlap_gpu_communication,
  1161. )
  1162. for type_hint in type_hints:
  1163. type_hint.set_communicator_id(communicator_id)
  1164. # Second, get registered accelerator context if any.
  1165. accelerator_module_name = AcceleratorContext.get().module_name
  1166. accelerator_communicator_cls = AcceleratorContext.get().communicator_cls
  1167. # Then, create communicators for collective operations.
  1168. # Reuse an already created communicator for the same set of actors.
  1169. for collective_op in self._collective_ops_with_unresolved_communicators:
  1170. if not self._create_default_communicator:
  1171. raise ValueError(
  1172. "Communicator creation is not allowed for collective operations."
  1173. )
  1174. # using tuple to preserve the order of actors for collective operations
  1175. actors = tuple(collective_op.actor_handles)
  1176. if actors in self._actors_to_created_communicator_id:
  1177. communicator_id = self._actors_to_created_communicator_id[actors]
  1178. else:
  1179. communicator_id = _init_communicator(
  1180. list(actors),
  1181. None,
  1182. self._overlap_gpu_communication,
  1183. accelerator_module_name,
  1184. accelerator_communicator_cls,
  1185. )
  1186. self._actors_to_created_communicator_id[actors] = communicator_id
  1187. collective_op.type_hint.set_communicator_id(communicator_id)
  1188. # Finally, create a communicator for P2P operations.
  1189. # Reuse an already created collective op communicator when p2p actors
  1190. # are a subset of the actors in the collective op communicator.
  1191. p2p_communicator_id = None
  1192. if self._p2p_actors_with_unresolved_communicators:
  1193. for (
  1194. actors,
  1195. communicator_id,
  1196. ) in self._actors_to_created_communicator_id.items():
  1197. if self._p2p_actors_with_unresolved_communicators.issubset(actors):
  1198. p2p_communicator_id = communicator_id
  1199. break
  1200. if p2p_communicator_id is None:
  1201. p2p_communicator_id = _init_communicator(
  1202. list(self._p2p_actors_with_unresolved_communicators),
  1203. None,
  1204. self._overlap_gpu_communication,
  1205. accelerator_module_name,
  1206. accelerator_communicator_cls,
  1207. )
  1208. for dag_node in self._p2p_dag_nodes_with_unresolved_communicators:
  1209. dag_node.type_hint.set_communicator_id(p2p_communicator_id)
  1210. def _track_communicator_usage(
  1211. self,
  1212. dag_node: "ray.dag.DAGNode",
  1213. actors: Set["ray.actor.ActorHandle"],
  1214. collective_op: bool = False,
  1215. ) -> None:
  1216. """
  1217. Track the usage of a communicator.
  1218. This method first determines the communicator to use: if a custom
  1219. communicator is specified, use it; if not and a default communicator
  1220. is available, use it; otherwise, it records necessary information to
  1221. create a new communicator later.
  1222. This method also performs validation checks on the passed-in communicator.
  1223. Args:
  1224. dag_node: The DAG node that uses the communicator, this is the node
  1225. that has the `with_tensor_transport()` type hint for p2p communication,
  1226. or a `CollectiveOutputNode` for collective operations.
  1227. actors: The full or partial set of actors that use the communicator.
  1228. This method should be called one or multiple times so that all actors
  1229. of the communicator are tracked.
  1230. collective_op: Whether the communicator is used for a collective operation.
  1231. """
  1232. if None in actors:
  1233. raise ValueError("Driver cannot participate in the communicator group.")
  1234. if collective_op:
  1235. type_hint = dag_node._collective_op.type_hint
  1236. else:
  1237. type_hint = dag_node.type_hint
  1238. communicator = type_hint.get_custom_communicator()
  1239. if communicator is None:
  1240. if (
  1241. self._default_communicator is None
  1242. and not self._create_default_communicator
  1243. ):
  1244. if dag_node._original_type_hint is not None:
  1245. assert isinstance(dag_node._original_type_hint, AutoTransportType)
  1246. raise ValueError(
  1247. f"with_tensor_transport(transport='auto') is used for DAGNode {dag_node}, "
  1248. "This requires specifying a default communicator or 'create' for "
  1249. "_default_communicator when calling experimental_compile()."
  1250. )
  1251. raise ValueError(
  1252. f"DAGNode {dag_node} has no custom communicator specified. "
  1253. "Please specify a custom communicator for the DAGNode using "
  1254. "`with_tensor_transport()`, or specify a communicator or 'create' for "
  1255. "_default_communicator when calling experimental_compile()."
  1256. )
  1257. communicator = self._default_communicator
  1258. if communicator is None:
  1259. if collective_op:
  1260. self._collective_ops_with_unresolved_communicators.add(
  1261. dag_node._collective_op
  1262. )
  1263. else:
  1264. self._p2p_dag_nodes_with_unresolved_communicators.add(dag_node)
  1265. self._p2p_actors_with_unresolved_communicators.update(actors)
  1266. else:
  1267. if collective_op:
  1268. if set(communicator.get_actor_handles()) != actors:
  1269. raise ValueError(
  1270. "The passed-in communicator must have the same set "
  1271. "of actors as the collective operation. "
  1272. f"The passed-in communicator has actors {communicator.get_actor_handles()} "
  1273. f"while the collective operation has actors {actors}."
  1274. )
  1275. else:
  1276. if not actors.issubset(set(communicator.get_actor_handles())):
  1277. raise ValueError(
  1278. "The passed-in communicator must include all of the actors "
  1279. "used in the P2P operation. "
  1280. f"The passed-in communicator has actors {communicator.get_actor_handles()} "
  1281. f"while the P2P operation has actors {actors}."
  1282. )
  1283. self._communicator_to_type_hints[communicator].add(type_hint)
  1284. def _resolve_auto_transport(
  1285. self,
  1286. auto_transport_tasks: Set["CompiledTask"],
  1287. ) -> None:
  1288. """
  1289. Resolve the auto transport type hint for the DAG.
  1290. """
  1291. type_hint_resolver = TypeHintResolver(self.actor_to_gpu_ids)
  1292. # Resolve AutoChannelType type hints and track the actors that use accelerator.
  1293. # This is needed so that the communicator group can be initialized for
  1294. # these actors that use accelerator.
  1295. for task in auto_transport_tasks:
  1296. writer = task.dag_node._get_actor_handle()
  1297. readers = task.downstream_task_idxs.values()
  1298. writer_and_node = (writer, self._get_node_id(writer))
  1299. reader_and_node_list = [
  1300. (reader, self._get_node_id(reader)) for reader in readers
  1301. ]
  1302. # Update the type hint to the resolved one. This is needed because
  1303. # the resolved type hint's `register_custom_serializer` will be called
  1304. # in preparation for channel I/O.
  1305. task.dag_node.type_hint = type_hint_resolver.resolve(
  1306. task.dag_node.type_hint,
  1307. writer_and_node,
  1308. reader_and_node_list,
  1309. )
  1310. if task.dag_node.type_hint.requires_accelerator():
  1311. self._track_communicator_usage(
  1312. task.dag_node,
  1313. set(readers).union({writer}),
  1314. )
  1315. def _check_leaf_nodes(self) -> None:
  1316. """
  1317. Check if there are leaf nodes in the DAG and raise an error if there are.
  1318. """
  1319. from ray.dag import (
  1320. ClassMethodNode,
  1321. DAGNode,
  1322. )
  1323. leaf_nodes: List[DAGNode] = []
  1324. for _, task in self.idx_to_task.items():
  1325. if not isinstance(task.dag_node, ClassMethodNode):
  1326. continue
  1327. if (
  1328. len(task.downstream_task_idxs) == 0
  1329. and not task.dag_node.is_cgraph_output_node
  1330. ):
  1331. leaf_nodes.append(task.dag_node)
  1332. # Leaf nodes are not allowed because the exception thrown by the leaf
  1333. # node will not be propagated to the driver.
  1334. if len(leaf_nodes) != 0:
  1335. raise ValueError(
  1336. "Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have "
  1337. "downstream nodes and are not output nodes. There are "
  1338. f"{len(leaf_nodes)} leaf nodes in the DAG. Please add the outputs of "
  1339. f"{[leaf_node.get_method_name() for leaf_node in leaf_nodes]} to the "
  1340. f"the MultiOutputNode."
  1341. )
  1342. @staticmethod
  1343. def _get_gpu_ids(actor_handle: "ray.actor.ActorHandle") -> List[str]:
  1344. """
  1345. Get the GPU IDs of an actor handle.
  1346. """
  1347. accelerator_ids = ray.get(
  1348. actor_handle.__ray_call__.remote(
  1349. lambda self: ray.get_runtime_context().get_accelerator_ids()
  1350. )
  1351. )
  1352. return accelerator_ids.get("GPU", [])
  1353. def _get_node_id(self, actor_handle: Optional["ray.actor.ActorHandle"]) -> str:
  1354. """
  1355. Get the node ID of an actor handle and cache it.
  1356. Args:
  1357. actor_handle: The actor handle, or None if the actor handle is the
  1358. driver.
  1359. Returns:
  1360. The node ID of the actor handle or driver.
  1361. """
  1362. if actor_handle in self.actor_to_node_id:
  1363. return self.actor_to_node_id[actor_handle]
  1364. node_id = None
  1365. if actor_handle == self._proxy_actor or actor_handle is None:
  1366. node_id = ray.get_runtime_context().get_node_id()
  1367. else:
  1368. node_id = ray.get(
  1369. actor_handle.__ray_call__.remote(
  1370. lambda self: ray.get_runtime_context().get_node_id()
  1371. )
  1372. )
  1373. self.actor_to_node_id[actor_handle] = node_id
  1374. return node_id
  1375. def _get_or_compile(
  1376. self,
  1377. ) -> None:
  1378. """Compile an execution path. This allocates channels for adjacent
  1379. tasks to send/receive values. An infinite task is submitted to each
  1380. actor in the DAG that repeatedly receives from input channel(s) and
  1381. sends to output channel(s).
  1382. This function is idempotent and will cache the previously allocated
  1383. channels. After calling this function, _dag_submitter and
  1384. _dag_output_fetcher will be set and can be used to invoke and fetch
  1385. outputs for the DAG.
  1386. """
  1387. from ray.dag import (
  1388. ClassMethodNode,
  1389. DAGNode,
  1390. InputAttributeNode,
  1391. InputNode,
  1392. MultiOutputNode,
  1393. )
  1394. if self.input_task_idx is None:
  1395. self._preprocess()
  1396. assert self.input_task_idx is not None
  1397. if self._dag_submitter is not None:
  1398. assert self._dag_output_fetcher is not None
  1399. return
  1400. frontier = [self.input_task_idx]
  1401. visited = set()
  1402. # Create output buffers. This loop does a breadth-first search through the DAG.
  1403. while frontier:
  1404. cur_idx = frontier.pop(0)
  1405. if cur_idx in visited:
  1406. continue
  1407. visited.add(cur_idx)
  1408. task = self.idx_to_task[cur_idx]
  1409. if (
  1410. isinstance(task.dag_node, ClassMethodNode)
  1411. and task.dag_node.is_class_method_call
  1412. ):
  1413. # Create output buffers for the actor method.
  1414. assert len(task.output_channels) == 0
  1415. # `output_to_readers` stores the reader tasks for each output of
  1416. # the current node. If the current node returns one output, the
  1417. # readers are the downstream nodes of the current node. If the
  1418. # current node returns multiple outputs, the readers of each
  1419. # output are the downstream nodes of the ClassMethodNode that
  1420. # is a class method output.
  1421. output_to_readers: Dict[CompiledTask, List[CompiledTask]] = defaultdict(
  1422. list
  1423. )
  1424. for idx in task.downstream_task_idxs:
  1425. downstream_task = self.idx_to_task[idx]
  1426. downstream_node = downstream_task.dag_node
  1427. if (
  1428. isinstance(downstream_node, ClassMethodNode)
  1429. and downstream_node.is_class_method_output
  1430. ):
  1431. output_to_readers[downstream_task] = [
  1432. self.idx_to_task[idx]
  1433. for idx in downstream_task.downstream_task_idxs
  1434. ]
  1435. else:
  1436. if task not in output_to_readers:
  1437. output_to_readers[task] = []
  1438. output_to_readers[task].append(downstream_task)
  1439. fn = task.dag_node._get_remote_method("__ray_call__")
  1440. for output, readers in output_to_readers.items():
  1441. reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]] = []
  1442. # Use reader_handles_set to deduplicate readers on the
  1443. # same actor, because with CachedChannel each actor will
  1444. # only read from the upstream channel once.
  1445. reader_handles_set = set()
  1446. read_by_multi_output_node = False
  1447. for reader in readers:
  1448. if isinstance(reader.dag_node, MultiOutputNode):
  1449. read_by_multi_output_node = True
  1450. # inserting at 0 to make sure driver is first reader as
  1451. # expected by CompositeChannel read
  1452. reader_and_node_list.insert(
  1453. 0,
  1454. (
  1455. self._proxy_actor,
  1456. self._get_node_id(self._proxy_actor),
  1457. ),
  1458. )
  1459. else:
  1460. reader_handle = reader.dag_node._get_actor_handle()
  1461. if reader_handle not in reader_handles_set:
  1462. reader_handle = reader.dag_node._get_actor_handle()
  1463. reader_and_node_list.append(
  1464. (reader_handle, self._get_node_id(reader_handle))
  1465. )
  1466. reader_handles_set.add(reader_handle)
  1467. # if driver is an actual actor, gets driver actor id
  1468. driver_actor_id = (
  1469. ray.get_runtime_context().get_actor_id()
  1470. if read_by_multi_output_node
  1471. else None
  1472. )
  1473. # Create an output channel for each output of the current node.
  1474. output_channel = ray.get(
  1475. fn.remote(
  1476. do_allocate_channel,
  1477. reader_and_node_list,
  1478. task.dag_node.type_hint,
  1479. driver_actor_id,
  1480. )
  1481. )
  1482. output_idx = None
  1483. downstream_node = output.dag_node
  1484. if (
  1485. isinstance(downstream_node, ClassMethodNode)
  1486. and downstream_node.is_class_method_output
  1487. ):
  1488. output_idx = downstream_node.output_idx
  1489. task.output_channels.append(output_channel)
  1490. task.output_idxs.append(output_idx)
  1491. task.output_node_idxs.append(self.dag_node_to_idx[downstream_node])
  1492. actor_handle = task.dag_node._get_actor_handle()
  1493. assert actor_handle is not None
  1494. self.actor_to_tasks[actor_handle].append(task)
  1495. elif (
  1496. isinstance(task.dag_node, ClassMethodNode)
  1497. and task.dag_node.is_class_method_output
  1498. ):
  1499. task_node = task.dag_node
  1500. upstream_node = task_node.class_method_call
  1501. assert upstream_node
  1502. upstream_task = self.idx_to_task[self.dag_node_to_idx[upstream_node]]
  1503. for i in range(len(upstream_task.output_channels)):
  1504. if upstream_task.output_idxs[i] == task_node.output_idx:
  1505. task.output_channels.append(upstream_task.output_channels[i])
  1506. task.output_idxs.append(upstream_task.output_idxs[i])
  1507. assert len(task.output_channels) == 1
  1508. elif isinstance(task.dag_node, InputNode):
  1509. # A dictionary that maps an InputNode or InputAttributeNode to its
  1510. # readers and the node on which the reader is running. Use `set` to
  1511. # deduplicate readers on the same actor because with CachedChannel
  1512. # each actor will only read from the shared memory once.
  1513. input_node_to_reader_and_node_set: Dict[
  1514. Union[InputNode, InputAttributeNode],
  1515. Set[Tuple["ray.actor.ActorHandle", str]],
  1516. ] = defaultdict(set)
  1517. for idx in task.downstream_task_idxs:
  1518. reader_task = self.idx_to_task[idx]
  1519. assert isinstance(reader_task.dag_node, ClassMethodNode)
  1520. reader_handle = reader_task.dag_node._get_actor_handle()
  1521. reader_node_id = self._get_node_id(reader_handle)
  1522. for arg in reader_task.args:
  1523. if isinstance(arg, InputAttributeNode) or isinstance(
  1524. arg, InputNode
  1525. ):
  1526. input_node_to_reader_and_node_set[arg].add(
  1527. (reader_handle, reader_node_id)
  1528. )
  1529. # A single channel is responsible for sending the same data to
  1530. # corresponding consumers. Therefore, we create a channel for
  1531. # each InputAttributeNode, or a single channel for the entire
  1532. # input data if there are no InputAttributeNodes.
  1533. task.output_channels = []
  1534. for input_dag_node in input_node_to_reader_and_node_set:
  1535. reader_and_node_list = list(
  1536. input_node_to_reader_and_node_set[input_dag_node]
  1537. )
  1538. output_channel = do_allocate_channel(
  1539. self,
  1540. reader_and_node_list,
  1541. input_dag_node.type_hint,
  1542. None,
  1543. )
  1544. task.output_channels.append(output_channel)
  1545. task.output_idxs.append(
  1546. None
  1547. if isinstance(input_dag_node, InputNode)
  1548. else input_dag_node.key
  1549. )
  1550. # Update the InputAttributeNode's `output_channels`, which is
  1551. # used to determine whether to create a CachedChannel.
  1552. if isinstance(input_dag_node, InputAttributeNode):
  1553. input_attr_idx = self.dag_node_to_idx[input_dag_node]
  1554. input_attr_task = self.idx_to_task[input_attr_idx]
  1555. input_attr_task.output_channels.append(output_channel)
  1556. assert len(input_attr_task.output_channels) == 1
  1557. else:
  1558. assert isinstance(task.dag_node, InputAttributeNode) or isinstance(
  1559. task.dag_node, MultiOutputNode
  1560. )
  1561. for idx in task.downstream_task_idxs:
  1562. frontier.append(idx)
  1563. # Validate input channels for tasks that have not been visited
  1564. for node_idx, task in self.idx_to_task.items():
  1565. if (
  1566. node_idx == self.input_task_idx
  1567. or node_idx == self.output_task_idx
  1568. or isinstance(task.dag_node, InputAttributeNode)
  1569. ):
  1570. continue
  1571. if node_idx not in visited:
  1572. has_at_least_one_channel_input = False
  1573. for arg in task.args:
  1574. if isinstance(arg, DAGNode):
  1575. has_at_least_one_channel_input = True
  1576. if not has_at_least_one_channel_input:
  1577. raise ValueError(
  1578. "Compiled DAGs require each task to take a ray.dag.InputNode "
  1579. "or at least one other DAGNode as an input. "
  1580. "Invalid task node:\n"
  1581. f"{task.dag_node}\n"
  1582. "Please bind the task to proper DAG nodes."
  1583. )
  1584. from ray.dag.constants import RAY_CGRAPH_ENABLE_DETECT_DEADLOCK
  1585. if RAY_CGRAPH_ENABLE_DETECT_DEADLOCK and self._detect_deadlock():
  1586. raise ValueError(
  1587. "This DAG cannot be compiled because it will deadlock on accelerator "
  1588. "calls. If you believe this is a false positive, please disable "
  1589. "the graph verification by setting the environment variable "
  1590. "RAY_CGRAPH_ENABLE_DETECT_DEADLOCK to 0 and file an issue at "
  1591. "https://github.com/ray-project/ray/issues/new/."
  1592. )
  1593. input_task = self.idx_to_task[self.input_task_idx]
  1594. self.dag_input_channels = input_task.output_channels
  1595. assert self.dag_input_channels is not None
  1596. # Create executable tasks for each actor
  1597. for actor_handle, tasks in self.actor_to_tasks.items():
  1598. # Dict from arg to the set of tasks that consume it.
  1599. arg_to_consumers: Dict[DAGNode, Set[CompiledTask]] = defaultdict(set)
  1600. # Step 1: populate `arg_to_consumers` and perform some validation.
  1601. for task in tasks:
  1602. has_at_least_one_channel_input = False
  1603. for arg in task.args:
  1604. if isinstance(arg, DAGNode):
  1605. has_at_least_one_channel_input = True
  1606. arg_to_consumers[arg].add(task)
  1607. arg_idx = self.dag_node_to_idx[arg]
  1608. upstream_task = self.idx_to_task[arg_idx]
  1609. assert len(upstream_task.output_channels) == 1
  1610. arg_channel = upstream_task.output_channels[0]
  1611. assert arg_channel is not None
  1612. # TODO: Support no-input DAGs (use an empty object to signal).
  1613. if not has_at_least_one_channel_input:
  1614. raise ValueError(
  1615. "Compiled DAGs require each task to take a "
  1616. "ray.dag.InputNode or at least one other DAGNode as an "
  1617. "input"
  1618. )
  1619. # Step 2: create cached channels if needed
  1620. # Dict from original channel to the channel to be used in execution.
  1621. # The value of this dict is either the original channel or a newly
  1622. # created CachedChannel (if the original channel is read more than once).
  1623. for arg, consumers in arg_to_consumers.items():
  1624. arg_idx = self.dag_node_to_idx[arg]
  1625. upstream_task = self.idx_to_task[arg_idx]
  1626. assert len(upstream_task.output_channels) == 1
  1627. arg_channel = upstream_task.output_channels[0]
  1628. assert arg_channel is not None
  1629. if len(consumers) > 1:
  1630. self._channel_dict[arg_channel] = CachedChannel(
  1631. len(consumers),
  1632. arg_channel,
  1633. )
  1634. else:
  1635. self._channel_dict[arg_channel] = arg_channel
  1636. # Step 3: create executable tasks for the actor
  1637. executable_tasks = []
  1638. for task in tasks:
  1639. resolved_args: List[Any] = []
  1640. for arg in task.args:
  1641. if isinstance(arg, DAGNode):
  1642. arg_idx = self.dag_node_to_idx[arg]
  1643. upstream_task = self.idx_to_task[arg_idx]
  1644. assert len(upstream_task.output_channels) == 1
  1645. arg_channel = upstream_task.output_channels[0]
  1646. assert arg_channel is not None
  1647. arg_channel = self._channel_dict[arg_channel]
  1648. resolved_args.append(arg_channel)
  1649. else:
  1650. # Constant arg
  1651. resolved_args.append(arg)
  1652. executable_task = ExecutableTask(
  1653. task,
  1654. resolved_args,
  1655. task.kwargs,
  1656. )
  1657. executable_tasks.append(executable_task)
  1658. # Sort executable tasks based on their bind index, i.e., submission order
  1659. # so that they will be executed in that order.
  1660. executable_tasks.sort(key=lambda task: task.bind_index)
  1661. self.actor_to_executable_tasks[actor_handle] = executable_tasks
  1662. from ray.dag.constants import RAY_CGRAPH_ENABLE_PROFILING
  1663. if RAY_CGRAPH_ENABLE_PROFILING:
  1664. exec_task_func = do_profile_tasks
  1665. else:
  1666. exec_task_func = do_exec_tasks
  1667. # Build an execution schedule for each actor
  1668. self.actor_to_execution_schedule = self._build_execution_schedule()
  1669. for actor_handle, executable_tasks in self.actor_to_executable_tasks.items():
  1670. self.worker_task_refs[actor_handle] = actor_handle.__ray_call__.options(
  1671. concurrency_group="_ray_system"
  1672. ).remote(
  1673. exec_task_func,
  1674. executable_tasks,
  1675. self.actor_to_execution_schedule[actor_handle],
  1676. self._overlap_gpu_communication,
  1677. )
  1678. assert self.output_task_idx is not None
  1679. self.dag_output_channels = []
  1680. for output in self.idx_to_task[self.output_task_idx].args:
  1681. assert isinstance(output, DAGNode)
  1682. output_idx = self.dag_node_to_idx[output]
  1683. task = self.idx_to_task[output_idx]
  1684. assert len(task.output_channels) == 1
  1685. self.dag_output_channels.append(task.output_channels[0])
  1686. # Register custom serializers for input, input attribute, and output nodes.
  1687. self._register_input_output_custom_serializer()
  1688. assert self.dag_input_channels
  1689. assert self.dag_output_channels
  1690. assert [
  1691. output_channel is not None for output_channel in self.dag_output_channels
  1692. ]
  1693. # If no MultiOutputNode was specified during the DAG creation, there is only
  1694. # one output. Return a single output channel instead of a list of
  1695. # channels.
  1696. if not self._returns_list:
  1697. assert len(self.dag_output_channels) == 1
  1698. # Driver should ray.put on input, ray.get/release on output
  1699. self._monitor = self._monitor_failures()
  1700. input_task = self.idx_to_task[self.input_task_idx]
  1701. if self._enable_asyncio:
  1702. self._dag_submitter = AwaitableBackgroundWriter(
  1703. self.dag_input_channels,
  1704. input_task.output_idxs,
  1705. is_input=True,
  1706. )
  1707. self._dag_output_fetcher = AwaitableBackgroundReader(
  1708. self.dag_output_channels,
  1709. self._fut_queue,
  1710. )
  1711. else:
  1712. self._dag_submitter = SynchronousWriter(
  1713. self.dag_input_channels, input_task.output_idxs, is_input=True
  1714. )
  1715. self._dag_output_fetcher = SynchronousReader(self.dag_output_channels)
  1716. self._dag_submitter.start()
  1717. self._dag_output_fetcher.start()
  1718. def _generate_dag_operation_graph_node(
  1719. self,
  1720. ) -> Dict["ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]]:
  1721. """
  1722. Generate READ, COMPUTE, and WRITE operations for each DAG node.
  1723. Returns:
  1724. A dictionary that maps an actor handle to a list of lists of
  1725. _DAGOperationGraphNode. For the same actor, the index of the
  1726. outer list corresponds to the index of the ExecutableTask in
  1727. the list of `executable_tasks` in `actor_to_executable_tasks`,
  1728. i.e. `exec_task_idx`. In the inner list, the order of operations
  1729. is READ, COMPUTE, and WRITE.
  1730. Example:
  1731. {
  1732. actor1: [
  1733. [READ COMPUTE WRITE] # exec_task_idx 0
  1734. [READ COMPUTE WRITE] # exec_task_idx 1
  1735. ]
  1736. }
  1737. """
  1738. from ray.dag.collective_node import CollectiveOutputNode
  1739. assert self.idx_to_task
  1740. assert self.actor_to_executable_tasks
  1741. actor_to_operation_nodes: Dict[
  1742. "ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]
  1743. ] = defaultdict(list)
  1744. for actor_handle, executable_tasks in self.actor_to_executable_tasks.items():
  1745. for exec_task_idx, exec_task in enumerate(executable_tasks):
  1746. # Divide a DAG node into three _DAGOperationGraphNodes: READ, COMPUTE,
  1747. # and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation.
  1748. task_idx = exec_task.task_idx
  1749. dag_node = self.idx_to_task[task_idx].dag_node
  1750. method_name = exec_task.method_name
  1751. actor_handle = dag_node._get_actor_handle()
  1752. requires_accelerator_read = False
  1753. for upstream_node in dag_node._upstream_nodes:
  1754. if upstream_node.type_hint.requires_accelerator():
  1755. requires_accelerator_read = True
  1756. break
  1757. requires_accelerator_compute = isinstance(
  1758. dag_node, CollectiveOutputNode
  1759. )
  1760. requires_accelerator_write = dag_node.type_hint.requires_accelerator()
  1761. read_node = _DAGOperationGraphNode(
  1762. _DAGNodeOperation(
  1763. exec_task_idx, _DAGNodeOperationType.READ, method_name
  1764. ),
  1765. task_idx,
  1766. actor_handle,
  1767. requires_accelerator_read,
  1768. )
  1769. compute_node = _DAGOperationGraphNode(
  1770. _DAGNodeOperation(
  1771. exec_task_idx, _DAGNodeOperationType.COMPUTE, method_name
  1772. ),
  1773. task_idx,
  1774. actor_handle,
  1775. requires_accelerator_compute,
  1776. )
  1777. write_node = _DAGOperationGraphNode(
  1778. _DAGNodeOperation(
  1779. exec_task_idx, _DAGNodeOperationType.WRITE, method_name
  1780. ),
  1781. task_idx,
  1782. actor_handle,
  1783. requires_accelerator_write,
  1784. )
  1785. actor_to_operation_nodes[actor_handle].append(
  1786. [read_node, compute_node, write_node]
  1787. )
  1788. return actor_to_operation_nodes
  1789. def _build_execution_schedule(
  1790. self,
  1791. ) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]:
  1792. """
  1793. Generate an execution schedule for each actor. The schedule is a list of
  1794. _DAGNodeOperation.
  1795. Step 1: Generate a DAG node operation graph. Refer to the functions
  1796. `_generate_dag_operation_graph_node` and `_build_dag_node_operation_graph`
  1797. for more details.
  1798. Step 2: Topological sort
  1799. It is possible to have multiple _DAGOperationGraphNodes with zero in-degree.
  1800. Refer to the function `_select_next_nodes` for the logic of selecting nodes.
  1801. Then, put the selected nodes into the corresponding actors' schedules.
  1802. The schedule should be intuitive to users, meaning that the execution should
  1803. perform operations in ascending order of `bind_index` as much as possible.
  1804. [Example]:
  1805. See `test_execution_schedule` for more examples.
  1806. Returns:
  1807. actor_to_execution_schedule: A dictionary that maps an actor handle to
  1808. the execution schedule which is a list of operations to be executed.
  1809. """
  1810. # Step 1: Build a graph of _DAGOperationGraphNode
  1811. actor_to_operation_nodes = self._generate_dag_operation_graph_node()
  1812. graph = _build_dag_node_operation_graph(
  1813. self.idx_to_task, actor_to_operation_nodes
  1814. )
  1815. # Step 2: Generate an execution schedule for each actor using topological sort
  1816. actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph)
  1817. # Step 3: Overlap GPU communication for the execution schedule if configured
  1818. actor_to_overlapped_schedule = None
  1819. if self._overlap_gpu_communication:
  1820. actor_to_overlapped_schedule = _generate_overlapped_execution_schedule(
  1821. actor_to_execution_schedule
  1822. )
  1823. if RAY_CGRAPH_VISUALIZE_SCHEDULE:
  1824. _visualize_execution_schedule(
  1825. actor_to_execution_schedule, actor_to_overlapped_schedule, graph
  1826. )
  1827. if actor_to_overlapped_schedule is not None:
  1828. return _extract_execution_schedule(actor_to_overlapped_schedule)
  1829. else:
  1830. return _extract_execution_schedule(actor_to_execution_schedule)
  1831. def _detect_deadlock(self) -> bool:
  1832. """
  1833. TODO (kevin85421): Avoid false negatives.
  1834. Currently, a compiled graph may deadlock if there are accelerator channels,
  1835. and the readers have control dependencies on the same actor. For example:
  1836. actor1.a ---> actor2.f1
  1837. |
  1838. ---> actor2.f2
  1839. The control dependency between `actor2.f1` and `actor2.f2` is that `f1` should
  1840. run before `f2`. If `actor1.a` writes to `actor2.f2` before `actor2.f1`, a
  1841. deadlock will occur.
  1842. Currently, the execution schedule is not granular enough to detect this
  1843. deadlock.
  1844. Returns:
  1845. True if a deadlock is detected; otherwise, False.
  1846. """
  1847. logger.debug("Deadlock detection has not been implemented yet.")
  1848. return False
  1849. def _monitor_failures(self):
  1850. get_outer = weakref.ref(self)
  1851. class Monitor(threading.Thread):
  1852. def __init__(self):
  1853. super().__init__(daemon=True)
  1854. self.name = "CompiledGraphMonitorThread"
  1855. # Lock to make sure that we only perform teardown for this DAG
  1856. # once.
  1857. self._in_teardown_lock = threading.Lock()
  1858. self._teardown_done = False
  1859. def _outer_ref_alive(self) -> bool:
  1860. if get_outer() is None:
  1861. logger.error(
  1862. "CompiledDAG has been destructed before teardown. "
  1863. "This should not occur please report an issue at "
  1864. "https://github.com/ray-project/ray/issues/new/.",
  1865. stack_info=True,
  1866. )
  1867. return False
  1868. return True
  1869. def wait_teardown(self, kill_actors: bool = False):
  1870. outer = get_outer()
  1871. if not self._outer_ref_alive():
  1872. return
  1873. from ray.dag import DAGContext
  1874. ctx = DAGContext.get_current()
  1875. teardown_timeout = ctx.teardown_timeout
  1876. for actor, ref in outer.worker_task_refs.items():
  1877. timeout = False
  1878. try:
  1879. ray.get(ref, timeout=teardown_timeout)
  1880. except ray.exceptions.GetTimeoutError:
  1881. msg = (
  1882. f"Compiled DAG actor {actor} is still running "
  1883. f"{teardown_timeout}s after teardown()."
  1884. )
  1885. if kill_actors:
  1886. msg += (
  1887. " Force-killing actor. "
  1888. "Increase RAY_CGRAPH_teardown_timeout if you want "
  1889. "teardown to wait longer."
  1890. )
  1891. ray.kill(actor)
  1892. else:
  1893. msg += (
  1894. " Teardown may hang. "
  1895. "Call teardown with kill_actors=True if force kill "
  1896. "is desired."
  1897. )
  1898. logger.warning(msg)
  1899. timeout = True
  1900. except Exception:
  1901. # We just want to check that the task has finished so
  1902. # we don't care if the actor task ended in an
  1903. # exception.
  1904. pass
  1905. if not timeout:
  1906. continue
  1907. try:
  1908. ray.get(ref)
  1909. except Exception:
  1910. pass
  1911. if kill_actors:
  1912. # In the previous loop, we allow the actor tasks to exit first.
  1913. # Now, we force kill the actors if not yet.
  1914. for actor in outer.worker_task_refs:
  1915. logger.info(f"Killing actor: {actor}")
  1916. ray.kill(actor)
  1917. def teardown(self, kill_actors: bool = False):
  1918. with self._in_teardown_lock:
  1919. if self._teardown_done:
  1920. return
  1921. outer = get_outer()
  1922. if not self._outer_ref_alive():
  1923. return
  1924. logger.info("Tearing down compiled DAG")
  1925. outer._dag_submitter.close()
  1926. outer._dag_output_fetcher.close()
  1927. for actor in outer.actor_to_executable_tasks.keys():
  1928. logger.info(f"Cancelling compiled worker on actor: {actor}")
  1929. # Cancel all actor loops in parallel.
  1930. cancel_refs = [
  1931. actor.__ray_call__.remote(do_cancel_executable_tasks, tasks)
  1932. for actor, tasks in outer.actor_to_executable_tasks.items()
  1933. ]
  1934. for cancel_ref in cancel_refs:
  1935. try:
  1936. ray.get(cancel_ref, timeout=30)
  1937. except RayChannelError:
  1938. # Channel error happens when a channel is closed
  1939. # or timed out. In this case, do not log.
  1940. pass
  1941. except Exception:
  1942. logger.exception("Error cancelling worker task")
  1943. pass
  1944. for (
  1945. communicator_id
  1946. ) in outer._actors_to_created_communicator_id.values():
  1947. _destroy_communicator(communicator_id)
  1948. logger.info("Waiting for worker tasks to exit")
  1949. self.wait_teardown(kill_actors=kill_actors)
  1950. logger.info("Teardown complete")
  1951. self._teardown_done = True
  1952. def run(self):
  1953. try:
  1954. outer = get_outer()
  1955. if not self._outer_ref_alive():
  1956. return
  1957. ray.get(list(outer.worker_task_refs.values()))
  1958. except KeyboardInterrupt:
  1959. logger.info(
  1960. "Received KeyboardInterrupt, tearing down with kill_actors=True"
  1961. )
  1962. self.teardown(kill_actors=True)
  1963. except Exception as e:
  1964. logger.debug(f"Handling exception from worker tasks: {e}")
  1965. self.teardown()
  1966. monitor = Monitor()
  1967. monitor.start()
  1968. return monitor
  1969. def _raise_if_too_many_inflight_executions(self):
  1970. num_inflight_executions = (
  1971. self._execution_index - self._max_finished_execution_index
  1972. )
  1973. if num_inflight_executions >= self._max_inflight_executions:
  1974. raise ray.exceptions.RayCgraphCapacityExceeded(
  1975. "The compiled graph can't have more than "
  1976. f"{self._max_inflight_executions} in-flight executions, and you "
  1977. f"currently have {num_inflight_executions} in-flight executions. "
  1978. "Retrieve an output using ray.get before submitting more requests or "
  1979. "increase `_max_inflight_executions`. "
  1980. "`dag.experimental_compile(_max_inflight_executions=...)`"
  1981. )
  1982. def _has_execution_results(
  1983. self,
  1984. execution_index: int,
  1985. ) -> bool:
  1986. """Check whether there are results corresponding to the given execution
  1987. index stored in self._result_buffer. This helps avoid fetching and
  1988. caching results again.
  1989. Args:
  1990. execution_index: The execution index corresponding to the result.
  1991. Returns:
  1992. Whether the result for the given index has been fetched and cached.
  1993. """
  1994. return execution_index in self._result_buffer
  1995. def _cache_execution_results(
  1996. self,
  1997. execution_index: int,
  1998. result: Any,
  1999. ):
  2000. """Cache execution results in self._result_buffer. Results are converted
  2001. to dictionary format to allow efficient element removal and calculation of
  2002. the buffer size. This can only be called once per execution index.
  2003. Args:
  2004. execution_index: The execution index corresponding to the result.
  2005. result: The results from all channels to be cached.
  2006. """
  2007. if not self._has_execution_results(execution_index):
  2008. for chan_idx, res in enumerate(result):
  2009. # avoid caching for any CompiledDAGRef that has already been destructed.
  2010. if not (
  2011. execution_index in self._destructed_ref_idxs
  2012. and chan_idx in self._destructed_ref_idxs[execution_index]
  2013. ):
  2014. self._result_buffer[execution_index][chan_idx] = res
  2015. def _get_execution_results(
  2016. self, execution_index: int, channel_index: Optional[int]
  2017. ) -> List[Any]:
  2018. """Retrieve execution results from self._result_buffer and return the result.
  2019. Results are converted back to original list format ordered by output channel
  2020. index.
  2021. Args:
  2022. execution_index: The execution index to retrieve results from.
  2023. channel_index: The index of the output channel corresponding to the result.
  2024. Channel indexing is consistent with the order of
  2025. self.dag_output_channels. None means that the result wraps outputs from
  2026. all output channels.
  2027. Returns:
  2028. The execution result corresponding to the given execution index and channel
  2029. index.
  2030. """
  2031. # Although CompiledDAGRef and CompiledDAGFuture guarantee that the same
  2032. # execution index and channel index combination will not be requested multiple
  2033. # times and therefore self._result_buffer will always have execution_index as
  2034. # a key, we still do a sanity check to avoid misuses.
  2035. assert execution_index in self._result_buffer
  2036. if channel_index is None:
  2037. # Convert results stored in self._result_buffer back to original
  2038. # list representation
  2039. result = [
  2040. kv[1]
  2041. for kv in sorted(
  2042. self._result_buffer.pop(execution_index).items(),
  2043. key=lambda kv: kv[0],
  2044. )
  2045. ]
  2046. else:
  2047. result = [self._result_buffer[execution_index].pop(channel_index)]
  2048. if execution_index not in self._got_ref_idxs:
  2049. self._got_ref_idxs[execution_index] = set()
  2050. self._got_ref_idxs[execution_index].add(channel_index)
  2051. self._clean_up_buffers(execution_index)
  2052. return result
  2053. def _delete_execution_results(self, execution_index: int, channel_index: int):
  2054. """
  2055. Delete the execution results for the given execution index and channel index.
  2056. This method should be called when a CompiledDAGRef or CompiledDAGFuture is
  2057. destructed.
  2058. Note that this method maintains metadata for the deleted execution results,
  2059. and only actually deletes the buffers lazily when the buffer is not needed
  2060. anymore.
  2061. Args:
  2062. execution_index: The execution index to destruct results from.
  2063. channel_index: The index of the output channel corresponding to the result.
  2064. """
  2065. if execution_index not in self._destructed_ref_idxs:
  2066. self._destructed_ref_idxs[execution_index] = set()
  2067. self._destructed_ref_idxs[execution_index].add(channel_index)
  2068. self._clean_up_buffers(execution_index)
  2069. def _try_release_result_buffer(self, execution_index: int):
  2070. """
  2071. Try to release the result buffer for the given execution index.
  2072. """
  2073. should_release = False
  2074. got_channel_idxs = self._got_ref_idxs.get(execution_index, set())
  2075. if None in got_channel_idxs:
  2076. assert len(got_channel_idxs) == 1, (
  2077. "when None exists in got_channel_idxs, it means all channels, and "
  2078. "it should be the only value in the set",
  2079. )
  2080. should_release = True
  2081. else:
  2082. destructed_channel_idxs = self._destructed_ref_idxs.get(
  2083. execution_index, set()
  2084. )
  2085. processed_channel_idxs = got_channel_idxs.union(destructed_channel_idxs)
  2086. # No more processing is needed for this execution index.
  2087. should_release = processed_channel_idxs == set(
  2088. range(len(self.dag_output_channels))
  2089. )
  2090. if not should_release:
  2091. return False
  2092. self._result_buffer.pop(execution_index, None)
  2093. self._destructed_ref_idxs.pop(execution_index, None)
  2094. self._got_ref_idxs.pop(execution_index, None)
  2095. return True
  2096. def _try_release_native_buffer(
  2097. self, idx_to_release: int, timeout: Optional[float] = None
  2098. ) -> bool:
  2099. """
  2100. Try to release the native buffer for the given execution index.
  2101. Args:
  2102. idx_to_release: The execution index to release buffers from.
  2103. timeout: The maximum time in seconds to wait for the release.
  2104. Returns:
  2105. Whether the buffers have been released.
  2106. """
  2107. if idx_to_release != self._max_finished_execution_index + 1:
  2108. # Native buffer can only be released for the next execution index.
  2109. return False
  2110. destructed_channel_idxs = self._destructed_ref_idxs.get(idx_to_release, set())
  2111. should_release = False
  2112. if None in destructed_channel_idxs:
  2113. assert len(destructed_channel_idxs) == 1, (
  2114. "when None exists in destructed_channel_idxs, it means all channels, "
  2115. "and it should be the only value in the set",
  2116. )
  2117. should_release = True
  2118. elif len(destructed_channel_idxs) == len(self.dag_output_channels):
  2119. should_release = True
  2120. if not should_release:
  2121. return False
  2122. # refs corresponding to idx_to_release are all destructed,
  2123. # and they are never fetched or cached.
  2124. assert idx_to_release not in self._result_buffer
  2125. assert idx_to_release not in self._got_ref_idxs
  2126. try:
  2127. self._dag_output_fetcher.release_channel_buffers(timeout)
  2128. except RayChannelTimeoutError as e:
  2129. raise RayChannelTimeoutError(
  2130. "Releasing native buffers corresponding to a stale CompiledDAGRef "
  2131. "is taking a long time. If this is expected, increase "
  2132. f"RAY_CGRAPH_get_timeout which is currently {self._get_timeout} "
  2133. "seconds. Otherwise, this may indicate that the execution "
  2134. "is hanging."
  2135. ) from e
  2136. self._destructed_ref_idxs.pop(idx_to_release)
  2137. return True
  2138. def _try_release_buffer(
  2139. self, idx_to_release: int, timeout: Optional[float] = None
  2140. ) -> bool:
  2141. """
  2142. Try to release the buffer for the given execution index.
  2143. First try to release the native buffer, then try to release the result buffer.
  2144. Args:
  2145. idx_to_release: The execution index to release buffers from.
  2146. timeout: The maximum time in seconds to wait for the release.
  2147. Returns:
  2148. Whether the native buffer or result buffer has been released.
  2149. """
  2150. if self._try_release_native_buffer(idx_to_release, timeout):
  2151. # Releasing native buffer means the corresponding execution result
  2152. # is consumed (and discarded).
  2153. self._max_finished_execution_index += 1
  2154. return True
  2155. return self._try_release_result_buffer(idx_to_release)
  2156. def _try_release_buffers(self):
  2157. """
  2158. Repeatedly release buffer if possible.
  2159. This method starts from _max_finished_execution_index + 1 and tries to release
  2160. as many buffers as possible. If a native buffer is released,
  2161. _max_finished_execution_index will be incremented.
  2162. """
  2163. timeout = self._get_timeout
  2164. while True:
  2165. start_time = time.monotonic()
  2166. if not self._try_release_buffer(
  2167. self._max_finished_execution_index + 1, timeout
  2168. ):
  2169. break
  2170. if timeout != -1:
  2171. timeout -= time.monotonic() - start_time
  2172. timeout = max(timeout, 0)
  2173. def _clean_up_buffers(self, idx_to_release: int):
  2174. """
  2175. Clean up native and result buffers.
  2176. This method:
  2177. 1. Tries to release the buffer for the given execution index.
  2178. This index is the specific one that requires a clean up,
  2179. e.g., right after get() is called or a CompiledDAGRef/CompiledDAGFuture
  2180. is destructed.
  2181. 2. Tries to release all buffers starting from _max_finished_execution_index + 1.
  2182. This step is to clean up buffers that are no longer needed.
  2183. Args:
  2184. idx_to_release: The execution index that requires a clean up,
  2185. e.g., right after get() is called or a CompiledDAGRef/CompiledDAGFuture
  2186. is destructed.
  2187. """
  2188. self._try_release_buffer(idx_to_release)
  2189. self._try_release_buffers()
  2190. def _execute_until(
  2191. self,
  2192. execution_index: int,
  2193. channel_index: Optional[int] = None,
  2194. timeout: Optional[float] = None,
  2195. ):
  2196. """Repeatedly execute this DAG until the given execution index and
  2197. buffer results for all CompiledDagRef's.
  2198. If the DAG has already been executed up to the given index, it will do nothing.
  2199. Note: If this comes across execution indices for which the corresponding
  2200. CompiledDAGRef's have been destructed, it will release the buffer and not
  2201. cache the result.
  2202. Args:
  2203. execution_index: The execution index to execute until.
  2204. channel_index: The index of the output channel to get the result from.
  2205. Channel indexing is consistent with the order of
  2206. self.dag_output_channels. None means wrapping results from all output
  2207. channels into a single list.
  2208. timeout: The maximum time in seconds to wait for the execution.
  2209. None means using default timeout (DAGContext.get_timeout),
  2210. 0 means immediate timeout (immediate success or timeout without
  2211. blocking), -1 means infinite timeout (block indefinitely).
  2212. TODO(rui): catch the case that user holds onto the CompiledDAGRefs
  2213. """
  2214. if timeout is None:
  2215. timeout = self._get_timeout
  2216. while self._max_finished_execution_index < execution_index:
  2217. if len(self._result_buffer) >= self._max_buffered_results:
  2218. raise RayCgraphCapacityExceeded(
  2219. "The compiled graph can't have more than "
  2220. f"{self._max_buffered_results} buffered results, and you "
  2221. f"currently have {len(self._result_buffer)} buffered results. "
  2222. "Call `ray.get()` on CompiledDAGRef's (or await on "
  2223. "CompiledDAGFuture's) to retrieve results, or increase "
  2224. f"`_max_buffered_results` if buffering is desired, note that "
  2225. "this will increase driver memory usage."
  2226. )
  2227. start_time = time.monotonic()
  2228. # Fetch results from each output channel up to execution_index and cache
  2229. # them separately to enable individual retrieval
  2230. # If a CompiledDagRef for a specific execution index has been destructed,
  2231. # release the channel buffers for that execution index instead of caching
  2232. try:
  2233. if not self._try_release_native_buffer(
  2234. self._max_finished_execution_index + 1, timeout
  2235. ):
  2236. result = self._dag_output_fetcher.read(timeout)
  2237. self._cache_execution_results(
  2238. self._max_finished_execution_index + 1,
  2239. result,
  2240. )
  2241. # We have either released the native buffer or fetched and
  2242. # cached the result buffer, therefore we always increment
  2243. # _max_finished_execution_index.
  2244. self._max_finished_execution_index += 1
  2245. except RayChannelTimeoutError as e:
  2246. raise RayChannelTimeoutError(
  2247. "If the execution is expected to take a long time, increase "
  2248. f"RAY_CGRAPH_get_timeout which is currently {self._get_timeout} "
  2249. "seconds. Otherwise, this may indicate that the execution is "
  2250. "hanging."
  2251. ) from e
  2252. if timeout != -1:
  2253. timeout -= time.monotonic() - start_time
  2254. timeout = max(timeout, 0)
  2255. def execute(
  2256. self,
  2257. *args,
  2258. **kwargs,
  2259. ) -> Union[CompiledDAGRef, List[CompiledDAGRef]]:
  2260. """Execute this DAG using the compiled execution path.
  2261. Args:
  2262. args: Args to the InputNode.
  2263. kwargs: Kwargs to the InputNode
  2264. Returns:
  2265. A list of Channels that can be used to read the DAG result.
  2266. Raises:
  2267. RayChannelTimeoutError: If the execution does not complete within
  2268. self._submit_timeout seconds.
  2269. NOTE: Not thread-safe due to _execution_index etc.
  2270. """
  2271. if self._enable_asyncio:
  2272. raise ValueError("Use execute_async if enable_asyncio=True")
  2273. self._get_or_compile()
  2274. self._check_inputs(args, kwargs)
  2275. if len(args) == 1 and len(kwargs) == 0:
  2276. # When serializing a tuple, the Ray serializer invokes pickle5, which adds
  2277. # several microseconds of overhead. One common case for Compiled Graphs is
  2278. # passing a single argument (oftentimes of of type `bytes`, which requires
  2279. # no serialization). To avoid imposing this overhead on this common case, we
  2280. # create a fast path for this case that avoids pickle5.
  2281. inp = args[0]
  2282. else:
  2283. inp = CompiledDAGArgs(args=args, kwargs=kwargs)
  2284. # We want to release any buffers we can at this point based on the
  2285. # max_finished_execution_index so that the number of inflight executions
  2286. # is up to date.
  2287. self._try_release_buffers()
  2288. self._raise_if_too_many_inflight_executions()
  2289. try:
  2290. self._dag_submitter.write(inp, self._submit_timeout)
  2291. except RayChannelTimeoutError as e:
  2292. raise RayChannelTimeoutError(
  2293. "If the execution is expected to take a long time, increase "
  2294. f"RAY_CGRAPH_submit_timeout which is currently {self._submit_timeout} "
  2295. "seconds. Otherwise, this may indicate that execution is hanging."
  2296. ) from e
  2297. self._execution_index += 1
  2298. if self._returns_list:
  2299. ref = [
  2300. CompiledDAGRef(self, self._execution_index, channel_index)
  2301. for channel_index in range(len(self.dag_output_channels))
  2302. ]
  2303. else:
  2304. ref = CompiledDAGRef(self, self._execution_index)
  2305. return ref
  2306. def _check_inputs(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
  2307. """
  2308. Helper method to check that the DAG args provided by the user during
  2309. execution are valid according to the defined DAG.
  2310. """
  2311. if len(args) != self._input_num_positional_args:
  2312. raise ValueError(
  2313. "dag.execute() or dag.execute_async() must be "
  2314. f"called with {self._input_num_positional_args} positional args, got "
  2315. f"{len(args)}"
  2316. )
  2317. for kwarg in self._input_kwargs:
  2318. if kwarg not in kwargs:
  2319. raise ValueError(
  2320. "dag.execute() or dag.execute_async() "
  2321. f"must be called with kwarg `{kwarg}`"
  2322. )
  2323. async def execute_async(
  2324. self,
  2325. *args,
  2326. **kwargs,
  2327. ) -> Union[CompiledDAGFuture, List[CompiledDAGFuture]]:
  2328. """Execute this DAG using the compiled execution path.
  2329. NOTE: Not thread-safe.
  2330. Args:
  2331. args: Args to the InputNode.
  2332. kwargs: Kwargs to the InputNode.
  2333. Returns:
  2334. A list of Channels that can be used to read the DAG result.
  2335. """
  2336. if not self._enable_asyncio:
  2337. raise ValueError("Use execute if enable_asyncio=False")
  2338. self._get_or_compile()
  2339. self._check_inputs(args, kwargs)
  2340. async with self._dag_submission_lock:
  2341. if len(args) == 1 and len(kwargs) == 0:
  2342. # When serializing a tuple, the Ray serializer invokes pickle5, which
  2343. # adds several microseconds of overhead. One common case for accelerated
  2344. # DAGs is passing a single argument (oftentimes of of type `bytes`,
  2345. # which requires no serialization). To avoid imposing this overhead on
  2346. # this common case, we create a fast path for this case that avoids
  2347. # pickle5.
  2348. inp = args[0]
  2349. else:
  2350. inp = CompiledDAGArgs(args=args, kwargs=kwargs)
  2351. self._raise_if_too_many_inflight_executions()
  2352. await self._dag_submitter.write(inp)
  2353. # Allocate a future that the caller can use to get the result.
  2354. fut = asyncio.Future()
  2355. await self._fut_queue.put(fut)
  2356. self._execution_index += 1
  2357. if self._returns_list:
  2358. fut = [
  2359. CompiledDAGFuture(self, self._execution_index, fut, channel_index)
  2360. for channel_index in range(len(self.dag_output_channels))
  2361. ]
  2362. else:
  2363. fut = CompiledDAGFuture(self, self._execution_index, fut)
  2364. return fut
  2365. def _visualize_ascii(self) -> str:
  2366. """
  2367. Visualize the compiled graph in
  2368. ASCII format with directional markers.
  2369. This function generates an ASCII visualization of a Compiled Graph,
  2370. where each task node is labeled,
  2371. and edges use `<` and `>` markers to show data flow direction.
  2372. This method is called by:
  2373. - `compiled_dag.visualize(format="ascii")`
  2374. High-Level Algorithm:
  2375. - Topological Sorting: Sort nodes topologically to organize
  2376. them into layers based on dependencies.
  2377. - Grid Initialization: Set up a 2D grid canvas with dimensions based
  2378. on the number of layers and the maximum number of nodes per layer.
  2379. - Node Placement: Position each node on the grid according to its
  2380. layer and relative position within that layer.
  2381. Spacing is added for readability, and directional markers (`<` and `>`)
  2382. are added to edges to show input/output flow clearly.
  2383. This method should be called
  2384. **after** compiling the graph with `experimental_compile()`.
  2385. Returns:
  2386. ASCII representation of the CG with Nodes Information,
  2387. Edges Information and Graph Built.
  2388. Limitations:
  2389. - Note: This is only used for quick visualization for small graphs.
  2390. For complex graph (i.e. more than 20 tasks), please use graphviz.
  2391. - Scale: Works best for smaller CGs (typically fewer than 20 tasks).
  2392. Larger CGs may result in dense, less readable ASCII
  2393. outputs due to limited space for node and edge rendering.
  2394. - Shape: Ideal for relatively shallow CGs with clear dependency paths.
  2395. For deep, highly branched or densely connected CGs,
  2396. readability may suffer.
  2397. - Edge Overlap: In cases with high fan-out (i.e., nodes with many children)
  2398. or fan-in (nodes with many parents), edge lines may intersect or overlap
  2399. in the ASCII visualization, potentially obscuring some connections.
  2400. - Multi-output Tasks: Multi-output tasks can be visualized, but positioning
  2401. may cause line breaks or overlap when a task has multiple outputs that
  2402. feed into nodes at varying depths.
  2403. Example:
  2404. Basic Visualization:
  2405. ```python
  2406. # Print the CG structure in ASCII format
  2407. print(compiled_dag.visualize(format="ascii"))
  2408. ```
  2409. Example of Ordered Visualization (task is build in order
  2410. to reduce line intersection):
  2411. ```python
  2412. with InputNode() as i:
  2413. o1, o2, o3 = a.return_three.bind(i)
  2414. o4 = b.echo.bind(o1)
  2415. o5 = b.echo.bind(o2)
  2416. o6, o7 = b.return_two.bind(o3)
  2417. dag = MultiOutputNode([o4, o5, o6, o7])
  2418. compiled_dag = dag.experimental_compile()
  2419. compiled_dag.visualize(format="ascii",view=True)
  2420. # Output:
  2421. # 0:InputNode
  2422. # |
  2423. # 1:Actor_54777d:return_three
  2424. # |---------------------------->|---------------------------->| # noqa
  2425. # 2:Output[0] 3:Output[1] 4:Output[2] # noqa
  2426. # | | | # noqa
  2427. # 5:Actor_c927c9:echo 6:Actor_c927c9:echo 7:Actor_c927c9:return_two # noqa
  2428. # | | |---------------------------->| # noqa
  2429. # | | 9:Output[0] 10:Output[1] # noqa
  2430. # |<----------------------------|-----------------------------|-----------------------------| # noqa
  2431. # 8:MultiOutputNode
  2432. ```
  2433. Example of Anti-pattern Visualization (There are intersections):
  2434. # We can swtich the nodes ordering to reduce intersections, i.e. swap o2 and o3
  2435. ```python
  2436. with InputNode() as i:
  2437. o1, o2, o3 = a.return_three.bind(i)
  2438. o4 = b.echo.bind(o1)
  2439. o5 = b.echo.bind(o3)
  2440. o6, o7 = b.return_two.bind(o2)
  2441. dag = MultiOutputNode([o4, o5, o6, o7])
  2442. compiled_dag = dag.experimental_compile()
  2443. compiled_dag.visualize(format="ascii",view=True)
  2444. # Output (Nodes 5, 7, 9, 10 should connect to Node 8):
  2445. # 0:InputNode
  2446. # |
  2447. # 1:Actor_84835a:return_three
  2448. # |---------------------------->|---------------------------->| # noqa
  2449. # 2:Output[0] 3:Output[1] 4:Output[2] # noqa
  2450. # | | | # noqa
  2451. # 5:Actor_02a6a1:echo 6:Actor_02a6a1:return_two 7:Actor_02a6a1:echo # noqa
  2452. # | |---------------------------->| # noqa
  2453. # | 9:Output[0] 10:Output[1] # noqa
  2454. # |<----------------------------------------------------------| # noqa
  2455. # 8:MultiOutputNode
  2456. ```
  2457. """
  2458. from ray.dag import (
  2459. ClassMethodNode,
  2460. DAGNode,
  2461. InputAttributeNode,
  2462. InputNode,
  2463. MultiOutputNode,
  2464. )
  2465. # Check that the DAG has been compiled
  2466. if not hasattr(self, "idx_to_task") or not self.idx_to_task:
  2467. raise ValueError(
  2468. "The DAG must be compiled before calling 'visualize()'. "
  2469. "Please call 'experimental_compile()' first."
  2470. )
  2471. # Check that each CompiledTask has a valid dag_node
  2472. for idx, task in self.idx_to_task.items():
  2473. if not hasattr(task, "dag_node") or not isinstance(task.dag_node, DAGNode):
  2474. raise ValueError(
  2475. f"Task at index {idx} does not have a valid 'dag_node'. "
  2476. "Ensure that 'experimental_compile()' completed successfully."
  2477. )
  2478. from collections import defaultdict, deque
  2479. # Create adjacency list representation of the DAG
  2480. # Adjacency list for DAG; maps a node index to its downstream nodes.
  2481. adj_list: Dict[int, List[int]] = defaultdict(list)
  2482. # Indegree count for topological sorting; maps a node index to its indegree.
  2483. indegree: Dict[int, int] = defaultdict(int)
  2484. # Tracks whether a node is a multi-output node.
  2485. is_multi_output: Dict[int, bool] = defaultdict(bool)
  2486. # Maps child node indices to their parent node indices.
  2487. child2parent: Dict[int, int] = defaultdict(int)
  2488. ascii_visualization = ""
  2489. # Node information; maps a node index to its descriptive label.
  2490. node_info: Dict[int, str] = {}
  2491. # Edge information; tuples of (upstream_index, downstream_index, edge_label).
  2492. edge_info: List[Tuple[int, int, str]] = []
  2493. for idx, task in self.idx_to_task.items():
  2494. dag_node = task.dag_node
  2495. label = f"Task {idx} "
  2496. # Determine the type and label of the node
  2497. if isinstance(dag_node, InputNode):
  2498. label += "InputNode"
  2499. elif isinstance(dag_node, InputAttributeNode):
  2500. label += f"InputAttributeNode[{dag_node.key}]"
  2501. elif isinstance(dag_node, MultiOutputNode):
  2502. label += "MultiOutputNode"
  2503. elif isinstance(dag_node, ClassMethodNode):
  2504. if dag_node.is_class_method_call:
  2505. method_name = dag_node.get_method_name()
  2506. actor_handle = dag_node._get_actor_handle()
  2507. actor_id = (
  2508. actor_handle._actor_id.hex()[:6] if actor_handle else "unknown"
  2509. )
  2510. label += f"Actor: {actor_id}... Method: {method_name}"
  2511. elif dag_node.is_class_method_output:
  2512. label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
  2513. else:
  2514. label += "ClassMethodNode"
  2515. else:
  2516. label += type(dag_node).__name__
  2517. node_info[idx] = label
  2518. for arg_index, arg in enumerate(dag_node.get_args()):
  2519. if isinstance(arg, DAGNode):
  2520. upstream_task_idx = self.dag_node_to_idx[arg]
  2521. # Get the type hint for this argument
  2522. if arg_index < len(task.arg_type_hints):
  2523. if task.arg_type_hints[arg_index].requires_accelerator():
  2524. type_hint = "Accelerator"
  2525. else:
  2526. type_hint = type(task.arg_type_hints[arg_index]).__name__
  2527. else:
  2528. type_hint = "UnknownType"
  2529. adj_list[upstream_task_idx].append(idx)
  2530. indegree[idx] += 1
  2531. edge_info.append((upstream_task_idx, idx, type_hint))
  2532. width_adjust = 0
  2533. for upstream_task_idx, child_idx_list in adj_list.items():
  2534. # Mark as multi-output if the node has more than one output path
  2535. if len(child_idx_list) > 1:
  2536. for child in child_idx_list:
  2537. is_multi_output[child] = True
  2538. child2parent[child] = upstream_task_idx
  2539. width_adjust = max(width_adjust, len(child_idx_list))
  2540. # Topological sort to determine layers
  2541. layers = defaultdict(list)
  2542. zero_indegree = deque([idx for idx in self.idx_to_task if indegree[idx] == 0])
  2543. layer_index = 0
  2544. while zero_indegree:
  2545. next_layer = deque()
  2546. while zero_indegree:
  2547. task_idx = zero_indegree.popleft()
  2548. layers[layer_index].append(task_idx)
  2549. for downstream in adj_list[task_idx]:
  2550. indegree[downstream] -= 1
  2551. if indegree[downstream] == 0:
  2552. next_layer.append(downstream)
  2553. zero_indegree = next_layer
  2554. layer_index += 1
  2555. # Print detailed node information
  2556. ascii_visualization += "Nodes Information:\n"
  2557. for idx, info in node_info.items():
  2558. ascii_visualization += f'{idx} [label="{info}"] \n'
  2559. # Print edges
  2560. ascii_visualization += "\nEdges Information:\n"
  2561. for upstream_task, downstream_task, type_hint in edge_info:
  2562. if type_hint == "Accelerator":
  2563. edgs_channel = "+++"
  2564. else:
  2565. edgs_channel = "---"
  2566. ascii_visualization += (
  2567. f"{upstream_task} {edgs_channel}>" f" {downstream_task}\n"
  2568. )
  2569. # Add the legend to the output
  2570. ascii_visualization += "\nLegend:\n"
  2571. ascii_visualization += "+++> : Represents Accelerator-type data channels\n"
  2572. ascii_visualization += "---> : Represents Shared Memory data channels\n"
  2573. # Find the maximum width (number of nodes in any layer)
  2574. max_width = max(len(layer) for layer in layers.values()) + width_adjust
  2575. height = len(layers)
  2576. # Build grid for ASCII visualization
  2577. grid = [[" " for _ in range(max_width * 20)] for _ in range(height * 2 - 1)]
  2578. # Place nodes in the grid with more details
  2579. task_to_pos = {}
  2580. for layer_num, layer_tasks in layers.items():
  2581. layer_y = layer_num * 2 # Every second row is for nodes
  2582. for col_num, task_idx in enumerate(layer_tasks):
  2583. task = self.idx_to_task[task_idx]
  2584. task_info = f"{task_idx}:"
  2585. # Determine if it's an actor method or a regular task
  2586. if isinstance(task.dag_node, ClassMethodNode):
  2587. if task.dag_node.is_class_method_call:
  2588. method_name = task.dag_node.get_method_name()
  2589. actor_handle = task.dag_node._get_actor_handle()
  2590. actor_id = (
  2591. actor_handle._actor_id.hex()[:6]
  2592. if actor_handle
  2593. else "unknown"
  2594. )
  2595. task_info += f"Actor_{actor_id}:{method_name}"
  2596. elif task.dag_node.is_class_method_output:
  2597. task_info += f"Output[{task.dag_node.output_idx}]"
  2598. else:
  2599. task_info += "UnknownMethod"
  2600. else:
  2601. task_info += type(task.dag_node).__name__
  2602. adjust_col_num = 0
  2603. if task_idx in is_multi_output:
  2604. adjust_col_num = layers[layer_num - 1].index(child2parent[task_idx])
  2605. col_x = (col_num + adjust_col_num) * 30 # Every 30th column for spacing
  2606. # Place the task information into the grid
  2607. for i, char in enumerate(task_info):
  2608. if col_x + i < len(grid[0]): # Ensure we don't overflow the grid
  2609. grid[layer_y][col_x + i] = char
  2610. task_to_pos[task_idx] = (layer_y, col_x)
  2611. # Connect the nodes with lines
  2612. for upstream_task, downstream_tasks in adj_list.items():
  2613. upstream_y, upstream_x = task_to_pos[upstream_task]
  2614. for downstream_task in downstream_tasks:
  2615. downstream_y, downstream_x = task_to_pos[downstream_task]
  2616. # Draw vertical line
  2617. for y in range(upstream_y + 1, downstream_y):
  2618. if grid[y][upstream_x] == " ":
  2619. grid[y][upstream_x] = "|"
  2620. # Draw horizontal line with directional arrows
  2621. if upstream_x != downstream_x:
  2622. for x in range(
  2623. min(upstream_x, downstream_x) + 1,
  2624. max(upstream_x, downstream_x),
  2625. ):
  2626. grid[downstream_y - 1][x] = (
  2627. "-"
  2628. if grid[downstream_y - 1][x] == " "
  2629. else grid[downstream_y - 1][x]
  2630. )
  2631. # Add arrows to indicate flow direction
  2632. if downstream_x > upstream_x:
  2633. grid[downstream_y - 1][downstream_x - 1] = ">"
  2634. else:
  2635. grid[downstream_y - 1][downstream_x + 1] = "<"
  2636. # Draw connection to the next task
  2637. grid[downstream_y - 1][downstream_x] = "|"
  2638. # Ensure proper multi-output task connection
  2639. for idx, task in self.idx_to_task.items():
  2640. if isinstance(task.dag_node, MultiOutputNode):
  2641. output_tasks = task.dag_node.get_args()
  2642. for i, output_task in enumerate(output_tasks):
  2643. if isinstance(output_task, DAGNode):
  2644. output_task_idx = self.dag_node_to_idx[output_task]
  2645. if output_task_idx in task_to_pos:
  2646. output_y, output_x = task_to_pos[output_task_idx]
  2647. grid[output_y - 1][output_x] = "|"
  2648. # Convert grid to string for printing
  2649. ascii_visualization += "\nGraph Built:\n"
  2650. ascii_visualization += "\n".join("".join(row) for row in grid)
  2651. return ascii_visualization
  2652. def get_channel_details(
  2653. self, channel: ChannelInterface, downstream_actor_id: str
  2654. ) -> str:
  2655. """
  2656. Get details about outer and inner channel types and channel ids
  2657. based on the channel and the downstream actor ID.
  2658. Used for graph visualization.
  2659. Args:
  2660. channel: The channel to get details for.
  2661. downstream_actor_id: The downstream actor ID.
  2662. Returns:
  2663. A string with details about the channel based on its connection
  2664. to the actor provided.
  2665. """
  2666. channel_details = type(channel).__name__
  2667. # get outer channel
  2668. if channel in self._channel_dict and self._channel_dict[channel] != channel:
  2669. channel = self._channel_dict[channel]
  2670. channel_details += f"\n{type(channel).__name__}"
  2671. if type(channel) is CachedChannel:
  2672. channel_details += f", {channel._channel_id[:6]}..."
  2673. # get inner channel
  2674. if (
  2675. type(channel) is CompositeChannel
  2676. and downstream_actor_id in channel._channel_dict
  2677. ):
  2678. inner_channel = channel._channel_dict[downstream_actor_id]
  2679. channel_details += f"\n{type(inner_channel).__name__}"
  2680. if type(inner_channel) is IntraProcessChannel:
  2681. channel_details += f", {inner_channel._channel_id[:6]}..."
  2682. return channel_details
  2683. def visualize(
  2684. self,
  2685. filename="compiled_graph",
  2686. format="png",
  2687. view=False,
  2688. channel_details=False,
  2689. ) -> str:
  2690. """
  2691. Visualize the compiled graph by showing tasks and their dependencies.
  2692. This method should be called **after** the graph has been compiled using
  2693. `experimental_compile()`.
  2694. Args:
  2695. filename: For non-ASCII formats, the output file name (without extension).
  2696. For ASCII format, the visualization will be printed to the console,
  2697. and this argument is ignored.
  2698. format: The format of the output file (e.g., 'png', 'pdf', 'ascii').
  2699. view: For non-ASCII formats: Whether to open the file with the default
  2700. viewer. For ASCII format: Whether to print the visualization and return
  2701. None or return the ascii visualization string directly.
  2702. channel_details: If True, adds channel details to edges.
  2703. Returns:
  2704. The string representation of the compiled graph. For Graphviz-based formats
  2705. (e.g., 'png', 'pdf', 'jpeg'), returns the Graphviz DOT string representation
  2706. of the compiled graph. For ASCII format, returns the ASCII string
  2707. representation of the compiled graph.
  2708. Raises:
  2709. ValueError: If the graph is empty or not properly compiled.
  2710. ImportError: If the `graphviz` package is not installed.
  2711. """
  2712. if format == "ascii":
  2713. if channel_details:
  2714. raise ValueError(
  2715. "Parameters 'channel_details' are"
  2716. " not compatible with 'ascii' format."
  2717. )
  2718. ascii_visualiztion_str = self._visualize_ascii()
  2719. if view:
  2720. print(ascii_visualiztion_str)
  2721. return ascii_visualiztion_str
  2722. try:
  2723. import graphviz
  2724. except ImportError:
  2725. raise ImportError(
  2726. "Please install graphviz to visualize the compiled graph. "
  2727. "You can install it by running `pip install graphviz`."
  2728. )
  2729. from ray.dag import (
  2730. ClassMethodNode,
  2731. DAGNode,
  2732. InputAttributeNode,
  2733. InputNode,
  2734. MultiOutputNode,
  2735. )
  2736. # Check that the DAG has been compiled
  2737. if not hasattr(self, "idx_to_task") or not self.idx_to_task:
  2738. raise ValueError(
  2739. "The DAG must be compiled before calling 'visualize()'. "
  2740. "Please call 'experimental_compile()' first."
  2741. )
  2742. # Check that each CompiledTask has a valid dag_node
  2743. for idx, task in self.idx_to_task.items():
  2744. if not hasattr(task, "dag_node") or not isinstance(task.dag_node, DAGNode):
  2745. raise ValueError(
  2746. f"Task at index {idx} does not have a valid 'dag_node'. "
  2747. "Ensure that 'experimental_compile()' completed successfully."
  2748. )
  2749. # Dot file for debugging
  2750. dot = graphviz.Digraph(name="compiled_graph", format=format)
  2751. # Give every actor a unique color, colors between 24k -> 40k tested as readable
  2752. # other colors may be too dark, especially when wrapping back around to 0
  2753. actor_id_to_color = defaultdict(
  2754. lambda: f"#{((len(actor_id_to_color) * 2000 + 24000) % 0xFFFFFF):06X}"
  2755. )
  2756. # Add nodes with task information
  2757. for idx, task in self.idx_to_task.items():
  2758. dag_node = task.dag_node
  2759. # Initialize the label and attributes
  2760. label = f"Task {idx}\n"
  2761. shape = "oval" # Default shape
  2762. style = "filled"
  2763. fillcolor = ""
  2764. # Handle different types of dag_node
  2765. if isinstance(dag_node, InputNode):
  2766. label += "InputNode"
  2767. shape = "rectangle"
  2768. fillcolor = "lightblue"
  2769. elif isinstance(dag_node, InputAttributeNode):
  2770. label += f"InputAttributeNode[{dag_node.key}]"
  2771. shape = "rectangle"
  2772. fillcolor = "lightblue"
  2773. elif isinstance(dag_node, MultiOutputNode):
  2774. label += "MultiOutputNode"
  2775. shape = "rectangle"
  2776. fillcolor = "yellow"
  2777. elif isinstance(dag_node, ClassMethodNode):
  2778. if dag_node.is_class_method_call:
  2779. # Class Method Call Node
  2780. method_name = dag_node.get_method_name()
  2781. actor = dag_node._get_actor_handle()
  2782. if actor:
  2783. class_name = (
  2784. actor._ray_actor_creation_function_descriptor.class_name
  2785. )
  2786. actor_id = actor._actor_id.hex()
  2787. label += f"Actor: {class_name}\n"
  2788. label += f"ID: {actor_id[:6]}...\n"
  2789. label += f"Method: {method_name}"
  2790. fillcolor = actor_id_to_color[actor_id]
  2791. else:
  2792. label += f"Method: {method_name}"
  2793. fillcolor = "lightgreen"
  2794. shape = "oval"
  2795. elif dag_node.is_class_method_output:
  2796. # Class Method Output Node
  2797. label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
  2798. shape = "rectangle"
  2799. fillcolor = "orange"
  2800. else:
  2801. # Unexpected ClassMethodNode
  2802. label += "ClassMethodNode"
  2803. shape = "diamond"
  2804. fillcolor = "red"
  2805. else:
  2806. # Unexpected node type
  2807. label += type(dag_node).__name__
  2808. shape = "diamond"
  2809. fillcolor = "red"
  2810. # Add the node to the graph with attributes
  2811. dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor)
  2812. channel_type_str = (
  2813. (
  2814. type(dag_node.type_hint).__name__
  2815. if dag_node.type_hint
  2816. else "UnknownType"
  2817. )
  2818. + "\n"
  2819. if channel_details
  2820. else None
  2821. )
  2822. # This logic is built on the assumption that there will only be multiple
  2823. # output channels if the task has multiple returns
  2824. # case: task with one output
  2825. if len(task.output_channels) == 1:
  2826. for downstream_node in task.dag_node._downstream_nodes:
  2827. downstream_idx = self.dag_node_to_idx[downstream_node]
  2828. edge_label = None
  2829. if channel_details:
  2830. edge_label = channel_type_str
  2831. edge_label += self.get_channel_details(
  2832. task.output_channels[0],
  2833. (
  2834. downstream_node._get_actor_handle()._actor_id.hex()
  2835. if type(downstream_node) is ClassMethodNode
  2836. else self._proxy_actor._actor_id.hex()
  2837. ),
  2838. )
  2839. dot.edge(str(idx), str(downstream_idx), label=edge_label)
  2840. # case: multi return, output channels connect to class method output nodes
  2841. elif len(task.output_channels) > 1:
  2842. assert len(task.output_idxs) == len(task.output_channels)
  2843. for output_channel, downstream_idx in zip(
  2844. task.output_channels, task.output_node_idxs
  2845. ):
  2846. edge_label = None
  2847. if channel_details:
  2848. edge_label = channel_type_str
  2849. edge_label += self.get_channel_details(
  2850. output_channel,
  2851. task.dag_node._get_actor_handle()._actor_id.hex(),
  2852. )
  2853. dot.edge(str(idx), str(downstream_idx), label=edge_label)
  2854. if type(task.dag_node) is InputAttributeNode:
  2855. # Add an edge from the InputAttributeNode to the InputNode
  2856. dot.edge(str(self.input_task_idx), str(idx))
  2857. dot.render(filename, view=view)
  2858. return dot.source
  2859. def _register_input_output_custom_serializer(self):
  2860. """
  2861. Register custom serializers for input, input attribute, and output nodes.
  2862. """
  2863. assert self.input_task_idx is not None
  2864. assert self.output_task_idx is not None
  2865. # Register custom serializers for input node.
  2866. input_task = self.idx_to_task[self.input_task_idx]
  2867. input_task.dag_node.type_hint.register_custom_serializer()
  2868. # Register custom serializers for input attribute nodes.
  2869. for input_attr_task_idx in self.input_attr_task_idxs:
  2870. input_attr_task = self.idx_to_task[input_attr_task_idx]
  2871. input_attr_task.dag_node.type_hint.register_custom_serializer()
  2872. # Register custom serializers for output nodes.
  2873. for output in self.idx_to_task[self.output_task_idx].args:
  2874. output.type_hint.register_custom_serializer()
  2875. def teardown(self, kill_actors: bool = False):
  2876. """
  2877. Teardown and cancel all actor tasks for this DAG. After this
  2878. function returns, the actors should be available to execute new tasks
  2879. or compile a new DAG.
  2880. Note: This method is automatically called when the CompiledDAG is destructed
  2881. or the script exits. However, this should be explicitly called before compiling
  2882. another graph on the same actors. Python may not garbage collect the
  2883. CompiledDAG object immediately when you may expect.
  2884. """
  2885. if self._is_teardown:
  2886. return
  2887. monitor = getattr(self, "_monitor", None)
  2888. if monitor is not None:
  2889. from ray.dag import DAGContext
  2890. ctx = DAGContext.get_current()
  2891. monitor.teardown(kill_actors=kill_actors)
  2892. monitor.join(timeout=ctx.teardown_timeout)
  2893. # We do not log a warning here if the thread is still alive because
  2894. # wait_teardown already logs upon teardown_timeout.
  2895. self._is_teardown = True
  2896. def __del__(self):
  2897. self.teardown()
  2898. @DeveloperAPI
  2899. def build_compiled_dag_from_ray_dag(
  2900. dag: "ray.dag.DAGNode",
  2901. submit_timeout: Optional[float] = None,
  2902. buffer_size_bytes: Optional[int] = None,
  2903. enable_asyncio: bool = False,
  2904. max_inflight_executions: Optional[int] = None,
  2905. max_buffered_results: Optional[int] = None,
  2906. overlap_gpu_communication: Optional[bool] = None,
  2907. default_communicator: Optional[Union[Communicator, str]] = "create",
  2908. ) -> "CompiledDAG":
  2909. compiled_dag = CompiledDAG(
  2910. submit_timeout,
  2911. buffer_size_bytes,
  2912. enable_asyncio,
  2913. max_inflight_executions,
  2914. max_buffered_results,
  2915. overlap_gpu_communication,
  2916. default_communicator,
  2917. )
  2918. def _build_compiled_dag(node):
  2919. compiled_dag._add_node(node)
  2920. return node
  2921. root = dag._find_root()
  2922. root.traverse_and_apply(_build_compiled_dag)
  2923. compiled_dag._get_or_compile()
  2924. global _compiled_dags
  2925. _compiled_dags[compiled_dag.get_id()] = compiled_dag
  2926. return compiled_dag