| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- # mypy: allow-untyped-defs
- import collections
- import heapq
- import operator
- from collections.abc import Mapping
- from dataclasses import dataclass
- from typing import Any, Optional, Union
- import torch
- import torch.fx
- from torch.fx._compatibility import compatibility
- from torch.fx.node import _get_qualified_name
- __all__ = [
- "get_acc_ops_name",
- "get_node_target",
- "is_node_output_tensor",
- "FxNetAccFusionsFinder",
- "legalize_graph",
- "stable_topological_sort",
- ]
- Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]]
- TensorOrTensors = Union[torch.Tensor, Tensors]
- NodeList = list[torch.fx.Node]
- NodeSet = set[torch.fx.Node]
- Names = list[str]
- CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
- @compatibility(is_backward_compatible=False)
- def get_acc_ops_name(k):
- if isinstance(k, str):
- return k
- elif k.__module__ and "acc_ops" in k.__module__:
- return f"acc_ops.{k.__name__}"
- else:
- module = k.__module__.replace(
- "torch._ops", "torch.ops"
- ) # WAR for bug in how torch.ops assigns module
- return f"{module if module else ''}.{k.__name__}"
- @compatibility(is_backward_compatible=False)
- def get_node_target(
- submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
- ) -> str:
- """
- Given a `node` returns its target typename.
- For "call_method" node, return node.target which is the name of that method being called.
- This could potential lead to conflict but should be okay because normally it's on a tensor.
- For "call_function" node, return typename of node.target.
- For "call_module" node, return typename of the module that node.target point to.
- If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
- "torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
- """
- if node.op not in CALLABLE_NODE_OPS:
- raise AssertionError(
- "Expect op types of "
- + ", ".join(CALLABLE_NODE_OPS)
- + f", but found {node.op}"
- )
- if node.op == "call_module":
- if not isinstance(node.target, str):
- raise AssertionError(f"Expected str target, got {type(node.target)}")
- submod = submodules[node.target]
- submod_type = getattr(submod, "_base_class_origin", type(submod))
- return get_acc_ops_name(submod_type)
- elif node.op == "call_function":
- target: Any = node.target
- return (
- f"acc_ops.{target.__name__}"
- if target.__module__ is not None and "acc_ops" in target.__module__
- else _get_qualified_name(target)
- )
- else:
- if not isinstance(node.target, str):
- raise AssertionError(f"Expected str target, got {type(node.target)}")
- return node.target
- @compatibility(is_backward_compatible=False)
- def is_node_output_tensor(node: torch.fx.Node) -> bool:
- """Checks if the node output produces a Tensor or not.
- NOTE: This requires to run `ShapeProp` on the containing fx graph before
- calling this function. This is because it works by checking the `type`
- metadata on the node. This metadata is produced by the `ShapeProp`.
- """
- type_ = node.meta.get("type", None)
- return type_ is not None and issubclass(type_, torch.Tensor)
- @compatibility(is_backward_compatible=False)
- class FxNetAccFusionsFinder:
- """
- Finds groups of connected ACC nodes that pass non-tensor data between each other.
- Such groups are called fusion groups.
- """
- def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
- self.module = module
- self.nodes = list(module.graph.nodes)
- self.acc_nodes = acc_nodes
- self.node_index = {node: i for i, node in enumerate(self.nodes)}
- @dataclass
- class FusionGroup:
- # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
- top_node_idx: int
- # Nodes in this fusion group.
- nodes: NodeSet
- # Inputs to this fusion group.
- inputs: NodeSet
- # Nodes that in the fusion group that haven't been processed yet.
- nodes_need_process: NodeSet
- def add_node(self, node):
- """
- Add a node to fusion group.
- """
- if node in self.nodes:
- return
- self.nodes_need_process.add(node)
- self.nodes.add(node)
- self.inputs.discard(node)
- self.inputs.update(
- {
- n
- for n in node.all_input_nodes
- if n.op in CALLABLE_NODE_OPS and n not in self.nodes
- }
- )
- def recursive_add_node(
- self,
- fusion_group: "FxNetAccFusionsFinder.FusionGroup",
- inputs: Union[NodeSet, NodeList],
- visited: Optional[NodeSet] = None,
- ):
- """
- Start from inputs and going reverse topological order. If any upstream node
- is in the fusion group, add all the nodes in this path to fusion group.
- """
- for arg in inputs:
- # skip the node if already seen
- if visited is not None:
- if arg in visited:
- continue
- visited.add(arg)
- # Skip placeholder and get_attr because they won't be in the fusion group.
- if arg.op not in CALLABLE_NODE_OPS:
- continue
- # If the node has smaller idx, it's already an upstream node of the fusion
- # group. We don't need to check it anymore.
- if self.node_index[arg] < fusion_group.top_node_idx:
- continue
- # If the node is in the fusion group, return True.
- if arg in fusion_group.nodes:
- return True
- # Check the upstream nodes of the node, if any of them is in the fusion group
- # we'll add this node to fusion group and return True.
- if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
- fusion_group.add_node(arg)
- return True
- return False
- def __call__(self) -> dict[torch.fx.Node, NodeSet]:
- result: dict[torch.fx.Node, NodeSet] = {}
- acc_nodes = list(self.acc_nodes)
- for node in acc_nodes:
- if node in result:
- continue
- if node.op not in CALLABLE_NODE_OPS:
- continue
- if "tensor_meta" in node.meta:
- continue
- if node not in self.acc_nodes:
- continue
- fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
- top_node_idx=self.node_index[node],
- nodes={node},
- inputs=set(node.all_input_nodes),
- nodes_need_process={node},
- )
- while fusion_group.nodes_need_process:
- node = fusion_group.nodes_need_process.pop()
- self.recursive_add_node(
- fusion_group,
- fusion_group.inputs,
- visited=set(),
- )
- # Optionally add downstream nodes
- if "tensor_meta" not in node.meta:
- for user in node.users:
- if user.op not in CALLABLE_NODE_OPS:
- continue
- if user in fusion_group.nodes:
- continue
- fusion_group.add_node(user)
- self.recursive_add_node(
- fusion_group,
- fusion_group.inputs,
- visited=set(),
- )
- # Add some upstream nodes
- for arg in node.all_input_nodes:
- if arg.op not in CALLABLE_NODE_OPS:
- continue
- if "tensor_meta" in arg.meta:
- continue
- if arg in fusion_group.nodes:
- continue
- fusion_group.add_node(arg)
- fusion_group.top_node_idx = min(
- fusion_group.top_node_idx, self.node_index[arg]
- )
- self.recursive_add_node(
- fusion_group,
- fusion_group.inputs,
- visited=set(),
- )
- if not (set(fusion_group.nodes) <= self.acc_nodes):
- self.acc_nodes -= fusion_group.nodes
- else:
- for n in fusion_group.nodes:
- result[n] = fusion_group.nodes
- return result
- @compatibility(is_backward_compatible=False)
- def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
- """
- Replace the graph of the given GraphModule with one that contains the same nodes as the
- original, but in topologically sorted order.
- This is used by the merge_matmul transformation below, which disturbs the topologically sorted
- order of its input GraphModule, so that this order is restored before further transformation.
- Arguments:
- gm: The graph module to topologically sort. It is modified in-place.
- Returns:
- The graph module in-place sorted
- Warning:
- This topological sort is NOT stable, it will NOT preserve the original node order.
- If you need a stable topological sort, use stable_topological_sort instead.
- """
- # These operators are used for making runtime assertions before any
- # data-dependent operators occur. We want to prioritize sorting these to
- # ensure that these assertions appear before any data-dependent operations
- # in the graph.
- PRIORITIZED_OPS = [
- operator.add,
- operator.mul,
- operator.sub,
- operator.floordiv,
- operator.truediv,
- operator.mod,
- operator.le,
- operator.lt,
- operator.ge,
- operator.gt,
- operator.eq,
- operator.ne,
- torch.ops.aten.sym_constrain_range.default,
- torch.ops.aten.sym_constrain_range_for_size.default,
- torch.ops.aten._assert_async.msg,
- torch.ops.aten.scalar_tensor.default,
- torch.ops.aten._assert_scalar.default,
- ]
- indeg = dict.fromkeys(gm.graph.nodes, 0)
- new_graph = torch.fx.Graph()
- # Track how many unfulfilled dependencies each node has
- for node in gm.graph.nodes:
- for user in node.users:
- indeg[user] += 1
- queue: collections.deque = collections.deque()
- # Add all nodes with no dependencies to the queue
- for node in gm.graph.nodes:
- if indeg[node] == 0:
- queue.append(node)
- env: dict[torch.fx.Node, torch.fx.Node] = {}
- # Pop nodes from the queue, and add nodes that have had all their
- # dependencies fulfilled
- while len(queue) > 0:
- cur = queue.popleft()
- env[cur] = new_graph.node_copy(cur, lambda x: env[x])
- for user in cur.users:
- indeg[user] -= 1
- if indeg[user] == 0:
- if user.op == "call_function" and user.target in PRIORITIZED_OPS:
- queue.appendleft(user)
- else:
- queue.append(user)
- # If the new graph's size is not as large as the old one, then there must be
- # a cycle (i.e. some node's dependencies were not satisfied.)
- if len(new_graph.nodes) < len(gm.graph.nodes):
- raise RuntimeError(
- f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
- )
- new_graph._codegen = gm.graph._codegen
- gm.graph = new_graph
- return gm
- @compatibility(is_backward_compatible=False)
- def stable_topological_sort(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
- """
- Replace the graph of the given GraphModule with one that contains the same nodes as the
- original, but in topologically sorted order while preserving the original node order
- as much as possible.
- This function performs a stable topological sort where nodes appear in an order that:
- 1. Respects data dependencies (topological ordering)
- 2. Preserves the original node order when there are no dependency constraints
- The algorithm uses Kahn's algorithm with a priority queue: nodes with all dependencies
- satisfied are added to a min-heap, ordered by their original position. This ensures
- we always process the earliest node in the original order among ready nodes.
- Arguments:
- gm: The graph module to topologically sort. It is modified in-place.
- Returns:
- The graph module in-place sorted
- """
- indeg = dict.fromkeys(gm.graph.nodes, 0)
- new_graph = torch.fx.Graph()
- # Build node to original index mapping
- node_to_id: dict[torch.fx.Node, int] = {
- node: idx for idx, node in enumerate(gm.graph.nodes)
- }
- # Track how many unfulfilled dependencies each node has
- for node in gm.graph.nodes:
- for user in node.users:
- indeg[user] += 1
- # Priority queue: (original_index, node)
- # Use min-heap to always process the node with smallest original index
- ready_queue: list[tuple[int, torch.fx.Node]] = []
- for node in gm.graph.nodes:
- if indeg[node] == 0:
- heapq.heappush(ready_queue, (node_to_id[node], node))
- env: dict[torch.fx.Node, torch.fx.Node] = {}
- # Process nodes
- while ready_queue:
- # Pop node with smallest original index
- _, cur = heapq.heappop(ready_queue)
- env[cur] = new_graph.node_copy(cur, lambda x: env[x])
- # Update in-degrees and add newly ready nodes
- for user in cur.users:
- indeg[user] -= 1
- if indeg[user] == 0:
- heapq.heappush(ready_queue, (node_to_id[user], user))
- # Check if all nodes were processed
- if len(new_graph.nodes) != len(gm.graph.nodes):
- raise AssertionError(
- f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
- )
- new_graph._codegen = gm.graph._codegen
- gm.graph = new_graph
- return gm
|