| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451 |
- import copy
- from collections.abc import Callable
- from dataclasses import dataclass
- from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union
- import torch
- from ._compatibility import compatibility
- from ._symbolic_trace import symbolic_trace
- from .graph import Graph
- from .graph_module import GraphModule
- from .node import Node
- if TYPE_CHECKING:
- from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
- __all__ = [
- "Match",
- "replace_pattern",
- "replace_pattern_with_filters",
- "ReplacedPatterns",
- ]
- @compatibility(is_backward_compatible=True)
- class Match(NamedTuple):
- # Node from which the match was found
- anchor: Node
- # Maps nodes in the pattern subgraph to nodes in the larger graph
- nodes_map: dict[Node, Node]
- @compatibility(is_backward_compatible=False)
- @dataclass
- class ReplacedPatterns:
- # Node from which the match was found
- anchor: Node
- # Maps nodes in the pattern subgraph to nodes in the larger graph
- nodes_map: dict[Node, Node]
- # List of nodes that were added into the graph
- replacements: list[Node]
- def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
- gm.delete_all_unused_submodules()
- if isinstance(replacement, GraphModule):
- replacement.graph.lint()
- def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]:
- module_path, _, attr_name = target.rpartition(".")
- try:
- mod: torch.nn.Module = gm.get_submodule(module_path)
- except AttributeError:
- return None
- attr = getattr(mod, attr_name, None)
- return attr
- for node in gm.graph.nodes:
- if node.op == "call_module" or node.op == "get_attr":
- gm_attr = try_get_attr(gm, node.target)
- replacement_attr = try_get_attr(replacement, node.target)
- # CASE 1: This target already exists as an attribute in our
- # result GraphModule. Whether or not it exists in
- # `replacement`, the existing submodule takes precedence.
- if gm_attr is not None:
- continue
- # CASE 2: The target exists as an attribute in `replacement`
- # only, so we need to copy it over.
- elif replacement_attr is not None:
- new_attr = copy.deepcopy(replacement_attr)
- if isinstance(replacement_attr, torch.nn.Module):
- gm.add_submodule(node.target, new_attr)
- else:
- setattr(gm, node.target, new_attr)
- # CASE 3: The target doesn't exist as an attribute in `gm`
- # or `replacement`
- else:
- raise RuntimeError(
- 'Attempted to create a "',
- node.op,
- '" node during subgraph rewriting '
- f"with target {node.target}, but "
- "the referenced attribute does not "
- "exist in the replacement GraphModule",
- )
- gm.graph.lint()
- @compatibility(is_backward_compatible=True)
- def replace_pattern(
- gm: GraphModule,
- pattern: Union[Callable, GraphModule],
- replacement: Union[Callable, GraphModule],
- ) -> list[Match]:
- """
- Matches all possible non-overlapping sets of operators and their
- data dependencies (``pattern``) in the Graph of a GraphModule
- (``gm``), then replaces each of these matched subgraphs with another
- subgraph (``replacement``).
- Args:
- ``gm``: The GraphModule that wraps the Graph to operate on
- ``pattern``: The subgraph to match in ``gm`` for replacement
- ``replacement``: The subgraph to replace ``pattern`` with
- Returns:
- List[Match]: A list of ``Match`` objects representing the places
- in the original graph that ``pattern`` was matched to. The list
- is empty if there are no matches. ``Match`` is defined as:
- .. code-block:: python
- class Match(NamedTuple):
- # Node from which the match was found
- anchor: Node
- # Maps nodes in the pattern subgraph to nodes in the larger graph
- nodes_map: Dict[Node, Node]
- Examples:
- .. code-block:: python
- import torch
- from torch.fx import symbolic_trace, subgraph_rewriter
- class M(torch.nn.Module):
- def __init__(self) -> None:
- super().__init__()
- def forward(self, x, w1, w2):
- m1 = torch.cat([w1, w2]).sum()
- m2 = torch.cat([w1, w2]).sum()
- return x + torch.max(m1) + torch.max(m2)
- def pattern(w1, w2):
- return torch.cat([w1, w2])
- def replacement(w1, w2):
- return torch.stack([w1, w2])
- traced_module = symbolic_trace(M())
- subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
- The above code will first match ``pattern`` in the ``forward``
- method of ``traced_module``. Pattern-matching is done based on
- use-def relationships, not node names. For example, if you had
- ``p = torch.cat([a, b])`` in ``pattern``, you could match
- ``m = torch.cat([a, b])`` in the original ``forward`` function,
- despite the variable names being different (``p`` vs ``m``).
- The ``return`` statement in ``pattern`` is matched based on its
- value only; it may or may not match to the ``return`` statement in
- the larger graph. In other words, the pattern doesn't have to extend
- to the end of the larger graph.
- When the pattern is matched, it will be removed from the larger
- function and replaced by ``replacement``. If there are multiple
- matches for ``pattern`` in the larger function, each non-overlapping
- match will be replaced. In the case of a match overlap, the first
- found match in the set of overlapping matches will be replaced.
- ("First" here being defined as the first in a topological ordering
- of the Nodes' use-def relationships. In most cases, the first Node
- is the parameter that appears directly after ``self``, while the
- last Node is whatever the function returns.)
- One important thing to note is that the parameters of the
- ``pattern`` Callable must be used in the Callable itself,
- and the parameters of the ``replacement`` Callable must match
- the pattern. The first rule is why, in the above code block, the
- ``forward`` function has parameters ``x, w1, w2``, but the
- ``pattern`` function only has parameters ``w1, w2``. ``pattern``
- doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
- As an example of the second rule, consider replacing
- .. code-block:: python
- def pattern(x, y):
- return torch.neg(x) + torch.relu(y)
- with
- .. code-block:: python
- def replacement(x, y):
- return torch.relu(x)
- In this case, ``replacement`` needs the same number of parameters
- as ``pattern`` (both ``x`` and ``y``), even though the parameter
- ``y`` isn't used in ``replacement``.
- After calling ``subgraph_rewriter.replace_pattern``, the generated
- Python code looks like this:
- .. code-block:: python
- def forward(self, x, w1, w2):
- stack_1 = torch.stack([w1, w2])
- sum_1 = stack_1.sum()
- stack_2 = torch.stack([w1, w2])
- sum_2 = stack_2.sum()
- max_1 = torch.max(sum_1)
- add_1 = x + max_1
- max_2 = torch.max(sum_2)
- add_2 = add_1 + max_2
- return add_2
- """
- match_and_replacements = _replace_pattern(gm, pattern, replacement)
- return [
- Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements
- ]
- # Experimental API, not backward compatible
- @compatibility(is_backward_compatible=False)
- def replace_pattern_with_filters(
- gm: GraphModule,
- pattern: Union[Callable, Graph, GraphModule],
- replacement: Union[Callable, Graph, GraphModule, None] = None,
- match_filters: Optional[
- list[Callable[["InternalMatch", Graph, Graph], bool]]
- ] = None,
- ignore_literals: bool = False,
- # Placed at the end to avoid breaking backward compatibility
- replacement_callback: Optional[
- Callable[["InternalMatch", Graph, Graph], Graph]
- ] = None,
- node_name_match: str = "",
- ) -> list[ReplacedPatterns]:
- """
- See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
- Args:
- ``match_filters``: A list of functions that take in
- (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
- whether the match satisfies the condition.
- See matcher_utils.py for definition of InternalMatch.
- ``replacement_callback``: A function that takes in a match and returns a
- Graph to be used as the replacement. This allows you to construct a
- replacement graph based on the match.
- ``replacement_callback``: Node name to match. If not empty, it will try to match the node name.
- """
- return _replace_pattern(
- gm,
- pattern,
- replacement,
- match_filters,
- ignore_literals,
- replacement_callback,
- node_name_match,
- )
- def _replace_pattern(
- gm: GraphModule,
- pattern: Union[Callable, Graph, GraphModule],
- replacement: Union[Callable, Graph, GraphModule, None] = None,
- match_filters: Optional[
- list[Callable[["InternalMatch", Graph, Graph], bool]]
- ] = None,
- ignore_literals: bool = False,
- # Placed at the end to avoid breaking backward compatibility
- replacement_callback: Optional[
- Callable[["InternalMatch", Graph, Graph], Graph]
- ] = None,
- node_name_match: str = "",
- ) -> list[ReplacedPatterns]:
- from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
- if match_filters is None:
- match_filters = []
- # Get the graphs for `gm`, `pattern`, `replacement`
- original_graph: Graph = gm.graph
- if isinstance(pattern, GraphModule):
- pattern_graph = pattern.graph
- elif isinstance(pattern, Graph):
- pattern_graph = pattern
- else:
- pattern_graph = symbolic_trace(pattern).graph # type: ignore[arg-type]
- matcher = SubgraphMatcher(
- pattern_graph,
- match_output=False,
- match_placeholder=False,
- remove_overlapping_matches=True,
- ignore_literals=ignore_literals,
- )
- _matches: list[InternalMatch] = matcher.match(
- original_graph, node_name_match=node_name_match
- )
- # Filter out matches that don't match the filter
- _matches = [
- m
- for m in _matches
- if all(
- match_filter(m, original_graph, pattern_graph)
- for match_filter in match_filters
- )
- ]
- if isinstance(replacement, GraphModule):
- common_replacement_graph = replacement.graph
- elif isinstance(replacement, Graph):
- common_replacement_graph = replacement
- elif callable(replacement):
- common_replacement_graph = symbolic_trace(replacement).graph
- else:
- if replacement_callback is None:
- raise AssertionError(
- "Must provide either a replacement GraphModule or a replacement callback"
- )
- common_replacement_graph = None # type: ignore[assignment]
- # As we progressively replace nodes, we'll need to keep track of how the match results should change
- match_changed_node: dict[Node, Node] = {}
- match_and_replacements = []
- for match in _matches:
- if replacement_callback is not None:
- replacement_graph = replacement_callback(
- match, original_graph, pattern_graph
- )
- else:
- if common_replacement_graph is None:
- raise AssertionError(
- "Must provide either a replacement GraphModule or a replacement callback"
- )
- replacement_graph = common_replacement_graph
- replacement_placeholders = [
- n for n in replacement_graph.nodes if n.op == "placeholder"
- ]
- # Build connecting between replacement graph's input and original graph input producer node
- # Initialize `val_map` with mappings from placeholder nodes in
- # `replacement` to their corresponding node in `original_graph`
- if len(match.placeholder_nodes) != len(replacement_placeholders):
- raise AssertionError(
- f"Placeholder count mismatch: {len(match.placeholder_nodes)} vs "
- f"{len(replacement_placeholders)}"
- )
- val_map: dict[Node, Node] = {}
- for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
- if isinstance(gn, Node):
- val_map[rn] = match_changed_node.get(gn, gn)
- if gn != val_map[rn]:
- # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
- gn_ind = match.placeholder_nodes.index(gn)
- match.placeholder_nodes[gn_ind] = match_changed_node[gn]
- map_key = list(match.nodes_map.keys())[
- list(match.nodes_map.values()).index(gn)
- ]
- match.nodes_map[map_key] = match_changed_node[gn]
- else:
- val_map[rn] = gn
- # Copy the replacement graph over
- user_nodes: set[Node] = set()
- for n in match.returning_nodes:
- user_nodes.update(n.users)
- first_user_node = None
- if len(user_nodes) == 0:
- first_user_node = None
- elif len(user_nodes) == 1:
- first_user_node = next(iter(user_nodes))
- else:
- # If there are multiple user nodes, we need to find the first user node
- # in the current execution order of the `original_graph`
- for n in original_graph.nodes:
- if n in user_nodes:
- first_user_node = n
- break
- first_next_node = None
- if first_user_node is None:
- # no users, so we insert the replacement graph before the first next
- # node of returning nodes
- next_node = None
- for n in reversed(original_graph.nodes):
- if n in match.returning_nodes:
- first_next_node = next_node
- break
- else:
- next_node = n
- insert_point = (
- first_user_node if first_user_node is not None else first_next_node
- )
- if insert_point is None:
- raise AssertionError("The insert point can't be None")
- with original_graph.inserting_before(insert_point):
- copied_returning_nodes = original_graph.graph_copy(
- replacement_graph, val_map
- )
- if isinstance(copied_returning_nodes, Node):
- copied_returning_nodes = (copied_returning_nodes,)
- # Get a list of nodes that have been replaced into the graph
- replacement_nodes: list[Node] = [
- v for v in val_map.values() if v not in match.placeholder_nodes
- ]
- # Hook the output Node of the replacement subgraph in to the
- # original Graph at the correct location
- if len(match.returning_nodes) != len(copied_returning_nodes): # type: ignore[arg-type]
- raise AssertionError(
- f"Returning nodes count mismatch: {len(match.returning_nodes)} vs "
- f"{len(copied_returning_nodes)}" # pyrefly: ignore [bad-argument-type]
- )
- for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type]
- gn.replace_all_uses_with(copied_node)
- match_changed_node[gn] = copied_node
- # Remove the original nodes
- for node in reversed(pattern_graph.nodes):
- if node.op != "placeholder" and node.op != "output":
- gn = match.nodes_map[node]
- gm.graph.erase_node(gn)
- match_and_replacements.append(
- ReplacedPatterns(
- anchor=match.anchors[0],
- nodes_map=match.nodes_map,
- replacements=replacement_nodes,
- )
- )
- # Update the passed-in GraphModule to reflect the new state of
- # `original_graph`
- gm.recompile()
- # If `replacement` was an nn.Module, we'll need to make sure that
- # all the submodules have been copied over correctly
- if isinstance(replacement, torch.nn.Module):
- _replace_attributes(gm, replacement)
- return match_and_replacements
|