_utils.py 936 B

1234567891011121314151617181920212223242526
  1. import torch
  2. from torch.types import Device as _device_t
  3. def _get_device_index(device: _device_t, optional: bool = False) -> int:
  4. if isinstance(device, int):
  5. return device
  6. if isinstance(device, str):
  7. device = torch.device(device)
  8. device_index: int | None = None
  9. if isinstance(device, torch.device):
  10. acc = torch.accelerator.current_accelerator()
  11. if acc is None:
  12. raise RuntimeError("Accelerator expected")
  13. if acc.type != device.type:
  14. raise ValueError(
  15. f"{device.type} doesn't match the current accelerator {acc}."
  16. )
  17. device_index = device.index
  18. if device_index is None:
  19. if not optional:
  20. raise ValueError(
  21. f"Expected a torch.device with a specified index or an integer, but got:{device}"
  22. )
  23. return torch.accelerator.current_device_index()
  24. return device_index