streams.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # mypy: allow-untyped-defs
  2. # pylint: disable=useless-parent-delegation
  3. from __future__ import annotations
  4. import ctypes
  5. import torch
  6. from torch._utils import _dummy_type
  7. if not hasattr(torch._C, "_CudaStreamBase"):
  8. # Define dummy base classes
  9. torch._C.__dict__["_CudaStreamBase"] = _dummy_type("_CudaStreamBase")
  10. torch._C.__dict__["_CudaEventBase"] = _dummy_type("_CudaEventBase")
  11. class Stream(torch._C._CudaStreamBase):
  12. r"""Wrapper around a CUDA stream.
  13. A CUDA stream is a linear sequence of execution that belongs to a specific
  14. device, independent from other streams. It supports with statement as a
  15. context manager to ensure the operators within the with block are running
  16. on the corresponding stream. See :ref:`cuda-semantics` for details.
  17. Args:
  18. device(torch.device or int, optional): a device on which to allocate
  19. the stream. If :attr:`device` is ``None`` (default) or a negative
  20. integer, this will use the current device.
  21. priority(int, optional): priority of the stream, which can be positive, 0, or negative.
  22. A lower number indicates a higher priority. By default, the priority is set to 0.
  23. If the value falls outside of the allowed priority range, it will automatically be
  24. mapped to the nearest valid priority (lowest for large positive numbers or
  25. highest for large negative numbers).
  26. """
  27. def __new__(cls, device=None, priority=0, **kwargs):
  28. # Check CUDA availability
  29. if not torch.backends.cuda.is_built():
  30. raise RuntimeError("torch.cuda.Stream requires CUDA support")
  31. # setting device manager is expensive, so we avoid it unless necessary
  32. if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
  33. return super().__new__(cls, priority=priority, **kwargs)
  34. else:
  35. with torch.cuda.device(device):
  36. return super().__new__(cls, priority=priority, **kwargs)
  37. def wait_event(self, event: Event | torch.Event) -> None:
  38. r"""Make all future work submitted to the stream wait for an event.
  39. Args:
  40. event (Event, torch.Event): an event to wait for.
  41. .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
  42. `CUDA Stream documentation`_ for more info.
  43. This function returns without waiting for :attr:`event`: only future
  44. operations are affected.
  45. .. _CUDA Stream documentation:
  46. https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html
  47. """
  48. event.wait(self)
  49. def wait_stream(self, stream: Stream | torch.Stream) -> None:
  50. r"""Synchronize with another stream.
  51. All future work submitted to this stream will wait until all kernels
  52. submitted to a given stream at the time of call complete.
  53. Args:
  54. stream (Stream, torch.Stream): a stream to synchronize.
  55. .. note:: This function returns without waiting for currently enqueued
  56. kernels in :attr:`stream`: only future operations are affected.
  57. """
  58. self.wait_event(stream.record_event())
  59. def record_event(self, event: Event | torch.Event | None = None):
  60. r"""Record an event.
  61. Args:
  62. event (Event, torch.Event, optional): event to record. If not given, a new one
  63. will be allocated.
  64. Returns:
  65. Recorded event.
  66. """
  67. if event is None:
  68. event = Event()
  69. event.record(self)
  70. return event
  71. def query(self) -> bool:
  72. r"""Check if all the work submitted has been completed.
  73. Returns:
  74. A boolean indicating if all kernels in this stream are completed.
  75. """
  76. return super().query()
  77. def synchronize(self) -> None:
  78. r"""Wait for all the kernels in this stream to complete.
  79. .. note:: This is a wrapper around ``cudaStreamSynchronize()``: see
  80. `CUDA Stream documentation`_ for more info.
  81. """
  82. super().synchronize()
  83. @property
  84. def _as_parameter_(self):
  85. return ctypes.c_void_p(self.cuda_stream)
  86. def __eq__(self, o) -> bool:
  87. if isinstance(o, Stream):
  88. return super().__eq__(o)
  89. return False
  90. def __hash__(self):
  91. return hash((self.cuda_stream, self.device))
  92. def __repr__(self):
  93. return f"<torch.cuda.Stream device={self.device} cuda_stream={self.cuda_stream:#x}>"
  94. def __cuda_stream__(self):
  95. """Implements the CUDA Stream Protocol:
  96. https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
  97. Returns:
  98. tuple: A 2-tuple of (version, handle) where version is the protocol version
  99. and handle is the address of cudaStream_t (CUDA) or hipStream_t (ROCm) as a Python int.
  100. """
  101. return (0, self.cuda_stream)
  102. class ExternalStream(Stream):
  103. r"""Wrapper around an externally allocated CUDA stream.
  104. This class is used to wrap streams allocated in other libraries in order
  105. to facilitate data exchange and multi-library interactions.
  106. .. note:: This class doesn't manage the stream life-cycle, it is the user
  107. responsibility to keep the referenced stream alive while this class is
  108. being used.
  109. Args:
  110. stream_ptr(int): Integer representation of the `cudaStream_t` value.
  111. allocated externally.
  112. device(torch.device or int, optional): the device where the stream
  113. was originally allocated. If device is specified incorrectly,
  114. subsequent launches using this stream may fail.
  115. """
  116. def __new__(cls, stream_ptr, device=None, **kwargs):
  117. with torch.cuda.device(device):
  118. return super().__new__(cls, stream_ptr=stream_ptr, **kwargs)
  119. class Event(torch._C._CudaEventBase):
  120. r"""Wrapper around a CUDA event.
  121. CUDA events are synchronization markers that can be used to monitor the
  122. device's progress, to accurately measure timing, and to synchronize CUDA
  123. streams.
  124. The underlying CUDA events are lazily initialized when the event is first
  125. recorded or exported to another process. After creation, only streams on the
  126. same device may record the event. However, streams on any device can wait on
  127. the event.
  128. Args:
  129. enable_timing (bool, optional): indicates if the event should measure time
  130. (default: ``False``)
  131. blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
  132. interprocess (bool): if ``True``, the event can be shared between processes
  133. (default: ``False``)
  134. external (bool, optional): indicates whether this event should create event record and event wait nodes, or create an internal cross-stream dependency, when captured in a cuda graph. See `cross-stream dependencies <https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events>`_, `cudaEventRecordExternal <https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47>`_, and `cudaEventWaitExternal <https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e>`_ for more information about internal vs. external events. (default: ``False``)
  135. .. _CUDA Event Documentation:
  136. https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
  137. """ # noqa: B950
  138. def __new__(
  139. cls, enable_timing=False, blocking=False, interprocess=False, external=False
  140. ):
  141. return super().__new__(
  142. cls,
  143. enable_timing=enable_timing,
  144. blocking=blocking,
  145. interprocess=interprocess,
  146. external=external,
  147. )
  148. @classmethod
  149. def from_ipc_handle(cls, device, handle):
  150. r"""Reconstruct an event from an IPC handle on the given device."""
  151. return super().from_ipc_handle(device, handle)
  152. def record(self, stream: Stream | torch.Stream | None = None):
  153. r"""Record the event in a given stream.
  154. Args:
  155. stream (Stream, torch.Stream, optional): Uses ``torch.cuda.current_stream()`` if no stream is specified.
  156. The stream's device must match the event's device.
  157. """
  158. if stream is None:
  159. stream = torch.cuda.current_stream()
  160. # pyrefly: ignore [bad-argument-type]
  161. super().record(stream)
  162. def wait(self, stream: Stream | torch.Stream | None = None) -> None:
  163. r"""Make all future work submitted to the given stream wait for this event.
  164. Args:
  165. stream (Stream, torch.Stream, optional): Uses ``torch.cuda.current_stream()`` if no stream is specified.
  166. .. note:: This is a wrapper around ``cudaStreamWaitEvent()``: see
  167. `CUDA Event documentation`_ for more info.
  168. """
  169. if stream is None:
  170. stream = torch.cuda.current_stream()
  171. # pyrefly: ignore [bad-argument-type]
  172. super().wait(stream)
  173. def query(self):
  174. r"""Check if all work currently captured by event has completed.
  175. Returns:
  176. A boolean indicating if all work currently captured by event has
  177. completed.
  178. """
  179. return super().query()
  180. def elapsed_time(self, end_event: Event):
  181. r"""Return the time elapsed.
  182. Time reported in milliseconds after the event was recorded and
  183. before the end_event was recorded.
  184. Args:
  185. end_event (Event): the end event.
  186. """
  187. return super().elapsed_time(end_event)
  188. def synchronize(self) -> None:
  189. r"""Wait for the event to complete.
  190. Waits until the completion of all work currently captured in this event.
  191. This prevents the CPU thread from proceeding until the event completes.
  192. .. note:: This is a wrapper around ``cudaEventSynchronize()``: see
  193. `CUDA Event documentation`_ for more info.
  194. """
  195. super().synchronize()
  196. def ipc_handle(self):
  197. r"""Return an IPC handle of this event.
  198. If not recorded yet, the event will use the current device.
  199. """
  200. return super().ipc_handle()
  201. @property
  202. def _as_parameter_(self):
  203. return ctypes.c_void_p(self.cuda_event)
  204. def __repr__(self) -> str:
  205. if self.cuda_event:
  206. return f"<torch.cuda.Event {self._as_parameter_.value:#x}>"
  207. else:
  208. return "<torch.cuda.Event uninitialized>"