code_generator.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639
  1. import ast
  2. import builtins
  3. import contextlib
  4. import copy
  5. import inspect
  6. import re
  7. import warnings
  8. import textwrap
  9. from dataclasses import dataclass
  10. from types import ModuleType
  11. from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
  12. from .. import knobs, language
  13. from .._C.libtriton import ir, gluon_ir
  14. from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
  15. from ..language.core import _unwrap_if_constexpr, base_value, base_type
  16. # ideally we wouldn't need any runtime component
  17. from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction
  18. from .._utils import find_paths_if, get_iterable_path, set_iterable_path, is_namedtuple
  19. from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
  20. def check_identifier_legality(name, type):
  21. pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$'
  22. if not re.match(pattern, name):
  23. raise CompilationError(f"invalid {type} identifier: {name}", name)
  24. return name
  25. def mangle_fn(name, arg_tys, constants, caller_context):
  26. # doesn't mangle ret type, which must be a function of arg tys
  27. mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
  28. mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
  29. mangled_constants = mangled_constants.replace('.', '_d_')
  30. mangled_constants = mangled_constants.replace("'", '_sq_')
  31. # [ and ] are not allowed in LLVM identifiers
  32. mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
  33. ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
  34. if caller_context is not None:
  35. ret += caller_context.mangle()
  36. return ret
  37. def _is_triton_value(o: Any) -> bool:
  38. return isinstance(o, base_value)
  39. def _is_triton_tensor(o: Any) -> bool:
  40. return isinstance(o, tensor)
  41. def _is_constexpr(o: Any) -> bool:
  42. return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
  43. def _is_non_scalar_tensor(o: Any) -> bool:
  44. return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
  45. def _is_list_like(o: Any) -> bool:
  46. return isinstance(o, (list, tuple))
  47. def _check_fn_args(node, fn, args):
  48. if fn.noinline:
  49. for idx, arg in enumerate(args):
  50. if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
  51. raise UnsupportedLanguageConstruct(
  52. fn.src, node,
  53. f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
  54. )
  55. def _apply_to_tuple_values(value, fn):
  56. if is_namedtuple(type(value)):
  57. fields = value._fields
  58. elif isinstance(value, language.tuple):
  59. fields = value.type.fields
  60. else:
  61. assert False, f"Unsupported type {type(value)}"
  62. vals = [fn(v) for v in value]
  63. vals = [constexpr(v) if v is None else v for v in vals]
  64. types = [v.type for v in vals]
  65. return language.tuple(vals, language.tuple_type(types, fields))
  66. def flatten_values_to_ir(values: Iterable[base_value]):
  67. handles = []
  68. for v in values:
  69. v._flatten_ir(handles)
  70. return handles
  71. def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
  72. cursor = 0
  73. for ty in types:
  74. value, cursor = ty._unflatten_ir(handles, cursor)
  75. yield value
  76. assert cursor == len(handles)
  77. _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
  78. def _clone_triton_value(val):
  79. handles = []
  80. val._flatten_ir(handles)
  81. clone, _ = val.type._unflatten_ir(handles, 0)
  82. return clone
  83. def _clone_scope(scope):
  84. return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
  85. class enter_sub_region:
  86. def __init__(self, generator):
  87. self.generator = generator
  88. def __enter__(self):
  89. # record lscope & local_defs in the parent scope
  90. self.liveins = _clone_scope(self.generator.lscope)
  91. self.prev_defs = _clone_scope(self.generator.local_defs)
  92. self.generator.local_defs = {}
  93. self.insert_block = self.generator.builder.get_insertion_block()
  94. self.insert_point = self.generator.builder.get_insertion_point()
  95. return self.liveins, self.insert_block
  96. def __exit__(self, *args, **kwargs):
  97. self.generator.builder.restore_insertion_point(self.insert_point)
  98. self.generator.lscope = self.liveins
  99. self.generator.local_defs = self.prev_defs
  100. # Check if the given syntax node has an "early" return
  101. class ContainsReturnChecker(ast.NodeVisitor):
  102. def __init__(self, gscope):
  103. self.gscope = gscope
  104. def _visit_stmts(self, body) -> bool:
  105. return any(self.visit(s) for s in body)
  106. def _visit_function(self, fn) -> bool:
  107. # No need to check within the function as it won't cause an early return.
  108. # If the function itself has unstructured control flow we may not be able to inline it causing poor performance,
  109. # we should check for this and emit a warning.
  110. return False
  111. def generic_visit(self, node) -> bool:
  112. ret = False
  113. for _, value in ast.iter_fields(node):
  114. if isinstance(value, list):
  115. for item in value:
  116. if isinstance(item, ast.AST):
  117. ret = ret or self.visit(item)
  118. elif isinstance(value, ast.AST):
  119. ret = ret or self.visit(value)
  120. return ret
  121. def visit_Attribute(self, node: ast.Attribute) -> bool:
  122. # If the left part is a name, it's possible that
  123. # we call triton native function or a jit function from another module.
  124. # If the left part is not a name, it must return a tensor or a constexpr
  125. # whose methods do not contain return statements
  126. # e.g., (tl.load(x)).to(y)
  127. # So we only check if the expressions within value have return or not
  128. if isinstance(node.value, ast.Name):
  129. if node.value.id in self.gscope:
  130. value = self.gscope[node.value.id]
  131. fn = getattr(value, node.attr)
  132. return self._visit_function(fn)
  133. return False
  134. return self.visit(node.value)
  135. def visit_Name(self, node: ast.Name) -> bool:
  136. if type(node.ctx) is ast.Store:
  137. return False
  138. if node.id in self.gscope:
  139. fn = self.gscope[node.id]
  140. return self._visit_function(fn)
  141. return False
  142. def visit_Return(self, node: ast.Return) -> bool:
  143. return True
  144. def visit_Assign(self, node: ast.Assign) -> bool:
  145. # There couldn't be an early return
  146. # x = ...
  147. return False
  148. def visit_AugAssign(self, node: ast.AugAssign) -> bool:
  149. # There couldn't be an early return
  150. # x += ...
  151. return False
  152. def visit_Module(self, node: ast.Module) -> bool:
  153. return self._visit_stmts(node.body)
  154. def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
  155. return self._visit_stmts(node.body)
  156. def visit_If(self, node: ast.If) -> bool:
  157. # TODO: optimize the following case in which we actually don't have
  158. # a return when static_cond is false:
  159. # if dynamic_cond
  160. # if static_cond
  161. # func_with_return
  162. # else
  163. # func_without_return
  164. ret = self._visit_stmts(node.body)
  165. if node.orelse:
  166. ret = ret or self._visit_stmts(node.orelse)
  167. return ret
  168. def visit_IfExp(self, node: ast.IfExp) -> bool:
  169. return self.visit(node.body) or self.visit(node.orelse)
  170. def visit_Call(self, node: ast.Call) -> bool:
  171. return self.visit(node.func)
  172. class ASTFunction:
  173. def __init__(self, ret_types, arg_types, constants, attrs):
  174. self.ret_types = ret_types
  175. self.arg_types = arg_types
  176. self.constants = constants
  177. self.attrs = attrs
  178. def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
  179. ir_types = []
  180. for ty in types:
  181. if ty is None:
  182. continue
  183. ty._flatten_ir_types(builder, ir_types)
  184. return ir_types
  185. def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
  186. return self.flatten_ir_types(builder, self.ret_types)
  187. def serialize(self, builder: ir.builder):
  188. # fill up IR values in template
  189. # > build function
  190. is_val = lambda path, _: path not in self.constants and _ is not None
  191. val_paths = list(find_paths_if(self.arg_types, is_val))
  192. arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
  193. arg_types_ir = self.flatten_ir_types(builder, arg_types)
  194. ret_types_ir = self.return_types_ir(builder)
  195. return builder.get_function_ty(arg_types_ir, ret_types_ir)
  196. def deserialize(self, fn):
  197. # create "template"
  198. def make_template(ty):
  199. if isinstance(ty, (list, tuple, language.tuple_type)):
  200. return language.tuple([make_template(x) for x in ty], ty)
  201. return language.constexpr(None)
  202. vals = make_template(self.arg_types)
  203. is_val = lambda path, _: path not in self.constants and _ is not None
  204. val_paths = list(find_paths_if(self.arg_types, is_val))
  205. # > add IR values to the template
  206. cursor = 0
  207. handles = [fn.args(i) for i in range(fn.get_num_args())]
  208. for path in val_paths:
  209. ty = get_iterable_path(self.arg_types, path)
  210. # > set attributes
  211. attr_specs = self.attrs.get(path, [])
  212. for attr_name, attr_val in attr_specs:
  213. fn.set_arg_attr(cursor, attr_name, attr_val)
  214. # > build frontend value
  215. val, cursor = ty._unflatten_ir(handles, cursor)
  216. set_iterable_path(vals, path, val)
  217. # > add constexpr values to the template
  218. constants = self.constants
  219. for path, val in constants.items():
  220. set_iterable_path(vals, path, language.constexpr(val))
  221. return vals
  222. @dataclass(frozen=True)
  223. class BoundJITMethod:
  224. __self__: base_value
  225. __func__: JITFunction
  226. class CodeGenerator(ast.NodeVisitor):
  227. def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
  228. module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
  229. noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
  230. self.context = context
  231. self.is_gluon = is_gluon
  232. if is_gluon:
  233. from triton.experimental.gluon.language._semantic import GluonSemantic
  234. self.builder = gluon_ir.GluonOpBuilder(context)
  235. self.semantic = GluonSemantic(self.builder)
  236. else:
  237. from triton.language.semantic import TritonSemantic
  238. self.builder = ir.builder(context)
  239. self.semantic = TritonSemantic(self.builder)
  240. self.name_loc_as_prefix = None
  241. self.file_name = file_name
  242. # node.lineno starts from 1, so we need to subtract 1
  243. self.begin_line = begin_line - 1
  244. self.builder.set_loc(file_name, begin_line, 0)
  245. self.builder.options = options
  246. # dict of functions provided by the backend. Below are the list of possible functions:
  247. # Convert custom types not natively supported on HW.
  248. # convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None)
  249. self.builder.codegen_fns = codegen_fns
  250. self.builder.module_map = {} if module_map is None else module_map
  251. self.module = self.builder.create_module() if module is None else module
  252. self.function_ret_types = {} if function_types is None else function_types
  253. self.prototype = prototype
  254. self.gscope = {}
  255. for k, v in gscope.items():
  256. if isinstance(v, ModuleType):
  257. self.gscope[k] = module_map.get(v.__name__, v)
  258. continue
  259. module_name = getattr(v, "__module__", "")
  260. if module_name in module_map:
  261. self.gscope[k] = getattr(module_map[module_name], v.__name__)
  262. else:
  263. self.gscope[k] = v
  264. self.lscope = {}
  265. self.jit_fn = jit_fn
  266. # TODO: we currently generate illegal names for non-kernel functions involving constexprs!
  267. if is_kernel:
  268. function_name = function_name[function_name.rfind('.') + 1:]
  269. function_name = check_identifier_legality(function_name, "function")
  270. self.function_name = function_name
  271. self.is_kernel = is_kernel
  272. self.cur_node = None
  273. self.noinline = noinline
  274. self.caller_context = caller_context
  275. self.scf_stack = []
  276. self.ret_type = None
  277. # SSA-construction
  278. # name => language.tensor
  279. self.local_defs: Dict[str, tensor] = {}
  280. self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
  281. self.fn = None
  282. # Are we currently visiting an ast.arg's default value? These have some
  283. # special handling.
  284. self.visiting_arg_default_value = False
  285. builtin_namespace: Dict[str, Any] = {
  286. _.__name__: _
  287. for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
  288. }
  289. builtin_namespace.update((
  290. ('print', language.core.device_print),
  291. ('min', language.core.builtin_min),
  292. ('max', language.core.builtin_max),
  293. ))
  294. def _unsupported(self, node, message):
  295. return UnsupportedLanguageConstruct(self.jit_fn.src, node, message)
  296. def _is_constexpr_global(self, name):
  297. absent_marker = object()
  298. val = self.gscope.get(name, absent_marker)
  299. if val is absent_marker:
  300. return False
  301. if _is_constexpr(val):
  302. return True
  303. return False
  304. def _define_name_lookup(self):
  305. def local_lookup(name: str, absent):
  306. # this needs to be re-fetched from `self` every time, because it gets switched occasionally
  307. return self.lscope.get(name, absent)
  308. def global_lookup(name: str, absent):
  309. val = self.gscope.get(name, absent)
  310. # The high-level rule is that only constexpr globals are allowed.
  311. # But actually a bunch of other things, such as module imports, are
  312. # technically Python globals. We have to allow these too!
  313. if any([
  314. val is absent,
  315. name in self.builtin_namespace, #
  316. type(val) is ModuleType, #
  317. isinstance(val, JITCallable), #
  318. getattr(val, "__triton_builtin__", False), #
  319. getattr(val, "__triton_aggregate__", False), #
  320. getattr(val, "__module__", "").startswith("triton.language"), #
  321. getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), #
  322. isinstance(val, language.dtype), #
  323. is_namedtuple(val),
  324. self._is_constexpr_global(name), #
  325. # Allow accesses to globals while visiting an ast.arg
  326. # because you should be able to do
  327. # @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
  328. self.visiting_arg_default_value, #
  329. knobs.compilation.allow_non_constexpr_globals,
  330. ]):
  331. return val
  332. raise NameError(
  333. textwrap.dedent(f"""\
  334. Cannot access global variable {name} from within @jit'ed
  335. function. Triton kernels can only access global variables that
  336. are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from
  337. annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the
  338. envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
  339. promise to support this forever.""").replace("\n", " "))
  340. absent_marker = object()
  341. def name_lookup(name: str) -> Any:
  342. absent = absent_marker
  343. for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get:
  344. value = lookup_function(name, absent)
  345. if value is not absent:
  346. return value
  347. raise NameError(f'{name} is not defined')
  348. return name_lookup
  349. @contextlib.contextmanager
  350. def _name_loc_prefix(self, prefix):
  351. self.name_loc_as_prefix = prefix
  352. yield
  353. self.name_loc_as_prefix = None
  354. def _maybe_set_loc_to_name(self, val, name):
  355. if isinstance(val, (ir.value, ir.block_argument)):
  356. val.set_loc(self.builder.create_name_loc(name, val.get_loc()))
  357. elif _is_triton_value(val):
  358. handles = []
  359. val._flatten_ir(handles)
  360. for handle in handles:
  361. handle.set_loc(self.builder.create_name_loc(name, handle.get_loc()))
  362. def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
  363. ''' This function:
  364. called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
  365. 1. record local defined name (FIXME: should consider control flow)
  366. 2. store tensor in self.lvalue
  367. '''
  368. self.lscope[name] = value
  369. self.local_defs[name] = value
  370. def _get_insertion_point_and_loc(self):
  371. # XXX: this is a hack to get the location of the insertion point.
  372. # The insertion point's location could be invalid sometimes,
  373. # so we need to explicitly set the location
  374. loc = self.builder.get_loc()
  375. ip = self.builder.get_insertion_point()
  376. return ip, loc
  377. def _set_insertion_point_and_loc(self, ip, loc):
  378. self.builder.restore_insertion_point(ip)
  379. self.builder.set_loc(loc)
  380. def _find_carries(self, node, liveins, ignore: set[str] = set()):
  381. # create loop body block
  382. block = self.builder.create_block()
  383. self.builder.set_insertion_point_to_start(block)
  384. # dry visit loop body
  385. self.scf_stack.append(node)
  386. self.visit_compound_statement(node.body)
  387. self.scf_stack.pop()
  388. block.erase()
  389. # If a variable (name) has changed value within the loop, then it's
  390. # a loop-carried variable. (The new and old value must be of the
  391. # same type)
  392. init_tys = []
  393. init_handles = []
  394. names = []
  395. for name, live_val in liveins.items():
  396. if name in ignore:
  397. continue
  398. if _is_triton_value(live_val):
  399. loop_val = self.lscope[name]
  400. self._verify_loop_carried_variable(name, loop_val, live_val)
  401. live_handles = flatten_values_to_ir([live_val])
  402. loop_handles = flatten_values_to_ir([loop_val])
  403. if live_handles != loop_handles:
  404. names.append(name)
  405. init_tys.append(live_val.type)
  406. init_handles.extend(live_handles)
  407. else:
  408. assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
  409. # reset local scope to not pick up local defs from the dry run.
  410. self.lscope = liveins.copy()
  411. self.local_defs = {}
  412. return names, init_handles, init_tys
  413. #
  414. # AST visitor
  415. #
  416. def visit_compound_statement(self, stmts):
  417. # Ensure that stmts is iterable
  418. if not _is_list_like(stmts):
  419. stmts = [stmts]
  420. for stmt in stmts:
  421. self.visit(stmt)
  422. # Stop parsing as soon as we hit a `return` statement; everything
  423. # after this is dead code.
  424. if isinstance(stmt, ast.Return):
  425. break
  426. def visit_Module(self, node):
  427. ast.NodeVisitor.generic_visit(self, node)
  428. def visit_List(self, node):
  429. ctx = self.visit(node.ctx)
  430. assert ctx is None
  431. elts = language.tuple([self.visit(elt) for elt in node.elts])
  432. return elts
  433. def visit_ListComp(self, node: ast.ListComp):
  434. if len(node.generators) != 1:
  435. raise ValueError("nested comprehensions are not supported")
  436. comp = node.generators[0]
  437. iter = self.visit(comp.iter)
  438. if not isinstance(iter, tl_tuple):
  439. raise NotImplementedError("only tuple comprehensions are supported")
  440. results = []
  441. for item in iter:
  442. self.set_value(comp.target.id, item)
  443. results.append(self.visit(node.elt))
  444. return tl_tuple(results)
  445. # By design, only non-kernel functions can return
  446. def visit_Return(self, node):
  447. ret_value = self.visit(node.value)
  448. handles = []
  449. def decay(value):
  450. if isinstance(value, language.tuple):
  451. return _apply_to_tuple_values(value, decay)
  452. elif isinstance(value, (language.constexpr, int, float)):
  453. return self.semantic.to_tensor(value)
  454. return value
  455. ret_value = decay(ret_value)
  456. if ret_value is None:
  457. ret_ty = language.void
  458. else:
  459. assert isinstance(ret_value, language.core.base_value)
  460. ret_value._flatten_ir(handles)
  461. ret_ty = ret_value.type
  462. self.builder.ret(handles)
  463. if self.ret_type is None:
  464. self.ret_type = ret_ty
  465. elif self.ret_type != ret_ty:
  466. raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}')
  467. # A return op must always terminate the basic block, so we create a dead
  468. # basic block in case there are any ops after the return.
  469. post_ret_block = self.builder.create_block()
  470. self.builder.set_insertion_point_to_end(post_ret_block)
  471. def visit_FunctionDef(self, node):
  472. arg_names, kwarg_names = self.visit(node.args)
  473. if self.fn:
  474. raise self._unsupported(node, "nested function definition is not supported.")
  475. # initialize defaults
  476. for i, default_value in enumerate(node.args.defaults[::-1]):
  477. arg_node = node.args.args[-i - 1]
  478. annotation = arg_node.annotation
  479. name = arg_node.arg
  480. st_target = ast.Name(id=name, ctx=ast.Store())
  481. if annotation is None:
  482. init_node = ast.Assign(targets=[st_target], value=default_value)
  483. else:
  484. init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
  485. try:
  486. assert not self.visiting_arg_default_value
  487. self.visiting_arg_default_value = True
  488. self.visit(init_node)
  489. finally:
  490. self.visiting_arg_default_value = False
  491. # initialize function
  492. visibility = "public" if self.is_kernel else "private"
  493. fn_ty = self.prototype.serialize(self.builder)
  494. self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline)
  495. self.module.push_back(self.fn)
  496. entry = self.fn.add_entry_block()
  497. arg_values = self.prototype.deserialize(self.fn)
  498. if self.caller_context is not None:
  499. self.caller_context.initialize_callee(self.fn, self.builder)
  500. # bind arguments to symbols
  501. for arg_name, arg_value in zip(arg_names, arg_values):
  502. self._maybe_set_loc_to_name(arg_value, arg_name)
  503. self.set_value(arg_name, arg_value)
  504. insert_pt = self.builder.get_insertion_block()
  505. self.builder.set_insertion_point_to_start(entry)
  506. # visit function body
  507. self.visit_compound_statement(node.body)
  508. # finalize function
  509. assert not self.builder.get_insertion_block().has_terminator()
  510. if self.ret_type is None or self.ret_type == language.void:
  511. self.ret_type = language.void
  512. self.builder.ret([])
  513. else:
  514. if isinstance(self.ret_type, language.tuple_type):
  515. self.prototype.ret_types = self.ret_type.types
  516. else:
  517. self.prototype.ret_types = [self.ret_type]
  518. self.fn.reset_type(self.prototype.serialize(self.builder))
  519. self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)])
  520. self.fn.finalize()
  521. if insert_pt:
  522. self.builder.set_insertion_point_to_end(insert_pt)
  523. def visit_arguments(self, node):
  524. arg_names = []
  525. for arg in node.args:
  526. arg_names += [self.visit(arg)]
  527. kwarg_names = self.visit(node.kwarg)
  528. return arg_names, kwarg_names
  529. def visit_arg(self, node):
  530. ast.NodeVisitor.generic_visit(self, node)
  531. param = next(p for p in self.jit_fn.params if p.name == node.arg)
  532. if param.is_constexpr and (param.do_not_specialize or param.do_not_specialize_on_alignment):
  533. raise CompilationError(
  534. self.jit_fn.src, node,
  535. f"{node.arg} marked as constexpr and listed in do_not_specialize/do_not_specialize_on_alignment. "
  536. "Remove constexpr designation to skip specialization.")
  537. return node.arg
  538. def visit_AnnAssign(self, node):
  539. # extract attributes
  540. annotation = self.visit(node.annotation)
  541. target = self.visit(node.target)
  542. value = self.visit(node.value)
  543. # constexpr
  544. if annotation == constexpr:
  545. if target in self.lscope:
  546. raise ValueError(f'{target} is already defined.'
  547. f' constexpr cannot be reassigned.')
  548. value = constexpr(value)
  549. self.lscope[target] = value
  550. return self.lscope[target]
  551. # default: call visit_Assign
  552. return self.visit_Assign(node)
  553. def assignTarget(self, target, value):
  554. assert isinstance(target.ctx, ast.Store)
  555. if isinstance(target, ast.Subscript):
  556. return self.visit_Subscript_Store(target, value)
  557. if isinstance(target, ast.Tuple):
  558. for i, target in enumerate(target.elts):
  559. self.assignTarget(target, value.values[i])
  560. return
  561. if isinstance(target, ast.Attribute):
  562. raise NotImplementedError("Attribute assignment is not supported in triton")
  563. assert isinstance(target, ast.Name)
  564. self.set_value(self.visit(target), value)
  565. def visit_Assign(self, node):
  566. # construct values to assign
  567. def _sanitize_value(value):
  568. if isinstance(value, language.tuple):
  569. return _apply_to_tuple_values(value, _sanitize_value)
  570. native_nontensor_types = (language.dtype, language.tuple)
  571. value = _unwrap_if_constexpr(value)
  572. if value is not None and \
  573. not _is_triton_value(value) and \
  574. not isinstance(value, native_nontensor_types):
  575. value = self.semantic.to_tensor(value)
  576. return value
  577. targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
  578. assert len(targets) == 1
  579. target = targets[0]
  580. if isinstance(target, ast.Name):
  581. with self._name_loc_prefix(target.id):
  582. values = _sanitize_value(self.visit(node.value))
  583. else:
  584. values = _sanitize_value(self.visit(node.value))
  585. self.assignTarget(target, values)
  586. def visit_AugAssign(self, node):
  587. lhs = copy.deepcopy(node.target)
  588. lhs.ctx = ast.Load()
  589. rhs = ast.BinOp(lhs, node.op, node.value)
  590. assign = ast.Assign(targets=[node.target], value=rhs)
  591. for x in ['lineno', 'col_offset', 'end_lineno', 'end_col_offset']:
  592. if hasattr(node, x):
  593. y = getattr(node, x)
  594. setattr(rhs, x, y)
  595. setattr(assign, x, y)
  596. self.visit(assign)
  597. return self.visit(lhs)
  598. def visit_Name(self, node):
  599. if type(node.ctx) is ast.Store:
  600. return node.id
  601. return self.dereference_name(node.id)
  602. def visit_Store(self, node):
  603. ast.NodeVisitor.generic_visit(self, node)
  604. def visit_Load(self, node):
  605. ast.NodeVisitor.generic_visit(self, node)
  606. def visit_Tuple(self, node):
  607. args = [self.visit(x) for x in node.elts]
  608. return language.tuple(args)
  609. def _apply_binary_method(self, node, method_name, lhs, rhs):
  610. # TODO: raise something meaningful if getattr fails below, esp for reverse method
  611. if _is_triton_tensor(lhs):
  612. return getattr(lhs, method_name)(rhs, _semantic=self.semantic)
  613. if _is_triton_tensor(rhs):
  614. reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
  615. return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
  616. if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
  617. lhs = constexpr(lhs)
  618. if isinstance(lhs, constexpr):
  619. fn = getattr(lhs, method_name)
  620. else:
  621. fn = self.get_Attribute(lhs, method_name)
  622. return self.call_Function(node, fn, [rhs], {})
  623. def visit_BinOp(self, node):
  624. lhs = self.visit(node.left)
  625. rhs = self.visit(node.right)
  626. method_name = self._method_name_for_bin_op.get(type(node.op))
  627. if method_name is None:
  628. raise self._unsupported(node,
  629. "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
  630. return self._apply_binary_method(node, method_name, lhs, rhs)
  631. _method_name_for_bin_op: Dict[Type[ast.operator], str] = {
  632. ast.Add: '__add__',
  633. ast.Sub: '__sub__',
  634. ast.Mult: '__mul__',
  635. ast.Div: '__truediv__',
  636. ast.FloorDiv: '__floordiv__',
  637. ast.Mod: '__mod__',
  638. ast.Pow: '__pow__',
  639. ast.LShift: '__lshift__',
  640. ast.RShift: '__rshift__',
  641. ast.BitAnd: '__and__',
  642. ast.BitOr: '__or__',
  643. ast.BitXor: '__xor__',
  644. }
  645. def visit_then_else_blocks(self, node, liveins, then_block, else_block):
  646. # then block
  647. self.builder.set_insertion_point_to_start(then_block)
  648. self.visit_compound_statement(node.body)
  649. then_block = self.builder.get_insertion_block()
  650. then_defs = self.local_defs.copy()
  651. then_vals = self.lscope.copy()
  652. # else block
  653. else_defs = {}
  654. else_vals = liveins.copy()
  655. if node.orelse:
  656. self.builder.set_insertion_point_to_start(else_block)
  657. self.lscope = liveins.copy()
  658. self.local_defs = {}
  659. self.visit_compound_statement(node.orelse)
  660. else_defs = self.local_defs.copy()
  661. else_block = self.builder.get_insertion_block()
  662. else_vals = self.lscope.copy()
  663. # update block arguments
  664. names = []
  665. # variables in livein whose value is updated in `if`
  666. for name, value in liveins.items():
  667. # livein variable changed value in either then or else
  668. if not _is_triton_value(value):
  669. continue
  670. then_handles = flatten_values_to_ir([then_vals[name]])
  671. else_handles = flatten_values_to_ir([else_vals[name]])
  672. if then_handles == else_handles:
  673. continue
  674. names.append(name)
  675. then_defs[name] = then_vals[name]
  676. else_defs[name] = else_vals[name]
  677. # check type
  678. for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
  679. type_equal = type(defs[name]) == type(value) # noqa: E721
  680. assert type_equal and defs[name].type == value.type, \
  681. f'initial value for `{name}` is of type {value}, '\
  682. f'but the {block_name} block redefines it as {defs[name]}'
  683. # variables that are both in then and else but not in liveins
  684. # TODO: could probably be cleaned up
  685. for name in sorted(then_defs.keys() & else_defs.keys()):
  686. if name in names:
  687. continue
  688. then_val = then_defs[name]
  689. then_ty = then_val.type
  690. else_val = else_defs[name]
  691. else_ty = else_val.type
  692. type_equal = type(then_val) == type(else_val) # noqa: E721
  693. assert type_equal and then_ty == else_ty, \
  694. f'Mismatched type for {name} between then block ({then_ty}) '\
  695. f'and else block ({else_ty})'
  696. names.append(name)
  697. return then_defs, else_defs, then_block, else_block, names
  698. def visit_if_top_level(self, cond, node):
  699. with enter_sub_region(self) as sr:
  700. liveins, ip_block = sr
  701. then_block = self.builder.create_block()
  702. else_block = self.builder.create_block()
  703. # create branch
  704. self.builder.set_insertion_point_to_end(ip_block)
  705. self.builder.create_cond_branch(cond.handle, then_block, else_block)
  706. # visit then and else blocks
  707. then_defs, else_defs, then_block, else_block, names = \
  708. self.visit_then_else_blocks(node, liveins, then_block, else_block)
  709. # create basic-block after conditional
  710. endif_block = self.builder.create_block()
  711. # then terminator
  712. self.builder.set_insertion_point_to_end(then_block)
  713. assert not then_block.has_terminator(), f"{then_block}"
  714. then_handles = flatten_values_to_ir(then_defs[name] for name in names)
  715. self.builder.create_branch(endif_block, then_handles)
  716. # else terminator
  717. self.builder.set_insertion_point_to_end(else_block)
  718. assert not else_block.has_terminator(), f"{else_block}"
  719. else_handles = flatten_values_to_ir(else_defs[name] for name in names)
  720. self.builder.create_branch(endif_block, else_handles)
  721. assert len(then_handles) == len(else_handles)
  722. for then_h, else_h in zip(then_handles, else_handles):
  723. ty = then_h.get_type()
  724. assert ty == else_h.get_type()
  725. endif_block.add_argument(ty)
  726. # change block
  727. self.builder.set_insertion_point_to_start(endif_block)
  728. # update value
  729. res_handles = [endif_block.arg(i) for i in range(len(then_handles))]
  730. types = [then_defs[name].type for name in names]
  731. new_values = unflatten_ir_values(res_handles, types)
  732. for name, new_value in zip(names, new_values):
  733. self.set_value(name, new_value)
  734. # TODO: refactor
  735. def visit_if_scf(self, cond, node):
  736. with enter_sub_region(self) as sr:
  737. liveins, _ = sr
  738. ip, last_loc = self._get_insertion_point_and_loc()
  739. then_block = self.builder.create_block()
  740. else_block = self.builder.create_block() if node.orelse else None
  741. then_defs, else_defs, then_block, else_block, names = \
  742. self.visit_then_else_blocks(node, liveins, then_block, else_block)
  743. # create if op
  744. then_handles = flatten_values_to_ir(then_defs[name] for name in names)
  745. for name, val in zip(names, then_handles):
  746. self._maybe_set_loc_to_name(val, name)
  747. self._set_insertion_point_and_loc(ip, last_loc)
  748. if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
  749. then_block.merge_block_before(if_op.get_then_block())
  750. self.builder.set_insertion_point_to_end(if_op.get_then_block())
  751. if len(names) > 0:
  752. self.builder.create_yield_op(then_handles)
  753. if not node.orelse:
  754. else_block = if_op.get_else_block()
  755. else:
  756. else_block.merge_block_before(if_op.get_else_block())
  757. self.builder.set_insertion_point_to_end(if_op.get_else_block())
  758. if len(names) > 0:
  759. else_handles = flatten_values_to_ir(else_defs[name] for name in names)
  760. for name, val in zip(names, else_handles):
  761. self._maybe_set_loc_to_name(val, name)
  762. self.builder.create_yield_op(else_handles)
  763. # update values
  764. res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
  765. types = [then_defs[name].type for name in names]
  766. new_values = unflatten_ir_values(res_handles, types)
  767. for name, new_value in zip(names, new_values):
  768. self.set_value(name, new_value)
  769. def visit_If(self, node):
  770. cond = self.visit(node.test)
  771. if _is_triton_tensor(cond):
  772. if _is_non_scalar_tensor(cond):
  773. raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
  774. if cond.type.is_block():
  775. warnings.warn(
  776. "If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
  777. % ast.unparse(node.test))
  778. cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
  779. cond = cond.to(language.int1, _semantic=self.semantic)
  780. if ContainsReturnChecker(self.gscope).visit(node):
  781. if self.scf_stack:
  782. raise self._unsupported(
  783. node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
  784. self.visit_if_top_level(cond, node)
  785. else:
  786. self.visit_if_scf(cond, node)
  787. else:
  788. cond = _unwrap_if_constexpr(cond)
  789. # not isinstance - we insist the real thing, no subclasses and no ducks
  790. if type(cond) not in _condition_types:
  791. raise self._unsupported(
  792. node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
  793. ', '.join(_.__name__ for _ in _condition_types),
  794. type(cond).__name__))
  795. active_block = node.body if cond else node.orelse
  796. self.visit_compound_statement(active_block)
  797. def visit_IfExp(self, node):
  798. cond = self.visit(node.test)
  799. if _is_triton_tensor(cond):
  800. cond = cond.to(language.int1, _semantic=self.semantic)
  801. # TODO: Deal w/ more complicated return types (e.g tuple)
  802. with enter_sub_region(self):
  803. ip, last_loc = self._get_insertion_point_and_loc()
  804. then_block = self.builder.create_block()
  805. self.builder.set_insertion_point_to_start(then_block)
  806. then_val = self.semantic.to_tensor(self.visit(node.body))
  807. then_block = self.builder.get_insertion_block()
  808. else_block = self.builder.create_block()
  809. self.builder.set_insertion_point_to_start(else_block)
  810. # do not need to reset lscope since
  811. # ternary expressions cannot define new variables
  812. else_val = self.semantic.to_tensor(self.visit(node.orelse))
  813. else_block = self.builder.get_insertion_block()
  814. self._set_insertion_point_and_loc(ip, last_loc)
  815. assert then_val.type == else_val.type, \
  816. f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
  817. ret_type = then_val.type
  818. ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
  819. if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
  820. then_block.merge_block_before(if_op.get_then_block())
  821. if ret_type_ir:
  822. self.builder.set_insertion_point_to_end(if_op.get_then_block())
  823. self.builder.create_yield_op([then_val.handle])
  824. self.builder.set_insertion_point_to_end(if_op.get_then_block())
  825. else_block.merge_block_before(if_op.get_else_block())
  826. if ret_type_ir:
  827. self.builder.set_insertion_point_to_end(if_op.get_else_block())
  828. self.builder.create_yield_op([else_val.handle])
  829. return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
  830. else:
  831. cond = _unwrap_if_constexpr(cond)
  832. # not isinstance - we insist the real thing, no subclasses and no ducks
  833. if type(cond) not in _condition_types:
  834. raise self._unsupported(
  835. node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
  836. ', '.join(_.__name__ for _ in _condition_types),
  837. type(cond).__name__))
  838. if cond:
  839. return self.visit(node.body)
  840. else:
  841. return self.visit(node.orelse)
  842. def visit_With(self, node):
  843. # Lower `with` statements by constructing context managers and calling their enter/exit hooks
  844. # Instantiate each context manager with builder injection
  845. cm_list = []
  846. for item in node.items:
  847. call = item.context_expr
  848. fn = self.visit(call.func)
  849. args = [self.visit(arg) for arg in call.args]
  850. kws = dict(self.visit(kw) for kw in call.keywords)
  851. cm = fn(*args, _semantic=self.semantic, **kws)
  852. cm_list.append(cm)
  853. for cm, item in zip(cm_list, node.items):
  854. res = cm.__enter__()
  855. if item.optional_vars is not None:
  856. var_name = self.visit(item.optional_vars)
  857. self.set_value(var_name, res)
  858. if ContainsReturnChecker(self.gscope).visit(node):
  859. raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ")
  860. self.visit_compound_statement(node.body)
  861. for cm in reversed(cm_list):
  862. cm.__exit__(None, None, None)
  863. def visit_Pass(self, node):
  864. pass
  865. def visit_Compare(self, node):
  866. if not (len(node.comparators) == 1 and len(node.ops) == 1):
  867. raise self._unsupported(node, "simultaneous multiple comparison is not supported")
  868. lhs = self.visit(node.left)
  869. rhs = self.visit(node.comparators[0])
  870. lhs_value = _unwrap_if_constexpr(lhs)
  871. rhs_value = _unwrap_if_constexpr(rhs)
  872. if type(node.ops[0]) is ast.Is:
  873. return constexpr(lhs_value is rhs_value)
  874. if type(node.ops[0]) is ast.IsNot:
  875. return constexpr(lhs_value is not rhs_value)
  876. method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
  877. if method_name is None:
  878. raise self._unsupported(
  879. node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
  880. return self._apply_binary_method(node, method_name, lhs, rhs)
  881. _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
  882. ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
  883. }
  884. def visit_UnaryOp(self, node):
  885. operand = self.visit(node.operand)
  886. fn = self._method_name_for_unary_op.get(type(node.op))
  887. if fn is None:
  888. raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
  889. if _is_triton_tensor(operand):
  890. return getattr(operand, fn)(_semantic=self.semantic)
  891. try:
  892. return getattr(operand, fn)()
  893. except AttributeError:
  894. if fn == "__not__":
  895. return constexpr(not operand)
  896. raise self._unsupported(
  897. node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
  898. _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
  899. ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
  900. }
  901. def _verify_loop_carried_variable(self, name, loop_val, live_val):
  902. assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
  903. assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
  904. assert type(loop_val) is type(live_val), (
  905. f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}')
  906. assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
  907. f'Loop-carried variable {name} has initial type {live_val.type} '\
  908. f'but is re-assigned to {loop_val.type} in loop! '\
  909. f'Please make sure that the type stays consistent.'
  910. def visit_While(self, node):
  911. with enter_sub_region(self) as sr:
  912. liveins, insert_block = sr
  913. ip, last_loc = self._get_insertion_point_and_loc()
  914. names, init_handles, init_fe_tys = self._find_carries(node, liveins)
  915. init_tys = [h.get_type() for h in init_handles]
  916. self._set_insertion_point_and_loc(ip, last_loc)
  917. while_op = self.builder.create_while_op(init_tys, init_handles)
  918. # merge the condition region
  919. before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys)
  920. self.builder.set_insertion_point_to_start(before_block)
  921. block_args = [before_block.arg(i) for i in range(len(init_handles))]
  922. condition_args = unflatten_ir_values(block_args, init_fe_tys)
  923. for name, val in zip(names, condition_args):
  924. self.lscope[name] = val
  925. self.local_defs[name] = val
  926. self._maybe_set_loc_to_name(val, name)
  927. cond = self.visit(node.test)
  928. if isinstance(cond, language.condition):
  929. if cond.disable_licm:
  930. while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
  931. cond = cond.condition
  932. self.builder.set_insertion_point_to_end(before_block)
  933. # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
  934. self.builder.create_condition_op(cond.handle, block_args)
  935. # merge the loop body
  936. after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys)
  937. # generate loop body
  938. self.builder.set_insertion_point_to_start(after_block)
  939. body_handles = [after_block.arg(i) for i in range(len(init_handles))]
  940. body_args = unflatten_ir_values(body_handles, init_fe_tys)
  941. for name, val in zip(names, body_args):
  942. self.lscope[name] = val
  943. self.local_defs[name] = val
  944. self._maybe_set_loc_to_name(val, name)
  945. self.scf_stack.append(node)
  946. self.visit_compound_statement(node.body)
  947. self.scf_stack.pop()
  948. yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
  949. self.builder.create_yield_op(yield_handles)
  950. # WhileOp defines new values, update the symbol table (lscope, local_defs)
  951. result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
  952. result_vals = unflatten_ir_values(result_handles, init_fe_tys)
  953. for name, new_def in zip(names, result_vals):
  954. self.lscope[name] = new_def
  955. self.local_defs[name] = new_def
  956. self._maybe_set_loc_to_name(new_def, name)
  957. for stmt in node.orelse:
  958. assert False, "Not implemented"
  959. ast.NodeVisitor.generic_visit(self, stmt)
  960. def visit_Subscript_Load(self, node):
  961. assert isinstance(node.ctx, ast.Load)
  962. lhs = self.visit(node.value)
  963. slices = self.visit(node.slice)
  964. if _is_triton_value(lhs):
  965. return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
  966. return lhs[slices]
  967. def visit_Subscript_Store(self, node, value):
  968. raise NotImplementedError("__setitem__ is not supported in triton")
  969. def visit_Subscript(self, node):
  970. return self.visit_Subscript_Load(node)
  971. def visit_ExtSlice(self, node):
  972. return [self.visit(dim) for dim in node.dims]
  973. def visit_For(self, node):
  974. IteratorClass = self.visit(node.iter.func)
  975. iter_args = [self.visit(arg) for arg in node.iter.args]
  976. iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords)
  977. if IteratorClass == language.static_range:
  978. iterator = IteratorClass(*iter_args, **iter_kwargs)
  979. static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
  980. for i in static_range:
  981. self.lscope[node.target.id] = constexpr(i)
  982. self.visit_compound_statement(node.body)
  983. for stmt in node.orelse:
  984. ast.NodeVisitor.generic_visit(self, stmt)
  985. return
  986. num_stages = None
  987. loop_unroll_factor = None
  988. disallow_acc_multi_buffer = False
  989. flatten = False
  990. warp_specialize = False
  991. disable_licm = False
  992. if IteratorClass is language.range:
  993. iterator = IteratorClass(*iter_args, **iter_kwargs)
  994. # visit iterator arguments
  995. # note: only `range` iterator is supported now
  996. # collect lower bound (lb), upper bound (ub), and step
  997. lb = iterator.start
  998. ub = iterator.end
  999. step = iterator.step
  1000. num_stages = iterator.num_stages
  1001. loop_unroll_factor = iterator.loop_unroll_factor
  1002. disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
  1003. flatten = iterator.flatten
  1004. warp_specialize = iterator.warp_specialize
  1005. disable_licm = iterator.disable_licm
  1006. elif IteratorClass is range:
  1007. # visit iterator arguments
  1008. # note: only `range` iterator is supported now
  1009. # collect lower bound (lb), upper bound (ub), and step
  1010. lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Constant(0))
  1011. ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
  1012. step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Constant(1))
  1013. else:
  1014. raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
  1015. # handle negative constant step (not supported by scf.for in MLIR)
  1016. negative_step = False
  1017. if _is_constexpr(step) and step.value < 0:
  1018. step = constexpr(-step.value)
  1019. negative_step = True
  1020. lb, ub = ub, lb
  1021. lb = self.semantic.to_tensor(lb)
  1022. ub = self.semantic.to_tensor(ub)
  1023. step = self.semantic.to_tensor(step)
  1024. # induction variable type
  1025. if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
  1026. raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
  1027. if _is_non_scalar_tensor(lb):
  1028. raise TypeError(f"For lower bound must be a scalar, got {lb.type}")
  1029. if _is_non_scalar_tensor(ub):
  1030. raise TypeError(f"For upper bound must be a scalar, got {ub.type}")
  1031. if _is_non_scalar_tensor(step):
  1032. raise TypeError(f"For step must be a scalar, got {step.type}")
  1033. iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype)
  1034. iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype)
  1035. iv_ir_type = iv_type.to_ir(self.builder)
  1036. iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
  1037. # lb/ub/step might be constexpr, we need to cast them to tensor
  1038. lb = lb.handle
  1039. ub = ub.handle
  1040. step = step.handle
  1041. # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
  1042. lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed)
  1043. ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
  1044. step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
  1045. # Create placeholder for the loop induction variable
  1046. iv_placeholder = self.builder.create_poison(iv_ir_type)
  1047. self.set_value(node.target.id, language.core.tensor(iv_placeholder, iv_type))
  1048. with enter_sub_region(self) as sr:
  1049. liveins, insert_block = sr
  1050. ip, last_loc = self._get_insertion_point_and_loc()
  1051. names, init_handles, init_tys = self._find_carries(node, liveins, ignore={node.target.id})
  1052. # create ForOp
  1053. self._set_insertion_point_and_loc(ip, last_loc)
  1054. for_op = self.builder.create_for_op(lb, ub, step, init_handles)
  1055. if _unwrap_if_constexpr(num_stages) is not None:
  1056. for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
  1057. if _unwrap_if_constexpr(loop_unroll_factor) is not None:
  1058. for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
  1059. if disallow_acc_multi_buffer:
  1060. for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
  1061. if flatten:
  1062. for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
  1063. if warp_specialize:
  1064. for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
  1065. if disable_licm:
  1066. for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
  1067. self.scf_stack.append(node)
  1068. for_op_body = for_op.get_body(0)
  1069. self.builder.set_insertion_point_to_start(for_op_body)
  1070. block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
  1071. block_args = unflatten_ir_values(block_handles, init_tys)
  1072. for name, val in zip(names, block_args):
  1073. self._maybe_set_loc_to_name(val, name)
  1074. self.set_value(name, val)
  1075. self.visit_compound_statement(node.body)
  1076. self.scf_stack.pop()
  1077. yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
  1078. # create YieldOp
  1079. if len(yield_handles) > 0:
  1080. self.builder.create_yield_op(yield_handles)
  1081. for_op_region = for_op_body.get_parent()
  1082. assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
  1083. # update induction variable with actual value, and replace all uses
  1084. self.builder.set_insertion_point_to_start(for_op_body)
  1085. iv = for_op.get_induction_var()
  1086. if negative_step:
  1087. iv = self.builder.create_sub(ub, iv)
  1088. iv = self.builder.create_add(iv, lb)
  1089. iv_placeholder.replace_all_uses_with(iv)
  1090. self.set_value(node.target.id, language.core.tensor(iv, iv_type))
  1091. self._maybe_set_loc_to_name(iv, node.target.id)
  1092. # update lscope & local_defs (ForOp defines new values)
  1093. result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
  1094. result_values = unflatten_ir_values(result_handles, init_tys)
  1095. for name, val in zip(names, result_values):
  1096. self.set_value(name, val)
  1097. self._maybe_set_loc_to_name(val, name)
  1098. for stmt in node.orelse:
  1099. assert False, "Don't know what to do with else after for"
  1100. ast.NodeVisitor.generic_visit(self, stmt)
  1101. def visit_Slice(self, node):
  1102. lower = self.visit(node.lower)
  1103. upper = self.visit(node.upper)
  1104. step = self.visit(node.step)
  1105. return language.slice(lower, upper, step)
  1106. def visit_Index(self, node):
  1107. return self.visit(node.value)
  1108. def visit_keyword(self, node) -> Tuple[str, Any]:
  1109. return node.arg, self.visit(node.value)
  1110. def visit_Assert(self, node) -> Any:
  1111. test = self.visit(node.test)
  1112. msg = self.visit(node.msg) if node.msg is not None else ""
  1113. return language.core.device_assert(test, msg, _semantic=self.semantic)
  1114. def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
  1115. args = inspect.getcallargs(fn.fn, *args, **kwargs)
  1116. args = [args[name] for name in fn.arg_names]
  1117. for i, arg in enumerate(args):
  1118. if isinstance(arg, (language.dtype, float, int, bool, JITFunction)):
  1119. args[i] = language.core.constexpr(arg)
  1120. args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
  1121. args_cst = {path: get_iterable_path(args, path) for path in args_cst}
  1122. args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
  1123. args_val = [get_iterable_path(args, path) for path in args_path]
  1124. # mangle
  1125. caller_context = caller_context or self.caller_context
  1126. fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
  1127. # generate function def if necessary
  1128. if not self.module.has_function(fn_name):
  1129. # If the callee is not set, we use the same debug setting as the caller
  1130. file_name, begin_line = get_jit_fn_file_line(fn)
  1131. arg_types = [
  1132. language.core.constexpr if arg is None or isinstance(arg,
  1133. (bool, int, language.core.dtype)) else arg.type
  1134. for arg in args
  1135. ]
  1136. prototype = ASTFunction([], arg_types, args_cst, dict())
  1137. generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn,
  1138. function_name=fn_name, function_types=self.function_ret_types,
  1139. noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
  1140. options=self.builder.options, codegen_fns=self.builder.codegen_fns,
  1141. module_map=self.builder.module_map, caller_context=caller_context,
  1142. is_gluon=self.is_gluon)
  1143. try:
  1144. generator.visit(fn.parse())
  1145. except Exception as e:
  1146. # Wrap the error in the callee with the location of the call.
  1147. if knobs.compilation.front_end_debugging:
  1148. raise
  1149. raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
  1150. callee_ret_type = generator.ret_type
  1151. self.function_ret_types[fn_name] = callee_ret_type
  1152. else:
  1153. callee_ret_type = self.function_ret_types[fn_name]
  1154. symbol = self.module.get_function(fn_name)
  1155. args_val = flatten_values_to_ir(args_val)
  1156. call_op = self.builder.call(symbol, args_val)
  1157. if callee_ret_type == language.void:
  1158. return None
  1159. handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
  1160. return next(unflatten_ir_values(handles, [callee_ret_type]))
  1161. def call_Function(self, node, fn, args, kws):
  1162. if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)):
  1163. args.insert(0, fn.__self__)
  1164. fn = fn.__func__
  1165. if isinstance(fn, JITFunction):
  1166. _check_fn_args(node, fn, args)
  1167. return self.call_JitFunction(fn, args, kws)
  1168. if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance(
  1169. fn, ConstexprFunction):
  1170. extra_kwargs = dict()
  1171. if isinstance(fn, ConstexprFunction):
  1172. sig = inspect.signature(fn.__call__)
  1173. else:
  1174. sig = inspect.signature(fn)
  1175. if '_semantic' in sig.parameters:
  1176. extra_kwargs["_semantic"] = self.semantic
  1177. if '_generator' in sig.parameters:
  1178. extra_kwargs['_generator'] = self
  1179. try:
  1180. ret = fn(*args, **extra_kwargs, **kws)
  1181. # builtin functions return plain tuples for readability
  1182. if isinstance(ret, tuple):
  1183. ret = language.tuple(ret)
  1184. return ret
  1185. except Exception as e:
  1186. if knobs.compilation.front_end_debugging:
  1187. raise
  1188. # Normally when we raise a CompilationError, we raise it as
  1189. # `from None`, because the original fileline from the exception
  1190. # is not relevant (and often points into code_generator.py
  1191. # itself). But when calling a function, we raise as `from e` to
  1192. # preserve the traceback of the original error, which may e.g.
  1193. # be in core.py.
  1194. raise CompilationError(self.jit_fn.src, node, str(e)) from e
  1195. if fn in self.builtin_namespace.values() or (hasattr(fn, '__self__') and not _is_triton_value(fn.__self__)):
  1196. args = map(_unwrap_if_constexpr, args)
  1197. ret = fn(*args, **kws)
  1198. def wrap_constexpr(x):
  1199. if _is_triton_value(x):
  1200. return x
  1201. return constexpr(x)
  1202. if isinstance(ret, (builtins.tuple, language.tuple)):
  1203. return _apply_to_tuple_values(ret, wrap_constexpr)
  1204. return wrap_constexpr(ret)
  1205. def call_Method(self, node, fn, fn_self, args, kws):
  1206. if isinstance(fn, JITFunction):
  1207. args.insert(0, fn_self)
  1208. return self.call_Function(node, fn, args, kws)
  1209. def visit_Call(self, node):
  1210. fn = _unwrap_if_constexpr(self.visit(node.func))
  1211. if not isinstance(fn, BoundJITMethod):
  1212. static_implementation = self.statically_implemented_functions.get(fn)
  1213. if static_implementation is not None:
  1214. return static_implementation(self, node)
  1215. mur = getattr(fn, '_must_use_result', False)
  1216. if mur and getattr(node, '_is_unused', False):
  1217. error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
  1218. if isinstance(mur, str):
  1219. error_message.append(mur)
  1220. raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
  1221. kws = dict(self.visit(keyword) for keyword in node.keywords)
  1222. args = []
  1223. for arg in node.args:
  1224. if isinstance(arg, ast.Starred):
  1225. arg = self.visit(arg.value)
  1226. assert isinstance(arg, language.core.tuple)
  1227. args.extend(arg.values)
  1228. else:
  1229. args.append(self.visit(arg))
  1230. return self.call_Function(node, fn, args, kws)
  1231. def visit_Constant(self, node):
  1232. return constexpr(node.value)
  1233. def visit_BoolOp(self, node: ast.BoolOp):
  1234. method_name = self._method_name_for_bool_op.get(type(node.op))
  1235. if method_name is None:
  1236. raise self._unsupported(
  1237. node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
  1238. nontrivial_values = []
  1239. for subnode in node.values:
  1240. # we visit the values in order, executing their side-effects
  1241. # and possibly early-exiting:
  1242. value = self.visit(subnode)
  1243. if not _is_triton_tensor(value):
  1244. # this is a constexpr, so we might be able to short-circuit:
  1245. bv = bool(value)
  1246. if (bv is False) and (method_name == "logical_and"):
  1247. # value is falsey so return that:
  1248. return value
  1249. if (bv is True) and (method_name == "logical_or"):
  1250. # value is truthy so return that:
  1251. return value
  1252. # otherwise, our constexpr has no effect on the output of the
  1253. # expression so we do not append it to nontrivial_values.
  1254. else:
  1255. if value.type.is_block():
  1256. lineno = getattr(node, "lineno", None)
  1257. if lineno is not None:
  1258. lineno += self.begin_line
  1259. warnings.warn_explicit(
  1260. "Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
  1261. category=UserWarning,
  1262. filename=self.file_name,
  1263. lineno=lineno,
  1264. source=ast.unparse(node),
  1265. )
  1266. # not a constexpr so we must append it:
  1267. nontrivial_values.append(value)
  1268. if len(nontrivial_values) == 0:
  1269. # the semantics of a disjunction of falsey values or conjunction
  1270. # of truthy values is to return the final value:
  1271. nontrivial_values.append(value)
  1272. while len(nontrivial_values) >= 2:
  1273. rhs = nontrivial_values.pop()
  1274. lhs = nontrivial_values.pop()
  1275. res = self._apply_binary_method(node, method_name, lhs, rhs)
  1276. nontrivial_values.append(res)
  1277. assert len(nontrivial_values) == 1
  1278. return nontrivial_values[0]
  1279. _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
  1280. def get_Attribute(self, lhs, attr):
  1281. if _is_triton_tensor(lhs) and attr == "T":
  1282. return self.semantic.permute(lhs, (1, 0))
  1283. # NOTE: special case ".value" for BC
  1284. if isinstance(lhs, constexpr) and attr not in ("value", "type"):
  1285. lhs = lhs.value
  1286. attr = getattr(lhs, attr)
  1287. if _is_triton_value(lhs) and isinstance(attr, JITFunction):
  1288. return BoundJITMethod(lhs, attr)
  1289. return attr
  1290. def visit_Attribute(self, node):
  1291. lhs = self.visit(node.value)
  1292. if isinstance(lhs, ModuleType):
  1293. # follow module_map until reaching fixed-point:
  1294. while (name := lhs.__name__) in self.builder.module_map:
  1295. lhs = self.builder.module_map[name]
  1296. if lhs.__name__ == name:
  1297. break
  1298. return self.get_Attribute(lhs, node.attr)
  1299. def visit_Expr(self, node):
  1300. node.value._is_unused = True
  1301. ast.NodeVisitor.generic_visit(self, node)
  1302. def visit_NoneType(self, node):
  1303. return None
  1304. def visit_JoinedStr(self, node):
  1305. values = list(node.values)
  1306. for i, value in enumerate(values):
  1307. if isinstance(value, ast.Constant):
  1308. values[i] = str(value.value)
  1309. elif isinstance(value, ast.FormattedValue):
  1310. conversion_code = value.conversion
  1311. evaluated = self.visit(value.value)
  1312. if not _is_constexpr(evaluated):
  1313. raise self._unsupported(
  1314. node,
  1315. "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
  1316. + str(type(evaluated)))
  1317. values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
  1318. else:
  1319. raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
  1320. return ''.join(values)
  1321. def visit(self, node):
  1322. if node is None:
  1323. return
  1324. with warnings.catch_warnings():
  1325. # The ast library added visit_Constant and deprecated some other
  1326. # methods but we can't move to that without breaking Python 3.6 and 3.7.
  1327. warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
  1328. warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
  1329. last_node = self.cur_node
  1330. last_loc = self.builder.get_loc()
  1331. self.cur_node = node
  1332. if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
  1333. here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
  1334. if self.name_loc_as_prefix is not None:
  1335. self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc))
  1336. else:
  1337. self.builder.set_loc(here_loc)
  1338. last_loc = self.builder.get_loc()
  1339. try:
  1340. ret = super().visit(node)
  1341. except CompilationError:
  1342. raise
  1343. except Exception as e:
  1344. if knobs.compilation.front_end_debugging:
  1345. raise
  1346. # Wrap the error in a CompilationError which contains the source
  1347. # of the @jit function.
  1348. raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
  1349. # Reset the location to the last one before the visit
  1350. if last_loc:
  1351. self.cur_node = last_node
  1352. self.builder.set_loc(last_loc)
  1353. return ret
  1354. def generic_visit(self, node):
  1355. raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__))
  1356. def execute_static_assert(self, node: ast.Call) -> None:
  1357. arg_count = len(node.args)
  1358. if not (0 < arg_count <= 2) or len(node.keywords):
  1359. raise TypeError("`static_assert` requires one or two positional arguments only")
  1360. passed = _unwrap_if_constexpr(self.visit(node.args[0]))
  1361. if not isinstance(passed, bool):
  1362. raise NotImplementedError(
  1363. "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
  1364. )
  1365. if not passed:
  1366. if arg_count == 1:
  1367. message = ""
  1368. else:
  1369. try:
  1370. message = self.visit(node.args[1])
  1371. except Exception as e:
  1372. message = "<failed to evaluate assertion message: " + repr(e) + ">"
  1373. raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message))
  1374. return None
  1375. def static_executor(python_fn):
  1376. def ret(self, node: ast.Call):
  1377. kws = {
  1378. name: _unwrap_if_constexpr(value)
  1379. for name, value in (self.visit(keyword) for keyword in node.keywords)
  1380. }
  1381. args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args]
  1382. return constexpr(python_fn(*args, **kws))
  1383. return ret
  1384. from ..experimental.gluon import language as ttgl
  1385. statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
  1386. language.core.static_assert: execute_static_assert,
  1387. language.core.static_print: static_executor(print),
  1388. ttgl.static_assert: execute_static_assert,
  1389. ttgl.static_print: static_executor(print),
  1390. int: static_executor(int),
  1391. len: static_executor(len),
  1392. }
  1393. def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
  1394. arg_types = [None] * len(fn.arg_names)
  1395. for k, v in src.signature.items():
  1396. idx = fn.arg_names.index(k)
  1397. arg_types[idx] = str_to_ty(v, None)
  1398. def apply_constexpr_types(argument, indices, value):
  1399. index = indices.pop()
  1400. if len(indices) == 0:
  1401. if isinstance(argument, list):
  1402. argument[index] = constexpr(value).type
  1403. else:
  1404. argument.types[index] = constexpr(value).type
  1405. else:
  1406. apply_constexpr_types(argument[index], indices, value)
  1407. for path, value in src.constants.items():
  1408. apply_constexpr_types(arg_types, list(path)[::-1], value)
  1409. prototype = ASTFunction([], arg_types, src.constants, src.attrs)
  1410. file_name, begin_line = get_jit_fn_file_line(fn)
  1411. # query function representation
  1412. from collections import namedtuple
  1413. leaves = filter(lambda v: len(v) == 1, src.constants)
  1414. constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
  1415. signature = src.signature
  1416. proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
  1417. generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
  1418. jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
  1419. codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
  1420. generator.visit(fn.parse())
  1421. module = generator.module
  1422. # module takes ownership of the context
  1423. module.context = context
  1424. if not module.verify():
  1425. if not fn.is_gluon():
  1426. print(module)
  1427. raise RuntimeError("error encountered during parsing")
  1428. return module