| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577 |
- import ast
- import contextlib
- import inspect
- import logging
- import threading
- from collections.abc import Callable, Generator, Iterable
- from typing import Any, Optional, Union
- from torch.utils._exposed_in import exposed_in
- from .custom_ops import custom_op, CustomOpDef
- from .infer_schema import infer_schema
- logger = logging.getLogger(__name__)
- triton_ops_to_kernels: dict[str, list[object]] = {}
- def get_triton_kernels_for_op(name: str) -> list[object]:
- return triton_ops_to_kernels.get(name, [])
- def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]:
- """
- Inspect the source of an arbitrary callable passed to torch._library.triton_op,
- and grab all of the triton kernels that are wrapped inside of it.
- This function traces local variable assignments to handle patterns like:
- kernel_fn = _my_kernel # global JITFunction
- wrapped = some_wrapper(kernel_fn)
- capture_triton(wrapped)[grid](...)
- It also recursively analyzes called functions to find triton kernels hidden
- behind helper function calls.
- That said, it is best effort. There are cases (e.g., recursion > MAX_RECURSION_DEPTH)
- that are not accounted for, so keep that in mind.
- """
- # prevent infinite recursion
- MAX_RECURSION_DEPTH = 5
- def find_triton_kernels(
- fn: Callable[..., Any],
- visited_fns: set[int] | None = None,
- depth: int = 0,
- ) -> list[object]:
- try:
- from triton.runtime.autotuner import Autotuner
- from triton.runtime.jit import JITFunction
- except ImportError:
- logger.warning("Triton not available, find_triton_kernels = []")
- return []
- # unwrap decorated fn's (e.g., @lru_cache) to get the original
- fn = inspect.unwrap(fn)
- # init visited set and check for cycles/depth limit
- if visited_fns is None:
- visited_fns = set()
- fn_id = id(fn)
- if fn_id in visited_fns:
- return []
- if depth > MAX_RECURSION_DEPTH:
- logger.debug(
- "reached max recursion depth (%s) in find_triton_kernels",
- MAX_RECURSION_DEPTH,
- )
- return []
- visited_fns.add(fn_id)
- try:
- source = inspect.getsource(fn)
- except (OSError, TypeError):
- return [] # Source code not available
- from torch._inductor.utils import IndentedBuffer
- buffer = IndentedBuffer()
- buffer.splice(source, strip=True)
- tree = ast.parse(buffer.getrawvalue())
- # Visitor to collect function calls, assignments, and triton kernels
- class Visitor(ast.NodeVisitor):
- def __init__(self) -> None:
- self.triton_kernels: list[Any] = []
- # track local variable assignments: var_name -> list of RHS expressions
- self.assignments: dict[str, list[ast.expr]] = {}
- # track function calls
- self.called_functions: list[str] = []
- # track return statement expressions
- self.return_exprs: list[ast.expr] = []
- def visit_Assign(self, node: ast.Assign) -> None:
- for target in node.targets:
- if isinstance(target, ast.Name):
- self.assignments.setdefault(target.id, []).append(node.value)
- self.generic_visit(node)
- def visit_Return(self, node: ast.Return) -> None:
- if node.value is not None:
- self.return_exprs.append(node.value)
- self.generic_visit(node)
- def visit_Call(self, node: ast.Call) -> None:
- triton_func_names = ("capture_triton", "wrap_triton")
- if isinstance(node.func, ast.Attribute):
- attr = node.func
- if isinstance(attr.value, ast.Attribute):
- if (
- isinstance(attr.value.value, ast.Name)
- and attr.value.value.id == "torch"
- and attr.value.attr == "_library"
- and attr.attr in triton_func_names
- ):
- if node.args and isinstance(node.args[0], ast.Name):
- self.triton_kernels.append(node.args[0].id)
- elif (
- isinstance(attr.value.value, ast.Attribute)
- and isinstance(attr.value.value.value, ast.Name)
- and attr.value.value.value.id == "torch"
- and attr.value.value.attr == "ops"
- ):
- self.called_functions.append(
- f"{attr.value.attr}::{attr.attr}"
- )
- # Catch capture_triton, wrap_triton that's been
- # imported directly
- elif isinstance(node.func, ast.Name):
- if node.func.id in triton_func_names:
- if node.args and isinstance(node.args[0], ast.Name):
- self.triton_kernels.append(node.args[0].id)
- else:
- # track regular function calls for recursive analysis
- self.called_functions.append(node.func.id)
- self.generic_visit(node)
- collector = Visitor()
- collector.visit(tree)
- def extract_names_from_expr(expr: ast.expr) -> list[str]:
- """Extract all Name references from an AST expression."""
- names: list[str] = []
- class NameExtractor(ast.NodeVisitor):
- def visit_Name(self, node: ast.Name) -> None:
- names.append(node.id)
- def visit_Call(self, node: ast.Call) -> None:
- # for function calls, visit the function and all args
- self.generic_visit(node)
- NameExtractor().visit(expr)
- return names
- def resolve_to_kernel(obj: object) -> object | None:
- """Check if obj is a triton kernel or wrapper and return the kernel."""
- if isinstance(obj, (JITFunction, Autotuner)):
- return obj
- # handle wrappers that have a .fn attribute pointing to JITFunction
- if callable(obj) and hasattr(obj, "fn"):
- inner = obj.fn
- if isinstance(inner, JITFunction):
- return inner
- return None
- def build_namespace(func_obj: object) -> dict[str, Any]:
- """Build a combined namespace from a function's globals and closures."""
- # unwrap decorated fns (e.g., @lru_cache)
- if callable(func_obj):
- try:
- func_obj = inspect.unwrap(func_obj)
- except ValueError:
- pass
- if not callable(func_obj) or not hasattr(func_obj, "__code__"):
- return {}
- func_closure_vars = inspect.getclosurevars(func_obj)
- namespace: dict[str, Any] = {}
- namespace.update(func_closure_vars.builtins)
- namespace.update(func_closure_vars.globals)
- namespace.update(func_closure_vars.nonlocals)
- if hasattr(func_obj, "__globals__"):
- namespace.update(func_obj.__globals__)
- return namespace
- all_names = build_namespace(fn)
- def resolve_names_to_kernels(
- names: list[str],
- namespace: dict[str, Any],
- assignments: dict[str, list[ast.expr]] | None = None,
- visited: set[str] | None = None,
- ) -> list[object]:
- """
- Resolve a list of names to triton kernels using the given namespace.
- """
- if visited is None:
- visited = set()
- results: list[object] = []
- for name in names:
- if name in visited:
- continue
- visited.add(name)
- if name in namespace:
- obj = namespace[name]
- kernel = resolve_to_kernel(obj)
- if kernel is not None:
- results.append(kernel)
- continue
- # recurse into callable objects (factory fn's),
- # unwrapping decorators if applicable
- if callable(obj):
- try:
- unwrapped = inspect.unwrap(obj)
- except ValueError:
- unwrapped = obj
- if hasattr(unwrapped, "__code__"):
- nested = find_triton_kernels(
- unwrapped, visited_fns, depth + 1
- )
- if nested:
- results.extend(nested)
- continue
- logger.debug("failed to resolve %s to a triton kernel", name)
- elif assignments is not None and name in assignments:
- # trace through local assignments
- for rhs_expr in assignments[name]:
- referenced = extract_names_from_expr(rhs_expr)
- traced = resolve_names_to_kernels(
- referenced, namespace, assignments, visited
- )
- results.extend(traced)
- else:
- logger.debug("%s not found in namespace or assignments", name)
- return results
- # resolve kernel names, tracing through local variables if needed
- resolved: list[object] = []
- seen_ids: set[int] = set()
- names_to_resolve: list[str] = list(collector.triton_kernels)
- for expr in collector.return_exprs:
- names_to_resolve.extend(extract_names_from_expr(expr))
- for name in names_to_resolve:
- traced_objects = resolve_names_to_kernels(
- [name], all_names, collector.assignments
- )
- for obj in traced_objects:
- obj_id = id(obj)
- if obj_id not in seen_ids:
- seen_ids.add(obj_id)
- resolved.append(obj)
- for func_name in collector.called_functions:
- func_obj = all_names.get(func_name)
- if func_obj is None:
- from torch._library.custom_ops import OPDEFS
- if func_name in OPDEFS:
- func_obj = OPDEFS[func_name]._abstract_fn
- # skip if not a callable or if it's a triton kernel itself
- if func_obj is None or not callable(func_obj):
- continue
- # skip built-in functions and C extensions (they can't contain triton kernels)
- if not hasattr(func_obj, "__code__"):
- continue
- try:
- nested_kernels = find_triton_kernels(func_obj, visited_fns, depth + 1)
- for kernel in nested_kernels:
- kernel_id = id(kernel)
- if kernel_id not in seen_ids:
- seen_ids.add(kernel_id)
- resolved.append(kernel)
- except Exception:
- logger.debug(
- "failed to analyze called function %s", func_name, exc_info=True
- )
- return resolved
- return find_triton_kernels(fn)
- @exposed_in("torch.library")
- def triton_op(
- name: str,
- fn: Optional[Callable] = None,
- /,
- *,
- mutates_args: Union[str, Iterable[str]],
- schema: Optional[str] = None,
- ) -> Callable:
- """Create a custom operator whose implementation is backed by 1+ triton kernels.
- This is a more structured way of using triton kernels with PyTorch.
- Prefer using triton kernels with no ``torch.library`` custom operator wrappers
- (like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because
- that is simpler;
- only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you
- want to create an operator that behaves like PyTorch built-in operators.
- For example, you may use a ``torch.library`` wrapper API to define the
- behavior of the triton kernel when passed a tensor subclass or under
- a TorchDispatchMode.
- Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op`
- when the implementation
- consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
- custom operators as opaque (:func:`torch.compile` and
- :func:`torch.export.export` will never trace into them), but ``triton_op``
- makes the implementation visible to these subsystems, allowing them
- to optimize the triton kernel(s).
- Note that ``fn`` must only consist of calls to PyTorch-understood
- operators and triton kernels. Any triton kernels called inside ``fn``
- must be wrapped in a call to :func:`torch.library.wrap_triton`.
- Args:
- name (str): A name for the custom op that looks like "{namespace}::{name}",
- e.g. "mylib::my_linear". The name is used as the op's stable identifier
- in PyTorch subsystems (e.g. torch.export, FX graphs).
- To avoid name collisions, please use your project name as the namespace;
- e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
- mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
- This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
- it pessimistically assumes that all inputs to the operator are being mutated.
- schema (str | None): A schema string for the operator. If None
- (recommended) we'll infer a schema for the operator from its type
- annotations. We recommend letting us infer a schema unless you
- have a specific reason not to.
- Example: "(Tensor x, int y) -> (Tensor, Tensor)".
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> import torch
- >>> from torch.library import triton_op, wrap_triton
- >>>
- >>> import triton
- >>> from triton import language as tl
- >>>
- >>> @triton.jit
- >>> def add_kernel(
- >>> in_ptr0,
- >>> in_ptr1,
- >>> out_ptr,
- >>> n_elements,
- >>> BLOCK_SIZE: "tl.constexpr",
- >>> ):
- >>> pid = tl.program_id(axis=0)
- >>> block_start = pid * BLOCK_SIZE
- >>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
- >>> mask = offsets < n_elements
- >>> x = tl.load(in_ptr0 + offsets, mask=mask)
- >>> y = tl.load(in_ptr1 + offsets, mask=mask)
- >>> output = x + y
- >>> tl.store(out_ptr + offsets, output, mask=mask)
- >>>
- >>> @triton_op("mylib::add", mutates_args={})
- >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- >>> output = torch.empty_like(x)
- >>> n_elements = output.numel()
- >>>
- >>> def grid(meta):
- >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- >>>
- >>> # NB: we need to wrap the triton kernel in a call to wrap_triton
- >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
- >>> return output
- >>>
- >>> @torch.compile
- >>> def f(x, y):
- >>> return add(x, y)
- >>>
- >>> x = torch.randn(3, device="cuda")
- >>> y = torch.randn(3, device="cuda")
- >>>
- >>> z = f(x, y)
- >>> assert torch.allclose(z, x + y)
- """
- def dec(fn: Callable[..., object]) -> CustomOpDef:
- def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
- # Optimization: we're passing regular Tensors into the triton kernel, so
- # no need to go through HOP dispatch
- with set_wrap_triton_enabled(False):
- return fn(*args, **kwargs)
- result = custom_op(
- name,
- backend_fn,
- mutates_args=mutates_args,
- schema=infer_schema(fn, mutates_args=mutates_args),
- )
- from .._subclasses.functional_tensor import FunctionalTensorMode
- # We require that the user pass us a function that is make_fx traceable,
- # so we can just register it as the Fake/meta kernel.
- result.register_fake(fn)
- # We decompose the operator when FunctionalTensorMode is active.
- # The goal is to decompose the operator in AOTDispatcher.
- # - With torch.compile, this means that the backend (usually Inductor)
- # can see a call to the triton kernel(s) and so it can directly optimize
- # them by inlining them into the lowering process.
- def functional_decomp( # type: ignore[no-untyped-def]
- mode, op, types, args, kwargs
- ):
- # NOTE [Export custom triton op]
- # For torch.export (strict and non-strict), we don't do functional decomposition.
- # Instead, we preserve the custom triton ops as custom ops. This is because we want
- # the exported program to be high-level and serializable. If we decompose
- # the custom op to a functional hop and make it a node in exported program,
- # we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
- # functions and triton dtypes. This is undesirable because:
- # - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
- # - exported program will contain the implementation detail (e.g. triton source code) for a specific
- # backend (GPU), which is probably at a wrong level of abstraction.
- # - changes to triton or the serialization logic for triton arguments can be BC breaking
- #
- # In the short term, we expect users to have a separate aot_compile stage that compiles the exported program
- # into a Cubin file on the same machine that users call export, which does autotuning and removes triton
- # dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
- # In the long term, we may export multiple cubins for the triton op directly
- from torch.export._trace import custom_triton_ops_decomposition_disabled
- if custom_triton_ops_decomposition_disabled():
- return mode.__torch_dispatch__(op, types, args, kwargs)
- else:
- # TODO: https://github.com/pytorch/pytorch/issues/160333
- # We should deduplicate the unrecognized_types logic.
- import torch._subclasses
- unrecognized_types = [
- t
- for t in types
- if not issubclass(t, torch._subclasses.FakeTensor)
- and t
- not in [
- torch.Tensor,
- torch._subclasses.functional_tensor.FunctionalTensor,
- ]
- ]
- if unrecognized_types:
- return NotImplemented
- with mode:
- return fn(*args, **kwargs)
- triton_kernels = get_inner_triton_kernels(fn)
- triton_ops_to_kernels[name] = triton_kernels
- result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
- return result
- if fn is None:
- return dec
- else:
- return dec(fn)
- wrap_triton_enabled = threading.local()
- wrap_triton_enabled_default = True
- @contextlib.contextmanager
- def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]:
- """If triton kernels annotated with @wrap_triton should dispatch via HOP
- or go straight to the triton kernel execution.
- We have this switch because eager-mode performance of HOP dispatch is slow
- enough to matter (~1ms) and we know that wrap_triton isn't necessary in
- some situations (eager-mode with regular Tensors)
- """
- try:
- prev = is_wrap_triton_enabled()
- wrap_triton_enabled.value = enabled
- yield
- finally:
- wrap_triton_enabled.value = prev
- def is_wrap_triton_enabled() -> bool:
- return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default)
- def capture_triton(triton_kernel: Callable, /) -> Any:
- """This API has been renamed to wrap_triton"""
- return wrap_triton(triton_kernel)
- @exposed_in("torch.library")
- def wrap_triton(triton_kernel: Callable, /) -> Any:
- """Allows capture of a triton kernel into a graph via make_fx or
- non-strict ``torch.export``.
- These technologies perform Dispatcher-based tracing (via
- ``__torch_dispatch__``) and cannot see calls to raw triton kernels.
- The ``wrap_triton`` API wraps a triton kernel into a callable that
- can actually be traced into a graph.
- Please use this API together with :func:`torch.library.triton_op`.
- Examples:
- >>> # xdoctest: +SKIP
- >>> import torch
- >>> import triton
- >>> from triton import language as tl
- >>> from torch.fx.experimental.proxy_tensor import make_fx
- >>> from torch.library import wrap_triton
- >>>
- >>> @triton.jit
- >>> def add_kernel(
- >>> in_ptr0,
- >>> in_ptr1,
- >>> out_ptr,
- >>> n_elements,
- >>> BLOCK_SIZE: "tl.constexpr",
- >>> ):
- >>> pid = tl.program_id(axis=0)
- >>> block_start = pid * BLOCK_SIZE
- >>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
- >>> mask = offsets < n_elements
- >>> x = tl.load(in_ptr0 + offsets, mask=mask)
- >>> y = tl.load(in_ptr1 + offsets, mask=mask)
- >>> output = x + y
- >>> tl.store(out_ptr + offsets, output, mask=mask)
- >>>
- >>> def add(x, y):
- >>> output = torch.empty_like(x)
- >>> n_elements = output.numel()
- >>>
- >>> def grid_fn(meta):
- >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- >>>
- >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
- >>> return output
- >>>
- >>> x = torch.randn(3, device="cuda")
- >>> y = torch.randn(3, device="cuda")
- >>> gm = make_fx(add)(x, y)
- >>> print(gm.code)
- >>> # def forward(self, x_1, y_1):
- >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
- >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
- >>> # kernel_idx = 0, constant_args_idx = 0,
- >>> # grid = [(1, 1, 1)], kwargs = {
- >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
- >>> # 'n_elements': 3, 'BLOCK_SIZE': 16
- >>> # })
- >>> # return empty_like
- """
- from triton.runtime.autotuner import Autotuner
- from triton.runtime.jit import JITFunction
- from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
- if not isinstance(triton_kernel, (JITFunction, Autotuner)):
- raise RuntimeError(
- "wrap_triton only works on functions annotated with triton.jit or triton.autotune"
- )
- if not is_wrap_triton_enabled():
- return triton_kernel
- return TraceableTritonKernelWrapper(triton_kernel, None, None)
|