autotuner.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. from __future__ import annotations
  2. import builtins
  3. import time
  4. import inspect
  5. import hashlib
  6. import json
  7. from functools import cached_property
  8. from typing import Dict, Tuple, List, Optional
  9. from .. import knobs
  10. from .jit import KernelInterface, JITFunction
  11. from .errors import OutOfResources, PTXASError, AutotunerError
  12. from .driver import driver
  13. from .cache import get_cache_manager, triton_key
  14. from triton._C.libtriton import get_cache_invalidating_env_vars
  15. class Autotuner(KernelInterface):
  16. def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
  17. prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None,
  18. cache_results=False):
  19. """
  20. :param prune_configs_by: a dict of functions that are used to prune configs, fields:
  21. 'perf_model': performance model used to predicate running time with different configs, returns running time
  22. 'top_k': number of configs to bench
  23. 'early_config_prune': a function used to prune configs. It should have the signature
  24. `prune_configs_by( configs: List[triton.Config], named_args: Dict[str, Any], **kwargs: Dict[str, Any]) -> List[triton.Config]:`
  25. and return pruned configs. It should return at least one config.
  26. """
  27. if not configs:
  28. self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
  29. else:
  30. self.configs = configs
  31. self.keys = key
  32. self.cache: Dict[Tuple, Config] = {}
  33. self.arg_names = arg_names
  34. self.cache_results = (cache_results or knobs.autotuning.cache) and not knobs.runtime.interpret
  35. # Reset to zero or restore values
  36. self.reset_to_zero = []
  37. if reset_to_zero is not None:
  38. self.reset_to_zero = list(reset_to_zero)
  39. self.restore_value = []
  40. if restore_value is not None:
  41. self.restore_value = list(restore_value)
  42. # Hook to reset or restore for required tensors
  43. self.pre_hook = lambda kwargs, reset_only=False: 0
  44. self.post_hook = lambda kwargs, exception: 0
  45. self.user_defined_pre_hook = False
  46. self.user_defined_post_hook = False
  47. if pre_hook:
  48. self.pre_hook = pre_hook
  49. self.user_defined_pre_hook = True
  50. elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0):
  51. def _pre_hook(kwargs, reset_only=False):
  52. for name in self.reset_to_zero:
  53. kwargs[name].zero_()
  54. if not reset_only:
  55. self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value}
  56. self.pre_hook = _pre_hook
  57. if post_hook:
  58. self.post_hook = post_hook
  59. self.user_defined_post_hook = True
  60. elif len(self.restore_value) > 0:
  61. def _post_hook(kwargs, exception):
  62. for name in self.restore_value:
  63. kwargs[name].copy_(self.restore_copies[name])
  64. self.restore_copies = {}
  65. self.post_hook = _post_hook
  66. self.perf_model = None
  67. self.configs_top_k = 1.0
  68. self.early_config_prune = None
  69. if prune_configs_by:
  70. self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
  71. self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
  72. self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune)
  73. self.fn = fn
  74. self.base_fn = fn
  75. while not inspect.isfunction(self.base_fn):
  76. self.base_fn = self.base_fn.fn
  77. self._do_bench = do_bench
  78. self.num_warmups = warmup
  79. self.num_reps = rep
  80. self.use_cuda_graph = use_cuda_graph
  81. # If we got explicitly called via the old interface, raise a warning
  82. # and proceed with the old behavior.
  83. if warmup is not None or rep is not None or use_cuda_graph:
  84. import warnings
  85. warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
  86. "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning,
  87. stacklevel=1)
  88. if use_cuda_graph:
  89. from ..testing import do_bench_cudagraph
  90. self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
  91. kernel_call,
  92. rep=rep if rep is not None else 100,
  93. quantiles=quantiles,
  94. )
  95. return
  96. import triton.testing
  97. self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
  98. kernel_call,
  99. warmup=warmup if warmup is not None else 25,
  100. rep=rep if rep is not None else 100,
  101. quantiles=quantiles,
  102. )
  103. return
  104. @cached_property
  105. def do_bench(self):
  106. if self._do_bench is None:
  107. return driver.active.get_benchmarker()
  108. return self._do_bench
  109. def _bench(self, *args, config, **meta):
  110. from ..compiler.errors import CompileTimeAssertionFailure
  111. verbose = knobs.autotuning.print
  112. if verbose:
  113. print(f"Autotuning kernel {self.base_fn.__name__} with config {config}")
  114. # check for conflicts, i.e. meta-parameters both provided
  115. # as kwargs and by the autotuner
  116. conflicts = meta.keys() & config.kwargs.keys()
  117. if conflicts:
  118. raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
  119. " Make sure that you don't re-define auto-tuned symbols.")
  120. # augment meta-parameters with tunable ones
  121. current = dict(meta, **config.all_kwargs())
  122. full_nargs = {**self.nargs, **current}
  123. def kernel_call():
  124. if config.pre_hook:
  125. config.pre_hook(full_nargs)
  126. self.pre_hook(full_nargs)
  127. try:
  128. self.fn.run(
  129. *args,
  130. **current,
  131. )
  132. except Exception as e:
  133. try:
  134. self.post_hook(full_nargs, exception=e)
  135. finally:
  136. # Throw exception raised by `self.fn.run`
  137. raise
  138. self.post_hook(full_nargs, exception=None)
  139. try:
  140. return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
  141. except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e:
  142. if verbose:
  143. print(f"Autotuning failed with {e}")
  144. return [float("inf"), float("inf"), float("inf")]
  145. def check_disk_cache(self, tuning_key, configs, bench_fn):
  146. # We can't serialize prehooks, so just give up and run the benchmarks.
  147. if not tuning_key or any(cfg.pre_hook for cfg in configs):
  148. bench_fn()
  149. return False
  150. from triton.compiler.compiler import make_backend
  151. fn = self.fn
  152. while not isinstance(fn, JITFunction):
  153. fn = fn.fn
  154. env_vars = get_cache_invalidating_env_vars()
  155. cache_key = [
  156. triton_key(),
  157. make_backend(driver.active.get_current_target()).hash(),
  158. fn.cache_key,
  159. str(sorted(env_vars.items())),
  160. str(tuning_key),
  161. ] + [str(c) for c in configs]
  162. cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
  163. cache = get_cache_manager(cache_key)
  164. file_name = f"{fn.__name__[:150]}.autotune.json"
  165. path = cache.get_file(file_name)
  166. if path:
  167. with open(path, "r") as cached_configs:
  168. timings = json.load(cached_configs)["configs_timings"]
  169. timings = {Config(**config): timing for config, timing in timings}
  170. self.cache[tuning_key] = builtins.min(timings, key=timings.get)
  171. self.configs_timings = timings
  172. return True
  173. bench_fn()
  174. cache.put(
  175. json.dumps({
  176. "key":
  177. tuning_key,
  178. "configs_timings":
  179. [(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
  180. }), file_name, binary=False)
  181. return False
  182. def run(self, *args, **kwargs):
  183. self.nargs = dict(zip(self.arg_names, args))
  184. used_cached_result = True
  185. if len(self.configs) > 1:
  186. all_args = {**self.nargs, **kwargs}
  187. _args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
  188. key = [_args[key] for key in self.keys if key in _args]
  189. for _, arg in _args.items():
  190. if hasattr(arg, "dtype"):
  191. key.append(str(arg.dtype))
  192. key = tuple(key)
  193. if key not in self.cache:
  194. used_cached_result = False
  195. pruned_configs = self.prune_configs(kwargs)
  196. def benchmark():
  197. bench_start = time.perf_counter()
  198. timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  199. bench_end = time.perf_counter()
  200. self.bench_time = bench_end - bench_start
  201. self.cache[key] = builtins.min(timings, key=timings.get)
  202. full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
  203. self.pre_hook(full_nargs, reset_only=True)
  204. self.configs_timings = timings
  205. if self.cache_results:
  206. used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
  207. else:
  208. benchmark()
  209. config = self.cache[key]
  210. else:
  211. config = self.configs[0]
  212. self.best_config = config
  213. if knobs.autotuning.print and not used_cached_result:
  214. print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
  215. f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
  216. if config.pre_hook is not None:
  217. full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
  218. config.pre_hook(full_nargs)
  219. ret = self.fn.run(
  220. *args,
  221. **kwargs,
  222. **config.all_kwargs(),
  223. )
  224. self.nargs = None
  225. return ret
  226. def prune_configs(self, kwargs: Dict) -> List[Config]:
  227. pruned_configs = self.configs
  228. if self.early_config_prune:
  229. pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
  230. if not pruned_configs:
  231. raise AutotunerError(
  232. "No valid autotuner configs after pruning. `early_config_prune` should return at least one config.")
  233. if self.perf_model:
  234. top_k = self.configs_top_k
  235. if isinstance(top_k, float) and top_k <= 1.0:
  236. top_k = int(len(self.configs) * top_k)
  237. elif not isinstance(top_k, int):
  238. # Slice index must be an integer
  239. raise TypeError("Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int")
  240. if len(pruned_configs) > top_k:
  241. est_timing = {
  242. config: self.perf_model(
  243. **self.nargs,
  244. **kwargs,
  245. **config.all_kwargs(),
  246. )
  247. for config in pruned_configs
  248. }
  249. pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
  250. return pruned_configs
  251. def warmup(self, *args, **kwargs):
  252. self.nargs = dict(zip(self.arg_names, args))
  253. ret = []
  254. for autotune_config in self.prune_configs(kwargs):
  255. ret.append(self.fn.warmup(
  256. *args,
  257. **kwargs,
  258. **autotune_config.all_kwargs(),
  259. ))
  260. self.nargs = None
  261. return ret
  262. class Config:
  263. """
  264. An object that represents a possible kernel configuration for the auto-tuner to try.
  265. :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
  266. :type kwargs: dict[Str, Any]
  267. :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
  268. `num_warps=8`, then each kernel instance will be automatically parallelized to
  269. cooperatively execute using `8 * 32 = 256` threads.
  270. :type num_warps: int
  271. :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
  272. Mostly useful for matrix multiplication workloads on SM80+ GPUs.
  273. :type num_stages: int
  274. :ivar num_ctas: number of blocks in a block cluster. SM90+ only.
  275. :type num_ctas: int
  276. :type maxnreg: Optional[int]
  277. :ivar maxnreg: maximum number of registers one thread can use. Corresponds
  278. to ptx .maxnreg directive. Not supported on all platforms.
  279. :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
  280. function are args.
  281. :ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}).
  282. """
  283. def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None):
  284. self.kwargs = kwargs
  285. self.num_warps = num_warps
  286. self.num_ctas = num_ctas
  287. self.num_stages = num_stages
  288. self.maxnreg = maxnreg
  289. self.pre_hook = pre_hook
  290. self.ir_override = ir_override
  291. def __setstate__(self, state):
  292. self.kwargs = state.get("kwargs", {})
  293. self.num_warps = state.get("num_warps", 4)
  294. self.num_stages = state.get("num_stages", 3)
  295. self.num_ctas = state.get("num_ctas", 1)
  296. self.maxnreg = state.get("maxnreg", None)
  297. self.pre_hook = state.get("pre_hook", None)
  298. self.ir_override = state.get("ir_override", None)
  299. def all_kwargs(self):
  300. return {
  301. **self.kwargs, **{
  302. k: v
  303. for (k, v) in (
  304. ("num_warps", self.num_warps),
  305. ("num_ctas", self.num_ctas),
  306. ("num_stages", self.num_stages),
  307. ("maxnreg", self.maxnreg),
  308. ("ir_override", self.ir_override),
  309. ) if v is not None
  310. }
  311. }
  312. def __str__(self):
  313. res = []
  314. for k, v in self.kwargs.items():
  315. res.append(f"{k}: {v}")
  316. res.append(f"num_warps: {self.num_warps}")
  317. res.append(f"num_ctas: {self.num_ctas}")
  318. res.append(f"num_stages: {self.num_stages}")
  319. res.append(f"maxnreg: {self.maxnreg}")
  320. return ", ".join(res)
  321. def __hash__(self):
  322. return hash((*self.all_kwargs().items(), self.pre_hook))
  323. def __eq__(self, other):
  324. self_tuple = tuple((
  325. *self.all_kwargs().items(),
  326. self.pre_hook,
  327. ))
  328. other_tuple = tuple((
  329. *other.all_kwargs().items(),
  330. other.pre_hook,
  331. ))
  332. return self_tuple == other_tuple
  333. def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
  334. warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False):
  335. """
  336. Decorator for auto-tuning a :code:`triton.jit`'d function.
  337. .. highlight:: python
  338. .. code-block:: python
  339. @triton.autotune(configs=[
  340. triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4),
  341. triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8),
  342. ],
  343. key=['x_size'] # the two above configs will be evaluated anytime
  344. # the value of x_size changes
  345. )
  346. @triton.jit
  347. def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
  348. ...
  349. :note: When all the configurations are evaluated, the kernel will run multiple times.
  350. This means that whatever value the kernel updates will be updated multiple times.
  351. To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
  352. resets the value of the provided tensor to `zero` before running any configuration.
  353. If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to
  354. :code:`"1"`, Triton will print a message to stdout after autotuning each
  355. kernel, including the time spent autotuning and the best configuration.
  356. :param configs: a list of :code:`triton.Config` objects
  357. :type configs: list[triton.Config]
  358. :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
  359. :type key: list[str]
  360. :param prune_configs_by: a dict of functions that are used to prune configs, fields:
  361. 'perf_model': performance model used to predicate running time with different configs, returns running time
  362. 'top_k': number of configs to bench
  363. 'early_config_prune': a function used to prune configs. It should have the signature
  364. `prune_configs_by( configs: List[triton.Config], named_args: Dict[str, Any], **kwargs: Dict[str, Any]) -> List[triton.Config]:`
  365. and return pruned configs. It should return at least one config.
  366. :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
  367. :type reset_to_zero: list[str]
  368. :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
  369. :type restore_value: list[str]
  370. :param pre_hook: a function that will be called before the kernel is called.
  371. This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'.
  372. 'kwargs': a dict of all arguments passed to the kernel.
  373. 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook.
  374. :type pre_hook: lambda args, reset_only
  375. :param post_hook: a function that will be called after the kernel is called.
  376. This overrides the default post_hook used for 'restore_value'.
  377. 'kwargs': a dict of all arguments passed to the kernel.
  378. 'exception': the exception raised by the kernel in case of a compilation or runtime error.
  379. :type post_hook: lambda args, exception
  380. :param warmup: warmup time (in ms) to pass to benchmarking (deprecated).
  381. :type warmup: int
  382. :param rep: repetition time (in ms) to pass to benchmarking (deprecated).
  383. :type rep: int
  384. :param do_bench: a benchmark function to measure the time of each run.
  385. :type do_bench: lambda fn, quantiles
  386. :param cache_results: whether to cache autotune timings to disk. Defaults to False.
  387. "type cache_results: bool
  388. """
  389. def decorator(fn):
  390. return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
  391. post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
  392. use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results)
  393. return decorator
  394. class Heuristics(KernelInterface):
  395. def __init__(self, fn, arg_names, values) -> None:
  396. self.fn = fn
  397. self.values = values
  398. self.arg_names = arg_names
  399. def run(self, *args, **kwargs):
  400. for v, heur in self.values.items():
  401. kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
  402. return self.fn.run(*args, **kwargs)
  403. def heuristics(values):
  404. """
  405. Decorator for specifying how the values of certain meta-parameters may be computed.
  406. This is useful for cases where auto-tuning is prohibitively expensive, or just not applicable.
  407. .. highlight:: python
  408. .. code-block:: python
  409. # smallest power-of-two >= x_size
  410. @triton.heuristics(values={'BLOCK_SIZE': lambda args: triton.next_power_of_2(args['x_size'])})
  411. @triton.jit
  412. def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
  413. ...
  414. :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter.
  415. each such function takes a list of positional arguments as input.
  416. :type values: dict[str, Callable[[dict[str, Any]], Any]]
  417. """
  418. def decorator(fn):
  419. return Heuristics(fn, fn.arg_names, values)
  420. return decorator