subgraph_rewriter.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import copy
  2. from collections.abc import Callable
  3. from dataclasses import dataclass
  4. from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
  5. import torch
  6. from ._compatibility import compatibility
  7. from ._symbolic_trace import symbolic_trace
  8. from .graph import Graph
  9. from .graph_module import GraphModule
  10. from .node import Node
  11. if TYPE_CHECKING:
  12. from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
  13. __all__ = [
  14. "Match",
  15. "replace_pattern",
  16. "replace_pattern_with_filters",
  17. "ReplacedPatterns",
  18. ]
  19. @compatibility(is_backward_compatible=True)
  20. class Match(NamedTuple):
  21. # Node from which the match was found
  22. anchor: Node
  23. # Maps nodes in the pattern subgraph to nodes in the larger graph
  24. nodes_map: dict[Node, Node]
  25. @compatibility(is_backward_compatible=False)
  26. @dataclass
  27. class ReplacedPatterns:
  28. # Node from which the match was found
  29. anchor: Node
  30. # Maps nodes in the pattern subgraph to nodes in the larger graph
  31. nodes_map: dict[Node, Node]
  32. # List of nodes that were added into the graph
  33. replacements: list[Node]
  34. def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
  35. gm.delete_all_unused_submodules()
  36. if isinstance(replacement, GraphModule):
  37. replacement.graph.lint()
  38. def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]:
  39. module_path, _, attr_name = target.rpartition(".")
  40. try:
  41. mod: torch.nn.Module = gm.get_submodule(module_path)
  42. except AttributeError:
  43. return None
  44. attr = getattr(mod, attr_name, None)
  45. return attr
  46. for node in gm.graph.nodes:
  47. if node.op == "call_module" or node.op == "get_attr":
  48. gm_attr = try_get_attr(gm, node.target)
  49. replacement_attr = try_get_attr(replacement, node.target)
  50. # CASE 1: This target already exists as an attribute in our
  51. # result GraphModule. Whether or not it exists in
  52. # `replacement`, the existing submodule takes precedence.
  53. if gm_attr is not None:
  54. continue
  55. # CASE 2: The target exists as an attribute in `replacement`
  56. # only, so we need to copy it over.
  57. elif replacement_attr is not None:
  58. new_attr = copy.deepcopy(replacement_attr)
  59. if isinstance(replacement_attr, torch.nn.Module):
  60. gm.add_submodule(node.target, new_attr)
  61. else:
  62. setattr(gm, node.target, new_attr)
  63. # CASE 3: The target doesn't exist as an attribute in `gm`
  64. # or `replacement`
  65. else:
  66. raise RuntimeError(
  67. 'Attempted to create a "',
  68. node.op,
  69. '" node during subgraph rewriting '
  70. f"with target {node.target}, but "
  71. "the referenced attribute does not "
  72. "exist in the replacement GraphModule",
  73. )
  74. gm.graph.lint()
  75. @compatibility(is_backward_compatible=True)
  76. def replace_pattern(
  77. gm: GraphModule,
  78. pattern: Union[Callable, GraphModule],
  79. replacement: Union[Callable, GraphModule],
  80. ) -> list[Match]:
  81. """
  82. Matches all possible non-overlapping sets of operators and their
  83. data dependencies (``pattern``) in the Graph of a GraphModule
  84. (``gm``), then replaces each of these matched subgraphs with another
  85. subgraph (``replacement``).
  86. Args:
  87. ``gm``: The GraphModule that wraps the Graph to operate on
  88. ``pattern``: The subgraph to match in ``gm`` for replacement
  89. ``replacement``: The subgraph to replace ``pattern`` with
  90. Returns:
  91. List[Match]: A list of ``Match`` objects representing the places
  92. in the original graph that ``pattern`` was matched to. The list
  93. is empty if there are no matches. ``Match`` is defined as:
  94. .. code-block:: python
  95. class Match(NamedTuple):
  96. # Node from which the match was found
  97. anchor: Node
  98. # Maps nodes in the pattern subgraph to nodes in the larger graph
  99. nodes_map: Dict[Node, Node]
  100. Examples:
  101. .. code-block:: python
  102. import torch
  103. from torch.fx import symbolic_trace, subgraph_rewriter
  104. class M(torch.nn.Module):
  105. def __init__(self) -> None:
  106. super().__init__()
  107. def forward(self, x, w1, w2):
  108. m1 = torch.cat([w1, w2]).sum()
  109. m2 = torch.cat([w1, w2]).sum()
  110. return x + torch.max(m1) + torch.max(m2)
  111. def pattern(w1, w2):
  112. return torch.cat([w1, w2])
  113. def replacement(w1, w2):
  114. return torch.stack([w1, w2])
  115. traced_module = symbolic_trace(M())
  116. subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
  117. The above code will first match ``pattern`` in the ``forward``
  118. method of ``traced_module``. Pattern-matching is done based on
  119. use-def relationships, not node names. For example, if you had
  120. ``p = torch.cat([a, b])`` in ``pattern``, you could match
  121. ``m = torch.cat([a, b])`` in the original ``forward`` function,
  122. despite the variable names being different (``p`` vs ``m``).
  123. The ``return`` statement in ``pattern`` is matched based on its
  124. value only; it may or may not match to the ``return`` statement in
  125. the larger graph. In other words, the pattern doesn't have to extend
  126. to the end of the larger graph.
  127. When the pattern is matched, it will be removed from the larger
  128. function and replaced by ``replacement``. If there are multiple
  129. matches for ``pattern`` in the larger function, each non-overlapping
  130. match will be replaced. In the case of a match overlap, the first
  131. found match in the set of overlapping matches will be replaced.
  132. ("First" here being defined as the first in a topological ordering
  133. of the Nodes' use-def relationships. In most cases, the first Node
  134. is the parameter that appears directly after ``self``, while the
  135. last Node is whatever the function returns.)
  136. One important thing to note is that the parameters of the
  137. ``pattern`` Callable must be used in the Callable itself,
  138. and the parameters of the ``replacement`` Callable must match
  139. the pattern. The first rule is why, in the above code block, the
  140. ``forward`` function has parameters ``x, w1, w2``, but the
  141. ``pattern`` function only has parameters ``w1, w2``. ``pattern``
  142. doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
  143. As an example of the second rule, consider replacing
  144. .. code-block:: python
  145. def pattern(x, y):
  146. return torch.neg(x) + torch.relu(y)
  147. with
  148. .. code-block:: python
  149. def replacement(x, y):
  150. return torch.relu(x)
  151. In this case, ``replacement`` needs the same number of parameters
  152. as ``pattern`` (both ``x`` and ``y``), even though the parameter
  153. ``y`` isn't used in ``replacement``.
  154. After calling ``subgraph_rewriter.replace_pattern``, the generated
  155. Python code looks like this:
  156. .. code-block:: python
  157. def forward(self, x, w1, w2):
  158. stack_1 = torch.stack([w1, w2])
  159. sum_1 = stack_1.sum()
  160. stack_2 = torch.stack([w1, w2])
  161. sum_2 = stack_2.sum()
  162. max_1 = torch.max(sum_1)
  163. add_1 = x + max_1
  164. max_2 = torch.max(sum_2)
  165. add_2 = add_1 + max_2
  166. return add_2
  167. """
  168. match_and_replacements = _replace_pattern(gm, pattern, replacement)
  169. return [
  170. Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements
  171. ]
  172. # Experimental API, not backward compatible
  173. @compatibility(is_backward_compatible=False)
  174. def replace_pattern_with_filters(
  175. gm: GraphModule,
  176. pattern: Union[Callable, Graph, GraphModule],
  177. replacement: Union[Callable, Graph, GraphModule, None] = None,
  178. match_filters: Optional[
  179. list[Callable[["InternalMatch", Graph, Graph], bool]]
  180. ] = None,
  181. ignore_literals: bool = False,
  182. # Placed at the end to avoid breaking backward compatibility
  183. replacement_callback: Optional[
  184. Callable[["InternalMatch", Graph, Graph], Graph]
  185. ] = None,
  186. node_name_match: str = "",
  187. ) -> list[ReplacedPatterns]:
  188. """
  189. See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
  190. Args:
  191. ``match_filters``: A list of functions that take in
  192. (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
  193. whether the match satisfies the condition.
  194. See matcher_utils.py for definition of InternalMatch.
  195. ``replacement_callback``: A function that takes in a match and returns a
  196. Graph to be used as the replacement. This allows you to construct a
  197. replacement graph based on the match.
  198. ``replacement_callback``: Node name to match. If not empty, it will try to match the node name.
  199. """
  200. return _replace_pattern(
  201. gm,
  202. pattern,
  203. replacement,
  204. match_filters,
  205. ignore_literals,
  206. replacement_callback,
  207. node_name_match,
  208. )
  209. def _replace_pattern(
  210. gm: GraphModule,
  211. pattern: Union[Callable, Graph, GraphModule],
  212. replacement: Union[Callable, Graph, GraphModule, None] = None,
  213. match_filters: Optional[
  214. list[Callable[["InternalMatch", Graph, Graph], bool]]
  215. ] = None,
  216. ignore_literals: bool = False,
  217. # Placed at the end to avoid breaking backward compatibility
  218. replacement_callback: Optional[
  219. Callable[["InternalMatch", Graph, Graph], Graph]
  220. ] = None,
  221. node_name_match: str = "",
  222. ) -> list[ReplacedPatterns]:
  223. from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
  224. if match_filters is None:
  225. match_filters = []
  226. # Get the graphs for `gm`, `pattern`, `replacement`
  227. original_graph: Graph = gm.graph
  228. if isinstance(pattern, GraphModule):
  229. pattern_graph = pattern.graph
  230. elif isinstance(pattern, Graph):
  231. pattern_graph = pattern
  232. else:
  233. pattern_graph = symbolic_trace(pattern).graph # type: ignore[arg-type]
  234. matcher = SubgraphMatcher(
  235. pattern_graph,
  236. match_output=False,
  237. match_placeholder=False,
  238. remove_overlapping_matches=True,
  239. ignore_literals=ignore_literals,
  240. )
  241. _matches: list[InternalMatch] = matcher.match(
  242. original_graph, node_name_match=node_name_match
  243. )
  244. # Filter out matches that don't match the filter
  245. _matches = [
  246. m
  247. for m in _matches
  248. if all(
  249. match_filter(m, original_graph, pattern_graph)
  250. for match_filter in match_filters
  251. )
  252. ]
  253. if isinstance(replacement, GraphModule):
  254. common_replacement_graph = replacement.graph
  255. elif isinstance(replacement, Graph):
  256. common_replacement_graph = replacement
  257. elif callable(replacement):
  258. common_replacement_graph = symbolic_trace(replacement).graph
  259. else:
  260. if replacement_callback is None:
  261. raise AssertionError(
  262. "Must provide either a replacement GraphModule or a replacement callback"
  263. )
  264. common_replacement_graph = None # type: ignore[assignment]
  265. # As we progressively replace nodes, we'll need to keep track of how the match results should change
  266. match_changed_node: dict[Node, Node] = {}
  267. match_and_replacements = []
  268. for match in _matches:
  269. if replacement_callback is not None:
  270. replacement_graph = replacement_callback(
  271. match, original_graph, pattern_graph
  272. )
  273. else:
  274. if common_replacement_graph is None:
  275. raise AssertionError(
  276. "Must provide either a replacement GraphModule or a replacement callback"
  277. )
  278. replacement_graph = common_replacement_graph
  279. replacement_placeholders = [
  280. n for n in replacement_graph.nodes if n.op == "placeholder"
  281. ]
  282. # Build connecting between replacement graph's input and original graph input producer node
  283. # Initialize `val_map` with mappings from placeholder nodes in
  284. # `replacement` to their corresponding node in `original_graph`
  285. if len(match.placeholder_nodes) != len(replacement_placeholders):
  286. raise AssertionError(
  287. f"Placeholder count mismatch: {len(match.placeholder_nodes)} vs "
  288. f"{len(replacement_placeholders)}"
  289. )
  290. val_map: dict[Node, Node] = {}
  291. for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
  292. if isinstance(gn, Node):
  293. val_map[rn] = match_changed_node.get(gn, gn)
  294. if gn != val_map[rn]:
  295. # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
  296. gn_ind = match.placeholder_nodes.index(gn)
  297. match.placeholder_nodes[gn_ind] = match_changed_node[gn]
  298. map_key = list(match.nodes_map.keys())[
  299. list(match.nodes_map.values()).index(gn)
  300. ]
  301. match.nodes_map[map_key] = match_changed_node[gn]
  302. else:
  303. val_map[rn] = gn
  304. # Copy the replacement graph over
  305. user_nodes: set[Node] = set()
  306. for n in match.returning_nodes:
  307. user_nodes.update(n.users)
  308. first_user_node = None
  309. if len(user_nodes) == 0:
  310. first_user_node = None
  311. elif len(user_nodes) == 1:
  312. first_user_node = next(iter(user_nodes))
  313. else:
  314. # If there are multiple user nodes, we need to find the first user node
  315. # in the current execution order of the `original_graph`
  316. for n in original_graph.nodes:
  317. if n in user_nodes:
  318. first_user_node = n
  319. break
  320. first_next_node = None
  321. if first_user_node is None:
  322. # no users, so we insert the replacement graph before the first next
  323. # node of returning nodes
  324. next_node = None
  325. for n in reversed(original_graph.nodes):
  326. if n in match.returning_nodes:
  327. first_next_node = next_node
  328. break
  329. else:
  330. next_node = n
  331. insert_point = (
  332. first_user_node if first_user_node is not None else first_next_node
  333. )
  334. if insert_point is None:
  335. raise AssertionError("The insert point can't be None")
  336. with original_graph.inserting_before(insert_point):
  337. copied_returning_nodes = original_graph.graph_copy(
  338. replacement_graph, val_map
  339. )
  340. if isinstance(copied_returning_nodes, Node):
  341. copied_returning_nodes = (copied_returning_nodes,)
  342. # Get a list of nodes that have been replaced into the graph
  343. replacement_nodes: list[Node] = [
  344. v for v in val_map.values() if v not in match.placeholder_nodes
  345. ]
  346. # Hook the output Node of the replacement subgraph in to the
  347. # original Graph at the correct location
  348. if len(match.returning_nodes) != len(copied_returning_nodes): # type: ignore[arg-type]
  349. raise AssertionError(
  350. f"Returning nodes count mismatch: {len(match.returning_nodes)} vs "
  351. f"{len(copied_returning_nodes)}" # pyrefly: ignore [bad-argument-type]
  352. )
  353. for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type]
  354. gn.replace_all_uses_with(copied_node)
  355. match_changed_node[gn] = copied_node
  356. # Remove the original nodes
  357. for node in reversed(pattern_graph.nodes):
  358. if node.op != "placeholder" and node.op != "output":
  359. gn = match.nodes_map[node]
  360. gm.graph.erase_node(gn)
  361. match_and_replacements.append(
  362. ReplacedPatterns(
  363. anchor=match.anchors[0],
  364. nodes_map=match.nodes_map,
  365. replacements=replacement_nodes,
  366. )
  367. )
  368. # Update the passed-in GraphModule to reflect the new state of
  369. # `original_graph`
  370. gm.recompile()
  371. # If `replacement` was an nn.Module, we'll need to make sure that
  372. # all the submodules have been copied over correctly
  373. if isinstance(replacement, torch.nn.Module):
  374. _replace_attributes(gm, replacement)
  375. return match_and_replacements