| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- # mypy: allow-untyped-defs
- import torch
- from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
- from torch.overrides import TorchFunctionMode
- class AutogradStateOpsFailSafeguard(TorchFunctionMode):
- """
- Detect grad state ops during exporting the graph and fail the process by
- raising an error, to avoid unexpected behavior. Those grad mode ops could be:
- `torch.no_grad`
- `torch.enable_grad`
- `torch.set_grad_enabled`
- Export with predispatch mode is exempted.
- """
- def __torch_function__(self, func, types, args=(), kwargs=None):
- kwargs = kwargs or {}
- unsupported_grad_mode_ops = [
- torch._C._set_grad_enabled,
- ]
- # It's only enabled while tracing, by confirming the torch dispatch mode is
- # any active PROXY. This is to allow the autograd ops out of tracing.
- current_state = torch._C.is_grad_enabled()
- if func in unsupported_grad_mode_ops:
- if len(args) != 1:
- raise AssertionError(
- f"Expected exactly 1 argument for grad mode op, but got {len(args)}"
- )
- changed_state = args[0]
- mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
- # Intend to check if it's not the pre_dispatch mode. It's allowed to use
- # autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
- if (
- mode
- and isinstance(mode, ProxyTorchDispatchMode)
- and not mode.pre_dispatch
- and changed_state != current_state
- ):
- raise RuntimeError(
- f"Encountered autograd state manager op {func} trying to change global autograd state "
- "while exporting. This is unsafe because we don't capture this op in torch.export "
- "today, hence we can't reflect the user intention soundly. You can fix this by "
- "adding a torch.no_grad() context around the export call."
- )
- return func(*args, **kwargs)
|