| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639 |
- import ast
- import builtins
- import contextlib
- import copy
- import inspect
- import re
- import warnings
- import textwrap
- from dataclasses import dataclass
- from types import ModuleType
- from typing import Any, Callable, Dict, Optional, Tuple, Type, Union, Iterable, List
- from .. import knobs, language
- from .._C.libtriton import ir, gluon_ir
- from ..language import constexpr, str_to_ty, tensor, tuple as tl_tuple
- from ..language.core import _unwrap_if_constexpr, base_value, base_type
- # ideally we wouldn't need any runtime component
- from ..runtime.jit import get_jit_fn_file_line, get_full_name, JITCallable, BoundConstexprFunction, ConstexprFunction, JITFunction
- from .._utils import find_paths_if, get_iterable_path, set_iterable_path, is_namedtuple
- from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
- def check_identifier_legality(name, type):
- pattern = r'^[a-zA-Z_][a-zA-Z0-9_]*$'
- if not re.match(pattern, name):
- raise CompilationError(f"invalid {type} identifier: {name}", name)
- return name
- def mangle_fn(name, arg_tys, constants, caller_context):
- # doesn't mangle ret type, which must be a function of arg tys
- mangled_arg_names = '_'.join([ty.mangle() for ty in arg_tys])
- mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
- mangled_constants = mangled_constants.replace('.', '_d_')
- mangled_constants = mangled_constants.replace("'", '_sq_')
- # [ and ] are not allowed in LLVM identifiers
- mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
- ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
- if caller_context is not None:
- ret += caller_context.mangle()
- return ret
- def _is_triton_value(o: Any) -> bool:
- return isinstance(o, base_value)
- def _is_triton_tensor(o: Any) -> bool:
- return isinstance(o, tensor)
- def _is_constexpr(o: Any) -> bool:
- return o is None or isinstance(o, (constexpr, language.core.dtype, JITCallable))
- def _is_non_scalar_tensor(o: Any) -> bool:
- return _is_triton_tensor(o) and (o.type.is_block() and o.type.numel != 1)
- def _is_list_like(o: Any) -> bool:
- return isinstance(o, (list, tuple))
- def _check_fn_args(node, fn, args):
- if fn.noinline:
- for idx, arg in enumerate(args):
- if not _is_constexpr(arg) and _is_non_scalar_tensor(arg):
- raise UnsupportedLanguageConstruct(
- fn.src, node,
- f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}'
- )
- def _apply_to_tuple_values(value, fn):
- if is_namedtuple(type(value)):
- fields = value._fields
- elif isinstance(value, language.tuple):
- fields = value.type.fields
- else:
- assert False, f"Unsupported type {type(value)}"
- vals = [fn(v) for v in value]
- vals = [constexpr(v) if v is None else v for v in vals]
- types = [v.type for v in vals]
- return language.tuple(vals, language.tuple_type(types, fields))
- def flatten_values_to_ir(values: Iterable[base_value]):
- handles = []
- for v in values:
- v._flatten_ir(handles)
- return handles
- def unflatten_ir_values(handles: List[ir.value], types: List[base_type]):
- cursor = 0
- for ty in types:
- value, cursor = ty._unflatten_ir(handles, cursor)
- yield value
- assert cursor == len(handles)
- _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
- def _clone_triton_value(val):
- handles = []
- val._flatten_ir(handles)
- clone, _ = val.type._unflatten_ir(handles, 0)
- return clone
- def _clone_scope(scope):
- return {name: _clone_triton_value(val) if _is_triton_value(val) else val for name, val in scope.items()}
- class enter_sub_region:
- def __init__(self, generator):
- self.generator = generator
- def __enter__(self):
- # record lscope & local_defs in the parent scope
- self.liveins = _clone_scope(self.generator.lscope)
- self.prev_defs = _clone_scope(self.generator.local_defs)
- self.generator.local_defs = {}
- self.insert_block = self.generator.builder.get_insertion_block()
- self.insert_point = self.generator.builder.get_insertion_point()
- return self.liveins, self.insert_block
- def __exit__(self, *args, **kwargs):
- self.generator.builder.restore_insertion_point(self.insert_point)
- self.generator.lscope = self.liveins
- self.generator.local_defs = self.prev_defs
- # Check if the given syntax node has an "early" return
- class ContainsReturnChecker(ast.NodeVisitor):
- def __init__(self, gscope):
- self.gscope = gscope
- def _visit_stmts(self, body) -> bool:
- return any(self.visit(s) for s in body)
- def _visit_function(self, fn) -> bool:
- # No need to check within the function as it won't cause an early return.
- # If the function itself has unstructured control flow we may not be able to inline it causing poor performance,
- # we should check for this and emit a warning.
- return False
- def generic_visit(self, node) -> bool:
- ret = False
- for _, value in ast.iter_fields(node):
- if isinstance(value, list):
- for item in value:
- if isinstance(item, ast.AST):
- ret = ret or self.visit(item)
- elif isinstance(value, ast.AST):
- ret = ret or self.visit(value)
- return ret
- def visit_Attribute(self, node: ast.Attribute) -> bool:
- # If the left part is a name, it's possible that
- # we call triton native function or a jit function from another module.
- # If the left part is not a name, it must return a tensor or a constexpr
- # whose methods do not contain return statements
- # e.g., (tl.load(x)).to(y)
- # So we only check if the expressions within value have return or not
- if isinstance(node.value, ast.Name):
- if node.value.id in self.gscope:
- value = self.gscope[node.value.id]
- fn = getattr(value, node.attr)
- return self._visit_function(fn)
- return False
- return self.visit(node.value)
- def visit_Name(self, node: ast.Name) -> bool:
- if type(node.ctx) is ast.Store:
- return False
- if node.id in self.gscope:
- fn = self.gscope[node.id]
- return self._visit_function(fn)
- return False
- def visit_Return(self, node: ast.Return) -> bool:
- return True
- def visit_Assign(self, node: ast.Assign) -> bool:
- # There couldn't be an early return
- # x = ...
- return False
- def visit_AugAssign(self, node: ast.AugAssign) -> bool:
- # There couldn't be an early return
- # x += ...
- return False
- def visit_Module(self, node: ast.Module) -> bool:
- return self._visit_stmts(node.body)
- def visit_FunctionDef(self, node: ast.FunctionDef) -> bool:
- return self._visit_stmts(node.body)
- def visit_If(self, node: ast.If) -> bool:
- # TODO: optimize the following case in which we actually don't have
- # a return when static_cond is false:
- # if dynamic_cond
- # if static_cond
- # func_with_return
- # else
- # func_without_return
- ret = self._visit_stmts(node.body)
- if node.orelse:
- ret = ret or self._visit_stmts(node.orelse)
- return ret
- def visit_IfExp(self, node: ast.IfExp) -> bool:
- return self.visit(node.body) or self.visit(node.orelse)
- def visit_Call(self, node: ast.Call) -> bool:
- return self.visit(node.func)
- class ASTFunction:
- def __init__(self, ret_types, arg_types, constants, attrs):
- self.ret_types = ret_types
- self.arg_types = arg_types
- self.constants = constants
- self.attrs = attrs
- def flatten_ir_types(self, builder: ir.builder, types: List[base_type]) -> List[ir.type]:
- ir_types = []
- for ty in types:
- if ty is None:
- continue
- ty._flatten_ir_types(builder, ir_types)
- return ir_types
- def return_types_ir(self, builder: ir.builder) -> List[ir.type]:
- return self.flatten_ir_types(builder, self.ret_types)
- def serialize(self, builder: ir.builder):
- # fill up IR values in template
- # > build function
- is_val = lambda path, _: path not in self.constants and _ is not None
- val_paths = list(find_paths_if(self.arg_types, is_val))
- arg_types = [get_iterable_path(self.arg_types, path) for path in val_paths]
- arg_types_ir = self.flatten_ir_types(builder, arg_types)
- ret_types_ir = self.return_types_ir(builder)
- return builder.get_function_ty(arg_types_ir, ret_types_ir)
- def deserialize(self, fn):
- # create "template"
- def make_template(ty):
- if isinstance(ty, (list, tuple, language.tuple_type)):
- return language.tuple([make_template(x) for x in ty], ty)
- return language.constexpr(None)
- vals = make_template(self.arg_types)
- is_val = lambda path, _: path not in self.constants and _ is not None
- val_paths = list(find_paths_if(self.arg_types, is_val))
- # > add IR values to the template
- cursor = 0
- handles = [fn.args(i) for i in range(fn.get_num_args())]
- for path in val_paths:
- ty = get_iterable_path(self.arg_types, path)
- # > set attributes
- attr_specs = self.attrs.get(path, [])
- for attr_name, attr_val in attr_specs:
- fn.set_arg_attr(cursor, attr_name, attr_val)
- # > build frontend value
- val, cursor = ty._unflatten_ir(handles, cursor)
- set_iterable_path(vals, path, val)
- # > add constexpr values to the template
- constants = self.constants
- for path, val in constants.items():
- set_iterable_path(vals, path, language.constexpr(val))
- return vals
- @dataclass(frozen=True)
- class BoundJITMethod:
- __self__: base_value
- __func__: JITFunction
- class CodeGenerator(ast.NodeVisitor):
- def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, *, options, codegen_fns,
- module_map, is_gluon, module=None, is_kernel=False, function_types: Optional[Dict] = None,
- noinline=False, caller_context=None, file_name: Optional[str] = None, begin_line=0):
- self.context = context
- self.is_gluon = is_gluon
- if is_gluon:
- from triton.experimental.gluon.language._semantic import GluonSemantic
- self.builder = gluon_ir.GluonOpBuilder(context)
- self.semantic = GluonSemantic(self.builder)
- else:
- from triton.language.semantic import TritonSemantic
- self.builder = ir.builder(context)
- self.semantic = TritonSemantic(self.builder)
- self.name_loc_as_prefix = None
- self.file_name = file_name
- # node.lineno starts from 1, so we need to subtract 1
- self.begin_line = begin_line - 1
- self.builder.set_loc(file_name, begin_line, 0)
- self.builder.options = options
- # dict of functions provided by the backend. Below are the list of possible functions:
- # Convert custom types not natively supported on HW.
- # convert_custom_types(input_tensor, dtype, fp_downcast_rounding=None, _builder=None)
- self.builder.codegen_fns = codegen_fns
- self.builder.module_map = {} if module_map is None else module_map
- self.module = self.builder.create_module() if module is None else module
- self.function_ret_types = {} if function_types is None else function_types
- self.prototype = prototype
- self.gscope = {}
- for k, v in gscope.items():
- if isinstance(v, ModuleType):
- self.gscope[k] = module_map.get(v.__name__, v)
- continue
- module_name = getattr(v, "__module__", "")
- if module_name in module_map:
- self.gscope[k] = getattr(module_map[module_name], v.__name__)
- else:
- self.gscope[k] = v
- self.lscope = {}
- self.jit_fn = jit_fn
- # TODO: we currently generate illegal names for non-kernel functions involving constexprs!
- if is_kernel:
- function_name = function_name[function_name.rfind('.') + 1:]
- function_name = check_identifier_legality(function_name, "function")
- self.function_name = function_name
- self.is_kernel = is_kernel
- self.cur_node = None
- self.noinline = noinline
- self.caller_context = caller_context
- self.scf_stack = []
- self.ret_type = None
- # SSA-construction
- # name => language.tensor
- self.local_defs: Dict[str, tensor] = {}
- self.dereference_name: Callable[[str], Any] = self._define_name_lookup()
- self.fn = None
- # Are we currently visiting an ast.arg's default value? These have some
- # special handling.
- self.visiting_arg_default_value = False
- builtin_namespace: Dict[str, Any] = {
- _.__name__: _
- for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
- }
- builtin_namespace.update((
- ('print', language.core.device_print),
- ('min', language.core.builtin_min),
- ('max', language.core.builtin_max),
- ))
- def _unsupported(self, node, message):
- return UnsupportedLanguageConstruct(self.jit_fn.src, node, message)
- def _is_constexpr_global(self, name):
- absent_marker = object()
- val = self.gscope.get(name, absent_marker)
- if val is absent_marker:
- return False
- if _is_constexpr(val):
- return True
- return False
- def _define_name_lookup(self):
- def local_lookup(name: str, absent):
- # this needs to be re-fetched from `self` every time, because it gets switched occasionally
- return self.lscope.get(name, absent)
- def global_lookup(name: str, absent):
- val = self.gscope.get(name, absent)
- # The high-level rule is that only constexpr globals are allowed.
- # But actually a bunch of other things, such as module imports, are
- # technically Python globals. We have to allow these too!
- if any([
- val is absent,
- name in self.builtin_namespace, #
- type(val) is ModuleType, #
- isinstance(val, JITCallable), #
- getattr(val, "__triton_builtin__", False), #
- getattr(val, "__triton_aggregate__", False), #
- getattr(val, "__module__", "").startswith("triton.language"), #
- getattr(val, "__module__", "").startswith("triton.experimental.gluon.language"), #
- isinstance(val, language.dtype), #
- is_namedtuple(val),
- self._is_constexpr_global(name), #
- # Allow accesses to globals while visiting an ast.arg
- # because you should be able to do
- # @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
- self.visiting_arg_default_value, #
- knobs.compilation.allow_non_constexpr_globals,
- ]):
- return val
- raise NameError(
- textwrap.dedent(f"""\
- Cannot access global variable {name} from within @jit'ed
- function. Triton kernels can only access global variables that
- are instanstiated as constexpr (`x = triton.language.constexpr(42)`). Note that this is different from
- annotating a variable as constexpr (`x: triton.language.constexpr = 42`), which is not supported. Alternatively, set the
- envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not
- promise to support this forever.""").replace("\n", " "))
- absent_marker = object()
- def name_lookup(name: str) -> Any:
- absent = absent_marker
- for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get:
- value = lookup_function(name, absent)
- if value is not absent:
- return value
- raise NameError(f'{name} is not defined')
- return name_lookup
- @contextlib.contextmanager
- def _name_loc_prefix(self, prefix):
- self.name_loc_as_prefix = prefix
- yield
- self.name_loc_as_prefix = None
- def _maybe_set_loc_to_name(self, val, name):
- if isinstance(val, (ir.value, ir.block_argument)):
- val.set_loc(self.builder.create_name_loc(name, val.get_loc()))
- elif _is_triton_value(val):
- handles = []
- val._flatten_ir(handles)
- for handle in handles:
- handle.set_loc(self.builder.create_name_loc(name, handle.get_loc()))
- def set_value(self, name: str, value: Union[base_value, constexpr]) -> None:
- ''' This function:
- called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
- 1. record local defined name (FIXME: should consider control flow)
- 2. store tensor in self.lvalue
- '''
- self.lscope[name] = value
- self.local_defs[name] = value
- def _get_insertion_point_and_loc(self):
- # XXX: this is a hack to get the location of the insertion point.
- # The insertion point's location could be invalid sometimes,
- # so we need to explicitly set the location
- loc = self.builder.get_loc()
- ip = self.builder.get_insertion_point()
- return ip, loc
- def _set_insertion_point_and_loc(self, ip, loc):
- self.builder.restore_insertion_point(ip)
- self.builder.set_loc(loc)
- def _find_carries(self, node, liveins, ignore: set[str] = set()):
- # create loop body block
- block = self.builder.create_block()
- self.builder.set_insertion_point_to_start(block)
- # dry visit loop body
- self.scf_stack.append(node)
- self.visit_compound_statement(node.body)
- self.scf_stack.pop()
- block.erase()
- # If a variable (name) has changed value within the loop, then it's
- # a loop-carried variable. (The new and old value must be of the
- # same type)
- init_tys = []
- init_handles = []
- names = []
- for name, live_val in liveins.items():
- if name in ignore:
- continue
- if _is_triton_value(live_val):
- loop_val = self.lscope[name]
- self._verify_loop_carried_variable(name, loop_val, live_val)
- live_handles = flatten_values_to_ir([live_val])
- loop_handles = flatten_values_to_ir([loop_val])
- if live_handles != loop_handles:
- names.append(name)
- init_tys.append(live_val.type)
- init_handles.extend(live_handles)
- else:
- assert name not in self.local_defs, f'Loop carried variable {name} is not a triton value'
- # reset local scope to not pick up local defs from the dry run.
- self.lscope = liveins.copy()
- self.local_defs = {}
- return names, init_handles, init_tys
- #
- # AST visitor
- #
- def visit_compound_statement(self, stmts):
- # Ensure that stmts is iterable
- if not _is_list_like(stmts):
- stmts = [stmts]
- for stmt in stmts:
- self.visit(stmt)
- # Stop parsing as soon as we hit a `return` statement; everything
- # after this is dead code.
- if isinstance(stmt, ast.Return):
- break
- def visit_Module(self, node):
- ast.NodeVisitor.generic_visit(self, node)
- def visit_List(self, node):
- ctx = self.visit(node.ctx)
- assert ctx is None
- elts = language.tuple([self.visit(elt) for elt in node.elts])
- return elts
- def visit_ListComp(self, node: ast.ListComp):
- if len(node.generators) != 1:
- raise ValueError("nested comprehensions are not supported")
- comp = node.generators[0]
- iter = self.visit(comp.iter)
- if not isinstance(iter, tl_tuple):
- raise NotImplementedError("only tuple comprehensions are supported")
- results = []
- for item in iter:
- self.set_value(comp.target.id, item)
- results.append(self.visit(node.elt))
- return tl_tuple(results)
- # By design, only non-kernel functions can return
- def visit_Return(self, node):
- ret_value = self.visit(node.value)
- handles = []
- def decay(value):
- if isinstance(value, language.tuple):
- return _apply_to_tuple_values(value, decay)
- elif isinstance(value, (language.constexpr, int, float)):
- return self.semantic.to_tensor(value)
- return value
- ret_value = decay(ret_value)
- if ret_value is None:
- ret_ty = language.void
- else:
- assert isinstance(ret_value, language.core.base_value)
- ret_value._flatten_ir(handles)
- ret_ty = ret_value.type
- self.builder.ret(handles)
- if self.ret_type is None:
- self.ret_type = ret_ty
- elif self.ret_type != ret_ty:
- raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}')
- # A return op must always terminate the basic block, so we create a dead
- # basic block in case there are any ops after the return.
- post_ret_block = self.builder.create_block()
- self.builder.set_insertion_point_to_end(post_ret_block)
- def visit_FunctionDef(self, node):
- arg_names, kwarg_names = self.visit(node.args)
- if self.fn:
- raise self._unsupported(node, "nested function definition is not supported.")
- # initialize defaults
- for i, default_value in enumerate(node.args.defaults[::-1]):
- arg_node = node.args.args[-i - 1]
- annotation = arg_node.annotation
- name = arg_node.arg
- st_target = ast.Name(id=name, ctx=ast.Store())
- if annotation is None:
- init_node = ast.Assign(targets=[st_target], value=default_value)
- else:
- init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation)
- try:
- assert not self.visiting_arg_default_value
- self.visiting_arg_default_value = True
- self.visit(init_node)
- finally:
- self.visiting_arg_default_value = False
- # initialize function
- visibility = "public" if self.is_kernel else "private"
- fn_ty = self.prototype.serialize(self.builder)
- self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline)
- self.module.push_back(self.fn)
- entry = self.fn.add_entry_block()
- arg_values = self.prototype.deserialize(self.fn)
- if self.caller_context is not None:
- self.caller_context.initialize_callee(self.fn, self.builder)
- # bind arguments to symbols
- for arg_name, arg_value in zip(arg_names, arg_values):
- self._maybe_set_loc_to_name(arg_value, arg_name)
- self.set_value(arg_name, arg_value)
- insert_pt = self.builder.get_insertion_block()
- self.builder.set_insertion_point_to_start(entry)
- # visit function body
- self.visit_compound_statement(node.body)
- # finalize function
- assert not self.builder.get_insertion_block().has_terminator()
- if self.ret_type is None or self.ret_type == language.void:
- self.ret_type = language.void
- self.builder.ret([])
- else:
- if isinstance(self.ret_type, language.tuple_type):
- self.prototype.ret_types = self.ret_type.types
- else:
- self.prototype.ret_types = [self.ret_type]
- self.fn.reset_type(self.prototype.serialize(self.builder))
- self.builder.ret([self.builder.create_poison(ty) for ty in self.prototype.return_types_ir(self.builder)])
- self.fn.finalize()
- if insert_pt:
- self.builder.set_insertion_point_to_end(insert_pt)
- def visit_arguments(self, node):
- arg_names = []
- for arg in node.args:
- arg_names += [self.visit(arg)]
- kwarg_names = self.visit(node.kwarg)
- return arg_names, kwarg_names
- def visit_arg(self, node):
- ast.NodeVisitor.generic_visit(self, node)
- param = next(p for p in self.jit_fn.params if p.name == node.arg)
- if param.is_constexpr and (param.do_not_specialize or param.do_not_specialize_on_alignment):
- raise CompilationError(
- self.jit_fn.src, node,
- f"{node.arg} marked as constexpr and listed in do_not_specialize/do_not_specialize_on_alignment. "
- "Remove constexpr designation to skip specialization.")
- return node.arg
- def visit_AnnAssign(self, node):
- # extract attributes
- annotation = self.visit(node.annotation)
- target = self.visit(node.target)
- value = self.visit(node.value)
- # constexpr
- if annotation == constexpr:
- if target in self.lscope:
- raise ValueError(f'{target} is already defined.'
- f' constexpr cannot be reassigned.')
- value = constexpr(value)
- self.lscope[target] = value
- return self.lscope[target]
- # default: call visit_Assign
- return self.visit_Assign(node)
- def assignTarget(self, target, value):
- assert isinstance(target.ctx, ast.Store)
- if isinstance(target, ast.Subscript):
- return self.visit_Subscript_Store(target, value)
- if isinstance(target, ast.Tuple):
- for i, target in enumerate(target.elts):
- self.assignTarget(target, value.values[i])
- return
- if isinstance(target, ast.Attribute):
- raise NotImplementedError("Attribute assignment is not supported in triton")
- assert isinstance(target, ast.Name)
- self.set_value(self.visit(target), value)
- def visit_Assign(self, node):
- # construct values to assign
- def _sanitize_value(value):
- if isinstance(value, language.tuple):
- return _apply_to_tuple_values(value, _sanitize_value)
- native_nontensor_types = (language.dtype, language.tuple)
- value = _unwrap_if_constexpr(value)
- if value is not None and \
- not _is_triton_value(value) and \
- not isinstance(value, native_nontensor_types):
- value = self.semantic.to_tensor(value)
- return value
- targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets
- assert len(targets) == 1
- target = targets[0]
- if isinstance(target, ast.Name):
- with self._name_loc_prefix(target.id):
- values = _sanitize_value(self.visit(node.value))
- else:
- values = _sanitize_value(self.visit(node.value))
- self.assignTarget(target, values)
- def visit_AugAssign(self, node):
- lhs = copy.deepcopy(node.target)
- lhs.ctx = ast.Load()
- rhs = ast.BinOp(lhs, node.op, node.value)
- assign = ast.Assign(targets=[node.target], value=rhs)
- for x in ['lineno', 'col_offset', 'end_lineno', 'end_col_offset']:
- if hasattr(node, x):
- y = getattr(node, x)
- setattr(rhs, x, y)
- setattr(assign, x, y)
- self.visit(assign)
- return self.visit(lhs)
- def visit_Name(self, node):
- if type(node.ctx) is ast.Store:
- return node.id
- return self.dereference_name(node.id)
- def visit_Store(self, node):
- ast.NodeVisitor.generic_visit(self, node)
- def visit_Load(self, node):
- ast.NodeVisitor.generic_visit(self, node)
- def visit_Tuple(self, node):
- args = [self.visit(x) for x in node.elts]
- return language.tuple(args)
- def _apply_binary_method(self, node, method_name, lhs, rhs):
- # TODO: raise something meaningful if getattr fails below, esp for reverse method
- if _is_triton_tensor(lhs):
- return getattr(lhs, method_name)(rhs, _semantic=self.semantic)
- if _is_triton_tensor(rhs):
- reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
- return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
- if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
- lhs = constexpr(lhs)
- if isinstance(lhs, constexpr):
- fn = getattr(lhs, method_name)
- else:
- fn = self.get_Attribute(lhs, method_name)
- return self.call_Function(node, fn, [rhs], {})
- def visit_BinOp(self, node):
- lhs = self.visit(node.left)
- rhs = self.visit(node.right)
- method_name = self._method_name_for_bin_op.get(type(node.op))
- if method_name is None:
- raise self._unsupported(node,
- "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__))
- return self._apply_binary_method(node, method_name, lhs, rhs)
- _method_name_for_bin_op: Dict[Type[ast.operator], str] = {
- ast.Add: '__add__',
- ast.Sub: '__sub__',
- ast.Mult: '__mul__',
- ast.Div: '__truediv__',
- ast.FloorDiv: '__floordiv__',
- ast.Mod: '__mod__',
- ast.Pow: '__pow__',
- ast.LShift: '__lshift__',
- ast.RShift: '__rshift__',
- ast.BitAnd: '__and__',
- ast.BitOr: '__or__',
- ast.BitXor: '__xor__',
- }
- def visit_then_else_blocks(self, node, liveins, then_block, else_block):
- # then block
- self.builder.set_insertion_point_to_start(then_block)
- self.visit_compound_statement(node.body)
- then_block = self.builder.get_insertion_block()
- then_defs = self.local_defs.copy()
- then_vals = self.lscope.copy()
- # else block
- else_defs = {}
- else_vals = liveins.copy()
- if node.orelse:
- self.builder.set_insertion_point_to_start(else_block)
- self.lscope = liveins.copy()
- self.local_defs = {}
- self.visit_compound_statement(node.orelse)
- else_defs = self.local_defs.copy()
- else_block = self.builder.get_insertion_block()
- else_vals = self.lscope.copy()
- # update block arguments
- names = []
- # variables in livein whose value is updated in `if`
- for name, value in liveins.items():
- # livein variable changed value in either then or else
- if not _is_triton_value(value):
- continue
- then_handles = flatten_values_to_ir([then_vals[name]])
- else_handles = flatten_values_to_ir([else_vals[name]])
- if then_handles == else_handles:
- continue
- names.append(name)
- then_defs[name] = then_vals[name]
- else_defs[name] = else_vals[name]
- # check type
- for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]:
- type_equal = type(defs[name]) == type(value) # noqa: E721
- assert type_equal and defs[name].type == value.type, \
- f'initial value for `{name}` is of type {value}, '\
- f'but the {block_name} block redefines it as {defs[name]}'
- # variables that are both in then and else but not in liveins
- # TODO: could probably be cleaned up
- for name in sorted(then_defs.keys() & else_defs.keys()):
- if name in names:
- continue
- then_val = then_defs[name]
- then_ty = then_val.type
- else_val = else_defs[name]
- else_ty = else_val.type
- type_equal = type(then_val) == type(else_val) # noqa: E721
- assert type_equal and then_ty == else_ty, \
- f'Mismatched type for {name} between then block ({then_ty}) '\
- f'and else block ({else_ty})'
- names.append(name)
- return then_defs, else_defs, then_block, else_block, names
- def visit_if_top_level(self, cond, node):
- with enter_sub_region(self) as sr:
- liveins, ip_block = sr
- then_block = self.builder.create_block()
- else_block = self.builder.create_block()
- # create branch
- self.builder.set_insertion_point_to_end(ip_block)
- self.builder.create_cond_branch(cond.handle, then_block, else_block)
- # visit then and else blocks
- then_defs, else_defs, then_block, else_block, names = \
- self.visit_then_else_blocks(node, liveins, then_block, else_block)
- # create basic-block after conditional
- endif_block = self.builder.create_block()
- # then terminator
- self.builder.set_insertion_point_to_end(then_block)
- assert not then_block.has_terminator(), f"{then_block}"
- then_handles = flatten_values_to_ir(then_defs[name] for name in names)
- self.builder.create_branch(endif_block, then_handles)
- # else terminator
- self.builder.set_insertion_point_to_end(else_block)
- assert not else_block.has_terminator(), f"{else_block}"
- else_handles = flatten_values_to_ir(else_defs[name] for name in names)
- self.builder.create_branch(endif_block, else_handles)
- assert len(then_handles) == len(else_handles)
- for then_h, else_h in zip(then_handles, else_handles):
- ty = then_h.get_type()
- assert ty == else_h.get_type()
- endif_block.add_argument(ty)
- # change block
- self.builder.set_insertion_point_to_start(endif_block)
- # update value
- res_handles = [endif_block.arg(i) for i in range(len(then_handles))]
- types = [then_defs[name].type for name in names]
- new_values = unflatten_ir_values(res_handles, types)
- for name, new_value in zip(names, new_values):
- self.set_value(name, new_value)
- # TODO: refactor
- def visit_if_scf(self, cond, node):
- with enter_sub_region(self) as sr:
- liveins, _ = sr
- ip, last_loc = self._get_insertion_point_and_loc()
- then_block = self.builder.create_block()
- else_block = self.builder.create_block() if node.orelse else None
- then_defs, else_defs, then_block, else_block, names = \
- self.visit_then_else_blocks(node, liveins, then_block, else_block)
- # create if op
- then_handles = flatten_values_to_ir(then_defs[name] for name in names)
- for name, val in zip(names, then_handles):
- self._maybe_set_loc_to_name(val, name)
- self._set_insertion_point_and_loc(ip, last_loc)
- if_op = self.builder.create_if_op([h.get_type() for h in then_handles], cond.handle, True)
- then_block.merge_block_before(if_op.get_then_block())
- self.builder.set_insertion_point_to_end(if_op.get_then_block())
- if len(names) > 0:
- self.builder.create_yield_op(then_handles)
- if not node.orelse:
- else_block = if_op.get_else_block()
- else:
- else_block.merge_block_before(if_op.get_else_block())
- self.builder.set_insertion_point_to_end(if_op.get_else_block())
- if len(names) > 0:
- else_handles = flatten_values_to_ir(else_defs[name] for name in names)
- for name, val in zip(names, else_handles):
- self._maybe_set_loc_to_name(val, name)
- self.builder.create_yield_op(else_handles)
- # update values
- res_handles = [if_op.get_result(i) for i in range(len(then_handles))]
- types = [then_defs[name].type for name in names]
- new_values = unflatten_ir_values(res_handles, types)
- for name, new_value in zip(names, new_values):
- self.set_value(name, new_value)
- def visit_If(self, node):
- cond = self.visit(node.test)
- if _is_triton_tensor(cond):
- if _is_non_scalar_tensor(cond):
- raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
- if cond.type.is_block():
- warnings.warn(
- "If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
- % ast.unparse(node.test))
- cond = language.core._unsplat(cond, _semantic=self.semantic, _generator=self)
- cond = cond.to(language.int1, _semantic=self.semantic)
- if ContainsReturnChecker(self.gscope).visit(node):
- if self.scf_stack:
- raise self._unsupported(
- node, "Cannot have `return` statements inside `while` or `for` statements in triton.")
- self.visit_if_top_level(cond, node)
- else:
- self.visit_if_scf(cond, node)
- else:
- cond = _unwrap_if_constexpr(cond)
- # not isinstance - we insist the real thing, no subclasses and no ducks
- if type(cond) not in _condition_types:
- raise self._unsupported(
- node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
- ', '.join(_.__name__ for _ in _condition_types),
- type(cond).__name__))
- active_block = node.body if cond else node.orelse
- self.visit_compound_statement(active_block)
- def visit_IfExp(self, node):
- cond = self.visit(node.test)
- if _is_triton_tensor(cond):
- cond = cond.to(language.int1, _semantic=self.semantic)
- # TODO: Deal w/ more complicated return types (e.g tuple)
- with enter_sub_region(self):
- ip, last_loc = self._get_insertion_point_and_loc()
- then_block = self.builder.create_block()
- self.builder.set_insertion_point_to_start(then_block)
- then_val = self.semantic.to_tensor(self.visit(node.body))
- then_block = self.builder.get_insertion_block()
- else_block = self.builder.create_block()
- self.builder.set_insertion_point_to_start(else_block)
- # do not need to reset lscope since
- # ternary expressions cannot define new variables
- else_val = self.semantic.to_tensor(self.visit(node.orelse))
- else_block = self.builder.get_insertion_block()
- self._set_insertion_point_and_loc(ip, last_loc)
- assert then_val.type == else_val.type, \
- f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
- ret_type = then_val.type
- ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
- if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
- then_block.merge_block_before(if_op.get_then_block())
- if ret_type_ir:
- self.builder.set_insertion_point_to_end(if_op.get_then_block())
- self.builder.create_yield_op([then_val.handle])
- self.builder.set_insertion_point_to_end(if_op.get_then_block())
- else_block.merge_block_before(if_op.get_else_block())
- if ret_type_ir:
- self.builder.set_insertion_point_to_end(if_op.get_else_block())
- self.builder.create_yield_op([else_val.handle])
- return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
- else:
- cond = _unwrap_if_constexpr(cond)
- # not isinstance - we insist the real thing, no subclasses and no ducks
- if type(cond) not in _condition_types:
- raise self._unsupported(
- node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
- ', '.join(_.__name__ for _ in _condition_types),
- type(cond).__name__))
- if cond:
- return self.visit(node.body)
- else:
- return self.visit(node.orelse)
- def visit_With(self, node):
- # Lower `with` statements by constructing context managers and calling their enter/exit hooks
- # Instantiate each context manager with builder injection
- cm_list = []
- for item in node.items:
- call = item.context_expr
- fn = self.visit(call.func)
- args = [self.visit(arg) for arg in call.args]
- kws = dict(self.visit(kw) for kw in call.keywords)
- cm = fn(*args, _semantic=self.semantic, **kws)
- cm_list.append(cm)
- for cm, item in zip(cm_list, node.items):
- res = cm.__enter__()
- if item.optional_vars is not None:
- var_name = self.visit(item.optional_vars)
- self.set_value(var_name, res)
- if ContainsReturnChecker(self.gscope).visit(node):
- raise self._unsupported(node, "Cannot have `return` statements inside `with` statements in triton ")
- self.visit_compound_statement(node.body)
- for cm in reversed(cm_list):
- cm.__exit__(None, None, None)
- def visit_Pass(self, node):
- pass
- def visit_Compare(self, node):
- if not (len(node.comparators) == 1 and len(node.ops) == 1):
- raise self._unsupported(node, "simultaneous multiple comparison is not supported")
- lhs = self.visit(node.left)
- rhs = self.visit(node.comparators[0])
- lhs_value = _unwrap_if_constexpr(lhs)
- rhs_value = _unwrap_if_constexpr(rhs)
- if type(node.ops[0]) is ast.Is:
- return constexpr(lhs_value is rhs_value)
- if type(node.ops[0]) is ast.IsNot:
- return constexpr(lhs_value is not rhs_value)
- method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
- if method_name is None:
- raise self._unsupported(
- node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__))
- return self._apply_binary_method(node, method_name, lhs, rhs)
- _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = {
- ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__'
- }
- def visit_UnaryOp(self, node):
- operand = self.visit(node.operand)
- fn = self._method_name_for_unary_op.get(type(node.op))
- if fn is None:
- raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.")
- if _is_triton_tensor(operand):
- return getattr(operand, fn)(_semantic=self.semantic)
- try:
- return getattr(operand, fn)()
- except AttributeError:
- if fn == "__not__":
- return constexpr(not operand)
- raise self._unsupported(
- node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
- _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {
- ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
- }
- def _verify_loop_carried_variable(self, name, loop_val, live_val):
- assert _is_triton_value(loop_val), f'cannot reassign constexpr {name} in the loop'
- assert _is_triton_value(live_val), f'cannot reassign constexpr {name} in the loop'
- assert type(loop_val) is type(live_val), (
- f'Loop carried variable {name} changed type, was {type(loop_val)} but is now {type(live_val)}')
- assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
- f'Loop-carried variable {name} has initial type {live_val.type} '\
- f'but is re-assigned to {loop_val.type} in loop! '\
- f'Please make sure that the type stays consistent.'
- def visit_While(self, node):
- with enter_sub_region(self) as sr:
- liveins, insert_block = sr
- ip, last_loc = self._get_insertion_point_and_loc()
- names, init_handles, init_fe_tys = self._find_carries(node, liveins)
- init_tys = [h.get_type() for h in init_handles]
- self._set_insertion_point_and_loc(ip, last_loc)
- while_op = self.builder.create_while_op(init_tys, init_handles)
- # merge the condition region
- before_block = self.builder.create_block_with_parent(while_op.get_before(), init_tys)
- self.builder.set_insertion_point_to_start(before_block)
- block_args = [before_block.arg(i) for i in range(len(init_handles))]
- condition_args = unflatten_ir_values(block_args, init_fe_tys)
- for name, val in zip(names, condition_args):
- self.lscope[name] = val
- self.local_defs[name] = val
- self._maybe_set_loc_to_name(val, name)
- cond = self.visit(node.test)
- if isinstance(cond, language.condition):
- if cond.disable_licm:
- while_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
- cond = cond.condition
- self.builder.set_insertion_point_to_end(before_block)
- # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
- self.builder.create_condition_op(cond.handle, block_args)
- # merge the loop body
- after_block = self.builder.create_block_with_parent(while_op.get_after(), init_tys)
- # generate loop body
- self.builder.set_insertion_point_to_start(after_block)
- body_handles = [after_block.arg(i) for i in range(len(init_handles))]
- body_args = unflatten_ir_values(body_handles, init_fe_tys)
- for name, val in zip(names, body_args):
- self.lscope[name] = val
- self.local_defs[name] = val
- self._maybe_set_loc_to_name(val, name)
- self.scf_stack.append(node)
- self.visit_compound_statement(node.body)
- self.scf_stack.pop()
- yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
- self.builder.create_yield_op(yield_handles)
- # WhileOp defines new values, update the symbol table (lscope, local_defs)
- result_handles = [while_op.get_result(i) for i in range(len(init_handles))]
- result_vals = unflatten_ir_values(result_handles, init_fe_tys)
- for name, new_def in zip(names, result_vals):
- self.lscope[name] = new_def
- self.local_defs[name] = new_def
- self._maybe_set_loc_to_name(new_def, name)
- for stmt in node.orelse:
- assert False, "Not implemented"
- ast.NodeVisitor.generic_visit(self, stmt)
- def visit_Subscript_Load(self, node):
- assert isinstance(node.ctx, ast.Load)
- lhs = self.visit(node.value)
- slices = self.visit(node.slice)
- if _is_triton_value(lhs):
- return self.call_Method(node, lhs.__getitem__, lhs, [slices], {})
- return lhs[slices]
- def visit_Subscript_Store(self, node, value):
- raise NotImplementedError("__setitem__ is not supported in triton")
- def visit_Subscript(self, node):
- return self.visit_Subscript_Load(node)
- def visit_ExtSlice(self, node):
- return [self.visit(dim) for dim in node.dims]
- def visit_For(self, node):
- IteratorClass = self.visit(node.iter.func)
- iter_args = [self.visit(arg) for arg in node.iter.args]
- iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords)
- if IteratorClass == language.static_range:
- iterator = IteratorClass(*iter_args, **iter_kwargs)
- static_range = range(iterator.start.value, iterator.end.value, iterator.step.value)
- for i in static_range:
- self.lscope[node.target.id] = constexpr(i)
- self.visit_compound_statement(node.body)
- for stmt in node.orelse:
- ast.NodeVisitor.generic_visit(self, stmt)
- return
- num_stages = None
- loop_unroll_factor = None
- disallow_acc_multi_buffer = False
- flatten = False
- warp_specialize = False
- disable_licm = False
- if IteratorClass is language.range:
- iterator = IteratorClass(*iter_args, **iter_kwargs)
- # visit iterator arguments
- # note: only `range` iterator is supported now
- # collect lower bound (lb), upper bound (ub), and step
- lb = iterator.start
- ub = iterator.end
- step = iterator.step
- num_stages = iterator.num_stages
- loop_unroll_factor = iterator.loop_unroll_factor
- disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer
- flatten = iterator.flatten
- warp_specialize = iterator.warp_specialize
- disable_licm = iterator.disable_licm
- elif IteratorClass is range:
- # visit iterator arguments
- # note: only `range` iterator is supported now
- # collect lower bound (lb), upper bound (ub), and step
- lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Constant(0))
- ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
- step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Constant(1))
- else:
- raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
- # handle negative constant step (not supported by scf.for in MLIR)
- negative_step = False
- if _is_constexpr(step) and step.value < 0:
- step = constexpr(-step.value)
- negative_step = True
- lb, ub = ub, lb
- lb = self.semantic.to_tensor(lb)
- ub = self.semantic.to_tensor(ub)
- step = self.semantic.to_tensor(step)
- # induction variable type
- if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
- raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
- if _is_non_scalar_tensor(lb):
- raise TypeError(f"For lower bound must be a scalar, got {lb.type}")
- if _is_non_scalar_tensor(ub):
- raise TypeError(f"For upper bound must be a scalar, got {ub.type}")
- if _is_non_scalar_tensor(step):
- raise TypeError(f"For step must be a scalar, got {step.type}")
- iv_type = self.semantic.integer_promote_impl(lb.dtype, ub.dtype)
- iv_type = self.semantic.integer_promote_impl(iv_type, step.dtype)
- iv_ir_type = iv_type.to_ir(self.builder)
- iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED
- # lb/ub/step might be constexpr, we need to cast them to tensor
- lb = lb.handle
- ub = ub.handle
- step = step.handle
- # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
- lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed)
- ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
- step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
- # Create placeholder for the loop induction variable
- iv_placeholder = self.builder.create_poison(iv_ir_type)
- self.set_value(node.target.id, language.core.tensor(iv_placeholder, iv_type))
- with enter_sub_region(self) as sr:
- liveins, insert_block = sr
- ip, last_loc = self._get_insertion_point_and_loc()
- names, init_handles, init_tys = self._find_carries(node, liveins, ignore={node.target.id})
- # create ForOp
- self._set_insertion_point_and_loc(ip, last_loc)
- for_op = self.builder.create_for_op(lb, ub, step, init_handles)
- if _unwrap_if_constexpr(num_stages) is not None:
- for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages))
- if _unwrap_if_constexpr(loop_unroll_factor) is not None:
- for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor))
- if disallow_acc_multi_buffer:
- for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr())
- if flatten:
- for_op.set_attr("tt.flatten", self.builder.get_unit_attr())
- if warp_specialize:
- for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr())
- if disable_licm:
- for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr())
- self.scf_stack.append(node)
- for_op_body = for_op.get_body(0)
- self.builder.set_insertion_point_to_start(for_op_body)
- block_handles = [for_op_body.arg(i + 1) for i in range(len(init_handles))]
- block_args = unflatten_ir_values(block_handles, init_tys)
- for name, val in zip(names, block_args):
- self._maybe_set_loc_to_name(val, name)
- self.set_value(name, val)
- self.visit_compound_statement(node.body)
- self.scf_stack.pop()
- yield_handles = flatten_values_to_ir(self.lscope[name] for name in names)
- # create YieldOp
- if len(yield_handles) > 0:
- self.builder.create_yield_op(yield_handles)
- for_op_region = for_op_body.get_parent()
- assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block"
- # update induction variable with actual value, and replace all uses
- self.builder.set_insertion_point_to_start(for_op_body)
- iv = for_op.get_induction_var()
- if negative_step:
- iv = self.builder.create_sub(ub, iv)
- iv = self.builder.create_add(iv, lb)
- iv_placeholder.replace_all_uses_with(iv)
- self.set_value(node.target.id, language.core.tensor(iv, iv_type))
- self._maybe_set_loc_to_name(iv, node.target.id)
- # update lscope & local_defs (ForOp defines new values)
- result_handles = [for_op.get_result(i) for i in range(len(init_handles))]
- result_values = unflatten_ir_values(result_handles, init_tys)
- for name, val in zip(names, result_values):
- self.set_value(name, val)
- self._maybe_set_loc_to_name(val, name)
- for stmt in node.orelse:
- assert False, "Don't know what to do with else after for"
- ast.NodeVisitor.generic_visit(self, stmt)
- def visit_Slice(self, node):
- lower = self.visit(node.lower)
- upper = self.visit(node.upper)
- step = self.visit(node.step)
- return language.slice(lower, upper, step)
- def visit_Index(self, node):
- return self.visit(node.value)
- def visit_keyword(self, node) -> Tuple[str, Any]:
- return node.arg, self.visit(node.value)
- def visit_Assert(self, node) -> Any:
- test = self.visit(node.test)
- msg = self.visit(node.msg) if node.msg is not None else ""
- return language.core.device_assert(test, msg, _semantic=self.semantic)
- def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
- args = inspect.getcallargs(fn.fn, *args, **kwargs)
- args = [args[name] for name in fn.arg_names]
- for i, arg in enumerate(args):
- if isinstance(arg, (language.dtype, float, int, bool, JITFunction)):
- args[i] = language.core.constexpr(arg)
- args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
- args_cst = {path: get_iterable_path(args, path) for path in args_cst}
- args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
- args_val = [get_iterable_path(args, path) for path in args_path]
- # mangle
- caller_context = caller_context or self.caller_context
- fn_name = mangle_fn(get_full_name(fn), [arg.type for arg in args_val], args_cst, caller_context)
- # generate function def if necessary
- if not self.module.has_function(fn_name):
- # If the callee is not set, we use the same debug setting as the caller
- file_name, begin_line = get_jit_fn_file_line(fn)
- arg_types = [
- language.core.constexpr if arg is None or isinstance(arg,
- (bool, int, language.core.dtype)) else arg.type
- for arg in args
- ]
- prototype = ASTFunction([], arg_types, args_cst, dict())
- generator = CodeGenerator(self.context, prototype, fn.get_capture_scope(), module=self.module, jit_fn=fn,
- function_name=fn_name, function_types=self.function_ret_types,
- noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
- options=self.builder.options, codegen_fns=self.builder.codegen_fns,
- module_map=self.builder.module_map, caller_context=caller_context,
- is_gluon=self.is_gluon)
- try:
- generator.visit(fn.parse())
- except Exception as e:
- # Wrap the error in the callee with the location of the call.
- if knobs.compilation.front_end_debugging:
- raise
- raise CompilationError(self.jit_fn.src, self.cur_node, None) from e
- callee_ret_type = generator.ret_type
- self.function_ret_types[fn_name] = callee_ret_type
- else:
- callee_ret_type = self.function_ret_types[fn_name]
- symbol = self.module.get_function(fn_name)
- args_val = flatten_values_to_ir(args_val)
- call_op = self.builder.call(symbol, args_val)
- if callee_ret_type == language.void:
- return None
- handles = [call_op.get_result(i) for i in range(call_op.get_num_results())]
- return next(unflatten_ir_values(handles, [callee_ret_type]))
- def call_Function(self, node, fn, args, kws):
- if isinstance(fn, (BoundJITMethod, BoundConstexprFunction)):
- args.insert(0, fn.__self__)
- fn = fn.__func__
- if isinstance(fn, JITFunction):
- _check_fn_args(node, fn, args)
- return self.call_JitFunction(fn, args, kws)
- if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn) or isinstance(
- fn, ConstexprFunction):
- extra_kwargs = dict()
- if isinstance(fn, ConstexprFunction):
- sig = inspect.signature(fn.__call__)
- else:
- sig = inspect.signature(fn)
- if '_semantic' in sig.parameters:
- extra_kwargs["_semantic"] = self.semantic
- if '_generator' in sig.parameters:
- extra_kwargs['_generator'] = self
- try:
- ret = fn(*args, **extra_kwargs, **kws)
- # builtin functions return plain tuples for readability
- if isinstance(ret, tuple):
- ret = language.tuple(ret)
- return ret
- except Exception as e:
- if knobs.compilation.front_end_debugging:
- raise
- # Normally when we raise a CompilationError, we raise it as
- # `from None`, because the original fileline from the exception
- # is not relevant (and often points into code_generator.py
- # itself). But when calling a function, we raise as `from e` to
- # preserve the traceback of the original error, which may e.g.
- # be in core.py.
- raise CompilationError(self.jit_fn.src, node, str(e)) from e
- if fn in self.builtin_namespace.values() or (hasattr(fn, '__self__') and not _is_triton_value(fn.__self__)):
- args = map(_unwrap_if_constexpr, args)
- ret = fn(*args, **kws)
- def wrap_constexpr(x):
- if _is_triton_value(x):
- return x
- return constexpr(x)
- if isinstance(ret, (builtins.tuple, language.tuple)):
- return _apply_to_tuple_values(ret, wrap_constexpr)
- return wrap_constexpr(ret)
- def call_Method(self, node, fn, fn_self, args, kws):
- if isinstance(fn, JITFunction):
- args.insert(0, fn_self)
- return self.call_Function(node, fn, args, kws)
- def visit_Call(self, node):
- fn = _unwrap_if_constexpr(self.visit(node.func))
- if not isinstance(fn, BoundJITMethod):
- static_implementation = self.statically_implemented_functions.get(fn)
- if static_implementation is not None:
- return static_implementation(self, node)
- mur = getattr(fn, '_must_use_result', False)
- if mur and getattr(node, '_is_unused', False):
- error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
- if isinstance(mur, str):
- error_message.append(mur)
- raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
- kws = dict(self.visit(keyword) for keyword in node.keywords)
- args = []
- for arg in node.args:
- if isinstance(arg, ast.Starred):
- arg = self.visit(arg.value)
- assert isinstance(arg, language.core.tuple)
- args.extend(arg.values)
- else:
- args.append(self.visit(arg))
- return self.call_Function(node, fn, args, kws)
- def visit_Constant(self, node):
- return constexpr(node.value)
- def visit_BoolOp(self, node: ast.BoolOp):
- method_name = self._method_name_for_bool_op.get(type(node.op))
- if method_name is None:
- raise self._unsupported(
- node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
- nontrivial_values = []
- for subnode in node.values:
- # we visit the values in order, executing their side-effects
- # and possibly early-exiting:
- value = self.visit(subnode)
- if not _is_triton_tensor(value):
- # this is a constexpr, so we might be able to short-circuit:
- bv = bool(value)
- if (bv is False) and (method_name == "logical_and"):
- # value is falsey so return that:
- return value
- if (bv is True) and (method_name == "logical_or"):
- # value is truthy so return that:
- return value
- # otherwise, our constexpr has no effect on the output of the
- # expression so we do not append it to nontrivial_values.
- else:
- if value.type.is_block():
- lineno = getattr(node, "lineno", None)
- if lineno is not None:
- lineno += self.begin_line
- warnings.warn_explicit(
- "Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead",
- category=UserWarning,
- filename=self.file_name,
- lineno=lineno,
- source=ast.unparse(node),
- )
- # not a constexpr so we must append it:
- nontrivial_values.append(value)
- if len(nontrivial_values) == 0:
- # the semantics of a disjunction of falsey values or conjunction
- # of truthy values is to return the final value:
- nontrivial_values.append(value)
- while len(nontrivial_values) >= 2:
- rhs = nontrivial_values.pop()
- lhs = nontrivial_values.pop()
- res = self._apply_binary_method(node, method_name, lhs, rhs)
- nontrivial_values.append(res)
- assert len(nontrivial_values) == 1
- return nontrivial_values[0]
- _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
- def get_Attribute(self, lhs, attr):
- if _is_triton_tensor(lhs) and attr == "T":
- return self.semantic.permute(lhs, (1, 0))
- # NOTE: special case ".value" for BC
- if isinstance(lhs, constexpr) and attr not in ("value", "type"):
- lhs = lhs.value
- attr = getattr(lhs, attr)
- if _is_triton_value(lhs) and isinstance(attr, JITFunction):
- return BoundJITMethod(lhs, attr)
- return attr
- def visit_Attribute(self, node):
- lhs = self.visit(node.value)
- if isinstance(lhs, ModuleType):
- # follow module_map until reaching fixed-point:
- while (name := lhs.__name__) in self.builder.module_map:
- lhs = self.builder.module_map[name]
- if lhs.__name__ == name:
- break
- return self.get_Attribute(lhs, node.attr)
- def visit_Expr(self, node):
- node.value._is_unused = True
- ast.NodeVisitor.generic_visit(self, node)
- def visit_NoneType(self, node):
- return None
- def visit_JoinedStr(self, node):
- values = list(node.values)
- for i, value in enumerate(values):
- if isinstance(value, ast.Constant):
- values[i] = str(value.value)
- elif isinstance(value, ast.FormattedValue):
- conversion_code = value.conversion
- evaluated = self.visit(value.value)
- if not _is_constexpr(evaluated):
- raise self._unsupported(
- node,
- "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type "
- + str(type(evaluated)))
- values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
- else:
- raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
- return ''.join(values)
- def visit(self, node):
- if node is None:
- return
- with warnings.catch_warnings():
- # The ast library added visit_Constant and deprecated some other
- # methods but we can't move to that without breaking Python 3.6 and 3.7.
- warnings.simplefilter("ignore", DeprecationWarning) # python 3.9
- warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
- last_node = self.cur_node
- last_loc = self.builder.get_loc()
- self.cur_node = node
- if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
- here_loc = self.builder.create_loc(self.file_name, self.begin_line + node.lineno, node.col_offset)
- if self.name_loc_as_prefix is not None:
- self.builder.set_loc(self.builder.create_name_loc(self.name_loc_as_prefix, here_loc))
- else:
- self.builder.set_loc(here_loc)
- last_loc = self.builder.get_loc()
- try:
- ret = super().visit(node)
- except CompilationError:
- raise
- except Exception as e:
- if knobs.compilation.front_end_debugging:
- raise
- # Wrap the error in a CompilationError which contains the source
- # of the @jit function.
- raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None
- # Reset the location to the last one before the visit
- if last_loc:
- self.cur_node = last_node
- self.builder.set_loc(last_loc)
- return ret
- def generic_visit(self, node):
- raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__))
- def execute_static_assert(self, node: ast.Call) -> None:
- arg_count = len(node.args)
- if not (0 < arg_count <= 2) or len(node.keywords):
- raise TypeError("`static_assert` requires one or two positional arguments only")
- passed = _unwrap_if_constexpr(self.visit(node.args[0]))
- if not isinstance(passed, bool):
- raise NotImplementedError(
- "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values"
- )
- if not passed:
- if arg_count == 1:
- message = ""
- else:
- try:
- message = self.visit(node.args[1])
- except Exception as e:
- message = "<failed to evaluate assertion message: " + repr(e) + ">"
- raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message))
- return None
- def static_executor(python_fn):
- def ret(self, node: ast.Call):
- kws = {
- name: _unwrap_if_constexpr(value)
- for name, value in (self.visit(keyword) for keyword in node.keywords)
- }
- args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args]
- return constexpr(python_fn(*args, **kws))
- return ret
- from ..experimental.gluon import language as ttgl
- statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
- language.core.static_assert: execute_static_assert,
- language.core.static_print: static_executor(print),
- ttgl.static_assert: execute_static_assert,
- ttgl.static_print: static_executor(print),
- int: static_executor(int),
- len: static_executor(len),
- }
- def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
- arg_types = [None] * len(fn.arg_names)
- for k, v in src.signature.items():
- idx = fn.arg_names.index(k)
- arg_types[idx] = str_to_ty(v, None)
- def apply_constexpr_types(argument, indices, value):
- index = indices.pop()
- if len(indices) == 0:
- if isinstance(argument, list):
- argument[index] = constexpr(value).type
- else:
- argument.types[index] = constexpr(value).type
- else:
- apply_constexpr_types(argument[index], indices, value)
- for path, value in src.constants.items():
- apply_constexpr_types(arg_types, list(path)[::-1], value)
- prototype = ASTFunction([], arg_types, src.constants, src.attrs)
- file_name, begin_line = get_jit_fn_file_line(fn)
- # query function representation
- from collections import namedtuple
- leaves = filter(lambda v: len(v) == 1, src.constants)
- constants = {fn.arg_names[i[0]]: src.constants[i] for i in leaves}
- signature = src.signature
- proxy = namedtuple("SpecializationProxy", ["constants", "signature"])(constants, signature)
- generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
- jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
- codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
- generator.visit(fn.parse())
- module = generator.module
- # module takes ownership of the context
- module.context = context
- if not module.verify():
- if not fn.is_gluon():
- print(module)
- raise RuntimeError("error encountered during parsing")
- return module
|