distributed.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. """
  2. This module implements distributed training optimizations for TorchDynamo backends.
  3. It provides functionality to optimize models wrapped in DistributedDataParallel (DDP)
  4. by intelligently splitting compiled graphs to align with DDP's gradient synchronization
  5. boundaries. Key features include:
  6. - Graph partitioning based on parameter bucket sizes
  7. - Optimization of allreduce operations for distributed training
  8. - Support for parameter ignoring and buffer handling
  9. - Submodule compilation and management
  10. - Debugging utilities for distributed training
  11. The main component is the DDPOptimizer class, which handles graph splitting and
  12. recompilation to enable efficient distributed training while maintaining the benefits
  13. of compilation.
  14. """
  15. import logging
  16. import traceback
  17. from collections.abc import Callable
  18. from dataclasses import dataclass, field
  19. from typing import Any, Optional, TYPE_CHECKING
  20. from unittest import mock
  21. import torch
  22. from torch import fx
  23. from torch._dynamo.backends.registry import CompiledFn, CompilerFn
  24. from torch._dynamo.output_graph import GraphCompileReason
  25. from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
  26. from torch._logging import trace_structured
  27. from torch.fx.node import Node
  28. if TYPE_CHECKING:
  29. from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
  30. # Regular log messages should go through 'log'.
  31. # ddp_graph_log is a separate artifact logger reserved for dumping graphs.
  32. # See docs/source/logging.rst for more info.
  33. log = logging.getLogger(__name__)
  34. ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
  35. def args_str(args: Any) -> str:
  36. # a debug helper
  37. if torch.is_tensor(args):
  38. return f"T[{args.shape}]"
  39. elif isinstance(args, tuple):
  40. return f"tuple({', '.join([args_str(x) for x in args])})"
  41. elif isinstance(args, list):
  42. return f"list({', '.join([args_str(x) for x in args])})"
  43. else:
  44. return str(args)
  45. @dataclass
  46. class Bucket:
  47. size: int = 0
  48. params: list[str] = field(default_factory=list)
  49. nodes: list[fx.Node] = field(default_factory=list)
  50. # param_ids is just used for unit testing
  51. param_ids: list[int] = field(default_factory=list)
  52. # keep track of any buckets that were extended for logging purposes
  53. opcount_increased_to_capture_external_output: int = 0
  54. paramsize_before_opcount_increase: int = 0
  55. def bucket_has_external_output(bucket: Bucket) -> bool:
  56. nodes_in_bucket = set()
  57. # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
  58. # so we don't reverse it here
  59. for node in bucket.nodes:
  60. # assume node.op != output, since those are filtered in the original iteration
  61. nodes_in_bucket.add(node)
  62. for user in node.users:
  63. if user not in nodes_in_bucket:
  64. return True
  65. return False
  66. def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None:
  67. headers = ("Index", "Size (b)", "Param Names")
  68. rows: list[tuple[Optional[int], Optional[int], str]] = []
  69. # pyrefly: ignore [implicit-any]
  70. extended_buckets = []
  71. for idx, bucket in enumerate(reversed(buckets)):
  72. if len(bucket.params) > 0:
  73. rows.append((idx, bucket.size, bucket.params[0]))
  74. rows.extend((None, None, param) for param in bucket.params[1:])
  75. if bucket.opcount_increased_to_capture_external_output > 0:
  76. extended_buckets.append(
  77. (
  78. idx,
  79. bucket.opcount_increased_to_capture_external_output,
  80. bucket.size - bucket.paramsize_before_opcount_increase,
  81. )
  82. )
  83. if rows:
  84. log.info(
  85. "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
  86. bucket_bytes_cap,
  87. len(buckets),
  88. )
  89. if extended_buckets:
  90. log.warning(
  91. "Some buckets were extended beyond their requested parameter capacities"
  92. " in order to ensure each subgraph has an output node, required for fx graph partitioning."
  93. " This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
  94. " and returning no logical outputs. This should not be a problem, unless it results in too few graph"
  95. " partitions for optimal DDP performance."
  96. )
  97. try:
  98. from tabulate import tabulate
  99. log.debug(
  100. "\nDDPOptimizer produced the following bucket assignments:\n%s",
  101. tabulate(rows, headers=headers, tablefmt="simple_grid"),
  102. )
  103. if extended_buckets:
  104. log.warning(
  105. "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
  106. tabulate(
  107. extended_buckets,
  108. headers=("Index", "Extra Ops", "Extra Param Size (b)"),
  109. tablefmt="simple_grid",
  110. ),
  111. )
  112. except ImportError:
  113. log.debug(
  114. "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
  115. )
  116. else:
  117. log.debug("DDPOptimizer captured no parameters and did not split this graph.")
  118. def has_higher_order_op(gm: fx.GraphModule) -> bool:
  119. # Check if there is a higher order op in the graph
  120. for node in gm.graph.nodes:
  121. if node.op == "get_attr":
  122. maybe_param = getattr(gm, node.target)
  123. if isinstance(maybe_param, torch.fx.GraphModule):
  124. return True
  125. return False
  126. def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
  127. for name, module in split_gm.named_modules():
  128. if "." not in name and len(name):
  129. # TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
  130. module.meta = orig_gm.meta
  131. module._param_name_to_source = orig_gm._param_name_to_source
  132. def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
  133. name_to_dynamo_source = {}
  134. for node in orig_gm.graph.find_nodes(op="placeholder"):
  135. name_to_dynamo_source[node.name] = node._dynamo_source
  136. for name, module in split_gm.named_modules():
  137. if "." not in name and len(name):
  138. for node in module.graph.find_nodes(op="placeholder"):
  139. # non-placeholder in original_gm may become placeholder in submodules
  140. node._dynamo_source = name_to_dynamo_source.get(node.name)
  141. class DDPOptimizerContext:
  142. def __init__(self) -> None:
  143. self.curr_bucket: int = -1
  144. self.metadata_per_bucket: list[ViewAndMutationMeta] = []
  145. # compile each of the partitioned submodules using the user-provided compiler
  146. class SubmodCompiler(torch.fx.interpreter.Interpreter):
  147. def __init__(
  148. self,
  149. module: fx.GraphModule,
  150. compiler: CompilerFn,
  151. fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
  152. ) -> None:
  153. super().__init__(module)
  154. self.compiler = compiler
  155. self.fake_mode = fake_mode
  156. # See Note [DDPOptimizer and fw_metadata]
  157. ctx = torch._guards.TracingContext.try_get()
  158. if ctx is not None:
  159. ctx.ddp_optimizer_ctx = DDPOptimizerContext()
  160. def compile_submod(
  161. self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any
  162. ) -> Any:
  163. """
  164. Compile the submodule,
  165. using a wrapper to make sure its output is always a tuple,
  166. which is required by AotAutograd based compilers
  167. """
  168. assert len(kwargs) == 0, "We assume only args for these modules"
  169. class WrapperModule(torch.nn.Module):
  170. def __init__(
  171. self, submod: Callable[..., Any], unwrap_singleton_tuple: bool
  172. ) -> None:
  173. super().__init__()
  174. self.submod = submod
  175. self.unwrap_singleton_tuple = unwrap_singleton_tuple
  176. def forward(self, *args: Any) -> Any:
  177. x = self.submod(*args)
  178. # TODO(whc)
  179. # for some reason the isinstance check is necessary if I split one node per submod
  180. # - even though I supposedly wrapped the output in a tuple in those cases, the real
  181. # compiled module was still returning a tensor
  182. if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
  183. return x[0]
  184. return x
  185. unwrap_singleton_tuple = False
  186. for sn in input_mod.graph.nodes:
  187. if sn.op == "output":
  188. if not isinstance(sn.args[0], tuple):
  189. unwrap_singleton_tuple = True
  190. sn.args = (sn.args,)
  191. input_mod.recompile()
  192. input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment]
  193. "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
  194. " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
  195. [
  196. # it's close to useless to get a real stacktrace here, and quite verbose.
  197. traceback.FrameSummary(__file__, 0, "DDPOptimizer"),
  198. ],
  199. )
  200. wrapper = WrapperModule(
  201. self.compiler(input_mod, args),
  202. unwrap_singleton_tuple,
  203. )
  204. return wrapper
  205. # Note:
  206. #
  207. # The way distributed works today around fake tensors can be somewhat confusing.
  208. # Some of these codepaths are shared in both runtime, and compile time. The presence
  209. # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
  210. #
  211. # A few things to keep in mind:
  212. #
  213. # 1) We invoke `compile_submod` with a real module. The output of that gets stored
  214. # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
  215. #
  216. # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
  217. # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
  218. #
  219. # 3) Fake tensors should always be around during compile time.
  220. #
  221. # 4) Fake tensors should never be around at runtime.
  222. #
  223. # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
  224. # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
  225. def run_node(self, n: Node) -> Any:
  226. args, kwargs = self.fetch_args_kwargs_from_env(n)
  227. new_args = []
  228. assert self.fake_mode
  229. for arg in args:
  230. if isinstance(arg, torch.Tensor) and not isinstance(
  231. arg, torch._subclasses.FakeTensor
  232. ):
  233. new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
  234. else:
  235. new_args.append(arg)
  236. log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
  237. assert isinstance(args, tuple)
  238. assert isinstance(kwargs, dict)
  239. if n.op == "call_module":
  240. real_mod = self.fetch_attr(str(n.target))
  241. if self.fake_mode:
  242. curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
  243. else:
  244. curr_submod = real_mod
  245. ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)
  246. # When calling the compiler on the submod, inputs (new_args) are expected to
  247. # be FakeTensors already since Dynamo would have made them FakeTensors in the
  248. # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
  249. # since this wrapping happens during compilation
  250. # Note: Returning Fake Tensors on First AOT Autograd Call
  251. #
  252. # Inductor will optimize strides of outputs when it deems it profitable.
  253. # For instance, converting to channels last. When we split the graph here
  254. # into multiple inductor compilations, we need to make sure that the
  255. # output strides of one compilation is appropriately passed to the subsequent
  256. # compilations. However, the mapping from inductor output to dynamo output
  257. # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
  258. # subclass handling, etc. In order to replay all this logic we set a flag such that
  259. # the first invocation of inductor in aot_autograd will return Fake Tensors with
  260. # appropriate strides. Then, all of aot autograd's runtime logic is replayed.
  261. # This gives us the appropriately strided outputs here which will reflect runtime strides.
  262. class FakeifyFirstAOTInvocationGuard:
  263. def __init__(self) -> None:
  264. self.tc = torch._guards.TracingContext.try_get()
  265. assert self.tc
  266. self.tc.fakify_first_call = True
  267. def __del__(self) -> None:
  268. self.tc.fakify_first_call = False # type: ignore[union-attr]
  269. # For aot_eager and other backends, tracing context is not set
  270. has_tracing_context = torch._guards.TracingContext.try_get() is not None
  271. if has_tracing_context:
  272. g = FakeifyFirstAOTInvocationGuard() # noqa: F841
  273. from torch._dynamo.utils import counters
  274. init = counters["aot_autograd"]["total"]
  275. compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)
  276. # TODO - better way of doing this?
  277. # Only aot autograd handles fakifying first call
  278. invoked_aot_autograd = init != counters["aot_autograd"]["total"]
  279. # We update the original (outer) graph with a call into the compiled module
  280. # instead of the uncompiled one.
  281. self.module.delete_submodule(n.target) # type: ignore[operator]
  282. n.target = "compiled_" + n.target # type: ignore[operator]
  283. self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator]
  284. # Finally, we have to produce inputs for use compiling the next submodule,
  285. # and these need to be FakeTensors, so we execute the module under fake_mode
  286. # Because parameters are not fake we patch fake tensor mode to allow non fake inputs
  287. with (
  288. self.fake_mode,
  289. mock.patch.object(self.fake_mode, "allow_non_fake_inputs", True),
  290. ):
  291. if has_tracing_context and invoked_aot_autograd:
  292. tracing_ctx = torch._guards.TracingContext.try_get()
  293. assert tracing_ctx is not None
  294. # DDPOptimizer maintains 1 dynamo graph -> N AOT graphs
  295. # Dynamo only has 1 tracing context, so it needs to maintain all N AOT metadata instances
  296. ddp_ctx = tracing_ctx.ddp_optimizer_ctx
  297. assert ddp_ctx is not None
  298. assert tracing_ctx.fw_metadata is not None
  299. ddp_ctx.curr_bucket += 1
  300. ddp_ctx.metadata_per_bucket.append(tracing_ctx.fw_metadata)
  301. out = compiled_submod_real(*new_args, **kwargs)
  302. # output should be fake or subclass
  303. assert all(
  304. (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
  305. for t in (out if isinstance(out, (list, tuple)) else [out])
  306. )
  307. return out
  308. else:
  309. return curr_submod(*new_args, **kwargs)
  310. else:
  311. # placeholder or output nodes don't need to get compiled, just executed
  312. return getattr(self, n.op)(n.target, new_args, kwargs)
  313. class DDPOptimizer:
  314. """Note [DDPOptimizer]
  315. DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
  316. breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
  317. the boundaries of gradient-allreduce buckets chosen by DDP.
  318. Background/Motivation
  319. - DDP uses allreduce collectives to synchronize partial gradients computed on different workers
  320. - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
  321. - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
  322. at around the same time during backward and thus can share the same allreduce efficiently
  323. - Allreduces must overlap with backward compute for optimal training performance
  324. - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
  325. operates when individual grads become 'ready'
  326. - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
  327. autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
  328. fused backward function executes, preventing any overlap of compute and communication
  329. Algorithm
  330. - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
  331. this graph in reverse order to determine the true order that gradients will become ready during backward.
  332. - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
  333. and a graph break introduced
  334. - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
  335. into an outer module that is returned to the user
  336. Notes
  337. - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
  338. and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
  339. in eager.
  340. - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
  341. produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
  342. degradation approaching the baseline case where graph-splits are not used, but not worse.
  343. - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
  344. subgraphs being compiled
  345. - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
  346. left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
  347. also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
  348. it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
  349. - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
  350. and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
  351. DDPOptimizer)
  352. Debugging
  353. - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
  354. - In many cases, the log messages are helpful (they show bucket size assignments)-
  355. just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.
  356. - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
  357. in a single process (or with torchrun, in multiple processes)
  358. Args:
  359. bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
  360. set to match the equivalent parameter on the original DDP module.
  361. backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
  362. first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
  363. special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
  364. """
  365. def __init__(
  366. self,
  367. bucket_bytes_cap: int,
  368. backend_compile_fn: CompilerFn,
  369. first_bucket_cap: Optional[int] = None,
  370. ) -> None:
  371. if first_bucket_cap is not None:
  372. self.first_bucket_cap = first_bucket_cap
  373. elif torch.distributed.is_available():
  374. # this constant comes from C10D lib which is not always built
  375. self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
  376. else:
  377. self.first_bucket_cap = bucket_bytes_cap
  378. self.bucket_bytes_cap = bucket_bytes_cap
  379. assert self.first_bucket_cap <= self.bucket_bytes_cap, (
  380. "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
  381. )
  382. self.backend_compile_fn = backend_compile_fn
  383. def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool:
  384. return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
  385. def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None:
  386. bucket.size += param.untyped_storage().nbytes()
  387. bucket.params.append(name)
  388. bucket.param_ids.append(id(param))
  389. def add_module_params_to_bucket(
  390. self,
  391. mod: torch.nn.Module,
  392. bucket: Bucket,
  393. processed_modules: set[torch.nn.Module],
  394. prefix: str,
  395. ) -> None:
  396. processed_modules.add(mod)
  397. for name, param in mod.named_parameters():
  398. if param.requires_grad and not self._ignore_parameter(param):
  399. self.add_param(bucket, param, f"{prefix}_{name}")
  400. def add_param_args(self, bucket: Bucket, node: fx.Node) -> None:
  401. for arg in node.args:
  402. if not isinstance(arg, torch.fx.node.Node):
  403. continue
  404. if arg.op != "placeholder":
  405. continue
  406. param = arg.meta["example_value"]
  407. if (
  408. isinstance(param, torch.nn.Parameter)
  409. and param.requires_grad
  410. and not self._ignore_parameter(param)
  411. ):
  412. self.add_param(bucket, param, str(arg.target))
  413. def compile_fn(
  414. self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]
  415. ) -> CompiledFn:
  416. """
  417. Implements graph splitting, first determining a set of of buckets by counting
  418. parameter sizes in reverse graph order, then invoking the user/backend compiler
  419. to compile each subgraph. Finally, stitches compiled graphs into one graphmodule
  420. and returns its callable.
  421. """
  422. # 1: compute the partition map according to DDP bucket logic
  423. buckets = [Bucket()] # (size, param_names)
  424. processed_modules: set[torch.nn.Module] = set()
  425. for node in reversed(gm.graph.nodes):
  426. if node.op in ("output", "placeholder"):
  427. continue
  428. if (
  429. buckets[0].size >= self.bucket_bytes_cap
  430. or len(buckets) == 1
  431. and buckets[0].size >= self.first_bucket_cap
  432. ):
  433. if bucket_has_external_output(buckets[0]):
  434. buckets.insert(0, Bucket())
  435. else:
  436. # continue building this bucket past the point of filling its parameter capacity,
  437. # to increase chances it contains at least one node that is either a global output or
  438. # passed as input to a subsequent graph
  439. if buckets[0].opcount_increased_to_capture_external_output == 0:
  440. buckets[0].paramsize_before_opcount_increase = buckets[0].size
  441. buckets[0].opcount_increased_to_capture_external_output += 1
  442. if node.op == "call_function":
  443. self.add_param_args(buckets[0], node)
  444. elif node.op == "call_module":
  445. target_mod = gm.get_submodule(node.target)
  446. if target_mod not in processed_modules:
  447. self.add_module_params_to_bucket(
  448. target_mod, buckets[0], processed_modules, node.target
  449. )
  450. elif node.op == "call_method":
  451. if isinstance(node.args[0].target, str):
  452. target_mod = None
  453. try:
  454. target_mod = gm.get_submodule(node.args[0].target)
  455. except AttributeError:
  456. pass
  457. if target_mod is not None and target_mod not in processed_modules:
  458. self.add_module_params_to_bucket(
  459. target_mod, buckets[0], processed_modules, node.target
  460. )
  461. # This handles situations like tmp = torch.mm(x, self.weight.t())
  462. # t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None
  463. # tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None
  464. self.add_param_args(buckets[0], node)
  465. elif node.op == "get_attr":
  466. maybe_param = getattr(gm, node.target)
  467. if (
  468. isinstance(maybe_param, torch.nn.Parameter)
  469. and maybe_param.requires_grad
  470. and not self._ignore_parameter(maybe_param)
  471. ):
  472. self.add_param(buckets[0], maybe_param, node.target)
  473. # All nodes have to be mapped to a bucket, even if they don't have their own params
  474. # Ignored params still end up in buckets, we just don't count them towards the capacity
  475. buckets[0].nodes.append(node)
  476. if len(buckets) > 1 and buckets[0].size == 0:
  477. # we collected a small preamble graph with ops that don't include parameters, fuse it back
  478. buckets[1].nodes.extend(buckets[0].nodes)
  479. assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
  480. del buckets[0]
  481. # stash buckets for testing/debugging purposes
  482. self.buckets = buckets
  483. pretty_print_buckets(buckets, self.bucket_bytes_cap)
  484. if len(buckets) == 1:
  485. # bypass split/fuse logic if there is only one bucket
  486. return self.backend_compile_fn(gm, example_inputs)
  487. # 2: partition the graphmodule according to bucket capacity
  488. partition_map = {}
  489. for idx, b in enumerate(buckets):
  490. for node in b.nodes:
  491. partition_map[node] = idx
  492. split_gm = fx.passes.split_module.split_module(
  493. gm,
  494. None, # type: ignore[arg-type]
  495. lambda node: partition_map[node],
  496. )
  497. # See note [Assumption on Dynamo Metadata]
  498. propagate_dynamo_source(gm, split_gm)
  499. propagate_metadata(gm, split_gm)
  500. debug_str = (
  501. f"\n---orig graph---\n{gm.graph}\n"
  502. + f"\n---split graph---\n{split_gm.graph}\n"
  503. )
  504. for name, module in split_gm.named_modules():
  505. if "." not in name and len(name):
  506. # only print the submod graphs, not their children
  507. debug_str += f"\n---{name} graph---\n{module.graph}\n"
  508. debug_str += "\n---------------\n"
  509. ddp_graph_log.debug(debug_str)
  510. trace_structured(
  511. "optimize_ddp_split_graph",
  512. payload_fn=lambda: split_gm.print_readable(print_output=False),
  513. )
  514. for name, module in split_gm.named_modules():
  515. if "." not in name and len(name):
  516. trace_structured(
  517. "optimize_ddp_split_child",
  518. lambda: {"name": name},
  519. payload_fn=lambda: module.print_readable(print_output=False),
  520. )
  521. fake_mode = detect_fake_mode(example_inputs)
  522. if fake_mode is None:
  523. fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
  524. submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode)
  525. with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
  526. submod_compiler.run(*example_inputs)
  527. split_gm.recompile()
  528. ddp_graph_log.debug(
  529. "\n---final graph---\n%s\n---------------\n", split_gm.graph
  530. )
  531. return split_gm