grad_mode.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # mypy: allow-untyped-defs
  2. from typing import Any, Union
  3. import torch
  4. from torch.utils._contextlib import (
  5. _DecoratorContextManager,
  6. _NoParamDecoratorContextManager,
  7. F,
  8. )
  9. __all__ = [
  10. "no_grad",
  11. "enable_grad",
  12. "set_grad_enabled",
  13. "inference_mode",
  14. "set_multithreading_enabled",
  15. ]
  16. class no_grad(_NoParamDecoratorContextManager):
  17. r"""Context-manager that disables gradient calculation.
  18. Disabling gradient calculation is useful for inference, when you are sure
  19. that you will not call :meth:`Tensor.backward()`. It will reduce memory
  20. consumption for computations that would otherwise have `requires_grad=True`.
  21. In this mode, the result of every computation will have
  22. `requires_grad=False`, even when the inputs have `requires_grad=True`.
  23. There is an exception! All factory functions, or functions that create
  24. a new Tensor and take a requires_grad kwarg, will NOT be affected by
  25. this mode.
  26. This context manager is thread local; it will not affect computation
  27. in other threads.
  28. Also functions as a decorator.
  29. .. note::
  30. No-grad is one of several mechanisms that can enable or
  31. disable gradients locally see :ref:`locally-disable-grad-doc` for
  32. more information on how they compare.
  33. .. note::
  34. This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
  35. If you want to disable forward AD for a computation, you can unpack
  36. your dual tensors.
  37. Example::
  38. >>> # xdoctest: +SKIP
  39. >>> x = torch.tensor([1.], requires_grad=True)
  40. >>> with torch.no_grad():
  41. ... y = x * 2
  42. >>> y.requires_grad
  43. False
  44. >>> @torch.no_grad()
  45. ... def doubler(x):
  46. ... return x * 2
  47. >>> z = doubler(x)
  48. >>> z.requires_grad
  49. False
  50. >>> @torch.no_grad()
  51. ... def tripler(x):
  52. ... return x * 3
  53. >>> z = tripler(x)
  54. >>> z.requires_grad
  55. False
  56. >>> # factory function exception
  57. >>> with torch.no_grad():
  58. ... a = torch.nn.Parameter(torch.rand(10))
  59. >>> a.requires_grad
  60. True
  61. """
  62. def __init__(self) -> None:
  63. if not torch._jit_internal.is_scripting():
  64. super().__init__()
  65. self.prev = False
  66. def __enter__(self) -> None:
  67. self.prev = torch.is_grad_enabled()
  68. torch.set_grad_enabled(False)
  69. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  70. torch.set_grad_enabled(self.prev)
  71. class enable_grad(_NoParamDecoratorContextManager):
  72. r"""Context-manager that enables gradient calculation.
  73. Enables gradient calculation, if it has been disabled via :class:`~no_grad`
  74. or :class:`~set_grad_enabled`.
  75. This context manager is thread local; it will not affect computation
  76. in other threads.
  77. Also functions as a decorator.
  78. .. note::
  79. enable_grad is one of several mechanisms that can enable or
  80. disable gradients locally see :ref:`locally-disable-grad-doc` for
  81. more information on how they compare.
  82. .. note::
  83. This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
  84. Example::
  85. >>> # xdoctest: +SKIP
  86. >>> x = torch.tensor([1.], requires_grad=True)
  87. >>> with torch.no_grad():
  88. ... with torch.enable_grad():
  89. ... y = x * 2
  90. >>> y.requires_grad
  91. True
  92. >>> y.backward()
  93. >>> x.grad
  94. tensor([2.])
  95. >>> @torch.enable_grad()
  96. ... def doubler(x):
  97. ... return x * 2
  98. >>> with torch.no_grad():
  99. ... z = doubler(x)
  100. >>> z.requires_grad
  101. True
  102. >>> @torch.enable_grad()
  103. ... def tripler(x):
  104. ... return x * 3
  105. >>> with torch.no_grad():
  106. ... z = tripler(x)
  107. >>> z.requires_grad
  108. True
  109. """
  110. def __enter__(self) -> None:
  111. self.prev = torch.is_grad_enabled()
  112. torch._C._set_grad_enabled(True)
  113. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  114. torch._C._set_grad_enabled(self.prev)
  115. class set_grad_enabled(_DecoratorContextManager):
  116. r"""Context-manager that sets gradient calculation on or off.
  117. ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
  118. It can be used as a context-manager or as a function.
  119. This context manager is thread local; it will not affect computation
  120. in other threads.
  121. Args:
  122. mode (bool): Flag whether to enable grad (``True``), or disable
  123. (``False``). This can be used to conditionally enable
  124. gradients.
  125. .. note::
  126. set_grad_enabled is one of several mechanisms that can enable or
  127. disable gradients locally see :ref:`locally-disable-grad-doc` for
  128. more information on how they compare.
  129. .. note::
  130. This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
  131. Example::
  132. >>> # xdoctest: +SKIP
  133. >>> x = torch.tensor([1.], requires_grad=True)
  134. >>> is_train = False
  135. >>> with torch.set_grad_enabled(is_train):
  136. ... y = x * 2
  137. >>> y.requires_grad
  138. False
  139. >>> _ = torch.set_grad_enabled(True)
  140. >>> y = x * 2
  141. >>> y.requires_grad
  142. True
  143. >>> _ = torch.set_grad_enabled(False)
  144. >>> y = x * 2
  145. >>> y.requires_grad
  146. False
  147. """
  148. def __init__(self, mode: bool) -> None:
  149. self.prev = torch.is_grad_enabled()
  150. self.mode = mode
  151. torch._C._set_grad_enabled(mode)
  152. def __call__(self, orig_func: F) -> F:
  153. torch._C._set_grad_enabled(self.prev)
  154. return super().__call__(orig_func)
  155. def __enter__(self) -> None:
  156. torch._C._set_grad_enabled(self.mode)
  157. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  158. torch._C._set_grad_enabled(self.prev)
  159. def __str__(self) -> str:
  160. return f"{torch.typename(self)}(mode={self.mode})"
  161. def __repr__(self) -> str:
  162. return str(self)
  163. def clone(self) -> "set_grad_enabled":
  164. r"""
  165. Create a copy of this class
  166. """
  167. return self.__class__(self.mode)
  168. class inference_mode(_DecoratorContextManager):
  169. r"""Context manager that enables or disables inference mode.
  170. InferenceMode is analogous to :class:`~no_grad` and should be used
  171. when you are certain your operations will not interact with autograd
  172. (e.g., during data loading or model evaluation). Compared to
  173. :class:`~no_grad`, it removes additional overhead by disabling view
  174. tracking and version counter bumps. It is also more restrictive, in
  175. that tensors created in this mode cannot be used in computations
  176. recorded by autograd.
  177. This context manager is thread-local; it does not affect computation
  178. in other threads.
  179. Also functions as a decorator.
  180. .. note::
  181. Inference mode is one of several mechanisms that can locally enable
  182. or disable gradients. See :ref:`locally-disable-grad-doc` for a
  183. comparison. If avoiding the use of tensors created in inference mode
  184. in autograd-tracked regions is difficult, consider benchmarking your
  185. code with and without inference mode to weigh the performance benefits
  186. against the trade-offs. You can always use :class:`~no_grad` instead.
  187. .. note::
  188. Unlike some other mechanisms that locally enable or disable grad,
  189. entering inference_mode also disables :ref:`forward-mode AD <forward-mode-ad>`.
  190. .. warning::
  191. `inference_mode` does NOT automatically set the model to evaluation mode.
  192. For proper inference behavior (e.g., disabling dropout, using running statistics
  193. in batch normalization), you must explicitly set your model to evaluation mode using
  194. `model.eval()` in addition to using this context manager.
  195. Args:
  196. mode (bool or function): Either a boolean flag to enable or disable
  197. inference mode, or a Python function to decorate with inference
  198. mode enabled.
  199. Example::
  200. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
  201. >>> import torch
  202. >>> x = torch.ones(1, 2, 3, requires_grad=True)
  203. >>> with torch.inference_mode():
  204. ... y = x * x
  205. >>> y.requires_grad
  206. False
  207. >>> # xdoctest: +SKIP("want string isn't quite right")
  208. >>> y._version
  209. Traceback (most recent call last):
  210. File "<stdin>", line 1, in <module>
  211. RuntimeError: Inference tensors do not track version counter.
  212. >>> @torch.inference_mode()
  213. ... def func(x):
  214. ... return x * x
  215. >>> out = func(x)
  216. >>> out.requires_grad
  217. False
  218. >>> @torch.inference_mode()
  219. ... def doubler(x):
  220. ... return x * 2
  221. >>> out = doubler(x)
  222. >>> out.requires_grad
  223. False
  224. """
  225. def __init__(self, mode: bool = True) -> None:
  226. if not torch._jit_internal.is_scripting():
  227. super().__init__()
  228. self.mode = mode
  229. def __new__(cls, mode=True):
  230. if isinstance(mode, bool):
  231. return super().__new__(cls)
  232. return cls()(mode)
  233. def __enter__(self) -> None:
  234. self._inference_mode_context = torch._C._InferenceMode(self.mode)
  235. self._inference_mode_context.__enter__()
  236. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  237. self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
  238. def clone(self) -> "inference_mode":
  239. r"""
  240. Create a copy of this class
  241. """
  242. return self.__class__(self.mode)
  243. def _enter_inference_mode(mode):
  244. mode_context = torch._C._InferenceMode(mode)
  245. mode_context.__enter__()
  246. return mode_context
  247. def _exit_inference_mode(mode):
  248. mode.__exit__(None, None, None)
  249. class set_multithreading_enabled(_DecoratorContextManager):
  250. r"""Context-manager that enables or disables multithreaded backward.
  251. Ordinarily, when :ref:`accelerator<accelerators>` devices are in use,
  252. the backward pass runs on device-specific worker threads. The engine
  253. creates these threads based on the number of available devices and
  254. reuses them across iterations.
  255. When ``mode=False``, the backward pass runs on the calling thread
  256. instead. ``mode=True`` restores the default behavior.
  257. This can be used as a context-manager or as a function. It is
  258. thread-local and will not affect computation in other threads.
  259. Args:
  260. mode (bool): Whether to enable multithreaded backward (``True``,
  261. default) or disable (``False``).
  262. .. note::
  263. This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`,
  264. which never uses multithreading.
  265. """
  266. def __init__(self, mode: bool) -> None:
  267. self.prev = torch._C._is_multithreading_enabled()
  268. torch._C._set_multithreading_enabled(mode)
  269. self.mode = mode
  270. def __enter__(self) -> None:
  271. pass
  272. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  273. torch._C._set_multithreading_enabled(self.prev)
  274. def clone(self) -> "set_multithreading_enabled":
  275. r"""
  276. Create a copy of this class
  277. """
  278. return self.__class__(self.mode)
  279. class _force_original_view_tracking(_DecoratorContextManager):
  280. r"""Context-manager that sets whether or not to always enable view-replay in autograd.
  281. ``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
  282. It can be used as a context-manager or as a function.
  283. This context manager is thread local; it will not affect computation
  284. in other threads.
  285. When a tensor view is mutated, the autograd engine needs to decide whether or not
  286. to regenerate the "updated view" by either replaying the chain of views from the updated base,
  287. or with a single call to as_strided.
  288. If set_view_replay_enabled is set to True, then autograd will always use view replay.
  289. Otherwise, it will fall back to its existing logic.
  290. Args:
  291. mode (bool): Flag whether to enable view-replay (``True``), or disable
  292. (``False``).
  293. """
  294. def __init__(self, mode: bool) -> None:
  295. self.prev = torch._C._is_view_replay_enabled()
  296. torch._C._set_view_replay_enabled(mode)
  297. self.mode = mode
  298. def __enter__(self) -> None:
  299. pass
  300. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  301. torch._C._set_view_replay_enabled(self.prev)
  302. def clone(self):
  303. return self.__class__(self.mode)
  304. class _unsafe_preserve_version_counter(_DecoratorContextManager):
  305. r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
  306. This context manager can lead to arbitrary silent-correctness issues in any other part of your code
  307. (even the ones not touched directly by the context manager)!
  308. Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
  309. This is generally important for correctness, as for example, mutating a tensor that autograd has saved
  310. for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
  311. and error out in this situation.
  312. However, there are rare instances where it might be useful to hide mutations from autograd. For example:
  313. if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
  314. the tensor right before it is needed by autograd.
  315. Args:
  316. tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
  317. .. note::
  318. This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
  319. """
  320. def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
  321. self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
  322. if not isinstance(self.tensors, tuple):
  323. raise AssertionError("Expected tensors to be a tuple")
  324. self.prev_versions = tuple(t._version for t in self.tensors)
  325. def __enter__(self) -> None:
  326. pass
  327. # pyrefly: ignore [bad-override]
  328. def __exit__(self, *args) -> None:
  329. torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)