common.py 230 B

1234567891011
  1. # mypy: allow-untyped-defs
  2. from importlib.util import find_spec
  3. import torch
  4. __all__ = ["amp_definitely_not_available"]
  5. def amp_definitely_not_available():
  6. return not (torch.cuda.is_available() or find_spec("torch_xla"))