| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from __future__ import annotations
- from triton.compiler.compiler import ASTSource
- from triton.backends.compiler import Language
- from triton.runtime.jit import JITFunction, constexpr_function
- from typing import TypeVar, Optional, Callable, Iterable, Union
- from triton._C.libtriton import ir
- T = TypeVar("T")
- __all__ = ["constexpr_function", "jit"]
- class GluonASTSource(ASTSource):
- def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
- super().__init__(fn, signature, constexprs, attrs)
- self.language = Language.GLUON
- self.ext = "ttgir"
- def make_ir(self, target, options, codegen_fns, module_map, context):
- from triton.compiler.compiler import make_backend
- from triton.compiler.code_generator import ast_to_ttir
- builder = ir.builder(context)
- module = builder.create_module()
- # Assign module attributes eagerly, as they are needed to verify layouts
- backend = make_backend(target)
- target = backend.get_target_name(options)
- module.set_attr("ttg.target", builder.get_string_attr(target))
- module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
- module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
- module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
- is_cuda = options.backend_name == "cuda"
- if is_cuda and options.maxnreg is not None:
- module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
- module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
- module_map=module_map, module=module)
- return module
- class GluonJITFunction(JITFunction[T]):
- def create_binder(self):
- result = super().create_binder()
- self.ASTSource = GluonASTSource
- return result
- def is_gluon(self):
- return True
- def jit(
- fn: Optional[T] = None,
- *,
- version=None,
- repr: Optional[Callable] = None,
- launch_metadata: Optional[Callable] = None,
- do_not_specialize: Optional[Iterable[int | str]] = None,
- do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
- debug: Optional[bool] = None,
- noinline: Optional[bool] = None,
- ) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
- """
- Decorator for JIT-compiling a function using the Triton compiler.
- :note: When a jit'd function is called, arguments are
- implicitly converted to pointers if they have a :code:`.data_ptr()` method
- and a `.dtype` attribute.
- :note: This function will be compiled and run on the GPU. It will only have access to:
- * python primitives,
- * builtins within the triton package,
- * arguments to this function,
- * other jit'd functions
- :param fn: the function to be jit-compiled
- :type fn: Callable
- """
- def decorator(fn: T) -> JITFunction[T]:
- assert callable(fn)
- return GluonJITFunction(
- fn,
- version=version,
- do_not_specialize=do_not_specialize,
- do_not_specialize_on_alignment=do_not_specialize_on_alignment,
- debug=debug,
- noinline=noinline,
- repr=repr,
- launch_metadata=launch_metadata,
- )
- if fn is not None:
- return decorator(fn)
- else:
- return decorator
|