| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- import os
- import tempfile
- from ray.dag import DAGNode
- from ray.dag.utils import _DAGNodeNameGenerator
- from ray.util.annotations import DeveloperAPI
- @DeveloperAPI
- def plot(dag: DAGNode, to_file=None):
- if to_file is None:
- tmp_file = tempfile.NamedTemporaryFile(suffix=".png")
- to_file = tmp_file.name
- extension = "png"
- else:
- _, extension = os.path.splitext(to_file)
- if not extension:
- extension = "png"
- else:
- extension = extension[1:]
- graph = _dag_to_dot(dag)
- graph.write(to_file, format=extension)
- # Render the image directly if running inside a Jupyter notebook
- try:
- from IPython import display
- return display.Image(filename=to_file)
- except ImportError:
- pass
- # close temp file if needed
- try:
- tmp_file.close()
- except NameError:
- pass
- def _check_pydot_and_graphviz():
- """Check if pydot and graphviz are installed.
- pydot and graphviz are required for plotting. We check this
- during runtime rather than adding them to Ray dependencies.
- """
- try:
- import pydot
- except ImportError:
- raise ImportError(
- "pydot is required to plot DAG, install it with `pip install pydot`."
- )
- try:
- pydot.Dot.create(pydot.Dot())
- except (OSError, pydot.InvocationException):
- raise ImportError(
- "graphviz is required to plot DAG, "
- "download it from https://graphviz.gitlab.io/download/"
- )
- def _get_nodes_and_edges(dag: DAGNode):
- """Get all unique nodes and edges in the DAG.
- A basic dfs with memoization to get all unique nodes
- and edges in the DAG.
- Unique nodes will be used to generate unique names,
- while edges will be used to construct the graph.
- """
- edges = []
- nodes = []
- def _dfs(node):
- nodes.append(node)
- for child_node in node._get_all_child_nodes():
- edges.append((child_node, node))
- return node
- dag.apply_recursive(_dfs)
- return nodes, edges
- def _dag_to_dot(dag: DAGNode):
- """Create a Dot graph from dag.
- TODO(lchu):
- 1. add more Dot configs in kwargs,
- e.g. rankdir, alignment, etc.
- 2. add more contents to graph,
- e.g. args, kwargs and options of each node
- """
- # Step 0: check dependencies and init graph
- _check_pydot_and_graphviz()
- import pydot
- graph = pydot.Dot(rankdir="LR")
- # Step 1: generate unique name for each node in dag
- nodes, edges = _get_nodes_and_edges(dag)
- name_generator = _DAGNodeNameGenerator()
- node_names = {}
- for node in nodes:
- node_names[node] = name_generator.get_node_name(node)
- # Step 2: create graph with all the edges
- for edge in edges:
- graph.add_edge(pydot.Edge(node_names[edge[0]], node_names[edge[1]]))
- # if there is only one node
- if len(nodes) == 1 and len(edges) == 0:
- graph.add_node(pydot.Node(node_names[nodes[0]]))
- return graph
|