split_utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. # mypy: allow-untyped-defs
  2. import copy
  3. from dataclasses import dataclass, field
  4. from typing import Optional, Union
  5. import torch.fx
  6. from torch.fx._compatibility import compatibility
  7. from torch.fx.graph import map_arg
  8. from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
  9. from .tools_common import CALLABLE_NODE_OPS, is_node_output_tensor, NodeList
  10. __all__ = [
  11. "getattr_recursive",
  12. "setattr_recursive",
  13. "Component",
  14. "split_by_tags",
  15. "move_non_tensor_nodes_on_boundary",
  16. ]
  17. @compatibility(is_backward_compatible=False)
  18. def getattr_recursive(obj, name):
  19. for layer in name.split("."):
  20. if isinstance(obj, torch.nn.ModuleList):
  21. if hasattr(obj, "_modules") and layer in obj._modules:
  22. obj = obj._modules[layer]
  23. else:
  24. return None
  25. elif hasattr(obj, layer):
  26. obj = getattr(obj, layer)
  27. else:
  28. return None
  29. return obj
  30. @compatibility(is_backward_compatible=False)
  31. def setattr_recursive(obj, attr, value):
  32. if "." not in attr:
  33. setattr(obj, attr, value)
  34. else:
  35. layer = attr.split(".")
  36. setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
  37. @compatibility(is_backward_compatible=False)
  38. @dataclass
  39. class Component:
  40. """
  41. A component serves as a container for a subgraph we want to create afterwards.
  42. """
  43. graph: torch.fx.Graph
  44. order: int
  45. name: str
  46. # Stores the placeholder nodes in `graph`.
  47. input_placeholders: list = field(default_factory=list)
  48. # Store the nodes in original graph that are placeholder in `graph`.
  49. orig_inputs: list = field(default_factory=list)
  50. # Store the nodes in original graph that are outputs in `graph`.
  51. orig_outputs: list = field(default_factory=list)
  52. # Mapping from get_attr node in original graph to get_attr node in `graph`.
  53. getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
  54. constructor_args: list[str] = field(default_factory=list)
  55. gm: Optional[torch.fx.GraphModule] = None
  56. @compatibility(is_backward_compatible=False)
  57. def split_by_tags(
  58. gm: torch.fx.GraphModule,
  59. tags: list[str],
  60. return_fqn_mapping: bool = False,
  61. return_tuple: bool = False,
  62. GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule,
  63. ) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]:
  64. """
  65. Splits a GraphModule using tags on its graph nodes. We honor the order of
  66. tags. For example, we have tags = ["a", "b", "c"], the function will create
  67. the initial submodules in the order of "a", "b", "c".
  68. To set a tag:
  69. gm.graph.nodes[idx].tag = "mytag"
  70. This will result in all nodes with the same tag being extracted and placed in their
  71. own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
  72. and output nodes are created when needed while get_attr nodes get copied to submodules
  73. where they are used.
  74. Given the following module def:
  75. class SimpleModule(torch.nn.Module):
  76. def __init__(self) -> None:
  77. super().__init__()
  78. self.linear1 = torch.nn.Linear(...)
  79. self.linear2 = torch.nn.Linear(...)
  80. self.linear3 = torch.nn.Linear(...)
  81. def forward(self, in1, in2):
  82. r1 = self.linear1(in1)
  83. r2 = self.linear2(in2)
  84. r3 = torch.cat([r1, r2])
  85. return self.linear3(r3)
  86. Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
  87. ro:
  88. def forward(self, in1):
  89. self = self.root
  90. linear1 = self.linear1(in1)
  91. return linear1
  92. main:
  93. def forward(self, in2, linear1):
  94. self = self.root
  95. linear2 = self.linear2(in2)
  96. cat_1 = torch.cat([linear1, linear2])
  97. linear3 = self.linear3(cat_1)
  98. return linear3
  99. main:
  100. def forward(self, in1, in2):
  101. self = self.root
  102. ro_0 = self.ro_0(in1)
  103. main_1 = self.main_1(in2, ro_0)
  104. return main_1
  105. Returns:
  106. split_gm: torch fx graph after split
  107. orig_to_split_fqn_mapping: a map between the original fqn and the fqn
  108. after split for call_module and get_attr.
  109. """
  110. def flatten(x: torch.fx.node.Argument) -> NodeList:
  111. """
  112. Stores nodes in x to a list and returns the list.
  113. """
  114. r: NodeList = []
  115. map_arg(x, r.append)
  116. return r
  117. # Mapping from node in original module to node in created submodule.
  118. node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
  119. # Mapping from node in original module or created submodules to
  120. # corresponding component.
  121. node_to_component: dict[torch.fx.Node, Component] = {}
  122. # Mapping from tag to the corresponding component.
  123. tag_to_component: dict[str, Component] = {}
  124. # Stores all components.
  125. all_components: list[Component] = []
  126. # Stores nodes that will be used in main graph.
  127. used_in_main: dict[torch.fx.Node, None] = {}
  128. # Main graph after split.
  129. main_g = torch.fx.Graph()
  130. # Mapping from node in original module to node in main graph after split.
  131. main_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
  132. # Output node of original module.
  133. output_node: Optional[torch.fx.Node] = None
  134. # Create a component for each tag, we don't expect to create other components afterwards.
  135. for tag in tags:
  136. comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
  137. all_components.append(comp)
  138. tag_to_component[tag] = comp
  139. # Traverse the nodes in original graph and take care of them.
  140. for node in gm.graph.nodes:
  141. if node.op == "output":
  142. if output_node is not None:
  143. raise RuntimeError("Multiple output nodes in graph!")
  144. output_node = node
  145. continue
  146. # Placeholders in the original graph get copied to main graph.
  147. if node.op == "placeholder":
  148. main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
  149. main_remapping[node].meta = copy.copy(node.meta)
  150. continue
  151. # Get_attr nodes are ignored because we are not tagging them.
  152. # Instead, we copy them directly to the submodules use them afterwards.
  153. if node.op == "get_attr":
  154. continue
  155. # Now we process callable nodes which are nodes with op of call_module,
  156. # call_function or call_method. Every callable nodes should be tagged.
  157. if not hasattr(node, "tag"):
  158. raise AssertionError(f"Node does not have tag: {node.format_node()}")
  159. upstream_components = [
  160. node_to_component[x]
  161. for x in flatten(node.args) + flatten(node.kwargs)
  162. if x.op not in {"placeholder", "get_attr"}
  163. ]
  164. comp = tag_to_component[node.tag]
  165. node_to_component[node] = comp
  166. # Max order of upperstream components.
  167. mx = max((c.order for c in upstream_components), default=0)
  168. # Expect the component for `node` has higher order then its upstream components.
  169. if comp.order < mx:
  170. raise AssertionError(
  171. f"Component {comp.name} order must be >= max of its upstream components, "
  172. f"order={comp.order} and max={mx}"
  173. )
  174. # Map a input of `node` to nodes in the component's graph.
  175. def remap_func(x):
  176. # If input is a get_attr node, copy it to current component's graph.
  177. # Returns the get_attr node in current component's graph.
  178. if x.op == "get_attr":
  179. if x not in comp.getattr_maps:
  180. comp.getattr_maps[x] = comp.graph.get_attr(
  181. x.target, type_expr=x.type
  182. )
  183. comp.getattr_maps[x].meta = copy.copy(x.meta)
  184. return comp.getattr_maps[x]
  185. # If input is not a placeholder, it should have been put into a component
  186. # already. If it's the current component then we return the corresponding
  187. # node in the component.
  188. if x.op != "placeholder" and node_to_component[x] == comp:
  189. return node_remapping[x]
  190. # If input is a placeholder or it's in other components, we want to make it
  191. # as a placeholder in current component's graph.
  192. if x not in comp.orig_inputs:
  193. comp.orig_inputs.append(x)
  194. placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
  195. placeholder.meta = copy.copy(x.meta)
  196. comp.input_placeholders.append(placeholder)
  197. used_in_main[x] = None
  198. return comp.input_placeholders[comp.orig_inputs.index(x)]
  199. n = comp.graph.node_copy(node, remap_func)
  200. n.tag = node.tag # type: ignore[attr-defined]
  201. node_remapping[node] = n
  202. node_to_component[n] = comp
  203. if output_node is None:
  204. raise RuntimeError("Graph had no output node!")
  205. for x in flatten(output_node.args[0]):
  206. if x.op == "get_attr":
  207. # We don't need components mapping for nodes of type "get_attr"
  208. # that are consumed by the output. Only need to make sure we create
  209. # corresponding counterparts in the resulting graph.
  210. main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
  211. else:
  212. # All component results consumed by the output node should be
  213. # marked as "used in main".
  214. used_in_main[x] = None
  215. # If a node is used in main graph then we mark it as an output in the component
  216. # it belongs to.
  217. for n in used_in_main:
  218. if n.op != "placeholder":
  219. node_to_component[n].orig_outputs.append(n)
  220. # Now we create a graphmodule for each component.
  221. orig_to_split_fqn_mapping: dict[str, str] = {}
  222. for comp in all_components:
  223. outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
  224. if return_tuple:
  225. comp.graph.output(outs)
  226. else:
  227. # Take care of the args of FX output node. If there's a single
  228. # output then the output node args is like (output_single), else
  229. # if there're multiple outputs then the output node args is like
  230. # ((output_0, output_1, ...)).
  231. comp.graph.output(outs[0] if len(outs) == 1 else outs)
  232. comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
  233. gm, subgraph=comp.graph, comp_name=comp.name
  234. )
  235. orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
  236. # Create a call_module node in main graph.
  237. main_node = main_g.call_module(
  238. comp.name,
  239. args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
  240. kwargs=None,
  241. )
  242. if len(outs) == 1 and not return_tuple:
  243. main_remapping[comp.orig_outputs[0]] = main_node
  244. else:
  245. for i, o in enumerate(comp.orig_outputs):
  246. # Use Proxy to record getitem access.
  247. main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index]
  248. main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
  249. main_root = HolderModule({comp.name: comp.gm for comp in all_components})
  250. main_g._codegen = gm.graph._codegen
  251. # If the output nodes consumes get_attr directly in the original graph,
  252. # then we need to make sure get_attr is copied to the new graph.
  253. for x in flatten(output_node.args[0]):
  254. if x.op == "get_attr":
  255. setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
  256. result_gm = GraphModuleCls(main_root, main_g)
  257. if return_fqn_mapping:
  258. return result_gm, orig_to_split_fqn_mapping
  259. return result_gm
  260. @compatibility(is_backward_compatible=False)
  261. def move_non_tensor_nodes_on_boundary(subgraphs) -> None:
  262. """
  263. Move non-tensor nodes on the boundary between subgraphs.
  264. For each subgraph:
  265. 1. Find nodes whose type is not tensor and any of its children is in another
  266. subgraph, put them in a queue for next step
  267. 2. Do a BFS on those nodes in the queue, and run a DFS for each node, let's say node X and it is in subgraph A:
  268. a. if it is in to_subgraph, return (continue DFS)
  269. b. if it is in from_subgraph, collect the nodes to nodes_to_move, and continue DFS
  270. c. otherwise, this means it cannot be moved
  271. d. also check if node X's parent should be put into the queue. (The queue may
  272. have duplicated nodes, just process the node once)
  273. Args:
  274. subgraphs: List of subgraphs containing nodes to be processed
  275. """
  276. # Create a mapping from node to subgraph for quick lookup
  277. node_to_subgraph: dict[torch.fx.Node, int] = {}
  278. for i, subgraph in enumerate(subgraphs):
  279. for node in subgraph.nodes:
  280. node_to_subgraph[node] = i
  281. def get_children_in_graph(node: torch.fx.Node) -> list[torch.fx.Node]:
  282. """Get children nodes that are in callable ops and in some subgraph"""
  283. return [
  284. user
  285. for user in node.users
  286. if user.op in CALLABLE_NODE_OPS and user in node_to_subgraph
  287. ]
  288. def get_parents_in_graph(node: torch.fx.Node) -> list[torch.fx.Node]:
  289. """Get parent nodes that are in callable ops and in some subgraph"""
  290. return [
  291. arg
  292. for arg in node.all_input_nodes
  293. if arg.op in CALLABLE_NODE_OPS and arg in node_to_subgraph
  294. ]
  295. def has_children_in_other_subgraph(
  296. node: torch.fx.Node, current_subgraph_idx: int
  297. ) -> bool:
  298. """
  299. Check if the node has any children in a subgraph different from current_subgraph_idx.
  300. This is the requirement used in both step 1 and step d.
  301. """
  302. children = get_children_in_graph(node)
  303. return any(
  304. node_to_subgraph[child] != current_subgraph_idx for child in children
  305. )
  306. def can_move_node_and_dependencies(
  307. node: torch.fx.Node, from_subgraph: int, to_subgraph: int
  308. ) -> tuple[bool, set[torch.fx.Node]]:
  309. """
  310. Check if node and its dependencies can be moved from from_subgraph to to_subgraph.
  311. Returns (can_move, nodes_to_move)
  312. For node X, do a DFS on its descendants, for each node:
  313. - if it is in to_subgraph, return (continue DFS)
  314. - if it is in from_subgraph, collect the nodes to nodes_to_move, and continue DFS
  315. - otherwise, this means it cannot be moved
  316. """
  317. nodes_to_move = set()
  318. visited = set()
  319. can_move = True
  320. def dfs(current_node):
  321. nonlocal can_move, nodes_to_move
  322. if current_node in visited:
  323. return
  324. visited.add(current_node)
  325. # Check current node's subgraph
  326. if current_node not in node_to_subgraph:
  327. return # Skip nodes not in any subgraph
  328. current_subgraph = node_to_subgraph[current_node]
  329. if current_subgraph == to_subgraph:
  330. # If it is in to_subgraph, just end DFS
  331. return
  332. elif current_subgraph == from_subgraph:
  333. # If it is in from_subgraph, collect it and continue DFS
  334. nodes_to_move.add(current_node)
  335. else:
  336. # Otherwise, this means it cannot be moved
  337. can_move = False
  338. return
  339. # Continue DFS on children
  340. children = get_children_in_graph(current_node)
  341. for child in children:
  342. if can_move: # Only continue if we haven't already failed
  343. dfs(child)
  344. # Start DFS from the original node
  345. dfs(node)
  346. return can_move, nodes_to_move
  347. # For each subgraph, find non-tensor nodes with children in other subgraphs
  348. for subgraph_idx, subgraph in enumerate(subgraphs):
  349. # non acc nodes cannot be moved to downstream acc graph, so skip
  350. if not subgraph.is_acc:
  351. continue
  352. # Step 1: Find non-tensor nodes with children in other subgraphs
  353. queue: list[torch.fx.Node] = []
  354. processed: set[torch.fx.Node] = set()
  355. for node in subgraph.nodes:
  356. # Check if node is non-tensor
  357. if is_node_output_tensor(node):
  358. continue
  359. # Check if node meets step 1 requirement: any children in another subgraph
  360. if has_children_in_other_subgraph(node, subgraph_idx):
  361. queue.append(node)
  362. # Step 2: BFS to move nodes that meet the criteria
  363. while queue:
  364. current_node = queue.pop(0)
  365. # Skip if already processed (queue may have duplicates)
  366. if current_node in processed:
  367. continue
  368. processed.add(current_node)
  369. # Skip if node is no longer in this subgraph (may have been moved)
  370. if (
  371. current_node not in node_to_subgraph
  372. or node_to_subgraph[current_node] != subgraph_idx
  373. ):
  374. continue
  375. children = get_children_in_graph(current_node)
  376. if len(children) == 0:
  377. raise AssertionError(
  378. "Only node that has children in other subgraph can be moved"
  379. )
  380. # Find target subgraph. The children should all be in the same subgraph except current subgraph
  381. target_subgraph_candidates = set()
  382. for child in children:
  383. child_subgraph = node_to_subgraph[child]
  384. if child_subgraph != subgraph_idx:
  385. target_subgraph_candidates.add(child_subgraph)
  386. # If multiple children live in different subgraphs, the node cannot be moved. User needs to find other ways to move it.
  387. if len(target_subgraph_candidates) != 1:
  388. print(
  389. f"Cannot move non-tensor node {current_node.name} on boundary because it has children in multiple subgraphs"
  390. )
  391. continue
  392. target_subgraph = target_subgraph_candidates.pop()
  393. # Check if we can move this node and its dependencies
  394. can_move, nodes_to_move = can_move_node_and_dependencies(
  395. current_node, subgraph_idx, target_subgraph
  396. )
  397. if can_move:
  398. # Move all nodes in nodes_to_move to target subgraph
  399. for node_to_move in nodes_to_move:
  400. # Remove from current subgraph
  401. subgraph.nodes.remove(node_to_move)
  402. # Add to target subgraph
  403. subgraphs[target_subgraph].nodes.append(node_to_move)
  404. # Update mapping
  405. node_to_subgraph[node_to_move] = target_subgraph
  406. print(
  407. f"In order move the non-tensor node {current_node.name} on boundary, "
  408. f"moved node {node_to_move.name} from {'acc' if subgraph.is_acc else 'gpu'}_{subgraph_idx} "
  409. f"to {'acc' if subgraphs[target_subgraph].is_acc else 'gpu'}_{target_subgraph}"
  410. )
  411. # Add parents to the queue if they're non-tensor and not already processed
  412. # and meet the requirement from step 1 (any children in another subgraph)
  413. parents = get_parents_in_graph(current_node)
  414. for parent in parents:
  415. if (
  416. not is_node_output_tensor(parent)
  417. and parent not in processed
  418. and parent in node_to_subgraph
  419. and node_to_subgraph[parent] == subgraph_idx
  420. ):
  421. # Check if parent meets step 1 requirement: any children in another subgraph
  422. if not has_children_in_other_subgraph(parent, subgraph_idx):
  423. raise AssertionError(
  424. f"Parent {parent.name} should have children in another subgraph"
  425. )
  426. queue.append(parent)