profiler.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. from . import check_error, cudart
  4. __all__ = ["start", "stop", "profile"]
  5. DEFAULT_FLAGS = [
  6. "gpustarttimestamp",
  7. "gpuendtimestamp",
  8. "gridsize3d",
  9. "threadblocksize",
  10. "streamid",
  11. "enableonstart 0",
  12. "conckerneltrace",
  13. ]
  14. def start():
  15. r"""Starts cuda profiler data collection.
  16. .. warning::
  17. Raises CudaError in case of it is unable to start the profiler.
  18. """
  19. check_error(cudart().cudaProfilerStart())
  20. def stop():
  21. r"""Stops cuda profiler data collection.
  22. .. warning::
  23. Raises CudaError in case of it is unable to stop the profiler.
  24. """
  25. check_error(cudart().cudaProfilerStop())
  26. @contextlib.contextmanager
  27. def profile():
  28. """
  29. Enable profiling.
  30. Context Manager to enabling profile collection by the active profiling tool from CUDA backend.
  31. Example:
  32. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  33. >>> import torch
  34. >>> model = torch.nn.Linear(20, 30).cuda()
  35. >>> inputs = torch.randn(128, 20).cuda()
  36. >>> with torch.cuda.profiler.profile() as prof:
  37. ... model(inputs)
  38. """
  39. try:
  40. start()
  41. yield
  42. finally:
  43. stop()