| 1234567891011121314151617181920212223242526272829 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <ATen/cuda/CUDAContext.h>
- #if defined(USE_ROCM)
- #include <hipsparse/hipsparse-version.h>
- #define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
- #endif
- // cuSparse Generic API spsv function was added in CUDA 11.3.0
- // hipSparse supports SpSV as well
- #if (defined(CUDART_VERSION) && defined(CUSPARSE_VERSION)) || defined(USE_ROCM)
- #define AT_USE_CUSPARSE_GENERIC_SPSV() 1
- #else
- #define AT_USE_CUSPARSE_GENERIC_SPSV() 0
- #endif
- // cuSparse Generic API spsm function was added in CUDA 11.3.1
- // hipSparse supports SpSM as well
- #if (defined(CUDART_VERSION) && defined(CUSPARSE_VERSION)) || defined(USE_ROCM)
- #define AT_USE_CUSPARSE_GENERIC_SPSM() 1
- #else
- #define AT_USE_CUSPARSE_GENERIC_SPSM() 0
- #endif
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|