_pattern_matcher.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. # mypy: allow-untyped-defs
  2. import json
  3. import math
  4. import os
  5. import re
  6. import torch
  7. import torch.utils.benchmark as benchmark
  8. from torch._C._profiler import (
  9. _EventType,
  10. _ExtraFields_PyCall,
  11. _ExtraFields_PyCCall,
  12. _ExtraFields_TorchOp,
  13. _ProfilerEvent,
  14. )
  15. from torch.profiler import profile
  16. from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs
  17. class Pattern:
  18. """
  19. Base class for all patterns, subclass this class and implement match()
  20. to define custom patterns.
  21. In subclass, define description and skip property.
  22. """
  23. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  24. self.prof = prof
  25. self.should_benchmark = should_benchmark
  26. self.name = "Please specify a name for pattern"
  27. self.description = "Please specify a description for pattern"
  28. self.url = ""
  29. if prof.profiler is None or prof.profiler.kineto_results is None:
  30. raise AssertionError("profiler and kineto_results must not be None")
  31. self.event_tree = prof.profiler.kineto_results.experimental_event_tree()
  32. self.tid_root: dict[int, list[_ProfilerEvent]] = {}
  33. for event in self.event_tree:
  34. self.tid_root.setdefault(event.start_tid, []).append(event)
  35. @property
  36. def skip(self) -> bool:
  37. return False
  38. def report(self, event: _ProfilerEvent):
  39. msg = (
  40. f"{self.description}\n[Source Code Location] {source_code_location(event)}"
  41. )
  42. return msg
  43. def eventTreeTraversal(self):
  44. """
  45. Traverse the event tree and yield all events.
  46. Override this method in subclass to customize the traversal.
  47. """
  48. yield from traverse_dfs(self.event_tree)
  49. def summary(self, events: list[_ProfilerEvent]):
  50. default_summary = f"{self.name}: {len(events)} events matched."
  51. if self.should_benchmark:
  52. # If benchmark summary is not empty, use it.
  53. return (
  54. self.benchmark_summary(events)
  55. if hasattr(self, "benchmark") # type: ignore[attr-defined]
  56. else default_summary
  57. )
  58. return default_summary
  59. def benchmark_summary(self, events: list[_ProfilerEvent]) -> str:
  60. def format_time(time_ns: int) -> str:
  61. unit_lst = ["ns", "us", "ms"]
  62. for unit in unit_lst:
  63. if time_ns < 1000:
  64. return f"{time_ns:.2f} {unit}"
  65. time_ns //= 1000
  66. return f"{time_ns:.2f} s"
  67. if not hasattr(self, "benchmark"):
  68. raise AssertionError("Please implement benchmark()")
  69. shapes_factor_map = self.benchmark(events) # type: ignore[attr-defined]
  70. original_time = sum(event.duration_time_ns for event in events)
  71. new_time = sum(
  72. shapes_factor_map[input_shapes(event)] * event.duration_time_ns
  73. for event in events
  74. )
  75. return (
  76. f"{self.name}: {len(events)} events matched. "
  77. f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time / new_time, 2)}X)"
  78. )
  79. def match(self, event: _ProfilerEvent):
  80. """
  81. Return True if the event matches the pattern.
  82. This method should be overridden in subclass.
  83. """
  84. raise NotImplementedError
  85. def matched_events(self):
  86. if self.skip:
  87. return []
  88. matched_events = [
  89. event for event in self.eventTreeTraversal() if self.match(event)
  90. ]
  91. return matched_events
  92. def root_of(self, event: _ProfilerEvent):
  93. while event.parent:
  94. event = event.parent
  95. return event
  96. def siblings_of(self, event: _ProfilerEvent):
  97. if event.parent:
  98. children = event.parent.children
  99. else:
  100. children = self.tid_root[event.start_tid]
  101. index = children.index(event)
  102. return children[:index], children[index + 1 :]
  103. def next_of(self, event: _ProfilerEvent):
  104. _, next_events = self.siblings_of(event)
  105. return next_events[0] if next_events else None
  106. def prev_of(self, event: _ProfilerEvent):
  107. prev_events, _ = self.siblings_of(event)
  108. return prev_events[-1] if prev_events else None
  109. def go_up_until(self, event: _ProfilerEvent, predicate):
  110. if not event:
  111. return None
  112. while event.parent and not predicate(event):
  113. event = event.parent
  114. return event
  115. # Patterns
  116. class NamePattern(Pattern):
  117. def __init__(
  118. self, prof: profile, name: str, should_benchmark: bool = False
  119. ) -> None:
  120. super().__init__(prof, should_benchmark)
  121. self.description = f"Matched Name Event: {name}"
  122. self.name = name
  123. def match(self, event: _ProfilerEvent):
  124. return re.search(self.name, event.name) is not None
  125. class ExtraCUDACopyPattern(Pattern):
  126. """
  127. This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU.
  128. example: torch.zeros((100, 100)).to("cuda")
  129. Pattern:
  130. built-in method |built-in method
  131. ... | aten::to
  132. aten::fill_/aten::zero_ | aten::_to_copy
  133. Algorithm:
  134. We start at node aten::to, go parent events' previous events,
  135. and check if we have a aten::fill_/aten::zero_ as we keep going down the tree.
  136. We always select the last child in the children list when we go down the tree.
  137. If at any step we failed, it is not a match.
  138. """
  139. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  140. super().__init__(prof, should_benchmark)
  141. self.name = "Extra CUDA Copy Pattern"
  142. self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
  143. self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
  144. self.init_ops = {
  145. "aten::fill_",
  146. "aten::zero_",
  147. "aten::normal_",
  148. "aten::uniform_",
  149. }
  150. @property
  151. def skip(self) -> bool:
  152. return not self.prof.with_stack or not self.prof.record_shapes
  153. def match(self, event):
  154. # TODO: We should also check tensor identities
  155. if event.name != "aten::to":
  156. return False
  157. to_event = event
  158. if not event.children:
  159. return False
  160. event = event.children[-1]
  161. if event.name != "aten::_to_copy":
  162. return False
  163. if not event.children:
  164. return False
  165. event = event.children[-1]
  166. if event.name != "aten::copy_":
  167. return False
  168. # aten::copy_ should have the first 2 args dtype the same
  169. dtypes = input_dtypes(event)
  170. if len(dtypes) < 2:
  171. return False
  172. if dtypes[0] is None or dtypes[0] != dtypes[1]:
  173. return False
  174. event = to_event
  175. # Up one level
  176. event = event.parent
  177. if event is None:
  178. return False
  179. # Check if we have a aten::fill_ in previous leaf
  180. event = self.prev_of(event)
  181. if event is None:
  182. return False
  183. while event.children:
  184. event = event.children[-1]
  185. # aten::zero_ is a special optimization case where fill_ is not called
  186. if event.name in self.init_ops:
  187. return True
  188. return event.name in self.init_ops
  189. # TODO: Check if tensor is reused
  190. def benchmark(self, events: list[_ProfilerEvent]):
  191. shapes_factor_map = {input_shapes(event): 0.0 for event in events}
  192. for shape in shapes_factor_map:
  193. size = shape[0]
  194. to_timer = benchmark.Timer(
  195. stmt='torch.ones(size).to("cuda")', globals={"size": size}
  196. )
  197. de_timer = benchmark.Timer(
  198. stmt='torch.ones(size, device="cuda")', globals={"size": size}
  199. )
  200. to_time = to_timer.timeit(10).mean
  201. de_time = de_timer.timeit(10).mean
  202. shapes_factor_map[shape] = de_time / to_time
  203. return shapes_factor_map
  204. class ForLoopIndexingPattern(Pattern):
  205. """
  206. This pattern identifies if we use a for loop to index a tensor that
  207. can be vectorized.
  208. example:
  209. tensor = torch.empty((100, 100))
  210. for i in range(100):
  211. tensor[i] = i
  212. Pattern:
  213. aten::select | ... | aten::select | ... (Repeat)
  214. Algorithm:
  215. We start at node aten::select, and we check if we can find this alternating patterns.
  216. We also keep a dictionary to avoid duplicate match in the for loop.
  217. """
  218. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  219. super().__init__(prof, should_benchmark)
  220. self.name = "For Loop Indexing Pattern"
  221. self.description = "For loop indexing detected. Vectorization recommended."
  222. self.visited: set[int] = set()
  223. def eventTreeTraversal(self):
  224. """
  225. We need to use BFS traversal order to avoid duplicate match.
  226. """
  227. yield from traverse_bfs(self.event_tree)
  228. def match(self, event: _ProfilerEvent):
  229. if event.name != "aten::select":
  230. return False
  231. if event.id in self.visited:
  232. return False
  233. repeat_count = 1
  234. _, next = self.siblings_of(event)
  235. if len(next) <= 1:
  236. return False
  237. # Custom event list matching
  238. def same_ops(list1, list2) -> bool:
  239. if len(list1) != len(list2):
  240. return False
  241. for op1, op2 in zip(list1, list2, strict=True):
  242. if op1.name != op2.name:
  243. return False
  244. return True
  245. # Record the ops between two aten::select
  246. next_select_idx = index_of_first_match(next, lambda e: e.name == "aten::select")
  247. if next_select_idx is None:
  248. return False
  249. indexing_ops = [event] + next[:next_select_idx]
  250. next = next[len(indexing_ops) - 1 :]
  251. for i in range(0, len(next), len(indexing_ops)):
  252. if same_ops(indexing_ops, next[i : i + len(indexing_ops)]):
  253. repeat_count += 1
  254. self.visited.add(next[i].id)
  255. else:
  256. break
  257. return repeat_count >= 10
  258. class FP32MatMulPattern(Pattern):
  259. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  260. super().__init__(prof, should_benchmark)
  261. self.name = "FP32 MatMul Pattern"
  262. self.description = (
  263. "You are currently using GPU that supports TF32. "
  264. "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
  265. )
  266. self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
  267. @property
  268. def skip(self):
  269. if torch.version.hip is not None:
  270. has_tf32 = False
  271. else:
  272. # Anything less than sm_80 is not Ampere which doesn't support TF32
  273. has_tf32 = all(
  274. int(re.sub("sm_|compute_", "", arch)) >= 80
  275. for arch in torch.cuda.get_arch_list()
  276. )
  277. return has_tf32 is False or super().skip or not self.prof.record_shapes
  278. def match(self, event: _ProfilerEvent) -> bool:
  279. # If we saw this pattern once, we don't need to match it again
  280. if event.tag != _EventType.TorchOp:
  281. return False
  282. if not isinstance(event.extra_fields, _ExtraFields_TorchOp):
  283. raise AssertionError(
  284. f"expected _ExtraFields_TorchOp, got {type(event.extra_fields).__name__}"
  285. )
  286. if event.name == "aten::mm":
  287. if event.extra_fields.allow_tf32_cublas is False:
  288. return True
  289. return False
  290. def report(self, event: _ProfilerEvent):
  291. return self.description
  292. def benchmark(self, events: list[_ProfilerEvent]):
  293. shapes_factor_map = {input_shapes(event): 0.0 for event in events}
  294. for shape in shapes_factor_map:
  295. matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
  296. matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32)
  297. fp32_timer = benchmark.Timer(
  298. stmt="torch.mm(matrixA, matrixB)",
  299. globals={"matrixA": matrixA, "matrixB": matrixB},
  300. )
  301. tf32_timer = benchmark.Timer(
  302. stmt="torch.mm(matrixA, matrixB)",
  303. setup="torch.backends.cuda.matmul.allow_tf32 = True",
  304. globals={"matrixA": matrixA, "matrixB": matrixB},
  305. )
  306. torch.backends.cuda.matmul.allow_tf32 = False
  307. fp32_time = fp32_timer.timeit(10).mean
  308. tf32_time = tf32_timer.timeit(10).mean
  309. shapes_factor_map[shape] = tf32_time / fp32_time
  310. return shapes_factor_map
  311. class OptimizerSingleTensorPattern(Pattern):
  312. """
  313. This pattern identifies if we are using the single-tensor version of an optimizer.
  314. example:
  315. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  316. By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when
  317. the kernels are relatively small.
  318. Pattern:
  319. XXXXX: _single_tenser_<OPTIMIZER_NAME>
  320. Algorithm:
  321. String match
  322. """
  323. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  324. super().__init__(prof, should_benchmark)
  325. self.name = "Optimizer Single Tensor Pattern"
  326. self.optimizers_with_foreach = ["adam", "sgd", "adamw"]
  327. self.description = (
  328. "Detected optimizer running with single tensor implementation. "
  329. "Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
  330. )
  331. self.url = ""
  332. def match(self, event: _ProfilerEvent) -> bool:
  333. for optimizer in self.optimizers_with_foreach:
  334. if event.name.endswith(f"_single_tensor_{optimizer}"):
  335. return True
  336. return False
  337. class SynchronizedDataLoaderPattern(Pattern):
  338. """
  339. This pattern identifies if we are using num_workers=0 in DataLoader.
  340. example:
  341. torch.utils.data.DataLoader(dataset, batch_size=batch_size)
  342. Add num_workers=N to the arguments. N depends on system configuration.
  343. Pattern:
  344. dataloader.py(...): __iter__
  345. dataloader.py(...): _get_iterator
  346. NOT dataloader.py(...): check_worker_number_rationality
  347. Algorithm:
  348. If we don't see check_worker_number_rationality call in the dataloader __iter__,
  349. It is not an asynchronous dataloader.
  350. """
  351. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  352. super().__init__(prof, should_benchmark)
  353. self.name = "Synchronized DataLoader Pattern"
  354. self.description = (
  355. "Detected DataLoader running with synchronized implementation. "
  356. "Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
  357. )
  358. self.url = (
  359. "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
  360. "#enable-async-data-loading-and-augmentation"
  361. )
  362. def match(self, event: _ProfilerEvent) -> bool:
  363. def is_dataloader_function(name: str, function_name: str):
  364. return name.startswith(
  365. os.path.join("torch", "utils", "data", "dataloader.py")
  366. ) and name.endswith(function_name)
  367. # TODO: fixme! Due to lifetime issues of the function name, this field might
  368. # actually point to an already freed string when the even is a PyCall.
  369. # Just silently skip this to unblock testing.
  370. try:
  371. event.name
  372. except UnicodeDecodeError:
  373. return False
  374. if not is_dataloader_function(event.name, "__iter__"):
  375. return False
  376. if not event.children:
  377. return False
  378. event = event.children[0]
  379. if not is_dataloader_function(event.name, "_get_iterator"):
  380. return False
  381. if not event.children:
  382. return False
  383. event = event.children[0]
  384. return not is_dataloader_function(event.name, "check_worker_number_rationality")
  385. # TODO: We should also check if the loader is bottleneck.
  386. class GradNotSetToNonePattern(Pattern):
  387. """
  388. This pattern identifies if we are not setting grad to None in zero_grad.
  389. example:
  390. optimizer.zero_grad()
  391. By setting set_to_none=True, we can gain speedup
  392. Pattern:
  393. XXXXX: _zero_grad
  394. NOT aten::zeros
  395. aten::zero_
  396. aten::zero_ is called on each parameter in the model.
  397. We also want to make sure it is not called by aten::zeros.
  398. Algorithm:
  399. String match
  400. """
  401. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  402. super().__init__(prof, should_benchmark)
  403. self.name = "Gradient Set To Zero Instead of None Pattern"
  404. self.description = (
  405. "Detected gradient set to zero instead of None. "
  406. "Please add 'set_to_none=True' when calling zero_grad()."
  407. )
  408. self.url = (
  409. "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
  410. "#disable-gradient-calculation-for-validation-or-inference"
  411. )
  412. def match(self, event: _ProfilerEvent) -> bool:
  413. if not event.name.endswith(": zero_grad"):
  414. return False
  415. if not event.children:
  416. return False
  417. for sub_event in traverse_dfs(event.children):
  418. if (
  419. sub_event.name == "aten::zero_"
  420. and sub_event.parent.name != "aten::zeros"
  421. ):
  422. return True
  423. # TODO: We should also check if the optimizer's numerical behavior will change.
  424. return False
  425. class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
  426. """
  427. This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d.
  428. Bias doesn't do anything when followed by batchnorm.
  429. Pattern:
  430. nn.Module: Conv2d | nn.Module: BatchNorm2d
  431. ...
  432. aten::conv2d AND dtype of third argument is not null
  433. The third argument is the bias
  434. Algorithm:
  435. String match
  436. """
  437. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  438. super().__init__(prof, should_benchmark)
  439. self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
  440. self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
  441. self.url = (
  442. "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
  443. "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm"
  444. )
  445. @property
  446. def skip(self):
  447. return self.prof.record_shapes is False or super().skip
  448. def match(self, event: _ProfilerEvent):
  449. if event.name != "aten::conv2d":
  450. return False
  451. if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] is None:
  452. return False
  453. # This means bias=True
  454. event = self.go_up_until(
  455. event, lambda e: e.name.startswith("nn.Module: Conv2d")
  456. )
  457. if not event:
  458. return False
  459. event = self.next_of(event)
  460. if not event:
  461. return False
  462. return event.name.startswith("nn.Module: BatchNorm2d")
  463. class MatMulDimInFP16Pattern(Pattern):
  464. def __init__(self, prof: profile, should_benchmark: bool = False) -> None:
  465. super().__init__(prof, should_benchmark)
  466. self.name = "Matrix Multiplication Dimension Not Aligned Pattern"
  467. self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
  468. self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"
  469. @property
  470. def skip(self) -> bool:
  471. return not self.prof.with_stack or not self.prof.record_shapes
  472. def match(self, event: _ProfilerEvent) -> bool:
  473. def mutiple_of(shapes, multiple):
  474. return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:])
  475. if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"):
  476. return False
  477. if not input_dtypes(event):
  478. return False
  479. arg_dtype = input_dtypes(event)[0]
  480. if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of(
  481. input_shapes(event), 8
  482. ):
  483. return True
  484. return False
  485. def benchmark(self, events: list[_ProfilerEvent]):
  486. def closest_multiple(shapes, multiple):
  487. return [multiple * math.ceil(shape / multiple) for shape in shapes]
  488. shapes_factor_map = {input_shapes(event): 0.0 for event in events}
  489. for shape in shapes_factor_map:
  490. matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16)
  491. matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16)
  492. not_aligned_dim_timer = benchmark.Timer(
  493. stmt="torch.mm(matrixA, matrixB)",
  494. globals={"matrixA": matrixA, "matrixB": matrixB},
  495. )
  496. matrixA = torch.randn(
  497. closest_multiple(shape[0], 8), device="cuda", dtype=torch.float16
  498. )
  499. matrixB = torch.randn(
  500. closest_multiple(shape[1], 8), device="cuda", dtype=torch.float16
  501. )
  502. aligned_dim_timer = benchmark.Timer(
  503. stmt="torch.mm(matrixA, matrixB)",
  504. globals={"matrixA": matrixA, "matrixB": matrixB},
  505. )
  506. not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean
  507. aligned_dim_time = aligned_dim_timer.timeit(10).mean
  508. shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time
  509. return shapes_factor_map
  510. def source_code_location(event: _ProfilerEvent | None) -> str:
  511. while event:
  512. if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
  513. if not isinstance(
  514. event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall)
  515. ):
  516. raise AssertionError(
  517. f"expected _ExtraFields_PyCall or _ExtraFields_PyCCall, "
  518. f"got {type(event.extra_fields).__name__}"
  519. )
  520. if not event.extra_fields.caller.file_name.startswith("torch" + os.sep):
  521. return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
  522. event = event.parent
  523. return "No source code location found"
  524. def input_shapes(event: _ProfilerEvent):
  525. if not isinstance(event.extra_fields, _ExtraFields_TorchOp):
  526. raise AssertionError(
  527. f"expected _ExtraFields_TorchOp, got {type(event.extra_fields).__name__}"
  528. )
  529. return tuple(tuple(getattr(i, "sizes", ())) for i in event.extra_fields.inputs)
  530. def input_dtypes(event: _ProfilerEvent):
  531. if not isinstance(event.extra_fields, _ExtraFields_TorchOp):
  532. raise AssertionError(
  533. f"expected _ExtraFields_TorchOp, got {type(event.extra_fields).__name__}"
  534. )
  535. return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs)
  536. def report_all_anti_patterns(
  537. prof,
  538. should_benchmark: bool = False,
  539. print_enable: bool = True,
  540. json_report_dir: str | None = None,
  541. ) -> None:
  542. report_dict: dict = {}
  543. anti_patterns = [
  544. ExtraCUDACopyPattern(prof, should_benchmark),
  545. # ForLoopIndexingPattern(prof, should_benchmark),
  546. FP32MatMulPattern(prof, should_benchmark),
  547. OptimizerSingleTensorPattern(prof, should_benchmark),
  548. SynchronizedDataLoaderPattern(prof, should_benchmark),
  549. GradNotSetToNonePattern(prof, should_benchmark),
  550. Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark),
  551. MatMulDimInFP16Pattern(prof, should_benchmark),
  552. ]
  553. reported = set()
  554. summaries = []
  555. message_list = [f"{'-' * 40}TorchTidy Report{'-' * 40}"]
  556. message_list.append("Matched Events:")
  557. for anti_pattern in anti_patterns:
  558. matched_events = anti_pattern.matched_events()
  559. if not matched_events:
  560. continue
  561. summaries.append(anti_pattern.summary(matched_events))
  562. for event in matched_events:
  563. report_msg = anti_pattern.report(event)
  564. if report_msg not in reported:
  565. message_list.append(report_msg)
  566. reported.add(report_msg)
  567. src_location, line_no = source_code_location(event).split(":")
  568. report_dict.setdefault(src_location, []).append(
  569. {
  570. "line_number": int(line_no),
  571. "name": anti_pattern.name,
  572. "url": anti_pattern.url,
  573. "message": anti_pattern.description,
  574. }
  575. )
  576. if json_report_dir is not None:
  577. json_report_path = os.path.join(json_report_dir, "torchtidy_report.json")
  578. if os.path.exists(json_report_path):
  579. with open(json_report_path) as f:
  580. exisiting_report = json.load(f)
  581. exisiting_report.update(report_dict)
  582. report_dict = exisiting_report
  583. with open(json_report_path, "w") as f:
  584. json.dump(report_dict, f, indent=4)
  585. message_list.append("Summary:")
  586. message_list += summaries
  587. message_list.append(f"{'-' * 40}TorchTidy Report{'-' * 40}")
  588. if print_enable:
  589. print("\n".join(message_list))