CUDAStream.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <cuda_runtime_api.h>
  4. #include <c10/core/DeviceGuard.h>
  5. #include <c10/core/Stream.h>
  6. #include <c10/cuda/CUDAFunctions.h>
  7. #include <c10/util/Exception.h>
  8. /*
  9. * Stream pool note.
  10. *
  11. * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
  12. * are backed by cuStreams, but they use several pools to minimize the costs
  13. * associated with creating, retaining, and destroying cuStreams.
  14. *
  15. * There are three pools per device, and a device's pools are lazily created.
  16. *
  17. * The first pool contains only the default stream. When the default stream
  18. * is requested it's returned.
  19. *
  20. * The second pool is the "low priority" or "default priority" streams. In
  21. * HIP builds there is no distinction between streams in this pool and streams
  22. * in the third pool (below). There are 32 of these streams per device, and
  23. * when a stream is requested one of these streams is returned round-robin.
  24. * That is, the first stream requested is at index 0, the second at index 1...
  25. * to index 31, then index 0 again.
  26. *
  27. * This means that if 33 low priority streams are requested, the first and
  28. * last streams requested are actually the same stream (under the covers)
  29. * and kernels enqueued on them cannot run concurrently.
  30. *
  31. * The third pool is the "high priority" streams. The third pool acts like
  32. * the second pool except the streams are created with a higher priority.
  33. *
  34. * These pools suggest that stream users should prefer many short-lived streams,
  35. * as the cost of acquiring and releasing streams is effectively zero. If
  36. * many longer-lived streams are required in performance critical scenarios
  37. * then the functionality here may need to be extended to allow, for example,
  38. * "reserving" a subset of the pool so that other streams do not accidentally
  39. * overlap the performance critical streams.
  40. *
  41. * Note: although the notion of "current stream for device" is thread local
  42. * (every OS thread has a separate current stream, as one might expect),
  43. * the stream pool is global across all threads; stream 0 is always stream 0
  44. * no matter which thread you use it on. Multiple threads can synchronize
  45. * on the same stream. Although the CUDA documentation is not very clear
  46. * on the matter, streams are thread safe; e.g., it is safe to enqueue
  47. * a kernel on the same stream from two different threads.
  48. */
  49. namespace c10::cuda {
  50. static constexpr int max_compile_time_stream_priorities = 4;
  51. // Value object representing a CUDA stream. This is just a wrapper
  52. // around c10::Stream, but it comes with a little extra CUDA-specific
  53. // functionality (conversion to cudaStream_t), and a guarantee that
  54. // the wrapped c10::Stream really is a CUDA stream.
  55. class C10_CUDA_API CUDAStream {
  56. public:
  57. enum Unchecked { UNCHECKED };
  58. /// Construct a CUDAStream from a Stream. This construction is checked,
  59. /// and will raise an error if the Stream is not, in fact, a CUDA stream.
  60. explicit CUDAStream(Stream stream) : stream_(stream) {
  61. TORCH_CHECK(stream_.device_type() == DeviceType::CUDA);
  62. }
  63. /// Construct a CUDAStream from a Stream with no error checking.
  64. /// This constructor uses the "named" constructor idiom, and can
  65. /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream)
  66. explicit CUDAStream(Unchecked /*unused*/, Stream stream) : stream_(stream) {}
  67. bool operator==(const CUDAStream& other) const noexcept {
  68. return unwrap() == other.unwrap();
  69. }
  70. bool operator!=(const CUDAStream& other) const noexcept {
  71. return unwrap() != other.unwrap();
  72. }
  73. /// Implicit conversion to cudaStream_t.
  74. operator cudaStream_t() const {
  75. return stream();
  76. }
  77. /// Implicit conversion to Stream (a.k.a., forget that the stream is a
  78. /// CUDA stream).
  79. operator Stream() const {
  80. return unwrap();
  81. }
  82. /// Used to avoid baking in device type explicitly to Python-side API.
  83. DeviceType device_type() const {
  84. return DeviceType::CUDA;
  85. }
  86. /// Get the CUDA device index that this stream is associated with.
  87. DeviceIndex device_index() const {
  88. return stream_.device_index();
  89. }
  90. /// Get the full Device that this stream is associated with. The Device
  91. /// is guaranteed to be a CUDA device.
  92. Device device() const {
  93. return Device(DeviceType::CUDA, device_index());
  94. }
  95. /// Return the stream ID corresponding to this particular stream.
  96. StreamId id() const {
  97. return stream_.id();
  98. }
  99. bool query() const;
  100. void synchronize() const;
  101. int priority() const {
  102. DeviceGuard guard{stream_.device()};
  103. int priority = 0;
  104. C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
  105. return priority;
  106. }
  107. /// Explicit conversion to cudaStream_t.
  108. cudaStream_t stream() const;
  109. /// Explicit conversion to Stream.
  110. Stream unwrap() const {
  111. return stream_;
  112. }
  113. /// Reversibly pack a CUDAStream into a struct representation.
  114. /// Previously the stream's data was packed into a single int64_t,
  115. /// as it was assumed the fields would not require more than
  116. /// 64 bits of storage in total.
  117. /// See https://github.com/pytorch/pytorch/issues/75854
  118. /// for more information regarding newer platforms that may violate
  119. /// this assumption.
  120. ///
  121. /// The CUDAStream can be unpacked using unpack().
  122. struct c10::StreamData3 pack3() const {
  123. return stream_.pack3();
  124. }
  125. // Unpack a CUDAStream from the 3 fields generated by pack().
  126. static CUDAStream unpack3(
  127. StreamId stream_id,
  128. DeviceIndex device_index,
  129. DeviceType device_type) {
  130. return CUDAStream(Stream::unpack3(stream_id, device_index, device_type));
  131. }
  132. static std::tuple<int, int> priority_range() {
  133. // Note: this returns the range of priority **supported by PyTorch**, not
  134. // the range of priority **supported by CUDA**. The former is a subset of
  135. // the latter.
  136. int least_priority = 0, greatest_priority = 0;
  137. C10_CUDA_CHECK(
  138. cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
  139. #ifdef USE_ROCM
  140. // See Note [HIP stream priorities]
  141. TORCH_INTERNAL_ASSERT(
  142. least_priority == 1, "Unexpected HIP stream priority range");
  143. least_priority = 0;
  144. #else
  145. TORCH_INTERNAL_ASSERT(
  146. least_priority == 0, "Unexpected CUDA stream priority range");
  147. #endif
  148. TORCH_INTERNAL_ASSERT(
  149. greatest_priority <= -1, "Unexpected CUDA stream priority range");
  150. greatest_priority = std::max(
  151. -c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority);
  152. return std::make_tuple(least_priority, greatest_priority);
  153. }
  154. // Deleted for now; use CUDAEvent::block instead
  155. // void synchronize_with(const CUDAEvent& event) const;
  156. private:
  157. Stream stream_;
  158. };
  159. /**
  160. * Get a new stream from the CUDA stream pool. You can think of this
  161. * as "creating" a new stream, but no such creation actually happens;
  162. * instead, streams are preallocated from the pool and returned in a
  163. * round-robin fashion.
  164. *
  165. * You can request a stream from the high priority pool by setting
  166. * isHighPriority to true, or a stream for a specific device by setting device
  167. * (defaulting to the current CUDA stream.)
  168. */
  169. C10_API CUDAStream
  170. getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
  171. // no default priority to disambiguate overloads
  172. C10_API CUDAStream
  173. getStreamFromPool(const int priority, DeviceIndex device = -1);
  174. /**
  175. * Get a CUDAStream from a externally allocated one.
  176. *
  177. * This is mainly for interoperability with different libraries where we
  178. * want to operate on a non-torch allocated stream for data exchange or similar
  179. * purposes
  180. */
  181. C10_API CUDAStream
  182. getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index);
  183. /**
  184. * Get the default CUDA stream, for the passed CUDA device, or for the
  185. * current device if no device index is passed. The default stream is
  186. * where most computation occurs when you aren't explicitly using
  187. * streams.
  188. */
  189. C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1);
  190. /**
  191. * Get the current CUDA stream, for the passed CUDA device, or for the
  192. * current device if no device index is passed. The current CUDA stream
  193. * will usually be the default CUDA stream for the device, but it may
  194. * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
  195. * or 'CUDAStreamGuard'.
  196. */
  197. C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1);
  198. /**
  199. * Set the current stream on the device of the passed in stream to be
  200. * the passed in stream. Yes, you read that right: this function
  201. * has *nothing* to do with the current device: it toggles the current
  202. * stream of the device of the passed stream.
  203. *
  204. * Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead
  205. * (which will switch both your current device and current stream in the way you
  206. * expect, and reset it back to its original state afterwards).
  207. */
  208. C10_API void setCurrentCUDAStream(CUDAStream stream);
  209. C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s);
  210. } // namespace c10::cuda
  211. // hipify v2 backward compat in external projects
  212. #ifdef USE_ROCM
  213. namespace c10::hip {
  214. using c10::cuda::getStreamFromExternal;
  215. using c10::cuda::getStreamFromPool;
  216. // must use inline wrappers instead of reference aliases due to default args
  217. inline c10::cuda::CUDAStream getDefaultHIPStream(
  218. DeviceIndex device_index = -1) {
  219. return c10::cuda::getDefaultCUDAStream(device_index);
  220. }
  221. inline c10::cuda::CUDAStream getCurrentHIPStream(
  222. DeviceIndex device_index = -1) {
  223. return c10::cuda::getCurrentCUDAStream(device_index);
  224. }
  225. inline auto& setCurrentHIPStream = c10::cuda::setCurrentCUDAStream;
  226. inline c10::cuda::CUDAStream getStreamFromPoolMasqueradingAsCUDA(
  227. const bool isHighPriority = false,
  228. DeviceIndex device = -1) {
  229. return c10::cuda::getStreamFromPool(isHighPriority, device);
  230. }
  231. inline c10::cuda::CUDAStream getStreamFromPoolMasqueradingAsCUDA(
  232. const int priority,
  233. DeviceIndex device = -1) {
  234. return c10::cuda::getStreamFromPool(priority, device);
  235. }
  236. inline auto& getStreamFromExternalMasqueradingAsCUDA =
  237. c10::cuda::getStreamFromExternal;
  238. inline c10::cuda::CUDAStream getDefaultHIPStreamMasqueradingAsCUDA(
  239. DeviceIndex device_index = -1) {
  240. return c10::cuda::getDefaultCUDAStream(device_index);
  241. }
  242. inline c10::cuda::CUDAStream getCurrentHIPStreamMasqueradingAsCUDA(
  243. DeviceIndex device_index = -1) {
  244. return c10::cuda::getCurrentCUDAStream(device_index);
  245. }
  246. inline auto& setCurrentHIPStreamMasqueradingAsCUDA =
  247. c10::cuda::setCurrentCUDAStream;
  248. } // namespace c10::hip
  249. #endif
  250. namespace std {
  251. template <>
  252. struct hash<c10::cuda::CUDAStream> {
  253. size_t operator()(c10::cuda::CUDAStream s) const noexcept {
  254. return std::hash<c10::Stream>{}(s.unwrap());
  255. }
  256. };
  257. } // namespace std
  258. #else
  259. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  260. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)