testing.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. import functools
  2. import math
  3. import os
  4. import statistics
  5. import subprocess
  6. import sys
  7. from contextlib import contextmanager
  8. from typing import Any, Dict, List
  9. from . import language as tl
  10. from . import runtime
  11. def nvsmi(attrs):
  12. attrs = ','.join(attrs)
  13. cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
  14. out = subprocess.check_output(cmd)
  15. ret = out.decode(sys.stdout.encoding).split(',')
  16. ret = [int(x) for x in ret]
  17. return ret
  18. # pure Python implementation of np.quantile/torch.quantile
  19. # to avoid unnecessary runtime dependency on numpy/torch
  20. def _quantile(a, q):
  21. n = len(a)
  22. a = sorted(a)
  23. def get_quantile(q):
  24. if not (0 <= q <= 1):
  25. raise ValueError("Quantiles must be in the range [0, 1]")
  26. point = q * (n - 1)
  27. lower = math.floor(point)
  28. upper = math.ceil(point)
  29. t = point - lower
  30. return (1 - t) * a[lower] + t * a[upper]
  31. return [get_quantile(q) for q in q]
  32. def _summarize_statistics(times, quantiles, return_mode):
  33. if quantiles is not None:
  34. ret = _quantile(times, quantiles)
  35. if len(ret) == 1:
  36. ret = ret[0]
  37. return ret
  38. if return_mode == "all":
  39. return times
  40. elif return_mode == "min":
  41. return min(times)
  42. elif return_mode == "max":
  43. return max(times)
  44. elif return_mode == "mean":
  45. return statistics.mean(times)
  46. elif return_mode == "median":
  47. return statistics.median(times)
  48. def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
  49. """
  50. Benchmark the runtime of the provided function.
  51. :param fn: Function to benchmark
  52. :type fn: Callable
  53. :param rep: Repetition time (in ms)
  54. :type rep: int
  55. :param grad_to_none: Reset the gradient of the provided tensor to None
  56. :type grad_to_none: torch.tensor, optional
  57. :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
  58. :type return_mode: str
  59. """
  60. import torch
  61. assert return_mode in ["min", "max", "mean", "median", "all"]
  62. with torch.cuda.stream(torch.cuda.Stream()):
  63. # warmup
  64. fn()
  65. if grad_to_none is not None:
  66. for x in grad_to_none:
  67. x.detach_()
  68. x.requires_grad_(True)
  69. x.grad = None
  70. # step 1 - we estimate the amount of time the kernel call takes
  71. # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
  72. # but it is probably good enough
  73. # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive,
  74. # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2
  75. # cache flush).
  76. start_event = torch.cuda.Event(enable_timing=True)
  77. end_event = torch.cuda.Event(enable_timing=True)
  78. start_event.record()
  79. for _ in range(5):
  80. fn()
  81. end_event.record()
  82. torch.cuda.synchronize()
  83. estimate_ms = start_event.elapsed_time(end_event) / 5
  84. # Rewrite to avoid possible division by 0 issues with fast benchmarks
  85. if estimate_ms == 0:
  86. n_repeat = 1000
  87. else:
  88. n_repeat = max(1, int(rep / estimate_ms))
  89. # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
  90. # host overhead
  91. g = torch.cuda.CUDAGraph()
  92. with torch.cuda.graph(g):
  93. for _ in range(n_repeat):
  94. if grad_to_none is not None:
  95. for x in grad_to_none:
  96. x.grad = None
  97. fn()
  98. torch.cuda.synchronize()
  99. # measure time and return
  100. ret = []
  101. n_retries = 10
  102. for _ in range(n_retries):
  103. start_event = torch.cuda.Event(enable_timing=True)
  104. end_event = torch.cuda.Event(enable_timing=True)
  105. start_event.record()
  106. g.replay()
  107. end_event.record()
  108. torch.cuda.synchronize()
  109. ret += [start_event.elapsed_time(end_event) / n_repeat]
  110. return _summarize_statistics(ret, quantiles, return_mode)
  111. def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
  112. """
  113. Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
  114. the 20-th and 80-th performance percentile.
  115. :param fn: Function to benchmark
  116. :type fn: Callable
  117. :param warmup: Warmup time (in ms)
  118. :type warmup: int
  119. :param rep: Repetition time (in ms)
  120. :type rep: int
  121. :param grad_to_none: Reset the gradient of the provided tensor to None
  122. :type grad_to_none: torch.tensor, optional
  123. :param quantiles: Performance percentile to return in addition to the median.
  124. :type quantiles: list[float], optional
  125. :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
  126. :type return_mode: str
  127. """
  128. assert return_mode in ["min", "max", "mean", "median", "all"]
  129. di = runtime.driver.active.get_device_interface()
  130. fn()
  131. di.synchronize()
  132. cache = runtime.driver.active.get_empty_cache_for_benchmark()
  133. # Estimate the runtime of the function
  134. start_event = di.Event(enable_timing=True)
  135. end_event = di.Event(enable_timing=True)
  136. start_event.record()
  137. for _ in range(5):
  138. runtime.driver.active.clear_cache(cache)
  139. fn()
  140. end_event.record()
  141. di.synchronize()
  142. estimate_ms = start_event.elapsed_time(end_event) / 5
  143. # compute number of warmup and repeat
  144. n_warmup = max(1, int(warmup / estimate_ms))
  145. n_repeat = max(1, int(rep / estimate_ms))
  146. start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
  147. end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
  148. # Warm-up
  149. for _ in range(n_warmup):
  150. fn()
  151. # Benchmark
  152. for i in range(n_repeat):
  153. # we don't want `fn` to accumulate gradient values
  154. # if it contains a backward pass. So we clear the
  155. # provided gradients
  156. if grad_to_none is not None:
  157. for x in grad_to_none:
  158. x.grad = None
  159. # we clear the L2 cache before each run
  160. runtime.driver.active.clear_cache(cache)
  161. # record time of `fn`
  162. start_event[i].record()
  163. fn()
  164. end_event[i].record()
  165. # Record clocks
  166. di.synchronize()
  167. times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
  168. return _summarize_statistics(times, quantiles, return_mode)
  169. def assert_close(x, y, atol=None, rtol=None, err_msg=''):
  170. """
  171. Asserts that two inputs are close within a certain tolerance.
  172. :param x: The first input.
  173. :type x: scala, list, numpy.ndarray, or torch.Tensor
  174. :param y: The second input.
  175. :type y: scala, list, numpy.ndarray, or torch.Tensor
  176. :param atol: The absolute tolerance. Default value is 1e-2.
  177. :type atol: float, optional
  178. :param rtol: The relative tolerance. Default value is 0.
  179. :type rtol: float, optional
  180. :param err_msg: The error message to use if the assertion fails.
  181. :type err_msg: str
  182. """
  183. import numpy as np
  184. import torch
  185. # canonicalize arguments to be tensors
  186. if not isinstance(x, torch.Tensor):
  187. x = torch.tensor(x)
  188. if not isinstance(y, torch.Tensor):
  189. y = torch.tensor(y)
  190. # absolute tolerance
  191. if atol is None:
  192. atol = 1e-2
  193. atol = atol(x.dtype) if callable(atol) else atol
  194. # relative tolerance hook
  195. if rtol is None:
  196. rtol = 0.
  197. rtol = rtol(x.dtype) if callable(rtol) else rtol
  198. # we use numpy instead of pytorch
  199. # as it seems more memory efficient
  200. # pytorch tends to oom on large tensors
  201. if isinstance(x, torch.Tensor):
  202. if x.dtype == torch.bfloat16:
  203. x = x.float()
  204. x = x.cpu().detach().numpy()
  205. if isinstance(y, torch.Tensor):
  206. if y.dtype == torch.bfloat16:
  207. y = y.float()
  208. y = y.cpu().detach().numpy()
  209. # we handle size==1 case separately as we can
  210. # provide better error message there
  211. if x.size > 1 or y.size > 1:
  212. np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
  213. return
  214. if not np.allclose(x, y, atol=atol, rtol=rtol):
  215. raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
  216. class Benchmark:
  217. """
  218. This class is used by the :code:`perf_report` function to generate line plots with a concise API.
  219. """
  220. def __init__(
  221. self,
  222. x_names: List[str],
  223. x_vals: List[Any],
  224. line_arg: str,
  225. line_vals: List[Any],
  226. line_names: List[str],
  227. plot_name: str,
  228. args: Dict[str, Any],
  229. xlabel: str = '',
  230. ylabel: str = '',
  231. x_log: bool = False,
  232. y_log: bool = False,
  233. styles=None,
  234. ):
  235. """
  236. Constructor.
  237. x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list
  238. of scalars and there are multiple x_names, all arguments will have the same value.
  239. If x_vals is a list of tuples/lists, each element should have the same length as
  240. x_names.
  241. :param x_names: Name of the arguments that should appear on the x axis of the plot.
  242. :type x_names: List[str]
  243. :param x_vals: List of values to use for the arguments in :code:`x_names`.
  244. :type x_vals: List[Any]
  245. :param line_arg: Argument name for which different values correspond to different lines in the plot.
  246. :type line_arg: str
  247. :param line_vals: List of values to use for the arguments in :code:`line_arg`.
  248. :type line_vals: List[Any]
  249. :param line_names: Label names for the different lines.
  250. :type line_names: List[str]
  251. :param plot_name: Name of the plot.
  252. :type plot_name: str
  253. :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark.
  254. :type args: Dict[str, Any]
  255. :param xlabel: Label for the x axis of the plot.
  256. :type xlabel: str, optional
  257. :param ylabel: Label for the y axis of the plot.
  258. :type ylabel: str, optional
  259. :param x_log: Whether the x axis should be log scale.
  260. :type x_log: bool, optional
  261. :param y_log: Whether the y axis should be log scale.
  262. :type y_log: bool, optional
  263. :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle.
  264. :type styles: list[tuple[str, str]]
  265. """
  266. self.x_names = x_names
  267. self.x_vals = x_vals
  268. self.x_log = x_log
  269. self.line_arg = line_arg
  270. self.line_vals = line_vals
  271. self.line_names = line_names
  272. self.y_log = y_log
  273. self.styles = styles
  274. # plot info
  275. self.xlabel = xlabel
  276. self.ylabel = ylabel
  277. self.plot_name = plot_name
  278. self.args = args
  279. class Mark:
  280. def __init__(self, fn, benchmarks):
  281. self.fn = fn
  282. self.benchmarks = benchmarks
  283. def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False,
  284. save_precision=6, **kwrags):
  285. import os
  286. import matplotlib.pyplot as plt
  287. import pandas as pd
  288. y_mean_labels = [f'{x} ({bench.ylabel})' for x in bench.line_names]
  289. y_min_labels = [f'{x}-min ({bench.ylabel})' for x in bench.line_names]
  290. y_max_labels = [f'{x}-max ({bench.ylabel})' for x in bench.line_names]
  291. x_names = list(bench.x_names)
  292. df = pd.DataFrame(columns=x_names + y_mean_labels + y_min_labels + y_max_labels)
  293. for x in bench.x_vals:
  294. # x can be a single value or a sequence of values.
  295. if not isinstance(x, (list, tuple)):
  296. x = [x for _ in x_names]
  297. if len(x) != len(x_names):
  298. raise ValueError(f"Expected {len(x_names)} values, got {x}")
  299. x_args = dict(zip(x_names, x))
  300. row_mean, row_min, row_max = [], [], []
  301. for y in bench.line_vals:
  302. ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
  303. try:
  304. y_mean, y_min, y_max = ret
  305. except TypeError:
  306. y_mean, y_min, y_max = ret, None, None
  307. row_mean += [y_mean]
  308. row_min += [y_min]
  309. row_max += [y_max]
  310. df.loc[len(df)] = list(x) + row_mean + row_min + row_max
  311. if bench.plot_name:
  312. plt.figure()
  313. ax = plt.subplot()
  314. # Plot first x value on x axis if there are multiple.
  315. first_x = x_names[0]
  316. for i, (mean_label, min_label, max_label) in enumerate(zip(y_mean_labels, y_min_labels, y_max_labels)):
  317. y_min, y_max = df[min_label], df[max_label]
  318. col = bench.styles[i][0] if bench.styles else None
  319. sty = bench.styles[i][1] if bench.styles else None
  320. ax.plot(df[first_x], df[mean_label], label=mean_label, color=col, ls=sty)
  321. if not y_min.isnull().all() and not y_max.isnull().all():
  322. y_min = y_min.astype(float)
  323. y_max = y_max.astype(float)
  324. ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
  325. ax.legend()
  326. ax.set_xlabel(bench.xlabel or first_x)
  327. ax.set_ylabel(bench.ylabel)
  328. # ax.set_title(bench.plot_name)
  329. ax.set_xscale("log" if bench.x_log else "linear")
  330. ax.set_yscale("log" if bench.y_log else "linear")
  331. if show_plots:
  332. plt.show()
  333. if save_path:
  334. plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
  335. df = df[x_names + y_mean_labels]
  336. if diff_col and df.shape[1] == 2:
  337. col0, col1 = df.columns.tolist()
  338. df['Diff'] = df[col1] - df[col0]
  339. if print_data:
  340. print(bench.plot_name + ':')
  341. print(df.to_string())
  342. if save_path:
  343. df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
  344. index=False)
  345. return df
  346. def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
  347. has_single_bench = isinstance(self.benchmarks, Benchmark)
  348. benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
  349. result_dfs = []
  350. try:
  351. for bench in benchmarks:
  352. result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
  353. finally:
  354. if save_path:
  355. # Create directory if it doesn't exist
  356. os.makedirs(save_path, exist_ok=True)
  357. with open(os.path.join(save_path, "results.html"), "w") as html:
  358. html.write("<html><body>\n")
  359. for bench in benchmarks[:len(result_dfs)]:
  360. html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
  361. html.write("</body></html>\n")
  362. if return_df:
  363. if has_single_bench:
  364. return result_dfs[0]
  365. else:
  366. return result_dfs
  367. return None
  368. def perf_report(benchmarks):
  369. """
  370. Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
  371. :param benchmarks: Benchmarking configurations.
  372. :type benchmarks: List of :class:`Benchmark`
  373. """
  374. wrapper = lambda fn: Mark(fn, benchmarks)
  375. return wrapper
  376. def get_dram_gbps(device=None):
  377. ''' return DRAM bandwidth in GB/s '''
  378. from .runtime import driver
  379. if device is None:
  380. device = driver.active.get_device_interface().current_device()
  381. mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
  382. bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"]
  383. bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
  384. return bw_gbps
  385. def get_max_tensorcore_tflops(dtype, clock_rate, device=None):
  386. import torch
  387. from .runtime import driver
  388. if not device:
  389. device = torch.cuda.current_device()
  390. num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
  391. capability = torch.cuda.get_device_capability(device)
  392. if capability[0] < 8:
  393. assert dtype == torch.float16
  394. ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
  395. else:
  396. if dtype in [torch.float32, torch.int32]:
  397. ops_per_sub_core = 256
  398. elif dtype in [torch.float16, torch.bfloat16, torch.int16]:
  399. ops_per_sub_core = 512
  400. elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]:
  401. ops_per_sub_core = 1024
  402. else:
  403. raise RuntimeError("dtype not supported")
  404. tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
  405. return tflops
  406. # create decorator that wraps test function into
  407. # a cuda-memcheck system call
  408. def cuda_memcheck(**target_kwargs):
  409. def decorator(test_fn):
  410. @functools.wraps(test_fn)
  411. def wrapper(*args, **kwargs):
  412. import psutil
  413. ppid_name = psutil.Process(os.getppid()).name()
  414. run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
  415. if run_cuda_memcheck and ppid_name != "cuda-memcheck":
  416. path = os.path.realpath(test_fn.__globals__["__file__"])
  417. # get path of current file
  418. env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
  419. assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
  420. test_id = kwargs['request'].node.callspec.id
  421. cmd = f"{path}::{test_fn.__name__}[{test_id}]"
  422. out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
  423. assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
  424. assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
  425. else:
  426. test_fn(*args, **kwargs)
  427. return wrapper
  428. return decorator
  429. @contextmanager
  430. def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
  431. try:
  432. subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
  433. subprocess.check_output([
  434. "nvidia-smi",
  435. "-i",
  436. "0",
  437. f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
  438. ])
  439. subprocess.check_output([
  440. "nvidia-smi",
  441. "-i",
  442. "0",
  443. f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
  444. ])
  445. cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
  446. cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
  447. assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
  448. assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
  449. tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
  450. gbps = 640 * 2 * ref_mem_clock * 1e-3
  451. yield tflops, gbps
  452. finally:
  453. subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
  454. subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
  455. subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
  456. def get_max_simd_tflops(dtype, clock_rate, device=None):
  457. import torch
  458. from .runtime import driver
  459. if not device:
  460. device = torch.cuda.current_device()
  461. num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
  462. capability = torch.cuda.get_device_capability()
  463. if capability[0] < 8:
  464. if dtype == torch.float32:
  465. ops_per_sub_core = 32 # 2*16
  466. elif dtype == torch.float16:
  467. ops_per_sub_core = 64
  468. else:
  469. raise RuntimeError("dtype not supported")
  470. else:
  471. if dtype == torch.float32:
  472. ops_per_sub_core = 32
  473. elif dtype in [torch.float16, torch.bfloat16]:
  474. ops_per_sub_core = 64
  475. else:
  476. raise RuntimeError("dtype not supported")
  477. tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
  478. return tflops