| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- # mypy: allow-untyped-defs
- import copy
- from dataclasses import dataclass, field
- from typing import Optional, Union
- import torch.fx
- from torch.fx._compatibility import compatibility
- from torch.fx.graph import map_arg
- from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
- from .tools_common import CALLABLE_NODE_OPS, is_node_output_tensor, NodeList
- __all__ = [
- "getattr_recursive",
- "setattr_recursive",
- "Component",
- "split_by_tags",
- "move_non_tensor_nodes_on_boundary",
- ]
- @compatibility(is_backward_compatible=False)
- def getattr_recursive(obj, name):
- for layer in name.split("."):
- if isinstance(obj, torch.nn.ModuleList):
- if hasattr(obj, "_modules") and layer in obj._modules:
- obj = obj._modules[layer]
- else:
- return None
- elif hasattr(obj, layer):
- obj = getattr(obj, layer)
- else:
- return None
- return obj
- @compatibility(is_backward_compatible=False)
- def setattr_recursive(obj, attr, value):
- if "." not in attr:
- setattr(obj, attr, value)
- else:
- layer = attr.split(".")
- setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
- @compatibility(is_backward_compatible=False)
- @dataclass
- class Component:
- """
- A component serves as a container for a subgraph we want to create afterwards.
- """
- graph: torch.fx.Graph
- order: int
- name: str
- # Stores the placeholder nodes in `graph`.
- input_placeholders: list = field(default_factory=list)
- # Store the nodes in original graph that are placeholder in `graph`.
- orig_inputs: list = field(default_factory=list)
- # Store the nodes in original graph that are outputs in `graph`.
- orig_outputs: list = field(default_factory=list)
- # Mapping from get_attr node in original graph to get_attr node in `graph`.
- getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
- constructor_args: list[str] = field(default_factory=list)
- gm: Optional[torch.fx.GraphModule] = None
- @compatibility(is_backward_compatible=False)
- def split_by_tags(
- gm: torch.fx.GraphModule,
- tags: list[str],
- return_fqn_mapping: bool = False,
- return_tuple: bool = False,
- GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule,
- ) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]:
- """
- Splits a GraphModule using tags on its graph nodes. We honor the order of
- tags. For example, we have tags = ["a", "b", "c"], the function will create
- the initial submodules in the order of "a", "b", "c".
- To set a tag:
- gm.graph.nodes[idx].tag = "mytag"
- This will result in all nodes with the same tag being extracted and placed in their
- own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
- and output nodes are created when needed while get_attr nodes get copied to submodules
- where they are used.
- Given the following module def:
- class SimpleModule(torch.nn.Module):
- def __init__(self) -> None:
- super().__init__()
- self.linear1 = torch.nn.Linear(...)
- self.linear2 = torch.nn.Linear(...)
- self.linear3 = torch.nn.Linear(...)
- def forward(self, in1, in2):
- r1 = self.linear1(in1)
- r2 = self.linear2(in2)
- r3 = torch.cat([r1, r2])
- return self.linear3(r3)
- Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
- ro:
- def forward(self, in1):
- self = self.root
- linear1 = self.linear1(in1)
- return linear1
- main:
- def forward(self, in2, linear1):
- self = self.root
- linear2 = self.linear2(in2)
- cat_1 = torch.cat([linear1, linear2])
- linear3 = self.linear3(cat_1)
- return linear3
- main:
- def forward(self, in1, in2):
- self = self.root
- ro_0 = self.ro_0(in1)
- main_1 = self.main_1(in2, ro_0)
- return main_1
- Returns:
- split_gm: torch fx graph after split
- orig_to_split_fqn_mapping: a map between the original fqn and the fqn
- after split for call_module and get_attr.
- """
- def flatten(x: torch.fx.node.Argument) -> NodeList:
- """
- Stores nodes in x to a list and returns the list.
- """
- r: NodeList = []
- map_arg(x, r.append)
- return r
- # Mapping from node in original module to node in created submodule.
- node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
- # Mapping from node in original module or created submodules to
- # corresponding component.
- node_to_component: dict[torch.fx.Node, Component] = {}
- # Mapping from tag to the corresponding component.
- tag_to_component: dict[str, Component] = {}
- # Stores all components.
- all_components: list[Component] = []
- # Stores nodes that will be used in main graph.
- used_in_main: dict[torch.fx.Node, None] = {}
- # Main graph after split.
- main_g = torch.fx.Graph()
- # Mapping from node in original module to node in main graph after split.
- main_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
- # Output node of original module.
- output_node: Optional[torch.fx.Node] = None
- # Create a component for each tag, we don't expect to create other components afterwards.
- for tag in tags:
- comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
- all_components.append(comp)
- tag_to_component[tag] = comp
- # Traverse the nodes in original graph and take care of them.
- for node in gm.graph.nodes:
- if node.op == "output":
- if output_node is not None:
- raise RuntimeError("Multiple output nodes in graph!")
- output_node = node
- continue
- # Placeholders in the original graph get copied to main graph.
- if node.op == "placeholder":
- main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
- main_remapping[node].meta = copy.copy(node.meta)
- continue
- # Get_attr nodes are ignored because we are not tagging them.
- # Instead, we copy them directly to the submodules use them afterwards.
- if node.op == "get_attr":
- continue
- # Now we process callable nodes which are nodes with op of call_module,
- # call_function or call_method. Every callable nodes should be tagged.
- if not hasattr(node, "tag"):
- raise AssertionError(f"Node does not have tag: {node.format_node()}")
- upstream_components = [
- node_to_component[x]
- for x in flatten(node.args) + flatten(node.kwargs)
- if x.op not in {"placeholder", "get_attr"}
- ]
- comp = tag_to_component[node.tag]
- node_to_component[node] = comp
- # Max order of upperstream components.
- mx = max((c.order for c in upstream_components), default=0)
- # Expect the component for `node` has higher order then its upstream components.
- if comp.order < mx:
- raise AssertionError(
- f"Component {comp.name} order must be >= max of its upstream components, "
- f"order={comp.order} and max={mx}"
- )
- # Map a input of `node` to nodes in the component's graph.
- def remap_func(x):
- # If input is a get_attr node, copy it to current component's graph.
- # Returns the get_attr node in current component's graph.
- if x.op == "get_attr":
- if x not in comp.getattr_maps:
- comp.getattr_maps[x] = comp.graph.get_attr(
- x.target, type_expr=x.type
- )
- comp.getattr_maps[x].meta = copy.copy(x.meta)
- return comp.getattr_maps[x]
- # If input is not a placeholder, it should have been put into a component
- # already. If it's the current component then we return the corresponding
- # node in the component.
- if x.op != "placeholder" and node_to_component[x] == comp:
- return node_remapping[x]
- # If input is a placeholder or it's in other components, we want to make it
- # as a placeholder in current component's graph.
- if x not in comp.orig_inputs:
- comp.orig_inputs.append(x)
- placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
- placeholder.meta = copy.copy(x.meta)
- comp.input_placeholders.append(placeholder)
- used_in_main[x] = None
- return comp.input_placeholders[comp.orig_inputs.index(x)]
- n = comp.graph.node_copy(node, remap_func)
- n.tag = node.tag # type: ignore[attr-defined]
- node_remapping[node] = n
- node_to_component[n] = comp
- if output_node is None:
- raise RuntimeError("Graph had no output node!")
- for x in flatten(output_node.args[0]):
- if x.op == "get_attr":
- # We don't need components mapping for nodes of type "get_attr"
- # that are consumed by the output. Only need to make sure we create
- # corresponding counterparts in the resulting graph.
- main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
- else:
- # All component results consumed by the output node should be
- # marked as "used in main".
- used_in_main[x] = None
- # If a node is used in main graph then we mark it as an output in the component
- # it belongs to.
- for n in used_in_main:
- if n.op != "placeholder":
- node_to_component[n].orig_outputs.append(n)
- # Now we create a graphmodule for each component.
- orig_to_split_fqn_mapping: dict[str, str] = {}
- for comp in all_components:
- outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
- if return_tuple:
- comp.graph.output(outs)
- else:
- # Take care of the args of FX output node. If there's a single
- # output then the output node args is like (output_single), else
- # if there're multiple outputs then the output node args is like
- # ((output_0, output_1, ...)).
- comp.graph.output(outs[0] if len(outs) == 1 else outs)
- comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
- gm, subgraph=comp.graph, comp_name=comp.name
- )
- orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
- # Create a call_module node in main graph.
- main_node = main_g.call_module(
- comp.name,
- args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
- kwargs=None,
- )
- if len(outs) == 1 and not return_tuple:
- main_remapping[comp.orig_outputs[0]] = main_node
- else:
- for i, o in enumerate(comp.orig_outputs):
- # Use Proxy to record getitem access.
- main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index]
- main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
- main_root = HolderModule({comp.name: comp.gm for comp in all_components})
- main_g._codegen = gm.graph._codegen
- # If the output nodes consumes get_attr directly in the original graph,
- # then we need to make sure get_attr is copied to the new graph.
- for x in flatten(output_node.args[0]):
- if x.op == "get_attr":
- setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
- result_gm = GraphModuleCls(main_root, main_g)
- if return_fqn_mapping:
- return result_gm, orig_to_split_fqn_mapping
- return result_gm
- @compatibility(is_backward_compatible=False)
- def move_non_tensor_nodes_on_boundary(subgraphs) -> None:
- """
- Move non-tensor nodes on the boundary between subgraphs.
- For each subgraph:
- 1. Find nodes whose type is not tensor and any of its children is in another
- subgraph, put them in a queue for next step
- 2. Do a BFS on those nodes in the queue, and run a DFS for each node, let's say node X and it is in subgraph A:
- a. if it is in to_subgraph, return (continue DFS)
- b. if it is in from_subgraph, collect the nodes to nodes_to_move, and continue DFS
- c. otherwise, this means it cannot be moved
- d. also check if node X's parent should be put into the queue. (The queue may
- have duplicated nodes, just process the node once)
- Args:
- subgraphs: List of subgraphs containing nodes to be processed
- """
- # Create a mapping from node to subgraph for quick lookup
- node_to_subgraph: dict[torch.fx.Node, int] = {}
- for i, subgraph in enumerate(subgraphs):
- for node in subgraph.nodes:
- node_to_subgraph[node] = i
- def get_children_in_graph(node: torch.fx.Node) -> list[torch.fx.Node]:
- """Get children nodes that are in callable ops and in some subgraph"""
- return [
- user
- for user in node.users
- if user.op in CALLABLE_NODE_OPS and user in node_to_subgraph
- ]
- def get_parents_in_graph(node: torch.fx.Node) -> list[torch.fx.Node]:
- """Get parent nodes that are in callable ops and in some subgraph"""
- return [
- arg
- for arg in node.all_input_nodes
- if arg.op in CALLABLE_NODE_OPS and arg in node_to_subgraph
- ]
- def has_children_in_other_subgraph(
- node: torch.fx.Node, current_subgraph_idx: int
- ) -> bool:
- """
- Check if the node has any children in a subgraph different from current_subgraph_idx.
- This is the requirement used in both step 1 and step d.
- """
- children = get_children_in_graph(node)
- return any(
- node_to_subgraph[child] != current_subgraph_idx for child in children
- )
- def can_move_node_and_dependencies(
- node: torch.fx.Node, from_subgraph: int, to_subgraph: int
- ) -> tuple[bool, set[torch.fx.Node]]:
- """
- Check if node and its dependencies can be moved from from_subgraph to to_subgraph.
- Returns (can_move, nodes_to_move)
- For node X, do a DFS on its descendants, for each node:
- - if it is in to_subgraph, return (continue DFS)
- - if it is in from_subgraph, collect the nodes to nodes_to_move, and continue DFS
- - otherwise, this means it cannot be moved
- """
- nodes_to_move = set()
- visited = set()
- can_move = True
- def dfs(current_node):
- nonlocal can_move, nodes_to_move
- if current_node in visited:
- return
- visited.add(current_node)
- # Check current node's subgraph
- if current_node not in node_to_subgraph:
- return # Skip nodes not in any subgraph
- current_subgraph = node_to_subgraph[current_node]
- if current_subgraph == to_subgraph:
- # If it is in to_subgraph, just end DFS
- return
- elif current_subgraph == from_subgraph:
- # If it is in from_subgraph, collect it and continue DFS
- nodes_to_move.add(current_node)
- else:
- # Otherwise, this means it cannot be moved
- can_move = False
- return
- # Continue DFS on children
- children = get_children_in_graph(current_node)
- for child in children:
- if can_move: # Only continue if we haven't already failed
- dfs(child)
- # Start DFS from the original node
- dfs(node)
- return can_move, nodes_to_move
- # For each subgraph, find non-tensor nodes with children in other subgraphs
- for subgraph_idx, subgraph in enumerate(subgraphs):
- # non acc nodes cannot be moved to downstream acc graph, so skip
- if not subgraph.is_acc:
- continue
- # Step 1: Find non-tensor nodes with children in other subgraphs
- queue: list[torch.fx.Node] = []
- processed: set[torch.fx.Node] = set()
- for node in subgraph.nodes:
- # Check if node is non-tensor
- if is_node_output_tensor(node):
- continue
- # Check if node meets step 1 requirement: any children in another subgraph
- if has_children_in_other_subgraph(node, subgraph_idx):
- queue.append(node)
- # Step 2: BFS to move nodes that meet the criteria
- while queue:
- current_node = queue.pop(0)
- # Skip if already processed (queue may have duplicates)
- if current_node in processed:
- continue
- processed.add(current_node)
- # Skip if node is no longer in this subgraph (may have been moved)
- if (
- current_node not in node_to_subgraph
- or node_to_subgraph[current_node] != subgraph_idx
- ):
- continue
- children = get_children_in_graph(current_node)
- if len(children) == 0:
- raise AssertionError(
- "Only node that has children in other subgraph can be moved"
- )
- # Find target subgraph. The children should all be in the same subgraph except current subgraph
- target_subgraph_candidates = set()
- for child in children:
- child_subgraph = node_to_subgraph[child]
- if child_subgraph != subgraph_idx:
- target_subgraph_candidates.add(child_subgraph)
- # If multiple children live in different subgraphs, the node cannot be moved. User needs to find other ways to move it.
- if len(target_subgraph_candidates) != 1:
- print(
- f"Cannot move non-tensor node {current_node.name} on boundary because it has children in multiple subgraphs"
- )
- continue
- target_subgraph = target_subgraph_candidates.pop()
- # Check if we can move this node and its dependencies
- can_move, nodes_to_move = can_move_node_and_dependencies(
- current_node, subgraph_idx, target_subgraph
- )
- if can_move:
- # Move all nodes in nodes_to_move to target subgraph
- for node_to_move in nodes_to_move:
- # Remove from current subgraph
- subgraph.nodes.remove(node_to_move)
- # Add to target subgraph
- subgraphs[target_subgraph].nodes.append(node_to_move)
- # Update mapping
- node_to_subgraph[node_to_move] = target_subgraph
- print(
- f"In order move the non-tensor node {current_node.name} on boundary, "
- f"moved node {node_to_move.name} from {'acc' if subgraph.is_acc else 'gpu'}_{subgraph_idx} "
- f"to {'acc' if subgraphs[target_subgraph].is_acc else 'gpu'}_{target_subgraph}"
- )
- # Add parents to the queue if they're non-tensor and not already processed
- # and meet the requirement from step 1 (any children in another subgraph)
- parents = get_parents_in_graph(current_node)
- for parent in parents:
- if (
- not is_node_output_tensor(parent)
- and parent not in processed
- and parent in node_to_subgraph
- and node_to_subgraph[parent] == subgraph_idx
- ):
- # Check if parent meets step 1 requirement: any children in another subgraph
- if not has_children_in_other_subgraph(parent, subgraph_idx):
- raise AssertionError(
- f"Parent {parent.name} should have children in another subgraph"
- )
- queue.append(parent)
|