metrics.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. from __future__ import annotations
  2. import csv
  3. import dataclasses
  4. import inspect
  5. import os
  6. import re
  7. from dataclasses import dataclass
  8. from functools import lru_cache
  9. from typing import Optional, TYPE_CHECKING, Union
  10. from torch._inductor import config
  11. from torch._inductor.utils import get_benchmark_name
  12. from torch.utils._ordered_set import OrderedSet
  13. # Prevent circular import
  14. if TYPE_CHECKING:
  15. from collections.abc import Callable
  16. from torch._inductor.runtime.triton_compat import Config
  17. from torch._inductor.scheduler import BaseSchedulerNode
  18. # counter for tracking how many kernels have been generated
  19. generated_kernel_count = 0
  20. generated_cpp_vec_kernel_count = 0
  21. num_bytes_accessed = 0
  22. nodes_num_elem: list[
  23. tuple[
  24. BaseSchedulerNode,
  25. int,
  26. ]
  27. ] = []
  28. node_runtimes: list[tuple[BaseSchedulerNode, float]] = []
  29. # counters for tracking fusions
  30. ir_nodes_pre_fusion = 0
  31. # counters for tracking to_dtype inserted
  32. cpp_to_dtype_count = 0
  33. @dataclasses.dataclass
  34. class CppOuterLoopFusedCount:
  35. inner_kernel_number: int
  36. local_buffer_number: int = 0
  37. # The length counts the number of outer loop fusions.
  38. cpp_outer_loop_fused_inner_counts: list[CppOuterLoopFusedCount] = []
  39. num_comprehensive_padding = 0
  40. num_matches_for_scatter_upon_const_tensor = 0
  41. num_loop_reordering = 0
  42. num_auto_chunking: int = 0
  43. # counter for parallel reduction.
  44. parallel_reduction_count = 0
  45. codegen_mix_order_reduction = 0
  46. # reset all counters
  47. def reset() -> None:
  48. global generated_kernel_count
  49. global generated_cpp_vec_kernel_count
  50. global num_bytes_accessed, nodes_num_elem
  51. global ir_nodes_pre_fusion
  52. global cpp_to_dtype_count
  53. global cpp_outer_loop_fused_inner_counts
  54. global num_comprehensive_padding
  55. global num_matches_for_scatter_upon_const_tensor
  56. global num_loop_reordering
  57. global parallel_reduction_count
  58. global codegen_mix_order_reduction
  59. global num_auto_chunking
  60. generated_kernel_count = 0
  61. generated_cpp_vec_kernel_count = 0
  62. num_bytes_accessed = 0
  63. nodes_num_elem.clear()
  64. node_runtimes.clear()
  65. ir_nodes_pre_fusion = 0
  66. cpp_to_dtype_count = 0
  67. cpp_outer_loop_fused_inner_counts.clear()
  68. num_comprehensive_padding = 0
  69. num_matches_for_scatter_upon_const_tensor = 0
  70. num_loop_reordering = 0
  71. parallel_reduction_count = 0
  72. codegen_mix_order_reduction = 0
  73. num_auto_chunking = 0
  74. @dataclass
  75. class CachedMetricsDeltas:
  76. """
  77. The subset of metrics we want update across cache hits, e.g., the
  78. FxGraphCache.
  79. """
  80. generated_kernel_count: int
  81. generated_cpp_vec_kernel_count: int
  82. ir_nodes_pre_fusion: int
  83. cpp_to_dtype_count: int
  84. num_bytes_accessed: int
  85. num_matches_for_scatter_upon_const_tensor: int
  86. def get_metric_fields() -> list[str]:
  87. return [field.name for field in dataclasses.fields(CachedMetricsDeltas)]
  88. class CachedMetricsHelper:
  89. """
  90. A helper class to help calculate and apply counter deltas for those
  91. metrics we want to save with cache entries (e.g., FxGraphCache) and
  92. apply on a cache hit.
  93. """
  94. def __init__(self) -> None:
  95. self.cached_metrics = {}
  96. for metric in get_metric_fields():
  97. self.cached_metrics[metric] = globals()[metric]
  98. def get_deltas(self) -> CachedMetricsDeltas:
  99. delta_metrics = {}
  100. for metric in get_metric_fields():
  101. delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric]
  102. return CachedMetricsDeltas(**delta_metrics)
  103. @staticmethod
  104. def apply_deltas(delta: CachedMetricsDeltas) -> None:
  105. for metric in get_metric_fields():
  106. globals()[metric] += getattr(delta, metric)
  107. REGISTERED_METRIC_TABLES: dict[str, MetricTable] = {}
  108. @dataclass
  109. class MetricTable:
  110. table_name: str
  111. column_names: list[str]
  112. num_rows_added: int = 0
  113. def add_row(
  114. self, row_fn: Callable[[], dict[str, Optional[Union[str, float]]]]
  115. ) -> None:
  116. if self.table_name not in enabled_metric_tables():
  117. return
  118. row_dict = row_fn()
  119. assert len(self.column_names) == len(row_dict), (
  120. f"{len(self.column_names)} v.s. {len(row_dict)}"
  121. )
  122. assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), (
  123. f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
  124. )
  125. bn = get_benchmark_name()
  126. # assert bn is not None
  127. row = [bn] + [row_dict[column_name] for column_name in self.column_names]
  128. assert all(isinstance(i, (str, float, type(None))) for i in row)
  129. self._write_row(row)
  130. def output_filename(self) -> str:
  131. return f"metric_table_{self.table_name}.csv"
  132. def write_header(self) -> None:
  133. filename = self.output_filename()
  134. with open(filename, "w") as fd:
  135. writer = csv.writer(fd, lineterminator="\n")
  136. writer.writerow(["model_name"] + self.column_names)
  137. def _write_row(self, row: list[str | float | None]) -> None:
  138. filename = self.output_filename()
  139. if self.num_rows_added == 0 and not os.path.exists(filename):
  140. self.write_header()
  141. self.num_rows_added += 1
  142. for idx, orig_val in enumerate(row):
  143. if isinstance(orig_val, float):
  144. new_val = f"{orig_val:.6f}"
  145. elif orig_val is None:
  146. new_val = ""
  147. else:
  148. new_val = orig_val
  149. row[idx] = new_val
  150. with open(filename, "a") as fd:
  151. writer = csv.writer(fd, lineterminator="\n")
  152. writer.writerow(row)
  153. @staticmethod
  154. def register_table(name: str, column_names: list[str]) -> None:
  155. table = MetricTable(name, column_names)
  156. REGISTERED_METRIC_TABLES[name] = table
  157. MetricTable.register_table(
  158. "slow_fusion",
  159. [
  160. "kernel1_path",
  161. "kernel1_latency",
  162. "kernel2_path",
  163. "kernel2_latency",
  164. "fused_kernel_path",
  165. "fused_kernel_latency",
  166. "slow_down_ratio",
  167. ],
  168. )
  169. # track the fusion statistics for each graph
  170. MetricTable.register_table(
  171. "graph_stats",
  172. [
  173. "graph_id",
  174. "num_nodes_before_fusion",
  175. "num_nodes_after_fusion",
  176. ],
  177. )
  178. # track the perf difference between persistent reduction and non-persistent
  179. # reductions
  180. MetricTable.register_table(
  181. "persistent_red_perf",
  182. [
  183. "kernel0_path",
  184. "kernel1_path",
  185. "kernel2_path",
  186. "kernel3_path",
  187. "kernel0_latency",
  188. "kernel1_latency",
  189. "kernel2_latency",
  190. "kernel3_latency",
  191. "size_hints",
  192. "reduction_hint",
  193. ],
  194. )
  195. # Log the fusion failures due to indexing mismatch
  196. MetricTable.register_table(
  197. "fusion_failure_due_to_indexing_mismatch",
  198. [
  199. "pre_grad_graph_id",
  200. "post_grad_graph_id",
  201. "node1_name",
  202. "node2_name",
  203. "node1_debug_str",
  204. "node2_debug_str",
  205. "common_buffer_names",
  206. "failure_reason",
  207. ],
  208. )
  209. # Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint
  210. MetricTable.register_table(
  211. "kernel_metadata",
  212. [
  213. "kernel_name",
  214. "kernel_path",
  215. "kernel_category", # pointwise/reduction/foreach etc.
  216. "size_hints",
  217. "reduction_hint",
  218. "line_of_code",
  219. "num_load",
  220. "num_store",
  221. "num_for_loop",
  222. "num_atomic_add",
  223. "num_args",
  224. # xyz numel can be different to size_hints since size_hints are rounded
  225. # up to the nearest power of 2.
  226. # Inductor kernel will burn in the xyz numel in kernel code for static
  227. # shape kernels.
  228. # Logging them will be helpful to find unaligned shape for reduction
  229. "xnumel",
  230. "ynumel",
  231. "rnumel",
  232. "kernel_args_num_gb",
  233. ],
  234. )
  235. def _parse_kernel_fn_code(kernel_module_code: str) -> str:
  236. """
  237. The kernel_module_code is the python module that contains kernel function code.
  238. kernel function is the proper triton kernel function annotated with
  239. @triton.jit
  240. """
  241. from .codecache import PyCodeCache
  242. from .wrapper_benchmark import get_triton_kernel
  243. mod = PyCodeCache.load(kernel_module_code)
  244. kernel = get_triton_kernel(mod)
  245. # kernel is a CachingAutotune; kernel.fn is the JITFunction;
  246. # kernel.fn.fn is the function being decorate by triton.jit
  247. return inspect.getsource(kernel.fn.fn)
  248. def _parse_kernel_line_of_code(proper_kernel_fn_code: str) -> int:
  249. """
  250. Return the line of code for the kernel excluding the decorators.
  251. """
  252. return len(proper_kernel_fn_code.splitlines())
  253. def _parse_size_hints(kernel_module_code: str, kernel_category: str) -> Optional[str]:
  254. if kernel_category == "foreach":
  255. # foreach kernel does not have size_hints
  256. return None
  257. m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code)
  258. assert m, "size_hints missing!"
  259. return m.group(1)
  260. def _parse_reduction_hint(
  261. kernel_category: str, kernel_module_code: str
  262. ) -> Optional[str]:
  263. if kernel_category not in ("reduction", "persistent_reduction"):
  264. return None
  265. m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code)
  266. assert m, "reduction_hint not found in kernel source code!"
  267. return m.group(1)
  268. def _count_pattern(proper_kernel_fn_code: str, pattern: str) -> int:
  269. return proper_kernel_fn_code.count(pattern)
  270. def _count_args(proper_kernel_fn_code: str) -> int:
  271. def_line = proper_kernel_fn_code.splitlines()[0]
  272. assert def_line.startswith("def ")
  273. start_idx = def_line.index("(")
  274. end_idx = def_line.index("):")
  275. decl_csv = def_line[start_idx + 1 : end_idx]
  276. comps = decl_csv.split(",")
  277. return len(comps)
  278. def _parse_proper_kernel_fn_code(kernel_fn_code: str) -> str:
  279. """
  280. Skip decorators.
  281. """
  282. start_pos = kernel_fn_code.index("def ")
  283. return kernel_fn_code[start_pos:]
  284. def _parse_numel(proper_kernel_fn_code: str, numel_arg_name: str) -> Optional[int]:
  285. m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code)
  286. if m:
  287. return int(m.group(1))
  288. else:
  289. return None
  290. def _parse_kernel_args_num_gb(
  291. kernel_fn_code: str, kernel_category: str
  292. ) -> Optional[float]:
  293. """
  294. inductor meta looks like:
  295. inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0},
  296. """
  297. m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code)
  298. if m:
  299. return float(m.group(1))
  300. else:
  301. """
  302. There are a few cases that kernel_num_gdb field can be missing:
  303. 1. the field will be missing if config.benchmark_kernel and
  304. config.profile_bandwidth are false
  305. 2. even if config.benchmark_kernel or config.profile_bandwidth is true.
  306. foreach kernel does not have kernel_num_gb field in the metadata
  307. """
  308. return None
  309. def log_kernel_metadata(
  310. kernel_name: str, kernel_path: str, kernel_module_code: str
  311. ) -> None:
  312. """
  313. An utility to log kernel metadata. We may parse metadata from kernel source code here.
  314. It's fine to parse the generated kernel code here since the logging is
  315. disabled by default. It would hurt compilation time.
  316. """
  317. from .wrapper_benchmark import get_kernel_category_by_source_code
  318. kernel_category = get_kernel_category_by_source_code(kernel_module_code)
  319. reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code)
  320. size_hints = _parse_size_hints(kernel_module_code, kernel_category)
  321. kernel_fn_code = _parse_kernel_fn_code(kernel_module_code)
  322. proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code)
  323. # the line of code excluding the decortors
  324. kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code)
  325. get_metric_table("kernel_metadata").add_row(
  326. lambda: {
  327. "kernel_name": kernel_name,
  328. "kernel_path": kernel_path,
  329. "kernel_category": kernel_category,
  330. "size_hints": size_hints,
  331. "reduction_hint": reduction_hint,
  332. "line_of_code": kernel_line_of_code,
  333. "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"),
  334. "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"),
  335. "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "),
  336. "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"),
  337. "num_args": _count_args(proper_kernel_fn_code),
  338. "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"),
  339. "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"),
  340. "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"),
  341. "kernel_args_num_gb": _parse_kernel_args_num_gb(
  342. kernel_fn_code, kernel_category
  343. ),
  344. }
  345. )
  346. def purge_old_log_files() -> None:
  347. """
  348. Purge the old log file at the beginning when the benchmark script runs.
  349. Should do it in the parent process rather than the child processes running
  350. each individual model.
  351. """
  352. for name, table in REGISTERED_METRIC_TABLES.items():
  353. if name in enabled_metric_tables():
  354. filename = table.output_filename()
  355. if os.path.exists(filename):
  356. os.unlink(filename)
  357. table.write_header()
  358. def enabled_metric_tables() -> OrderedSet[str]:
  359. return enabled_metric_tables_impl(config.enabled_metric_tables)
  360. @lru_cache
  361. def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
  362. enabled: OrderedSet[str] = OrderedSet()
  363. for name in config_str.split(","):
  364. name = name.strip()
  365. if not name:
  366. continue
  367. assert name in REGISTERED_METRIC_TABLES, (
  368. f"Metric table name {name} is not registered"
  369. )
  370. enabled.add(name)
  371. return enabled
  372. def is_metric_table_enabled(name: str) -> bool:
  373. return name in enabled_metric_tables()
  374. def get_metric_table(name: str) -> MetricTable:
  375. assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
  376. return REGISTERED_METRIC_TABLES[name]
  377. MetricTable.register_table(
  378. "kernel_autotune",
  379. [
  380. "kernel_path",
  381. "kernel_name",
  382. "triton_config",
  383. "latency_ms",
  384. ],
  385. )
  386. def log_kernel_autotune_result(
  387. kernel_path: str, kernel_name: str, config: Config, latency: float
  388. ) -> None:
  389. get_metric_table("kernel_autotune").add_row(
  390. lambda: {
  391. "kernel_path": kernel_path,
  392. "kernel_name": kernel_name,
  393. "triton_config": str(config),
  394. "latency_ms": latency,
  395. }
  396. )