streams.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # mypy: allow-untyped-defs
  2. # pylint: disable=useless-parent-delegation
  3. from __future__ import annotations
  4. import ctypes
  5. import torch
  6. from torch._utils import _dummy_type
  7. if not hasattr(torch._C, "_XpuStreamBase"):
  8. # Define dummy base classes
  9. torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
  10. torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
  11. class Stream(torch._C._XpuStreamBase):
  12. r"""Wrapper around a XPU stream.
  13. A XPU stream is a linear sequence of execution that belongs to a specific
  14. device, independent from other streams. It supports with statement as a
  15. context manager to ensure the operators within the with block are running
  16. on the corresponding stream.
  17. Args:
  18. device(torch.device or int, optional): a device on which to allocate
  19. the stream. If :attr:`device` is ``None`` (default) or a negative
  20. integer, this will use the current device.
  21. priority(int, optional): priority of the stream, which can be positive, 0, or negative.
  22. A lower number indicates a higher priority. By default, the priority is set to 0.
  23. If the value falls outside of the allowed priority range, it will automatically be
  24. mapped to the nearest valid priority (lowest for large positive numbers or
  25. highest for large negative numbers).
  26. """
  27. def __new__(cls, device=None, priority=0, **kwargs):
  28. # setting device manager is expensive, so we avoid it unless necessary
  29. if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
  30. return super().__new__(cls, priority=priority, **kwargs)
  31. else:
  32. with torch.xpu.device(device):
  33. return super().__new__(cls, priority=priority, **kwargs)
  34. def wait_event(self, event: Event | torch.Event) -> None:
  35. r"""Make all future work submitted to the stream wait for an event.
  36. Args:
  37. event (Event, torch.Event): an event to wait for.
  38. """
  39. event.wait(self)
  40. def wait_stream(self, stream: Stream | torch.Stream) -> None:
  41. r"""Synchronize with another stream.
  42. All future work submitted to this stream will wait until all kernels
  43. submitted to a given stream at the time of call complete.
  44. Args:
  45. stream (Stream, torch.Stream): a stream to synchronize.
  46. """
  47. self.wait_event(stream.record_event())
  48. def record_event(self, event: Event | torch.Event | None = None):
  49. r"""Record an event.
  50. Args:
  51. event (Event, torch.Event, optional): event to record. If not given, a new one
  52. will be allocated.
  53. Returns:
  54. Recorded event.
  55. """
  56. if event is None:
  57. event = Event()
  58. event.record(self)
  59. return event
  60. def query(self) -> bool:
  61. r"""Check if all the work submitted has been completed.
  62. Returns:
  63. A boolean indicating if all kernels in this stream are completed.
  64. """
  65. return super().query()
  66. def synchronize(self) -> None:
  67. r"""Wait for all the kernels in this stream to complete."""
  68. super().synchronize()
  69. @property
  70. def _as_parameter_(self):
  71. return ctypes.c_void_p(self.sycl_queue)
  72. def __eq__(self, o):
  73. if isinstance(o, Stream):
  74. return super().__eq__(o)
  75. return False
  76. def __hash__(self):
  77. return hash((self.sycl_queue, self.device))
  78. def __repr__(self) -> str:
  79. return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
  80. class Event(torch._C._XpuEventBase):
  81. r"""Wrapper around a XPU event.
  82. XPU events are synchronization markers that can be used to monitor the
  83. device's progress, and to synchronize XPU streams.
  84. The underlying XPU events are lazily initialized when the event is first
  85. recorded. After creation, only streams on the same device may record the
  86. event. However, streams on any device can wait on the event.
  87. Args:
  88. enable_timing (bool, optional): indicates if the event should measure time
  89. (default: ``False``)
  90. """
  91. def __new__(cls, enable_timing=False):
  92. return super().__new__(cls, enable_timing=enable_timing)
  93. # pyrefly: ignore [bad-override]
  94. def record(self, stream: Stream | torch.Stream | None = None) -> None:
  95. r"""Record the event in a given stream.
  96. Args:
  97. stream (Stream, torch.Stream, optional): Uses ``torch.xpu.current_stream()`` if no stream is specified.
  98. The stream's device must match the event's device.
  99. """
  100. if stream is None:
  101. stream = torch.xpu.current_stream()
  102. # pyrefly: ignore [bad-argument-type]
  103. super().record(stream)
  104. def wait(self, stream: Stream | torch.Stream | None = None) -> None:
  105. r"""Make all future work submitted to the given stream wait for this event.
  106. Args:
  107. stream (Stream, torch.Stream, optional): Uses ``torch.xpu.current_stream()`` if no stream is specified.
  108. """
  109. if stream is None:
  110. stream = torch.xpu.current_stream()
  111. # pyrefly: ignore [bad-argument-type]
  112. super().wait(stream)
  113. def query(self) -> bool:
  114. r"""Check if all work currently captured by event has completed.
  115. Returns:
  116. A boolean indicating if all work currently captured by event has
  117. completed.
  118. """
  119. return super().query()
  120. def elapsed_time(self, end_event: Event):
  121. r"""Return the time elapsed.
  122. Time reported in milliseconds after the event was recorded and
  123. before the end_event was recorded.
  124. Args:
  125. end_event (Event): the end event.
  126. """
  127. return super().elapsed_time(end_event)
  128. def synchronize(self) -> None:
  129. r"""Wait for the event to complete.
  130. Waits until the completion of all work currently captured in this event.
  131. This prevents the CPU thread from proceeding until the event completes.
  132. """
  133. super().synchronize()
  134. @property
  135. def _as_parameter_(self):
  136. return ctypes.c_void_p(self.sycl_event)
  137. def __repr__(self) -> str:
  138. if self.sycl_event:
  139. return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})"
  140. else:
  141. return "torch.xpu.Event(uninitialized)"