graph.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # Licensed under the GPL: https://www.gnu.org/licenses/old-licenses/gpl-2.0.html
  2. # For details: https://github.com/pylint-dev/pylint/blob/main/LICENSE
  3. # Copyright (c) https://github.com/pylint-dev/pylint/blob/main/CONTRIBUTORS.txt
  4. """Graph manipulation utilities.
  5. (dot generation adapted from pypy/translator/tool/make_dot.py)
  6. """
  7. from __future__ import annotations
  8. import os
  9. import shutil
  10. import subprocess
  11. import tempfile
  12. from collections.abc import Sequence
  13. from typing import Any
  14. def target_info_from_filename(filename: str) -> tuple[str, str, str]:
  15. """Transforms /some/path/foo.png into ('/some/path', 'foo.png', 'png')."""
  16. basename = os.path.basename(filename)
  17. storedir = os.path.dirname(os.path.abspath(filename))
  18. target = os.path.splitext(filename)[-1][1:]
  19. return storedir, basename, target
  20. class DotBackend:
  21. """Dot File back-end."""
  22. def __init__(
  23. self,
  24. graphname: str,
  25. rankdir: str | None = None,
  26. size: Any = None,
  27. ratio: Any = None,
  28. charset: str = "utf-8",
  29. renderer: str = "dot",
  30. additional_param: dict[str, Any] | None = None,
  31. ) -> None:
  32. if additional_param is None:
  33. additional_param = {}
  34. self.graphname = graphname
  35. self.renderer = renderer
  36. self.lines: list[str] = []
  37. self._source: str | None = None
  38. self.emit(f"digraph {normalize_node_id(graphname)} {{")
  39. if rankdir:
  40. self.emit(f"rankdir={rankdir}")
  41. if ratio:
  42. self.emit(f"ratio={ratio}")
  43. if size:
  44. self.emit(f'size="{size}"')
  45. if charset:
  46. assert charset.lower() in {
  47. "utf-8",
  48. "iso-8859-1",
  49. "latin1",
  50. }, f"unsupported charset {charset}"
  51. self.emit(f'charset="{charset}"')
  52. for param in additional_param.items():
  53. self.emit("=".join(param))
  54. def get_source(self) -> str:
  55. """Returns self._source."""
  56. if self._source is None:
  57. self.emit("}\n")
  58. self._source = "\n".join(self.lines)
  59. del self.lines
  60. return self._source
  61. source = property(get_source)
  62. def generate(
  63. self, outputfile: str | None = None, mapfile: str | None = None
  64. ) -> str:
  65. """Generates a graph file.
  66. :param str outputfile: filename and path [defaults to graphname.png]
  67. :param str mapfile: filename and path
  68. :rtype: str
  69. :return: a path to the generated file
  70. :raises RuntimeError: if the executable for rendering was not found
  71. """
  72. # pylint: disable=duplicate-code
  73. graphviz_extensions = ("dot", "gv")
  74. name = self.graphname
  75. if outputfile is None:
  76. target = "png"
  77. pdot, dot_sourcepath = tempfile.mkstemp(".gv", name)
  78. ppng, outputfile = tempfile.mkstemp(".png", name)
  79. os.close(pdot)
  80. os.close(ppng)
  81. else:
  82. _, _, target = target_info_from_filename(outputfile)
  83. if not target:
  84. target = "png"
  85. outputfile = outputfile + "." + target
  86. if target not in graphviz_extensions:
  87. pdot, dot_sourcepath = tempfile.mkstemp(".gv", name)
  88. os.close(pdot)
  89. else:
  90. dot_sourcepath = outputfile
  91. with open(dot_sourcepath, "w", encoding="utf8") as file:
  92. file.write(self.source)
  93. if target not in graphviz_extensions:
  94. if shutil.which(self.renderer) is None:
  95. raise RuntimeError(
  96. f"Cannot generate `{outputfile}` because '{self.renderer}' "
  97. "executable not found. Install graphviz, or specify a `.gv` "
  98. "outputfile to produce the DOT source code."
  99. )
  100. if mapfile:
  101. subprocess.run(
  102. [
  103. self.renderer,
  104. "-Tcmapx",
  105. "-o",
  106. mapfile,
  107. "-T",
  108. target,
  109. dot_sourcepath,
  110. "-o",
  111. outputfile,
  112. ],
  113. check=True,
  114. )
  115. else:
  116. subprocess.run(
  117. [self.renderer, "-T", target, dot_sourcepath, "-o", outputfile],
  118. check=True,
  119. )
  120. os.unlink(dot_sourcepath)
  121. return outputfile
  122. def emit(self, line: str) -> None:
  123. """Adds <line> to final output."""
  124. self.lines.append(line)
  125. def emit_edge(self, name1: str, name2: str, **props: Any) -> None:
  126. """Emit an edge from <name1> to <name2>.
  127. For edge properties: see https://www.graphviz.org/doc/info/attrs.html
  128. """
  129. attrs = [f'{prop}="{value}"' for prop, value in props.items()]
  130. n_from, n_to = normalize_node_id(name1), normalize_node_id(name2)
  131. self.emit(f"{n_from} -> {n_to} [{', '.join(sorted(attrs))}];")
  132. def emit_node(self, name: str, **props: Any) -> None:
  133. """Emit a node with given properties.
  134. For node properties: see https://www.graphviz.org/doc/info/attrs.html
  135. """
  136. attrs = [f'{prop}="{value}"' for prop, value in props.items()]
  137. self.emit(f"{normalize_node_id(name)} [{', '.join(sorted(attrs))}];")
  138. def normalize_node_id(nid: str) -> str:
  139. """Returns a suitable DOT node id for `nid`."""
  140. return f'"{nid}"'
  141. def get_cycles(
  142. graph_dict: dict[str, set[str]], vertices: list[str] | None = None
  143. ) -> Sequence[list[str]]:
  144. """Return a list of detected cycles based on an ordered graph (i.e. keys are
  145. vertices and values are lists of destination vertices representing edges).
  146. """
  147. if not graph_dict:
  148. return ()
  149. result: list[list[str]] = []
  150. if vertices is None:
  151. vertices = list(graph_dict.keys())
  152. for vertice in vertices:
  153. _get_cycles(graph_dict, [], set(), result, vertice)
  154. return result
  155. def _get_cycles(
  156. graph_dict: dict[str, set[str]],
  157. path: list[str],
  158. visited: set[str],
  159. result: list[list[str]],
  160. vertice: str,
  161. ) -> None:
  162. """Recursive function doing the real work for get_cycles."""
  163. if vertice in path:
  164. cycle = [vertice]
  165. for node in path[::-1]:
  166. if node == vertice:
  167. break
  168. cycle.insert(0, node)
  169. # make a canonical representation
  170. start_from = min(cycle)
  171. index = cycle.index(start_from)
  172. cycle = cycle[index:] + cycle[0:index]
  173. # append it to result if not already in
  174. if cycle not in result:
  175. result.append(cycle)
  176. return
  177. path.append(vertice)
  178. try:
  179. for node in graph_dict[vertice]:
  180. # don't check already visited nodes again
  181. if node not in visited:
  182. _get_cycles(graph_dict, path, visited, result, node)
  183. visited.add(node)
  184. except KeyError:
  185. pass
  186. path.pop()