_utils.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from typing import Any
  2. import torch
  3. # The _get_device_index has been moved to torch.utils._get_device_index
  4. from torch._utils import _get_device_index as _torch_get_device_index
  5. def _get_device_index(
  6. device: Any, optional: bool = False, allow_cpu: bool = False
  7. ) -> int:
  8. r"""Get the device index from :attr:`device`, which can be a torch.device
  9. object, a Python integer, or ``None``.
  10. If :attr:`device` is a torch.device object, returns the device index if it
  11. is a XPU device. Note that for a XPU device without a specified index,
  12. i.e., ``torch.device('xpu')``, this will return the current default XPU
  13. device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
  14. CPU devices will be accepted and ``-1`` will be returned in this case.
  15. If :attr:`device` is a Python integer, it is returned as is.
  16. If :attr:`device` is ``None``, this will return the current default XPU
  17. device if :attr:`optional` is ``True``.
  18. """
  19. if isinstance(device, int):
  20. return device
  21. if isinstance(device, str):
  22. device = torch.device(device)
  23. if isinstance(device, torch.device):
  24. if allow_cpu:
  25. if device.type not in ["xpu", "cpu"]:
  26. raise ValueError(f"Expected a xpu or cpu device, but got: {device}")
  27. elif device.type != "xpu":
  28. raise ValueError(f"Expected a xpu device, but got: {device}")
  29. if not torch.jit.is_scripting():
  30. if isinstance(device, torch.xpu.device):
  31. return device.idx
  32. return _torch_get_device_index(device, optional, allow_cpu)