itt.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # mypy: allow-untyped-defs
  2. from contextlib import contextmanager
  3. from typing import NoReturn
  4. try:
  5. from torch._C import _itt
  6. except ImportError:
  7. class _ITTStub:
  8. @staticmethod
  9. def _fail(*args, **kwargs) -> NoReturn:
  10. raise RuntimeError(
  11. "ITT functions not installed. Are you sure you have a ITT build?"
  12. )
  13. @staticmethod
  14. def is_available() -> bool:
  15. return False
  16. rangePush = _fail
  17. rangePop = _fail
  18. mark = _fail
  19. _itt = _ITTStub() # type: ignore[assignment]
  20. __all__ = ["is_available", "range_push", "range_pop", "mark", "range"]
  21. def is_available():
  22. """
  23. Check if ITT feature is available or not
  24. """
  25. return _itt.is_available()
  26. def range_push(msg):
  27. """
  28. Pushes a range onto a stack of nested range span. Returns zero-based
  29. depth of the range that is started.
  30. Arguments:
  31. msg (str): ASCII message to associate with range
  32. """
  33. return _itt.rangePush(msg)
  34. def range_pop():
  35. """
  36. Pops a range off of a stack of nested range spans. Returns the
  37. zero-based depth of the range that is ended.
  38. """
  39. return _itt.rangePop()
  40. def mark(msg):
  41. """
  42. Describe an instantaneous event that occurred at some point.
  43. Arguments:
  44. msg (str): ASCII message to associate with the event.
  45. """
  46. return _itt.mark(msg)
  47. @contextmanager
  48. def range(msg, *args, **kwargs):
  49. """
  50. Context manager / decorator that pushes an ITT range at the beginning
  51. of its scope, and pops it at the end. If extra arguments are given,
  52. they are passed as arguments to msg.format().
  53. Args:
  54. msg (str): message to associate with the range
  55. """
  56. range_push(msg.format(*args, **kwargs))
  57. try:
  58. yield
  59. finally:
  60. range_pop()