_IR.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import copy
  4. import logging
  5. import operator
  6. from collections import defaultdict
  7. from collections.abc import Callable
  8. from enum import Enum
  9. from inspect import Parameter, Signature, signature
  10. from types import MethodType
  11. from typing import Any, Union
  12. import torch
  13. import torch.fx as fx
  14. from torch.distributed import ProcessGroup
  15. from torch.export import ExportedProgram
  16. from torch.export.unflatten import (
  17. _assign_attr,
  18. _AttrKind,
  19. _sink_params,
  20. InterpreterModule,
  21. )
  22. from torch.fx.node import map_aggregate
  23. from torch.fx.passes.split_module import split_module
  24. from ._backward import _null_coalesce_accumulate, stage_backward
  25. from ._unflatten import _outline_submodules
  26. from ._utils import PipeInfo
  27. from .stage import _PipelineStage
  28. logger = logging.getLogger(__name__)
  29. # TODO:
  30. # 1. investigate gradient sync for shared parameters. how does DDP do it?
  31. # 2. Add parameter movement to split_module
  32. PP_SUBMOD_PREFIX = "submod_pp"
  33. def get_submod_name(stage_idx: int):
  34. """Returns the name of the submod for a given stage index.
  35. For example, "submod_pp_0", "submod_pp_1", etc.
  36. """
  37. return "_".join([PP_SUBMOD_PREFIX, str(stage_idx)])
  38. def _find_loss_from_output_and_spec(output_val, spec_val):
  39. if spec_val is False:
  40. return None
  41. if spec_val is True:
  42. if not isinstance(output_val, fx.Node):
  43. raise RuntimeError(
  44. f"Loss spec must specify a dynamic value but got {output_val}"
  45. )
  46. return output_val
  47. if isinstance(spec_val, (tuple, list)):
  48. if not isinstance(output_val, (tuple, list)):
  49. raise RuntimeError(
  50. f"Output value {output_val} must match type of loss specification "
  51. f"{spec_val}"
  52. )
  53. if len(output_val) != len(spec_val):
  54. raise RuntimeError(
  55. f"Output value {output_val} must match length of loss specification "
  56. f"{spec_val}"
  57. )
  58. for out, spec in zip(output_val, spec_val):
  59. loss_val = _find_loss_from_output_and_spec(out, spec)
  60. if loss_val is not None:
  61. return loss_val
  62. raise RuntimeError(f"Did not find loss value in specification {spec_val}")
  63. if isinstance(spec_val, dict):
  64. if not isinstance(output_val, dict):
  65. raise RuntimeError(
  66. f"Output value {output_val} must match type of loss specification "
  67. f"{spec_val}"
  68. )
  69. if set(output_val.keys()) != set(spec_val.keys()):
  70. raise RuntimeError(
  71. f"Output value {output_val} must match keys of loss specification "
  72. f"{spec_val}"
  73. )
  74. for k in spec_val:
  75. loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k])
  76. if loss_val is not None:
  77. return loss_val
  78. raise RuntimeError(f"Did not find loss value in specification {spec_val}")
  79. raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification")
  80. def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec):
  81. output_nodes = [n for n in g.nodes if n.op == "output"]
  82. if not len(output_nodes) == 1:
  83. raise AssertionError(f"Expected 1 output node, got {len(output_nodes)}")
  84. output_node = output_nodes[0]
  85. output_val = output_node.args[0]
  86. generated_spec: Any = None
  87. if isinstance(mod, TrivialLossWrapper):
  88. # TrivialLossWrapper is pre-defined by PiPPy.
  89. # It has loss as the only output so we can safely assume the first output arg is the loss.
  90. if not len(output_node.args) == 1:
  91. raise AssertionError(f"Expected 1 output arg, got {len(output_node.args)}")
  92. loss_node = output_val
  93. generated_spec = TrivialLossWrapper.loss_spec
  94. elif output_loss_value_spec is None:
  95. # Use default spec, i.e. search for "loss" in output values
  96. if isinstance(output_val, dict) and "loss" in output_val:
  97. loss_node = output_val["loss"]
  98. generated_spec = {k: k == "loss" for k in output_val}
  99. else:
  100. loss_node = None
  101. generated_spec = None
  102. else:
  103. loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec)
  104. generated_spec = output_loss_value_spec
  105. return loss_node, output_node, generated_spec
  106. def _insert_stage_symbolic_backward(
  107. g: fx.Graph,
  108. loss_node: fx.Node,
  109. output_node: fx.Node,
  110. ):
  111. # Collect metadata about tuple output values. TODO: move this to split_module or FX IR
  112. tuples: dict[fx.Node, tuple] = {}
  113. for node in reversed(g.nodes):
  114. if node.op == "call_function":
  115. # In the forward pass, only emit placeholder, module calls, and
  116. # getitem calls. If we have a target other than getitem in this
  117. # (forward-only) code, there is a bug.
  118. if not node.target == operator.getitem:
  119. raise AssertionError(
  120. "Found non-getitem call in forward pass. Please report a bug to PiPPy"
  121. )
  122. if not len(node.args) == 2:
  123. raise AssertionError(
  124. "Found malformed getitem call. Please report a bug to PiPPy"
  125. )
  126. indexed_value, node_idx = tuple(node.args)
  127. # indexed_value is a collection that we are indexing into. It could
  128. # exist in the tuples map if we've processed another `getitem`
  129. # already.
  130. existing_list_size = (
  131. len(tuples[indexed_value]) if indexed_value in tuples else -1
  132. )
  133. new_list_size = max(node_idx + 1, existing_list_size)
  134. reconstructed_list = [None for _ in range(new_list_size)]
  135. # Copy over existing elements if present
  136. if indexed_value in tuples:
  137. for i, val in enumerate(tuples[indexed_value]):
  138. reconstructed_list[i] = val
  139. # Populate value represented by this node
  140. reconstructed_list[node_idx] = node
  141. tuples[indexed_value] = tuple(reconstructed_list)
  142. # Keep track of nodes that dominate the loss node.
  143. # We will only emit backward operations for nodes that can contribute
  144. # to the specified loss value.
  145. live_nodes = {loss_node: None}
  146. val_to_grad: dict[fx.Node, fx.Node | None] = {loss_node: None}
  147. def assign_or_accumulate_grad(forward_node, grad_value):
  148. if forward_node in val_to_grad and forward_node.op != "placeholder":
  149. grad_value = g.call_function(
  150. _null_coalesce_accumulate,
  151. (val_to_grad[forward_node], grad_value),
  152. )
  153. val_to_grad[forward_node] = grad_value
  154. with g.inserting_before(output_node):
  155. for node in reversed(g.nodes):
  156. if node not in live_nodes:
  157. continue
  158. def add_to_live_nodes(n):
  159. live_nodes.setdefault(n, None)
  160. fx.node.map_arg(node.args, add_to_live_nodes)
  161. fx.node.map_arg(node.kwargs, add_to_live_nodes)
  162. if node.op == "call_module":
  163. output_grads: tuple[fx.Node | None, ...] | fx.Node | None
  164. if node in tuples:
  165. stage_output = tuples[node]
  166. output_grads = tuple(val_to_grad.get(n) for n in tuples[node])
  167. outputs_with_grads_idxs = [
  168. i for i, n in enumerate(tuples[node]) if n in live_nodes
  169. ]
  170. else:
  171. stage_output = (node,)
  172. output_grads = val_to_grad[node]
  173. outputs_with_grads_idxs = [0]
  174. output_grads = (
  175. (output_grads,)
  176. if not isinstance(output_grads, tuple)
  177. else output_grads
  178. )
  179. grad_call = g.call_function(
  180. stage_backward,
  181. kwargs={
  182. "stage_output": stage_output,
  183. "output_grads": output_grads,
  184. "input_values": list(node.all_input_nodes),
  185. "outputs_with_grads_idxs": outputs_with_grads_idxs,
  186. },
  187. )
  188. # Insert backward stage debug info
  189. kwargs_copy = dict(grad_call.kwargs)
  190. grad_call.kwargs = kwargs_copy
  191. grad_call_proxy = fx.Proxy(grad_call)
  192. grads = grad_call_proxy.node
  193. input_nodes = list(node.all_input_nodes)
  194. grads_proxy = fx.Proxy(grads)
  195. for i, input_node in enumerate(input_nodes):
  196. assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index]
  197. return g
  198. class PipeSequential(torch.nn.Sequential):
  199. @staticmethod
  200. def from_sequential(sequential_instance: torch.nn.Sequential):
  201. return PipeSequential(*[copy.copy(m) for m in sequential_instance])
  202. def forward(self, input):
  203. for i, module in enumerate(self):
  204. input = module(input)
  205. if i != len(self) - 1:
  206. pipe_split()
  207. return input
  208. class LossWrapper(torch.nn.Module):
  209. """
  210. LossWrapper is a convenient abstract class that allows you to wrap up both
  211. your model as well as its loss function and specify the connectivity between
  212. the inputs, model, loss function, and output value. Example::
  213. class MyModelWrapper(LossWrapper):
  214. def forward(self, x, targets):
  215. model_out = self.module(x)
  216. loss_value = self.loss_fn(model_out, targets)
  217. return loss_value
  218. The above example defines a connectivity where we expect the forward/loss/backward
  219. training procedure to take two arguments (x and targets), pass x into the module
  220. to get the output of the feedforward computation, pass the model output and the
  221. targets value into the loss function, and get and return the loss value, which will
  222. be backpropagated by PiPPy. The above class would then be instantiated like::
  223. model = ... # instantiate the model
  224. loss_fn = torch.nn.MSELoss() # for the sake of demonstration
  225. wrapper = MyModelWrapper(model, loss_fn)
  226. pipe = Pipe.from_tracing(wrapper, ...)
  227. """
  228. def __init__(self, module, loss_fn):
  229. super().__init__()
  230. self.module = module
  231. self.loss_fn = loss_fn
  232. def forward(self, *args, **kwargs):
  233. raise NotImplementedError(
  234. "This instance of LossWrapper does not have an overridden"
  235. "forward(). Please implement forward() to specify the arguments, "
  236. "connection between the module and loss, and loss output "
  237. "value."
  238. )
  239. class TrivialLossWrapper(LossWrapper):
  240. # pyrefly: ignore [bad-override]
  241. def forward(self, x, targets):
  242. model_out = self.module(x)
  243. return self.loss_fn(model_out, targets)
  244. loss_spec = True
  245. # Pipe model representation
  246. #
  247. # Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies
  248. # a single topological ordering of pipeline "stages" that, when run in series,
  249. # constitutes all of the operations of the program. However, unlike `nn.Sequential`,
  250. # Pipe allows non-local usages of values, so long as those uses still respect
  251. # topological ordering. In particular:
  252. #
  253. # 1. Non-local activations. This type of usage can appear in, for example, skip
  254. # connections. These values will be directly transmitted from the "def" stage
  255. # to all stages that use them skipping intermediate stages. During autograd,
  256. # gradients will be propagated back through this skip connection reverse
  257. # to how activations propagated in the forward pass.
  258. # 2. Non-local parameter/module invocations. This occurs when a parameter is used
  259. # in a stage downstream of where it is resident. These values can be carried
  260. # forward similarly to (1), but in addition one might want to replicate the
  261. # value on multiple stages. Gradients for these shared parameters will be
  262. # accumulated separately on each stage, but there will be an additional
  263. # gradient accumulation before the optimizer step.
  264. # Register `_pipe_split()` as an ATen operator. This is required for Export to
  265. # preserve this marker in the graph.
  266. torch.library.define("pippy::_pipe_split", "() -> ()")
  267. @torch.library.impl("pippy::_pipe_split", "BackendSelect")
  268. def _pipe_split():
  269. return None
  270. @torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef]
  271. def _pipe_split(): # noqa: F811
  272. return None
  273. # Add an alias for convenience
  274. aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
  275. # Ask Export to preserve the `_pipe_split` op.
  276. # See examples in pytorch/torch/fx/node.py
  277. fx.node._side_effectful_functions.add(aten_pipe_split_alias)
  278. # User facing API
  279. def pipe_split():
  280. """
  281. pipe_split is a special operator that is used to mark the boundary between
  282. stages in a module. It is used to split the module into stages. It is a
  283. no-op if your annotated module is run eagerly.
  284. Example:
  285. >>> # xdoctest: +SKIP
  286. >>> def forward(self, x):
  287. >>> x = torch.mm(x, self.mm_param)
  288. >>> x = torch.relu(x)
  289. >>> pipe_split()
  290. >>> x = self.lin(x)
  291. >>> return x
  292. The above example will be split into two stages.
  293. """
  294. return torch.ops.pippy._pipe_split()
  295. class MultiUseParameterConfig(Enum):
  296. TRANSMIT = 1
  297. REPLICATE = 2
  298. MultiUseParamSpec = Union[MultiUseParameterConfig, dict[str, MultiUseParameterConfig]]
  299. class DetachExecutor(fx.Interpreter):
  300. """
  301. Special interpreter to run the split_gm in testing that detaches all inputs to
  302. a module invocation. This is needed so that the values at the boundary are
  303. leaf modules in autograd execution.
  304. """
  305. def __init__(self, module, garbage_collect_values=True):
  306. garbage_collect_values = False
  307. super().__init__(module, garbage_collect_values)
  308. self.value_remap = {}
  309. def run(self, *args, initial_env=None): # type: ignore[override]
  310. self.value_remap = {}
  311. return super().run(*args, initial_env=initial_env)
  312. def call_module(self, target, args, kwargs):
  313. def detach_tensors(a):
  314. if isinstance(a, torch.Tensor) and a.requires_grad:
  315. if a not in self.value_remap:
  316. new_val = a.detach().requires_grad_(True)
  317. self.value_remap[a] = new_val
  318. return self.value_remap[a]
  319. else:
  320. return a
  321. """
  322. def dont_traverse_size(a):
  323. return type(a) is not torch.Size
  324. """
  325. args = map_aggregate(
  326. args,
  327. detach_tensors, # dont_traverse_size
  328. )
  329. kwargs = map_aggregate(
  330. kwargs,
  331. detach_tensors, # dont_traverse_size
  332. )
  333. return super().call_module(target, args, kwargs)
  334. def call_function(self, target, args, kwargs):
  335. # HACK to reroute saved input tensors to point to the detach()ed version
  336. if target is stage_backward:
  337. kwargs = dict(kwargs)
  338. kwargs["input_values"] = [
  339. self.value_remap.get(v, v) for v in kwargs["input_values"]
  340. ]
  341. return super().call_function(target, args, kwargs)
  342. class _NodeReference:
  343. def __init__(self, name):
  344. self.name = name
  345. name: str
  346. class _LinearNodeList:
  347. def __init__(self, node_list):
  348. self.serialize_node_list = []
  349. for node in node_list:
  350. node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
  351. node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
  352. serialize_node = fx.Node(
  353. graph=None, # type: ignore[arg-type]
  354. name=node.name,
  355. op=node.op,
  356. target=node.target,
  357. args=node_args, # type: ignore[arg-type]
  358. kwargs=node_kwargs, # type: ignore[arg-type]
  359. return_type=node.type,
  360. )
  361. serialize_node.meta = copy.copy(node.meta)
  362. self.serialize_node_list.append(serialize_node)
  363. def to_graph(self):
  364. graph = fx.Graph()
  365. ref_str_to_node: dict[str, fx.Node] = {}
  366. def ref_to_node(arg):
  367. if isinstance(arg, _NodeReference):
  368. return ref_str_to_node[arg.name]
  369. else:
  370. return arg
  371. for node in self.serialize_node_list:
  372. node_args = map_aggregate(node.args, ref_to_node)
  373. node_kwargs = map_aggregate(node.kwargs, ref_to_node)
  374. deser_node = graph.create_node(
  375. op=node.op,
  376. target=node.target,
  377. args=node_args, # type: ignore[arg-type]
  378. kwargs=node_kwargs, # type: ignore[arg-type]
  379. name=node.name,
  380. type_expr=node.type,
  381. )
  382. ref_str_to_node[node.name] = deser_node
  383. return graph
  384. def _direct_serialization_deserialize(body, nodes):
  385. """
  386. Custom `__reduce__` method for serialization.
  387. DO AS I SAY -- NOT AS I DO. This violates the principle that
  388. GraphModules serialize via code export & re-tracing. We allow
  389. for this here because **PIPE STAGES SHOULD NOT BE PERSISTED
  390. TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting
  391. these instances to disk will expose internal implementation
  392. details of `fx.Graph` and related data structures and is
  393. NOT advised.
  394. """
  395. class DummyModule(torch.nn.Module):
  396. def __init__(self, body):
  397. super().__init__()
  398. self.__dict__.update(body)
  399. dummy = DummyModule(body)
  400. return fx.GraphModule(dummy, nodes.to_graph())
  401. def _direct_serialization_reduce(self):
  402. serialization_dict = dict(self.__dict__)
  403. serialization_dict.pop("_graph")
  404. return (
  405. _direct_serialization_deserialize,
  406. (serialization_dict, _LinearNodeList(self.graph.nodes)),
  407. )
  408. def _modify_graph_op_device(
  409. gm: torch.fx.GraphModule,
  410. new_device: torch.device,
  411. ):
  412. """
  413. Modify the device argument of all "call_function" nodes in the graph. This
  414. is useful for moving the graph to a different device. In particular for
  415. generator ops, like torch.ones.
  416. """
  417. modified = False
  418. for node in gm.graph.nodes:
  419. if node.op == "call_function":
  420. if "device" in node.kwargs and node.kwargs["device"] != new_device:
  421. logger.debug(
  422. f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
  423. )
  424. node.update_kwarg("device", new_device)
  425. modified = True
  426. elif node.op == "call_module":
  427. # Recursively modify "device" in submodules
  428. submod = gm.get_submodule(node.target)
  429. if isinstance(submod, torch.fx.GraphModule):
  430. _modify_graph_op_device(submod, new_device)
  431. elif isinstance(submod, InterpreterModule):
  432. # If unflattening has been performed, we need to access its graph module by `.graph_module`
  433. _modify_graph_op_device(submod.graph_module, new_device) # type: ignore[arg-type]
  434. else:
  435. logger.warning(
  436. f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
  437. )
  438. if modified:
  439. gm.recompile()
  440. class Pipe(torch.nn.Module):
  441. def __init__(
  442. self,
  443. split_gm: fx.GraphModule,
  444. num_stages: int,
  445. has_loss_and_backward: bool,
  446. loss_spec,
  447. ):
  448. # TODO: is there a way not to hard wire init?
  449. torch.nn.Module.__init__(self)
  450. self.split_gm: fx.GraphModule = split_gm
  451. self.executor: DetachExecutor = DetachExecutor(self.split_gm)
  452. self.num_stages: int = num_stages
  453. self.has_loss_and_backward = has_loss_and_backward
  454. self.loss_spec = loss_spec
  455. for node in split_gm.graph.nodes:
  456. if not (
  457. node.op in {"call_module", "placeholder", "output"}
  458. or (node.op, node.target) == ("call_function", operator.getitem)
  459. or (node.op, node.target) == ("call_method", "backward")
  460. or (node.op, node.target) == ("call_function", stage_backward)
  461. or (node.op, node.target)
  462. == ("call_function", _null_coalesce_accumulate)
  463. ):
  464. raise AssertionError(f"Unexpected node: {node}")
  465. # Detect replicated parameters so we know that we have to do an additional allreduce
  466. # before applying the optimizer
  467. #
  468. # Note that this also handles the case where there were multiple calls to a single
  469. # module from different stages, regardless of whether that module invocation
  470. # was handled by the logic above.
  471. # Map parameter value to a dictionary that maps the user pipeline module
  472. # to the local qualname within that module
  473. params_to_users: dict[torch.nn.Parameter, dict[str, str]] = {}
  474. for m_qualname, mod in self.split_gm.named_children():
  475. for p_qualname, param in mod.named_parameters():
  476. params_to_users.setdefault(param, {})
  477. params_to_users[param][m_qualname] = p_qualname
  478. self.replicated_params: list[dict[str, str]] = [
  479. use_mapping
  480. for _, use_mapping in params_to_users.items()
  481. if len(use_mapping) > 1
  482. ]
  483. # We must break the aliasing relationship between the replicated parameters for correct
  484. # numerics in reference runs. If we do not do this, the autograd tape in separate stages
  485. # will have a reference to the same tensor value and will erroneously apply gradient
  486. # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the
  487. # values so that we have separate instances.
  488. for param_mapping in self.replicated_params:
  489. for submod_name, param_qualname in param_mapping.items():
  490. submod = getattr(self.split_gm, submod_name)
  491. atoms = param_qualname.split(".")
  492. for atom in atoms[:-1]:
  493. submod = getattr(submod, atom)
  494. setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
  495. def throw(self, *args, **kwargs):
  496. raise RuntimeError(
  497. "To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
  498. )
  499. self.split_gm.forward = throw
  500. # Make submodules use custom direct-serialized GraphModule
  501. i = 0
  502. while True:
  503. try:
  504. name = get_submod_name(i)
  505. submod = getattr(self.split_gm, name)
  506. submod.__class__.__reduce__ = _direct_serialization_reduce
  507. i += 1
  508. except AttributeError:
  509. break
  510. def forward(self, *args, **kwargs):
  511. executor_args = args
  512. if len(kwargs) > 0:
  513. parameters = []
  514. for node in self.split_gm.graph.nodes:
  515. if node.op == "placeholder":
  516. if node.args and len(node.args) > 0:
  517. parameters.append(
  518. Parameter(
  519. node.target,
  520. Parameter.POSITIONAL_OR_KEYWORD,
  521. default=node.args[0],
  522. )
  523. )
  524. else:
  525. parameter_kind = Parameter.POSITIONAL_OR_KEYWORD
  526. param_name = node.target
  527. if node.target.startswith("**"):
  528. parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment]
  529. param_name = param_name[2:]
  530. elif node.target.startswith("*"):
  531. parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment]
  532. param_name = param_name[1:]
  533. parameters.append(Parameter(param_name, parameter_kind))
  534. signature = Signature(parameters)
  535. ba = signature.bind(*args, **kwargs)
  536. ba.apply_defaults()
  537. executor_args = ba.arguments.values() # type: ignore[assignment]
  538. res = self.executor.run(*executor_args)
  539. return res
  540. def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
  541. """
  542. Return a stage module corresponding to `stage_idx` of the `pipe`.
  543. """
  544. if stage_idx < 0 or stage_idx >= self.num_stages:
  545. raise ValueError(f"Invalid stage index {stage_idx}!")
  546. submod_name = get_submod_name(stage_idx)
  547. return getattr(self.split_gm, submod_name)
  548. @staticmethod
  549. def _number_and_count_forward_stages(gm: fx.GraphModule):
  550. num_stages = 0
  551. found_idxs: dict[int, None] = {}
  552. for node in gm.graph.nodes:
  553. if node.op == "call_module" and node.target.startswith(PP_SUBMOD_PREFIX):
  554. node.meta["stage_idx"] = int(node.target[len(PP_SUBMOD_PREFIX) + 1 :])
  555. found_idxs.setdefault(node.meta["stage_idx"])
  556. num_stages += 1
  557. # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule
  558. # Update: the following assert may fail against some torch versions >=
  559. # 2.2.0, as:
  560. # submod_0, submod_1, submod_2, ...
  561. # may be named as
  562. # submod_0, submod_2, submod_4, ...
  563. # TODO: investigate
  564. # assert all(i in found_idxs for i in range(num_stages))
  565. return num_stages
  566. @staticmethod
  567. def _from_traced(
  568. mod: torch.nn.Module,
  569. exported_program: ExportedProgram,
  570. multi_use_param_spec: MultiUseParamSpec | None = None,
  571. output_loss_value_spec=None,
  572. split_policy: Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
  573. | None = None,
  574. ):
  575. """
  576. Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
  577. which value in the output of `forward` is the loss value on which PiPPy should apply
  578. backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``,
  579. you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns
  580. a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify
  581. ``output_loss_value_spec={'loss': True, 'model_out': False}``
  582. """
  583. traced = exported_program.module(check_guards=False)
  584. if split_policy is not None:
  585. logger.info("Auto-splitting model")
  586. traced = split_policy(traced) # type: ignore[arg-type]
  587. logger.debug(traced.print_readable(print_output=False)) # type: ignore[operator]
  588. # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
  589. # parameters relies on the invariant that parameter accesses happen once. This is not necessarily
  590. # the case (especially with custom tracers), so fix that up here.
  591. get_attr_nodes: dict[str, fx.Node] = {}
  592. for node in traced.graph.nodes: # type: ignore[union-attr]
  593. if node.op == "get_attr":
  594. get_attr_nodes.setdefault(node.target, node)
  595. if get_attr_nodes[node.target] != node:
  596. node.replace_all_uses_with(get_attr_nodes[node.target])
  597. traced.graph.erase_node(node) # type: ignore[operator, union-attr]
  598. # avoid looking at next node by keeping track of previous pipe_split
  599. prev_pipe_split_idx = -1
  600. pipe_split_nodes_to_erase = set()
  601. for i, node in enumerate(traced.graph.nodes): # type: ignore[arg-type, union-attr]
  602. if (node.op, node.target) == ("call_function", pipe_split):
  603. if prev_pipe_split_idx == i - 1:
  604. pipe_split_nodes_to_erase.add(node)
  605. prev_pipe_split_idx = i
  606. for node in pipe_split_nodes_to_erase:
  607. traced.graph.erase_node(node) # type: ignore[operator, union-attr]
  608. traced.recompile() # type: ignore[operator]
  609. part_idx = 0
  610. def split_callback(n: fx.Node):
  611. nonlocal part_idx
  612. if (n.op, n.target) == (
  613. "call_function",
  614. aten_pipe_split_alias,
  615. ):
  616. logger.debug(f"Found pipe_split {part_idx}") # noqa: G004
  617. part_idx += 1
  618. return part_idx
  619. # TODO: what does split do with module invocations? does it move the modules
  620. # into the submodules?
  621. split = split_module(traced, mod, split_callback, partition_affix="pp") # type: ignore[arg-type]
  622. # a (custom) tracer can produce dead code like orphan get_attr nodes
  623. split.graph.eliminate_dead_code()
  624. # peephole to remove pipe_split
  625. for submodule in split.modules():
  626. if isinstance(submodule, fx.GraphModule):
  627. for node in submodule.graph.nodes:
  628. if (node.op, node.target) == (
  629. "call_function",
  630. aten_pipe_split_alias,
  631. ):
  632. submodule.graph.erase_node(node)
  633. submodule.recompile()
  634. for name, submodule in split.named_children():
  635. if isinstance(submodule, fx.GraphModule):
  636. new_submod = _outline_submodules(submodule.graph)
  637. # Replace old submod
  638. split.register_module(name, new_submod)
  639. # TODO: backport this into split_module
  640. def delete_user_reference(node, user):
  641. """
  642. Delete reference of `node` from `user`'s arg list.
  643. Args:
  644. - node: a `get_attr` node at root.
  645. - user: a submodule node that uses `node`.
  646. """
  647. if not len(user.kwargs) == 0:
  648. raise AssertionError(
  649. f"Expected user.kwargs to be empty, got {len(user.kwargs)}"
  650. )
  651. use_idxs = [i for i, arg in enumerate(user.args) if arg == node]
  652. if not len(use_idxs) == 1:
  653. raise AssertionError(f"Expected 1 use index, got {len(use_idxs)}")
  654. args_copy = list(user.args)
  655. args_copy.pop(use_idxs[0])
  656. user.args = tuple(args_copy)
  657. logger.debug(
  658. f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004
  659. )
  660. # A list of param referrals for deferred deletion.
  661. # To be accumulated in `move_param_to_callee`.
  662. to_delete = []
  663. def _recursive_getattr_with_parent(mod, fqn):
  664. # Returns getattr call given a nested FQN, and the last parent
  665. atoms = fqn.split(".")
  666. for atom in atoms[:-1]:
  667. if not hasattr(mod, atom):
  668. return None, None
  669. mod = getattr(mod, atom)
  670. if not hasattr(mod, atoms[-1]):
  671. return mod, None
  672. attr = getattr(mod, atoms[-1])
  673. return mod, attr
  674. def move_param_to_callee(
  675. root,
  676. callee_name,
  677. param_fqn,
  678. ):
  679. """
  680. Move a parameter from the root module to a submodule.
  681. Args:
  682. root: The root module.
  683. callee_name: The name of the submodule to move the parameter to.
  684. param_fqn: The fully qualified name of the parameter to move.
  685. """
  686. # `atoms` is a list of strings representing the path to the
  687. # parameter in the original model
  688. atoms = param_fqn.split(".")
  689. mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn)
  690. # Check whether the parameter is a buffer or a parameter
  691. is_buffer = atoms[-1] in mod_itr._buffers
  692. # Check whether the parameter is a tensor
  693. if not isinstance(param_val, torch.Tensor):
  694. raise AssertionError(
  695. f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}."
  696. + (
  697. f" It might happen if module '{param_fqn}' was passed to some 'leaf function'"
  698. f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect "
  699. f"usages of '{param_fqn}' in the traced graph."
  700. if isinstance(param_val, torch.nn.Module)
  701. else ""
  702. )
  703. )
  704. # Get submodule
  705. callee = root.get_submodule(callee_name)
  706. if hasattr(callee, param_fqn):
  707. raise AssertionError(
  708. f"Module {callee_name} already has a parameter named {param_fqn}"
  709. )
  710. # Assign the parameter to the submodule
  711. if is_buffer:
  712. _assign_attr(
  713. param_val,
  714. callee,
  715. param_fqn,
  716. attr_kind=_AttrKind.BUFFER,
  717. persistent=True, # TODO: handle non-persistent buffer
  718. )
  719. else:
  720. _assign_attr(
  721. param_val,
  722. callee,
  723. param_fqn,
  724. attr_kind=_AttrKind.PARAMETER,
  725. )
  726. logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004
  727. # Next step is to replace placeholder of submodule with a get_attr.
  728. # Those placeholders are created by `split_module` inside each
  729. # submodule.
  730. # Update: this step is now moved to `_sink_params` because
  731. # `_sink_params` can do it recursively (i.e. for modules inside
  732. # submodule)
  733. to_delete.append((mod_itr, atoms[-1]))
  734. # Get the list of all parameters in the root module
  735. attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes))
  736. for node in attr_nodes:
  737. # Check whether the parameter is used in only one submodule
  738. if len(node.users) > 1:
  739. logger.info(
  740. f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004
  741. )
  742. for user in node.users:
  743. if not user.op == "call_module":
  744. raise AssertionError(
  745. f"Expected user.op to be 'call_module', got {user.op}"
  746. )
  747. # Move parameter into submodule
  748. move_param_to_callee(
  749. split,
  750. user.target,
  751. node.target,
  752. )
  753. # [aliasing] store tensor id -> list of FQNs, built from state dict
  754. # Also assign non-persistent buffers
  755. id_to_fqns: dict[int, set[str]] = defaultdict(set)
  756. for fqn, tensor in mod.state_dict(keep_vars=True).items():
  757. id_to_fqns[id(tensor)].add(fqn)
  758. for fqn, tensor in mod.named_buffers():
  759. id_to_fqns[id(tensor)].add(fqn)
  760. # After moving the params to their corresponding hierarchies, we also
  761. # need to move the `get_attr` nodes from the root of the graph to those
  762. # hierarchies.
  763. # [aliasing] use id -> fqn mapping to list out all valid FQNs
  764. inputs_to_state: dict[str, list[str]] = {}
  765. for attr in attr_nodes:
  766. _, tensor = _recursive_getattr_with_parent(mod, attr.target)
  767. fqns = list(id_to_fqns[id(tensor)])
  768. if fqns:
  769. inputs_to_state[attr.name] = fqns
  770. elif attr.target in exported_program.constants: # lifted constants
  771. inputs_to_state[attr.name] = [attr.target]
  772. # [aliasing] for each submodule split, assign attributes on FQNs that may be used.
  773. # We determine this based on whether or not the FQN attribute parent exists.
  774. # i.e. if the last submodule exists, assign the attribute.
  775. added_attributes: dict[str, list[str]] = defaultdict(list)
  776. for fqn, tensor in mod.state_dict(keep_vars=True).items():
  777. for name, submod in split.named_children():
  778. if isinstance(submod, fx.GraphModule):
  779. parent, child = _recursive_getattr_with_parent(submod, fqn)
  780. if (
  781. parent and child is None
  782. ): # parent exists, attribute doesn't -> assign
  783. added_attributes[name].append(fqn)
  784. setattr(parent, fqn.split(".")[-1], tensor)
  785. # Deferral deletion: Remove the original attributes (to params) from the
  786. # root GraphModule
  787. for mod_itr, last_atom in to_delete:
  788. try:
  789. delattr(mod_itr, last_atom)
  790. except AttributeError:
  791. # This is expected if the parameter is used in multiple stages
  792. pass
  793. # This is done by (1) `_sink_params` at each submodule;
  794. for submod in split.children():
  795. if isinstance(submod, fx.GraphModule):
  796. _sink_params(submod, inputs_to_state, [])
  797. submod.graph.lint()
  798. submod.recompile()
  799. # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory.
  800. # After _sink_params() routine has run, clean up unused attributes that we previously added.
  801. # Determine this based on the get_attr nodes - if not used, remove it.
  802. for name, attributes in added_attributes.items():
  803. submod = getattr(split, name)
  804. unused_attributes = set(attributes)
  805. # track used attributes in the submodule, running DFS on subgraph hierarchy
  806. stack = [("", submod)] # (scope, submodule)
  807. while stack:
  808. scope, _mod = stack.pop()
  809. if isinstance(_mod, (fx.GraphModule, InterpreterModule)):
  810. for node in _mod.graph.nodes:
  811. if node.op == "get_attr":
  812. # get_attr might get access deeper level attribute
  813. fqn = scope + "." + node.target if scope else node.target
  814. unused_attributes.discard(fqn)
  815. for _name, _submod in _mod.named_children():
  816. stack.append((scope + "." + _name if scope else _name, _submod))
  817. # delete unused attributes
  818. for attr in unused_attributes:
  819. mod_itr, atoms = submod, attr.split(".")
  820. for atom in atoms[:-1]:
  821. mod_itr = getattr(mod_itr, atom)
  822. delattr(mod_itr, atoms[-1])
  823. for node in attr_nodes:
  824. # And (2): remove `get_attr` node from submod's arg list
  825. for user in copy.copy(node.users):
  826. if not user.op == "call_module":
  827. raise AssertionError(
  828. f"Expected user.op to be 'call_module', got {user.op}"
  829. )
  830. delete_user_reference(node, user)
  831. # And (3): remove the `get_attr` node from the root graph.
  832. split.graph.erase_node(node)
  833. split.delete_all_unused_submodules()
  834. split.graph.lint()
  835. split.recompile()
  836. num_stages = Pipe._number_and_count_forward_stages(split)
  837. has_loss_and_backward = False
  838. generated_loss_spec = output_loss_value_spec
  839. if output_loss_value_spec is not None:
  840. loss_node, output_node, generated_loss_spec = _find_loss_output(
  841. mod, split.graph, output_loss_value_spec
  842. )
  843. if loss_node is not None:
  844. _insert_stage_symbolic_backward(
  845. split.graph,
  846. loss_node,
  847. output_node,
  848. )
  849. split.recompile()
  850. has_loss_and_backward = True
  851. logger.debug("Pipeline is in training mode, backward pass generated")
  852. else:
  853. raise RuntimeError(
  854. f"Did not find any loss value according to {output_loss_value_spec=}"
  855. )
  856. else:
  857. logger.debug("Pipeline is in inference mode, backward pass not generated")
  858. logger.debug(f"Full pipe model:\n{split}") # noqa: G004
  859. return Pipe(
  860. split,
  861. num_stages,
  862. has_loss_and_backward,
  863. generated_loss_spec,
  864. )
  865. def print_readable(self):
  866. """
  867. Print the pipe in a human-readable format.
  868. This will print both the root pipe and each stage module.
  869. """
  870. self.split_gm.print_readable()
  871. @staticmethod
  872. def _trace_with_export(
  873. mod: torch.nn.Module,
  874. example_args: tuple[Any, ...],
  875. example_kwargs: dict[str, Any] | None = None,
  876. ) -> ExportedProgram:
  877. logger.info("Tracing model ...")
  878. try:
  879. ep = torch.export.export(mod, example_args, example_kwargs)
  880. except Exception as e:
  881. raise RuntimeError(
  882. "It seems that we cannot capture your model as a full graph. "
  883. "Typical reasons include graph breaks, data/shape-dependent "
  884. "control flow, or missing meta kernels for custom operators. "
  885. "You can use our manual pipeline interfaces, or try to fix the "
  886. "graph breaks, see https://pytorch.org/docs/stable/export.html"
  887. ) from e
  888. return ep
  889. @staticmethod
  890. def from_tracing(
  891. mod: torch.nn.Module,
  892. example_args: tuple[Any, ...],
  893. example_kwargs: dict[str, Any] | None = None,
  894. split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None,
  895. ):
  896. # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
  897. # stages instead of TRANSMIT'ting it
  898. multi_use_param_spec = MultiUseParameterConfig.REPLICATE
  899. # Figure out which output is loss from output_chunk_spec
  900. output_loss_value_spec: Any = None
  901. # Deprecated
  902. """
  903. if output_chunk_spec is not None:
  904. output_loss_value_spec = map_aggregate(
  905. output_chunk_spec, lambda v: isinstance(v, _LossReducer)
  906. )
  907. """
  908. # Trace with export
  909. exported_program = Pipe._trace_with_export(
  910. mod,
  911. example_args,
  912. example_kwargs,
  913. )
  914. pipe = Pipe._from_traced(
  915. mod,
  916. exported_program,
  917. multi_use_param_spec,
  918. output_loss_value_spec=output_loss_value_spec,
  919. split_policy=split_policy,
  920. )
  921. # Users want the first pipeline stage to accept kwargs if the original
  922. # program does. This is controlled by the `_codegen` field of the graph,
  923. # so we make a copy here. Note: we only want the input spec and not the
  924. # output spec, because the output spec is for the last stage. Maybe a
  925. # TODO? Not sure yet.
  926. split = pipe.split_gm
  927. traced = exported_program.module()
  928. submod0 = next(iter(split.children()))
  929. submod0_sign = signature(submod0.forward)
  930. model_sign = signature(traced.forward)
  931. if len(model_sign.parameters) != len(submod0_sign.parameters):
  932. # We don't change the signature of the first stage if it takes
  933. # different number of args than original model
  934. logger.info(
  935. f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004
  936. f"first pipeline stage takes {len(submod0_sign.parameters)}. "
  937. "Please provide args to respective pipeline stages."
  938. )
  939. else:
  940. # Support kwargs for the first stage
  941. submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) # type: ignore[union-attr]
  942. # `_replace` is actually not "private" or internal. based on this doc:
  943. # To prevent conflicts with field names, the method and attribute names
  944. # start with an underscore
  945. submod0.graph._codegen.pytree_info = ( # type: ignore[union-attr]
  946. submod0.graph._codegen.pytree_info._replace(out_spec=None) # type: ignore[operator, union-attr]
  947. )
  948. submod0.recompile()
  949. return pipe
  950. def __str__(self):
  951. return self.split_gm.__str__()
  952. def __repr__(self):
  953. return self.split_gm.__repr__()
  954. def info(self) -> PipeInfo:
  955. """
  956. Get information about the pipe.
  957. Returns
  958. -------
  959. PipeInfo
  960. A dataclass containing information about the pipe.
  961. """
  962. return PipeInfo(
  963. graph=self.split_gm.graph,
  964. num_stages=self.num_stages,
  965. has_loss_and_backward=self.has_loss_and_backward,
  966. )
  967. def build_stage(
  968. self,
  969. stage_index: int,
  970. device: torch.device,
  971. group: ProcessGroup | None = None,
  972. ) -> _PipelineStage:
  973. """
  974. Create a `PipelineStage` given a stage index and distributed group.
  975. The `PipelineStage` can run with `PipelineSchedule`s.
  976. """
  977. # Find stage module
  978. stage_module = self.get_stage_module(stage_index)
  979. # Move ops argument to device
  980. # Today PT2 tracer does not treat `x.device` as a symbolic device;
  981. # instead, the device of tracing time got burned into the generated
  982. # code. Here we provide a workaround for users to manually modify the
  983. # "device" kwarg of operations. Such operation may include:
  984. # `torch.ones`, `torch.zeros`, `torch.rand`, etc.
  985. if isinstance(stage_module, torch.fx.GraphModule):
  986. _modify_graph_op_device(stage_module, device)
  987. else:
  988. logger.warning(
  989. f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004
  990. )
  991. # Detach pipe info
  992. # Note: be careful what's included in `pipe_info`. We don't want to keep
  993. # a reference to `Pipe` or `Pipe.split_gm` which stops python from
  994. # recycling them. When python recycles them, other stage modules (which
  995. # are irrelevant to current rank) can be automatically freed.
  996. pipe_info = self.info()
  997. return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
  998. class SplitPoint(Enum):
  999. """
  1000. Enum representing the points at which a split can occur in the execution of a submodule.
  1001. Attributes:
  1002. BEGINNING: Represents adding a split point *before* the execution of a certain submodule in the `forward` function.
  1003. END: Represents adding a split point *after* the execution of a certain submodule in the `forward` function.
  1004. """
  1005. BEGINNING = 1
  1006. END = 2
  1007. # For backward compatibility, we kept the PipeSplitWrapper class because `class
  1008. # SplitPoint` used to be defined in this class.
  1009. class PipeSplitWrapper:
  1010. # Create a class alias for BC
  1011. SplitPoint = SplitPoint
  1012. def _split_before_forward(self, *args, **kwargs):
  1013. pipe_split()
  1014. return self._orig_forward(*args, **kwargs)
  1015. def _split_after_forward(self, *args, **kwargs):
  1016. try:
  1017. return self._orig_forward(*args, **kwargs)
  1018. finally:
  1019. pipe_split()
  1020. def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
  1021. # TODO: make this implementation out-of-place?
  1022. for qualname, split_type in spec.items():
  1023. atoms = qualname.split(".")
  1024. predecessor_module = mod
  1025. for i, atom in enumerate(atoms[:-1]):
  1026. try:
  1027. predecessor_module = getattr(predecessor_module, atom)
  1028. except AttributeError as e:
  1029. raise AttributeError(
  1030. f"Specified target {qualname} referenced "
  1031. f"nonexistent module {'.'.join(atoms[: i + 1])}"
  1032. ) from e
  1033. mod_to_wrap = getattr(predecessor_module, atoms[-1])
  1034. mod_to_wrap._orig_forward = mod_to_wrap.forward
  1035. if split_type == SplitPoint.BEGINNING:
  1036. mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap)
  1037. elif split_type == SplitPoint.END:
  1038. mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap)
  1039. else:
  1040. raise ValueError("Unknown split point type.")
  1041. def pipeline(
  1042. module: torch.nn.Module,
  1043. mb_args: tuple[Any, ...],
  1044. mb_kwargs: dict[str, Any] | None = None,
  1045. split_spec: dict[str, SplitPoint] | None = None,
  1046. split_policy: Callable[[fx.GraphModule], fx.GraphModule] | None = None,
  1047. ) -> Pipe:
  1048. """
  1049. Split a module based on a specification.
  1050. See `Pipe` for more details.
  1051. Arguments
  1052. ---------
  1053. module:
  1054. The module to be split.
  1055. mb_args:
  1056. Example positional inputs, in micro-batch form.
  1057. mb_kwargs:
  1058. Example keyword inputs, in micro-batch form. (default: `None`)
  1059. split_spec:
  1060. A dictionary using submodule names as split marker. (default: `None`)
  1061. split_policy:
  1062. The policy to use for splitting the module. (default: `None`)
  1063. Returns
  1064. -------
  1065. A pipeline representation of class `Pipe`.
  1066. """
  1067. if split_spec is not None and split_policy is not None:
  1068. raise ValueError(
  1069. "Cannot specify both `split_spec` and `split_policy`. Please use only one of them."
  1070. )
  1071. if split_spec is not None:
  1072. # Annotate split points in the module based on user spec
  1073. annotate_split_points(module, split_spec)
  1074. return Pipe.from_tracing(
  1075. mod=module,
  1076. example_args=mb_args,
  1077. example_kwargs=mb_kwargs,
  1078. )
  1079. else:
  1080. # Use split policy
  1081. return Pipe.from_tracing(
  1082. mod=module,
  1083. example_args=mb_args,
  1084. example_kwargs=mb_kwargs,
  1085. split_policy=split_policy,
  1086. )