DispatchStub.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/DeviceType.h>
  4. #include <c10/macros/Macros.h>
  5. #include <atomic>
  6. #include <utility>
  7. #include <variant>
  8. // Implements instruction set specific function dispatch.
  9. //
  10. // Kernels that may make use of specialized instruction sets (e.g. AVX2) are
  11. // compiled multiple times with different compiler flags (e.g. -mavx2). A
  12. // DispatchStub contains a table of function pointers for a kernel. At runtime,
  13. // the fastest available kernel is chosen based on the features reported by
  14. // cpuinfo.
  15. //
  16. // Example:
  17. //
  18. // In native/MyKernel.h:
  19. // using fn_type = void(*)(const Tensor& x);
  20. // DECLARE_DISPATCH(fn_type, stub)
  21. //
  22. // In native/MyKernel.cpp
  23. // DEFINE_DISPATCH(stub);
  24. //
  25. // In native/cpu/MyKernel.cpp:
  26. // namespace {
  27. // // use anonymous namespace so that different cpu versions won't conflict
  28. // void kernel(const Tensor& x) { ... }
  29. // }
  30. // REGISTER_DISPATCH(stub, &kernel);
  31. //
  32. // To call:
  33. // stub(kCPU, tensor);
  34. //
  35. // TODO: CPU instruction set selection should be folded into whatever
  36. // the main dispatch mechanism is.
  37. //
  38. // Supported device types for registration:
  39. // - CPU: Central Processing Unit
  40. // - CUDA: NVIDIA GPUs
  41. // - HIP: AMD GPUs
  42. // - MPS: Apple Silicon GPUs (Metal Performance Shaders)
  43. // - MTIA: Meta Training and Inference Devices
  44. // - XPU: Intel GPUs
  45. // - HPU: Reserved for HPU (Intel Gaudi) device types
  46. // - PrivateUse1: Reserved for private/custom device types
  47. //
  48. // If you want to update the list of supported devices, add a new dispatch_ptr
  49. // member in DispatchStubImpl.h and update the get_call_ptr switch.
  50. // As well you will need to update the inlined list in 'is_device_supported`
  51. //
  52. //
  53. // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
  54. C10_CLANG_DIAGNOSTIC_PUSH()
  55. C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
  56. namespace at::native {
  57. enum class CPUCapability {
  58. DEFAULT = 0,
  59. #if defined(HAVE_VSX_CPU_DEFINITION)
  60. VSX = 1,
  61. #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
  62. ZVECTOR = 1,
  63. #elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
  64. SVE256 = 1,
  65. #else
  66. AVX2 = 1,
  67. AVX512 = 2,
  68. #endif
  69. NUM_OPTIONS
  70. };
  71. // Enum for error types
  72. enum class ErrorType {
  73. MissingDeviceKernel,
  74. DeviceNotSupported
  75. };
  76. // Alias for the return type using std::variant
  77. using DispatchResult = std::variant<void*, ErrorType>;
  78. CPUCapability get_cpu_capability();
  79. template <typename FnPtr, typename T>
  80. struct DispatchStub;
  81. /**
  82. * The sole purpose of this class is to outline methods that don't need to be
  83. * specialized or otherwise inlined and duplicated (by the compiler due to
  84. * template expansion), since it causes size bloat if there are a significant
  85. * number of specialization of the DispatchStub<> class.
  86. */
  87. struct TORCH_API DispatchStubImpl {
  88. // The DispatchStubImpl::try_get_call_ptr() method is used to get the call
  89. // pointer for a given device type. If the call pointer is not found,
  90. // DispatchStubImpl::try_get_call_ptr() returns an ErrorType.
  91. // The main difference between try_get_call_ptr() and get_call_ptr() is that
  92. // try_get_call_ptr() will return the ErrorType and not raise an exception.
  93. DispatchResult try_get_call_ptr(
  94. c10::DeviceType device_type
  95. , void *DEFAULT
  96. #ifdef HAVE_AVX512_CPU_DEFINITION
  97. , void *AVX512
  98. #endif
  99. #ifdef HAVE_AVX2_CPU_DEFINITION
  100. , void *AVX2
  101. #endif
  102. #ifdef HAVE_VSX_CPU_DEFINITION
  103. , void *VSX
  104. #endif
  105. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  106. , void *ZVECTOR
  107. #endif
  108. #ifdef HAVE_SVE256_CPU_DEFINITION
  109. , void *SVE256
  110. #endif
  111. );
  112. // Analogous to try_get_call_ptr(), but it will return the ErrorType and not
  113. // raise an exception.
  114. DispatchResult try_choose_cpu_impl(
  115. void *DEFAULT
  116. #ifdef HAVE_AVX512_CPU_DEFINITION
  117. , void *AVX512
  118. #endif
  119. #ifdef HAVE_AVX2_CPU_DEFINITION
  120. , void *AVX2
  121. #endif
  122. #ifdef HAVE_VSX_CPU_DEFINITION
  123. , void *VSX
  124. #endif
  125. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  126. , void *ZVECTOR
  127. #endif
  128. #ifdef HAVE_SVE256_CPU_DEFINITION
  129. , void *SVE256
  130. #endif
  131. );
  132. void* get_call_ptr(
  133. c10::DeviceType device_type
  134. , void *DEFAULT
  135. #ifdef HAVE_AVX512_CPU_DEFINITION
  136. , void *AVX512
  137. #endif
  138. #ifdef HAVE_AVX2_CPU_DEFINITION
  139. , void *AVX2
  140. #endif
  141. #ifdef HAVE_VSX_CPU_DEFINITION
  142. , void *VSX
  143. #endif
  144. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  145. , void *ZVECTOR
  146. #endif
  147. #ifdef HAVE_SVE256_CPU_DEFINITION
  148. , void *SVE256
  149. #endif
  150. );
  151. /**
  152. * The CPU Dispatch actual method is chosen in decreasing order of preference by
  153. * DispatchStubImpl::choose_cpu_impl() in case none is found by
  154. * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
  155. */
  156. void* choose_cpu_impl(
  157. void *DEFAULT
  158. #ifdef HAVE_AVX512_CPU_DEFINITION
  159. , void *AVX512
  160. #endif
  161. #ifdef HAVE_AVX2_CPU_DEFINITION
  162. , void *AVX2
  163. #endif
  164. #ifdef HAVE_VSX_CPU_DEFINITION
  165. , void *VSX
  166. #endif
  167. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  168. , void *ZVECTOR
  169. #endif
  170. #ifdef HAVE_SVE256_CPU_DEFINITION
  171. , void *SVE256
  172. #endif
  173. );
  174. // Fixing dispatch error in Windows debug builds.
  175. // See https://github.com/pytorch/pytorch/issues/22681 for more details.
  176. #if defined(_MSC_VER) && defined(_DEBUG)
  177. std::atomic<void*> cpu_dispatch_ptr;
  178. void* cuda_dispatch_ptr;
  179. void* hip_dispatch_ptr;
  180. void* mps_dispatch_ptr;
  181. void* mtia_dispatch_ptr;
  182. #if defined(USE_XPU)
  183. void* xpu_dispatch_ptr;
  184. #endif
  185. void* hpu_dispatch_ptr;
  186. void* privateuse1_dispatch_ptr;
  187. #else
  188. std::atomic<void*> cpu_dispatch_ptr{nullptr};
  189. void* cuda_dispatch_ptr = nullptr;
  190. void* hip_dispatch_ptr = nullptr;
  191. void* mps_dispatch_ptr = nullptr;
  192. void* mtia_dispatch_ptr = nullptr;
  193. #if defined(USE_XPU)
  194. void* xpu_dispatch_ptr = nullptr;
  195. #endif
  196. void* hpu_dispatch_ptr = nullptr;
  197. void* privateuse1_dispatch_ptr = nullptr;
  198. #endif
  199. };
  200. template <typename rT, typename T, typename... Args>
  201. struct DispatchStub<rT (*)(Args...), T> {
  202. using FnPtr = rT (*) (Args...);
  203. DispatchStub() = default;
  204. DispatchStub(const DispatchStub&) = delete;
  205. DispatchStub& operator=(const DispatchStub&) = delete;
  206. private:
  207. FnPtr get_call_ptr(const c10::DeviceType device_type) {
  208. return reinterpret_cast<FnPtr>(
  209. impl.get_call_ptr(device_type
  210. , reinterpret_cast<void*>(DEFAULT)
  211. #ifdef HAVE_AVX512_CPU_DEFINITION
  212. , reinterpret_cast<void*>(AVX512)
  213. #endif
  214. #ifdef HAVE_AVX2_CPU_DEFINITION
  215. , reinterpret_cast<void*>(AVX2)
  216. #endif
  217. #ifdef HAVE_VSX_CPU_DEFINITION
  218. , reinterpret_cast<void*>(VSX)
  219. #endif
  220. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  221. , reinterpret_cast<void*>(ZVECTOR)
  222. #endif
  223. #ifdef HAVE_SVE256_CPU_DEFINITION
  224. , reinterpret_cast<void*>(SVE256)
  225. #endif
  226. )
  227. );
  228. }
  229. public:
  230. template <typename... ArgTypes>
  231. rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
  232. FnPtr call_ptr = get_call_ptr(device_type);
  233. return (*call_ptr)(std::forward<ArgTypes>(args)...);
  234. }
  235. void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
  236. impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  237. }
  238. #if defined(USE_XPU)
  239. void set_xpu_dispatch_ptr(FnPtr fn_ptr){
  240. impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  241. }
  242. #endif
  243. void set_hpu_dispatch_ptr(FnPtr fn_ptr) {
  244. impl.hpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  245. }
  246. void set_hip_dispatch_ptr(FnPtr fn_ptr) {
  247. impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  248. }
  249. void set_mps_dispatch_ptr(FnPtr fn_ptr) {
  250. impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  251. }
  252. void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
  253. impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  254. }
  255. void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
  256. impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  257. }
  258. // Returns true if the dispatcher has a kernel registered for this device
  259. // type.
  260. bool is_device_supported(const c10::DeviceType device_type) {
  261. auto result = impl.try_get_call_ptr(device_type
  262. , reinterpret_cast<void*>(DEFAULT)
  263. #ifdef HAVE_AVX512_CPU_DEFINITION
  264. , reinterpret_cast<void*>(AVX512)
  265. #endif
  266. #ifdef HAVE_AVX2_CPU_DEFINITION
  267. , reinterpret_cast<void*>(AVX2)
  268. #endif
  269. #ifdef HAVE_VSX_CPU_DEFINITION
  270. , reinterpret_cast<void*>(VSX)
  271. #endif
  272. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  273. , reinterpret_cast<void*>(ZVECTOR)
  274. #endif
  275. #ifdef HAVE_SVE256_CPU_DEFINITION
  276. , reinterpret_cast<void*>(SVE256)
  277. #endif
  278. );
  279. if (std::holds_alternative<ErrorType>(result)){
  280. return false;
  281. }
  282. return true;
  283. }
  284. static TORCH_API FnPtr DEFAULT;
  285. #ifdef HAVE_AVX512_CPU_DEFINITION
  286. static TORCH_API FnPtr AVX512;
  287. #endif
  288. #ifdef HAVE_AVX2_CPU_DEFINITION
  289. static TORCH_API FnPtr AVX2;
  290. #endif
  291. #ifdef HAVE_VSX_CPU_DEFINITION
  292. static TORCH_API FnPtr VSX;
  293. #endif
  294. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  295. static TORCH_API FnPtr ZVECTOR;
  296. #endif
  297. #ifdef HAVE_SVE256_CPU_DEFINITION
  298. static TORCH_API FnPtr SVE256;
  299. #endif
  300. private:
  301. DispatchStubImpl impl;
  302. };
  303. namespace {
  304. template <typename DispatchStub>
  305. struct RegisterCUDADispatch {
  306. RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  307. stub.set_cuda_dispatch_ptr(value);
  308. }
  309. };
  310. template <typename DispatchStub>
  311. struct RegisterXPUDispatch {
  312. RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
  313. stub.set_xpu_dispatch_ptr(value);
  314. }
  315. };
  316. template <typename DispatchStub>
  317. struct RegisterHPUDispatch {
  318. RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
  319. stub.set_hpu_dispatch_ptr(value);
  320. }
  321. };
  322. template <typename DispatchStub>
  323. struct RegisterMPSDispatch {
  324. RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  325. stub.set_mps_dispatch_ptr(value);
  326. }
  327. };
  328. template <typename DispatchStub>
  329. struct RegisterHIPDispatch {
  330. RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  331. // TODO: make this point at hip_dispatch_ptr
  332. stub.set_cuda_dispatch_ptr(value);
  333. }
  334. };
  335. template <typename DispatchStub>
  336. struct RegisterMTIADispatch {
  337. RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  338. stub.set_mtia_dispatch_ptr(value);
  339. }
  340. };
  341. template <typename DispatchStub>
  342. struct RegisterPRIVATEUSE1Dispatch {
  343. RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  344. stub.set_privateuse1_dispatch_ptr(value);
  345. }
  346. };
  347. } // anonymous namespace
  348. // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
  349. // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
  350. // adding parentheses and using helper struct to get rid of the parentheses, do
  351. // not work with MSVC. So do a `using`-declaration if you need to pass in such
  352. // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
  353. #define DECLARE_DISPATCH(fn, name) \
  354. struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
  355. name##_DECLARE_DISPATCH_type() = default; \
  356. name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
  357. name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
  358. name##_DECLARE_DISPATCH_type(name##_DECLARE_DISPATCH_type&&) = delete; \
  359. name##_DECLARE_DISPATCH_type& operator=(name##_DECLARE_DISPATCH_type&&) = delete; \
  360. ~name##_DECLARE_DISPATCH_type() = default; \
  361. }; \
  362. extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
  363. #define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
  364. #define REGISTER_ARCH_DISPATCH(name, arch, fn) \
  365. template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
  366. #ifdef HAVE_AVX512_CPU_DEFINITION
  367. #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
  368. #else
  369. #define REGISTER_AVX512_DISPATCH(name, fn)
  370. #endif
  371. #ifdef HAVE_AVX2_CPU_DEFINITION
  372. #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
  373. #else
  374. #define REGISTER_AVX2_DISPATCH(name, fn)
  375. #endif
  376. #ifdef HAVE_VSX_CPU_DEFINITION
  377. #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
  378. #else
  379. #define REGISTER_VSX_DISPATCH(name, fn)
  380. #endif
  381. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  382. #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
  383. #else
  384. #define REGISTER_ZVECTOR_DISPATCH(name, fn)
  385. #endif
  386. #ifdef HAVE_SVE256_CPU_DEFINITION
  387. #define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn)
  388. #else
  389. #define REGISTER_SVE256_DISPATCH(name, fn)
  390. #endif
  391. // Macro to register the same kernel for all CPU arch types. This is useful
  392. // if a kernel does not benefit from being recompiled across different arch types.
  393. #define REGISTER_ALL_CPU_DISPATCH(name, fn) \
  394. REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
  395. REGISTER_AVX512_DISPATCH(name, fn) \
  396. REGISTER_AVX2_DISPATCH(name, fn) \
  397. REGISTER_VSX_DISPATCH(name, fn) \
  398. REGISTER_ZVECTOR_DISPATCH(name, fn) \
  399. REGISTER_SVE256_DISPATCH(name, fn)
  400. #define REGISTER_NO_CPU_DISPATCH(name) \
  401. REGISTER_ALL_CPU_DISPATCH(name, nullptr)
  402. #define REGISTER_CUDA_DISPATCH(name, fn) \
  403. static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  404. #define REGISTER_XPU_DISPATCH(name, fn) \
  405. static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  406. #define REGISTER_HPU_DISPATCH(name, fn) \
  407. static RegisterHPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  408. #define REGISTER_HIP_DISPATCH(name, fn) \
  409. static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  410. #define REGISTER_MPS_DISPATCH(name, fn) \
  411. static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  412. #define REGISTER_MTIA_DISPATCH(name, fn) \
  413. static RegisterMTIADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  414. #define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
  415. static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
  416. // NB: This macro must be used in an actual 'cu' file; if you try using
  417. // it from a 'cpp' file it will not work!
  418. #if defined(__CUDACC__)
  419. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
  420. #elif defined(__HIPCC__)
  421. // TODO: cut this over to HIP dispatch once we stop pretending that CUDA
  422. // is HIP in the PyTorch HIPify build.
  423. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
  424. // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
  425. #elif defined(__OBJC__) && defined(USE_MPS)
  426. // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
  427. #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
  428. #elif defined(CPU_CAPABILITY)
  429. // REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
  430. // ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
  431. // ALSO_REGISTER_SVE256_DISPATCH should be used for ensuring SVE256 dispatch, among others.
  432. #ifdef CPU_CAPABILITY_AVX512
  433. #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
  434. #else
  435. #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  436. #endif
  437. #define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  438. #define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  439. #endif
  440. } // namespace at::native
  441. C10_CLANG_DIAGNOSTIC_POP()
  442. #else
  443. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  444. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)