CUDAUtils.h 670 B

12345678910111213141516171819202122232425
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/cuda/CUDAContext.h>
  4. namespace at::cuda {
  5. // Check if every tensor in a list of tensors matches the current
  6. // device.
  7. inline bool check_device(ArrayRef<Tensor> ts) {
  8. if (ts.empty()) {
  9. return true;
  10. }
  11. Device curDevice = Device(kCUDA, current_device());
  12. for (const Tensor& t : ts) {
  13. if (t.device() != curDevice) return false;
  14. }
  15. return true;
  16. }
  17. } // namespace at::cuda
  18. #else
  19. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  20. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)