Dispatch.h 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/ScalarType.h>
  4. #include <c10/macros/Macros.h>
  5. #include <c10/util/Exception.h>
  6. #include <c10/util/Half.h>
  7. #include <c10/util/Metaprogramming.h>
  8. #include <c10/util/complex.h>
  9. #include <torch/headeronly/core/Dispatch.h>
  10. #ifdef __CUDACC__
  11. #include <cuda.h> // For CUDA_VERSION
  12. #endif
  13. #ifdef TEMPLATE_SELECTIVE_BUILD
  14. #include <ATen/selected_mobile_ops.h>
  15. #else
  16. namespace at {
  17. /**
  18. * The method should_include_kernel_dtype() returns true/false
  19. * based on whether the switching code for a specific dtype should be
  20. * included based on build time constants generated from tracing model
  21. * execution. This method will be implemented via code-generation and
  22. * included in this file when code-gen is ready.
  23. */
  24. inline constexpr bool should_include_kernel_dtype(
  25. const char* /*kernel_tag_str*/,
  26. at::ScalarType /*scalar_type*/
  27. ) {
  28. return true;
  29. }
  30. } // namespace at
  31. #endif
  32. /**
  33. * In the Facebook internal build (using BUCK), this macro is enabled by
  34. * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
  35. * binary.
  36. */
  37. #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
  38. namespace at::detail {
  39. TORCH_API void record_kernel_function_dtype(std::string name);
  40. } // namespace at::detail
  41. #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
  42. at::detail::record_kernel_function_dtype( \
  43. std::string(NAME) + "$" + toString(enum_type));
  44. #else
  45. #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
  46. #endif
  47. #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
  48. do { \
  49. if constexpr (!at::should_include_kernel_dtype( \
  50. at_dispatch_name, enum_type)) { \
  51. TORCH_CHECK( \
  52. false, \
  53. "dtype '", \
  54. toString(enum_type), \
  55. "' not selected for kernel tag ", \
  56. at_dispatch_name); \
  57. } \
  58. } while (0)
  59. #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
  60. THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
  61. AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
  62. #define AT_DISPATCH_CASE(enum_type, ...) \
  63. AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
  64. #define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
  65. case enum_type: { \
  66. AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
  67. using scalar_t = scalar_type; \
  68. using underlying_t [[maybe_unused]] = typename scalar_t::underlying; \
  69. [[maybe_unused]] const auto& SCALAR_TYPE = enum_type; \
  70. [[maybe_unused]] const auto& UNDERLYING_TYPE = toUnderlying(enum_type); \
  71. return __VA_ARGS__(); \
  72. }
  73. #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  74. enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
  75. case enum_type: { \
  76. AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
  77. using scalar_t = scalar_type; \
  78. using underlying_t [[maybe_unused]] = typename scalar_t::underlying; \
  79. [[maybe_unused]] const auto& SCALAR_TYPE = enum_type; \
  80. [[maybe_unused]] const auto& UNDERLYING_TYPE = toUnderlying(enum_type); \
  81. [[maybe_unused]] int bit_width = bitwidth; \
  82. [[maybe_unused]] int64_t quant_min = qmin; \
  83. [[maybe_unused]] int64_t quant_max = qmax; \
  84. return __VA_ARGS__(); \
  85. }
  86. // The AT_DISPATCH_* family of macros provides the ability to
  87. // conveniently generate specializations of a kernel over all of the
  88. // dtypes we care about in PyTorch. We call it "dispatch" because
  89. // we are "dispatching" to the correct, dtype-specific kernel.
  90. //
  91. // A standard usage looks like:
  92. //
  93. // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
  94. // // Your code here, with 'scalar_t' now defined to
  95. // // be the dtype in question
  96. // });
  97. //
  98. // There are many variations of this macro, so it's important to
  99. // understand exactly /which/ dtypes you want to get instantiated, as
  100. // well as what the "default" set is.
  101. //
  102. // The default set of dtypes that are instantiated (e.g., by
  103. // AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
  104. // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
  105. // but NOT booleans (bool), half-precision floats (Half) or
  106. // complex number (c10::complex<float>, c10::complex<double>).
  107. // This "cut" is somewhat historical (the default types are the
  108. // ones that TH historically supported), but it also reflects the
  109. // fact that the non-default types are "poorly" behaved (booleans
  110. // are NOT integers mod 2, half precision operations ~essentially
  111. // don't exist on CPU, complex numbers are an experimental application).
  112. //
  113. // Here are the questions you should generally ask to decide which
  114. // dispatch you want:
  115. //
  116. // 1. Is this an integral or floating point specific operation?
  117. // (If so, you'll want one of the FLOATING or INTEGRAL macros.)
  118. //
  119. // 2. Should half be supported? (If you're on CPU, the answer is almost
  120. // definitely no. If you do want support, use one of the AND_HALF
  121. // macros)
  122. //
  123. // Much rarer situations:
  124. //
  125. // 3. Should bool be supported? (You often have to write your kernel
  126. // differently if arithmetic operations are involved.) If so,
  127. // Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
  128. //
  129. // 4. Should complex be supported? The answer is almost always no,
  130. // unless you are working on "generic" code that should work on
  131. // all dtypes.
  132. //
  133. // Parameters:
  134. // -----------
  135. //
  136. // 1. The NAME argument is a "tag" that is used to trace and then
  137. // conditionally compile fragments of the case statements such
  138. // that the kernel functions are specialized only for the dtypes
  139. // that are needed. The NAME parameter *must* be a build time
  140. // const char* (can't be std::string, etc...)
  141. //
  142. // Please ensure that the NAME is unique for every implementation
  143. // or you run the risk of over-including code for the kernel
  144. // functions. There is no risk of missing out on any code, so
  145. // it's mostly a risk of a Type-2 error, and not a Type-1 error.
  146. //
  147. // Switch-like syntax:
  148. // -------------------
  149. // There is also a switch-case like syntax which is useful if a kernel
  150. // needs to be specialized for particular scalar types
  151. //
  152. // AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
  153. // AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
  154. // op_integral<scalar_t>(iter);
  155. // })
  156. // AT_DISPATCH_CASE_FLOATING_TYPES([&] {
  157. // op_floating<scalar_t>(iter);
  158. // })
  159. // AT_DISPATCH_CASE(kBool, [&] {
  160. // op_bool(iter);
  161. // })
  162. // );
  163. //
  164. // For each AT_DISPATCH_FOO macro, there is a corresponding
  165. // AT_DISPATCH_CASE_FOO macro which can be used inside of an
  166. // AT_DISPATCH_SWITCH block.
  167. // NB: the the_type variable is not used, but we have kept it for
  168. // backwards compatibility. It's probably not used by anyone though;
  169. // but we're just being safe (and it doesn't hurt.) Note we must
  170. // use it to shut up warnings about unused store.
  171. #define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
  172. THO_DISPATCH_SWITCH_TMPL( \
  173. RECORD_KERNEL_FUNCTION_DTYPE, \
  174. TORCH_CHECK_NOT_IMPLEMENTED, \
  175. TYPE, \
  176. NAME, \
  177. __VA_ARGS__)
  178. #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
  179. AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
  180. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
  181. #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
  182. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
  183. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
  184. AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
  185. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
  186. AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
  187. #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
  188. AT_DISPATCH_SWITCH( \
  189. TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
  190. #define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \
  191. AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
  192. AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
  193. #define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \
  194. AT_DISPATCH_SWITCH( \
  195. TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__))
  196. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
  197. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  198. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  199. #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  200. AT_DISPATCH_SWITCH( \
  201. TYPE, \
  202. NAME, \
  203. AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  204. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
  205. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  206. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  207. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  208. #define AT_DISPATCH_FLOATING_TYPES_AND2( \
  209. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  210. AT_DISPATCH_SWITCH( \
  211. TYPE, \
  212. NAME, \
  213. AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
  214. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  215. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
  216. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  217. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  218. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  219. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  220. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  221. #define AT_DISPATCH_FLOATING_TYPES_AND3( \
  222. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  223. AT_DISPATCH_SWITCH( \
  224. TYPE, \
  225. NAME, \
  226. AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \
  227. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  228. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
  229. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
  230. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  231. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  232. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  233. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  234. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
  235. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \
  236. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
  237. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  238. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  239. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  240. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  241. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  242. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
  243. #define AT_DISPATCH_FLOATING_TYPES_AND4( \
  244. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
  245. AT_DISPATCH_SWITCH( \
  246. TYPE, \
  247. NAME, \
  248. AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \
  249. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
  250. #define AT_DISPATCH_FLOATING_TYPES_AND5( \
  251. SCALARTYPE1, \
  252. SCALARTYPE2, \
  253. SCALARTYPE3, \
  254. SCALARTYPE4, \
  255. SCALARTYPE5, \
  256. TYPE, \
  257. NAME, \
  258. ...) \
  259. AT_DISPATCH_SWITCH( \
  260. TYPE, \
  261. NAME, \
  262. AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \
  263. SCALARTYPE1, \
  264. SCALARTYPE2, \
  265. SCALARTYPE3, \
  266. SCALARTYPE4, \
  267. SCALARTYPE5, \
  268. __VA_ARGS__))
  269. #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
  270. AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
  271. AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
  272. #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
  273. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
  274. #define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
  275. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
  276. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  277. #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  278. AT_DISPATCH_SWITCH( \
  279. TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  280. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
  281. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  282. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
  283. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
  284. AT_DISPATCH_SWITCH( \
  285. TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
  286. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
  287. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  288. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  289. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
  290. SCALARTYPE, TYPE, NAME, ...) \
  291. AT_DISPATCH_SWITCH( \
  292. TYPE, \
  293. NAME, \
  294. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
  295. SCALARTYPE, __VA_ARGS__))
  296. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
  297. SCALARTYPE1, SCALARTYPE2, ...) \
  298. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  299. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  300. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  301. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
  302. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  303. AT_DISPATCH_SWITCH( \
  304. TYPE, \
  305. NAME, \
  306. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
  307. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  308. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
  309. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  310. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  311. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  312. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  313. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  314. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
  315. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  316. AT_DISPATCH_SWITCH( \
  317. TYPE, \
  318. NAME, \
  319. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
  320. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  321. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
  322. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
  323. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  324. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  325. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  326. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  327. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
  328. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \
  329. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
  330. AT_DISPATCH_SWITCH( \
  331. TYPE, \
  332. NAME, \
  333. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \
  334. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
  335. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
  336. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
  337. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  338. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  339. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  340. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  341. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  342. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
  343. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \
  344. SCALARTYPE1, \
  345. SCALARTYPE2, \
  346. SCALARTYPE3, \
  347. SCALARTYPE4, \
  348. SCALARTYPE5, \
  349. TYPE, \
  350. NAME, \
  351. ...) \
  352. AT_DISPATCH_SWITCH( \
  353. TYPE, \
  354. NAME, \
  355. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \
  356. SCALARTYPE1, \
  357. SCALARTYPE2, \
  358. SCALARTYPE3, \
  359. SCALARTYPE4, \
  360. SCALARTYPE5, \
  361. __VA_ARGS__))
  362. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
  363. SCALARTYPE1, \
  364. SCALARTYPE2, \
  365. SCALARTYPE3, \
  366. SCALARTYPE4, \
  367. SCALARTYPE5, \
  368. SCALARTYPE6, \
  369. ...) \
  370. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  371. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  372. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  373. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  374. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  375. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
  376. AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
  377. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \
  378. SCALARTYPE1, \
  379. SCALARTYPE2, \
  380. SCALARTYPE3, \
  381. SCALARTYPE4, \
  382. SCALARTYPE5, \
  383. SCALARTYPE6, \
  384. TYPE, \
  385. NAME, \
  386. ...) \
  387. AT_DISPATCH_SWITCH( \
  388. TYPE, \
  389. NAME, \
  390. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \
  391. SCALARTYPE1, \
  392. SCALARTYPE2, \
  393. SCALARTYPE3, \
  394. SCALARTYPE4, \
  395. SCALARTYPE5, \
  396. SCALARTYPE6, \
  397. __VA_ARGS__))
  398. #define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
  399. AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
  400. AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
  401. AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
  402. AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
  403. AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
  404. #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
  405. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
  406. #define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
  407. AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
  408. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  409. #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  410. AT_DISPATCH_SWITCH( \
  411. TYPE, \
  412. NAME, \
  413. AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  414. #define AT_DISPATCH_CASE_ALL_TYPES(...) \
  415. AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
  416. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
  417. #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
  418. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
  419. #define AT_DISPATCH_CASE_QINT_TYPES(...) \
  420. AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
  421. AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
  422. AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
  423. #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
  424. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
  425. #define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \
  426. AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \
  427. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  428. #define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  429. AT_DISPATCH_SWITCH( \
  430. TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  431. #define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
  432. AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
  433. AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
  434. #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
  435. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
  436. #define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
  437. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  438. at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
  439. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  440. at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
  441. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  442. at::kQInt32, \
  443. at::qint32, \
  444. CHAR_BIT * sizeof(int), \
  445. INT_MIN, \
  446. INT_MAX, \
  447. __VA_ARGS__) \
  448. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  449. at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
  450. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  451. at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
  452. #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
  453. AT_DISPATCH_SWITCH( \
  454. TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
  455. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
  456. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  457. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
  458. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
  459. AT_DISPATCH_SWITCH( \
  460. TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
  461. #define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
  462. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  463. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  464. #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  465. AT_DISPATCH_SWITCH( \
  466. TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  467. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
  468. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  469. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  470. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
  471. AT_DISPATCH_SWITCH( \
  472. TYPE, \
  473. NAME, \
  474. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
  475. #define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
  476. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  477. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  478. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  479. #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  480. AT_DISPATCH_SWITCH( \
  481. TYPE, \
  482. NAME, \
  483. AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  484. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
  485. SCALARTYPE1, SCALARTYPE2, ...) \
  486. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  487. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  488. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  489. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
  490. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  491. AT_DISPATCH_SWITCH( \
  492. TYPE, \
  493. NAME, \
  494. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
  495. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  496. #define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
  497. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  498. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  499. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  500. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  501. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  502. #define AT_DISPATCH_ALL_TYPES_AND3( \
  503. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  504. AT_DISPATCH_SWITCH( \
  505. TYPE, \
  506. NAME, \
  507. AT_DISPATCH_CASE_ALL_TYPES_AND3( \
  508. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  509. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
  510. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  511. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  512. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  513. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  514. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  515. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
  516. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  517. AT_DISPATCH_SWITCH( \
  518. TYPE, \
  519. NAME, \
  520. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
  521. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  522. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  523. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
  524. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  525. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  526. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  527. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  528. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
  529. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
  530. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
  531. AT_DISPATCH_SWITCH( \
  532. TYPE, \
  533. NAME, \
  534. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  535. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
  536. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
  537. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
  538. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  539. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  540. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  541. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  542. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  543. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
  544. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \
  545. SCALARTYPE1, \
  546. SCALARTYPE2, \
  547. SCALARTYPE3, \
  548. SCALARTYPE4, \
  549. SCALARTYPE5, \
  550. TYPE, \
  551. NAME, \
  552. ...) \
  553. AT_DISPATCH_SWITCH( \
  554. TYPE, \
  555. NAME, \
  556. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
  557. SCALARTYPE1, \
  558. SCALARTYPE2, \
  559. SCALARTYPE3, \
  560. SCALARTYPE4, \
  561. SCALARTYPE5, \
  562. __VA_ARGS__))
  563. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
  564. SCALARTYPE1, \
  565. SCALARTYPE2, \
  566. SCALARTYPE3, \
  567. SCALARTYPE4, \
  568. SCALARTYPE5, \
  569. SCALARTYPE6, \
  570. ...) \
  571. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  572. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  573. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  574. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  575. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  576. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
  577. AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
  578. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
  579. SCALARTYPE1, \
  580. SCALARTYPE2, \
  581. SCALARTYPE3, \
  582. SCALARTYPE4, \
  583. SCALARTYPE5, \
  584. SCALARTYPE6, \
  585. TYPE, \
  586. NAME, \
  587. ...) \
  588. AT_DISPATCH_SWITCH( \
  589. TYPE, \
  590. NAME, \
  591. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
  592. SCALARTYPE1, \
  593. SCALARTYPE2, \
  594. SCALARTYPE3, \
  595. SCALARTYPE4, \
  596. SCALARTYPE5, \
  597. SCALARTYPE6, \
  598. __VA_ARGS__))
  599. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
  600. SCALARTYPE1, \
  601. SCALARTYPE2, \
  602. SCALARTYPE3, \
  603. SCALARTYPE4, \
  604. SCALARTYPE5, \
  605. SCALARTYPE6, \
  606. SCALARTYPE7, \
  607. ...) \
  608. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  609. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  610. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  611. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  612. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  613. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
  614. AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
  615. AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__)
  616. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \
  617. SCALARTYPE1, \
  618. SCALARTYPE2, \
  619. SCALARTYPE3, \
  620. SCALARTYPE4, \
  621. SCALARTYPE5, \
  622. SCALARTYPE6, \
  623. SCALARTYPE7, \
  624. TYPE, \
  625. NAME, \
  626. ...) \
  627. AT_DISPATCH_SWITCH( \
  628. TYPE, \
  629. NAME, \
  630. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \
  631. SCALARTYPE1, \
  632. SCALARTYPE2, \
  633. SCALARTYPE3, \
  634. SCALARTYPE4, \
  635. SCALARTYPE5, \
  636. SCALARTYPE6, \
  637. SCALARTYPE7, \
  638. __VA_ARGS__))
  639. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
  640. SCALARTYPE1, \
  641. SCALARTYPE2, \
  642. SCALARTYPE3, \
  643. SCALARTYPE4, \
  644. SCALARTYPE5, \
  645. SCALARTYPE6, \
  646. SCALARTYPE7, \
  647. SCALARTYPE8, \
  648. ...) \
  649. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  650. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  651. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  652. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  653. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \
  654. AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \
  655. AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \
  656. AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \
  657. AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__)
  658. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \
  659. SCALARTYPE1, \
  660. SCALARTYPE2, \
  661. SCALARTYPE3, \
  662. SCALARTYPE4, \
  663. SCALARTYPE5, \
  664. SCALARTYPE6, \
  665. SCALARTYPE7, \
  666. SCALARTYPE8, \
  667. TYPE, \
  668. NAME, \
  669. ...) \
  670. AT_DISPATCH_SWITCH( \
  671. TYPE, \
  672. NAME, \
  673. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \
  674. SCALARTYPE1, \
  675. SCALARTYPE2, \
  676. SCALARTYPE3, \
  677. SCALARTYPE4, \
  678. SCALARTYPE5, \
  679. SCALARTYPE6, \
  680. SCALARTYPE7, \
  681. SCALARTYPE8, \
  682. __VA_ARGS__))
  683. #define AT_DISPATCH_CASE_BIT_TYPES(...) \
  684. AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \
  685. AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \
  686. AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \
  687. AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \
  688. AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__)
  689. #define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \
  690. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__))
  691. #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
  692. AT_DISPATCH_SWITCH( \
  693. TYPE, \
  694. NAME, \
  695. AT_PRIVATE_CASE_TYPE_USING_HINT( \
  696. at::ScalarType::Int, index_t, __VA_ARGS__) \
  697. AT_PRIVATE_CASE_TYPE_USING_HINT( \
  698. at::ScalarType::Long, index_t, __VA_ARGS__))
  699. #else
  700. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  701. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)