Utils.h 849 B

123456789101112131415161718192021222324252627
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. #include <ATen/cudnn/Handle.h>
  6. #include <ATen/cudnn/cudnn-wrapper.h>
  7. namespace at::native {
  8. // cuDNN has a buggy check for tensor being contiguous (that is, it does
  9. // not ignore stride for dimension that is equal to 0). This function
  10. // makes tensors which have zero stride contiguous, by setting the
  11. // strides to 1 as cuDNN likes.
  12. inline Tensor contiguousIfZeroInStrides(const Tensor& t) {
  13. for (auto s : t.strides()) {
  14. if (s == 0)
  15. return t.contiguous();
  16. }
  17. return t;
  18. }
  19. } // namespace at::native
  20. #else
  21. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  22. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)