profiler.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import contextlib
  2. from collections.abc import Iterator
  3. from typing import Literal
  4. import torch
  5. __all__ = [
  6. "start",
  7. "stop",
  8. "profile",
  9. "metal_capture",
  10. "is_metal_capture_enabled",
  11. "is_capturing_metal",
  12. ]
  13. ProfilerMode = Literal["interval", "event", "interval,event"]
  14. def start(mode: ProfilerMode = "interval", wait_until_completed: bool = False) -> None:
  15. r"""Start OS Signpost tracing from MPS backend.
  16. The generated OS Signposts could be recorded and viewed in
  17. XCode Instruments Logging tool.
  18. Args:
  19. mode(str): OS Signpost tracing mode could be "interval", "event",
  20. or both "interval,event".
  21. The interval mode traces the duration of execution of the operations,
  22. whereas event mode marks the completion of executions.
  23. See document `Recording Performance Data`_ for more info.
  24. wait_until_completed(bool): Waits until the MPS Stream complete
  25. executing each encoded GPU operation. This helps generating single
  26. dispatches on the trace's timeline.
  27. Note that enabling this option would affect the performance negatively.
  28. .. _Recording Performance Data:
  29. https://developer.apple.com/documentation/os/logging/recording_performance_data
  30. """
  31. mode_normalized = mode.lower().replace(" ", "")
  32. torch._C._mps_profilerStartTrace( # type: ignore[attr-defined]
  33. mode_normalized, wait_until_completed
  34. )
  35. def stop() -> None:
  36. r"""Stops generating OS Signpost tracing from MPS backend."""
  37. torch._C._mps_profilerStopTrace() # type: ignore[attr-defined]
  38. @contextlib.contextmanager
  39. def profile(
  40. mode: ProfilerMode = "interval", wait_until_completed: bool = False
  41. ) -> Iterator[None]:
  42. r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
  43. Args:
  44. mode(str): OS Signpost tracing mode could be "interval", "event",
  45. or both "interval,event".
  46. The interval mode traces the duration of execution of the operations,
  47. whereas event mode marks the completion of executions.
  48. See document `Recording Performance Data`_ for more info.
  49. wait_until_completed(bool): Waits until the MPS Stream complete
  50. executing each encoded GPU operation. This helps generating single
  51. dispatches on the trace's timeline.
  52. Note that enabling this option would affect the performance negatively.
  53. .. _Recording Performance Data:
  54. https://developer.apple.com/documentation/os/logging/recording_performance_data
  55. """
  56. try:
  57. start(mode, wait_until_completed)
  58. yield
  59. finally:
  60. stop()
  61. def is_metal_capture_enabled() -> bool:
  62. """Checks if `metal_capture` context manager is usable
  63. To enable metal capture, set MTL_CAPTURE_ENABLED envvar
  64. """
  65. return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined, no-any-return]
  66. def is_capturing_metal() -> bool:
  67. """Checks if metal capture is in progress"""
  68. return torch._C._mps_isCapturing() # type: ignore[attr-defined, no-any-return]
  69. @contextlib.contextmanager
  70. def metal_capture(fname: str) -> Iterator[None]:
  71. """Context manager that enables capturing of Metal calls into gputrace"""
  72. try:
  73. torch._C._mps_startCapture(fname) # type: ignore[attr-defined]
  74. yield
  75. # Drain all the work that were enqueued during the context call
  76. torch.mps.synchronize()
  77. finally:
  78. torch._C._mps_stopCapture() # type: ignore[attr-defined]