compiled_autograd.py 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634
  1. """
  2. Provides functionality for compiling PyTorch's autograd (automatic differentiation) system.
  3. This module implements compiled autograd, which traces and optimizes backward pass
  4. computations at runtime. The key components are:
  5. - AutogradCompilerInstance: Traces and compiles autograd graphs using FX
  6. - Context managers (_enable/_disable): Control when compiled autograd is active
  7. - Utility functions: Support graph manipulation, tensor operations, and hooks
  8. Compiled autograd can significantly improve backward pass performance by removing
  9. Python overhead and enabling additional optimizations. It works by capturing
  10. backward computations into an FX graph that can be compiled and optimized,
  11. while maintaining the same semantics as eager mode autograd.
  12. """
  13. import contextlib
  14. import functools
  15. import itertools
  16. import operator
  17. import time
  18. from collections import Counter, defaultdict
  19. from collections.abc import Callable, Generator, Sequence
  20. from typing import Any, Optional, TYPE_CHECKING, Union
  21. import torch
  22. import torch.utils._pytree as pytree
  23. from torch._dispatch.python import enable_python_dispatcher
  24. from torch._dynamo.external_utils import (
  25. call_accumulate_grad,
  26. call_backward,
  27. call_hook,
  28. FakeCompiledAutogradEngine,
  29. unwrap_maybe_dynamic_int,
  30. )
  31. from torch._dynamo.source import GetItemSource, LocalSource
  32. from torch._dynamo.utils import (
  33. counters,
  34. get_chromium_event_logger,
  35. lazy_format_graph_code,
  36. set_locals_to_steal,
  37. )
  38. from torch._functorch._aot_autograd.runtime_wrappers import (
  39. AutogradLazyBackwardCompileInfo,
  40. CachedAutogradLazyBackwardCompileInfo,
  41. )
  42. from torch._guards import compile_context, CompileContext, CompileId, Source
  43. from torch._logging import getArtifactLogger, trace_structured
  44. from torch._prims_common import clone_preserve_strides
  45. from torch._subclasses import FakeTensorMode
  46. from torch._subclasses.fake_tensor import FakeTensor
  47. from torch.fx import GraphModule
  48. from torch.fx.experimental._backward_state import BackwardState
  49. from torch.fx.experimental.proxy_tensor import (
  50. decompose,
  51. disable_autocast_cache,
  52. disable_proxy_modes_tracing,
  53. fetch_object_proxy,
  54. ProxyTorchDispatchMode,
  55. PythonKeyTracer,
  56. track_tensor_tree,
  57. )
  58. from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
  59. from torch.fx.traceback import preserve_node_meta, set_stack_trace
  60. from torch.types import FloatLikeType, IntLikeType
  61. from torch.utils._ordered_set import OrderedSet
  62. from torch.utils._traceback import CapturedTraceback
  63. if TYPE_CHECKING:
  64. from torch.fx.proxy import Proxy
  65. TURN_OFF_MSG = """You can turn off compiled autograd by either:
  66. 1. Moving the unsupported autograd call outside of the torch.compile'd region.
  67. 2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager.
  68. 3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call.
  69. 4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program."""
  70. compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
  71. verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
  72. def snapshot_verbose_logging_enabled() -> bool:
  73. return torch._logging._internal.log_state.is_artifact_enabled(
  74. "compiled_autograd_verbose"
  75. )
  76. def snapshot_cudagraph_enabled() -> bool:
  77. return torch._inductor.config.triton.cudagraphs
  78. def maybe_clone(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
  79. if x is not None:
  80. return clone_preserve_strides(x)
  81. return x
  82. def extract_bw_module(CompiledFunction: Any) -> Callable[..., Any]:
  83. if isinstance(
  84. CompiledFunction._lazy_backward_info, AutogradLazyBackwardCompileInfo
  85. ):
  86. return CompiledFunction._lazy_backward_info.bw_module
  87. elif isinstance(
  88. CompiledFunction._lazy_backward_info, CachedAutogradLazyBackwardCompileInfo
  89. ):
  90. with torch._subclasses.fake_tensor.unset_fake_temporarily():
  91. return CompiledFunction._lazy_backward_info.bw_module_fn()
  92. else:
  93. raise AssertionError(
  94. "Unexpected Lazy Backward Compilation Info Type. Please file an issue."
  95. )
  96. # Note: [Anomaly Mode Semantics in Compiled Autograd]
  97. # In the eager autograd engine, anomaly mode is able to detect NaNs
  98. # after each node. This is useful, because the executed code with
  99. # and without anomaly mode are the same. So assuming determinism,
  100. # a NaN in regular mode should also happen in anomaly mode.
  101. #
  102. # With torch.compile, following eager semantics would require inserting
  103. # runtime asserts to check for NaNs, which could prevent some fusions.
  104. # This results in different code being run with and without anomaly mode.
  105. # So different semantics are needed, this implementation below will check
  106. # for NaNs at the end of the autograd call, instead of after each node
  107. class NaNChecker:
  108. def __init__(self, accumulate_grad: bool) -> None:
  109. self.accumulate_grad = accumulate_grad
  110. self.params_indices: list[int] = []
  111. self.params_to_check: dict[str, torch.Tensor] = {}
  112. self.output_names: list[str] = []
  113. def prep_with_graph(self, graph: torch.fx.Graph) -> None:
  114. inputs_node = next(iter(graph.nodes))
  115. acc_grad_nodes = graph.find_nodes(
  116. op="call_function", target=call_accumulate_grad
  117. )
  118. output_nodes = graph.find_nodes(op="output")[0].args[0]
  119. assert self.accumulate_grad == bool(
  120. acc_grad_nodes
  121. ) and self.accumulate_grad == (not output_nodes)
  122. for node in acc_grad_nodes:
  123. param_node = node.args[0]
  124. # AccumulateGrad always saves a reference to the param
  125. # so Compiled Autograd will always lift the param and
  126. # this should always be true
  127. assert (
  128. param_node.target is operator.getitem
  129. and param_node.args[0] is inputs_node # type: ignore[possibly-undefined]
  130. and isinstance(param_node.args[1], int)
  131. )
  132. self.params_indices.append(param_node.args[1])
  133. self.output_names = [node.name for node in output_nodes]
  134. def prep_with_inputs(self, inputs: tuple[torch.Tensor, ...]) -> None:
  135. if not self.accumulate_grad:
  136. # Using .grad, nothing to prep
  137. return
  138. # Using .backward, we must check existing grads on params if any
  139. for idx in self.params_indices:
  140. grad = inputs[idx].grad
  141. if grad is not None:
  142. assert not torch.isnan(grad).any(), (
  143. f"Compiled autograd running under anomaly mode with inputs[{idx}] already "
  144. f"having NaN gradient. This is not supported. {TURN_OFF_MSG}"
  145. )
  146. self.params_to_check[f"inputs[{idx}]"] = inputs[idx]
  147. def check(self, out: tuple[torch.Tensor, ...]) -> None:
  148. if self.accumulate_grad:
  149. # Using .backward, graph outputs are empty
  150. assert not out
  151. nan_params: list[str] = []
  152. for inputs_str, param in self.params_to_check.items():
  153. assert param.grad is not None # not true for autograd.grad
  154. if torch.isnan(param.grad).any():
  155. nan_params.append(inputs_str)
  156. if nan_params:
  157. raise RuntimeError(
  158. f"Compiled Autograd returned NaN gradients for parameters: {','.join(nan_params)}."
  159. )
  160. else:
  161. # Using .grad, graph outputs are grads
  162. nan_grads: list[str] = []
  163. for i, grad in enumerate(out):
  164. if torch.isnan(grad).any():
  165. nan_grads.append(self.output_names[i])
  166. if nan_grads:
  167. raise RuntimeError(
  168. f"Compiled Autograd returned NaN gradients for output nodes: {','.join(nan_grads)}."
  169. )
  170. # We lazily bind "functional backward" variants for PyTorch built-in autograd
  171. # nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0
  172. # Each "functional backward" is bound the first time the node's apply_with_saved
  173. # function is called. It's possible to avoid lazy binding and instead bind
  174. # all of this upfront (perhaps at import time) via codegen changes.
  175. class OpNamespace:
  176. def __init__(self) -> None:
  177. self.custom_function_name_counter: Counter[str] = Counter()
  178. def add(
  179. self,
  180. name: str,
  181. fn: Callable[..., Any],
  182. is_custom_function: bool,
  183. is_traceable: bool,
  184. ) -> str:
  185. if is_custom_function:
  186. name = "CppNode" + name
  187. count = self.custom_function_name_counter[name]
  188. self.custom_function_name_counter[name] += 1
  189. name = f"{name}{count}"
  190. assert not hasattr(self, name)
  191. result = Op(name, fn, is_custom_function)
  192. if is_traceable:
  193. setattr(self, name, torch._dynamo.allow_in_graph(result))
  194. else:
  195. # C++ autograd function was not marked as traceable
  196. # Dynamo can't dry run it at compile time, so must fallback to eager
  197. @torch._dynamo.disable # type: ignore[misc]
  198. def run_non_traceable_cpp_in_eager(*args: Any, **kwargs: Any) -> Any:
  199. return result(*args, **kwargs)
  200. setattr(self, name, run_non_traceable_cpp_in_eager)
  201. return name
  202. def get(self, name: str) -> Any:
  203. return getattr(self, name)
  204. class Op:
  205. def __init__(
  206. self, name: str, fn: Callable[..., Any], is_custom_function: bool
  207. ) -> None:
  208. self.fn = fn
  209. self.is_custom_function = is_custom_function
  210. self.__name__ = name
  211. self.__module__ = "torch._dynamo.compiled_autograd.ops"
  212. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  213. return self.fn(*args, **kwargs)
  214. def __repr__(self) -> str:
  215. return self.__module__ + "." + self.__name__
  216. ops = OpNamespace()
  217. _graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"]
  218. _impure_targets = OrderedSet(
  219. [
  220. call_hook,
  221. call_backward,
  222. FakeCompiledAutogradEngine._exec_final_callbacks_stub,
  223. call_accumulate_grad,
  224. ]
  225. )
  226. COMPILE_COUNTER = itertools.count()
  227. def make_compile_context(compiled_autograd_id: int) -> Any:
  228. return compile_context(
  229. CompileContext(
  230. CompileId(
  231. compiled_autograd_id=compiled_autograd_id,
  232. frame_id=None,
  233. frame_compile_id=None,
  234. )
  235. )
  236. )
  237. class AutogradCompilerInstance:
  238. def __init__(self, compiler_fn: Callable[..., Any]) -> None:
  239. self.compiler_fn = compiler_fn
  240. self.stack = contextlib.ExitStack()
  241. self.close = self.stack.close
  242. self.shape_env = ShapeEnv()
  243. self.fake_tensor_mode = FakeTensorMode(
  244. allow_fallback_kernels=True,
  245. allow_non_fake_inputs=True,
  246. shape_env=self.shape_env,
  247. )
  248. self.fx_tracer = PythonKeyTracer()
  249. self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
  250. self.hooks_proxy: Optional[Proxy] = None
  251. def wrap_fake(self, x: torch.Tensor, source: Optional[Source]) -> FakeTensor:
  252. assert isinstance(x, torch.Tensor)
  253. return self.fake_tensor_mode.from_tensor(x, source=source)
  254. @staticmethod
  255. def source(name: str, idx: Any) -> GetItemSource:
  256. return GetItemSource(LocalSource(name), idx)
  257. def begin_capture(
  258. self,
  259. inputs: list[torch.Tensor],
  260. sizes: list[int],
  261. scalars: list[Union[int, float]],
  262. origins: list[list[tuple[int, str]]],
  263. accumulate_grad: bool,
  264. check_nans: bool,
  265. ) -> tuple[str, list[torch.Tensor], list[IntLikeType], list[FloatLikeType]]:
  266. counters["compiled_autograd"]["captures"] += 1
  267. self.id = next(COMPILE_COUNTER)
  268. self.aot_id_counter: dict[int, int] = defaultdict(int)
  269. self.compile_context = make_compile_context(self.id)
  270. self.compile_context.__enter__()
  271. self.nan_checker = NaNChecker(accumulate_grad) if check_nans else None
  272. self.start_time_ns = time.time_ns()
  273. get_chromium_event_logger().log_event_start(
  274. "compiled_autograd",
  275. self.start_time_ns,
  276. {"graph_id": self.id},
  277. log_pt2_compile_event=True,
  278. )
  279. self.fx_tracer.root = torch.nn.Module()
  280. self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
  281. self.fx_tracer.tensor_attrs = {}
  282. self.symnode_proxy_lookup = {}
  283. (
  284. args_proxy,
  285. self.sizes_proxy,
  286. self.scalars_proxy,
  287. self.hooks_proxy,
  288. self.packed_data_proxy,
  289. ) = (
  290. self.fx_tracer.create_proxy("placeholder", name, (), {})
  291. for name in _graph_placeholders
  292. )
  293. self.stack.enter_context(preserve_node_meta())
  294. inputs_origins, sizes_origins, scalars_origins = origins
  295. # Turn on PythonDispatcher during initial trace to make it identifiable
  296. # that tracing is happening, which is needed to prevent hashing symints
  297. self.stack.enter_context(enable_python_dispatcher())
  298. # tensor inputs to fake tensors
  299. x = inputs[0] # mypy will complain about unbound x
  300. try:
  301. for idx, x in enumerate(inputs):
  302. inputs[idx] = self.wrap_fake(x, self.source("inputs", idx))
  303. except Exception as e:
  304. raise NotImplementedError(
  305. f"Found tensor of type {type(x)}, which is not supported by FakeTensorMode. {TURN_OFF_MSG}"
  306. ) from e
  307. self.bind_objects_to_proxies(inputs, args_proxy, inputs_origins)
  308. # size inputs to symints
  309. sym_sizes = [
  310. self.shape_env.create_unspecified_symint_and_symbol(
  311. val,
  312. self.source("sizes", idx),
  313. DimDynamic.DYNAMIC,
  314. )
  315. for idx, val in enumerate(sizes)
  316. ]
  317. # We want to mark every size as dynamic, but since there's no way to
  318. # mark a primitive `int` as dynamic, we need to wrap it in a tensor.
  319. # In the graph, we unwrap it with `unwrap_maybe_dynamic_int` back into a primitive.
  320. proxies = [self.sizes_proxy[i] for i in range(len(sym_sizes))] # type: ignore[index]
  321. for i, symint in enumerate(sym_sizes):
  322. proxies[i] = self.fx_tracer.create_proxy(
  323. "call_function",
  324. unwrap_maybe_dynamic_int,
  325. (proxies[i],),
  326. {},
  327. )
  328. self.symnode_proxy_lookup[symint.node] = proxies[i]
  329. proxies = self.bind_objects_to_proxies(sym_sizes, proxies, sizes_origins)
  330. for idx, val in enumerate(scalars):
  331. source = self.source("scalars", idx)
  332. if isinstance(val, int):
  333. scalars[idx] = self.shape_env.create_unspecified_symint_and_symbol(
  334. val,
  335. source,
  336. DimDynamic.DYNAMIC,
  337. )
  338. elif isinstance(val, float):
  339. scalars[idx] = self.shape_env.create_symfloatnode(
  340. self.shape_env.create_unspecified_symbol(
  341. val,
  342. source=source,
  343. dynamic_dim=DimDynamic.DYNAMIC,
  344. ),
  345. hint=val,
  346. source=source,
  347. )
  348. else:
  349. raise AssertionError("Unexpected scalar type: ", type(val))
  350. self.bind_objects_to_proxies(scalars, self.scalars_proxy, scalars_origins)
  351. for i, symval in enumerate(scalars):
  352. self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr]
  353. # TODO(jansel): are all these modes needed?
  354. self.stack.enter_context(decompose({}))
  355. self.stack.enter_context(self.fake_tensor_mode)
  356. self.stack.enter_context(self.proxy_mode)
  357. self.stack.enter_context(disable_autocast_cache())
  358. # Needed to make sure we don't accidentally specialize any symbols
  359. assert self.fake_tensor_mode.shape_env is not None
  360. env = self.fake_tensor_mode.shape_env
  361. self.stack.enter_context(
  362. torch.fx.experimental.symbolic_shapes._suppress_guards(env)
  363. )
  364. # pyrefly: ignore [bad-return]
  365. return (
  366. str(CompileContext.current_compile_id()),
  367. inputs,
  368. sym_sizes,
  369. scalars, # type: ignore[return-value]
  370. )
  371. def log_compile_reasons(
  372. self,
  373. compile_reasons: list[str],
  374. ) -> None:
  375. assert compile_reasons
  376. trace_structured(
  377. "artifact",
  378. metadata_fn=lambda: {
  379. "name": "compiled_autograd_compile_reasons",
  380. "encoding": "json",
  381. },
  382. payload_fn=lambda: compile_reasons,
  383. )
  384. def proxy_call_aot_backward(
  385. self,
  386. pinputs: Sequence[Any],
  387. psaved_tensors: Sequence[torch.Tensor],
  388. saved_tensors: Sequence[torch.Tensor],
  389. pctx: Any,
  390. ctx: Any,
  391. maybe_backward_state_idx: Optional[int],
  392. opaque_object_indices: list[int],
  393. ) -> Sequence[Any]:
  394. # The AOTBackward call consists of three things: the prologue, the
  395. # backward graph, and the epilogue.
  396. # Our strategy is:
  397. # - allow_in_graph the prologue (in the CA graph and Dynamo graph),
  398. # - copy-paste the backward graph into the CA graph so that CA passes and Dynamo can see it
  399. # - trace directly through the epilogue. Anything that gets baked in is
  400. # constant metadata (for example, metadata about the number of outputs, or removing
  401. # RNG arguments or effect tokens).
  402. # If Dynamo graph capture were better, then we could add a node for the prologue
  403. # into the CA graph and have Dynamo trace into it.
  404. psymints = [self.to_proxy(e) for e in ctx._get_compiled_autograd_symints()]
  405. popaque_objects = [
  406. self.hooks_proxy[idx] # type: ignore[index]
  407. for idx in opaque_object_indices
  408. ]
  409. # NOTE: we should only close over constants
  410. CompiledFunction = ctx._forward_cls
  411. bw_module = extract_bw_module(CompiledFunction)
  412. metadata = CompiledFunction.metadata
  413. maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata
  414. aot_id = CompiledFunction._aot_id
  415. del CompiledFunction
  416. if torch.is_grad_enabled():
  417. for output_alias_info in metadata.output_info:
  418. if output_alias_info.requires_grad:
  419. raise RuntimeError(
  420. "torch.compile does not currently support higher order gradients."
  421. )
  422. @torch._dynamo.allow_in_graph # type: ignore[misc]
  423. def call_aot_bwd_prologue(
  424. ctx_saved_tensors: Sequence[torch.Tensor],
  425. ctx_symints: Sequence[IntLikeType],
  426. ctx_opaque_objs: Sequence[Any],
  427. *flat_args: Sequence[Any],
  428. ) -> Any:
  429. out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional(
  430. ctx_saved_tensors,
  431. ctx_symints,
  432. ctx_opaque_objs,
  433. metadata,
  434. maybe_subclass_metadata,
  435. *flat_args,
  436. )
  437. return out
  438. pgrads = self.fx_tracer.create_proxy(
  439. kind="call_function",
  440. target=call_aot_bwd_prologue,
  441. args=(
  442. psaved_tensors,
  443. psymints,
  444. popaque_objects,
  445. *pinputs,
  446. ),
  447. kwargs={},
  448. )
  449. pbackward_state = None
  450. if maybe_backward_state_idx is not None:
  451. pbackward_state = self.hooks_proxy[maybe_backward_state_idx] # type: ignore[index]
  452. # Copy-paste the AOT backward graph into the compiled autograd graph
  453. def copy_paste_aot_backward_graph() -> list[torch.Tensor]:
  454. def num_inputs(graph: torch.fx.Graph) -> int:
  455. num_args = 0
  456. for node in graph.nodes:
  457. if node.op == "placeholder":
  458. num_args += 1
  459. continue
  460. else:
  461. break
  462. return num_args
  463. # set up the proxy inputs to bw_module
  464. # the calling convention is: [*symints, *args (primals and tangents), backward_state]
  465. num_args = num_inputs(bw_module.graph) # type: ignore[attr-defined]
  466. pall_args = [
  467. pgrads[i] for i in range(num_args - int(pbackward_state is not None))
  468. ]
  469. # replace the symints with our symints
  470. symints = ctx._get_compiled_autograd_symints()
  471. assert len(symints) == len(ctx.symints)
  472. psymints = [self.to_proxy(e) for e in symints]
  473. pall_args[: len(symints)] = psymints
  474. # Add backward_state
  475. if pbackward_state is not None:
  476. pall_args.append(pbackward_state)
  477. # run over all nodes of the aot_backward graph.
  478. # copy and paste them all into the compiled autograd graph.
  479. args_idx = 0
  480. # pyrefly: ignore [implicit-any]
  481. value_remap = {}
  482. poutputs: Optional[list[torch.fx.Proxy]] = None
  483. # names of nodes must appear only once in the fx.Graph
  484. # dedup AOT backwards that appear multiple times
  485. deduped_aot_id = str(aot_id)
  486. if self.aot_id_counter[aot_id]:
  487. deduped_aot_id += f"_{self.aot_id_counter[aot_id]}"
  488. self.aot_id_counter[aot_id] += 1
  489. def make_unique(node_name: str) -> str:
  490. # make it both informative and unique
  491. return f"aot{deduped_aot_id}_{node_name}"
  492. for node in bw_module.graph.nodes: # type: ignore[attr-defined]
  493. if node.op == "placeholder":
  494. ph = pall_args[args_idx].node
  495. ph.name = make_unique(node.name)
  496. value_remap[node] = ph
  497. args_idx += 1
  498. elif node.op == "output":
  499. assert len(node.args) == 1
  500. poutputs = [
  501. torch.fx.Proxy(value_remap[n], self.fx_tracer)
  502. if isinstance(n, torch.fx.Node)
  503. else n
  504. for n in node.args[0]
  505. ]
  506. elif node.op == "get_attr":
  507. name = node.target
  508. qualname = self.fx_tracer.get_fresh_qualname(name)
  509. setattr(self.fx_tracer.root, qualname, getattr(bw_module, name))
  510. result = self.fx_tracer.create_node("get_attr", qualname, (), {})
  511. result.name = make_unique(node.name)
  512. value_remap[node] = result
  513. elif node.op == "call_function":
  514. if node.target is torch.ops.aten.view.default:
  515. # this aot bwd graph is being lazily compiled
  516. # we must manually apply the view_to_reshape post grad pass
  517. # since it was already applied to the aot fwd, and baked into the gradients
  518. node.target = torch.ops.aten.reshape.default
  519. result = self.fx_tracer.graph.node_copy(
  520. node, lambda n: value_remap[n]
  521. )
  522. result.name = make_unique(node.name)
  523. value_remap[node] = result
  524. elif node.op == "call_module":
  525. name = node.target
  526. qualname = self.fx_tracer.get_fresh_qualname(name)
  527. setattr(self.fx_tracer.root, qualname, getattr(bw_module, name))
  528. result = self.fx_tracer.graph.node_copy(
  529. node, lambda n: value_remap[n]
  530. )
  531. result.target = qualname
  532. value_remap[node] = result
  533. else:
  534. raise AssertionError("shouldn't get here")
  535. assert poutputs is not None
  536. # In general we don't know what the shapes of the outputs are, so allocate
  537. # some dummy sizes for them.
  538. def dummy() -> torch.Tensor:
  539. with disable_proxy_modes_tracing():
  540. return torch.zeros(0, 0, 0, 0, 123)
  541. outputs = [
  542. dummy() if isinstance(o, torch.fx.Proxy) else o for o in poutputs
  543. ]
  544. self.bind_objects_to_proxies(outputs, poutputs)
  545. return outputs
  546. outputs = copy_paste_aot_backward_graph()
  547. def proxy_subclass_constructor(
  548. subclass_meta: Any, is_runtime: bool, unwrapped_args: Sequence[Any]
  549. ) -> torch.Tensor:
  550. @torch._dynamo.allow_in_graph # type: ignore[misc]
  551. def make_subclass(*unwrapped_args: Any) -> Any:
  552. return subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime)
  553. punwrapped_args = pytree.tree_map(self.to_proxy, unwrapped_args)
  554. poutput = self.fx_tracer.create_proxy(
  555. kind="call_function",
  556. target=make_subclass,
  557. args=tuple(punwrapped_args),
  558. kwargs={},
  559. )
  560. output = self.allocate_dummy()
  561. self.bind_objects_to_proxies([output], [poutput])
  562. return output
  563. results = torch._functorch._aot_autograd.runtime_wrappers._backward_epilogue_functional(
  564. metadata,
  565. maybe_subclass_metadata,
  566. outputs,
  567. make_subclass_override=proxy_subclass_constructor,
  568. )
  569. presults = pytree.tree_map(self.to_proxy, results)
  570. return presults
  571. def proxy_call_backward(
  572. self,
  573. inputs: Sequence[Any],
  574. output_metadatas: Sequence[Any],
  575. saved_tensors: Sequence[torch.Tensor],
  576. backward_idx: int,
  577. ctx: torch.autograd.function.BackwardCFunction,
  578. maybe_backward_state_idx: Optional[int],
  579. opaque_object_indices: list[int],
  580. ) -> tuple[Optional[torch.Tensor], ...]:
  581. assert self.hooks_proxy is not None
  582. pctx = self.hooks_proxy[backward_idx] # type: ignore[index]
  583. pinputs = self.to_proxy(inputs)
  584. psaved_tensors = self.to_proxy(saved_tensors)
  585. if hasattr(ctx._forward_cls, "_aot_id"): # type: ignore[attr-defined]
  586. # AOT backward
  587. proxies = self.proxy_call_aot_backward(
  588. pinputs,
  589. psaved_tensors,
  590. saved_tensors,
  591. pctx,
  592. ctx,
  593. maybe_backward_state_idx,
  594. opaque_object_indices,
  595. )
  596. else:
  597. proxies = self.fx_tracer.create_proxy(
  598. kind="call_function",
  599. target=call_backward,
  600. args=(
  601. pctx,
  602. psaved_tensors,
  603. *pinputs,
  604. ),
  605. kwargs={},
  606. )
  607. assert proxies is not None
  608. with disable_proxy_modes_tracing():
  609. # create fake Tensors
  610. grad_ins: list[Optional[torch.Tensor]] = []
  611. for idx, output_metadata in enumerate(output_metadatas):
  612. if output_metadata is None or proxies[idx] is None:
  613. grad_ins.append(None)
  614. continue
  615. layout, device, dtype, size = output_metadata
  616. grad_ins.append(
  617. torch.empty(size=size, dtype=dtype, layout=layout, device=device)
  618. )
  619. self.bind_objects_to_proxies(grad_ins, proxies)
  620. return tuple(grad_ins)
  621. def call_copy_slices_prologue(
  622. self,
  623. inputs: Sequence[Any],
  624. base_sizes: Sequence[Any],
  625. base_strides: Sequence[Any],
  626. base_storage_offset: Any,
  627. view_sizes: Sequence[Any],
  628. view_strides: Sequence[Any],
  629. view_storage_offset: Any,
  630. ) -> Sequence[torch.Tensor]:
  631. args = (
  632. inputs,
  633. self.to_proxy(base_sizes),
  634. self.to_proxy(base_strides),
  635. self.to_proxy(base_storage_offset),
  636. self.to_proxy(view_sizes),
  637. self.to_proxy(view_strides),
  638. self.to_proxy(view_storage_offset),
  639. )
  640. return self.proxy_call(copy_slices_prologue, args, [None] * 3)
  641. def call_copy_slices_epilogue(
  642. self,
  643. needs_input_grad: Sequence[bool],
  644. result: torch.Tensor,
  645. res: Sequence[Any],
  646. grad_slice: torch.Tensor,
  647. ) -> Sequence[torch.Tensor]:
  648. return self.proxy_call(
  649. copy_slices_epilogue,
  650. (needs_input_grad, result, res, grad_slice),
  651. [None] * len(needs_input_grad),
  652. )
  653. def allocate_dummy(self) -> torch.Tensor:
  654. with disable_proxy_modes_tracing():
  655. # Weird quantity so it's easy to grep
  656. return torch.zeros([0, 123456789])
  657. def bind_function(
  658. self,
  659. fn_name: str,
  660. fn: Callable[..., Any],
  661. is_custom_function: bool,
  662. is_traceable: bool,
  663. ) -> str:
  664. """Binds ops.fn_name = fn"""
  665. return ops.add(fn_name, fn, is_custom_function, is_traceable)
  666. def apply_functional(
  667. self,
  668. fn_name: str,
  669. grads: Sequence[Any],
  670. args: Any,
  671. output_metadata: Sequence[Any],
  672. ) -> Sequence[torch.Tensor]:
  673. """Proxies a call to ops.fn_name(grads, *args) into the graph"""
  674. op = ops.get(fn_name)
  675. return self.proxy_call(op, (grads, *args), output_metadata)
  676. def proxy_call(
  677. self, fn: Callable[..., Any], args: Any, output_metadata: Sequence[Any]
  678. ) -> Sequence[torch.Tensor]:
  679. """Proxies a call to fn(*args) into the graph"""
  680. proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args)
  681. proxy_out = self.fx_tracer.create_proxy(
  682. "call_function", fn, args=proxy_args, kwargs={}
  683. )
  684. result = [self.allocate_dummy() for _ in output_metadata]
  685. self.bind_objects_to_proxies(result, [proxy_out[i] for i in range(len(result))])
  686. return result
  687. def validate_outputs(
  688. self, _: Any, outputs: Sequence[Any], args: Any, output_metadata: Sequence[Any]
  689. ) -> Sequence[torch.Tensor]:
  690. """Proxies a call to ops.validate_outputs(outputs, *args) into the graph"""
  691. op = ops.get("validate_outputs")
  692. proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args))
  693. new_proxy_outputs = self.fx_tracer.create_proxy(
  694. "call_function", op, args=proxy_args, kwargs={}
  695. )
  696. assert len(output_metadata) == len(outputs)
  697. self.bind_objects_to_proxies(outputs, new_proxy_outputs)
  698. return outputs
  699. def accumulate(self, old_var: Any, new_var: Any) -> torch.Tensor:
  700. old_var_proxy = self.to_proxy(old_var)
  701. new_var_proxy = self.to_proxy(new_var)
  702. proxy_out = self.fx_tracer.create_proxy(
  703. "call_function", torch.add, args=(old_var_proxy, new_var_proxy), kwargs={}
  704. )
  705. result = self.allocate_dummy()
  706. self.bind_objects_to_proxies([result], [proxy_out])
  707. return result
  708. def accumulate_grad(
  709. self, variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool
  710. ) -> None:
  711. self.fx_tracer.create_proxy(
  712. "call_function",
  713. call_accumulate_grad,
  714. args=(
  715. self.to_proxy(variable),
  716. self.to_proxy(grad),
  717. has_post_hooks,
  718. ),
  719. kwargs={},
  720. )
  721. def proxy_call_hook(
  722. self, hook: Callable[..., Any], *args: Any, **kwargs: Any
  723. ) -> torch.fx.Proxy:
  724. return self.fx_tracer.create_proxy(
  725. "call_function",
  726. call_hook,
  727. (
  728. hook,
  729. *[self.to_proxy(x) for x in args],
  730. ),
  731. kwargs,
  732. )
  733. def unpack_hook(self, hook_id: int, data_id: int) -> torch.Tensor:
  734. assert self.hooks_proxy is not None
  735. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  736. data = self.packed_data_proxy[data_id] # type: ignore[index]
  737. proxy = self.proxy_call_hook(
  738. hook,
  739. data,
  740. hook_type="unpack_hook",
  741. )
  742. out = self.allocate_dummy()
  743. self.bind_objects_to_proxies([out], [proxy])
  744. return out
  745. def tensor_pre_hook(
  746. self, inputs: list[torch.Tensor], hook_id: int, i: int
  747. ) -> list[torch.Tensor]:
  748. assert self.hooks_proxy is not None
  749. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  750. proxy = self.proxy_call_hook(
  751. hook,
  752. inputs[i],
  753. hook_type="tensor_pre_hook",
  754. )
  755. with disable_proxy_modes_tracing():
  756. inputs[i] = maybe_clone(inputs[i]) # type: ignore[assignment]
  757. self.bind_objects_to_proxies([inputs[i]], [proxy])
  758. return inputs
  759. def cpp_tensor_pre_hook(
  760. self, inputs: list[torch.Tensor], hook_id: int, i: int
  761. ) -> list[torch.Tensor]:
  762. proxy = self.fx_tracer.create_proxy(
  763. "call_function",
  764. torch._C._dynamo.compiled_autograd.call_cpp_tensor_pre_hooks,
  765. (hook_id, self.to_proxy(inputs[i])),
  766. {},
  767. )
  768. with disable_proxy_modes_tracing():
  769. inputs[i] = maybe_clone(inputs[i]) # type: ignore[assignment]
  770. self.bind_objects_to_proxies([inputs[i]], [proxy])
  771. return inputs
  772. def pre_hook(self, inputs: Sequence[Any], hook_id: int) -> list[torch.Tensor]:
  773. assert self.hooks_proxy is not None
  774. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  775. proxies = self.proxy_call_hook(
  776. hook,
  777. inputs,
  778. hook_type="pre_hook",
  779. )
  780. with disable_proxy_modes_tracing():
  781. inputs = [maybe_clone(x) for x in inputs]
  782. self.bind_objects_to_proxies(inputs, proxies)
  783. return inputs
  784. def post_hook(
  785. self, outputs: list[torch.Tensor], inputs: Sequence[torch.Tensor], hook_id: int
  786. ) -> list[torch.Tensor]:
  787. assert self.hooks_proxy is not None
  788. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  789. proxies = self.proxy_call_hook(
  790. hook,
  791. outputs,
  792. inputs,
  793. hook_type="post_hook",
  794. )
  795. with disable_proxy_modes_tracing():
  796. outputs = [maybe_clone(x) for x in outputs] # type: ignore[misc]
  797. self.bind_objects_to_proxies(outputs, proxies)
  798. return outputs
  799. def post_acc_grad_hook(
  800. self, input: torch.Tensor, hook_id: int
  801. ) -> list[torch.Tensor]:
  802. assert isinstance(input, torch.Tensor)
  803. assert self.hooks_proxy is not None
  804. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  805. proxy = self.proxy_call_hook(
  806. hook,
  807. input,
  808. hook_type="post_acc_grad_hook",
  809. )
  810. with disable_proxy_modes_tracing():
  811. res = [maybe_clone(input)]
  812. self.bind_objects_to_proxies(res, [proxy])
  813. return res # type: ignore[return-value]
  814. # Note: [Compiled autograd and cudagraphs]
  815. # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_.
  816. # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph
  817. # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the
  818. # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too.
  819. def move_graph_nodes_to_cuda(self, graph: torch.fx.Graph) -> list[int]:
  820. to_move: dict[int, torch.fx.Node] = {}
  821. has_cuda_inputs = False
  822. nodes = list(graph.nodes)
  823. assert nodes[0].target == "inputs"
  824. inputs = nodes[0]
  825. inputs_users = list(inputs.users.keys())
  826. # input access nodes should immediately follow placeholder nodes
  827. first_getitem_idx = len(_graph_placeholders)
  828. assert nodes[first_getitem_idx] == inputs_users[0]
  829. last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
  830. assert nodes[last_getitem_idx] == inputs_users[-1]
  831. # getitem nodes on inputs
  832. for i, node in enumerate(inputs_users):
  833. if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
  834. has_cuda_inputs = True
  835. continue
  836. is_cpu = node.meta["val"].device.type == "cpu"
  837. is_scalar = len(node.meta["val"].size()) == 0
  838. if is_cpu and is_scalar:
  839. node_users = list(node.users.keys())
  840. # We can only move the cpu scalar if it is not exposed to user code.
  841. if all(
  842. (
  843. isinstance(user.target, torch._ops.OpOverload)
  844. and user.target.namespace in ("prims", "aten")
  845. )
  846. or (
  847. isinstance(user.target, Op)
  848. and not user.target.is_custom_function
  849. )
  850. for user in node_users
  851. ):
  852. # all users are prims/aten, can move safely
  853. to_move[i] = node
  854. # only move cpu scalars to cuda if there were cuda activations in this graph,
  855. # this is to handle the case where cudagraphs is enabled on a cpu-only graph
  856. if has_cuda_inputs:
  857. for node in to_move.values():
  858. verbose_log.debug("Moving node %s from cpu to cuda", node)
  859. node.meta["val"] = node.meta["val"].cuda()
  860. # return runtime indices we need to move to cuda
  861. return list(to_move.keys())
  862. return []
  863. def is_sym_node(self, node: Any) -> bool:
  864. return (
  865. isinstance(node, torch.fx.Node)
  866. and node.op == "call_function"
  867. and node.target
  868. in [torch.ops.aten.sym_size.int, torch.ops.aten.sym_numel.default]
  869. )
  870. def dce(self) -> None:
  871. # Most of these removed nodes would have been removed during Dynamo and AOTDispatch
  872. # Remove some of these nodes earlier to improve compilation speed
  873. # Dynamo guards will error instead of creating aliasing guards unless we unpack them in the graph
  874. unpack_nodes: OrderedSet[torch.fx.Node] = OrderedSet()
  875. i: int | None = None
  876. for i, node in enumerate(self.fx_tracer.graph.find_nodes(op="placeholder")): # noqa: B007
  877. unpack_nodes.update(node.users.keys())
  878. assert i == len(_graph_placeholders) - 1
  879. def is_impure(node: torch.fx.Node) -> bool:
  880. if node in unpack_nodes or (
  881. node.op == "call_function" and node.target in _impure_targets
  882. ):
  883. return True
  884. return node.is_impure()
  885. before = len(self.fx_tracer.graph.nodes)
  886. self.fx_tracer.graph.eliminate_dead_code(is_impure)
  887. after = len(self.fx_tracer.graph.nodes)
  888. verbose_log.debug("DCE removed %d nodes", before - after)
  889. def remove_unused_sizes(self) -> set[int]:
  890. used_sizes = []
  891. unused_sizes = []
  892. # seek placeholder, should be at nodes[1]
  893. it = iter(self.fx_tracer.graph.nodes)
  894. next(it)
  895. sizes_node = next(it)
  896. assert sizes_node.name == "sizes"
  897. for getitem_node in sizes_node.users:
  898. assert getitem_node.target is operator.getitem
  899. if getitem_node.users:
  900. used_sizes.append(getitem_node)
  901. else:
  902. # remove from the graph
  903. unused_sizes.append(getitem_node)
  904. used_sizes_idx: set[int] = set()
  905. for used in used_sizes:
  906. assert isinstance(used.args, tuple)
  907. assert used.args[0] == sizes_node
  908. assert isinstance(used.args[1], int)
  909. next_size_idx = len(used_sizes_idx)
  910. # used later reindex the runtime sizes arg
  911. used_sizes_idx.add(used.args[1])
  912. # reindex the graph
  913. used.args = (used.args[0], next_size_idx)
  914. for unused in unused_sizes:
  915. self.fx_tracer.graph.erase_node(unused)
  916. return used_sizes_idx
  917. def create_graph_module(self, id: str) -> GraphModule:
  918. return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id)
  919. def end_capture(self, outputs: Any) -> tuple[Callable[..., Any], Any]:
  920. self.fx_tracer.create_proxy(
  921. "call_function",
  922. FakeCompiledAutogradEngine._exec_final_callbacks_stub,
  923. (),
  924. {},
  925. )
  926. self.stack.close()
  927. self.fx_tracer.create_node(
  928. "output",
  929. "output",
  930. (self.fx_tracer.create_arg(self.to_proxy(outputs)),),
  931. {},
  932. )
  933. runtime_inputs_to_move: list[int] = []
  934. if snapshot_cudagraph_enabled():
  935. runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
  936. # We traced using dummy tensors. Delete all the metadata of the dummy tensors.
  937. # It's probably better to refactor this class to use a different tracer
  938. # than the make_fx tracer, but that is a larger change.
  939. for node in self.fx_tracer.graph.nodes:
  940. for field in ["tensor_meta", "example_value", "val"]:
  941. if field in node.meta:
  942. del node.meta[field]
  943. trace_structured(
  944. "artifact",
  945. metadata_fn=lambda: {
  946. "name": "compiled_autograd_graph_pre_reordering",
  947. "encoding": "string",
  948. },
  949. payload_fn=lambda: GraphModule(
  950. self.fx_tracer.root,
  951. self.fx_tracer.graph,
  952. f"CompiledAutograd{self.id}PreReordering",
  953. ).print_readable(print_output=False),
  954. )
  955. self.delay_unpack_hook_nodes()
  956. self.reorder_tensor_pre_hook_nodes()
  957. self.reorder_pre_hook_nodes_to_schedule_asap()
  958. self.reorder_accumulate_grad_nodes()
  959. self.reorder_pre_hook_nodes_to_mimic_eager()
  960. self.reorder_post_acc_grad_hook_nodes()
  961. self.reorder_post_hook_nodes()
  962. # TODO(yf225): work around: remove dead codes like `sym_size` and `sym_numel` which are not used downstream. e.g.
  963. # ```
  964. # sym_numel_default = torch.ops.aten.sym_numel.default(sum_109); sum_109 = None
  965. # eq_115 = 16 == sym_numel_default; sym_numel_default = eq_115 = None
  966. # sym_size_int_39 = torch.ops.aten.sym_size.int(getitem_112, 1); getitem_112 = None
  967. # eq_116 = 16 == sym_size_int_39; eq_116 = None
  968. # eq_117 = 16 == sym_size_int_39; sym_size_int_39 = eq_117 = None
  969. # ```
  970. # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and
  971. # should prevent these ops from going into the CA graph.
  972. self.dce()
  973. if self.nan_checker:
  974. self.nan_checker.prep_with_graph(self.fx_tracer.graph)
  975. # keep only sizes that are actually used in the graph
  976. used_sizes_idx = self.remove_unused_sizes()
  977. graph = self.create_graph_module(f"CompiledAutograd{self.id}")
  978. set_locals_to_steal(graph, ["inputs"])
  979. lazy_graph_code = lazy_format_graph_code(
  980. "Compiled autograd graph",
  981. graph,
  982. include_device=True,
  983. include_stride=True,
  984. colored=True,
  985. )
  986. compiled_autograd_log.info("%s", lazy_graph_code)
  987. verbose_log.debug("%s", lazy_graph_code)
  988. trace_structured(
  989. "compiled_autograd_graph",
  990. payload_fn=lambda: graph.print_readable(print_output=False),
  991. )
  992. def runtime_wrapper(
  993. compiled_fn: Callable[..., Any],
  994. inputs: Any,
  995. sizes: Any,
  996. scalars: Any,
  997. hooks: Any,
  998. packed_inputs: Any,
  999. ) -> tuple[Any, Any]:
  1000. global in_compiled_autograd_region
  1001. try:
  1002. in_compiled_autograd_region = True
  1003. if self.nan_checker:
  1004. self.nan_checker.prep_with_inputs(inputs)
  1005. filtered_sizes = []
  1006. for idx, integer in enumerate(sizes):
  1007. if idx in used_sizes_idx:
  1008. # can't create negative size
  1009. if integer > 0:
  1010. filtered_sizes.append(torch.empty(0, integer))
  1011. torch._dynamo.maybe_mark_dynamic(filtered_sizes[-1], 1)
  1012. else:
  1013. filtered_sizes.append(integer)
  1014. for i in runtime_inputs_to_move:
  1015. inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
  1016. with _disable(), make_compile_context(self.id):
  1017. out = compiled_fn(
  1018. inputs, filtered_sizes, scalars, hooks, packed_inputs
  1019. )
  1020. if self.nan_checker:
  1021. self.nan_checker.check(out)
  1022. return out
  1023. finally:
  1024. in_compiled_autograd_region = False
  1025. get_chromium_event_logger().log_event_end(
  1026. "compiled_autograd",
  1027. time.time_ns(),
  1028. {"graph_id": self.id},
  1029. self.start_time_ns,
  1030. log_pt2_compile_event=True,
  1031. )
  1032. self.compile_context.__exit__(None, None, None)
  1033. return runtime_wrapper, self.compiler_fn(graph)
  1034. @staticmethod
  1035. def get_all_nodes(args: Sequence[Any]) -> list[torch.fx.Node]:
  1036. # filter out non-Node args, like None
  1037. nodes = [n for n in args if type(n) is torch.fx.Node]
  1038. return nodes
  1039. @staticmethod
  1040. def is_placeholder(node: torch.fx.Node) -> bool:
  1041. if node.op == "placeholder" or (
  1042. node.op == "call_function"
  1043. and node.target is operator.getitem
  1044. and node.args[0].op == "placeholder" # type: ignore[union-attr, arg-type]
  1045. ):
  1046. return True
  1047. return False
  1048. def reorder_accumulate_grad_nodes(self) -> None:
  1049. """
  1050. Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of
  1051. the graph. This differs from eager mode, which schedules them as soon as possible. This
  1052. pass attempts to reorder the graph to mimic eager behavior.
  1053. """
  1054. for node in self.fx_tracer.graph.find_nodes(
  1055. op="call_function", target=call_accumulate_grad
  1056. ):
  1057. param_node, grad_node = node.args[0], node.args[1]
  1058. getitem_node = None
  1059. if grad_node.target is operator.getitem:
  1060. getitem_node = grad_node
  1061. grad_node = getitem_node.args[0]
  1062. arg = max([param_node, grad_node]) # last arg
  1063. if arg is not node.prev and not self.is_placeholder(arg):
  1064. arg.append(node)
  1065. if getitem_node is not None:
  1066. arg.append(getitem_node)
  1067. def delay_unpack_hook_nodes(self) -> None:
  1068. """
  1069. We can delay unpack hooks until they are needed, even later than in the eager autograd engine.
  1070. """
  1071. for node in self.fx_tracer.graph.find_nodes(
  1072. op="call_function", target=call_hook
  1073. ):
  1074. if node.kwargs.get("hook_type", None) != "unpack_hook":
  1075. continue
  1076. first_user = min(node.users)
  1077. first_user.prepend(node)
  1078. def reorder_tensor_pre_hook_nodes(self) -> None:
  1079. """
  1080. Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed
  1081. to the end of the graph. This differs from eager mode, which schedules
  1082. them as soon as possible. This pass attempts to reorder the graph to
  1083. mimic eager behavior.
  1084. """
  1085. for node in self.fx_tracer.graph.find_nodes(
  1086. op="call_function", target=call_hook
  1087. ):
  1088. if node.kwargs.get("hook_type", None) != "tensor_pre_hook":
  1089. continue
  1090. getitem_node = node.args[0]
  1091. input_node = node.args[1] # tensor_pre_hook handle only one grad tensor
  1092. if input_node is not node.prev and not self.is_placeholder(input_node):
  1093. input_node.append(getitem_node)
  1094. getitem_node.append(node)
  1095. def reorder_pre_hook_nodes_to_schedule_asap(self) -> None:
  1096. """
  1097. In this function, we schedule the pre hooks as soon as possible. This
  1098. does not match eager behavior (schedule pre hook right before its
  1099. registered node), but it can make acc grad be scheduled properly when
  1100. the pre hooks are registered to them. After reordering acc grad node, we
  1101. will reorder the pre hooks again to mimic eager behavior.
  1102. """
  1103. for node in self.fx_tracer.graph.find_nodes(
  1104. op="call_function", target=call_hook
  1105. ):
  1106. if node.kwargs.get("hook_type", None) != "pre_hook":
  1107. continue
  1108. getitem_node = node.args[0]
  1109. # pre_hook handle a tuple of grad tensors
  1110. input_nodes = self.get_all_nodes(node.args[1])
  1111. to_remove = []
  1112. to_append = []
  1113. hook_block = [node] # contain the hook and hook args getitem
  1114. for n in input_nodes:
  1115. if n.op == "call_function" and n.target is operator.getitem:
  1116. to_append.append(n.args[0])
  1117. to_remove.append(n)
  1118. hook_block.append(n)
  1119. for a, b in zip(to_remove, to_append):
  1120. input_nodes.remove(a)
  1121. input_nodes.append(b) # type: ignore[arg-type]
  1122. arg = max(input_nodes) # last input
  1123. if arg is not node.prev and not self.is_placeholder(arg):
  1124. arg.append(getitem_node)
  1125. for n in hook_block:
  1126. getitem_node.append(n)
  1127. def reorder_pre_hook_nodes_to_mimic_eager(self) -> None:
  1128. """
  1129. Usage of AOTAutograd causes all the pre_hook nodes to get pushed to the
  1130. end of the graph. This differs from eager mode, which schedules them
  1131. right before their registered node execution. This pass attempts to
  1132. reorder the graph to mimic eager behavior.
  1133. """
  1134. pre_hooks = []
  1135. for node in self.fx_tracer.graph.find_nodes(
  1136. op="call_function", target=call_hook
  1137. ):
  1138. if node.kwargs.get("hook_type", None) != "pre_hook":
  1139. continue
  1140. pre_hooks.append(node)
  1141. for node in reversed(pre_hooks):
  1142. hook_getitem_node = node.args[0]
  1143. users = list(node.users.keys())
  1144. if len(users) == 0:
  1145. continue
  1146. # users are all getitem ops and they are used by same registered node
  1147. assert all(
  1148. user.op == "call_function" and user.target is operator.getitem
  1149. for user in users
  1150. )
  1151. registered_node = next(iter(users[0].users.keys()))
  1152. if registered_node is not node.next:
  1153. registered_node.prepend(hook_getitem_node)
  1154. registered_node.prepend(node)
  1155. for getitem in users:
  1156. registered_node.prepend(getitem)
  1157. def reorder_post_acc_grad_hook_nodes(self) -> None:
  1158. """
  1159. Usage of AOTAutograd causes all the post_acc_grad_hook nodes to get
  1160. pushed to the end of the graph. This differs from eager mode, which
  1161. schedules them as soon as possible. This pass attempts to reorder the
  1162. graph to mimic eager behavior.
  1163. """
  1164. post_acc_grad_hooks = []
  1165. for node in self.fx_tracer.graph.find_nodes(
  1166. op="call_function", target=call_hook
  1167. ):
  1168. if node.kwargs.get("hook_type", None) != "post_acc_grad_hook":
  1169. continue
  1170. post_acc_grad_hooks.append(node)
  1171. # nodes in post_acc_grad_hooks are in topo order. For hooks registered
  1172. # to same node, we should keep their relative order
  1173. for node in reversed(post_acc_grad_hooks):
  1174. getitem_node = node.args[0]
  1175. param_node = node.args[1] # post_acc_grad_hook handle one param
  1176. # find the corresponding acc_grad node
  1177. acc_grad_node = None
  1178. for n in list(param_node.users.keys()):
  1179. if n.op == "call_function" and n.target is call_accumulate_grad:
  1180. acc_grad_node = n
  1181. break
  1182. assert acc_grad_node is not None, (
  1183. "post_acc_grad_hook must have corresponding acc grad node"
  1184. )
  1185. # append post_acc_grad_hook after acc_grad node
  1186. acc_grad_node.append(getitem_node)
  1187. getitem_node.append(node)
  1188. def reorder_post_hook_nodes(self) -> None:
  1189. """
  1190. Usage of AOTAutograd causes all the post_hook nodes to get pushed to the
  1191. end of the graph. This differs from eager mode, which schedules them as
  1192. soon as possible. This pass attempts to reorder the graph to mimic eager
  1193. behavior.
  1194. """
  1195. post_hooks = []
  1196. for node in self.fx_tracer.graph.find_nodes(
  1197. op="call_function", target=call_hook
  1198. ):
  1199. if node.kwargs.get("hook_type", None) != "post_hook":
  1200. continue
  1201. post_hooks.append(node)
  1202. for node in reversed(post_hooks):
  1203. getitem_node = node.args[0]
  1204. output_nodes = node.args[1]
  1205. input_nodes = node.args[2]
  1206. if len(output_nodes) > 0:
  1207. continue
  1208. # pyrefly: ignore [implicit-any]
  1209. input_nodes_and_users = []
  1210. input_nodes_and_users.extend(list(input_nodes))
  1211. for input_node in input_nodes:
  1212. input_nodes_and_users.extend(
  1213. user
  1214. for user in list(input_node.users.keys())
  1215. if not (
  1216. user.op == "call_function"
  1217. and user.target is call_hook
  1218. and node.kwargs.get("hook_type", None) == "post_hook"
  1219. )
  1220. )
  1221. arg = max(input_nodes_and_users) # last input users
  1222. if arg.op == "call_function" and arg.target is call_accumulate_grad:
  1223. param_node = arg.args[0]
  1224. post_acc_grad_hook_node = None
  1225. for n in list(param_node.users.keys()):
  1226. if (
  1227. n.op == "call_function"
  1228. and n.target is call_hook
  1229. and n.kwargs.get("hook_type", None) == "post_acc_grad_hook"
  1230. ):
  1231. post_acc_grad_hook_node = n
  1232. if post_acc_grad_hook_node is not None:
  1233. post_acc_grad_hook_node.append(getitem_node)
  1234. getitem_node.append(node)
  1235. continue
  1236. if arg is not node.prev and not self.is_placeholder(arg):
  1237. arg.append(getitem_node)
  1238. getitem_node.append(node)
  1239. def to_proxy(self, t: Any) -> Any:
  1240. if t is None:
  1241. return None
  1242. if isinstance(t, list):
  1243. return [self.to_proxy(x) for x in t]
  1244. if isinstance(t, tuple):
  1245. return tuple(self.to_proxy(x) for x in t)
  1246. if isinstance(t, (torch.SymInt, torch.SymFloat)):
  1247. return self.symnode_proxy_lookup[t.node]
  1248. if not isinstance(t, torch.Tensor):
  1249. # constant types like device, dtype, str
  1250. return t
  1251. proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
  1252. assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
  1253. return proxy_tensor.proxy
  1254. def bind_objects_to_proxies(
  1255. self,
  1256. objects: Sequence[Any],
  1257. proxies: Any,
  1258. origins: Optional[list[tuple[int, str]]] = None,
  1259. ) -> Sequence[Any]:
  1260. if isinstance(proxies, torch.fx.Proxy):
  1261. if origins:
  1262. assert len(origins) == len(objects)
  1263. bound_proxies = []
  1264. for i in range(len(objects)):
  1265. nodecall_index, node_name = origins[i]
  1266. self.set_node_origin(node_name, nodecall_index, None)
  1267. bound_proxies.append(proxies[i]) # type: ignore[index]
  1268. proxies = bound_proxies
  1269. else:
  1270. proxies = [proxies[i] for i in range(len(objects))] # type: ignore[index]
  1271. assert len(objects) == len(proxies)
  1272. track_tensor_tree(objects, proxies, constant=None, tracer=self.fx_tracer)
  1273. return proxies
  1274. def bind_backward_state(self, index: int) -> BackwardState:
  1275. assert self.hooks_proxy is not None
  1276. proxy = self.hooks_proxy[index] # type: ignore[index]
  1277. bw_state = BackwardState()
  1278. track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
  1279. return bw_state
  1280. def set_node_origin(
  1281. self,
  1282. node_name: str,
  1283. nodecall_index: int,
  1284. pyobj: Optional[torch.autograd.Function],
  1285. ) -> None:
  1286. maybe_aot_id = ""
  1287. if pyobj is not None:
  1288. forward_cls = pyobj._forward_cls # type: ignore[attr-defined]
  1289. if hasattr(forward_cls, "_aot_id"):
  1290. # backward was created by AOT Dispatcher
  1291. if forward_cls._lazy_backward_info is None:
  1292. raise RuntimeError(
  1293. """This compiled backward function was saved by AOTAutogradCache, which does not support
  1294. compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`."""
  1295. )
  1296. maybe_aot_id = forward_cls._aot_id
  1297. new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})"
  1298. raw_stack_trace = CapturedTraceback.extract().format()[-1]
  1299. new_stack_trace = raw_stack_trace.replace(
  1300. "raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code
  1301. )
  1302. set_stack_trace(new_stack_trace)
  1303. # state of the autograd engine dispatch, kept in sync by enable/disable context managers
  1304. compiled_autograd_enabled = False
  1305. # global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager"
  1306. compiled_autograd_enabled_force_eager = False
  1307. # global flag to check if we are processing graphs produced from a compiled autograd graph
  1308. in_compiled_autograd_region = False
  1309. active_disable_ctx = False
  1310. depth = 0
  1311. @contextlib.contextmanager
  1312. def _enable(
  1313. compiler_fn: Callable[..., Any],
  1314. dynamic: bool = True,
  1315. ignore_active_disable_ctx: bool = True,
  1316. ) -> Generator[None, None, None]:
  1317. # The entrypoint to enable CA.
  1318. # It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather
  1319. # than using this context manager directly. If you are torch.compiling the corresponding
  1320. # forward pass, make sure they are wrapped under this context as well.
  1321. #
  1322. # Example:
  1323. # def train(model, inputs, target):
  1324. # compiled_model = torch.compile(model)
  1325. # pred = compiled_model(data)
  1326. # loss = compute_loss(pred, target)
  1327. # loss.backward()
  1328. #
  1329. # with _enable(compiler_fn):
  1330. # train(model, inputs, target)
  1331. #
  1332. # Inputs:
  1333. # - compiler_fn: The wrapper that will consume the compiled autograd graph, e.g. `torch.compile`
  1334. # - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic.
  1335. # This doesn't affect the dynamic configuration of the compilation wrapper.
  1336. if not ignore_active_disable_ctx and active_disable_ctx:
  1337. yield
  1338. else:
  1339. if dynamic:
  1340. assert type(dynamic) is bool
  1341. from torch._dynamo import eval_frame
  1342. if eval_frame._stance.stance == "force_eager":
  1343. # If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd
  1344. # to fall back to eager as well.
  1345. global compiled_autograd_enabled_force_eager
  1346. compiled_autograd_enabled_force_eager = True
  1347. try:
  1348. yield
  1349. finally:
  1350. compiled_autograd_enabled_force_eager = False
  1351. else:
  1352. # we need to import this, because user might not have imported it if they directly use this context manager
  1353. # we need to lazily import it, because of circular dependencies
  1354. if torch.cuda.is_available():
  1355. from torch._inductor import cudagraph_trees # noqa: F401
  1356. (
  1357. prior_compiler,
  1358. prior_dynamic,
  1359. ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
  1360. functools.partial(AutogradCompilerInstance, compiler_fn), dynamic
  1361. )
  1362. if snapshot_verbose_logging_enabled():
  1363. torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type]
  1364. global compiled_autograd_enabled
  1365. compiled_autograd_enabled = True
  1366. global depth
  1367. prior_depth = depth
  1368. depth += 1
  1369. try:
  1370. with torch.autograd.set_multithreading_enabled(False):
  1371. yield
  1372. finally:
  1373. if not prior_compiler:
  1374. compiled_autograd_enabled = False
  1375. torch._C._dynamo.compiled_autograd.set_autograd_compiler(
  1376. prior_compiler, prior_dynamic
  1377. )
  1378. depth -= 1
  1379. assert depth == prior_depth, (
  1380. "Nested Compiled Autograd Contexts must return before their parent context"
  1381. )
  1382. @contextlib.contextmanager
  1383. def _disable() -> Generator[None, None, None]:
  1384. (
  1385. prior_compiler,
  1386. prior_dynamic,
  1387. ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False)
  1388. global compiled_autograd_enabled
  1389. compiled_autograd_enabled = False
  1390. global active_disable_ctx
  1391. if not active_disable_ctx:
  1392. active_disable_ctx = True
  1393. try:
  1394. yield
  1395. finally:
  1396. if prior_compiler:
  1397. compiled_autograd_enabled = True
  1398. active_disable_ctx = False
  1399. torch._C._dynamo.compiled_autograd.set_autograd_compiler(
  1400. prior_compiler, prior_dynamic
  1401. )
  1402. # return to starting state of a new process
  1403. def reset() -> None:
  1404. global compiled_autograd_enabled
  1405. compiled_autograd_enabled = False
  1406. assert not in_compiled_autograd_region
  1407. torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False)
  1408. torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
  1409. torch._C._dynamo.compiled_autograd.clear_cache()
  1410. global COMPILE_COUNTER
  1411. COMPILE_COUNTER = itertools.count()
  1412. # Reimplementation of part of CopySlices::apply in Python.
  1413. # The shared code is really similar so we're not going to try to deduplicate.
  1414. def copy_slices_prologue(
  1415. inputs: Sequence[torch.Tensor],
  1416. base_sizes: Sequence[IntLikeType],
  1417. base_strides: Sequence[IntLikeType],
  1418. base_storage_offset: IntLikeType,
  1419. view_sizes: Sequence[IntLikeType],
  1420. view_strides: Sequence[IntLikeType],
  1421. view_storage_offset: IntLikeType,
  1422. ) -> list[torch.Tensor]:
  1423. grad = inputs[0]
  1424. result = grad.new_empty_strided(base_sizes, base_strides)
  1425. assert grad is not None
  1426. result.copy_(grad)
  1427. offset = view_storage_offset - base_storage_offset
  1428. grad_slice = result.as_strided(view_sizes, view_strides, offset)
  1429. return [result, grad_slice, grad_slice.clone(memory_format=torch.contiguous_format)]
  1430. # Reimplementation of part of CopySlices::apply in Python.
  1431. # The shared code is really similar so we're not going to try to deduplicate.
  1432. def copy_slices_epilogue(
  1433. needs_input_grad: Sequence[bool],
  1434. result: torch.Tensor,
  1435. res: Sequence[Optional[torch.Tensor]],
  1436. grad_slice: torch.Tensor,
  1437. ) -> list[Optional[torch.Tensor]]:
  1438. grad_inputs: list[Optional[torch.Tensor]] = [None] * len(needs_input_grad)
  1439. for i in range(len(needs_input_grad)):
  1440. if needs_input_grad[i]:
  1441. if res[i] is None:
  1442. continue
  1443. if i == 0:
  1444. to_copy = res[i]
  1445. assert to_copy is not None
  1446. grad_slice.copy_(to_copy)
  1447. grad_inputs[i] = result
  1448. else:
  1449. grad_inputs[i] = res[i]
  1450. return grad_inputs