compile.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import binascii
  2. import hashlib
  3. import importlib.util
  4. import sys
  5. from argparse import ArgumentParser
  6. from dataclasses import dataclass
  7. from pathlib import Path
  8. from typing import List
  9. import triton
  10. import triton.backends
  11. @dataclass
  12. class CompileArgs:
  13. '''
  14. A class to contain arguments from command-line parser.
  15. '''
  16. path: str = ''
  17. kernel_name: str = ''
  18. signature: str = ''
  19. grid: str = ''
  20. target: str | None = None
  21. num_warps: int = 1
  22. num_stages: int = 3
  23. out_name: str | None = None
  24. out_path: Path | None = None
  25. desc = """
  26. Triton ahead-of-time compiler:
  27. This program compiles the kernel with name `kernel-name` in the file at the
  28. provided `path` into self-contained C source-code that embeds the `cubin`
  29. data along with utilities to load, unload and launch the kernel.
  30. signature is provided as a list of (optionally divisibility-hinted) types
  31. or constexpr values, e.g.
  32. `compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py`
  33. will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`.
  34. Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16,
  35. and argument 2 is assumed to be a compile-time constant of value 1024, i.e. it won't be part of the generated prototype.
  36. The resulting entry point will have signature
  37. CUresult kernel_{specialization_suffix}(CUstream stream, unsigned gX, unsigned gY, unsigned gZ, float* arg0, int32_t arg1, int32_t arg2)
  38. Different such specialized entry points can be combined using the `linker.py` script.
  39. NOTE: when resolving the scope of /path/to/kernel.py, the file will be executed from within its parent directory with the python interpreter
  40. used to run this `compile.py` script
  41. """
  42. def main():
  43. # command-line arguments
  44. parser = ArgumentParser(description=desc)
  45. parser.add_argument("path",
  46. help="Path to Python source containing desired kernel in its scope. File will be executed.")
  47. parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile",
  48. required=True)
  49. parser.add_argument(
  50. "--target", "-t", type=str, default=None,
  51. help="The target to compile towards, in format of '<backend>:<arch>:<warp-size>'; "
  52. "e.g., 'cuda:80:32', 'hip:gfx942:64'. Default to None, which means using current machine's GPU target")
  53. parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel")
  54. parser.add_argument("--num-stages", "-ns", type=int, default=3,
  55. help="Number of stages (meta-parameter of the kernel)")
  56. parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel")
  57. parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename")
  58. parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True)
  59. parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True)
  60. cli_args = parser.parse_args()
  61. args = CompileArgs(**vars(cli_args)) # A sanity check to ensure class CompileArgs is updated as well.
  62. compile_kernel(args)
  63. def compile_kernel(args: CompileArgs):
  64. out_name = args.out_name if args.out_name else args.kernel_name
  65. out_path = args.out_path if args.out_path else Path(out_name)
  66. # execute python sources and extract functions wrapped in JITFunction
  67. arg_path = Path(args.path)
  68. sys.path.insert(0, str(arg_path.parent))
  69. spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path)
  70. mod = importlib.util.module_from_spec(spec)
  71. spec.loader.exec_module(mod)
  72. kernel = getattr(mod, args.kernel_name)
  73. grid = args.grid.split(",")
  74. assert len(grid) == 3
  75. # validate and parse signature
  76. signature = list(map(lambda s: s.strip(" "), args.signature.split(",")))
  77. def hash_signature(signature: List[str]):
  78. m = hashlib.sha256()
  79. m.update(" ".join(signature).encode())
  80. return m.hexdigest()[:8]
  81. meta_sig = f"warps{args.num_warps}xstages{args.num_stages}"
  82. sig_hash = hash_signature(signature + [meta_sig])
  83. def constexpr(s):
  84. try:
  85. ret = int(s)
  86. return ret
  87. except ValueError:
  88. pass
  89. try:
  90. ret = float(s)
  91. return ret
  92. except ValueError:
  93. pass
  94. return None
  95. hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s}
  96. hints = {k: v for k, v in hints.items() if v is not None}
  97. constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)}
  98. constants = {k: v for k, v in constants.items() if v is not None}
  99. for key, value in hints.items():
  100. if value == 1:
  101. constants[kernel.arg_names[key[0]]] = value
  102. signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)}
  103. for key in constants:
  104. signature[key] = 'constexpr'
  105. const_sig = 'x'.join([str(v) for v in constants.values()])
  106. doc_string = [f"{k}={v}" for k, v in constants.items()]
  107. doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"]
  108. # compile ast into cubin
  109. for h in hints.values():
  110. assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}"
  111. attrs = {k: [["tt.divisibility", 16]] for k, v in hints.items() if v == 16}
  112. kernel.create_binder()
  113. src = kernel.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs)
  114. target = triton.backends.compiler.GPUTarget(*args.target.split(":")) \
  115. if args.target else triton.runtime.driver.active.get_current_target()
  116. backend = triton.compiler.make_backend(target)
  117. kwargs = {"num_warps": args.num_warps, "num_stages": args.num_stages}
  118. options = backend.parse_options(kwargs)
  119. ccinfo = triton.compile(src, target=target, options=options.__dict__)
  120. if getattr(ccinfo.metadata, "global_scratch_size", 0) > 0:
  121. raise RuntimeError("AOT compiling kernels with global scratch requirements is not yet implemented")
  122. if ccinfo.metadata.profile_scratch_size > 0:
  123. raise RuntimeError("AOT compiling kernels with profile scratch requirements is not yet implemented")
  124. arg_names = []
  125. arg_types = []
  126. arg_names_not_1 = []
  127. arg_types_not_1 = []
  128. for i, arg_name in enumerate(kernel.arg_names):
  129. if arg_name not in constants:
  130. arg_names.append(arg_name)
  131. arg_types.append(signature[arg_name])
  132. arg_names_not_1.append(arg_name)
  133. arg_types_not_1.append(signature[arg_name])
  134. elif hints.get((i, ), None) == 1:
  135. arg_names.append(arg_name)
  136. arg_types.append("i32")
  137. # dump C stub code
  138. suffix = ''
  139. for i, ty in enumerate(signature.values()):
  140. suffix += str(i)
  141. if hints.get((i, ), None) == 1:
  142. suffix += 'c'
  143. if hints.get((i, ), None) == 16:
  144. suffix += 'd'
  145. func_name = '_'.join([out_name, sig_hash, suffix])
  146. asm = ccinfo.asm[backend.binary_ext] # store binary data once
  147. hex_ = str(binascii.hexlify(asm))[2:-1]
  148. ty_to_cpp = triton.runtime.driver.active.map_python_to_cpp_type
  149. params = {
  150. "kernel_name": func_name,
  151. "triton_kernel_name": args.kernel_name,
  152. "bin_size": len(asm),
  153. "bin_data": ", ".join([f"0x{x}{y}" for x, y in zip(hex_[::2], hex_[1::2])]),
  154. "signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names_not_1, arg_types_not_1)]),
  155. "full_signature": ", ".join([f"{ty_to_cpp(ty)} {name}" for name, ty in zip(arg_names, arg_types)]),
  156. "arg_pointers": ", ".join([f"&{arg}" for arg in arg_names_not_1] + ["&global_scratch"] + ["&profile_scratch"]),
  157. "num_args": len(arg_names_not_1) + 2, # +2 for global and profile scratch
  158. "kernel_docstring": doc_string,
  159. "shared": ccinfo.metadata.shared,
  160. "num_warps": args.num_warps,
  161. "algo_info": "_".join([const_sig, meta_sig]),
  162. "gridX": grid[0],
  163. "gridY": grid[1],
  164. "gridZ": grid[2],
  165. "_placeholder": "",
  166. "warp_size": target.warp_size,
  167. }
  168. output_files = []
  169. backend_name = target.backend
  170. template_dir = Path(__file__).parent / "extra" / backend_name
  171. for template_path in template_dir.glob('compile.*'):
  172. ext = template_path.suffix
  173. output_file = out_path.with_suffix(f".{sig_hash}_{suffix}{ext}")
  174. with output_file.open("w") as fp:
  175. fp.write(template_path.read_text().format(**params))
  176. output_files.append(output_file)
  177. return func_name, output_files
  178. if __name__ == "__main__":
  179. main()