torch.py 676 B

123456789101112131415161718192021
  1. import torch
  2. from functools import partial
  3. from typing import Callable
  4. def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
  5. def decorator(*args, **kwargs):
  6. if cuda_amp_deprecated:
  7. kwargs["device_type"] = "cuda"
  8. return dec(*args, **kwargs)
  9. return decorator
  10. if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
  11. deprecated = True
  12. from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
  13. else:
  14. deprecated = False
  15. from torch.cuda.amp import custom_fwd, custom_bwd
  16. custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
  17. custom_bwd = custom_amp_decorator(custom_bwd, deprecated)