triton.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. import ast
  2. import contextlib
  3. import inspect
  4. import logging
  5. import threading
  6. from collections.abc import Callable, Generator, Iterable
  7. from typing import Any, Optional, Union
  8. from torch.utils._exposed_in import exposed_in
  9. from .custom_ops import custom_op, CustomOpDef
  10. from .infer_schema import infer_schema
  11. logger = logging.getLogger(__name__)
  12. triton_ops_to_kernels: dict[str, list[object]] = {}
  13. def get_triton_kernels_for_op(name: str) -> list[object]:
  14. return triton_ops_to_kernels.get(name, [])
  15. def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]:
  16. """
  17. Inspect the source of an arbitrary callable passed to torch._library.triton_op,
  18. and grab all of the triton kernels that are wrapped inside of it.
  19. This function traces local variable assignments to handle patterns like:
  20. kernel_fn = _my_kernel # global JITFunction
  21. wrapped = some_wrapper(kernel_fn)
  22. capture_triton(wrapped)[grid](...)
  23. It also recursively analyzes called functions to find triton kernels hidden
  24. behind helper function calls.
  25. That said, it is best effort. There are cases (e.g., recursion > MAX_RECURSION_DEPTH)
  26. that are not accounted for, so keep that in mind.
  27. """
  28. # prevent infinite recursion
  29. MAX_RECURSION_DEPTH = 5
  30. def find_triton_kernels(
  31. fn: Callable[..., Any],
  32. visited_fns: set[int] | None = None,
  33. depth: int = 0,
  34. ) -> list[object]:
  35. try:
  36. from triton.runtime.autotuner import Autotuner
  37. from triton.runtime.jit import JITFunction
  38. except ImportError:
  39. logger.warning("Triton not available, find_triton_kernels = []")
  40. return []
  41. # unwrap decorated fn's (e.g., @lru_cache) to get the original
  42. fn = inspect.unwrap(fn)
  43. # init visited set and check for cycles/depth limit
  44. if visited_fns is None:
  45. visited_fns = set()
  46. fn_id = id(fn)
  47. if fn_id in visited_fns:
  48. return []
  49. if depth > MAX_RECURSION_DEPTH:
  50. logger.debug(
  51. "reached max recursion depth (%s) in find_triton_kernels",
  52. MAX_RECURSION_DEPTH,
  53. )
  54. return []
  55. visited_fns.add(fn_id)
  56. try:
  57. source = inspect.getsource(fn)
  58. except (OSError, TypeError):
  59. return [] # Source code not available
  60. from torch._inductor.utils import IndentedBuffer
  61. buffer = IndentedBuffer()
  62. buffer.splice(source, strip=True)
  63. tree = ast.parse(buffer.getrawvalue())
  64. # Visitor to collect function calls, assignments, and triton kernels
  65. class Visitor(ast.NodeVisitor):
  66. def __init__(self) -> None:
  67. self.triton_kernels: list[Any] = []
  68. # track local variable assignments: var_name -> list of RHS expressions
  69. self.assignments: dict[str, list[ast.expr]] = {}
  70. # track function calls
  71. self.called_functions: list[str] = []
  72. # track return statement expressions
  73. self.return_exprs: list[ast.expr] = []
  74. def visit_Assign(self, node: ast.Assign) -> None:
  75. for target in node.targets:
  76. if isinstance(target, ast.Name):
  77. self.assignments.setdefault(target.id, []).append(node.value)
  78. self.generic_visit(node)
  79. def visit_Return(self, node: ast.Return) -> None:
  80. if node.value is not None:
  81. self.return_exprs.append(node.value)
  82. self.generic_visit(node)
  83. def visit_Call(self, node: ast.Call) -> None:
  84. triton_func_names = ("capture_triton", "wrap_triton")
  85. if isinstance(node.func, ast.Attribute):
  86. attr = node.func
  87. if isinstance(attr.value, ast.Attribute):
  88. if (
  89. isinstance(attr.value.value, ast.Name)
  90. and attr.value.value.id == "torch"
  91. and attr.value.attr == "_library"
  92. and attr.attr in triton_func_names
  93. ):
  94. if node.args and isinstance(node.args[0], ast.Name):
  95. self.triton_kernels.append(node.args[0].id)
  96. elif (
  97. isinstance(attr.value.value, ast.Attribute)
  98. and isinstance(attr.value.value.value, ast.Name)
  99. and attr.value.value.value.id == "torch"
  100. and attr.value.value.attr == "ops"
  101. ):
  102. self.called_functions.append(
  103. f"{attr.value.attr}::{attr.attr}"
  104. )
  105. # Catch capture_triton, wrap_triton that's been
  106. # imported directly
  107. elif isinstance(node.func, ast.Name):
  108. if node.func.id in triton_func_names:
  109. if node.args and isinstance(node.args[0], ast.Name):
  110. self.triton_kernels.append(node.args[0].id)
  111. else:
  112. # track regular function calls for recursive analysis
  113. self.called_functions.append(node.func.id)
  114. self.generic_visit(node)
  115. collector = Visitor()
  116. collector.visit(tree)
  117. def extract_names_from_expr(expr: ast.expr) -> list[str]:
  118. """Extract all Name references from an AST expression."""
  119. names: list[str] = []
  120. class NameExtractor(ast.NodeVisitor):
  121. def visit_Name(self, node: ast.Name) -> None:
  122. names.append(node.id)
  123. def visit_Call(self, node: ast.Call) -> None:
  124. # for function calls, visit the function and all args
  125. self.generic_visit(node)
  126. NameExtractor().visit(expr)
  127. return names
  128. def resolve_to_kernel(obj: object) -> object | None:
  129. """Check if obj is a triton kernel or wrapper and return the kernel."""
  130. if isinstance(obj, (JITFunction, Autotuner)):
  131. return obj
  132. # handle wrappers that have a .fn attribute pointing to JITFunction
  133. if callable(obj) and hasattr(obj, "fn"):
  134. inner = obj.fn
  135. if isinstance(inner, JITFunction):
  136. return inner
  137. return None
  138. def build_namespace(func_obj: object) -> dict[str, Any]:
  139. """Build a combined namespace from a function's globals and closures."""
  140. # unwrap decorated fns (e.g., @lru_cache)
  141. if callable(func_obj):
  142. try:
  143. func_obj = inspect.unwrap(func_obj)
  144. except ValueError:
  145. pass
  146. if not callable(func_obj) or not hasattr(func_obj, "__code__"):
  147. return {}
  148. func_closure_vars = inspect.getclosurevars(func_obj)
  149. namespace: dict[str, Any] = {}
  150. namespace.update(func_closure_vars.builtins)
  151. namespace.update(func_closure_vars.globals)
  152. namespace.update(func_closure_vars.nonlocals)
  153. if hasattr(func_obj, "__globals__"):
  154. namespace.update(func_obj.__globals__)
  155. return namespace
  156. all_names = build_namespace(fn)
  157. def resolve_names_to_kernels(
  158. names: list[str],
  159. namespace: dict[str, Any],
  160. assignments: dict[str, list[ast.expr]] | None = None,
  161. visited: set[str] | None = None,
  162. ) -> list[object]:
  163. """
  164. Resolve a list of names to triton kernels using the given namespace.
  165. """
  166. if visited is None:
  167. visited = set()
  168. results: list[object] = []
  169. for name in names:
  170. if name in visited:
  171. continue
  172. visited.add(name)
  173. if name in namespace:
  174. obj = namespace[name]
  175. kernel = resolve_to_kernel(obj)
  176. if kernel is not None:
  177. results.append(kernel)
  178. continue
  179. # recurse into callable objects (factory fn's),
  180. # unwrapping decorators if applicable
  181. if callable(obj):
  182. try:
  183. unwrapped = inspect.unwrap(obj)
  184. except ValueError:
  185. unwrapped = obj
  186. if hasattr(unwrapped, "__code__"):
  187. nested = find_triton_kernels(
  188. unwrapped, visited_fns, depth + 1
  189. )
  190. if nested:
  191. results.extend(nested)
  192. continue
  193. logger.debug("failed to resolve %s to a triton kernel", name)
  194. elif assignments is not None and name in assignments:
  195. # trace through local assignments
  196. for rhs_expr in assignments[name]:
  197. referenced = extract_names_from_expr(rhs_expr)
  198. traced = resolve_names_to_kernels(
  199. referenced, namespace, assignments, visited
  200. )
  201. results.extend(traced)
  202. else:
  203. logger.debug("%s not found in namespace or assignments", name)
  204. return results
  205. # resolve kernel names, tracing through local variables if needed
  206. resolved: list[object] = []
  207. seen_ids: set[int] = set()
  208. names_to_resolve: list[str] = list(collector.triton_kernels)
  209. for expr in collector.return_exprs:
  210. names_to_resolve.extend(extract_names_from_expr(expr))
  211. for name in names_to_resolve:
  212. traced_objects = resolve_names_to_kernels(
  213. [name], all_names, collector.assignments
  214. )
  215. for obj in traced_objects:
  216. obj_id = id(obj)
  217. if obj_id not in seen_ids:
  218. seen_ids.add(obj_id)
  219. resolved.append(obj)
  220. for func_name in collector.called_functions:
  221. func_obj = all_names.get(func_name)
  222. if func_obj is None:
  223. from torch._library.custom_ops import OPDEFS
  224. if func_name in OPDEFS:
  225. func_obj = OPDEFS[func_name]._abstract_fn
  226. # skip if not a callable or if it's a triton kernel itself
  227. if func_obj is None or not callable(func_obj):
  228. continue
  229. # skip built-in functions and C extensions (they can't contain triton kernels)
  230. if not hasattr(func_obj, "__code__"):
  231. continue
  232. try:
  233. nested_kernels = find_triton_kernels(func_obj, visited_fns, depth + 1)
  234. for kernel in nested_kernels:
  235. kernel_id = id(kernel)
  236. if kernel_id not in seen_ids:
  237. seen_ids.add(kernel_id)
  238. resolved.append(kernel)
  239. except Exception:
  240. logger.debug(
  241. "failed to analyze called function %s", func_name, exc_info=True
  242. )
  243. return resolved
  244. return find_triton_kernels(fn)
  245. @exposed_in("torch.library")
  246. def triton_op(
  247. name: str,
  248. fn: Optional[Callable] = None,
  249. /,
  250. *,
  251. mutates_args: Union[str, Iterable[str]],
  252. schema: Optional[str] = None,
  253. ) -> Callable:
  254. """Create a custom operator whose implementation is backed by 1+ triton kernels.
  255. This is a more structured way of using triton kernels with PyTorch.
  256. Prefer using triton kernels with no ``torch.library`` custom operator wrappers
  257. (like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because
  258. that is simpler;
  259. only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you
  260. want to create an operator that behaves like PyTorch built-in operators.
  261. For example, you may use a ``torch.library`` wrapper API to define the
  262. behavior of the triton kernel when passed a tensor subclass or under
  263. a TorchDispatchMode.
  264. Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op`
  265. when the implementation
  266. consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
  267. custom operators as opaque (:func:`torch.compile` and
  268. :func:`torch.export.export` will never trace into them), but ``triton_op``
  269. makes the implementation visible to these subsystems, allowing them
  270. to optimize the triton kernel(s).
  271. Note that ``fn`` must only consist of calls to PyTorch-understood
  272. operators and triton kernels. Any triton kernels called inside ``fn``
  273. must be wrapped in a call to :func:`torch.library.wrap_triton`.
  274. Args:
  275. name (str): A name for the custom op that looks like "{namespace}::{name}",
  276. e.g. "mylib::my_linear". The name is used as the op's stable identifier
  277. in PyTorch subsystems (e.g. torch.export, FX graphs).
  278. To avoid name collisions, please use your project name as the namespace;
  279. e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
  280. mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
  281. This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
  282. it pessimistically assumes that all inputs to the operator are being mutated.
  283. schema (str | None): A schema string for the operator. If None
  284. (recommended) we'll infer a schema for the operator from its type
  285. annotations. We recommend letting us infer a schema unless you
  286. have a specific reason not to.
  287. Example: "(Tensor x, int y) -> (Tensor, Tensor)".
  288. Example::
  289. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  290. >>> import torch
  291. >>> from torch.library import triton_op, wrap_triton
  292. >>>
  293. >>> import triton
  294. >>> from triton import language as tl
  295. >>>
  296. >>> @triton.jit
  297. >>> def add_kernel(
  298. >>> in_ptr0,
  299. >>> in_ptr1,
  300. >>> out_ptr,
  301. >>> n_elements,
  302. >>> BLOCK_SIZE: "tl.constexpr",
  303. >>> ):
  304. >>> pid = tl.program_id(axis=0)
  305. >>> block_start = pid * BLOCK_SIZE
  306. >>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
  307. >>> mask = offsets < n_elements
  308. >>> x = tl.load(in_ptr0 + offsets, mask=mask)
  309. >>> y = tl.load(in_ptr1 + offsets, mask=mask)
  310. >>> output = x + y
  311. >>> tl.store(out_ptr + offsets, output, mask=mask)
  312. >>>
  313. >>> @triton_op("mylib::add", mutates_args={})
  314. >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  315. >>> output = torch.empty_like(x)
  316. >>> n_elements = output.numel()
  317. >>>
  318. >>> def grid(meta):
  319. >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
  320. >>>
  321. >>> # NB: we need to wrap the triton kernel in a call to wrap_triton
  322. >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
  323. >>> return output
  324. >>>
  325. >>> @torch.compile
  326. >>> def f(x, y):
  327. >>> return add(x, y)
  328. >>>
  329. >>> x = torch.randn(3, device="cuda")
  330. >>> y = torch.randn(3, device="cuda")
  331. >>>
  332. >>> z = f(x, y)
  333. >>> assert torch.allclose(z, x + y)
  334. """
  335. def dec(fn: Callable[..., object]) -> CustomOpDef:
  336. def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
  337. # Optimization: we're passing regular Tensors into the triton kernel, so
  338. # no need to go through HOP dispatch
  339. with set_wrap_triton_enabled(False):
  340. return fn(*args, **kwargs)
  341. result = custom_op(
  342. name,
  343. backend_fn,
  344. mutates_args=mutates_args,
  345. schema=infer_schema(fn, mutates_args=mutates_args),
  346. )
  347. from .._subclasses.functional_tensor import FunctionalTensorMode
  348. # We require that the user pass us a function that is make_fx traceable,
  349. # so we can just register it as the Fake/meta kernel.
  350. result.register_fake(fn)
  351. # We decompose the operator when FunctionalTensorMode is active.
  352. # The goal is to decompose the operator in AOTDispatcher.
  353. # - With torch.compile, this means that the backend (usually Inductor)
  354. # can see a call to the triton kernel(s) and so it can directly optimize
  355. # them by inlining them into the lowering process.
  356. def functional_decomp( # type: ignore[no-untyped-def]
  357. mode, op, types, args, kwargs
  358. ):
  359. # NOTE [Export custom triton op]
  360. # For torch.export (strict and non-strict), we don't do functional decomposition.
  361. # Instead, we preserve the custom triton ops as custom ops. This is because we want
  362. # the exported program to be high-level and serializable. If we decompose
  363. # the custom op to a functional hop and make it a node in exported program,
  364. # we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
  365. # functions and triton dtypes. This is undesirable because:
  366. # - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
  367. # - exported program will contain the implementation detail (e.g. triton source code) for a specific
  368. # backend (GPU), which is probably at a wrong level of abstraction.
  369. # - changes to triton or the serialization logic for triton arguments can be BC breaking
  370. #
  371. # In the short term, we expect users to have a separate aot_compile stage that compiles the exported program
  372. # into a Cubin file on the same machine that users call export, which does autotuning and removes triton
  373. # dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
  374. # In the long term, we may export multiple cubins for the triton op directly
  375. from torch.export._trace import custom_triton_ops_decomposition_disabled
  376. if custom_triton_ops_decomposition_disabled():
  377. return mode.__torch_dispatch__(op, types, args, kwargs)
  378. else:
  379. # TODO: https://github.com/pytorch/pytorch/issues/160333
  380. # We should deduplicate the unrecognized_types logic.
  381. import torch._subclasses
  382. unrecognized_types = [
  383. t
  384. for t in types
  385. if not issubclass(t, torch._subclasses.FakeTensor)
  386. and t
  387. not in [
  388. torch.Tensor,
  389. torch._subclasses.functional_tensor.FunctionalTensor,
  390. ]
  391. ]
  392. if unrecognized_types:
  393. return NotImplemented
  394. with mode:
  395. return fn(*args, **kwargs)
  396. triton_kernels = get_inner_triton_kernels(fn)
  397. triton_ops_to_kernels[name] = triton_kernels
  398. result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
  399. return result
  400. if fn is None:
  401. return dec
  402. else:
  403. return dec(fn)
  404. wrap_triton_enabled = threading.local()
  405. wrap_triton_enabled_default = True
  406. @contextlib.contextmanager
  407. def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]:
  408. """If triton kernels annotated with @wrap_triton should dispatch via HOP
  409. or go straight to the triton kernel execution.
  410. We have this switch because eager-mode performance of HOP dispatch is slow
  411. enough to matter (~1ms) and we know that wrap_triton isn't necessary in
  412. some situations (eager-mode with regular Tensors)
  413. """
  414. try:
  415. prev = is_wrap_triton_enabled()
  416. wrap_triton_enabled.value = enabled
  417. yield
  418. finally:
  419. wrap_triton_enabled.value = prev
  420. def is_wrap_triton_enabled() -> bool:
  421. return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default)
  422. def capture_triton(triton_kernel: Callable, /) -> Any:
  423. """This API has been renamed to wrap_triton"""
  424. return wrap_triton(triton_kernel)
  425. @exposed_in("torch.library")
  426. def wrap_triton(triton_kernel: Callable, /) -> Any:
  427. """Allows capture of a triton kernel into a graph via make_fx or
  428. non-strict ``torch.export``.
  429. These technologies perform Dispatcher-based tracing (via
  430. ``__torch_dispatch__``) and cannot see calls to raw triton kernels.
  431. The ``wrap_triton`` API wraps a triton kernel into a callable that
  432. can actually be traced into a graph.
  433. Please use this API together with :func:`torch.library.triton_op`.
  434. Examples:
  435. >>> # xdoctest: +SKIP
  436. >>> import torch
  437. >>> import triton
  438. >>> from triton import language as tl
  439. >>> from torch.fx.experimental.proxy_tensor import make_fx
  440. >>> from torch.library import wrap_triton
  441. >>>
  442. >>> @triton.jit
  443. >>> def add_kernel(
  444. >>> in_ptr0,
  445. >>> in_ptr1,
  446. >>> out_ptr,
  447. >>> n_elements,
  448. >>> BLOCK_SIZE: "tl.constexpr",
  449. >>> ):
  450. >>> pid = tl.program_id(axis=0)
  451. >>> block_start = pid * BLOCK_SIZE
  452. >>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
  453. >>> mask = offsets < n_elements
  454. >>> x = tl.load(in_ptr0 + offsets, mask=mask)
  455. >>> y = tl.load(in_ptr1 + offsets, mask=mask)
  456. >>> output = x + y
  457. >>> tl.store(out_ptr + offsets, output, mask=mask)
  458. >>>
  459. >>> def add(x, y):
  460. >>> output = torch.empty_like(x)
  461. >>> n_elements = output.numel()
  462. >>>
  463. >>> def grid_fn(meta):
  464. >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
  465. >>>
  466. >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
  467. >>> return output
  468. >>>
  469. >>> x = torch.randn(3, device="cuda")
  470. >>> y = torch.randn(3, device="cuda")
  471. >>> gm = make_fx(add)(x, y)
  472. >>> print(gm.code)
  473. >>> # def forward(self, x_1, y_1):
  474. >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
  475. >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
  476. >>> # kernel_idx = 0, constant_args_idx = 0,
  477. >>> # grid = [(1, 1, 1)], kwargs = {
  478. >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
  479. >>> # 'n_elements': 3, 'BLOCK_SIZE': 16
  480. >>> # })
  481. >>> # return empty_like
  482. """
  483. from triton.runtime.autotuner import Autotuner
  484. from triton.runtime.jit import JITFunction
  485. from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
  486. if not isinstance(triton_kernel, (JITFunction, Autotuner)):
  487. raise RuntimeError(
  488. "wrap_triton only works on functions annotated with triton.jit or triton.autotune"
  489. )
  490. if not is_wrap_triton_enabled():
  491. return triton_kernel
  492. return TraceableTritonKernelWrapper(triton_kernel, None, None)