_runtime.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from __future__ import annotations
  2. from triton.compiler.compiler import ASTSource
  3. from triton.backends.compiler import Language
  4. from triton.runtime.jit import JITFunction, constexpr_function
  5. from typing import TypeVar, Optional, Callable, Iterable, Union
  6. from triton._C.libtriton import ir
  7. T = TypeVar("T")
  8. __all__ = ["constexpr_function", "jit"]
  9. class GluonASTSource(ASTSource):
  10. def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
  11. super().__init__(fn, signature, constexprs, attrs)
  12. self.language = Language.GLUON
  13. self.ext = "ttgir"
  14. def make_ir(self, target, options, codegen_fns, module_map, context):
  15. from triton.compiler.compiler import make_backend
  16. from triton.compiler.code_generator import ast_to_ttir
  17. builder = ir.builder(context)
  18. module = builder.create_module()
  19. # Assign module attributes eagerly, as they are needed to verify layouts
  20. backend = make_backend(target)
  21. target = backend.get_target_name(options)
  22. module.set_attr("ttg.target", builder.get_string_attr(target))
  23. module.set_attr("ttg.num-warps", builder.get_int32_attr(options.num_warps))
  24. module.set_attr("ttg.num-ctas", builder.get_int32_attr(options.num_ctas))
  25. module.set_attr("ttg.threads-per-warp", builder.get_int32_attr(options.warp_size))
  26. is_cuda = options.backend_name == "cuda"
  27. if is_cuda and options.maxnreg is not None:
  28. module.set_attr("ttg.maxnreg", builder.get_int32_attr(options.maxnreg))
  29. module = ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
  30. module_map=module_map, module=module)
  31. return module
  32. class GluonJITFunction(JITFunction[T]):
  33. def create_binder(self):
  34. result = super().create_binder()
  35. self.ASTSource = GluonASTSource
  36. return result
  37. def is_gluon(self):
  38. return True
  39. def jit(
  40. fn: Optional[T] = None,
  41. *,
  42. version=None,
  43. repr: Optional[Callable] = None,
  44. launch_metadata: Optional[Callable] = None,
  45. do_not_specialize: Optional[Iterable[int | str]] = None,
  46. do_not_specialize_on_alignment: Optional[Iterable[int | str]] = None,
  47. debug: Optional[bool] = None,
  48. noinline: Optional[bool] = None,
  49. ) -> Union[GluonJITFunction[T], Callable[[T], JITFunction[T]]]:
  50. """
  51. Decorator for JIT-compiling a function using the Triton compiler.
  52. :note: When a jit'd function is called, arguments are
  53. implicitly converted to pointers if they have a :code:`.data_ptr()` method
  54. and a `.dtype` attribute.
  55. :note: This function will be compiled and run on the GPU. It will only have access to:
  56. * python primitives,
  57. * builtins within the triton package,
  58. * arguments to this function,
  59. * other jit'd functions
  60. :param fn: the function to be jit-compiled
  61. :type fn: Callable
  62. """
  63. def decorator(fn: T) -> JITFunction[T]:
  64. assert callable(fn)
  65. return GluonJITFunction(
  66. fn,
  67. version=version,
  68. do_not_specialize=do_not_specialize,
  69. do_not_specialize_on_alignment=do_not_specialize_on_alignment,
  70. debug=debug,
  71. noinline=noinline,
  72. repr=repr,
  73. launch_metadata=launch_metadata,
  74. )
  75. if fn is not None:
  76. return decorator(fn)
  77. else:
  78. return decorator