| 123456789101112131415161718192021 |
- import torch
- from functools import partial
- from typing import Callable
- def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
- def decorator(*args, **kwargs):
- if cuda_amp_deprecated:
- kwargs["device_type"] = "cuda"
- return dec(*args, **kwargs)
- return decorator
- if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
- deprecated = True
- from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
- else:
- deprecated = False
- from torch.cuda.amp import custom_fwd, custom_bwd
- custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
- custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
|