| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507 |
- # mypy: allow-untyped-defs
- import hashlib
- from itertools import chain
- from types import ModuleType
- from typing import Any, Optional, TYPE_CHECKING
- import torch
- import torch.fx
- from torch.fx._compatibility import compatibility
- from torch.fx.graph import _parse_stack_trace
- from torch.fx.node import _format_arg, _get_qualified_name
- from torch.fx.operator_schemas import normalize_function
- from torch.fx.passes.shape_prop import TensorMetadata
- if TYPE_CHECKING:
- import pydot
- HAS_PYDOT = True
- else:
- pydot: Optional[ModuleType]
- try:
- import pydot
- HAS_PYDOT = True
- except ModuleNotFoundError:
- HAS_PYDOT = False
- pydot = None
- __all__ = ["FxGraphDrawer"]
- _COLOR_MAP = {
- "placeholder": '"AliceBlue"',
- "call_module": "LemonChiffon1",
- "get_param": "Yellow2",
- "get_attr": "LightGrey",
- "output": "PowderBlue",
- }
- _HASH_COLOR_MAP = [
- "CadetBlue1",
- "Coral",
- "DarkOliveGreen1",
- "DarkSeaGreen1",
- "GhostWhite",
- "Khaki1",
- "LavenderBlush1",
- "LightSkyBlue",
- "MistyRose1",
- "MistyRose2",
- "PaleTurquoise2",
- "PeachPuff1",
- "Salmon",
- "Thistle1",
- "Thistle3",
- "Wheat1",
- ]
- _WEIGHT_TEMPLATE = {
- "fillcolor": "Salmon",
- "style": '"filled,rounded"',
- "fontcolor": "#000000",
- }
- if HAS_PYDOT:
- @compatibility(is_backward_compatible=False)
- class FxGraphDrawer:
- """
- Visualize a torch.fx.Graph with graphviz
- Basic usage:
- g = FxGraphDrawer(symbolic_traced, "resnet18")
- g.get_dot_graph().write_svg("a.svg")
- """
- def __init__(
- self,
- graph_module: torch.fx.GraphModule,
- name: str,
- ignore_getattr: bool = False,
- ignore_parameters_and_buffers: bool = False,
- skip_node_names_in_args: bool = True,
- parse_stack_trace: bool = False,
- dot_graph_shape: Optional[str] = None,
- normalize_args: bool = False,
- ):
- self._name = name
- self.dot_graph_shape = (
- dot_graph_shape if dot_graph_shape is not None else "record"
- )
- self.normalize_args = normalize_args
- _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape
- self._dot_graphs = {
- name: self._to_dot(
- graph_module,
- name,
- ignore_getattr,
- ignore_parameters_and_buffers,
- skip_node_names_in_args,
- parse_stack_trace,
- )
- }
- for node in graph_module.graph.nodes:
- if node.op != "call_module":
- continue
- leaf_node = self._get_leaf_node(graph_module, node)
- if not isinstance(leaf_node, torch.fx.GraphModule):
- continue
- self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
- leaf_node,
- f"{name}_{node.target}",
- ignore_getattr,
- ignore_parameters_and_buffers,
- skip_node_names_in_args,
- parse_stack_trace,
- )
- def get_dot_graph(self, submod_name=None) -> pydot.Dot:
- """
- Visualize a torch.fx.Graph with graphviz
- Example:
- >>> # xdoctest: +REQUIRES(module:pydot)
- >>> # xdoctest: +REQUIRES(module:ubelt)
- >>> # define module
- >>> class MyModule(torch.nn.Module):
- >>> def __init__(self) -> None:
- >>> super().__init__()
- >>> self.linear = torch.nn.Linear(4, 5)
- >>> def forward(self, x):
- >>> return self.linear(x).clamp(min=0.0, max=1.0)
- >>> module = MyModule()
- >>> # trace the module
- >>> symbolic_traced = torch.fx.symbolic_trace(module)
- >>> # setup output file
- >>> import ubelt as ub
- >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir()
- >>> fpath = dpath / "linear.svg"
- >>> # draw the graph
- >>> g = FxGraphDrawer(symbolic_traced, "linear")
- >>> g.get_dot_graph().write_svg(fpath)
- """
- if submod_name is None:
- return self.get_main_dot_graph()
- else:
- return self.get_submod_dot_graph(submod_name)
- def get_main_dot_graph(self) -> pydot.Dot:
- return self._dot_graphs[self._name]
- def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
- return self._dot_graphs[f"{self._name}_{submod_name}"]
- def get_all_dot_graphs(self) -> dict[str, pydot.Dot]:
- return self._dot_graphs
- def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]:
- template = {
- "shape": self.dot_graph_shape,
- "fillcolor": "#CAFFE3",
- "style": '"filled,rounded"',
- "fontcolor": "#000000",
- }
- if node.op in _COLOR_MAP:
- template["fillcolor"] = _COLOR_MAP[node.op]
- else:
- # Use a random color for each node; based on its name so it's stable.
- target_name = node._pretty_print_target(node.target)
- target_hash = int(
- hashlib.md5(
- target_name.encode(), usedforsecurity=False
- ).hexdigest()[:8],
- 16,
- )
- template["fillcolor"] = _HASH_COLOR_MAP[
- target_hash % len(_HASH_COLOR_MAP)
- ]
- return template
- def _get_leaf_node(
- self, module: torch.nn.Module, node: torch.fx.Node
- ) -> torch.nn.Module:
- py_obj = module
- if not isinstance(node.target, str):
- raise AssertionError(f"Expected str target, got {type(node.target)}")
- atoms = node.target.split(".")
- for atom in atoms:
- if not hasattr(py_obj, atom):
- raise RuntimeError(
- str(py_obj) + " does not have attribute " + atom + "!"
- )
- py_obj = getattr(py_obj, atom)
- return py_obj
- def _typename(self, target: Any) -> str:
- if isinstance(target, torch.nn.Module):
- ret = torch.typename(target)
- elif isinstance(target, str):
- ret = target
- else:
- ret = _get_qualified_name(target)
- # Escape "{" and "}" to prevent dot files like:
- # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
- # which triggers `Error: bad label format (...)` from dot
- return ret.replace("{", r"\{").replace("}", r"\}")
- # shorten path to avoid drawing long boxes
- # for full path = '/home/weif/pytorch/test.py'
- # return short path = 'pytorch/test.py'
- def _shorten_file_name(
- self,
- full_file_name: str,
- truncate_to_last_n: int = 2,
- ):
- splits = full_file_name.split("/")
- if len(splits) >= truncate_to_last_n:
- return "/".join(splits[-truncate_to_last_n:])
- return full_file_name
- def _get_node_label(
- self,
- module: torch.fx.GraphModule,
- node: torch.fx.Node,
- skip_node_names_in_args: bool,
- parse_stack_trace: bool,
- ) -> str:
- def _get_str_for_args_kwargs(arg):
- if isinstance(arg, tuple):
- prefix, suffix = r"|args=(\l", r",\n)\l"
- arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
- elif isinstance(arg, dict):
- prefix, suffix = r"|kwargs={\l", r",\n}\l"
- arg_strs_list = [
- f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items()
- ]
- else: # Fall back to nothing in unexpected case.
- return ""
- # Strip out node names if requested.
- if skip_node_names_in_args:
- arg_strs_list = [a for a in arg_strs_list if "%" not in a]
- if len(arg_strs_list) == 0:
- return ""
- arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
- if len(arg_strs_list) == 1:
- arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "")
- return arg_strs.replace("{", r"\{").replace("}", r"\}")
- label = "{" + f"name=%{node.name}|op_code={node.op}\n"
- if node.op == "call_module":
- leaf_module = self._get_leaf_node(module, node)
- label += r"\n" + self._typename(leaf_module) + r"\n|"
- extra = ""
- if hasattr(leaf_module, "__constants__"):
- extra = r"\n".join(
- [
- f"{c}: {getattr(leaf_module, c)}"
- for c in leaf_module.__constants__ # type: ignore[union-attr]
- ] # type: ignore[union-attr]
- )
- label += extra + r"\n"
- else:
- label += f"|target={self._typename(node.target)}" + r"\n"
- if self.normalize_args:
- try:
- args, kwargs = normalize_function( # type: ignore[misc]
- node.target, # type: ignore[arg-type]
- node.args, # type: ignore[arg-type]
- node.kwargs,
- normalize_to_only_use_kwargs=True,
- )
- except Exception:
- # Fallback to not normalizing if there's an exception.
- # Some functions need overloads specified to normalize.
- args, kwargs = node.args, node.kwargs
- else:
- args, kwargs = node.args, node.kwargs
- if len(args) > 0:
- label += _get_str_for_args_kwargs(args)
- if len(kwargs) > 0:
- label += _get_str_for_args_kwargs(kwargs)
- label += f"|num_users={len(node.users)}" + r"\n"
- tensor_meta = node.meta.get("tensor_meta")
- label += self._tensor_meta_to_label(tensor_meta)
- # for original fx graph
- # print buf=buf0, n_origin=6
- buf_meta = node.meta.get("buf_meta", None)
- if buf_meta is not None:
- label += f"|buf={buf_meta.name}" + r"\n"
- label += f"|n_origin={buf_meta.n_origin}" + r"\n"
- # for original fx graph
- # print file:lineno code
- if parse_stack_trace and node.stack_trace is not None:
- parsed_stack_trace = _parse_stack_trace(node.stack_trace)
- fname = self._shorten_file_name(parsed_stack_trace.file)
- label += (
- f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}"
- + r"\n"
- )
- return label + "}"
- def _tensor_meta_to_label(self, tm) -> str:
- if tm is None:
- return ""
- elif isinstance(tm, TensorMetadata):
- return self._stringify_tensor_meta(tm)
- elif isinstance(tm, list):
- result = ""
- for item in tm:
- result += self._tensor_meta_to_label(item)
- return result
- elif isinstance(tm, dict):
- result = ""
- for v in tm.values():
- result += self._tensor_meta_to_label(v)
- return result
- elif isinstance(tm, tuple):
- result = ""
- for item in tm:
- result += self._tensor_meta_to_label(item)
- return result
- else:
- raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
- def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
- result = ""
- if not hasattr(tm, "dtype"):
- print("tm", tm)
- result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
- result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
- result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
- result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
- if tm.is_quantized:
- if tm.qparams is None:
- raise AssertionError("qparams is None for quantized tensor")
- if "qscheme" not in tm.qparams:
- raise AssertionError("qscheme not in qparams")
- qscheme = tm.qparams["qscheme"]
- if qscheme in {
- torch.per_tensor_affine,
- torch.per_tensor_symmetric,
- }:
- result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
- result += (
- "|"
- + "q_zero_point"
- + "="
- + str(tm.qparams["zero_point"])
- + r"\n"
- )
- elif qscheme in {
- torch.per_channel_affine,
- torch.per_channel_symmetric,
- torch.per_channel_affine_float_qparams,
- }:
- result += (
- "|"
- + "q_per_channel_scale"
- + "="
- + str(tm.qparams["scale"])
- + r"\n"
- )
- result += (
- "|"
- + "q_per_channel_zero_point"
- + "="
- + str(tm.qparams["zero_point"])
- + r"\n"
- )
- result += (
- "|"
- + "q_per_channel_axis"
- + "="
- + str(tm.qparams["axis"])
- + r"\n"
- )
- else:
- raise RuntimeError(f"Unsupported qscheme: {qscheme}")
- result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
- return result
- def _get_tensor_label(self, t: torch.Tensor) -> str:
- return str(t.dtype) + str(list(t.shape)) + r"\n"
- # when parse_stack_trace=True
- # print file:lineno code
- def _to_dot(
- self,
- graph_module: torch.fx.GraphModule,
- name: str,
- ignore_getattr: bool,
- ignore_parameters_and_buffers: bool,
- skip_node_names_in_args: bool,
- parse_stack_trace: bool,
- ) -> pydot.Dot:
- """
- Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
- If ignore_parameters_and_buffers is True, the parameters and buffers
- created with the module will not be added as nodes and edges.
- """
- # "TB" means top-to-bottom rank direction in layout
- dot_graph = pydot.Dot(name, rankdir="TB")
- buf_name_to_subgraph = {}
- for node in graph_module.graph.nodes:
- if ignore_getattr and node.op == "get_attr":
- continue
- style = self._get_node_style(node)
- dot_node = pydot.Node(
- node.name,
- label=self._get_node_label(
- graph_module, node, skip_node_names_in_args, parse_stack_trace
- ),
- **style, # type: ignore[arg-type]
- )
- current_graph = dot_graph
- buf_meta = node.meta.get("buf_meta", None)
- if buf_meta is not None and buf_meta.n_origin > 1:
- buf_name = buf_meta.name
- if buf_name not in buf_name_to_subgraph:
- buf_name_to_subgraph[buf_name] = pydot.Cluster(
- buf_name, label=buf_name
- )
- current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment]
- # pyrefly: ignore [missing-attribute]
- current_graph.add_node(dot_node)
- def get_module_params_or_buffers():
- for pname, ptensor in chain(
- leaf_module.named_parameters(),
- # pyrefly: ignore [bad-argument-type]
- leaf_module.named_buffers(),
- ):
- pname1 = node.name + "." + pname
- label1 = (
- pname1 + "|op_code=get_" + "parameter"
- if isinstance(ptensor, torch.nn.Parameter)
- else "buffer" + r"\l"
- )
- dot_w_node = pydot.Node(
- pname1,
- label="{" + label1 + self._get_tensor_label(ptensor) + "}",
- **_WEIGHT_TEMPLATE, # type: ignore[arg-type]
- )
- dot_graph.add_node(dot_w_node)
- dot_graph.add_edge(pydot.Edge(pname1, node.name))
- if node.op == "call_module":
- leaf_module = self._get_leaf_node(graph_module, node)
- if not ignore_parameters_and_buffers and not isinstance(
- leaf_module, torch.fx.GraphModule
- ):
- get_module_params_or_buffers()
- for subgraph in buf_name_to_subgraph.values():
- subgraph.set("color", "royalblue")
- subgraph.set("penwidth", "2")
- dot_graph.add_subgraph(subgraph) # type: ignore[arg-type]
- for node in graph_module.graph.nodes:
- if ignore_getattr and node.op == "get_attr":
- continue
- for user in node.users:
- dot_graph.add_edge(pydot.Edge(node.name, user.name))
- return dot_graph
- else:
- if not TYPE_CHECKING:
- @compatibility(is_backward_compatible=False)
- class FxGraphDrawer:
- def __init__(
- self,
- graph_module: torch.fx.GraphModule,
- name: str,
- ignore_getattr: bool = False,
- ignore_parameters_and_buffers: bool = False,
- skip_node_names_in_args: bool = True,
- parse_stack_trace: bool = False,
- dot_graph_shape: Optional[str] = None,
- normalize_args: bool = False,
- ):
- raise RuntimeError(
- "FXGraphDrawer requires the pydot package to be installed. Please install "
- "pydot through your favorite Python package manager."
- )
|