| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426 |
- # mypy: allow-untyped-defs
- from typing import Any, Union
- import torch
- from torch.utils._contextlib import (
- _DecoratorContextManager,
- _NoParamDecoratorContextManager,
- F,
- )
- __all__ = [
- "no_grad",
- "enable_grad",
- "set_grad_enabled",
- "inference_mode",
- "set_multithreading_enabled",
- ]
- class no_grad(_NoParamDecoratorContextManager):
- r"""Context-manager that disables gradient calculation.
- Disabling gradient calculation is useful for inference, when you are sure
- that you will not call :meth:`Tensor.backward()`. It will reduce memory
- consumption for computations that would otherwise have `requires_grad=True`.
- In this mode, the result of every computation will have
- `requires_grad=False`, even when the inputs have `requires_grad=True`.
- There is an exception! All factory functions, or functions that create
- a new Tensor and take a requires_grad kwarg, will NOT be affected by
- this mode.
- This context manager is thread local; it will not affect computation
- in other threads.
- Also functions as a decorator.
- .. note::
- No-grad is one of several mechanisms that can enable or
- disable gradients locally see :ref:`locally-disable-grad-doc` for
- more information on how they compare.
- .. note::
- This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
- If you want to disable forward AD for a computation, you can unpack
- your dual tensors.
- Example::
- >>> # xdoctest: +SKIP
- >>> x = torch.tensor([1.], requires_grad=True)
- >>> with torch.no_grad():
- ... y = x * 2
- >>> y.requires_grad
- False
- >>> @torch.no_grad()
- ... def doubler(x):
- ... return x * 2
- >>> z = doubler(x)
- >>> z.requires_grad
- False
- >>> @torch.no_grad()
- ... def tripler(x):
- ... return x * 3
- >>> z = tripler(x)
- >>> z.requires_grad
- False
- >>> # factory function exception
- >>> with torch.no_grad():
- ... a = torch.nn.Parameter(torch.rand(10))
- >>> a.requires_grad
- True
- """
- def __init__(self) -> None:
- if not torch._jit_internal.is_scripting():
- super().__init__()
- self.prev = False
- def __enter__(self) -> None:
- self.prev = torch.is_grad_enabled()
- torch.set_grad_enabled(False)
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- torch.set_grad_enabled(self.prev)
- class enable_grad(_NoParamDecoratorContextManager):
- r"""Context-manager that enables gradient calculation.
- Enables gradient calculation, if it has been disabled via :class:`~no_grad`
- or :class:`~set_grad_enabled`.
- This context manager is thread local; it will not affect computation
- in other threads.
- Also functions as a decorator.
- .. note::
- enable_grad is one of several mechanisms that can enable or
- disable gradients locally see :ref:`locally-disable-grad-doc` for
- more information on how they compare.
- .. note::
- This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
- Example::
- >>> # xdoctest: +SKIP
- >>> x = torch.tensor([1.], requires_grad=True)
- >>> with torch.no_grad():
- ... with torch.enable_grad():
- ... y = x * 2
- >>> y.requires_grad
- True
- >>> y.backward()
- >>> x.grad
- tensor([2.])
- >>> @torch.enable_grad()
- ... def doubler(x):
- ... return x * 2
- >>> with torch.no_grad():
- ... z = doubler(x)
- >>> z.requires_grad
- True
- >>> @torch.enable_grad()
- ... def tripler(x):
- ... return x * 3
- >>> with torch.no_grad():
- ... z = tripler(x)
- >>> z.requires_grad
- True
- """
- def __enter__(self) -> None:
- self.prev = torch.is_grad_enabled()
- torch._C._set_grad_enabled(True)
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- torch._C._set_grad_enabled(self.prev)
- class set_grad_enabled(_DecoratorContextManager):
- r"""Context-manager that sets gradient calculation on or off.
- ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
- It can be used as a context-manager or as a function.
- This context manager is thread local; it will not affect computation
- in other threads.
- Args:
- mode (bool): Flag whether to enable grad (``True``), or disable
- (``False``). This can be used to conditionally enable
- gradients.
- .. note::
- set_grad_enabled is one of several mechanisms that can enable or
- disable gradients locally see :ref:`locally-disable-grad-doc` for
- more information on how they compare.
- .. note::
- This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
- Example::
- >>> # xdoctest: +SKIP
- >>> x = torch.tensor([1.], requires_grad=True)
- >>> is_train = False
- >>> with torch.set_grad_enabled(is_train):
- ... y = x * 2
- >>> y.requires_grad
- False
- >>> _ = torch.set_grad_enabled(True)
- >>> y = x * 2
- >>> y.requires_grad
- True
- >>> _ = torch.set_grad_enabled(False)
- >>> y = x * 2
- >>> y.requires_grad
- False
- """
- def __init__(self, mode: bool) -> None:
- self.prev = torch.is_grad_enabled()
- self.mode = mode
- torch._C._set_grad_enabled(mode)
- def __call__(self, orig_func: F) -> F:
- torch._C._set_grad_enabled(self.prev)
- return super().__call__(orig_func)
- def __enter__(self) -> None:
- torch._C._set_grad_enabled(self.mode)
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- torch._C._set_grad_enabled(self.prev)
- def __str__(self) -> str:
- return f"{torch.typename(self)}(mode={self.mode})"
- def __repr__(self) -> str:
- return str(self)
- def clone(self) -> "set_grad_enabled":
- r"""
- Create a copy of this class
- """
- return self.__class__(self.mode)
- class inference_mode(_DecoratorContextManager):
- r"""Context manager that enables or disables inference mode.
- InferenceMode is analogous to :class:`~no_grad` and should be used
- when you are certain your operations will not interact with autograd
- (e.g., during data loading or model evaluation). Compared to
- :class:`~no_grad`, it removes additional overhead by disabling view
- tracking and version counter bumps. It is also more restrictive, in
- that tensors created in this mode cannot be used in computations
- recorded by autograd.
- This context manager is thread-local; it does not affect computation
- in other threads.
- Also functions as a decorator.
- .. note::
- Inference mode is one of several mechanisms that can locally enable
- or disable gradients. See :ref:`locally-disable-grad-doc` for a
- comparison. If avoiding the use of tensors created in inference mode
- in autograd-tracked regions is difficult, consider benchmarking your
- code with and without inference mode to weigh the performance benefits
- against the trade-offs. You can always use :class:`~no_grad` instead.
- .. note::
- Unlike some other mechanisms that locally enable or disable grad,
- entering inference_mode also disables :ref:`forward-mode AD <forward-mode-ad>`.
- .. warning::
- `inference_mode` does NOT automatically set the model to evaluation mode.
- For proper inference behavior (e.g., disabling dropout, using running statistics
- in batch normalization), you must explicitly set your model to evaluation mode using
- `model.eval()` in addition to using this context manager.
- Args:
- mode (bool or function): Either a boolean flag to enable or disable
- inference mode, or a Python function to decorate with inference
- mode enabled.
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
- >>> import torch
- >>> x = torch.ones(1, 2, 3, requires_grad=True)
- >>> with torch.inference_mode():
- ... y = x * x
- >>> y.requires_grad
- False
- >>> # xdoctest: +SKIP("want string isn't quite right")
- >>> y._version
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- RuntimeError: Inference tensors do not track version counter.
- >>> @torch.inference_mode()
- ... def func(x):
- ... return x * x
- >>> out = func(x)
- >>> out.requires_grad
- False
- >>> @torch.inference_mode()
- ... def doubler(x):
- ... return x * 2
- >>> out = doubler(x)
- >>> out.requires_grad
- False
- """
- def __init__(self, mode: bool = True) -> None:
- if not torch._jit_internal.is_scripting():
- super().__init__()
- self.mode = mode
- def __new__(cls, mode=True):
- if isinstance(mode, bool):
- return super().__new__(cls)
- return cls()(mode)
- def __enter__(self) -> None:
- self._inference_mode_context = torch._C._InferenceMode(self.mode)
- self._inference_mode_context.__enter__()
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
- def clone(self) -> "inference_mode":
- r"""
- Create a copy of this class
- """
- return self.__class__(self.mode)
- def _enter_inference_mode(mode):
- mode_context = torch._C._InferenceMode(mode)
- mode_context.__enter__()
- return mode_context
- def _exit_inference_mode(mode):
- mode.__exit__(None, None, None)
- class set_multithreading_enabled(_DecoratorContextManager):
- r"""Context-manager that enables or disables multithreaded backward.
- Ordinarily, when :ref:`accelerator<accelerators>` devices are in use,
- the backward pass runs on device-specific worker threads. The engine
- creates these threads based on the number of available devices and
- reuses them across iterations.
- When ``mode=False``, the backward pass runs on the calling thread
- instead. ``mode=True`` restores the default behavior.
- This can be used as a context-manager or as a function. It is
- thread-local and will not affect computation in other threads.
- Args:
- mode (bool): Whether to enable multithreaded backward (``True``,
- default) or disable (``False``).
- .. note::
- This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`,
- which never uses multithreading.
- """
- def __init__(self, mode: bool) -> None:
- self.prev = torch._C._is_multithreading_enabled()
- torch._C._set_multithreading_enabled(mode)
- self.mode = mode
- def __enter__(self) -> None:
- pass
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- torch._C._set_multithreading_enabled(self.prev)
- def clone(self) -> "set_multithreading_enabled":
- r"""
- Create a copy of this class
- """
- return self.__class__(self.mode)
- class _force_original_view_tracking(_DecoratorContextManager):
- r"""Context-manager that sets whether or not to always enable view-replay in autograd.
- ``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
- It can be used as a context-manager or as a function.
- This context manager is thread local; it will not affect computation
- in other threads.
- When a tensor view is mutated, the autograd engine needs to decide whether or not
- to regenerate the "updated view" by either replaying the chain of views from the updated base,
- or with a single call to as_strided.
- If set_view_replay_enabled is set to True, then autograd will always use view replay.
- Otherwise, it will fall back to its existing logic.
- Args:
- mode (bool): Flag whether to enable view-replay (``True``), or disable
- (``False``).
- """
- def __init__(self, mode: bool) -> None:
- self.prev = torch._C._is_view_replay_enabled()
- torch._C._set_view_replay_enabled(mode)
- self.mode = mode
- def __enter__(self) -> None:
- pass
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- torch._C._set_view_replay_enabled(self.prev)
- def clone(self):
- return self.__class__(self.mode)
- class _unsafe_preserve_version_counter(_DecoratorContextManager):
- r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
- This context manager can lead to arbitrary silent-correctness issues in any other part of your code
- (even the ones not touched directly by the context manager)!
- Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
- This is generally important for correctness, as for example, mutating a tensor that autograd has saved
- for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
- and error out in this situation.
- However, there are rare instances where it might be useful to hide mutations from autograd. For example:
- if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
- the tensor right before it is needed by autograd.
- Args:
- tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
- .. note::
- This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
- """
- def __init__(self, tensors: Union[torch.Tensor, tuple[torch.Tensor, ...]]) -> None:
- self.tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tensors
- if not isinstance(self.tensors, tuple):
- raise AssertionError("Expected tensors to be a tuple")
- self.prev_versions = tuple(t._version for t in self.tensors)
- def __enter__(self) -> None:
- pass
- # pyrefly: ignore [bad-override]
- def __exit__(self, *args) -> None:
- torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)
|