CUDAEvent.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/alignment.h>
  4. #include <c10/core/impl/GPUTrace.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #include <c10/cuda/CUDAStream.h>
  7. #include <c10/util/Exception.h>
  8. #include <c10/util/irange.h>
  9. /*
  10. * `cudaEventExternal` is a torch-specific flag that is used to
  11. * indicate that the CUDAEvent will be used only for synchronization
  12. * with work outside of the cuda graph, rather than creation of
  13. * cross-stream dependencies within a cuda graph. Resources:
  14. * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events
  15. * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47
  16. * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e
  17. */
  18. #define cudaEventExternal 0x08
  19. namespace c10::cuda {
  20. /*
  21. * CUDAEvents are movable not copyable wrappers around CUDA's events.
  22. *
  23. * CUDAEvents are constructed lazily when first recorded unless it is
  24. * reconstructed from a cudaIpcEventHandle_t. The event has a device, and this
  25. * device is acquired from the first recording stream. However, if reconstructed
  26. * from a handle, the device should be explicitly specified; or if ipc_handle()
  27. * is called before the event is ever recorded, it will use the current device.
  28. * Later streams that record the event must match this device.
  29. */
  30. struct CUDAEvent {
  31. // Constructors
  32. // Default value for `flags` is specified below - it's cudaEventDisableTiming
  33. CUDAEvent() noexcept = default;
  34. CUDAEvent(unsigned int flags) noexcept : flags_{flags} {}
  35. CUDAEvent(DeviceIndex device_index, const cudaIpcEventHandle_t* handle)
  36. : device_index_(device_index) {
  37. CUDAGuard guard(device_index_);
  38. C10_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle));
  39. is_created_ = true;
  40. }
  41. // Note: event destruction done on creating device to avoid creating a
  42. // CUDA context on other devices.
  43. ~CUDAEvent() {
  44. if (is_created_) {
  45. CUDAGuard guard(device_index_);
  46. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  47. if (C10_UNLIKELY(interp)) {
  48. (*interp)->trace_gpu_event_deletion(
  49. c10::kCUDA, reinterpret_cast<uintptr_t>(event_));
  50. }
  51. C10_CUDA_CHECK_WARN(cudaEventDestroy(event_));
  52. }
  53. }
  54. CUDAEvent(const CUDAEvent&) = delete;
  55. CUDAEvent& operator=(const CUDAEvent&) = delete;
  56. CUDAEvent(CUDAEvent&& other) noexcept {
  57. moveHelper(std::move(other));
  58. }
  59. CUDAEvent& operator=(CUDAEvent&& other) noexcept {
  60. if (this != &other) {
  61. moveHelper(std::move(other));
  62. }
  63. return *this;
  64. }
  65. operator cudaEvent_t() const {
  66. return event();
  67. }
  68. // Less than operator (to allow use in sets)
  69. friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
  70. return left.event_ < right.event_;
  71. }
  72. std::optional<c10::Device> device() const {
  73. if (is_created_) {
  74. return c10::Device(c10::kCUDA, device_index_);
  75. } else {
  76. return {};
  77. }
  78. }
  79. bool isCreated() const {
  80. return is_created_;
  81. }
  82. DeviceIndex device_index() const {
  83. return device_index_;
  84. }
  85. cudaEvent_t event() const {
  86. return event_;
  87. }
  88. // Note: cudaEventQuery can be safely called from any device
  89. bool query() const {
  90. if (!is_created_) {
  91. return true;
  92. }
  93. cudaError_t err = cudaEventQuery(event_);
  94. if (err == cudaSuccess) {
  95. return true;
  96. } else if (err != cudaErrorNotReady) {
  97. C10_CUDA_CHECK(err);
  98. } else {
  99. // ignore and clear the error if not ready
  100. (void)cudaGetLastError();
  101. }
  102. return false;
  103. }
  104. void record() {
  105. record(getCurrentCUDAStream());
  106. }
  107. void recordOnce(const CUDAStream& stream) {
  108. if (!was_recorded_)
  109. record(stream);
  110. }
  111. // Note: cudaEventRecord must be called on the same device as the event.
  112. void record(const CUDAStream& stream) {
  113. if (!is_created_) {
  114. createEvent(stream.device_index());
  115. }
  116. TORCH_CHECK(
  117. device_index_ == stream.device_index(),
  118. "Event device ",
  119. device_index_,
  120. " does not match recording stream's device ",
  121. stream.device_index(),
  122. ".");
  123. CUDAGuard guard(device_index_);
  124. #ifndef USE_ROCM
  125. // it is an error to use cudaEventRecordExternal when not doing stream
  126. // capture
  127. unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() !=
  128. c10::cuda::CaptureStatus::None &&
  129. external_)
  130. ? cudaEventRecordExternal
  131. : cudaEventRecordDefault;
  132. C10_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags));
  133. #else
  134. C10_CUDA_CHECK(cudaEventRecord(event_, stream));
  135. #endif
  136. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  137. if (C10_UNLIKELY(interp)) {
  138. (*interp)->trace_gpu_event_record(
  139. c10::kCUDA,
  140. reinterpret_cast<uintptr_t>(event_),
  141. reinterpret_cast<uintptr_t>(stream.stream()));
  142. }
  143. was_recorded_ = true;
  144. }
  145. // Note: cudaStreamWaitEvent must be called on the same device as the stream.
  146. // The event has no actual GPU resources associated with it.
  147. void block(const CUDAStream& stream) {
  148. if (is_created_) {
  149. CUDAGuard guard(stream.device_index());
  150. #ifndef USE_ROCM
  151. // it is an error to use cudaEventWaitExternal when not doing stream
  152. // capture
  153. unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() !=
  154. c10::cuda::CaptureStatus::None &&
  155. external_)
  156. ? cudaEventWaitExternal
  157. : cudaEventWaitDefault;
  158. C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags));
  159. #else
  160. C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_));
  161. #endif
  162. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  163. if (C10_UNLIKELY(interp)) {
  164. (*interp)->trace_gpu_event_wait(
  165. c10::kCUDA,
  166. reinterpret_cast<uintptr_t>(event_),
  167. reinterpret_cast<uintptr_t>(stream.stream()));
  168. }
  169. }
  170. }
  171. // Note: cudaEventElapsedTime can be safely called from any device
  172. float elapsed_time(const CUDAEvent& other) const {
  173. TORCH_CHECK_VALUE(
  174. !(flags_ & cudaEventDisableTiming) &&
  175. !(other.flags_ & cudaEventDisableTiming),
  176. "Both events must be created with argument 'enable_timing=True'.");
  177. TORCH_CHECK_VALUE(
  178. is_created_ && other.isCreated(),
  179. "Both events must be recorded before calculating elapsed time.");
  180. TORCH_CHECK(
  181. query() && other.query(),
  182. "Both events must be completed before calculating elapsed time.");
  183. float time_ms = 0;
  184. // We do not strictly have to set the device index to the same as our event,
  185. // but if we don't and the current device is not initialized, it will
  186. // create a new cuda context, which will consume a lot of memory.
  187. CUDAGuard guard(device_index_);
  188. // raise cudaErrorNotReady if either event is recorded but not yet completed
  189. C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
  190. return time_ms;
  191. }
  192. // Note: cudaEventSynchronize can be safely called from any device
  193. void synchronize() const {
  194. if (is_created_) {
  195. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  196. if (C10_UNLIKELY(interp)) {
  197. (*interp)->trace_gpu_event_synchronization(
  198. c10::kCUDA, reinterpret_cast<uintptr_t>(event_));
  199. }
  200. C10_CUDA_CHECK(cudaEventSynchronize(event_));
  201. }
  202. }
  203. // Note: cudaIpcGetEventHandle must be called on the same device as the event
  204. void ipc_handle(cudaIpcEventHandle_t* handle) {
  205. if (!is_created_) {
  206. // this CUDAEvent object was initially constructed from flags but event_
  207. // is not created yet.
  208. createEvent(getCurrentCUDAStream().device_index());
  209. }
  210. CUDAGuard guard(device_index_);
  211. C10_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_));
  212. }
  213. void create(DeviceIndex device_index) {
  214. if (!is_created_) {
  215. createEvent(device_index);
  216. }
  217. }
  218. private:
  219. unsigned int flags_ = cudaEventDisableTiming;
  220. bool is_created_ = false;
  221. bool was_recorded_ = false;
  222. bool external_ = false;
  223. DeviceIndex device_index_ = -1;
  224. cudaEvent_t event_{};
  225. void createEvent(DeviceIndex device_index) {
  226. external_ = (flags_ & cudaEventExternal) != 0;
  227. #ifdef USE_ROCM
  228. TORCH_CHECK(!external_, "External events are disallowed in rocm");
  229. #endif
  230. flags_ &= ~cudaEventExternal;
  231. device_index_ = device_index;
  232. CUDAGuard guard(device_index_);
  233. C10_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
  234. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  235. if (C10_UNLIKELY(interp)) {
  236. (*interp)->trace_gpu_event_creation(
  237. c10::kCUDA, reinterpret_cast<uintptr_t>(event_));
  238. }
  239. is_created_ = true;
  240. }
  241. void moveHelper(CUDAEvent&& other) {
  242. // Transfer ownership of all state from other to this
  243. flags_ = other.flags_;
  244. is_created_ = other.is_created_;
  245. was_recorded_ = other.was_recorded_;
  246. external_ = other.external_;
  247. device_index_ = other.device_index_;
  248. event_ = other.event_;
  249. // Reset other to a valid empty state to prevent double-free
  250. // The moved-from object must not attempt to destroy the event
  251. other.is_created_ = false;
  252. other.event_ = cudaEvent_t{};
  253. }
  254. };
  255. // CUDAEventPool - A thread-safe pool of CUDA events to avoid the overhead of
  256. // repeatedly calling cudaEventCreate(). Concurrent cudaEventCreate() calls
  257. // can incur significant cost on some device/driver combinations.
  258. //
  259. // This pool maintains per-device lists of pre-created CUDA events.
  260. // Borrowed events are returned to the pool via a custom unique_ptr deleter.
  261. class CUDAEventPool {
  262. public:
  263. using Event = std::unique_ptr<
  264. c10::cuda::CUDAEvent,
  265. std::function<void(c10::cuda::CUDAEvent*)>>;
  266. CUDAEventPool(size_t init_num_events = 0)
  267. : pools_(c10::cuda::device_count()) {
  268. if (init_num_events > 0) {
  269. reserve_events_on_pools(init_num_events);
  270. }
  271. }
  272. // Acquire an event associated with a given device. If device is invalid, fall
  273. // back to a regular CUDAEvent and no pooling.
  274. Event get(const DeviceIndex device) {
  275. if (device < 0 || device >= (DeviceIndex)pools_.size()) {
  276. auto deleter = [](CUDAEvent* event) { delete event; };
  277. return Event(std::make_unique<CUDAEvent>().release(), deleter);
  278. }
  279. auto& pool = pools_[device];
  280. // Create a destructor that returns the event to the appropriate device pool
  281. auto destructor = [&pool](CUDAEvent* event) noexcept {
  282. if (event != nullptr) {
  283. std::lock_guard<std::mutex> lock(pool.mutex_);
  284. pool.event_pool_.emplace_back(event);
  285. }
  286. };
  287. {
  288. std::lock_guard<std::mutex> lock(pool.mutex_);
  289. if (!pool.event_pool_.empty()) {
  290. auto event = std::move(pool.event_pool_.back());
  291. pool.event_pool_.pop_back();
  292. return Event(event.release(), destructor);
  293. }
  294. }
  295. // Pool is empty then create a new Event
  296. return Event(std::make_unique<CUDAEvent>().release(), destructor);
  297. }
  298. void empty_cache() {
  299. for (auto& pool : pools_) {
  300. std::lock_guard<std::mutex> lock(pool.mutex_);
  301. pool.event_pool_.clear();
  302. }
  303. }
  304. private:
  305. // Pre-initialize each device pool with N events. This prevents
  306. // cudaEventCreate() from invoking during steady-state execution.
  307. void reserve_events_on_pools(size_t num_events) {
  308. for (const auto device : c10::irange(pools_.size())) {
  309. std::vector<Event> temp_events;
  310. temp_events.reserve(num_events);
  311. pools_[device].event_pool_.reserve(num_events);
  312. for ([[maybe_unused]] const auto _ : c10::irange(num_events)) {
  313. auto event = get(device);
  314. event->create(device);
  315. temp_events.emplace_back(std::move(event));
  316. }
  317. // Events will be returned to pool when temp_events is destroyed.
  318. }
  319. }
  320. struct alignas(c10::hardware_destructive_interference_size) PerDevicePool {
  321. alignas(c10::hardware_destructive_interference_size) std::mutex mutex_;
  322. std::vector<std::unique_ptr<CUDAEvent>> event_pool_;
  323. };
  324. std::vector<PerDevicePool> pools_;
  325. };
  326. } // namespace c10::cuda
  327. #else
  328. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  329. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)