| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492 |
- from __future__ import annotations
- import ast
- import textwrap
- import inspect
- from typing import Tuple, List, Dict, Callable, TypeVar
- import math
- import numpy as np
- import triton
- import triton.language as tl
- import dataclasses
- from dataclasses import dataclass
- from triton.language.semantic import TritonSemantic
- from triton.runtime.jit import KernelInterface
- from triton.tools.tensor_descriptor import TensorDescriptor
- from .errors import InterpreterError
- from functools import partial
- from .._C.libtriton import interpreter as _interpreter
- from .._C.libtriton import ir as _ir
- T = TypeVar("T")
- @dataclass
- class TensorHandle:
- '''
- data: numpy array
- dtype: triton type, either pointer_type or scalar_type.
- we don't store block_type here because the shape information is already available in the data field
- attr: a dictionary of attributes
- '''
- data: np.array
- dtype: tl.dtype
- attr: Dict = dataclasses.field(default_factory=dict)
- def __post_init__(self):
- if not _validate_np_data_size(self.data, self.dtype):
- raise ValueError(f"numpy data itemsize ({self.data.itemsize * 8} bits) exceeds dtype primitive_bitwidth "
- f"({self.dtype.primitive_bitwidth} bits) for triton type {self.dtype}")
- def __bool__(self):
- return bool(self.data.all())
- def get_element_ty(self):
- dtype = self.dtype
- while hasattr(dtype, "element_ty"):
- dtype = dtype.element_ty
- return dtype
- def clone(self):
- return TensorHandle(self.data.copy(), self.dtype)
- def set_attr(self, key, value):
- self.attr[key] = value
- class BlockPointerHandle:
- def __init__(self, base, shape, strides, offsets, block_shape, order):
- self.base = base
- self.shape = shape
- self.strides = strides
- self.offsets = offsets
- self.block_shape = block_shape
- self.order = order
- def materialize_pointers(self, boundary_check):
- dtype_tt = self.base.get_element_ty()
- n_bytes = dtype_tt.primitive_bitwidth // 8
- ptrs = np.broadcast_to(self.base.data, self.block_shape)
- masks = np.ones(self.block_shape, dtype=bool)
- for dim in range(len(self.block_shape)):
- bcast_dims = [1] * len(self.block_shape)
- bcast_dims[dim] = self.block_shape[dim]
- off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
- ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
- if dim in boundary_check:
- masks = masks & (off < self.shape[dim].data) & (off >= 0)
- ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
- return ptrs, masks
- class TensorDescHandle:
- def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
- block_shape: List[int], padding):
- self.base = base
- self.ndim = len(shape)
- self.shape = shape
- self.strides = strides
- self.block_shape = block_shape
- self.padding = padding
- def validate(self):
- assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
- assert len(self.strides) == self.ndim
- assert len(self.block_shape) == self.ndim
- assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
- scalar_ty = self.base.dtype.element_ty
- itemsize = scalar_ty.primitive_bitwidth // 8
- for stride in self.strides[:-1]:
- byte_stride = stride.data.item() * itemsize
- assert byte_stride % 16 == 0, "stride must be 16-byte aligned"
- assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
- def materialize_pointers(self, offsets: List[TensorHandle]):
- assert len(offsets) == self.ndim
- scalar_ty = self.base.dtype.element_ty
- itemsize = scalar_ty.primitive_bitwidth // 8
- assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned"
- ptrs = np.broadcast_to(self.base.data, self.block_shape)
- masks = np.ones(self.block_shape, dtype=bool)
- for dim in range(len(self.block_shape)):
- bcast_dims = [1] * len(self.block_shape)
- bcast_dims[dim] = self.block_shape[dim]
- off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
- ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
- masks = masks & (0 <= off) & (off < self.shape[dim].data)
- assert ptrs.dtype == np.uint64
- ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
- return ptrs, masks
- @dataclass(frozen=True)
- class InterpreterOptions:
- extern_libs: dict = None
- debug: bool = False
- sanitize_overflow: bool = True
- arch: str = None
- supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
- deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
- default_dot_input_precision: str = "tf32"
- allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
- max_num_imprecise_acc_default: int = 0
- backend_name: str = "interpreter"
- def _validate_np_data_size(np_array, tl_dtype):
- if isinstance(tl_dtype, tl.pointer_type):
- return True
- np_dtype_bitwidth = np_array.itemsize * 8
- tl_dtype_bitwidth = tl_dtype.primitive_bitwidth
- # numpy lowest itemsize is at least 8 bits
- if tl_dtype_bitwidth < 8:
- tl_dtype_bitwidth = 8
- if np_dtype_bitwidth > tl_dtype_bitwidth:
- return False
- return True
- def _get_signed_np_dtype(dtype):
- if dtype == np.uint8:
- return np.int8
- if dtype == np.uint16:
- return np.int16
- if dtype == np.uint32:
- return np.int32
- if dtype == np.uint64:
- return np.int64
- return dtype
- def _get_np_dtype(tt_dtype):
- if isinstance(tt_dtype, tl.pointer_type):
- return np.dtype(np.uint64)
- np_types = {
- tl.int1: np.dtype(bool),
- tl.float16: np.dtype(np.float16),
- tl.float32: np.dtype(np.float32),
- tl.float64: np.dtype(np.float64),
- tl.int8: np.dtype(np.int8),
- tl.uint8: np.dtype(np.uint8),
- tl.int16: np.dtype(np.int16),
- tl.uint16: np.dtype(np.uint16),
- tl.int32: np.dtype(np.int32),
- tl.uint32: np.dtype(np.uint32),
- tl.int64: np.dtype(np.int64),
- tl.uint64: np.dtype(np.uint64),
- # bfloat16 types are stored as uint16
- tl.bfloat16: np.dtype(np.uint16),
- # float8 types are stored as uint8
- tl.float8e5: np.dtype(np.uint8),
- tl.float8e5b16: np.dtype(np.uint8),
- tl.float8e4nv: np.dtype(np.uint8),
- tl.float8e4b8: np.dtype(np.uint8),
- tl.float8e4b15: np.dtype(np.uint8),
- }
- if isinstance(tt_dtype, tl.block_type):
- if isinstance(tt_dtype.element_ty, tl.pointer_type):
- return np.dtype(np.uint64)
- return np_types[tt_dtype.element_ty]
- return np_types[tt_dtype]
- def _convert_float(input, input_dtype, output_dtype, rounding_mode):
- input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}")
- output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}")
- input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype)
- sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01
- input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1
- output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1
- significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1)
- bias_input = input_dtype.exponent_bias
- bias_output = output_dtype.exponent_bias
- exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
- subnormal_index = exponent == 0
- if np.any(subnormal_index):
- # Credit to Phil: phil@openai.com
- # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn))
- # where m0, m1, ..., mn are the 1-bit of the mantissa
- # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0))
- bit_pos = np.zeros_like(input_bin, dtype=np.int32)
- # Find the most significant bit of the mantissa in the significand
- for i in range(input_dtype.fp_mantissa_width):
- bit_index = ((significand >> i) & 0x01)
- # pos should be >= 1
- bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i
- zero_significand_index = significand == 0
- exponent[subnormal_index] = 1 - bit_pos[subnormal_index]
- # 0 significand and subnormal should be treated as 0
- exponent[zero_significand_index & subnormal_index] = bias_input - bias_output
- significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & (
- (1 << input_dtype.fp_mantissa_width) - 1)
- # Prevent overflow and underflow
- exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1))
- exponent_output = exponent_output.astype(output_unint_dtype)
- sign_output = sign.astype(output_unint_dtype)
- if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast
- significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & (
- (1 << output_dtype.fp_mantissa_width) - 1)
- if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even
- # find the cut-off bit
- cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1))
- significand_output = significand_output + (cut_off > 0)
- significand_output = significand_output.astype(output_unint_dtype)
- else: # Upcast
- significand_output = (significand.astype(output_unint_dtype) <<
- (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & (
- (1 << output_dtype.fp_mantissa_width) - 1)
- subnormal_index = exponent_output == 0
- if np.any(subnormal_index): # underflow
- # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn))
- # where m0, m1, ..., mn are the 1-bit of the mantissa
- # shift = (1 - exp_bias_output) - (exp - exp_bias_input)
- # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift))
- exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
- non_zero_exponent_index = exponent != 0
- # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa
- subnormal_index = subnormal_index & non_zero_exponent_index
- shift = np.zeros_like(input_bin, dtype=np.int32)
- shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input)
- significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | (
- 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index]))
- output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | (
- exponent_output << output_dtype.fp_mantissa_width) | significand_output
- return output.reshape(input.shape)
- def _erf(x):
- # Numpy does not support erf
- return math.erf(x)
- def _umulhi_64(a, b):
- # Numpy does not support 128-bit multiplication
- # So we have to implement it manually
- return (int(a) * int(b)) >> 64
- np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32])
- np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64])
- np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
- class ExtraFunctions:
- @staticmethod
- def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic):
- return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
- class InterpreterBuilder:
- ir_sem_to_interpreter_sem = {
- _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE,
- _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE,
- _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED,
- _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE,
- }
- ir_rmw_op_to_interpreter_rmw_op = {
- _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD,
- _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD,
- _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN,
- _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN,
- _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX,
- _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX,
- _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND,
- _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR,
- _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR,
- _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG,
- }
- def __init__(self) -> None:
- self.arch = None
- self.options = InterpreterOptions()
- self.codegen_fns = {}
- self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types
- self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1)
- def set_grid_idx(self, x, y, z):
- if not x < self.grid_dim[0]:
- raise ValueError("x >= grid_dim[0]")
- if not y < self.grid_dim[1]:
- raise ValueError("y >= grid_dim[1]")
- if not z < self.grid_dim[2]:
- raise ValueError("z >= grid_dim[2]")
- self.grid_idx = (x, y, z)
- def set_grid_dim(self, nx, ny, nz):
- self.grid_dim = (nx, ny, nz)
- # constants
- def get_half_ty(self):
- return tl.float16
- def get_bf16_ty(self):
- return tl.bfloat16
- def get_float_ty(self):
- return tl.float32
- def get_double_ty(self):
- return tl.float64
- def get_int1_ty(self):
- return tl.int1
- def get_int8_ty(self):
- return tl.int8
- def get_uint8_ty(self):
- return tl.uint8
- def get_int16_ty(self):
- return tl.int16
- def get_uint16_ty(self):
- return tl.uint16
- def get_int32_ty(self):
- return tl.int32
- def get_uint32_ty(self):
- return tl.uint32
- def get_int64_ty(self):
- return tl.int64
- def get_uint64_ty(self):
- return tl.uint64
- def get_fp8e4nv_ty(self):
- return tl.float8e4nv
- def get_fp8e4b15_ty(self):
- return tl.float8e4b15
- def get_fp8e4b8_ty(self):
- return tl.float8e4b8
- def get_fp8e5_ty(self):
- return tl.float8e5
- def get_fp8e5b16_ty(self):
- return tl.float8e5b16
- def get_ptr_ty(self, elt_ty, addr_space):
- return tl.pointer_type(elt_ty, addr_space)
- def get_block_ty(self, dtype, shape):
- return tl.block_type(dtype, shape)
- def get_int1(self, value):
- return TensorHandle(np.array([value], dtype=np.bool_), tl.int1)
- def get_uint8(self, value):
- return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8)
- def get_int8(self, value):
- return TensorHandle(np.array([value], dtype=np.int8), tl.int8)
- def get_uint16(self, value):
- return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16)
- def get_int16(self, value):
- return TensorHandle(np.array([value], dtype=np.int16), tl.int16)
- def get_uint32(self, value):
- return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32)
- def get_int32(self, value):
- return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
- def get_uint64(self, value):
- return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64)
- def get_int64(self, value):
- return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
- def get_fp16(self, value):
- return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
- def get_fp32(self, value):
- return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
- def get_fp64(self, value):
- return TensorHandle(np.array([value], dtype=np.float64), tl.float64)
- def get_null_value(self, type):
- return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type)
- # programming model
- def create_get_program_id(self, axis):
- if self.grid_idx is None:
- raise ValueError("grid_idx is None")
- return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
- def create_get_num_programs(self, axis):
- return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
- # memory ops
- def create_load(self, ptr, _0, _1, is_volatile):
- mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
- other = None
- return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
- def create_store(self, ptr, val, _0, _1):
- mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
- return self.create_masked_store(ptr, val, mask, None, None)
- def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
- dtype_tt = ptrs.get_element_ty()
- dtype_np = _get_np_dtype(dtype_tt)
- if other is None:
- other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
- ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
- return TensorHandle(ret, dtype_tt)
- def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
- return _interpreter.store(ptrs.data, value.data, mask.data)
- # casting ops
- def cast_impl(self, src, dst_type):
- src_element_type = src.dtype.scalar
- dst_element_type = dst_type.scalar
- if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \
- (src_element_type == tl.float32 and dst_element_type == tl.bfloat16):
- data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type))
- return TensorHandle(data, dst_type.scalar)
- else:
- return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar)
- create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
- create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
- create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
- create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
- create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
- create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
- create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
- def create_fp_to_fp(self, src, dst_type, rounding_mode):
- src_element_type = src.dtype.scalar
- dst_element_type = dst_type.scalar
- data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type))
- return TensorHandle(data, dst_type.scalar)
- def create_bitcast(self, src, dst_type):
- return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar)
- # binary operators
- def binary_op(self, lhs, rhs, op):
- output = op(lhs.data, rhs.data)
- tl_dtype = lhs.dtype.scalar
- if not _validate_np_data_size(output, tl_dtype):
- output = output.astype(_get_np_dtype(tl_dtype))
- return TensorHandle(output, tl_dtype)
- create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
- create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
- create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
- create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
- create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
- create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
- create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
- create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
- create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
- # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
- create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
- create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
- create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
- create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
- create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
- create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
- create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
- create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
- create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
- create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
- create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
- create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
- create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
- create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
- create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
- create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
- create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
- create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
- create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
- create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
- create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
- create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
- create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
- create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
- create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
- create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
- create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
- create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
- create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
- create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
- create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
- create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
- create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
- create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
- create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
- create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
- create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
- create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
- create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
- create_int_to_ptr = create_bitcast
- create_ptr_to_int = create_bitcast
- def create_idiv(self, lhs, rhs):
- # Triton has IEEE, not numpy/torch, semantics for %, and those carry
- # through to //, so we have to use a nonstandard expression to get a
- # reference result for //.
- return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar)
- def create_ashr(self, lhs, rhs):
- # Triton's rshift operator depends on the signedness of the left operand
- lhs_dtype = _get_signed_np_dtype(lhs.data.dtype)
- rhs_dtype = _get_signed_np_dtype(rhs.data.dtype)
- lhs.data = lhs.data.astype(lhs_dtype)
- rhs.data = rhs.data.astype(rhs_dtype)
- return self.binary_op(lhs, rhs, np.right_shift)
- def create_umulhi(self, lhs, rhs):
- dtype = lhs.data.dtype
- if dtype == np.int64 or dtype == np.uint64:
- return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar)
- else:
- compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}")
- lhs_data = lhs.data.astype(compute_dtype)
- rhs_data = rhs.data.astype(compute_dtype)
- ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8)
- return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar)
- # ternary functions
- def ternary_op(self, lhs, rhs, other, op):
- output = op(lhs.data, rhs.data, other.data)
- tl_dtype = other.dtype.scalar
- if not _validate_np_data_size(output, tl_dtype):
- output = output.astype(_get_np_dtype(tl_dtype))
- return TensorHandle(output, tl_dtype)
- create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip)
- create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
- def create_fma(self, x, y, z):
- return TensorHandle(x.data * y.data + z.data, z.dtype.scalar)
- # unary functions
- def unary_op(self, arg, op):
- return TensorHandle(op(arg.data), arg.dtype.scalar)
- def create_fabs(self, arg):
- # Mask out the sign bit based on the primitive length
- dtype_tt = arg.dtype
- mask_bitwidth = dtype_tt.primitive_bitwidth - 1
- np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}")
- data = arg.data.view(np_uint_dtype)
- mask = (1 << mask_bitwidth) - 1
- ret = (data & mask).view(_get_np_dtype(dtype_tt))
- return TensorHandle(ret, arg.dtype.scalar)
- create_cos = lambda self, arg: self.unary_op(arg, np.cos)
- create_exp = lambda self, arg: self.unary_op(arg, np.exp)
- create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2)
- create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
- create_floor = lambda self, arg: self.unary_op(arg, np.floor)
- create_ceil = lambda self, arg: self.unary_op(arg, np.ceil)
- create_log = lambda self, arg: self.unary_op(arg, np.log)
- create_log2 = lambda self, arg: self.unary_op(arg, np.log2)
- create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
- create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
- create_sin = lambda self, arg: self.unary_op(arg, np.sin)
- def create_erf(self, arg):
- ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data)
- return TensorHandle(ret, arg.dtype.scalar)
- def create_rsqrt(self, arg):
- return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar)
- # tensor operators
- create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar)
- def create_trans(self, arg, perm):
- return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar)
- def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc):
- a_data = a.data
- b_data = b.data
- if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \
- (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()):
- a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16)
- b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
- return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
- def create_make_range(self, ret_ty, start, stop):
- return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
- def create_histogram(self, data, bins, mask):
- if mask is None:
- mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
- # By default np.histogram returns int64 dtype values
- # Docs specify that returned dtype is taken based on optional weights.dtype
- # This is fix for interpreter cases where for example int32 tensor is being passed
- # But unexpectedly int64 values are being returned causing
- # tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption
- dummy_weights = np.ones_like(data.data, dtype=data.data.dtype)
- # force all masked elements to zero
- data = np.where(mask.data, data.data, np.zeros_like(data.data))
- histogram = np.histogram(data, bins=bins, range=(0, bins), weights=dummy_weights)[0]
- # remove overcounted elements
- histogram[0] -= np.logical_not(mask.data).sum()
- return TensorHandle(histogram, tl.int32)
- def create_gather(self, src, indices, axis):
- return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
- # pointer arithmetic
- def create_addptr(self, ptr, offset):
- dtype_tt = ptr.get_element_ty()
- element_bitwidth = dtype_tt.primitive_bitwidth
- # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic
- element_bytewidth = max(1, element_bitwidth // 8)
- return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype)
- def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
- is_volatile):
- ptrs, masks = ptr.materialize_pointers(boundary_check)
- dtype_tt = ptrs.get_element_ty()
- dtype_np = _get_np_dtype(dtype_tt)
- if padding_option is None:
- other = None
- elif padding_option == _ir.PADDING_OPTION.PAD_ZERO:
- other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
- elif padding_option == _ir.PADDING_OPTION.PAD_NAN:
- other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
- else:
- raise ValueError(f"unsupported padding option {padding_option}")
- return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
- def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
- ptrs, masks = ptr.materialize_pointers(boundary_check)
- return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
- def create_expand_dims(self, arg, axis):
- return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar)
- def create_broadcast(self, arg, shape):
- return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar)
- def create_cat(self, lhs, rhs):
- return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar)
- def create_join(self, lhs, rhs):
- # Triton only supports joining two original tensors into a new one along the last axis
- return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar)
- def create_split(self, val):
- # Triton only supports splitting the original tensor into two along the last axis
- return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
- def create_splat(self, ret_ty, arg):
- shape = ret_ty.shape
- if isinstance(arg.dtype, tl.block_type):
- return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
- else: # scalar
- return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
- def create_unsplat(self, arg):
- return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
- def create_atomic_cas(self, ptr, cmp, val, sem, scope):
- if sem not in self.ir_sem_to_interpreter_sem:
- raise ValueError(f"unsupported semantic {sem}")
- sem = self.ir_sem_to_interpreter_sem[sem]
- return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar)
- def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope):
- if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op:
- raise ValueError(f"unsupported rmwOp {rmwOp}")
- if sem not in self.ir_sem_to_interpreter_sem:
- raise ValueError(f"unsupported semantic {sem}")
- rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp]
- sem = self.ir_sem_to_interpreter_sem[sem]
- return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar)
- def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
- raise NotImplementedError("extern_elementwise not supported in interpreter mode")
- def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
- raise NotImplementedError("inline_asm not supported in interpreter mode")
- def create_print(self, prefix, hex, values, isSigned):
- # NOTE: the `isSigned` variable is not really used here; because Signness is already known
- # by `values` themselves in python interpreter, thus not really needed here;
- # it is only used for triton PrintOpToLLVM to correctly construct the format specifier.
- # Interpreter's device_print function has a different format than Triton's device_print
- msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})"
- if prefix:
- msg += f" {prefix}"
- if hex:
- np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"})
- for value in values:
- print(msg + f" {value.data}")
- if hex:
- np.set_printoptions(formatter=None)
- def create_assert(self, condition, message):
- # Interpreter's device_assert function has a different format than Triton's device_assert
- assert condition, f"{message}"
- def create_assume(self, condition):
- assert condition, "Assume failed"
- def create_barrier(self):
- # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
- pass
- def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
- # Create new offsets to avoid modifying the original
- new_offsets = [offset.clone() for offset in offsets]
- return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)
- def create_advance(self, ptr, offsets):
- if len(ptr.offsets) != len(offsets):
- raise ValueError("len(ptr.offsets) != len(offsets)")
- # Create new offsets to avoid modifying the original
- new_offsets = [offset.clone() for offset in ptr.offsets]
- ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
- for i in range(len(offsets)):
- ret.offsets[i].data += offsets[i].data
- return ret
- def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
- tensor_shape: List[int], is_signed: bool, padding: str = "zero"):
- desc = TensorDescHandle(base, shape, strides, tensor_shape, padding)
- desc.validate()
- return desc
- def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier,
- eviction_policy):
- assert isinstance(desc, TensorDescHandle)
- ptrs, mask = desc.materialize_pointers(indices)
- dtype_tt = ptrs.get_element_ty()
- dtype_np = _get_np_dtype(dtype_tt)
- padding = desc.padding
- if padding == _ir.PADDING_OPTION.PAD_ZERO:
- other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
- elif padding == _ir.PADDING_OPTION.PAD_NAN:
- other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
- else:
- raise ValueError(f"unsupported padding {padding}")
- return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier,
- eviction_policy=eviction_policy, is_volatile=False)
- def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
- ptrs, mask = desc.materialize_pointers(indices)
- return self.create_masked_store(ptrs, value, mask, None, None)
- def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type):
- dtype = desc.base.dtype.element_ty
- np_dtype = _get_np_dtype(dtype)
- result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype)
- cache_modifier = None
- eviction_policy = None
- for i, x_offset in enumerate(x_offsets.data):
- indices = [TensorHandle(x_offset, tl.int32), y_offset]
- result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data
- return TensorHandle(result, dtype)
- def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle,
- y_offset: TensorHandle):
- for i, x_offset in enumerate(x_offsets.data):
- slice = TensorHandle(value.data[i], value.dtype)
- indices = [TensorHandle(x_offset, tl.int32), y_offset]
- self.create_descriptor_store(desc, slice, indices)
- def get_all_ones_value(self, type):
- np_type = _get_np_dtype(type)
- if "int" in np_type.name:
- return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
- elif np_type == np.bool_:
- return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
- else:
- raise TypeError(f"unsupported type {type}")
- _MISSING = object()
- class _LangPatchScope:
- """Tracks patched attributes so they can be restored."""
- def __init__(self) -> None:
- self._changes: list[tuple[object, str, object]] = []
- def set_attr(self, obj: object, name: str, value: object) -> None:
- original = getattr(obj, name, _MISSING)
- self._changes.append((obj, name, original))
- setattr(obj, name, value)
- def restore(self) -> None:
- while self._changes:
- obj, name, original = self._changes.pop()
- if original is _MISSING:
- delattr(obj, name)
- else:
- setattr(obj, name, original)
- def _patch_attr(obj, name, member, builder, scope: _LangPatchScope):
- semantic = TritonSemantic(builder)
- new_member = lambda *args, member=member, **kwargs: (member(*args, **
- {k: v
- for k, v in kwargs.items()
- if k != "_semantic"}, _semantic=semantic))
- scope.set_attr(obj, name, new_member)
- def _patch_builtin(pkg, builder, scope: _LangPatchScope):
- for name, member in inspect.getmembers(pkg):
- if tl.core.is_builtin(member):
- _patch_attr(pkg, name, member, builder, scope)
- def _patch_lang_tensor(tensor, scope: _LangPatchScope):
- def _get_bool(self):
- data = self.handle.data
- # in triton, only scalars can be converted to booleans
- # here we need this hack because all scalars are tensors
- return bool(data) if data.size == 1 else True
- def _get_transpose(self):
- handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype)
- assert self.type.is_block()
- block_shape = list(self.type.shape)
- block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1]
- res_ty = tl.core.block_type(self.dtype, block_shape)
- return tl.core.tensor(handle, res_ty)
- scope.set_attr(tensor, "__index__", lambda self: int(self.handle.data))
- scope.set_attr(tensor, "__bool__", lambda self: _get_bool(self))
- scope.set_attr(tensor, "__repr__", lambda self: repr(self.handle.data))
- scope.set_attr(tensor, "__str__", lambda self: str(self.handle.data))
- scope.set_attr(tensor, "T", property(_get_transpose))
- class ReduceScanOpInterface:
- def __init__(self, axis, combine_fn):
- self.axis = axis
- self.combine_fn = combine_fn
- def check_axis(self, shape, axis):
- if axis is not None and axis >= len(shape):
- raise ValueError(f"axis {axis} out of bounds for shape {shape}")
- def check_tensor(self, input):
- for arg in input:
- if not isinstance(arg, tl.core.tensor):
- raise ValueError(f"input must be a tensor, got {type(arg)}")
- self.check_axis(arg.shape, self.axis)
- def to_tensor(self, ret, dtype):
- np_dtype = _get_np_dtype(dtype)
- if hasattr(ret, "shape") and ret.shape:
- ret = ret.astype(np_dtype)
- ret_type = tl.block_type(dtype, list(ret.shape))
- else:
- ret = np.array([ret], dtype=np_dtype)
- ret_type = dtype
- return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)
- def apply(self, input):
- if not isinstance(input, tuple):
- return self.apply((input, ))[0]
- self.check_tensor(input)
- ret = self.apply_impl(input)
- return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, )
- class ReduceOps(ReduceScanOpInterface):
- def __init__(self, axis, combine_fn, keep_dims):
- super().__init__(axis, combine_fn)
- self.keep_dims = keep_dims
- def unravel(self, input, axis):
- ret = []
- for data in input:
- if axis is not None:
- ret.append(data)
- else:
- axis = 0
- ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype))
- return tuple(ret), axis
- def generic_reduce(self, input):
- original_axis = self.axis
- input, axis = self.unravel(input, self.axis)
- input_data = []
- output_data = []
- input_shape = input[0].handle.data.shape
- output_shape = input_shape[0:axis] + input_shape[axis + 1:]
- for arg in input:
- input_data.append(arg.handle.data)
- output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype))
- # Reduce on axis
- for i in range(input_data[0].size):
- # Recover input_index from i using input_shape
- input_index = np.unravel_index(i, input_shape)
- output_index = input_index[0:axis] + input_index[axis + 1:]
- input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data))
- if input_index[axis] == 0:
- # First element
- for j in range(len(output_data)):
- output_data[j][output_index] = input_tuple[j].handle.data.item()
- else:
- acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data))
- combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple)
- acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
- for j in range(len(output_data)):
- output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance(
- acc_tuple[j], tl.core.tensor) else acc_tuple[j]
- # Pack output
- ret = []
- for i, data in enumerate(output_data):
- if self.keep_dims:
- if original_axis is not None:
- data = np.expand_dims(data, axis)
- else:
- for _ in range(len(input_shape)):
- data = np.expand_dims(data, 0)
- elif original_axis is None:
- # Take a scalar
- data = data.item()
- ret.append(self.to_tensor(data, input[i].dtype))
- return ret
- def min_max(self, input, val_reduce_op, idx_reduce_op=None):
- # If input is a tuple, it must be (val, index), and we only take val
- input = input[0] if isinstance(input, tuple) else input
- val = None
- idx = None
- if val_reduce_op:
- val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
- if idx_reduce_op:
- idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32)
- if val is not None and idx is not None:
- return val, idx
- elif val is not None:
- return val
- elif idx is not None:
- return idx
- else:
- raise ValueError("val_reduce_op and idx_reduce_op are both None")
- def sum(self, input):
- return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
- def apply_impl(self, input):
- if self.combine_fn == tl.standard._argmin_combine_tie_break_left:
- return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin)
- elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
- return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
- elif self.combine_fn == tl.standard._elementwise_max:
- return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
- elif self.combine_fn == tl.standard._elementwise_min:
- return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
- elif self.combine_fn == tl.standard._sum_combine:
- return self.sum(input[0])
- else:
- # Fall back to the slow mode
- return self.generic_reduce(input)
- class ScanOps(ReduceScanOpInterface):
- def __init__(self, axis, combine_fn, reverse):
- super().__init__(axis, combine_fn)
- self.reverse = reverse
- def cumsum(self, input):
- return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)]
- def cumprod(self, input):
- return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)]
- def generic_scan(self, input):
- input_data = []
- output_data = []
- shape = input[0].handle.data.shape
- for arg in input:
- input_data.append(arg.handle.data)
- output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype))
- # Scan on axis
- for i in range(input_data[0].size):
- # Recover index from i using shape
- index = np.unravel_index(i, shape)
- data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data))
- if index[self.axis] == 0:
- # First element
- for j in range(len(output_data)):
- output_data[j][index] = data[j].handle.data.item()
- else:
- prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index)))
- acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data))
- combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data)
- acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
- for j in range(len(output_data)):
- output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance(
- acc_tuple[j], tl.core.tensor) else acc_tuple[j]
- # Pack output
- ret = []
- for i, data in enumerate(output_data):
- ret.append(self.to_tensor(data, input[i].dtype))
- return ret
- def apply_impl(self, input):
- new_input = []
- if self.reverse:
- for arg in input:
- new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype))
- else:
- new_input = input
- if self.combine_fn == tl.standard._sum_combine:
- ret = self.cumsum(new_input[0])
- elif self.combine_fn == tl.standard._prod_combine:
- ret = self.cumprod(new_input[0])
- else:
- # Fall back to the slow mode
- ret = self.generic_scan(new_input)
- if self.reverse:
- for arg in ret:
- arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
- return ret
- def _patch_reduce_scan(scope: _LangPatchScope):
- # Because interpreter doesn't support region_builder_fn, we cannot patch the builder
- # to use the new reduce and scan functions.
- # Instead, we need to patch reduce and reduce functions in tl and tl.core
- def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs):
- return ReduceOps(axis, combine_fn, keep_dims).apply(input)
- def _new_scan(input, axis, combine_fn, reverse=False, **kwargs):
- return ScanOps(axis, combine_fn, reverse).apply(input)
- scope.set_attr(tl, "reduce", _new_reduce)
- scope.set_attr(tl, "associative_scan", _new_scan)
- scope.set_attr(tl.core, "reduce", _new_reduce)
- scope.set_attr(tl.core, "associative_scan", _new_scan)
- def _patch_lang_core(lang, scope: _LangPatchScope):
- def _new_to_ir(self, builder):
- # We need to specify signedness for integer types in the numpy mode
- if self.name == 'void':
- return builder.get_void_ty()
- elif self.name == 'int1':
- return builder.get_int1_ty()
- elif self.name == 'int8':
- return builder.get_int8_ty()
- elif self.name == 'uint8':
- return builder.get_uint8_ty()
- elif self.name == 'int16':
- return builder.get_int16_ty()
- elif self.name == 'uint16':
- return builder.get_uint16_ty()
- elif self.name == 'int32':
- return builder.get_int32_ty()
- elif self.name == 'uint32':
- return builder.get_uint32_ty()
- elif self.name == 'int64':
- return builder.get_int64_ty()
- elif self.name == 'uint64':
- return builder.get_uint64_ty()
- elif self.name == 'fp8e5':
- return builder.get_fp8e5_ty()
- elif self.name == 'fp8e4nv':
- return builder.get_fp8e4nv_ty()
- elif self.name == 'fp8e4b15':
- return builder.get_fp8e4b15_ty()
- elif self.name == 'fp16':
- return builder.get_half_ty()
- elif self.name == 'bf16':
- return builder.get_bf16_ty()
- elif self.name == 'fp32':
- return builder.get_float_ty()
- elif self.name == 'fp64':
- return builder.get_double_ty()
- raise ValueError(f'fail to convert {self} to ir type')
- # can't just map lang.static_range to `range`, because `tl.static_range`
- # can get `step` passed by keyword
- def _new_range(arg1, arg2=None, step=None, **kwargs):
- if step is None:
- step = 1
- if arg2 is None:
- start, end = 0, arg1
- else:
- start, end = arg1, arg2
- return range(start, end, step)
- def _new_static_assert(cond, msg=""):
- assert cond, msg
- def _set_attr(input, values, name):
- # skip non tensor types. This may happen for induction variables.
- if not isinstance(input, tl.tensor):
- return input
- # Unwrap constexpr
- values = [values] if not isinstance(values, (list, tuple)) else values
- values = [v.value if isinstance(v, tl.constexpr) else v for v in values]
- if len(values) != max(1, len(input.shape)):
- raise ValueError(f"len(values) != len(input.shape) for {name}")
- input.handle.set_attr(name, values)
- return input
- scope.set_attr(lang, "range", _new_range)
- scope.set_attr(lang, "static_range", _new_range)
- scope.set_attr(lang, "static_assert", _new_static_assert)
- scope.set_attr(lang, "static_print", print)
- scope.set_attr(lang.dtype, "to_ir", _new_to_ir)
- scope.set_attr(lang, "multiple_of", partial(_set_attr, name="tt.divisibility"))
- scope.set_attr(lang, "max_contiguous", partial(_set_attr, name="tt.contiguity"))
- scope.set_attr(lang, "max_constancy", partial(_set_attr, name="tt.constancy"))
- _patch_reduce_scan(scope)
- def _patch_lang(fn):
- scope = _LangPatchScope()
- langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]]
- assert len(langs) >= 1, "triton.language must be visible from within jit'd function"
- for lang in langs:
- _patch_builtin(lang, interpreter_builder, scope)
- _patch_builtin(lang.tensor, interpreter_builder, scope)
- if lang == tl:
- _patch_builtin(lang.math, interpreter_builder, scope)
- _patch_lang_tensor(lang.tensor, scope)
- _patch_lang_core(lang, scope)
- _patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder, scope)
- return scope
- def _tuple_create(arg, contents):
- # NamedTuples and tuples have different construction semantics. NamedTuple
- # has a constructor that takes individual arguments, while tuple takes an
- # iterable. Both have type "tuple" making it difficult to distinguish
- # between them, but only NamedTuple has "_fields" and apparently this is how
- # everyone does the check.
- return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents)
- # TODO: wrap everything in triton tensors
- def _implicit_cvt(arg):
- if isinstance(arg, int):
- ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
- dtype = np.int32
- if -2**31 <= arg < 2**31:
- dtype = np.int32
- elif 2**31 <= arg < 2**32:
- dtype = np.uint32
- elif -2**63 <= arg < 2**63:
- dtype = np.int64
- elif 2**63 <= arg < 2**64:
- dtype = np.uint64
- else:
- raise ValueError(f"Unsupported integer value {arg}")
- handle = TensorHandle(np.array([arg], dtype=dtype), ty)
- return tl.tensor(handle, ty)
- if hasattr(arg, "data_ptr"):
- ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
- handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
- return tl.tensor(handle, ty)
- elif isinstance(arg, tuple):
- return _tuple_create(arg, map(_implicit_cvt, arg))
- elif isinstance(arg, TensorDescriptor):
- strides = [_implicit_cvt(s) for s in arg.strides]
- assert arg.strides[-1] == 1
- strides[-1] = tl.constexpr(1)
- semantic = TritonSemantic(InterpreterBuilder())
- return semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base),
- shape=[_implicit_cvt(s) for s in arg.shape], strides=strides,
- block_shape=[tl.constexpr(b)
- for b in arg.block_shape], padding_option=arg.padding)
- return arg
- interpreter_builder = InterpreterBuilder()
- interpreter_semantic = TritonSemantic(interpreter_builder)
- def _unwrap_tensor(t):
- if isinstance(t, triton.runtime.jit.TensorWrapper):
- return t.base
- return t
- def _rewrap_tensor(t, original_tensor):
- if isinstance(original_tensor, triton.runtime.jit.TensorWrapper):
- return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype)
- return t
- class GridExecutor:
- def __init__(self, fn, arg_names, grid, pre_run_hooks=[]):
- from .jit import _normalize_ty # TODO: modularize
- self.fn = fn
- self.arg_names = arg_names
- self.grid = grid
- self.pre_run_hooks = pre_run_hooks
- __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
- self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
- def _init_args_hst(self, args_dev, kwargs):
- storages = {}
- def _to_cpu(arg):
- if isinstance(arg, tuple):
- return _tuple_create(arg, map(_to_cpu, arg))
- elif isinstance(arg, TensorDescriptor):
- return TensorDescriptor(
- _to_cpu(arg.base),
- arg.shape,
- arg.strides,
- arg.block_shape,
- arg.padding,
- )
- elif not hasattr(arg, "data_ptr"):
- return arg
- unwrapped_arg = _unwrap_tensor(arg)
- if unwrapped_arg.untyped_storage().data_ptr() not in storages:
- storage = unwrapped_arg.untyped_storage()
- storages[storage.data_ptr()] = storage.cpu()
- storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
- cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
- cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
- cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
- return cpu_arg
- args_hst = [_to_cpu(arg) for arg in args_dev]
- # Process keyword arguments
- kwargs_hst = {}
- for key, value in kwargs.items():
- kwargs_hst[key] = _to_cpu(value)
- return args_hst, kwargs_hst
- def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
- storages = {}
- def _from_cpu(arg_dev, arg_hst):
- if hasattr(arg_dev, "data_ptr"):
- # No need to rewrap because this just modifies internal
- arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
- storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
- elif isinstance(arg_dev, tuple):
- for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
- _from_cpu(arg_dev, arg_hst)
- elif isinstance(arg_dev, TensorDescriptor):
- _from_cpu(arg_dev.base, arg_hst.base)
- for arg_dev, arg_hst in zip(args_dev, args_hst):
- _from_cpu(arg_dev, arg_hst)
- # Restore keyword arguments
- for key, kwarg_dev in kwargs.items():
- kwarg_hst = kwargs_hst[key]
- _from_cpu(kwarg_dev, kwarg_hst)
- for (arg_dev, arg_hst) in storages.values():
- arg_dev.copy_(arg_hst)
- def __call__(self, *args_dev, **kwargs):
- # Removes not used reserved keywords from kwargs
- # Triton doesn't support keyword-only, variable positional or variable keyword arguments
- # It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
- argspec = inspect.getfullargspec(self.fn)
- kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
- # copy arguments to the host
- args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
- # run pre-run hooks
- for hook in self.pre_run_hooks:
- hook(*args_hst, **kwargs_hst)
- # remaps core language functions to interpreted ones
- patch_scope = _patch_lang(self.fn)
- try:
- # we need to copy arguments to the host for the interpreter
- # implicitly convert tensor arguments to their base pointers
- args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
- args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
- # iterate through grid
- grid = self.grid(args) if callable(self.grid) else self.grid
- assert len(grid) <= 3, "grid must have at most 3 dimensions"
- grid = grid + (1, ) * (3 - len(grid))
- interpreter_builder.set_grid_dim(*grid)
- try:
- for x in range(grid[0]):
- for y in range(grid[1]):
- for z in range(grid[2]):
- interpreter_builder.set_grid_idx(x, y, z)
- self.fn(**args)
- except Exception as e:
- if triton.knobs.compilation.front_end_debugging:
- raise
- raise InterpreterError(repr(e)) from e
- finally:
- patch_scope.restore()
- # copy arguments back to propagate side-effects
- self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
- class ASTTransformer(ast.NodeTransformer):
- def visit_Assign(self, node):
- names = []
- for target in node.targets:
- names += [self.visit(target)]
- if len(names) > 1:
- raise ValueError("Multiple assignments are not supported")
- # Modify the assignment x = value to
- # interpreter_semantic.to_tensor(value, False)
- node.value = ast.Call(
- func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor",
- ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[])
- return node
- class FunctionRewriter:
- ast_transformer = ASTTransformer()
- def __init__(self, fn, **kwargs):
- self.fn = fn
- self.kwargs = kwargs
- self.filename: str = ""
- # Absolute line number in the file
- self.def_file_lineno: int = 0
- def rewrite_ast(self):
- # If exception is raise, it means the function does not have source code available,
- # e.g., dynamically generated functions, we cannot rewrite it so just return the original function
- try:
- lines, _ = inspect.getsourcelines(self.fn)
- except Exception:
- return self.fn
- # truncate lines before def
- # @triton.autotune(...)
- # ...
- # @triton.jit
- # ...
- # def foo(...): <- this line is the function definition
- self.filename, self.def_file_lineno = self._get_jit_fn_file_line()
- self.def_lineno = self._find_def(lines)
- src = self._prepare_source(lines)
- transformed_ast = self._transform_ast(src)
- return self._compile_and_exec(transformed_ast)
- def _get_jit_fn_file_line(self):
- from .jit import get_jit_fn_file_line, JITFunction
- return get_jit_fn_file_line(JITFunction(self.fn))
- def _find_def(self, lines):
- def_lineno = 0
- # Line numbers start from 1
- for i, line in enumerate(lines):
- if line.strip().startswith("def "):
- def_lineno = i + 1
- return def_lineno
- def _prepare_source(self, lines):
- lines = lines[self.def_lineno - 1:]
- src = ''.join(lines)
- return textwrap.dedent(src)
- def _transform_ast(self, src):
- # src is like:
- # 1: def foo(...):
- # 2: ...
- parsed_ast = ast.parse(src)
- transformed_ast = self.ast_transformer.visit(parsed_ast)
- ast.fix_missing_locations(transformed_ast)
- inc_lineno = self.def_file_lineno - 1
- ast.increment_lineno(transformed_ast, inc_lineno)
- return transformed_ast
- def _compile_and_exec(self, transformed_ast):
- compiled_code = compile(transformed_ast, filename=self.filename, mode='exec')
- local_namespace = {**self.kwargs}
- fn_globals = self.fn.__globals__
- for key, value in globals().items():
- if key not in fn_globals:
- fn_globals[key] = value
- exec(compiled_code, fn_globals, local_namespace)
- return local_namespace[self.fn.__name__]
- class InterpretedFunction(KernelInterface[T]):
- # Cache all rewritten functions
- rewritten_fn: Dict[Callable, Callable] = {}
- def __init__(self, fn, **kwargs) -> None:
- self.fn = fn
- self.rewriter = FunctionRewriter(fn, **kwargs)
- self.kwargs = kwargs
- self.pre_run_hooks = []
- signature = inspect.signature(fn)
- self.arg_names = [v.name for v in signature.parameters.values()]
- def run(self, *args, grid, warmup, **kwargs):
- if warmup:
- return
- fn = self.rewrite()
- return GridExecutor(fn, self.arg_names, grid, self.pre_run_hooks)(*args, **kwargs)
- def add_pre_run_hook(self, hook):
- assert callable(hook)
- self.pre_run_hooks.append(hook)
- def rewrite(self):
- if self.fn not in self.rewritten_fn:
- self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast()
- return self.rewritten_fn[self.fn]
- @property
- def __name__(self):
- return self.fn.__name__
- def __call__(self, *args, **kwargs):
- # This is a device function call
- _patch_lang(self.fn)
- fn = self.rewrite()
- try:
- return fn(*args, **kwargs)
- except Exception as e:
- raise InterpreterError(repr(e)) from e
|