| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <ATen/BlasBackend.h>
- #include <ATen/CPUGeneratorImpl.h>
- #include <ATen/DeviceAccelerator.h>
- #include <ATen/LinalgBackend.h>
- #include <ATen/ROCmFABackend.h>
- #include <ATen/SDPBackend.h>
- #include <ATen/core/ATenGeneral.h>
- #include <ATen/core/DeprecatedTypeProperties.h>
- #include <ATen/core/Generator.h>
- #include <ATen/core/LegacyTypeDispatch.h>
- #include <ATen/detail/AcceleratorHooksInterface.h>
- #include <ATen/detail/CUDAHooksInterface.h>
- #include <ATen/detail/HIPHooksInterface.h>
- #include <ATen/detail/HPUHooksInterface.h>
- #include <ATen/detail/IPUHooksInterface.h>
- #include <ATen/detail/MAIAHooksInterface.h>
- #include <ATen/detail/MPSHooksInterface.h>
- #include <ATen/detail/MTIAHooksInterface.h>
- #include <ATen/detail/PrivateUse1HooksInterface.h>
- #include <ATen/detail/XLAHooksInterface.h>
- #include <ATen/detail/XPUHooksInterface.h>
- #include <c10/core/QEngine.h>
- #include <c10/core/impl/DeviceGuardImplInterface.h>
- #include <c10/util/CallOnce.h>
- #include <c10/util/Exception.h>
- #include <c10/util/env.h>
- #include <c10/util/hash.h>
- #include <c10/util/irange.h>
- #include <cstdint>
- #include <map>
- #include <mutex>
- #include <unordered_map>
- namespace at {
- class Tensor;
- enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
- enum class CuBLASReductionOption : uint8_t {
- AllowReducedPrecisionWithSplitK = 0,
- DisallowReducedPrecisionAllowSplitK = 1,
- DisallowReducedPrecisionDisallowSplitK = 2,
- };
- enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
- enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
- enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
- TORCH_API Float32Backend str2backend(const std::string& name);
- TORCH_API Float32Op str2op(const std::string& name);
- TORCH_API Float32Precision str2precision(const std::string& name);
- TORCH_API std::string precision2str(Float32Precision prec);
- class TORCH_API Context {
- public:
- Context();
- const Generator& defaultGenerator(Device device) {
- c10::DeviceType device_type = device.type();
- lazyInitDevice(device_type);
- if (device_type == at::kCPU) {
- return at::detail::getDefaultCPUGenerator();
- } else {
- return getAcceleratorHooksInterface(device_type)
- .getDefaultGenerator(device.index());
- }
- }
- const AcceleratorHooksInterface& getAcceleratorHooksInterface(
- std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
- if (!opt_device_type.has_value()) {
- opt_device_type = at::getAccelerator(true);
- }
- if (opt_device_type == at::kCUDA) {
- return at::detail::getCUDAHooks();
- } else if (opt_device_type == at::kXPU) {
- return at::detail::getXPUHooks();
- } else if (opt_device_type == at::kMPS) {
- return at::detail::getMPSHooks();
- } else if (opt_device_type == at::kPrivateUse1) {
- return at::detail::getPrivateUse1Hooks();
- } else if (opt_device_type == at::kMTIA) {
- return at::detail::getMTIAHooks();
- } else if (opt_device_type == at::kHIP) {
- return at::detail::getHIPHooks();
- } else if (opt_device_type == at::kHPU) {
- return at::detail::getHPUHooks();
- } else if (opt_device_type == at::kXLA) {
- return at::detail::getXLAHooks();
- } else {
- TORCH_CHECK(
- false,
- opt_device_type.has_value()
- ? c10::DeviceTypeName(opt_device_type.value())
- : "None",
- " device type not an accelerator.");
- }
- }
- Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
- lazyInitDevice(device_type);
- if (device_type == at::kCPU) {
- return c10::DeviceType::CPU;
- } else {
- return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
- }
- }
- bool isPinnedPtr(
- const void* data,
- std::optional<c10::DeviceType> device_type = std::nullopt) {
- auto opt_device_type =
- device_type.has_value() ? device_type : at::getAccelerator();
- if (!opt_device_type.has_value() || // there is no accelerator
- !at::isAccelerator(
- opt_device_type.value())) { // passed device not an accelerator
- return false;
- }
- if (!init_[static_cast<int8_t>(opt_device_type.value())].test_once()) {
- // If the device is not initialized, no pointer can be pinned for it
- return false;
- }
- return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
- }
- Allocator* getPinnedMemoryAllocator(
- std::optional<c10::DeviceType> device_type = std::nullopt) {
- auto opt_device_type =
- device_type.has_value() ? device_type : at::getAccelerator();
- if (opt_device_type) {
- lazyInitDevice(opt_device_type.value());
- }
- return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
- }
- void lazyInitDevice(c10::DeviceType device_type) {
- if (device_type != at::kCPU) {
- c10::call_once(init_[static_cast<int8_t>(device_type)], [&] {
- getAcceleratorHooksInterface(device_type).init();
- });
- }
- }
- static bool hasOpenMP();
- static bool hasMKL();
- static bool hasKleidiAI();
- static bool hasLAPACK();
- static bool hasMKLDNN();
- static bool ckSupported();
- static bool hasEigenSparse();
- static bool hasMAGMA() {
- return detail::getCUDAHooks().hasMAGMA();
- }
- static bool hasCUDA() {
- return detail::getCUDAHooks().hasCUDA();
- }
- static bool hasMTIA() {
- return detail::getMTIAHooks().hasMTIA();
- }
- static bool hasCUDART() {
- return detail::getCUDAHooks().hasCUDART();
- }
- static long versionCUDART() {
- return detail::getCUDAHooks().versionCUDART();
- }
- static bool hasCuDNN() {
- return detail::getCUDAHooks().hasCuDNN();
- }
- static long versionCuDNN() {
- return detail::getCUDAHooks().versionCuDNN();
- }
- static long versionRuntimeCuDNN() {
- return detail::getCUDAHooks().versionRuntimeCuDNN();
- }
- static long versionCuDNNFrontend() {
- return detail::getCUDAHooks().versionCuDNNFrontend();
- }
- static bool hasCuSOLVER() {
- return detail::getCUDAHooks().hasCuSOLVER();
- }
- static bool hasCuBLASLt() {
- return detail::getCUDAHooks().hasCuBLASLt();
- }
- static bool hasROCM() {
- return detail::getCUDAHooks().hasROCM();
- }
- static bool hasCKSDPA() {
- return detail::getCUDAHooks().hasCKSDPA();
- }
- static bool hasCKGEMM() {
- return detail::getCUDAHooks().hasCKGEMM();
- }
- static bool hasHIP() {
- return detail::getHIPHooks().hasHIP();
- }
- static bool hasMPS() {
- return detail::getMPSHooks().hasMPS();
- }
- static bool hasIPU() {
- return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
- }
- static bool hasXLA() {
- return detail::getXLAHooks().hasXLA();
- }
- static bool hasXPU() {
- return detail::getXPUHooks().hasXPU();
- }
- static bool hasLazy() {
- return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
- }
- static bool hasMAIA() {
- return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
- }
- static bool hasHPU() {
- return detail::getHPUHooks().hasHPU();
- }
- static const at::cuda::NVRTC& getNVRTC() {
- return detail::getCUDAHooks().nvrtc();
- }
- static const at::xpu::LevelZero& getLevelZero() {
- return detail::getXPUHooks().level_zero();
- }
- static bool setFlushDenormal(bool on);
- // NB: This method is *purely* whether or not a user requested
- // that CuDNN was enabled, it doesn't actually say anything about
- // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
- // to test this instead
- bool userEnabledCuDNN() const;
- void setUserEnabledCuDNN(bool e);
- bool userEnabledMkldnn() const;
- void setUserEnabledMkldnn(bool e);
- bool benchmarkCuDNN() const;
- void setBenchmarkCuDNN(bool /*b*/);
- int benchmarkLimitCuDNN() const;
- void setBenchmarkLimitCuDNN(int /*b*/);
- bool immediateMiopen() const;
- void setImmediateMiopen(bool /*b*/);
- bool deterministicCuDNN() const;
- void setDeterministicCuDNN(bool /*b*/);
- bool deterministicMkldnn() const;
- void setDeterministicMkldnn(bool /*b*/);
- bool userEnabledNNPACK() const;
- void setUserEnabledNNPACK(bool e);
- // Note [Disabling Fused SDP Kernels]
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // Flash and Memory Efficient SDP kernels are enabled by default.
- // However, they can be disabled by setting
- // at::globalContext().setUserEnabledFlashSDP(false) flag.
- // This is useful for debugging purposes. For example, if you want to
- // compare the performance of the flash SDP kernels with the unfused
- // kernel, you can disable the flash SDP kernels. By disabling
- // the math SDP kernel, you can force your code to use flash kernels.
- // The math SDP kernel can be disabled by setting
- // at::globalContext().setUserEnabledMathSDP(false) flag.
- void setSDPPriorityOrder(const std::vector<int64_t>& order);
- std::array<at::SDPBackend, at::num_sdp_backends> sDPPriorityOrder();
- void setSDPUseFlash(bool /*e*/);
- bool userEnabledFlashSDP() const;
- void setSDPUseFA3(bool /*e*/);
- bool userEnabledFA3SDP() const;
- void setSDPUseMemEfficient(bool /*e*/);
- bool userEnabledMemEfficientSDP() const;
- void setSDPUseMath(bool /*e*/);
- bool userEnabledMathSDP() const;
- void setSDPUseCuDNN(bool /*e*/);
- bool userEnabledCuDNNSDP() const;
- void setAllowFP16BF16ReductionMathSDP(bool /*e*/);
- bool allowFP16BF16ReductionMathSDP() const;
- void setSDPUseOverrideable(bool /*e*/);
- bool userEnabledOverrideableSDP() const;
- at::LinalgBackend linalgPreferredBackend() const;
- void setLinalgPreferredBackend(at::LinalgBackend /*b*/);
- at::BlasBackend blasPreferredBackend();
- void setBlasPreferredBackend(at::BlasBackend /*b*/);
- at::ROCmFABackend getROCmFAPreferredBackend();
- void setROCmFAPreferredBackend(at::ROCmFABackend /*b*/);
- // Note [Enabling Deterministic Operations]
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // Operations in PyTorch that normally act nondeterministically, but have an
- // alternate deterministic implementation, should satisfy the following
- // requirements:
- //
- // * Include this comment: "See Note [Enabling Deterministic Operations]"
- //
- // * Check the value of `at::globalContext().deterministicAlgorithms()` to
- // toggle
- // between nondeterministic and deterministic implementations.
- //
- // * Have an entry in the list of PyTorch operations that toggle between
- // nondeterministic
- // and deterministic implementations, in the docstring of
- // `use_deterministic_algorithms()` in torch/__init__.py
- //
- // `example_func()` below shows an example of toggling between
- // nondeterministic and deterministic implementations:
- //
- // void example_func() {
- // // See Note [Enabling Deterministic Operations]
- // if (at::globalContext().deterministicAlgorithms()) {
- // example_func_deterministic();
- // } else {
- // example_func_nondeterministic();
- // }
- // }
- bool deterministicAlgorithms() const;
- bool deterministicAlgorithmsWarnOnly() const;
- void setDeterministicAlgorithms(bool /*b*/, bool /*warn_only*/);
- bool deterministicFillUninitializedMemory() const;
- void setDeterministicFillUninitializedMemory(bool /*b*/);
- // Note [Writing Nondeterministic Operations]
- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- // Operations in PyTorch that act nondeterministically and do not have an
- // alternate deterministic implementation should satisfy the following
- // requirements:
- //
- // * Include this comment: "See Note [Writing Nondeterministic Operations]"
- //
- // * Include a comment explaining why the operation is nondeterministic.
- //
- // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
- // of the time, this should be accomplished by calling
- // `at::globalContext().alertNotDeterminstic().
- //
- // * Have an entry in the list of nondeterministic PyTorch operations in the
- // docstring of `use_deterministic_algorithms()` in torch/__init__.py
- //
- // * Have a test function in `test/test_torch.py` whose name begins with
- // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
- // configuration is the reason for nondeterminism, the operation should be
- // included in the `test_cublas_config_nondeterministic_alert` test. Any new
- // tests should ideally follow a pattern similar to the existing ones.
- //
- // `example_func()` below shows an example of the comments and error-throwing
- // code for a nondeterministic operation:
- //
- // void example_func() {
- // // See Note [Writing Nondeterministic Operations]
- // // Nondeterministic because <reason>
- // at::globalContext().alertNondeterministic("example_func");
- // ...
- // }
- // Throws an error if `Context::deterministicAlgorithms()` is true
- static void alertNotDeterministic(std::string_view const& caller);
- void setFloat32MatmulPrecision(const std::string& s);
- void setFloat32Precision(
- Float32Backend backend,
- Float32Op op,
- Float32Precision p);
- bool allowTF32CuDNN(std::optional<Float32Op> op = std::nullopt) const;
- void setAllowTF32CuDNN(bool /*b*/);
- bool allowTF32OneDNN() const;
- void setAllowTF32OneDNN(bool /*b*/);
- bool allowTF32CuBLAS() const;
- void setAllowTF32CuBLAS(bool /*b*/);
- Float32MatmulPrecision float32MatmulPrecision() const;
- Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
- CuBLASReductionOption allowFP16ReductionCuBLAS() const;
- void setAllowFP16ReductionCuBLAS(
- bool allow_reduced_precision,
- bool allow_splitk = true);
- CuBLASReductionOption allowBF16ReductionCuBLAS() const;
- void setAllowBF16ReductionCuBLAS(
- bool allow_reduced_precision,
- bool allow_splitk = true);
- bool allowFP16AccumulationCuBLAS() const;
- void setAllowFP16AccumulationCuBLAS(bool /*b*/);
- bool rocmAllowGroupGemmCk() const;
- // Matmuls can use a so-called "persistent" kernel which launches one CUDA
- // block for each SM on the GPU, and each block then iterates over multiple
- // output tiles. This allows to use software pipelining to hide the begin/end
- // latencies (e.g., epilogue), especially when only one tile fits per SM.
- // However, if some SMs are busy (e.g., with a background NCCL kernel), the
- // matmul's blocks will be scheduled in two waves and, in the absence of some
- // smart load balancing, the kernel will take twice as long. This flag allows
- // to make matmuls target only a subset of the SMs, so they can fully schedule
- // even next to a comms kernel, and only be a few percent slower.
- std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
- void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t> /*c*/);
- at::QEngine qEngine() const;
- void setQEngine(at::QEngine e);
- static const std::vector<at::QEngine>& supportedQEngines();
- static bool isXNNPACKAvailable();
- void setCheckSparseTensorInvariants(std::optional<bool> e);
- std::optional<bool> checkSparseTensorInvariants(
- bool warn_when_uninitialized = false) const;
- // This method is used to release the original weight after pre-packing.
- // It should be called once before loading/running the model.
- // NB: By default it is set to true for mobile builds.
- void setReleaseWeightsWhenPrepacking(bool e);
- bool releaseWeightsWhenPrepacking() const;
- void setDisplayVmapFallbackWarnings(bool enabled);
- bool areVmapFallbackWarningsEnabled() const;
- void setWarnOnAccumulateGradStreamMismatch(bool enabled);
- bool warnOnAccumulateGradStreamMismatch() const;
- bool isDefaultMobileCPUAllocatorSet();
- void setDefaultMobileCPUAllocator();
- void unsetDefaultMobileCPUAllocator();
- bool allowFP16ReductionCPU() const;
- void setAllowFP16ReductionCPU(bool /*b*/);
- // Preserved for BC
- void lazyInitCUDA() {
- TORCH_WARN_DEPRECATION(
- "lazyInitCUDA is deprecated. Please use lazyInitDevice(at::kCUDA) instead.")
- lazyInitDevice(at::kCUDA);
- }
- void lazyInitHIP() {
- TORCH_WARN_DEPRECATION(
- "lazyInitHIP is deprecated. Please use lazyInitDevice(at::kHIP) instead.")
- lazyInitDevice(at::kHIP);
- }
- void lazyInitXPU() {
- TORCH_WARN_DEPRECATION(
- "lazyInitXPU is deprecated. Please use lazyInitDevice(at::kXPU) instead.")
- lazyInitDevice(at::kXPU);
- }
- void lazyInitMTIA() {
- TORCH_WARN_DEPRECATION(
- "lazyInitMTIA is deprecated. Please use lazyInitDevice(at::kMTIA) instead.")
- lazyInitDevice(at::kMTIA);
- }
- void lazyInitPrivateUse1() {
- TORCH_WARN_DEPRECATION(
- "lazyInitPrivateUse1 is deprecated. Please use lazyInitDevice(at::kPrivateUse1) instead.")
- lazyInitDevice(at::kPrivateUse1);
- }
- private:
- std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES> init_;
- bool enabled_cudnn = true;
- bool deterministic_cudnn = false;
- bool deterministic_mkldnn = false;
- bool _deterministic_algorithms = false;
- bool _deterministic_algorithms_warn_only = false;
- bool _deterministic_fill_uninitialized_memory = true;
- std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
- at::SDPBackend::flash_attention,
- at::SDPBackend::efficient_attention,
- at::SDPBackend::math,
- at::SDPBackend::cudnn_attention,
- at::SDPBackend::overrideable};
- bool enabled_flashSDP = true;
- bool enabled_fa3SDP = false;
- bool enabled_mem_efficientSDP = true;
- bool enabled_mathSDP = true;
- bool enabled_cudnnSDP = true;
- bool enabled_overrideable = true;
- bool allow_fp16_bf16_reduction_mathSDP = false;
- bool benchmark_cudnn = false;
- bool immediate_miopen = false;
- Float32MatmulPrecision float32_matmul_precision =
- c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
- ? at::Float32MatmulPrecision::HIGH
- : at::Float32MatmulPrecision::HIGHEST;
- int benchmark_limit_cudnn = 10;
- bool allow_tf32_cudnn = true;
- CuBLASReductionOption allow_fp16_reduction_cublas =
- CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
- CuBLASReductionOption allow_bf16_reduction_cublas =
- CuBLASReductionOption::AllowReducedPrecisionWithSplitK;
- bool allow_fp16_accumulation_cublas = false;
- std::optional<int32_t> sm_carveout = std::nullopt;
- bool enabled_mkldnn = true;
- bool allow_tf32_onednn = false;
- bool enabled_nnpack = true;
- at::LinalgBackend linalg_preferred_backend =
- (c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true ||
- c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias
- ? at::LinalgBackend::Cusolver
- : at::LinalgBackend::Default;
- at::BlasBackend blas_preferred_backend =
- (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
- c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) // alias
- ? at::BlasBackend::Cublaslt
- : at::BlasBackend::Default;
- at::ROCmFABackend rocm_fa_preferred_backend =
- c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true
- ? at::ROCmFABackend::Ck
- : at::ROCmFABackend::Default;
- #ifdef C10_MOBILE
- bool release_original_weights = true;
- #else
- bool release_original_weights = false;
- #endif
- bool display_vmap_fallback_warnings_ = false;
- bool warn_on_accumulate_grad_stream_mismatch_ = true;
- std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
- std::optional<bool> enable_sparse_tensor_invariant_checks = std::nullopt;
- bool allow_fp16_reduction_cpu = false;
- using Key = std::pair<Float32Backend, Float32Op>;
- std::unordered_map<Key, Float32Precision, c10::hash<Key>> fp32_precision = {
- {{Float32Backend::GENERIC, Float32Op::ALL}, Float32Precision::NONE},
- {{Float32Backend::MKLDNN, Float32Op::ALL}, Float32Precision::NONE},
- {{Float32Backend::MKLDNN, Float32Op::CONV}, Float32Precision::NONE},
- {{Float32Backend::MKLDNN, Float32Op::RNN}, Float32Precision::NONE},
- {{Float32Backend::MKLDNN, Float32Op::MATMUL}, Float32Precision::NONE},
- {{Float32Backend::CUDA, Float32Op::ALL}, Float32Precision::NONE},
- {{Float32Backend::CUDA, Float32Op::CONV}, Float32Precision::TF32},
- {{Float32Backend::CUDA, Float32Op::RNN}, Float32Precision::TF32},
- {{Float32Backend::CUDA, Float32Op::MATMUL},
- float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
- ? Float32Precision::NONE
- : Float32Precision::TF32},
- };
- Allocator* prev_allocator_ptr_{nullptr};
- };
- TORCH_API Context& globalContext();
- inline void init() {
- globalContext();
- }
- TORCH_API Allocator* getCPUAllocator();
- inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
- Backend p,
- ScalarType s) {
- return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
- p, s);
- }
- inline DeprecatedTypeProperties& CPU(ScalarType s) {
- return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
- Backend::CPU, s);
- }
- inline DeprecatedTypeProperties& CUDA(ScalarType s) {
- return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
- Backend::CUDA, s);
- }
- inline DeprecatedTypeProperties& HIP(ScalarType s) {
- return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
- Backend::HIP, s);
- }
- inline DeprecatedTypeProperties& MPS(ScalarType s) {
- return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
- Backend::MPS, s);
- }
- inline bool hasCUDA() {
- return globalContext().hasCUDA();
- }
- inline bool hasMTIA() {
- return globalContext().hasMTIA();
- }
- inline bool hasHIP() {
- return globalContext().hasHIP();
- }
- inline bool hasIPU() {
- return globalContext().hasIPU();
- }
- inline bool hasXLA() {
- return globalContext().hasXLA();
- }
- inline bool hasMPS() {
- return globalContext().hasMPS();
- }
- inline bool hasMAIA() {
- return globalContext().hasMAIA();
- }
- inline bool hasXPU() {
- return globalContext().hasXPU();
- }
- inline bool hasHPU() {
- return globalContext().hasHPU();
- }
- // Despite its name, this function returns the number of *CUDA* GPUs.
- inline size_t getNumGPUs() {
- // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
- // FUNCTION. If you are interested in interrogating the number of
- // devices for a specific device type, add that function to the
- // relevant library (e.g., similar to at::cuda::device_count())
- if (hasCUDA() && hasHIP()) {
- TORCH_CHECK(
- false,
- "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
- "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
- "means HIP. Rebuild PyTorch with one or the other disabled.");
- } else if (hasCUDA()) {
- return detail::getCUDAHooks().deviceCount();
- } else if (hasHIP()) {
- return detail::getHIPHooks().getNumGPUs();
- } else {
- return 0;
- }
- }
- inline bool hasOpenMP() {
- return globalContext().hasOpenMP();
- }
- inline bool hasMKL() {
- return globalContext().hasMKL();
- }
- inline bool hasKleidiAI() {
- return globalContext().hasKleidiAI();
- }
- inline bool hasLAPACK() {
- return globalContext().hasLAPACK();
- }
- inline bool hasEigenSparse() {
- return globalContext().hasEigenSparse();
- }
- inline bool hasMAGMA() {
- return globalContext().hasMAGMA();
- }
- inline bool hasMKLDNN() {
- return globalContext().hasMKLDNN();
- }
- inline void manual_seed(uint64_t seed) {
- {
- auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
- // See Note [Acquire lock when using random generators]
- std::lock_guard<std::mutex> lock(gen.mutex());
- gen.set_current_seed(seed);
- }
- const auto opt_device_type = at::getAccelerator();
- if (!opt_device_type.has_value()) {
- return;
- }
- const auto num_gpus = globalContext()
- .getAcceleratorHooksInterface(opt_device_type)
- .deviceCount();
- for (const auto i : c10::irange(num_gpus)) {
- auto gen = globalContext().defaultGenerator(
- Device(opt_device_type.value(), static_cast<c10::DeviceIndex>(i)));
- {
- // See Note [Acquire lock when using random generators]
- std::lock_guard<std::mutex> lock(gen.mutex());
- gen.set_current_seed(seed);
- }
- }
- }
- // When the global flag `allow_tf32` is set to true, cuBLAS handles are
- // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
- // For some operators, such as addmv, TF32 offers no performance improvement
- // but causes precision loss. To help this case, this class implements
- // a RAII guard that can be used to quickly disable TF32 within its scope.
- //
- // Usage:
- // NoTF32Guard disable_tf32;
- struct TORCH_API NoTF32Guard {
- NoTF32Guard();
- NoTF32Guard(NoTF32Guard&& other) = delete;
- NoTF32Guard(const NoTF32Guard&) = delete;
- NoTF32Guard& operator=(const NoTF32Guard&) = delete;
- NoTF32Guard& operator=(NoTF32Guard&&) = delete;
- ~NoTF32Guard();
- static bool should_disable_tf32();
- private:
- bool changed = false;
- };
- struct TORCH_API ROCmBackwardPassGuard {
- ROCmBackwardPassGuard();
- ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete;
- ROCmBackwardPassGuard(const ROCmBackwardPassGuard&) = delete;
- ROCmBackwardPassGuard& operator=(const ROCmBackwardPassGuard&) = delete;
- ROCmBackwardPassGuard& operator=(ROCmBackwardPassGuard&&) = delete;
- ~ROCmBackwardPassGuard();
- static bool is_backward_pass();
- };
- } // namespace at
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|