graph_drawer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # mypy: allow-untyped-defs
  2. import hashlib
  3. from itertools import chain
  4. from types import ModuleType
  5. from typing import Any, Optional, TYPE_CHECKING
  6. import torch
  7. import torch.fx
  8. from torch.fx._compatibility import compatibility
  9. from torch.fx.graph import _parse_stack_trace
  10. from torch.fx.node import _format_arg, _get_qualified_name
  11. from torch.fx.operator_schemas import normalize_function
  12. from torch.fx.passes.shape_prop import TensorMetadata
  13. if TYPE_CHECKING:
  14. import pydot
  15. HAS_PYDOT = True
  16. else:
  17. pydot: Optional[ModuleType]
  18. try:
  19. import pydot
  20. HAS_PYDOT = True
  21. except ModuleNotFoundError:
  22. HAS_PYDOT = False
  23. pydot = None
  24. __all__ = ["FxGraphDrawer"]
  25. _COLOR_MAP = {
  26. "placeholder": '"AliceBlue"',
  27. "call_module": "LemonChiffon1",
  28. "get_param": "Yellow2",
  29. "get_attr": "LightGrey",
  30. "output": "PowderBlue",
  31. }
  32. _HASH_COLOR_MAP = [
  33. "CadetBlue1",
  34. "Coral",
  35. "DarkOliveGreen1",
  36. "DarkSeaGreen1",
  37. "GhostWhite",
  38. "Khaki1",
  39. "LavenderBlush1",
  40. "LightSkyBlue",
  41. "MistyRose1",
  42. "MistyRose2",
  43. "PaleTurquoise2",
  44. "PeachPuff1",
  45. "Salmon",
  46. "Thistle1",
  47. "Thistle3",
  48. "Wheat1",
  49. ]
  50. _WEIGHT_TEMPLATE = {
  51. "fillcolor": "Salmon",
  52. "style": '"filled,rounded"',
  53. "fontcolor": "#000000",
  54. }
  55. if HAS_PYDOT:
  56. @compatibility(is_backward_compatible=False)
  57. class FxGraphDrawer:
  58. """
  59. Visualize a torch.fx.Graph with graphviz
  60. Basic usage:
  61. g = FxGraphDrawer(symbolic_traced, "resnet18")
  62. g.get_dot_graph().write_svg("a.svg")
  63. """
  64. def __init__(
  65. self,
  66. graph_module: torch.fx.GraphModule,
  67. name: str,
  68. ignore_getattr: bool = False,
  69. ignore_parameters_and_buffers: bool = False,
  70. skip_node_names_in_args: bool = True,
  71. parse_stack_trace: bool = False,
  72. dot_graph_shape: Optional[str] = None,
  73. normalize_args: bool = False,
  74. ):
  75. self._name = name
  76. self.dot_graph_shape = (
  77. dot_graph_shape if dot_graph_shape is not None else "record"
  78. )
  79. self.normalize_args = normalize_args
  80. _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
  81. self._dot_graphs = {
  82. name: self._to_dot(
  83. graph_module,
  84. name,
  85. ignore_getattr,
  86. ignore_parameters_and_buffers,
  87. skip_node_names_in_args,
  88. parse_stack_trace,
  89. )
  90. }
  91. for node in graph_module.graph.nodes:
  92. if node.op != "call_module":
  93. continue
  94. leaf_node = self._get_leaf_node(graph_module, node)
  95. if not isinstance(leaf_node, torch.fx.GraphModule):
  96. continue
  97. self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
  98. leaf_node,
  99. f"{name}_{node.target}",
  100. ignore_getattr,
  101. ignore_parameters_and_buffers,
  102. skip_node_names_in_args,
  103. parse_stack_trace,
  104. )
  105. def get_dot_graph(self, submod_name=None) -> pydot.Dot:
  106. """
  107. Visualize a torch.fx.Graph with graphviz
  108. Example:
  109. >>> # xdoctest: +REQUIRES(module:pydot)
  110. >>> # xdoctest: +REQUIRES(module:ubelt)
  111. >>> # define module
  112. >>> class MyModule(torch.nn.Module):
  113. >>> def __init__(self) -> None:
  114. >>> super().__init__()
  115. >>> self.linear = torch.nn.Linear(4, 5)
  116. >>> def forward(self, x):
  117. >>> return self.linear(x).clamp(min=0.0, max=1.0)
  118. >>> module = MyModule()
  119. >>> # trace the module
  120. >>> symbolic_traced = torch.fx.symbolic_trace(module)
  121. >>> # setup output file
  122. >>> import ubelt as ub
  123. >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir()
  124. >>> fpath = dpath / "linear.svg"
  125. >>> # draw the graph
  126. >>> g = FxGraphDrawer(symbolic_traced, "linear")
  127. >>> g.get_dot_graph().write_svg(fpath)
  128. """
  129. if submod_name is None:
  130. return self.get_main_dot_graph()
  131. else:
  132. return self.get_submod_dot_graph(submod_name)
  133. def get_main_dot_graph(self) -> pydot.Dot:
  134. return self._dot_graphs[self._name]
  135. def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
  136. return self._dot_graphs[f"{self._name}_{submod_name}"]
  137. def get_all_dot_graphs(self) -> dict[str, pydot.Dot]:
  138. return self._dot_graphs
  139. def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]:
  140. template = {
  141. "shape": self.dot_graph_shape,
  142. "fillcolor": "#CAFFE3",
  143. "style": '"filled,rounded"',
  144. "fontcolor": "#000000",
  145. }
  146. if node.op in _COLOR_MAP:
  147. template["fillcolor"] = _COLOR_MAP[node.op]
  148. else:
  149. # Use a random color for each node; based on its name so it's stable.
  150. target_name = node._pretty_print_target(node.target)
  151. target_hash = int(
  152. hashlib.md5(
  153. target_name.encode(), usedforsecurity=False
  154. ).hexdigest()[:8],
  155. 16,
  156. )
  157. template["fillcolor"] = _HASH_COLOR_MAP[
  158. target_hash % len(_HASH_COLOR_MAP)
  159. ]
  160. return template
  161. def _get_leaf_node(
  162. self, module: torch.nn.Module, node: torch.fx.Node
  163. ) -> torch.nn.Module:
  164. py_obj = module
  165. if not isinstance(node.target, str):
  166. raise AssertionError(f"Expected str target, got {type(node.target)}")
  167. atoms = node.target.split(".")
  168. for atom in atoms:
  169. if not hasattr(py_obj, atom):
  170. raise RuntimeError(
  171. str(py_obj) + " does not have attribute " + atom + "!"
  172. )
  173. py_obj = getattr(py_obj, atom)
  174. return py_obj
  175. def _typename(self, target: Any) -> str:
  176. if isinstance(target, torch.nn.Module):
  177. ret = torch.typename(target)
  178. elif isinstance(target, str):
  179. ret = target
  180. else:
  181. ret = _get_qualified_name(target)
  182. # Escape "{" and "}" to prevent dot files like:
  183. # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
  184. # which triggers `Error: bad label format (...)` from dot
  185. return ret.replace("{", r"\{").replace("}", r"\}")
  186. # shorten path to avoid drawing long boxes
  187. # for full path = '/home/weif/pytorch/test.py'
  188. # return short path = 'pytorch/test.py'
  189. def _shorten_file_name(
  190. self,
  191. full_file_name: str,
  192. truncate_to_last_n: int = 2,
  193. ):
  194. splits = full_file_name.split("/")
  195. if len(splits) >= truncate_to_last_n:
  196. return "/".join(splits[-truncate_to_last_n:])
  197. return full_file_name
  198. def _get_node_label(
  199. self,
  200. module: torch.fx.GraphModule,
  201. node: torch.fx.Node,
  202. skip_node_names_in_args: bool,
  203. parse_stack_trace: bool,
  204. ) -> str:
  205. def _get_str_for_args_kwargs(arg):
  206. if isinstance(arg, tuple):
  207. prefix, suffix = r"|args=(\l", r",\n)\l"
  208. arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
  209. elif isinstance(arg, dict):
  210. prefix, suffix = r"|kwargs={\l", r",\n}\l"
  211. arg_strs_list = [
  212. f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items()
  213. ]
  214. else: # Fall back to nothing in unexpected case.
  215. return ""
  216. # Strip out node names if requested.
  217. if skip_node_names_in_args:
  218. arg_strs_list = [a for a in arg_strs_list if "%" not in a]
  219. if len(arg_strs_list) == 0:
  220. return ""
  221. arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
  222. if len(arg_strs_list) == 1:
  223. arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
  224. return arg_strs.replace("{", r"\{").replace("}", r"\}")
  225. label = "{" + f"name=%{node.name}|op_code={node.op}\n"
  226. if node.op == "call_module":
  227. leaf_module = self._get_leaf_node(module, node)
  228. label += r"\n" + self._typename(leaf_module) + r"\n|"
  229. extra = ""
  230. if hasattr(leaf_module, "__constants__"):
  231. extra = r"\n".join(
  232. [
  233. f"{c}: {getattr(leaf_module, c)}"
  234. for c in leaf_module.__constants__ # type: ignore[union-attr]
  235. ] # type: ignore[union-attr]
  236. )
  237. label += extra + r"\n"
  238. else:
  239. label += f"|target={self._typename(node.target)}" + r"\n"
  240. if self.normalize_args:
  241. try:
  242. args, kwargs = normalize_function( # type: ignore[misc]
  243. node.target, # type: ignore[arg-type]
  244. node.args, # type: ignore[arg-type]
  245. node.kwargs,
  246. normalize_to_only_use_kwargs=True,
  247. )
  248. except Exception:
  249. # Fallback to not normalizing if there's an exception.
  250. # Some functions need overloads specified to normalize.
  251. args, kwargs = node.args, node.kwargs
  252. else:
  253. args, kwargs = node.args, node.kwargs
  254. if len(args) > 0:
  255. label += _get_str_for_args_kwargs(args)
  256. if len(kwargs) > 0:
  257. label += _get_str_for_args_kwargs(kwargs)
  258. label += f"|num_users={len(node.users)}" + r"\n"
  259. tensor_meta = node.meta.get("tensor_meta")
  260. label += self._tensor_meta_to_label(tensor_meta)
  261. # for original fx graph
  262. # print buf=buf0, n_origin=6
  263. buf_meta = node.meta.get("buf_meta", None)
  264. if buf_meta is not None:
  265. label += f"|buf={buf_meta.name}" + r"\n"
  266. label += f"|n_origin={buf_meta.n_origin}" + r"\n"
  267. # for original fx graph
  268. # print file:lineno code
  269. if parse_stack_trace and node.stack_trace is not None:
  270. parsed_stack_trace = _parse_stack_trace(node.stack_trace)
  271. fname = self._shorten_file_name(parsed_stack_trace.file)
  272. label += (
  273. f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}"
  274. + r"\n"
  275. )
  276. return label + "}"
  277. def _tensor_meta_to_label(self, tm) -> str:
  278. if tm is None:
  279. return ""
  280. elif isinstance(tm, TensorMetadata):
  281. return self._stringify_tensor_meta(tm)
  282. elif isinstance(tm, list):
  283. result = ""
  284. for item in tm:
  285. result += self._tensor_meta_to_label(item)
  286. return result
  287. elif isinstance(tm, dict):
  288. result = ""
  289. for v in tm.values():
  290. result += self._tensor_meta_to_label(v)
  291. return result
  292. elif isinstance(tm, tuple):
  293. result = ""
  294. for item in tm:
  295. result += self._tensor_meta_to_label(item)
  296. return result
  297. else:
  298. raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
  299. def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
  300. result = ""
  301. if not hasattr(tm, "dtype"):
  302. print("tm", tm)
  303. result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
  304. result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
  305. result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
  306. result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
  307. if tm.is_quantized:
  308. if tm.qparams is None:
  309. raise AssertionError("qparams is None for quantized tensor")
  310. if "qscheme" not in tm.qparams:
  311. raise AssertionError("qscheme not in qparams")
  312. qscheme = tm.qparams["qscheme"]
  313. if qscheme in {
  314. torch.per_tensor_affine,
  315. torch.per_tensor_symmetric,
  316. }:
  317. result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
  318. result += (
  319. "|"
  320. + "q_zero_point"
  321. + "="
  322. + str(tm.qparams["zero_point"])
  323. + r"\n"
  324. )
  325. elif qscheme in {
  326. torch.per_channel_affine,
  327. torch.per_channel_symmetric,
  328. torch.per_channel_affine_float_qparams,
  329. }:
  330. result += (
  331. "|"
  332. + "q_per_channel_scale"
  333. + "="
  334. + str(tm.qparams["scale"])
  335. + r"\n"
  336. )
  337. result += (
  338. "|"
  339. + "q_per_channel_zero_point"
  340. + "="
  341. + str(tm.qparams["zero_point"])
  342. + r"\n"
  343. )
  344. result += (
  345. "|"
  346. + "q_per_channel_axis"
  347. + "="
  348. + str(tm.qparams["axis"])
  349. + r"\n"
  350. )
  351. else:
  352. raise RuntimeError(f"Unsupported qscheme: {qscheme}")
  353. result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
  354. return result
  355. def _get_tensor_label(self, t: torch.Tensor) -> str:
  356. return str(t.dtype) + str(list(t.shape)) + r"\n"
  357. # when parse_stack_trace=True
  358. # print file:lineno code
  359. def _to_dot(
  360. self,
  361. graph_module: torch.fx.GraphModule,
  362. name: str,
  363. ignore_getattr: bool,
  364. ignore_parameters_and_buffers: bool,
  365. skip_node_names_in_args: bool,
  366. parse_stack_trace: bool,
  367. ) -> pydot.Dot:
  368. """
  369. Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
  370. If ignore_parameters_and_buffers is True, the parameters and buffers
  371. created with the module will not be added as nodes and edges.
  372. """
  373. # "TB" means top-to-bottom rank direction in layout
  374. dot_graph = pydot.Dot(name, rankdir="TB")
  375. buf_name_to_subgraph = {}
  376. for node in graph_module.graph.nodes:
  377. if ignore_getattr and node.op == "get_attr":
  378. continue
  379. style = self._get_node_style(node)
  380. dot_node = pydot.Node(
  381. node.name,
  382. label=self._get_node_label(
  383. graph_module, node, skip_node_names_in_args, parse_stack_trace
  384. ),
  385. **style, # type: ignore[arg-type]
  386. )
  387. current_graph = dot_graph
  388. buf_meta = node.meta.get("buf_meta", None)
  389. if buf_meta is not None and buf_meta.n_origin > 1:
  390. buf_name = buf_meta.name
  391. if buf_name not in buf_name_to_subgraph:
  392. buf_name_to_subgraph[buf_name] = pydot.Cluster(
  393. buf_name, label=buf_name
  394. )
  395. current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment]
  396. # pyrefly: ignore [missing-attribute]
  397. current_graph.add_node(dot_node)
  398. def get_module_params_or_buffers():
  399. for pname, ptensor in chain(
  400. leaf_module.named_parameters(),
  401. # pyrefly: ignore [bad-argument-type]
  402. leaf_module.named_buffers(),
  403. ):
  404. pname1 = node.name + "." + pname
  405. label1 = (
  406. pname1 + "|op_code=get_" + "parameter"
  407. if isinstance(ptensor, torch.nn.Parameter)
  408. else "buffer" + r"\l"
  409. )
  410. dot_w_node = pydot.Node(
  411. pname1,
  412. label="{" + label1 + self._get_tensor_label(ptensor) + "}",
  413. **_WEIGHT_TEMPLATE, # type: ignore[arg-type]
  414. )
  415. dot_graph.add_node(dot_w_node)
  416. dot_graph.add_edge(pydot.Edge(pname1, node.name))
  417. if node.op == "call_module":
  418. leaf_module = self._get_leaf_node(graph_module, node)
  419. if not ignore_parameters_and_buffers and not isinstance(
  420. leaf_module, torch.fx.GraphModule
  421. ):
  422. get_module_params_or_buffers()
  423. for subgraph in buf_name_to_subgraph.values():
  424. subgraph.set("color", "royalblue")
  425. subgraph.set("penwidth", "2")
  426. dot_graph.add_subgraph(subgraph) # type: ignore[arg-type]
  427. for node in graph_module.graph.nodes:
  428. if ignore_getattr and node.op == "get_attr":
  429. continue
  430. for user in node.users:
  431. dot_graph.add_edge(pydot.Edge(node.name, user.name))
  432. return dot_graph
  433. else:
  434. if not TYPE_CHECKING:
  435. @compatibility(is_backward_compatible=False)
  436. class FxGraphDrawer:
  437. def __init__(
  438. self,
  439. graph_module: torch.fx.GraphModule,
  440. name: str,
  441. ignore_getattr: bool = False,
  442. ignore_parameters_and_buffers: bool = False,
  443. skip_node_names_in_args: bool = True,
  444. parse_stack_trace: bool = False,
  445. dot_graph_shape: Optional[str] = None,
  446. normalize_args: bool = False,
  447. ):
  448. raise RuntimeError(
  449. "FXGraphDrawer requires the pydot package to be installed. Please install "
  450. "pydot through your favorite Python package manager."
  451. )