link.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. from collections import defaultdict
  2. from pathlib import Path
  3. from typing import Sequence, Union
  4. from dataclasses import dataclass
  5. def _exists(x):
  6. return x is not None
  7. class LinkerError(Exception):
  8. pass
  9. @dataclass
  10. class KernelLinkerMeta:
  11. orig_kernel_name: str
  12. arg_names: Sequence[str]
  13. arg_ctypes: Sequence[str]
  14. sizes: Sequence[Union[int, None]]
  15. sig_hash: str
  16. triton_suffix: str
  17. suffix: str
  18. num_specs: int
  19. """ number of specialized arguments """
  20. class HeaderParser:
  21. def __init__(self) -> None:
  22. import re
  23. # [kernel_name, c signature]
  24. self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)")
  25. # [name, hash, suffix]
  26. self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$")
  27. # [(type, name)]
  28. self.c_sig = re.compile("[\\s]*(\\w+)\\s(\\w+)[,]?")
  29. # [d|c]
  30. self.arg_suffix = re.compile("[c,d]")
  31. self.kernels = defaultdict(list)
  32. def extract_linker_meta(self, header: str):
  33. for ln in header.splitlines():
  34. if ln.startswith("//"):
  35. m = self.linker_directives.match(ln)
  36. if _exists(m):
  37. ker_name, c_sig, algo_info = m.group(1), m.group(2), m.group(3)
  38. name, sig_hash, suffix = self._match_name(ker_name)
  39. c_types, arg_names = self._match_c_sig(c_sig)
  40. num_specs, sizes = self._match_suffix(suffix, c_sig)
  41. self._add_kernel(
  42. "_".join([name, algo_info]),
  43. KernelLinkerMeta(
  44. orig_kernel_name=name,
  45. arg_names=arg_names,
  46. arg_ctypes=c_types,
  47. sizes=sizes,
  48. sig_hash=sig_hash,
  49. triton_suffix=suffix,
  50. suffix=suffix,
  51. num_specs=num_specs,
  52. ),
  53. )
  54. def _match_name(self, ker_name: str):
  55. m = self.kernel_name.match(ker_name)
  56. if _exists(m):
  57. name, sig_hash, suffix = m.group(1), m.group(2), m.group(3)
  58. return name, sig_hash, suffix
  59. raise LinkerError(f"{ker_name} is not a valid kernel name")
  60. def _match_c_sig(self, c_sig: str):
  61. m = self.c_sig.findall(c_sig)
  62. if len(m):
  63. tys, args = [], []
  64. for ty, arg_name in m:
  65. tys.append(ty)
  66. args.append(arg_name)
  67. return tys, args
  68. raise LinkerError(f"{c_sig} is not a valid argument signature")
  69. def _match_suffix(self, suffix: str, c_sig: str):
  70. args = c_sig.split(",")
  71. s2i = {"c": 1, "d": 16}
  72. num_specs = 0
  73. sizes = []
  74. # scan through suffix, first find the index,
  75. # then see if it is followed by d or c
  76. for i in range(len(args)):
  77. pos = suffix.find(str(i))
  78. if pos == -1:
  79. raise LinkerError(f"{suffix} is not a valid kernel suffix")
  80. pos += len(str(i))
  81. if self.arg_suffix.match(suffix, pos):
  82. num_specs += 1
  83. sizes.extend([None] * (i - len(sizes)))
  84. sizes.append(s2i[suffix[pos]])
  85. pos += 1
  86. if i < len(args) - 1:
  87. suffix = suffix[pos:]
  88. else:
  89. sizes.extend([None] * (len(args) - len(sizes)))
  90. return num_specs, sizes
  91. def _add_kernel(self, name: str, ker: KernelLinkerMeta):
  92. if name in self.kernels:
  93. last: KernelLinkerMeta = self.kernels[name][-1]
  94. for cur, new_ in zip(last.arg_ctypes, ker.arg_ctypes):
  95. if cur != new_:
  96. raise LinkerError(
  97. f"Mismatched signature for kernel {name}: \n\texisting sig is: {','.join(last.arg_ctypes)}\n\tcurrent is: {','.join(ker.arg_ctypes)}"
  98. )
  99. self.kernels[name].append(ker)
  100. def gen_signature_with_full_args(m):
  101. return ", ".join([f"{ty} {arg}" for ty, arg in zip(m.arg_ctypes, m.arg_names)])
  102. def gen_signature(m):
  103. arg_types = [ty for ty, hint in zip(m.arg_ctypes, m.sizes) if hint != 1]
  104. arg_names = [arg for arg, hint in zip(m.arg_names, m.sizes) if hint != 1]
  105. sig = ", ".join([f"{ty} {arg}" for ty, arg in zip(arg_types, arg_names)])
  106. return sig
  107. # generate declarations of kernels with meta-parameter and constant values
  108. def make_algo_decls(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
  109. return f"""
  110. CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])});
  111. void load_{name}();
  112. void unload_{name}();
  113. """
  114. # generate declarations of kernels with meta-parameter and constant values
  115. def make_global_decl(meta: KernelLinkerMeta) -> str:
  116. return f"""
  117. CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)});
  118. CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id);
  119. void load_{meta.orig_kernel_name}();
  120. void unload_{meta.orig_kernel_name}();
  121. """
  122. # generate dispatcher function for kernels with different meta-parameter and constant values
  123. def make_default_algo_kernel(meta: KernelLinkerMeta) -> str:
  124. src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n"
  125. src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n")
  126. src += "}\n"
  127. return src
  128. # generate dispatcher function for kernels with different integer value hints
  129. def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) -> str:
  130. src = f"// launcher for: {name}\n"
  131. for meta in sorted(metas, key=lambda m: -m.num_specs):
  132. src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n"
  133. src += "\n"
  134. src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{")
  135. src += "\n"
  136. for meta in sorted(metas, key=lambda m: -m.num_specs):
  137. cond_fn = ( #
  138. lambda val, hint: f"({val} % {hint} == 0)" #
  139. if hint == 16 #
  140. else f"({val} == {hint})" #
  141. if hint == 1 #
  142. else None)
  143. conds = " && ".join([ #
  144. cond_fn(val, hint) #
  145. for val, hint in zip(meta.arg_names, meta.sizes) #
  146. if hint is not None
  147. ])
  148. src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n"
  149. ) # Edge case where no specializations hence no dispatching required
  150. arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1]
  151. src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n"
  152. src += "\n"
  153. src += " return CUDA_ERROR_INVALID_VALUE;\n"
  154. src += "}\n"
  155. for mode in ["load", "unload"]:
  156. src += f"\n// {mode} for: {name}\n"
  157. for meta in sorted(metas, key=lambda m: -m.num_specs):
  158. src += f"void {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n"
  159. src += f"void {mode}_{name}() {{"
  160. src += "\n"
  161. for meta in sorted(metas, key=lambda m: -m.num_specs):
  162. src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n")
  163. src += "}\n"
  164. return src
  165. # generate dispatcher function for kernels with different meta-parameter and constant values
  166. def make_kernel_meta_const_dispatcher(meta: KernelLinkerMeta) -> str:
  167. src = f"CUresult {meta.orig_kernel_name}(CUstream stream, {gen_signature_with_full_args(meta)}, int algo_id){{\n"
  168. src += f" assert (algo_id < (int)sizeof({meta.orig_kernel_name}_kernels));\n"
  169. src += f" return {meta.orig_kernel_name}_kernels[algo_id](stream, {', '.join(meta.arg_names)});\n"
  170. src += "}\n"
  171. return src
  172. # generate definition of function pointers of kernel dispatchers based on meta-parameter and constant values
  173. def make_func_pointers(names: str, meta: KernelLinkerMeta) -> str:
  174. # the table of hint dispatchers
  175. src = f"typedef CUresult (*kernel_func_t)(CUstream stream, {gen_signature_with_full_args(meta)});\n"
  176. src += f"kernel_func_t {meta.orig_kernel_name}_kernels[] = {{\n"
  177. for name in names:
  178. src += f" {name},\n"
  179. src += "};\n"
  180. return src
  181. # generate definition for load/unload functions for kernels with different meta-parameter and constant values
  182. def make_kernel_load_def(names: str, meta: KernelLinkerMeta) -> str:
  183. src = ""
  184. for mode in ["load", "unload"]:
  185. src += f"void {mode}_{meta.orig_kernel_name}(void){{\n"
  186. for name in names:
  187. src += f" {mode}_{name}();\n"
  188. src += "}\n\n"
  189. return src
  190. def make_get_num_algos_decl(meta: KernelLinkerMeta) -> str:
  191. src = f"int {meta.orig_kernel_name}_get_num_algos(void);"
  192. return src
  193. def make_get_num_algos_def(meta: KernelLinkerMeta) -> str:
  194. src = f"int {meta.orig_kernel_name}_get_num_algos(void){{\n"
  195. src += f" return (int)(sizeof({meta.orig_kernel_name}_kernels) / sizeof({meta.orig_kernel_name}_kernels[0]));\n"
  196. src += "}\n"
  197. return src
  198. desc = """
  199. Triton ahead-of-time linker:
  200. This program takes in header files generated by compile.py, and generates a
  201. single entry-point responsible for dispatching the user's input to the right
  202. kernel given the specializations that were compiled.
  203. Example usage:
  204. python link.py /path/to/headers/*.h -o kernel_name
  205. """
  206. if __name__ == "__main__":
  207. from argparse import ArgumentParser
  208. parser = ArgumentParser(description=desc)
  209. parser.add_argument(
  210. "headers",
  211. nargs="+",
  212. help="Paths to header files to link. Must include linker directive annotations (autogenerated by ttc)",
  213. )
  214. parser.add_argument("--out", "-o", type=Path, help="Out filename")
  215. parser.add_argument(
  216. "--prefix",
  217. type=str,
  218. default="",
  219. help="String to prefix kernel dispatcher names",
  220. )
  221. args = parser.parse_args()
  222. # metadata
  223. parser = HeaderParser()
  224. includes = []
  225. for header in args.headers:
  226. h_path = Path(header)
  227. h_str = h_path.read_text()
  228. includes.append(h_path.name)
  229. parser.extract_linker_meta(h_str)
  230. # generate headers
  231. algo_decls = [make_algo_decls(name, meta) for name, meta in parser.kernels.items()]
  232. meta_lists = [meta for name, meta in parser.kernels.items()]
  233. meta = meta_lists[0][0]
  234. get_num_algos_decl = make_get_num_algos_decl(meta)
  235. global_decl = make_global_decl(meta)
  236. with args.out.with_suffix(".h").open("w") as fp:
  237. out = "#include <cuda.h>\n"
  238. out += "\n".join(algo_decls)
  239. out += "\n"
  240. out += get_num_algos_decl
  241. out += "\n"
  242. out += global_decl
  243. fp.write(out)
  244. # generate source
  245. defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()]
  246. names = [name for name in parser.kernels.keys()]
  247. func_pointers_def = make_func_pointers(names, meta)
  248. meta_const_def = make_kernel_meta_const_dispatcher(meta)
  249. load_unload_def = make_kernel_load_def(names, meta)
  250. get_num_algos_def = make_get_num_algos_def(meta)
  251. default_algo_kernel = make_default_algo_kernel(meta)
  252. with args.out.with_suffix(".c").open("w") as fp:
  253. out = ""
  254. out += "#include <cuda.h>\n"
  255. out += "#include <stdint.h>\n"
  256. out += "#include <assert.h>\n"
  257. out += "\n"
  258. out += "\n".join(defs)
  259. out += "\n"
  260. out += func_pointers_def
  261. out += "\n"
  262. out += get_num_algos_def
  263. out += "\n"
  264. out += meta_const_def
  265. out += "\n"
  266. out += load_unload_def
  267. out += "\n"
  268. out += default_algo_kernel
  269. fp.write(out)