| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498 |
- # mypy: allow-untyped-defs
- import copy
- import logging
- import operator
- import time
- from collections import defaultdict
- from collections.abc import Iterable
- from enum import Enum
- from typing import Any, cast, Optional
- import torch
- import torch.fx as fx
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.mkldnn as th_mkldnn
- from torch.fx.node import Argument, Target
- from torch.fx.passes.shape_prop import ShapeProp
- from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval
- __all__ = [
- "matches_module_pattern",
- "replace_node_module",
- "fuse",
- "remove_dropout",
- "extract_subgraph",
- "modules_to_mkldnn",
- "reset_modules",
- "MklSubgraph",
- "gen_mkl_autotuner",
- "use_mkl_length",
- "UnionFind",
- "optimize_for_inference",
- ]
- def _parent_name(target: str) -> tuple[str, str]:
- """
- Splits a qualname into parent path and last atom.
- For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
- """
- *parent, name = target.rsplit(".", 1)
- return parent[0] if parent else "", name
- # Works for length 2 patterns with 2 modules
- def matches_module_pattern(
- pattern: Iterable[type], node: fx.Node, modules: dict[str, Any]
- ):
- if len(node.args) == 0:
- return False
- nodes: tuple[Any, fx.Node] = (node.args[0], node)
- for expected_type, current_node in zip(pattern, nodes):
- if not isinstance(current_node, fx.Node):
- return False
- if current_node.op != "call_module":
- return False
- if not isinstance(current_node.target, str):
- return False
- if current_node.target not in modules:
- return False
- if type(modules[current_node.target]) is not expected_type:
- return False
- return True
- def replace_node_module(
- node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module
- ):
- if not isinstance(node.target, str):
- raise AssertionError(f"Expected str target, got {type(node.target)}")
- parent_name, name = _parent_name(node.target)
- modules[node.target] = new_module
- setattr(modules[parent_name], name, new_module)
- def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
- """
- Fuses convolution/BN and linear/BN layers for inference purposes.
- Will deepcopy your model by default, but can modify the model inplace as well.
- """
- patterns = [
- (nn.Conv1d, nn.BatchNorm1d),
- (nn.Conv2d, nn.BatchNorm2d),
- (nn.Conv3d, nn.BatchNorm3d),
- (nn.Linear, nn.BatchNorm1d),
- ]
- if not inplace:
- model = copy.deepcopy(model)
- if not no_trace or not isinstance(model, torch.fx.GraphModule):
- fx_model = fx.symbolic_trace(model)
- else:
- fx_model = model
- modules = dict(fx_model.named_modules())
- new_graph = copy.deepcopy(fx_model.graph)
- for pattern in patterns:
- for node in new_graph.nodes:
- if matches_module_pattern(pattern, node, modules):
- if len(node.args[0].users) > 1:
- # Output of conv/linear is used by other nodes
- continue
- first_layer = modules[node.args[0].target]
- bn = modules[node.target]
- if not bn.track_running_stats:
- continue
- if pattern[0] in [nn.Conv1d, nn.Conv2d, nn.Conv3d]:
- fused_layer = fuse_conv_bn_eval(first_layer, bn)
- else: # nn.Linear
- fused_layer = fuse_linear_bn_eval(first_layer, bn)
- replace_node_module(node.args[0], modules, fused_layer)
- node.replace_all_uses_with(node.args[0])
- new_graph.erase_node(node)
- return fx.GraphModule(fx_model, new_graph)
- def remove_dropout(model: nn.Module) -> nn.Module:
- """
- Removes all dropout layers from the module.
- """
- fx_model = fx.symbolic_trace(model)
- class DropoutRemover(torch.fx.Transformer):
- def call_module(
- self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
- ) -> Any:
- if isinstance(self.submodules[target], nn.Dropout):
- if len(args) != 1:
- raise AssertionError(f"Expected 1 arg for Dropout, got {len(args)}")
- return args[0]
- else:
- return super().call_module(target, args, kwargs)
- return DropoutRemover(fx_model).transform()
- def extract_subgraph(
- orig_module: nn.Module,
- nodes: list[fx.Node],
- inputs: list[fx.Node],
- outputs: list[fx.Node],
- ):
- """
- Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
- """
- new_graph = fx.Graph()
- env: dict[fx.Node, fx.Node] = {}
- for input in inputs:
- new_node = new_graph.placeholder(input.name)
- env[input] = new_node
- for node in nodes:
- new_node = new_graph.node_copy(node, lambda x: env[x])
- env[node] = new_node
- new_graph.output([env[output] for output in outputs])
- new_graph.lint()
- return fx.GraphModule(orig_module, new_graph)
- mkldnn_supported = [
- nn.Conv2d,
- nn.Linear,
- nn.BatchNorm2d,
- nn.ReLU,
- nn.MaxPool2d,
- nn.AvgPool2d,
- nn.AdaptiveAvgPool2d,
- torch.relu,
- torch.transpose,
- torch.sigmoid,
- F.relu,
- F.avg_pool2d,
- F.adaptive_avg_pool2d,
- ]
- # These are operators that may not be convertible into MKLDNN ops (e.g. the
- # args are scalar values). Thus, we only include them in the subgraph if their
- # arguments are already in MKLDNN.
- # TODO: Determine whether this can be removed after type inference.
- mkldnn_supported_unknown = [operator.add, operator.mul]
- mkldnn_map = {
- nn.Conv2d: th_mkldnn.MkldnnConv2d,
- nn.Linear: th_mkldnn.MkldnnLinear,
- nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a),
- }
- def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
- """
- For each node, if it's a module that can be preconverted into MKLDNN,
- then we do so and create a mapping to allow us to convert from the MKLDNN
- version of the module to the original.
- """
- old_modules: dict[nn.Module, nn.Module] = {}
- for node in nodes:
- if node.op == "call_module":
- if not isinstance(node.target, str):
- raise AssertionError(f"Expected str target, got {type(node.target)}")
- cur_module = modules[node.target]
- if type(cur_module) in mkldnn_map:
- # pyrefly: ignore [bad-index, index-error]
- new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
- if not isinstance(new_module, nn.Module):
- raise AssertionError(f"Expected nn.Module, got {type(new_module)}")
- old_modules[new_module] = copy.deepcopy(cur_module)
- replace_node_module(node, modules, new_module)
- return old_modules
- def reset_modules(
- nodes: list[fx.Node],
- modules: dict[str, nn.Module],
- old_modules: dict[nn.Module, nn.Module],
- ):
- """
- Maps each module that's been changed with `modules_to_mkldnn` back to its
- original.
- """
- for node in nodes:
- if node.op == "call_module":
- if not isinstance(node.target, str):
- raise AssertionError(f"Expected str target, got {type(node.target)}")
- cur_module = modules[node.target]
- if cur_module in old_modules:
- replace_node_module(node, modules, old_modules[cur_module])
- class MklSubgraph:
- def __init__(self, fx_graph: fx.Graph):
- self.fx_graph = fx_graph
- self.nodes: list[fx.Node] = []
- self.start_nodes: list[fx.Node] = []
- self.end_nodes: list[fx.Node] = []
- def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
- """
- This generates a heuristic that can be passed into `optimize_for_inference` that
- determines whether a subgraph should be run in MKL by running it with the example_inputs.
- Example usage:
- heuristic = gen_mkl_autotuner(example_inputs, iters=10)
- fast_model = optimization.optimize_for_inference(model, heuristic)
- """
- fx_model = None
- old_modules = None
- def use_mkl_heuristic(graph: MklSubgraph) -> bool:
- nonlocal fx_model, old_modules
- input_nodes = graph.start_nodes
- if fx_model is None:
- fx_model = graph.fx_graph.owning_module
- old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
- ShapeProp(fx_model).propagate(example_inputs)
- sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
- output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes])
- submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
- def benchmark(f):
- for _ in range(warmup):
- f()
- begin = time.time()
- for _ in range(iters):
- f()
- return time.time() - begin
- mkl_time = benchmark(
- lambda: [
- i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])
- ]
- )
- reset_modules(
- submodule.graph.nodes,
- dict(submodule.named_modules()),
- # pyrefly: ignore [bad-argument-type]
- old_modules,
- )
- no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
- return mkl_time < no_mkl_time
- return use_mkl_heuristic
- def use_mkl_length(graph: MklSubgraph) -> bool:
- """
- This is a heuristic that can be passed into `optimize_for_inference` that
- determines whether a subgraph should be run in MKL by checking if there
- are more than 2 nodes in it
- """
- return len(graph.nodes) > 2
- class UnionFind:
- def __init__(self, n):
- self.parent: list[Optional[int]] = [None] * n
- self.size: list[int] = [0] * n
- def make_set(self, v: int):
- self.parent[v] = v
- self.size[v] = 1
- def find(self, v: int) -> int:
- par = self.parent[v]
- if v == par:
- return v
- if par is None:
- raise AssertionError("Parent is None")
- self.parent[v] = self.find(par)
- return cast(int, self.parent[v])
- def join(self, a: int, b: int):
- a, b = self.find(a), self.find(b)
- if a == b:
- return a
- if self.size[a] < self.size[b]:
- a, b = b, a
- self.parent[b] = a
- self.size[a] += self.size[b]
- def optimize_for_inference(
- model: torch.nn.Module,
- pass_config: Optional[dict[str, Any]] = None,
- tracer: type[fx.Tracer] = fx.Tracer,
- ) -> torch.nn.Module:
- """
- Performs a set of optimization passes to optimize a model for the
- purposes of inference. Specifically, the passes that are run are:
- 1. Conv/BN fusion
- 2. Dropout removal
- 3. MKL layout optimizations
- The third optimization takes a function `use_mkl_heuristic` that's used
- to determine whether a subgraph should be explicitly run in MKL layout.
- Note: As FX does not currently handle aliasing, this pass currently
- assumes nothing aliases. If that isn't true, use at your own risk.
- """
- default_pass_config = {
- "conv_bn_fuse": True,
- "remove_dropout": True,
- "mkldnn_layout_optimize": {"heuristic": use_mkl_length},
- }
- if pass_config is None:
- pass_config = {}
- default_pass_config.update(pass_config)
- if default_pass_config["conv_bn_fuse"]:
- model = fuse(model)
- if default_pass_config["remove_dropout"]:
- model = remove_dropout(model)
- if default_pass_config["mkldnn_layout_optimize"] is False:
- return model
- if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
- raise RuntimeError("mkldnn_layout_optimize config is not a dict")
- if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
- raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
- use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
- cur_tracer = tracer()
- fx_graph = cur_tracer.trace(copy.deepcopy(model))
- fx.GraphModule(cur_tracer.root, fx_graph)
- modules: dict[str, nn.Module] = dict(model.named_modules())
- class MklSupport(Enum):
- NO = 1
- YES = 2
- UNKNOWN = 3
- # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
- # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
- # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
- # a MKLDNN node if its inputs are MKLDNN nodes.
- for node in list(fx_graph.nodes):
- supports_mkldnn = MklSupport.NO
- if node.op == "call_module":
- cur_module = modules[node.target]
- if type(cur_module) in mkldnn_supported:
- supports_mkldnn = MklSupport.YES
- sample_parameter = next(cur_module.parameters(), None)
- if sample_parameter is not None:
- if sample_parameter.dtype != torch.float:
- raise AssertionError(
- "this pass is only for torch.float modules"
- )
- if sample_parameter.device != torch.device("cpu"):
- raise AssertionError("this pass is only for CPU modules")
- elif node.op == "call_function":
- if node.target in mkldnn_supported:
- supports_mkldnn = MklSupport.YES
- elif node.target in mkldnn_supported_unknown:
- supports_mkldnn = MklSupport.UNKNOWN
- if supports_mkldnn != MklSupport.NO:
- if supports_mkldnn == MklSupport.UNKNOWN:
- if not any(arg.target == "to_dense" for arg in node.args):
- continue
- with fx_graph.inserting_before(node):
- mkldnn_args = fx.map_arg(
- node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
- )
- node.args = cast(tuple[fx.node.Argument], mkldnn_args)
- with fx_graph.inserting_after(node):
- dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
- node.replace_all_uses_with(dense_x)
- dense_x.args = (node,)
- # Does pre-conversion of all modules into MKLDNN (when possible)
- old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
- fx_graph.old_modules = old_modules # type: ignore[attr-defined]
- # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
- for node in fx_graph.nodes:
- if node.op == "call_method" and node.target == "to_dense":
- prv_node = node.args[0]
- users = list(node.users)
- for user in users:
- if user.op == "call_method" and user.target == "to_mkldnn":
- user.replace_all_uses_with(prv_node)
- fx_graph.erase_node(user)
- if len(node.users) == 0:
- fx_graph.erase_node(node)
- num_nodes = len(fx_graph.nodes)
- uf = UnionFind(num_nodes)
- def get_color(n):
- if hasattr(n, "color"): # Current node is part of a MKL subgraph
- return uf.find(n.color)
- if hasattr(n, "start_color"): # Current node is input to MKL subgraph
- return uf.find(n.start_color)
- return None
- # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
- # of input nodes (which are only `to_mkldnn` calls), output nodes
- # (`to_dense` calls), and intermediate nodes, which are run entirely on
- # MKLDNN layout tensors.
- #
- # Specifically, this code does a flood fill on a directed acyclic graph
- # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
- # If every node only had one input, this would be sufficient. However, in
- # the case that a node has multiple inputs coming from different start
- # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
- # using a Disjoint Set Union.
- for cur_idx, node in enumerate(fx_graph.nodes):
- if node.op == "call_method" and node.target == "to_mkldnn":
- node.start_color = cur_idx
- uf.make_set(cur_idx)
- elif node.op == "call_method" and node.target == "to_dense":
- if get_color(node.args[0]) is None:
- raise AssertionError("Expected color for to_dense input")
- node.end_color = get_color(node.args[0])
- else:
- cur_colors = [
- get_color(i)
- for i in node.all_input_nodes
- if isinstance(i, fx.Node)
- if get_color(i) is not None
- ]
- if len(cur_colors) == 0:
- continue
- if any(i is None for i in cur_colors):
- raise AssertionError("Found None in cur_colors")
- cur_colors = sorted(cur_colors)
- node.color = cur_colors[0]
- for other_color in cur_colors[1:]:
- uf.join(cur_colors[0], other_color)
- mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
- for node in fx_graph.nodes:
- if hasattr(node, "color"):
- mkldnn_graphs[uf.find(node.color)].nodes.append(node)
- if hasattr(node, "start_color"):
- mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
- if hasattr(node, "end_color"):
- mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
- # Now that we have all the subgraphs, we need to decide which MKLDNN
- # subgraphs we actually want to keep in MKLDNN.
- for graph in mkldnn_graphs.values():
- if not use_mkl_heuristic(graph):
- for node in graph.start_nodes + graph.end_nodes:
- prv = node.args[0]
- node.replace_all_uses_with(prv) # type: ignore[arg-type]
- fx_graph.erase_node(node)
- reset_modules(graph.nodes, modules, old_modules)
- mkldnn_conversions = 0
- for node in fx_graph.nodes:
- if node.target == "to_mkldnn" or node.target == "to_dense":
- mkldnn_conversions += 1
- logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)
- fx_graph.lint()
- result = fx.GraphModule(model, fx_graph)
- return result
|