XPUStream.h 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/Stream.h>
  4. #include <c10/core/impl/GPUTrace.h>
  5. #include <c10/xpu/XPUFunctions.h>
  6. namespace c10::xpu {
  7. /*
  8. * Note [Stream Management]
  9. *
  10. * An XPUStream is an abstraction of an actual SYCL queue in which SYCL kernel
  11. * can execute. Currently, there are several pools per device to manage SYCL
  12. * queue, and a device's pool is lazily created.
  13. *
  14. * There are two pools per device. The first pool contains "normal priority"
  15. * queues. The second pool is the "high priority" queues. There are 32 queues in
  16. * per pool per device, and when a queue is requested one of these queues is
  17. * returned round-robin. That is, the first queue requested is at index 0, the
  18. * second at index 1... to index 31, then index 0 again.
  19. *
  20. * This means that if 33 queues are requested, the first and last queues
  21. * requested are actually the same queue (under the covers) and kernels enqueued
  22. * on them cannot run concurrently.
  23. *
  24. * It is safe to enqueue a kernel on the same queue from two different
  25. * threads as the SYCL specification described.
  26. */
  27. static constexpr int max_compile_time_stream_priorities = 3;
  28. /*
  29. * This serves as a wrapper around c10::Stream and acts as a representation for
  30. * a SYCL queue, which allows asynchronous execution of XPU tasks.
  31. */
  32. class C10_XPU_API XPUStream {
  33. public:
  34. enum Unchecked { UNCHECKED };
  35. /// Construct a XPUStream from a Stream. This construction is checked, and
  36. /// will raise an error if the Stream is not, in fact, a XPU stream.
  37. explicit XPUStream(Stream stream) : stream_(stream) {
  38. TORCH_CHECK(stream_.device_type() == DeviceType::XPU);
  39. }
  40. /// Construct a XPUStream from a Stream with no error checking.
  41. explicit XPUStream(Unchecked /*unused*/, Stream stream) : stream_(stream) {}
  42. bool operator==(const XPUStream& other) const noexcept {
  43. return unwrap() == other.unwrap();
  44. }
  45. bool operator!=(const XPUStream& other) const noexcept {
  46. return unwrap() != other.unwrap();
  47. }
  48. /// Implicit conversion to sycl::queue&.
  49. operator sycl::queue&() const {
  50. return queue();
  51. }
  52. /// Implicit conversion to sycl::queue*.
  53. operator sycl::queue*() const {
  54. return &queue();
  55. }
  56. /// Implicit conversion to Stream (a.k.a., forget that the stream is a
  57. /// XPU stream).
  58. operator Stream() const {
  59. return unwrap();
  60. }
  61. /// Get the XPU device type that this stream is associated with.
  62. DeviceType device_type() const {
  63. return DeviceType::XPU;
  64. }
  65. /// Get the XPU device index that this stream is associated with.
  66. DeviceIndex device_index() const {
  67. return stream_.device_index();
  68. }
  69. /// Get the full Device that this stream is associated with. The Device is
  70. /// guaranteed to be a XPU device.
  71. Device device() const {
  72. return Device(DeviceType::XPU, device_index());
  73. }
  74. /// Return the stream ID corresponding to this particular stream. StreamId is
  75. /// a int64_t representation generated by its type and index.
  76. StreamId id() const {
  77. return stream_.id();
  78. }
  79. /// Return true if all enqueued tasks in this stream have been completed,
  80. /// otherwise return false.
  81. bool query() const {
  82. return queue().ext_oneapi_empty();
  83. }
  84. /// Performs a blocking wait for the completion of all enqueued tasks in this
  85. /// stream.
  86. void synchronize() const {
  87. queue().wait_and_throw();
  88. const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
  89. if (C10_UNLIKELY(interp)) {
  90. (*interp)->trace_gpu_stream_synchronization(
  91. c10::kXPU, reinterpret_cast<uintptr_t>(&queue()));
  92. }
  93. }
  94. /// Return the priority that this stream is associated with. Lower numbers
  95. /// represent higher priority.
  96. int priority() const;
  97. /// Explicit conversion to sycl::queue&.
  98. sycl::queue& queue() const;
  99. /// Explicit conversion to Stream.
  100. Stream unwrap() const {
  101. return stream_;
  102. }
  103. /// Reversibly pack a XPUStream into a struct representation. The XPUStream
  104. /// can be unpacked using unpack3().
  105. struct c10::StreamData3 pack3() const {
  106. return stream_.pack3();
  107. }
  108. /// Unpack a XPUStream from the 3 fields generated by pack3().
  109. static XPUStream unpack3(
  110. StreamId stream_id,
  111. DeviceIndex device_index,
  112. DeviceType device_type) {
  113. return XPUStream(Stream::unpack3(stream_id, device_index, device_type));
  114. }
  115. /// Return the range of priority **supported by PyTorch**.
  116. static std::tuple<int, int> priority_range() {
  117. // See Note [XPU Stream priorities]
  118. return std::make_tuple(1, -max_compile_time_stream_priorities + 2);
  119. }
  120. private:
  121. Stream stream_;
  122. };
  123. /**
  124. * Get a stream from the pool in a round-robin fashion.
  125. *
  126. * You can request a stream from the highest priority pool by setting
  127. * isHighPriority to true for a specific device.
  128. */
  129. C10_XPU_API XPUStream
  130. getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
  131. /**
  132. * Get a stream from the pool in a round-robin fashion.
  133. *
  134. * You can request a stream by setting a priority value for a specific device.
  135. * The priority number lower, the priority higher.
  136. */
  137. C10_XPU_API XPUStream
  138. getStreamFromPool(const int priority, DeviceIndex device = -1);
  139. /**
  140. * Get an XPUStream from an external SYCL queue.
  141. *
  142. * This function allows interoperability with other libraries by enabling
  143. * the use of an external SYCL queue that was not created by PyTorch. This
  144. * can be useful for data exchange or other operations where integration
  145. * with non-PyTorch queues is required.
  146. *
  147. * NOTE: It is the user's responsibility to ensure that the referenced SYCL
  148. * queue remains alive while the corresponding XPUStream, or any c10::Stream
  149. * derived from it, is in use. The different SYCL queue pointers will result in
  150. * distinct XPUStream instances, even if the SYCL queues they dereference are
  151. * equivalent.
  152. */
  153. C10_XPU_API XPUStream
  154. getStreamFromExternal(sycl::queue* ext_queue, DeviceIndex device_index);
  155. /**
  156. * Get the current XPU stream, for the passed XPU device, or for the current
  157. * device if no device index is passed.
  158. */
  159. C10_XPU_API XPUStream getCurrentXPUStream(DeviceIndex device = -1);
  160. /**
  161. * Set the current stream on the device of the passed in stream to be the passed
  162. * in stream.
  163. */
  164. C10_XPU_API void setCurrentXPUStream(XPUStream stream);
  165. C10_XPU_API std::ostream& operator<<(std::ostream& stream, const XPUStream& s);
  166. /**
  167. * Block all reserved SYCL queues in the stream pools on the device, and wait
  168. * for their synchronizations.
  169. */
  170. C10_XPU_API void syncStreamsOnDevice(DeviceIndex device = -1);
  171. } // namespace c10::xpu
  172. namespace std {
  173. template <>
  174. struct hash<c10::xpu::XPUStream> {
  175. size_t operator()(c10::xpu::XPUStream s) const noexcept {
  176. return std::hash<c10::Stream>{}(s.unwrap());
  177. }
  178. };
  179. } // namespace std
  180. #else
  181. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  182. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)