| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312 |
- import asyncio
- import logging
- import threading
- import time
- import traceback
- import uuid
- import weakref
- from collections import defaultdict
- from contextlib import nullcontext
- from dataclasses import asdict, dataclass
- from typing import (
- Any,
- Dict,
- List,
- Optional,
- Set,
- Tuple,
- Union,
- )
- import ray
- import ray.exceptions
- from ray.dag.constants import (
- RAY_CGRAPH_ENABLE_NVTX_PROFILING,
- RAY_CGRAPH_ENABLE_TORCH_PROFILING,
- RAY_CGRAPH_VISUALIZE_SCHEDULE,
- )
- from ray.dag.dag_node_operation import (
- _build_dag_node_operation_graph,
- _DAGNodeOperation,
- _DAGNodeOperationType,
- _DAGOperationGraphNode,
- _extract_execution_schedule,
- _generate_actor_to_execution_schedule,
- _generate_overlapped_execution_schedule,
- _visualize_execution_schedule,
- )
- from ray.dag.dag_operation_future import DAGOperationFuture, GPUFuture, ResolvedFuture
- from ray.exceptions import (
- RayCgraphCapacityExceeded,
- RayChannelError,
- RayChannelTimeoutError,
- RayTaskError,
- )
- from ray.experimental.channel import (
- AwaitableBackgroundReader,
- AwaitableBackgroundWriter,
- ChannelContext,
- ChannelInterface,
- ChannelOutputType,
- CompiledDAGArgs,
- CompositeChannel,
- IntraProcessChannel,
- ReaderInterface,
- SynchronousReader,
- SynchronousWriter,
- WriterInterface,
- )
- from ray.experimental.channel.accelerator_context import AcceleratorContext
- from ray.experimental.channel.auto_transport_type import (
- AutoTransportType,
- TypeHintResolver,
- )
- from ray.experimental.channel.cached_channel import CachedChannel
- from ray.experimental.channel.communicator import Communicator
- from ray.experimental.channel.shared_memory_channel import (
- SharedMemoryType,
- )
- from ray.experimental.channel.torch_tensor_accelerator_channel import (
- _destroy_communicator,
- _init_communicator,
- )
- from ray.experimental.channel.torch_tensor_type import TorchTensorType
- from ray.experimental.compiled_dag_ref import (
- CompiledDAGFuture,
- CompiledDAGRef,
- _process_return_vals,
- )
- from ray.util.annotations import DeveloperAPI
- from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
- logger = logging.getLogger(__name__)
- # Keep tracking of every compiled dag created during the lifetime of
- # this process. It tracks them as weakref meaning when the compiled dag
- # is GC'ed, it is automatically removed from here. It is used to teardown
- # compiled dags at interpreter shutdown time.
- _compiled_dags = weakref.WeakValueDictionary()
- # Relying on __del__ doesn't work well upon shutdown because
- # the destructor order is not guaranteed. We call this function
- # upon `ray.worker.shutdown` which is registered to atexit handler
- # so that teardown is properly called before objects are destructed.
- def _shutdown_all_compiled_dags():
- global _compiled_dags
- for _, compiled_dag in _compiled_dags.items():
- # Kill DAG actors to avoid hanging during shutdown if the actor tasks
- # cannot be cancelled.
- compiled_dag.teardown(kill_actors=True)
- _compiled_dags = weakref.WeakValueDictionary()
- def _check_unused_dag_input_attributes(
- output_node: "ray.dag.MultiOutputNode", input_attributes: Set[str]
- ) -> Set[str]:
- """
- Helper function to check that all input attributes are used in the DAG.
- For example, if the user creates an input attribute by calling
- InputNode()["x"], we ensure that there is a path from the
- InputAttributeNode corresponding to "x" to the DAG's output. If an
- input attribute is not used, throw an error.
- Args:
- output_node: The starting node for the traversal.
- input_attributes: A set of attributes accessed by the InputNode.
- """
- from ray.dag import InputAttributeNode
- used_attributes = set()
- visited_nodes = set()
- stack: List["ray.dag.DAGNode"] = [output_node]
- while stack:
- current_node = stack.pop()
- if current_node in visited_nodes:
- continue
- visited_nodes.add(current_node)
- if isinstance(current_node, InputAttributeNode):
- used_attributes.add(current_node.key)
- stack.extend(current_node._upstream_nodes)
- unused_attributes = input_attributes - used_attributes
- if unused_attributes:
- unused_attributes_str = ", ".join(str(key) for key in unused_attributes)
- input_attributes_str = ", ".join(str(key) for key in input_attributes)
- unused_phrase = "is unused" if len(unused_attributes) == 1 else "are unused"
- raise ValueError(
- "Compiled Graph expects input to be accessed "
- f"using all of attributes {input_attributes_str}, "
- f"but {unused_attributes_str} {unused_phrase}. "
- "Ensure all input attributes are used and contribute "
- "to the computation of the Compiled Graph output."
- )
- @DeveloperAPI
- def do_allocate_channel(
- self,
- reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]],
- typ: ChannelOutputType,
- driver_actor_id: Optional[str] = None,
- ) -> ChannelInterface:
- """Generic actor method to allocate an output channel.
- Args:
- reader_and_node_list: A list of tuples, where each tuple contains a reader
- actor handle and the node ID where the actor is located.
- typ: The output type hint for the channel.
- driver_actor_id: If this channel is read by a driver and that driver is an
- actual actor, this will be the actor ID of that driver actor.
- Returns:
- The allocated channel.
- """
- # None means it is called from a driver.
- writer: Optional["ray.actor.ActorHandle"] = None
- try:
- writer = ray.get_runtime_context().current_actor
- except RuntimeError:
- # This is the driver so there is no current actor handle.
- pass
- output_channel = typ.create_channel(
- writer,
- reader_and_node_list,
- driver_actor_id,
- )
- return output_channel
- @DeveloperAPI
- def do_exec_tasks(
- self,
- tasks: List["ExecutableTask"],
- schedule: List[_DAGNodeOperation],
- overlap_gpu_communication: bool = False,
- ) -> None:
- """A generic actor method to begin executing the operations belonging to an
- actor. This runs an infinite loop to execute each _DAGNodeOperation in the
- order specified by the schedule. It exits only if the actor dies or an
- exception is thrown.
- Args:
- tasks: the executable tasks corresponding to the actor methods.
- schedule: A list of _DAGNodeOperation that should be executed in order.
- overlap_gpu_communication: Whether to overlap GPU communication with
- computation during DAG execution to improve performance.
- """
- try:
- for task in tasks:
- task.prepare(overlap_gpu_communication=overlap_gpu_communication)
- if RAY_CGRAPH_ENABLE_NVTX_PROFILING:
- assert (
- not RAY_CGRAPH_ENABLE_TORCH_PROFILING
- ), "NVTX and torch profiling cannot be enabled at the same time."
- try:
- import nvtx
- except ImportError:
- raise ImportError(
- "Please install nvtx to enable nsight profiling. "
- "You can install it by running `pip install nvtx`."
- )
- nvtx_profile = nvtx.Profile()
- nvtx_profile.enable()
- if RAY_CGRAPH_ENABLE_TORCH_PROFILING:
- assert (
- not RAY_CGRAPH_ENABLE_NVTX_PROFILING
- ), "NVTX and torch profiling cannot be enabled at the same time."
- import torch
- torch_profile = torch.profiler.profile(
- activities=[
- torch.profiler.ProfilerActivity.CPU,
- torch.profiler.ProfilerActivity.CUDA,
- ],
- with_stack=True,
- on_trace_ready=torch.profiler.tensorboard_trace_handler(
- "compiled_graph_torch_profiles"
- ),
- )
- torch_profile.start()
- logger.info("Torch profiling started")
- done = False
- while True:
- if done:
- break
- for operation in schedule:
- done = tasks[operation.exec_task_idx].exec_operation(
- self, operation.type, overlap_gpu_communication
- )
- if done:
- break
- if RAY_CGRAPH_ENABLE_NVTX_PROFILING:
- nvtx_profile.disable()
- if RAY_CGRAPH_ENABLE_TORCH_PROFILING:
- torch_profile.stop()
- logger.info("Torch profiling stopped")
- except Exception:
- logging.exception("Compiled DAG task exited with exception")
- raise
- @DeveloperAPI
- def do_profile_tasks(
- self,
- tasks: List["ExecutableTask"],
- schedule: List[_DAGNodeOperation],
- overlap_gpu_communication: bool = False,
- ) -> None:
- """A generic actor method similar to `do_exec_tasks`, but with profiling enabled.
- Args:
- tasks: the executable tasks corresponding to the actor methods.
- schedule: A list of _DAGNodeOperation that should be executed in order.
- overlap_gpu_communication: Whether to overlap GPU communication with
- computation during DAG execution to improve performance.
- """
- try:
- for task in tasks:
- task.prepare(overlap_gpu_communication=overlap_gpu_communication)
- if not hasattr(self, "__ray_cgraph_events"):
- self.__ray_cgraph_events = []
- done = False
- while True:
- if done:
- break
- for operation in schedule:
- start_t = time.perf_counter()
- task = tasks[operation.exec_task_idx]
- done = task.exec_operation(
- self, operation.type, overlap_gpu_communication
- )
- end_t = time.perf_counter()
- self.__ray_cgraph_events.append(
- _ExecutableTaskRecord(
- actor_classname=self.__class__.__name__,
- actor_name=ray.get_runtime_context().get_actor_name(),
- actor_id=ray.get_runtime_context().get_actor_id(),
- method_name=task.method_name,
- bind_index=task.bind_index,
- operation=operation.type.value,
- start_t=start_t,
- end_t=end_t,
- )
- )
- if done:
- break
- except Exception:
- logging.exception("Compiled DAG task exited with exception")
- raise
- @DeveloperAPI
- def do_cancel_executable_tasks(self, tasks: List["ExecutableTask"]) -> None:
- # CUDA events should be destroyed before other CUDA resources.
- for task in tasks:
- task.destroy_cuda_event()
- for task in tasks:
- task.cancel()
- def _wrap_exception(exc):
- backtrace = ray._private.utils.format_error_message(
- "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
- task_exception=True,
- )
- wrapped = RayTaskError(
- function_name="do_exec_tasks",
- traceback_str=backtrace,
- cause=exc,
- )
- return wrapped
- def _get_comm_group_id(type_hint: ChannelOutputType) -> Optional[str]:
- """
- Get the communicator group ID from the type hint. If the type hint does not
- require communicator, return None.
- Args:
- type_hint: The type hint of the channel.
- Returns:
- The communicator group ID if the type hint requires communicator,
- otherwise None.
- """
- if type_hint.requires_accelerator():
- assert isinstance(type_hint, TorchTensorType)
- return type_hint.communicator_id
- return None
- def _device_context_manager():
- """
- Return a context manager for executing communication operations
- (i.e., READ and WRITE). For accelerator operations, the context manager
- uses the proper cuda device from channel context, otherwise,
- nullcontext will be returned.
- """
- if not ChannelContext.get_current().torch_available:
- return nullcontext()
- import torch
- from ray.experimental.channel.accelerator_context import AcceleratorContext
- device = ChannelContext.get_current().torch_device
- if device.type == "cuda" and not torch.cuda.is_available():
- # In the case of mocked NCCL, we may get a device with type "cuda"
- # but CUDA is not available. We return nullcontext() in that case,
- # otherwise torch raises a runtime error if the cuda device context
- # manager is used.
- # TODO(rui): consider better mocking NCCL to support device context.
- return nullcontext()
- return AcceleratorContext.get().get_device_context(device)
- @DeveloperAPI
- class CompiledTask:
- """Wraps the normal Ray DAGNode with some metadata."""
- def __init__(self, idx: int, dag_node: "ray.dag.DAGNode"):
- """
- Args:
- idx: A unique index into the original DAG.
- dag_node: The original DAG node created by the user.
- """
- self.idx = idx
- self.dag_node = dag_node
- # Dict from task index to actor handle for immediate downstream tasks.
- self.downstream_task_idxs: Dict[int, "ray.actor.ActorHandle"] = {}
- # Case 1: The task represents a ClassMethodNode.
- #
- # Multiple return values are written to separate `output_channels`.
- # `output_idxs` represents the tuple index of the output value for
- # multiple returns in a tuple. If an output index is None, it means
- # the complete return value is written to the output channel.
- # Otherwise, the return value is a tuple and the index is used
- # to extract the value to be written to the output channel.
- #
- # Case 2: The task represents an InputNode.
- #
- # `output_idxs` can be an integer or a string to retrieve the
- # corresponding value from `args` or `kwargs` in the DAG's input.
- self.output_channels: List[ChannelInterface] = []
- self.output_idxs: List[Optional[Union[int, str]]] = []
- # The DAGNodes that are arguments to this task.
- # This is used for lazy resolution of the arguments' type hints.
- self.arg_nodes: List["ray.dag.DAGNode"] = []
- # idxs of possible ClassMethodOutputNodes if they exist, used for visualization
- self.output_node_idxs: List[int] = []
- @property
- def args(self) -> Tuple[Any]:
- return self.dag_node.get_args()
- @property
- def kwargs(self) -> Dict[str, Any]:
- return self.dag_node.get_kwargs()
- @property
- def num_readers(self) -> int:
- return len(self.downstream_task_idxs)
- @property
- def arg_type_hints(self) -> List["ChannelOutputType"]:
- return [arg_node.type_hint for arg_node in self.arg_nodes]
- def __str__(self) -> str:
- return f"""
- Node: {self.dag_node}
- Arguments: {self.args}
- Output: {self.output_channels}
- """
- class _ExecutableTaskInput:
- """Represents an input to an ExecutableTask.
- Args:
- input_variant: either an unresolved input (when type is ChannelInterface)
- , or a resolved input value (when type is Any)
- channel_idx: if input_variant is an unresolved input, this is the index
- into the input channels list.
- """
- def __init__(
- self,
- input_variant: Union[ChannelInterface, Any],
- channel_idx: Optional[int],
- ):
- self.input_variant = input_variant
- self.channel_idx = channel_idx
- def resolve(self, channel_results: Any) -> Any:
- """
- Resolve the input value from the channel results.
- Args:
- channel_results: The results from reading the input channels.
- """
- if isinstance(self.input_variant, ChannelInterface):
- value = channel_results[self.channel_idx]
- else:
- value = self.input_variant
- return value
- @DeveloperAPI
- class ExecutableTask:
- """A task that can be executed in a compiled DAG, and it
- corresponds to an actor method.
- """
- def __init__(
- self,
- task: "CompiledTask",
- resolved_args: List[Any],
- resolved_kwargs: Dict[str, Any],
- ):
- """
- Args:
- task: The CompiledTask that this ExecutableTask corresponds to.
- resolved_args: The arguments to the method. Arguments that are
- not Channels will get passed through to the actor method.
- If the argument is a channel, it will be replaced by the
- value read from the channel before the method executes.
- resolved_kwargs: The keyword arguments to the method. Currently, we
- do not support binding kwargs to other DAG nodes, so the values
- of the dictionary cannot be Channels.
- """
- from ray.dag import CollectiveOutputNode
- self.method_name = task.dag_node.get_method_name()
- self.bind_index = task.dag_node._get_bind_index()
- self.output_channels = task.output_channels
- self.output_idxs = task.output_idxs
- self.input_type_hints: List[ChannelOutputType] = task.arg_type_hints
- self.output_type_hint: ChannelOutputType = task.dag_node.type_hint
- # The accelerator collective operation.
- self.collective_op: Optional["ray.dag.CollectiveOperation"] = None
- if isinstance(task.dag_node, CollectiveOutputNode):
- self.collective_op = task.dag_node.collective_op
- self.input_channels: List[ChannelInterface] = []
- self.task_inputs: List[_ExecutableTaskInput] = []
- self.resolved_kwargs: Dict[str, Any] = resolved_kwargs
- # A unique index which can be used to index into `idx_to_task` to get
- # the corresponding task.
- self.task_idx = task.idx
- # Reverse map for input_channels: maps an input channel to
- # its index in input_channels.
- input_channel_to_idx: dict[ChannelInterface, int] = {}
- for arg in resolved_args:
- if isinstance(arg, ChannelInterface):
- channel = arg
- if channel in input_channel_to_idx:
- # The same channel was added before, so reuse the index.
- channel_idx = input_channel_to_idx[channel]
- else:
- # Add a new channel to the list of input channels.
- self.input_channels.append(channel)
- channel_idx = len(self.input_channels) - 1
- input_channel_to_idx[channel] = channel_idx
- task_input = _ExecutableTaskInput(arg, channel_idx)
- else:
- task_input = _ExecutableTaskInput(arg, None)
- self.task_inputs.append(task_input)
- # Currently DAGs do not support binding kwargs to other DAG nodes.
- for val in self.resolved_kwargs.values():
- assert not isinstance(val, ChannelInterface)
- # Input reader to read input data from upstream DAG nodes.
- self.input_reader: ReaderInterface = SynchronousReader(self.input_channels)
- # Output writer to write output data to downstream DAG nodes.
- self.output_writer: WriterInterface = SynchronousWriter(
- self.output_channels, self.output_idxs
- )
- # The intermediate future for a READ or COMPUTE operation,
- # and `wait()` must be called to get the actual result of the operation.
- # The result of a READ operation will be used by a COMPUTE operation,
- # and the result of a COMPUTE operation will be used by a WRITE operation.
- self._intermediate_future: Optional[DAGOperationFuture] = None
- def cancel(self):
- """
- Close all the input channels and the output channel. The exact behavior
- depends on the type of channel. Typically, it will release the resources
- used by the channels.
- """
- self.input_reader.close()
- self.output_writer.close()
- def destroy_cuda_event(self):
- """
- If this executable task has created a GPU future that is not yet waited on,
- that future is in the channel context cache. Remove the future from the cache
- and destroy its CUDA event.
- """
- GPUFuture.remove_gpu_future(self.task_idx)
- def prepare(self, overlap_gpu_communication: bool = False):
- """
- Prepare the task for execution. The `exec_operation` function can only
- be called after `prepare` has been called.
- Args:
- overlap_gpu_communication: Whether to overlap GPU communication with
- computation during DAG execution to improve performance
- """
- for typ_hint in self.input_type_hints:
- typ_hint.register_custom_serializer()
- self.output_type_hint.register_custom_serializer()
- self.input_reader.start()
- self.output_writer.start()
- # Stream context type are different between different accelerators.
- # Type hint is not applicable here.
- self._send_stream = nullcontext()
- self._recv_stream = nullcontext()
- if not overlap_gpu_communication:
- return
- # Set up send_stream and recv_stream when overlap_gpu_communication
- # is configured
- if self.output_type_hint.requires_accelerator():
- comm_group_id = _get_comm_group_id(self.output_type_hint)
- comm_group = ChannelContext.get_current().communicators.get(comm_group_id)
- assert comm_group is not None
- self._send_stream = comm_group.send_stream
- if self.input_type_hints:
- for type_hint in self.input_type_hints:
- if type_hint.requires_accelerator():
- comm_group_id = _get_comm_group_id(type_hint)
- comm_group = ChannelContext.get_current().communicators.get(
- comm_group_id
- )
- assert comm_group is not None
- if not isinstance(self._recv_stream, nullcontext):
- assert self._recv_stream == comm_group.recv_stream, (
- "Currently all torch tensor input channels of a "
- "Compiled Graph task should use the same recv cuda stream."
- )
- self._recv_stream = comm_group.recv_stream
- def wrap_and_set_intermediate_future(
- self, val: Any, wrap_in_gpu_future: bool
- ) -> None:
- """
- Wrap the value in a `DAGOperationFuture` and store to the intermediate future.
- The value corresponds to result of a READ or COMPUTE operation.
- If wrap_in_gpu_future is True, the value will be wrapped in a GPUFuture,
- Otherwise, the future will be a ResolvedFuture.
- Args:
- val: The value to wrap in a future.
- wrap_in_gpu_future: Whether to wrap the value in a GPUFuture.
- """
- assert self._intermediate_future is None
- if wrap_in_gpu_future:
- future = GPUFuture(val, self.task_idx)
- else:
- future = ResolvedFuture(val)
- self._intermediate_future = future
- def reset_and_wait_intermediate_future(self) -> Any:
- """
- Reset the intermediate future and wait for the result.
- The wait does not block the CPU because:
- - If the future is a ResolvedFuture, the result is immediately returned.
- - If the future is a GPUFuture, the result is only waited by the current
- CUDA stream, and the CPU is not blocked.
- Returns:
- The result of a READ or COMPUTE operation from the intermediate future.
- """
- future = self._intermediate_future
- self._intermediate_future = None
- return future.wait()
- def _read(self, overlap_gpu_communication: bool) -> bool:
- """
- Read input data from upstream DAG nodes and cache the intermediate result.
- Args:
- overlap_gpu_communication: Whether to overlap GPU communication with
- computation during DAG execution to improve performance.
- Returns:
- True if system error occurs and exit the loop; otherwise, False.
- """
- assert self._intermediate_future is None
- exit = False
- try:
- input_data = self.input_reader.read()
- # When overlap_gpu_communication is enabled, wrap the result in
- # a GPUFuture so that this read operation (communication) can
- # be overlapped with computation.
- self.wrap_and_set_intermediate_future(
- input_data,
- wrap_in_gpu_future=overlap_gpu_communication,
- )
- except RayChannelError:
- # Channel closed. Exit the loop.
- exit = True
- return exit
- def _compute(
- self,
- overlap_gpu_communication: bool,
- class_handle,
- ) -> bool:
- """
- Retrieve the intermediate result from the READ operation and perform the
- computation. Then, cache the new intermediate result. The caller must ensure
- that the last operation executed is READ so that the function retrieves the
- correct intermediate result.
- Args:
- overlap_gpu_communication: Whether to overlap GPU communication with
- computation during DAG execution to improve performance.
- class_handle: An instance of the class to which the actor belongs. For
- example, the type of `class_handle` is <class 'xxxx.Worker'> if the
- actor belongs to the `class Worker` class.
- Returns:
- True if system error occurs and exit the loop; otherwise, False.
- """
- input_data = self.reset_and_wait_intermediate_future()
- try:
- _process_return_vals(input_data, return_single_output=False)
- except Exception as exc:
- # Previous task raised an application-level exception.
- # Propagate it and skip the actual task. We don't need to wrap the
- # exception in a RayTaskError here because it has already been wrapped
- # by the previous task.
- self.wrap_and_set_intermediate_future(
- exc, wrap_in_gpu_future=overlap_gpu_communication
- )
- return False
- resolved_inputs = []
- for task_input in self.task_inputs:
- resolved_inputs.append(task_input.resolve(input_data))
- if self.collective_op is not None:
- # Run an accelerator collective operation.
- method = self.collective_op.execute
- else:
- # Run an actor method.
- method = getattr(class_handle, self.method_name)
- try:
- output_val = method(*resolved_inputs, **self.resolved_kwargs)
- except Exception as exc:
- output_val = _wrap_exception(exc)
- # When overlap_gpu_communication is enabled, wrap the result in a GPUFuture
- # so that this compute operation can be overlapped with communication.
- self.wrap_and_set_intermediate_future(
- output_val, wrap_in_gpu_future=overlap_gpu_communication
- )
- return False
- def _write(self) -> bool:
- """
- Retrieve the intermediate result from the COMPUTE operation and write to its
- downstream DAG nodes. The caller must ensure that the last operation executed
- is COMPUTE so that the function retrieves the correct intermediate result.
- Returns:
- True if system error occurs and exit the loop; otherwise, False.
- """
- output_val = self.reset_and_wait_intermediate_future()
- exit = False
- try:
- self.output_writer.write(output_val)
- except RayChannelError:
- # Channel closed. Exit the loop.
- exit = True
- return exit
- def exec_operation(
- self,
- class_handle,
- op_type: _DAGNodeOperationType,
- overlap_gpu_communication: bool = False,
- ) -> bool:
- """
- An ExecutableTask corresponds to a DAGNode. It consists of three
- operations: READ, COMPUTE, and WRITE, which should be executed in
- order to ensure that each operation can read the correct intermediate
- result.
- Args:
- class_handle: The handle of the class to which the actor belongs.
- op_type: The type of the operation. Possible types are READ,
- COMPUTE, and WRITE.
- overlap_gpu_communication: Whether to overlap GPU communication with
- computation during DAG execution to improve performance.
- Returns:
- True if the next operation should not be executed; otherwise, False.
- """
- if op_type == _DAGNodeOperationType.READ:
- with _device_context_manager():
- with self._recv_stream:
- return self._read(overlap_gpu_communication)
- elif op_type == _DAGNodeOperationType.COMPUTE:
- return self._compute(overlap_gpu_communication, class_handle)
- elif op_type == _DAGNodeOperationType.WRITE:
- with _device_context_manager():
- with self._send_stream:
- return self._write()
- @dataclass
- class _ExecutableTaskRecord:
- actor_classname: str
- actor_name: str
- actor_id: str
- method_name: str
- bind_index: int
- operation: str
- start_t: float
- end_t: float
- def to_dict(self):
- return asdict(self)
- @DeveloperAPI
- class CompiledDAG:
- """Experimental class for accelerated execution.
- This class should not be called directly. Instead, create
- a ray.dag and call experimental_compile().
- See REP https://github.com/ray-project/enhancements/pull/48 for more
- information.
- """
- @ray.remote(num_cpus=0)
- class DAGDriverProxyActor:
- """
- To support the driver as a reader, the output writer needs to be able to invoke
- remote functions on the driver. This is necessary so that the output writer can
- create a reader ref on the driver node, and later potentially create a larger
- reader ref on the driver node if the channel backing store needs to be resized.
- However, remote functions cannot be invoked on the driver.
- A Compiled Graph creates an actor from this class when the DAG is initialized.
- The actor is on the same node as the driver. This class has an empty
- implementation, though it serves as a way for the output writer to invoke remote
- functions on the driver node.
- """
- pass
- def __init__(
- self,
- submit_timeout: Optional[float] = None,
- buffer_size_bytes: Optional[int] = None,
- enable_asyncio: bool = False,
- max_inflight_executions: Optional[int] = None,
- max_buffered_results: Optional[int] = None,
- overlap_gpu_communication: Optional[bool] = None,
- default_communicator: Optional[Union[Communicator, str]] = "create",
- ):
- """
- Args:
- submit_timeout: The maximum time in seconds to wait for execute() calls.
- None means using default timeout (DAGContext.submit_timeout),
- 0 means immediate timeout (immediate success or timeout without
- blocking), -1 means infinite timeout (block indefinitely).
- buffer_size_bytes: The initial buffer size in bytes for messages
- that can be passed between tasks in the DAG. The buffers will
- be automatically resized if larger messages are written to the
- channel.
- enable_asyncio: Whether to enable asyncio. If enabled, caller must
- be running in an event loop and must use `execute_async` to
- invoke the DAG. Otherwise, the caller should use `execute` to
- invoke the DAG.
- max_inflight_executions: The maximum number of in-flight executions that
- can be submitted via `execute` or `execute_async` before consuming
- the output using `ray.get()`. If the caller submits more executions,
- `RayCgraphCapacityExceeded` is raised.
- max_buffered_results: The maximum number of results that can be
- buffered at the driver. If more results are buffered,
- `RayCgraphCapacityExceeded` is raised. Note that
- when result corresponding to an execution is retrieved
- (by calling `ray.get()` on a `CompiledDAGRef` or
- `CompiledDAGRef` or await on a `CompiledDAGFuture), results
- corresponding to earlier executions that have not been retrieved
- yet are buffered.
- overlap_gpu_communication: (experimental) Whether to overlap GPU
- communication with computation during DAG execution. If True, the
- communication and computation can be overlapped, which can improve
- the performance of the DAG execution. If None, the default value
- will be used.
- _default_communicator: The default communicator to use to transfer
- tensors. Three types of values are valid. (1) Communicator:
- For p2p operations, this is the default communicator
- to use for nodes annotated with `with_tensor_transport()` and when
- shared memory is not the desired option (e.g., when transport="accelerator",
- or when transport="auto" for communication between two different GPUs).
- For collective operations, this is the default communicator to use
- when a custom communicator is not specified.
- (2) "create": for each collective operation without a custom communicator
- specified, a communicator is created and initialized on its involved actors,
- or an already created communicator is reused if the set of actors is the same.
- For all p2p operations without a custom communicator specified, it reuses
- an already created collective communicator if the p2p actors are a subset.
- Otherwise, a new communicator is created.
- (3) None: a ValueError will be thrown if a custom communicator is not specified.
- Returns:
- Channel: A wrapper around ray.ObjectRef.
- """
- from ray.dag import DAGContext
- ctx = DAGContext.get_current()
- self._enable_asyncio: bool = enable_asyncio
- self._fut_queue = asyncio.Queue()
- self._max_inflight_executions = max_inflight_executions
- if self._max_inflight_executions is None:
- self._max_inflight_executions = ctx.max_inflight_executions
- self._max_buffered_results = max_buffered_results
- if self._max_buffered_results is None:
- self._max_buffered_results = ctx.max_buffered_results
- self._dag_id = uuid.uuid4().hex
- self._submit_timeout: Optional[float] = submit_timeout
- if self._submit_timeout is None:
- self._submit_timeout = ctx.submit_timeout
- self._get_timeout: Optional[float] = ctx.get_timeout
- self._buffer_size_bytes: Optional[int] = buffer_size_bytes
- if self._buffer_size_bytes is None:
- self._buffer_size_bytes = ctx.buffer_size_bytes
- self._overlap_gpu_communication: Optional[bool] = overlap_gpu_communication
- if self._overlap_gpu_communication is None:
- self._overlap_gpu_communication = ctx.overlap_gpu_communication
- self._create_default_communicator = False
- if isinstance(default_communicator, str):
- if default_communicator == "create":
- self._create_default_communicator = True
- default_communicator = None
- else:
- raise ValueError(
- "The only allowed string for default_communicator is 'create', "
- f"got {default_communicator}"
- )
- elif default_communicator is not None and not isinstance(
- default_communicator, Communicator
- ):
- raise ValueError(
- "The default_communicator must be None, a string, or a Communicator, "
- f"got {type(default_communicator)}"
- )
- self._default_communicator: Optional[Communicator] = default_communicator
- # Dict from passed-in communicator to set of type hints that refer to it.
- self._communicator_to_type_hints: Dict[
- Communicator,
- Set["ray.experimental.channel.torch_tensor_type.TorchTensorType"],
- ] = defaultdict(set)
- # Dict from set of actors to created communicator ID.
- # These communicators are created by Compiled Graph, rather than passed in.
- # Communicators are only created when self._create_default_communicator is True.
- self._actors_to_created_communicator_id: Dict[
- Tuple["ray.actor.ActorHandle"], str
- ] = {}
- # Set of actors involved in P2P communication using an unresolved communicator.
- self._p2p_actors_with_unresolved_communicators: Set[
- "ray.actor.ActorHandle"
- ] = set()
- # Set of DAG nodes involved in P2P communication using an unresolved communicator.
- self._p2p_dag_nodes_with_unresolved_communicators: Set[
- "ray.dag.DAGNode"
- ] = set()
- # Set of collective operations using an unresolved communicator.
- self._collective_ops_with_unresolved_communicators: Set[
- "ray.dag.collective_node._CollectiveOperation"
- ] = set()
- self._default_type_hint: ChannelOutputType = SharedMemoryType(
- buffer_size_bytes=self._buffer_size_bytes,
- # We conservatively set num_shm_buffers to _max_inflight_executions.
- # It means that the DAG can be underutilized, but it guarantees there's
- # no false positive timeouts.
- num_shm_buffers=self._max_inflight_executions,
- )
- if not isinstance(self._buffer_size_bytes, int) or self._buffer_size_bytes <= 0:
- raise ValueError(
- "`buffer_size_bytes` must be a positive integer, found "
- f"{self._buffer_size_bytes}"
- )
- # Used to ensure that the future returned to the
- # caller corresponds to the correct DAG output. I.e.
- # order of futures added to fut_queue should match the
- # order of inputs written to the DAG.
- self._dag_submission_lock = asyncio.Lock()
- # idx -> CompiledTask.
- self.idx_to_task: Dict[int, "CompiledTask"] = {}
- # DAGNode -> idx.
- self.dag_node_to_idx: Dict["ray.dag.DAGNode", int] = {}
- # idx counter.
- self.counter: int = 0
- # Attributes that are set during preprocessing.
- # Preprocessing identifies the input node and output node.
- self.input_task_idx: Optional[int] = None
- self.output_task_idx: Optional[int] = None
- # List of task indices that are input attribute nodes.
- self.input_attr_task_idxs: List[int] = []
- # Denotes whether execute/execute_async returns a list of refs/futures.
- self._returns_list: bool = False
- # Number of expected positional args and kwargs that may be passed to
- # dag.execute.
- self._input_num_positional_args: Optional[int] = None
- self._input_kwargs: Tuple[str, ...] = None
- # Cached attributes that are set during compilation.
- self.dag_input_channels: Optional[List[ChannelInterface]] = None
- self.dag_output_channels: Optional[List[ChannelInterface]] = None
- self._dag_submitter: Optional[WriterInterface] = None
- self._dag_output_fetcher: Optional[ReaderInterface] = None
- # ObjectRef for each worker's task. The task is an infinite loop that
- # repeatedly executes the method specified in the DAG.
- self.worker_task_refs: Dict["ray.actor.ActorHandle", "ray.ObjectRef"] = {}
- self.actor_to_tasks: Dict[
- "ray.actor.ActorHandle", List["CompiledTask"]
- ] = defaultdict(list)
- # Mapping from actor handle to its GPU IDs.
- # This is used for type hint resolution for with_tensor_transport("auto").
- self.actor_to_gpu_ids: Dict["ray.actor.ActorHandle", List[str]] = {}
- self.actor_to_executable_tasks: Dict[
- "ray.actor.ActorHandle", List["ExecutableTask"]
- ] = {}
- # Mapping from the actor handle to the execution schedule which is a list
- # of operations to be executed.
- self.actor_to_execution_schedule: Dict[
- "ray.actor.ActorHandle", List[_DAGNodeOperation]
- ] = defaultdict(list)
- # Mapping from the actor handle to the node ID that the actor is on.
- # A None actor handle means the actor is the driver.
- self.actor_to_node_id: Dict[Optional["ray.actor.ActorHandle"], str] = {}
- # The index of the current execution. It is incremented each time
- # the DAG is executed.
- self._execution_index: int = -1
- # The maximum index of finished executions.
- # All results with higher indexes have not been generated yet.
- self._max_finished_execution_index: int = -1
- # execution_index -> {channel_index -> result}
- self._result_buffer: Dict[int, Dict[int, Any]] = defaultdict(dict)
- # channel to possible inner channel
- self._channel_dict: Dict[ChannelInterface, ChannelInterface] = {}
- def _create_proxy_actor() -> "ray.actor.ActorHandle":
- # Creates the driver actor on the same node as the driver.
- #
- # To support the driver as a reader, the output writer needs to be able to
- # invoke remote functions on the driver (e.g., to create the reader ref, to
- # create a reader ref for a larger object when the channel backing store is
- # resized, etc.). The driver actor serves as a way for the output writer
- # to invoke remote functions on the driver node.
- return CompiledDAG.DAGDriverProxyActor.options(
- scheduling_strategy=NodeAffinitySchedulingStrategy(
- ray.get_runtime_context().get_node_id(), soft=False
- )
- ).remote()
- self._proxy_actor = _create_proxy_actor()
- # Set to True when `teardown` API is called.
- self._is_teardown = False
- # Execution index to set of channel indices for CompiledDAGRefs
- # or CompiledDAGFuture whose destructor has been called. A "None"
- # channel index means there is only one channel, and its destructor
- # has been called.
- self._destructed_ref_idxs: Dict[int, Set[Optional[int]]] = dict()
- # Execution index to set of channel indices for CompiledDAGRefs
- # or CompiledDAGFuture whose get() has been called. A "None"
- # channel index means there is only one channel, and its get()
- # has been called.
- self._got_ref_idxs: Dict[int, Set[Optional[int]]] = dict()
- @property
- def is_teardown(self) -> bool:
- return self._is_teardown
- def get_id(self) -> str:
- """
- Get the unique ID of the compiled DAG.
- """
- return self._dag_id
- def __str__(self) -> str:
- return f"CompiledDAG({self._dag_id})"
- def _add_node(self, node: "ray.dag.DAGNode") -> None:
- idx = self.counter
- self.idx_to_task[idx] = CompiledTask(idx, node)
- self.dag_node_to_idx[node] = idx
- self.counter += 1
- def _preprocess(self) -> None:
- """Before compiling, preprocess the DAG to build an index from task to
- upstream and downstream tasks, and to set the input and output node(s)
- of the DAG.
- This function is idempotent.
- """
- from ray.dag import (
- ClassMethodNode,
- CollectiveOutputNode,
- DAGNode,
- FunctionNode,
- InputAttributeNode,
- InputNode,
- MultiOutputNode,
- )
- self.input_task_idx, self.output_task_idx = None, None
- input_attributes: Set[str] = set()
- # Find the input node and input attribute nodes in the DAG.
- for idx, task in self.idx_to_task.items():
- if isinstance(task.dag_node, InputNode):
- assert self.input_task_idx is None, "More than one InputNode found"
- self.input_task_idx = idx
- # handle_unused_attributes:
- # Save input attributes in a set.
- input_node = task.dag_node
- input_attributes.update(input_node.input_attribute_nodes.keys())
- elif isinstance(task.dag_node, InputAttributeNode):
- self.input_attr_task_idxs.append(idx)
- # Find the (multi-)output node to the DAG.
- for idx, task in self.idx_to_task.items():
- if idx == self.input_task_idx or isinstance(
- task.dag_node, InputAttributeNode
- ):
- continue
- if (
- len(task.downstream_task_idxs) == 0
- and task.dag_node.is_cgraph_output_node
- ):
- assert self.output_task_idx is None, "More than one output node found"
- self.output_task_idx = idx
- assert self.output_task_idx is not None
- output_node = self.idx_to_task[self.output_task_idx].dag_node
- # Add an MultiOutputNode to the end of the DAG if it's not already there.
- if not isinstance(output_node, MultiOutputNode):
- output_node = MultiOutputNode([output_node])
- self._add_node(output_node)
- self.output_task_idx = self.dag_node_to_idx[output_node]
- else:
- self._returns_list = True
- # TODO: Support no-input DAGs (use an empty object to signal).
- if self.input_task_idx is None:
- raise NotImplementedError(
- "Compiled DAGs currently require exactly one InputNode"
- )
- # Whether the DAG binds directly to the InputNode(), versus binding to
- # a positional arg or kwarg of the input. For example, a.foo.bind(inp)
- # instead of a.foo.bind(inp[0]) or a.foo.bind(inp.key).
- direct_input: Optional[bool] = None
- # Collect the set of InputNode keys bound to DAG node args.
- input_positional_args: Set[int] = set()
- input_kwargs: Set[str] = set()
- # Set of tasks with annotation of with_tensor_transport("auto").
- # These only correspond to ClassMethodNodes, but not InputNodes
- # or InputAttributeNodes.
- auto_transport_tasks: Set["CompiledTask"] = set()
- # For each task node, set its upstream and downstream task nodes.
- # Also collect the set of tasks that produce torch.tensors.
- for task_idx, task in self.idx_to_task.items():
- dag_node = task.dag_node
- if not (
- isinstance(dag_node, InputNode)
- or isinstance(dag_node, InputAttributeNode)
- or isinstance(dag_node, MultiOutputNode)
- or isinstance(dag_node, ClassMethodNode)
- ):
- if isinstance(dag_node, FunctionNode):
- # TODO(swang): Support non-actor tasks.
- raise NotImplementedError(
- "Compiled DAGs currently only support actor method nodes"
- )
- else:
- raise ValueError(f"Found unsupported node of type {type(dag_node)}")
- if isinstance(dag_node, ClassMethodNode) and dag_node.is_class_method_call:
- actor_handle = dag_node._get_actor_handle()
- if actor_handle is None:
- raise ValueError(
- "Compiled DAGs can only bind methods to an actor "
- "that is already created with Actor.remote()"
- )
- if actor_handle not in self.actor_to_gpu_ids:
- self.actor_to_gpu_ids[actor_handle] = CompiledDAG._get_gpu_ids(
- actor_handle
- )
- if isinstance(dag_node.type_hint, AutoTransportType):
- auto_transport_tasks.add(task)
- # Collect actors for accelerator P2P methods.
- if dag_node.type_hint.requires_accelerator():
- self._track_communicator_usage(dag_node, {actor_handle})
- # Collect accelerator collective operations.
- if isinstance(dag_node, CollectiveOutputNode):
- self._track_communicator_usage(
- dag_node,
- set(dag_node._collective_op.actor_handles),
- collective_op=True,
- )
- assert not self._overlap_gpu_communication, (
- "Currently, the overlap_gpu_communication option is not "
- "supported for accelerator collective operations. Please set "
- "overlap_gpu_communication=False."
- )
- elif isinstance(dag_node, InputNode) or isinstance(
- dag_node, InputAttributeNode
- ):
- if dag_node.type_hint.requires_accelerator():
- raise ValueError(
- "DAG inputs cannot be transferred via accelerator because "
- "the driver cannot participate in the communicator group"
- )
- if isinstance(dag_node.type_hint, AutoTransportType):
- # Currently driver on GPU is not supported, so we always
- # use shared memory to transfer tensors.
- dag_node.type_hint = TorchTensorType(
- device=dag_node.type_hint.device
- )
- if type(dag_node.type_hint) is ChannelOutputType:
- # No type hint specified by the user. Replace
- # with the default type hint for this DAG.
- dag_node.type_hint = self._default_type_hint
- for _, val in task.kwargs.items():
- if isinstance(val, DAGNode):
- raise ValueError(
- "Compiled DAG currently does not support binding to "
- "other DAG nodes as kwargs"
- )
- for _, arg in enumerate(task.args):
- if not isinstance(arg, DAGNode):
- continue
- upstream_node_idx = self.dag_node_to_idx[arg]
- upstream_task = self.idx_to_task[upstream_node_idx]
- downstream_actor_handle = None
- if (
- isinstance(dag_node, ClassMethodNode)
- and dag_node.is_class_method_call
- ):
- downstream_actor_handle = dag_node._get_actor_handle()
- # Add upstream node as the argument nodes of this task, whose
- # type hints may be updated when resolved lazily.
- task.arg_nodes.append(upstream_task.dag_node)
- if isinstance(upstream_task.dag_node, InputAttributeNode):
- # Record all of the keys used to index the InputNode.
- # During execution, we will check that the user provides
- # the same args and kwargs.
- if isinstance(upstream_task.dag_node.key, int):
- input_positional_args.add(upstream_task.dag_node.key)
- elif isinstance(upstream_task.dag_node.key, str):
- input_kwargs.add(upstream_task.dag_node.key)
- else:
- raise ValueError(
- "InputNode() can only be indexed using int "
- "for positional args or str for kwargs."
- )
- if direct_input is not None and direct_input:
- raise ValueError(
- "All tasks must either use InputNode() "
- "directly, or they must index to specific args or "
- "kwargs."
- )
- direct_input = False
- # If the upstream node is an InputAttributeNode, treat the
- # DAG's input node as the actual upstream node
- upstream_task = self.idx_to_task[self.input_task_idx]
- elif isinstance(upstream_task.dag_node, InputNode):
- if direct_input is not None and not direct_input:
- raise ValueError(
- "All tasks must either use InputNode() directly, "
- "or they must index to specific args or kwargs."
- )
- direct_input = True
- upstream_task.downstream_task_idxs[task_idx] = downstream_actor_handle
- if upstream_task.dag_node.type_hint.requires_accelerator():
- # Here we are processing the args of the DAGNode, so track
- # downstream actors only, upstream actor is already tracked
- # when processing the DAGNode itself.
- self._track_communicator_usage(
- upstream_task.dag_node,
- {downstream_actor_handle},
- )
- # Check that all specified input attributes, e.g., InputNode()["x"],
- # are used in the DAG.
- _check_unused_dag_input_attributes(output_node, input_attributes)
- self._check_leaf_nodes()
- self._resolve_auto_transport(auto_transport_tasks)
- self._init_communicators()
- if direct_input:
- self._input_num_positional_args = 1
- elif not input_positional_args:
- self._input_num_positional_args = 0
- else:
- self._input_num_positional_args = max(input_positional_args) + 1
- self._input_kwargs = tuple(input_kwargs)
- def _init_communicators(self) -> None:
- """
- Initialize communicators for the DAG.
- """
- # First, initialize communicators that are passed in by the user.
- for communicator, type_hints in self._communicator_to_type_hints.items():
- communicator_id = _init_communicator(
- communicator.get_actor_handles(),
- communicator,
- self._overlap_gpu_communication,
- )
- for type_hint in type_hints:
- type_hint.set_communicator_id(communicator_id)
- # Second, get registered accelerator context if any.
- accelerator_module_name = AcceleratorContext.get().module_name
- accelerator_communicator_cls = AcceleratorContext.get().communicator_cls
- # Then, create communicators for collective operations.
- # Reuse an already created communicator for the same set of actors.
- for collective_op in self._collective_ops_with_unresolved_communicators:
- if not self._create_default_communicator:
- raise ValueError(
- "Communicator creation is not allowed for collective operations."
- )
- # using tuple to preserve the order of actors for collective operations
- actors = tuple(collective_op.actor_handles)
- if actors in self._actors_to_created_communicator_id:
- communicator_id = self._actors_to_created_communicator_id[actors]
- else:
- communicator_id = _init_communicator(
- list(actors),
- None,
- self._overlap_gpu_communication,
- accelerator_module_name,
- accelerator_communicator_cls,
- )
- self._actors_to_created_communicator_id[actors] = communicator_id
- collective_op.type_hint.set_communicator_id(communicator_id)
- # Finally, create a communicator for P2P operations.
- # Reuse an already created collective op communicator when p2p actors
- # are a subset of the actors in the collective op communicator.
- p2p_communicator_id = None
- if self._p2p_actors_with_unresolved_communicators:
- for (
- actors,
- communicator_id,
- ) in self._actors_to_created_communicator_id.items():
- if self._p2p_actors_with_unresolved_communicators.issubset(actors):
- p2p_communicator_id = communicator_id
- break
- if p2p_communicator_id is None:
- p2p_communicator_id = _init_communicator(
- list(self._p2p_actors_with_unresolved_communicators),
- None,
- self._overlap_gpu_communication,
- accelerator_module_name,
- accelerator_communicator_cls,
- )
- for dag_node in self._p2p_dag_nodes_with_unresolved_communicators:
- dag_node.type_hint.set_communicator_id(p2p_communicator_id)
- def _track_communicator_usage(
- self,
- dag_node: "ray.dag.DAGNode",
- actors: Set["ray.actor.ActorHandle"],
- collective_op: bool = False,
- ) -> None:
- """
- Track the usage of a communicator.
- This method first determines the communicator to use: if a custom
- communicator is specified, use it; if not and a default communicator
- is available, use it; otherwise, it records necessary information to
- create a new communicator later.
- This method also performs validation checks on the passed-in communicator.
- Args:
- dag_node: The DAG node that uses the communicator, this is the node
- that has the `with_tensor_transport()` type hint for p2p communication,
- or a `CollectiveOutputNode` for collective operations.
- actors: The full or partial set of actors that use the communicator.
- This method should be called one or multiple times so that all actors
- of the communicator are tracked.
- collective_op: Whether the communicator is used for a collective operation.
- """
- if None in actors:
- raise ValueError("Driver cannot participate in the communicator group.")
- if collective_op:
- type_hint = dag_node._collective_op.type_hint
- else:
- type_hint = dag_node.type_hint
- communicator = type_hint.get_custom_communicator()
- if communicator is None:
- if (
- self._default_communicator is None
- and not self._create_default_communicator
- ):
- if dag_node._original_type_hint is not None:
- assert isinstance(dag_node._original_type_hint, AutoTransportType)
- raise ValueError(
- f"with_tensor_transport(transport='auto') is used for DAGNode {dag_node}, "
- "This requires specifying a default communicator or 'create' for "
- "_default_communicator when calling experimental_compile()."
- )
- raise ValueError(
- f"DAGNode {dag_node} has no custom communicator specified. "
- "Please specify a custom communicator for the DAGNode using "
- "`with_tensor_transport()`, or specify a communicator or 'create' for "
- "_default_communicator when calling experimental_compile()."
- )
- communicator = self._default_communicator
- if communicator is None:
- if collective_op:
- self._collective_ops_with_unresolved_communicators.add(
- dag_node._collective_op
- )
- else:
- self._p2p_dag_nodes_with_unresolved_communicators.add(dag_node)
- self._p2p_actors_with_unresolved_communicators.update(actors)
- else:
- if collective_op:
- if set(communicator.get_actor_handles()) != actors:
- raise ValueError(
- "The passed-in communicator must have the same set "
- "of actors as the collective operation. "
- f"The passed-in communicator has actors {communicator.get_actor_handles()} "
- f"while the collective operation has actors {actors}."
- )
- else:
- if not actors.issubset(set(communicator.get_actor_handles())):
- raise ValueError(
- "The passed-in communicator must include all of the actors "
- "used in the P2P operation. "
- f"The passed-in communicator has actors {communicator.get_actor_handles()} "
- f"while the P2P operation has actors {actors}."
- )
- self._communicator_to_type_hints[communicator].add(type_hint)
- def _resolve_auto_transport(
- self,
- auto_transport_tasks: Set["CompiledTask"],
- ) -> None:
- """
- Resolve the auto transport type hint for the DAG.
- """
- type_hint_resolver = TypeHintResolver(self.actor_to_gpu_ids)
- # Resolve AutoChannelType type hints and track the actors that use accelerator.
- # This is needed so that the communicator group can be initialized for
- # these actors that use accelerator.
- for task in auto_transport_tasks:
- writer = task.dag_node._get_actor_handle()
- readers = task.downstream_task_idxs.values()
- writer_and_node = (writer, self._get_node_id(writer))
- reader_and_node_list = [
- (reader, self._get_node_id(reader)) for reader in readers
- ]
- # Update the type hint to the resolved one. This is needed because
- # the resolved type hint's `register_custom_serializer` will be called
- # in preparation for channel I/O.
- task.dag_node.type_hint = type_hint_resolver.resolve(
- task.dag_node.type_hint,
- writer_and_node,
- reader_and_node_list,
- )
- if task.dag_node.type_hint.requires_accelerator():
- self._track_communicator_usage(
- task.dag_node,
- set(readers).union({writer}),
- )
- def _check_leaf_nodes(self) -> None:
- """
- Check if there are leaf nodes in the DAG and raise an error if there are.
- """
- from ray.dag import (
- ClassMethodNode,
- DAGNode,
- )
- leaf_nodes: List[DAGNode] = []
- for _, task in self.idx_to_task.items():
- if not isinstance(task.dag_node, ClassMethodNode):
- continue
- if (
- len(task.downstream_task_idxs) == 0
- and not task.dag_node.is_cgraph_output_node
- ):
- leaf_nodes.append(task.dag_node)
- # Leaf nodes are not allowed because the exception thrown by the leaf
- # node will not be propagated to the driver.
- if len(leaf_nodes) != 0:
- raise ValueError(
- "Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have "
- "downstream nodes and are not output nodes. There are "
- f"{len(leaf_nodes)} leaf nodes in the DAG. Please add the outputs of "
- f"{[leaf_node.get_method_name() for leaf_node in leaf_nodes]} to the "
- f"the MultiOutputNode."
- )
- @staticmethod
- def _get_gpu_ids(actor_handle: "ray.actor.ActorHandle") -> List[str]:
- """
- Get the GPU IDs of an actor handle.
- """
- accelerator_ids = ray.get(
- actor_handle.__ray_call__.remote(
- lambda self: ray.get_runtime_context().get_accelerator_ids()
- )
- )
- return accelerator_ids.get("GPU", [])
- def _get_node_id(self, actor_handle: Optional["ray.actor.ActorHandle"]) -> str:
- """
- Get the node ID of an actor handle and cache it.
- Args:
- actor_handle: The actor handle, or None if the actor handle is the
- driver.
- Returns:
- The node ID of the actor handle or driver.
- """
- if actor_handle in self.actor_to_node_id:
- return self.actor_to_node_id[actor_handle]
- node_id = None
- if actor_handle == self._proxy_actor or actor_handle is None:
- node_id = ray.get_runtime_context().get_node_id()
- else:
- node_id = ray.get(
- actor_handle.__ray_call__.remote(
- lambda self: ray.get_runtime_context().get_node_id()
- )
- )
- self.actor_to_node_id[actor_handle] = node_id
- return node_id
- def _get_or_compile(
- self,
- ) -> None:
- """Compile an execution path. This allocates channels for adjacent
- tasks to send/receive values. An infinite task is submitted to each
- actor in the DAG that repeatedly receives from input channel(s) and
- sends to output channel(s).
- This function is idempotent and will cache the previously allocated
- channels. After calling this function, _dag_submitter and
- _dag_output_fetcher will be set and can be used to invoke and fetch
- outputs for the DAG.
- """
- from ray.dag import (
- ClassMethodNode,
- DAGNode,
- InputAttributeNode,
- InputNode,
- MultiOutputNode,
- )
- if self.input_task_idx is None:
- self._preprocess()
- assert self.input_task_idx is not None
- if self._dag_submitter is not None:
- assert self._dag_output_fetcher is not None
- return
- frontier = [self.input_task_idx]
- visited = set()
- # Create output buffers. This loop does a breadth-first search through the DAG.
- while frontier:
- cur_idx = frontier.pop(0)
- if cur_idx in visited:
- continue
- visited.add(cur_idx)
- task = self.idx_to_task[cur_idx]
- if (
- isinstance(task.dag_node, ClassMethodNode)
- and task.dag_node.is_class_method_call
- ):
- # Create output buffers for the actor method.
- assert len(task.output_channels) == 0
- # `output_to_readers` stores the reader tasks for each output of
- # the current node. If the current node returns one output, the
- # readers are the downstream nodes of the current node. If the
- # current node returns multiple outputs, the readers of each
- # output are the downstream nodes of the ClassMethodNode that
- # is a class method output.
- output_to_readers: Dict[CompiledTask, List[CompiledTask]] = defaultdict(
- list
- )
- for idx in task.downstream_task_idxs:
- downstream_task = self.idx_to_task[idx]
- downstream_node = downstream_task.dag_node
- if (
- isinstance(downstream_node, ClassMethodNode)
- and downstream_node.is_class_method_output
- ):
- output_to_readers[downstream_task] = [
- self.idx_to_task[idx]
- for idx in downstream_task.downstream_task_idxs
- ]
- else:
- if task not in output_to_readers:
- output_to_readers[task] = []
- output_to_readers[task].append(downstream_task)
- fn = task.dag_node._get_remote_method("__ray_call__")
- for output, readers in output_to_readers.items():
- reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]] = []
- # Use reader_handles_set to deduplicate readers on the
- # same actor, because with CachedChannel each actor will
- # only read from the upstream channel once.
- reader_handles_set = set()
- read_by_multi_output_node = False
- for reader in readers:
- if isinstance(reader.dag_node, MultiOutputNode):
- read_by_multi_output_node = True
- # inserting at 0 to make sure driver is first reader as
- # expected by CompositeChannel read
- reader_and_node_list.insert(
- 0,
- (
- self._proxy_actor,
- self._get_node_id(self._proxy_actor),
- ),
- )
- else:
- reader_handle = reader.dag_node._get_actor_handle()
- if reader_handle not in reader_handles_set:
- reader_handle = reader.dag_node._get_actor_handle()
- reader_and_node_list.append(
- (reader_handle, self._get_node_id(reader_handle))
- )
- reader_handles_set.add(reader_handle)
- # if driver is an actual actor, gets driver actor id
- driver_actor_id = (
- ray.get_runtime_context().get_actor_id()
- if read_by_multi_output_node
- else None
- )
- # Create an output channel for each output of the current node.
- output_channel = ray.get(
- fn.remote(
- do_allocate_channel,
- reader_and_node_list,
- task.dag_node.type_hint,
- driver_actor_id,
- )
- )
- output_idx = None
- downstream_node = output.dag_node
- if (
- isinstance(downstream_node, ClassMethodNode)
- and downstream_node.is_class_method_output
- ):
- output_idx = downstream_node.output_idx
- task.output_channels.append(output_channel)
- task.output_idxs.append(output_idx)
- task.output_node_idxs.append(self.dag_node_to_idx[downstream_node])
- actor_handle = task.dag_node._get_actor_handle()
- assert actor_handle is not None
- self.actor_to_tasks[actor_handle].append(task)
- elif (
- isinstance(task.dag_node, ClassMethodNode)
- and task.dag_node.is_class_method_output
- ):
- task_node = task.dag_node
- upstream_node = task_node.class_method_call
- assert upstream_node
- upstream_task = self.idx_to_task[self.dag_node_to_idx[upstream_node]]
- for i in range(len(upstream_task.output_channels)):
- if upstream_task.output_idxs[i] == task_node.output_idx:
- task.output_channels.append(upstream_task.output_channels[i])
- task.output_idxs.append(upstream_task.output_idxs[i])
- assert len(task.output_channels) == 1
- elif isinstance(task.dag_node, InputNode):
- # A dictionary that maps an InputNode or InputAttributeNode to its
- # readers and the node on which the reader is running. Use `set` to
- # deduplicate readers on the same actor because with CachedChannel
- # each actor will only read from the shared memory once.
- input_node_to_reader_and_node_set: Dict[
- Union[InputNode, InputAttributeNode],
- Set[Tuple["ray.actor.ActorHandle", str]],
- ] = defaultdict(set)
- for idx in task.downstream_task_idxs:
- reader_task = self.idx_to_task[idx]
- assert isinstance(reader_task.dag_node, ClassMethodNode)
- reader_handle = reader_task.dag_node._get_actor_handle()
- reader_node_id = self._get_node_id(reader_handle)
- for arg in reader_task.args:
- if isinstance(arg, InputAttributeNode) or isinstance(
- arg, InputNode
- ):
- input_node_to_reader_and_node_set[arg].add(
- (reader_handle, reader_node_id)
- )
- # A single channel is responsible for sending the same data to
- # corresponding consumers. Therefore, we create a channel for
- # each InputAttributeNode, or a single channel for the entire
- # input data if there are no InputAttributeNodes.
- task.output_channels = []
- for input_dag_node in input_node_to_reader_and_node_set:
- reader_and_node_list = list(
- input_node_to_reader_and_node_set[input_dag_node]
- )
- output_channel = do_allocate_channel(
- self,
- reader_and_node_list,
- input_dag_node.type_hint,
- None,
- )
- task.output_channels.append(output_channel)
- task.output_idxs.append(
- None
- if isinstance(input_dag_node, InputNode)
- else input_dag_node.key
- )
- # Update the InputAttributeNode's `output_channels`, which is
- # used to determine whether to create a CachedChannel.
- if isinstance(input_dag_node, InputAttributeNode):
- input_attr_idx = self.dag_node_to_idx[input_dag_node]
- input_attr_task = self.idx_to_task[input_attr_idx]
- input_attr_task.output_channels.append(output_channel)
- assert len(input_attr_task.output_channels) == 1
- else:
- assert isinstance(task.dag_node, InputAttributeNode) or isinstance(
- task.dag_node, MultiOutputNode
- )
- for idx in task.downstream_task_idxs:
- frontier.append(idx)
- # Validate input channels for tasks that have not been visited
- for node_idx, task in self.idx_to_task.items():
- if (
- node_idx == self.input_task_idx
- or node_idx == self.output_task_idx
- or isinstance(task.dag_node, InputAttributeNode)
- ):
- continue
- if node_idx not in visited:
- has_at_least_one_channel_input = False
- for arg in task.args:
- if isinstance(arg, DAGNode):
- has_at_least_one_channel_input = True
- if not has_at_least_one_channel_input:
- raise ValueError(
- "Compiled DAGs require each task to take a ray.dag.InputNode "
- "or at least one other DAGNode as an input. "
- "Invalid task node:\n"
- f"{task.dag_node}\n"
- "Please bind the task to proper DAG nodes."
- )
- from ray.dag.constants import RAY_CGRAPH_ENABLE_DETECT_DEADLOCK
- if RAY_CGRAPH_ENABLE_DETECT_DEADLOCK and self._detect_deadlock():
- raise ValueError(
- "This DAG cannot be compiled because it will deadlock on accelerator "
- "calls. If you believe this is a false positive, please disable "
- "the graph verification by setting the environment variable "
- "RAY_CGRAPH_ENABLE_DETECT_DEADLOCK to 0 and file an issue at "
- "https://github.com/ray-project/ray/issues/new/."
- )
- input_task = self.idx_to_task[self.input_task_idx]
- self.dag_input_channels = input_task.output_channels
- assert self.dag_input_channels is not None
- # Create executable tasks for each actor
- for actor_handle, tasks in self.actor_to_tasks.items():
- # Dict from arg to the set of tasks that consume it.
- arg_to_consumers: Dict[DAGNode, Set[CompiledTask]] = defaultdict(set)
- # Step 1: populate `arg_to_consumers` and perform some validation.
- for task in tasks:
- has_at_least_one_channel_input = False
- for arg in task.args:
- if isinstance(arg, DAGNode):
- has_at_least_one_channel_input = True
- arg_to_consumers[arg].add(task)
- arg_idx = self.dag_node_to_idx[arg]
- upstream_task = self.idx_to_task[arg_idx]
- assert len(upstream_task.output_channels) == 1
- arg_channel = upstream_task.output_channels[0]
- assert arg_channel is not None
- # TODO: Support no-input DAGs (use an empty object to signal).
- if not has_at_least_one_channel_input:
- raise ValueError(
- "Compiled DAGs require each task to take a "
- "ray.dag.InputNode or at least one other DAGNode as an "
- "input"
- )
- # Step 2: create cached channels if needed
- # Dict from original channel to the channel to be used in execution.
- # The value of this dict is either the original channel or a newly
- # created CachedChannel (if the original channel is read more than once).
- for arg, consumers in arg_to_consumers.items():
- arg_idx = self.dag_node_to_idx[arg]
- upstream_task = self.idx_to_task[arg_idx]
- assert len(upstream_task.output_channels) == 1
- arg_channel = upstream_task.output_channels[0]
- assert arg_channel is not None
- if len(consumers) > 1:
- self._channel_dict[arg_channel] = CachedChannel(
- len(consumers),
- arg_channel,
- )
- else:
- self._channel_dict[arg_channel] = arg_channel
- # Step 3: create executable tasks for the actor
- executable_tasks = []
- for task in tasks:
- resolved_args: List[Any] = []
- for arg in task.args:
- if isinstance(arg, DAGNode):
- arg_idx = self.dag_node_to_idx[arg]
- upstream_task = self.idx_to_task[arg_idx]
- assert len(upstream_task.output_channels) == 1
- arg_channel = upstream_task.output_channels[0]
- assert arg_channel is not None
- arg_channel = self._channel_dict[arg_channel]
- resolved_args.append(arg_channel)
- else:
- # Constant arg
- resolved_args.append(arg)
- executable_task = ExecutableTask(
- task,
- resolved_args,
- task.kwargs,
- )
- executable_tasks.append(executable_task)
- # Sort executable tasks based on their bind index, i.e., submission order
- # so that they will be executed in that order.
- executable_tasks.sort(key=lambda task: task.bind_index)
- self.actor_to_executable_tasks[actor_handle] = executable_tasks
- from ray.dag.constants import RAY_CGRAPH_ENABLE_PROFILING
- if RAY_CGRAPH_ENABLE_PROFILING:
- exec_task_func = do_profile_tasks
- else:
- exec_task_func = do_exec_tasks
- # Build an execution schedule for each actor
- self.actor_to_execution_schedule = self._build_execution_schedule()
- for actor_handle, executable_tasks in self.actor_to_executable_tasks.items():
- self.worker_task_refs[actor_handle] = actor_handle.__ray_call__.options(
- concurrency_group="_ray_system"
- ).remote(
- exec_task_func,
- executable_tasks,
- self.actor_to_execution_schedule[actor_handle],
- self._overlap_gpu_communication,
- )
- assert self.output_task_idx is not None
- self.dag_output_channels = []
- for output in self.idx_to_task[self.output_task_idx].args:
- assert isinstance(output, DAGNode)
- output_idx = self.dag_node_to_idx[output]
- task = self.idx_to_task[output_idx]
- assert len(task.output_channels) == 1
- self.dag_output_channels.append(task.output_channels[0])
- # Register custom serializers for input, input attribute, and output nodes.
- self._register_input_output_custom_serializer()
- assert self.dag_input_channels
- assert self.dag_output_channels
- assert [
- output_channel is not None for output_channel in self.dag_output_channels
- ]
- # If no MultiOutputNode was specified during the DAG creation, there is only
- # one output. Return a single output channel instead of a list of
- # channels.
- if not self._returns_list:
- assert len(self.dag_output_channels) == 1
- # Driver should ray.put on input, ray.get/release on output
- self._monitor = self._monitor_failures()
- input_task = self.idx_to_task[self.input_task_idx]
- if self._enable_asyncio:
- self._dag_submitter = AwaitableBackgroundWriter(
- self.dag_input_channels,
- input_task.output_idxs,
- is_input=True,
- )
- self._dag_output_fetcher = AwaitableBackgroundReader(
- self.dag_output_channels,
- self._fut_queue,
- )
- else:
- self._dag_submitter = SynchronousWriter(
- self.dag_input_channels, input_task.output_idxs, is_input=True
- )
- self._dag_output_fetcher = SynchronousReader(self.dag_output_channels)
- self._dag_submitter.start()
- self._dag_output_fetcher.start()
- def _generate_dag_operation_graph_node(
- self,
- ) -> Dict["ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]]:
- """
- Generate READ, COMPUTE, and WRITE operations for each DAG node.
- Returns:
- A dictionary that maps an actor handle to a list of lists of
- _DAGOperationGraphNode. For the same actor, the index of the
- outer list corresponds to the index of the ExecutableTask in
- the list of `executable_tasks` in `actor_to_executable_tasks`,
- i.e. `exec_task_idx`. In the inner list, the order of operations
- is READ, COMPUTE, and WRITE.
- Example:
- {
- actor1: [
- [READ COMPUTE WRITE] # exec_task_idx 0
- [READ COMPUTE WRITE] # exec_task_idx 1
- ]
- }
- """
- from ray.dag.collective_node import CollectiveOutputNode
- assert self.idx_to_task
- assert self.actor_to_executable_tasks
- actor_to_operation_nodes: Dict[
- "ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]
- ] = defaultdict(list)
- for actor_handle, executable_tasks in self.actor_to_executable_tasks.items():
- for exec_task_idx, exec_task in enumerate(executable_tasks):
- # Divide a DAG node into three _DAGOperationGraphNodes: READ, COMPUTE,
- # and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation.
- task_idx = exec_task.task_idx
- dag_node = self.idx_to_task[task_idx].dag_node
- method_name = exec_task.method_name
- actor_handle = dag_node._get_actor_handle()
- requires_accelerator_read = False
- for upstream_node in dag_node._upstream_nodes:
- if upstream_node.type_hint.requires_accelerator():
- requires_accelerator_read = True
- break
- requires_accelerator_compute = isinstance(
- dag_node, CollectiveOutputNode
- )
- requires_accelerator_write = dag_node.type_hint.requires_accelerator()
- read_node = _DAGOperationGraphNode(
- _DAGNodeOperation(
- exec_task_idx, _DAGNodeOperationType.READ, method_name
- ),
- task_idx,
- actor_handle,
- requires_accelerator_read,
- )
- compute_node = _DAGOperationGraphNode(
- _DAGNodeOperation(
- exec_task_idx, _DAGNodeOperationType.COMPUTE, method_name
- ),
- task_idx,
- actor_handle,
- requires_accelerator_compute,
- )
- write_node = _DAGOperationGraphNode(
- _DAGNodeOperation(
- exec_task_idx, _DAGNodeOperationType.WRITE, method_name
- ),
- task_idx,
- actor_handle,
- requires_accelerator_write,
- )
- actor_to_operation_nodes[actor_handle].append(
- [read_node, compute_node, write_node]
- )
- return actor_to_operation_nodes
- def _build_execution_schedule(
- self,
- ) -> Dict["ray.actor.ActorHandle", List[_DAGNodeOperation]]:
- """
- Generate an execution schedule for each actor. The schedule is a list of
- _DAGNodeOperation.
- Step 1: Generate a DAG node operation graph. Refer to the functions
- `_generate_dag_operation_graph_node` and `_build_dag_node_operation_graph`
- for more details.
- Step 2: Topological sort
- It is possible to have multiple _DAGOperationGraphNodes with zero in-degree.
- Refer to the function `_select_next_nodes` for the logic of selecting nodes.
- Then, put the selected nodes into the corresponding actors' schedules.
- The schedule should be intuitive to users, meaning that the execution should
- perform operations in ascending order of `bind_index` as much as possible.
- [Example]:
- See `test_execution_schedule` for more examples.
- Returns:
- actor_to_execution_schedule: A dictionary that maps an actor handle to
- the execution schedule which is a list of operations to be executed.
- """
- # Step 1: Build a graph of _DAGOperationGraphNode
- actor_to_operation_nodes = self._generate_dag_operation_graph_node()
- graph = _build_dag_node_operation_graph(
- self.idx_to_task, actor_to_operation_nodes
- )
- # Step 2: Generate an execution schedule for each actor using topological sort
- actor_to_execution_schedule = _generate_actor_to_execution_schedule(graph)
- # Step 3: Overlap GPU communication for the execution schedule if configured
- actor_to_overlapped_schedule = None
- if self._overlap_gpu_communication:
- actor_to_overlapped_schedule = _generate_overlapped_execution_schedule(
- actor_to_execution_schedule
- )
- if RAY_CGRAPH_VISUALIZE_SCHEDULE:
- _visualize_execution_schedule(
- actor_to_execution_schedule, actor_to_overlapped_schedule, graph
- )
- if actor_to_overlapped_schedule is not None:
- return _extract_execution_schedule(actor_to_overlapped_schedule)
- else:
- return _extract_execution_schedule(actor_to_execution_schedule)
- def _detect_deadlock(self) -> bool:
- """
- TODO (kevin85421): Avoid false negatives.
- Currently, a compiled graph may deadlock if there are accelerator channels,
- and the readers have control dependencies on the same actor. For example:
- actor1.a ---> actor2.f1
- |
- ---> actor2.f2
- The control dependency between `actor2.f1` and `actor2.f2` is that `f1` should
- run before `f2`. If `actor1.a` writes to `actor2.f2` before `actor2.f1`, a
- deadlock will occur.
- Currently, the execution schedule is not granular enough to detect this
- deadlock.
- Returns:
- True if a deadlock is detected; otherwise, False.
- """
- logger.debug("Deadlock detection has not been implemented yet.")
- return False
- def _monitor_failures(self):
- get_outer = weakref.ref(self)
- class Monitor(threading.Thread):
- def __init__(self):
- super().__init__(daemon=True)
- self.name = "CompiledGraphMonitorThread"
- # Lock to make sure that we only perform teardown for this DAG
- # once.
- self._in_teardown_lock = threading.Lock()
- self._teardown_done = False
- def _outer_ref_alive(self) -> bool:
- if get_outer() is None:
- logger.error(
- "CompiledDAG has been destructed before teardown. "
- "This should not occur please report an issue at "
- "https://github.com/ray-project/ray/issues/new/.",
- stack_info=True,
- )
- return False
- return True
- def wait_teardown(self, kill_actors: bool = False):
- outer = get_outer()
- if not self._outer_ref_alive():
- return
- from ray.dag import DAGContext
- ctx = DAGContext.get_current()
- teardown_timeout = ctx.teardown_timeout
- for actor, ref in outer.worker_task_refs.items():
- timeout = False
- try:
- ray.get(ref, timeout=teardown_timeout)
- except ray.exceptions.GetTimeoutError:
- msg = (
- f"Compiled DAG actor {actor} is still running "
- f"{teardown_timeout}s after teardown()."
- )
- if kill_actors:
- msg += (
- " Force-killing actor. "
- "Increase RAY_CGRAPH_teardown_timeout if you want "
- "teardown to wait longer."
- )
- ray.kill(actor)
- else:
- msg += (
- " Teardown may hang. "
- "Call teardown with kill_actors=True if force kill "
- "is desired."
- )
- logger.warning(msg)
- timeout = True
- except Exception:
- # We just want to check that the task has finished so
- # we don't care if the actor task ended in an
- # exception.
- pass
- if not timeout:
- continue
- try:
- ray.get(ref)
- except Exception:
- pass
- if kill_actors:
- # In the previous loop, we allow the actor tasks to exit first.
- # Now, we force kill the actors if not yet.
- for actor in outer.worker_task_refs:
- logger.info(f"Killing actor: {actor}")
- ray.kill(actor)
- def teardown(self, kill_actors: bool = False):
- with self._in_teardown_lock:
- if self._teardown_done:
- return
- outer = get_outer()
- if not self._outer_ref_alive():
- return
- logger.info("Tearing down compiled DAG")
- outer._dag_submitter.close()
- outer._dag_output_fetcher.close()
- for actor in outer.actor_to_executable_tasks.keys():
- logger.info(f"Cancelling compiled worker on actor: {actor}")
- # Cancel all actor loops in parallel.
- cancel_refs = [
- actor.__ray_call__.remote(do_cancel_executable_tasks, tasks)
- for actor, tasks in outer.actor_to_executable_tasks.items()
- ]
- for cancel_ref in cancel_refs:
- try:
- ray.get(cancel_ref, timeout=30)
- except RayChannelError:
- # Channel error happens when a channel is closed
- # or timed out. In this case, do not log.
- pass
- except Exception:
- logger.exception("Error cancelling worker task")
- pass
- for (
- communicator_id
- ) in outer._actors_to_created_communicator_id.values():
- _destroy_communicator(communicator_id)
- logger.info("Waiting for worker tasks to exit")
- self.wait_teardown(kill_actors=kill_actors)
- logger.info("Teardown complete")
- self._teardown_done = True
- def run(self):
- try:
- outer = get_outer()
- if not self._outer_ref_alive():
- return
- ray.get(list(outer.worker_task_refs.values()))
- except KeyboardInterrupt:
- logger.info(
- "Received KeyboardInterrupt, tearing down with kill_actors=True"
- )
- self.teardown(kill_actors=True)
- except Exception as e:
- logger.debug(f"Handling exception from worker tasks: {e}")
- self.teardown()
- monitor = Monitor()
- monitor.start()
- return monitor
- def _raise_if_too_many_inflight_executions(self):
- num_inflight_executions = (
- self._execution_index - self._max_finished_execution_index
- )
- if num_inflight_executions >= self._max_inflight_executions:
- raise ray.exceptions.RayCgraphCapacityExceeded(
- "The compiled graph can't have more than "
- f"{self._max_inflight_executions} in-flight executions, and you "
- f"currently have {num_inflight_executions} in-flight executions. "
- "Retrieve an output using ray.get before submitting more requests or "
- "increase `_max_inflight_executions`. "
- "`dag.experimental_compile(_max_inflight_executions=...)`"
- )
- def _has_execution_results(
- self,
- execution_index: int,
- ) -> bool:
- """Check whether there are results corresponding to the given execution
- index stored in self._result_buffer. This helps avoid fetching and
- caching results again.
- Args:
- execution_index: The execution index corresponding to the result.
- Returns:
- Whether the result for the given index has been fetched and cached.
- """
- return execution_index in self._result_buffer
- def _cache_execution_results(
- self,
- execution_index: int,
- result: Any,
- ):
- """Cache execution results in self._result_buffer. Results are converted
- to dictionary format to allow efficient element removal and calculation of
- the buffer size. This can only be called once per execution index.
- Args:
- execution_index: The execution index corresponding to the result.
- result: The results from all channels to be cached.
- """
- if not self._has_execution_results(execution_index):
- for chan_idx, res in enumerate(result):
- # avoid caching for any CompiledDAGRef that has already been destructed.
- if not (
- execution_index in self._destructed_ref_idxs
- and chan_idx in self._destructed_ref_idxs[execution_index]
- ):
- self._result_buffer[execution_index][chan_idx] = res
- def _get_execution_results(
- self, execution_index: int, channel_index: Optional[int]
- ) -> List[Any]:
- """Retrieve execution results from self._result_buffer and return the result.
- Results are converted back to original list format ordered by output channel
- index.
- Args:
- execution_index: The execution index to retrieve results from.
- channel_index: The index of the output channel corresponding to the result.
- Channel indexing is consistent with the order of
- self.dag_output_channels. None means that the result wraps outputs from
- all output channels.
- Returns:
- The execution result corresponding to the given execution index and channel
- index.
- """
- # Although CompiledDAGRef and CompiledDAGFuture guarantee that the same
- # execution index and channel index combination will not be requested multiple
- # times and therefore self._result_buffer will always have execution_index as
- # a key, we still do a sanity check to avoid misuses.
- assert execution_index in self._result_buffer
- if channel_index is None:
- # Convert results stored in self._result_buffer back to original
- # list representation
- result = [
- kv[1]
- for kv in sorted(
- self._result_buffer.pop(execution_index).items(),
- key=lambda kv: kv[0],
- )
- ]
- else:
- result = [self._result_buffer[execution_index].pop(channel_index)]
- if execution_index not in self._got_ref_idxs:
- self._got_ref_idxs[execution_index] = set()
- self._got_ref_idxs[execution_index].add(channel_index)
- self._clean_up_buffers(execution_index)
- return result
- def _delete_execution_results(self, execution_index: int, channel_index: int):
- """
- Delete the execution results for the given execution index and channel index.
- This method should be called when a CompiledDAGRef or CompiledDAGFuture is
- destructed.
- Note that this method maintains metadata for the deleted execution results,
- and only actually deletes the buffers lazily when the buffer is not needed
- anymore.
- Args:
- execution_index: The execution index to destruct results from.
- channel_index: The index of the output channel corresponding to the result.
- """
- if execution_index not in self._destructed_ref_idxs:
- self._destructed_ref_idxs[execution_index] = set()
- self._destructed_ref_idxs[execution_index].add(channel_index)
- self._clean_up_buffers(execution_index)
- def _try_release_result_buffer(self, execution_index: int):
- """
- Try to release the result buffer for the given execution index.
- """
- should_release = False
- got_channel_idxs = self._got_ref_idxs.get(execution_index, set())
- if None in got_channel_idxs:
- assert len(got_channel_idxs) == 1, (
- "when None exists in got_channel_idxs, it means all channels, and "
- "it should be the only value in the set",
- )
- should_release = True
- else:
- destructed_channel_idxs = self._destructed_ref_idxs.get(
- execution_index, set()
- )
- processed_channel_idxs = got_channel_idxs.union(destructed_channel_idxs)
- # No more processing is needed for this execution index.
- should_release = processed_channel_idxs == set(
- range(len(self.dag_output_channels))
- )
- if not should_release:
- return False
- self._result_buffer.pop(execution_index, None)
- self._destructed_ref_idxs.pop(execution_index, None)
- self._got_ref_idxs.pop(execution_index, None)
- return True
- def _try_release_native_buffer(
- self, idx_to_release: int, timeout: Optional[float] = None
- ) -> bool:
- """
- Try to release the native buffer for the given execution index.
- Args:
- idx_to_release: The execution index to release buffers from.
- timeout: The maximum time in seconds to wait for the release.
- Returns:
- Whether the buffers have been released.
- """
- if idx_to_release != self._max_finished_execution_index + 1:
- # Native buffer can only be released for the next execution index.
- return False
- destructed_channel_idxs = self._destructed_ref_idxs.get(idx_to_release, set())
- should_release = False
- if None in destructed_channel_idxs:
- assert len(destructed_channel_idxs) == 1, (
- "when None exists in destructed_channel_idxs, it means all channels, "
- "and it should be the only value in the set",
- )
- should_release = True
- elif len(destructed_channel_idxs) == len(self.dag_output_channels):
- should_release = True
- if not should_release:
- return False
- # refs corresponding to idx_to_release are all destructed,
- # and they are never fetched or cached.
- assert idx_to_release not in self._result_buffer
- assert idx_to_release not in self._got_ref_idxs
- try:
- self._dag_output_fetcher.release_channel_buffers(timeout)
- except RayChannelTimeoutError as e:
- raise RayChannelTimeoutError(
- "Releasing native buffers corresponding to a stale CompiledDAGRef "
- "is taking a long time. If this is expected, increase "
- f"RAY_CGRAPH_get_timeout which is currently {self._get_timeout} "
- "seconds. Otherwise, this may indicate that the execution "
- "is hanging."
- ) from e
- self._destructed_ref_idxs.pop(idx_to_release)
- return True
- def _try_release_buffer(
- self, idx_to_release: int, timeout: Optional[float] = None
- ) -> bool:
- """
- Try to release the buffer for the given execution index.
- First try to release the native buffer, then try to release the result buffer.
- Args:
- idx_to_release: The execution index to release buffers from.
- timeout: The maximum time in seconds to wait for the release.
- Returns:
- Whether the native buffer or result buffer has been released.
- """
- if self._try_release_native_buffer(idx_to_release, timeout):
- # Releasing native buffer means the corresponding execution result
- # is consumed (and discarded).
- self._max_finished_execution_index += 1
- return True
- return self._try_release_result_buffer(idx_to_release)
- def _try_release_buffers(self):
- """
- Repeatedly release buffer if possible.
- This method starts from _max_finished_execution_index + 1 and tries to release
- as many buffers as possible. If a native buffer is released,
- _max_finished_execution_index will be incremented.
- """
- timeout = self._get_timeout
- while True:
- start_time = time.monotonic()
- if not self._try_release_buffer(
- self._max_finished_execution_index + 1, timeout
- ):
- break
- if timeout != -1:
- timeout -= time.monotonic() - start_time
- timeout = max(timeout, 0)
- def _clean_up_buffers(self, idx_to_release: int):
- """
- Clean up native and result buffers.
- This method:
- 1. Tries to release the buffer for the given execution index.
- This index is the specific one that requires a clean up,
- e.g., right after get() is called or a CompiledDAGRef/CompiledDAGFuture
- is destructed.
- 2. Tries to release all buffers starting from _max_finished_execution_index + 1.
- This step is to clean up buffers that are no longer needed.
- Args:
- idx_to_release: The execution index that requires a clean up,
- e.g., right after get() is called or a CompiledDAGRef/CompiledDAGFuture
- is destructed.
- """
- self._try_release_buffer(idx_to_release)
- self._try_release_buffers()
- def _execute_until(
- self,
- execution_index: int,
- channel_index: Optional[int] = None,
- timeout: Optional[float] = None,
- ):
- """Repeatedly execute this DAG until the given execution index and
- buffer results for all CompiledDagRef's.
- If the DAG has already been executed up to the given index, it will do nothing.
- Note: If this comes across execution indices for which the corresponding
- CompiledDAGRef's have been destructed, it will release the buffer and not
- cache the result.
- Args:
- execution_index: The execution index to execute until.
- channel_index: The index of the output channel to get the result from.
- Channel indexing is consistent with the order of
- self.dag_output_channels. None means wrapping results from all output
- channels into a single list.
- timeout: The maximum time in seconds to wait for the execution.
- None means using default timeout (DAGContext.get_timeout),
- 0 means immediate timeout (immediate success or timeout without
- blocking), -1 means infinite timeout (block indefinitely).
- TODO(rui): catch the case that user holds onto the CompiledDAGRefs
- """
- if timeout is None:
- timeout = self._get_timeout
- while self._max_finished_execution_index < execution_index:
- if len(self._result_buffer) >= self._max_buffered_results:
- raise RayCgraphCapacityExceeded(
- "The compiled graph can't have more than "
- f"{self._max_buffered_results} buffered results, and you "
- f"currently have {len(self._result_buffer)} buffered results. "
- "Call `ray.get()` on CompiledDAGRef's (or await on "
- "CompiledDAGFuture's) to retrieve results, or increase "
- f"`_max_buffered_results` if buffering is desired, note that "
- "this will increase driver memory usage."
- )
- start_time = time.monotonic()
- # Fetch results from each output channel up to execution_index and cache
- # them separately to enable individual retrieval
- # If a CompiledDagRef for a specific execution index has been destructed,
- # release the channel buffers for that execution index instead of caching
- try:
- if not self._try_release_native_buffer(
- self._max_finished_execution_index + 1, timeout
- ):
- result = self._dag_output_fetcher.read(timeout)
- self._cache_execution_results(
- self._max_finished_execution_index + 1,
- result,
- )
- # We have either released the native buffer or fetched and
- # cached the result buffer, therefore we always increment
- # _max_finished_execution_index.
- self._max_finished_execution_index += 1
- except RayChannelTimeoutError as e:
- raise RayChannelTimeoutError(
- "If the execution is expected to take a long time, increase "
- f"RAY_CGRAPH_get_timeout which is currently {self._get_timeout} "
- "seconds. Otherwise, this may indicate that the execution is "
- "hanging."
- ) from e
- if timeout != -1:
- timeout -= time.monotonic() - start_time
- timeout = max(timeout, 0)
- def execute(
- self,
- *args,
- **kwargs,
- ) -> Union[CompiledDAGRef, List[CompiledDAGRef]]:
- """Execute this DAG using the compiled execution path.
- Args:
- args: Args to the InputNode.
- kwargs: Kwargs to the InputNode
- Returns:
- A list of Channels that can be used to read the DAG result.
- Raises:
- RayChannelTimeoutError: If the execution does not complete within
- self._submit_timeout seconds.
- NOTE: Not thread-safe due to _execution_index etc.
- """
- if self._enable_asyncio:
- raise ValueError("Use execute_async if enable_asyncio=True")
- self._get_or_compile()
- self._check_inputs(args, kwargs)
- if len(args) == 1 and len(kwargs) == 0:
- # When serializing a tuple, the Ray serializer invokes pickle5, which adds
- # several microseconds of overhead. One common case for Compiled Graphs is
- # passing a single argument (oftentimes of of type `bytes`, which requires
- # no serialization). To avoid imposing this overhead on this common case, we
- # create a fast path for this case that avoids pickle5.
- inp = args[0]
- else:
- inp = CompiledDAGArgs(args=args, kwargs=kwargs)
- # We want to release any buffers we can at this point based on the
- # max_finished_execution_index so that the number of inflight executions
- # is up to date.
- self._try_release_buffers()
- self._raise_if_too_many_inflight_executions()
- try:
- self._dag_submitter.write(inp, self._submit_timeout)
- except RayChannelTimeoutError as e:
- raise RayChannelTimeoutError(
- "If the execution is expected to take a long time, increase "
- f"RAY_CGRAPH_submit_timeout which is currently {self._submit_timeout} "
- "seconds. Otherwise, this may indicate that execution is hanging."
- ) from e
- self._execution_index += 1
- if self._returns_list:
- ref = [
- CompiledDAGRef(self, self._execution_index, channel_index)
- for channel_index in range(len(self.dag_output_channels))
- ]
- else:
- ref = CompiledDAGRef(self, self._execution_index)
- return ref
- def _check_inputs(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
- """
- Helper method to check that the DAG args provided by the user during
- execution are valid according to the defined DAG.
- """
- if len(args) != self._input_num_positional_args:
- raise ValueError(
- "dag.execute() or dag.execute_async() must be "
- f"called with {self._input_num_positional_args} positional args, got "
- f"{len(args)}"
- )
- for kwarg in self._input_kwargs:
- if kwarg not in kwargs:
- raise ValueError(
- "dag.execute() or dag.execute_async() "
- f"must be called with kwarg `{kwarg}`"
- )
- async def execute_async(
- self,
- *args,
- **kwargs,
- ) -> Union[CompiledDAGFuture, List[CompiledDAGFuture]]:
- """Execute this DAG using the compiled execution path.
- NOTE: Not thread-safe.
- Args:
- args: Args to the InputNode.
- kwargs: Kwargs to the InputNode.
- Returns:
- A list of Channels that can be used to read the DAG result.
- """
- if not self._enable_asyncio:
- raise ValueError("Use execute if enable_asyncio=False")
- self._get_or_compile()
- self._check_inputs(args, kwargs)
- async with self._dag_submission_lock:
- if len(args) == 1 and len(kwargs) == 0:
- # When serializing a tuple, the Ray serializer invokes pickle5, which
- # adds several microseconds of overhead. One common case for accelerated
- # DAGs is passing a single argument (oftentimes of of type `bytes`,
- # which requires no serialization). To avoid imposing this overhead on
- # this common case, we create a fast path for this case that avoids
- # pickle5.
- inp = args[0]
- else:
- inp = CompiledDAGArgs(args=args, kwargs=kwargs)
- self._raise_if_too_many_inflight_executions()
- await self._dag_submitter.write(inp)
- # Allocate a future that the caller can use to get the result.
- fut = asyncio.Future()
- await self._fut_queue.put(fut)
- self._execution_index += 1
- if self._returns_list:
- fut = [
- CompiledDAGFuture(self, self._execution_index, fut, channel_index)
- for channel_index in range(len(self.dag_output_channels))
- ]
- else:
- fut = CompiledDAGFuture(self, self._execution_index, fut)
- return fut
- def _visualize_ascii(self) -> str:
- """
- Visualize the compiled graph in
- ASCII format with directional markers.
- This function generates an ASCII visualization of a Compiled Graph,
- where each task node is labeled,
- and edges use `<` and `>` markers to show data flow direction.
- This method is called by:
- - `compiled_dag.visualize(format="ascii")`
- High-Level Algorithm:
- - Topological Sorting: Sort nodes topologically to organize
- them into layers based on dependencies.
- - Grid Initialization: Set up a 2D grid canvas with dimensions based
- on the number of layers and the maximum number of nodes per layer.
- - Node Placement: Position each node on the grid according to its
- layer and relative position within that layer.
- Spacing is added for readability, and directional markers (`<` and `>`)
- are added to edges to show input/output flow clearly.
- This method should be called
- **after** compiling the graph with `experimental_compile()`.
- Returns:
- ASCII representation of the CG with Nodes Information,
- Edges Information and Graph Built.
- Limitations:
- - Note: This is only used for quick visualization for small graphs.
- For complex graph (i.e. more than 20 tasks), please use graphviz.
- - Scale: Works best for smaller CGs (typically fewer than 20 tasks).
- Larger CGs may result in dense, less readable ASCII
- outputs due to limited space for node and edge rendering.
- - Shape: Ideal for relatively shallow CGs with clear dependency paths.
- For deep, highly branched or densely connected CGs,
- readability may suffer.
- - Edge Overlap: In cases with high fan-out (i.e., nodes with many children)
- or fan-in (nodes with many parents), edge lines may intersect or overlap
- in the ASCII visualization, potentially obscuring some connections.
- - Multi-output Tasks: Multi-output tasks can be visualized, but positioning
- may cause line breaks or overlap when a task has multiple outputs that
- feed into nodes at varying depths.
- Example:
- Basic Visualization:
- ```python
- # Print the CG structure in ASCII format
- print(compiled_dag.visualize(format="ascii"))
- ```
- Example of Ordered Visualization (task is build in order
- to reduce line intersection):
- ```python
- with InputNode() as i:
- o1, o2, o3 = a.return_three.bind(i)
- o4 = b.echo.bind(o1)
- o5 = b.echo.bind(o2)
- o6, o7 = b.return_two.bind(o3)
- dag = MultiOutputNode([o4, o5, o6, o7])
- compiled_dag = dag.experimental_compile()
- compiled_dag.visualize(format="ascii",view=True)
- # Output:
- # 0:InputNode
- # |
- # 1:Actor_54777d:return_three
- # |---------------------------->|---------------------------->| # noqa
- # 2:Output[0] 3:Output[1] 4:Output[2] # noqa
- # | | | # noqa
- # 5:Actor_c927c9:echo 6:Actor_c927c9:echo 7:Actor_c927c9:return_two # noqa
- # | | |---------------------------->| # noqa
- # | | 9:Output[0] 10:Output[1] # noqa
- # |<----------------------------|-----------------------------|-----------------------------| # noqa
- # 8:MultiOutputNode
- ```
- Example of Anti-pattern Visualization (There are intersections):
- # We can swtich the nodes ordering to reduce intersections, i.e. swap o2 and o3
- ```python
- with InputNode() as i:
- o1, o2, o3 = a.return_three.bind(i)
- o4 = b.echo.bind(o1)
- o5 = b.echo.bind(o3)
- o6, o7 = b.return_two.bind(o2)
- dag = MultiOutputNode([o4, o5, o6, o7])
- compiled_dag = dag.experimental_compile()
- compiled_dag.visualize(format="ascii",view=True)
- # Output (Nodes 5, 7, 9, 10 should connect to Node 8):
- # 0:InputNode
- # |
- # 1:Actor_84835a:return_three
- # |---------------------------->|---------------------------->| # noqa
- # 2:Output[0] 3:Output[1] 4:Output[2] # noqa
- # | | | # noqa
- # 5:Actor_02a6a1:echo 6:Actor_02a6a1:return_two 7:Actor_02a6a1:echo # noqa
- # | |---------------------------->| # noqa
- # | 9:Output[0] 10:Output[1] # noqa
- # |<----------------------------------------------------------| # noqa
- # 8:MultiOutputNode
- ```
- """
- from ray.dag import (
- ClassMethodNode,
- DAGNode,
- InputAttributeNode,
- InputNode,
- MultiOutputNode,
- )
- # Check that the DAG has been compiled
- if not hasattr(self, "idx_to_task") or not self.idx_to_task:
- raise ValueError(
- "The DAG must be compiled before calling 'visualize()'. "
- "Please call 'experimental_compile()' first."
- )
- # Check that each CompiledTask has a valid dag_node
- for idx, task in self.idx_to_task.items():
- if not hasattr(task, "dag_node") or not isinstance(task.dag_node, DAGNode):
- raise ValueError(
- f"Task at index {idx} does not have a valid 'dag_node'. "
- "Ensure that 'experimental_compile()' completed successfully."
- )
- from collections import defaultdict, deque
- # Create adjacency list representation of the DAG
- # Adjacency list for DAG; maps a node index to its downstream nodes.
- adj_list: Dict[int, List[int]] = defaultdict(list)
- # Indegree count for topological sorting; maps a node index to its indegree.
- indegree: Dict[int, int] = defaultdict(int)
- # Tracks whether a node is a multi-output node.
- is_multi_output: Dict[int, bool] = defaultdict(bool)
- # Maps child node indices to their parent node indices.
- child2parent: Dict[int, int] = defaultdict(int)
- ascii_visualization = ""
- # Node information; maps a node index to its descriptive label.
- node_info: Dict[int, str] = {}
- # Edge information; tuples of (upstream_index, downstream_index, edge_label).
- edge_info: List[Tuple[int, int, str]] = []
- for idx, task in self.idx_to_task.items():
- dag_node = task.dag_node
- label = f"Task {idx} "
- # Determine the type and label of the node
- if isinstance(dag_node, InputNode):
- label += "InputNode"
- elif isinstance(dag_node, InputAttributeNode):
- label += f"InputAttributeNode[{dag_node.key}]"
- elif isinstance(dag_node, MultiOutputNode):
- label += "MultiOutputNode"
- elif isinstance(dag_node, ClassMethodNode):
- if dag_node.is_class_method_call:
- method_name = dag_node.get_method_name()
- actor_handle = dag_node._get_actor_handle()
- actor_id = (
- actor_handle._actor_id.hex()[:6] if actor_handle else "unknown"
- )
- label += f"Actor: {actor_id}... Method: {method_name}"
- elif dag_node.is_class_method_output:
- label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
- else:
- label += "ClassMethodNode"
- else:
- label += type(dag_node).__name__
- node_info[idx] = label
- for arg_index, arg in enumerate(dag_node.get_args()):
- if isinstance(arg, DAGNode):
- upstream_task_idx = self.dag_node_to_idx[arg]
- # Get the type hint for this argument
- if arg_index < len(task.arg_type_hints):
- if task.arg_type_hints[arg_index].requires_accelerator():
- type_hint = "Accelerator"
- else:
- type_hint = type(task.arg_type_hints[arg_index]).__name__
- else:
- type_hint = "UnknownType"
- adj_list[upstream_task_idx].append(idx)
- indegree[idx] += 1
- edge_info.append((upstream_task_idx, idx, type_hint))
- width_adjust = 0
- for upstream_task_idx, child_idx_list in adj_list.items():
- # Mark as multi-output if the node has more than one output path
- if len(child_idx_list) > 1:
- for child in child_idx_list:
- is_multi_output[child] = True
- child2parent[child] = upstream_task_idx
- width_adjust = max(width_adjust, len(child_idx_list))
- # Topological sort to determine layers
- layers = defaultdict(list)
- zero_indegree = deque([idx for idx in self.idx_to_task if indegree[idx] == 0])
- layer_index = 0
- while zero_indegree:
- next_layer = deque()
- while zero_indegree:
- task_idx = zero_indegree.popleft()
- layers[layer_index].append(task_idx)
- for downstream in adj_list[task_idx]:
- indegree[downstream] -= 1
- if indegree[downstream] == 0:
- next_layer.append(downstream)
- zero_indegree = next_layer
- layer_index += 1
- # Print detailed node information
- ascii_visualization += "Nodes Information:\n"
- for idx, info in node_info.items():
- ascii_visualization += f'{idx} [label="{info}"] \n'
- # Print edges
- ascii_visualization += "\nEdges Information:\n"
- for upstream_task, downstream_task, type_hint in edge_info:
- if type_hint == "Accelerator":
- edgs_channel = "+++"
- else:
- edgs_channel = "---"
- ascii_visualization += (
- f"{upstream_task} {edgs_channel}>" f" {downstream_task}\n"
- )
- # Add the legend to the output
- ascii_visualization += "\nLegend:\n"
- ascii_visualization += "+++> : Represents Accelerator-type data channels\n"
- ascii_visualization += "---> : Represents Shared Memory data channels\n"
- # Find the maximum width (number of nodes in any layer)
- max_width = max(len(layer) for layer in layers.values()) + width_adjust
- height = len(layers)
- # Build grid for ASCII visualization
- grid = [[" " for _ in range(max_width * 20)] for _ in range(height * 2 - 1)]
- # Place nodes in the grid with more details
- task_to_pos = {}
- for layer_num, layer_tasks in layers.items():
- layer_y = layer_num * 2 # Every second row is for nodes
- for col_num, task_idx in enumerate(layer_tasks):
- task = self.idx_to_task[task_idx]
- task_info = f"{task_idx}:"
- # Determine if it's an actor method or a regular task
- if isinstance(task.dag_node, ClassMethodNode):
- if task.dag_node.is_class_method_call:
- method_name = task.dag_node.get_method_name()
- actor_handle = task.dag_node._get_actor_handle()
- actor_id = (
- actor_handle._actor_id.hex()[:6]
- if actor_handle
- else "unknown"
- )
- task_info += f"Actor_{actor_id}:{method_name}"
- elif task.dag_node.is_class_method_output:
- task_info += f"Output[{task.dag_node.output_idx}]"
- else:
- task_info += "UnknownMethod"
- else:
- task_info += type(task.dag_node).__name__
- adjust_col_num = 0
- if task_idx in is_multi_output:
- adjust_col_num = layers[layer_num - 1].index(child2parent[task_idx])
- col_x = (col_num + adjust_col_num) * 30 # Every 30th column for spacing
- # Place the task information into the grid
- for i, char in enumerate(task_info):
- if col_x + i < len(grid[0]): # Ensure we don't overflow the grid
- grid[layer_y][col_x + i] = char
- task_to_pos[task_idx] = (layer_y, col_x)
- # Connect the nodes with lines
- for upstream_task, downstream_tasks in adj_list.items():
- upstream_y, upstream_x = task_to_pos[upstream_task]
- for downstream_task in downstream_tasks:
- downstream_y, downstream_x = task_to_pos[downstream_task]
- # Draw vertical line
- for y in range(upstream_y + 1, downstream_y):
- if grid[y][upstream_x] == " ":
- grid[y][upstream_x] = "|"
- # Draw horizontal line with directional arrows
- if upstream_x != downstream_x:
- for x in range(
- min(upstream_x, downstream_x) + 1,
- max(upstream_x, downstream_x),
- ):
- grid[downstream_y - 1][x] = (
- "-"
- if grid[downstream_y - 1][x] == " "
- else grid[downstream_y - 1][x]
- )
- # Add arrows to indicate flow direction
- if downstream_x > upstream_x:
- grid[downstream_y - 1][downstream_x - 1] = ">"
- else:
- grid[downstream_y - 1][downstream_x + 1] = "<"
- # Draw connection to the next task
- grid[downstream_y - 1][downstream_x] = "|"
- # Ensure proper multi-output task connection
- for idx, task in self.idx_to_task.items():
- if isinstance(task.dag_node, MultiOutputNode):
- output_tasks = task.dag_node.get_args()
- for i, output_task in enumerate(output_tasks):
- if isinstance(output_task, DAGNode):
- output_task_idx = self.dag_node_to_idx[output_task]
- if output_task_idx in task_to_pos:
- output_y, output_x = task_to_pos[output_task_idx]
- grid[output_y - 1][output_x] = "|"
- # Convert grid to string for printing
- ascii_visualization += "\nGraph Built:\n"
- ascii_visualization += "\n".join("".join(row) for row in grid)
- return ascii_visualization
- def get_channel_details(
- self, channel: ChannelInterface, downstream_actor_id: str
- ) -> str:
- """
- Get details about outer and inner channel types and channel ids
- based on the channel and the downstream actor ID.
- Used for graph visualization.
- Args:
- channel: The channel to get details for.
- downstream_actor_id: The downstream actor ID.
- Returns:
- A string with details about the channel based on its connection
- to the actor provided.
- """
- channel_details = type(channel).__name__
- # get outer channel
- if channel in self._channel_dict and self._channel_dict[channel] != channel:
- channel = self._channel_dict[channel]
- channel_details += f"\n{type(channel).__name__}"
- if type(channel) is CachedChannel:
- channel_details += f", {channel._channel_id[:6]}..."
- # get inner channel
- if (
- type(channel) is CompositeChannel
- and downstream_actor_id in channel._channel_dict
- ):
- inner_channel = channel._channel_dict[downstream_actor_id]
- channel_details += f"\n{type(inner_channel).__name__}"
- if type(inner_channel) is IntraProcessChannel:
- channel_details += f", {inner_channel._channel_id[:6]}..."
- return channel_details
- def visualize(
- self,
- filename="compiled_graph",
- format="png",
- view=False,
- channel_details=False,
- ) -> str:
- """
- Visualize the compiled graph by showing tasks and their dependencies.
- This method should be called **after** the graph has been compiled using
- `experimental_compile()`.
- Args:
- filename: For non-ASCII formats, the output file name (without extension).
- For ASCII format, the visualization will be printed to the console,
- and this argument is ignored.
- format: The format of the output file (e.g., 'png', 'pdf', 'ascii').
- view: For non-ASCII formats: Whether to open the file with the default
- viewer. For ASCII format: Whether to print the visualization and return
- None or return the ascii visualization string directly.
- channel_details: If True, adds channel details to edges.
- Returns:
- The string representation of the compiled graph. For Graphviz-based formats
- (e.g., 'png', 'pdf', 'jpeg'), returns the Graphviz DOT string representation
- of the compiled graph. For ASCII format, returns the ASCII string
- representation of the compiled graph.
- Raises:
- ValueError: If the graph is empty or not properly compiled.
- ImportError: If the `graphviz` package is not installed.
- """
- if format == "ascii":
- if channel_details:
- raise ValueError(
- "Parameters 'channel_details' are"
- " not compatible with 'ascii' format."
- )
- ascii_visualiztion_str = self._visualize_ascii()
- if view:
- print(ascii_visualiztion_str)
- return ascii_visualiztion_str
- try:
- import graphviz
- except ImportError:
- raise ImportError(
- "Please install graphviz to visualize the compiled graph. "
- "You can install it by running `pip install graphviz`."
- )
- from ray.dag import (
- ClassMethodNode,
- DAGNode,
- InputAttributeNode,
- InputNode,
- MultiOutputNode,
- )
- # Check that the DAG has been compiled
- if not hasattr(self, "idx_to_task") or not self.idx_to_task:
- raise ValueError(
- "The DAG must be compiled before calling 'visualize()'. "
- "Please call 'experimental_compile()' first."
- )
- # Check that each CompiledTask has a valid dag_node
- for idx, task in self.idx_to_task.items():
- if not hasattr(task, "dag_node") or not isinstance(task.dag_node, DAGNode):
- raise ValueError(
- f"Task at index {idx} does not have a valid 'dag_node'. "
- "Ensure that 'experimental_compile()' completed successfully."
- )
- # Dot file for debugging
- dot = graphviz.Digraph(name="compiled_graph", format=format)
- # Give every actor a unique color, colors between 24k -> 40k tested as readable
- # other colors may be too dark, especially when wrapping back around to 0
- actor_id_to_color = defaultdict(
- lambda: f"#{((len(actor_id_to_color) * 2000 + 24000) % 0xFFFFFF):06X}"
- )
- # Add nodes with task information
- for idx, task in self.idx_to_task.items():
- dag_node = task.dag_node
- # Initialize the label and attributes
- label = f"Task {idx}\n"
- shape = "oval" # Default shape
- style = "filled"
- fillcolor = ""
- # Handle different types of dag_node
- if isinstance(dag_node, InputNode):
- label += "InputNode"
- shape = "rectangle"
- fillcolor = "lightblue"
- elif isinstance(dag_node, InputAttributeNode):
- label += f"InputAttributeNode[{dag_node.key}]"
- shape = "rectangle"
- fillcolor = "lightblue"
- elif isinstance(dag_node, MultiOutputNode):
- label += "MultiOutputNode"
- shape = "rectangle"
- fillcolor = "yellow"
- elif isinstance(dag_node, ClassMethodNode):
- if dag_node.is_class_method_call:
- # Class Method Call Node
- method_name = dag_node.get_method_name()
- actor = dag_node._get_actor_handle()
- if actor:
- class_name = (
- actor._ray_actor_creation_function_descriptor.class_name
- )
- actor_id = actor._actor_id.hex()
- label += f"Actor: {class_name}\n"
- label += f"ID: {actor_id[:6]}...\n"
- label += f"Method: {method_name}"
- fillcolor = actor_id_to_color[actor_id]
- else:
- label += f"Method: {method_name}"
- fillcolor = "lightgreen"
- shape = "oval"
- elif dag_node.is_class_method_output:
- # Class Method Output Node
- label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
- shape = "rectangle"
- fillcolor = "orange"
- else:
- # Unexpected ClassMethodNode
- label += "ClassMethodNode"
- shape = "diamond"
- fillcolor = "red"
- else:
- # Unexpected node type
- label += type(dag_node).__name__
- shape = "diamond"
- fillcolor = "red"
- # Add the node to the graph with attributes
- dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor)
- channel_type_str = (
- (
- type(dag_node.type_hint).__name__
- if dag_node.type_hint
- else "UnknownType"
- )
- + "\n"
- if channel_details
- else None
- )
- # This logic is built on the assumption that there will only be multiple
- # output channels if the task has multiple returns
- # case: task with one output
- if len(task.output_channels) == 1:
- for downstream_node in task.dag_node._downstream_nodes:
- downstream_idx = self.dag_node_to_idx[downstream_node]
- edge_label = None
- if channel_details:
- edge_label = channel_type_str
- edge_label += self.get_channel_details(
- task.output_channels[0],
- (
- downstream_node._get_actor_handle()._actor_id.hex()
- if type(downstream_node) is ClassMethodNode
- else self._proxy_actor._actor_id.hex()
- ),
- )
- dot.edge(str(idx), str(downstream_idx), label=edge_label)
- # case: multi return, output channels connect to class method output nodes
- elif len(task.output_channels) > 1:
- assert len(task.output_idxs) == len(task.output_channels)
- for output_channel, downstream_idx in zip(
- task.output_channels, task.output_node_idxs
- ):
- edge_label = None
- if channel_details:
- edge_label = channel_type_str
- edge_label += self.get_channel_details(
- output_channel,
- task.dag_node._get_actor_handle()._actor_id.hex(),
- )
- dot.edge(str(idx), str(downstream_idx), label=edge_label)
- if type(task.dag_node) is InputAttributeNode:
- # Add an edge from the InputAttributeNode to the InputNode
- dot.edge(str(self.input_task_idx), str(idx))
- dot.render(filename, view=view)
- return dot.source
- def _register_input_output_custom_serializer(self):
- """
- Register custom serializers for input, input attribute, and output nodes.
- """
- assert self.input_task_idx is not None
- assert self.output_task_idx is not None
- # Register custom serializers for input node.
- input_task = self.idx_to_task[self.input_task_idx]
- input_task.dag_node.type_hint.register_custom_serializer()
- # Register custom serializers for input attribute nodes.
- for input_attr_task_idx in self.input_attr_task_idxs:
- input_attr_task = self.idx_to_task[input_attr_task_idx]
- input_attr_task.dag_node.type_hint.register_custom_serializer()
- # Register custom serializers for output nodes.
- for output in self.idx_to_task[self.output_task_idx].args:
- output.type_hint.register_custom_serializer()
- def teardown(self, kill_actors: bool = False):
- """
- Teardown and cancel all actor tasks for this DAG. After this
- function returns, the actors should be available to execute new tasks
- or compile a new DAG.
- Note: This method is automatically called when the CompiledDAG is destructed
- or the script exits. However, this should be explicitly called before compiling
- another graph on the same actors. Python may not garbage collect the
- CompiledDAG object immediately when you may expect.
- """
- if self._is_teardown:
- return
- monitor = getattr(self, "_monitor", None)
- if monitor is not None:
- from ray.dag import DAGContext
- ctx = DAGContext.get_current()
- monitor.teardown(kill_actors=kill_actors)
- monitor.join(timeout=ctx.teardown_timeout)
- # We do not log a warning here if the thread is still alive because
- # wait_teardown already logs upon teardown_timeout.
- self._is_teardown = True
- def __del__(self):
- self.teardown()
- @DeveloperAPI
- def build_compiled_dag_from_ray_dag(
- dag: "ray.dag.DAGNode",
- submit_timeout: Optional[float] = None,
- buffer_size_bytes: Optional[int] = None,
- enable_asyncio: bool = False,
- max_inflight_executions: Optional[int] = None,
- max_buffered_results: Optional[int] = None,
- overlap_gpu_communication: Optional[bool] = None,
- default_communicator: Optional[Union[Communicator, str]] = "create",
- ) -> "CompiledDAG":
- compiled_dag = CompiledDAG(
- submit_timeout,
- buffer_size_bytes,
- enable_asyncio,
- max_inflight_executions,
- max_buffered_results,
- overlap_gpu_communication,
- default_communicator,
- )
- def _build_compiled_dag(node):
- compiled_dag._add_node(node)
- return node
- root = dag._find_root()
- root.traverse_and_apply(_build_compiled_dag)
- compiled_dag._get_or_compile()
- global _compiled_dags
- _compiled_dags[compiled_dag.get_id()] = compiled_dag
- return compiled_dag
|