import os import tempfile from ray.dag import DAGNode from ray.dag.utils import _DAGNodeNameGenerator from ray.util.annotations import DeveloperAPI @DeveloperAPI def plot(dag: DAGNode, to_file=None): if to_file is None: tmp_file = tempfile.NamedTemporaryFile(suffix=".png") to_file = tmp_file.name extension = "png" else: _, extension = os.path.splitext(to_file) if not extension: extension = "png" else: extension = extension[1:] graph = _dag_to_dot(dag) graph.write(to_file, format=extension) # Render the image directly if running inside a Jupyter notebook try: from IPython import display return display.Image(filename=to_file) except ImportError: pass # close temp file if needed try: tmp_file.close() except NameError: pass def _check_pydot_and_graphviz(): """Check if pydot and graphviz are installed. pydot and graphviz are required for plotting. We check this during runtime rather than adding them to Ray dependencies. """ try: import pydot except ImportError: raise ImportError( "pydot is required to plot DAG, install it with `pip install pydot`." ) try: pydot.Dot.create(pydot.Dot()) except (OSError, pydot.InvocationException): raise ImportError( "graphviz is required to plot DAG, " "download it from https://graphviz.gitlab.io/download/" ) def _get_nodes_and_edges(dag: DAGNode): """Get all unique nodes and edges in the DAG. A basic dfs with memoization to get all unique nodes and edges in the DAG. Unique nodes will be used to generate unique names, while edges will be used to construct the graph. """ edges = [] nodes = [] def _dfs(node): nodes.append(node) for child_node in node._get_all_child_nodes(): edges.append((child_node, node)) return node dag.apply_recursive(_dfs) return nodes, edges def _dag_to_dot(dag: DAGNode): """Create a Dot graph from dag. TODO(lchu): 1. add more Dot configs in kwargs, e.g. rankdir, alignment, etc. 2. add more contents to graph, e.g. args, kwargs and options of each node """ # Step 0: check dependencies and init graph _check_pydot_and_graphviz() import pydot graph = pydot.Dot(rankdir="LR") # Step 1: generate unique name for each node in dag nodes, edges = _get_nodes_and_edges(dag) name_generator = _DAGNodeNameGenerator() node_names = {} for node in nodes: node_names[node] = name_generator.get_node_name(node) # Step 2: create graph with all the edges for edge in edges: graph.add_edge(pydot.Edge(node_names[edge[0]], node_names[edge[1]])) # if there is only one node if len(nodes) == 1 and len(edges) == 0: graph.add_node(pydot.Node(node_names[nodes[0]])) return graph