checkpoint.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import platform
  4. import uuid
  5. import warnings
  6. import weakref
  7. from collections import defaultdict
  8. from typing import * # noqa: F403
  9. import enum
  10. from weakref import ReferenceType
  11. import torch
  12. import torch.fx.traceback as fx_traceback
  13. from torch.utils._pytree import tree_map
  14. from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
  15. from torch.utils._python_dispatch import TorchDispatchMode
  16. from typing import NoReturn
  17. __all__ = [
  18. "checkpoint",
  19. "checkpoint_sequential",
  20. "CheckpointError",
  21. "CheckpointFunction",
  22. "check_backward_validity",
  23. "detach_variable",
  24. "get_device_states",
  25. "set_device_states",
  26. "noop_context_fn",
  27. "set_checkpoint_early_stop",
  28. "DefaultDeviceType",
  29. "set_checkpoint_debug_enabled",
  30. "CheckpointPolicy",
  31. "SelectiveCheckpointContext",
  32. "create_selective_checkpoint_contexts",
  33. "SAC_IGNORED_OPS",
  34. "GraphExecGroup",
  35. ]
  36. _DEFAULT_DETERMINISM_MODE = "default"
  37. _checkpoint_debug_enabled: Optional[bool] = None
  38. @contextlib.contextmanager
  39. def set_checkpoint_debug_enabled(enabled: Optional[bool]):
  40. """
  41. Context manager that sets whether checkpoint should print additional debug
  42. information when running. See the ``debug`` flag for
  43. :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
  44. when set, this context manager overrides the value of ``debug`` passed to
  45. checkpoint. To defer to the local setting, pass ``None`` to this context.
  46. Args:
  47. enabled (bool): Whether checkpoint should print debug information.
  48. Default is 'None'.
  49. """
  50. global _checkpoint_debug_enabled
  51. try:
  52. prev = _checkpoint_debug_enabled
  53. _checkpoint_debug_enabled = enabled
  54. yield
  55. finally:
  56. _checkpoint_debug_enabled = prev
  57. def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
  58. if isinstance(inputs, tuple):
  59. out = []
  60. for inp in inputs:
  61. if not isinstance(inp, torch.Tensor):
  62. out.append(inp)
  63. continue
  64. x = inp.detach()
  65. x.requires_grad = inp.requires_grad
  66. out.append(x)
  67. return tuple(out)
  68. else:
  69. raise RuntimeError(
  70. "Only tuple of tensors is supported. Got Unsupported input type: ",
  71. type(inputs).__name__,
  72. )
  73. def check_backward_validity(inputs: Iterable[Any]) -> None:
  74. if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
  75. warnings.warn(
  76. "None of the inputs have requires_grad=True. Gradients will be None", stacklevel=2
  77. )
  78. def _get_device_module(device="cuda"):
  79. if device == "meta":
  80. return torch.device("meta")
  81. device_module = getattr(torch, device)
  82. return device_module
  83. class DefaultDeviceType:
  84. r"""
  85. A class that manages the default device type for checkpointing.
  86. If no non-CPU tensors are present, the default device type will
  87. be used. The default value is 'cuda'. The device type is used in
  88. the checkpointing process when determining which device states
  89. to save and restore for recomputation.
  90. """
  91. _default_device_type: Optional[str] = None
  92. @staticmethod
  93. def set_device_type(device: str = "cuda") -> None:
  94. """
  95. Set the default device type for checkpointing.
  96. Args:
  97. device (str): The device type to be set as default. Default is 'cuda'.
  98. """
  99. DefaultDeviceType._default_device_type = device
  100. @staticmethod
  101. def get_device_type() -> str:
  102. """
  103. Get the current default device type for checkpointing.
  104. Returns:
  105. str: The current default device type.
  106. """
  107. if not DefaultDeviceType._default_device_type:
  108. DefaultDeviceType._default_device_type = acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu"
  109. return DefaultDeviceType._default_device_type
  110. def _infer_device_type(*args):
  111. device_types = []
  112. def add_device_types(arg) -> None:
  113. nonlocal device_types
  114. if isinstance(arg, torch.Tensor) and arg.device.type != "cpu":
  115. device_types.append(arg.device.type)
  116. tree_map(add_device_types, args)
  117. device_types_set = set(device_types)
  118. if len(device_types_set) > 1:
  119. warnings.warn(
  120. "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
  121. "Device state will only be saved for devices of a single device type, and the remaining "
  122. "devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
  123. "this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
  124. "detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
  125. f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}", stacklevel=2
  126. )
  127. if len(device_types) == 0:
  128. return DefaultDeviceType.get_device_type()
  129. elif "cuda" in device_types_set:
  130. return "cuda"
  131. else:
  132. return device_types[0]
  133. # We can't know if the run_fn will internally move some args to different devices,
  134. # which would require logic to preserve rng states for those devices as well.
  135. # We could paranoically stash and restore ALL the rng states for all visible devices,
  136. # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
  137. # the device of all Tensor args.
  138. #
  139. # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
  140. def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
  141. # This will not error out if "arg" is a CPU tensor or a non-tensor type because
  142. # the conditionals short-circuit.
  143. fwd_device_ids = []
  144. def add_device_ids(arg) -> None:
  145. nonlocal fwd_device_ids
  146. if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}:
  147. fwd_device_ids.append(arg.get_device())
  148. tree_map(add_device_ids, args)
  149. fwd_device_states = []
  150. device_module = _get_device_module(_infer_device_type(*args))
  151. for device_id in fwd_device_ids:
  152. with device_module.device(device_id):
  153. fwd_device_states.append(device_module.get_rng_state())
  154. return fwd_device_ids, fwd_device_states
  155. def set_device_states(devices, states, *, device_type=None) -> None:
  156. """Sets random number generator states for the specified devices.
  157. Args:
  158. devices: Device ids to set states for.
  159. states: States to set.
  160. device_type: ``device_type`` of the devices to set states for. Default
  161. is the device returned by a call to ``DefaultDeviceType.get_device_type()``,
  162. which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``.
  163. """
  164. if device_type is None:
  165. device_type = DefaultDeviceType.get_device_type()
  166. if device_type == "meta":
  167. return
  168. device_module = _get_device_module(device_type)
  169. for device, state in zip(devices, states, strict=False):
  170. with device_module.device(device):
  171. device_module.set_rng_state(state)
  172. def _get_autocast_kwargs(device_type="cuda"):
  173. if torch.amp.is_autocast_available(device_type):
  174. device_autocast_kwargs = {
  175. "enabled": torch.is_autocast_enabled(device_type),
  176. "dtype": torch.get_autocast_dtype(device_type),
  177. "cache_enabled": torch.is_autocast_cache_enabled(),
  178. }
  179. else:
  180. device_autocast_kwargs = None
  181. cpu_autocast_kwargs = {
  182. "enabled": torch.is_autocast_enabled('cpu'),
  183. "dtype": torch.get_autocast_dtype('cpu'),
  184. "cache_enabled": torch.is_autocast_cache_enabled(),
  185. }
  186. return device_autocast_kwargs, cpu_autocast_kwargs
  187. class CheckpointFunction(torch.autograd.Function):
  188. @staticmethod
  189. # pyrefly: ignore [bad-override]
  190. def forward(ctx, run_function, preserve_rng_state, *args):
  191. check_backward_validity(args)
  192. ctx.run_function = run_function
  193. ctx.preserve_rng_state = preserve_rng_state
  194. # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
  195. ctx.device_type = _infer_device_type(*args)
  196. ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
  197. ctx.device_type
  198. )
  199. if preserve_rng_state:
  200. ctx.fwd_cpu_state = torch.get_rng_state()
  201. # Don't eagerly initialize the cuda context by accident.
  202. # (If the user intends that the context is initialized later, within their
  203. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  204. # we have no way to anticipate this will happen before we run the function.)
  205. ctx.had_device_in_fwd = False
  206. device_module = _get_device_module(ctx.device_type)
  207. if getattr(device_module, "_initialized", False):
  208. ctx.had_device_in_fwd = True
  209. ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
  210. # Save non-tensor inputs in ctx, keep a placeholder None for tensors
  211. # to be filled out during the backward.
  212. ctx.inputs = []
  213. ctx.tensor_indices = []
  214. tensor_inputs = []
  215. for i, arg in enumerate(args):
  216. if torch.is_tensor(arg):
  217. tensor_inputs.append(arg)
  218. ctx.tensor_indices.append(i)
  219. ctx.inputs.append(None)
  220. else:
  221. ctx.inputs.append(arg)
  222. ctx.save_for_backward(*tensor_inputs)
  223. with torch.no_grad():
  224. outputs = run_function(*args)
  225. return outputs
  226. @staticmethod
  227. def backward(ctx, *args):
  228. if not torch.autograd._is_checkpoint_valid():
  229. raise RuntimeError(
  230. "When use_reentrant=True, torch.utils.checkpoint is incompatible"
  231. " with .grad() or passing an `inputs` parameter to .backward()."
  232. " To resolve this error, you can either set use_reentrant=False,"
  233. " or call .backward() without passing the `inputs` argument."
  234. )
  235. # Copy the list to avoid modifying original list.
  236. inputs = list(ctx.inputs)
  237. tensor_indices = ctx.tensor_indices
  238. tensors = ctx.saved_tensors
  239. # Fill in inputs with appropriate saved tensors.
  240. for i, idx in enumerate(tensor_indices):
  241. inputs[idx] = tensors[i]
  242. # Stash the surrounding rng state, and mimic the state that was
  243. # present at this time during forward. Restore the surrounding state
  244. # when we're done.
  245. rng_devices = []
  246. if ctx.preserve_rng_state and ctx.had_device_in_fwd:
  247. rng_devices = ctx.fwd_devices
  248. with torch.random.fork_rng(
  249. devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type
  250. ):
  251. if ctx.preserve_rng_state:
  252. torch.set_rng_state(ctx.fwd_cpu_state)
  253. if ctx.had_device_in_fwd:
  254. set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type)
  255. detached_inputs = detach_variable(tuple(inputs))
  256. device_autocast_ctx = torch.amp.autocast(
  257. device_type=ctx.device_type, **ctx.device_autocast_kwargs
  258. ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext()
  259. with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
  260. outputs = ctx.run_function(*detached_inputs)
  261. if isinstance(outputs, torch.Tensor):
  262. outputs = (outputs,)
  263. # run backward() with only tensor that requires grad
  264. outputs_with_grad = []
  265. args_with_grad = []
  266. for i in range(len(outputs)):
  267. if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
  268. outputs_with_grad.append(outputs[i])
  269. args_with_grad.append(args[i])
  270. if len(outputs_with_grad) == 0:
  271. raise RuntimeError(
  272. "none of output has requires_grad=True,"
  273. " this checkpoint() is not necessary"
  274. )
  275. torch.autograd.backward(outputs_with_grad, args_with_grad)
  276. grads = tuple(
  277. inp.grad if isinstance(inp, torch.Tensor) else None
  278. for inp in detached_inputs
  279. )
  280. return (None, None) + grads
  281. def noop_context_fn():
  282. return contextlib.nullcontext(), contextlib.nullcontext()
  283. # Note: [torch.compile and checkpoint]
  284. # TorchDynamo does not step inside utils.checkpoint function. The flow
  285. # looks likes this
  286. # 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
  287. # speculatively checking if the forward function is safe to trace.
  288. # 2) If yes, then Dynamo-generated Fx graph has the wrapped higher
  289. # order op. As a result, TorchDynamo does not look inside utils.checkpoint.
  290. # 3) If not, then TorchDynamo falls back to eager by performing a graph
  291. # break. And here, the following disable wrapper ensures that
  292. # TorchDynamo does not trigger again on the frames created by
  293. # utils.checkpoint innards.
  294. @torch._disable_dynamo
  295. def checkpoint(
  296. function,
  297. *args,
  298. use_reentrant: Optional[bool] = None,
  299. context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
  300. determinism_check: str = _DEFAULT_DETERMINISM_MODE,
  301. debug: bool = False,
  302. early_stop: bool = True,
  303. **kwargs
  304. ):
  305. r"""Checkpoint a model or part of the model.
  306. Activation checkpointing is a technique that trades compute for memory.
  307. By default, tensors computed during the forward pass are kept alive until
  308. they are used in gradient computations in the backward pass. To reduce this
  309. memory usage, tensors produced in the passed :attr:`function` are not kept
  310. alive until the backward pass. Instead, any passed tensors in :attr:`args`
  311. are kept alive, and the unsaved tensors are recomputed by re-invoking
  312. :attr:`function` in the backward pass as needed for gradient computation.
  313. Activation checkpointing can be applied to any part of a model -- this is
  314. sometimes described as "checkpointing" that part of the model.
  315. There are currently two checkpointing implementations available, determined
  316. by the :attr:`use_reentrant` parameter. It is recommended that you use
  317. ``use_reentrant=False``. Please refer the note below for a discussion of
  318. their differences.
  319. .. warning::
  320. If the :attr:`function` invocation during the backward pass differs
  321. from the forward pass, e.g., due to a global variable, the checkpointed
  322. version may not be equivalent, potentially causing an
  323. error being raised or leading to silently incorrect gradients.
  324. .. warning::
  325. The ``use_reentrant`` parameter should be passed explicitly. In version
  326. 2.9 we will raise an exception if ``use_reentrant`` is not passed.
  327. If you are using the ``use_reentrant=True`` variant, please refer to the
  328. note below for important considerations and potential limitations.
  329. .. note::
  330. The reentrant variant of checkpoint (``use_reentrant=True``) and
  331. the non-reentrant variant of checkpoint (``use_reentrant=False``)
  332. differ in the following ways:
  333. * Non-reentrant checkpoint stops recomputation as soon as all needed
  334. intermediate activations have been recomputed. This feature is enabled
  335. by default, but can be disabled with :func:`set_checkpoint_early_stop`.
  336. Reentrant checkpoint always recomputes :attr:`function` in its
  337. entirety during the backward pass.
  338. * The reentrant variant does not record the autograd graph during the
  339. forward pass, as it runs with the forward pass under
  340. :func:`torch.no_grad`. The non-reentrant version does record the
  341. autograd graph, allowing one to perform backward on the graph within
  342. checkpointed regions.
  343. * The reentrant checkpoint only supports the
  344. :func:`torch.autograd.backward` API for the backward pass without its
  345. `inputs` argument, while the non-reentrant version supports all ways
  346. of performing the backward pass.
  347. * At least one input and output must have ``requires_grad=True`` for the
  348. reentrant variant. If this condition is unmet, the checkpointed part
  349. of the model will not have gradients. The non-reentrant version does
  350. not have this requirement.
  351. * The reentrant version does not consider tensors in nested structures
  352. (e.g., custom objects, lists, dicts, etc) as participating in
  353. autograd, while the non-reentrant version does.
  354. * The reentrant checkpoint does not support checkpointed regions with
  355. detached tensors from the computational graph, whereas the
  356. non-reentrant version does. For the reentrant variant, if the
  357. checkpointed segment contains tensors detached using ``detach()`` or
  358. with :func:`torch.no_grad`, the backward pass will raise an error.
  359. This is because ``checkpoint`` makes all the outputs require gradients
  360. and this causes issues when a tensor is defined to have no gradient in
  361. the model. To avoid this, detach the tensors outside of the
  362. ``checkpoint`` function.
  363. Args:
  364. function: describes what to run in the forward pass of the model or
  365. part of the model. It should also know how to handle the inputs
  366. passed as the tuple. For example, in LSTM, if user passes
  367. ``(activation, hidden)``, :attr:`function` should correctly use the
  368. first input as ``activation`` and the second input as ``hidden``
  369. args: tuple containing inputs to the :attr:`function`
  370. Keyword args:
  371. preserve_rng_state(bool, optional): Omit stashing and restoring
  372. the RNG state during each checkpoint. Note that under torch.compile,
  373. this flag doesn't take effect and we always preserve RNG state.
  374. Default: ``True``
  375. use_reentrant(bool):
  376. specify whether to use the activation checkpoint variant that
  377. requires reentrant autograd. This parameter should be passed
  378. explicitly. In version 2.9 we will raise an exception if
  379. ``use_reentrant`` is not passed. If ``use_reentrant=False``,
  380. ``checkpoint`` will use an implementation that does not require
  381. reentrant autograd. This allows ``checkpoint`` to support additional
  382. functionality, such as working as expected with
  383. ``torch.autograd.grad`` and support for keyword arguments input into
  384. the checkpointed function.
  385. context_fn(Callable, optional): A callable returning a tuple of two
  386. context managers. The function and its recomputation will be run
  387. under the first and second context managers respectively.
  388. This argument is only supported if ``use_reentrant=False``.
  389. determinism_check(str, optional): A string specifying the determinism
  390. check to perform. By default it is set to ``"default"`` which
  391. compares the shapes, dtypes, and devices of the recomputed tensors
  392. against those the saved tensors. To turn off this check, specify
  393. ``"none"``. Currently these are the only two supported values.
  394. Please open an issue if you would like to see more determinism
  395. checks. This argument is only supported if ``use_reentrant=False``,
  396. if ``use_reentrant=True``, the determinism check is always disabled.
  397. debug(bool, optional): If ``True``, error messages will also include
  398. a trace of the operators ran during the original forward computation
  399. as well as the recomputation. This argument is only supported if
  400. ``use_reentrant=False``.
  401. early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
  402. recomputation as soon as it has computed all needed Tensors. This
  403. argument is ignored if ``use_reentrant=True``. Can be overridden
  404. globally using :func:`set_checkpoint_early_stop` context manager.
  405. Default: ``True``.
  406. Returns:
  407. Output of running :attr:`function` on :attr:`*args`
  408. """
  409. if use_reentrant is None:
  410. warnings.warn(
  411. "torch.utils.checkpoint: the use_reentrant parameter should be "
  412. "passed explicitly. Starting in PyTorch 2.9, calling checkpoint "
  413. "without use_reentrant will raise an exception. use_reentrant=False is "
  414. "recommended, but if you need to preserve the current default "
  415. "behavior, you can pass use_reentrant=True. Refer to docs for more "
  416. "details on the differences between the two variants.",
  417. stacklevel=2
  418. )
  419. use_reentrant = True
  420. # Hack to mix *args with **kwargs in a python 2.7-compliant way
  421. preserve = kwargs.pop("preserve_rng_state", True)
  422. if kwargs and use_reentrant:
  423. raise ValueError(
  424. "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
  425. )
  426. if use_reentrant:
  427. if context_fn is not noop_context_fn or debug is not False:
  428. raise ValueError(
  429. "Passing `context_fn` or `debug` is only supported when "
  430. "use_reentrant=False."
  431. )
  432. return CheckpointFunction.apply(function, preserve, *args)
  433. else:
  434. gen = _checkpoint_without_reentrant_generator(
  435. function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs
  436. )
  437. # Runs pre-forward logic
  438. next(gen)
  439. ret = function(*args, **kwargs)
  440. # Runs post-forward logic
  441. try:
  442. next(gen)
  443. except StopIteration:
  444. return ret
  445. def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs):
  446. r"""Checkpoint a sequential model to save memory.
  447. Sequential models execute a list of modules/functions in order
  448. (sequentially). Therefore, we can divide such a model in various segments
  449. and checkpoint each segment. All segments except the last will not store
  450. the intermediate activations. The inputs of each checkpointed segment will
  451. be saved for re-running the segment in the backward pass.
  452. .. warning::
  453. The ``use_reentrant`` parameter should be passed explicitly. In version
  454. 2.9 we will raise an exception if ``use_reentrant`` is not passed.
  455. If you are using the ``use_reentrant=True` variant, please see
  456. :func:`~torch.utils.checkpoint.checkpoint` for
  457. the important considerations and limitations of this variant. It is
  458. recommended that you use ``use_reentrant=False``.
  459. .. warning:
  460. Since PyTorch 1.4, it allows only one Tensor as the input and
  461. intermediate outputs, just like :class:`torch.nn.Sequential`.
  462. Args:
  463. functions: A :class:`torch.nn.Sequential` or the list of modules or
  464. functions (comprising the model) to run sequentially.
  465. segments: Number of chunks to create in the model
  466. input: A Tensor that is input to :attr:`functions`
  467. preserve_rng_state(bool, optional): Omit stashing and restoring
  468. the RNG state during each checkpoint.
  469. Default: ``True``
  470. use_reentrant(bool):
  471. specify whether to use the activation checkpoint variant that
  472. requires reentrant autograd. This parameter should be passed
  473. explicitly. In version 2.5 we will raise an exception if
  474. ``use_reentrant`` is not passed. If ``use_reentrant=False``,
  475. ``checkpoint`` will use an implementation that does not require
  476. reentrant autograd. This allows ``checkpoint`` to support additional
  477. functionality, such as working as expected with
  478. ``torch.autograd.grad`` and support for keyword arguments input into
  479. the checkpointed function.
  480. Returns:
  481. Output of running :attr:`functions` sequentially on :attr:`*inputs`
  482. Example:
  483. >>> # xdoctest: +SKIP("stub")
  484. >>> model = nn.Sequential(...)
  485. >>> input_var = checkpoint_sequential(model, chunks, input_var)
  486. """
  487. if use_reentrant is None:
  488. warnings.warn(
  489. "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant "
  490. "parameter should be passed explicitly. "
  491. "In version 2.9 we will raise an exception if use_reentrant "
  492. "is not passed. use_reentrant=False is "
  493. "recommended, but if you need to preserve the current default "
  494. "behavior, you can pass use_reentrant=True. Refer to docs for more "
  495. "details on the differences between the two variants.", stacklevel=2
  496. )
  497. use_reentrant = True
  498. # Hack for keyword-only parameter in a python 2.7-compliant way
  499. preserve = kwargs.pop("preserve_rng_state", True)
  500. if kwargs:
  501. raise ValueError(
  502. "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
  503. )
  504. def run_function(start, end, functions):
  505. def forward(input):
  506. for j in range(start, end + 1):
  507. input = functions[j](input)
  508. return input
  509. return forward
  510. if isinstance(functions, torch.nn.Sequential):
  511. functions = list(functions.children())
  512. segment_size = len(functions) // segments
  513. # the last chunk has to be non-volatile
  514. end = -1
  515. for start in range(0, segment_size * (segments - 1), segment_size):
  516. end = start + segment_size - 1
  517. input = checkpoint(
  518. run_function(start, end, functions),
  519. input,
  520. use_reentrant=use_reentrant,
  521. preserve_rng_state=preserve,
  522. )
  523. return run_function(end + 1, len(functions) - 1, functions)(input)
  524. def _internal_assert(cond) -> None:
  525. if not cond:
  526. raise AssertionError(
  527. "Something went unexpectedly wrong in activation checkpoint. "
  528. "Please report this bug by filing an issue to PyTorch."
  529. )
  530. # NOTE [ Nestable Checkpoint ]
  531. #
  532. # The semantics of nested checkpoint can be defined by two basic rules.
  533. # Following the two rules leads to an important implication that is central
  534. # to motivating the design.
  535. #
  536. # Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden
  537. # from any outer layers of checkpoint.
  538. #
  539. # Rule 2. The inputs of inner checkpoints are treated as tensors saved to its
  540. # parent checkpoint.
  541. #
  542. # Implication: To recompute any given saved tensor, we need to recompute all of
  543. # the checkpoints wrapping it.
  544. #
  545. # Why is this implied? To unpack a saved tensor X during backward we need to
  546. # recompute the inner-most checkpoint (#1), and in order to recompute that
  547. # checkpoint I need to have its inputs, which are managed by that checkpoint's
  548. # parent (#2), which thus also needs to be recomputed first. Continue this line
  549. # of reasoning and we realize that in order to unpack X, all checkpoints that
  550. # were active at the time X was saved need to be recomputed. (unless we have
  551. # already done so in that backward for some other saved tensor).
  552. #
  553. # In practice, we use a noop autograd Function to save inputs as saved tensors.
  554. # During unpack calling ctx.saved_tensor triggers the parent checkpoint to
  555. # recompute.
  556. #
  557. # Rule 3. We should start recomputation as if there are no checkpoints currently
  558. # active. Checkpoints encountered during recomputation are still
  559. # respected.
  560. #
  561. # When we start recomputation, we push the saved variable hook meant for
  562. # recomputation on the stack. See examples in Rule 6 for more context.
  563. #
  564. # * * * *
  565. #
  566. # Beyond the basic semantics specific to nested checkpoint, we impose several
  567. # more constraints that may apply to checkpointing in general.
  568. #
  569. # Rule 4. Lifetime of recomputed tensors
  570. #
  571. # Recomputed tensors are considered specific to particular invocations
  572. # of backward and are always cleared immediately as they are unpacked
  573. # Particularly, we require this to happen even if retain_graph=True.
  574. #
  575. # [ Implementation details of Rule 4 ]
  576. #
  577. # If we were okay with recomputed tensors staying alive after backward is run
  578. # with retain_graph=True, we would store recomputed variables as the values of a
  579. # WeakKeyDictionary and pack strong references to the keys, so that as we
  580. # backward, those packed keys would be cleared as long as retain_graph=False.
  581. # Clearing the packed key clears the corresponding entry in the WKD.
  582. #
  583. # If we wish recomputed variables to be immediately cleared as we unpack them in
  584. # the retain_graph=True case, we cannot rely on the packed keys to be cleared by
  585. # backward automatically. Instead of packing the strong reference to the key
  586. # directly, we pack a container object, which we manually clear as we unpack.
  587. #
  588. # An important detail is that if a second backward happens, the second
  589. # recomputation needs to reset the container with a newly created key.
  590. #
  591. # Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we
  592. # know we need.
  593. #
  594. # [ Implementation details of Rule 5 ]
  595. #
  596. # During recomputation, raise an exception if the number of recomputed tensors
  597. # matches the number of tensors that we expected to recompute. We wrap the
  598. # recomputation call with a try-catch to catch this specific exception. See
  599. # Rule #6 below for some examples.
  600. #
  601. # Rule 6. We support doing backward inside checkpoint context
  602. #
  603. # [ retain_graph is True]
  604. #
  605. # def fn(x):
  606. # y = x.sin()
  607. # z = y.cos()
  608. # gx, = torch.autograd.grad(z, x, retains_grad=True)
  609. # return gx, z
  610. #
  611. # out = checkpoint(fn)(inp)
  612. # out.backward()
  613. #
  614. # Because z is saved by cos while checkpoint is enabled, it would not be
  615. # actually saved, and so the .grad() call inside must trigger a recomputation.
  616. #
  617. # During recomputation the "inner pack hook" has two responsibilities:
  618. #
  619. # 1) As usual, populating the WeakKeyDictionary storing recomputed tensors
  620. # 2) Pack the actual tensor (detached) so that one may perform backward on the
  621. # recomputed graph. The tensors saved to this graph will live until the end
  622. # of recomputation, or die earlier if someone performs backward with
  623. # retain_graph=False.
  624. #
  625. # More generally performing backward on the recomputed graph occurs in the
  626. # following cases:
  627. # - If backward is performed inside forward,
  628. # - During the original forward IF early-stop is disabled
  629. # - During the original backward
  630. # - If there are multiple .grad()/.backward() calls, we would perform backward
  631. # on the recomputed graph even if early-stop is enabled (see the example below)
  632. #
  633. # [ retain_graph is False ]
  634. #
  635. # The example below shows what happens if during recomputation we find that some
  636. # of the tensors we are trying to recompute have already been cleared.
  637. #
  638. # Spoiler: we don't do anything special, we just skip over them!
  639. #
  640. # def fn(x):
  641. # y = x.sin() # (1)
  642. # z = y.cos() # (2)
  643. # gx, = torch.autograd.grad(z, x) # (3)
  644. # return x.cos() * gx # (4)
  645. #
  646. # out = checkpoint(fn)(inp)
  647. # out.backward() # (5)
  648. #
  649. # 1, 2. Don't save x and y since we are inside a checkpoint.
  650. # 3. Trigger a recompute of fn since x and y weren't saved.
  651. # And depending on whether early stop is enabled, either stop at (2) or
  652. # continue running the function.
  653. # Because we are running backward with retain_graph=False, we clear x and y's
  654. # holders.
  655. # 4. Don't save x since we are inside a checkpoint.
  656. # 5. Calling backward triggers another recompute of fn. During recompute, we see
  657. # that x and y have already been cleared in the original graph as indicated
  658. # by holder=None. We skip over them. We still save x at (4) (since its holder
  659. # is still alive.)
  660. _enable_checkpoint_early_stop: Optional[bool] = None
  661. @contextlib.contextmanager
  662. def set_checkpoint_early_stop(enable: bool):
  663. """Context manager that sets whether checkpoint should stop recomputation early.
  664. By default, non-reentrant checkpoint stops recomputation as soon as it
  665. has computed all needed Tensors. This context manager can be used to disable
  666. that feature if it is problematic for your specific application.
  667. This context manager only needs to be active when forward is run. It does
  668. not need to be active during backward.
  669. Example::
  670. >>> # xdoctest: +SKIP(failing)
  671. >>> message = "saved tensors default hooks are disabled"
  672. >>> with set_checkpoint_early_stop(False):
  673. ... # Any checkpoint under this context manager will respect this
  674. ... # context manager, even if its backward is performed outside.
  675. ... out = checkpoint(fn, inputs)
  676. ...
  677. >>> out.backward()
  678. """
  679. global _enable_checkpoint_early_stop
  680. try:
  681. prev = _enable_checkpoint_early_stop
  682. _enable_checkpoint_early_stop = enable
  683. yield
  684. finally:
  685. _enable_checkpoint_early_stop = prev
  686. class _Handle:
  687. pass
  688. class _Holder:
  689. def __init__(self) -> None:
  690. self.handles: Dict[int, Optional[_Handle]] = {}
  691. class _NoopSaveInputs(torch.autograd.Function):
  692. @staticmethod
  693. # pyrefly: ignore [bad-override]
  694. def forward(*args):
  695. return torch.empty((0,))
  696. @staticmethod
  697. def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
  698. # Only tensors can be saved with ctx.save_for_backward, everything else
  699. # is captured by get_args, which is saved directly on ctx
  700. tensor_indices, tensors = zip(
  701. *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)], strict=False
  702. )
  703. idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
  704. # args but with tensors replaced with None as placeholders
  705. args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
  706. def get_args(saved_tensors):
  707. # restore the placeholders with the original tensors grabbed from
  708. # ctx.saved_tensors (which may be saved on a parent checkpoint if
  709. # this checkpoint is nested, and that would trigger a recursive
  710. # unpack!)
  711. ret = [
  712. saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
  713. for i, o in enumerate(args)
  714. ]
  715. # grab the tail since we also saved the dummy to avoid having to explicitly
  716. # handle the case where there are no tensor inputs
  717. return ret[1:]
  718. ctx.get_args = get_args
  719. ctx.save_for_backward(*tensors)
  720. @staticmethod
  721. def backward(ctx, *grad_outputs) -> NoReturn:
  722. raise AssertionError("Did not expect to backward on this graph")
  723. class _CheckpointFrame:
  724. def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn) -> None:
  725. self.recompute_fn = recompute_fn
  726. self.input_saver = None
  727. self.weak_holders: List[ReferenceType] = []
  728. # We store this as a weakkeydictionary so that in the case of a partial
  729. # backward, the entries in the dict are cleared alongside the Holder
  730. # which will be removed when the SavedVariable is cleared.
  731. self.recomputed: DefaultDict[
  732. int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
  733. ] = defaultdict(weakref.WeakKeyDictionary)
  734. # We need both recomp_counter and recomputed since they can diverge
  735. # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
  736. self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
  737. self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
  738. # See Rule 5
  739. self.early_stop = early_stop
  740. # Debugging
  741. self.metadata_fn = metadata_fn
  742. self.unpack_error_cb = unpack_error_cb
  743. self.x_metadatas = []
  744. self.forward_completed = False
  745. self.ignore_saved_mismatch = False
  746. def check_recomputed_tensors_match(self, gid) -> None:
  747. if self.ignore_saved_mismatch:
  748. # TODO: we can probably make this check stricter by checking that
  749. # the metadata of the first tensors still match.
  750. return
  751. # NOTE [ Error handling for checkpoint ]
  752. #
  753. # At a high level, we need to check that the tensors saved
  754. # during original forward matches tensors saved during recompute
  755. # This means handling 3 cases:
  756. #
  757. # 1. During recompute, more tensors were saved.
  758. #
  759. # Usually this is hidden due to the StopRecomputationError
  760. # but if early stop is not enabled, or we would have errored
  761. # anyway because there aren't enough weak_holders. But we
  762. # do want to have a nice error. See the _recomputation_hook
  763. # for details.
  764. if not len(self.weak_holders) == self.recomp_counter[gid]:
  765. # 2. During recompute, fewer tensors were saved
  766. #
  767. # We know that every time we save something do original forward
  768. # we append to weak_holder, and every time we save a tensor
  769. # during recompute we increment recompute_counter.
  770. raise CheckpointError(
  771. "torch.utils.checkpoint: A different number of tensors was saved "
  772. "during the original forward and recomputation.\n"
  773. f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
  774. f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}.\n"
  775. f"{_debug_tip_msg}"
  776. )
  777. # 3. During recompute, the same tensors were saved, but they
  778. # have different metadata
  779. nb_meta_different = []
  780. for idx, weak_holder in enumerate(self.weak_holders):
  781. holder = weak_holder()
  782. if holder is None:
  783. continue
  784. # We've seen all holders since we iterate over them in order
  785. # For every holder that is still alive now, it must've been
  786. # alive when we saw it during recompute, therefore, the
  787. # gid must be set.
  788. _internal_assert(gid in holder.handles)
  789. # We know this is the first unpack, so it couldn't have been set
  790. # to None yet.
  791. _internal_assert(holder.handles[gid] is not None)
  792. # We always set these together in the recomputation hook
  793. _internal_assert(holder.handles[gid] in self.recomputed[gid])
  794. # see pack hook, x_metadata is 1:1 with weak_holders.
  795. x_meta = self.x_metadatas[idx]
  796. recomputed_x = self.recomputed[gid][holder.handles[gid]]
  797. if x_meta != self.metadata_fn(recomputed_x):
  798. nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))
  799. if len(nb_meta_different) > 0:
  800. mismatched_tensors = ""
  801. for idx, x_meta, recomputed_meta in nb_meta_different:
  802. mismatched_tensors += (
  803. f"tensor at position {idx}:\n"
  804. f"saved metadata: {x_meta}\n"
  805. f"recomputed metadata: {recomputed_meta}\n"
  806. )
  807. raise CheckpointError(
  808. "torch.utils.checkpoint: Recomputed values for the following tensors "
  809. "have different metadata than during the forward pass.\n"
  810. f"{mismatched_tensors}.\n"
  811. f"{_debug_tip_msg}"
  812. )
  813. _debug_tip_msg = """
  814. Tip: To see a more detailed error message, either pass `debug=True` to
  815. `torch.utils.checkpoint.checkpoint(...)` or wrap the code block
  816. with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to
  817. enable checkpoint‑debug mode globally.
  818. """
  819. _checkpoint_error_template = """ \
  820. An error happened while unpacking tensors; dumping logs of latest computation
  821. because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
  822. Scroll all the way down for guidance on how to navigate these logs.
  823. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
  824. | 1. Stack traces of the operators that ran in the original forward |
  825. +------------------------------------------------------------------------------+
  826. {forward_traces}
  827. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
  828. | 2. Stack traces of the operators that ran during recomputation |
  829. +------------------------------------------------------------------------------+
  830. {recompute_traces}
  831. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
  832. | 3. Log of operators in the original forward and recomputation |
  833. +------------------------------------------------------------------------------+
  834. (Scroll up to correlate stack traces with each operation listed below. This
  835. helps identify their source in the code.)
  836. IMPORTANT: Differences in "detach" calls between the original forward and the
  837. recomputation are expected. They are introduced by the checkpointing
  838. mechanism and can be ignored.
  839. Operations executed during the original forward:
  840. {forward_ops}
  841. Operations executed during recomputation:
  842. {recompute_ops}
  843. +------------------------------------------------------------------------------+
  844. ERROR: Detected non-determinism while running activation checkpointing
  845. You are seeing this error because you passed `debug=True` to checkpoint and
  846. tensors to be saved during the original forward and differ between those saved
  847. during recomputation. This can happen if different operators were ran in the
  848. original forward and in the recomputation.
  849. To identify where the mismatch may be coming from, you can do the following:
  850. 1) Compare the operators ran during original forward and recomputation to
  851. see where they differ. These operators are printed above in the order they
  852. were executed.
  853. 2) Review the stack trace for each operator to locate its invocation source.
  854. Each operator's stack trace is printed in their execution order.
  855. Note that the logs can be quite long. Here's how they are structured:
  856. (Tip: you can Ctrl-f for these headers)
  857. 1. Stack traces of the operators that ran in the original forward
  858. 2. Stack traces of the operators that ran during recomputation
  859. 3. Log of operators in the original forward and recomputation
  860. 4. Error message <--- You are here
  861. --------------------------------------------------------------------------------
  862. """
  863. class CheckpointError(RuntimeError):
  864. pass
  865. def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]:
  866. # This function returns the context_fn and error_cb to be used by the
  867. # checkpointing mechanism. error_cb is invoked when an error is detected
  868. # during unpack.
  869. # record_context_cpp is not support on non-linux non-x86_64 platforms
  870. cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
  871. class CaptureLogs:
  872. def __init__(self) -> None:
  873. self.logs = None
  874. self.tbs = None
  875. def get_context_manager(self):
  876. @contextlib.contextmanager
  877. def logging_mode():
  878. with LoggingTensorMode(), \
  879. capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
  880. # pyrefly: ignore [bad-assignment]
  881. self.logs, self.tbs = logs_and_tb
  882. yield logs_and_tb
  883. return logging_mode()
  884. capture_logs_fwd = CaptureLogs()
  885. capture_logs_recompute = CaptureLogs()
  886. def unpack_error_cb(e: CheckpointError) -> NoReturn:
  887. def get_str_tb(label, capture_logs):
  888. out = ""
  889. total_len = len(capture_logs.logs)
  890. for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs, strict=False)):
  891. out += f"{log} ({i + 1} of {total_len} in {label})\n\n"
  892. found_torch_dispatch = False
  893. for line in tb:
  894. # Start printing stack trace only after __torch_dispatch__ is found
  895. is_torch_dispatch = line['name'] == '__torch_dispatch__'
  896. if not found_torch_dispatch and not is_torch_dispatch:
  897. continue
  898. elif is_torch_dispatch:
  899. found_torch_dispatch = True
  900. continue
  901. out += f"{line['filename']}:{line['line']}:{line['name']}\n"
  902. out += "\n\n"
  903. return out
  904. if capture_logs_fwd.logs is None:
  905. raise AssertionError("capture_logs_fwd.logs is None")
  906. if capture_logs_recompute.logs is None:
  907. raise AssertionError("capture_logs_recompute.logs is None")
  908. raise CheckpointError(
  909. _checkpoint_error_template.format(
  910. forward_traces=get_str_tb("original", capture_logs_fwd),
  911. recompute_traces=get_str_tb("recompute", capture_logs_recompute),
  912. forward_ops="\n".join(capture_logs_fwd.logs),
  913. recompute_ops="\n".join(capture_logs_recompute.logs)
  914. )
  915. ) from e
  916. def context_fn():
  917. return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager()
  918. return context_fn, unpack_error_cb
  919. def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
  920. # These properties are fast to check, easy to understand
  921. return {
  922. "shape": x.shape,
  923. "dtype": x.dtype,
  924. "device": x.device
  925. }
  926. _allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
  927. _DEFAULT_DETERMINISM_MODE: _default_meta_extractor,
  928. "none": lambda _: None,
  929. }
  930. # See Rule 5
  931. class _StopRecomputationError(Exception):
  932. pass
  933. class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
  934. def __init__(self, target_frame_ref: ReferenceType, gid: Union["GraphExecGroup", int]) -> None:
  935. def pack_hook(x):
  936. x = x.detach() if x.requires_grad else x
  937. target_frame = target_frame_ref()
  938. if target_frame is None:
  939. raise AssertionError("Internal error: target_frame reference is None")
  940. recomp_idx = target_frame.recomp_counter[gid]
  941. target_frame.recomp_counter[gid] += 1
  942. if recomp_idx >= len(target_frame.weak_holders):
  943. if target_frame.early_stop:
  944. raise AssertionError("Unexpected state: target_frame.early_stop is set")
  945. if not target_frame.forward_completed:
  946. # We run into this case when early stop is not enabled and do
  947. # grad within checkpoint.
  948. # We need to set this flag, so we don't error out later when
  949. # we check if the number of tensors saved during forward and
  950. # recomputation match.
  951. target_frame.ignore_saved_mismatch = True
  952. return x
  953. raise CheckpointError(
  954. "torch.utils.checkpoint: trying to save more tensors during "
  955. "recomputation than during the original forward pass.\n"
  956. f"{_debug_tip_msg}"
  957. )
  958. holder = target_frame.weak_holders[recomp_idx]()
  959. # This holder may have been cleared because someone may have called
  960. # backward within forward. If so, we don't need to save.
  961. if holder is not None:
  962. _internal_assert(holder.handles.get(gid, None) is None)
  963. holder.handles[gid] = _Handle()
  964. target_frame.recomputed[gid][holder.handles[gid]] = x
  965. if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
  966. target_frame.weak_holders
  967. ):
  968. raise _StopRecomputationError
  969. # See Rule 6: [ retain_graph is True ] above
  970. return x
  971. def unpack_hook(x):
  972. # See Rule 6: [ retain_graph is True ] above for an example of when
  973. # the graph created during recomputation could be backwarded.
  974. return x
  975. super().__init__(pack_hook, unpack_hook)
  976. # torch._disable_dynamo creates a reference cycle with decorated function
  977. # This function is used to ensure that the decorated function does not have
  978. # a closure, so that other objects aren't also kept alive.
  979. # https://github.com/pytorch/pytorch/issues/154642
  980. # Note: does not work when fn is compiled
  981. @torch._disable_dynamo
  982. def _run_fn_with_dynamo_disabled(fn, *args, **kwargs):
  983. return fn(*args, **kwargs)
  984. class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
  985. def __init__(self, frame) -> None:
  986. def pack_hook(x):
  987. # See Rule 4 above
  988. holder = _Holder()
  989. frame.weak_holders.append(weakref.ref(holder))
  990. # Save metadata to detect non-determinism
  991. if frame.metadata_fn is not None:
  992. with torch.no_grad():
  993. frame.x_metadatas.append(frame.metadata_fn(x))
  994. return holder
  995. def unpack_hook(holder):
  996. # First check if we're inside a GraphExecGroup context
  997. gid: Union[GraphExecGroup, None, int] = GraphExecGroup._get_current_group()
  998. if gid is None:
  999. # Fallback to using the current graph task id
  1000. gid = torch._C._current_graph_task_id()
  1001. if gid == -1:
  1002. # generate a temporary id if we trigger unpack outside of a backward call
  1003. gid = int(uuid.uuid4())
  1004. if not frame.is_recomputed[gid]:
  1005. ctx = frame.input_saver.grad_fn
  1006. args = ctx.get_args(ctx.saved_tensors)
  1007. try:
  1008. with _recomputation_hook(
  1009. weakref.ref(frame), gid
  1010. ), torch.autograd.enable_grad():
  1011. # See Note: [compiled autograd and checkpoint unpack hook]
  1012. _run_fn_with_dynamo_disabled(frame.recompute_fn, *args)
  1013. except _StopRecomputationError:
  1014. pass
  1015. frame.is_recomputed[gid] = True
  1016. frame.check_recomputed_tensors_match(gid)
  1017. _internal_assert(gid in holder.handles)
  1018. if holder.handles[gid] is None:
  1019. extra = ""
  1020. if torch._C._get_graph_exec_group() is not None:
  1021. extra = (
  1022. "Performing two backward calls that overlap (i.e. require the same "
  1023. "saved activation in order to compute gradients) is not allowed while "
  1024. "under the torch.utils.checkpoint.GraphExecGroup context. "
  1025. )
  1026. raise CheckpointError(
  1027. "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
  1028. f"unpacked once. {extra}If you are calling ctx.saved_tensors in backward, make sure "
  1029. "to do so only once. Otherwise please open an issue with details on your use case."
  1030. )
  1031. _internal_assert(holder.handles[gid] in frame.recomputed[gid])
  1032. ret = frame.recomputed[gid][holder.handles[gid]]
  1033. holder.handles[gid] = None
  1034. return ret
  1035. if frame.unpack_error_cb is not None:
  1036. def unpack_hook_with_error_cb(holder):
  1037. try:
  1038. return unpack_hook(holder)
  1039. except CheckpointError as e:
  1040. frame.unpack_error_cb(e)
  1041. super().__init__(pack_hook, unpack_hook_with_error_cb)
  1042. else:
  1043. super().__init__(pack_hook, unpack_hook)
  1044. def _is_compiling(func, args, kwargs):
  1045. # Check if we are under AOTAutograd tracing
  1046. # Checking that a functional mode is active should always do what we want
  1047. return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) is not None
  1048. class _VersionWrapper:
  1049. # Check that cached tensors are not mutated.
  1050. def __init__(self, val) -> None:
  1051. self.val: Union[torch.Tensor, Any] = val
  1052. self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None
  1053. def get_val(self, allow_cache_entry_mutation):
  1054. if self.version is not None and not allow_cache_entry_mutation:
  1055. if self.val._version != self.version:
  1056. # Can we give user a stack trace of where the mutation happened?
  1057. raise RuntimeError(
  1058. "Tensor cached during selective activation checkpoint has been mutated"
  1059. )
  1060. return self.val
  1061. def _maybe_detach(x, any_ret_has_alias_info):
  1062. # We detach for two separate reasons:
  1063. # - For view ops, we need to ensure that when the tensor is returned from
  1064. # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr
  1065. # - Avoid reference cycles
  1066. # For case 1, it is not enough to check whether x has differentiable dtype
  1067. # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g.
  1068. # when the tensor is a view.
  1069. if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info):
  1070. with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False):
  1071. # Ensure that view performed beneath autograd properly propagates
  1072. # version counter. TODO: Use reentrant_dispatch instead of
  1073. # manually manipulating dispatch keys. Using reentrant_dispatch
  1074. # would respect inference_mode, though that is not relevant for
  1075. # this case.
  1076. x = x.detach()
  1077. return x
  1078. class SelectiveCheckpointContext:
  1079. """
  1080. Context passed to policy function during selective checkpointing.
  1081. This class is used to pass relevant metadata to the policy function during
  1082. selective checkpointing. The metadata includes whether the current invocation
  1083. of the policy function is during recomputation or not.
  1084. Example:
  1085. >>> # xdoctest: +SKIP(stub)
  1086. >>>
  1087. >>> def policy_fn(ctx, op, *args, **kwargs):
  1088. >>> print(ctx.is_recompute)
  1089. >>>
  1090. >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
  1091. >>>
  1092. >>> out = torch.utils.checkpoint.checkpoint(
  1093. >>> fn, x, y,
  1094. >>> use_reentrant=False,
  1095. >>> context_fn=context_fn,
  1096. >>> )
  1097. """
  1098. def __init__(self, *, is_recompute) -> None:
  1099. self.is_recompute = is_recompute
  1100. class CheckpointPolicy(enum.Enum):
  1101. """
  1102. Enum for specifying the policy for checkpointing during backpropagation.
  1103. The following policies are supported:
  1104. - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward
  1105. pass and will not be recomputed during the backward pass
  1106. - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the
  1107. forward pass and will be recomputed during the backward pass
  1108. - ``{MUST,PREFER}_CPU_OFFLOAD``: The operation's output will be saved during the
  1109. forward pass, offloaded to CPU, and reloaded to GPU during the backward pass
  1110. Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden
  1111. by other subsystems like `torch.compile`.
  1112. .. note::
  1113. A policy function that always returns ``PREFER_RECOMPUTE`` is
  1114. equivalent to vanilla checkpointing.
  1115. A policy function that returns ``PREFER_SAVE`` every op is
  1116. NOT equivalent to not using checkpointing. Using such a policy would
  1117. save additional tensors not limited to ones that are actually needed for
  1118. gradient computation.
  1119. """
  1120. MUST_SAVE = 0
  1121. PREFER_SAVE = 1
  1122. MUST_RECOMPUTE = 2
  1123. PREFER_RECOMPUTE = 3
  1124. MUST_CPU_OFFLOAD = 4
  1125. PREFER_CPU_OFFLOAD = 5
  1126. def _policy_from_bool(b):
  1127. # For backward compatibility
  1128. return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE
  1129. SAC_IGNORED_OPS = {
  1130. # AC inserts different number of detach during forward and recompute.
  1131. torch.ops.aten.detach.default,
  1132. # AC's determinism check invokes additional metadata ops during forward.
  1133. # With subclasses involved, these metadata ops become dispatchable, this
  1134. # can result in incorrectness if these ops are selected cached.
  1135. torch.ops.prim.device.default,
  1136. } | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) # type: ignore[has-type]
  1137. class _CachingTorchDispatchMode(TorchDispatchMode):
  1138. @classmethod
  1139. def ignore_compile_internals(cls):
  1140. return True
  1141. # Used together with _CachedTorchDispatchMode to implement SAC.
  1142. def __init__(self, policy_fn, storage) -> None:
  1143. self.policy_fn = policy_fn
  1144. self.storage = storage
  1145. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  1146. if func in SAC_IGNORED_OPS:
  1147. return func(*args, **kwargs)
  1148. kwargs = {} if kwargs is None else kwargs
  1149. policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False),
  1150. func, *args, **kwargs)
  1151. if isinstance(policy, bool):
  1152. policy = _policy_from_bool(policy)
  1153. is_compiling = _is_compiling(func, args, kwargs)
  1154. if is_compiling:
  1155. # Overwrite each node's "recompute" tag to add in the user annotation.
  1156. fx_traceback.current_meta["recompute"] = policy
  1157. out = func(*args, **kwargs)
  1158. # HOPs don't support func._schema
  1159. # HOPs don't alias -> this is always true today and will be always true for a long time
  1160. # TODO HOPs don't mutate -> this is always true today but will not be true forever
  1161. if isinstance(func, torch._ops.HigherOrderOperator):
  1162. any_ret_has_alias_info = False
  1163. else:
  1164. any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)
  1165. if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
  1166. self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out))
  1167. return out
  1168. class _CachedTorchDispatchMode(TorchDispatchMode):
  1169. @classmethod
  1170. def ignore_compile_internals(cls):
  1171. return True
  1172. # Used together with _CachedTorchDispatchMode to implement SAC.
  1173. def __init__(self, policy_fn, storage, allow_cache_entry_mutation) -> None:
  1174. self.policy_fn = policy_fn
  1175. self.storage = storage
  1176. self.allow_cache_entry_mutation = allow_cache_entry_mutation
  1177. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  1178. if func in SAC_IGNORED_OPS:
  1179. return func(*args, **kwargs)
  1180. kwargs = {} if kwargs is None else kwargs
  1181. policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True),
  1182. func, *args, **kwargs)
  1183. if isinstance(policy, bool):
  1184. policy = _policy_from_bool(policy)
  1185. is_compiling = _is_compiling(func, args, kwargs)
  1186. if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
  1187. storage = self.storage.get(func)
  1188. if storage is None:
  1189. raise RuntimeError(f"{func} encountered during backward, but not found in storage")
  1190. if len(storage) == 0:
  1191. raise RuntimeError(
  1192. "Trying to backward an extra time. You are only allowed to backward once "
  1193. "on any region computed under selective activation checkpoint."
  1194. )
  1195. out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
  1196. else:
  1197. out = func(*args, **kwargs)
  1198. return out
  1199. def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
  1200. """
  1201. Helper to avoid recomputing certain ops during activation checkpointing.
  1202. Use this with `torch.utils.checkpoint.checkpoint` to control which
  1203. operations are recomputed during the backward pass.
  1204. Args:
  1205. policy_fn_or_list (Callable or List):
  1206. - If a policy function is provided, it should accept a
  1207. :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and
  1208. kwargs to the op, and return a :class:`CheckpointPolicy` enum value
  1209. indicating whether the execution of the op should be recomputed or not.
  1210. - If a list of operations is provided, it is equivalent to a policy
  1211. returning `CheckpointPolicy.MUST_SAVE` for the specified
  1212. operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other
  1213. operations.
  1214. allow_cache_entry_mutation (bool, optional): By default, an error is
  1215. raised if any tensors cached by selective activation checkpoint are
  1216. mutated in order to ensure correctness. If set to `True`, this check
  1217. is disabled.
  1218. Returns:
  1219. A tuple of two context managers.
  1220. Example:
  1221. >>> # xdoctest: +REQUIRES(LINUX)
  1222. >>> import functools
  1223. >>>
  1224. >>> x = torch.rand(10, 10, requires_grad=True)
  1225. >>> y = torch.rand(10, 10, requires_grad=True)
  1226. >>>
  1227. >>> ops_to_save = [
  1228. >>> torch.ops.aten.mm.default,
  1229. >>> ]
  1230. >>>
  1231. >>> def policy_fn(ctx, op, *args, **kwargs):
  1232. >>> if op in ops_to_save:
  1233. >>> return CheckpointPolicy.MUST_SAVE
  1234. >>> else:
  1235. >>> return CheckpointPolicy.PREFER_RECOMPUTE
  1236. >>>
  1237. >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
  1238. >>>
  1239. >>> # or equivalently
  1240. >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
  1241. >>>
  1242. >>> def fn(x, y):
  1243. >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
  1244. >>>
  1245. >>> out = torch.utils.checkpoint.checkpoint(
  1246. >>> fn, x, y,
  1247. >>> use_reentrant=False,
  1248. >>> context_fn=context_fn,
  1249. >>> )
  1250. """
  1251. # NB: If grad_mode is disabled, checkpoint would not run forward under
  1252. # context_fn anyway, so proceed as usual.
  1253. if isinstance(policy_fn_or_list, list):
  1254. for op in policy_fn_or_list:
  1255. if not isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
  1256. _extra_msg = (
  1257. "Please update the OpOverloadPacket to a specific OpOverload."
  1258. "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`."
  1259. ) if isinstance(op, torch._ops.OpOverloadPacket) else ""
  1260. raise ValueError(
  1261. f"Expected op in `op_list` to be an OpOverload but got: {op} "
  1262. f"of type {type(op)}. {_extra_msg}"
  1263. )
  1264. def policy_fn(ctx, op, *args, **kwargs):
  1265. if op in policy_fn_or_list:
  1266. return CheckpointPolicy.MUST_SAVE
  1267. else:
  1268. return CheckpointPolicy.PREFER_RECOMPUTE
  1269. elif callable(policy_fn_or_list):
  1270. policy_fn = policy_fn_or_list
  1271. else:
  1272. raise TypeError("policy_fn_or_list must be either a function or a list of ops.")
  1273. storage: Dict[Any, List[Any]] = defaultdict(list)
  1274. return (
  1275. _CachingTorchDispatchMode(policy_fn, storage),
  1276. _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation),
  1277. )
  1278. # NB: this helper wraps fn before calling checkpoint_impl. kwargs and
  1279. # saving/restoring of global state is handled here.
  1280. def _checkpoint_without_reentrant_generator(
  1281. fn,
  1282. preserve_rng_state=True,
  1283. context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
  1284. determinism_check: str = _DEFAULT_DETERMINISM_MODE,
  1285. debug: bool = False,
  1286. early_stop: bool = True,
  1287. *args,
  1288. **kwargs
  1289. ):
  1290. """Checkpointing without reentrant autograd.
  1291. Args:
  1292. fn: describes what to run in the forward pass of the model or
  1293. part of the model. It should also know how to handle the inputs
  1294. passed as the tuple. For example, in LSTM, if user passes
  1295. ``(activation, hidden)``, :attr:`function` should correctly use the
  1296. first input as ``activation`` and the second input as ``hidden``
  1297. preserve_rng_state(bool, optional): Omit stashing and restoring
  1298. the RNG state during each checkpoint.
  1299. Default: ``True``
  1300. context_fn(Callable, optional): A callable returning a tuple of two
  1301. context managers. The function and its recomputation will be run
  1302. under the first and second context managers respectively.
  1303. determinism_check(str, optional): A string specifying the determinism
  1304. check to perform. By default it is set to ``"default"`` which
  1305. compares the shapes, dtypes, and devices of the recomputed tensors
  1306. against those the saved tensors. To turn off this check, specify
  1307. ``"none"``. Currently these are the only two supported values.
  1308. Please open an issue if you would like to see more determinism
  1309. checks.
  1310. debug(bool, optional): If ``True``, error messages will also include
  1311. a trace of the operators ran during the original forward computation
  1312. as well as the recomputation.
  1313. early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
  1314. recomputation as soon as it has computed all needed Tensors. Can be
  1315. overridden globally using :func:`set_checkpoint_early_stop` context
  1316. manager. Default: ``True``.
  1317. *args: Arguments to pass in to the given ``function``.
  1318. **kwargs: Keyword arguments to pass into the given ``function``.
  1319. """
  1320. unpack_error_cb = None
  1321. if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
  1322. if context_fn is not noop_context_fn:
  1323. raise ValueError(
  1324. "debug=True is incompatible with non-default context_fn"
  1325. )
  1326. context_fn, unpack_error_cb = _get_debug_context_and_cb()
  1327. if determinism_check in _allowed_determinism_checks_to_fns:
  1328. metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
  1329. else:
  1330. raise ValueError(
  1331. f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
  1332. f"but got {determinism_check}"
  1333. )
  1334. device_type = _infer_device_type(*args)
  1335. device_module = _get_device_module(device_type)
  1336. forward_context, recompute_context = context_fn()
  1337. if _is_compiling(fn, args, kwargs) and context_fn is not noop_context_fn:
  1338. if (
  1339. not isinstance(forward_context, TorchDispatchMode)
  1340. or not isinstance(recompute_context, TorchDispatchMode)
  1341. ):
  1342. raise AssertionError(
  1343. "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` "
  1344. "must generate a tuple of two `TorchDispatchMode`s."
  1345. )
  1346. # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
  1347. device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type)
  1348. if preserve_rng_state:
  1349. fwd_cpu_state = torch.get_rng_state()
  1350. # Don't eagerly initialize the cuda context by accident.
  1351. # (If the user intends that the context is initialized later, within their
  1352. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  1353. # we have no way to anticipate this will happen before we run the function.
  1354. # If they do so, we raise an error.)
  1355. had_device_in_fwd = False
  1356. if getattr(device_module, "_initialized", False):
  1357. had_device_in_fwd = True
  1358. fwd_devices, fwd_device_states = get_device_states(*args)
  1359. from torch.overrides import _get_current_function_mode_stack
  1360. from torch.utils._device import DeviceContext
  1361. # recompute_fn should respect the device context of the original forward
  1362. device_ctx = next(
  1363. filter(
  1364. lambda mode: isinstance(mode, DeviceContext),
  1365. reversed(_get_current_function_mode_stack()),
  1366. ),
  1367. contextlib.nullcontext(),
  1368. )
  1369. def recompute_fn(*inputs) -> None:
  1370. kwargs, *args = inputs
  1371. # This will be called later during recomputation. This wrapping enables
  1372. # the necessary global state to be captured.
  1373. rng_devices = []
  1374. if preserve_rng_state and had_device_in_fwd:
  1375. rng_devices = fwd_devices
  1376. with torch.random.fork_rng(
  1377. devices=rng_devices, enabled=preserve_rng_state, device_type=device_type
  1378. ):
  1379. if preserve_rng_state:
  1380. torch.set_rng_state(fwd_cpu_state)
  1381. if had_device_in_fwd:
  1382. set_device_states(fwd_devices, fwd_device_states, device_type=device_type)
  1383. device_autocast_ctx = torch.amp.autocast(
  1384. device_type=device_type, **device_autocast_kwargs
  1385. ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext()
  1386. with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context, device_ctx: # type: ignore[attr-defined]
  1387. fn(*args, **kwargs)
  1388. new_frame = _CheckpointFrame(
  1389. recompute_fn,
  1390. _enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop,
  1391. unpack_error_cb,
  1392. metadata_fn
  1393. )
  1394. dummy = torch.empty((0,), requires_grad=True)
  1395. new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
  1396. # When ambient grad_mode is False
  1397. if new_frame.input_saver.grad_fn is None:
  1398. yield
  1399. return
  1400. with _checkpoint_hook(new_frame), forward_context:
  1401. yield
  1402. new_frame.forward_completed = True
  1403. if getattr(device_module, "_initialized", False) and \
  1404. preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined]
  1405. # Device was not initialized before running the forward, so we didn't
  1406. # stash the device state.
  1407. raise RuntimeError(
  1408. "PyTorch's device state was initialized in the forward pass "
  1409. "of a Checkpoint, which is not allowed. Please open an issue "
  1410. "if you need this feature."
  1411. )
  1412. return
  1413. class GraphExecGroup:
  1414. """Any checkpointed regions encountered by backward under the same instance
  1415. of this context manager will trigger recompute at most once, even if
  1416. there are multiple calls to backward.
  1417. Backward calls under the same instance of this context manager must execute
  1418. over non-overlapping regions of the backward graph even if retain_graph=True.
  1419. In particular, any two backward call cannot use the same saved activation for
  1420. gradient computation.
  1421. .. note::
  1422. This context manager only affects checkpoint with use_reentrant=False, and
  1423. is a no-op otherwise.
  1424. """
  1425. def __enter__(self) -> "GraphExecGroup":
  1426. if torch._C._get_graph_exec_group() is not None:
  1427. raise RuntimeError(
  1428. "GraphExecGroup contexts cannot be nested. "
  1429. f"Already inside group {torch._C._get_graph_exec_group()}"
  1430. )
  1431. torch._C._set_graph_exec_group(self)
  1432. return self
  1433. def __exit__(self, *args: object) -> None:
  1434. torch._C._set_graph_exec_group(None)
  1435. @classmethod
  1436. def _get_current_group(cls) -> Optional["GraphExecGroup"]:
  1437. # Private API to be used by utils like AC
  1438. return torch._C._get_graph_exec_group()
  1439. # Note: [compiled autograd and checkpoint unpack hook]
  1440. # When tracing via compiled autograd, this hook will be visible to the
  1441. # compiler if the forward of this checkpointed region ran in eager.
  1442. # If the forward had ran under compile, it would have been wrapped in a
  1443. # higher order op. See Note: [torch.compile and checkpoint].
  1444. #
  1445. # Since we run the recomputation hook under a enable_grad context,
  1446. # AOTDispatch will trace a joint graph for this hook, and may
  1447. # save different activations than in eager. This conflicts with the
  1448. # strict activation count checks in `frame.check_recomputed_tensors_match`.
  1449. # So, we disable this hook to force it to recompute eager checkpointed regions
  1450. # in eager. This could be removed if we can disable the partitioner for this
  1451. # graph segment.