| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- # mypy: allow-untyped-defs
- r"""This package adds support for NVIDIA Tools Extension (NVTX) used in profiling."""
- from contextlib import contextmanager
- try:
- from torch._C import _nvtx
- except ImportError:
- class _NVTXStub:
- @staticmethod
- def _fail(*args, **kwargs):
- raise RuntimeError(
- "NVTX functions not installed. Are you sure you have a CUDA build?"
- )
- rangePushA = _fail
- rangePop = _fail
- markA = _fail
- _nvtx = _NVTXStub() # type: ignore[assignment]
- __all__ = ["range_push", "range_pop", "range_start", "range_end", "mark", "range"]
- def range_push(msg):
- """
- Push a range onto a stack of nested range span. Returns zero-based depth of the range that is started.
- Args:
- msg (str): ASCII message to associate with range
- """
- return _nvtx.rangePushA(msg)
- def range_pop():
- """Pop a range off of a stack of nested range spans. Returns the zero-based depth of the range that is ended."""
- return _nvtx.rangePop()
- def range_start(msg) -> int:
- """
- Mark the start of a range with string message. It returns an unique handle
- for this range to pass to the corresponding call to rangeEnd().
- A key difference between this and range_push/range_pop is that the
- range_start/range_end version supports range across threads (start on one
- thread and end on another thread).
- Returns: A range handle (uint64_t) that can be passed to range_end().
- Args:
- msg (str): ASCII message to associate with the range.
- """
- # pyrefly: ignore [missing-attribute]
- return _nvtx.rangeStartA(msg)
- def range_end(range_id) -> None:
- """
- Mark the end of a range for a given range_id.
- Args:
- range_id (int): an unique handle for the start range.
- """
- # pyrefly: ignore [missing-attribute]
- _nvtx.rangeEnd(range_id)
- def _device_range_start(msg: str, stream: int = 0) -> object:
- """
- Marks the start of a range with string message.
- It returns an opaque heap-allocated handle for this range
- to pass to the corresponding call to device_range_end().
- A key difference between this and range_start is that the
- range_start marks the range right away, while _device_range_start
- marks the start of the range as soon as all the tasks on the
- CUDA stream are completed.
- Returns: An opaque heap-allocated handle that should be passed to _device_range_end().
- Args:
- msg (str): ASCII message to associate with the range.
- stream (int): CUDA stream id.
- """
- # pyrefly: ignore [missing-attribute]
- return _nvtx.deviceRangeStart(msg, stream)
- def _device_range_end(range_handle: object, stream: int = 0) -> None:
- """
- Mark the end of a range for a given range_handle as soon as all the tasks
- on the CUDA stream are completed.
- Args:
- range_handle: an unique handle for the start range.
- stream (int): CUDA stream id.
- """
- # pyrefly: ignore [missing-attribute]
- _nvtx.deviceRangeEnd(range_handle, stream)
- def mark(msg):
- """
- Describe an instantaneous event that occurred at some point.
- Args:
- msg (str): ASCII message to associate with the event.
- """
- return _nvtx.markA(msg)
- @contextmanager
- def range(msg, *args, **kwargs):
- """
- Context manager / decorator that pushes an NVTX range at the beginning
- of its scope, and pops it at the end. If extra arguments are given,
- they are passed as arguments to msg.format().
- Args:
- msg (str): message to associate with the range
- """
- range_push(msg.format(*args, **kwargs))
- try:
- yield
- finally:
- range_pop()
|