CUDASparse.h 1.0 KB

1234567891011121314151617181920212223242526272829
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/cuda/CUDAContext.h>
  4. #if defined(USE_ROCM)
  5. #include <hipsparse/hipsparse-version.h>
  6. #define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
  7. #endif
  8. // cuSparse Generic API spsv function was added in CUDA 11.3.0
  9. // hipSparse supports SpSV as well
  10. #if (defined(CUDART_VERSION) && defined(CUSPARSE_VERSION)) || defined(USE_ROCM)
  11. #define AT_USE_CUSPARSE_GENERIC_SPSV() 1
  12. #else
  13. #define AT_USE_CUSPARSE_GENERIC_SPSV() 0
  14. #endif
  15. // cuSparse Generic API spsm function was added in CUDA 11.3.1
  16. // hipSparse supports SpSM as well
  17. #if (defined(CUDART_VERSION) && defined(CUSPARSE_VERSION)) || defined(USE_ROCM)
  18. #define AT_USE_CUSPARSE_GENERIC_SPSM() 1
  19. #else
  20. #define AT_USE_CUSPARSE_GENERIC_SPSM() 0
  21. #endif
  22. #else
  23. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  24. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)