__init__.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. # mypy: allow-untyped-defs
  2. """
  3. ``torch.autograd`` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions.
  4. It requires minimal changes to the existing code - you only need to declare :class:`Tensor` s
  5. for which gradients should be computed with the ``requires_grad=True`` keyword.
  6. As of now, we only support autograd for floating point :class:`Tensor` types (
  7. half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).
  8. """
  9. import warnings
  10. from collections.abc import Sequence
  11. from typing import cast, Optional, Union
  12. import torch
  13. from torch import _vmap_internals
  14. from torch.overrides import handle_torch_function, has_torch_function, is_tensor_like
  15. from torch.types import _size, _TensorOrOptionalTensors, _TensorOrTensorsOrGradEdge
  16. from . import forward_ad, functional, graph
  17. from .anomaly_mode import detect_anomaly, set_detect_anomaly
  18. from .function import Function, NestedIOFunction
  19. from .grad_mode import (
  20. _force_original_view_tracking,
  21. _unsafe_preserve_version_counter,
  22. enable_grad,
  23. inference_mode,
  24. no_grad,
  25. set_grad_enabled,
  26. set_multithreading_enabled,
  27. )
  28. from .gradcheck import gradcheck, gradgradcheck
  29. from .graph import _engine_run_backward
  30. from .variable import Variable
  31. __all__ = [
  32. "Variable",
  33. "Function",
  34. "backward",
  35. "grad_mode",
  36. "NestedIOFunction",
  37. "detect_anomaly",
  38. "enable_grad",
  39. "grad",
  40. "gradcheck",
  41. "gradgradcheck",
  42. "inference_mode",
  43. "no_grad",
  44. "set_detect_anomaly",
  45. "set_grad_enabled",
  46. "set_multithreading_enabled",
  47. "variable",
  48. ]
  49. _OptionalTensor = Optional[torch.Tensor]
  50. _ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
  51. def _calculate_shape(
  52. output: Union[torch.Tensor, graph.GradientEdge],
  53. grad: torch.Tensor,
  54. is_grads_batched: bool,
  55. ) -> tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
  56. # is_same_size ensures that both tensors are either nested or non nested
  57. # circular import
  58. from torch.nested._internal.nested_tensor import NestedTensor
  59. if isinstance(output, graph.GradientEdge):
  60. # We have already checked that we are not a C++ NestedTensor
  61. if is_grads_batched:
  62. raise RuntimeError("Batched grads are not supported with GradientEdge")
  63. out_metadata = output.node._input_metadata[output.output_nr]
  64. return torch.Size(out_metadata.shape), grad.shape
  65. if output.is_nested and not isinstance(output, NestedTensor):
  66. if is_grads_batched:
  67. raise RuntimeError("Batched grads are not supported with Nested Tensor.")
  68. out_shape = output._nested_tensor_size()
  69. grad_shape = grad._nested_tensor_size()
  70. return out_shape, grad_shape
  71. reg_out_shape = output.shape
  72. reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
  73. return reg_out_shape, reg_grad_shape
  74. def _make_grads(
  75. outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
  76. grads: Sequence[_OptionalTensor],
  77. is_grads_batched: bool,
  78. ) -> tuple[_OptionalTensor, ...]:
  79. new_grads: list[_OptionalTensor] = []
  80. for out, grad in zip(outputs, grads):
  81. # pyrefly: ignore [redundant-cast]
  82. out = cast(Union[torch.Tensor, graph.GradientEdge], out)
  83. out_size = None
  84. out_device = None
  85. if isinstance(out, graph.GradientEdge):
  86. out_metadata = out.node._input_metadata[out.output_nr]
  87. out_size = torch.Size(out_metadata.shape)
  88. out_dtype = out_metadata.dtype
  89. out_device = out_metadata.device
  90. out_is_nested = out_metadata.is_nested_tensor
  91. if out_metadata.is_cpp_nested_tensor:
  92. raise RuntimeError(
  93. "C++ NestedTensor are not supported with GradientEdge"
  94. )
  95. out_is_cpp_nested = False
  96. else:
  97. # circular import
  98. from torch.nested._internal.nested_tensor import NestedTensor
  99. if not isinstance(out, torch.Tensor):
  100. raise AssertionError("Expected output to be a torch.Tensor")
  101. out_dtype = out.dtype
  102. out_is_nested = out.is_nested
  103. out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor)
  104. if not out_is_cpp_nested:
  105. out_size = out.shape
  106. if isinstance(grad, torch.Tensor):
  107. from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
  108. first_grad = grad if not is_grads_batched else grad[0]
  109. # TODO: We can remove this conditional once we uniformly use
  110. # singleton int to represent jagged dimension, so that size() call
  111. # on nested tensor works.
  112. if out_is_cpp_nested:
  113. if not isinstance(out, torch.Tensor):
  114. raise AssertionError("Expected output to be a torch.Tensor.")
  115. shape_matches = torch.is_same_size(out, first_grad)
  116. else:
  117. # We need to do a regular size check, without going through
  118. # the operator, to be able to handle unbacked symints
  119. # (expect_true ensures we can deal with unbacked)
  120. if out_size is None:
  121. raise AssertionError("Expected out_size to be set.")
  122. shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
  123. if not shape_matches:
  124. out = cast(Union[torch.Tensor, graph.GradientEdge], out) # type: ignore[redundant-cast]
  125. out_shape, grad_shape = _calculate_shape(
  126. out, first_grad, is_grads_batched
  127. )
  128. if is_grads_batched:
  129. raise RuntimeError(
  130. "If `is_grads_batched=True`, we interpret the first "
  131. "dimension of each grad_output as the batch dimension. "
  132. "The sizes of the remaining dimensions are expected to match "
  133. "the shape of corresponding output, but a mismatch "
  134. "was detected: grad_output["
  135. + str(grads.index(grad))
  136. + "] has a shape of "
  137. + str(grad_shape)
  138. + " and output["
  139. + str(outputs.index(out))
  140. + "] has a shape of "
  141. + str(out_shape)
  142. + ". "
  143. "If you only want some tensors in `grad_output` to be considered "
  144. "batched, consider using vmap."
  145. )
  146. else:
  147. raise RuntimeError(
  148. "Mismatch in shape: grad_output["
  149. + str(grads.index(grad))
  150. + "] has a shape of "
  151. + str(grad_shape)
  152. + " and output["
  153. + str(outputs.index(out))
  154. + "] has a shape of "
  155. + str(out_shape)
  156. + "."
  157. )
  158. if out_dtype.is_complex != grad.dtype.is_complex:
  159. raise RuntimeError(
  160. "For complex Tensors, both grad_output and output"
  161. " are required to have the same dtype."
  162. " Mismatch in dtype: grad_output["
  163. + str(grads.index(grad))
  164. + "] has a dtype of "
  165. + str(grad.dtype)
  166. + " and output["
  167. + str(outputs.index(out))
  168. + "] has a dtype of "
  169. + str(out_dtype)
  170. + "."
  171. )
  172. new_grads.append(grad)
  173. elif grad is None:
  174. if isinstance(out, graph.GradientEdge) or out.requires_grad: # type: ignore[attr-defined]
  175. if isinstance(out, graph.GradientEdge):
  176. if out_size is None:
  177. raise AssertionError("Expected out_size to be set.")
  178. out_numel_is_1 = all(o == 1 for o in out_size)
  179. else:
  180. if not isinstance(out, torch.Tensor):
  181. raise AssertionError("Expected output to be a torch.Tensor")
  182. out_numel_is_1 = out.numel() == 1
  183. if not out_numel_is_1:
  184. raise RuntimeError(
  185. "grad can be implicitly created only for scalar outputs"
  186. )
  187. if not out_dtype.is_floating_point:
  188. msg = (
  189. "grad can be implicitly created only for real scalar outputs"
  190. f" but got {out_dtype}"
  191. )
  192. raise RuntimeError(msg)
  193. if isinstance(out, graph.GradientEdge):
  194. if out_size is None:
  195. raise AssertionError("Expected out_size to be set.")
  196. if out_device is None:
  197. raise AssertionError("Expected out_device to be set.")
  198. new_grads.append(
  199. torch.ones(
  200. out_size,
  201. dtype=out_dtype,
  202. device=out_device,
  203. )
  204. )
  205. else:
  206. if not isinstance(out, torch.Tensor):
  207. raise AssertionError("Expected output to be a torch.Tensor")
  208. new_grads.append(
  209. torch.ones_like(out, memory_format=torch.preserve_format)
  210. )
  211. else:
  212. new_grads.append(None)
  213. else:
  214. raise TypeError(
  215. "gradients can be either Tensors or None, but got "
  216. + type(grad).__name__
  217. )
  218. return tuple(new_grads)
  219. def _tensor_or_tensors_to_tuple(
  220. tensors: Optional[_TensorOrOptionalTensors], length: int
  221. ) -> tuple[_OptionalTensor, ...]:
  222. if tensors is None:
  223. return (None,) * length
  224. if isinstance(tensors, torch.Tensor):
  225. return (tensors,)
  226. return tuple(tensors)
  227. def backward(
  228. tensors: _TensorOrTensorsOrGradEdge,
  229. grad_tensors: Optional[_TensorOrOptionalTensors] = None,
  230. retain_graph: Optional[bool] = None,
  231. create_graph: bool = False,
  232. grad_variables: Optional[_TensorOrOptionalTensors] = None,
  233. inputs: Optional[_TensorOrTensorsOrGradEdge] = None,
  234. ) -> None:
  235. r"""Compute the sum of gradients of given tensors with respect to graph leaves.
  236. The graph is differentiated using the chain rule. If any of ``tensors``
  237. are non-scalar (i.e. their data has more than one element) and require
  238. gradient, then the Jacobian-vector product would be computed, in this
  239. case the function additionally requires specifying ``grad_tensors``.
  240. It should be a sequence of matching length, that contains the "vector"
  241. in the Jacobian-vector product, usually the gradient of the differentiated
  242. function w.r.t. corresponding tensors (``None`` is an acceptable value for
  243. all tensors that don't need gradient tensors).
  244. This function accumulates gradients in the leaves - you might need to zero
  245. ``.grad`` attributes or set them to ``None`` before calling it.
  246. See :ref:`Default gradient layouts<default-grad-layouts>`
  247. for details on the memory layout of accumulated gradients.
  248. .. note::
  249. Using this method with ``create_graph=True`` will create a reference cycle
  250. between the parameter and its gradient which can cause a memory leak.
  251. We recommend using ``autograd.grad`` when creating the graph to avoid this.
  252. If you have to use this function, make sure to reset the ``.grad`` fields of your
  253. parameters to ``None`` after use to break the cycle and avoid the leak.
  254. .. note::
  255. If you run any forward ops, create ``grad_tensors``, and/or call ``backward``
  256. in a user-specified CUDA stream context, see
  257. :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
  258. .. note::
  259. When ``inputs`` are provided and a given input is not a leaf,
  260. the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
  261. It is an implementation detail on which the user should not rely.
  262. See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
  263. Args:
  264. tensors (Sequence[Tensor] or Tensor or Sequence[GradientEdge] or GradientEdge): Tensors of which
  265. the derivative will be computed.
  266. grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
  267. the Jacobian-vector product, usually gradients w.r.t. each element of
  268. corresponding tensors. None values can be specified for scalar Tensors or
  269. ones that don't require grad. If a None value would be acceptable for all
  270. grad_tensors, then this argument is optional.
  271. retain_graph (bool, optional): If ``False``, the graph used to compute the grad
  272. will be freed. Note that in nearly all cases setting this option to ``True``
  273. is not needed and often can be worked around in a much more efficient
  274. way. Defaults to the value of ``create_graph``.
  275. create_graph (bool, optional): If ``True``, graph of the derivative will
  276. be constructed, allowing to compute higher order derivative products.
  277. Defaults to ``False``.
  278. inputs (Sequence[Tensor] or Tensor or Sequence[GradientEdge], optional): Inputs w.r.t. which the gradient
  279. be will accumulated into ``.grad``. All other Tensors will be ignored. If
  280. not provided, the gradient is accumulated into all the leaf Tensors that
  281. were used to compute the :attr:`tensors`.
  282. """
  283. if torch._C._are_functorch_transforms_active():
  284. raise RuntimeError(
  285. "backward() called inside a functorch transform. This is not "
  286. "supported, please use functorch.grad or functorch.vjp instead "
  287. "or call backward() outside of functorch transforms."
  288. )
  289. if grad_variables is not None:
  290. warnings.warn(
  291. "`grad_variables` is deprecated. Use `grad_tensors` instead.",
  292. FutureWarning,
  293. stacklevel=2,
  294. )
  295. if grad_tensors is None:
  296. grad_tensors = grad_variables
  297. else:
  298. raise RuntimeError(
  299. "`grad_tensors` and `grad_variables` (deprecated) "
  300. "arguments both passed to `backward()`. Please only "
  301. "use `grad_tensors`."
  302. )
  303. inputs_tuple: tuple[Union[torch.Tensor, graph.GradientEdge], ...]
  304. if inputs is None:
  305. inputs_tuple = ()
  306. elif isinstance(inputs, (torch.Tensor, graph.GradientEdge)):
  307. inputs_tuple = (inputs,)
  308. else:
  309. inputs_tuple = tuple(inputs)
  310. if len(inputs_tuple) == 0:
  311. raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
  312. if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge):
  313. tensors = cast(
  314. Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
  315. )
  316. else:
  317. # pyrefly: ignore [bad-argument-type]
  318. tensors = tuple(tensors)
  319. # Check for __torch_function__ on tensors (similar to torch.autograd.grad)
  320. # This allows tensor subclasses to customize backward behavior
  321. t_tensors = tuple(t for t in tensors if is_tensor_like(t))
  322. t_inputs = tuple(t for t in inputs_tuple if is_tensor_like(t))
  323. overridable_args = t_tensors + t_inputs
  324. if has_torch_function(overridable_args):
  325. return handle_torch_function(
  326. backward,
  327. overridable_args,
  328. tensors,
  329. grad_tensors=grad_tensors,
  330. retain_graph=retain_graph,
  331. create_graph=create_graph,
  332. inputs=inputs,
  333. )
  334. grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
  335. grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
  336. if retain_graph is None:
  337. retain_graph = create_graph
  338. # The reason we repeat the same comment below is that
  339. # some Python versions print out the first line of a multi-line function
  340. # calls in the traceback and some print out the last line
  341. _engine_run_backward(
  342. tensors,
  343. grad_tensors_,
  344. retain_graph,
  345. create_graph,
  346. inputs_tuple,
  347. allow_unreachable=True,
  348. accumulate_grad=True,
  349. )
  350. def grad(
  351. outputs: _TensorOrTensorsOrGradEdge,
  352. inputs: _TensorOrTensorsOrGradEdge,
  353. grad_outputs: Optional[_TensorOrOptionalTensors] = None,
  354. retain_graph: Optional[bool] = None,
  355. create_graph: bool = False,
  356. only_inputs: bool = True,
  357. allow_unused: Optional[bool] = None,
  358. is_grads_batched: bool = False,
  359. materialize_grads: bool = False,
  360. ) -> tuple[torch.Tensor, ...]:
  361. r"""Compute and return the sum of gradients of outputs with respect to the inputs.
  362. ``grad_outputs`` should be a sequence of length matching ``output``
  363. containing the "vector" in vector-Jacobian product, usually the pre-computed
  364. gradients w.r.t. each of the outputs. If an output doesn't require_grad,
  365. then the gradient can be ``None``).
  366. .. note::
  367. If you run any forward ops, create ``grad_outputs``, and/or call ``grad``
  368. in a user-specified CUDA stream context, see
  369. :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
  370. .. note::
  371. ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``).
  372. To accumulate gradient for other parts of the graph, please use
  373. ``torch.autograd.backward``.
  374. Args:
  375. outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function.
  376. inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be
  377. returned (and not accumulated into ``.grad``).
  378. grad_outputs (sequence of [Tensor or None] or Tensor, optional): The "vector" in the
  379. vector-Jacobian product. Usually gradients w.r.t. each output. None values can be
  380. specified for scalar Tensors or ones that don't require grad. If a None value would be
  381. acceptable for all grad_tensors, then this argument is optional. Default: None.
  382. retain_graph (bool, optional): If ``False``, the graph used to compute the grad
  383. will be freed. Note that in nearly all cases setting this option to ``True``
  384. is not needed and often can be worked around in a much more efficient
  385. way. Defaults to the value of ``create_graph``.
  386. create_graph (bool, optional): If ``True``, graph of the derivative will
  387. be constructed, allowing to compute higher order derivative products.
  388. Default: ``False``.
  389. allow_unused (Optional[bool], optional): If ``False``, specifying inputs
  390. that were not used when computing outputs (and therefore their grad is
  391. always zero) is an error. Defaults to the value of ``materialize_grads``.
  392. is_grads_batched (bool, optional): If ``True``, the first dimension of each
  393. tensor in ``grad_outputs`` will be interpreted as the batch dimension.
  394. Instead of computing a single vector-Jacobian product, we compute a
  395. batch of vector-Jacobian products for each "vector" in the batch.
  396. We use the vmap prototype feature as the backend to vectorize calls
  397. to the autograd engine so that this computation can be performed in a
  398. single call. This should lead to performance improvements when compared
  399. to manually looping and performing backward multiple times. Note that
  400. due to this feature being experimental, there may be performance
  401. cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``
  402. to show any performance warnings and file an issue on github if warnings exist
  403. for your use case. Defaults to ``False``.
  404. materialize_grads (bool, optional): If ``True``, set the gradient for unused inputs
  405. to zero instead of None. This is useful when computing higher-order derivatives.
  406. If ``materialize_grads`` is ``True`` and ``allow_unused`` is ``False``, an error
  407. will be raised. Defaults to ``False``.
  408. """
  409. if materialize_grads and allow_unused is False:
  410. raise ValueError(
  411. "Expected allow_unused to be True or not passed when materialize_grads=True, "
  412. "but got: allow_unused=False."
  413. )
  414. if allow_unused is None:
  415. allow_unused = materialize_grads
  416. if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge):
  417. outputs = cast(
  418. Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
  419. )
  420. else:
  421. # pyrefly: ignore [bad-argument-type]
  422. outputs = tuple(outputs)
  423. if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
  424. inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
  425. else:
  426. # pyrefly: ignore [bad-argument-type]
  427. inputs = tuple(inputs)
  428. t_outputs = tuple(i for i in outputs if is_tensor_like(i))
  429. t_inputs = tuple(i for i in inputs if is_tensor_like(i))
  430. overridable_args = t_outputs + t_inputs
  431. if has_torch_function(overridable_args):
  432. return handle_torch_function(
  433. grad,
  434. overridable_args,
  435. outputs,
  436. inputs,
  437. grad_outputs=grad_outputs,
  438. retain_graph=retain_graph,
  439. create_graph=create_graph,
  440. only_inputs=only_inputs,
  441. allow_unused=allow_unused,
  442. is_grads_batched=is_grads_batched,
  443. materialize_grads=materialize_grads,
  444. )
  445. if not only_inputs:
  446. warnings.warn(
  447. "only_inputs argument is deprecated and is ignored now "
  448. "(defaults to True). To accumulate gradient for other "
  449. "parts of the graph, please use torch.autograd.backward.",
  450. FutureWarning,
  451. stacklevel=2,
  452. )
  453. grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
  454. grad_outputs_ = _make_grads(
  455. outputs, grad_outputs_, is_grads_batched=is_grads_batched
  456. )
  457. if retain_graph is None:
  458. retain_graph = create_graph
  459. # The reason we repeat the same comment several times below is because
  460. # some Python versions print out the first line of multi-line function
  461. # calls in the traceback and some print out the last line
  462. if is_grads_batched:
  463. def vjp(gO):
  464. return _engine_run_backward(
  465. outputs,
  466. gO,
  467. retain_graph,
  468. create_graph,
  469. inputs,
  470. allow_unused,
  471. accumulate_grad=False,
  472. )
  473. result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
  474. grad_outputs_
  475. )
  476. else:
  477. result = _engine_run_backward(
  478. outputs,
  479. grad_outputs_,
  480. retain_graph,
  481. create_graph,
  482. inputs,
  483. allow_unused,
  484. accumulate_grad=False,
  485. )
  486. if materialize_grads:
  487. if any(
  488. result[i] is None and not is_tensor_like(inputs[i])
  489. for i in range(len(inputs))
  490. ):
  491. raise RuntimeError(
  492. "materialize_grads cannot be used when the given input is a GradientEdge"
  493. )
  494. result = tuple(
  495. output
  496. if output is not None
  497. else torch.zeros_like(input, requires_grad=create_graph)
  498. for (output, input) in zip(result, inputs)
  499. )
  500. return result
  501. # This function applies in case of gradient checkpointing for memory
  502. # optimization. Currently, gradient checkpointing is supported only if the
  503. # execution engine is invoked through torch.autograd.backward() and its
  504. # inputs argument is not passed. It is not supported for torch.autograd.grad().
  505. # This is because if inputs are specified, the gradient won't be calculated for
  506. # anything else e.g. model parameters like weights, bias etc.
  507. #
  508. # This function returns whether the checkpointing is valid i.e. torch.autograd.backward
  509. # or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
  510. # local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask
  511. # in the stack and before a NodeTask is executed in evaluate_function, it
  512. # checks for whether reentrant backwards is imperative or not.
  513. # See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
  514. def _is_checkpoint_valid():
  515. return Variable._execution_engine.is_checkpoint_valid()
  516. def variable(*args, **kwargs): # noqa: D103
  517. raise RuntimeError(
  518. "torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead"
  519. )
  520. # Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing
  521. # f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the
  522. # output of an FX graph. Unfortunately the module name torch.autograd.variable is shadowed by the
  523. # deprecated function - variable(...).
  524. variable.Variable = Variable # type: ignore[attr-defined]
  525. if not torch._C._autograd_init():
  526. raise RuntimeError("autograd initialization failed")
  527. # Import all native method/classes
  528. from torch._C._autograd import (
  529. _add_metadata_json,
  530. _disable_profiler,
  531. _disable_profiler_legacy,
  532. _enable_profiler,
  533. _enable_profiler_legacy,
  534. _enable_record_function,
  535. _get_sequence_nr,
  536. _kineto_step,
  537. _KinetoEvent,
  538. _pop_saved_tensors_default_hooks,
  539. _prepare_profiler,
  540. _profiler_enabled,
  541. _ProfilerResult,
  542. _push_saved_tensors_default_hooks,
  543. _record_function_with_args_enter,
  544. _record_function_with_args_exit,
  545. _set_empty_test_observer,
  546. _supported_activities,
  547. _toggle_collection_dynamic,
  548. DeviceType,
  549. kineto_available,
  550. ProfilerEvent,
  551. SavedTensor,
  552. )
  553. from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState
  554. from . import profiler
  555. def _register_py_tensor_class_for_device(device, cls):
  556. if not isinstance(cls, type):
  557. raise RuntimeError("cls isn't a typeinfo object")
  558. torch._C._register_py_class_for_device(device, cls)
  559. is_multithreading_enabled = torch._C._is_multithreading_enabled
  560. torch._C._add_docstr(
  561. is_multithreading_enabled, "Returns True if multithreading is currently enabled."
  562. )
  563. is_view_replay_enabled = torch._C._is_view_replay_enabled
  564. torch._C._add_docstr(
  565. is_view_replay_enabled, "Returns True if view-replay is currently enabled."
  566. )