_safeguard.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
  4. from torch.overrides import TorchFunctionMode
  5. class AutogradStateOpsFailSafeguard(TorchFunctionMode):
  6. """
  7. Detect grad state ops during exporting the graph and fail the process by
  8. raising an error, to avoid unexpected behavior. Those grad mode ops could be:
  9. `torch.no_grad`
  10. `torch.enable_grad`
  11. `torch.set_grad_enabled`
  12. Export with predispatch mode is exempted.
  13. """
  14. def __torch_function__(self, func, types, args=(), kwargs=None):
  15. kwargs = kwargs or {}
  16. unsupported_grad_mode_ops = [
  17. torch._C._set_grad_enabled,
  18. ]
  19. # It's only enabled while tracing, by confirming the torch dispatch mode is
  20. # any active PROXY. This is to allow the autograd ops out of tracing.
  21. current_state = torch._C.is_grad_enabled()
  22. if func in unsupported_grad_mode_ops:
  23. if len(args) != 1:
  24. raise AssertionError(
  25. f"Expected exactly 1 argument for grad mode op, but got {len(args)}"
  26. )
  27. changed_state = args[0]
  28. mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
  29. # Intend to check if it's not the pre_dispatch mode. It's allowed to use
  30. # autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
  31. if (
  32. mode
  33. and isinstance(mode, ProxyTorchDispatchMode)
  34. and not mode.pre_dispatch
  35. and changed_state != current_state
  36. ):
  37. raise RuntimeError(
  38. f"Encountered autograd state manager op {func} trying to change global autograd state "
  39. "while exporting. This is unsafe because we don't capture this op in torch.export "
  40. "today, hence we can't reflect the user intention soundly. You can fix this by "
  41. "adding a torch.no_grad() context around the export call."
  42. )
  43. return func(*args, **kwargs)