profiler_util.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447
  1. # mypy: allow-untyped-defs
  2. import bisect
  3. import itertools
  4. import math
  5. from collections import defaultdict, namedtuple
  6. from operator import attrgetter
  7. from typing import Any, Optional
  8. from typing_extensions import deprecated
  9. import torch
  10. from torch.autograd import DeviceType
  11. __all__ = [
  12. "EventList",
  13. "FormattedTimesMixin",
  14. "Interval",
  15. "Kernel",
  16. "FunctionEvent",
  17. "FunctionEventAvg",
  18. "StringTable",
  19. "MemRecordsAcc",
  20. ]
  21. class EventList(list):
  22. """A list of profiling events with helper methods for analysis and visualization.
  23. EventList extends the standard Python list to provide specialized methods for
  24. working with profiling events (FunctionEvent or FunctionEventAvg objects).
  25. It includes utilities for aggregating statistics, formatting output tables,
  26. and exporting profiling data.
  27. This class is typically returned by profiler methods and should not be
  28. instantiated directly by users.
  29. Args:
  30. *args: Standard list arguments.
  31. use_device (str, optional): Device type for profiling ("cuda", "xpu", etc.).
  32. profile_memory (bool, optional): Whether memory profiling was enabled. Default: False.
  33. with_flops (bool, optional): Whether to include FLOP counts. Default: False.
  34. Attributes:
  35. _use_device (str): Device type being profiled.
  36. _profile_memory (bool): Whether memory profiling is enabled.
  37. _with_flops (bool): Whether FLOP counting is enabled.
  38. _tree_built (bool): Whether the event tree structure has been built.
  39. Key Methods:
  40. table(...): Format events as a table string for display.
  41. export_chrome_trace(path): Export to Chrome tracing format.
  42. export_stacks(path, metric): Export stack traces with metrics.
  43. key_averages(...): Compute averaged statistics grouped by operation name.
  44. total_average(): Compute aggregate totals across all events (sums, not averages).
  45. Properties:
  46. self_cpu_time_total: Sum of self CPU time across all events.
  47. Example::
  48. import torch
  49. from torch.profiler import profile, ProfilerActivity
  50. with profile(activities=[ProfilerActivity.CPU]) as prof:
  51. x = torch.randn(100, 100)
  52. y = torch.matmul(x, x)
  53. # EventList is returned by prof.events()
  54. events = prof.events()
  55. # Display as formatted table
  56. print(
  57. events.table(
  58. sort_by="cpu_time_total", row_limit=20, top_level_events_only=False
  59. )
  60. )
  61. # Export to Chrome tracing format
  62. events.export_chrome_trace("trace.json")
  63. # Get averaged statistics
  64. avg_events = events.key_averages()
  65. print(avg_events.table())
  66. # Export stack traces
  67. events.export_stacks("stacks.txt", "self_cpu_time_total")
  68. See Also:
  69. - :class:`FunctionEvent`: Individual profiling event
  70. - :class:`FunctionEventAvg`: Averaged profiling statistics
  71. - :meth:`table`: Format events as a readable table
  72. - :meth:`key_averages`: Aggregate events by operation name
  73. """
  74. def __init__(self, *args, **kwargs):
  75. use_device = kwargs.pop("use_device", None)
  76. profile_memory = kwargs.pop("profile_memory", False)
  77. with_flops = kwargs.pop("with_flops", False)
  78. # pyrefly: ignore [not-iterable]
  79. super().__init__(*args, **kwargs)
  80. self._use_device = use_device
  81. self._profile_memory = profile_memory
  82. self._tree_built = False
  83. self._with_flops = with_flops
  84. def _build_tree(self):
  85. self._populate_cpu_children()
  86. self._remove_dup_nodes()
  87. self._set_backward_stacktraces()
  88. self._tree_built = True
  89. def __str__(self):
  90. return self.table()
  91. def _remove_dup_nodes(self):
  92. while True:
  93. to_delete = set()
  94. for idx in range(len(self)):
  95. if (
  96. self[idx].cpu_parent is not None
  97. and self[idx].cpu_parent.name == self[idx].name
  98. and len(self[idx].cpu_parent.cpu_children) == 1
  99. ):
  100. self[idx].cpu_parent.cpu_children = self[idx].cpu_children
  101. self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up
  102. for ch in self[idx].cpu_children:
  103. ch.cpu_parent = self[idx].cpu_parent
  104. to_delete.add(idx)
  105. if len(to_delete) == 0:
  106. break
  107. new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
  108. self.clear()
  109. self.extend(new_evts)
  110. def _populate_cpu_children(self):
  111. """Populate child events into each underlying FunctionEvent object.
  112. One event is a child of another if [s1, e1) is inside [s2, e2). Where
  113. s1 and e1 would be start and end of the child event's interval. And
  114. s2 and e2 start and end of the parent event's interval
  115. Example: In event list [[0, 10], [1, 3], [3, 4]] would have make [0, 10]
  116. be a parent of two other intervals.
  117. If for any reason two intervals intersect only partially, this function
  118. will not record a parent child relationship between then.
  119. """
  120. # Some events can be async (i.e. start and end on different threads),
  121. # since it's generally undefined how to attribute children ranges to
  122. # async ranges, we do not use them when calculating nested ranges and stats
  123. sync_events = [
  124. evt
  125. for evt in self
  126. if not evt.is_async and evt.device_type == DeviceType.CPU
  127. ]
  128. events = sorted(
  129. sync_events,
  130. key=attrgetter("thread"),
  131. )
  132. # Group by both thread and node_id, so that events that happen to have
  133. # the same thread_id but are from different nodes aren't incorrectly
  134. # grouped together.
  135. threads = itertools.groupby(
  136. events, key=lambda event: (event.thread, event.node_id)
  137. )
  138. # For each thread we keep a stack of current nested parents.
  139. # We maintain the invariant that each interval is a subset of all other
  140. # intervals lower in the stack.
  141. #
  142. # First we sort the intervals by their start time. Then we iterate over them.
  143. # Every time we see a new interval we remove several parents from
  144. # the top until we restore the invariant. Then parent child relationship
  145. # if recorded if the stack is not empty.
  146. # Finally we add new interval to the list
  147. #
  148. # Algorithm has O(N * log(N)) complexity where N is number of
  149. # intervals
  150. for _thread_id, thread_events in threads:
  151. thread_events_ = sorted(
  152. thread_events,
  153. key=lambda event: [event.time_range.start, -event.time_range.end],
  154. )
  155. current_events: list[FunctionEvent] = []
  156. for event in thread_events_:
  157. while len(current_events) > 0:
  158. parent = current_events[-1]
  159. if (
  160. event.time_range.start >= parent.time_range.end
  161. or event.time_range.end > parent.time_range.end
  162. ):
  163. # this can't be a parent
  164. current_events.pop()
  165. else:
  166. parent.append_cpu_child(event)
  167. if event.cpu_parent is not None:
  168. raise AssertionError(
  169. f"There is already a CPU parent event for {event.key}"
  170. )
  171. event.set_cpu_parent(parent)
  172. break
  173. current_events.append(event)
  174. def _set_backward_stacktraces(self):
  175. def bw_parent(evt):
  176. if evt is None:
  177. return None
  178. elif evt.scope == 1: # BACKWARD_FUNCTION
  179. return evt
  180. else:
  181. return bw_parent(evt.cpu_parent)
  182. fwd_stacks = {}
  183. for evt in self:
  184. if bw_parent(evt) is None and evt.stack is not None:
  185. t = (evt.sequence_nr, evt.thread)
  186. if t not in fwd_stacks:
  187. fwd_stacks[t] = evt.stack
  188. for evt in self:
  189. p = bw_parent(evt)
  190. if p is not None:
  191. if p.fwd_thread is None:
  192. raise AssertionError(
  193. "Expected fwd_thread to be set for backward parent"
  194. )
  195. t = (p.sequence_nr, p.fwd_thread)
  196. evt.stack = fwd_stacks.get(t, [])
  197. @property
  198. def self_cpu_time_total(self):
  199. return sum(event.self_cpu_time_total for event in self)
  200. def table(
  201. self,
  202. sort_by=None,
  203. row_limit=100,
  204. max_src_column_width=75,
  205. max_name_column_width=55,
  206. max_shapes_column_width=80,
  207. header=None,
  208. top_level_events_only=False,
  209. time_unit=None,
  210. ):
  211. """Print an EventList as a nicely formatted table.
  212. Args:
  213. sort_by (str, optional): Attribute used to sort entries. By default
  214. they are printed in the same order as they were registered.
  215. Valid keys include: ``cpu_time``, ``cuda_time``, ``xpu_time``,
  216. ``cpu_time_total``, ``cuda_time_total``, ``xpu_time_total``,
  217. ``cpu_memory_usage``, ``cuda_memory_usage``, ``xpu_memory_usage``,
  218. ``self_cpu_memory_usage``, ``self_cuda_memory_usage``,
  219. ``self_xpu_memory_usage``, ``count``.
  220. top_level_events_only(bool, optional): Boolean flag to determine the
  221. selection of events to display. If true, the profiler will only
  222. display events at top level like top-level invocation of python
  223. `lstm`, python `add` or other functions, nested events like low-level
  224. cpu/cuda/xpu ops events are omitted for profiler result readability.
  225. time_unit(str, optional): A time unit to be used for all values in the
  226. table. Valid options are: ``s``, ``ms`` and ``us``.
  227. Returns:
  228. A string containing the table.
  229. """
  230. return _build_table(
  231. self,
  232. sort_by=sort_by,
  233. row_limit=row_limit,
  234. max_src_column_width=max_src_column_width,
  235. max_name_column_width=max_name_column_width,
  236. max_shapes_column_width=max_shapes_column_width,
  237. header=header,
  238. profile_memory=self._profile_memory,
  239. with_flops=self._with_flops,
  240. top_level_events_only=top_level_events_only,
  241. time_unit=time_unit,
  242. )
  243. def export_chrome_trace(self, path):
  244. """Export an EventList as a Chrome tracing tools file.
  245. The checkpoint can be later loaded and inspected under ``chrome://tracing`` URL.
  246. Args:
  247. path (str): Path where the trace will be written.
  248. """
  249. import os
  250. device_name = "cuda" if not self._use_device else self._use_device
  251. with open(path, "w") as f:
  252. next_id = 0
  253. # Use file IO over using json.dump since JSON dumping is very slow and
  254. # this technique is proven to give a 4x speedup.
  255. f.write("[")
  256. for evt in self:
  257. if evt.trace_name is None:
  258. continue
  259. f.write(
  260. '{{"name": "{}", '
  261. '"ph": "X", '
  262. '"ts": {}, '
  263. '"dur": {}, '
  264. '"tid": {}, '
  265. '"pid": "CPU functions", '
  266. '"args": {{}}}}, '.format(
  267. evt.trace_name,
  268. evt.time_range.start,
  269. evt.time_range.elapsed_us(),
  270. evt.thread
  271. if not evt.is_remote
  272. else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "',
  273. )
  274. )
  275. for _ in evt.kernels:
  276. # 's' and 'f' draw Flow arrows from
  277. # the CPU launch to the GPU kernel
  278. f.write(
  279. f'{{"name": "{evt.trace_name}", '
  280. '"ph": "s", '
  281. f'"ts": {evt.time_range.start}, '
  282. f'"tid": {evt.thread}, '
  283. '"pid": "CPU functions", '
  284. f'"id": {next_id}, '
  285. f'"cat": "cpu_to_{device_name}", '
  286. '"args": {}}, '
  287. )
  288. # Note: use torch.profiler to get device kernel trace
  289. next_id += 1
  290. if len(self) > 0:
  291. # remove trailing whitespace and comma
  292. f.seek(f.tell() - 2, os.SEEK_SET)
  293. f.truncate()
  294. f.write("]")
  295. def supported_export_stacks_metrics(self):
  296. return [
  297. "self_cpu_time_total",
  298. "self_cuda_time_total",
  299. "self_xpu_time_total",
  300. "self_privateuse1_time_total",
  301. ]
  302. def export_stacks(self, path: str, metric: str):
  303. if metric not in self.supported_export_stacks_metrics():
  304. raise ValueError(
  305. "metric should be one of: "
  306. + str(self.supported_export_stacks_metrics())
  307. )
  308. translate_table = str.maketrans(" ;\t\n", "____")
  309. with open(path, "w") as f:
  310. for evt in self:
  311. if evt.stack and len(evt.stack) > 0:
  312. metric_value = getattr(
  313. evt,
  314. metric.replace("cuda", "device")
  315. .replace("xpu", "device")
  316. .replace("privateuse1", "device"),
  317. )
  318. if int(metric_value) > 0:
  319. stack_str = ""
  320. for entry in reversed(evt.stack):
  321. stack_str += entry.translate(translate_table)
  322. stack_str += ";"
  323. stack_str = stack_str[:-1] + " " + str(int(metric_value))
  324. f.write(stack_str + "\n")
  325. def key_averages(
  326. self,
  327. group_by_input_shapes=False,
  328. group_by_stack_n=0,
  329. group_by_overload_name=False,
  330. ):
  331. """Averages all function events over their keys.
  332. Args:
  333. group_by_input_shapes: group entries by
  334. (event name, input shapes) rather than just event name.
  335. This is useful to see which input shapes contribute to the runtime
  336. the most and may help with size-specific optimizations or
  337. choosing the best candidates for quantization (aka fitting a roof line)
  338. group_by_stack_n: group by top n stack trace entries
  339. group_by_overload_name: Differentiate operators by their overload name e.g. aten::add.Tensor
  340. and aten::add.out will be aggregated separately
  341. Returns:
  342. An EventList containing FunctionEventAvg objects.
  343. """
  344. if not self._tree_built:
  345. raise AssertionError(
  346. "Expected tree to be built before calling key_averages"
  347. )
  348. stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
  349. def get_key(
  350. event, group_by_input_shapes, group_by_stack_n, group_by_overload_name
  351. ) -> tuple[str, ...]:
  352. key = [
  353. str(event.key),
  354. str(event.node_id),
  355. str(event.device_type),
  356. str(event.is_legacy),
  357. str(event.is_user_annotation),
  358. ]
  359. if group_by_overload_name:
  360. key.append(evt.overload_name)
  361. if group_by_input_shapes:
  362. key.append(str(event.input_shapes))
  363. if group_by_stack_n > 0:
  364. key += event.stack[:group_by_stack_n]
  365. return tuple(key)
  366. for evt in self:
  367. stats[
  368. get_key(
  369. evt, group_by_input_shapes, group_by_stack_n, group_by_overload_name
  370. )
  371. ].add(evt)
  372. avg_list = EventList(
  373. stats.values(),
  374. use_device=self._use_device,
  375. profile_memory=self._profile_memory,
  376. with_flops=self._with_flops,
  377. )
  378. for evt in avg_list:
  379. evt.stack = evt.stack[:group_by_stack_n]
  380. if not group_by_input_shapes:
  381. evt.input_shapes = ""
  382. if not group_by_overload_name:
  383. evt.overload_name = ""
  384. return avg_list
  385. def total_average(self):
  386. """Compute aggregate statistics across all events.
  387. Accumulates statistics from all events into a single FunctionEventAvg object.
  388. This is primarily useful for computing total metrics (total CPU time, total
  389. memory usage, etc.) across the entire profiling session, regardless of
  390. operation type.
  391. Note:
  392. This sums up times and counts across ALL different operations, so the
  393. "average" metrics (like cpu_time) represent the average time per operation
  394. call across the entire session, mixing all operation types together.
  395. For per-operation averages, use :meth:`key_averages` instead.
  396. Returns:
  397. FunctionEventAvg: A single aggregate object with key="Total" containing
  398. accumulated statistics.
  399. """
  400. total_stat = FunctionEventAvg()
  401. for evt in self:
  402. total_stat += evt
  403. total_stat.key = None
  404. total_stat.key = "Total"
  405. return total_stat
  406. def _format_time(time_us):
  407. """Define how to format time in FunctionEvent."""
  408. US_IN_SECOND = 1000.0 * 1000.0
  409. US_IN_MS = 1000.0
  410. if time_us >= US_IN_SECOND:
  411. return f"{time_us / US_IN_SECOND:.3f}s"
  412. if time_us >= US_IN_MS:
  413. return f"{time_us / US_IN_MS:.3f}ms"
  414. return f"{time_us:.3f}us"
  415. def _format_time_share(time_us, total_time_us):
  416. """Define how to format time in FunctionEvent."""
  417. if total_time_us == 0:
  418. if time_us != 0:
  419. raise AssertionError(f"Expected time_us == 0 but got {time_us}")
  420. return "NaN"
  421. return f"{time_us * 100.0 / total_time_us:.2f}%"
  422. def _format_memory(nbytes):
  423. """Return a formatted memory size string."""
  424. KB = 1024
  425. MB = 1024 * KB
  426. GB = 1024 * MB
  427. if abs(nbytes) >= GB:
  428. return f"{nbytes * 1.0 / GB:.2f} GB"
  429. elif abs(nbytes) >= MB:
  430. return f"{nbytes * 1.0 / MB:.2f} MB"
  431. elif abs(nbytes) >= KB:
  432. return f"{nbytes * 1.0 / KB:.2f} KB"
  433. else:
  434. return str(nbytes) + " B"
  435. def _attr_formatter(name):
  436. return property(lambda self: _format_time(getattr(self, name)))
  437. class FormattedTimesMixin:
  438. """Helpers for FunctionEvent and FunctionEventAvg.
  439. The subclass should define `*_time_total` and `count` attributes.
  440. """
  441. cpu_time_str = _attr_formatter("cpu_time")
  442. device_time_str = _attr_formatter("device_time")
  443. cpu_time_total_str = _attr_formatter("cpu_time_total")
  444. device_time_total_str = _attr_formatter("device_time_total")
  445. self_cpu_time_total_str = _attr_formatter("self_cpu_time_total")
  446. self_device_time_total_str = _attr_formatter("self_device_time_total")
  447. @property
  448. def cpu_time(self):
  449. return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore[attr-defined]
  450. @property
  451. def device_time(self):
  452. return 0.0 if self.count == 0 else 1.0 * self.device_time_total / self.count # type: ignore[attr-defined]
  453. @property
  454. @deprecated(
  455. "`cuda_time` is deprecated, please use `device_time` instead.",
  456. category=FutureWarning,
  457. )
  458. def cuda_time(self): # To be deprecated
  459. return self.device_time
  460. class Interval:
  461. def __init__(self, start, end):
  462. self.start = start
  463. self.end = end
  464. def elapsed_us(self):
  465. r"""
  466. Returns the length of the interval
  467. """
  468. return self.end - self.start
  469. Kernel = namedtuple("Kernel", ["name", "device", "duration"])
  470. class FunctionEvent(FormattedTimesMixin):
  471. """Profiling information about a single function.
  472. FunctionEvent records the execution of a single operation during profiling.
  473. These events are obtained from the profiler/kineto and contain detailed
  474. timing and memory usage information.
  475. .. note::
  476. FunctionEvent objects are typically created by the profiler/kineto and should not
  477. be instantiated directly by users. Access them through the profiler's output.
  478. Attributes:
  479. id (int): Unique identifier for this event.
  480. node_id (int): Node identifier for distributed profiling (-1 if not applicable).
  481. name (str): Name of the profiled function/operator.
  482. overload_name (str): Overload name for the operator (requires _ExperimentalConfig(capture_overload_names=True) set).
  483. trace_name (str): Same as name, just changes ProfilerStep* to ProfilerStep#
  484. time_range (Interval): Time interval containing start and end timestamps in microseconds.
  485. thread (int): Thread ID where the operation started.
  486. fwd_thread (int): Thread ID of the corresponding forward operation.
  487. kernels (List[Kernel]): List of device kernels launched by this operation.
  488. count (int): Number of times this event was called (usually 1).
  489. cpu_children (List[FunctionEvent]): Direct CPU child operations.
  490. cpu_parent (FunctionEvent): Direct CPU parent operation.
  491. input_shapes (Tuple[int, ...]): Shapes of input tensors (requires record_shapes=true).
  492. concrete_inputs (List[Any]): Concrete input values (requires record_shapes=true).
  493. kwinputs (Dict[str, Any]): Keyword arguments (requires record_shapes=true).
  494. stack (List[str]): Python stack trace where the operation was called (requires with_stack=true).
  495. scope (int): at::RecordScope identifier (0=forward, 1=backward, etc.).
  496. use_device (str): Device type being profiled ("cuda", "xpu", etc.).
  497. cpu_memory_usage (int): CPU memory allocated in bytes.
  498. device_memory_usage (int): Device memory allocated in bytes.
  499. is_async (bool): Whether this is an asynchronous operation.
  500. is_remote (bool): Whether this operation occurred on a remote node.
  501. sequence_nr (int): Sequence number for autograd operations.
  502. device_type (DeviceType): Type of device (CPU, CUDA, XPU, PrivateUse1, etc.).
  503. device_index (int): Index of the device (e.g., GPU 0, 1, 2).
  504. device_resource_id (int): Resource ID on the device (ie. stream ID).
  505. is_legacy (bool): Whether this is from the legacy profiler.
  506. flops (int): Estimated floating point operations.
  507. is_user_annotation (bool): Whether this is a user-annotated region.
  508. metadata_json (str): Additional metadata in JSON format.
  509. Properties:
  510. cpu_time_total (float): Total CPU time in microseconds.
  511. device_time_total (float): Total device (CUDA/XPU/etc) time in microseconds.
  512. self_cpu_time_total (float): CPU time excluding child operations.
  513. self_device_time_total (float): Device time excluding child operations.
  514. self_cpu_memory_usage (int): CPU memory usage excluding child operations.
  515. self_device_memory_usage (int): Device memory usage excluding child operations.
  516. cpu_time (float): Average CPU time per call.
  517. device_time (float): Average device time per call.
  518. key (str): Key used for grouping events (usually same as name).
  519. See Also:
  520. - :class:`torch.profiler.profile`: Context manager for profiling
  521. - :class:`EventList`: List container for FunctionEvent objects with helper methods
  522. - :class:`FunctionEventAvg`: Averaged statistics over multiple FunctionEvent objects
  523. """
  524. def __init__(
  525. self,
  526. id,
  527. name,
  528. thread,
  529. start_us,
  530. end_us,
  531. overload_name=None,
  532. fwd_thread=None,
  533. input_shapes=None,
  534. stack=None,
  535. scope=0,
  536. use_device=None,
  537. cpu_memory_usage=0,
  538. device_memory_usage=0,
  539. is_async=False,
  540. is_remote=False,
  541. sequence_nr=-1,
  542. node_id=-1,
  543. device_type=DeviceType.CPU,
  544. device_index=0,
  545. device_resource_id=None,
  546. is_legacy=False,
  547. flops=None,
  548. trace_name=None,
  549. concrete_inputs=None,
  550. kwinputs=None,
  551. is_user_annotation=False,
  552. metadata_json=None,
  553. ):
  554. self.id: int = id
  555. self.node_id: int = node_id
  556. self.name: str = name
  557. # pyrefly: ignore [bad-assignment]
  558. self.overload_name: str = overload_name
  559. # pyrefly: ignore [bad-assignment]
  560. self.trace_name: str = trace_name
  561. self.time_range: Interval = Interval(start_us, end_us)
  562. self.thread: int = thread
  563. self.fwd_thread: Optional[int] = fwd_thread
  564. self.kernels: list[Kernel] = []
  565. self.count: int = 1
  566. self.cpu_children: list[FunctionEvent] = []
  567. self.cpu_parent: Optional[FunctionEvent] = None
  568. # pyrefly: ignore [bad-assignment]
  569. self.input_shapes: tuple[int, ...] = input_shapes
  570. # pyrefly: ignore [bad-assignment]
  571. self.concrete_inputs: list[Any] = concrete_inputs
  572. # pyrefly: ignore [bad-assignment]
  573. self.kwinputs: dict[str, Any] = kwinputs
  574. # pyrefly: ignore [bad-assignment]
  575. self.stack: list = stack
  576. self.scope: int = scope
  577. self.use_device: Optional[str] = use_device
  578. self.cpu_memory_usage: int = cpu_memory_usage
  579. self.device_memory_usage: int = device_memory_usage
  580. self.is_async: bool = is_async
  581. self.is_remote: bool = is_remote
  582. self.sequence_nr: int = sequence_nr
  583. self.device_type: DeviceType = device_type
  584. self.device_index: int = device_index
  585. self.device_resource_id: int = (
  586. thread if device_resource_id is None else device_resource_id
  587. )
  588. self.is_legacy: bool = is_legacy
  589. self.flops: Optional[int] = flops
  590. self.is_user_annotation: Optional[bool] = is_user_annotation
  591. self.self_cpu_percent = -1
  592. self.total_cpu_percent = -1
  593. self.total_device_percent = -1
  594. self.metadata_json = metadata_json
  595. def append_kernel(self, name, device, duration):
  596. if self.device_type != DeviceType.CPU:
  597. raise AssertionError("Expected device_type to be CPU")
  598. self.kernels.append(Kernel(name, device, duration))
  599. def append_cpu_child(self, child):
  600. """Append a CPU child of type FunctionEvent.
  601. One is supposed to append only direct children to the event to have
  602. correct self cpu time being reported.
  603. """
  604. if self.device_type != DeviceType.CPU:
  605. raise AssertionError("Expected device_type to be CPU")
  606. if not isinstance(child, FunctionEvent):
  607. raise AssertionError("Expected child to be a FunctionEvent")
  608. if child.device_type != DeviceType.CPU:
  609. raise AssertionError("Expected child device_type to be CPU")
  610. self.cpu_children.append(child)
  611. def set_cpu_parent(self, parent):
  612. """Set the immediate CPU parent of type FunctionEvent.
  613. One profiling FunctionEvent should have only one CPU parent such that
  614. the child's range interval is completely inside the parent's. We use
  615. this connection to determine the event is from top-level op or not.
  616. """
  617. if self.device_type != DeviceType.CPU:
  618. raise AssertionError("Expected device_type to be CPU")
  619. if not isinstance(parent, FunctionEvent):
  620. raise AssertionError("Expected parent to be a FunctionEvent")
  621. if parent.device_type != DeviceType.CPU:
  622. raise AssertionError("Expected parent device_type to be CPU")
  623. self.cpu_parent = parent
  624. # Note: async events don't have children, are not used when computing 'self'
  625. # metrics of other events, have only total cpu time
  626. @property
  627. def self_cpu_memory_usage(self):
  628. if self.is_async or self.device_type != DeviceType.CPU:
  629. return 0
  630. return self.cpu_memory_usage - sum(
  631. child.cpu_memory_usage for child in self.cpu_children
  632. )
  633. @property
  634. def self_device_memory_usage(self):
  635. if self.is_async or self.device_type != DeviceType.CPU:
  636. return 0
  637. return self.device_memory_usage - sum(
  638. child.device_memory_usage for child in self.cpu_children
  639. )
  640. @property
  641. @deprecated(
  642. "`self_cuda_memory_usage` is deprecated. Use `self_device_memory_usage` instead.",
  643. category=FutureWarning,
  644. )
  645. def self_cuda_memory_usage(self): # To be deprecated
  646. return self.self_device_memory_usage
  647. @property
  648. def cpu_time_total(self):
  649. if self.device_type == DeviceType.CPU:
  650. return self.time_range.elapsed_us()
  651. else:
  652. return 0
  653. @property
  654. def self_cpu_time_total(self):
  655. if self.is_async or self.device_type != DeviceType.CPU:
  656. return 0
  657. return self.cpu_time_total - sum(
  658. child.cpu_time_total for child in self.cpu_children
  659. )
  660. @property
  661. def device_time_total(self):
  662. if self.is_async or not self.use_device:
  663. return 0
  664. if self.device_type == DeviceType.CPU:
  665. if not self.is_legacy:
  666. # account for the kernels in the children ops
  667. return sum(kinfo.duration for kinfo in self.kernels) + sum(
  668. ch.device_time_total for ch in self.cpu_children
  669. )
  670. else:
  671. # each legacy cpu events has a single (fake) kernel
  672. return sum(kinfo.duration for kinfo in self.kernels)
  673. else:
  674. if self.device_type not in [
  675. DeviceType.CUDA,
  676. DeviceType.PrivateUse1,
  677. DeviceType.MTIA,
  678. DeviceType.HPU,
  679. DeviceType.XPU,
  680. ]:
  681. raise AssertionError(
  682. f"Expected device_type to be CUDA, PrivateUse1, MTIA, HPU or XPU, but got {self.device_type}"
  683. )
  684. return self.time_range.elapsed_us()
  685. @property
  686. @deprecated(
  687. "`cuda_time_total` is deprecated. Use `device_time_total` instead.",
  688. category=FutureWarning,
  689. )
  690. def cuda_time_total(self): # To be deprecated
  691. return self.device_time_total
  692. @property
  693. def self_device_time_total(self):
  694. if self.is_async or not self.use_device:
  695. return 0
  696. if self.device_type == DeviceType.CPU:
  697. return self.device_time_total - sum(
  698. child.device_time_total for child in self.cpu_children
  699. )
  700. else:
  701. if self.device_type not in [
  702. DeviceType.CUDA,
  703. DeviceType.PrivateUse1,
  704. DeviceType.MTIA,
  705. DeviceType.HPU,
  706. DeviceType.XPU,
  707. ]:
  708. raise AssertionError(
  709. f"Expected device_type to be CUDA, PrivateUse1, MTIA, HPU or XPU, but got {self.device_type}"
  710. )
  711. return self.device_time_total
  712. @property
  713. @deprecated(
  714. "`self_cuda_time_total` is deprecated. Use `self_device_time_total` instead.",
  715. category=FutureWarning,
  716. )
  717. def self_cuda_time_total(self): # To be deprecated
  718. return self.self_device_time_total
  719. @property
  720. def key(self):
  721. return self.name
  722. def __repr__(self):
  723. device_name = self.use_device
  724. device_time = self.device_time_str
  725. device_memory_usage = self.device_memory_usage
  726. return (
  727. f"<FunctionEvent id={self.id} name={self.name} overload_name={self.overload_name} "
  728. f"device_type={self.device_type} node_id={self.node_id} cpu_time={self.cpu_time_str} "
  729. f"start_us={self.time_range.start} end_us={self.time_range.end} "
  730. f"cpu_children={str([child.id for child in self.cpu_children])} {device_name}_time={device_time} "
  731. f"name={self.name} thread={self.thread} input_shapes={str(self.input_shapes)} "
  732. f"cpu_memory_usage={self.cpu_memory_usage} {device_name}_memory_usage={device_memory_usage} "
  733. f"is_async={self.is_async} is_remote={self.is_remote} seq_nr={self.sequence_nr} is_legacy={self.is_legacy}>"
  734. )
  735. class FunctionEventAvg(FormattedTimesMixin):
  736. """Averaged profiling statistics over multiple FunctionEvent objects.
  737. FunctionEventAvg aggregates statistics from multiple FunctionEvent objects
  738. with the same key (typically same operation name). This is useful for getting
  739. average performance metrics across multiple invocations of the same operation.
  740. This class is typically created by calling :meth:`EventList.key_averages()` on
  741. a profiler's event list.
  742. Attributes:
  743. key (str): Grouping key for the events (typically operation name).
  744. count (int): Total number of events aggregated.
  745. node_id (int): Node identifier for distributed profiling (-1 if not applicable).
  746. is_async (bool): Whether the operations are asynchronous.
  747. is_remote (bool): Whether the operations occurred on a remote node.
  748. use_device (str): Device type being profiled ("cuda", "xpu", etc.).
  749. cpu_time_total (int): Accumulated total CPU time in microseconds.
  750. device_time_total (int): Accumulated total device time in microseconds.
  751. self_cpu_time_total (int): Accumulated self CPU time (excluding children) in microseconds.
  752. self_device_time_total (int): Accumulated self device time (excluding children) in microseconds.
  753. input_shapes (List[List[int]]): Input tensor shapes (requires record_shapes=true).
  754. overload_name (str): Operator overload name (requires _ExperimentalConfig(capture_overload_names=True) set).
  755. stack (List[str]): Python stack trace where the operation was called (requires with_stack=true).
  756. scope (int): at::RecordScope identifier (0=forward, 1=backward, etc.).
  757. cpu_memory_usage (int): Accumulated CPU memory usage in bytes.
  758. device_memory_usage (int): Accumulated device memory usage in bytes.
  759. self_cpu_memory_usage (int): Accumulated self CPU memory usage in bytes.
  760. self_device_memory_usage (int): Accumulated self device memory usage in bytes.
  761. cpu_children (List[FunctionEvent]): CPU child events.
  762. cpu_parent (FunctionEvent): CPU parent event.
  763. device_type (DeviceType): Type of device (CPU, CUDA, XPU, PrivateUse1, etc.).
  764. is_legacy (bool): Whether from legacy profiler.
  765. flops (int): Total floating point operations.
  766. is_user_annotation (bool): Whether this is a user-annotated region.
  767. Properties:
  768. cpu_time (float): Average CPU time per invocation.
  769. device_time (float): Average device time per invocation.
  770. See Also:
  771. - :class:`EventList.key_averages`: Method that creates FunctionEventAvg objects
  772. - :class:`FunctionEvent`: Individual profiling event
  773. - :class:`EventList`: Container for profiling events
  774. """
  775. def __init__(self) -> None:
  776. self.key: Optional[str] = None
  777. self.count: int = 0
  778. self.node_id: int = 0
  779. self.is_async: bool = False
  780. self.is_remote: bool = False
  781. self.use_device: Optional[str] = None
  782. self.cpu_time_total: int = 0
  783. self.device_time_total: int = 0
  784. self.self_cpu_time_total: int = 0
  785. self.self_device_time_total: int = 0
  786. self.input_shapes: Optional[list[list[int]]] = None
  787. self.overload_name: Optional[str] = None
  788. self.stack: Optional[list] = None
  789. self.scope: Optional[int] = None
  790. self.cpu_memory_usage: int = 0
  791. self.device_memory_usage: int = 0
  792. self.self_cpu_memory_usage: int = 0
  793. self.self_device_memory_usage: int = 0
  794. self.cpu_children: Optional[list[FunctionEvent]] = None
  795. self.cpu_parent: Optional[FunctionEvent] = None
  796. self.device_type: DeviceType = DeviceType.CPU
  797. self.is_legacy: bool = False
  798. self.flops: int = 0
  799. def add(self, other):
  800. if self.key is None:
  801. # First function being recorded as part of FunctionEventAvg, propagate
  802. # fields.
  803. self.key = other.key
  804. self.node_id = other.node_id
  805. self.is_async = other.is_async
  806. self.is_remote = other.is_remote
  807. self.cpu_parent = other.cpu_parent
  808. self.cpu_children = other.cpu_children
  809. self.overload_name = other.overload_name
  810. self.input_shapes = other.input_shapes
  811. self.stack = other.stack
  812. self.scope = other.scope
  813. self.device_type = other.device_type
  814. self.is_legacy = other.is_legacy
  815. self.use_device = other.use_device
  816. self.is_user_annotation = other.is_user_annotation
  817. if not isinstance(other, (FunctionEvent, FunctionEventAvg)):
  818. raise AssertionError(
  819. "Expected other to be a FunctionEvent or FunctionEventAvg"
  820. )
  821. if other.key != self.key:
  822. raise AssertionError(
  823. f"Expected keys to match, but got {other.key} vs {self.key}"
  824. )
  825. self.cpu_time_total += other.cpu_time_total
  826. self.device_time_total += other.device_time_total
  827. self.self_cpu_time_total += other.self_cpu_time_total
  828. self.self_device_time_total += other.self_device_time_total
  829. self.cpu_memory_usage += other.cpu_memory_usage
  830. self.device_memory_usage += other.device_memory_usage
  831. self.self_cpu_memory_usage += other.self_cpu_memory_usage
  832. self.self_device_memory_usage += other.self_device_memory_usage
  833. self.count += other.count
  834. if self.flops is None:
  835. # pyrefly: ignore [bad-assignment]
  836. self.flops = other.flops
  837. elif other.flops is not None:
  838. self.flops += other.flops
  839. return self
  840. def __iadd__(self, other):
  841. return self.add(other)
  842. def __repr__(self):
  843. device_name = "cuda" if not self.use_device else self.use_device
  844. self_device_time = self.self_device_time_total_str
  845. device_time = self.device_time_str
  846. device_memory = self.device_memory_usage
  847. return (
  848. f"<FunctionEventAvg key={self.key} self_cpu_time={self.self_cpu_time_total_str} cpu_time={self.cpu_time_str} "
  849. f" self_{device_name}_time={self_device_time} {device_name}_time={device_time} input_shapes={str(self.input_shapes)} "
  850. f"cpu_memory_usage={self.cpu_memory_usage} {device_name}_memory_usage={device_memory}>"
  851. )
  852. class StringTable(defaultdict):
  853. def __missing__(self, key):
  854. # manage cases like 't' (demangled to 'unsigned short') separately,
  855. # for now simply check the length to avoid unexpected results for
  856. # the short sequences
  857. self[key] = torch._C._demangle(key) if len(key) > 1 else key
  858. return self[key]
  859. class MemRecordsAcc:
  860. """Acceleration structure for accessing mem_records in interval."""
  861. def __init__(self, mem_records):
  862. self._mem_records = mem_records
  863. self._start_nses: list[int] = []
  864. self._indices: list[int] = []
  865. if len(mem_records) > 0:
  866. tmp = sorted([(r[0].start_ns(), i) for i, r in enumerate(mem_records)])
  867. self._start_nses, self._indices = zip(*tmp) # type: ignore[assignment]
  868. def in_interval(self, start_ns, end_ns):
  869. r"""
  870. Return all records in the given interval
  871. """
  872. start_idx = bisect.bisect_left(self._start_nses, start_ns)
  873. end_idx = bisect.bisect_right(self._start_nses, end_ns)
  874. for i in range(start_idx, end_idx):
  875. yield self._mem_records[self._indices[i]]
  876. def _filter_stack_entry(entry):
  877. filtered_entries = [
  878. ("autograd/__init__", "_make_grads"),
  879. ("autograd/__init__", "backward"),
  880. ("torch/tensor", "backward"),
  881. ("_internal/common_utils", "prof_callable"),
  882. ("_internal/common_utils", "prof_func_call"),
  883. ("_internal/common_utils", "prof_meth_call"),
  884. ]
  885. return all(not (f[0] in entry and f[1] in entry) for f in filtered_entries)
  886. MEMORY_EVENT_NAME = "[memory]"
  887. OUT_OF_MEMORY_EVENT_NAME = "[OutOfMemory]"
  888. def _filter_name(name):
  889. # ignoring the following utility ops
  890. filtered_out_names = [
  891. MEMORY_EVENT_NAME, # used only for the top-level memory events
  892. OUT_OF_MEMORY_EVENT_NAME,
  893. "profiler::_record_function_enter",
  894. "profiler::_record_function_enter_new",
  895. "profiler::_record_function_exit",
  896. "aten::is_leaf",
  897. "aten::output_nr",
  898. "aten::_version",
  899. ]
  900. return name in filtered_out_names
  901. # Demangles and optionally rewrites the provided event name,
  902. # with_wildcard - whether to replace certain numbered event names
  903. # with a wildcard name to aggregate them together in the profiler table
  904. # output
  905. def _rewrite_name(name, with_wildcard=False):
  906. string_table = StringTable()
  907. name = string_table[name]
  908. if with_wildcard:
  909. if name.startswith("ProfilerStep#"):
  910. name = "ProfilerStep*"
  911. return name
  912. def _build_table(
  913. events,
  914. sort_by=None,
  915. header=None,
  916. row_limit=100,
  917. max_src_column_width=75,
  918. max_name_column_width=55,
  919. max_shapes_column_width=80,
  920. with_flops=False,
  921. profile_memory=False,
  922. top_level_events_only=False,
  923. time_unit=None,
  924. ):
  925. """Print a summary of events (which can be a list of FunctionEvent or FunctionEventAvg)."""
  926. if len(events) == 0:
  927. return ""
  928. has_device_time = any(event.self_device_time_total > 0 for event in events)
  929. has_device_mem = any(event.self_device_memory_usage > 0 for event in events)
  930. use_device = events[0].use_device
  931. # Running on PrivateUse1 device with profiler but not enable
  932. # ProfilerActivity.PrivateUse1 can also catch privateuse1 memory usage.
  933. # Here only need to check has_privateuse1_time if not use_device.
  934. if not use_device and has_device_time:
  935. raise RuntimeError("use_device is None, but there is device performance data.")
  936. has_input_shapes = any(
  937. (event.input_shapes is not None and len(event.input_shapes) > 0)
  938. for event in events
  939. )
  940. has_overload_names = any(
  941. (event.overload_name is not None and len(event.overload_name) > 0)
  942. for event in events
  943. )
  944. if sort_by is not None:
  945. events = EventList(
  946. sorted(
  947. events,
  948. key=lambda evt: getattr(
  949. evt,
  950. sort_by.replace("cuda", "device")
  951. .replace("xpu", "device")
  952. .replace("privateuse1", "device"),
  953. ),
  954. reverse=True,
  955. ),
  956. use_device=use_device,
  957. profile_memory=profile_memory,
  958. with_flops=with_flops,
  959. )
  960. name_column_width = max(len(evt.key) for evt in events) + 4
  961. if max_name_column_width is not None:
  962. name_column_width = min(name_column_width, max_name_column_width)
  963. shapes_column_width = max(len(str(evt.input_shapes)) for evt in events) + 4
  964. if max_shapes_column_width is not None:
  965. shapes_column_width = min(shapes_column_width, max_shapes_column_width)
  966. DEFAULT_COLUMN_WIDTH = 12
  967. flops_column_width = DEFAULT_COLUMN_WIDTH
  968. src_column_width = None
  969. stacks = [
  970. evt.stack for evt in events if evt.stack is not None and len(evt.stack) > 0
  971. ]
  972. has_stack = len(stacks) > 0
  973. if has_stack:
  974. src_column_width = (
  975. max(max(len(entry) for entry in stack) for stack in stacks) + 4
  976. )
  977. if max_src_column_width is not None:
  978. src_column_width = min(src_column_width, max_src_column_width)
  979. headers = ["Name"]
  980. if has_overload_names:
  981. headers.append("Overload Name")
  982. headers += [
  983. "Self CPU %",
  984. "Self CPU",
  985. "CPU total %",
  986. "CPU total",
  987. "CPU time avg",
  988. ]
  989. device_name = use_device.upper() if use_device is not None else "None"
  990. if has_device_time:
  991. headers.extend(
  992. [
  993. f"Self {device_name}",
  994. f"Self {device_name} %",
  995. f"{device_name} total",
  996. f"{device_name} time avg",
  997. ]
  998. )
  999. if profile_memory:
  1000. headers.extend(
  1001. [
  1002. "CPU Mem",
  1003. "Self CPU Mem",
  1004. ]
  1005. )
  1006. if use_device and has_device_mem:
  1007. headers.extend(
  1008. [
  1009. f"{device_name} Mem",
  1010. f"Self {device_name} Mem",
  1011. ]
  1012. )
  1013. headers.append("# of Calls")
  1014. # Only append Node ID if any event has a valid (>= 0) Node ID
  1015. append_node_id = any(evt.node_id != -1 for evt in events)
  1016. if append_node_id:
  1017. headers.append("Node ID")
  1018. # Have to use a list because nonlocal is Py3 only...
  1019. SPACING_SIZE = 2
  1020. row_format_lst = [""]
  1021. header_sep_lst = [""]
  1022. line_length_lst = [-SPACING_SIZE]
  1023. def add_column(padding, text_dir=">"):
  1024. row_format_lst[0] += (
  1025. "{: " + text_dir + str(padding) + "}" + (" " * SPACING_SIZE)
  1026. )
  1027. header_sep_lst[0] += "-" * padding + (" " * SPACING_SIZE)
  1028. line_length_lst[0] += padding + SPACING_SIZE
  1029. def auto_scale_flops(flops):
  1030. flop_headers = [
  1031. "FLOPs",
  1032. "KFLOPs",
  1033. "MFLOPs",
  1034. "GFLOPs",
  1035. "TFLOPs",
  1036. "PFLOPs",
  1037. ]
  1038. if flops <= 0:
  1039. raise AssertionError(f"Expected flops to be positive, but got {flops}")
  1040. # pyrefly: ignore [no-matching-overload]
  1041. log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
  1042. if not (log_flops >= 0 and log_flops < len(flop_headers)):
  1043. raise AssertionError(
  1044. f"Expected log_flops to be in range [0, {len(flop_headers)}), but got {log_flops}"
  1045. )
  1046. return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
  1047. add_column(name_column_width)
  1048. if has_overload_names:
  1049. add_column(name_column_width)
  1050. for _ in headers[1 + has_overload_names :]:
  1051. add_column(DEFAULT_COLUMN_WIDTH)
  1052. if has_input_shapes:
  1053. headers.append("Input Shapes")
  1054. add_column(shapes_column_width)
  1055. if has_stack:
  1056. headers.append("Source Location")
  1057. add_column(src_column_width, text_dir="<")
  1058. if with_flops:
  1059. # Auto-scaling of flops header
  1060. raw_flops = [evt.flops for evt in events if evt.flops > 0]
  1061. if len(raw_flops) != 0:
  1062. (flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
  1063. headers.append(f"Total {flops_header}")
  1064. add_column(flops_column_width)
  1065. else:
  1066. with_flops = False # can't find any valid flops
  1067. row_format = row_format_lst[0]
  1068. header_sep = header_sep_lst[0]
  1069. line_length = line_length_lst[0]
  1070. add_column = None # type: ignore[assignment]
  1071. # Have to use a list because nonlocal is Py3 only...
  1072. result = []
  1073. def append(s):
  1074. result.append(s)
  1075. result.append("\n") # Yes, newline after the end as well
  1076. sum_self_cpu_time_total = 0
  1077. sum_self_device_time_total = 0
  1078. for evt in events:
  1079. sum_self_cpu_time_total += evt.self_cpu_time_total
  1080. if evt.device_type == DeviceType.CPU and evt.is_legacy:
  1081. # in legacy profiler, kernel info is stored in cpu events
  1082. sum_self_device_time_total += evt.self_device_time_total
  1083. elif (
  1084. evt.device_type
  1085. in [
  1086. DeviceType.CUDA,
  1087. DeviceType.PrivateUse1,
  1088. DeviceType.MTIA,
  1089. DeviceType.XPU,
  1090. ]
  1091. and not evt.is_user_annotation
  1092. ):
  1093. # in kineto profiler, there're events with the correct device type (e.g. CUDA)
  1094. sum_self_device_time_total += evt.self_device_time_total
  1095. # Actual printing
  1096. if header is not None:
  1097. append("=" * line_length)
  1098. append(header)
  1099. if top_level_events_only:
  1100. append("=" * line_length)
  1101. append("This report only display top-level ops statistics")
  1102. append(header_sep)
  1103. append(row_format.format(*headers))
  1104. append(header_sep)
  1105. def trim_path(path, src_column_width):
  1106. if len(path) > src_column_width:
  1107. offset = len(path) - src_column_width
  1108. path = path[offset:]
  1109. if len(path) > 3:
  1110. path = "..." + path[3:]
  1111. return path
  1112. def override_time_unit(time_us, default_str, time_unit):
  1113. US_IN_SECOND = 1000.0 * 1000.0
  1114. US_IN_MS = 1000.0
  1115. if time_unit == "s":
  1116. return f"{time_us / US_IN_SECOND:.3f}s"
  1117. elif time_unit == "ms":
  1118. return f"{time_us / US_IN_MS:.3f}ms"
  1119. elif time_unit == "us":
  1120. return f"{time_us:.3f}us"
  1121. else:
  1122. return default_str
  1123. event_limit = 0
  1124. for evt in events:
  1125. if event_limit == row_limit:
  1126. break
  1127. if top_level_events_only and evt.cpu_parent is not None:
  1128. continue
  1129. else:
  1130. event_limit += 1
  1131. name = evt.key
  1132. if max_name_column_width is not None and len(name) >= max_name_column_width - 3:
  1133. name = name[: (max_name_column_width - 3)] + "..."
  1134. evt.self_cpu_percent = _format_time_share(
  1135. evt.self_cpu_time_total, sum_self_cpu_time_total
  1136. )
  1137. evt.total_cpu_percent = (
  1138. _format_time_share(evt.cpu_time_total, sum_self_cpu_time_total)
  1139. if not evt.is_async
  1140. else 0
  1141. )
  1142. row_values = [name]
  1143. if has_overload_names:
  1144. overload_name = evt.overload_name
  1145. if (
  1146. max_name_column_width is not None
  1147. and len(overload_name) >= max_name_column_width - 3
  1148. ):
  1149. overload_name = overload_name[: (max_name_column_width - 3)] + "..."
  1150. row_values += [overload_name]
  1151. row_values += [
  1152. # Self CPU total %, 0 for async events.
  1153. evt.self_cpu_percent,
  1154. override_time_unit(
  1155. evt.self_cpu_time_total, evt.self_cpu_time_total_str, time_unit
  1156. ), # Self CPU total
  1157. # CPU total %, 0 for async events.
  1158. evt.total_cpu_percent,
  1159. override_time_unit(
  1160. evt.cpu_time_total, evt.cpu_time_total_str, time_unit
  1161. ), # CPU total
  1162. override_time_unit(
  1163. evt.cpu_time, evt.cpu_time_str, time_unit
  1164. ), # CPU time avg
  1165. ]
  1166. if has_device_time:
  1167. evt.total_device_percent = _format_time_share(
  1168. evt.self_device_time_total, sum_self_device_time_total
  1169. )
  1170. row_values.extend(
  1171. [
  1172. override_time_unit(
  1173. evt.self_device_time_total,
  1174. evt.self_device_time_total_str,
  1175. time_unit,
  1176. ),
  1177. # device time total %
  1178. evt.total_device_percent,
  1179. override_time_unit(
  1180. evt.device_time_total, evt.device_time_total_str, time_unit
  1181. ),
  1182. override_time_unit(
  1183. evt.device_time, evt.device_time_str, time_unit
  1184. ), # device time avg
  1185. ]
  1186. )
  1187. if profile_memory:
  1188. row_values.extend(
  1189. [
  1190. # CPU Mem Total
  1191. _format_memory(evt.cpu_memory_usage),
  1192. # Self CPU Mem Total
  1193. _format_memory(evt.self_cpu_memory_usage),
  1194. ]
  1195. )
  1196. if use_device and has_device_mem:
  1197. row_values.extend(
  1198. [
  1199. # Device Mem Total
  1200. _format_memory(evt.device_memory_usage),
  1201. # Self Device Mem Total
  1202. _format_memory(evt.self_device_memory_usage),
  1203. ]
  1204. )
  1205. row_values.append(
  1206. evt.count, # Number of calls
  1207. )
  1208. if append_node_id:
  1209. row_values.append(evt.node_id)
  1210. if has_input_shapes:
  1211. row_values.append(str(evt.input_shapes)[:shapes_column_width])
  1212. if with_flops:
  1213. if evt.flops <= 0:
  1214. row_values.append("--")
  1215. else:
  1216. row_values.append(f"{evt.flops * flops_scale:8.3f}") # type: ignore[possibly-undefined]
  1217. if has_stack:
  1218. src_field = ""
  1219. if len(evt.stack) > 0:
  1220. src_field = trim_path(evt.stack[0], src_column_width)
  1221. row_values.append(src_field)
  1222. append(row_format.format(*row_values))
  1223. if has_stack:
  1224. empty_headers = [""] * (len(headers) - 1)
  1225. for entry in evt.stack[1:]:
  1226. append(
  1227. row_format.format(
  1228. *(empty_headers + [trim_path(entry, src_column_width)])
  1229. )
  1230. )
  1231. empty_headers.append("")
  1232. append(row_format.format(*empty_headers))
  1233. append(header_sep)
  1234. append(
  1235. f"Self CPU time total: {override_time_unit(sum_self_cpu_time_total, _format_time(sum_self_cpu_time_total), time_unit)}"
  1236. )
  1237. if has_device_time:
  1238. append(
  1239. f"Self {use_device.upper() if use_device is not None else 'None'} "
  1240. f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
  1241. )
  1242. return "".join(result)
  1243. # Collect all events with stack traces and format them canonically
  1244. def _canonicalize_profiler_events(events):
  1245. """
  1246. Extract and format all events with stack traces in a canonical way
  1247. for deterministic testing.
  1248. """
  1249. events_with_traces = []
  1250. for event in events:
  1251. # Extract relevant fields
  1252. event_name = event.get("name", "")
  1253. node_name = event["args"].get("node_name", "")
  1254. stack_trace = event["args"].get("stack_trace", "")
  1255. # Get the last non-empty line of the stack trace
  1256. lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
  1257. stack_trace = lines[-1] if lines else ""
  1258. events_with_traces.append(
  1259. {
  1260. "event_name": event_name[:30],
  1261. "node_name": node_name,
  1262. "stack_trace": stack_trace,
  1263. "start_time": event.get("ts", 0),
  1264. }
  1265. )
  1266. # Sort by node_name for deterministic ordering
  1267. events_with_traces.sort(key=lambda x: x["start_time"])
  1268. # Format as a string
  1269. lines: list[str] = []
  1270. for evt in events_with_traces:
  1271. lines.append(
  1272. f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
  1273. )
  1274. return "\n".join(lines)