jit.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099
  1. from __future__ import annotations, division
  2. import ast
  3. import copy
  4. import hashlib
  5. import inspect
  6. import itertools
  7. import threading
  8. import re
  9. import textwrap
  10. from collections import defaultdict
  11. from dataclasses import dataclass
  12. from functools import cached_property
  13. from typing import Callable, Generic, Iterable, Optional, TypeVar, overload, Dict, Any, Tuple
  14. from triton.backends import BaseBackend
  15. from types import ModuleType
  16. from .. import knobs
  17. from .driver import driver
  18. from . import _async_compile
  19. from .._utils import find_paths_if, get_iterable_path, type_canonicalisation_dict, is_namedtuple
  20. from .cache import get_cache_key
  21. from triton._C.libtriton import get_cache_invalidating_env_vars, native_specialize_impl
  22. TRITON_MODULE = "triton.language"
  23. GLUON_MODULE = "triton.experimental.gluon.language"
  24. T = TypeVar("T")
  25. # -----------------------------------------------------------------------------
  26. # Dependencies Finder
  27. # -----------------------------------------------------------------------------
  28. class DependenciesFinder(ast.NodeVisitor):
  29. """
  30. This AST visitor is used to find dependencies of a JITFunction. This can
  31. be used to invalidate a JITFunction's hash when its source code -- or
  32. that of its dependencies -- changes.
  33. This visitor also keeps track of the global variables touched by the
  34. JITFunction. When we launch the kernel, we check that these have the same
  35. values as they did when we ran this visitor. If not, we raise an error (or
  36. otherwise we could recompile).
  37. """
  38. def __init__(self, name, globals, nonlocals, src) -> None:
  39. super().__init__()
  40. self.name = name
  41. self.hasher = hashlib.sha256(src.encode("utf-8"))
  42. # This function's __globals__ dict.
  43. self.globals = globals
  44. self.nonlocals = nonlocals
  45. # Python builtins that can be accessed from Triton kernels.
  46. self.supported_python_builtins = {
  47. 'float',
  48. 'getattr',
  49. 'int',
  50. 'isinstance',
  51. 'len',
  52. 'list',
  53. 'max',
  54. 'min',
  55. 'print',
  56. 'range',
  57. }
  58. self.supported_modules = {
  59. GLUON_MODULE,
  60. TRITON_MODULE,
  61. "copy",
  62. "math",
  63. }
  64. # used_global_vals tells us which global variables are used by this
  65. # function and all those it transitively calls, plus the values of those
  66. # variables when each function was initially run. (That is, if A calls
  67. # C, and B calls C, then the values for C in used_global_vals will be
  68. # from the first time C was run, either by A or B.)
  69. #
  70. # Each function may have a different __globals__ dict, so the global
  71. # variable `foo` may actually have a different value in the different
  72. # functions. Thus this map is actually
  73. # (var_name, id(__globals__)) -> (var_value, __globals__).
  74. self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
  75. self.visiting_arg_default_value = False
  76. @property
  77. def ret(self):
  78. return self.hasher.hexdigest()
  79. def _is_triton_builtin(self, node, func):
  80. if inspect.isbuiltin(node.func):
  81. return True
  82. module = getattr(func, "__module__", "")
  83. return module.startswith(TRITON_MODULE)
  84. def _update_hash(self, func):
  85. assert isinstance(func, JITCallable)
  86. # Merge our used_global_vals with those of the called function,
  87. # after checking that all overlapping values are consistent.
  88. for k in self.used_global_vals.keys() & func.used_global_vals.keys():
  89. var_name, _ = k
  90. v1, _ = self.used_global_vals[k]
  91. v2, _ = func.used_global_vals[k]
  92. if v1 != v2:
  93. raise RuntimeError(
  94. f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed."
  95. )
  96. self.used_global_vals.update(func.used_global_vals)
  97. # update hash
  98. func_key = func.cache_key
  99. func_key += str(getattr(func, "noinline", False))
  100. self.hasher.update(func_key.encode("utf-8"))
  101. def record_reference(self, val, var_dict=None, name=None):
  102. from ..language.core import constexpr
  103. # Only keep track of "interesting" global variables, that non-evil users
  104. # might change. Don't consider functions, modules, builtins, etc. This
  105. # helps keep the list of vars we have to check small.
  106. if val is None or type(val) is ModuleType:
  107. return
  108. if getattr(val, "__triton_aggregate__", False):
  109. for attr in val.hash_attrs:
  110. self.record_reference(attr)
  111. return
  112. if getattr(val, "__triton_builtin__", False):
  113. return
  114. # Stubs that aren't real functions
  115. if getattr(val, "__module__", "") == "triton.language.extra.libdevice":
  116. return
  117. if isinstance(val, JITCallable):
  118. self._update_hash(val)
  119. return
  120. if callable(val) and not isinstance(val, type) and not isinstance(val, constexpr):
  121. raise RuntimeError(f"Unsupported function referenced: {val}")
  122. # Python default arguments are resolved only once, when the
  123. # function is defined. So if you do `foo(a=A)` and the value of
  124. # A changes, foo will still use the old value of A.
  125. # It would be pretty evil if someone did `import x` and then
  126. # `x = blah`.
  127. if self.visiting_arg_default_value:
  128. return
  129. if var_dict is not None:
  130. self.used_global_vals[(name, id(var_dict))] = (copy.deepcopy(val), var_dict)
  131. return
  132. def visit_Name(self, node):
  133. if type(node.ctx) is ast.Store:
  134. return node.id
  135. if node.id in self.local_names:
  136. # The global name is hidden by the local name.
  137. return None
  138. def name_lookup(name):
  139. val = self.globals.get(name, None)
  140. if val is not None:
  141. return val, self.globals
  142. val = self.nonlocals.get(name, None)
  143. if val is not None:
  144. return val, self.nonlocals
  145. return None, None
  146. val, var_dict = name_lookup(node.id)
  147. if node.id in self.supported_python_builtins:
  148. return val
  149. self.record_reference(val, var_dict, node.id)
  150. return val
  151. def visit_Tuple(self, node):
  152. # We need to explicitly return the tuple values so that visit_Assign can
  153. # access them in the case of `a, b = ...`.
  154. return [self.visit(elt) for elt in node.elts]
  155. def visit_Attribute(self, node):
  156. lhs = self.visit(node.value)
  157. while isinstance(lhs, ast.Attribute):
  158. lhs = self.visit(lhs.value)
  159. lhs_name = getattr(lhs, "__name__", "")
  160. if lhs is None or lhs_name in self.supported_modules:
  161. return None
  162. ret = getattr(lhs, node.attr)
  163. self.record_reference(ret)
  164. return ret
  165. def visit_FunctionDef(self, node):
  166. # Save the local name, which may hide the global name.
  167. self.local_names = {arg.arg for arg in node.args.args}
  168. self.generic_visit(node)
  169. def visit_arguments(self, node):
  170. # The purpose of this function is to visit everything in `arguments`
  171. # just like `generic_visit`, except when we're visiting default values
  172. # (i.e. the `foo` part of `def fn(x = foo)`), we set
  173. # self.visiting_arg_default_value = True. This allows visit_Name to be
  174. # aware that we're inside function default values, which have special
  175. # semantics.
  176. # According to the AST docs, the arguments node has the following structure.
  177. #
  178. # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs,
  179. # expr* kw_defaults, arg? kwarg, expr* defaults)
  180. def visit_defaults(defaults):
  181. try:
  182. assert not self.visiting_arg_default_value
  183. self.visiting_arg_default_value = True
  184. for expr in defaults:
  185. if expr is not None:
  186. self.visit(expr)
  187. finally:
  188. self.visiting_arg_default_value = False
  189. for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs):
  190. self.visit(arg)
  191. visit_defaults(node.kw_defaults)
  192. if node.kwarg is not None:
  193. self.visit(node.kwarg)
  194. visit_defaults(node.defaults)
  195. def visitAssnTarget(self, node):
  196. # Target is either a single string, or a list of strings (if the assn
  197. # target is a tuple).
  198. target = self.visit(node)
  199. if isinstance(target, list):
  200. self.local_names |= set(target)
  201. else:
  202. self.local_names.add(target)
  203. def visit_Assign(self, node):
  204. if len(node.targets) != 1:
  205. # TODO(jlebar): I don't actually know how to hit this. You don't
  206. # get it from `a, b = ...` -- in that case, node.targets is a single
  207. # Tuple, and in fact we *do* need to handle that case if we want
  208. # existing code to work.
  209. raise TypeError("Simultaneous multiple assignment is not supported.")
  210. self.visitAssnTarget(node.targets[0])
  211. # This will re-visit the target, but that's OK.
  212. self.generic_visit(node)
  213. def visit_AnnAssign(self, node):
  214. self.visitAssnTarget(node.target)
  215. # This will re-visit the target, but that's OK.
  216. self.generic_visit(node)
  217. def visit_For(self, node):
  218. self.visitAssnTarget(node.target)
  219. # This will re-visit the target, but that's fine.
  220. self.generic_visit(node)
  221. # -----------------------------------------------------------------------------
  222. # JITFunction
  223. # -----------------------------------------------------------------------------
  224. def _normalize_ty(ty) -> str:
  225. import triton.language.core as core
  226. if isinstance(ty, str):
  227. ty = ty.strip()
  228. if ty.startswith("const "):
  229. ty = ty.removeprefix("const")
  230. ty = _normalize_ty(ty)
  231. assert ty.startswith("*")
  232. return "*k" + ty[1:]
  233. if ty.endswith("*"):
  234. return "*" + _normalize_ty(ty[:-1])
  235. if ty.startswith("*"):
  236. return "*" + _normalize_ty(ty[1:])
  237. if ty.startswith("tl."):
  238. return _normalize_ty(ty.removeprefix("tl."))
  239. elif isinstance(ty, core.pointer_type):
  240. return f"*{_normalize_ty(ty.element_ty)}"
  241. elif isinstance(ty, core.dtype):
  242. ty = ty.name
  243. elif isinstance(ty, type):
  244. ty = ty.__name__
  245. else:
  246. ty = str(ty)
  247. return type_canonicalisation_dict.get(ty.replace("_t", ""), ty)
  248. class KernelParam:
  249. """Represents a parameter (name plus metadata) to a @jit'ed function."""
  250. def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool,
  251. do_not_specialize_on_alignment: bool):
  252. self.num = num
  253. self._param = param
  254. self.do_not_specialize = do_not_specialize
  255. self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
  256. @cached_property
  257. def name(self):
  258. return self._param.name
  259. @cached_property
  260. def annotation(self) -> str:
  261. if not self._param.annotation or self._param.annotation == inspect.Parameter.empty:
  262. return ""
  263. return _normalize_ty(self._param.annotation)
  264. @cached_property
  265. def annotation_type(self) -> str:
  266. a = self.annotation
  267. if a.startswith("*k"):
  268. a = a[2:]
  269. elif a.startswith("*"):
  270. a = a[1:]
  271. if a in set(type_canonicalisation_dict.values()):
  272. return self.annotation
  273. return ""
  274. @cached_property
  275. def is_constexpr(self):
  276. return "constexpr" in self.annotation
  277. @cached_property
  278. def is_const(self):
  279. if self.is_constexpr:
  280. return False
  281. return "const" in self.annotation or self.annotation.startswith("*k")
  282. @property
  283. def default(self):
  284. return self._param.default
  285. @property
  286. def has_default(self):
  287. return self._param.default != inspect.Parameter.empty
  288. def mangle_type(arg, specialize=False):
  289. is_const = False
  290. align = True
  291. return native_specialize_impl(BaseBackend, arg, is_const, specialize, align)[0]
  292. class KernelInterface(Generic[T]):
  293. run: T
  294. def warmup(self, *args, grid, **kwargs):
  295. return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs)
  296. def run(self, *args, grid, warmup, **kwargs):
  297. raise NotImplementedError("run not implemented")
  298. def __getitem__(self, grid) -> T:
  299. """
  300. A JIT function is launched with: fn[grid](*args, **kwargs).
  301. Hence JITFunction.__getitem__ returns a callable proxy that
  302. memorizes the grid.
  303. """
  304. return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  305. # return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
  306. def serialize_specialization_data(name, signature, constants, attrs, options, key):
  307. constants = {
  308. key: str(value) if value.__class__.__name__ == "dtype" else
  309. {"constexpr": value.value} if value.__class__.__name__ == "constexpr" else value
  310. for key, value in constants.items()
  311. }
  312. import json
  313. obj = {
  314. 'name': name, 'signature': signature, 'constant_keys': [list(x) for x in constants.keys()], 'constant_vals':
  315. list(constants.values()), 'attrs_keys': [list(x) for x in attrs.keys()], 'attrs_vals': list(attrs.values()),
  316. 'options': options.__dict__, 'key': key
  317. }
  318. serialized_obj = json.dumps(obj)
  319. return serialized_obj
  320. def create_function_from_signature(sig, kparams, backend):
  321. """
  322. Equivalent to sig.bind followed by apply_defaults. This generates a
  323. native Python function (using exec) which can be memoized on a per-kernel
  324. basis to avoid having to run these expensive functions -- which constitute
  325. much of the kernel launch overhead -- every time we run the kernel.
  326. """
  327. assert len(sig.parameters) == len(kparams)
  328. # Create the function argument list and the dict entries for the return statement
  329. specialization = []
  330. # signature
  331. for name, kp in zip(sig.parameters.keys(), kparams):
  332. if kp.is_constexpr:
  333. specialization.append(f'("constexpr", {name})')
  334. else:
  335. is_const = 'True' if kp.is_const else 'False'
  336. specialize = 'False' if kp.do_not_specialize else 'True'
  337. align = 'False' if kp.do_not_specialize_on_alignment else 'True'
  338. ret = f"specialize_impl(backend, {name}, {is_const}, {specialize}, {align})"
  339. if kp.annotation_type:
  340. if isinstance(kp.annotation_type, str):
  341. if kp.annotation_type == "u1" or kp.annotation_type[:2] in ["fp", "bf"]:
  342. # we do not specialize non-constexpr floats and bools:
  343. specialize = False
  344. if specialize:
  345. specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
  346. else:
  347. # skip runtime specialization:
  348. specialization.append(f'("{kp.annotation_type}", None)')
  349. else:
  350. specialization.append(f"{ret}")
  351. # compute argument string for a given parameter
  352. arg = lambda x: x[0] if x[1].default is inspect.Parameter.empty else f"{x[0]}=default_{x[0]}"
  353. func_body = f"""
  354. def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options"])}):
  355. params = {{{', '.join([f"'{name}': {name}" for name in sig.parameters.keys()])}}}
  356. specialization = [{','.join(specialization)}]
  357. return params, specialization, options
  358. """
  359. # Prepare defaults to be inserted into function namespace
  360. func_namespace = {
  361. f"default_{name}": param.default
  362. for name, param in sig.parameters.items()
  363. if param.default is not inspect.Parameter.empty
  364. }
  365. specialize_impl = native_specialize_impl
  366. func_namespace["specialize_impl"] = specialize_impl
  367. func_namespace["backend"] = backend
  368. func_namespace["JITCallable"] = JITCallable
  369. # Execute the function string in func_namespace to create the function
  370. exec(func_body, func_namespace)
  371. # Extract the newly created function from the namespace
  372. return func_namespace['dynamic_func']
  373. def get_full_name(fn):
  374. return f"{fn.__module__}.{fn.__qualname__}"
  375. class JITCallable:
  376. def __init__(self, fn):
  377. self.fn = fn
  378. self.signature = inspect.signature(fn)
  379. try:
  380. self.raw_src, self.starting_line_number = inspect.getsourcelines(fn)
  381. except OSError as e:
  382. raise ValueError("@jit functions should be defined in a Python file") from e
  383. self._fn_name = get_full_name(fn)
  384. self._hash_lock = threading.RLock()
  385. # function source code (without decorators)
  386. src = textwrap.dedent("".join(self.raw_src))
  387. src = src[re.search(r"^def\s+\w+\s*\(", src, re.MULTILINE).start():]
  388. self._src = src
  389. self.hash = None
  390. # Map of global variables used by the function and any functions it
  391. # transitively calls, plus their values. The values are collected when
  392. # the function is first compiled. Then every time we run the function,
  393. # we check that the values of the globals match what's expected,
  394. # otherwise we raise an error.
  395. #
  396. # Different functions can have different __globals__ maps, so the map
  397. # key is actually (var name, id(__globals__)), and the map value is
  398. # (value, __globals__).
  399. self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {}
  400. # reuse docs of wrapped function
  401. self.__doc__ = fn.__doc__
  402. self.__name__ = fn.__name__
  403. self.__qualname__ = fn.__qualname__
  404. self.__globals__ = fn.__globals__
  405. self.__module__ = fn.__module__
  406. def get_capture_scope(self):
  407. return self.__globals__ | inspect.getclosurevars(self.fn).nonlocals
  408. @property
  409. def cache_key(self) -> str:
  410. # TODO : hash should be attribute of `self`
  411. with self._hash_lock:
  412. if self.hash is not None:
  413. return self.hash
  414. # Set a placeholder hash to break recursion in case the function
  415. # transitively calls itself. The full hash is set after.
  416. self.hash = f"recursion:{self._fn_name}"
  417. nonlocals = inspect.getclosurevars(self.fn).nonlocals
  418. dependencies_finder = DependenciesFinder(name=self._fn_name, globals=self.__globals__, nonlocals=nonlocals,
  419. src=self.src)
  420. dependencies_finder.visit(self.parse())
  421. self.hash = dependencies_finder.ret + str(self.starting_line_number)
  422. self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
  423. from triton.language.core import constexpr
  424. self.hash += str([(name, val)
  425. for (name, _), (val, _) in self.used_global_vals.items()
  426. if isinstance(val, constexpr)])
  427. self.hash = hashlib.sha256(self.hash.encode("utf-8")).hexdigest()
  428. return self.hash
  429. def __hash__(self):
  430. return hash(self.cache_key)
  431. # we do not parse `src` in the constructor because
  432. # the user might want to monkey-patch self.src dynamically.
  433. # Our unit tests do this, for example.
  434. def parse(self):
  435. tree = ast.parse(self._src)
  436. assert isinstance(tree, ast.Module)
  437. assert len(tree.body) == 1
  438. assert isinstance(tree.body[0], ast.FunctionDef)
  439. return tree
  440. @property
  441. def type(self):
  442. from triton.language.core import constexpr_type
  443. return constexpr_type(self)
  444. def _unsafe_update_src(self, new_src):
  445. """
  446. The only method allowed to modify src.
  447. Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
  448. Note that it is the callers responsibility to make sure any triton functions that call this function have the `.hash` value reset to None.
  449. """
  450. self.hash = None
  451. self._src = new_src
  452. def _set_src(self):
  453. raise AttributeError("Cannot set attribute 'src' directly. "
  454. "Use '_unsafe_update_src()' and manually clear `.hash` of all callers"
  455. "instead.")
  456. def _get_src(self):
  457. return self._src
  458. src = property(fget=_get_src, fset=_set_src)
  459. @dataclass
  460. class JitFunctionInfo:
  461. module: ModuleType
  462. name: str
  463. jit_function: JITFunction
  464. def compute_cache_key(kernel_key_cache, specialization, options):
  465. key = (tuple(specialization), str(options))
  466. cache_key = kernel_key_cache.get(key, None)
  467. if cache_key is not None:
  468. return cache_key
  469. # Replace JITCallable objects with their hash, so the cache key will change if the src is updated
  470. def replace_callables(obj):
  471. if isinstance(obj, list):
  472. return [replace_callables(arg) for arg in obj]
  473. elif is_namedtuple(obj):
  474. results = [replace_callables(arg) for arg in obj]
  475. return obj.__class__(*results)
  476. elif isinstance(obj, tuple):
  477. return tuple(replace_callables(arg) for arg in obj)
  478. elif isinstance(obj, JITCallable):
  479. return obj.cache_key
  480. return obj
  481. cache_key = str(replace_callables(specialization)) + str(options)
  482. kernel_key_cache[key] = cache_key
  483. return cache_key
  484. def convert_to_tuple_if_list(item):
  485. # If the incoming item is a list, recursively iterate through it to convert all lists therein into tuples
  486. if not isinstance(item, list):
  487. return item
  488. # The value must be a list at this point
  489. for i, nested_value in enumerate(item):
  490. item[i] = convert_to_tuple_if_list(nested_value)
  491. return tuple(item)
  492. class JITFunction(JITCallable, KernelInterface[T]):
  493. def is_gluon(self):
  494. return False
  495. def _call_hook(
  496. self,
  497. hook,
  498. key,
  499. signature,
  500. device,
  501. constants,
  502. options,
  503. configs,
  504. is_warmup,
  505. ) -> bool | None:
  506. if not hook:
  507. return None
  508. name = self.fn.__qualname__
  509. module = self.fn.__module__
  510. arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])])
  511. repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}, launch_cooperative_grid={options.launch_cooperative_grid}]({arg_reprs})"
  512. full_name = get_full_name(self.fn)
  513. specialization_data = serialize_specialization_data(full_name, signature, constants, configs[0], options, key)
  514. kwargs = {
  515. 'signature': signature,
  516. 'device': device,
  517. 'constants': constants,
  518. 'num_warps': options.num_warps,
  519. 'num_ctas': options.num_ctas,
  520. 'num_stages': options.num_stages,
  521. 'enable_fp_fusion': options.enable_fp_fusion,
  522. 'launch_cooperative_grid': options.launch_cooperative_grid,
  523. 'extern_libs': options.extern_libs,
  524. 'configs': configs,
  525. 'specialization_data': specialization_data,
  526. 'is_warmup': is_warmup,
  527. }
  528. return hook(
  529. key=key,
  530. repr=repr,
  531. fn=JitFunctionInfo(module, name, self),
  532. compile={"key": key, **kwargs},
  533. is_manual_warmup=is_warmup,
  534. already_compiled=False,
  535. )
  536. def add_pre_run_hook(self, hook):
  537. '''
  538. Add a hook that will be executed prior to the execution of run
  539. function with args and kwargs passed into the kernel
  540. '''
  541. assert callable(hook)
  542. self.pre_run_hooks.append(hook)
  543. def create_binder(self):
  544. """
  545. Precompute as much as possible.
  546. """
  547. from ..compiler import CompiledKernel, compile, ASTSource, make_backend
  548. target = driver.active.get_current_target()
  549. backend = make_backend(target)
  550. self.CompiledKernel = CompiledKernel
  551. self.compile = compile
  552. self.ASTSource = ASTSource
  553. binder = create_function_from_signature(self.signature, self.params, backend)
  554. return {}, {}, target, backend, binder
  555. def _pack_args(self, backend, kwargs, bound_args, specialization, options):
  556. # options
  557. options = backend.parse_options(kwargs)
  558. # signature
  559. sigkeys = [x.name for x in self.params]
  560. sigvals = [x[0] for x in specialization]
  561. signature = {k: v for (k, v) in zip(sigkeys, sigvals)}
  562. # check arguments
  563. assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used"
  564. assert "device" not in kwargs, "device option is deprecated; current device will be used"
  565. assert "stream" not in kwargs, "stream option is deprecated; current stream will be used"
  566. for k in kwargs:
  567. if k not in options.__dict__ and k not in sigkeys:
  568. raise KeyError("Keyword argument %s was specified but unrecognised" % k)
  569. # constexprs
  570. constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr")
  571. constexprs = {path: get_iterable_path(list(bound_args.values()), path) for path in constexprs}
  572. # attributes
  573. attrvals = [x[1] for x in specialization]
  574. attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
  575. attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs}
  576. return options, signature, constexprs, attrs
  577. def run(self, *args, grid, warmup, **kwargs):
  578. kwargs["debug"] = kwargs.get("debug", self.debug) or knobs.runtime.debug
  579. kwargs["instrumentation_mode"] = knobs.compilation.instrumentation_mode
  580. # parse options
  581. device = driver.active.get_current_device()
  582. stream = driver.active.get_current_stream(device)
  583. # Execute pre run hooks with args and kwargs
  584. for hook in self.pre_run_hooks:
  585. hook(*args, **kwargs)
  586. kernel_cache, kernel_key_cache, target, backend, binder = self.device_caches[device]
  587. # specialization is list[tuple[str, Any]], where first element of tuple is
  588. # the type and the second parameter is the 'specialization' value.
  589. bound_args, specialization, options = binder(*args, **kwargs)
  590. key = compute_cache_key(kernel_key_cache, specialization, options)
  591. kernel = kernel_cache.get(key, None)
  592. # Kernel is not cached; we have to compile.
  593. if kernel is None:
  594. options, signature, constexprs, attrs = self._pack_args(backend, kwargs, bound_args, specialization,
  595. options)
  596. kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
  597. if kernel is None:
  598. return None
  599. # Check that used global values have not changed.
  600. not_present = object()
  601. for (name, _), (val, globals_dict) in self.used_global_vals.items():
  602. if (newVal := globals_dict.get(name, not_present)) != val:
  603. raise RuntimeError(
  604. f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")
  605. if not warmup:
  606. # canonicalize grid
  607. assert grid is not None
  608. if callable(grid):
  609. grid = grid(bound_args)
  610. grid_size = len(grid)
  611. grid_0 = grid[0]
  612. grid_1 = grid[1] if grid_size > 1 else 1
  613. grid_2 = grid[2] if grid_size > 2 else 1
  614. if hasattr(kernel, "result"):
  615. kernel = kernel.result()
  616. # launch kernel
  617. launch_metadata = kernel.launch_metadata(grid, stream, *bound_args.values())
  618. kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  619. knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *bound_args.values())
  620. return kernel
  621. def repr(self, _):
  622. return self._fn_name if self._repr is None else self._repr(_)
  623. def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
  624. noinline=None, repr=None, launch_metadata=None):
  625. do_not_specialize = do_not_specialize if do_not_specialize else []
  626. do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []
  627. super().__init__(fn)
  628. self.module = fn.__module__
  629. self.version = version
  630. self.do_not_specialize = do_not_specialize
  631. self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
  632. self._repr = repr
  633. self.launch_metadata = launch_metadata
  634. self.params = []
  635. for i, param in enumerate(self.signature.parameters.values()):
  636. dns = i in do_not_specialize or param.name in do_not_specialize
  637. dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
  638. self.params.append(KernelParam(i, param, dns, dns_oa))
  639. # cache of just-in-time compiled kernels
  640. self.device_caches = defaultdict(self.create_binder)
  641. # JITFunction can be instantiated as kernel
  642. # when called with a grid using __getitem__
  643. self.kernel = None
  644. self.debug = debug
  645. self.noinline = noinline
  646. # TODO(jlebar): Remove uses of these fields outside this file, then
  647. # remove the fields here.
  648. self.arg_names = [p.name for p in self.params]
  649. self.constexprs = [p.num for p in self.params if p.is_constexpr]
  650. # Hooks that will be called prior to executing "run"
  651. self.pre_run_hooks = []
  652. def preload(self, specialization_data):
  653. import json
  654. import triton.language as tl
  655. device = driver.active.get_current_device()
  656. deserialized_obj = json.loads(specialization_data)
  657. if deserialized_obj['name'] != self._fn_name:
  658. raise RuntimeError(
  659. f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self._fn_name}")
  660. constant_keys = map(tuple, deserialized_obj['constant_keys'])
  661. constant_vals = deserialized_obj['constant_vals']
  662. constexprs = {
  663. key:
  664. tl.dtype(value) if tl.dtype.is_dtype(value) else
  665. tl.constexpr(value['constexpr']) if isinstance(value, dict) and 'constexpr' in value else value
  666. for key, value in zip(constant_keys, constant_vals)
  667. }
  668. attrs_keys = map(tuple, deserialized_obj['attrs_keys'])
  669. attrs_vals = deserialized_obj['attrs_vals']
  670. attrs = dict(zip(attrs_keys, attrs_vals))
  671. # JSON serializes tuples as lists, so they need to be converted back;
  672. # This can be done unconditionally, since lists are not accepted in Triton kernel signatures.
  673. signature = {key: convert_to_tuple_if_list(value) for key, value in deserialized_obj['signature'].items()}
  674. options = {
  675. key: tuple(value) if isinstance(value, list) else value
  676. for key, value in deserialized_obj['options'].items()
  677. }
  678. key = deserialized_obj['key']
  679. _, _, _, backend, _ = self.device_caches[device]
  680. options = backend.parse_options(options)
  681. return self._do_compile(
  682. key,
  683. signature,
  684. device,
  685. constexprs,
  686. options,
  687. attrs,
  688. warmup=True,
  689. )
  690. def _do_compile(self, key, signature, device, constexprs, options, attrs, warmup):
  691. kernel_cache, _, target, backend, _ = self.device_caches[device]
  692. if self._call_hook(knobs.runtime.jit_cache_hook, key, signature, device, constexprs, options, [attrs], warmup):
  693. return None
  694. src = self.ASTSource(self, signature, constexprs, attrs)
  695. async_mode = _async_compile.active_mode.get()
  696. if async_mode is not None:
  697. env_vars = get_cache_invalidating_env_vars()
  698. cache_key = get_cache_key(src, backend, options, env_vars)
  699. def async_compile():
  700. return self.compile(src, target=target, options=options.__dict__, _env_vars=env_vars)
  701. def finalize_compile(kernel):
  702. kernel_cache[key] = kernel
  703. self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options,
  704. [attrs], warmup)
  705. kernel = async_mode.submit(cache_key, async_compile, finalize_compile)
  706. else:
  707. kernel = self.compile(src, target=target, options=options.__dict__)
  708. kernel_cache[key] = kernel
  709. self._call_hook(knobs.runtime.jit_post_compile_hook, key, signature, device, constexprs, options, [attrs],
  710. warmup)
  711. return kernel
  712. def __call__(self, *args, **kwargs):
  713. raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel")
  714. def __repr__(self):
  715. return f"JITFunction({self.module}:{self.fn.__qualname__})"
  716. # -----------------------------------------------------------------------------
  717. # `jit` decorator
  718. # -----------------------------------------------------------------------------
  719. @overload
  720. def jit(fn: T) -> JITFunction[T]:
  721. ...
  722. @overload
  723. def jit(
  724. *,
  725. version=None,
  726. repr: Optional[Callable] = None,
  727. launch_metadata: Optional[Callable] = None,
  728. do_not_specialize: Optional[Iterable[int | str]] = None,
  729. do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
  730. debug: Optional[bool] = None,
  731. noinline: Optional[bool] = None,
  732. ) -> Callable[[T], JITFunction[T]]:
  733. ...
  734. def jit(
  735. fn: Optional[T] = None,
  736. *,
  737. version=None,
  738. repr: Optional[Callable] = None,
  739. launch_metadata: Optional[Callable] = None,
  740. do_not_specialize: Optional[Iterable[int | str]] = None,
  741. do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
  742. debug: Optional[bool] = None,
  743. noinline: Optional[bool] = None,
  744. ) -> KernelInterface[T]:
  745. """
  746. Decorator for JIT-compiling a function using the Triton compiler.
  747. :note: When a jit'd function is called, arguments are
  748. implicitly converted to pointers if they have a :code:`.data_ptr()` method
  749. and a `.dtype` attribute.
  750. :note: This function will be compiled and run on the GPU. It will only have access to:
  751. * python primitives,
  752. * builtins within the triton package,
  753. * arguments to this function,
  754. * other jit'd functions
  755. :param fn: the function to be jit-compiled
  756. :type fn: Callable
  757. """
  758. def decorator(fn: T) -> JITFunction[T]:
  759. assert callable(fn)
  760. if knobs.runtime.interpret:
  761. from .interpreter import InterpretedFunction
  762. return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize,
  763. do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug,
  764. noinline=noinline, repr=repr, launch_metadata=launch_metadata)
  765. else:
  766. return JITFunction(
  767. fn,
  768. version=version,
  769. do_not_specialize=do_not_specialize,
  770. do_not_specialize_on_alignment=do_not_specialize_on_alignment,
  771. debug=debug,
  772. noinline=noinline,
  773. repr=repr,
  774. launch_metadata=launch_metadata,
  775. )
  776. if fn is not None:
  777. return decorator(fn)
  778. else:
  779. return decorator
  780. # -----------------------------------------------------------------------------
  781. # Utilities for mocking tensors
  782. # -----------------------------------------------------------------------------
  783. class MockTensor:
  784. """
  785. Can be used in place of real tensors when calling:
  786. kernel.warmup(MockTensor(torch.float32), ...)
  787. """
  788. @staticmethod
  789. def wrap_dtype(arg):
  790. if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch":
  791. return MockTensor(arg)
  792. return arg
  793. def __init__(self, dtype, shape=None):
  794. if shape is None:
  795. shape = [1]
  796. self.dtype = dtype
  797. self.shape = shape
  798. def stride(self):
  799. strides = [1]
  800. for size in self.shape[1:]:
  801. strides.append(strides[-1] * size)
  802. return tuple(reversed(strides))
  803. @staticmethod
  804. def data_ptr():
  805. return 0 # optimistically assumes multiple of 16
  806. @staticmethod
  807. def ptr_range():
  808. return 0 # optimistically assumes 32 bit pointer range
  809. class TensorWrapper:
  810. def __init__(self, base, dtype):
  811. self.dtype = dtype
  812. self.base = base
  813. self.data = base.data
  814. self.device = base.device
  815. self.shape = self.base.shape
  816. def data_ptr(self):
  817. return self.base.data_ptr()
  818. def stride(self, *args):
  819. return self.base.stride(*args)
  820. def __str__(self) -> str:
  821. return f"TensorWrapper[{self.dtype}]({self.base})"
  822. def element_size(self):
  823. return self.base.element_size()
  824. def cpu(self):
  825. return TensorWrapper(self.base.cpu(), self.dtype)
  826. def copy_(self, other):
  827. self.base.copy_(other.base)
  828. def clone(self):
  829. return TensorWrapper(self.base.clone(), self.dtype)
  830. def to(self, device):
  831. return TensorWrapper(self.base.to(device), self.dtype)
  832. def new_empty(self, sizes):
  833. return TensorWrapper(self.base.new_empty(sizes), self.dtype)
  834. def reinterpret(tensor, dtype):
  835. if isinstance(tensor, TensorWrapper):
  836. if dtype == tensor.base.dtype:
  837. # Reinterpreting to the original interpretation; return the base.
  838. return tensor.base
  839. else:
  840. # Reinterpreting a wrapped tensor to a different type.
  841. return TensorWrapper(tensor.base, dtype)
  842. elif hasattr(tensor, "data_ptr"):
  843. # A new wrapper is needed around an unwrapped tensor.
  844. return TensorWrapper(tensor, dtype)
  845. else:
  846. raise TypeError(f"Cannot reinterpret a {type(tensor)}.")
  847. def get_jit_fn_file_line(fn):
  848. base_fn = fn
  849. while not isinstance(base_fn, JITCallable):
  850. base_fn = base_fn.fn
  851. file_name = base_fn.fn.__code__.co_filename
  852. begin_line = base_fn.starting_line_number
  853. # Match the following pattern:
  854. # @triton.autotune(...) <- foo.__code__.co_firstlineno
  855. # @triton.heuristics(...)
  856. # @triton.jit
  857. # def foo(...): <- this line is the first line
  858. for idx, line in enumerate(base_fn.raw_src):
  859. if line.strip().startswith("def "):
  860. begin_line += idx
  861. break
  862. return file_name, begin_line
  863. class BoundConstexprFunction(JITCallable):
  864. def __init__(self, instance, fn):
  865. self.__self__ = instance
  866. self.__func__ = fn
  867. @property
  868. def cache_key(self):
  869. return self.__func__.cache_key
  870. def __call__(self, *args, **kwargs):
  871. return self.__func__(self.__self__, *args, **kwargs)
  872. class ConstexprFunction(JITCallable):
  873. def __init__(self, fn):
  874. super().__init__(fn)
  875. def __get__(self, obj, objclass):
  876. # Create a bound function to support constexpr_function methods
  877. if obj is not None:
  878. return BoundConstexprFunction(obj, self)
  879. return self
  880. def __call__(self, *args, _semantic=None, **kwargs):
  881. from triton.language.core import _unwrap_if_constexpr, constexpr
  882. # de-constexpr arguments and discard the _semantic keyword argument:
  883. args = [_unwrap_if_constexpr(x) for x in args]
  884. kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
  885. # call the raw Python function f:
  886. res = self.fn(*args, **kwargs)
  887. if _semantic is None:
  888. # Not called by triton code generator, e.g. in host code, another constexpr function, or even an aggreate's __init__ function
  889. return res
  890. # convert result back to a Triton constexpr:
  891. if knobs.runtime.interpret:
  892. return res # No constexpr in interpreter
  893. return constexpr(res)
  894. def constexpr_function(fn):
  895. """
  896. Wraps an arbitrary Python function so that it can be called at
  897. compile-time on constexpr arguments in a Triton function and
  898. returns a constexpr result.
  899. """
  900. return ConstexprFunction(fn)