CUDAFunctions.h 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. // This header provides C++ wrappers around commonly used CUDA API functions.
  4. // The benefit of using C++ here is that we can raise an exception in the
  5. // event of an error, rather than explicitly pass around error codes. This
  6. // leads to more natural APIs.
  7. //
  8. // The naming convention used here matches the naming convention of torch.cuda
  9. #include <c10/core/Device.h>
  10. #include <c10/core/impl/GPUTrace.h>
  11. #include <c10/cuda/CUDAException.h>
  12. #include <c10/cuda/CUDAMacros.h>
  13. #include <cuda_runtime_api.h>
  14. namespace c10::cuda {
  15. // NB: In the past, we were inconsistent about whether or not this reported
  16. // an error if there were driver problems are not. Based on experience
  17. // interacting with users, it seems that people basically ~never want this
  18. // function to fail; it should just return zero if things are not working.
  19. // Oblige them.
  20. // It still might log a warning for user first time it's invoked
  21. C10_CUDA_API DeviceIndex device_count() noexcept;
  22. // Version of device_count that throws is no devices are detected
  23. C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
  24. C10_CUDA_API DeviceIndex current_device();
  25. C10_CUDA_API void set_device(DeviceIndex device, const bool force = false);
  26. C10_CUDA_API void device_synchronize();
  27. C10_CUDA_API void warn_or_error_on_sync();
  28. // Raw CUDA device management functions
  29. C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
  30. C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
  31. C10_CUDA_API cudaError_t
  32. SetDevice(DeviceIndex device, const bool force = false);
  33. C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
  34. C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device);
  35. C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device);
  36. C10_CUDA_API void SetTargetDevice();
  37. enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
  38. // this is a holder for c10 global state (similar to at GlobalContext)
  39. // currently it's used to store cuda synchronization warning state,
  40. // but can be expanded to hold other related global state, e.g. to
  41. // record stream usage
  42. class WarningState {
  43. public:
  44. void set_sync_debug_mode(SyncDebugMode l) {
  45. sync_debug_mode = l;
  46. }
  47. SyncDebugMode get_sync_debug_mode() {
  48. return sync_debug_mode;
  49. }
  50. private:
  51. SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
  52. };
  53. C10_CUDA_API __inline__ WarningState& warning_state() {
  54. static WarningState warning_state_;
  55. return warning_state_;
  56. }
  57. // the subsequent functions are defined in the header because for performance
  58. // reasons we want them to be inline
  59. C10_CUDA_API void __inline__ memcpy_and_sync(
  60. void* dst,
  61. const void* src,
  62. int64_t nbytes,
  63. cudaMemcpyKind kind,
  64. cudaStream_t stream) {
  65. if (C10_UNLIKELY(
  66. warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
  67. warn_or_error_on_sync();
  68. }
  69. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  70. if (C10_UNLIKELY(interp)) {
  71. (*interp)->trace_gpu_stream_synchronization(
  72. c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
  73. }
  74. #if defined(USE_ROCM) && USE_ROCM
  75. // As of ROCm 6.4.1, HIP runtime does not raise an error during capture of
  76. // hipMemcpyWithStream which is a synchronous call. Thus, we add a check
  77. // here explicitly.
  78. hipStreamCaptureStatus captureStatus;
  79. C10_CUDA_CHECK(hipStreamGetCaptureInfo(stream, &captureStatus, nullptr));
  80. if (C10_LIKELY(captureStatus == hipStreamCaptureStatusNone)) {
  81. C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
  82. } else {
  83. C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported);
  84. }
  85. #else
  86. C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
  87. C10_CUDA_CHECK(cudaStreamSynchronize(stream));
  88. #endif
  89. }
  90. C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
  91. if (C10_UNLIKELY(
  92. warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
  93. warn_or_error_on_sync();
  94. }
  95. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  96. if (C10_UNLIKELY(interp)) {
  97. (*interp)->trace_gpu_stream_synchronization(
  98. c10::kCUDA, reinterpret_cast<uintptr_t>(stream));
  99. }
  100. C10_CUDA_CHECK(cudaStreamSynchronize(stream));
  101. }
  102. C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index);
  103. C10_CUDA_API std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext();
  104. } // namespace c10::cuda
  105. #ifdef USE_ROCM
  106. // for backward-compat between hipify v1 and v2 for external projects
  107. namespace c10::hip {
  108. using c10::cuda::current_device;
  109. using c10::cuda::device_count;
  110. using c10::cuda::device_count_ensure_non_zero;
  111. using c10::cuda::device_synchronize;
  112. using c10::cuda::ExchangeDevice;
  113. using c10::cuda::GetDevice;
  114. using c10::cuda::GetDeviceCount;
  115. using c10::cuda::getDeviceIndexWithPrimaryContext;
  116. using c10::cuda::hasPrimaryContext;
  117. using c10::cuda::MaybeExchangeDevice;
  118. using c10::cuda::MaybeSetDevice;
  119. using c10::cuda::memcpy_and_sync;
  120. using c10::cuda::set_device;
  121. using c10::cuda::SetDevice;
  122. using c10::cuda::SetTargetDevice;
  123. using c10::cuda::stream_synchronize;
  124. using c10::cuda::warn_or_error_on_sync;
  125. } // namespace c10::hip
  126. #endif
  127. #else
  128. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  129. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)