autocast_mode.h 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/ATen.h>
  4. #include <ATen/NativeFunctions.h>
  5. #include <ATen/Operators.h>
  6. #include <torch/library.h>
  7. #include <c10/core/impl/LocalDispatchKeySet.h>
  8. #include <c10/util/intrusive_ptr.h>
  9. namespace at::autocast {
  10. TORCH_API bool is_autocast_enabled(at::DeviceType device_type);
  11. TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled);
  12. TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type);
  13. TORCH_API void set_autocast_dtype(
  14. at::DeviceType device_type,
  15. at::ScalarType dtype);
  16. TORCH_API void clear_cache();
  17. TORCH_API int increment_nesting();
  18. TORCH_API int decrement_nesting();
  19. TORCH_API bool is_autocast_cache_enabled();
  20. TORCH_API void set_autocast_cache_enabled(bool enabled);
  21. // deprecated CUDA-specific autocast APIs
  22. C10_DEPRECATED_MESSAGE(
  23. "at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
  24. inline bool is_enabled() {
  25. TORCH_WARN_DEPRECATION(
  26. "at::autocast::",
  27. __func__,
  28. "() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
  29. return is_autocast_enabled(at::kCUDA);
  30. }
  31. C10_DEPRECATED_MESSAGE(
  32. "at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
  33. inline void set_enabled(bool enabled) {
  34. TORCH_WARN_DEPRECATION(
  35. "at::autocast::",
  36. __func__,
  37. "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
  38. set_autocast_enabled(at::kCUDA, enabled);
  39. }
  40. C10_DEPRECATED_MESSAGE(
  41. "at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
  42. inline at::ScalarType get_autocast_gpu_dtype() {
  43. TORCH_WARN_DEPRECATION(
  44. "at::autocast::",
  45. __func__,
  46. "() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
  47. return get_autocast_dtype(at::kCUDA);
  48. }
  49. C10_DEPRECATED_MESSAGE(
  50. "at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
  51. inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
  52. TORCH_WARN_DEPRECATION(
  53. "at::autocast::",
  54. __func__,
  55. "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
  56. set_autocast_dtype(at::kCUDA, dtype);
  57. }
  58. #define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \
  59. C10_DEPRECATED_MESSAGE( \
  60. "at::autocast::is_" #name \
  61. "_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
  62. ") instead.") \
  63. inline bool is_##name##_enabled() { \
  64. TORCH_WARN_DEPRECATION( \
  65. "at::autocast::", \
  66. __func__, \
  67. "() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
  68. ") instead.") \
  69. return is_autocast_enabled(device_type); \
  70. } \
  71. \
  72. C10_DEPRECATED_MESSAGE( \
  73. "at::autocast::set_" #name \
  74. "_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
  75. ", enabled) instead.") \
  76. inline void set_##name##_enabled(bool enabled) { \
  77. TORCH_WARN_DEPRECATION( \
  78. "at::autocast::", \
  79. __func__, \
  80. "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
  81. ", enabled) instead.") \
  82. set_autocast_enabled(device_type, enabled); \
  83. } \
  84. \
  85. C10_DEPRECATED_MESSAGE( \
  86. "at::autocast::get_autocast_" #name \
  87. "_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \
  88. ") instead.") \
  89. inline at::ScalarType get_autocast_##name##_dtype() { \
  90. TORCH_WARN_DEPRECATION( \
  91. "at::autocast::", \
  92. __func__, \
  93. "() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \
  94. ") instead.") \
  95. return get_autocast_dtype(device_type); \
  96. } \
  97. \
  98. C10_DEPRECATED_MESSAGE( \
  99. "at::autocast::set_autocast_" #name \
  100. "_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
  101. ", dtype) instead.") \
  102. inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \
  103. TORCH_WARN_DEPRECATION( \
  104. "at::autocast::", \
  105. __func__, \
  106. "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
  107. ", dtype) instead.") \
  108. set_autocast_dtype(device_type, dtype); \
  109. }
  110. #define AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(_) \
  111. _(cpu, at::kCPU) \
  112. _(mtia, at::kMTIA) \
  113. _(xpu, at::kXPU) \
  114. _(xla, at::kXLA) \
  115. _(hpu, at::kHPU) \
  116. _(ipu, at::kIPU) \
  117. _(privateuseone, at::kPrivateUse1)
  118. // deprecated other backend specific autocast APIs
  119. // NOLINTNEXTLINE(misc-use-internal-linkage)
  120. AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
  121. const std::array<at::DeviceType, 10> _AUTOCAST_SUPPORTED_DEVICES{
  122. at::kCPU,
  123. at::kCUDA,
  124. at::kMTIA,
  125. at::kMAIA,
  126. at::kXPU,
  127. at::kIPU,
  128. at::kHPU,
  129. at::kXLA,
  130. at::kPrivateUse1,
  131. at::kMPS};
  132. namespace {
  133. inline bool is_autocast_eligible(
  134. const Tensor& tensor,
  135. c10::DeviceType device_type) {
  136. switch (device_type) {
  137. case c10::DeviceType::CUDA:
  138. return (tensor.is_cuda() || tensor.is_xla()) &&
  139. tensor.is_floating_point();
  140. case c10::DeviceType::CPU:
  141. return (tensor.is_cpu() || tensor.is_mkldnn()) &&
  142. tensor.is_floating_point();
  143. case c10::DeviceType::MTIA:
  144. return tensor.is_mtia() && tensor.is_floating_point();
  145. case c10::DeviceType::MAIA:
  146. return tensor.is_maia() && tensor.is_floating_point();
  147. case c10::DeviceType::XPU:
  148. return tensor.is_xpu() && tensor.is_floating_point();
  149. case c10::DeviceType::IPU:
  150. return tensor.is_ipu() && tensor.is_floating_point();
  151. case c10::DeviceType::HPU:
  152. return tensor.is_hpu() && tensor.is_floating_point();
  153. case c10::DeviceType::XLA:
  154. return tensor.is_xla() && tensor.is_floating_point();
  155. case c10::DeviceType::PrivateUse1:
  156. return tensor.is_privateuseone() && tensor.is_floating_point();
  157. case c10::DeviceType::MPS:
  158. return tensor.is_mps() && tensor.is_floating_point();
  159. default:
  160. return false;
  161. }
  162. }
  163. } // namespace
  164. inline DispatchKey get_autocast_dispatch_key_from_device_type(
  165. c10::DeviceType device_type) {
  166. switch (device_type) {
  167. case c10::DeviceType::CUDA:
  168. return DispatchKey::Autocast;
  169. case c10::DeviceType::CPU:
  170. return DispatchKey::AutocastCPU;
  171. case c10::DeviceType::MTIA:
  172. return DispatchKey::AutocastMTIA;
  173. case c10::DeviceType::MAIA:
  174. return DispatchKey::AutocastMAIA;
  175. case c10::DeviceType::XPU:
  176. return DispatchKey::AutocastXPU;
  177. case c10::DeviceType::IPU:
  178. return DispatchKey::AutocastIPU;
  179. case c10::DeviceType::HPU:
  180. return DispatchKey::AutocastHPU;
  181. case c10::DeviceType::XLA:
  182. return DispatchKey::AutocastXLA;
  183. case c10::DeviceType::PrivateUse1:
  184. return DispatchKey::AutocastPrivateUse1;
  185. case c10::DeviceType::MPS:
  186. return DispatchKey::AutocastMPS;
  187. default:
  188. TORCH_CHECK(
  189. false,
  190. "unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
  191. }
  192. }
  193. inline bool is_autocast_available(c10::DeviceType device_type) {
  194. if (std::find(
  195. _AUTOCAST_SUPPORTED_DEVICES.begin(),
  196. _AUTOCAST_SUPPORTED_DEVICES.end(),
  197. device_type) != _AUTOCAST_SUPPORTED_DEVICES.end()) {
  198. return true;
  199. } else {
  200. return false;
  201. }
  202. }
  203. inline at::ScalarType get_lower_precision_fp_from_device_type(
  204. c10::DeviceType device_type) {
  205. if (is_autocast_available(device_type)) {
  206. return get_autocast_dtype(device_type);
  207. } else {
  208. TORCH_CHECK(
  209. false,
  210. "unknown device type for autocast in get_lower_precision_fp_from_device_type");
  211. }
  212. }
  213. /********************************************************************
  214. Logic to extract the promote type from any Tensor or TensorList args.
  215. ********************************************************************/
  216. // Overload to catch Tensor args.
  217. // If nextArg is floating-point, compare its scalar_type with our
  218. // current best guess for the promote type, and update if necessary.
  219. inline at::ScalarType prioritize(
  220. at::ScalarType current,
  221. const Tensor& nextArg,
  222. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  223. if (current == at::kDouble) {
  224. TORCH_CHECK(false, "promote type is double in at::autocast::prioritize");
  225. return current;
  226. }
  227. at::ScalarType lower_precision_fp =
  228. get_lower_precision_fp_from_device_type(device_type);
  229. if (is_autocast_eligible(nextArg, device_type)) {
  230. auto next = nextArg.scalar_type();
  231. if (next == at::kDouble) {
  232. return current; // ignores double tensors
  233. } else if (current == at::kFloat || next == at::kFloat) {
  234. return at::kFloat; // prioritizes float over lower_precision_fp
  235. } else if (current == lower_precision_fp && next == lower_precision_fp) {
  236. return lower_precision_fp;
  237. } else {
  238. TORCH_CHECK(
  239. false, "Unexpected floating ScalarType in at::autocast::prioritize");
  240. return current;
  241. }
  242. } else {
  243. return current;
  244. }
  245. }
  246. // Overload to catch TensorList args (for e.g. cat, stack).
  247. // Reuses the overload above to process each Tensor in the list.
  248. inline at::ScalarType prioritize(
  249. at::ScalarType current,
  250. const TensorList& list,
  251. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  252. for (const auto& tensor : list) {
  253. current = prioritize(current, tensor, device_type);
  254. }
  255. return current;
  256. }
  257. inline at::ScalarType prioritize(
  258. at::ScalarType current,
  259. const ITensorListRef& list,
  260. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  261. for (const auto& tensor : list) {
  262. current = prioritize(current, tensor, device_type);
  263. }
  264. return current;
  265. }
  266. // Template to catch non-Tensor args (no-op that returns current best guess)
  267. template <typename T>
  268. inline at::ScalarType prioritize(
  269. at::ScalarType current,
  270. T nextArg,
  271. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  272. return current;
  273. }
  274. // Overload for the tail case.
  275. inline at::ScalarType promote_type(
  276. at::ScalarType current,
  277. c10::DeviceType device_type) {
  278. return current;
  279. }
  280. // Unpack args and determine if incoming lower_precision_fp tensors need to be
  281. // promoted to float32. Non-Tensor arguments are ignored.
  282. template <typename Arg0, typename... Args>
  283. inline at::ScalarType promote_type(
  284. at::ScalarType current,
  285. c10::DeviceType device_type,
  286. Arg0 arg0,
  287. Args... args) {
  288. auto new_current = prioritize(current, arg0, device_type);
  289. return promote_type(new_current, device_type, args...);
  290. }
  291. /****************************************************
  292. Logic to apply cached casting to any Tensor argument.
  293. ****************************************************/
  294. inline bool is_eligible(
  295. const Tensor& arg,
  296. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  297. return (
  298. arg.defined() && is_autocast_eligible(arg, device_type) &&
  299. (arg.scalar_type() != at::kDouble));
  300. }
  301. // Overload to catch Tensor args
  302. TORCH_API Tensor cached_cast(
  303. at::ScalarType to_type,
  304. const Tensor& arg,
  305. c10::DeviceType device_type = c10::DeviceType::CUDA);
  306. // Overload to process std::optional<Tensor>
  307. inline std::optional<Tensor> cached_cast(
  308. at::ScalarType to_type,
  309. const std::optional<Tensor>& arg,
  310. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  311. if (arg.has_value()) {
  312. return cached_cast(to_type, *arg, device_type);
  313. } else {
  314. return std::nullopt;
  315. }
  316. }
  317. // Overload to process TensorLists
  318. inline std::vector<Tensor> cached_cast(
  319. at::ScalarType to_type,
  320. const TensorList& arg,
  321. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  322. std::vector<Tensor> vec;
  323. vec.reserve(arg.size());
  324. for (const auto& t : arg) {
  325. vec.emplace_back(cached_cast(to_type, t, device_type));
  326. }
  327. return vec;
  328. }
  329. inline std::vector<Tensor> cached_cast(
  330. at::ScalarType to_type,
  331. const ITensorListRef& arg,
  332. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  333. std::vector<Tensor> vec;
  334. vec.reserve(arg.size());
  335. for (const auto& t : arg) {
  336. vec.emplace_back(cached_cast(to_type, t, device_type));
  337. }
  338. return vec;
  339. }
  340. // Template to catch non-Tensor args.
  341. template <typename T>
  342. inline T cached_cast(
  343. at::ScalarType to_type,
  344. T arg,
  345. c10::DeviceType device_type = c10::DeviceType::CUDA) {
  346. return arg;
  347. }
  348. /*******************************************************
  349. Logic to flip an output dtype flag.
  350. Keep it simple for now by assuming only one such flag is
  351. present in the argument list. If I ever need a function
  352. with more than flag I'll figure out something else.
  353. The policy is:
  354. If the user has explicitly specified a dtype, respect it.
  355. Otherwise, set it to the autocast type.
  356. ********************************************************/
  357. // Overload to catch dtype flags
  358. std::optional<ScalarType> inline set_opt_dtype(
  359. at::ScalarType to_type,
  360. const std::optional<ScalarType>& dtype) {
  361. return dtype.has_value() ? dtype : to_type;
  362. }
  363. // Template to catch other args
  364. template <typename T>
  365. inline T set_opt_dtype(at::ScalarType to_type, T arg) {
  366. return arg;
  367. }
  368. template <typename... Args>
  369. inline bool firstarg_is_eligible(
  370. c10::DeviceType device_type,
  371. const Tensor& arg,
  372. Args... args) {
  373. return is_eligible(arg, device_type);
  374. }
  375. template <typename... Args>
  376. inline at::ScalarType type_from_firstarg(
  377. c10::DeviceType device_type,
  378. at::ScalarType to_type,
  379. const Tensor& arg,
  380. Args... args) {
  381. return (is_eligible(arg, device_type) ? to_type : arg.scalar_type());
  382. }
  383. // Policies correspond to op categories that need code-divergent handling.
  384. // Wrapper templates below are specialized based on a policy template parameter.
  385. enum class CastPolicy : uint8_t {
  386. lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
  387. // running the op. Currently, lower_precision_fp is
  388. // fp16 for AutocastCUDA, and is defined by user
  389. // (default bf16) for AutocastCPU or other device.
  390. fp32, // Cast all inputs to at::kFloat before running the op.
  391. fp32_set_opt_dtype, // Treats functions (like softmax) that
  392. // 1. we'd like to run in fp32 and
  393. // 2. have a std::optional<ScalarType> arg that controls
  394. // the output type.
  395. // fp32_set_opt_dtype wrappers' policy is: if the output
  396. // type is already set, don't touch it, otherwise, set
  397. // it to at::kFloat.
  398. fp32_append_dtype, // Treats functions (like norm) that
  399. // 1. we'd like to run in fp32 and
  400. // 2. have some overloads that accept an output type and
  401. // other overloads that don't.
  402. // fp32_append_dtype wrappers wrap the overloads that don't
  403. // have an output dtype.
  404. // The wrapper policy is: append at::kFloat to the args,
  405. // and redispatch to the type-aware overload.
  406. promote, // Run in the widest dtype among several args.
  407. };
  408. /********************************************************************************************************
  409. Templates to provide wrapper functions
  410. I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to
  411. extract args and return type. (see also
  412. https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
  413. This strategy uses an exterior "WrapFunction" that extracts arguments on behalf
  414. of (in my case several specializations of) an interior "WrapFunction_".
  415. Interior WrapFunction_ specializations are defined for each CastPolicy.
  416. ********************************************************************************************************/
  417. // Base template for WrapFunction_, which is specialized to contain a "call"
  418. // method each CastPolicy
  419. template <
  420. CastPolicy policy,
  421. c10::DeviceType device_type,
  422. class Redispatch,
  423. Redispatch* F,
  424. class Ret,
  425. class ArgList>
  426. struct WrapFunction_ {};
  427. // CastPolicy::lower_precision_fp General_DeviceType
  428. template <
  429. c10::DeviceType device_type,
  430. class Redispatch,
  431. Redispatch* F,
  432. class Ret,
  433. class... Args>
  434. struct WrapFunction_<
  435. CastPolicy::lower_precision_fp,
  436. device_type,
  437. Redispatch,
  438. F,
  439. Ret,
  440. guts::typelist::typelist<Args...>> {
  441. static Ret call(Args... args) {
  442. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  443. get_autocast_dispatch_key_from_device_type(device_type));
  444. return (*F)(cached_cast(
  445. get_lower_precision_fp_from_device_type(device_type),
  446. args,
  447. device_type)...);
  448. }
  449. };
  450. // CastPolicy::fp32 General_DeviceType
  451. template <
  452. c10::DeviceType device_type,
  453. class Redispatch,
  454. Redispatch* F,
  455. class Ret,
  456. class... Args>
  457. struct WrapFunction_<
  458. CastPolicy::fp32,
  459. device_type,
  460. Redispatch,
  461. F,
  462. Ret,
  463. guts::typelist::typelist<Args...>> {
  464. static Ret call(Args... args) {
  465. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  466. get_autocast_dispatch_key_from_device_type(device_type));
  467. return (*F)(cached_cast(at::kFloat, args, device_type)...);
  468. }
  469. };
  470. // CastPolicy::fp32_set_opt_dtype General_DeviceType
  471. template <
  472. c10::DeviceType device_type,
  473. class Redispatch,
  474. Redispatch* F,
  475. class Ret,
  476. class... Args>
  477. struct WrapFunction_<
  478. CastPolicy::fp32_set_opt_dtype,
  479. device_type,
  480. Redispatch,
  481. F,
  482. Ret,
  483. guts::typelist::typelist<Args...>> {
  484. static Ret call(Args... args) {
  485. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  486. get_autocast_dispatch_key_from_device_type(device_type));
  487. if (firstarg_is_eligible(device_type, args...)) {
  488. return (*F)(set_opt_dtype(at::kFloat, args)...);
  489. } else {
  490. // If ineligible, calls F with unaltered args. Does not set opt dtype,
  491. // because setting opt dtype explicitly may interfere with internal
  492. // implicit promotion decisions.
  493. return (*F)(args...);
  494. }
  495. }
  496. };
  497. // CastPolicy::fp32_append_dtype General_DeviceType
  498. template <
  499. c10::DeviceType device_type,
  500. class Redispatch,
  501. Redispatch* F,
  502. class Ret,
  503. class... Args>
  504. struct WrapFunction_<
  505. CastPolicy::fp32_append_dtype,
  506. device_type,
  507. Redispatch,
  508. F,
  509. Ret,
  510. guts::typelist::typelist<Args...>> {
  511. static Ret call(Args... args) {
  512. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  513. get_autocast_dispatch_key_from_device_type(device_type));
  514. at::ScalarType out_type =
  515. type_from_firstarg(device_type, at::kFloat, args...);
  516. return (*F)(args..., out_type);
  517. }
  518. };
  519. // CastPolicy::promote General_DeviceType
  520. template <
  521. c10::DeviceType device_type,
  522. class Redispatch,
  523. Redispatch* F,
  524. class Ret,
  525. class... Args>
  526. struct WrapFunction_<
  527. CastPolicy::promote,
  528. device_type,
  529. Redispatch,
  530. F,
  531. Ret,
  532. guts::typelist::typelist<Args...>> {
  533. static Ret call(Args... args) {
  534. c10::impl::ExcludeDispatchKeyGuard no_autocast(
  535. get_autocast_dispatch_key_from_device_type(device_type));
  536. auto to_type = promote_type(
  537. get_lower_precision_fp_from_device_type(device_type),
  538. device_type,
  539. args...);
  540. return (*F)(cached_cast(to_type, args, device_type)...);
  541. }
  542. };
  543. // Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating
  544. // core/boxing/impl/WrapFunctionIntoFunctor.h)
  545. template <
  546. CastPolicy policy,
  547. c10::DeviceType device_type,
  548. class Registered, // The signature for which we're registering. The
  549. // dispatcher's calling code invokes our registered
  550. // functions with arguments matching Registered, so we
  551. // register WrapFunction_::call methods with a matching
  552. // signature to properly field those arguments.
  553. // guts::function_traits below extracts return_type and
  554. // parameter_types from Registered, which WrapFunction_
  555. // templates above use to declare their call methods.
  556. class Redispatch, // The signature for the function we're redispatching to.
  557. // In most cases this is the same as Registered, but for
  558. // some ops (for example, ops where we append a dtype)
  559. // it's useful to redispatch to a function with a
  560. // different signature.
  561. Redispatch* F> // The actual function we're redispatching to.
  562. struct WrapFunction final {
  563. using type = WrapFunction_<
  564. policy,
  565. device_type,
  566. Redispatch,
  567. F,
  568. typename guts::function_traits<Registered>::return_type,
  569. typename guts::function_traits<Registered>::parameter_types>;
  570. };
  571. /*****************************************************************************************************************
  572. This section performs load-time registration for autocast wrappers.
  573. It's debatable at what level operations should be patched. We'd like casts to
  574. be autograd-exposed and precede autograd history recording, so that for
  575. lower_precision_fp ops, input tensors are saved for backward in
  576. lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp
  577. can significantly reduce a model's memory footprint.
  578. Option 1 (strawman): Patch only at the level of explicit calls into
  579. cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are
  580. guaranteed to use Tensor Cores, therefore they're the ones that will benefit
  581. most from lower_precision_fp. Potential pitfall: convolutions (and other ops)
  582. are wrapped in several layers of at::* calls. If one of those happens to record
  583. autograd history, then we've lost the opportunity to save inputs in
  584. lower_precision_fp.
  585. Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd
  586. history recording can't sneak in ahead of autocast. This mirrors Apex most
  587. closely.
  588. I think Option 2 is the right answer for all ops, not just convolutions. Option
  589. 2 is what I implement here.
  590. *****************************************************************************************************************/
  591. /********************************************************************************************************************
  592. Explicit registration for out-of-place ops
  593. The stuff below could be codegenned. Ed said
  594. > you are going to have to write the function definition at some point, I
  595. wouldn't try to get clever about it Therefore, for the moment, this is all
  596. copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
  597. ********************************************************************************************************************/
  598. } // namespace at::autocast
  599. #define ADD_NS(RAW_OP) at::RAW_OP
  600. #define _KERNEL_OVERLOAD_NARG_IMPL(_0, _1, _2, N, ...) N
  601. #define _KERNEL_OVERLOAD_NARG(...) \
  602. C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1))
  603. // Common cases where registration signature matches redispatch signature
  604. // (that's why SIGNATURE is repeated in the WrapFunction instantiation)
  605. #define KERNEL1(DISPATCHKEY, OP, POLICY) \
  606. m.impl( \
  607. TORCH_SELECTIVE_NAME("aten::" #OP), \
  608. &::at::autocast::WrapFunction< \
  609. ::at::autocast::CastPolicy::POLICY, \
  610. DISPATCHKEY, \
  611. decltype(ATEN_FN(OP)), \
  612. decltype(ATEN_FN(OP)), \
  613. &ATEN_FN(OP)>::type::call);
  614. #define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY) \
  615. m.impl( \
  616. TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \
  617. &::at::autocast::WrapFunction< \
  618. ::at::autocast::CastPolicy::POLICY, \
  619. DISPATCHKEY, \
  620. decltype(ATEN_FN2(OP, OVERLOAD)), \
  621. decltype(ATEN_FN2(OP, OVERLOAD)), \
  622. &ATEN_FN2(OP, OVERLOAD)>::type::call);
  623. #define _KERNEL_DISPATCH(DISPATCHKEY, NARG, ...) \
  624. C10_CONCATENATE(KERNEL, NARG)(DISPATCHKEY, __VA_ARGS__)
  625. #define _KERNEL_IMPL(DISPATCHKEY, ...) \
  626. _KERNEL_DISPATCH(DISPATCHKEY, _KERNEL_OVERLOAD_NARG(__VA_ARGS__), __VA_ARGS__)
  627. // It will dispatch to KERNEL1 or KERNEL2 based on its inputs.
  628. #define KERNEL(DISPATCHKEY, ...) _KERNEL_IMPL(DISPATCHKEY, __VA_ARGS__)
  629. // Less-common but still useful case: redispatching to a function
  630. // with a new signature (e.g. appending a dtype)
  631. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  632. DISPATCHKEY, \
  633. REDISPATCH_FUNC, \
  634. REGISTER_NAME, \
  635. REGISTER_SIGNATURE, \
  636. REDISPATCH_SIGNATURE, \
  637. POLICY) \
  638. m.impl( \
  639. TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
  640. &::at::autocast::WrapFunction< \
  641. ::at::autocast::CastPolicy::POLICY, \
  642. DISPATCHKEY, \
  643. REGISTER_SIGNATURE, \
  644. REDISPATCH_SIGNATURE, \
  645. &REDISPATCH_FUNC>::type::call);
  646. // KERNEL_CPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU
  647. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCPU
  648. #define KERNEL_CPU(...) KERNEL(c10::DeviceType::CPU, __VA_ARGS__)
  649. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \
  650. REDISPATCH_FUNC, \
  651. REGISTER_NAME, \
  652. REGISTER_SIGNATURE, \
  653. REDISPATCH_SIGNATURE, \
  654. POLICY) \
  655. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  656. c10::DeviceType::CPU, \
  657. REDISPATCH_FUNC, \
  658. REGISTER_NAME, \
  659. REGISTER_SIGNATURE, \
  660. REDISPATCH_SIGNATURE, \
  661. POLICY)
  662. // KERNEL_CUDA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA
  663. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCUDA
  664. #define KERNEL_CUDA(...) KERNEL(c10::DeviceType::CUDA, __VA_ARGS__)
  665. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \
  666. REDISPATCH_FUNC, \
  667. REGISTER_NAME, \
  668. REGISTER_SIGNATURE, \
  669. REDISPATCH_SIGNATURE, \
  670. POLICY) \
  671. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  672. c10::DeviceType::CUDA, \
  673. REDISPATCH_FUNC, \
  674. REGISTER_NAME, \
  675. REGISTER_SIGNATURE, \
  676. REDISPATCH_SIGNATURE, \
  677. POLICY)
  678. // KERNEL_MTIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA
  679. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMTIA
  680. #define KERNEL_MTIA(...) KERNEL(c10::DeviceType::MTIA, __VA_ARGS__)
  681. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA( \
  682. REDISPATCH_FUNC, \
  683. REGISTER_NAME, \
  684. REGISTER_SIGNATURE, \
  685. REDISPATCH_SIGNATURE, \
  686. POLICY) \
  687. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  688. c10::DeviceType::MTIA, \
  689. REDISPATCH_FUNC, \
  690. REGISTER_NAME, \
  691. REGISTER_SIGNATURE, \
  692. REDISPATCH_SIGNATURE, \
  693. POLICY)
  694. // KERNEL_MAIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA
  695. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMAIA
  696. #define KERNEL_MAIA(...) KERNEL(c10::DeviceType::MAIA, __VA_ARGS__)
  697. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA( \
  698. REDISPATCH_FUNC, \
  699. REGISTER_NAME, \
  700. REGISTER_SIGNATURE, \
  701. REDISPATCH_SIGNATURE, \
  702. POLICY) \
  703. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  704. c10::DeviceType::MAIA, \
  705. REDISPATCH_FUNC, \
  706. REGISTER_NAME, \
  707. REGISTER_SIGNATURE, \
  708. REDISPATCH_SIGNATURE, \
  709. POLICY)
  710. // KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU
  711. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU
  712. #define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__)
  713. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU( \
  714. REDISPATCH_FUNC, \
  715. REGISTER_NAME, \
  716. REGISTER_SIGNATURE, \
  717. REDISPATCH_SIGNATURE, \
  718. POLICY) \
  719. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  720. c10::DeviceType::XPU, \
  721. REDISPATCH_FUNC, \
  722. REGISTER_NAME, \
  723. REGISTER_SIGNATURE, \
  724. REDISPATCH_SIGNATURE, \
  725. POLICY)
  726. // KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
  727. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1
  728. #define KERNEL_PRIVATEUSEONE(...) \
  729. KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__)
  730. #define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
  731. REDISPATCH_FUNC, \
  732. REGISTER_NAME, \
  733. REGISTER_SIGNATURE, \
  734. REDISPATCH_SIGNATURE, \
  735. POLICY) \
  736. KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \
  737. c10::DeviceType::PrivateUse1, \
  738. REDISPATCH_FUNC, \
  739. REGISTER_NAME, \
  740. REGISTER_SIGNATURE, \
  741. REDISPATCH_SIGNATURE, \
  742. POLICY)
  743. // KERNEL_MPS
  744. // registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMPS
  745. #define KERNEL_MPS(...) KERNEL(c10::DeviceType::MPS, __VA_ARGS__)
  746. // Op lists for different policies.
  747. // To make sure other backends can reuse the policy op list.
  748. #define AT_FORALL_LOWER_PRECISION_FP(_) \
  749. _(_convolution, deprecated) \
  750. _(_convolution) \
  751. _(conv1d) \
  752. _(conv2d) \
  753. _(conv3d) \
  754. _(conv_tbc) \
  755. _(conv_transpose1d) \
  756. _(conv_transpose2d, input) \
  757. _(conv_transpose3d, input) \
  758. _(convolution) \
  759. _(prelu) \
  760. _(addmm) \
  761. _(addmv) \
  762. _(addr) \
  763. _(matmul) \
  764. _(einsum) \
  765. _(mm) \
  766. _(mv) \
  767. _(linalg_vecdot) \
  768. _(linear) \
  769. _(addbmm) \
  770. _(baddbmm) \
  771. _(bmm) \
  772. _(chain_matmul) \
  773. _(linalg_multi_dot) \
  774. _(_thnn_fused_lstm_cell) \
  775. _(_thnn_fused_gru_cell) \
  776. _(lstm_cell) \
  777. _(gru_cell) \
  778. _(rnn_tanh_cell) \
  779. _(rnn_relu_cell) \
  780. _(_scaled_dot_product_flash_attention) \
  781. _(scaled_dot_product_attention)
  782. #define AT_FORALL_FP32(_) \
  783. _(acos) \
  784. _(asin) \
  785. _(cosh) \
  786. _(erfinv) \
  787. _(exp) \
  788. _(expm1) \
  789. _(log) \
  790. _(log10) \
  791. _(log2) \
  792. _(log1p) \
  793. _(reciprocal) \
  794. _(rsqrt) \
  795. _(sinh) \
  796. _(tan) \
  797. _(pow, Tensor_Scalar) \
  798. _(pow, Tensor_Tensor) \
  799. _(pow, Scalar) \
  800. _(softplus) \
  801. _(layer_norm) \
  802. _(native_layer_norm) \
  803. _(group_norm) \
  804. _(frobenius_norm, dim) \
  805. _(nuclear_norm) \
  806. _(nuclear_norm, dim) \
  807. _(cosine_similarity) \
  808. _(poisson_nll_loss) \
  809. _(cosine_embedding_loss) \
  810. _(nll_loss) \
  811. _(nll_loss2d) \
  812. _(hinge_embedding_loss) \
  813. _(kl_div) \
  814. _(l1_loss) \
  815. _(smooth_l1_loss) \
  816. _(huber_loss) \
  817. _(mse_loss) \
  818. _(margin_ranking_loss) \
  819. _(multilabel_margin_loss) \
  820. _(soft_margin_loss) \
  821. _(triplet_margin_loss) \
  822. _(multi_margin_loss) \
  823. _(binary_cross_entropy_with_logits) \
  824. _(dist) \
  825. _(pdist) \
  826. _(cdist) \
  827. _(renorm) \
  828. _(logsumexp) \
  829. _(upsample_nearest1d) \
  830. _(_upsample_nearest_exact1d) \
  831. _(upsample_nearest2d) \
  832. _(_upsample_nearest_exact2d) \
  833. _(upsample_nearest3d) \
  834. _(_upsample_nearest_exact3d) \
  835. _(upsample_linear1d) \
  836. _(upsample_bilinear2d) \
  837. _(_upsample_bilinear2d_aa) \
  838. _(upsample_trilinear3d) \
  839. _(upsample_bicubic2d) \
  840. _(_upsample_bicubic2d_aa)
  841. #define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
  842. _(prod) \
  843. _(prod, dim_int) \
  844. _(prod, dim_Dimname) \
  845. _(softmax, int) \
  846. _(softmax, Dimname) \
  847. _(log_softmax, int) \
  848. _(log_softmax, Dimname) \
  849. _(cumprod) \
  850. _(cumprod, dimname) \
  851. _(cumsum) \
  852. _(cumsum, dimname) \
  853. _(linalg_vector_norm) \
  854. _(linalg_matrix_norm) \
  855. _(linalg_matrix_norm, str_ord) \
  856. _(sum) \
  857. _(sum, dim_IntList) \
  858. _(sum, dim_DimnameList)
  859. #define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
  860. _(ADD_NS(norm), \
  861. "norm.Scalar", \
  862. Tensor(const Tensor&, const Scalar&), \
  863. Tensor(const Tensor&, const std::optional<Scalar>&, ScalarType), \
  864. fp32_append_dtype) \
  865. _(ADD_NS(norm), \
  866. "norm.ScalarOpt_dim", \
  867. Tensor(const Tensor&, const std::optional<Scalar>&, IntArrayRef, bool), \
  868. Tensor( \
  869. const Tensor&, \
  870. const std::optional<Scalar>&, \
  871. IntArrayRef, \
  872. bool, \
  873. ScalarType), \
  874. fp32_append_dtype) \
  875. _(ADD_NS(norm), \
  876. "norm.names_ScalarOpt_dim", \
  877. Tensor(const Tensor&, const std::optional<Scalar>&, DimnameList, bool), \
  878. Tensor( \
  879. const Tensor&, \
  880. const std::optional<Scalar>&, \
  881. DimnameList, \
  882. bool, \
  883. ScalarType), \
  884. fp32_append_dtype)
  885. #define AT_FORALL_PROMOTE(_) \
  886. _(addcdiv) \
  887. _(addcmul) \
  888. _(atan2) \
  889. _(bilinear) \
  890. _(cross) \
  891. _(dot) \
  892. _(vdot) \
  893. _(grid_sampler) \
  894. _(index_put) \
  895. _(tensordot) \
  896. _(scatter_add)
  897. #else
  898. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  899. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)