profiler_legacy.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. import warnings
  4. from typing_extensions import deprecated
  5. import torch
  6. import torch.cuda
  7. from torch.autograd import (
  8. _disable_profiler_legacy,
  9. _enable_profiler_legacy,
  10. DeviceType,
  11. ProfilerConfig,
  12. ProfilerState,
  13. )
  14. from torch.autograd.profiler_util import (
  15. _filter_name,
  16. _filter_stack_entry,
  17. _rewrite_name,
  18. EventList,
  19. FunctionEvent,
  20. MEMORY_EVENT_NAME,
  21. )
  22. __all__ = ["profile"]
  23. @deprecated(
  24. "`torch.autograd.profiler_legacy.profile` is deprecated and will be removed in a future release. "
  25. "Please use `torch.profiler` instead.",
  26. category=None, # TODO: change to `FutureWarning`
  27. )
  28. class profile:
  29. """DEPRECATED: use torch.profiler instead."""
  30. def __init__(
  31. self,
  32. enabled=True,
  33. *,
  34. use_cuda=False,
  35. record_shapes=False,
  36. with_flops=False,
  37. profile_memory=False,
  38. with_stack=False,
  39. with_modules=False,
  40. ):
  41. self.enabled: bool = enabled
  42. if not self.enabled:
  43. return
  44. self.use_cuda = use_cuda
  45. self.function_events = None
  46. self.entered = False
  47. self.record_shapes = record_shapes
  48. self.with_flops = with_flops
  49. self.record_shapes |= self.with_flops
  50. self.profile_memory = profile_memory
  51. self.with_stack = with_stack
  52. self.with_modules = with_modules
  53. if self.use_cuda and not torch.cuda.is_available():
  54. warnings.warn(
  55. "CUDA is not available, disabling CUDA profiling",
  56. stacklevel=2,
  57. )
  58. self.use_cuda = False
  59. if self.use_cuda:
  60. self.profiler_kind = ProfilerState.CUDA
  61. else:
  62. self.profiler_kind = ProfilerState.CPU
  63. def config(self):
  64. return ProfilerConfig(
  65. self.profiler_kind,
  66. self.record_shapes,
  67. self.profile_memory,
  68. self.with_stack,
  69. self.with_flops,
  70. self.with_modules,
  71. # avoid exposing _ExperimentalConfig this in legacy public API
  72. torch._C._profiler._ExperimentalConfig(),
  73. )
  74. def __enter__(self):
  75. if not self.enabled:
  76. return
  77. if self.entered:
  78. raise RuntimeError("Profiler context manager is not reentrant")
  79. self.entered = True
  80. self._start_trace()
  81. return self
  82. def _start_trace(self):
  83. _enable_profiler_legacy(self.config())
  84. def __exit__(self, exc_type, exc_val, exc_tb):
  85. if not self.enabled:
  86. return
  87. if self.use_cuda:
  88. torch.cuda.synchronize()
  89. records = _disable_profiler_legacy()
  90. parsed_results = _parse_legacy_records(records)
  91. # pyrefly: ignore [bad-assignment]
  92. self.function_events = EventList(
  93. parsed_results,
  94. use_device="cuda" if self.use_cuda else None,
  95. profile_memory=self.profile_memory,
  96. with_flops=self.with_flops,
  97. )
  98. # pyrefly: ignore [missing-attribute]
  99. self.function_events._build_tree()
  100. return False
  101. def __repr__(self):
  102. if self.function_events is None:
  103. return "<unfinished profiler_legacy.profile>"
  104. return repr(self.function_events)
  105. def __str__(self):
  106. if self.function_events is None:
  107. return "<unfinished profile.profiler_legacy.profile>"
  108. return str(self.function_events)
  109. def _check_finish(self):
  110. if self.function_events is None:
  111. raise RuntimeError("Profiler didn't finish running")
  112. def table(
  113. self,
  114. sort_by=None,
  115. row_limit=100,
  116. max_src_column_width=75,
  117. max_name_column_width=55,
  118. max_shapes_column_width=80,
  119. header=None,
  120. top_level_events_only=False,
  121. ):
  122. self._check_finish()
  123. if self.function_events is None:
  124. raise AssertionError("Expected profiling results")
  125. return self.function_events.table(
  126. sort_by=sort_by,
  127. row_limit=row_limit,
  128. max_src_column_width=max_src_column_width,
  129. max_name_column_width=max_name_column_width,
  130. max_shapes_column_width=max_shapes_column_width,
  131. header=header,
  132. top_level_events_only=top_level_events_only,
  133. )
  134. table.__doc__ = EventList.table.__doc__
  135. def export_chrome_trace(self, path):
  136. self._check_finish()
  137. if self.function_events is None:
  138. raise AssertionError("Expected profiling results")
  139. return self.function_events.export_chrome_trace(path)
  140. export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
  141. def export_stacks(self, path: str, metric: str = "self_cpu_time_total"):
  142. self._check_finish()
  143. if self.function_events is None:
  144. raise AssertionError("Expected profiling results")
  145. if not self.with_stack:
  146. raise AssertionError("export_stacks() requires with_stack=True")
  147. return self.function_events.export_stacks(path, metric)
  148. def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
  149. self._check_finish()
  150. if self.function_events is None:
  151. raise AssertionError("Expected profiling results")
  152. return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
  153. key_averages.__doc__ = EventList.key_averages.__doc__
  154. def total_average(self):
  155. self._check_finish()
  156. if self.function_events is None:
  157. raise AssertionError("Expected profiling results")
  158. return self.function_events.total_average()
  159. total_average.__doc__ = EventList.total_average.__doc__
  160. @property
  161. def self_cpu_time_total(self):
  162. """Return CPU time as the sum of self times across all events."""
  163. self._check_finish()
  164. if self.function_events is None:
  165. raise AssertionError("Expected profiling results")
  166. return self.function_events.self_cpu_time_total
  167. def _parse_legacy_records(thread_records):
  168. def _get_record_key(record):
  169. """Return a tuple for correlating start and end records in `_parse_legacy_records`."""
  170. return (record.handle(), record.node_id())
  171. start_record = None
  172. functions = []
  173. # '__start_profile' is not guaranteed to be first, so we must find it here
  174. for record in itertools.chain.from_iterable(thread_records):
  175. name = record.name()
  176. if start_record is None and name == "__start_profile":
  177. start_record = record
  178. if start_record is None or start_record.is_remote():
  179. raise AssertionError("Expected a valid local start_record")
  180. for thread_record_list in thread_records:
  181. # accumulated memory allocations per handle
  182. cpu_memory_allocs = {}
  183. cuda_memory_allocs = {}
  184. # ranges per handle
  185. range_starts = {}
  186. filtered_handles = set()
  187. prev_record = None
  188. for record in thread_record_list:
  189. record_key = _get_record_key(record)
  190. if _filter_name(record.name()) or record_key in filtered_handles:
  191. filtered_handles.add(record_key)
  192. continue
  193. if record.kind() == "push":
  194. # workaround to reduce double logging from operator
  195. # wrappers and redispatch
  196. if prev_record is not None:
  197. duplicate = (
  198. prev_record.name() == record.name()
  199. and prev_record.kind() == record.kind()
  200. and prev_record.node_id() == record.node_id()
  201. )
  202. if duplicate:
  203. filtered_handles.add(record_key)
  204. continue
  205. range_starts[record_key] = record
  206. cpu_memory_allocs[record_key] = 0
  207. cuda_memory_allocs[record_key] = 0
  208. elif record.kind() == "pop":
  209. if record_key not in range_starts:
  210. raise AssertionError(
  211. f"Expected record with key {record_key} to exist in range_starts. "
  212. "This means that the pop event did not have a corresponding push."
  213. )
  214. start = range_starts[record_key]
  215. cpu_memory_usage = cpu_memory_allocs[record_key]
  216. cuda_memory_usage = cuda_memory_allocs[record_key]
  217. is_async = start.is_async() or (start.thread_id() != record.thread_id())
  218. is_remote_event = record.is_remote()
  219. start_flops = start.flops()
  220. fe = FunctionEvent(
  221. id=record.handle(),
  222. node_id=record.node_id(),
  223. name=_rewrite_name(name=start.name(), with_wildcard=True),
  224. trace_name=_rewrite_name(name=start.name(), with_wildcard=False),
  225. thread=start.thread_id(),
  226. start_us=start_record.cpu_elapsed_us(start),
  227. end_us=start_record.cpu_elapsed_us(record),
  228. fwd_thread=start.fwd_thread_id(),
  229. input_shapes=start.shapes(),
  230. stack=[
  231. entry for entry in start.stack() if _filter_stack_entry(entry)
  232. ],
  233. scope=start.scope(),
  234. use_device="cuda" if start.has_cuda() else None,
  235. cpu_memory_usage=cpu_memory_usage,
  236. device_memory_usage=cuda_memory_usage,
  237. is_async=is_async,
  238. is_remote=is_remote_event,
  239. sequence_nr=start.sequence_nr(),
  240. device_type=DeviceType.CPU,
  241. is_legacy=True,
  242. flops=start_flops,
  243. )
  244. # note: async events have only cpu total time
  245. if not is_async and start.has_cuda():
  246. duration = start.cuda_elapsed_us(record)
  247. if duration > 0:
  248. fe.append_kernel(start.name(), start.device(), duration)
  249. functions.append(fe)
  250. del range_starts[record_key]
  251. del cpu_memory_allocs[record_key]
  252. del cuda_memory_allocs[record_key]
  253. elif record.kind() == "memory_alloc":
  254. num_open_handles_cpu = len(cpu_memory_allocs)
  255. num_open_handles_cuda = len(cuda_memory_allocs)
  256. if num_open_handles_cpu != num_open_handles_cuda:
  257. raise AssertionError(
  258. f"Expected CPU and CUDA memory allocation handles to match, "
  259. f"but got {num_open_handles_cpu} CPU and {num_open_handles_cuda} CUDA"
  260. )
  261. for handle in cpu_memory_allocs:
  262. cpu_memory_allocs[handle] += record.cpu_memory_usage()
  263. for handle in cuda_memory_allocs:
  264. cuda_memory_allocs[handle] += record.cuda_memory_usage()
  265. if num_open_handles_cpu == 0:
  266. # output event as a top-level memory event
  267. fe = FunctionEvent(
  268. id=0,
  269. name=MEMORY_EVENT_NAME,
  270. trace_name=None,
  271. thread=0,
  272. start_us=0,
  273. end_us=0,
  274. stack=[],
  275. cpu_memory_usage=record.cpu_memory_usage(),
  276. device_memory_usage=record.cuda_memory_usage(),
  277. is_legacy=True,
  278. )
  279. functions.append(fe)
  280. prev_record = record
  281. # Sort functions by start time then by end time ascending.
  282. # This ensures that--in the case of nested events which
  283. # have the same start time (which may happen due to the
  284. # granularity of the given clock tick)--we always show
  285. # the outermost nested call first. This adds stability
  286. # in how FunctionEvents appear
  287. functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end])
  288. return functions