XPUFunctions.h 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/Device.h>
  4. #include <c10/xpu/XPUDeviceProp.h>
  5. #include <c10/xpu/XPUMacros.h>
  6. // The naming convention used here matches the naming convention of torch.xpu
  7. namespace c10::xpu {
  8. // Log a warning only once if no devices are detected.
  9. C10_XPU_API DeviceIndex device_count();
  10. // Throws an error if no devices are detected.
  11. C10_XPU_API DeviceIndex device_count_ensure_non_zero();
  12. C10_XPU_API DeviceIndex current_device();
  13. C10_XPU_API void set_device(DeviceIndex device);
  14. C10_XPU_API DeviceIndex exchange_device(DeviceIndex device);
  15. C10_XPU_API DeviceIndex maybe_exchange_device(DeviceIndex to_device);
  16. C10_XPU_API sycl::device& get_raw_device(DeviceIndex device);
  17. C10_XPU_API sycl::context& get_device_context();
  18. C10_XPU_API void get_device_properties(
  19. DeviceProp* device_prop,
  20. DeviceIndex device);
  21. C10_XPU_API DeviceIndex get_device_idx_from_pointer(void* ptr);
  22. static inline void check_device_index(DeviceIndex device_index) {
  23. TORCH_CHECK(
  24. device_index >= 0 && device_index < c10::xpu::device_count(),
  25. "The device index is out of range. It must be in [0, ",
  26. static_cast<int>(c10::xpu::device_count()),
  27. "), but got ",
  28. static_cast<int>(device_index),
  29. ".");
  30. }
  31. } // namespace c10::xpu
  32. #else
  33. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  34. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)