profiler.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190
  1. # mypy: allow-untyped-defs
  2. import gzip
  3. import json
  4. import os
  5. import shutil
  6. import tempfile
  7. from abc import ABC, abstractmethod
  8. from collections.abc import Callable, Iterable
  9. from enum import Enum
  10. from functools import partial
  11. from typing import Any, Optional
  12. from typing_extensions import deprecated, Self
  13. from warnings import warn
  14. import torch
  15. import torch.autograd.profiler as prof
  16. from torch._C import _get_privateuse1_backend_name
  17. from torch._C._profiler import (
  18. _add_execution_trace_observer,
  19. _disable_execution_trace_observer,
  20. _enable_execution_trace_observer,
  21. _ExperimentalConfig,
  22. _remove_execution_trace_observer,
  23. )
  24. from torch._environment import is_fbcode
  25. from torch._utils_internal import profiler_allow_cudagraph_cupti_lazy_reinit_cuda12
  26. from torch.autograd import kineto_available, ProfilerActivity
  27. from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
  28. __all__ = [
  29. "supported_activities",
  30. "ProfilerAction",
  31. "schedule",
  32. "tensorboard_trace_handler",
  33. "profile",
  34. "ExecutionTraceObserver",
  35. ]
  36. PROFILER_STEP_NAME = "ProfilerStep"
  37. _WARNINGS_SHOWN = set()
  38. def _warn_once(msg, category=UserWarning, stacklevel=2):
  39. if msg not in _WARNINGS_SHOWN:
  40. _WARNINGS_SHOWN.add(msg)
  41. warn(msg, category=category, stacklevel=stacklevel)
  42. class _NumpyEncoder(json.JSONEncoder):
  43. """
  44. Json encoder for numpy types (np.int, np.float, np.array etc.)
  45. Returns default encoder if numpy is not available
  46. """
  47. def default(self, obj):
  48. """Encode NumPy types to JSON"""
  49. try:
  50. import numpy as np
  51. except ImportError:
  52. return json.JSONEncoder.default(self, obj)
  53. if isinstance(obj, np.integer):
  54. return int(obj)
  55. elif isinstance(obj, np.floating):
  56. return float(obj)
  57. elif isinstance(obj, np.ndarray):
  58. return obj.tolist()
  59. else:
  60. return json.JSONEncoder.default(self, obj)
  61. def supported_activities():
  62. """
  63. Returns a set of supported profiler tracing activities.
  64. Note: profiler uses CUPTI library to trace on-device CUDA kernels.
  65. In case when CUDA is enabled but CUPTI is not available, passing
  66. ``ProfilerActivity.CUDA`` to profiler results in using the legacy CUDA
  67. profiling code (same as in the legacy ``torch.autograd.profiler``).
  68. This, in turn, results in including CUDA time in the profiler table output,
  69. but not in the JSON trace.
  70. """
  71. return torch.autograd._supported_activities()
  72. class _ITraceObserver(ABC):
  73. """Abstract interface for a Trace observer.
  74. This satisfies 3 methods: start, stop and cleanup"""
  75. @abstractmethod
  76. def start(self):
  77. pass
  78. @abstractmethod
  79. def stop(self):
  80. pass
  81. @abstractmethod
  82. def cleanup(self):
  83. pass
  84. class _KinetoProfile:
  85. """Low-level profiler wrap the autograd profile
  86. Args:
  87. activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
  88. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
  89. ``torch.profiler.ProfilerActivity.XPU``.
  90. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
  91. or (when available) ProfilerActivity.XPU.
  92. record_shapes (bool): save information about operator's input shapes.
  93. profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline``
  94. for more details).
  95. with_stack (bool): record source information (file and line number) for the ops.
  96. with_flops (bool): use formula to estimate the FLOPS of specific operators
  97. (matrix multiplication and 2D convolution).
  98. with_modules (bool): record module hierarchy (including function names)
  99. corresponding to the callstack of the op. e.g. If module A's forward call's
  100. module B's forward which contains an aten::add op,
  101. then aten::add's module hierarchy is A.B
  102. Note that this support exist, at the moment, only for TorchScript models
  103. and not eager mode models.
  104. experimental_config (_ExperimentalConfig) : A set of experimental options
  105. used by profiler libraries like Kineto. Note, backward compatibility is not guaranteed.
  106. execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
  107. `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
  108. representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
  109. When this argument is included the observer start() and stop() will be called for the
  110. same time window as PyTorch profiler.
  111. acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
  112. post_processing_timeout_s (float): Optional timeout in seconds for post-processing profiler
  113. results. In this context, post-processing happens after the profiling itself has finished.
  114. If specified, event parsing will stop after this duration and return partial results. Useful
  115. for handling large traces that may take too long to process.
  116. .. note::
  117. This API is experimental and subject to change in the future.
  118. Enabling shape and stack tracing results in additional overhead.
  119. When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
  120. that may further prevent certain optimizations that depend on the reference count and introduce
  121. extra tensor copies.
  122. """
  123. def __init__(
  124. self,
  125. *,
  126. activities: Iterable[ProfilerActivity] | None = None,
  127. record_shapes: bool = False,
  128. profile_memory: bool = False,
  129. with_stack: bool = False,
  130. with_flops: bool = False,
  131. with_modules: bool = False,
  132. experimental_config: _ExperimentalConfig | None = None,
  133. execution_trace_observer: _ITraceObserver | None = None,
  134. acc_events: bool = False,
  135. custom_trace_id_callback: Callable[[], str] | None = None,
  136. post_processing_timeout_s: float | None = None,
  137. ) -> None:
  138. self.activities = set(activities) if activities else supported_activities()
  139. self.record_shapes = record_shapes
  140. self.with_flops = with_flops
  141. self.profile_memory = profile_memory
  142. self.with_stack = with_stack
  143. self.with_modules = with_modules
  144. self.experimental_config = experimental_config
  145. self.execution_trace_observer = execution_trace_observer
  146. self.acc_events = acc_events
  147. self.custom_trace_id_callback = custom_trace_id_callback
  148. self.post_processing_timeout_s = post_processing_timeout_s
  149. self.profiler: prof.profile | None = None
  150. self.has_cudagraphs = False
  151. self.mem_tl: MemoryProfileTimeline | None = None
  152. self.use_device = None
  153. if ProfilerActivity.CUDA in self.activities:
  154. # pyrefly: ignore [bad-assignment]
  155. self.use_device = "cuda"
  156. elif ProfilerActivity.XPU in self.activities:
  157. # pyrefly: ignore [bad-assignment]
  158. self.use_device = "xpu"
  159. elif ProfilerActivity.MTIA in self.activities:
  160. # pyrefly: ignore [bad-assignment]
  161. self.use_device = "mtia"
  162. elif ProfilerActivity.HPU in self.activities:
  163. # pyrefly: ignore [bad-assignment]
  164. self.use_device = "hpu"
  165. elif ProfilerActivity.PrivateUse1 in self.activities:
  166. # pyrefly: ignore [bad-assignment]
  167. self.use_device = _get_privateuse1_backend_name()
  168. # user-defined metadata to be amended to the trace
  169. self.preset_metadata: dict[str, str] = {}
  170. def start(self) -> None:
  171. self.prepare_trace()
  172. self.start_trace()
  173. def stop(self) -> None:
  174. self.stop_trace()
  175. def prepare_trace(self) -> None:
  176. if hasattr(torch, "_inductor"):
  177. import torch._inductor.config as inductor_config
  178. self.has_cudagraphs = inductor_config.triton.cudagraphs
  179. if (self.profiler is None) or (not self.acc_events):
  180. self.profiler = prof.profile(
  181. use_cpu=(ProfilerActivity.CPU in self.activities),
  182. use_device=self.use_device,
  183. record_shapes=self.record_shapes,
  184. with_flops=self.with_flops,
  185. profile_memory=self.profile_memory,
  186. with_stack=self.with_stack,
  187. with_modules=self.with_modules,
  188. use_kineto=True,
  189. experimental_config=self.experimental_config,
  190. acc_events=self.acc_events,
  191. custom_trace_id_callback=self.custom_trace_id_callback,
  192. post_processing_timeout_s=self.post_processing_timeout_s,
  193. )
  194. if (self.profiler is not None) and (not self.acc_events):
  195. _warn_once(
  196. "Warning: Profiler clears events at the end of each cycle."
  197. "Only events from the current cycle will be reported."
  198. "To keep events across cycles, set acc_events=True."
  199. )
  200. self.profiler._prepare_trace()
  201. def start_trace(self) -> None:
  202. if self.execution_trace_observer:
  203. self.execution_trace_observer.start()
  204. if self.profiler is None:
  205. raise AssertionError("Profiler must be initialized before starting trace")
  206. self.profiler._start_trace()
  207. if self.profile_memory:
  208. self.add_metadata_json("profile_memory", "1")
  209. if self.with_stack:
  210. self.add_metadata_json("with_stack", "1")
  211. if self.record_shapes:
  212. self.add_metadata_json("record_shapes", "1")
  213. if self.with_modules:
  214. self.add_metadata_json("with_modules", "1")
  215. if self.with_flops:
  216. self.add_metadata_json("with_flops", "1")
  217. if kineto_available():
  218. dist_info = self._get_distributed_info()
  219. if dist_info:
  220. self.add_metadata_json(
  221. "distributedInfo", json.dumps(dist_info, cls=_NumpyEncoder)
  222. )
  223. cuda_version = None
  224. if hasattr(torch, "version"):
  225. from torch.torch_version import TorchVersion
  226. cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
  227. if self.has_cudagraphs and (
  228. (cuda_version and cuda_version < "12.6")
  229. or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
  230. ):
  231. os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
  232. self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1")
  233. # FIXME: CUDA Graph does not work well with CUPTI teardown.
  234. # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)
  235. # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12)
  236. # Workaround: turn off CUPTI teardown when using CUDA Graphs.
  237. os.environ["TEARDOWN_CUPTI"] = "0"
  238. # Insert the preset user metadata to the trace
  239. for k, v in self.preset_metadata.items():
  240. self.add_metadata_json(k, v)
  241. def stop_trace(self) -> None:
  242. if self.execution_trace_observer:
  243. self.execution_trace_observer.stop()
  244. if self.profiler is None:
  245. raise AssertionError("Profiler must be initialized before stopping trace")
  246. self.profiler.__exit__(None, None, None)
  247. def export_chrome_trace(self, path: str):
  248. """
  249. Exports the collected trace in Chrome JSON format. If kineto is enabled, only
  250. last cycle in schedule is exported.
  251. """
  252. if self.profiler is None:
  253. raise AssertionError(
  254. "Profiler must be initialized before exporting chrome trace"
  255. )
  256. if path.endswith(".gz"):
  257. with tempfile.NamedTemporaryFile("w+b", suffix=".json") as fp:
  258. retvalue = self.profiler.export_chrome_trace(fp.name)
  259. with open(fp.name, "rb") as fin, gzip.open(path, "wb") as fout:
  260. fout.writelines(fin)
  261. return retvalue
  262. else:
  263. return self.profiler.export_chrome_trace(path)
  264. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  265. """Save stack traces to a file
  266. Args:
  267. path (str): save stacks file to this location;
  268. metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total"
  269. """
  270. if self.profiler is None:
  271. raise AssertionError("Profiler must be initialized before exporting stacks")
  272. return self.profiler.export_stacks(path, metric)
  273. def toggle_collection_dynamic(
  274. self, enable: bool, activities: Iterable[ProfilerActivity]
  275. ) -> None:
  276. """Toggle collection of activities on/off at any point of collection. Currently supports toggling Torch Ops
  277. (CPU) and CUDA activity supported in Kineto
  278. Args:
  279. activities (iterable): list of activity groups to use in profiling, supported values:
  280. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``
  281. Examples:
  282. .. code-block:: python
  283. with torch.profiler.profile(
  284. activities=[
  285. torch.profiler.ProfilerActivity.CPU,
  286. torch.profiler.ProfilerActivity.CUDA,
  287. ]
  288. ) as p:
  289. code_to_profile_0()
  290. // turn off collection of all CUDA activity
  291. p.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA])
  292. code_to_profile_1()
  293. // turn on collection of all CUDA activity
  294. p.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA])
  295. code_to_profile_2()
  296. print(p.key_averages().table(
  297. sort_by="self_cuda_time_total", row_limit=-1))
  298. """
  299. if self.profiler is None:
  300. return
  301. self.profiler.toggle_collection_dynamic(enable, activities)
  302. def key_averages(
  303. self,
  304. group_by_input_shape: bool = False,
  305. group_by_stack_n: int = 0,
  306. group_by_overload_name: bool = False,
  307. ):
  308. """Averages events, grouping them by operator name and (optionally) input shapes, stack
  309. and overload name.
  310. .. note::
  311. To use shape/stack functionality make sure to set record_shapes/with_stack
  312. when creating profiler context manager.
  313. """
  314. if self.profiler is None:
  315. raise AssertionError(
  316. "Profiler must be initialized before getting key averages"
  317. )
  318. return self.profiler.key_averages(
  319. group_by_input_shape, group_by_stack_n, group_by_overload_name
  320. )
  321. def events(self):
  322. """
  323. Returns the list of unaggregated profiler events,
  324. to be used in the trace callback or after the profiling is finished
  325. """
  326. if self.profiler is None:
  327. raise AssertionError("Profiler must be initialized before accessing events")
  328. return self.profiler.function_events
  329. def add_metadata(self, key: str, value: str) -> None:
  330. """
  331. Adds a user defined metadata with a string key and a string value
  332. into the trace file
  333. """
  334. wrapped_value = '"' + value.replace('"', '\\"') + '"'
  335. torch.autograd._add_metadata_json(key, wrapped_value)
  336. def add_metadata_json(self, key: str, value: str) -> None:
  337. """
  338. Adds a user defined metadata with a string key and a valid json value
  339. into the trace file
  340. """
  341. torch.autograd._add_metadata_json(key, value)
  342. def preset_metadata_json(self, key: str, value: str) -> None:
  343. """
  344. Preset a user defined metadata when the profiler is not started
  345. and added into the trace file later.
  346. Metadata is in the format of a string key and a valid json value
  347. """
  348. self.preset_metadata[key] = value
  349. def _get_distributed_info(self):
  350. import torch.distributed as dist
  351. if not dist.is_available() or not dist.is_initialized():
  352. return None
  353. backend = dist.get_backend()
  354. dist_info = {
  355. "backend": backend,
  356. "rank": dist.get_rank(),
  357. "world_size": dist.get_world_size(),
  358. "pg_count": dist.get_pg_count(),
  359. "pg_config": dist.distributed_c10d._get_all_pg_configs(),
  360. }
  361. if backend == "nccl":
  362. nccl_version = torch.cuda.nccl.version()
  363. # pyrefly: ignore [bad-typed-dict-key, unsupported-operation]
  364. dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
  365. return dist_info
  366. def _memory_profile(self) -> MemoryProfile:
  367. required = ("record_shapes", "profile_memory", "with_stack")
  368. missing = [f"{i}=True" for i in required if not getattr(self, i)]
  369. if missing:
  370. raise ValueError(f"{', '.join(missing)} required for memory profiling.")
  371. if self.profiler is None or self.profiler.kineto_results is None:
  372. raise AssertionError(
  373. "Profiler and kineto_results must be initialized for memory profiling"
  374. )
  375. return MemoryProfile(self.profiler.kineto_results)
  376. @deprecated(
  377. "`export_memory_timeline` is deprecated and will be removed in a future version. "
  378. "Please use `torch.cuda.memory._record_memory_history` and `torch.cuda.memory._export_memory_snapshot` instead.",
  379. category=FutureWarning,
  380. )
  381. def export_memory_timeline(self, path: str, device: str | None = None) -> None:
  382. """Export memory event information from the profiler collected
  383. tree for a given device, and export a timeline plot. There are 3
  384. exportable files using ``export_memory_timeline``, each controlled by the
  385. ``path``'s suffix.
  386. - For an HTML compatible plot, use the suffix ``.html``, and a memory timeline
  387. plot will be embedded as a PNG file in the HTML file.
  388. - For plot points consisting of ``[times, [sizes by category]]``, where
  389. ``times`` are timestamps and ``sizes`` are memory usage for each category.
  390. The memory timeline plot will be saved a JSON (``.json``) or gzipped JSON
  391. (``.json.gz``) depending on the suffix.
  392. - For raw memory points, use the suffix ``.raw.json.gz``. Each raw memory
  393. event will consist of ``(timestamp, action, numbytes, category)``, where
  394. ``action`` is one of ``[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]``,
  395. and ``category`` is one of the enums from
  396. ``torch.profiler._memory_profiler.Category``.
  397. Output: Memory timeline written as gzipped JSON, JSON, or HTML.
  398. .. deprecated::
  399. ``export_memory_timeline`` is deprecated and will be removed in a future version.
  400. Please use ``torch.cuda.memory._record_memory_history`` and
  401. ``torch.cuda.memory._export_memory_snapshot`` instead.
  402. """
  403. # Default to device 0, if unset. Fallback on cpu.
  404. if device is None:
  405. if self.use_device and self.use_device != "cuda":
  406. device = self.use_device + ":0"
  407. else:
  408. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  409. # Construct the memory timeline plot data
  410. self.mem_tl = MemoryProfileTimeline(self._memory_profile())
  411. # Depending on the file suffix, save the data as json.gz or json.
  412. # For html, we can embed the image into an HTML file.
  413. if path.endswith(".html"):
  414. self.mem_tl.export_memory_timeline_html(path, device)
  415. elif path.endswith(".gz"):
  416. with tempfile.NamedTemporaryFile("w+t", suffix=".json") as fp:
  417. if path.endswith("raw.json.gz"):
  418. self.mem_tl.export_memory_timeline_raw(fp.name, device)
  419. else:
  420. self.mem_tl.export_memory_timeline(fp.name, device)
  421. with open(fp.name) as fin, gzip.open(path, "wt") as fout:
  422. fout.writelines(fin)
  423. else:
  424. self.mem_tl.export_memory_timeline(path, device)
  425. class ProfilerAction(Enum):
  426. """
  427. Profiler actions that can be taken at the specified intervals
  428. """
  429. NONE = 0
  430. WARMUP = 1
  431. RECORD = 2
  432. RECORD_AND_SAVE = 3
  433. def schedule(
  434. *,
  435. wait: int,
  436. warmup: int,
  437. active: int,
  438. repeat: int = 0,
  439. skip_first: int = 0,
  440. skip_first_wait: int = 0,
  441. ) -> Callable:
  442. """
  443. Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip
  444. the first ``skip_first`` steps, then wait for ``wait`` steps, then do the warmup for the next ``warmup`` steps,
  445. then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps.
  446. The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that
  447. the cycles will continue until the profiling is finished.
  448. The ``skip_first_wait`` parameter controls whether the first ``wait`` stage should be skipped.
  449. This can be useful if a user wants to wait longer than ``skip_first`` between cycles, but not
  450. for the first profile. For example, if ``skip_first`` is 10 and ``wait`` is 20, the first cycle will
  451. wait 10 + 20 = 30 steps before warmup if ``skip_first_wait`` is zero, but will wait only 10
  452. steps if ``skip_first_wait`` is non-zero. All subsequent cycles will then wait 20 steps between the
  453. last active and warmup.
  454. """
  455. def schedule_fn(step: int) -> ProfilerAction:
  456. if step < 0:
  457. raise AssertionError(f"Step must be non-negative. Got {step}.")
  458. if step < skip_first:
  459. return ProfilerAction.NONE
  460. else:
  461. step -= skip_first
  462. # If wait >> skip_first and we want to grab profiling early, shift left by wait if skip_first_wait is True
  463. if skip_first_wait != 0:
  464. step += wait
  465. num_steps = wait + warmup + active
  466. if repeat > 0 and step / num_steps >= repeat:
  467. return ProfilerAction.NONE
  468. mod_step = step % num_steps
  469. if mod_step < wait:
  470. return ProfilerAction.NONE
  471. elif mod_step < wait + warmup:
  472. return ProfilerAction.WARMUP
  473. else:
  474. return (
  475. ProfilerAction.RECORD
  476. if mod_step < num_steps - 1
  477. else ProfilerAction.RECORD_AND_SAVE
  478. )
  479. if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0:
  480. raise AssertionError(
  481. f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), "
  482. f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)."
  483. )
  484. if warmup == 0:
  485. warn(
  486. "Profiler won't be using warmup, this can skew profiler results",
  487. stacklevel=2,
  488. )
  489. return schedule_fn
  490. def _default_schedule_fn(_: int) -> ProfilerAction:
  491. """
  492. Default profiler behavior - immediately starts recording the events,
  493. keeps doing it on every profiler step.
  494. """
  495. return ProfilerAction.RECORD
  496. def tensorboard_trace_handler(
  497. dir_name: str, worker_name: str | None = None, use_gzip: bool = False
  498. ):
  499. """
  500. Outputs tracing files to directory of ``dir_name``, then that directory can be
  501. directly delivered to tensorboard as logdir.
  502. ``worker_name`` should be unique for each worker in distributed scenario,
  503. it will be set to '[hostname]_[pid]' by default.
  504. """
  505. import socket
  506. import time
  507. def handler_fn(prof) -> None:
  508. nonlocal worker_name
  509. if not os.path.isdir(dir_name):
  510. try:
  511. os.makedirs(dir_name, exist_ok=True)
  512. except Exception as e:
  513. raise RuntimeError("Can't create directory: " + dir_name) from e
  514. if not worker_name:
  515. worker_name = f"{socket.gethostname()}_{os.getpid()}"
  516. # Use nanosecond here to avoid naming clash when exporting the trace
  517. file_name = f"{worker_name}.{time.time_ns()}.pt.trace.json"
  518. if use_gzip:
  519. file_name = file_name + ".gz"
  520. prof.export_chrome_trace(os.path.join(dir_name, file_name))
  521. return handler_fn
  522. class profile(_KinetoProfile):
  523. """Profiler context manager.
  524. Args:
  525. activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values:
  526. ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``,
  527. ``torch.profiler.ProfilerActivity.XPU``.
  528. Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA
  529. or (when available) ProfilerActivity.XPU.
  530. schedule (Callable): callable that takes step (int) as a single parameter and returns
  531. ``ProfilerAction`` value that specifies the profiler action to perform at each step.
  532. on_trace_ready (Callable): callable that is called at each step when ``schedule``
  533. returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling.
  534. record_shapes (bool): save information about operator's input shapes.
  535. profile_memory (bool): track tensor memory allocation/deallocation.
  536. with_stack (bool): record source information (file and line number) for the ops.
  537. with_flops (bool): use formula to estimate the FLOPs (floating point operations) of specific operators
  538. (matrix multiplication and 2D convolution).
  539. with_modules (bool): record module hierarchy (including function names)
  540. corresponding to the callstack of the op. e.g. If module A's forward call's
  541. module B's forward which contains an aten::add op,
  542. then aten::add's module hierarchy is A.B
  543. Note that this support exist, at the moment, only for TorchScript models
  544. and not eager mode models.
  545. experimental_config (_ExperimentalConfig) : A set of experimental options
  546. used for Kineto library features. Note, backward compatibility is not guaranteed.
  547. execution_trace_observer (ExecutionTraceObserver) : A PyTorch Execution Trace Observer object.
  548. `PyTorch Execution Traces <https://arxiv.org/pdf/2305.14516.pdf>`__ offer a graph based
  549. representation of AI/ML workloads and enable replay benchmarks, simulators, and emulators.
  550. When this argument is included the observer start() and stop() will be called for the
  551. same time window as PyTorch profiler. See the examples section below for a code sample.
  552. acc_events (bool): Enable the accumulation of FunctionEvents across multiple profiling cycles
  553. post_processing_timeout_s (float): Optional timeout in seconds for post-processing profiler
  554. results. If specified, event parsing will stop after this duration and return partial
  555. results. Useful for handling large traces that may take too long to process.
  556. use_cuda (bool):
  557. .. deprecated:: 1.8.1
  558. use ``activities`` instead.
  559. .. note::
  560. Use :func:`~torch.profiler.schedule` to generate the callable schedule.
  561. Non-default schedules are useful when profiling long training jobs
  562. and allow the user to obtain multiple traces at the different iterations
  563. of the training process.
  564. The default schedule simply records all the events continuously for the
  565. duration of the context manager.
  566. .. note::
  567. Use :func:`~torch.profiler.tensorboard_trace_handler` to generate result files for TensorBoard:
  568. ``on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)``
  569. After profiling, result files can be found in the specified directory. Use the command:
  570. ``tensorboard --logdir dir_name``
  571. to see the results in TensorBoard.
  572. For more information, see
  573. `PyTorch Profiler TensorBoard Plugin <https://github.com/pytorch/kineto/tree/master/tb_plugin>`__
  574. .. note::
  575. Enabling shape and stack tracing results in additional overhead.
  576. When record_shapes=True is specified, profiler will temporarily hold references to the tensors;
  577. that may further prevent certain optimizations that depend on the reference count and introduce
  578. extra tensor copies.
  579. Examples:
  580. .. code-block:: python
  581. with torch.profiler.profile(
  582. activities=[
  583. torch.profiler.ProfilerActivity.CPU,
  584. torch.profiler.ProfilerActivity.CUDA,
  585. ]
  586. ) as p:
  587. code_to_profile()
  588. print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
  589. Using the profiler's ``schedule``, ``on_trace_ready`` and ``step`` functions:
  590. .. code-block:: python
  591. # Non-default profiler schedule allows user to turn profiler on and off
  592. # on different iterations of the training loop;
  593. # trace_handler is called every time a new trace becomes available
  594. def trace_handler(prof):
  595. print(
  596. prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)
  597. )
  598. # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")
  599. with torch.profiler.profile(
  600. activities=[
  601. torch.profiler.ProfilerActivity.CPU,
  602. torch.profiler.ProfilerActivity.CUDA,
  603. ],
  604. # In this example with wait=1, warmup=1, active=2, repeat=1,
  605. # profiler will skip the first step/iteration,
  606. # start warming up on the second, record
  607. # the third and the forth iterations,
  608. # after which the trace will become available
  609. # and on_trace_ready (when set) is called;
  610. # the cycle repeats starting with the next step
  611. schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1),
  612. on_trace_ready=trace_handler,
  613. # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
  614. # used when outputting for tensorboard
  615. ) as p:
  616. for iter in range(N):
  617. code_iteration_to_profile(iter)
  618. # send a signal to the profiler that the next iteration has started
  619. p.step()
  620. The following sample shows how to setup up an Execution Trace Observer (`execution_trace_observer`)
  621. .. code-block:: python
  622. with torch.profiler.profile(
  623. ...
  624. execution_trace_observer=(
  625. ExecutionTraceObserver().register_callback("./execution_trace.json")
  626. ),
  627. ) as p:
  628. for iter in range(N):
  629. code_iteration_to_profile(iter)
  630. p.step()
  631. You can also refer to test_execution_trace_with_kineto() in tests/profiler/test_profiler.py.
  632. Note: One can also pass any object satisfying the _ITraceObserver interface.
  633. """
  634. def __init__(
  635. self,
  636. *,
  637. activities: Iterable[ProfilerActivity] | None = None,
  638. schedule: Callable[[int], ProfilerAction] | None = None,
  639. on_trace_ready: Callable[..., Any] | None = None,
  640. record_shapes: bool = False,
  641. profile_memory: bool = False,
  642. with_stack: bool = False,
  643. with_flops: bool = False,
  644. with_modules: bool = False,
  645. experimental_config: _ExperimentalConfig | None = None,
  646. execution_trace_observer: _ITraceObserver | None = None,
  647. acc_events: bool = False,
  648. # deprecated:
  649. use_cuda: bool | None = None,
  650. custom_trace_id_callback: Callable[[], str] | None = None,
  651. post_processing_timeout_s: float | None = None,
  652. ) -> None:
  653. activities_set = set(activities) if activities else supported_activities()
  654. if use_cuda is not None:
  655. warn(
  656. "`use_cuda` is deprecated, use `activities` argument instead",
  657. FutureWarning,
  658. stacklevel=2,
  659. )
  660. if use_cuda:
  661. activities_set.add(ProfilerActivity.CUDA)
  662. elif ProfilerActivity.CUDA in activities_set:
  663. activities_set.remove(ProfilerActivity.CUDA)
  664. if len(activities_set) == 0:
  665. raise AssertionError("No valid profiler activities found")
  666. super().__init__(
  667. activities=activities,
  668. record_shapes=record_shapes,
  669. profile_memory=profile_memory,
  670. with_stack=with_stack,
  671. with_flops=with_flops,
  672. with_modules=with_modules,
  673. experimental_config=experimental_config,
  674. execution_trace_observer=execution_trace_observer
  675. if execution_trace_observer
  676. else ExecutionTraceObserver.build_execution_trace_obs_from_env(),
  677. acc_events=acc_events,
  678. custom_trace_id_callback=custom_trace_id_callback,
  679. post_processing_timeout_s=post_processing_timeout_s,
  680. )
  681. if schedule:
  682. self.schedule = schedule
  683. # add step markers into the trace and table view
  684. self.record_steps = True
  685. else:
  686. self.schedule = _default_schedule_fn
  687. self.record_steps = False
  688. self.on_trace_ready = on_trace_ready
  689. self.step_num = 0
  690. self.current_action = self.schedule(self.step_num)
  691. self.step_rec_fn: prof.record_function | None = None
  692. self.action_map: dict[
  693. tuple[ProfilerAction, ProfilerAction | None], list[Any]
  694. ] = {
  695. # key is (prev_action, current_action), value is action list corresponding to the state pair.
  696. (ProfilerAction.NONE, ProfilerAction.NONE): [],
  697. (ProfilerAction.NONE, ProfilerAction.WARMUP): [self.prepare_trace],
  698. (ProfilerAction.NONE, ProfilerAction.RECORD): [
  699. self.prepare_trace,
  700. self.start_trace,
  701. ],
  702. (ProfilerAction.NONE, ProfilerAction.RECORD_AND_SAVE): [
  703. self.prepare_trace,
  704. self.start_trace,
  705. ],
  706. (ProfilerAction.WARMUP, ProfilerAction.NONE): [
  707. partial(warn, "Incorrect schedule: WARMUP followed by NONE"),
  708. self.start_trace,
  709. self.stop_trace,
  710. ],
  711. (ProfilerAction.WARMUP, ProfilerAction.WARMUP): [],
  712. (ProfilerAction.WARMUP, ProfilerAction.RECORD): [self.start_trace],
  713. (ProfilerAction.WARMUP, ProfilerAction.RECORD_AND_SAVE): [self.start_trace],
  714. (ProfilerAction.RECORD, ProfilerAction.NONE): [
  715. partial(warn, "Incorrect schedule: RECORD followed by NONE"),
  716. self.stop_trace,
  717. ],
  718. (ProfilerAction.RECORD, ProfilerAction.WARMUP): [
  719. partial(warn, "Incorrect schedule: RECORD followed by WARMUP"),
  720. self.stop_trace,
  721. ],
  722. (ProfilerAction.RECORD, ProfilerAction.RECORD): [],
  723. (ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE): [],
  724. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.NONE): [
  725. self.stop_trace,
  726. self._trace_ready,
  727. ],
  728. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.WARMUP): [
  729. self.stop_trace,
  730. self._trace_ready,
  731. self.prepare_trace,
  732. ],
  733. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD): [
  734. self.stop_trace,
  735. self._trace_ready,
  736. self.prepare_trace,
  737. self.start_trace,
  738. ],
  739. (ProfilerAction.RECORD_AND_SAVE, ProfilerAction.RECORD_AND_SAVE): [
  740. self.stop_trace,
  741. self._trace_ready,
  742. self.prepare_trace,
  743. self.start_trace,
  744. ],
  745. # used for exit action
  746. (ProfilerAction.WARMUP, None): [self.start_trace, self.stop_trace],
  747. (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready],
  748. (ProfilerAction.RECORD_AND_SAVE, None): [
  749. self.stop_trace,
  750. self._trace_ready,
  751. ],
  752. }
  753. # Start tracking increments to profiler step, this will be used
  754. # by Kineto
  755. prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME)
  756. def __enter__(self):
  757. self.start()
  758. return self
  759. def __exit__(self, exc_type, exc_val, exc_tb):
  760. self.stop()
  761. prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
  762. if self.execution_trace_observer:
  763. self.execution_trace_observer.cleanup()
  764. def start(self) -> None:
  765. self._transit_action(ProfilerAction.NONE, self.current_action)
  766. if self.record_steps:
  767. self.step_rec_fn = prof.record_function(
  768. "ProfilerStep#" + str(self.step_num)
  769. )
  770. self.step_rec_fn.__enter__()
  771. def stop(self) -> None:
  772. if self.record_steps and self.step_rec_fn:
  773. self.step_rec_fn.__exit__(None, None, None)
  774. self._transit_action(self.current_action, None)
  775. def step(self) -> None:
  776. """
  777. Signals the profiler that the next profiling step has started.
  778. """
  779. if self.record_steps and self.step_rec_fn:
  780. self.step_rec_fn.__exit__(None, None, None)
  781. prev_action = self.current_action
  782. self.step_num += 1
  783. self.current_action = self.schedule(self.step_num)
  784. self._transit_action(prev_action, self.current_action)
  785. if os.environ.get("KINETO_USE_DAEMON", "") or (
  786. is_fbcode() and os.environ.get("KINETO_FORCE_STEP_HOOK", "")
  787. ):
  788. prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME)
  789. if self.record_steps:
  790. self.step_rec_fn = prof.record_function(
  791. "ProfilerStep#" + str(self.step_num)
  792. )
  793. self.step_rec_fn.__enter__()
  794. def set_custom_trace_id_callback(self, callback) -> None:
  795. """
  796. Sets a callback to be called when a new trace ID is generated.
  797. """
  798. self.custom_trace_id_callback = callback
  799. def get_trace_id(self):
  800. """
  801. Returns the current trace ID.
  802. """
  803. if self.profiler is None:
  804. return None
  805. return self.profiler.trace_id
  806. def _trace_ready(self) -> None:
  807. if self.on_trace_ready:
  808. self.on_trace_ready(self)
  809. def _transit_action(self, prev_action, current_action) -> None:
  810. action_list = self.action_map.get((prev_action, current_action))
  811. if action_list:
  812. for action in action_list:
  813. action()
  814. def _stats(self) -> prof._ProfilerStats | None:
  815. if self.profiler is None:
  816. return None
  817. return self.profiler._stats
  818. class ExecutionTraceObserver(_ITraceObserver):
  819. """Execution Trace Observer
  820. Each process can have a single ExecutionTraceObserver instance. The observer
  821. can be added to record function callbacks via calling register_callback()
  822. explicitly. Without calling unregister_callback(), repeated calls to
  823. register_callback() will not add additional observers to record function
  824. callbacks. Once an ExecutionTraceObserver is created, the start() and stop()
  825. methods control when the event data is recorded.
  826. Deleting or calling unregister_callback() will remove the observer from the
  827. record function callbacks, finalize the output file, and will stop
  828. incurring any overheads.
  829. """
  830. def __init__(self) -> None:
  831. """
  832. Initializes the default states.
  833. """
  834. self._registered = False
  835. self._execution_trace_running = False
  836. self.extra_resources_collection = False
  837. self.resources_dir: str = ""
  838. self.output_file_path: str = ""
  839. self.output_file_path_observer: str = ""
  840. def __del__(self) -> None:
  841. """
  842. Calls unregister_callback() to make sure to finalize outputs.
  843. """
  844. self.unregister_callback()
  845. @staticmethod
  846. def build_execution_trace_obs_from_env() -> Optional["ExecutionTraceObserver"]:
  847. """
  848. Returns an ExecutionTraceObserver instance if the environment variable
  849. ENABLE_PYTORCH_EXECUTION_TRACE is set to 1, otherwise returns None.
  850. Configures the observer to also collect extra resources if the environment variable
  851. ``ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS=1``. These are resources such as generated kernels,
  852. index tensor data etc. that are required to make the Execution Trace replayable.
  853. """
  854. if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE", "0") == "1":
  855. try:
  856. with tempfile.NamedTemporaryFile(
  857. "w+t", suffix=".et.json", delete=False
  858. ) as fp:
  859. filename = fp.name
  860. except Exception as e:
  861. warn(
  862. f"Execution trace will not be recorded. Exception on creating default temporary file: {e}",
  863. stacklevel=2,
  864. )
  865. return None
  866. et = ExecutionTraceObserver()
  867. et.register_callback(filename)
  868. # additionally, check if the env requires us to collect extra resources
  869. if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS", "0") == "1":
  870. et.set_extra_resource_collection(True)
  871. else:
  872. et.set_extra_resource_collection(False)
  873. return et
  874. return None
  875. def set_extra_resource_collection(self, val) -> None:
  876. """
  877. Collects extra resources such as generated kernels, index tensor data, and any other
  878. metadata that is required to complete the Execution Trace content.
  879. The caller should call this method with val=True after calling register_callback() if they want
  880. to collect the extra resources.
  881. """
  882. self.extra_resources_collection = val
  883. if self.extra_resources_collection:
  884. self.get_resources_dir(can_create=True)
  885. return
  886. def register_callback(self, output_file_path: str) -> Self:
  887. """
  888. Adds ET observer to record function callbacks. The data will be
  889. written to output_file_path.
  890. """
  891. def get_temp_uncompressed_file() -> str:
  892. with tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) as fp:
  893. return fp.name
  894. if not self._registered:
  895. self.output_file_path = output_file_path
  896. if output_file_path.endswith(".gz"):
  897. output_file_path = get_temp_uncompressed_file()
  898. self.output_file_path_observer = output_file_path
  899. self._registered = _add_execution_trace_observer(output_file_path)
  900. return self
  901. def get_resources_dir(self, can_create=False) -> str | None:
  902. """
  903. Generates the resources directory for the generated kernels,
  904. or index tensor data or any other metadata that is required
  905. to complete the Execution Trace content.
  906. The directory is created right where the ET file is being output.
  907. Only works if the observer has called set_extra_resource_collection(val=True).
  908. Returns None if the observer is not configured with extra resource collection.
  909. """
  910. if not self.extra_resources_collection:
  911. return None
  912. if self.resources_dir:
  913. # already created
  914. return self.resources_dir
  915. generated_path = ExecutionTraceObserver.get_resources_dir_for_et_path(
  916. self.output_file_path, create_dir=can_create
  917. )
  918. if not generated_path:
  919. # could not find of create the resources dir
  920. return None
  921. self.resources_dir = generated_path
  922. return self.resources_dir
  923. @staticmethod
  924. def get_resources_dir_for_et_path(
  925. trace_path, create_dir: bool = False
  926. ) -> str | None:
  927. work_dir, file_name = os.path.split(trace_path)
  928. resource_dir = os.path.join(
  929. work_dir, os.path.splitext(file_name)[0] + "_resources"
  930. )
  931. if not os.path.exists(resource_dir):
  932. if create_dir:
  933. try:
  934. os.mkdir(resource_dir)
  935. except Exception:
  936. warn(
  937. f"Execution trace exception when creating {resource_dir}",
  938. stacklevel=2,
  939. )
  940. return None
  941. else:
  942. return None
  943. return resource_dir
  944. def unregister_callback(self) -> None:
  945. """
  946. Removes ET observer from record function callbacks.
  947. """
  948. def _save_triton_kernels() -> None:
  949. try:
  950. resource_dir = self.get_resources_dir()
  951. except Exception as e:
  952. warn(
  953. f"Execution trace exception when generating resource directory: {e}",
  954. stacklevel=2,
  955. )
  956. return
  957. if not resource_dir:
  958. return
  959. # Save the kernel paths for the generated kernels
  960. from torch._inductor.codecache import PyCodeCache
  961. kernel_files = [
  962. v.__file__
  963. for v in PyCodeCache.modules
  964. if getattr(v, "__file__", None) is not None
  965. ]
  966. for kernel_file in kernel_files:
  967. if kernel_file is None:
  968. continue
  969. name = os.path.basename(kernel_file)
  970. dst = os.path.join(resource_dir, name)
  971. shutil.copyfile(kernel_file, dst)
  972. def _save_gz_file(uncompressed_file: str, output_file: str) -> None:
  973. print(f"Execution Trace: compressing {uncompressed_file} to {output_file}")
  974. with open(uncompressed_file, "rb") as fin:
  975. with gzip.open(output_file, "wb") as fout:
  976. fout.writelines(fin)
  977. os.remove(uncompressed_file)
  978. if self._registered:
  979. self.stop()
  980. try:
  981. _save_triton_kernels()
  982. except Exception as e:
  983. warn(f"Execution trace failed to save kernels: {e}", stacklevel=2)
  984. _remove_execution_trace_observer()
  985. if self.output_file_path.endswith("gz"):
  986. _save_gz_file(self.output_file_path_observer, self.output_file_path)
  987. self._registered = False
  988. @property
  989. def is_registered(self):
  990. """
  991. Returns True if the execution trace observer is registered, otherwise False.
  992. """
  993. return self._registered
  994. def is_running(self):
  995. """
  996. Returns True if the observer is running, otherwise False.
  997. """
  998. return self._execution_trace_running
  999. def start(self) -> None:
  1000. """
  1001. Starts to capture.
  1002. """
  1003. if self._registered and not self._execution_trace_running:
  1004. _enable_execution_trace_observer()
  1005. self._execution_trace_running = True
  1006. self._record_pg_config()
  1007. def stop(self) -> None:
  1008. """
  1009. Stops to capture.
  1010. """
  1011. if self._execution_trace_running:
  1012. _disable_execution_trace_observer()
  1013. self._execution_trace_running = False
  1014. def cleanup(self) -> None:
  1015. """
  1016. Calls unregister_callback() to make sure to finalize outputs.
  1017. """
  1018. self.unregister_callback()
  1019. def get_output_file_path(self) -> str | None:
  1020. """
  1021. Returns the output file name or None.
  1022. """
  1023. if self.output_file_path:
  1024. return self.output_file_path
  1025. else:
  1026. return None
  1027. def _record_pg_config(self) -> None:
  1028. # Records the PG config info to the trace as node:
  1029. # ## process_group:init ##
  1030. if (
  1031. self.is_registered
  1032. and torch.distributed.is_available()
  1033. and torch.distributed.is_initialized()
  1034. ):
  1035. pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info
  1036. torch.autograd._record_function_with_args_enter(
  1037. "## process_group:init ##",
  1038. json.dumps(pg_config_info, cls=_NumpyEncoder),
  1039. )