Context.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/BlasBackend.h>
  4. #include <ATen/CPUGeneratorImpl.h>
  5. #include <ATen/DeviceAccelerator.h>
  6. #include <ATen/LinalgBackend.h>
  7. #include <ATen/ROCmFABackend.h>
  8. #include <ATen/SDPBackend.h>
  9. #include <ATen/core/ATenGeneral.h>
  10. #include <ATen/core/DeprecatedTypeProperties.h>
  11. #include <ATen/core/Generator.h>
  12. #include <ATen/core/LegacyTypeDispatch.h>
  13. #include <ATen/detail/AcceleratorHooksInterface.h>
  14. #include <ATen/detail/CUDAHooksInterface.h>
  15. #include <ATen/detail/HIPHooksInterface.h>
  16. #include <ATen/detail/HPUHooksInterface.h>
  17. #include <ATen/detail/IPUHooksInterface.h>
  18. #include <ATen/detail/MAIAHooksInterface.h>
  19. #include <ATen/detail/MPSHooksInterface.h>
  20. #include <ATen/detail/MTIAHooksInterface.h>
  21. #include <ATen/detail/PrivateUse1HooksInterface.h>
  22. #include <ATen/detail/XLAHooksInterface.h>
  23. #include <ATen/detail/XPUHooksInterface.h>
  24. #include <c10/core/QEngine.h>
  25. #include <c10/core/impl/DeviceGuardImplInterface.h>
  26. #include <c10/util/CallOnce.h>
  27. #include <c10/util/Exception.h>
  28. #include <c10/util/env.h>
  29. #include <c10/util/hash.h>
  30. #include <c10/util/irange.h>
  31. #include <cstdint>
  32. #include <map>
  33. #include <mutex>
  34. #include <unordered_map>
  35. namespace at {
  36. class Tensor;
  37. enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
  38. enum class CuBLASReductionOption : uint8_t {
  39. AllowReducedPrecisionWithSplitK = 0,
  40. DisallowReducedPrecisionAllowSplitK = 1,
  41. DisallowReducedPrecisionDisallowSplitK = 2,
  42. };
  43. enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
  44. enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
  45. enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
  46. TORCH_API Float32Backend str2backend(const std::string& name);
  47. TORCH_API Float32Op str2op(const std::string& name);
  48. TORCH_API Float32Precision str2precision(const std::string& name);
  49. TORCH_API std::string precision2str(Float32Precision prec);
  50. class TORCH_API Context {
  51. public:
  52. Context();
  53. const Generator& defaultGenerator(Device device) {
  54. c10::DeviceType device_type = device.type();
  55. lazyInitDevice(device_type);
  56. if (device_type == at::kCPU) {
  57. return at::detail::getDefaultCPUGenerator();
  58. } else {
  59. return getAcceleratorHooksInterface(device_type)
  60. .getDefaultGenerator(device.index());
  61. }
  62. }
  63. const AcceleratorHooksInterface& getAcceleratorHooksInterface(
  64. std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
  65. if (!opt_device_type.has_value()) {
  66. opt_device_type = at::getAccelerator(true);
  67. }
  68. if (opt_device_type == at::kCUDA) {
  69. return at::detail::getCUDAHooks();
  70. } else if (opt_device_type == at::kXPU) {
  71. return at::detail::getXPUHooks();
  72. } else if (opt_device_type == at::kMPS) {
  73. return at::detail::getMPSHooks();
  74. } else if (opt_device_type == at::kPrivateUse1) {
  75. return at::detail::getPrivateUse1Hooks();
  76. } else if (opt_device_type == at::kMTIA) {
  77. return at::detail::getMTIAHooks();
  78. } else if (opt_device_type == at::kHIP) {
  79. return at::detail::getHIPHooks();
  80. } else if (opt_device_type == at::kHPU) {
  81. return at::detail::getHPUHooks();
  82. } else if (opt_device_type == at::kXLA) {
  83. return at::detail::getXLAHooks();
  84. } else {
  85. TORCH_CHECK(
  86. false,
  87. opt_device_type.has_value()
  88. ? c10::DeviceTypeName(opt_device_type.value())
  89. : "None",
  90. " device type not an accelerator.");
  91. }
  92. }
  93. Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
  94. lazyInitDevice(device_type);
  95. if (device_type == at::kCPU) {
  96. return c10::DeviceType::CPU;
  97. } else {
  98. return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
  99. }
  100. }
  101. bool isPinnedPtr(
  102. const void* data,
  103. std::optional<c10::DeviceType> device_type = std::nullopt) {
  104. auto opt_device_type =
  105. device_type.has_value() ? device_type : at::getAccelerator();
  106. if (!opt_device_type.has_value() || // there is no accelerator
  107. !at::isAccelerator(
  108. opt_device_type.value())) { // passed device not an accelerator
  109. return false;
  110. }
  111. if (!init_[static_cast<int8_t>(opt_device_type.value())].test_once()) {
  112. // If the device is not initialized, no pointer can be pinned for it
  113. return false;
  114. }
  115. return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
  116. }
  117. Allocator* getPinnedMemoryAllocator(
  118. std::optional<c10::DeviceType> device_type = std::nullopt) {
  119. auto opt_device_type =
  120. device_type.has_value() ? device_type : at::getAccelerator();
  121. if (opt_device_type) {
  122. lazyInitDevice(opt_device_type.value());
  123. }
  124. return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
  125. }
  126. void lazyInitDevice(c10::DeviceType device_type) {
  127. if (device_type != at::kCPU) {
  128. c10::call_once(init_[static_cast<int8_t>(device_type)], [&] {
  129. getAcceleratorHooksInterface(device_type).init();
  130. });
  131. }
  132. }
  133. static bool hasOpenMP();
  134. static bool hasMKL();
  135. static bool hasKleidiAI();
  136. static bool hasLAPACK();
  137. static bool hasMKLDNN();
  138. static bool ckSupported();
  139. static bool hasEigenSparse();
  140. static bool hasMAGMA() {
  141. return detail::getCUDAHooks().hasMAGMA();
  142. }
  143. static bool hasCUDA() {
  144. return detail::getCUDAHooks().hasCUDA();
  145. }
  146. static bool hasMTIA() {
  147. return detail::getMTIAHooks().hasMTIA();
  148. }
  149. static bool hasCUDART() {
  150. return detail::getCUDAHooks().hasCUDART();
  151. }
  152. static long versionCUDART() {
  153. return detail::getCUDAHooks().versionCUDART();
  154. }
  155. static bool hasCuDNN() {
  156. return detail::getCUDAHooks().hasCuDNN();
  157. }
  158. static long versionCuDNN() {
  159. return detail::getCUDAHooks().versionCuDNN();
  160. }
  161. static long versionRuntimeCuDNN() {
  162. return detail::getCUDAHooks().versionRuntimeCuDNN();
  163. }
  164. static long versionCuDNNFrontend() {
  165. return detail::getCUDAHooks().versionCuDNNFrontend();
  166. }
  167. static bool hasCuSOLVER() {
  168. return detail::getCUDAHooks().hasCuSOLVER();
  169. }
  170. static bool hasCuBLASLt() {
  171. return detail::getCUDAHooks().hasCuBLASLt();
  172. }
  173. static bool hasROCM() {
  174. return detail::getCUDAHooks().hasROCM();
  175. }
  176. static bool hasCKSDPA() {
  177. return detail::getCUDAHooks().hasCKSDPA();
  178. }
  179. static bool hasCKGEMM() {
  180. return detail::getCUDAHooks().hasCKGEMM();
  181. }
  182. static bool hasHIP() {
  183. return detail::getHIPHooks().hasHIP();
  184. }
  185. static bool hasMPS() {
  186. return detail::getMPSHooks().hasMPS();
  187. }
  188. static bool hasIPU() {
  189. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
  190. }
  191. static bool hasXLA() {
  192. return detail::getXLAHooks().hasXLA();
  193. }
  194. static bool hasXPU() {
  195. return detail::getXPUHooks().hasXPU();
  196. }
  197. static bool hasLazy() {
  198. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
  199. }
  200. static bool hasMAIA() {
  201. return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
  202. }
  203. static bool hasHPU() {
  204. return detail::getHPUHooks().hasHPU();
  205. }
  206. static const at::cuda::NVRTC& getNVRTC() {
  207. return detail::getCUDAHooks().nvrtc();
  208. }
  209. static const at::xpu::LevelZero& getLevelZero() {
  210. return detail::getXPUHooks().level_zero();
  211. }
  212. static bool setFlushDenormal(bool on);
  213. // NB: This method is *purely* whether or not a user requested
  214. // that CuDNN was enabled, it doesn't actually say anything about
  215. // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
  216. // to test this instead
  217. bool userEnabledCuDNN() const;
  218. void setUserEnabledCuDNN(bool e);
  219. bool userEnabledMkldnn() const;
  220. void setUserEnabledMkldnn(bool e);
  221. bool benchmarkCuDNN() const;
  222. void setBenchmarkCuDNN(bool /*b*/);
  223. int benchmarkLimitCuDNN() const;
  224. void setBenchmarkLimitCuDNN(int /*b*/);
  225. bool immediateMiopen() const;
  226. void setImmediateMiopen(bool /*b*/);
  227. bool deterministicCuDNN() const;
  228. void setDeterministicCuDNN(bool /*b*/);
  229. bool deterministicMkldnn() const;
  230. void setDeterministicMkldnn(bool /*b*/);
  231. bool userEnabledNNPACK() const;
  232. void setUserEnabledNNPACK(bool e);
  233. // Note [Disabling Fused SDP Kernels]
  234. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  235. // Flash and Memory Efficient SDP kernels are enabled by default.
  236. // However, they can be disabled by setting
  237. // at::globalContext().setUserEnabledFlashSDP(false) flag.
  238. // This is useful for debugging purposes. For example, if you want to
  239. // compare the performance of the flash SDP kernels with the unfused
  240. // kernel, you can disable the flash SDP kernels. By disabling
  241. // the math SDP kernel, you can force your code to use flash kernels.
  242. // The math SDP kernel can be disabled by setting
  243. // at::globalContext().setUserEnabledMathSDP(false) flag.
  244. void setSDPPriorityOrder(const std::vector<int64_t>& order);
  245. std::array<at::SDPBackend, at::num_sdp_backends> sDPPriorityOrder();
  246. void setSDPUseFlash(bool /*e*/);
  247. bool userEnabledFlashSDP() const;
  248. void setSDPUseFA3(bool /*e*/);
  249. bool userEnabledFA3SDP() const;
  250. void setSDPUseMemEfficient(bool /*e*/);
  251. bool userEnabledMemEfficientSDP() const;
  252. void setSDPUseMath(bool /*e*/);
  253. bool userEnabledMathSDP() const;
  254. void setSDPUseCuDNN(bool /*e*/);
  255. bool userEnabledCuDNNSDP() const;
  256. void setAllowFP16BF16ReductionMathSDP(bool /*e*/);
  257. bool allowFP16BF16ReductionMathSDP() const;
  258. void setSDPUseOverrideable(bool /*e*/);
  259. bool userEnabledOverrideableSDP() const;
  260. at::LinalgBackend linalgPreferredBackend() const;
  261. void setLinalgPreferredBackend(at::LinalgBackend /*b*/);
  262. at::BlasBackend blasPreferredBackend();
  263. void setBlasPreferredBackend(at::BlasBackend /*b*/);
  264. at::ROCmFABackend getROCmFAPreferredBackend();
  265. void setROCmFAPreferredBackend(at::ROCmFABackend /*b*/);
  266. // Note [Enabling Deterministic Operations]
  267. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  268. // Operations in PyTorch that normally act nondeterministically, but have an
  269. // alternate deterministic implementation, should satisfy the following
  270. // requirements:
  271. //
  272. // * Include this comment: "See Note [Enabling Deterministic Operations]"
  273. //
  274. // * Check the value of `at::globalContext().deterministicAlgorithms()` to
  275. // toggle
  276. // between nondeterministic and deterministic implementations.
  277. //
  278. // * Have an entry in the list of PyTorch operations that toggle between
  279. // nondeterministic
  280. // and deterministic implementations, in the docstring of
  281. // `use_deterministic_algorithms()` in torch/__init__.py
  282. //
  283. // `example_func()` below shows an example of toggling between
  284. // nondeterministic and deterministic implementations:
  285. //
  286. // void example_func() {
  287. // // See Note [Enabling Deterministic Operations]
  288. // if (at::globalContext().deterministicAlgorithms()) {
  289. // example_func_deterministic();
  290. // } else {
  291. // example_func_nondeterministic();
  292. // }
  293. // }
  294. bool deterministicAlgorithms() const;
  295. bool deterministicAlgorithmsWarnOnly() const;
  296. void setDeterministicAlgorithms(bool /*b*/, bool /*warn_only*/);
  297. bool deterministicFillUninitializedMemory() const;
  298. void setDeterministicFillUninitializedMemory(bool /*b*/);
  299. // Note [Writing Nondeterministic Operations]
  300. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  301. // Operations in PyTorch that act nondeterministically and do not have an
  302. // alternate deterministic implementation should satisfy the following
  303. // requirements:
  304. //
  305. // * Include this comment: "See Note [Writing Nondeterministic Operations]"
  306. //
  307. // * Include a comment explaining why the operation is nondeterministic.
  308. //
  309. // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
  310. // of the time, this should be accomplished by calling
  311. // `at::globalContext().alertNotDeterminstic().
  312. //
  313. // * Have an entry in the list of nondeterministic PyTorch operations in the
  314. // docstring of `use_deterministic_algorithms()` in torch/__init__.py
  315. //
  316. // * Have a test function in `test/test_torch.py` whose name begins with
  317. // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
  318. // configuration is the reason for nondeterminism, the operation should be
  319. // included in the `test_cublas_config_nondeterministic_alert` test. Any new
  320. // tests should ideally follow a pattern similar to the existing ones.
  321. //
  322. // `example_func()` below shows an example of the comments and error-throwing
  323. // code for a nondeterministic operation:
  324. //
  325. // void example_func() {
  326. // // See Note [Writing Nondeterministic Operations]
  327. // // Nondeterministic because <reason>
  328. // at::globalContext().alertNondeterministic("example_func");
  329. // ...
  330. // }
  331. // Throws an error if `Context::deterministicAlgorithms()` is true
  332. static void alertNotDeterministic(std::string_view const& caller);
  333. void setFloat32MatmulPrecision(const std::string& s);
  334. void setFloat32Precision(
  335. Float32Backend backend,
  336. Float32Op op,
  337. Float32Precision p);
  338. bool allowTF32CuDNN(std::optional<Float32Op> op = std::nullopt) const;
  339. void setAllowTF32CuDNN(bool /*b*/);
  340. bool allowTF32OneDNN() const;
  341. void setAllowTF32OneDNN(bool /*b*/);
  342. bool allowTF32CuBLAS() const;
  343. void setAllowTF32CuBLAS(bool /*b*/);
  344. Float32MatmulPrecision float32MatmulPrecision() const;
  345. Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
  346. CuBLASReductionOption allowFP16ReductionCuBLAS() const;
  347. void setAllowFP16ReductionCuBLAS(
  348. bool allow_reduced_precision,
  349. bool allow_splitk = true);
  350. CuBLASReductionOption allowBF16ReductionCuBLAS() const;
  351. void setAllowBF16ReductionCuBLAS(
  352. bool allow_reduced_precision,
  353. bool allow_splitk = true);
  354. bool allowFP16AccumulationCuBLAS() const;
  355. void setAllowFP16AccumulationCuBLAS(bool /*b*/);
  356. bool rocmAllowGroupGemmCk() const;
  357. // Matmuls can use a so-called "persistent" kernel which launches one CUDA
  358. // block for each SM on the GPU, and each block then iterates over multiple
  359. // output tiles. This allows to use software pipelining to hide the begin/end
  360. // latencies (e.g., epilogue), especially when only one tile fits per SM.
  361. // However, if some SMs are busy (e.g., with a background NCCL kernel), the
  362. // matmul's blocks will be scheduled in two waves and, in the absence of some
  363. // smart load balancing, the kernel will take twice as long. This flag allows
  364. // to make matmuls target only a subset of the SMs, so they can fully schedule
  365. // even next to a comms kernel, and only be a few percent slower.
  366. std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
  367. void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t> /*c*/);
  368. at::QEngine qEngine() const;
  369. void setQEngine(at::QEngine e);
  370. static const std::vector<at::QEngine>& supportedQEngines();
  371. static bool isXNNPACKAvailable();
  372. void setCheckSparseTensorInvariants(std::optional<bool> e);
  373. std::optional<bool> checkSparseTensorInvariants(
  374. bool warn_when_uninitialized = false) const;
  375. // This method is used to release the original weight after pre-packing.
  376. // It should be called once before loading/running the model.
  377. // NB: By default it is set to true for mobile builds.
  378. void setReleaseWeightsWhenPrepacking(bool e);
  379. bool releaseWeightsWhenPrepacking() const;
  380. void setDisplayVmapFallbackWarnings(bool enabled);
  381. bool areVmapFallbackWarningsEnabled() const;
  382. void setWarnOnAccumulateGradStreamMismatch(bool enabled);
  383. bool warnOnAccumulateGradStreamMismatch() const;
  384. bool isDefaultMobileCPUAllocatorSet();
  385. void setDefaultMobileCPUAllocator();
  386. void unsetDefaultMobileCPUAllocator();
  387. bool allowFP16ReductionCPU() const;
  388. void setAllowFP16ReductionCPU(bool /*b*/);
  389. // Preserved for BC
  390. void lazyInitCUDA() {
  391. TORCH_WARN_DEPRECATION(
  392. "lazyInitCUDA is deprecated. Please use lazyInitDevice(at::kCUDA) instead.")
  393. lazyInitDevice(at::kCUDA);
  394. }
  395. void lazyInitHIP() {
  396. TORCH_WARN_DEPRECATION(
  397. "lazyInitHIP is deprecated. Please use lazyInitDevice(at::kHIP) instead.")
  398. lazyInitDevice(at::kHIP);
  399. }
  400. void lazyInitXPU() {
  401. TORCH_WARN_DEPRECATION(
  402. "lazyInitXPU is deprecated. Please use lazyInitDevice(at::kXPU) instead.")
  403. lazyInitDevice(at::kXPU);
  404. }
  405. void lazyInitMTIA() {
  406. TORCH_WARN_DEPRECATION(
  407. "lazyInitMTIA is deprecated. Please use lazyInitDevice(at::kMTIA) instead.")
  408. lazyInitDevice(at::kMTIA);
  409. }
  410. void lazyInitPrivateUse1() {
  411. TORCH_WARN_DEPRECATION(
  412. "lazyInitPrivateUse1 is deprecated. Please use lazyInitDevice(at::kPrivateUse1) instead.")
  413. lazyInitDevice(at::kPrivateUse1);
  414. }
  415. private:
  416. std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES> init_;
  417. bool enabled_cudnn = true;
  418. bool deterministic_cudnn = false;
  419. bool deterministic_mkldnn = false;
  420. bool _deterministic_algorithms = false;
  421. bool _deterministic_algorithms_warn_only = false;
  422. bool _deterministic_fill_uninitialized_memory = true;
  423. std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
  424. at::SDPBackend::flash_attention,
  425. at::SDPBackend::efficient_attention,
  426. at::SDPBackend::math,
  427. at::SDPBackend::cudnn_attention,
  428. at::SDPBackend::overrideable};
  429. bool enabled_flashSDP = true;
  430. bool enabled_fa3SDP = false;
  431. bool enabled_mem_efficientSDP = true;
  432. bool enabled_mathSDP = true;
  433. bool enabled_cudnnSDP = true;
  434. bool enabled_overrideable = true;
  435. bool allow_fp16_bf16_reduction_mathSDP = false;
  436. bool benchmark_cudnn = false;
  437. bool immediate_miopen = false;
  438. Float32MatmulPrecision float32_matmul_precision =
  439. c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
  440. ? at::Float32MatmulPrecision::HIGH
  441. : at::Float32MatmulPrecision::HIGHEST;
  442. int benchmark_limit_cudnn = 10;
  443. bool allow_tf32_cudnn = true;
  444. CuBLASReductionOption allow_fp16_reduction_cublas =
  445. CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
  446. CuBLASReductionOption allow_bf16_reduction_cublas =
  447. CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
  448. bool allow_fp16_accumulation_cublas = false;
  449. std::optional<int32_t> sm_carveout = std::nullopt;
  450. bool enabled_mkldnn = true;
  451. bool allow_tf32_onednn = false;
  452. bool enabled_nnpack = true;
  453. at::LinalgBackend linalg_preferred_backend =
  454. (c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true ||
  455. c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias
  456. ? at::LinalgBackend::Cusolver
  457. : at::LinalgBackend::Default;
  458. at::BlasBackend blas_preferred_backend =
  459. (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
  460. c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) // alias
  461. ? at::BlasBackend::Cublaslt
  462. : at::BlasBackend::Default;
  463. at::ROCmFABackend rocm_fa_preferred_backend =
  464. c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true
  465. ? at::ROCmFABackend::Ck
  466. : at::ROCmFABackend::Default;
  467. #ifdef C10_MOBILE
  468. bool release_original_weights = true;
  469. #else
  470. bool release_original_weights = false;
  471. #endif
  472. bool display_vmap_fallback_warnings_ = false;
  473. bool warn_on_accumulate_grad_stream_mismatch_ = true;
  474. std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
  475. std::optional<bool> enable_sparse_tensor_invariant_checks = std::nullopt;
  476. bool allow_fp16_reduction_cpu = false;
  477. using Key = std::pair<Float32Backend, Float32Op>;
  478. std::unordered_map<Key, Float32Precision, c10::hash<Key>> fp32_precision = {
  479. {{Float32Backend::GENERIC, Float32Op::ALL}, Float32Precision::NONE},
  480. {{Float32Backend::MKLDNN, Float32Op::ALL}, Float32Precision::NONE},
  481. {{Float32Backend::MKLDNN, Float32Op::CONV}, Float32Precision::NONE},
  482. {{Float32Backend::MKLDNN, Float32Op::RNN}, Float32Precision::NONE},
  483. {{Float32Backend::MKLDNN, Float32Op::MATMUL}, Float32Precision::NONE},
  484. {{Float32Backend::CUDA, Float32Op::ALL}, Float32Precision::NONE},
  485. {{Float32Backend::CUDA, Float32Op::CONV}, Float32Precision::TF32},
  486. {{Float32Backend::CUDA, Float32Op::RNN}, Float32Precision::TF32},
  487. {{Float32Backend::CUDA, Float32Op::MATMUL},
  488. float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
  489. ? Float32Precision::NONE
  490. : Float32Precision::TF32},
  491. };
  492. Allocator* prev_allocator_ptr_{nullptr};
  493. };
  494. TORCH_API Context& globalContext();
  495. inline void init() {
  496. globalContext();
  497. }
  498. TORCH_API Allocator* getCPUAllocator();
  499. inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
  500. Backend p,
  501. ScalarType s) {
  502. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  503. p, s);
  504. }
  505. inline DeprecatedTypeProperties& CPU(ScalarType s) {
  506. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  507. Backend::CPU, s);
  508. }
  509. inline DeprecatedTypeProperties& CUDA(ScalarType s) {
  510. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  511. Backend::CUDA, s);
  512. }
  513. inline DeprecatedTypeProperties& HIP(ScalarType s) {
  514. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  515. Backend::HIP, s);
  516. }
  517. inline DeprecatedTypeProperties& MPS(ScalarType s) {
  518. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  519. Backend::MPS, s);
  520. }
  521. inline bool hasCUDA() {
  522. return globalContext().hasCUDA();
  523. }
  524. inline bool hasMTIA() {
  525. return globalContext().hasMTIA();
  526. }
  527. inline bool hasHIP() {
  528. return globalContext().hasHIP();
  529. }
  530. inline bool hasIPU() {
  531. return globalContext().hasIPU();
  532. }
  533. inline bool hasXLA() {
  534. return globalContext().hasXLA();
  535. }
  536. inline bool hasMPS() {
  537. return globalContext().hasMPS();
  538. }
  539. inline bool hasMAIA() {
  540. return globalContext().hasMAIA();
  541. }
  542. inline bool hasXPU() {
  543. return globalContext().hasXPU();
  544. }
  545. inline bool hasHPU() {
  546. return globalContext().hasHPU();
  547. }
  548. // Despite its name, this function returns the number of *CUDA* GPUs.
  549. inline size_t getNumGPUs() {
  550. // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
  551. // FUNCTION. If you are interested in interrogating the number of
  552. // devices for a specific device type, add that function to the
  553. // relevant library (e.g., similar to at::cuda::device_count())
  554. if (hasCUDA() && hasHIP()) {
  555. TORCH_CHECK(
  556. false,
  557. "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
  558. "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
  559. "means HIP. Rebuild PyTorch with one or the other disabled.");
  560. } else if (hasCUDA()) {
  561. return detail::getCUDAHooks().deviceCount();
  562. } else if (hasHIP()) {
  563. return detail::getHIPHooks().getNumGPUs();
  564. } else {
  565. return 0;
  566. }
  567. }
  568. inline bool hasOpenMP() {
  569. return globalContext().hasOpenMP();
  570. }
  571. inline bool hasMKL() {
  572. return globalContext().hasMKL();
  573. }
  574. inline bool hasKleidiAI() {
  575. return globalContext().hasKleidiAI();
  576. }
  577. inline bool hasLAPACK() {
  578. return globalContext().hasLAPACK();
  579. }
  580. inline bool hasEigenSparse() {
  581. return globalContext().hasEigenSparse();
  582. }
  583. inline bool hasMAGMA() {
  584. return globalContext().hasMAGMA();
  585. }
  586. inline bool hasMKLDNN() {
  587. return globalContext().hasMKLDNN();
  588. }
  589. inline void manual_seed(uint64_t seed) {
  590. {
  591. auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
  592. // See Note [Acquire lock when using random generators]
  593. std::lock_guard<std::mutex> lock(gen.mutex());
  594. gen.set_current_seed(seed);
  595. }
  596. const auto opt_device_type = at::getAccelerator();
  597. if (!opt_device_type.has_value()) {
  598. return;
  599. }
  600. const auto num_gpus = globalContext()
  601. .getAcceleratorHooksInterface(opt_device_type)
  602. .deviceCount();
  603. for (const auto i : c10::irange(num_gpus)) {
  604. auto gen = globalContext().defaultGenerator(
  605. Device(opt_device_type.value(), static_cast<c10::DeviceIndex>(i)));
  606. {
  607. // See Note [Acquire lock when using random generators]
  608. std::lock_guard<std::mutex> lock(gen.mutex());
  609. gen.set_current_seed(seed);
  610. }
  611. }
  612. }
  613. // When the global flag `allow_tf32` is set to true, cuBLAS handles are
  614. // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
  615. // For some operators, such as addmv, TF32 offers no performance improvement
  616. // but causes precision loss. To help this case, this class implements
  617. // a RAII guard that can be used to quickly disable TF32 within its scope.
  618. //
  619. // Usage:
  620. // NoTF32Guard disable_tf32;
  621. struct TORCH_API NoTF32Guard {
  622. NoTF32Guard();
  623. NoTF32Guard(NoTF32Guard&& other) = delete;
  624. NoTF32Guard(const NoTF32Guard&) = delete;
  625. NoTF32Guard& operator=(const NoTF32Guard&) = delete;
  626. NoTF32Guard& operator=(NoTF32Guard&&) = delete;
  627. ~NoTF32Guard();
  628. static bool should_disable_tf32();
  629. private:
  630. bool changed = false;
  631. };
  632. struct TORCH_API ROCmBackwardPassGuard {
  633. ROCmBackwardPassGuard();
  634. ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete;
  635. ROCmBackwardPassGuard(const ROCmBackwardPassGuard&) = delete;
  636. ROCmBackwardPassGuard& operator=(const ROCmBackwardPassGuard&) = delete;
  637. ROCmBackwardPassGuard& operator=(ROCmBackwardPassGuard&&) = delete;
  638. ~ROCmBackwardPassGuard();
  639. static bool is_backward_pass();
  640. };
  641. } // namespace at
  642. #else
  643. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  644. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)