CUDAHooksInterface.h 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/Allocator.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/Registry.h>
  6. #include <ATen/detail/AcceleratorHooksInterface.h>
  7. // NB: Class must live in `at` due to limitations of Registry.h.
  8. namespace at {
  9. // Forward-declares at::cuda::NVRTC
  10. namespace cuda {
  11. struct NVRTC;
  12. } // namespace cuda
  13. #ifdef _MSC_VER
  14. constexpr const char* CUDA_HELP =
  15. "PyTorch splits its backend into two shared libraries: a CPU library "
  16. "and a CUDA library; this error has occurred because you are trying "
  17. "to use some CUDA functionality, but the CUDA library has not been "
  18. "loaded by the dynamic linker for some reason. The CUDA library MUST "
  19. "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! "
  20. "One common culprit is a lack of -INCLUDE:?warp_size@cuda@at@@YAHXZ "
  21. "in your link arguments; many dynamic linkers will delete dynamic library "
  22. "dependencies if you don't depend on any of their symbols. You can check "
  23. "if this has occurred by using link on your binary to see if there is a "
  24. "dependency on *_cuda.dll library.";
  25. #else
  26. constexpr const char* CUDA_HELP =
  27. "PyTorch splits its backend into two shared libraries: a CPU library "
  28. "and a CUDA library; this error has occurred because you are trying "
  29. "to use some CUDA functionality, but the CUDA library has not been "
  30. "loaded by the dynamic linker for some reason. The CUDA library MUST "
  31. "be loaded, EVEN IF you don't directly use any symbols from the CUDA library! "
  32. "One common culprit is a lack of -Wl,--no-as-needed in your link arguments; many "
  33. "dynamic linkers will delete dynamic library dependencies if you don't "
  34. "depend on any of their symbols. You can check if this has occurred by "
  35. "using ldd on your binary to see if there is a dependency on *_cuda.so "
  36. "library.";
  37. #endif
  38. // The CUDAHooksInterface is an omnibus interface for any CUDA functionality
  39. // which we may want to call into from CPU code (and thus must be dynamically
  40. // dispatched, to allow for separate compilation of CUDA code). How do I
  41. // decide if a function should live in this class? There are two tests:
  42. //
  43. // 1. Does the *implementation* of this function require linking against
  44. // CUDA libraries?
  45. //
  46. // 2. Is this function *called* from non-CUDA ATen code?
  47. //
  48. // (2) should filter out many ostensible use-cases, since many times a CUDA
  49. // function provided by ATen is only really ever used by actual CUDA code.
  50. //
  51. // TODO: Consider putting the stub definitions in another class, so that one
  52. // never forgets to implement each virtual function in the real implementation
  53. // in CUDAHooks. This probably doesn't buy us much though.
  54. struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
  55. // This should never actually be implemented, but it is used to
  56. // squelch -Werror=non-virtual-dtor
  57. ~CUDAHooksInterface() override = default;
  58. // Initialize THCState and, transitively, the CUDA state
  59. void init() const override {
  60. TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP);
  61. }
  62. const Generator& getDefaultGenerator(
  63. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  64. TORCH_CHECK(
  65. false,
  66. "Cannot get default CUDA generator without ATen_cuda library. ",
  67. CUDA_HELP);
  68. }
  69. Generator getNewGenerator(
  70. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  71. TORCH_CHECK(
  72. false,
  73. "Cannot get CUDA generator without ATen_cuda library. ",
  74. CUDA_HELP);
  75. }
  76. Device getDeviceFromPtr(void* /*data*/) const override {
  77. TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
  78. }
  79. bool isPinnedPtr(const void* /*data*/) const override {
  80. return false;
  81. }
  82. virtual bool hasCUDA() const {
  83. return false;
  84. }
  85. virtual bool hasCUDART() const {
  86. return false;
  87. }
  88. virtual bool hasMAGMA() const {
  89. return false;
  90. }
  91. virtual bool hasCuDNN() const {
  92. return false;
  93. }
  94. virtual bool hasCuSOLVER() const {
  95. return false;
  96. }
  97. virtual bool hasCuBLASLt() const {
  98. return false;
  99. }
  100. virtual bool hasROCM() const {
  101. return false;
  102. }
  103. virtual bool hasCKSDPA() const {
  104. return false;
  105. }
  106. virtual bool hasCKGEMM() const {
  107. return false;
  108. }
  109. virtual const at::cuda::NVRTC& nvrtc() const {
  110. TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
  111. }
  112. bool hasPrimaryContext(DeviceIndex device_index) const override {
  113. TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
  114. }
  115. virtual DeviceIndex current_device() const {
  116. return -1;
  117. }
  118. Allocator* getPinnedMemoryAllocator() const override {
  119. TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP);
  120. }
  121. virtual Allocator* getCUDADeviceAllocator() const {
  122. TORCH_CHECK(false, "CUDADeviceAllocator requires CUDA. ", CUDA_HELP);
  123. }
  124. virtual bool compiledWithCuDNN() const {
  125. return false;
  126. }
  127. virtual bool compiledWithMIOpen() const {
  128. return false;
  129. }
  130. virtual bool supportsDilatedConvolutionWithCuDNN() const {
  131. return false;
  132. }
  133. virtual bool supportsDepthwiseConvolutionWithCuDNN() const {
  134. return false;
  135. }
  136. virtual bool supportsBFloat16ConvolutionWithCuDNNv8() const {
  137. return false;
  138. }
  139. virtual bool supportsBFloat16RNNWithCuDNN() const {
  140. return false;
  141. }
  142. virtual long versionCuDNN() const {
  143. TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
  144. }
  145. virtual long versionRuntimeCuDNN() const {
  146. TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
  147. }
  148. virtual long versionCuDNNFrontend() const {
  149. TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
  150. }
  151. virtual long versionMIOpen() const {
  152. TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
  153. }
  154. virtual long versionHipBLASLt() const {
  155. TORCH_CHECK(false, "Cannot query HipBLASLt version without ATen_cuda library. ", CUDA_HELP);
  156. }
  157. virtual long versionCUDART() const {
  158. TORCH_CHECK(false, "Cannot query CUDART version without ATen_cuda library. ", CUDA_HELP);
  159. }
  160. virtual std::string showConfig() const {
  161. TORCH_CHECK(false, "Cannot query detailed CUDA version without ATen_cuda library. ", CUDA_HELP);
  162. }
  163. virtual double batchnormMinEpsilonCuDNN() const {
  164. TORCH_CHECK(false,
  165. "Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
  166. }
  167. virtual int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex /*device_index*/) const {
  168. TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
  169. }
  170. virtual void cuFFTSetPlanCacheMaxSize(DeviceIndex /*device_index*/, int64_t /*max_size*/) const {
  171. TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
  172. }
  173. virtual int64_t cuFFTGetPlanCacheSize(DeviceIndex /*device_index*/) const {
  174. TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
  175. }
  176. virtual void cuFFTClearPlanCache(DeviceIndex /*device_index*/) const {
  177. TORCH_CHECK(false, "Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
  178. }
  179. virtual int getNumGPUs() const {
  180. return 0;
  181. }
  182. #ifdef USE_ROCM
  183. virtual bool isGPUArch(const std::vector<std::string>& /*archs*/, DeviceIndex = -1 /*device_index*/) const {
  184. TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP);
  185. }
  186. virtual const std::vector<std::string>& getHipblasltPreferredArchs() const {
  187. static const std::vector<std::string> empty;
  188. TORCH_CHECK(false, "Cannot get hipBLASLt preferred archs without ATen_cuda library. ", CUDA_HELP);
  189. return empty;
  190. }
  191. virtual const std::vector<std::string>& getHipblasltSupportedArchs() const {
  192. static const std::vector<std::string> empty;
  193. TORCH_CHECK(false, "Cannot get hipBLASLt supported archs without ATen_cuda library. ", CUDA_HELP);
  194. return empty;
  195. }
  196. #endif
  197. virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
  198. TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
  199. }
  200. };
  201. // NB: dummy argument to suppress "ISO C++11 requires at least one argument
  202. // for the "..." in a variadic macro"
  203. struct TORCH_API CUDAHooksArgs {};
  204. TORCH_DECLARE_REGISTRY(CUDAHooksRegistry, CUDAHooksInterface, CUDAHooksArgs);
  205. #define REGISTER_CUDA_HOOKS(clsname) \
  206. C10_REGISTER_CLASS(CUDAHooksRegistry, clsname, clsname)
  207. namespace detail {
  208. TORCH_API const CUDAHooksInterface& getCUDAHooks();
  209. } // namespace detail
  210. } // namespace at
  211. #else
  212. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  213. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)