profiler.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. """Integration with pytorch profiler."""
  2. import os
  3. import wandb
  4. from wandb.errors import Error, UsageError
  5. from wandb.sdk.lib import telemetry
  6. PYTORCH_MODULE = "torch"
  7. PYTORCH_PROFILER_MODULE = "torch.profiler"
  8. def torch_trace_handler():
  9. """Create a trace handler for traces generated by the profiler.
  10. Provide as an argument to `torch.profiler.profile`:
  11. ```python
  12. torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler())
  13. ```
  14. Calling this function ensures that profiler charts & tables can be viewed in
  15. your run dashboard on wandb.ai.
  16. Please note that `wandb.init()` must be called before this function is
  17. invoked, and the reinit setting must not be set to "create_new". The PyTorch
  18. (torch) version must also be at least 1.9, in order to ensure stability of
  19. their Profiler API.
  20. Args:
  21. None
  22. Returns:
  23. None
  24. Raises:
  25. UsageError if wandb.init() hasn't been called before profiling.
  26. Error if torch version is less than 1.9.0.
  27. Examples:
  28. ```python
  29. run = wandb.init()
  30. run.config.id = "profile_code"
  31. with torch.profiler.profile(
  32. schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
  33. on_trace_ready=wandb.profiler.torch_trace_handler(),
  34. record_shapes=True,
  35. with_stack=True,
  36. ) as prof:
  37. for i, batch in enumerate(dataloader):
  38. if step >= 5:
  39. break
  40. train(batch)
  41. prof.step()
  42. ```
  43. """
  44. from packaging.version import parse
  45. torch = wandb.util.get_module(PYTORCH_MODULE, required=True)
  46. torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True)
  47. if parse(torch.__version__) < parse("1.9.0"):
  48. raise Error(
  49. f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\
  50. \nVersion of torch currently installed: {torch.__version__}"
  51. )
  52. try:
  53. logdir = os.path.join(wandb.run.dir, "pytorch_traces") # type: ignore
  54. os.mkdir(logdir)
  55. except AttributeError:
  56. raise UsageError(
  57. "Please call `wandb.init()` before `wandb.profiler.torch_trace_handler()`"
  58. ) from None
  59. with telemetry.context() as tel:
  60. tel.feature.torch_profiler_trace = True
  61. return torch_profiler.tensorboard_trace_handler(logdir)