| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542 |
- import functools
- import math
- import os
- import statistics
- import subprocess
- import sys
- from contextlib import contextmanager
- from typing import Any, Dict, List
- from . import language as tl
- from . import runtime
- def nvsmi(attrs):
- attrs = ','.join(attrs)
- cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits']
- out = subprocess.check_output(cmd)
- ret = out.decode(sys.stdout.encoding).split(',')
- ret = [int(x) for x in ret]
- return ret
- # pure Python implementation of np.quantile/torch.quantile
- # to avoid unnecessary runtime dependency on numpy/torch
- def _quantile(a, q):
- n = len(a)
- a = sorted(a)
- def get_quantile(q):
- if not (0 <= q <= 1):
- raise ValueError("Quantiles must be in the range [0, 1]")
- point = q * (n - 1)
- lower = math.floor(point)
- upper = math.ceil(point)
- t = point - lower
- return (1 - t) * a[lower] + t * a[upper]
- return [get_quantile(q) for q in q]
- def _summarize_statistics(times, quantiles, return_mode):
- if quantiles is not None:
- ret = _quantile(times, quantiles)
- if len(ret) == 1:
- ret = ret[0]
- return ret
- if return_mode == "all":
- return times
- elif return_mode == "min":
- return min(times)
- elif return_mode == "max":
- return max(times)
- elif return_mode == "mean":
- return statistics.mean(times)
- elif return_mode == "median":
- return statistics.median(times)
- def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"):
- """
- Benchmark the runtime of the provided function.
- :param fn: Function to benchmark
- :type fn: Callable
- :param rep: Repetition time (in ms)
- :type rep: int
- :param grad_to_none: Reset the gradient of the provided tensor to None
- :type grad_to_none: torch.tensor, optional
- :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
- :type return_mode: str
- """
- import torch
- assert return_mode in ["min", "max", "mean", "median", "all"]
- with torch.cuda.stream(torch.cuda.Stream()):
- # warmup
- fn()
- if grad_to_none is not None:
- for x in grad_to_none:
- x.detach_()
- x.requires_grad_(True)
- x.grad = None
- # step 1 - we estimate the amount of time the kernel call takes
- # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point
- # but it is probably good enough
- # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive,
- # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2
- # cache flush).
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- start_event.record()
- for _ in range(5):
- fn()
- end_event.record()
- torch.cuda.synchronize()
- estimate_ms = start_event.elapsed_time(end_event) / 5
- # Rewrite to avoid possible division by 0 issues with fast benchmarks
- if estimate_ms == 0:
- n_repeat = 1000
- else:
- n_repeat = max(1, int(rep / estimate_ms))
- # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize
- # host overhead
- g = torch.cuda.CUDAGraph()
- with torch.cuda.graph(g):
- for _ in range(n_repeat):
- if grad_to_none is not None:
- for x in grad_to_none:
- x.grad = None
- fn()
- torch.cuda.synchronize()
- # measure time and return
- ret = []
- n_retries = 10
- for _ in range(n_retries):
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- start_event.record()
- g.replay()
- end_event.record()
- torch.cuda.synchronize()
- ret += [start_event.elapsed_time(end_event) / n_repeat]
- return _summarize_statistics(ret, quantiles, return_mode)
- def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
- """
- Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
- the 20-th and 80-th performance percentile.
- :param fn: Function to benchmark
- :type fn: Callable
- :param warmup: Warmup time (in ms)
- :type warmup: int
- :param rep: Repetition time (in ms)
- :type rep: int
- :param grad_to_none: Reset the gradient of the provided tensor to None
- :type grad_to_none: torch.tensor, optional
- :param quantiles: Performance percentile to return in addition to the median.
- :type quantiles: list[float], optional
- :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
- :type return_mode: str
- """
- assert return_mode in ["min", "max", "mean", "median", "all"]
- di = runtime.driver.active.get_device_interface()
- fn()
- di.synchronize()
- cache = runtime.driver.active.get_empty_cache_for_benchmark()
- # Estimate the runtime of the function
- start_event = di.Event(enable_timing=True)
- end_event = di.Event(enable_timing=True)
- start_event.record()
- for _ in range(5):
- runtime.driver.active.clear_cache(cache)
- fn()
- end_event.record()
- di.synchronize()
- estimate_ms = start_event.elapsed_time(end_event) / 5
- # compute number of warmup and repeat
- n_warmup = max(1, int(warmup / estimate_ms))
- n_repeat = max(1, int(rep / estimate_ms))
- start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
- end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
- # Warm-up
- for _ in range(n_warmup):
- fn()
- # Benchmark
- for i in range(n_repeat):
- # we don't want `fn` to accumulate gradient values
- # if it contains a backward pass. So we clear the
- # provided gradients
- if grad_to_none is not None:
- for x in grad_to_none:
- x.grad = None
- # we clear the L2 cache before each run
- runtime.driver.active.clear_cache(cache)
- # record time of `fn`
- start_event[i].record()
- fn()
- end_event[i].record()
- # Record clocks
- di.synchronize()
- times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
- return _summarize_statistics(times, quantiles, return_mode)
- def assert_close(x, y, atol=None, rtol=None, err_msg=''):
- """
- Asserts that two inputs are close within a certain tolerance.
- :param x: The first input.
- :type x: scala, list, numpy.ndarray, or torch.Tensor
- :param y: The second input.
- :type y: scala, list, numpy.ndarray, or torch.Tensor
- :param atol: The absolute tolerance. Default value is 1e-2.
- :type atol: float, optional
- :param rtol: The relative tolerance. Default value is 0.
- :type rtol: float, optional
- :param err_msg: The error message to use if the assertion fails.
- :type err_msg: str
- """
- import numpy as np
- import torch
- # canonicalize arguments to be tensors
- if not isinstance(x, torch.Tensor):
- x = torch.tensor(x)
- if not isinstance(y, torch.Tensor):
- y = torch.tensor(y)
- # absolute tolerance
- if atol is None:
- atol = 1e-2
- atol = atol(x.dtype) if callable(atol) else atol
- # relative tolerance hook
- if rtol is None:
- rtol = 0.
- rtol = rtol(x.dtype) if callable(rtol) else rtol
- # we use numpy instead of pytorch
- # as it seems more memory efficient
- # pytorch tends to oom on large tensors
- if isinstance(x, torch.Tensor):
- if x.dtype == torch.bfloat16:
- x = x.float()
- x = x.cpu().detach().numpy()
- if isinstance(y, torch.Tensor):
- if y.dtype == torch.bfloat16:
- y = y.float()
- y = y.cpu().detach().numpy()
- # we handle size==1 case separately as we can
- # provide better error message there
- if x.size > 1 or y.size > 1:
- np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True)
- return
- if not np.allclose(x, y, atol=atol, rtol=rtol):
- raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})')
- class Benchmark:
- """
- This class is used by the :code:`perf_report` function to generate line plots with a concise API.
- """
- def __init__(
- self,
- x_names: List[str],
- x_vals: List[Any],
- line_arg: str,
- line_vals: List[Any],
- line_names: List[str],
- plot_name: str,
- args: Dict[str, Any],
- xlabel: str = '',
- ylabel: str = '',
- x_log: bool = False,
- y_log: bool = False,
- styles=None,
- ):
- """
- Constructor.
- x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list
- of scalars and there are multiple x_names, all arguments will have the same value.
- If x_vals is a list of tuples/lists, each element should have the same length as
- x_names.
- :param x_names: Name of the arguments that should appear on the x axis of the plot.
- :type x_names: List[str]
- :param x_vals: List of values to use for the arguments in :code:`x_names`.
- :type x_vals: List[Any]
- :param line_arg: Argument name for which different values correspond to different lines in the plot.
- :type line_arg: str
- :param line_vals: List of values to use for the arguments in :code:`line_arg`.
- :type line_vals: List[Any]
- :param line_names: Label names for the different lines.
- :type line_names: List[str]
- :param plot_name: Name of the plot.
- :type plot_name: str
- :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark.
- :type args: Dict[str, Any]
- :param xlabel: Label for the x axis of the plot.
- :type xlabel: str, optional
- :param ylabel: Label for the y axis of the plot.
- :type ylabel: str, optional
- :param x_log: Whether the x axis should be log scale.
- :type x_log: bool, optional
- :param y_log: Whether the y axis should be log scale.
- :type y_log: bool, optional
- :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle.
- :type styles: list[tuple[str, str]]
- """
- self.x_names = x_names
- self.x_vals = x_vals
- self.x_log = x_log
- self.line_arg = line_arg
- self.line_vals = line_vals
- self.line_names = line_names
- self.y_log = y_log
- self.styles = styles
- # plot info
- self.xlabel = xlabel
- self.ylabel = ylabel
- self.plot_name = plot_name
- self.args = args
- class Mark:
- def __init__(self, fn, benchmarks):
- self.fn = fn
- self.benchmarks = benchmarks
- def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False,
- save_precision=6, **kwrags):
- import os
- import matplotlib.pyplot as plt
- import pandas as pd
- y_mean_labels = [f'{x} ({bench.ylabel})' for x in bench.line_names]
- y_min_labels = [f'{x}-min ({bench.ylabel})' for x in bench.line_names]
- y_max_labels = [f'{x}-max ({bench.ylabel})' for x in bench.line_names]
- x_names = list(bench.x_names)
- df = pd.DataFrame(columns=x_names + y_mean_labels + y_min_labels + y_max_labels)
- for x in bench.x_vals:
- # x can be a single value or a sequence of values.
- if not isinstance(x, (list, tuple)):
- x = [x for _ in x_names]
- if len(x) != len(x_names):
- raise ValueError(f"Expected {len(x_names)} values, got {x}")
- x_args = dict(zip(x_names, x))
- row_mean, row_min, row_max = [], [], []
- for y in bench.line_vals:
- ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
- try:
- y_mean, y_min, y_max = ret
- except TypeError:
- y_mean, y_min, y_max = ret, None, None
- row_mean += [y_mean]
- row_min += [y_min]
- row_max += [y_max]
- df.loc[len(df)] = list(x) + row_mean + row_min + row_max
- if bench.plot_name:
- plt.figure()
- ax = plt.subplot()
- # Plot first x value on x axis if there are multiple.
- first_x = x_names[0]
- for i, (mean_label, min_label, max_label) in enumerate(zip(y_mean_labels, y_min_labels, y_max_labels)):
- y_min, y_max = df[min_label], df[max_label]
- col = bench.styles[i][0] if bench.styles else None
- sty = bench.styles[i][1] if bench.styles else None
- ax.plot(df[first_x], df[mean_label], label=mean_label, color=col, ls=sty)
- if not y_min.isnull().all() and not y_max.isnull().all():
- y_min = y_min.astype(float)
- y_max = y_max.astype(float)
- ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col)
- ax.legend()
- ax.set_xlabel(bench.xlabel or first_x)
- ax.set_ylabel(bench.ylabel)
- # ax.set_title(bench.plot_name)
- ax.set_xscale("log" if bench.x_log else "linear")
- ax.set_yscale("log" if bench.y_log else "linear")
- if show_plots:
- plt.show()
- if save_path:
- plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png"))
- df = df[x_names + y_mean_labels]
- if diff_col and df.shape[1] == 2:
- col0, col1 = df.columns.tolist()
- df['Diff'] = df[col1] - df[col0]
- if print_data:
- print(bench.plot_name + ':')
- print(df.to_string())
- if save_path:
- df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f",
- index=False)
- return df
- def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs):
- has_single_bench = isinstance(self.benchmarks, Benchmark)
- benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks
- result_dfs = []
- try:
- for bench in benchmarks:
- result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
- finally:
- if save_path:
- # Create directory if it doesn't exist
- os.makedirs(save_path, exist_ok=True)
- with open(os.path.join(save_path, "results.html"), "w") as html:
- html.write("<html><body>\n")
- for bench in benchmarks[:len(result_dfs)]:
- html.write(f"<image src=\"{bench.plot_name}.png\"/>\n")
- html.write("</body></html>\n")
- if return_df:
- if has_single_bench:
- return result_dfs[0]
- else:
- return result_dfs
- return None
- def perf_report(benchmarks):
- """
- Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.
- :param benchmarks: Benchmarking configurations.
- :type benchmarks: List of :class:`Benchmark`
- """
- wrapper = lambda fn: Mark(fn, benchmarks)
- return wrapper
- def get_dram_gbps(device=None):
- ''' return DRAM bandwidth in GB/s '''
- from .runtime import driver
- if device is None:
- device = driver.active.get_device_interface().current_device()
- mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz
- bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"]
- bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s
- return bw_gbps
- def get_max_tensorcore_tflops(dtype, clock_rate, device=None):
- import torch
- from .runtime import driver
- if not device:
- device = torch.cuda.current_device()
- num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
- capability = torch.cuda.get_device_capability(device)
- if capability[0] < 8:
- assert dtype == torch.float16
- ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores
- else:
- if dtype in [torch.float32, torch.int32]:
- ops_per_sub_core = 256
- elif dtype in [torch.float16, torch.bfloat16, torch.int16]:
- ops_per_sub_core = 512
- elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]:
- ops_per_sub_core = 1024
- else:
- raise RuntimeError("dtype not supported")
- tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
- return tflops
- # create decorator that wraps test function into
- # a cuda-memcheck system call
- def cuda_memcheck(**target_kwargs):
- def decorator(test_fn):
- @functools.wraps(test_fn)
- def wrapper(*args, **kwargs):
- import psutil
- ppid_name = psutil.Process(os.getppid()).name()
- run_cuda_memcheck = target_kwargs.items() <= kwargs.items()
- if run_cuda_memcheck and ppid_name != "cuda-memcheck":
- path = os.path.realpath(test_fn.__globals__["__file__"])
- # get path of current file
- env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"}
- assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture"
- test_id = kwargs['request'].node.callspec.id
- cmd = f"{path}::{test_fn.__name__}[{test_id}]"
- out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env)
- assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed"
- assert "ERROR SUMMARY: 0 errors" in str(out.stdout)
- else:
- test_fn(*args, **kwargs)
- return wrapper
- return decorator
- @contextmanager
- def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215):
- try:
- subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"])
- subprocess.check_output([
- "nvidia-smi",
- "-i",
- "0",
- f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
- ])
- subprocess.check_output([
- "nvidia-smi",
- "-i",
- "0",
- f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}",
- ])
- cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
- cur_mem_clock = nvsmi(["clocks.current.memory"])[0]
- assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz"
- assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz"
- tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock
- gbps = 640 * 2 * ref_mem_clock * 1e-3
- yield tflops, gbps
- finally:
- subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"])
- subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"])
- subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"])
- def get_max_simd_tflops(dtype, clock_rate, device=None):
- import torch
- from .runtime import driver
- if not device:
- device = torch.cuda.current_device()
- num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4
- capability = torch.cuda.get_device_capability()
- if capability[0] < 8:
- if dtype == torch.float32:
- ops_per_sub_core = 32 # 2*16
- elif dtype == torch.float16:
- ops_per_sub_core = 64
- else:
- raise RuntimeError("dtype not supported")
- else:
- if dtype == torch.float32:
- ops_per_sub_core = 32
- elif dtype in [torch.float16, torch.bfloat16]:
- ops_per_sub_core = 64
- else:
- raise RuntimeError("dtype not supported")
- tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9
- return tflops
|