anomaly_mode.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # mypy: allow-untyped-defs
  2. r"""Autograd anomaly mode."""
  3. import warnings
  4. import torch
  5. __all__ = ["detect_anomaly", "set_detect_anomaly"]
  6. class detect_anomaly:
  7. r"""Context-manager that enable anomaly detection for the autograd engine.
  8. This does two things:
  9. - Running the forward pass with detection enabled will allow the backward
  10. pass to print the traceback of the forward operation that created the failing
  11. backward function.
  12. - If ``check_nan`` is ``True``, any backward computation that generate "nan"
  13. value will raise an error. Default ``True``.
  14. .. warning::
  15. This mode should be enabled only for debugging as the different tests
  16. will slow down your program execution.
  17. Example:
  18. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMALY)
  19. >>> import torch
  20. >>> from torch import autograd
  21. >>> class MyFunc(autograd.Function):
  22. ... @staticmethod
  23. ... def forward(ctx, inp):
  24. ... return inp.clone()
  25. ...
  26. ... @staticmethod
  27. ... def backward(ctx, gO):
  28. ... # Error during the backward pass
  29. ... raise RuntimeError("Some error in backward")
  30. ... return gO.clone()
  31. >>> def run_fn(a):
  32. ... out = MyFunc.apply(a)
  33. ... return out.sum()
  34. >>> inp = torch.rand(10, 10, requires_grad=True)
  35. >>> out = run_fn(inp)
  36. >>> out.backward()
  37. Traceback (most recent call last):
  38. File "<stdin>", line 1, in <module>
  39. File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
  40. torch.autograd.backward(self, gradient, retain_graph, create_graph)
  41. File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
  42. allow_unreachable=True) # allow_unreachable flag
  43. File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
  44. return self._forward_cls.backward(self, *args)
  45. File "<stdin>", line 8, in backward
  46. RuntimeError: Some error in backward
  47. >>> with autograd.detect_anomaly():
  48. ... inp = torch.rand(10, 10, requires_grad=True)
  49. ... out = run_fn(inp)
  50. ... out.backward()
  51. Traceback of forward call that caused the error:
  52. File "tmp.py", line 53, in <module>
  53. out = run_fn(inp)
  54. File "tmp.py", line 44, in run_fn
  55. out = MyFunc.apply(a)
  56. Traceback (most recent call last):
  57. File "<stdin>", line 4, in <module>
  58. File "/your/pytorch/install/torch/_tensor.py", line 93, in backward
  59. torch.autograd.backward(self, gradient, retain_graph, create_graph)
  60. File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward
  61. allow_unreachable=True) # allow_unreachable flag
  62. File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply
  63. return self._forward_cls.backward(self, *args)
  64. File "<stdin>", line 8, in backward
  65. RuntimeError: Some error in backward
  66. """
  67. def __init__(self, check_nan=True) -> None: # noqa: D107
  68. self.prev = torch.is_anomaly_enabled()
  69. self.check_nan = check_nan
  70. self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
  71. warnings.warn(
  72. "Anomaly Detection has been enabled. "
  73. "This mode will increase the runtime "
  74. "and should only be enabled for debugging.",
  75. stacklevel=2,
  76. )
  77. def __enter__(self) -> None: # noqa: D105
  78. torch.set_anomaly_enabled(True, self.check_nan)
  79. def __exit__(self, *args: object) -> None: # noqa: D105
  80. torch.set_anomaly_enabled(self.prev, self.prev_check_nan)
  81. class set_detect_anomaly:
  82. r"""Context-manager that sets the anomaly detection for the autograd engine on or off.
  83. ``set_detect_anomaly`` will enable or disable the autograd anomaly detection
  84. based on its argument :attr:`mode`.
  85. It can be used as a context-manager or as a function.
  86. See ``detect_anomaly`` above for details of the anomaly detection behaviour.
  87. Args:
  88. mode (bool): Flag whether to enable anomaly detection (``True``),
  89. or disable (``False``).
  90. check_nan (bool): Flag whether to raise an error when the backward
  91. generate "nan"
  92. """
  93. def __init__(self, mode: bool, check_nan: bool = True) -> None: # noqa: D107
  94. self.prev = torch.is_anomaly_enabled()
  95. self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
  96. torch.set_anomaly_enabled(mode, check_nan)
  97. def __enter__(self) -> None: # noqa: D105
  98. pass
  99. def __exit__(self, *args: object) -> None: # noqa: D105
  100. torch.set_anomaly_enabled(self.prev, self.prev_check_nan)