splitter_base.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139
  1. # mypy: allow-untyped-defs
  2. import argparse
  3. import copy
  4. import json
  5. import os
  6. from collections import defaultdict
  7. from collections.abc import Iterable, Sequence
  8. from dataclasses import dataclass
  9. from typing import Any, Literal, NamedTuple, Optional
  10. import torch
  11. from torch._logging import trace_structured
  12. from torch.fx._compatibility import compatibility
  13. from torch.fx.node import map_arg
  14. from torch.fx.passes.graph_manipulation import get_size_of_node
  15. from .graph_drawer import FxGraphDrawer
  16. from .operator_support import get_node_target, OperatorSupportBase
  17. from .shape_prop import ShapeProp
  18. from .split_utils import move_non_tensor_nodes_on_boundary, split_by_tags
  19. from .tools_common import (
  20. CALLABLE_NODE_OPS,
  21. FxNetAccFusionsFinder,
  22. is_node_output_tensor,
  23. NodeList,
  24. NodeSet,
  25. Tensors,
  26. )
  27. __all__ = [
  28. "FxNetAccNodesFinder",
  29. "FxNetSplitterInternalError",
  30. "Subgraph",
  31. "SplitResult",
  32. "generate_inputs_for_submodules",
  33. "NodeEvent",
  34. "NodeEventTracker",
  35. ]
  36. DEFAULT_MIN_ACC_MODULE_SIZE = 1
  37. DEFAULT_SKIP_FUSION = False
  38. DEFAULT_ALLOW_NON_TENSOR = False
  39. # ENV var and constants for node tracker
  40. TRACKER_DUMP_PATH = "_fx_net_tracker"
  41. NODES_SUFFIX = "_nodes.txt"
  42. ALL_SUFFIX = "_all.txt"
  43. ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE = "FX_NET_ACC_SPLITTER_TRACKER_MODE"
  44. ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH = "FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH"
  45. ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES = (
  46. "FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES"
  47. )
  48. DUMP_PREFIX = os.environ.get(
  49. ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH, TRACKER_DUMP_PATH
  50. )
  51. """
  52. Different modes of the event tracker for local debugging:
  53. "0": No local dumps. Information available by setting breakpoints and visually inspect in pdb.
  54. "1": Dump all events to DUMP_PREFIX_all.txt
  55. "2": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES
  56. recursively and dump to DUMP_PREFIX_nodex.txt
  57. "3": In addition to events dump, track all nodes with more than 1 event recursively and dump to DUMP_PREFIX_nodex.txt
  58. In addition to the above local dumps, tracker is always enabled and dumps via trace_structured.
  59. """
  60. # pyrefly: ignore [bad-assignment]
  61. TRACKER_MODE: Literal["0", "1", "2", "3"] = os.environ.get(
  62. ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE, "0"
  63. ) # type: ignore[assignment]
  64. class _SplitterSettingBase:
  65. def __init__(
  66. self,
  67. min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
  68. skip_fusion=DEFAULT_SKIP_FUSION,
  69. allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR,
  70. max_acc_splits: int = -1,
  71. move_non_tensor_nodes_on_boundary: bool = False,
  72. ):
  73. parser = argparse.ArgumentParser()
  74. parser.add_argument(
  75. "--min-acc-module-size",
  76. "--min_acc_module_size",
  77. required=False,
  78. type=int,
  79. help="Minimum size limit of an accelerator subgraph.",
  80. )
  81. parser.add_argument(
  82. "--max-acc-splits",
  83. "--max_acc_splits",
  84. required=False,
  85. type=int,
  86. help="Enforce a maximum number of split subgraphs.",
  87. )
  88. parser.add_argument(
  89. "--skip-fusion",
  90. "--skip_fusion",
  91. default=False,
  92. action="store_true",
  93. help="If true then no fusion groups. Fusion group is used to "
  94. "enforce no non-tensor data flow between submodules. If we don't "
  95. "have this constrain, setting this to false is recommended as it "
  96. "can reduce overhead.",
  97. )
  98. parser.add_argument(
  99. "--allow-non-tensor",
  100. "--allow_non_tensor",
  101. default=False,
  102. action="store_true",
  103. help="For some backends non-tensor data flow between cpu and them "
  104. "are not allowed. Therefore, if a node supported by accelerator but "
  105. "it has non-tensor inputs or outputs to a cpu node we would want to "
  106. "consider it as a cpu node during splitting. However, for some backends "
  107. "we might not care about non-tensor data flow and we can set this option "
  108. "to true to disable the functionality that prevent non-tensor data flow.",
  109. )
  110. parser.add_argument(
  111. "--move-non-tensor-nodes-on-boundary",
  112. "--move_non_tensor_nodes_on_boundary",
  113. required=False,
  114. action="store_true",
  115. help="AOTI does not support non-tensor nodes on acc->acc, acc->gpu and gpu->acc boundary. "
  116. "For non-tensor nodes on acc->acc boundary and acc->gpu, we move the nodes from upstream to downstream. "
  117. "For non-tensor nodes on gpu->acc boundary, it is handled by the pre-split process. "
  118. "(by method reduce_acc_nodes_non_tensor_input). ",
  119. )
  120. args, _unknown = parser.parse_known_args()
  121. self.min_acc_module_size: int = (
  122. args.min_acc_module_size
  123. if args.min_acc_module_size
  124. else min_acc_module_size
  125. )
  126. self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
  127. self.allow_non_tensor: bool = (
  128. args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
  129. )
  130. self.max_acc_splits: int = max_acc_splits
  131. self.move_non_tensor_nodes_on_boundary: bool = (
  132. args.move_non_tensor_nodes_on_boundary
  133. if args.move_non_tensor_nodes_on_boundary
  134. else move_non_tensor_nodes_on_boundary
  135. )
  136. @compatibility(is_backward_compatible=False)
  137. class NodeEvent:
  138. """
  139. An event in graph split that happened on a node.
  140. source: Subject of the event
  141. desc: readable description
  142. dep: Optional dependency, usually the node that caused the event.
  143. """
  144. def __init__(
  145. self, source: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None
  146. ):
  147. self.source = source
  148. self.desc = desc
  149. self.dep = dep
  150. def to_str(self):
  151. # source: The name of the subject of the event.
  152. # desc: description of the event, in the format of <event_type>|<explanation>
  153. # dep: The name of the cause of this event, which is another node, or #
  154. # if it's caused by the subject node
  155. return f"{self.source.name}: {self.desc} {self.dep.name if self.dep else '#'}"
  156. @compatibility(is_backward_compatible=False)
  157. class NodeEventTracker:
  158. """
  159. Tracks node events during the splitter execution.
  160. """
  161. def __init__(self, tracker_mode, dump_prefix):
  162. self.tracker_mode = tracker_mode
  163. self.dump_prefix = dump_prefix
  164. # list of events
  165. self.events = []
  166. # dict from node name to event index
  167. self.node_events = {}
  168. self.writer = print
  169. def add(self, node: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None):
  170. """
  171. Add a new event to the tracker.
  172. """
  173. event = NodeEvent(node, desc, dep)
  174. self.events.append(event)
  175. if node.name not in self.node_events:
  176. self.node_events[node.name] = []
  177. self.node_events[node.name].append(len(self.events) - 1)
  178. def print_node(self, node_name, recursive=False, tab="", writer=None):
  179. """
  180. Print a node and its events.
  181. @param recursive: if True, print nodes that caused the events on this current node.
  182. @param tab: Indentation for dependencies.
  183. @param writer: function to write to file. If None, use print.
  184. """
  185. if not writer:
  186. writer = self.writer
  187. for idx in self.node_events.get(node_name, []):
  188. event = self.events[idx]
  189. writer(tab + event.to_str())
  190. if recursive and event.dep is not None:
  191. self.print_node(
  192. event.dep.name, recursive=True, tab="| " + tab, writer=writer
  193. )
  194. def to_dict(self):
  195. """
  196. Create dict dump on all events.
  197. """
  198. ret: dict[str, list[str]] = {}
  199. for name in self.node_events:
  200. ret[name] = []
  201. for idx in self.node_events.get(name, []):
  202. event = self.events[idx]
  203. ret[name].append(event.to_str())
  204. return ret
  205. def print_all(self, writer=None):
  206. """
  207. Print all nodes in a list.
  208. @param writer: function to write to file. If None, use print.
  209. """
  210. if not writer:
  211. writer = self.writer
  212. for name in self.node_events:
  213. writer(f"Node: {name}:")
  214. self.print_node(name, recursive=False, tab=" ", writer=writer)
  215. def dump(self):
  216. """
  217. Function to be invoked at the end of the finder execution to printout tracked events specified by the mode.
  218. """
  219. # dump via trace_structured
  220. trace_structured(
  221. "artifact",
  222. metadata_fn=lambda: {
  223. "name": "fx_net_acc_splitter_finder_events",
  224. "encoding": "json",
  225. },
  226. payload_fn=lambda: json.dumps(self.to_dict()),
  227. )
  228. def writeln(f):
  229. def fn(x):
  230. return f.write(x + "\n")
  231. return fn
  232. # Mode 0: no local dump
  233. # Mode >=1: Dump all events to file
  234. if self.tracker_mode >= 1:
  235. with open(self.dump_prefix + ALL_SUFFIX, "w") as f:
  236. self.print_all(writeln(f))
  237. def dump_selected_nodes(nodes):
  238. with open(self.dump_prefix + NODES_SUFFIX, "w") as f:
  239. for node_name in nodes:
  240. writeln(f"===== Tracking node {node_name} =====")
  241. self.print_node(
  242. node_name, recursive=True, tab="|-", writer=writeln(f)
  243. )
  244. writeln(f"===== End of tracking node {node_name} =====")
  245. # Mode 2: Dump specific nodes in recursive manner.
  246. # Mode 3: Dump all nodes with more than 1 event in recursive manner.
  247. if self.tracker_mode == 2 or self.tracker_mode == 3:
  248. nodes = (
  249. os.environ.get(ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES, "").split(
  250. ","
  251. )
  252. if self.tracker_mode == 2
  253. else [
  254. name for name, events in self.node_events.items() if len(events) > 1
  255. ]
  256. )
  257. dump_selected_nodes(nodes)
  258. @compatibility(is_backward_compatible=False)
  259. class FxNetAccNodesFinder:
  260. """
  261. Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
  262. input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
  263. I.e. if we have a chain:
  264. ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
  265. where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
  266. This behavior can be turned off by passing allow_non_tensor=True.
  267. """
  268. def __init__(
  269. self,
  270. module: torch.fx.GraphModule,
  271. operator_support: OperatorSupportBase,
  272. allow_non_tensor: bool,
  273. ):
  274. self.module = module
  275. self.operator_support = operator_support
  276. self.allow_non_tensor = allow_non_tensor
  277. self.acc_nodes: NodeSet = set()
  278. self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX)
  279. def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
  280. """
  281. Transitively excludes nodes from ACC supported set.
  282. For every node in the worklist:
  283. - removes its downstream ACC nodes from ACC supported set,
  284. - if any downstream ACC node produces non-tensor output,
  285. then it gets added into the worklist.
  286. """
  287. while cpu_worklist:
  288. node = cpu_worklist.pop(0)
  289. for user in node.users:
  290. if user in self.acc_nodes:
  291. self.acc_nodes.remove(user)
  292. self.tracker.add(user, "acc_del|user_of_new_cpu_node", node)
  293. if not is_node_output_tensor(user):
  294. self.tracker.add(user, "new_cpu_node|non_tensor_output")
  295. cpu_worklist.append(user)
  296. def reduce_acc_nodes_non_tensor_input(self):
  297. """
  298. Excludes nodes from ACC supported set that have direct
  299. upstream CPU nodes that produce non-tensor outputs.
  300. """
  301. non_tensor_cpu_nodes: NodeList = []
  302. for node in self.module.graph.nodes:
  303. if node.op not in CALLABLE_NODE_OPS:
  304. continue
  305. if node in self.acc_nodes:
  306. continue
  307. if is_node_output_tensor(node):
  308. continue
  309. self.tracker.add(node, "new_cpu_node|callable_non_tensor_input")
  310. non_tensor_cpu_nodes.append(node)
  311. self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
  312. def reduce_acc_nodes_non_tensor_output(self):
  313. """
  314. Excludes nodes from ACC supported set that produce non-tensor
  315. outputs and have downstream CPU nodes.
  316. """
  317. while True:
  318. new_cpu_nodes: NodeList = []
  319. for acc_node in self.acc_nodes:
  320. if is_node_output_tensor(acc_node):
  321. continue
  322. for user in acc_node.users:
  323. if user not in self.acc_nodes:
  324. new_cpu_nodes.append(acc_node)
  325. self.tracker.add(
  326. acc_node, "acc_del|non_tensor_output_with_cpu_user", user
  327. )
  328. break
  329. if not new_cpu_nodes:
  330. break
  331. for new_cpu_node in new_cpu_nodes:
  332. self.acc_nodes.remove(new_cpu_node)
  333. self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
  334. def __call__(self) -> NodeSet:
  335. submodules = dict(self.module.named_modules())
  336. self.acc_nodes = set()
  337. for n in self.module.graph.nodes:
  338. if n.op not in CALLABLE_NODE_OPS:
  339. self.tracker.add(n, "init_cpu|not_callable")
  340. continue
  341. if not self.operator_support.is_node_supported(submodules, n):
  342. self.tracker.add(n, "init_cpu|operator_support")
  343. continue
  344. self.tracker.add(n, "init_acc|callable_and_operator_supported")
  345. self.acc_nodes.add(n)
  346. if not self.allow_non_tensor:
  347. self.reduce_acc_nodes_non_tensor_input()
  348. self.reduce_acc_nodes_non_tensor_output()
  349. self.tracker.dump()
  350. return self.acc_nodes
  351. @compatibility(is_backward_compatible=False)
  352. class FxNetSplitterInternalError(Exception):
  353. pass
  354. @compatibility(is_backward_compatible=False)
  355. @dataclass
  356. class Subgraph:
  357. is_acc: bool
  358. nodes: NodeList
  359. device_ordinal: Optional[int] = None
  360. @compatibility(is_backward_compatible=False)
  361. class SplitResult(NamedTuple):
  362. """
  363. Stores the results of the splitter.
  364. Attributes:
  365. split_module: root module after splitting.
  366. submodule_inputs: a dict that maps submodule name to its inputs.
  367. non_acc_submodule_prefix: the prefix for non acc submodules. For
  368. acc submodule the prefix is always "_run_on_acc_".
  369. """
  370. split_module: torch.fx.GraphModule
  371. submodule_inputs: dict[str, Any]
  372. non_acc_submodule_prefix: str
  373. @compatibility(is_backward_compatible=False)
  374. def generate_inputs_for_submodules(
  375. model: torch.nn.Module,
  376. inputs: Sequence[Any],
  377. target_submodules: Iterable[str],
  378. deepcopy: bool = False,
  379. ) -> dict[str, Any]:
  380. """
  381. Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
  382. function doesn't work.
  383. Args:
  384. model: root model.
  385. inputs: inputs to the root model.
  386. target_submodules: submodules that we want to generate inputs for.
  387. Returns:
  388. A dict that maps from submodule name to its inputs.
  389. """
  390. handles = []
  391. results = {}
  392. submodule_to_names = {mod: name for name, mod in model.named_modules()}
  393. def pre_forward(module, module_inputs):
  394. results[submodule_to_names[module]] = (
  395. copy.deepcopy(module_inputs) if deepcopy else module_inputs
  396. )
  397. for name, mod in model.named_modules():
  398. if name in target_submodules:
  399. if not isinstance(mod, torch.jit.ScriptModule):
  400. handles.append(mod.register_forward_pre_hook(pre_forward))
  401. def clean_up_handles():
  402. for h in handles:
  403. h.remove()
  404. try:
  405. with torch.no_grad():
  406. model(*inputs)
  407. except Exception as e:
  408. clean_up_handles()
  409. raise e
  410. clean_up_handles()
  411. return results
  412. class _SplitterBase:
  413. """
  414. Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
  415. Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
  416. Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
  417. Given the following graph:
  418. ==> b ==>
  419. // \\
  420. a d
  421. \\ //
  422. ==> c ==>
  423. class SimpleModule(torch.nn.Module):
  424. def forward(self, a):
  425. b = torch.sin(a)
  426. c = torch.cos(a)
  427. d = b + c
  428. return d
  429. and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
  430. we will get the following split result:
  431. main:
  432. def forward(self, a):
  433. run_on_acc_0_0 = self._run_on_acc_0_0(a)
  434. getitem = run_on_acc_0_0[0]
  435. getitem_1 = run_on_acc_0_0[1]
  436. run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
  437. return run_on_cpu_1_1
  438. _run_on_acc_0_0:
  439. def forward(self, a):
  440. sin_1 = torch.sin(a)
  441. cos_1 = torch.cos(a)
  442. return (sin_1, cos_1)
  443. _run_on_cpu_1_1:
  444. def forward(self, sin_1, cos_1):
  445. add_1 = sin_1 + cos_1
  446. return add_1
  447. """
  448. # PCIe bandwidth for the backend, default to 100 GB/s
  449. PCIe_BW = 100 * 2**30
  450. def __init__(
  451. self,
  452. module: torch.fx.GraphModule,
  453. sample_input: Sequence[Any],
  454. operator_support: OperatorSupportBase,
  455. settings: _SplitterSettingBase,
  456. non_acc_submodule_name: str = "_run_on_cpu_",
  457. return_tuple: bool = False,
  458. nodes_finder: Optional[FxNetAccNodesFinder] = None,
  459. ):
  460. """
  461. Preprocesses graph before splitting:
  462. - finds nodes supported by ACC,
  463. - finds fusion groups for ACC nodes having non-tensor IO,
  464. - builds a graph of direct dependencies,
  465. - builds a map of fused nodes to their fusions.
  466. As a result we get self.acc_nodes, self.deps and self.fusions.
  467. """
  468. if not isinstance(module, torch.fx.GraphModule):
  469. raise AssertionError(f"Expected GraphModule, got {type(module)}")
  470. self.module = module
  471. ShapeProp(self.module).propagate(*sample_input)
  472. self.settings = settings
  473. self.operator_support = operator_support
  474. self.sample_input = sample_input
  475. if nodes_finder is None:
  476. nodes_finder = FxNetAccNodesFinder(
  477. self.module, self.operator_support, self.settings.allow_non_tensor
  478. )
  479. self.acc_nodes = nodes_finder()
  480. if self.settings.skip_fusion:
  481. self.fusions = {}
  482. else:
  483. self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
  484. # Modify deps to add more deps for fused nodes
  485. self.deps = self.find_deps()
  486. self.update_deps_for_fusions()
  487. self.non_acc_submodule_name = non_acc_submodule_name
  488. self._node_submodule_map: dict[str, str] = {}
  489. self._return_tuple = return_tuple
  490. self.tags: list[str] = []
  491. # ===============================================================
  492. # Helpers for ctor and initial state
  493. # ===============================================================
  494. def get_node_submodule_map(self) -> dict[str, str]:
  495. """Returns a map from node name to submodule name, e.g.
  496. node: main_module_impl_impl_over_arch_unary_multiple_embedding
  497. _pooling_embedding_pooling_sparse_entity_equivalence_key
  498. _proxy_embedding_bag
  499. maps to submodule name of: _run_on_acc_1
  500. """
  501. return self._node_submodule_map
  502. def find_deps(self) -> dict[torch.fx.Node, NodeSet]:
  503. """
  504. Builds a graph of node dependencies. Leaf nodes don't have any
  505. dependencies and the "output" node doesn't have nodes depending on it.
  506. Resulting graph has only direct dependencies, i.e. there are no
  507. transitive dependencies.
  508. """
  509. deps: dict[torch.fx.Node, NodeSet] = defaultdict(set)
  510. for node in self.module.graph.nodes:
  511. if node.op not in CALLABLE_NODE_OPS:
  512. continue
  513. for user in node.users:
  514. if user.op != "output":
  515. deps[user].add(node)
  516. return deps
  517. def update_deps_for_fusions(self):
  518. """
  519. Updates graph of dependencies so that:
  520. - nodes from the same fusion depend on the same set of outer nodes,
  521. - outer nodes depending on a fusion depend on all nodes in that fusion.
  522. """
  523. for node in self.fusions:
  524. fusion = self.fusions[node]
  525. for fused_neighbor in fusion:
  526. self.deps[node].update(self.deps[fused_neighbor] - fusion)
  527. for user in fused_neighbor.users:
  528. if user not in fusion:
  529. self.deps[user].add(node)
  530. # ===============================================================
  531. # Helpers for preview
  532. # ===============================================================
  533. def _lower_model_to_backend(
  534. self, mod: torch.fx.GraphModule, inputs: Tensors
  535. ) -> torch.nn.Module:
  536. """
  537. Lower the model to a backend.
  538. """
  539. return mod
  540. def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str:
  541. """
  542. When an error occurs during lowering or running the lowered mod, we use this
  543. function to find culprits in the `mod` that causes the error.
  544. """
  545. return "Unable to find a culprit because _find_culprit() function is not implemented."
  546. def _draw_graph_based_on_node_support(
  547. self, mod: torch.fx.GraphModule, supported_nodes: NodeList
  548. ):
  549. color_map = {
  550. "default": "AliceBlue",
  551. "supported": "chartreuse1",
  552. "unsupported": "crimson",
  553. }
  554. class CustomDrawer(FxGraphDrawer):
  555. def _get_node_style(self, node):
  556. template = super()._get_node_style(node)
  557. if node in supported_nodes:
  558. template["fillcolor"] = color_map["supported"]
  559. elif node.op in CALLABLE_NODE_OPS:
  560. template["fillcolor"] = color_map["unsupported"]
  561. else:
  562. template["fillcolor"] = color_map["default"]
  563. return template
  564. drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
  565. dot_graph = drawer.get_main_dot_graph()
  566. # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
  567. dot_graph.write_raw("node_support.dot") # type: ignore[attr-defined]
  568. def node_support_preview(self, dump_graph: bool = False):
  569. submodules = dict(self.module.named_modules())
  570. supported_nodes: NodeList = []
  571. supported_node_types = defaultdict(set)
  572. unsupported_node_types = defaultdict(set)
  573. def get_dtype(arg):
  574. tensor_meta = arg.meta.get("tensor_meta")
  575. return getattr(tensor_meta, "dtype", None)
  576. for node in self.module.graph.nodes:
  577. if node.op not in CALLABLE_NODE_OPS:
  578. continue
  579. target = get_node_target(submodules, node)
  580. # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
  581. arg_dtypes = [
  582. get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
  583. for arg in node.args
  584. ]
  585. # Find last non-None element. If all elements are None, return max_len.
  586. last_index = len(arg_dtypes) - next(
  587. (
  588. i
  589. for i, dtype in enumerate(reversed(arg_dtypes))
  590. if dtype is not None
  591. ),
  592. len(arg_dtypes),
  593. )
  594. # Strip None elements at the end.
  595. arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
  596. kwarg_dtypes_tuple = tuple(
  597. (k, get_dtype(arg))
  598. for k, arg in node.kwargs.items()
  599. if isinstance(arg, torch.fx.Node)
  600. )
  601. if self.operator_support.is_node_supported(submodules, node):
  602. supported_nodes.append(node)
  603. supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
  604. else:
  605. unsupported_node_types[target].add(
  606. (arg_dtypes_tuple, kwarg_dtypes_tuple)
  607. )
  608. if dump_graph:
  609. self._draw_graph_based_on_node_support(self.module, supported_nodes)
  610. reports = "\nSupported node types in the model:\n"
  611. for t, dtypes in supported_node_types.items():
  612. for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
  613. reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
  614. reports += "\nUnsupported node types in the model:\n"
  615. for t, dtypes in unsupported_node_types.items():
  616. for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
  617. reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
  618. print(reports)
  619. # Return reports for testing purpose
  620. return reports
  621. def split_preview(self, dump_graph: bool = False):
  622. reports = ""
  623. subgraphs = self.put_nodes_into_subgraphs()
  624. acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
  625. cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
  626. reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
  627. reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
  628. subgraphs = self.remove_small_acc_subgraphs(subgraphs)
  629. acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
  630. cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
  631. reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
  632. reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
  633. for i, subgraph in enumerate(subgraphs):
  634. reports += (
  635. f"_run_on_acc_{i}: "
  636. if subgraph.is_acc
  637. else f"{self.non_acc_submodule_name}{i}: "
  638. )
  639. reports += f"{len(subgraph.nodes)} node(s)\n"
  640. self.tag(subgraphs)
  641. split_mod = self.split(remove_tag=True)
  642. split_mod.eval()
  643. if dump_graph:
  644. drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True)
  645. dot_graphs = drawer.get_all_dot_graphs()
  646. for name, dot_graph in dot_graphs.items():
  647. # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`.
  648. dot_graph.write_raw(f"{name}.dot") # type: ignore[attr-defined]
  649. max_qps: float = self.PCIe_BW
  650. bottleneck_module = ""
  651. for node in split_mod.graph.nodes:
  652. if node.op == "call_module" and "acc" in node.target:
  653. reports += f"\nProcessing acc submodule {node.target}\n"
  654. submod = getattr(split_mod, node.target)
  655. def get_submod_inputs(main_mod, submod, example_inputs):
  656. sub_inputs = None
  657. def get_inputs(self, inputs):
  658. nonlocal sub_inputs
  659. sub_inputs = inputs
  660. handle = submod.register_forward_pre_hook(get_inputs)
  661. main_mod(*example_inputs)
  662. handle.remove()
  663. return sub_inputs
  664. submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input)
  665. ShapeProp(submod).propagate(*submod_inputs)
  666. total_input_bytes = 0
  667. total_output_bytes = 0
  668. reports += "Checking inputs...\n"
  669. for n in submod.graph.nodes:
  670. if n.op == "placeholder":
  671. if not is_node_output_tensor(n):
  672. reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
  673. else:
  674. total_input_bytes += get_size_of_node(submod, n)[0]
  675. if n.op == "output":
  676. output_node = n
  677. reports += "Checking outputs...\n"
  678. def get_bytes(node: torch.fx.Node):
  679. nonlocal total_output_bytes
  680. nonlocal reports
  681. if not is_node_output_tensor(node):
  682. reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
  683. else:
  684. total_output_bytes += get_size_of_node(submod, node)[0]
  685. map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
  686. qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
  687. reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
  688. reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
  689. if qps < max_qps:
  690. max_qps = qps
  691. bottleneck_module = node.target
  692. try:
  693. lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
  694. except RuntimeError:
  695. reports += "Run into an error during lowering!\n"
  696. reports += self._find_culprit(submod, submod_inputs)
  697. continue
  698. try:
  699. lowered_submod(*submod_inputs)
  700. except RuntimeError:
  701. reports += "Run into an error during inference!\n"
  702. reports += self._find_culprit(submod, submod_inputs)
  703. else:
  704. reports += "Lowering and running succeed!\n"
  705. reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
  706. reports += f" bottleneck is submodule {bottleneck_module}."
  707. print(reports)
  708. # return the reports for testing purposes
  709. return reports
  710. # ===============================================================
  711. # Helpers for extend_acc_subgraph() method
  712. # ===============================================================
  713. def find_reverse_deps(
  714. self, tag_id: Optional[int] = None
  715. ) -> dict[torch.fx.Node, NodeSet]:
  716. """
  717. Builds reversed topological node dependencies, if tag_id is specified,
  718. we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
  719. """
  720. result: dict[torch.fx.Node, NodeSet] = defaultdict(set)
  721. for node in self.module.graph.nodes:
  722. if node.op not in CALLABLE_NODE_OPS:
  723. continue
  724. for user in node.users:
  725. if user.op not in CALLABLE_NODE_OPS:
  726. continue
  727. if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
  728. result[node].add(user)
  729. return result
  730. def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]):
  731. processed_node = set()
  732. for node, fusion in self.fusions.items():
  733. if node in processed_node:
  734. continue
  735. new_dep = set()
  736. # Create a new dependency set which include all the
  737. # dependencies of the nodes in the fusion group
  738. for n in fusion:
  739. new_dep.update(deps[n])
  740. # Exclude nodes in the fusion
  741. new_dep.difference_update(fusion)
  742. # Update dependency
  743. for n in fusion:
  744. deps[n] = new_dep
  745. for arg in n.all_input_nodes:
  746. if arg not in fusion:
  747. deps[arg].update(fusion)
  748. processed_node.add(n)
  749. def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
  750. """
  751. Finds parent nodes of the `tag` subgraph.
  752. Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
  753. and is not a placeholder, we consider it as the parent node of the subgraph.
  754. """
  755. parent_nodes = set()
  756. for node in self.module.graph.nodes:
  757. if node.op in CALLABLE_NODE_OPS and node.tag == tag:
  758. for arg in node.all_input_nodes:
  759. if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
  760. parent_nodes.add(arg)
  761. return parent_nodes
  762. def extend_acc_subgraph(self, tag: str):
  763. """
  764. Extend the acc subgraph with `tag` going the reversed topological direction.
  765. """
  766. # Dict that maps node to its users and ignore users that
  767. # are in the subgraph that has greater tag
  768. deps = self.find_reverse_deps(tag_id=int(tag.rsplit("_", maxsplit=1)[-1]))
  769. self.update_reverse_deps_for_fusions(deps)
  770. # Parent nodes of the subgraph
  771. parent_nodes = self.find_parent_nodes_of_subgraph(tag)
  772. visited_nodes: NodeSet = set()
  773. while parent_nodes:
  774. node = None
  775. # Find a acc node that depends on visited nodes only
  776. for n in parent_nodes:
  777. if deps[n] <= visited_nodes and n in self.acc_nodes:
  778. node = n
  779. break
  780. if node is None:
  781. break
  782. # Put the node into `tag` subgraph
  783. node.tag = tag # type: ignore[attr-defined]
  784. parent_nodes.remove(node)
  785. visited_nodes.add(node)
  786. # If node is in a fusion group, add all fusion buddies to parent nodes
  787. if node in self.fusions:
  788. for fusion_node in self.fusions[node]:
  789. if fusion_node not in visited_nodes:
  790. parent_nodes.add(fusion_node)
  791. # Add inputs of the node to parent nodes
  792. for arg in node.all_input_nodes:
  793. if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
  794. parent_nodes.add(arg)
  795. # ===============================================================
  796. # Helpers for split() method
  797. # ===============================================================
  798. def starter_nodes(self) -> tuple[NodeSet, NodeSet]:
  799. """
  800. Finds nodes that consume module inputs or get_attr nodes.
  801. """
  802. starter_cpu_nodes: NodeSet = set()
  803. starter_acc_nodes: NodeSet = set()
  804. for node in self.module.graph.nodes:
  805. # edge case, call_function, but with no dependencies
  806. if node.op == "call_function" and len(node.all_input_nodes) == 0:
  807. if node in self.acc_nodes:
  808. starter_acc_nodes.add(node)
  809. else:
  810. starter_cpu_nodes.add(node)
  811. if node.op not in {"placeholder", "get_attr"}:
  812. continue
  813. for user in node.users:
  814. if user in self.acc_nodes:
  815. starter_acc_nodes.add(user)
  816. else:
  817. starter_cpu_nodes.add(user)
  818. return starter_cpu_nodes, starter_acc_nodes
  819. def put_nodes_into_subgraphs(self) -> list[Subgraph]:
  820. # We start graph traversal from leaf nodes
  821. current_cpu_nodes, current_acc_nodes = self.starter_nodes()
  822. visited_nodes: NodeSet = set()
  823. # Determine which subgraph to start from based on which subgraph has
  824. # 0-dep node
  825. acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
  826. current_subgraph_nodes: NodeList = []
  827. # Result accumulator
  828. subgraphs: list[Subgraph] = []
  829. while current_cpu_nodes or current_acc_nodes:
  830. # Find the first node that should belong to the current subgraph and has all dependencies resolved
  831. current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
  832. node = next(
  833. (n for n in current_nodes if self.deps[n] <= visited_nodes),
  834. None,
  835. )
  836. # If nothing was found, then it's time to flip the mode and start a new subgraph
  837. if node is None:
  838. if not current_subgraph_nodes:
  839. raise FxNetSplitterInternalError("Subgraph can't be empty")
  840. subgraphs.append(
  841. Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
  842. )
  843. acc_subgraph = not acc_subgraph
  844. current_subgraph_nodes = []
  845. continue
  846. current_nodes.remove(node)
  847. visited_nodes.add(node)
  848. current_subgraph_nodes.append(node)
  849. # Add fusion buddies
  850. if node in self.fusions:
  851. if node in self.acc_nodes:
  852. current_acc_nodes.update(self.fusions[node] - visited_nodes)
  853. else:
  854. current_cpu_nodes.update(self.fusions[node] - visited_nodes)
  855. # Put depending nodes into the queue
  856. for user in node.users:
  857. if user.op not in CALLABLE_NODE_OPS:
  858. continue
  859. # Add downstream nodes
  860. if user in self.acc_nodes:
  861. current_acc_nodes.add(user)
  862. else:
  863. current_cpu_nodes.add(user)
  864. # Check if the last subgraph was not created
  865. if current_subgraph_nodes:
  866. subgraphs.append(
  867. Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
  868. )
  869. if not subgraphs:
  870. raise FxNetSplitterInternalError("Couldn't create subgraphs")
  871. return subgraphs
  872. def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]:
  873. """
  874. This pass finds ACC submodules with less than specified size and merges
  875. them with adjacent CPU submodules.
  876. """
  877. result: list[Subgraph] = []
  878. for subgraph in subgraphs:
  879. if subgraph.is_acc:
  880. if len(subgraph.nodes) >= self.settings.min_acc_module_size:
  881. result.append(subgraph)
  882. else:
  883. print(
  884. "Eliminating acc subgraph because it's smaller than the threshold: "
  885. f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
  886. )
  887. if result:
  888. result[-1].nodes.extend(subgraph.nodes)
  889. else:
  890. subgraph.is_acc = False
  891. result.append(subgraph)
  892. else:
  893. if result and not result[-1].is_acc:
  894. result[-1].nodes.extend(subgraph.nodes)
  895. else:
  896. result.append(subgraph)
  897. return result
  898. def tag(self, subgraphs: list[Subgraph]):
  899. self.tags = []
  900. for subgraph in subgraphs:
  901. tag = (
  902. f"_run_on_acc_{len(self.tags)}"
  903. if subgraph.is_acc
  904. else f"{self.non_acc_submodule_name}{len(self.tags)}"
  905. )
  906. self.tags.append(tag)
  907. for node in subgraph.nodes:
  908. if hasattr(node, "tag"):
  909. raise FxNetSplitterInternalError(f"Node {node} was already tagged")
  910. node.tag = tag # type: ignore[attr-defined]
  911. self._node_submodule_map[node.name] = tag
  912. def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
  913. split_module = split_by_tags(
  914. self.module, self.tags, return_tuple=self._return_tuple
  915. )
  916. if remove_tag:
  917. for node in self.module.graph.nodes:
  918. if hasattr(node, "tag"):
  919. del node.tag
  920. return split_module # type: ignore[return-value]
  921. def __call__(self) -> torch.fx.GraphModule:
  922. subgraphs = self.put_nodes_into_subgraphs()
  923. if self.settings.move_non_tensor_nodes_on_boundary:
  924. move_non_tensor_nodes_on_boundary(subgraphs)
  925. subgraphs = self.remove_small_acc_subgraphs(subgraphs)
  926. acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
  927. non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
  928. print(
  929. f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs"
  930. )
  931. self.tag(subgraphs)
  932. return self.split()
  933. def generate_split_results(self) -> SplitResult:
  934. split_module = self()
  935. submodule_names = []
  936. for name, _mod in split_module.named_children():
  937. submodule_names.append(name)
  938. if (
  939. self.settings.max_acc_splits > 0
  940. and len(submodule_names) > self.settings.max_acc_splits
  941. ):
  942. raise ValueError(
  943. "Cannot fulfill max_acc_splits limit. "
  944. "This may cause split fragmentation and "
  945. "result in performance issues."
  946. )
  947. submodule_inputs = generate_inputs_for_submodules(
  948. split_module, self.sample_input, submodule_names
  949. )
  950. return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)