vis_utils.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import os
  2. import tempfile
  3. from ray.dag import DAGNode
  4. from ray.dag.utils import _DAGNodeNameGenerator
  5. from ray.util.annotations import DeveloperAPI
  6. @DeveloperAPI
  7. def plot(dag: DAGNode, to_file=None):
  8. if to_file is None:
  9. tmp_file = tempfile.NamedTemporaryFile(suffix=".png")
  10. to_file = tmp_file.name
  11. extension = "png"
  12. else:
  13. _, extension = os.path.splitext(to_file)
  14. if not extension:
  15. extension = "png"
  16. else:
  17. extension = extension[1:]
  18. graph = _dag_to_dot(dag)
  19. graph.write(to_file, format=extension)
  20. # Render the image directly if running inside a Jupyter notebook
  21. try:
  22. from IPython import display
  23. return display.Image(filename=to_file)
  24. except ImportError:
  25. pass
  26. # close temp file if needed
  27. try:
  28. tmp_file.close()
  29. except NameError:
  30. pass
  31. def _check_pydot_and_graphviz():
  32. """Check if pydot and graphviz are installed.
  33. pydot and graphviz are required for plotting. We check this
  34. during runtime rather than adding them to Ray dependencies.
  35. """
  36. try:
  37. import pydot
  38. except ImportError:
  39. raise ImportError(
  40. "pydot is required to plot DAG, install it with `pip install pydot`."
  41. )
  42. try:
  43. pydot.Dot.create(pydot.Dot())
  44. except (OSError, pydot.InvocationException):
  45. raise ImportError(
  46. "graphviz is required to plot DAG, "
  47. "download it from https://graphviz.gitlab.io/download/"
  48. )
  49. def _get_nodes_and_edges(dag: DAGNode):
  50. """Get all unique nodes and edges in the DAG.
  51. A basic dfs with memoization to get all unique nodes
  52. and edges in the DAG.
  53. Unique nodes will be used to generate unique names,
  54. while edges will be used to construct the graph.
  55. """
  56. edges = []
  57. nodes = []
  58. def _dfs(node):
  59. nodes.append(node)
  60. for child_node in node._get_all_child_nodes():
  61. edges.append((child_node, node))
  62. return node
  63. dag.apply_recursive(_dfs)
  64. return nodes, edges
  65. def _dag_to_dot(dag: DAGNode):
  66. """Create a Dot graph from dag.
  67. TODO(lchu):
  68. 1. add more Dot configs in kwargs,
  69. e.g. rankdir, alignment, etc.
  70. 2. add more contents to graph,
  71. e.g. args, kwargs and options of each node
  72. """
  73. # Step 0: check dependencies and init graph
  74. _check_pydot_and_graphviz()
  75. import pydot
  76. graph = pydot.Dot(rankdir="LR")
  77. # Step 1: generate unique name for each node in dag
  78. nodes, edges = _get_nodes_and_edges(dag)
  79. name_generator = _DAGNodeNameGenerator()
  80. node_names = {}
  81. for node in nodes:
  82. node_names[node] = name_generator.get_node_name(node)
  83. # Step 2: create graph with all the edges
  84. for edge in edges:
  85. graph.add_edge(pydot.Edge(node_names[edge[0]], node_names[edge[1]]))
  86. # if there is only one node
  87. if len(nodes) == 1 and len(edges) == 0:
  88. graph.add_node(pydot.Node(node_names[nodes[0]]))
  89. return graph