tools_common.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import heapq
  4. import operator
  5. from collections.abc import Mapping
  6. from dataclasses import dataclass
  7. from typing import Any, Optional, Union
  8. import torch
  9. import torch.fx
  10. from torch.fx._compatibility import compatibility
  11. from torch.fx.node import _get_qualified_name
  12. __all__ = [
  13. "get_acc_ops_name",
  14. "get_node_target",
  15. "is_node_output_tensor",
  16. "FxNetAccFusionsFinder",
  17. "legalize_graph",
  18. "stable_topological_sort",
  19. ]
  20. Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]]
  21. TensorOrTensors = Union[torch.Tensor, Tensors]
  22. NodeList = list[torch.fx.Node]
  23. NodeSet = set[torch.fx.Node]
  24. Names = list[str]
  25. CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
  26. @compatibility(is_backward_compatible=False)
  27. def get_acc_ops_name(k):
  28. if isinstance(k, str):
  29. return k
  30. elif k.__module__ and "acc_ops" in k.__module__:
  31. return f"acc_ops.{k.__name__}"
  32. else:
  33. module = k.__module__.replace(
  34. "torch._ops", "torch.ops"
  35. ) # WAR for bug in how torch.ops assigns module
  36. return f"{module if module else ''}.{k.__name__}"
  37. @compatibility(is_backward_compatible=False)
  38. def get_node_target(
  39. submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
  40. ) -> str:
  41. """
  42. Given a `node` returns its target typename.
  43. For "call_method" node, return node.target which is the name of that method being called.
  44. This could potential lead to conflict but should be okay because normally it's on a tensor.
  45. For "call_function" node, return typename of node.target.
  46. For "call_module" node, return typename of the module that node.target point to.
  47. If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
  48. "torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
  49. """
  50. if node.op not in CALLABLE_NODE_OPS:
  51. raise AssertionError(
  52. "Expect op types of "
  53. + ", ".join(CALLABLE_NODE_OPS)
  54. + f", but found {node.op}"
  55. )
  56. if node.op == "call_module":
  57. if not isinstance(node.target, str):
  58. raise AssertionError(f"Expected str target, got {type(node.target)}")
  59. submod = submodules[node.target]
  60. submod_type = getattr(submod, "_base_class_origin", type(submod))
  61. return get_acc_ops_name(submod_type)
  62. elif node.op == "call_function":
  63. target: Any = node.target
  64. return (
  65. f"acc_ops.{target.__name__}"
  66. if target.__module__ is not None and "acc_ops" in target.__module__
  67. else _get_qualified_name(target)
  68. )
  69. else:
  70. if not isinstance(node.target, str):
  71. raise AssertionError(f"Expected str target, got {type(node.target)}")
  72. return node.target
  73. @compatibility(is_backward_compatible=False)
  74. def is_node_output_tensor(node: torch.fx.Node) -> bool:
  75. """Checks if the node output produces a Tensor or not.
  76. NOTE: This requires to run `ShapeProp` on the containing fx graph before
  77. calling this function. This is because it works by checking the `type`
  78. metadata on the node. This metadata is produced by the `ShapeProp`.
  79. """
  80. type_ = node.meta.get("type", None)
  81. return type_ is not None and issubclass(type_, torch.Tensor)
  82. @compatibility(is_backward_compatible=False)
  83. class FxNetAccFusionsFinder:
  84. """
  85. Finds groups of connected ACC nodes that pass non-tensor data between each other.
  86. Such groups are called fusion groups.
  87. """
  88. def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
  89. self.module = module
  90. self.nodes = list(module.graph.nodes)
  91. self.acc_nodes = acc_nodes
  92. self.node_index = {node: i for i, node in enumerate(self.nodes)}
  93. @dataclass
  94. class FusionGroup:
  95. # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
  96. top_node_idx: int
  97. # Nodes in this fusion group.
  98. nodes: NodeSet
  99. # Inputs to this fusion group.
  100. inputs: NodeSet
  101. # Nodes that in the fusion group that haven't been processed yet.
  102. nodes_need_process: NodeSet
  103. def add_node(self, node):
  104. """
  105. Add a node to fusion group.
  106. """
  107. if node in self.nodes:
  108. return
  109. self.nodes_need_process.add(node)
  110. self.nodes.add(node)
  111. self.inputs.discard(node)
  112. self.inputs.update(
  113. {
  114. n
  115. for n in node.all_input_nodes
  116. if n.op in CALLABLE_NODE_OPS and n not in self.nodes
  117. }
  118. )
  119. def recursive_add_node(
  120. self,
  121. fusion_group: "FxNetAccFusionsFinder.FusionGroup",
  122. inputs: Union[NodeSet, NodeList],
  123. visited: Optional[NodeSet] = None,
  124. ):
  125. """
  126. Start from inputs and going reverse topological order. If any upstream node
  127. is in the fusion group, add all the nodes in this path to fusion group.
  128. """
  129. for arg in inputs:
  130. # skip the node if already seen
  131. if visited is not None:
  132. if arg in visited:
  133. continue
  134. visited.add(arg)
  135. # Skip placeholder and get_attr because they won't be in the fusion group.
  136. if arg.op not in CALLABLE_NODE_OPS:
  137. continue
  138. # If the node has smaller idx, it's already an upstream node of the fusion
  139. # group. We don't need to check it anymore.
  140. if self.node_index[arg] < fusion_group.top_node_idx:
  141. continue
  142. # If the node is in the fusion group, return True.
  143. if arg in fusion_group.nodes:
  144. return True
  145. # Check the upstream nodes of the node, if any of them is in the fusion group
  146. # we'll add this node to fusion group and return True.
  147. if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
  148. fusion_group.add_node(arg)
  149. return True
  150. return False
  151. def __call__(self) -> dict[torch.fx.Node, NodeSet]:
  152. result: dict[torch.fx.Node, NodeSet] = {}
  153. acc_nodes = list(self.acc_nodes)
  154. for node in acc_nodes:
  155. if node in result:
  156. continue
  157. if node.op not in CALLABLE_NODE_OPS:
  158. continue
  159. if "tensor_meta" in node.meta:
  160. continue
  161. if node not in self.acc_nodes:
  162. continue
  163. fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
  164. top_node_idx=self.node_index[node],
  165. nodes={node},
  166. inputs=set(node.all_input_nodes),
  167. nodes_need_process={node},
  168. )
  169. while fusion_group.nodes_need_process:
  170. node = fusion_group.nodes_need_process.pop()
  171. self.recursive_add_node(
  172. fusion_group,
  173. fusion_group.inputs,
  174. visited=set(),
  175. )
  176. # Optionally add downstream nodes
  177. if "tensor_meta" not in node.meta:
  178. for user in node.users:
  179. if user.op not in CALLABLE_NODE_OPS:
  180. continue
  181. if user in fusion_group.nodes:
  182. continue
  183. fusion_group.add_node(user)
  184. self.recursive_add_node(
  185. fusion_group,
  186. fusion_group.inputs,
  187. visited=set(),
  188. )
  189. # Add some upstream nodes
  190. for arg in node.all_input_nodes:
  191. if arg.op not in CALLABLE_NODE_OPS:
  192. continue
  193. if "tensor_meta" in arg.meta:
  194. continue
  195. if arg in fusion_group.nodes:
  196. continue
  197. fusion_group.add_node(arg)
  198. fusion_group.top_node_idx = min(
  199. fusion_group.top_node_idx, self.node_index[arg]
  200. )
  201. self.recursive_add_node(
  202. fusion_group,
  203. fusion_group.inputs,
  204. visited=set(),
  205. )
  206. if not (set(fusion_group.nodes) <= self.acc_nodes):
  207. self.acc_nodes -= fusion_group.nodes
  208. else:
  209. for n in fusion_group.nodes:
  210. result[n] = fusion_group.nodes
  211. return result
  212. @compatibility(is_backward_compatible=False)
  213. def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
  214. """
  215. Replace the graph of the given GraphModule with one that contains the same nodes as the
  216. original, but in topologically sorted order.
  217. This is used by the merge_matmul transformation below, which disturbs the topologically sorted
  218. order of its input GraphModule, so that this order is restored before further transformation.
  219. Arguments:
  220. gm: The graph module to topologically sort. It is modified in-place.
  221. Returns:
  222. The graph module in-place sorted
  223. Warning:
  224. This topological sort is NOT stable, it will NOT preserve the original node order.
  225. If you need a stable topological sort, use stable_topological_sort instead.
  226. """
  227. # These operators are used for making runtime assertions before any
  228. # data-dependent operators occur. We want to prioritize sorting these to
  229. # ensure that these assertions appear before any data-dependent operations
  230. # in the graph.
  231. PRIORITIZED_OPS = [
  232. operator.add,
  233. operator.mul,
  234. operator.sub,
  235. operator.floordiv,
  236. operator.truediv,
  237. operator.mod,
  238. operator.le,
  239. operator.lt,
  240. operator.ge,
  241. operator.gt,
  242. operator.eq,
  243. operator.ne,
  244. torch.ops.aten.sym_constrain_range.default,
  245. torch.ops.aten.sym_constrain_range_for_size.default,
  246. torch.ops.aten._assert_async.msg,
  247. torch.ops.aten.scalar_tensor.default,
  248. torch.ops.aten._assert_scalar.default,
  249. ]
  250. indeg = dict.fromkeys(gm.graph.nodes, 0)
  251. new_graph = torch.fx.Graph()
  252. # Track how many unfulfilled dependencies each node has
  253. for node in gm.graph.nodes:
  254. for user in node.users:
  255. indeg[user] += 1
  256. queue: collections.deque = collections.deque()
  257. # Add all nodes with no dependencies to the queue
  258. for node in gm.graph.nodes:
  259. if indeg[node] == 0:
  260. queue.append(node)
  261. env: dict[torch.fx.Node, torch.fx.Node] = {}
  262. # Pop nodes from the queue, and add nodes that have had all their
  263. # dependencies fulfilled
  264. while len(queue) > 0:
  265. cur = queue.popleft()
  266. env[cur] = new_graph.node_copy(cur, lambda x: env[x])
  267. for user in cur.users:
  268. indeg[user] -= 1
  269. if indeg[user] == 0:
  270. if user.op == "call_function" and user.target in PRIORITIZED_OPS:
  271. queue.appendleft(user)
  272. else:
  273. queue.append(user)
  274. # If the new graph's size is not as large as the old one, then there must be
  275. # a cycle (i.e. some node's dependencies were not satisfied.)
  276. if len(new_graph.nodes) < len(gm.graph.nodes):
  277. raise RuntimeError(
  278. f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
  279. )
  280. new_graph._codegen = gm.graph._codegen
  281. gm.graph = new_graph
  282. return gm
  283. @compatibility(is_backward_compatible=False)
  284. def stable_topological_sort(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
  285. """
  286. Replace the graph of the given GraphModule with one that contains the same nodes as the
  287. original, but in topologically sorted order while preserving the original node order
  288. as much as possible.
  289. This function performs a stable topological sort where nodes appear in an order that:
  290. 1. Respects data dependencies (topological ordering)
  291. 2. Preserves the original node order when there are no dependency constraints
  292. The algorithm uses Kahn's algorithm with a priority queue: nodes with all dependencies
  293. satisfied are added to a min-heap, ordered by their original position. This ensures
  294. we always process the earliest node in the original order among ready nodes.
  295. Arguments:
  296. gm: The graph module to topologically sort. It is modified in-place.
  297. Returns:
  298. The graph module in-place sorted
  299. """
  300. indeg = dict.fromkeys(gm.graph.nodes, 0)
  301. new_graph = torch.fx.Graph()
  302. # Build node to original index mapping
  303. node_to_id: dict[torch.fx.Node, int] = {
  304. node: idx for idx, node in enumerate(gm.graph.nodes)
  305. }
  306. # Track how many unfulfilled dependencies each node has
  307. for node in gm.graph.nodes:
  308. for user in node.users:
  309. indeg[user] += 1
  310. # Priority queue: (original_index, node)
  311. # Use min-heap to always process the node with smallest original index
  312. ready_queue: list[tuple[int, torch.fx.Node]] = []
  313. for node in gm.graph.nodes:
  314. if indeg[node] == 0:
  315. heapq.heappush(ready_queue, (node_to_id[node], node))
  316. env: dict[torch.fx.Node, torch.fx.Node] = {}
  317. # Process nodes
  318. while ready_queue:
  319. # Pop node with smallest original index
  320. _, cur = heapq.heappop(ready_queue)
  321. env[cur] = new_graph.node_copy(cur, lambda x: env[x])
  322. # Update in-degrees and add newly ready nodes
  323. for user in cur.users:
  324. indeg[user] -= 1
  325. if indeg[user] == 0:
  326. heapq.heappush(ready_queue, (node_to_id[user], user))
  327. # Check if all nodes were processed
  328. if len(new_graph.nodes) != len(gm.graph.nodes):
  329. raise AssertionError(
  330. f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}"
  331. )
  332. new_graph._codegen = gm.graph._codegen
  333. gm.graph = new_graph
  334. return gm