| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- # mypy: allow-untyped-defs
- # pylint: disable=useless-parent-delegation
- from __future__ import annotations
- import ctypes
- import torch
- from torch._utils import _dummy_type
- if not hasattr(torch._C, "_XpuStreamBase"):
- # Define dummy base classes
- torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
- torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
- class Stream(torch._C._XpuStreamBase):
- r"""Wrapper around a XPU stream.
- A XPU stream is a linear sequence of execution that belongs to a specific
- device, independent from other streams. It supports with statement as a
- context manager to ensure the operators within the with block are running
- on the corresponding stream.
- Args:
- device(torch.device or int, optional): a device on which to allocate
- the stream. If :attr:`device` is ``None`` (default) or a negative
- integer, this will use the current device.
- priority(int, optional): priority of the stream, which can be positive, 0, or negative.
- A lower number indicates a higher priority. By default, the priority is set to 0.
- If the value falls outside of the allowed priority range, it will automatically be
- mapped to the nearest valid priority (lowest for large positive numbers or
- highest for large negative numbers).
- """
- def __new__(cls, device=None, priority=0, **kwargs):
- # setting device manager is expensive, so we avoid it unless necessary
- if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
- return super().__new__(cls, priority=priority, **kwargs)
- else:
- with torch.xpu.device(device):
- return super().__new__(cls, priority=priority, **kwargs)
- def wait_event(self, event: Event | torch.Event) -> None:
- r"""Make all future work submitted to the stream wait for an event.
- Args:
- event (Event, torch.Event): an event to wait for.
- """
- event.wait(self)
- def wait_stream(self, stream: Stream | torch.Stream) -> None:
- r"""Synchronize with another stream.
- All future work submitted to this stream will wait until all kernels
- submitted to a given stream at the time of call complete.
- Args:
- stream (Stream, torch.Stream): a stream to synchronize.
- """
- self.wait_event(stream.record_event())
- def record_event(self, event: Event | torch.Event | None = None):
- r"""Record an event.
- Args:
- event (Event, torch.Event, optional): event to record. If not given, a new one
- will be allocated.
- Returns:
- Recorded event.
- """
- if event is None:
- event = Event()
- event.record(self)
- return event
- def query(self) -> bool:
- r"""Check if all the work submitted has been completed.
- Returns:
- A boolean indicating if all kernels in this stream are completed.
- """
- return super().query()
- def synchronize(self) -> None:
- r"""Wait for all the kernels in this stream to complete."""
- super().synchronize()
- @property
- def _as_parameter_(self):
- return ctypes.c_void_p(self.sycl_queue)
- def __eq__(self, o):
- if isinstance(o, Stream):
- return super().__eq__(o)
- return False
- def __hash__(self):
- return hash((self.sycl_queue, self.device))
- def __repr__(self) -> str:
- return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
- class Event(torch._C._XpuEventBase):
- r"""Wrapper around a XPU event.
- XPU events are synchronization markers that can be used to monitor the
- device's progress, and to synchronize XPU streams.
- The underlying XPU events are lazily initialized when the event is first
- recorded. After creation, only streams on the same device may record the
- event. However, streams on any device can wait on the event.
- Args:
- enable_timing (bool, optional): indicates if the event should measure time
- (default: ``False``)
- """
- def __new__(cls, enable_timing=False):
- return super().__new__(cls, enable_timing=enable_timing)
- # pyrefly: ignore [bad-override]
- def record(self, stream: Stream | torch.Stream | None = None) -> None:
- r"""Record the event in a given stream.
- Args:
- stream (Stream, torch.Stream, optional): Uses ``torch.xpu.current_stream()`` if no stream is specified.
- The stream's device must match the event's device.
- """
- if stream is None:
- stream = torch.xpu.current_stream()
- # pyrefly: ignore [bad-argument-type]
- super().record(stream)
- def wait(self, stream: Stream | torch.Stream | None = None) -> None:
- r"""Make all future work submitted to the given stream wait for this event.
- Args:
- stream (Stream, torch.Stream, optional): Uses ``torch.xpu.current_stream()`` if no stream is specified.
- """
- if stream is None:
- stream = torch.xpu.current_stream()
- # pyrefly: ignore [bad-argument-type]
- super().wait(stream)
- def query(self) -> bool:
- r"""Check if all work currently captured by event has completed.
- Returns:
- A boolean indicating if all work currently captured by event has
- completed.
- """
- return super().query()
- def elapsed_time(self, end_event: Event):
- r"""Return the time elapsed.
- Time reported in milliseconds after the event was recorded and
- before the end_event was recorded.
- Args:
- end_event (Event): the end event.
- """
- return super().elapsed_time(end_event)
- def synchronize(self) -> None:
- r"""Wait for the event to complete.
- Waits until the completion of all work currently captured in this event.
- This prevents the CPU thread from proceeding until the event completes.
- """
- super().synchronize()
- @property
- def _as_parameter_(self):
- return ctypes.c_void_p(self.sycl_event)
- def __repr__(self) -> str:
- if self.sycl_event:
- return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})"
- else:
- return "torch.xpu.Event(uninitialized)"
|