__init__.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. r"""
  2. PyTorch Profiler is a tool that allows the collection of performance metrics during training and inference.
  3. Profiler's context manager API can be used to better understand what model operators are the most expensive,
  4. examine their input shapes and stack traces, study device kernel activity and visualize the execution trace.
  5. .. note::
  6. An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated.
  7. """
  8. import os
  9. from typing import Any
  10. from typing_extensions import TypeVarTuple, Unpack
  11. from torch._C._autograd import _supported_activities, DeviceType, kineto_available
  12. from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope
  13. from torch._environment import is_fbcode
  14. from torch.autograd.profiler import KinetoStepTracker, record_function
  15. from torch.optim.optimizer import Optimizer, register_optimizer_step_post_hook
  16. from .profiler import (
  17. _KinetoProfile,
  18. ExecutionTraceObserver,
  19. profile,
  20. ProfilerAction,
  21. schedule,
  22. supported_activities,
  23. tensorboard_trace_handler,
  24. )
  25. __all__ = [
  26. "profile",
  27. "schedule",
  28. "supported_activities",
  29. "tensorboard_trace_handler",
  30. "ProfilerAction",
  31. "ProfilerActivity",
  32. "kineto_available",
  33. "DeviceType",
  34. "record_function",
  35. "ExecutionTraceObserver",
  36. ]
  37. from . import itt
  38. _Ts = TypeVarTuple("_Ts")
  39. def _optimizer_post_hook(
  40. optimizer: Optimizer, args: tuple[Unpack[_Ts]], kwargs: dict[str, Any]
  41. ) -> None:
  42. KinetoStepTracker.increment_step("Optimizer")
  43. if os.environ.get("KINETO_USE_DAEMON", "") or (
  44. is_fbcode() and os.environ.get("KINETO_FORCE_OPTIMIZER_HOOK", "")
  45. ):
  46. _ = register_optimizer_step_post_hook(_optimizer_post_hook)