_utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import operator
  4. import re
  5. from collections import deque
  6. from dataclasses import dataclass
  7. from typing import Any, Literal, TYPE_CHECKING
  8. from torch.autograd.profiler import profile
  9. from torch.profiler import DeviceType
  10. if TYPE_CHECKING:
  11. from torch.autograd import _KinetoEvent
  12. def _traverse(tree, next_fn, children_fn=lambda x: x.children, reverse: bool = False):
  13. order = reversed if reverse else lambda x: x
  14. remaining = deque(order(tree))
  15. while remaining:
  16. curr_event = next_fn(remaining)
  17. yield curr_event
  18. for child_event in order(children_fn(curr_event)):
  19. remaining.append(child_event)
  20. traverse_dfs = functools.partial(_traverse, next_fn=lambda x: x.pop(), reverse=True)
  21. traverse_bfs = functools.partial(
  22. _traverse, next_fn=lambda x: x.popleft(), reverse=False
  23. )
  24. @dataclass
  25. class EventMetrics:
  26. duration_time_ns: int = 0
  27. self_time_ns: int = 0
  28. idle_time_ns: int = 0
  29. queue_depth: int = 0
  30. @property
  31. def fraction_idle_time(self):
  32. if self.duration_time_ns == 0:
  33. return 0.0
  34. return self.idle_time_ns / self.duration_time_ns
  35. @dataclass
  36. class Interval:
  37. start: int
  38. end: int
  39. queue_depth: int = 0
  40. class EventKey:
  41. def __init__(self, event) -> None:
  42. self.event = event
  43. def __hash__(self):
  44. return hash(self.event.id)
  45. def __eq__(self, other):
  46. return self.event.id == other.event.id
  47. def __repr__(self) -> str:
  48. return f"{self.event.name}"
  49. def intervals_overlap(self, intervals: list[Interval]):
  50. overlap_time = 0
  51. intervals = sorted(intervals, key=lambda x: x.start)
  52. if intervals:
  53. overlap_start = max(self.event.start_time_ns, intervals[0].start)
  54. overlap_end = min(self.event.end_time_ns, intervals[0].end)
  55. if overlap_start < overlap_end:
  56. overlap_time += overlap_end - overlap_start
  57. i, j = 0, 1
  58. while j < len(intervals):
  59. prev_interval = intervals[i]
  60. curr_interval = intervals[j]
  61. j += 1
  62. if prev_interval.end > curr_interval.start:
  63. # Completely subsumed by previous interval
  64. if prev_interval.end > curr_interval.end:
  65. j += 1
  66. continue
  67. else:
  68. curr_interval.start = prev_interval.end
  69. i = j
  70. overlap_start = max(self.event.start_time_ns, curr_interval.start)
  71. overlap_end = min(self.event.end_time_ns, curr_interval.end)
  72. if overlap_start < overlap_end:
  73. overlap_time += overlap_end - overlap_start
  74. return overlap_time
  75. class BasicEvaluation:
  76. def __init__(self, prof: profile) -> None:
  77. self.profile = prof
  78. self.metrics: dict[EventKey, EventMetrics] = {}
  79. self.compute_self_time()
  80. self.event_keys = sorted(
  81. self.metrics.keys(), key=lambda x: x.event.start_time_ns
  82. )
  83. self.events = [e.event for e in self.event_keys]
  84. self.cuda_events: list[_KinetoEvent] = []
  85. self.queue_depth_list = self.compute_queue_depth()
  86. self.compute_idle_time()
  87. def compute_self_time(self) -> None:
  88. """
  89. Computes event's self time(total time - time in child ops).
  90. """
  91. if self.profile.kineto_results is None:
  92. raise AssertionError("kineto_results must not be None")
  93. stack = deque(self.profile.kineto_results.experimental_event_tree())
  94. # standard iterating dfs
  95. while stack:
  96. curr_event = stack.pop()
  97. self_time = curr_event.duration_time_ns
  98. for child_event in curr_event.children:
  99. self_time -= child_event.duration_time_ns
  100. stack.append(child_event)
  101. if EventKey(curr_event) in self.metrics:
  102. raise AssertionError(
  103. f"Duplicate id: {curr_event.id}, {curr_event.name}"
  104. )
  105. self.metrics[EventKey(curr_event)] = EventMetrics(self_time_ns=self_time)
  106. self.metrics[
  107. EventKey(curr_event)
  108. ].duration_time_ns = curr_event.duration_time_ns
  109. def compute_queue_depth(self):
  110. """
  111. Computes queue_depth at each event. This will calculate the queue depth data for
  112. All the events in the tree.
  113. This will return a list of Interval of queue depth data of cuda launch and kernels.
  114. """
  115. if self.profile.kineto_results is None:
  116. raise AssertionError("kineto_results must not be None")
  117. cuda_event_list = self.profile.kineto_results.events()
  118. def is_cuda_launch_kernel(e):
  119. """Check if the event is a CUDA launch kernel."""
  120. launch_patterns = {
  121. "cudaLaunchKernel", # Standard CUDA
  122. "cudaLaunchKernelExC", # Extended C
  123. "__cudaLaunchKernel", # Internal
  124. "cudaLaunchCooperativeKernel", # Collaborative (single-device)
  125. "cudaLaunchCooperativeKernelMultiDevice", # Collaborative (multi-devices)
  126. }
  127. name = str(getattr(e, "name", e))
  128. return any(name.startswith(pattern) for pattern in launch_patterns)
  129. def is_cuda_kernel(e):
  130. """Check if the event is a CUDA runtime kernel."""
  131. # Check if the kernel is CUDA
  132. if e.device_type() != DeviceType.CUDA:
  133. return False
  134. name = str(getattr(e, "name", e)).lower()
  135. # Exclude memory operations
  136. exclude_patterns = {"mem", "cpy", "alloc", "free"}
  137. return not any(pattern in name for pattern in exclude_patterns)
  138. cuda_launch_events = sorted(
  139. (e for e in cuda_event_list if is_cuda_launch_kernel(e)),
  140. key=lambda x: x.start_ns(),
  141. )
  142. cuda_kernel_events = sorted(
  143. (e for e in cuda_event_list if is_cuda_kernel(e)),
  144. key=lambda x: x.start_ns(),
  145. )
  146. self.cuda_events = sorted(
  147. cuda_launch_events + cuda_kernel_events, key=lambda x: x.start_ns()
  148. )
  149. kernel_mapping: dict[_KinetoEvent, int] = {}
  150. last_mapped_kernel = 0
  151. for cuda_launch_event in cuda_launch_events:
  152. index = index_of_first_match(
  153. cuda_kernel_events,
  154. lambda x: x.linked_correlation_id()
  155. == cuda_launch_event.linked_correlation_id(),
  156. start=last_mapped_kernel,
  157. )
  158. kernel_mapping[cuda_launch_event] = index
  159. last_mapped_kernel = index if index is not None else last_mapped_kernel
  160. current_kernel_index = 0
  161. spawned_kernel_index = -1
  162. all_events = cuda_launch_events + cuda_kernel_events + self.events
  163. def new_old_event_comparator(event):
  164. if hasattr(event, "start_us"):
  165. return event.start_us() * 1000
  166. if hasattr(event, "start_ns"):
  167. return event.start_ns()
  168. if hasattr(event, "start_time_ns"):
  169. return event.start_time_ns
  170. raise Exception("Unknown Event Type") # noqa: TRY002
  171. queue_depth_list: list[Interval] = []
  172. all_events.sort(key=new_old_event_comparator)
  173. for event in all_events:
  174. # Find latest cuda kernel event
  175. if hasattr(event, "start_us"):
  176. start_time = event.start_us() * 1000
  177. # pyrefly: ignore [missing-attribute]
  178. end_time = (event.start_us() + event.duration_us()) * 1000
  179. # Find current spawned cuda kernel event
  180. if event in kernel_mapping and kernel_mapping[event] is not None:
  181. spawned_kernel_index = kernel_mapping[event]
  182. if hasattr(event, "start_ns"):
  183. start_time = event.start_ns()
  184. end_time = event.start_ns() + event.duration_ns()
  185. # Find current spawned cuda kernel event
  186. if event in kernel_mapping and kernel_mapping[event] is not None:
  187. spawned_kernel_index = kernel_mapping[event]
  188. elif hasattr(event, "start_time_ns"):
  189. start_time = event.start_time_ns # type: ignore[attr-defined]
  190. end_time = event.end_time_ns # type: ignore[attr-defined]
  191. while (
  192. current_kernel_index < len(cuda_kernel_events)
  193. and (cuda_kernel_events[current_kernel_index].start_ns()) <= start_time # type: ignore[possibly-undefined]
  194. ):
  195. current_kernel_index += 1
  196. current_queue_depth = spawned_kernel_index - current_kernel_index + 1
  197. current_queue_depth = max(current_queue_depth, 0)
  198. if hasattr(event, "start_us") or hasattr(event, "start_ns"):
  199. queue_depth_list.append(
  200. Interval(start_time, end_time, current_queue_depth) # type: ignore[possibly-undefined]
  201. )
  202. elif hasattr(event, "start_time_ns"):
  203. self.metrics[EventKey(event)].queue_depth = current_queue_depth
  204. return queue_depth_list
  205. def compute_idle_time(self) -> None:
  206. """
  207. Computes idle time of the profile.
  208. """
  209. # Based on queue_depth_list, we can calculate idle time for all the events
  210. idle = False
  211. idle_start = 0
  212. idle_intervals: list[Interval] = []
  213. if self.queue_depth_list and self.events:
  214. idle_intervals += [
  215. Interval(self.events[0].start_time_ns, self.queue_depth_list[0].start),
  216. Interval(self.queue_depth_list[-1].end, self.events[-1].end_time_ns),
  217. ]
  218. for data_point in self.queue_depth_list:
  219. if data_point.queue_depth == 0 and not idle:
  220. idle_start = data_point.end
  221. idle = True
  222. if data_point.queue_depth > 0 and idle:
  223. idle_intervals.append(Interval(idle_start, data_point.start))
  224. idle = False
  225. event_list = [e.event for e in self.metrics]
  226. for event in event_list:
  227. self.metrics[EventKey(event)].idle_time_ns = EventKey(
  228. event
  229. ).intervals_overlap(idle_intervals)
  230. def rank_events(self, length):
  231. """
  232. Filter and Rank the events based on some heuristics:
  233. 1) Events that are in the falling phase of the queue depth.
  234. 2) Events that have a high idle_time, self_time difference.
  235. Parameters:
  236. length: The number of events to return.
  237. """
  238. # Find the interval when qd is falling to 0
  239. import torch
  240. queue_depth_list = list(reversed(self.queue_depth_list))
  241. qd_values = [e.queue_depth for e in queue_depth_list]
  242. bottom_threashold = 0
  243. top_threashold = 4
  244. decrease_interval = []
  245. i = 0
  246. while i < len(qd_values):
  247. if qd_values[i] > bottom_threashold:
  248. i += 1
  249. continue
  250. for j in range(i + 1, len(qd_values)):
  251. # Find next zero and if the max value between them exceeds
  252. # the threshold, then we have a falling interval
  253. next_minimum_idx = index_of_first_match(
  254. qd_values, lambda x: x <= bottom_threashold, start=j
  255. )
  256. peak_idx = argmax(qd_values, start=j, end=next_minimum_idx)
  257. # if is a valid peak, we add to list and continue
  258. if peak_idx is not None and qd_values[peak_idx] >= top_threashold:
  259. decrease_interval.append(
  260. Interval(
  261. queue_depth_list[peak_idx].start, queue_depth_list[i].start
  262. )
  263. )
  264. i = next_minimum_idx if next_minimum_idx is not None else i
  265. break
  266. i += 1
  267. # Filter out events that are not in the decrease interval
  268. event_list = [
  269. event
  270. for event in self.metrics
  271. if event.intervals_overlap(decrease_interval)
  272. ]
  273. if event_list:
  274. self_time = torch.tensor(
  275. [self.metrics[event].self_time_ns for event in event_list],
  276. dtype=torch.float32,
  277. )
  278. idle_time = torch.tensor(
  279. [self.metrics[event].fraction_idle_time for event in event_list],
  280. dtype=torch.float32,
  281. )
  282. normalized_gain = (idle_time - torch.mean(idle_time)) / torch.std(idle_time)
  283. normalized_self = (self_time - torch.mean(self_time)) / torch.std(self_time)
  284. heuristic_score_list = normalized_gain + 0.6 * normalized_self
  285. # Sort events by heuristic
  286. event_list = [
  287. event
  288. for _, event in sorted(
  289. zip(heuristic_score_list, event_list, strict=True),
  290. key=operator.itemgetter(0),
  291. reverse=True,
  292. )
  293. ]
  294. event_list = event_list[:length]
  295. return event_list
  296. def get_optimizable_events(self, length: int = 1, print_enable: bool = True):
  297. event_list = self.rank_events(length)
  298. if not print_enable:
  299. return event_list
  300. output = "Optimizable events:\n" if event_list else "No events to optimize\n"
  301. output += "\n".join(
  302. [
  303. f"""{"-" * 80}
  304. Event: {event}
  305. Source code location: {source_code_location(event.event)}
  306. Percentage idle time: {self.metrics[event].fraction_idle_time * 100:.2f}%
  307. {"-" * 80}"""
  308. for event in event_list
  309. ]
  310. )
  311. if print_enable:
  312. print(output)
  313. return event_list
  314. def index_of_first_match(seq, predicate, start=0, end=None):
  315. if end is None or end >= len(seq):
  316. end = len(seq)
  317. for i in range(start, end):
  318. if predicate(seq[i]):
  319. return i
  320. return None
  321. def argmax(seq, key=lambda x: x, start=0, end=None):
  322. seq = seq[start:end]
  323. if len(seq) == 0:
  324. return None
  325. return seq.index(max(seq, key=key)) + start
  326. def source_code_location(event):
  327. while event is not None:
  328. match = re.search(r"\.py\(.*\)", event.name)
  329. if match is None:
  330. event = event.parent
  331. continue
  332. return event.name
  333. return "No source code location found"
  334. # Provide an OSS workaround for cudagraphs + CUPTI issue
  335. # https://github.com/pytorch/pytorch/issues/75504
  336. # TODO(dberard) - deprecate / remove workaround for CUDA >= 12, when
  337. # we stop supporting older CUDA versions.
  338. def _init_for_cuda_graphs() -> None:
  339. from torch.autograd.profiler import profile
  340. with profile():
  341. pass
  342. @dataclass
  343. class TimelineEvent:
  344. """Represents an event in the profiler timeline."""
  345. timestamp: int
  346. event_type: Literal["start", "end", "regular"]
  347. marker_type: Literal["filename", "node"] | None
  348. identifier: str | int | None
  349. event: dict[str, Any]
  350. @dataclass
  351. class ContextStackEntry:
  352. """Represents a context (filename or node) in the stack."""
  353. context_type: Literal["filename", "node"]
  354. identifier: str | int
  355. metadata: dict | None
  356. tid: int | None = None # Thread ID associated with this context
  357. def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
  358. """
  359. Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
  360. Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
  361. sorts by timestamp, then processes chronologically while maintaining a context stack of active
  362. filename/node scopes. Regular events are augmented with stack traces and node names from the
  363. innermost active context. Runtime is O(n log n) for n events.
  364. Args:
  365. traced_data: Json of profiler events from Chrome trace
  366. Returns:
  367. Dict mapping recorded event names to their aten operations with added stack traces
  368. """
  369. from torch.fx.traceback import _FX_METADATA_REGISTRY
  370. trace_events = traced_data.get("traceEvents", [])
  371. # Create event timeline
  372. event_timeline: list[TimelineEvent] = []
  373. def is_fx_marker_event(event):
  374. return (
  375. event.get("cat") == "cpu_op"
  376. and event.get("name", "").startswith("## ")
  377. and event.get("name", "").endswith(" ##")
  378. )
  379. def append_fx_marker_event(event_type, identifier, event):
  380. start_ts = event["ts"]
  381. end_ts = start_ts + event["dur"]
  382. event_timeline.append(
  383. TimelineEvent(start_ts, "start", event_type, identifier, event)
  384. )
  385. event_timeline.append(
  386. TimelineEvent(end_ts, "end", event_type, identifier, event)
  387. )
  388. for event in trace_events:
  389. if "ts" not in event or "dur" not in event:
  390. continue
  391. if is_fx_marker_event(event):
  392. content = event["name"][3:-3]
  393. if content.endswith(".py"):
  394. append_fx_marker_event("filename", content, event)
  395. else:
  396. try:
  397. node_index = int(content)
  398. except ValueError:
  399. pass
  400. append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
  401. else:
  402. # Regular event that needs augmentation
  403. start_ts = event["ts"]
  404. event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
  405. # Sort by timestamp
  406. event_timeline.sort(key=lambda x: x.timestamp)
  407. # Process events in chronological order with a stack
  408. context_stack: list[ContextStackEntry] = []
  409. # Invariant: all start event has a corresponding end event
  410. for timeline_event in event_timeline:
  411. match timeline_event.event_type:
  412. case "start":
  413. if timeline_event.identifier is None:
  414. raise AssertionError("identifier must not be None for start event")
  415. if timeline_event.marker_type == "filename":
  416. if not isinstance(timeline_event.identifier, str):
  417. raise AssertionError(
  418. f"identifier must be str for filename marker, "
  419. f"got {type(timeline_event.identifier).__name__}"
  420. )
  421. # Push filename context - query metadata registry on-demand
  422. metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
  423. tid = timeline_event.event.get("tid")
  424. context_stack.append(
  425. ContextStackEntry(
  426. "filename", timeline_event.identifier, metadata, tid
  427. )
  428. )
  429. elif timeline_event.marker_type == "node":
  430. # Find the current filename from stack
  431. current_file_metadata = None
  432. tid = timeline_event.event.get("tid")
  433. for ctx_entry in reversed(context_stack):
  434. if (
  435. ctx_entry.context_type == "filename"
  436. and ctx_entry.tid == tid
  437. ):
  438. current_file_metadata = ctx_entry.metadata
  439. break
  440. if current_file_metadata:
  441. node_metadata = current_file_metadata.get("node_metadata", {})
  442. if timeline_event.identifier in node_metadata:
  443. node_meta: dict | None = node_metadata[
  444. timeline_event.identifier
  445. ]
  446. context_stack.append(
  447. ContextStackEntry(
  448. "node", timeline_event.identifier, node_meta, tid
  449. )
  450. )
  451. case "end":
  452. # Pop from stack - search backwards to find matching context
  453. for i in range(len(context_stack) - 1, -1, -1):
  454. ctx_entry = context_stack[i]
  455. if (
  456. timeline_event.marker_type == ctx_entry.context_type
  457. and timeline_event.identifier == ctx_entry.identifier
  458. ):
  459. context_stack.pop(i)
  460. break
  461. case "regular":
  462. # Apply metadata from current context stack
  463. # Find the most specific context (node takes precedence over filename)
  464. # Only augment events with the same tid as the file/node event matched
  465. current_stack_trace = None
  466. current_node_name = None
  467. event_tid = timeline_event.event.get("tid")
  468. for ctx_entry in reversed(context_stack):
  469. # Only apply metadata from contexts with matching tid
  470. if ctx_entry.tid == event_tid:
  471. if ctx_entry.context_type == "node" and ctx_entry.metadata:
  472. current_stack_trace = ctx_entry.metadata.get(
  473. "stack_trace", "No model stack trace available"
  474. )
  475. current_node_name = ctx_entry.metadata.get("name", "")
  476. # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
  477. # if nodes are nested, e.g. in nested graph modules
  478. break
  479. # Augment the event
  480. if current_stack_trace or current_node_name:
  481. args = timeline_event.event.setdefault("args", {})
  482. if current_stack_trace:
  483. args["stack_trace"] = current_stack_trace
  484. if current_node_name:
  485. args["node_name"] = current_node_name