interpreter.py 64 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492
  1. from __future__ import annotations
  2. import ast
  3. import textwrap
  4. import inspect
  5. from typing import Tuple, List, Dict, Callable, TypeVar
  6. import math
  7. import numpy as np
  8. import triton
  9. import triton.language as tl
  10. import dataclasses
  11. from dataclasses import dataclass
  12. from triton.language.semantic import TritonSemantic
  13. from triton.runtime.jit import KernelInterface
  14. from triton.tools.tensor_descriptor import TensorDescriptor
  15. from .errors import InterpreterError
  16. from functools import partial
  17. from .._C.libtriton import interpreter as _interpreter
  18. from .._C.libtriton import ir as _ir
  19. T = TypeVar("T")
  20. @dataclass
  21. class TensorHandle:
  22. '''
  23. data: numpy array
  24. dtype: triton type, either pointer_type or scalar_type.
  25. we don't store block_type here because the shape information is already available in the data field
  26. attr: a dictionary of attributes
  27. '''
  28. data: np.array
  29. dtype: tl.dtype
  30. attr: Dict = dataclasses.field(default_factory=dict)
  31. def __post_init__(self):
  32. if not _validate_np_data_size(self.data, self.dtype):
  33. raise ValueError(f"numpy data itemsize ({self.data.itemsize * 8} bits) exceeds dtype primitive_bitwidth "
  34. f"({self.dtype.primitive_bitwidth} bits) for triton type {self.dtype}")
  35. def __bool__(self):
  36. return bool(self.data.all())
  37. def get_element_ty(self):
  38. dtype = self.dtype
  39. while hasattr(dtype, "element_ty"):
  40. dtype = dtype.element_ty
  41. return dtype
  42. def clone(self):
  43. return TensorHandle(self.data.copy(), self.dtype)
  44. def set_attr(self, key, value):
  45. self.attr[key] = value
  46. class BlockPointerHandle:
  47. def __init__(self, base, shape, strides, offsets, block_shape, order):
  48. self.base = base
  49. self.shape = shape
  50. self.strides = strides
  51. self.offsets = offsets
  52. self.block_shape = block_shape
  53. self.order = order
  54. def materialize_pointers(self, boundary_check):
  55. dtype_tt = self.base.get_element_ty()
  56. n_bytes = dtype_tt.primitive_bitwidth // 8
  57. ptrs = np.broadcast_to(self.base.data, self.block_shape)
  58. masks = np.ones(self.block_shape, dtype=bool)
  59. for dim in range(len(self.block_shape)):
  60. bcast_dims = [1] * len(self.block_shape)
  61. bcast_dims[dim] = self.block_shape[dim]
  62. off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
  63. ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
  64. if dim in boundary_check:
  65. masks = masks & (off < self.shape[dim].data) & (off >= 0)
  66. ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
  67. return ptrs, masks
  68. class TensorDescHandle:
  69. def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
  70. block_shape: List[int], padding):
  71. self.base = base
  72. self.ndim = len(shape)
  73. self.shape = shape
  74. self.strides = strides
  75. self.block_shape = block_shape
  76. self.padding = padding
  77. def validate(self):
  78. assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
  79. assert len(self.strides) == self.ndim
  80. assert len(self.block_shape) == self.ndim
  81. assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
  82. scalar_ty = self.base.dtype.element_ty
  83. itemsize = scalar_ty.primitive_bitwidth // 8
  84. for stride in self.strides[:-1]:
  85. byte_stride = stride.data.item() * itemsize
  86. assert byte_stride % 16 == 0, "stride must be 16-byte aligned"
  87. assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
  88. def materialize_pointers(self, offsets: List[TensorHandle]):
  89. assert len(offsets) == self.ndim
  90. scalar_ty = self.base.dtype.element_ty
  91. itemsize = scalar_ty.primitive_bitwidth // 8
  92. assert (offsets[-1].data * itemsize) % 16 == 0, "block offset start must be 16-byte aligned"
  93. ptrs = np.broadcast_to(self.base.data, self.block_shape)
  94. masks = np.ones(self.block_shape, dtype=bool)
  95. for dim in range(len(self.block_shape)):
  96. bcast_dims = [1] * len(self.block_shape)
  97. bcast_dims[dim] = self.block_shape[dim]
  98. off = (offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
  99. ptrs = ptrs + (itemsize * off * self.strides[dim].data).astype(np.uint64)
  100. masks = masks & (0 <= off) & (off < self.shape[dim].data)
  101. assert ptrs.dtype == np.uint64
  102. ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
  103. return ptrs, masks
  104. @dataclass(frozen=True)
  105. class InterpreterOptions:
  106. extern_libs: dict = None
  107. debug: bool = False
  108. sanitize_overflow: bool = True
  109. arch: str = None
  110. supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
  111. deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
  112. default_dot_input_precision: str = "tf32"
  113. allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
  114. max_num_imprecise_acc_default: int = 0
  115. backend_name: str = "interpreter"
  116. def _validate_np_data_size(np_array, tl_dtype):
  117. if isinstance(tl_dtype, tl.pointer_type):
  118. return True
  119. np_dtype_bitwidth = np_array.itemsize * 8
  120. tl_dtype_bitwidth = tl_dtype.primitive_bitwidth
  121. # numpy lowest itemsize is at least 8 bits
  122. if tl_dtype_bitwidth < 8:
  123. tl_dtype_bitwidth = 8
  124. if np_dtype_bitwidth > tl_dtype_bitwidth:
  125. return False
  126. return True
  127. def _get_signed_np_dtype(dtype):
  128. if dtype == np.uint8:
  129. return np.int8
  130. if dtype == np.uint16:
  131. return np.int16
  132. if dtype == np.uint32:
  133. return np.int32
  134. if dtype == np.uint64:
  135. return np.int64
  136. return dtype
  137. def _get_np_dtype(tt_dtype):
  138. if isinstance(tt_dtype, tl.pointer_type):
  139. return np.dtype(np.uint64)
  140. np_types = {
  141. tl.int1: np.dtype(bool),
  142. tl.float16: np.dtype(np.float16),
  143. tl.float32: np.dtype(np.float32),
  144. tl.float64: np.dtype(np.float64),
  145. tl.int8: np.dtype(np.int8),
  146. tl.uint8: np.dtype(np.uint8),
  147. tl.int16: np.dtype(np.int16),
  148. tl.uint16: np.dtype(np.uint16),
  149. tl.int32: np.dtype(np.int32),
  150. tl.uint32: np.dtype(np.uint32),
  151. tl.int64: np.dtype(np.int64),
  152. tl.uint64: np.dtype(np.uint64),
  153. # bfloat16 types are stored as uint16
  154. tl.bfloat16: np.dtype(np.uint16),
  155. # float8 types are stored as uint8
  156. tl.float8e5: np.dtype(np.uint8),
  157. tl.float8e5b16: np.dtype(np.uint8),
  158. tl.float8e4nv: np.dtype(np.uint8),
  159. tl.float8e4b8: np.dtype(np.uint8),
  160. tl.float8e4b15: np.dtype(np.uint8),
  161. }
  162. if isinstance(tt_dtype, tl.block_type):
  163. if isinstance(tt_dtype.element_ty, tl.pointer_type):
  164. return np.dtype(np.uint64)
  165. return np_types[tt_dtype.element_ty]
  166. return np_types[tt_dtype]
  167. def _convert_float(input, input_dtype, output_dtype, rounding_mode):
  168. input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}")
  169. output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}")
  170. input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype)
  171. sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01
  172. input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1
  173. output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1
  174. significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1)
  175. bias_input = input_dtype.exponent_bias
  176. bias_output = output_dtype.exponent_bias
  177. exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
  178. subnormal_index = exponent == 0
  179. if np.any(subnormal_index):
  180. # Credit to Phil: phil@openai.com
  181. # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn))
  182. # where m0, m1, ..., mn are the 1-bit of the mantissa
  183. # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0))
  184. bit_pos = np.zeros_like(input_bin, dtype=np.int32)
  185. # Find the most significant bit of the mantissa in the significand
  186. for i in range(input_dtype.fp_mantissa_width):
  187. bit_index = ((significand >> i) & 0x01)
  188. # pos should be >= 1
  189. bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i
  190. zero_significand_index = significand == 0
  191. exponent[subnormal_index] = 1 - bit_pos[subnormal_index]
  192. # 0 significand and subnormal should be treated as 0
  193. exponent[zero_significand_index & subnormal_index] = bias_input - bias_output
  194. significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & (
  195. (1 << input_dtype.fp_mantissa_width) - 1)
  196. # Prevent overflow and underflow
  197. exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1))
  198. exponent_output = exponent_output.astype(output_unint_dtype)
  199. sign_output = sign.astype(output_unint_dtype)
  200. if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast
  201. significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & (
  202. (1 << output_dtype.fp_mantissa_width) - 1)
  203. if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even
  204. # find the cut-off bit
  205. cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1))
  206. significand_output = significand_output + (cut_off > 0)
  207. significand_output = significand_output.astype(output_unint_dtype)
  208. else: # Upcast
  209. significand_output = (significand.astype(output_unint_dtype) <<
  210. (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & (
  211. (1 << output_dtype.fp_mantissa_width) - 1)
  212. subnormal_index = exponent_output == 0
  213. if np.any(subnormal_index): # underflow
  214. # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn))
  215. # where m0, m1, ..., mn are the 1-bit of the mantissa
  216. # shift = (1 - exp_bias_output) - (exp - exp_bias_input)
  217. # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift))
  218. exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32)
  219. non_zero_exponent_index = exponent != 0
  220. # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa
  221. subnormal_index = subnormal_index & non_zero_exponent_index
  222. shift = np.zeros_like(input_bin, dtype=np.int32)
  223. shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input)
  224. significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | (
  225. 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index]))
  226. output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | (
  227. exponent_output << output_dtype.fp_mantissa_width) | significand_output
  228. return output.reshape(input.shape)
  229. def _erf(x):
  230. # Numpy does not support erf
  231. return math.erf(x)
  232. def _umulhi_64(a, b):
  233. # Numpy does not support 128-bit multiplication
  234. # So we have to implement it manually
  235. return (int(a) * int(b)) >> 64
  236. np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32])
  237. np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64])
  238. np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64])
  239. class ExtraFunctions:
  240. @staticmethod
  241. def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _semantic):
  242. return tl.tensor(_semantic.builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty)
  243. class InterpreterBuilder:
  244. ir_sem_to_interpreter_sem = {
  245. _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE,
  246. _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE,
  247. _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED,
  248. _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE,
  249. }
  250. ir_rmw_op_to_interpreter_rmw_op = {
  251. _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD,
  252. _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD,
  253. _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN,
  254. _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN,
  255. _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX,
  256. _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX,
  257. _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND,
  258. _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR,
  259. _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR,
  260. _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG,
  261. }
  262. def __init__(self) -> None:
  263. self.arch = None
  264. self.options = InterpreterOptions()
  265. self.codegen_fns = {}
  266. self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types
  267. self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1)
  268. def set_grid_idx(self, x, y, z):
  269. if not x < self.grid_dim[0]:
  270. raise ValueError("x >= grid_dim[0]")
  271. if not y < self.grid_dim[1]:
  272. raise ValueError("y >= grid_dim[1]")
  273. if not z < self.grid_dim[2]:
  274. raise ValueError("z >= grid_dim[2]")
  275. self.grid_idx = (x, y, z)
  276. def set_grid_dim(self, nx, ny, nz):
  277. self.grid_dim = (nx, ny, nz)
  278. # constants
  279. def get_half_ty(self):
  280. return tl.float16
  281. def get_bf16_ty(self):
  282. return tl.bfloat16
  283. def get_float_ty(self):
  284. return tl.float32
  285. def get_double_ty(self):
  286. return tl.float64
  287. def get_int1_ty(self):
  288. return tl.int1
  289. def get_int8_ty(self):
  290. return tl.int8
  291. def get_uint8_ty(self):
  292. return tl.uint8
  293. def get_int16_ty(self):
  294. return tl.int16
  295. def get_uint16_ty(self):
  296. return tl.uint16
  297. def get_int32_ty(self):
  298. return tl.int32
  299. def get_uint32_ty(self):
  300. return tl.uint32
  301. def get_int64_ty(self):
  302. return tl.int64
  303. def get_uint64_ty(self):
  304. return tl.uint64
  305. def get_fp8e4nv_ty(self):
  306. return tl.float8e4nv
  307. def get_fp8e4b15_ty(self):
  308. return tl.float8e4b15
  309. def get_fp8e4b8_ty(self):
  310. return tl.float8e4b8
  311. def get_fp8e5_ty(self):
  312. return tl.float8e5
  313. def get_fp8e5b16_ty(self):
  314. return tl.float8e5b16
  315. def get_ptr_ty(self, elt_ty, addr_space):
  316. return tl.pointer_type(elt_ty, addr_space)
  317. def get_block_ty(self, dtype, shape):
  318. return tl.block_type(dtype, shape)
  319. def get_int1(self, value):
  320. return TensorHandle(np.array([value], dtype=np.bool_), tl.int1)
  321. def get_uint8(self, value):
  322. return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8)
  323. def get_int8(self, value):
  324. return TensorHandle(np.array([value], dtype=np.int8), tl.int8)
  325. def get_uint16(self, value):
  326. return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16)
  327. def get_int16(self, value):
  328. return TensorHandle(np.array([value], dtype=np.int16), tl.int16)
  329. def get_uint32(self, value):
  330. return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32)
  331. def get_int32(self, value):
  332. return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
  333. def get_uint64(self, value):
  334. return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64)
  335. def get_int64(self, value):
  336. return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
  337. def get_fp16(self, value):
  338. return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
  339. def get_fp32(self, value):
  340. return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
  341. def get_fp64(self, value):
  342. return TensorHandle(np.array([value], dtype=np.float64), tl.float64)
  343. def get_null_value(self, type):
  344. return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type)
  345. # programming model
  346. def create_get_program_id(self, axis):
  347. if self.grid_idx is None:
  348. raise ValueError("grid_idx is None")
  349. return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
  350. def create_get_num_programs(self, axis):
  351. return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
  352. # memory ops
  353. def create_load(self, ptr, _0, _1, is_volatile):
  354. mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
  355. other = None
  356. return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
  357. def create_store(self, ptr, val, _0, _1):
  358. mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
  359. return self.create_masked_store(ptr, val, mask, None, None)
  360. def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
  361. dtype_tt = ptrs.get_element_ty()
  362. dtype_np = _get_np_dtype(dtype_tt)
  363. if other is None:
  364. other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
  365. ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
  366. return TensorHandle(ret, dtype_tt)
  367. def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
  368. return _interpreter.store(ptrs.data, value.data, mask.data)
  369. # casting ops
  370. def cast_impl(self, src, dst_type):
  371. src_element_type = src.dtype.scalar
  372. dst_element_type = dst_type.scalar
  373. if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \
  374. (src_element_type == tl.float32 and dst_element_type == tl.bfloat16):
  375. data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type))
  376. return TensorHandle(data, dst_type.scalar)
  377. else:
  378. return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar)
  379. create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
  380. create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
  381. create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
  382. create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
  383. create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
  384. create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
  385. create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
  386. def create_fp_to_fp(self, src, dst_type, rounding_mode):
  387. src_element_type = src.dtype.scalar
  388. dst_element_type = dst_type.scalar
  389. data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type))
  390. return TensorHandle(data, dst_type.scalar)
  391. def create_bitcast(self, src, dst_type):
  392. return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar)
  393. # binary operators
  394. def binary_op(self, lhs, rhs, op):
  395. output = op(lhs.data, rhs.data)
  396. tl_dtype = lhs.dtype.scalar
  397. if not _validate_np_data_size(output, tl_dtype):
  398. output = output.astype(_get_np_dtype(tl_dtype))
  399. return TensorHandle(output, tl_dtype)
  400. create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
  401. create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
  402. create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
  403. create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
  404. create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
  405. create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
  406. create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
  407. create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
  408. create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs)
  409. # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
  410. create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
  411. create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod)
  412. create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
  413. create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
  414. create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
  415. create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
  416. create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
  417. create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
  418. create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
  419. create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
  420. create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
  421. create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
  422. create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
  423. create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
  424. create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
  425. create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
  426. create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
  427. create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
  428. create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
  429. create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
  430. create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
  431. create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
  432. create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
  433. create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
  434. create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
  435. create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
  436. create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
  437. create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
  438. create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
  439. create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
  440. create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
  441. create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
  442. create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
  443. create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
  444. create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
  445. create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
  446. create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
  447. create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
  448. create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
  449. create_int_to_ptr = create_bitcast
  450. create_ptr_to_int = create_bitcast
  451. def create_idiv(self, lhs, rhs):
  452. # Triton has IEEE, not numpy/torch, semantics for %, and those carry
  453. # through to //, so we have to use a nonstandard expression to get a
  454. # reference result for //.
  455. return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar)
  456. def create_ashr(self, lhs, rhs):
  457. # Triton's rshift operator depends on the signedness of the left operand
  458. lhs_dtype = _get_signed_np_dtype(lhs.data.dtype)
  459. rhs_dtype = _get_signed_np_dtype(rhs.data.dtype)
  460. lhs.data = lhs.data.astype(lhs_dtype)
  461. rhs.data = rhs.data.astype(rhs_dtype)
  462. return self.binary_op(lhs, rhs, np.right_shift)
  463. def create_umulhi(self, lhs, rhs):
  464. dtype = lhs.data.dtype
  465. if dtype == np.int64 or dtype == np.uint64:
  466. return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar)
  467. else:
  468. compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}")
  469. lhs_data = lhs.data.astype(compute_dtype)
  470. rhs_data = rhs.data.astype(compute_dtype)
  471. ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8)
  472. return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar)
  473. # ternary functions
  474. def ternary_op(self, lhs, rhs, other, op):
  475. output = op(lhs.data, rhs.data, other.data)
  476. tl_dtype = other.dtype.scalar
  477. if not _validate_np_data_size(output, tl_dtype):
  478. output = output.astype(_get_np_dtype(tl_dtype))
  479. return TensorHandle(output, tl_dtype)
  480. create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip)
  481. create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
  482. def create_fma(self, x, y, z):
  483. return TensorHandle(x.data * y.data + z.data, z.dtype.scalar)
  484. # unary functions
  485. def unary_op(self, arg, op):
  486. return TensorHandle(op(arg.data), arg.dtype.scalar)
  487. def create_fabs(self, arg):
  488. # Mask out the sign bit based on the primitive length
  489. dtype_tt = arg.dtype
  490. mask_bitwidth = dtype_tt.primitive_bitwidth - 1
  491. np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}")
  492. data = arg.data.view(np_uint_dtype)
  493. mask = (1 << mask_bitwidth) - 1
  494. ret = (data & mask).view(_get_np_dtype(dtype_tt))
  495. return TensorHandle(ret, arg.dtype.scalar)
  496. create_cos = lambda self, arg: self.unary_op(arg, np.cos)
  497. create_exp = lambda self, arg: self.unary_op(arg, np.exp)
  498. create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2)
  499. create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
  500. create_floor = lambda self, arg: self.unary_op(arg, np.floor)
  501. create_ceil = lambda self, arg: self.unary_op(arg, np.ceil)
  502. create_log = lambda self, arg: self.unary_op(arg, np.log)
  503. create_log2 = lambda self, arg: self.unary_op(arg, np.log2)
  504. create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
  505. create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
  506. create_sin = lambda self, arg: self.unary_op(arg, np.sin)
  507. def create_erf(self, arg):
  508. ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data)
  509. return TensorHandle(ret, arg.dtype.scalar)
  510. def create_rsqrt(self, arg):
  511. return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar)
  512. # tensor operators
  513. create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar)
  514. def create_trans(self, arg, perm):
  515. return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar)
  516. def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc):
  517. a_data = a.data
  518. b_data = b.data
  519. if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \
  520. (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()):
  521. a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16)
  522. b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16)
  523. return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar)
  524. def create_make_range(self, ret_ty, start, stop):
  525. return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
  526. def create_histogram(self, data, bins, mask):
  527. if mask is None:
  528. mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
  529. # By default np.histogram returns int64 dtype values
  530. # Docs specify that returned dtype is taken based on optional weights.dtype
  531. # This is fix for interpreter cases where for example int32 tensor is being passed
  532. # But unexpectedly int64 values are being returned causing
  533. # tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption
  534. dummy_weights = np.ones_like(data.data, dtype=data.data.dtype)
  535. # force all masked elements to zero
  536. data = np.where(mask.data, data.data, np.zeros_like(data.data))
  537. histogram = np.histogram(data, bins=bins, range=(0, bins), weights=dummy_weights)[0]
  538. # remove overcounted elements
  539. histogram[0] -= np.logical_not(mask.data).sum()
  540. return TensorHandle(histogram, tl.int32)
  541. def create_gather(self, src, indices, axis):
  542. return TensorHandle(np.take_along_axis(src.data, indices.data, axis=axis), src.dtype.scalar)
  543. # pointer arithmetic
  544. def create_addptr(self, ptr, offset):
  545. dtype_tt = ptr.get_element_ty()
  546. element_bitwidth = dtype_tt.primitive_bitwidth
  547. # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic
  548. element_bytewidth = max(1, element_bitwidth // 8)
  549. return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype)
  550. def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy,
  551. is_volatile):
  552. ptrs, masks = ptr.materialize_pointers(boundary_check)
  553. dtype_tt = ptrs.get_element_ty()
  554. dtype_np = _get_np_dtype(dtype_tt)
  555. if padding_option is None:
  556. other = None
  557. elif padding_option == _ir.PADDING_OPTION.PAD_ZERO:
  558. other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
  559. elif padding_option == _ir.PADDING_OPTION.PAD_NAN:
  560. other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
  561. else:
  562. raise ValueError(f"unsupported padding option {padding_option}")
  563. return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
  564. def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
  565. ptrs, masks = ptr.materialize_pointers(boundary_check)
  566. return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
  567. def create_expand_dims(self, arg, axis):
  568. return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar)
  569. def create_broadcast(self, arg, shape):
  570. return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar)
  571. def create_cat(self, lhs, rhs):
  572. return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar)
  573. def create_join(self, lhs, rhs):
  574. # Triton only supports joining two original tensors into a new one along the last axis
  575. return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar)
  576. def create_split(self, val):
  577. # Triton only supports splitting the original tensor into two along the last axis
  578. return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar))
  579. def create_splat(self, ret_ty, arg):
  580. shape = ret_ty.shape
  581. if isinstance(arg.dtype, tl.block_type):
  582. return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
  583. else: # scalar
  584. return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
  585. def create_unsplat(self, arg):
  586. return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
  587. def create_atomic_cas(self, ptr, cmp, val, sem, scope):
  588. if sem not in self.ir_sem_to_interpreter_sem:
  589. raise ValueError(f"unsupported semantic {sem}")
  590. sem = self.ir_sem_to_interpreter_sem[sem]
  591. return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar)
  592. def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope):
  593. if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op:
  594. raise ValueError(f"unsupported rmwOp {rmwOp}")
  595. if sem not in self.ir_sem_to_interpreter_sem:
  596. raise ValueError(f"unsupported semantic {sem}")
  597. rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp]
  598. sem = self.ir_sem_to_interpreter_sem[sem]
  599. return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar)
  600. def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
  601. raise NotImplementedError("extern_elementwise not supported in interpreter mode")
  602. def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
  603. raise NotImplementedError("inline_asm not supported in interpreter mode")
  604. def create_print(self, prefix, hex, values, isSigned):
  605. # NOTE: the `isSigned` variable is not really used here; because Signness is already known
  606. # by `values` themselves in python interpreter, thus not really needed here;
  607. # it is only used for triton PrintOpToLLVM to correctly construct the format specifier.
  608. # Interpreter's device_print function has a different format than Triton's device_print
  609. msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})"
  610. if prefix:
  611. msg += f" {prefix}"
  612. if hex:
  613. np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"})
  614. for value in values:
  615. print(msg + f" {value.data}")
  616. if hex:
  617. np.set_printoptions(formatter=None)
  618. def create_assert(self, condition, message):
  619. # Interpreter's device_assert function has a different format than Triton's device_assert
  620. assert condition, f"{message}"
  621. def create_assume(self, condition):
  622. assert condition, "Assume failed"
  623. def create_barrier(self):
  624. # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
  625. pass
  626. def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
  627. # Create new offsets to avoid modifying the original
  628. new_offsets = [offset.clone() for offset in offsets]
  629. return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)
  630. def create_advance(self, ptr, offsets):
  631. if len(ptr.offsets) != len(offsets):
  632. raise ValueError("len(ptr.offsets) != len(offsets)")
  633. # Create new offsets to avoid modifying the original
  634. new_offsets = [offset.clone() for offset in ptr.offsets]
  635. ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
  636. for i in range(len(offsets)):
  637. ret.offsets[i].data += offsets[i].data
  638. return ret
  639. def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
  640. tensor_shape: List[int], is_signed: bool, padding: str = "zero"):
  641. desc = TensorDescHandle(base, shape, strides, tensor_shape, padding)
  642. desc.validate()
  643. return desc
  644. def create_descriptor_load(self, desc: TensorDescHandle, indices: List[TensorHandle], cache_modifier,
  645. eviction_policy):
  646. assert isinstance(desc, TensorDescHandle)
  647. ptrs, mask = desc.materialize_pointers(indices)
  648. dtype_tt = ptrs.get_element_ty()
  649. dtype_np = _get_np_dtype(dtype_tt)
  650. padding = desc.padding
  651. if padding == _ir.PADDING_OPTION.PAD_ZERO:
  652. other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
  653. elif padding == _ir.PADDING_OPTION.PAD_NAN:
  654. other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
  655. else:
  656. raise ValueError(f"unsupported padding {padding}")
  657. return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier,
  658. eviction_policy=eviction_policy, is_volatile=False)
  659. def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
  660. ptrs, mask = desc.materialize_pointers(indices)
  661. return self.create_masked_store(ptrs, value, mask, None, None)
  662. def create_descriptor_gather(self, desc: TensorDescHandle, x_offsets: TensorHandle, y_offset: TensorHandle, type):
  663. dtype = desc.base.dtype.element_ty
  664. np_dtype = _get_np_dtype(dtype)
  665. result = np.zeros([x_offsets.data.shape[0], desc.block_shape[-1]], dtype=np_dtype)
  666. cache_modifier = None
  667. eviction_policy = None
  668. for i, x_offset in enumerate(x_offsets.data):
  669. indices = [TensorHandle(x_offset, tl.int32), y_offset]
  670. result[i, :] = self.create_descriptor_load(desc, indices, cache_modifier, eviction_policy).data
  671. return TensorHandle(result, dtype)
  672. def create_descriptor_scatter(self, desc: TensorDescHandle, value: TensorHandle, x_offsets: TensorHandle,
  673. y_offset: TensorHandle):
  674. for i, x_offset in enumerate(x_offsets.data):
  675. slice = TensorHandle(value.data[i], value.dtype)
  676. indices = [TensorHandle(x_offset, tl.int32), y_offset]
  677. self.create_descriptor_store(desc, slice, indices)
  678. def get_all_ones_value(self, type):
  679. np_type = _get_np_dtype(type)
  680. if "int" in np_type.name:
  681. return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar)
  682. elif np_type == np.bool_:
  683. return TensorHandle(np.full(1, True, dtype=np_type), type.scalar)
  684. else:
  685. raise TypeError(f"unsupported type {type}")
  686. _MISSING = object()
  687. class _LangPatchScope:
  688. """Tracks patched attributes so they can be restored."""
  689. def __init__(self) -> None:
  690. self._changes: list[tuple[object, str, object]] = []
  691. def set_attr(self, obj: object, name: str, value: object) -> None:
  692. original = getattr(obj, name, _MISSING)
  693. self._changes.append((obj, name, original))
  694. setattr(obj, name, value)
  695. def restore(self) -> None:
  696. while self._changes:
  697. obj, name, original = self._changes.pop()
  698. if original is _MISSING:
  699. delattr(obj, name)
  700. else:
  701. setattr(obj, name, original)
  702. def _patch_attr(obj, name, member, builder, scope: _LangPatchScope):
  703. semantic = TritonSemantic(builder)
  704. new_member = lambda *args, member=member, **kwargs: (member(*args, **
  705. {k: v
  706. for k, v in kwargs.items()
  707. if k != "_semantic"}, _semantic=semantic))
  708. scope.set_attr(obj, name, new_member)
  709. def _patch_builtin(pkg, builder, scope: _LangPatchScope):
  710. for name, member in inspect.getmembers(pkg):
  711. if tl.core.is_builtin(member):
  712. _patch_attr(pkg, name, member, builder, scope)
  713. def _patch_lang_tensor(tensor, scope: _LangPatchScope):
  714. def _get_bool(self):
  715. data = self.handle.data
  716. # in triton, only scalars can be converted to booleans
  717. # here we need this hack because all scalars are tensors
  718. return bool(data) if data.size == 1 else True
  719. def _get_transpose(self):
  720. handle = TensorHandle(np.transpose(self.handle.data), self.handle.dtype)
  721. assert self.type.is_block()
  722. block_shape = list(self.type.shape)
  723. block_shape[-1], block_shape[-2] = block_shape[-2], block_shape[-1]
  724. res_ty = tl.core.block_type(self.dtype, block_shape)
  725. return tl.core.tensor(handle, res_ty)
  726. scope.set_attr(tensor, "__index__", lambda self: int(self.handle.data))
  727. scope.set_attr(tensor, "__bool__", lambda self: _get_bool(self))
  728. scope.set_attr(tensor, "__repr__", lambda self: repr(self.handle.data))
  729. scope.set_attr(tensor, "__str__", lambda self: str(self.handle.data))
  730. scope.set_attr(tensor, "T", property(_get_transpose))
  731. class ReduceScanOpInterface:
  732. def __init__(self, axis, combine_fn):
  733. self.axis = axis
  734. self.combine_fn = combine_fn
  735. def check_axis(self, shape, axis):
  736. if axis is not None and axis >= len(shape):
  737. raise ValueError(f"axis {axis} out of bounds for shape {shape}")
  738. def check_tensor(self, input):
  739. for arg in input:
  740. if not isinstance(arg, tl.core.tensor):
  741. raise ValueError(f"input must be a tensor, got {type(arg)}")
  742. self.check_axis(arg.shape, self.axis)
  743. def to_tensor(self, ret, dtype):
  744. np_dtype = _get_np_dtype(dtype)
  745. if hasattr(ret, "shape") and ret.shape:
  746. ret = ret.astype(np_dtype)
  747. ret_type = tl.block_type(dtype, list(ret.shape))
  748. else:
  749. ret = np.array([ret], dtype=np_dtype)
  750. ret_type = dtype
  751. return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type)
  752. def apply(self, input):
  753. if not isinstance(input, tuple):
  754. return self.apply((input, ))[0]
  755. self.check_tensor(input)
  756. ret = self.apply_impl(input)
  757. return tuple(ret) if isinstance(ret, (list, tuple)) else (ret, )
  758. class ReduceOps(ReduceScanOpInterface):
  759. def __init__(self, axis, combine_fn, keep_dims):
  760. super().__init__(axis, combine_fn)
  761. self.keep_dims = keep_dims
  762. def unravel(self, input, axis):
  763. ret = []
  764. for data in input:
  765. if axis is not None:
  766. ret.append(data)
  767. else:
  768. axis = 0
  769. ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype))
  770. return tuple(ret), axis
  771. def generic_reduce(self, input):
  772. original_axis = self.axis
  773. input, axis = self.unravel(input, self.axis)
  774. input_data = []
  775. output_data = []
  776. input_shape = input[0].handle.data.shape
  777. output_shape = input_shape[0:axis] + input_shape[axis + 1:]
  778. for arg in input:
  779. input_data.append(arg.handle.data)
  780. output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype))
  781. # Reduce on axis
  782. for i in range(input_data[0].size):
  783. # Recover input_index from i using input_shape
  784. input_index = np.unravel_index(i, input_shape)
  785. output_index = input_index[0:axis] + input_index[axis + 1:]
  786. input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data))
  787. if input_index[axis] == 0:
  788. # First element
  789. for j in range(len(output_data)):
  790. output_data[j][output_index] = input_tuple[j].handle.data.item()
  791. else:
  792. acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data))
  793. combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple)
  794. acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
  795. for j in range(len(output_data)):
  796. output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance(
  797. acc_tuple[j], tl.core.tensor) else acc_tuple[j]
  798. # Pack output
  799. ret = []
  800. for i, data in enumerate(output_data):
  801. if self.keep_dims:
  802. if original_axis is not None:
  803. data = np.expand_dims(data, axis)
  804. else:
  805. for _ in range(len(input_shape)):
  806. data = np.expand_dims(data, 0)
  807. elif original_axis is None:
  808. # Take a scalar
  809. data = data.item()
  810. ret.append(self.to_tensor(data, input[i].dtype))
  811. return ret
  812. def min_max(self, input, val_reduce_op, idx_reduce_op=None):
  813. # If input is a tuple, it must be (val, index), and we only take val
  814. input = input[0] if isinstance(input, tuple) else input
  815. val = None
  816. idx = None
  817. if val_reduce_op:
  818. val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
  819. if idx_reduce_op:
  820. idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32)
  821. if val is not None and idx is not None:
  822. return val, idx
  823. elif val is not None:
  824. return val
  825. elif idx is not None:
  826. return idx
  827. else:
  828. raise ValueError("val_reduce_op and idx_reduce_op are both None")
  829. def sum(self, input):
  830. return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype)
  831. def apply_impl(self, input):
  832. if self.combine_fn == tl.standard._argmin_combine_tie_break_left:
  833. return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin)
  834. elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
  835. return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
  836. elif self.combine_fn == tl.standard._elementwise_max:
  837. return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
  838. elif self.combine_fn == tl.standard._elementwise_min:
  839. return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
  840. elif self.combine_fn == tl.standard._sum_combine:
  841. return self.sum(input[0])
  842. else:
  843. # Fall back to the slow mode
  844. return self.generic_reduce(input)
  845. class ScanOps(ReduceScanOpInterface):
  846. def __init__(self, axis, combine_fn, reverse):
  847. super().__init__(axis, combine_fn)
  848. self.reverse = reverse
  849. def cumsum(self, input):
  850. return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)]
  851. def cumprod(self, input):
  852. return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)]
  853. def generic_scan(self, input):
  854. input_data = []
  855. output_data = []
  856. shape = input[0].handle.data.shape
  857. for arg in input:
  858. input_data.append(arg.handle.data)
  859. output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype))
  860. # Scan on axis
  861. for i in range(input_data[0].size):
  862. # Recover index from i using shape
  863. index = np.unravel_index(i, shape)
  864. data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data))
  865. if index[self.axis] == 0:
  866. # First element
  867. for j in range(len(output_data)):
  868. output_data[j][index] = data[j].handle.data.item()
  869. else:
  870. prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index)))
  871. acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data))
  872. combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data)
  873. acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret
  874. for j in range(len(output_data)):
  875. output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance(
  876. acc_tuple[j], tl.core.tensor) else acc_tuple[j]
  877. # Pack output
  878. ret = []
  879. for i, data in enumerate(output_data):
  880. ret.append(self.to_tensor(data, input[i].dtype))
  881. return ret
  882. def apply_impl(self, input):
  883. new_input = []
  884. if self.reverse:
  885. for arg in input:
  886. new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype))
  887. else:
  888. new_input = input
  889. if self.combine_fn == tl.standard._sum_combine:
  890. ret = self.cumsum(new_input[0])
  891. elif self.combine_fn == tl.standard._prod_combine:
  892. ret = self.cumprod(new_input[0])
  893. else:
  894. # Fall back to the slow mode
  895. ret = self.generic_scan(new_input)
  896. if self.reverse:
  897. for arg in ret:
  898. arg.handle.data = np.flip(arg.handle.data, axis=self.axis)
  899. return ret
  900. def _patch_reduce_scan(scope: _LangPatchScope):
  901. # Because interpreter doesn't support region_builder_fn, we cannot patch the builder
  902. # to use the new reduce and scan functions.
  903. # Instead, we need to patch reduce and reduce functions in tl and tl.core
  904. def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs):
  905. return ReduceOps(axis, combine_fn, keep_dims).apply(input)
  906. def _new_scan(input, axis, combine_fn, reverse=False, **kwargs):
  907. return ScanOps(axis, combine_fn, reverse).apply(input)
  908. scope.set_attr(tl, "reduce", _new_reduce)
  909. scope.set_attr(tl, "associative_scan", _new_scan)
  910. scope.set_attr(tl.core, "reduce", _new_reduce)
  911. scope.set_attr(tl.core, "associative_scan", _new_scan)
  912. def _patch_lang_core(lang, scope: _LangPatchScope):
  913. def _new_to_ir(self, builder):
  914. # We need to specify signedness for integer types in the numpy mode
  915. if self.name == 'void':
  916. return builder.get_void_ty()
  917. elif self.name == 'int1':
  918. return builder.get_int1_ty()
  919. elif self.name == 'int8':
  920. return builder.get_int8_ty()
  921. elif self.name == 'uint8':
  922. return builder.get_uint8_ty()
  923. elif self.name == 'int16':
  924. return builder.get_int16_ty()
  925. elif self.name == 'uint16':
  926. return builder.get_uint16_ty()
  927. elif self.name == 'int32':
  928. return builder.get_int32_ty()
  929. elif self.name == 'uint32':
  930. return builder.get_uint32_ty()
  931. elif self.name == 'int64':
  932. return builder.get_int64_ty()
  933. elif self.name == 'uint64':
  934. return builder.get_uint64_ty()
  935. elif self.name == 'fp8e5':
  936. return builder.get_fp8e5_ty()
  937. elif self.name == 'fp8e4nv':
  938. return builder.get_fp8e4nv_ty()
  939. elif self.name == 'fp8e4b15':
  940. return builder.get_fp8e4b15_ty()
  941. elif self.name == 'fp16':
  942. return builder.get_half_ty()
  943. elif self.name == 'bf16':
  944. return builder.get_bf16_ty()
  945. elif self.name == 'fp32':
  946. return builder.get_float_ty()
  947. elif self.name == 'fp64':
  948. return builder.get_double_ty()
  949. raise ValueError(f'fail to convert {self} to ir type')
  950. # can't just map lang.static_range to `range`, because `tl.static_range`
  951. # can get `step` passed by keyword
  952. def _new_range(arg1, arg2=None, step=None, **kwargs):
  953. if step is None:
  954. step = 1
  955. if arg2 is None:
  956. start, end = 0, arg1
  957. else:
  958. start, end = arg1, arg2
  959. return range(start, end, step)
  960. def _new_static_assert(cond, msg=""):
  961. assert cond, msg
  962. def _set_attr(input, values, name):
  963. # skip non tensor types. This may happen for induction variables.
  964. if not isinstance(input, tl.tensor):
  965. return input
  966. # Unwrap constexpr
  967. values = [values] if not isinstance(values, (list, tuple)) else values
  968. values = [v.value if isinstance(v, tl.constexpr) else v for v in values]
  969. if len(values) != max(1, len(input.shape)):
  970. raise ValueError(f"len(values) != len(input.shape) for {name}")
  971. input.handle.set_attr(name, values)
  972. return input
  973. scope.set_attr(lang, "range", _new_range)
  974. scope.set_attr(lang, "static_range", _new_range)
  975. scope.set_attr(lang, "static_assert", _new_static_assert)
  976. scope.set_attr(lang, "static_print", print)
  977. scope.set_attr(lang.dtype, "to_ir", _new_to_ir)
  978. scope.set_attr(lang, "multiple_of", partial(_set_attr, name="tt.divisibility"))
  979. scope.set_attr(lang, "max_contiguous", partial(_set_attr, name="tt.contiguity"))
  980. scope.set_attr(lang, "max_constancy", partial(_set_attr, name="tt.constancy"))
  981. _patch_reduce_scan(scope)
  982. def _patch_lang(fn):
  983. scope = _LangPatchScope()
  984. langs = [value for _, value in fn.__globals__.items() if inspect.ismodule(value) and value in [tl, tl.core]]
  985. assert len(langs) >= 1, "triton.language must be visible from within jit'd function"
  986. for lang in langs:
  987. _patch_builtin(lang, interpreter_builder, scope)
  988. _patch_builtin(lang.tensor, interpreter_builder, scope)
  989. if lang == tl:
  990. _patch_builtin(lang.math, interpreter_builder, scope)
  991. _patch_lang_tensor(lang.tensor, scope)
  992. _patch_lang_core(lang, scope)
  993. _patch_builtin(tl.core.tensor_descriptor_base, interpreter_builder, scope)
  994. return scope
  995. def _tuple_create(arg, contents):
  996. # NamedTuples and tuples have different construction semantics. NamedTuple
  997. # has a constructor that takes individual arguments, while tuple takes an
  998. # iterable. Both have type "tuple" making it difficult to distinguish
  999. # between them, but only NamedTuple has "_fields" and apparently this is how
  1000. # everyone does the check.
  1001. return type(arg)(*contents) if hasattr(arg, "_fields") else type(arg)(contents)
  1002. # TODO: wrap everything in triton tensors
  1003. def _implicit_cvt(arg):
  1004. if isinstance(arg, int):
  1005. ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
  1006. dtype = np.int32
  1007. if -2**31 <= arg < 2**31:
  1008. dtype = np.int32
  1009. elif 2**31 <= arg < 2**32:
  1010. dtype = np.uint32
  1011. elif -2**63 <= arg < 2**63:
  1012. dtype = np.int64
  1013. elif 2**63 <= arg < 2**64:
  1014. dtype = np.uint64
  1015. else:
  1016. raise ValueError(f"Unsupported integer value {arg}")
  1017. handle = TensorHandle(np.array([arg], dtype=dtype), ty)
  1018. return tl.tensor(handle, ty)
  1019. if hasattr(arg, "data_ptr"):
  1020. ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
  1021. handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
  1022. return tl.tensor(handle, ty)
  1023. elif isinstance(arg, tuple):
  1024. return _tuple_create(arg, map(_implicit_cvt, arg))
  1025. elif isinstance(arg, TensorDescriptor):
  1026. strides = [_implicit_cvt(s) for s in arg.strides]
  1027. assert arg.strides[-1] == 1
  1028. strides[-1] = tl.constexpr(1)
  1029. semantic = TritonSemantic(InterpreterBuilder())
  1030. return semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base),
  1031. shape=[_implicit_cvt(s) for s in arg.shape], strides=strides,
  1032. block_shape=[tl.constexpr(b)
  1033. for b in arg.block_shape], padding_option=arg.padding)
  1034. return arg
  1035. interpreter_builder = InterpreterBuilder()
  1036. interpreter_semantic = TritonSemantic(interpreter_builder)
  1037. def _unwrap_tensor(t):
  1038. if isinstance(t, triton.runtime.jit.TensorWrapper):
  1039. return t.base
  1040. return t
  1041. def _rewrap_tensor(t, original_tensor):
  1042. if isinstance(original_tensor, triton.runtime.jit.TensorWrapper):
  1043. return triton.runtime.jit.TensorWrapper(t, original_tensor.dtype)
  1044. return t
  1045. class GridExecutor:
  1046. def __init__(self, fn, arg_names, grid, pre_run_hooks=[]):
  1047. from .jit import _normalize_ty # TODO: modularize
  1048. self.fn = fn
  1049. self.arg_names = arg_names
  1050. self.grid = grid
  1051. self.pre_run_hooks = pre_run_hooks
  1052. __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
  1053. self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
  1054. def _init_args_hst(self, args_dev, kwargs):
  1055. storages = {}
  1056. def _to_cpu(arg):
  1057. if isinstance(arg, tuple):
  1058. return _tuple_create(arg, map(_to_cpu, arg))
  1059. elif isinstance(arg, TensorDescriptor):
  1060. return TensorDescriptor(
  1061. _to_cpu(arg.base),
  1062. arg.shape,
  1063. arg.strides,
  1064. arg.block_shape,
  1065. arg.padding,
  1066. )
  1067. elif not hasattr(arg, "data_ptr"):
  1068. return arg
  1069. unwrapped_arg = _unwrap_tensor(arg)
  1070. if unwrapped_arg.untyped_storage().data_ptr() not in storages:
  1071. storage = unwrapped_arg.untyped_storage()
  1072. storages[storage.data_ptr()] = storage.cpu()
  1073. storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
  1074. cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
  1075. cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
  1076. cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
  1077. return cpu_arg
  1078. args_hst = [_to_cpu(arg) for arg in args_dev]
  1079. # Process keyword arguments
  1080. kwargs_hst = {}
  1081. for key, value in kwargs.items():
  1082. kwargs_hst[key] = _to_cpu(value)
  1083. return args_hst, kwargs_hst
  1084. def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
  1085. storages = {}
  1086. def _from_cpu(arg_dev, arg_hst):
  1087. if hasattr(arg_dev, "data_ptr"):
  1088. # No need to rewrap because this just modifies internal
  1089. arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
  1090. storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
  1091. elif isinstance(arg_dev, tuple):
  1092. for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
  1093. _from_cpu(arg_dev, arg_hst)
  1094. elif isinstance(arg_dev, TensorDescriptor):
  1095. _from_cpu(arg_dev.base, arg_hst.base)
  1096. for arg_dev, arg_hst in zip(args_dev, args_hst):
  1097. _from_cpu(arg_dev, arg_hst)
  1098. # Restore keyword arguments
  1099. for key, kwarg_dev in kwargs.items():
  1100. kwarg_hst = kwargs_hst[key]
  1101. _from_cpu(kwarg_dev, kwarg_hst)
  1102. for (arg_dev, arg_hst) in storages.values():
  1103. arg_dev.copy_(arg_hst)
  1104. def __call__(self, *args_dev, **kwargs):
  1105. # Removes not used reserved keywords from kwargs
  1106. # Triton doesn't support keyword-only, variable positional or variable keyword arguments
  1107. # It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
  1108. argspec = inspect.getfullargspec(self.fn)
  1109. kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
  1110. # copy arguments to the host
  1111. args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
  1112. # run pre-run hooks
  1113. for hook in self.pre_run_hooks:
  1114. hook(*args_hst, **kwargs_hst)
  1115. # remaps core language functions to interpreted ones
  1116. patch_scope = _patch_lang(self.fn)
  1117. try:
  1118. # we need to copy arguments to the host for the interpreter
  1119. # implicitly convert tensor arguments to their base pointers
  1120. args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
  1121. args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
  1122. # iterate through grid
  1123. grid = self.grid(args) if callable(self.grid) else self.grid
  1124. assert len(grid) <= 3, "grid must have at most 3 dimensions"
  1125. grid = grid + (1, ) * (3 - len(grid))
  1126. interpreter_builder.set_grid_dim(*grid)
  1127. try:
  1128. for x in range(grid[0]):
  1129. for y in range(grid[1]):
  1130. for z in range(grid[2]):
  1131. interpreter_builder.set_grid_idx(x, y, z)
  1132. self.fn(**args)
  1133. except Exception as e:
  1134. if triton.knobs.compilation.front_end_debugging:
  1135. raise
  1136. raise InterpreterError(repr(e)) from e
  1137. finally:
  1138. patch_scope.restore()
  1139. # copy arguments back to propagate side-effects
  1140. self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
  1141. class ASTTransformer(ast.NodeTransformer):
  1142. def visit_Assign(self, node):
  1143. names = []
  1144. for target in node.targets:
  1145. names += [self.visit(target)]
  1146. if len(names) > 1:
  1147. raise ValueError("Multiple assignments are not supported")
  1148. # Modify the assignment x = value to
  1149. # interpreter_semantic.to_tensor(value, False)
  1150. node.value = ast.Call(
  1151. func=ast.Attribute(value=ast.Name(id="interpreter_semantic", ctx=ast.Load()), attr="to_tensor",
  1152. ctx=ast.Load()), args=[node.value, ast.Constant(value=False)], keywords=[])
  1153. return node
  1154. class FunctionRewriter:
  1155. ast_transformer = ASTTransformer()
  1156. def __init__(self, fn, **kwargs):
  1157. self.fn = fn
  1158. self.kwargs = kwargs
  1159. self.filename: str = ""
  1160. # Absolute line number in the file
  1161. self.def_file_lineno: int = 0
  1162. def rewrite_ast(self):
  1163. # If exception is raise, it means the function does not have source code available,
  1164. # e.g., dynamically generated functions, we cannot rewrite it so just return the original function
  1165. try:
  1166. lines, _ = inspect.getsourcelines(self.fn)
  1167. except Exception:
  1168. return self.fn
  1169. # truncate lines before def
  1170. # @triton.autotune(...)
  1171. # ...
  1172. # @triton.jit
  1173. # ...
  1174. # def foo(...): <- this line is the function definition
  1175. self.filename, self.def_file_lineno = self._get_jit_fn_file_line()
  1176. self.def_lineno = self._find_def(lines)
  1177. src = self._prepare_source(lines)
  1178. transformed_ast = self._transform_ast(src)
  1179. return self._compile_and_exec(transformed_ast)
  1180. def _get_jit_fn_file_line(self):
  1181. from .jit import get_jit_fn_file_line, JITFunction
  1182. return get_jit_fn_file_line(JITFunction(self.fn))
  1183. def _find_def(self, lines):
  1184. def_lineno = 0
  1185. # Line numbers start from 1
  1186. for i, line in enumerate(lines):
  1187. if line.strip().startswith("def "):
  1188. def_lineno = i + 1
  1189. return def_lineno
  1190. def _prepare_source(self, lines):
  1191. lines = lines[self.def_lineno - 1:]
  1192. src = ''.join(lines)
  1193. return textwrap.dedent(src)
  1194. def _transform_ast(self, src):
  1195. # src is like:
  1196. # 1: def foo(...):
  1197. # 2: ...
  1198. parsed_ast = ast.parse(src)
  1199. transformed_ast = self.ast_transformer.visit(parsed_ast)
  1200. ast.fix_missing_locations(transformed_ast)
  1201. inc_lineno = self.def_file_lineno - 1
  1202. ast.increment_lineno(transformed_ast, inc_lineno)
  1203. return transformed_ast
  1204. def _compile_and_exec(self, transformed_ast):
  1205. compiled_code = compile(transformed_ast, filename=self.filename, mode='exec')
  1206. local_namespace = {**self.kwargs}
  1207. fn_globals = self.fn.__globals__
  1208. for key, value in globals().items():
  1209. if key not in fn_globals:
  1210. fn_globals[key] = value
  1211. exec(compiled_code, fn_globals, local_namespace)
  1212. return local_namespace[self.fn.__name__]
  1213. class InterpretedFunction(KernelInterface[T]):
  1214. # Cache all rewritten functions
  1215. rewritten_fn: Dict[Callable, Callable] = {}
  1216. def __init__(self, fn, **kwargs) -> None:
  1217. self.fn = fn
  1218. self.rewriter = FunctionRewriter(fn, **kwargs)
  1219. self.kwargs = kwargs
  1220. self.pre_run_hooks = []
  1221. signature = inspect.signature(fn)
  1222. self.arg_names = [v.name for v in signature.parameters.values()]
  1223. def run(self, *args, grid, warmup, **kwargs):
  1224. if warmup:
  1225. return
  1226. fn = self.rewrite()
  1227. return GridExecutor(fn, self.arg_names, grid, self.pre_run_hooks)(*args, **kwargs)
  1228. def add_pre_run_hook(self, hook):
  1229. assert callable(hook)
  1230. self.pre_run_hooks.append(hook)
  1231. def rewrite(self):
  1232. if self.fn not in self.rewritten_fn:
  1233. self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast()
  1234. return self.rewritten_fn[self.fn]
  1235. @property
  1236. def __name__(self):
  1237. return self.fn.__name__
  1238. def __call__(self, *args, **kwargs):
  1239. # This is a device function call
  1240. _patch_lang(self.fn)
  1241. fn = self.rewrite()
  1242. try:
  1243. return fn(*args, **kwargs)
  1244. except Exception as e:
  1245. raise InterpreterError(repr(e)) from e