| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- // This header provides C++ wrappers around commonly used CUDA API functions.
- // The benefit of using C++ here is that we can raise an exception in the
- // event of an error, rather than explicitly pass around error codes. This
- // leads to more natural APIs.
- //
- // The naming convention used here matches the naming convention of torch.cuda
- #include <c10/core/Device.h>
- #include <c10/core/impl/GPUTrace.h>
- #include <c10/cuda/CUDAException.h>
- #include <c10/cuda/CUDAMacros.h>
- #include <cuda_runtime_api.h>
- namespace c10::cuda {
- // NB: In the past, we were inconsistent about whether or not this reported
- // an error if there were driver problems are not. Based on experience
- // interacting with users, it seems that people basically ~never want this
- // function to fail; it should just return zero if things are not working.
- // Oblige them.
- // It still might log a warning for user first time it's invoked
- C10_CUDA_API DeviceIndex device_count() noexcept;
- // Version of device_count that throws is no devices are detected
- C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
- C10_CUDA_API DeviceIndex current_device();
- C10_CUDA_API void set_device(DeviceIndex device, const bool force = false);
- C10_CUDA_API void device_synchronize();
- C10_CUDA_API void warn_or_error_on_sync();
- // Raw CUDA device management functions
- C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
- C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
- C10_CUDA_API cudaError_t
- SetDevice(DeviceIndex device, const bool force = false);
- C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
- C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device);
- C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device);
- C10_CUDA_API void SetTargetDevice();
- enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
- // this is a holder for c10 global state (similar to at GlobalContext)
- // currently it's used to store cuda synchronization warning state,
- // but can be expanded to hold other related global state, e.g. to
- // record stream usage
- class WarningState {
- public:
- void set_sync_debug_mode(SyncDebugMode l) {
- sync_debug_mode = l;
- }
- SyncDebugMode get_sync_debug_mode() {
- return sync_debug_mode;
- }
- private:
- SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
- };
- C10_CUDA_API __inline__ WarningState& warning_state() {
- static WarningState warning_state_;
- return warning_state_;
- }
- // the subsequent functions are defined in the header because for performance
- // reasons we want them to be inline
- C10_CUDA_API void __inline__ memcpy_and_sync(
- void* dst,
- const void* src,
- int64_t nbytes,
- cudaMemcpyKind kind,
- cudaStream_t stream) {
- if (C10_UNLIKELY(
- warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
- warn_or_error_on_sync();
- }
- const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
- if (C10_UNLIKELY(interp)) {
- (*interp)->trace_gpu_stream_synchronization(
- c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
- }
- #if defined(USE_ROCM) && USE_ROCM
- // As of ROCm 6.4.1, HIP runtime does not raise an error during capture of
- // hipMemcpyWithStream which is a synchronous call. Thus, we add a check
- // here explicitly.
- hipStreamCaptureStatus captureStatus;
- C10_CUDA_CHECK(hipStreamGetCaptureInfo(stream, &captureStatus, nullptr));
- if (C10_LIKELY(captureStatus == hipStreamCaptureStatusNone)) {
- C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
- } else {
- C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported);
- }
- #else
- C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
- C10_CUDA_CHECK(cudaStreamSynchronize(stream));
- #endif
- }
- C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
- if (C10_UNLIKELY(
- warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
- warn_or_error_on_sync();
- }
- const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
- if (C10_UNLIKELY(interp)) {
- (*interp)->trace_gpu_stream_synchronization(
- c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
- }
- C10_CUDA_CHECK(cudaStreamSynchronize(stream));
- }
- C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
- C10_CUDA_API std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
- } // namespace c10::cuda
- #ifdef USE_ROCM
- // for backward-compat between hipify v1 and v2 for external projects
- namespace c10::hip {
- using c10::cuda::current_device;
- using c10::cuda::device_count;
- using c10::cuda::device_count_ensure_non_zero;
- using c10::cuda::device_synchronize;
- using c10::cuda::ExchangeDevice;
- using c10::cuda::GetDevice;
- using c10::cuda::GetDeviceCount;
- using c10::cuda::getDeviceIndexWithPrimaryContext;
- using c10::cuda::hasPrimaryContext;
- using c10::cuda::MaybeExchangeDevice;
- using c10::cuda::MaybeSetDevice;
- using c10::cuda::memcpy_and_sync;
- using c10::cuda::set_device;
- using c10::cuda::SetDevice;
- using c10::cuda::SetTargetDevice;
- using c10::cuda::stream_synchronize;
- using c10::cuda::warn_or_error_on_sync;
- } // namespace c10::hip
- #endif
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|