Exceptions.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <cublas_v2.h>
  4. #include <cusparse.h>
  5. #include <c10/macros/Export.h>
  6. #if !defined(USE_ROCM)
  7. #include <cusolver_common.h>
  8. #else
  9. #include <hipsolver/hipsolver.h>
  10. #endif
  11. #if defined(USE_CUDSS)
  12. #include <cudss.h>
  13. #endif
  14. #include <ATen/Context.h>
  15. #include <c10/util/Exception.h>
  16. #include <c10/cuda/CUDAException.h>
  17. namespace c10 {
  18. class CuDNNError : public c10::Error {
  19. using Error::Error;
  20. };
  21. } // namespace c10
  22. #define AT_CUDNN_FRONTEND_CHECK(EXPR, ...) \
  23. do { \
  24. auto error_object = EXPR; \
  25. if (!error_object.is_good()) { \
  26. TORCH_CHECK_WITH(CuDNNError, false, \
  27. "cuDNN Frontend error: ", error_object.get_message()); \
  28. } \
  29. } while (0) \
  30. #define AT_CUDNN_CHECK_WITH_SHAPES(EXPR, ...) AT_CUDNN_CHECK(EXPR, "\n", ##__VA_ARGS__)
  31. // See Note [CHECK macro]
  32. #define AT_CUDNN_CHECK(EXPR, ...) \
  33. do { \
  34. cudnnStatus_t status = EXPR; \
  35. if (status != CUDNN_STATUS_SUCCESS) { \
  36. if (status == CUDNN_STATUS_NOT_SUPPORTED) { \
  37. TORCH_CHECK_WITH(CuDNNError, false, \
  38. "cuDNN error: ", \
  39. cudnnGetErrorString(status), \
  40. ". This error may appear if you passed in a non-contiguous input.", ##__VA_ARGS__); \
  41. } else { \
  42. TORCH_CHECK_WITH(CuDNNError, false, \
  43. "cuDNN error: ", cudnnGetErrorString(status), ##__VA_ARGS__); \
  44. } \
  45. } \
  46. } while (0)
  47. namespace at::cuda::blas {
  48. C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
  49. } // namespace at::cuda::blas
  50. #define TORCH_CUDABLAS_CHECK(EXPR) \
  51. do { \
  52. cublasStatus_t __err = EXPR; \
  53. TORCH_CHECK(__err == CUBLAS_STATUS_SUCCESS, \
  54. "CUDA error: ", \
  55. at::cuda::blas::_cublasGetErrorEnum(__err), \
  56. " when calling `" #EXPR "`"); \
  57. } while (0)
  58. const char *cusparseGetErrorString(cusparseStatus_t status);
  59. #define TORCH_CUDASPARSE_CHECK(EXPR) \
  60. do { \
  61. cusparseStatus_t __err = EXPR; \
  62. TORCH_CHECK(__err == CUSPARSE_STATUS_SUCCESS, \
  63. "CUDA error: ", \
  64. cusparseGetErrorString(__err), \
  65. " when calling `" #EXPR "`"); \
  66. } while (0)
  67. #if defined(USE_CUDSS)
  68. namespace at::cuda::cudss {
  69. C10_EXPORT const char* cudssGetErrorMessage(cudssStatus_t error);
  70. } // namespace at::cuda::solver
  71. #define TORCH_CUDSS_CHECK(EXPR) \
  72. do { \
  73. cudssStatus_t __err = EXPR; \
  74. if (__err == CUDSS_STATUS_EXECUTION_FAILED) { \
  75. TORCH_CHECK_LINALG( \
  76. false, \
  77. "cudss error: ", \
  78. at::cuda::cudss::cudssGetErrorMessage(__err), \
  79. ", when calling `" #EXPR "`", \
  80. ". This error may appear if the input matrix contains NaN. ");\
  81. } else { \
  82. TORCH_CHECK( \
  83. __err == CUDSS_STATUS_SUCCESS, \
  84. "cudss error: ", \
  85. at::cuda::cudss::cudssGetErrorMessage(__err), \
  86. ", when calling `" #EXPR "`. "); \
  87. } \
  88. } while (0)
  89. #else
  90. #define TORCH_CUDSS_CHECK(EXPR) EXPR
  91. #endif
  92. namespace at::cuda::solver {
  93. #if !defined(USE_ROCM)
  94. C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
  95. constexpr const char* _cusolver_backend_suggestion = \
  96. "If you keep seeing this error, you may use " \
  97. "`torch.backends.cuda.preferred_linalg_library()` to try " \
  98. "linear algebra operators with other supported backends. " \
  99. "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
  100. // When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue.
  101. #define TORCH_CUSOLVER_CHECK(EXPR) \
  102. do { \
  103. cusolverStatus_t __err = EXPR; \
  104. if (__err == CUSOLVER_STATUS_INVALID_VALUE) { \
  105. TORCH_CHECK_LINALG( \
  106. false, \
  107. "cusolver error: ", \
  108. at::cuda::solver::cusolverGetErrorMessage(__err), \
  109. ", when calling `" #EXPR "`", \
  110. ". This error may appear if the input matrix contains NaN. ", \
  111. at::cuda::solver::_cusolver_backend_suggestion); \
  112. } else { \
  113. TORCH_CHECK( \
  114. __err == CUSOLVER_STATUS_SUCCESS, \
  115. "cusolver error: ", \
  116. at::cuda::solver::cusolverGetErrorMessage(__err), \
  117. ", when calling `" #EXPR "`. ", \
  118. at::cuda::solver::_cusolver_backend_suggestion); \
  119. } \
  120. } while (0)
  121. #else // defined(USE_ROCM)
  122. C10_EXPORT const char* hipsolverGetErrorMessage(hipsolverStatus_t status);
  123. constexpr const char* _hipsolver_backend_suggestion = \
  124. "If you keep seeing this error, you may use " \
  125. "`torch.backends.cuda.preferred_linalg_library()` to try " \
  126. "linear algebra operators with other supported backends. " \
  127. "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";
  128. #define TORCH_CUSOLVER_CHECK(EXPR) \
  129. do { \
  130. hipsolverStatus_t __err = EXPR; \
  131. if (__err == HIPSOLVER_STATUS_INVALID_VALUE) { \
  132. TORCH_CHECK_LINALG( \
  133. false, \
  134. "hipsolver error: ", \
  135. at::cuda::solver::hipsolverGetErrorMessage(__err), \
  136. ", when calling `" #EXPR "`", \
  137. ". This error may appear if the input matrix contains NaN. ", \
  138. at::cuda::solver::_hipsolver_backend_suggestion); \
  139. } else { \
  140. TORCH_CHECK( \
  141. __err == HIPSOLVER_STATUS_SUCCESS, \
  142. "hipsolver error: ", \
  143. at::cuda::solver::hipsolverGetErrorMessage(__err), \
  144. ", when calling `" #EXPR "`. ", \
  145. at::cuda::solver::_hipsolver_backend_suggestion); \
  146. } \
  147. } while (0)
  148. #endif
  149. } // namespace at::cuda::solver
  150. #define AT_CUDA_CHECK(EXPR) C10_CUDA_CHECK(EXPR)
  151. // For CUDA Driver API
  152. //
  153. // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
  154. // in ATen, and we need to use its nvrtcGetErrorString.
  155. // See NOTE [ USE OF NVRTC AND DRIVER API ].
  156. #if !defined(USE_ROCM)
  157. #define AT_CUDA_DRIVER_CHECK(EXPR) \
  158. do { \
  159. CUresult __err = EXPR; \
  160. if (__err != CUDA_SUCCESS) { \
  161. const char* err_str; \
  162. [[maybe_unused]] CUresult get_error_str_err = \
  163. at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \
  164. if (get_error_str_err != CUDA_SUCCESS) { \
  165. TORCH_CHECK(false, "CUDA driver error: unknown error"); \
  166. } else { \
  167. TORCH_CHECK(false, "CUDA driver error: ", err_str); \
  168. } \
  169. } \
  170. } while (0)
  171. #else
  172. #define AT_CUDA_DRIVER_CHECK(EXPR) \
  173. do { \
  174. CUresult __err = EXPR; \
  175. if (__err != CUDA_SUCCESS) { \
  176. TORCH_CHECK(false, "CUDA driver error: ", static_cast<int>(__err)); \
  177. } \
  178. } while (0)
  179. #endif
  180. // For CUDA NVRTC
  181. //
  182. // Note: As of CUDA 10, nvrtc error code 7, NVRTC_ERROR_BUILTIN_OPERATION_FAILURE,
  183. // incorrectly produces the error string "NVRTC unknown error."
  184. // The following maps it correctly.
  185. //
  186. // This is here instead of in c10 because NVRTC is loaded dynamically via a stub
  187. // in ATen, and we need to use its nvrtcGetErrorString.
  188. // See NOTE [ USE OF NVRTC AND DRIVER API ].
  189. #define AT_CUDA_NVRTC_CHECK(EXPR) \
  190. do { \
  191. nvrtcResult __err = EXPR; \
  192. if (__err != NVRTC_SUCCESS) { \
  193. if (static_cast<int>(__err) != 7) { \
  194. TORCH_CHECK(false, "CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \
  195. } else { \
  196. TORCH_CHECK(false, "CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \
  197. } \
  198. } \
  199. } while (0)
  200. #else
  201. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  202. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)