build_extern.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import argparse
  2. import subprocess
  3. from abc import ABC, abstractmethod
  4. from typing import Dict, List, Optional
  5. class Symbol:
  6. _name: str
  7. _op_name: str
  8. _ret_type: str
  9. _arg_names: List[str]
  10. _arg_types: List[str]
  11. def __init__(
  12. self,
  13. name: str,
  14. op_name: str,
  15. ret_type: str,
  16. arg_names: List[str],
  17. arg_types: List[str],
  18. ) -> None:
  19. '''
  20. A symbol is a function declaration.
  21. :param name: name of the symbol
  22. :param op_name: name of the operation
  23. :param ret_type: return type of the operation
  24. :param arg_names: names of the arguments
  25. :param arg_types: types of the arguments
  26. '''
  27. self._name = name
  28. self._op_name = op_name
  29. self._ret_type = ret_type
  30. self._arg_names = list(arg_names)
  31. self._arg_types = list(arg_types)
  32. @property
  33. def name(self) -> str:
  34. return self._name
  35. @property
  36. def op_name(self) -> str:
  37. return self._op_name
  38. @property
  39. def ret_type(self) -> str:
  40. return self._ret_type
  41. @property
  42. def arg_names(self) -> List[str]:
  43. return self._arg_names
  44. @property
  45. def arg_types(self) -> List[str]:
  46. return self._arg_types
  47. def convert_type(type_str) -> Optional[str]:
  48. if type_str == "i32":
  49. return "int32"
  50. elif type_str == "u32":
  51. return "uint32"
  52. elif type_str == "i64":
  53. return "int64"
  54. elif type_str == "u64":
  55. return "uint64"
  56. elif type_str == "float":
  57. return "fp32"
  58. elif type_str == "double":
  59. return "fp64"
  60. else:
  61. # ignore other types, such as pointer types
  62. return None
  63. def to_unsigned(type_str) -> str:
  64. if type_str == "int32":
  65. return "uint32"
  66. elif type_str == "int64":
  67. return "uint64"
  68. else:
  69. return type_str
  70. class ExternLibrary(ABC):
  71. _name: str
  72. _path: str
  73. _symbols: Dict[str, Symbol]
  74. _format: bool
  75. _grouping: bool
  76. def __init__(
  77. self,
  78. name: str,
  79. path: str,
  80. format: bool = True,
  81. grouping: bool = True,
  82. ) -> None:
  83. '''
  84. Abstract class for extern library.
  85. :param name: name of the library
  86. :param path: path of the library
  87. :param format: whether to format the generated stub file
  88. '''
  89. self._name = name
  90. self._path = path
  91. self._symbols = {}
  92. self._format = format
  93. self._grouping = grouping
  94. @property
  95. def name(self) -> str:
  96. return self._name
  97. @property
  98. def path(self) -> str:
  99. return self._path
  100. @property
  101. def symbols(self) -> Dict[str, Symbol]:
  102. return self._symbols
  103. @property
  104. def grouping(self) -> bool:
  105. return self._grouping
  106. @abstractmethod
  107. def parse_symbols(self, input_file) -> None:
  108. pass
  109. @abstractmethod
  110. def _output_stubs(self) -> str:
  111. pass
  112. def generate_stub_file(self, output_dir) -> None:
  113. file_str = self._output_stubs()
  114. if file_str is None or len(file_str) == 0:
  115. raise Exception("file_str is empty")
  116. output_file = f"{output_dir}/{self._name}.py"
  117. with open(output_file, "w") as f:
  118. f.write(file_str)
  119. f.close()
  120. if self._format:
  121. subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate()
  122. subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate()
  123. class Libdevice(ExternLibrary):
  124. _symbol_groups: Dict[str, List[Symbol]]
  125. def __init__(self, path) -> None:
  126. '''
  127. Constructor for Libdevice.
  128. :param path: path of the libdevice library
  129. '''
  130. super().__init__("libdevice", path)
  131. self._symbol_groups = {}
  132. self.is_pure = True
  133. @staticmethod
  134. def _extract_symbol(line) -> Optional[Symbol]:
  135. # Extract symbols from line in the following format:
  136. # "define [internal] <ret_type> @<name>(<arg_types>,)"
  137. entries = line.split("@")
  138. ret_str = entries[0]
  139. func_str = entries[1]
  140. # Get ret_type, skip internal symbols
  141. ret_strs = ret_str.split()
  142. if ret_strs[1] == "internal":
  143. return None
  144. ret_type = convert_type(ret_strs[1])
  145. if ret_type is None:
  146. return None
  147. # Get function name
  148. func_strs = func_str.split("(")
  149. func_name = func_strs[0].replace("@", "")
  150. op_name = func_name.replace("__nv_", "")
  151. if 'ieee' in op_name:
  152. return None
  153. # Get arg_types
  154. arg_strs = func_strs[1].split(",")
  155. arg_types = []
  156. arg_names = []
  157. for i, arg_str in enumerate(arg_strs):
  158. arg_type = convert_type(arg_str.split()[0])
  159. if arg_type is None:
  160. return None
  161. arg_name = 'arg' + str(i)
  162. arg_types.append(arg_type)
  163. arg_names.append(arg_name)
  164. if op_name == "sad":
  165. # Special case for sad, where the last argument is an unsigned int
  166. arg_types[-1] = to_unsigned(arg_types[-1])
  167. elif op_name.startswith("u"):
  168. # LLVM does not differentiate between signed and unsigned integer type.
  169. # We have to convert the types to unsigned
  170. ret_type = to_unsigned(ret_type)
  171. for i, arg_type in enumerate(arg_types):
  172. arg_types[i] = to_unsigned(arg_type)
  173. return Symbol(func_name, op_name, ret_type, arg_names, arg_types)
  174. def _group_symbols(self) -> None:
  175. symbol_set = {}
  176. for symbol in self._symbols.values():
  177. op_name = symbol.op_name
  178. symbol_set[op_name] = symbol
  179. # Group functions together by renaming.
  180. renaming = {
  181. 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn':
  182. 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz':
  183. 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh',
  184. 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos',
  185. 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1',
  186. 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru',
  187. 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf':
  188. 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2',
  189. 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll':
  190. 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru',
  191. 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff':
  192. 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn',
  193. 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f':
  194. 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax':
  195. 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min',
  196. 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn',
  197. 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24',
  198. 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf':
  199. 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv',
  200. 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd',
  201. 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru',
  202. 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot',
  203. 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt',
  204. 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit',
  205. 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd':
  206. 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru',
  207. 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn',
  208. 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz',
  209. 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf':
  210. 'yn'
  211. }
  212. for symbol in self._symbols.values():
  213. op_name = symbol.op_name
  214. if op_name in renaming:
  215. op_name = renaming[op_name]
  216. symbol._op_name = op_name
  217. if op_name in self._symbol_groups:
  218. self._symbol_groups[op_name].append(symbol)
  219. else:
  220. self._symbol_groups[op_name] = [symbol]
  221. def parse_symbols(self, input_file) -> None:
  222. if len(self.symbols) > 0:
  223. return
  224. output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines()
  225. for line in output:
  226. symbol = self._extract_symbol(line)
  227. if symbol is None:
  228. continue
  229. self._symbols[symbol.name] = symbol
  230. self._group_symbols()
  231. def _output_stubs(self) -> str:
  232. # Generate python functions in the following format:
  233. # @extern.extern
  234. # def <op_name>(<args>, _builder=None):
  235. # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
  236. # return core.extern_elementwise("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
  237. import_str = "from . import core\n"
  238. header_str = ""
  239. func_str = ""
  240. for symbols in self._symbol_groups.values():
  241. func_str += "@core.extern\n"
  242. func_name_str = f"def {symbols[0].op_name}("
  243. for arg_name in symbols[0].arg_names:
  244. func_name_str += f"{arg_name}, "
  245. func_name_str += "_builder=None):\n"
  246. return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), ["
  247. for arg_name in symbols[0].arg_names:
  248. return_str += f"{arg_name}, "
  249. return_str += "], \n"
  250. arg_type_symbol_dict_str = "{"
  251. for symbol in symbols:
  252. arg_type_symbol_dict_str += "("
  253. for arg_type in symbol.arg_types:
  254. arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),'
  255. ret_type = f'core.dtype("{symbol.ret_type}")'
  256. arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n"
  257. arg_type_symbol_dict_str += "}"
  258. return_str += arg_type_symbol_dict_str
  259. return_str += f", is_pure={self.is_pure}"
  260. return_str += ", _builder=_builder)\n"
  261. func_str += func_name_str + return_str + "\n"
  262. file_str = import_str + header_str + func_str
  263. return file_str
  264. class LLVMDisassembler:
  265. _path: str
  266. _ll_file: str
  267. def __init__(self, path) -> None:
  268. '''
  269. Invoke llvm-dis to disassemble the given file.
  270. :param path: path to llvm-dis
  271. '''
  272. self._path = path
  273. self._ll_file = "/tmp/extern_lib.ll"
  274. def disasm(self, lib_path: str) -> None:
  275. subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate()
  276. @property
  277. def ll_file(self) -> str:
  278. return self._ll_file
  279. @property
  280. def path(self) -> str:
  281. return self._path
  282. extern_libs = ["libdevice"]
  283. def build(
  284. llvm_dis_path: str,
  285. lib_path: str,
  286. lib_name: str,
  287. output_dir: str,
  288. ) -> None:
  289. '''
  290. Interface function to build the library file.
  291. :param llvm_dis_path: path to the llvm-dis binary
  292. :param lib_path: path to the external library file
  293. :param lib_name: name of the library
  294. :param output_dir: path to the output directory
  295. '''
  296. if lib_name == "libdevice":
  297. extern_lib = Libdevice(lib_path)
  298. else:
  299. raise Exception(f"Unknown extern library: {lib_name}")
  300. llvm_disassembler = LLVMDisassembler(llvm_dis_path)
  301. llvm_disassembler.disasm(lib_path)
  302. extern_lib.parse_symbols(llvm_disassembler.ll_file)
  303. extern_lib.generate_stub_file(output_dir)
  304. if __name__ == "__main__":
  305. parser = argparse.ArgumentParser()
  306. parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis")
  307. parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library")
  308. parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library")
  309. parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/")
  310. args = parser.parse_args()
  311. build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir)