split_module.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. import logging
  4. from collections import OrderedDict
  5. from collections.abc import Callable
  6. from typing import Any, Optional
  7. import torch
  8. from torch.fx._compatibility import compatibility
  9. from torch.fx._utils import lazy_format_graph_code
  10. from torch.fx.graph_module import GraphModule
  11. from torch.fx.node import Node
  12. __all__ = ["Partition", "split_module"]
  13. log = _LOGGER = logging.getLogger(__name__)
  14. @compatibility(is_backward_compatible=True)
  15. class Partition:
  16. def __init__(self, name: str):
  17. self.name: str = name
  18. self.submod_name = f"submod_{name}"
  19. self.node_names: list[str] = []
  20. self.inputs: dict[str, None] = {}
  21. self.outputs: dict[str, None] = {}
  22. self.dependencies: dict[str, None] = {}
  23. self.dependents: dict[str, None] = {}
  24. self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
  25. self.environment: dict[Node, Node] = {}
  26. self.targets: dict[str, Any] = {}
  27. def __repr__(self) -> str:
  28. return (
  29. f"name: {self.name},\n"
  30. f" nodes: {self.node_names},\n"
  31. f" inputs: {self.inputs},\n"
  32. f" outputs: {self.outputs},\n"
  33. f" partitions depended on: {self.dependencies},\n"
  34. f" partition dependents: {self.dependents}"
  35. )
  36. def _get_attr_from_qualname(mod: torch.nn.Module, qualname: str) -> Any:
  37. attr_val = mod
  38. for atom in qualname.split("."): # type: ignore[union-attr]
  39. if not hasattr(attr_val, atom):
  40. raise AttributeError(f"Node target {qualname} not found!")
  41. attr_val = getattr(attr_val, atom)
  42. return attr_val
  43. # Creates subgraphs out of main graph
  44. @compatibility(is_backward_compatible=True)
  45. def split_module(
  46. m: GraphModule,
  47. root_m: torch.nn.Module,
  48. split_callback: Callable[[Node], int],
  49. qualname_map: Optional[dict[str, str]] = None,
  50. keep_original_order: Optional[bool] = False,
  51. keep_original_node_name: Optional[bool] = False,
  52. keep_original_input_name: bool = True,
  53. *,
  54. partition_affix: Optional[str] = None,
  55. ):
  56. """
  57. Creates subgraphs out of main graph
  58. Args:
  59. m (GraphModule): Graph module to split
  60. root_m (torch.nn.Module): root nn module. Not currently used. Included
  61. because the root nn module is usually transformed via
  62. torch.fx._symbolic_trace.symbolic_trace (see example below)
  63. split_callback (Callable[[Node], int]): Callable function
  64. that maps a given Node instance to a numeric partition identifier.
  65. split_module will use this function as the policy for which operations
  66. appear in which partitions in the output Module.
  67. qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
  68. mapping from new target names in the module after split to old target
  69. names in the original module.
  70. keep_original_order: Optional[bool]: keep the original order of the GraphModule
  71. or use the Topological order of the new constructed GraphModule
  72. keep_original_node_name: Optional[bool]: If the partitioned graphs should
  73. have the same node names as the original graph.
  74. keep_original_input_name: bool: If the partitioned graphs should
  75. have the same input names as the original graph.
  76. partition_affix: Optional[str]: If specified, the submodules' names will contain
  77. the affix, e.g. "submod_<affix>_<idx>".
  78. Returns:
  79. GraphModule: the module after split.
  80. Example:
  81. This is a sample setup:
  82. import torch
  83. from torch.fx._symbolic_trace import symbolic_trace
  84. from torch.fx.graph_module import GraphModule
  85. from torch.fx.node import Node
  86. from torch.fx.passes.split_module import split_module
  87. class MyModule(torch.nn.Module):
  88. def __init__(self) -> None:
  89. super().__init__()
  90. self.param = torch.nn.Parameter(torch.rand(3, 4))
  91. self.linear = torch.nn.Linear(4, 5)
  92. def forward(self, x, y):
  93. z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
  94. w = self.linear(y).clamp(min=0.0, max=1.0)
  95. return z + w
  96. # symbolically trace model
  97. my_module = MyModule()
  98. my_module_traced = symbolic_trace(my_module)
  99. # random mod partitioning
  100. partition_counter = 0
  101. NPARTITIONS = 3
  102. def mod_partition(node: Node):
  103. global partition_counter
  104. partition = partition_counter % NPARTITIONS
  105. partition_counter = (partition_counter + 1) % NPARTITIONS
  106. return partition
  107. # split module in module with submodules
  108. module_with_submodules = split_module(
  109. my_module_traced, my_module, mod_partition
  110. )
  111. Output looks like this. Original graph is broken into partitions
  112. > print(module_with_submodules)
  113. GraphModule(
  114. (submod_0): GraphModule(
  115. (linear): Linear(in_features=4, out_features=5, bias=True)
  116. )
  117. (submod_1): GraphModule(
  118. (linear): Linear(in_features=4, out_features=5, bias=True)
  119. )
  120. (submod_2): GraphModule()
  121. )
  122. def forward(self, x, y):
  123. param = self.param
  124. submod_0 = self.submod_0(x, param, y); x = param = y = None
  125. getitem = submod_0[0]
  126. getitem_1 = submod_0[1]; submod_0 = None
  127. submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
  128. getitem_2 = submod_1[0]
  129. getitem_3 = submod_1[1]; submod_1 = None
  130. submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
  131. return submod_2
  132. Output of split module is the same as output of input traced module.
  133. This is an example within a test setting:
  134. > orig_out = my_module_traced(x, y)
  135. > submodules_out = module_with_submodules(x, y)
  136. > self.assertEqual(orig_out, submodules_out)
  137. True
  138. """
  139. log.debug(
  140. "%s",
  141. lazy_format_graph_code("pre split_module", m, colored=True),
  142. )
  143. def construct_graph(
  144. node: Node,
  145. base_mod_env: dict[str, Node],
  146. base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule],
  147. ):
  148. if node.op == "placeholder":
  149. default_value = (
  150. node.args[0] if len(node.args) > 0 else inspect.Signature.empty
  151. )
  152. if keep_original_node_name:
  153. args = (
  154. () if default_value is inspect.Signature.empty else (default_value,)
  155. )
  156. base_mod_env[node.name] = base_mod_graph.create_node(
  157. "placeholder",
  158. node.name,
  159. args=args, # type: ignore[arg-type]
  160. type_expr=node.type,
  161. )
  162. else:
  163. base_mod_env[node.name] = base_mod_graph.placeholder(
  164. node.target, # type: ignore[arg-type]
  165. type_expr=node.type,
  166. default_value=default_value,
  167. )
  168. base_mod_env[node.name].meta = node.meta.copy()
  169. elif node.op == "get_attr":
  170. base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type]
  171. base_mod_env[node.name].meta = node.meta.copy()
  172. if not isinstance(node.target, str):
  173. raise AssertionError(f"Expected str target, got {type(node.target)}")
  174. attr_val = _get_attr_from_qualname(m, node.target)
  175. base_mod_attrs[node.target] = attr_val # type: ignore[index]
  176. return base_mod_env, base_mod_attrs
  177. import sympy
  178. partitions: dict[str, Partition] = {}
  179. orig_nodes: dict[str, Node] = {}
  180. symbol_to_node: dict[sympy.Symbol, Node] = {}
  181. def record_cross_partition_use(def_node: Node, use_node: Optional[Node]):
  182. from torch.fx.experimental.symbolic_shapes import free_symbols
  183. defined = getattr(def_node, "_fx_partition", None)
  184. used = getattr(use_node, "_fx_partition", None)
  185. log.debug(
  186. "record_cross_partition_use %s (%s) %s (%s)",
  187. def_node.name,
  188. defined,
  189. use_node.name if use_node is not None else "-",
  190. used,
  191. )
  192. if defined != used:
  193. if defined is not None:
  194. def_partition = partitions[defined]
  195. def_partition.outputs.setdefault(def_node.name)
  196. if used is not None:
  197. def_partition.dependents.setdefault(used)
  198. if used is not None:
  199. use_partition = partitions[used]
  200. use_partition.inputs.setdefault(def_node.name)
  201. # We have made def_node an input to the use_partition. If
  202. # this input has symbolic symbols in its size, those also must
  203. # be made as inputs to the partition
  204. if (def_val := def_node.meta.get("example_value")) is not None:
  205. for s in sorted(free_symbols(def_val), key=str):
  206. s_node = symbol_to_node[s]
  207. use_partition.inputs.setdefault(s_node.name)
  208. if symbol_to_node[s].op != "placeholder":
  209. # If the node that defines the symbol is not a
  210. # placeholder, we must make it an output of the
  211. # partition. Note that this may be in a different
  212. # partition than defined! Although, this doesn't
  213. # really make a difference for correctness, since
  214. # defined is guaranteed to have the symbol in
  215. # scope and can return it; you just get less
  216. # optimal codegen in this case.
  217. s_defined = getattr(s_node, "_fx_partition", None)
  218. if s_defined is not None:
  219. s_def_partition = partitions[s_defined]
  220. s_def_partition.outputs.setdefault(s_node.name)
  221. s_def_partition.dependents.setdefault(used)
  222. use_partition.dependencies.setdefault(s_defined)
  223. if defined is not None:
  224. use_partition.dependencies.setdefault(defined)
  225. def instantiate_node_partition_mapping(node):
  226. partition_idx = split_callback(node)
  227. partition_name = str(partition_idx)
  228. if partition_affix is not None:
  229. # For example, if user specifies partition_affix = "pp", then the
  230. # partition name will be "pp_0", "pp_1", etc
  231. partition_name = "_".join([partition_affix, partition_name])
  232. log.debug(
  233. "instantiate_node_partition_mapping %s (%s)", node.name, partition_name
  234. )
  235. # add node to partitions
  236. partition = partitions.get(partition_name)
  237. if partition is None:
  238. partitions[partition_name] = partition = Partition(partition_name)
  239. partition.node_names.append(node.name)
  240. node._fx_partition = partition_name
  241. # Global State Nodes are nodes which by their global state effects,
  242. # "taint" all downstream nodes while they are active.
  243. GLOBAL_STATE_NODES = [
  244. torch.amp._enter_autocast,
  245. torch.amp._exit_autocast,
  246. torch._C._set_grad_enabled,
  247. ]
  248. # For grad regions:
  249. # ------------------------
  250. # 1. first region: we do nothing
  251. # 2. subsequent regions: we insert the set_grad at the beginning
  252. grad_regions: OrderedDict[Node, set[int]] = OrderedDict()
  253. # For autocast regions:
  254. # ------------------------
  255. # 1. first region: we will only insert the _exit at the end
  256. # 2. intermediate regions: we will insert both the
  257. # _enter at the beginning and _exit at the end
  258. # 3. last region: we will only insert _enter at the beginning
  259. # We will do so in the order in which the autocasts were instantiated.
  260. autocast_regions: OrderedDict[Node, set[int]] = OrderedDict()
  261. autocast_exits: dict[Node, Optional[Node]] = {}
  262. active_grad = None
  263. active_autocasts = set()
  264. for node in m.graph.nodes:
  265. # This will prefer placeholder bindings, because those come first.
  266. # This is a little dangerous though: it is possible that an unbacked
  267. # symbol is used without any binding site for it, in which case we
  268. # will get a KeyError not able to find it. I'd like to fix this by
  269. # having passes.runtime_assert establish some invariants that I can
  270. # rely on later, but this needs some extra work. Quick fix first.
  271. # See https://github.com/pytorch/pytorch/issues/130534
  272. if (
  273. (val := node.meta.get("example_value")) is not None
  274. and isinstance(val, (torch.SymInt, torch.SymFloat))
  275. and isinstance(s0 := val.node.expr, sympy.Symbol)
  276. and s0 not in symbol_to_node
  277. ):
  278. symbol_to_node[val.node.expr] = node
  279. if node.op in ["placeholder", "get_attr", "output"]:
  280. continue
  281. instantiate_node_partition_mapping(node)
  282. if node.op == "call_function" and node.target in GLOBAL_STATE_NODES:
  283. if node.target is torch._C._set_grad_enabled:
  284. if len(node.args) != 1:
  285. raise AssertionError(
  286. f"Expected 1 arg for _set_grad_enabled, got {len(node.args)}"
  287. )
  288. if not isinstance(node.args[0], bool):
  289. raise AssertionError(f"Expected bool arg, got {type(node.args[0])}")
  290. active_grad = node
  291. grad_regions[active_grad] = set({split_callback(node)})
  292. elif node.target is torch.amp._enter_autocast:
  293. # Should all be python constants
  294. if not all(not isinstance(arg, Node) for arg in node.args):
  295. raise AssertionError(
  296. "Expected all args to be python constants, not Nodes"
  297. )
  298. active_autocasts.add(node)
  299. autocast_regions[node] = set({split_callback(node)})
  300. autocast_exits[node] = None
  301. elif node.target is torch.amp._exit_autocast:
  302. if len(node.args) != 1:
  303. raise AssertionError(
  304. f"Expected 1 arg for _exit_autocast, got {len(node.args)}"
  305. )
  306. autocast_regions[node.args[0]].add(split_callback(node))
  307. active_autocasts.remove(node.args[0])
  308. autocast_exits[node.args[0]] = node
  309. if active_grad is not None:
  310. grad_regions[active_grad].add(split_callback(node))
  311. for a in active_autocasts:
  312. autocast_regions[a].add(split_callback(node))
  313. if not all(v is not None for v in autocast_exits.values()):
  314. raise AssertionError("autocast must exit")
  315. # pyrefly: ignore [bad-assignment]
  316. autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
  317. # pyrefly: ignore [bad-assignment]
  318. grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
  319. if _LOGGER.isEnabledFor(logging.DEBUG):
  320. _LOGGER.debug("autocast_regions: %s", autocast_regions)
  321. _LOGGER.debug("grad_regions: %s", grad_regions)
  322. assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions)
  323. # split nodes into partitions
  324. highest_partition = -1
  325. for node in m.graph.nodes:
  326. orig_nodes[node.name] = node
  327. # TODO currently placeholders/parameters aren't put into random partitions,
  328. # rather they're added to the graphs where they are used down below
  329. if node.op in ["placeholder", "get_attr"]:
  330. continue
  331. if node.op == "output":
  332. torch.fx.graph.map_arg(
  333. node.args[0], lambda n: record_cross_partition_use(n, None)
  334. )
  335. continue
  336. if assert_monotonically_increasing:
  337. pid = split_callback(node)
  338. if highest_partition > pid:
  339. raise AssertionError(
  340. "autocast or set_grad_enabled require monotonically increasing "
  341. f"partitions: highest: {highest_partition}, this node's: {pid}"
  342. )
  343. highest_partition = pid
  344. # do not capture cross-partition dependencies for global state nodes as they will be
  345. # self-contained - their setup and unwind will be isolated to each partition submodule.
  346. if node.target not in GLOBAL_STATE_NODES:
  347. torch.fx.graph.map_arg(
  348. node.args, lambda def_node: record_cross_partition_use(def_node, node)
  349. )
  350. torch.fx.graph.map_arg(
  351. node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
  352. ) # noqa: B950
  353. original_partition_order = list(partitions.keys())
  354. # find partitions with no dependencies
  355. root_partitions: list[str] = []
  356. for partition_name, partition in partitions.items():
  357. if not len(partition.dependencies):
  358. root_partitions.append(partition_name)
  359. # check partitions for circular dependencies and create topological partition ordering
  360. sorted_partitions: list[str] = []
  361. while root_partitions:
  362. root_partition = root_partitions.pop()
  363. sorted_partitions.append(root_partition)
  364. for dependent in partitions[root_partition].dependents:
  365. partitions[dependent].dependencies.pop(root_partition) # noqa: B909
  366. if not partitions[dependent].dependencies:
  367. root_partitions.append(dependent)
  368. if len(sorted_partitions) != len(partitions):
  369. raise RuntimeError("cycle exists between partitions!")
  370. # Enter prelude
  371. for regions_mapping in [autocast_regions, grad_regions]:
  372. for node, regions in regions_mapping.items():
  373. if len(regions) == 0:
  374. raise AssertionError("Expected at least one region for node")
  375. # pyrefly: ignore [bad-index]
  376. partitions[str(regions[0])].environment[node] = node
  377. # pyrefly: ignore [bad-index, index-error]
  378. # pyrefly: ignore [bad-index, index-error]
  379. for r in regions[1:]:
  380. partition = partitions[str(r)]
  381. new_node = partition.graph.create_node(
  382. op=node.op,
  383. target=node.target,
  384. args=tuple(arg for arg in node.args),
  385. kwargs={},
  386. type_expr=node.type,
  387. )
  388. new_node.meta = (
  389. node.meta.copy()
  390. ) # is it really a good idea to copy this?
  391. partition.environment[node] = new_node
  392. # add placeholders to partition inputs
  393. for partition_name in sorted_partitions:
  394. partition = partitions[partition_name]
  395. new_inputs: dict[str, None] = {}
  396. counter = 0
  397. for inp in partition.inputs:
  398. orig_node = orig_nodes[inp]
  399. # We don't pass in get_attr nodes as inputs to the partition, but
  400. # instead set them as targets and use getattr within the module
  401. def add_placeholder():
  402. if keep_original_input_name:
  403. name = inp
  404. else:
  405. nonlocal counter
  406. name = f"arg_{counter}"
  407. counter += 1
  408. placeholder = partition.graph.placeholder(
  409. name,
  410. type_expr=orig_nodes[inp].type,
  411. )
  412. new_inputs[inp] = None
  413. return placeholder
  414. if orig_node.op == "get_attr":
  415. if not isinstance(orig_node.target, str):
  416. raise AssertionError(
  417. f"Expected str target, got {type(orig_node.target)}"
  418. )
  419. orig_attr = _get_attr_from_qualname(m, orig_node.target)
  420. if isinstance(orig_attr, torch.nn.Module):
  421. placeholder = partition.graph.get_attr(orig_node.target)
  422. partition.targets[orig_node.target] = orig_attr
  423. else:
  424. placeholder = add_placeholder()
  425. else:
  426. placeholder = add_placeholder()
  427. placeholder.meta = orig_nodes[inp].meta.copy()
  428. partition.environment[orig_nodes[inp]] = placeholder
  429. partition.inputs = new_inputs
  430. # Transform nodes and collect targets for partition's submodule
  431. for node in m.graph.nodes:
  432. if hasattr(node, "_fx_partition"):
  433. partition = partitions[node._fx_partition]
  434. # swap out old graph nodes in kw/args with references to new nodes in this submodule
  435. environment = partition.environment
  436. gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
  437. gathered_kwargs = torch.fx.graph.map_arg(
  438. node.kwargs, lambda n: environment[n]
  439. )
  440. if node.op not in ["call_module", "get_attr"]:
  441. target = node.target
  442. else:
  443. target_attr = _get_attr_from_qualname(m, node.target)
  444. target = node.target.replace(".", "_")
  445. partition.targets[target] = target_attr
  446. # Fill in the passed-in mapping from new qualname to old qualname
  447. if qualname_map is not None:
  448. # When creating the split module later, the submodules will have
  449. # path prefix matching the corresponding partition's submod_name
  450. qualname = f"{partition.submod_name}.{target}"
  451. qualname_map[qualname] = node.target
  452. if not isinstance(gathered_args, tuple):
  453. raise AssertionError(
  454. f"Expected tuple for gathered_args, got {type(gathered_args)}"
  455. )
  456. if not isinstance(gathered_kwargs, dict):
  457. raise AssertionError(
  458. f"Expected dict for gathered_kwargs, got {type(gathered_kwargs)}"
  459. )
  460. name = node.name if keep_original_node_name else None
  461. new_node = partition.graph.create_node(
  462. op=node.op,
  463. target=target,
  464. args=gathered_args,
  465. kwargs=gathered_kwargs,
  466. type_expr=node.type,
  467. name=name,
  468. )
  469. new_node.meta = node.meta.copy()
  470. partition.environment[node] = new_node
  471. # Exit epilogue
  472. for regions_mapping in [autocast_regions]:
  473. for node in reversed(regions_mapping):
  474. regions = regions_mapping[node]
  475. if len(regions) == 0:
  476. raise AssertionError("Expected at least one region")
  477. # pyrefly: ignore [bad-index, index-error]
  478. for r in regions[:-1]:
  479. partition = partitions[str(r)]
  480. exit_node = autocast_exits[node]
  481. if exit_node is None:
  482. raise AssertionError("Missing exit node")
  483. new_node = partition.graph.create_node(
  484. op=exit_node.op,
  485. target=exit_node.target,
  486. args=(partition.environment[node],),
  487. kwargs={},
  488. type_expr=exit_node.type,
  489. )
  490. new_node.meta = (
  491. exit_node.meta.copy()
  492. ) # is it really a good idea to copy this?
  493. # original module environment dict mapping node names to nodes
  494. orig_mod_env: dict[str, Node] = {}
  495. # Set up values to construct base module
  496. base_mod_env: dict[str, Node] = {}
  497. base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
  498. base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {}
  499. if not keep_original_order:
  500. for node in m.graph.nodes:
  501. base_mod_env, base_mod_attrs = construct_graph(
  502. node, base_mod_env, base_mod_attrs
  503. )
  504. else:
  505. # Go through the graph to construct the mapping dict
  506. for node in m.graph.nodes:
  507. orig_mod_env[node.name] = node
  508. # Do some things iterating over the partitions in topological order again:
  509. # 1) Finish off submodule Graphs by setting corresponding outputs
  510. # 2) Construct GraphModules for each submodule
  511. # 3) Construct the base graph by emitting calls to those submodules in
  512. # topological order or original order specified by keep_original_order
  513. construct_order_partitions = (
  514. sorted_partitions if not keep_original_order else original_partition_order
  515. )
  516. already_constructed_attr_nodes = set()
  517. # We actually need to insert the placeholder nodes in the original order
  518. # otherwise graph signature will be wrong.
  519. original_order = [node for node in m.graph.nodes if node.op == "placeholder"]
  520. for partition_name in construct_order_partitions:
  521. partition = partitions[partition_name]
  522. # Set correct output values
  523. output_vals = tuple(
  524. partition.environment[orig_nodes[name]] for name in partition.outputs
  525. )
  526. # skip output node generation if there are no output values
  527. num_output_vals = len(output_vals)
  528. if num_output_vals == 1:
  529. partition.graph.output(output_vals[0])
  530. elif num_output_vals > 1:
  531. partition.graph.output(output_vals)
  532. else:
  533. # Invariant - Graph should always have an output node.
  534. partition.graph.output(())
  535. if keep_original_order:
  536. # first get the attr nodes required by this partition
  537. orig_mod_attr_nodes: list[Node] = [
  538. orig_mod_env[key]
  539. for key in partition.inputs
  540. if key not in original_order
  541. ]
  542. for node in original_order:
  543. if node in already_constructed_attr_nodes:
  544. continue # already added this attr to the base graph
  545. base_mod_env, _based_mod_attrs = construct_graph(
  546. node, base_mod_env, base_mod_attrs
  547. )
  548. already_constructed_attr_nodes.add(node)
  549. # Construct GraphModule for this partition
  550. for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
  551. if node in already_constructed_attr_nodes:
  552. continue
  553. base_mod_env, base_mod_attrs = construct_graph(
  554. node, base_mod_env, base_mod_attrs
  555. )
  556. already_constructed_attr_nodes.add(node)
  557. base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
  558. partition.targets, partition.graph
  559. ) # noqa: B950
  560. # Emit call in base graph to this submodule
  561. output_val = base_mod_graph.call_module(
  562. partition.submod_name,
  563. tuple(base_mod_env[name] for name in partition.inputs),
  564. )
  565. num_outputs = len(partition.outputs)
  566. if num_outputs > 1:
  567. # Unpack multiple return values from submodule
  568. output_val_proxy = torch.fx.proxy.Proxy(output_val)
  569. for i, output_name in enumerate(partition.outputs):
  570. base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
  571. elif num_outputs == 1:
  572. base_mod_env[next(iter(partition.outputs))] = output_val
  573. # When keep_original_order=True and if the graph doesn't have any
  574. # `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs`
  575. # are never populated.
  576. # For this case, we call `construct_graph` here which takes care of updating them.
  577. if keep_original_order and not base_mod_env:
  578. for node in m.graph.nodes:
  579. base_mod_env, base_mod_attrs = construct_graph(
  580. node, base_mod_env, base_mod_attrs
  581. )
  582. # Add output node to `base_mod_graph` (i.e. the split graph) which will be returned.
  583. for node in m.graph.nodes:
  584. if node.op == "output":
  585. base_mod_graph.output(
  586. torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
  587. ) # noqa: B950
  588. ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
  589. log.debug(
  590. "%s",
  591. lazy_format_graph_code("post split_module", ret, colored=True),
  592. )
  593. return ret