optimization.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import logging
  4. import operator
  5. import time
  6. from collections import defaultdict
  7. from collections.abc import Iterable
  8. from enum import Enum
  9. from typing import Any, cast, Optional
  10. import torch
  11. import torch.fx as fx
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import torch.utils.mkldnn as th_mkldnn
  15. from torch.fx.node import Argument, Target
  16. from torch.fx.passes.shape_prop import ShapeProp
  17. from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval
  18. __all__ = [
  19. "matches_module_pattern",
  20. "replace_node_module",
  21. "fuse",
  22. "remove_dropout",
  23. "extract_subgraph",
  24. "modules_to_mkldnn",
  25. "reset_modules",
  26. "MklSubgraph",
  27. "gen_mkl_autotuner",
  28. "use_mkl_length",
  29. "UnionFind",
  30. "optimize_for_inference",
  31. ]
  32. def _parent_name(target: str) -> tuple[str, str]:
  33. """
  34. Splits a qualname into parent path and last atom.
  35. For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
  36. """
  37. *parent, name = target.rsplit(".", 1)
  38. return parent[0] if parent else "", name
  39. # Works for length 2 patterns with 2 modules
  40. def matches_module_pattern(
  41. pattern: Iterable[type], node: fx.Node, modules: dict[str, Any]
  42. ):
  43. if len(node.args) == 0:
  44. return False
  45. nodes: tuple[Any, fx.Node] = (node.args[0], node)
  46. for expected_type, current_node in zip(pattern, nodes):
  47. if not isinstance(current_node, fx.Node):
  48. return False
  49. if current_node.op != "call_module":
  50. return False
  51. if not isinstance(current_node.target, str):
  52. return False
  53. if current_node.target not in modules:
  54. return False
  55. if type(modules[current_node.target]) is not expected_type:
  56. return False
  57. return True
  58. def replace_node_module(
  59. node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module
  60. ):
  61. if not isinstance(node.target, str):
  62. raise AssertionError(f"Expected str target, got {type(node.target)}")
  63. parent_name, name = _parent_name(node.target)
  64. modules[node.target] = new_module
  65. setattr(modules[parent_name], name, new_module)
  66. def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module:
  67. """
  68. Fuses convolution/BN and linear/BN layers for inference purposes.
  69. Will deepcopy your model by default, but can modify the model inplace as well.
  70. """
  71. patterns = [
  72. (nn.Conv1d, nn.BatchNorm1d),
  73. (nn.Conv2d, nn.BatchNorm2d),
  74. (nn.Conv3d, nn.BatchNorm3d),
  75. (nn.Linear, nn.BatchNorm1d),
  76. ]
  77. if not inplace:
  78. model = copy.deepcopy(model)
  79. if not no_trace or not isinstance(model, torch.fx.GraphModule):
  80. fx_model = fx.symbolic_trace(model)
  81. else:
  82. fx_model = model
  83. modules = dict(fx_model.named_modules())
  84. new_graph = copy.deepcopy(fx_model.graph)
  85. for pattern in patterns:
  86. for node in new_graph.nodes:
  87. if matches_module_pattern(pattern, node, modules):
  88. if len(node.args[0].users) > 1:
  89. # Output of conv/linear is used by other nodes
  90. continue
  91. first_layer = modules[node.args[0].target]
  92. bn = modules[node.target]
  93. if not bn.track_running_stats:
  94. continue
  95. if pattern[0] in [nn.Conv1d, nn.Conv2d, nn.Conv3d]:
  96. fused_layer = fuse_conv_bn_eval(first_layer, bn)
  97. else: # nn.Linear
  98. fused_layer = fuse_linear_bn_eval(first_layer, bn)
  99. replace_node_module(node.args[0], modules, fused_layer)
  100. node.replace_all_uses_with(node.args[0])
  101. new_graph.erase_node(node)
  102. return fx.GraphModule(fx_model, new_graph)
  103. def remove_dropout(model: nn.Module) -> nn.Module:
  104. """
  105. Removes all dropout layers from the module.
  106. """
  107. fx_model = fx.symbolic_trace(model)
  108. class DropoutRemover(torch.fx.Transformer):
  109. def call_module(
  110. self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any]
  111. ) -> Any:
  112. if isinstance(self.submodules[target], nn.Dropout):
  113. if len(args) != 1:
  114. raise AssertionError(f"Expected 1 arg for Dropout, got {len(args)}")
  115. return args[0]
  116. else:
  117. return super().call_module(target, args, kwargs)
  118. return DropoutRemover(fx_model).transform()
  119. def extract_subgraph(
  120. orig_module: nn.Module,
  121. nodes: list[fx.Node],
  122. inputs: list[fx.Node],
  123. outputs: list[fx.Node],
  124. ):
  125. """
  126. Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.
  127. """
  128. new_graph = fx.Graph()
  129. env: dict[fx.Node, fx.Node] = {}
  130. for input in inputs:
  131. new_node = new_graph.placeholder(input.name)
  132. env[input] = new_node
  133. for node in nodes:
  134. new_node = new_graph.node_copy(node, lambda x: env[x])
  135. env[node] = new_node
  136. new_graph.output([env[output] for output in outputs])
  137. new_graph.lint()
  138. return fx.GraphModule(orig_module, new_graph)
  139. mkldnn_supported = [
  140. nn.Conv2d,
  141. nn.Linear,
  142. nn.BatchNorm2d,
  143. nn.ReLU,
  144. nn.MaxPool2d,
  145. nn.AvgPool2d,
  146. nn.AdaptiveAvgPool2d,
  147. torch.relu,
  148. torch.transpose,
  149. torch.sigmoid,
  150. F.relu,
  151. F.avg_pool2d,
  152. F.adaptive_avg_pool2d,
  153. ]
  154. # These are operators that may not be convertible into MKLDNN ops (e.g. the
  155. # args are scalar values). Thus, we only include them in the subgraph if their
  156. # arguments are already in MKLDNN.
  157. # TODO: Determine whether this can be removed after type inference.
  158. mkldnn_supported_unknown = [operator.add, operator.mul]
  159. mkldnn_map = {
  160. nn.Conv2d: th_mkldnn.MkldnnConv2d,
  161. nn.Linear: th_mkldnn.MkldnnLinear,
  162. nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a),
  163. }
  164. def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
  165. """
  166. For each node, if it's a module that can be preconverted into MKLDNN,
  167. then we do so and create a mapping to allow us to convert from the MKLDNN
  168. version of the module to the original.
  169. """
  170. old_modules: dict[nn.Module, nn.Module] = {}
  171. for node in nodes:
  172. if node.op == "call_module":
  173. if not isinstance(node.target, str):
  174. raise AssertionError(f"Expected str target, got {type(node.target)}")
  175. cur_module = modules[node.target]
  176. if type(cur_module) in mkldnn_map:
  177. # pyrefly: ignore [bad-index, index-error]
  178. new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
  179. if not isinstance(new_module, nn.Module):
  180. raise AssertionError(f"Expected nn.Module, got {type(new_module)}")
  181. old_modules[new_module] = copy.deepcopy(cur_module)
  182. replace_node_module(node, modules, new_module)
  183. return old_modules
  184. def reset_modules(
  185. nodes: list[fx.Node],
  186. modules: dict[str, nn.Module],
  187. old_modules: dict[nn.Module, nn.Module],
  188. ):
  189. """
  190. Maps each module that's been changed with `modules_to_mkldnn` back to its
  191. original.
  192. """
  193. for node in nodes:
  194. if node.op == "call_module":
  195. if not isinstance(node.target, str):
  196. raise AssertionError(f"Expected str target, got {type(node.target)}")
  197. cur_module = modules[node.target]
  198. if cur_module in old_modules:
  199. replace_node_module(node, modules, old_modules[cur_module])
  200. class MklSubgraph:
  201. def __init__(self, fx_graph: fx.Graph):
  202. self.fx_graph = fx_graph
  203. self.nodes: list[fx.Node] = []
  204. self.start_nodes: list[fx.Node] = []
  205. self.end_nodes: list[fx.Node] = []
  206. def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
  207. """
  208. This generates a heuristic that can be passed into `optimize_for_inference` that
  209. determines whether a subgraph should be run in MKL by running it with the example_inputs.
  210. Example usage:
  211. heuristic = gen_mkl_autotuner(example_inputs, iters=10)
  212. fast_model = optimization.optimize_for_inference(model, heuristic)
  213. """
  214. fx_model = None
  215. old_modules = None
  216. def use_mkl_heuristic(graph: MklSubgraph) -> bool:
  217. nonlocal fx_model, old_modules
  218. input_nodes = graph.start_nodes
  219. if fx_model is None:
  220. fx_model = graph.fx_graph.owning_module
  221. old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined]
  222. ShapeProp(fx_model).propagate(example_inputs)
  223. sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined]
  224. output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes])
  225. submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args)
  226. def benchmark(f):
  227. for _ in range(warmup):
  228. f()
  229. begin = time.time()
  230. for _ in range(iters):
  231. f()
  232. return time.time() - begin
  233. mkl_time = benchmark(
  234. lambda: [
  235. i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])
  236. ]
  237. )
  238. reset_modules(
  239. submodule.graph.nodes,
  240. dict(submodule.named_modules()),
  241. # pyrefly: ignore [bad-argument-type]
  242. old_modules,
  243. )
  244. no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
  245. return mkl_time < no_mkl_time
  246. return use_mkl_heuristic
  247. def use_mkl_length(graph: MklSubgraph) -> bool:
  248. """
  249. This is a heuristic that can be passed into `optimize_for_inference` that
  250. determines whether a subgraph should be run in MKL by checking if there
  251. are more than 2 nodes in it
  252. """
  253. return len(graph.nodes) > 2
  254. class UnionFind:
  255. def __init__(self, n):
  256. self.parent: list[Optional[int]] = [None] * n
  257. self.size: list[int] = [0] * n
  258. def make_set(self, v: int):
  259. self.parent[v] = v
  260. self.size[v] = 1
  261. def find(self, v: int) -> int:
  262. par = self.parent[v]
  263. if v == par:
  264. return v
  265. if par is None:
  266. raise AssertionError("Parent is None")
  267. self.parent[v] = self.find(par)
  268. return cast(int, self.parent[v])
  269. def join(self, a: int, b: int):
  270. a, b = self.find(a), self.find(b)
  271. if a == b:
  272. return a
  273. if self.size[a] < self.size[b]:
  274. a, b = b, a
  275. self.parent[b] = a
  276. self.size[a] += self.size[b]
  277. def optimize_for_inference(
  278. model: torch.nn.Module,
  279. pass_config: Optional[dict[str, Any]] = None,
  280. tracer: type[fx.Tracer] = fx.Tracer,
  281. ) -> torch.nn.Module:
  282. """
  283. Performs a set of optimization passes to optimize a model for the
  284. purposes of inference. Specifically, the passes that are run are:
  285. 1. Conv/BN fusion
  286. 2. Dropout removal
  287. 3. MKL layout optimizations
  288. The third optimization takes a function `use_mkl_heuristic` that's used
  289. to determine whether a subgraph should be explicitly run in MKL layout.
  290. Note: As FX does not currently handle aliasing, this pass currently
  291. assumes nothing aliases. If that isn't true, use at your own risk.
  292. """
  293. default_pass_config = {
  294. "conv_bn_fuse": True,
  295. "remove_dropout": True,
  296. "mkldnn_layout_optimize": {"heuristic": use_mkl_length},
  297. }
  298. if pass_config is None:
  299. pass_config = {}
  300. default_pass_config.update(pass_config)
  301. if default_pass_config["conv_bn_fuse"]:
  302. model = fuse(model)
  303. if default_pass_config["remove_dropout"]:
  304. model = remove_dropout(model)
  305. if default_pass_config["mkldnn_layout_optimize"] is False:
  306. return model
  307. if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict):
  308. raise RuntimeError("mkldnn_layout_optimize config is not a dict")
  309. if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]:
  310. raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config")
  311. use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"]
  312. cur_tracer = tracer()
  313. fx_graph = cur_tracer.trace(copy.deepcopy(model))
  314. fx.GraphModule(cur_tracer.root, fx_graph)
  315. modules: dict[str, nn.Module] = dict(model.named_modules())
  316. class MklSupport(Enum):
  317. NO = 1
  318. YES = 2
  319. UNKNOWN = 3
  320. # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node.
  321. # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node.
  322. # However, if it's in `mkldnn_supported_unknown`, then we only treat it as
  323. # a MKLDNN node if its inputs are MKLDNN nodes.
  324. for node in list(fx_graph.nodes):
  325. supports_mkldnn = MklSupport.NO
  326. if node.op == "call_module":
  327. cur_module = modules[node.target]
  328. if type(cur_module) in mkldnn_supported:
  329. supports_mkldnn = MklSupport.YES
  330. sample_parameter = next(cur_module.parameters(), None)
  331. if sample_parameter is not None:
  332. if sample_parameter.dtype != torch.float:
  333. raise AssertionError(
  334. "this pass is only for torch.float modules"
  335. )
  336. if sample_parameter.device != torch.device("cpu"):
  337. raise AssertionError("this pass is only for CPU modules")
  338. elif node.op == "call_function":
  339. if node.target in mkldnn_supported:
  340. supports_mkldnn = MklSupport.YES
  341. elif node.target in mkldnn_supported_unknown:
  342. supports_mkldnn = MklSupport.UNKNOWN
  343. if supports_mkldnn != MklSupport.NO:
  344. if supports_mkldnn == MklSupport.UNKNOWN:
  345. if not any(arg.target == "to_dense" for arg in node.args):
  346. continue
  347. with fx_graph.inserting_before(node):
  348. mkldnn_args = fx.map_arg(
  349. node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,))
  350. )
  351. node.args = cast(tuple[fx.node.Argument], mkldnn_args)
  352. with fx_graph.inserting_after(node):
  353. dense_x = fx_graph.create_node("call_method", "to_dense", (node,))
  354. node.replace_all_uses_with(dense_x)
  355. dense_x.args = (node,)
  356. # Does pre-conversion of all modules into MKLDNN (when possible)
  357. old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules)
  358. fx_graph.old_modules = old_modules # type: ignore[attr-defined]
  359. # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b
  360. for node in fx_graph.nodes:
  361. if node.op == "call_method" and node.target == "to_dense":
  362. prv_node = node.args[0]
  363. users = list(node.users)
  364. for user in users:
  365. if user.op == "call_method" and user.target == "to_mkldnn":
  366. user.replace_all_uses_with(prv_node)
  367. fx_graph.erase_node(user)
  368. if len(node.users) == 0:
  369. fx_graph.erase_node(node)
  370. num_nodes = len(fx_graph.nodes)
  371. uf = UnionFind(num_nodes)
  372. def get_color(n):
  373. if hasattr(n, "color"): # Current node is part of a MKL subgraph
  374. return uf.find(n.color)
  375. if hasattr(n, "start_color"): # Current node is input to MKL subgraph
  376. return uf.find(n.start_color)
  377. return None
  378. # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists
  379. # of input nodes (which are only `to_mkldnn` calls), output nodes
  380. # (`to_dense` calls), and intermediate nodes, which are run entirely on
  381. # MKLDNN layout tensors.
  382. #
  383. # Specifically, this code does a flood fill on a directed acyclic graph
  384. # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes).
  385. # If every node only had one input, this would be sufficient. However, in
  386. # the case that a node has multiple inputs coming from different start
  387. # nodes (i.e. colors), we need to join these 2 colors into 1. That's done
  388. # using a Disjoint Set Union.
  389. for cur_idx, node in enumerate(fx_graph.nodes):
  390. if node.op == "call_method" and node.target == "to_mkldnn":
  391. node.start_color = cur_idx
  392. uf.make_set(cur_idx)
  393. elif node.op == "call_method" and node.target == "to_dense":
  394. if get_color(node.args[0]) is None:
  395. raise AssertionError("Expected color for to_dense input")
  396. node.end_color = get_color(node.args[0])
  397. else:
  398. cur_colors = [
  399. get_color(i)
  400. for i in node.all_input_nodes
  401. if isinstance(i, fx.Node)
  402. if get_color(i) is not None
  403. ]
  404. if len(cur_colors) == 0:
  405. continue
  406. if any(i is None for i in cur_colors):
  407. raise AssertionError("Found None in cur_colors")
  408. cur_colors = sorted(cur_colors)
  409. node.color = cur_colors[0]
  410. for other_color in cur_colors[1:]:
  411. uf.join(cur_colors[0], other_color)
  412. mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph))
  413. for node in fx_graph.nodes:
  414. if hasattr(node, "color"):
  415. mkldnn_graphs[uf.find(node.color)].nodes.append(node)
  416. if hasattr(node, "start_color"):
  417. mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node)
  418. if hasattr(node, "end_color"):
  419. mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node)
  420. # Now that we have all the subgraphs, we need to decide which MKLDNN
  421. # subgraphs we actually want to keep in MKLDNN.
  422. for graph in mkldnn_graphs.values():
  423. if not use_mkl_heuristic(graph):
  424. for node in graph.start_nodes + graph.end_nodes:
  425. prv = node.args[0]
  426. node.replace_all_uses_with(prv) # type: ignore[arg-type]
  427. fx_graph.erase_node(node)
  428. reset_modules(graph.nodes, modules, old_modules)
  429. mkldnn_conversions = 0
  430. for node in fx_graph.nodes:
  431. if node.target == "to_mkldnn" or node.target == "to_dense":
  432. mkldnn_conversions += 1
  433. logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions)
  434. fx_graph.lint()
  435. result = fx.GraphModule(model, fx_graph)
  436. return result