XPUUtils.h 676 B

1234567891011121314151617181920212223242526
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/xpu/XPUContext.h>
  4. namespace at::xpu {
  5. // Check if every tensor in a list of tensors matches the current device.
  6. inline bool check_device(ArrayRef<Tensor> ts) {
  7. if (ts.empty()) {
  8. return true;
  9. }
  10. Device curDevice = Device(kXPU, current_device());
  11. for (const Tensor& t : ts) {
  12. if (t.device() != curDevice) {
  13. return false;
  14. }
  15. }
  16. return true;
  17. }
  18. } // namespace at::xpu
  19. #else
  20. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  21. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)