| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139 |
- # mypy: allow-untyped-defs
- import argparse
- import copy
- import json
- import os
- from collections import defaultdict
- from collections.abc import Iterable, Sequence
- from dataclasses import dataclass
- from typing import Any, Literal, NamedTuple, Optional
- import torch
- from torch._logging import trace_structured
- from torch.fx._compatibility import compatibility
- from torch.fx.node import map_arg
- from torch.fx.passes.graph_manipulation import get_size_of_node
- from .graph_drawer import FxGraphDrawer
- from .operator_support import get_node_target, OperatorSupportBase
- from .shape_prop import ShapeProp
- from .split_utils import move_non_tensor_nodes_on_boundary, split_by_tags
- from .tools_common import (
- CALLABLE_NODE_OPS,
- FxNetAccFusionsFinder,
- is_node_output_tensor,
- NodeList,
- NodeSet,
- Tensors,
- )
- __all__ = [
- "FxNetAccNodesFinder",
- "FxNetSplitterInternalError",
- "Subgraph",
- "SplitResult",
- "generate_inputs_for_submodules",
- "NodeEvent",
- "NodeEventTracker",
- ]
- DEFAULT_MIN_ACC_MODULE_SIZE = 1
- DEFAULT_SKIP_FUSION = False
- DEFAULT_ALLOW_NON_TENSOR = False
- # ENV var and constants for node tracker
- TRACKER_DUMP_PATH = "_fx_net_tracker"
- NODES_SUFFIX = "_nodes.txt"
- ALL_SUFFIX = "_all.txt"
- ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE = "FX_NET_ACC_SPLITTER_TRACKER_MODE"
- ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH = "FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH"
- ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES = (
- "FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES"
- )
- DUMP_PREFIX = os.environ.get(
- ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH, TRACKER_DUMP_PATH
- )
- """
- Different modes of the event tracker for local debugging:
- "0": No local dumps. Information available by setting breakpoints and visually inspect in pdb.
- "1": Dump all events to DUMP_PREFIX_all.txt
- "2": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES
- recursively and dump to DUMP_PREFIX_nodex.txt
- "3": In addition to events dump, track all nodes with more than 1 event recursively and dump to DUMP_PREFIX_nodex.txt
- In addition to the above local dumps, tracker is always enabled and dumps via trace_structured.
- """
- # pyrefly: ignore [bad-assignment]
- TRACKER_MODE: Literal["0", "1", "2", "3"] = os.environ.get(
- ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE, "0"
- ) # type: ignore[assignment]
- class _SplitterSettingBase:
- def __init__(
- self,
- min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
- skip_fusion=DEFAULT_SKIP_FUSION,
- allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR,
- max_acc_splits: int = -1,
- move_non_tensor_nodes_on_boundary: bool = False,
- ):
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--min-acc-module-size",
- "--min_acc_module_size",
- required=False,
- type=int,
- help="Minimum size limit of an accelerator subgraph.",
- )
- parser.add_argument(
- "--max-acc-splits",
- "--max_acc_splits",
- required=False,
- type=int,
- help="Enforce a maximum number of split subgraphs.",
- )
- parser.add_argument(
- "--skip-fusion",
- "--skip_fusion",
- default=False,
- action="store_true",
- help="If true then no fusion groups. Fusion group is used to "
- "enforce no non-tensor data flow between submodules. If we don't "
- "have this constrain, setting this to false is recommended as it "
- "can reduce overhead.",
- )
- parser.add_argument(
- "--allow-non-tensor",
- "--allow_non_tensor",
- default=False,
- action="store_true",
- help="For some backends non-tensor data flow between cpu and them "
- "are not allowed. Therefore, if a node supported by accelerator but "
- "it has non-tensor inputs or outputs to a cpu node we would want to "
- "consider it as a cpu node during splitting. However, for some backends "
- "we might not care about non-tensor data flow and we can set this option "
- "to true to disable the functionality that prevent non-tensor data flow.",
- )
- parser.add_argument(
- "--move-non-tensor-nodes-on-boundary",
- "--move_non_tensor_nodes_on_boundary",
- required=False,
- action="store_true",
- help="AOTI does not support non-tensor nodes on acc->acc, acc->gpu and gpu->acc boundary. "
- "For non-tensor nodes on acc->acc boundary and acc->gpu, we move the nodes from upstream to downstream. "
- "For non-tensor nodes on gpu->acc boundary, it is handled by the pre-split process. "
- "(by method reduce_acc_nodes_non_tensor_input). ",
- )
- args, _unknown = parser.parse_known_args()
- self.min_acc_module_size: int = (
- args.min_acc_module_size
- if args.min_acc_module_size
- else min_acc_module_size
- )
- self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
- self.allow_non_tensor: bool = (
- args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
- )
- self.max_acc_splits: int = max_acc_splits
- self.move_non_tensor_nodes_on_boundary: bool = (
- args.move_non_tensor_nodes_on_boundary
- if args.move_non_tensor_nodes_on_boundary
- else move_non_tensor_nodes_on_boundary
- )
- @compatibility(is_backward_compatible=False)
- class NodeEvent:
- """
- An event in graph split that happened on a node.
- source: Subject of the event
- desc: readable description
- dep: Optional dependency, usually the node that caused the event.
- """
- def __init__(
- self, source: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None
- ):
- self.source = source
- self.desc = desc
- self.dep = dep
- def to_str(self):
- # source: The name of the subject of the event.
- # desc: description of the event, in the format of <event_type>|<explanation>
- # dep: The name of the cause of this event, which is another node, or #
- # if it's caused by the subject node
- return f"{self.source.name}: {self.desc} {self.dep.name if self.dep else '#'}"
- @compatibility(is_backward_compatible=False)
- class NodeEventTracker:
- """
- Tracks node events during the splitter execution.
- """
- def __init__(self, tracker_mode, dump_prefix):
- self.tracker_mode = tracker_mode
- self.dump_prefix = dump_prefix
- # list of events
- self.events = []
- # dict from node name to event index
- self.node_events = {}
- self.writer = print
- def add(self, node: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None):
- """
- Add a new event to the tracker.
- """
- event = NodeEvent(node, desc, dep)
- self.events.append(event)
- if node.name not in self.node_events:
- self.node_events[node.name] = []
- self.node_events[node.name].append(len(self.events) - 1)
- def print_node(self, node_name, recursive=False, tab="", writer=None):
- """
- Print a node and its events.
- @param recursive: if True, print nodes that caused the events on this current node.
- @param tab: Indentation for dependencies.
- @param writer: function to write to file. If None, use print.
- """
- if not writer:
- writer = self.writer
- for idx in self.node_events.get(node_name, []):
- event = self.events[idx]
- writer(tab + event.to_str())
- if recursive and event.dep is not None:
- self.print_node(
- event.dep.name, recursive=True, tab="| " + tab, writer=writer
- )
- def to_dict(self):
- """
- Create dict dump on all events.
- """
- ret: dict[str, list[str]] = {}
- for name in self.node_events:
- ret[name] = []
- for idx in self.node_events.get(name, []):
- event = self.events[idx]
- ret[name].append(event.to_str())
- return ret
- def print_all(self, writer=None):
- """
- Print all nodes in a list.
- @param writer: function to write to file. If None, use print.
- """
- if not writer:
- writer = self.writer
- for name in self.node_events:
- writer(f"Node: {name}:")
- self.print_node(name, recursive=False, tab=" ", writer=writer)
- def dump(self):
- """
- Function to be invoked at the end of the finder execution to printout tracked events specified by the mode.
- """
- # dump via trace_structured
- trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "fx_net_acc_splitter_finder_events",
- "encoding": "json",
- },
- payload_fn=lambda: json.dumps(self.to_dict()),
- )
- def writeln(f):
- def fn(x):
- return f.write(x + "\n")
- return fn
- # Mode 0: no local dump
- # Mode >=1: Dump all events to file
- if self.tracker_mode >= 1:
- with open(self.dump_prefix + ALL_SUFFIX, "w") as f:
- self.print_all(writeln(f))
- def dump_selected_nodes(nodes):
- with open(self.dump_prefix + NODES_SUFFIX, "w") as f:
- for node_name in nodes:
- writeln(f"===== Tracking node {node_name} =====")
- self.print_node(
- node_name, recursive=True, tab="|-", writer=writeln(f)
- )
- writeln(f"===== End of tracking node {node_name} =====")
- # Mode 2: Dump specific nodes in recursive manner.
- # Mode 3: Dump all nodes with more than 1 event in recursive manner.
- if self.tracker_mode == 2 or self.tracker_mode == 3:
- nodes = (
- os.environ.get(ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES, "").split(
- ","
- )
- if self.tracker_mode == 2
- else [
- name for name, events in self.node_events.items() if len(events) > 1
- ]
- )
- dump_selected_nodes(nodes)
- @compatibility(is_backward_compatible=False)
- class FxNetAccNodesFinder:
- """
- Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
- input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
- I.e. if we have a chain:
- ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
- where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
- This behavior can be turned off by passing allow_non_tensor=True.
- """
- def __init__(
- self,
- module: torch.fx.GraphModule,
- operator_support: OperatorSupportBase,
- allow_non_tensor: bool,
- ):
- self.module = module
- self.operator_support = operator_support
- self.allow_non_tensor = allow_non_tensor
- self.acc_nodes: NodeSet = set()
- self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX)
- def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
- """
- Transitively excludes nodes from ACC supported set.
- For every node in the worklist:
- - removes its downstream ACC nodes from ACC supported set,
- - if any downstream ACC node produces non-tensor output,
- then it gets added into the worklist.
- """
- while cpu_worklist:
- node = cpu_worklist.pop(0)
- for user in node.users:
- if user in self.acc_nodes:
- self.acc_nodes.remove(user)
- self.tracker.add(user, "acc_del|user_of_new_cpu_node", node)
- if not is_node_output_tensor(user):
- self.tracker.add(user, "new_cpu_node|non_tensor_output")
- cpu_worklist.append(user)
- def reduce_acc_nodes_non_tensor_input(self):
- """
- Excludes nodes from ACC supported set that have direct
- upstream CPU nodes that produce non-tensor outputs.
- """
- non_tensor_cpu_nodes: NodeList = []
- for node in self.module.graph.nodes:
- if node.op not in CALLABLE_NODE_OPS:
- continue
- if node in self.acc_nodes:
- continue
- if is_node_output_tensor(node):
- continue
- self.tracker.add(node, "new_cpu_node|callable_non_tensor_input")
- non_tensor_cpu_nodes.append(node)
- self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
- def reduce_acc_nodes_non_tensor_output(self):
- """
- Excludes nodes from ACC supported set that produce non-tensor
- outputs and have downstream CPU nodes.
- """
- while True:
- new_cpu_nodes: NodeList = []
- for acc_node in self.acc_nodes:
- if is_node_output_tensor(acc_node):
- continue
- for user in acc_node.users:
- if user not in self.acc_nodes:
- new_cpu_nodes.append(acc_node)
- self.tracker.add(
- acc_node, "acc_del|non_tensor_output_with_cpu_user", user
- )
- break
- if not new_cpu_nodes:
- break
- for new_cpu_node in new_cpu_nodes:
- self.acc_nodes.remove(new_cpu_node)
- self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
- def __call__(self) -> NodeSet:
- submodules = dict(self.module.named_modules())
- self.acc_nodes = set()
- for n in self.module.graph.nodes:
- if n.op not in CALLABLE_NODE_OPS:
- self.tracker.add(n, "init_cpu|not_callable")
- continue
- if not self.operator_support.is_node_supported(submodules, n):
- self.tracker.add(n, "init_cpu|operator_support")
- continue
- self.tracker.add(n, "init_acc|callable_and_operator_supported")
- self.acc_nodes.add(n)
- if not self.allow_non_tensor:
- self.reduce_acc_nodes_non_tensor_input()
- self.reduce_acc_nodes_non_tensor_output()
- self.tracker.dump()
- return self.acc_nodes
- @compatibility(is_backward_compatible=False)
- class FxNetSplitterInternalError(Exception):
- pass
- @compatibility(is_backward_compatible=False)
- @dataclass
- class Subgraph:
- is_acc: bool
- nodes: NodeList
- device_ordinal: Optional[int] = None
- @compatibility(is_backward_compatible=False)
- class SplitResult(NamedTuple):
- """
- Stores the results of the splitter.
- Attributes:
- split_module: root module after splitting.
- submodule_inputs: a dict that maps submodule name to its inputs.
- non_acc_submodule_prefix: the prefix for non acc submodules. For
- acc submodule the prefix is always "_run_on_acc_".
- """
- split_module: torch.fx.GraphModule
- submodule_inputs: dict[str, Any]
- non_acc_submodule_prefix: str
- @compatibility(is_backward_compatible=False)
- def generate_inputs_for_submodules(
- model: torch.nn.Module,
- inputs: Sequence[Any],
- target_submodules: Iterable[str],
- deepcopy: bool = False,
- ) -> dict[str, Any]:
- """
- Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
- function doesn't work.
- Args:
- model: root model.
- inputs: inputs to the root model.
- target_submodules: submodules that we want to generate inputs for.
- Returns:
- A dict that maps from submodule name to its inputs.
- """
- handles = []
- results = {}
- submodule_to_names = {mod: name for name, mod in model.named_modules()}
- def pre_forward(module, module_inputs):
- results[submodule_to_names[module]] = (
- copy.deepcopy(module_inputs) if deepcopy else module_inputs
- )
- for name, mod in model.named_modules():
- if name in target_submodules:
- if not isinstance(mod, torch.jit.ScriptModule):
- handles.append(mod.register_forward_pre_hook(pre_forward))
- def clean_up_handles():
- for h in handles:
- h.remove()
- try:
- with torch.no_grad():
- model(*inputs)
- except Exception as e:
- clean_up_handles()
- raise e
- clean_up_handles()
- return results
- class _SplitterBase:
- """
- Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
- Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
- Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
- Given the following graph:
- ==> b ==>
- // \\
- a d
- \\ //
- ==> c ==>
- class SimpleModule(torch.nn.Module):
- def forward(self, a):
- b = torch.sin(a)
- c = torch.cos(a)
- d = b + c
- return d
- and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
- we will get the following split result:
- main:
- def forward(self, a):
- run_on_acc_0_0 = self._run_on_acc_0_0(a)
- getitem = run_on_acc_0_0[0]
- getitem_1 = run_on_acc_0_0[1]
- run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
- return run_on_cpu_1_1
- _run_on_acc_0_0:
- def forward(self, a):
- sin_1 = torch.sin(a)
- cos_1 = torch.cos(a)
- return (sin_1, cos_1)
- _run_on_cpu_1_1:
- def forward(self, sin_1, cos_1):
- add_1 = sin_1 + cos_1
- return add_1
- """
- # PCIe bandwidth for the backend, default to 100 GB/s
- PCIe_BW = 100 * 2**30
- def __init__(
- self,
- module: torch.fx.GraphModule,
- sample_input: Sequence[Any],
- operator_support: OperatorSupportBase,
- settings: _SplitterSettingBase,
- non_acc_submodule_name: str = "_run_on_cpu_",
- return_tuple: bool = False,
- nodes_finder: Optional[FxNetAccNodesFinder] = None,
- ):
- """
- Preprocesses graph before splitting:
- - finds nodes supported by ACC,
- - finds fusion groups for ACC nodes having non-tensor IO,
- - builds a graph of direct dependencies,
- - builds a map of fused nodes to their fusions.
- As a result we get self.acc_nodes, self.deps and self.fusions.
- """
- if not isinstance(module, torch.fx.GraphModule):
- raise AssertionError(f"Expected GraphModule, got {type(module)}")
- self.module = module
- ShapeProp(self.module).propagate(*sample_input)
- self.settings = settings
- self.operator_support = operator_support
- self.sample_input = sample_input
- if nodes_finder is None:
- nodes_finder = FxNetAccNodesFinder(
- self.module, self.operator_support, self.settings.allow_non_tensor
- )
- self.acc_nodes = nodes_finder()
- if self.settings.skip_fusion:
- self.fusions = {}
- else:
- self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
- # Modify deps to add more deps for fused nodes
- self.deps = self.find_deps()
- self.update_deps_for_fusions()
- self.non_acc_submodule_name = non_acc_submodule_name
- self._node_submodule_map: dict[str, str] = {}
- self._return_tuple = return_tuple
- self.tags: list[str] = []
- # ===============================================================
- # Helpers for ctor and initial state
- # ===============================================================
- def get_node_submodule_map(self) -> dict[str, str]:
- """Returns a map from node name to submodule name, e.g.
- node: main_module_impl_impl_over_arch_unary_multiple_embedding
- _pooling_embedding_pooling_sparse_entity_equivalence_key
- _proxy_embedding_bag
- maps to submodule name of: _run_on_acc_1
- """
- return self._node_submodule_map
- def find_deps(self) -> dict[torch.fx.Node, NodeSet]:
- """
- Builds a graph of node dependencies. Leaf nodes don't have any
- dependencies and the "output" node doesn't have nodes depending on it.
- Resulting graph has only direct dependencies, i.e. there are no
- transitive dependencies.
- """
- deps: dict[torch.fx.Node, NodeSet] = defaultdict(set)
- for node in self.module.graph.nodes:
- if node.op not in CALLABLE_NODE_OPS:
- continue
- for user in node.users:
- if user.op != "output":
- deps[user].add(node)
- return deps
- def update_deps_for_fusions(self):
- """
- Updates graph of dependencies so that:
- - nodes from the same fusion depend on the same set of outer nodes,
- - outer nodes depending on a fusion depend on all nodes in that fusion.
- """
- for node in self.fusions:
- fusion = self.fusions[node]
- for fused_neighbor in fusion:
- self.deps[node].update(self.deps[fused_neighbor] - fusion)
- for user in fused_neighbor.users:
- if user not in fusion:
- self.deps[user].add(node)
- # ===============================================================
- # Helpers for preview
- # ===============================================================
- def _lower_model_to_backend(
- self, mod: torch.fx.GraphModule, inputs: Tensors
- ) -> torch.nn.Module:
- """
- Lower the model to a backend.
- """
- return mod
- def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str:
- """
- When an error occurs during lowering or running the lowered mod, we use this
- function to find culprits in the `mod` that causes the error.
- """
- return "Unable to find a culprit because _find_culprit() function is not implemented."
- def _draw_graph_based_on_node_support(
- self, mod: torch.fx.GraphModule, supported_nodes: NodeList
- ):
- color_map = {
- "default": "AliceBlue",
- "supported": "chartreuse1",
- "unsupported": "crimson",
- }
- class CustomDrawer(FxGraphDrawer):
- def _get_node_style(self, node):
- template = super()._get_node_style(node)
- if node in supported_nodes:
- template["fillcolor"] = color_map["supported"]
- elif node.op in CALLABLE_NODE_OPS:
- template["fillcolor"] = color_map["unsupported"]
- else:
- template["fillcolor"] = color_map["default"]
- return template
- drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
- dot_graph = drawer.get_main_dot_graph()
- # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
- dot_graph.write_raw("node_support.dot") # type: ignore[attr-defined]
- def node_support_preview(self, dump_graph: bool = False):
- submodules = dict(self.module.named_modules())
- supported_nodes: NodeList = []
- supported_node_types = defaultdict(set)
- unsupported_node_types = defaultdict(set)
- def get_dtype(arg):
- tensor_meta = arg.meta.get("tensor_meta")
- return getattr(tensor_meta, "dtype", None)
- for node in self.module.graph.nodes:
- if node.op not in CALLABLE_NODE_OPS:
- continue
- target = get_node_target(submodules, node)
- # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
- arg_dtypes = [
- get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
- for arg in node.args
- ]
- # Find last non-None element. If all elements are None, return max_len.
- last_index = len(arg_dtypes) - next(
- (
- i
- for i, dtype in enumerate(reversed(arg_dtypes))
- if dtype is not None
- ),
- len(arg_dtypes),
- )
- # Strip None elements at the end.
- arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
- kwarg_dtypes_tuple = tuple(
- (k, get_dtype(arg))
- for k, arg in node.kwargs.items()
- if isinstance(arg, torch.fx.Node)
- )
- if self.operator_support.is_node_supported(submodules, node):
- supported_nodes.append(node)
- supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
- else:
- unsupported_node_types[target].add(
- (arg_dtypes_tuple, kwarg_dtypes_tuple)
- )
- if dump_graph:
- self._draw_graph_based_on_node_support(self.module, supported_nodes)
- reports = "\nSupported node types in the model:\n"
- for t, dtypes in supported_node_types.items():
- for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
- reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
- reports += "\nUnsupported node types in the model:\n"
- for t, dtypes in unsupported_node_types.items():
- for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
- reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
- print(reports)
- # Return reports for testing purpose
- return reports
- def split_preview(self, dump_graph: bool = False):
- reports = ""
- subgraphs = self.put_nodes_into_subgraphs()
- acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
- cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
- reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
- reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
- subgraphs = self.remove_small_acc_subgraphs(subgraphs)
- acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
- cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
- reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
- reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
- for i, subgraph in enumerate(subgraphs):
- reports += (
- f"_run_on_acc_{i}: "
- if subgraph.is_acc
- else f"{self.non_acc_submodule_name}{i}: "
- )
- reports += f"{len(subgraph.nodes)} node(s)\n"
- self.tag(subgraphs)
- split_mod = self.split(remove_tag=True)
- split_mod.eval()
- if dump_graph:
- drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True)
- dot_graphs = drawer.get_all_dot_graphs()
- for name, dot_graph in dot_graphs.items():
- # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
- dot_graph.write_raw(f"{name}.dot") # type: ignore[attr-defined]
- max_qps: float = self.PCIe_BW
- bottleneck_module = ""
- for node in split_mod.graph.nodes:
- if node.op == "call_module" and "acc" in node.target:
- reports += f"\nProcessing acc submodule {node.target}\n"
- submod = getattr(split_mod, node.target)
- def get_submod_inputs(main_mod, submod, example_inputs):
- sub_inputs = None
- def get_inputs(self, inputs):
- nonlocal sub_inputs
- sub_inputs = inputs
- handle = submod.register_forward_pre_hook(get_inputs)
- main_mod(*example_inputs)
- handle.remove()
- return sub_inputs
- submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input)
- ShapeProp(submod).propagate(*submod_inputs)
- total_input_bytes = 0
- total_output_bytes = 0
- reports += "Checking inputs...\n"
- for n in submod.graph.nodes:
- if n.op == "placeholder":
- if not is_node_output_tensor(n):
- reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
- else:
- total_input_bytes += get_size_of_node(submod, n)[0]
- if n.op == "output":
- output_node = n
- reports += "Checking outputs...\n"
- def get_bytes(node: torch.fx.Node):
- nonlocal total_output_bytes
- nonlocal reports
- if not is_node_output_tensor(node):
- reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
- else:
- total_output_bytes += get_size_of_node(submod, node)[0]
- map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
- qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
- reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
- reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
- if qps < max_qps:
- max_qps = qps
- bottleneck_module = node.target
- try:
- lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
- except RuntimeError:
- reports += "Run into an error during lowering!\n"
- reports += self._find_culprit(submod, submod_inputs)
- continue
- try:
- lowered_submod(*submod_inputs)
- except RuntimeError:
- reports += "Run into an error during inference!\n"
- reports += self._find_culprit(submod, submod_inputs)
- else:
- reports += "Lowering and running succeed!\n"
- reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
- reports += f" bottleneck is submodule {bottleneck_module}."
- print(reports)
- # return the reports for testing purposes
- return reports
- # ===============================================================
- # Helpers for extend_acc_subgraph() method
- # ===============================================================
- def find_reverse_deps(
- self, tag_id: Optional[int] = None
- ) -> dict[torch.fx.Node, NodeSet]:
- """
- Builds reversed topological node dependencies, if tag_id is specified,
- we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
- """
- result: dict[torch.fx.Node, NodeSet] = defaultdict(set)
- for node in self.module.graph.nodes:
- if node.op not in CALLABLE_NODE_OPS:
- continue
- for user in node.users:
- if user.op not in CALLABLE_NODE_OPS:
- continue
- if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
- result[node].add(user)
- return result
- def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]):
- processed_node = set()
- for node, fusion in self.fusions.items():
- if node in processed_node:
- continue
- new_dep = set()
- # Create a new dependency set which include all the
- # dependencies of the nodes in the fusion group
- for n in fusion:
- new_dep.update(deps[n])
- # Exclude nodes in the fusion
- new_dep.difference_update(fusion)
- # Update dependency
- for n in fusion:
- deps[n] = new_dep
- for arg in n.all_input_nodes:
- if arg not in fusion:
- deps[arg].update(fusion)
- processed_node.add(n)
- def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
- """
- Finds parent nodes of the `tag` subgraph.
- Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
- and is not a placeholder, we consider it as the parent node of the subgraph.
- """
- parent_nodes = set()
- for node in self.module.graph.nodes:
- if node.op in CALLABLE_NODE_OPS and node.tag == tag:
- for arg in node.all_input_nodes:
- if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
- parent_nodes.add(arg)
- return parent_nodes
- def extend_acc_subgraph(self, tag: str):
- """
- Extend the acc subgraph with `tag` going the reversed topological direction.
- """
- # Dict that maps node to its users and ignore users that
- # are in the subgraph that has greater tag
- deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1]))
- self.update_reverse_deps_for_fusions(deps)
- # Parent nodes of the subgraph
- parent_nodes = self.find_parent_nodes_of_subgraph(tag)
- visited_nodes: NodeSet = set()
- while parent_nodes:
- node = None
- # Find a acc node that depends on visited nodes only
- for n in parent_nodes:
- if deps[n] <= visited_nodes and n in self.acc_nodes:
- node = n
- break
- if node is None:
- break
- # Put the node into `tag` subgraph
- node.tag = tag # type: ignore[attr-defined]
- parent_nodes.remove(node)
- visited_nodes.add(node)
- # If node is in a fusion group, add all fusion buddies to parent nodes
- if node in self.fusions:
- for fusion_node in self.fusions[node]:
- if fusion_node not in visited_nodes:
- parent_nodes.add(fusion_node)
- # Add inputs of the node to parent nodes
- for arg in node.all_input_nodes:
- if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
- parent_nodes.add(arg)
- # ===============================================================
- # Helpers for split() method
- # ===============================================================
- def starter_nodes(self) -> tuple[NodeSet, NodeSet]:
- """
- Finds nodes that consume module inputs or get_attr nodes.
- """
- starter_cpu_nodes: NodeSet = set()
- starter_acc_nodes: NodeSet = set()
- for node in self.module.graph.nodes:
- # edge case, call_function, but with no dependencies
- if node.op == "call_function" and len(node.all_input_nodes) == 0:
- if node in self.acc_nodes:
- starter_acc_nodes.add(node)
- else:
- starter_cpu_nodes.add(node)
- if node.op not in {"placeholder", "get_attr"}:
- continue
- for user in node.users:
- if user in self.acc_nodes:
- starter_acc_nodes.add(user)
- else:
- starter_cpu_nodes.add(user)
- return starter_cpu_nodes, starter_acc_nodes
- def put_nodes_into_subgraphs(self) -> list[Subgraph]:
- # We start graph traversal from leaf nodes
- current_cpu_nodes, current_acc_nodes = self.starter_nodes()
- visited_nodes: NodeSet = set()
- # Determine which subgraph to start from based on which subgraph has
- # 0-dep node
- acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
- current_subgraph_nodes: NodeList = []
- # Result accumulator
- subgraphs: list[Subgraph] = []
- while current_cpu_nodes or current_acc_nodes:
- # Find the first node that should belong to the current subgraph and has all dependencies resolved
- current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
- node = next(
- (n for n in current_nodes if self.deps[n] <= visited_nodes),
- None,
- )
- # If nothing was found, then it's time to flip the mode and start a new subgraph
- if node is None:
- if not current_subgraph_nodes:
- raise FxNetSplitterInternalError("Subgraph can't be empty")
- subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
- )
- acc_subgraph = not acc_subgraph
- current_subgraph_nodes = []
- continue
- current_nodes.remove(node)
- visited_nodes.add(node)
- current_subgraph_nodes.append(node)
- # Add fusion buddies
- if node in self.fusions:
- if node in self.acc_nodes:
- current_acc_nodes.update(self.fusions[node] - visited_nodes)
- else:
- current_cpu_nodes.update(self.fusions[node] - visited_nodes)
- # Put depending nodes into the queue
- for user in node.users:
- if user.op not in CALLABLE_NODE_OPS:
- continue
- # Add downstream nodes
- if user in self.acc_nodes:
- current_acc_nodes.add(user)
- else:
- current_cpu_nodes.add(user)
- # Check if the last subgraph was not created
- if current_subgraph_nodes:
- subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
- )
- if not subgraphs:
- raise FxNetSplitterInternalError("Couldn't create subgraphs")
- return subgraphs
- def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]:
- """
- This pass finds ACC submodules with less than specified size and merges
- them with adjacent CPU submodules.
- """
- result: list[Subgraph] = []
- for subgraph in subgraphs:
- if subgraph.is_acc:
- if len(subgraph.nodes) >= self.settings.min_acc_module_size:
- result.append(subgraph)
- else:
- print(
- "Eliminating acc subgraph because it's smaller than the threshold: "
- f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
- )
- if result:
- result[-1].nodes.extend(subgraph.nodes)
- else:
- subgraph.is_acc = False
- result.append(subgraph)
- else:
- if result and not result[-1].is_acc:
- result[-1].nodes.extend(subgraph.nodes)
- else:
- result.append(subgraph)
- return result
- def tag(self, subgraphs: list[Subgraph]):
- self.tags = []
- for subgraph in subgraphs:
- tag = (
- f"_run_on_acc_{len(self.tags)}"
- if subgraph.is_acc
- else f"{self.non_acc_submodule_name}{len(self.tags)}"
- )
- self.tags.append(tag)
- for node in subgraph.nodes:
- if hasattr(node, "tag"):
- raise FxNetSplitterInternalError(f"Node {node} was already tagged")
- node.tag = tag # type: ignore[attr-defined]
- self._node_submodule_map[node.name] = tag
- def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
- split_module = split_by_tags(
- self.module, self.tags, return_tuple=self._return_tuple
- )
- if remove_tag:
- for node in self.module.graph.nodes:
- if hasattr(node, "tag"):
- del node.tag
- return split_module # type: ignore[return-value]
- def __call__(self) -> torch.fx.GraphModule:
- subgraphs = self.put_nodes_into_subgraphs()
- if self.settings.move_non_tensor_nodes_on_boundary:
- move_non_tensor_nodes_on_boundary(subgraphs)
- subgraphs = self.remove_small_acc_subgraphs(subgraphs)
- acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
- non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
- print(
- f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs"
- )
- self.tag(subgraphs)
- return self.split()
- def generate_split_results(self) -> SplitResult:
- split_module = self()
- submodule_names = []
- for name, _mod in split_module.named_children():
- submodule_names.append(name)
- if (
- self.settings.max_acc_splits > 0
- and len(submodule_names) > self.settings.max_acc_splits
- ):
- raise ValueError(
- "Cannot fulfill max_acc_splits limit. "
- "This may cause split fragmentation and "
- "result in performance issues."
- )
- submodule_inputs = generate_inputs_for_submodules(
- split_module, self.sample_input, submodule_names
- )
- return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
|