semantic.py 99 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966
  1. from __future__ import annotations # remove after python 3.11
  2. import warnings
  3. from typing import List, Optional, Sequence, Tuple, TypeVar, Generic, Type
  4. import numbers
  5. from triton.runtime import driver
  6. from .._C.libtriton import ir
  7. from . import core as tl
  8. T = TypeVar('T')
  9. TensorTy = TypeVar('TensorTy')
  10. class IncompatibleTypeErrorImpl(Exception):
  11. def __init__(self, type_a, type_b):
  12. self.type_a = type_a
  13. self.type_b = type_b
  14. self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
  15. super(IncompatibleTypeErrorImpl, self).__init__(self.message)
  16. class TritonSemantic(Generic[TensorTy]):
  17. tensor: Type[TensorTy] = tl.tensor
  18. lang = tl
  19. builder: ir.builder
  20. def __init__(self, builder):
  21. self.builder = builder
  22. # ===----------------------------------------------------------------------===##
  23. # Programming Model
  24. # ===----------------------------------------------------------------------===##
  25. def program_id(self, axis: int) -> TensorTy:
  26. if axis not in (0, 1, 2):
  27. raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
  28. return self.tensor(self.builder.create_get_program_id(axis), tl.int32)
  29. def num_programs(self, axis: int) -> TensorTy:
  30. if axis not in (0, 1, 2):
  31. raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
  32. return self.tensor(self.builder.create_get_num_programs(axis), tl.int32)
  33. # ===----------------------------------------------------------------------===//
  34. # Implicit Casting Utilities
  35. # ===----------------------------------------------------------------------===//
  36. def integer_promote_impl(self, a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
  37. a_rank = a_ty.int_bitwidth
  38. b_rank = b_ty.int_bitwidth
  39. a_sn = a_ty.int_signedness
  40. b_sn = b_ty.int_signedness
  41. # Rules for signedness taken from "Usual arithmetic conversions" on
  42. # https://en.cppreference.com/w/c/language/conversion.
  43. if a_sn == b_sn:
  44. return a_ty if a_rank > b_rank else b_ty
  45. elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
  46. return a_ty if a_rank >= b_rank else b_ty
  47. elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
  48. return b_ty if b_rank >= a_rank else a_ty
  49. raise TypeError(f"unexpected signedness {a_sn} and {b_sn}")
  50. def computation_type_impl(self, a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool,
  51. div_or_mod: bool) -> tl.dtype:
  52. # 0) For scalars we follow semantics similar to PyTorch, namely:
  53. # - If the scalar is of a lower or equal kind (bool < uint < int < fp),
  54. # it doesn't participate in the promotion
  55. if a_is_scalar != b_is_scalar:
  56. scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty)
  57. if scalar_ty.kind().value <= tensor_ty.kind().value:
  58. # Upcast because of 3) and 4) below!
  59. if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)):
  60. return tl.float32
  61. return tensor_ty
  62. # 1) if one operand is double, the other is implicitly
  63. # converted to double
  64. if a_ty.is_fp64() or b_ty.is_fp64():
  65. return tl.float64
  66. # 2) if one operand is float, the other is implicitly
  67. # converted to float
  68. if a_ty.is_fp32() or b_ty.is_fp32():
  69. return tl.float32
  70. # 3 ) if one operand is half, the other is implicitly converted to half
  71. # unless we're doing / or %, which do not exist natively in PTX for fp16.
  72. # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
  73. if a_ty.is_fp16() or b_ty.is_fp16():
  74. if div_or_mod:
  75. return tl.float32
  76. else:
  77. return tl.float16
  78. # 4) return bf16 only if both operands are of bf16
  79. if a_ty.is_bf16() and b_ty.is_bf16():
  80. if div_or_mod:
  81. return tl.float32
  82. else:
  83. return tl.bfloat16
  84. if a_ty.is_bf16() or b_ty.is_bf16():
  85. return tl.float32
  86. # 5) return fp16 if operands are different fp8
  87. if a_ty.is_fp8() and b_ty.is_fp8():
  88. return a_ty if a_ty == b_ty else tl.float16
  89. if not a_ty.is_int() or not b_ty.is_int():
  90. raise TypeError(f"unexpected type {a_ty} and {b_ty}")
  91. # 6 ) both operands are integer and undergo
  92. # integer promotion
  93. if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
  94. raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() +
  95. " because they have different signedness;"
  96. "this is unlikely to result in a useful answer. Cast them to the same signedness.")
  97. return self.integer_promote_impl(a_ty, b_ty)
  98. def to_tensor(self, x, check_type: bool = True):
  99. if isinstance(x, bool):
  100. return self.tensor(self.builder.get_int1(x), tl.int1)
  101. # Note: compile-time const integers are represented by unsigned values
  102. elif isinstance(x, int):
  103. if -2**31 <= x < 2**31:
  104. dtype = tl.int32
  105. elif 2**31 <= x < 2**32:
  106. dtype = tl.uint32
  107. elif -2**63 <= x < 2**63:
  108. dtype = tl.int64
  109. elif 2**63 <= x < 2**64:
  110. dtype = tl.uint64
  111. else:
  112. raise ValueError(f'Nonrepresentable integer {x}.')
  113. return self.scalar_constant(x, dtype=dtype)
  114. elif isinstance(x, float):
  115. min_float32 = 2**-126
  116. max_float32 = (2 - 2**-23) * 2**127
  117. abs_x = __builtins__['abs'](x)
  118. if abs_x == float("inf") or\
  119. abs_x == 0.0 or \
  120. x != x or \
  121. min_float32 <= abs_x <= max_float32:
  122. dtype = tl.float32
  123. else:
  124. dtype = tl.float64
  125. return self.scalar_constant(x, dtype=dtype)
  126. elif isinstance(x, tl.constexpr):
  127. return self.to_tensor(x.value)
  128. elif isinstance(x, self.tensor):
  129. return x
  130. if check_type:
  131. raise TypeError(f"cannot convert {x} of type {type(x)} to tensor")
  132. return x
  133. # ===----------------------------------------------------------------------===//
  134. # Binary Operators
  135. # ===----------------------------------------------------------------------===//
  136. def check_ptr_type_impl(self, type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
  137. if type_a.is_ptr():
  138. if not allow_ptr_a:
  139. raise IncompatibleTypeErrorImpl(type_a, type_b)
  140. # T* + U* with T != U
  141. if type_b.is_ptr() and (type_a != type_b):
  142. raise IncompatibleTypeErrorImpl(type_a, type_b)
  143. # T* + float
  144. if type_b.is_floating():
  145. raise IncompatibleTypeErrorImpl(type_a, type_b)
  146. def binary_op_type_checking_impl(self, lhs: TensorTy | numbers.Number, rhs: TensorTy | numbers.Number,
  147. allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True,
  148. div_or_mod=False) -> Tuple[TensorTy, TensorTy]:
  149. lhs_is_scalar = isinstance(lhs, numbers.Number)
  150. rhs_is_scalar = isinstance(rhs, numbers.Number)
  151. if lhs_is_scalar:
  152. lhs_scalar = lhs
  153. lhs = self.to_tensor(lhs)
  154. if rhs_is_scalar:
  155. rhs_scalar = rhs
  156. rhs = self.to_tensor(rhs)
  157. # implicit typecasting
  158. lhs_sca_ty = lhs.type.scalar
  159. rhs_sca_ty = rhs.type.scalar
  160. self.check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
  161. self.check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
  162. if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
  163. ret_sca_ty = self.computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod)
  164. if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned()
  165. or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
  166. raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
  167. "Perform a explicit cast on one of them.")
  168. if ret_sca_ty.is_int():
  169. if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <=
  170. ret_sca_ty.get_int_max_value()):
  171. raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
  172. if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <=
  173. ret_sca_ty.get_int_max_value()):
  174. raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
  175. lhs = self.scalar_constant(lhs_scalar, dtype=ret_sca_ty) if lhs_is_scalar else self.cast(lhs, ret_sca_ty)
  176. rhs = self.scalar_constant(rhs_scalar, dtype=ret_sca_ty) if rhs_is_scalar else self.cast(rhs, ret_sca_ty)
  177. # implicit broadcasting
  178. lhs, rhs = self.broadcast_impl_value(lhs, rhs)
  179. return lhs, rhs
  180. def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_op: callable):
  181. if lhs.type.scalar.int_bitwidth >= 64 or not self.builder.options.sanitize_overflow:
  182. return
  183. lhs_sca_ty = lhs.type.scalar
  184. rhs_sca_ty = rhs.type.scalar
  185. assert lhs_sca_ty == rhs_sca_ty
  186. assert lhs_sca_ty.is_int()
  187. lhs = self.cast(lhs, tl.int64)
  188. rhs = self.cast(rhs, tl.int64)
  189. ret = binary_op(lhs, rhs, False)
  190. max_value = lhs_sca_ty.get_int_max_value()
  191. max_value = self.scalar_constant(max_value, tl.int64)
  192. min_value = lhs_sca_ty.get_int_min_value()
  193. min_value = self.scalar_constant(min_value, tl.int64)
  194. cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value))
  195. msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
  196. self.device_assert(cond, msg, None)
  197. def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
  198. sanitize_overflow: bool) -> TensorTy:
  199. input, other = self.binary_op_type_checking_impl(input, other, True, True)
  200. input_scalar_ty = input.type.scalar
  201. other_scalar_ty = other.type.scalar
  202. if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
  203. raise TypeError("cannot add pointers together")
  204. # offset + ptr
  205. # ptr + offset
  206. if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
  207. input, other = other, input
  208. input_scalar_ty = input.type.scalar
  209. other_scalar_ty = other.type.scalar
  210. if input_scalar_ty.is_ptr():
  211. other_handle = other.handle
  212. if other.dtype.is_int_unsigned() and other.dtype.int_bitwidth < 64:
  213. # addptr treats offset as signed. Zero-extend unsigned offsets to ensure they're positive
  214. i64_ty = other.type.with_element_ty(tl.int64).to_ir(self.builder)
  215. other_handle = self.builder.create_int_cast(other.handle, i64_ty, False)
  216. return self.tensor(self.builder.create_addptr(input.handle, other_handle), input.type)
  217. # float + float
  218. elif input_scalar_ty.is_floating():
  219. return self.tensor(self.builder.create_fadd(input.handle, other.handle), input.type)
  220. # int + int
  221. elif input_scalar_ty.is_int():
  222. if sanitize_overflow:
  223. self.binary_op_sanitize_overflow_impl(input, other, self.add)
  224. return self.tensor(self.builder.create_add(input.handle, other.handle), input.type)
  225. raise TypeError(f"unexpected type {input_scalar_ty}")
  226. def sub(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
  227. sanitize_overflow: bool) -> TensorTy:
  228. input, other = self.binary_op_type_checking_impl(input, other, True, False)
  229. scalar_ty = input.type.scalar
  230. # ptr - offset
  231. if scalar_ty.is_ptr():
  232. return self.add(input, self.minus(other), sanitize_overflow=False)
  233. # float - float
  234. if scalar_ty.is_floating():
  235. return self.tensor(self.builder.create_fsub(input.handle, other.handle), input.type)
  236. # int - int
  237. elif scalar_ty.is_int():
  238. if sanitize_overflow:
  239. self.binary_op_sanitize_overflow_impl(input, other, self.sub)
  240. return self.tensor(self.builder.create_sub(input.handle, other.handle), input.type)
  241. raise TypeError(f"unexpected type {scalar_ty}")
  242. def mul(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
  243. sanitize_overflow: bool) -> TensorTy:
  244. input, other = self.binary_op_type_checking_impl(input, other)
  245. scalar_ty = input.type.scalar
  246. # float * float
  247. if scalar_ty.is_floating():
  248. return self.tensor(self.builder.create_fmul(input.handle, other.handle), input.type)
  249. # int * int
  250. elif scalar_ty.is_int():
  251. if sanitize_overflow:
  252. self.binary_op_sanitize_overflow_impl(input, other, self.mul)
  253. return self.tensor(self.builder.create_mul(input.handle, other.handle), input.type)
  254. raise TypeError(f"unexpected type {scalar_ty}")
  255. def truediv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
  256. input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
  257. input_scalar_ty = input.type.scalar
  258. other_scalar_ty = other.type.scalar
  259. # float / int
  260. if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
  261. other = self.cast(other, input_scalar_ty)
  262. # int / float
  263. elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
  264. input = self.cast(input, other_scalar_ty)
  265. # int / int (cast to tl.float32)
  266. elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
  267. input = self.cast(input, tl.float32)
  268. other = self.cast(other, tl.float32)
  269. # float / float (cast to the highest exponent type)
  270. elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
  271. if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
  272. other = self.cast(other, input_scalar_ty)
  273. else:
  274. input = self.cast(input, other_scalar_ty)
  275. # unreachable
  276. else:
  277. raise TypeError(f"unexpected type {input_scalar_ty}")
  278. return self.tensor(self.builder.create_fdiv(input.handle, other.handle), input.type)
  279. def floordiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
  280. input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
  281. input_scalar_ty = input.type.scalar
  282. other_scalar_ty = other.type.scalar
  283. if input_scalar_ty.is_int() and other_scalar_ty.is_int():
  284. ret_ty = self.integer_promote_impl(input_scalar_ty, other_scalar_ty)
  285. input = self.cast(input, ret_ty)
  286. other = self.cast(other, ret_ty)
  287. if ret_ty.is_int_signed():
  288. return self.tensor(self.builder.create_sdiv(input.handle, other.handle), input.type)
  289. else:
  290. return self.tensor(self.builder.create_udiv(input.handle, other.handle), input.type)
  291. raise TypeError(f"unexpected type {input_scalar_ty}")
  292. def fdiv(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number, ieee_rounding: bool) -> TensorTy:
  293. input_scalar_ty = input.type.scalar
  294. other_scalar_ty = other.type.scalar
  295. if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
  296. raise TypeError("both operands of fdiv must have floating scalar type")
  297. input, other = self.binary_op_type_checking_impl(input, other, False, False, False, True)
  298. ret = self.builder.create_fdiv(input.handle, other.handle)
  299. return self.tensor(ret, input.type)
  300. def mod(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number) -> TensorTy:
  301. input, other = self.binary_op_type_checking_impl(input, other, False, False, True, True)
  302. scalar_ty = input.type.scalar
  303. other_scalar_ty = other.type.scalar
  304. # float % float
  305. if scalar_ty.is_floating():
  306. return self.tensor(self.builder.create_frem(input.handle, other.handle), input.type)
  307. # % int
  308. elif scalar_ty.is_int():
  309. if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
  310. raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
  311. "because they have different signedness;"
  312. "this is unlikely to result in a useful answer. Cast them to the same signedness.")
  313. if scalar_ty.is_int_signed():
  314. return self.tensor(self.builder.create_srem(input.handle, other.handle), input.type)
  315. else:
  316. return self.tensor(self.builder.create_urem(input.handle, other.handle), input.type)
  317. raise TypeError(f"unexpected type {scalar_ty}")
  318. ##############
  319. # other arithmetic ops
  320. ##############
  321. def minimum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan):
  322. x, y = self.binary_op_type_checking_impl(x, y)
  323. dtype = x.dtype
  324. if dtype.is_floating():
  325. if propagate_nan == tl.PropagateNan.ALL:
  326. return self.tensor(self.builder.create_minimumf(x.handle, y.handle), x.type)
  327. elif propagate_nan == tl.PropagateNan.NONE:
  328. return self.tensor(self.builder.create_minnumf(x.handle, y.handle), x.type)
  329. else:
  330. raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
  331. elif dtype.is_int_signed():
  332. return self.tensor(self.builder.create_minsi(x.handle, y.handle), x.type)
  333. elif dtype.is_int_unsigned():
  334. return self.tensor(self.builder.create_minui(x.handle, y.handle), x.type)
  335. else:
  336. raise TypeError(f"Unexpected dtype {dtype}")
  337. def maximum(self, x: TensorTy, y: TensorTy, propagate_nan: tl.PropagateNan):
  338. x, y = self.binary_op_type_checking_impl(x, y)
  339. dtype = x.dtype
  340. if dtype.is_floating():
  341. if propagate_nan == tl.PropagateNan.ALL:
  342. return self.tensor(self.builder.create_maximumf(x.handle, y.handle), x.type)
  343. elif propagate_nan == tl.PropagateNan.NONE:
  344. return self.tensor(self.builder.create_maxnumf(x.handle, y.handle), x.type)
  345. else:
  346. raise ValueError(f"Unexpected propagate_nan {propagate_nan}")
  347. elif dtype.is_int_signed():
  348. return self.tensor(self.builder.create_maxsi(x.handle, y.handle), x.type)
  349. elif dtype.is_int_unsigned():
  350. return self.tensor(self.builder.create_maxui(x.handle, y.handle), x.type)
  351. else:
  352. raise TypeError(f"Unexpected dtype {dtype}")
  353. def clamp(self, x: TensorTy, min: TensorTy, max: TensorTy, propagate_nan: tl.PropagateNan):
  354. min, max = self.binary_op_type_checking_impl(min, max)
  355. x, min = self.binary_op_type_checking_impl(x, min)
  356. x, max = self.binary_op_type_checking_impl(x, max)
  357. dtype = x.dtype
  358. if dtype.is_floating():
  359. return self.tensor(self.builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type)
  360. else:
  361. raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported")
  362. ##############
  363. # bitwise ops
  364. ##############
  365. def bitwise_op_type_checking_impl(self, input: TensorTy, other: TensorTy) -> Tuple[TensorTy, TensorTy]:
  366. input, other = self.binary_op_type_checking_impl(input, other)
  367. input_sca_ty = input.type.scalar
  368. other_sca_ty = other.type.scalar
  369. if not input_sca_ty.is_int() or not other_sca_ty.is_int():
  370. raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
  371. ret_sca_ty = self.integer_promote_impl(input_sca_ty, other_sca_ty)
  372. if ret_sca_ty != input_sca_ty:
  373. input = self.cast(input, ret_sca_ty)
  374. if ret_sca_ty != other_sca_ty:
  375. other = self.cast(other, ret_sca_ty)
  376. return input, other
  377. def and_(self, input: TensorTy, other: TensorTy) -> TensorTy:
  378. input, other = self.bitwise_op_type_checking_impl(input, other)
  379. return self.tensor(self.builder.create_and(input.handle, other.handle), input.type)
  380. def or_(self, input: TensorTy, other: TensorTy) -> TensorTy:
  381. input, other = self.bitwise_op_type_checking_impl(input, other)
  382. return self.tensor(self.builder.create_or(input.handle, other.handle), input.type)
  383. def xor_(self, input: TensorTy, other: TensorTy) -> TensorTy:
  384. input, other = self.bitwise_op_type_checking_impl(input, other)
  385. return self.tensor(self.builder.create_xor(input.handle, other.handle), input.type)
  386. def logical_and(self, input: TensorTy, other: TensorTy) -> TensorTy:
  387. if not input.type.is_int1():
  388. input = self.bitcast(input, tl.int1)
  389. if not other.type.is_int1():
  390. other = self.bitcast(other, tl.int1)
  391. return self.and_(input, other)
  392. def logical_or(self, input: TensorTy, other: TensorTy) -> TensorTy:
  393. if not input.type.is_int1():
  394. input = self.bitcast(input, tl.int1)
  395. if not other.type.is_int1():
  396. other = self.bitcast(other, tl.int1)
  397. return self.or_(input, other)
  398. def not_(self, input: TensorTy):
  399. if not input.type.is_int1():
  400. input = self.bitcast(input, tl.int1)
  401. return self.invert(input)
  402. def lshr(self, input: TensorTy, other: TensorTy) -> TensorTy:
  403. input, other = self.bitwise_op_type_checking_impl(input, other)
  404. return self.tensor(self.builder.create_lshr(input.handle, other.handle), input.type)
  405. def ashr(self, input: TensorTy, other: TensorTy) -> TensorTy:
  406. input, other = self.bitwise_op_type_checking_impl(input, other)
  407. return self.tensor(self.builder.create_ashr(input.handle, other.handle), input.type)
  408. def shl(self, input: TensorTy, other: TensorTy) -> TensorTy:
  409. input, other = self.bitwise_op_type_checking_impl(input, other)
  410. return self.tensor(self.builder.create_shl(input.handle, other.handle), input.type)
  411. # ===----------------------------------------------------------------------===//
  412. # Unary Operators
  413. # ===----------------------------------------------------------------------===//
  414. def plus(self, input: TensorTy) -> TensorTy:
  415. return input
  416. def minus(self, input: TensorTy) -> TensorTy:
  417. input_sca_ty = input.type.scalar
  418. if input_sca_ty.is_ptr():
  419. raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
  420. _0 = self.tensor(self.builder.get_null_value(input_sca_ty.to_ir(self.builder)), input_sca_ty)
  421. return self.sub(_0, input, True)
  422. def invert(self, input: TensorTy) -> TensorTy:
  423. input_sca_ty = input.type.scalar
  424. if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
  425. raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
  426. _1 = self.tensor(self.builder.get_all_ones_value(input_sca_ty.to_ir(self.builder)), input_sca_ty)
  427. return self.xor_(input, _1)
  428. # ===----------------------------------------------------------------------===//
  429. # Comparison Operators
  430. # ===----------------------------------------------------------------------===//
  431. def _bool_like(self, v: TensorTy) -> tl.block_type:
  432. return v.type.with_element_ty(tl.int1)
  433. def greater_than(self, input: TensorTy, other: TensorTy) -> TensorTy:
  434. input, other = self.binary_op_type_checking_impl(input, other)
  435. scalar_ty = input.type.scalar
  436. # float > float
  437. if scalar_ty.is_floating():
  438. return self.tensor(self.builder.create_fcmpOGT(input.handle, other.handle), self._bool_like(input))
  439. # > int
  440. elif scalar_ty.is_int():
  441. if scalar_ty.is_int_signed():
  442. return self.tensor(self.builder.create_icmpSGT(input.handle, other.handle), self._bool_like(input))
  443. else:
  444. return self.tensor(self.builder.create_icmpUGT(input.handle, other.handle), self._bool_like(input))
  445. raise TypeError(f"unexpected type {scalar_ty}")
  446. def greater_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
  447. input, other = self.binary_op_type_checking_impl(input, other)
  448. scalar_ty = input.type.scalar
  449. # float >= float
  450. if scalar_ty.is_floating():
  451. return self.tensor(self.builder.create_fcmpOGE(input.handle, other.handle), self._bool_like(input))
  452. # >= int
  453. elif scalar_ty.is_int():
  454. if scalar_ty.is_int_signed():
  455. return self.tensor(self.builder.create_icmpSGE(input.handle, other.handle), self._bool_like(input))
  456. else:
  457. return self.tensor(self.builder.create_icmpUGE(input.handle, other.handle), self._bool_like(input))
  458. raise TypeError(f"unexpected type {scalar_ty}")
  459. def less_than(self, input: TensorTy, other: TensorTy) -> TensorTy:
  460. input, other = self.binary_op_type_checking_impl(input, other)
  461. scalar_ty = input.type.scalar
  462. # float < float
  463. if scalar_ty.is_floating():
  464. return self.tensor(self.builder.create_fcmpOLT(input.handle, other.handle), self._bool_like(input))
  465. # < int
  466. elif scalar_ty.is_int():
  467. if scalar_ty.is_int_signed():
  468. return self.tensor(self.builder.create_icmpSLT(input.handle, other.handle), self._bool_like(input))
  469. else:
  470. return self.tensor(self.builder.create_icmpULT(input.handle, other.handle), self._bool_like(input))
  471. raise TypeError(f"unexpected type {scalar_ty}")
  472. def less_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
  473. input, other = self.binary_op_type_checking_impl(input, other)
  474. scalar_ty = input.type.scalar
  475. # float < float
  476. if scalar_ty.is_floating():
  477. return self.tensor(self.builder.create_fcmpOLE(input.handle, other.handle), self._bool_like(input))
  478. # < int
  479. elif scalar_ty.is_int():
  480. if scalar_ty.is_int_signed():
  481. return self.tensor(self.builder.create_icmpSLE(input.handle, other.handle), self._bool_like(input))
  482. else:
  483. return self.tensor(self.builder.create_icmpULE(input.handle, other.handle), self._bool_like(input))
  484. raise TypeError(f"unexpected type {scalar_ty}")
  485. def equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
  486. input, other = self.binary_op_type_checking_impl(input, other)
  487. scalar_ty = input.type.scalar
  488. # float == float
  489. if scalar_ty.is_floating():
  490. return self.tensor(self.builder.create_fcmpOEQ(input.handle, other.handle), self._bool_like(input))
  491. # == int
  492. elif scalar_ty.is_int():
  493. return self.tensor(self.builder.create_icmpEQ(input.handle, other.handle), self._bool_like(input))
  494. raise TypeError(f"unexpected type {scalar_ty}")
  495. def not_equal(self, input: TensorTy, other: TensorTy) -> TensorTy:
  496. input, other = self.binary_op_type_checking_impl(input, other)
  497. scalar_ty = input.type.scalar
  498. # float == float
  499. if scalar_ty.is_floating():
  500. return self.tensor(self.builder.create_fcmpUNE(input.handle, other.handle), self._bool_like(input))
  501. # == int
  502. elif scalar_ty.is_int():
  503. return self.tensor(self.builder.create_icmpNE(input.handle, other.handle), self._bool_like(input))
  504. raise TypeError(f"unexpected type {scalar_ty}")
  505. # ===----------------------------------------------------------------------===//
  506. # Block Creation
  507. # ===----------------------------------------------------------------------===//
  508. def arange(self, start: int, end: int, *, ret_ty: tl.block_type = None) -> TensorTy:
  509. if not isinstance(start, int) or not isinstance(end, int):
  510. raise ValueError("arange's arguments must be of type tl.constexpr")
  511. is_start_int64 = bool(start >> 32)
  512. is_end_int64 = bool(end >> 32)
  513. if is_start_int64 or is_end_int64:
  514. raise ValueError("arange must fit in int32")
  515. if end <= start:
  516. raise ValueError("arange's end argument must be greater than the start argument")
  517. range = end - start
  518. if (range & (range - 1)) != 0:
  519. raise ValueError("arange's range must be a power of 2")
  520. shape = [range]
  521. if ret_ty is None:
  522. ret_ty = tl.block_type(tl.int32, shape)
  523. ret_ty_ir = ret_ty.to_ir(self.builder)
  524. return self.tensor(self.builder.create_make_range(ret_ty_ir, start, end), ret_ty)
  525. def scalar_constant(self, value, dtype: tl.dtype) -> TensorTy:
  526. # scalar
  527. if dtype is None:
  528. raise ValueError("dtype must be specified when value is not a tensor")
  529. if value == 0:
  530. value = self.builder.get_null_value(dtype.to_ir(self.builder))
  531. else:
  532. get_value_fn = getattr(self.builder, f"get_{dtype.name}")
  533. value = get_value_fn(value)
  534. return self.tensor(value, dtype)
  535. def make_scalar(self, value, dtype: tl.dtype) -> TensorTy:
  536. if isinstance(value, tl.tensor):
  537. assert value.numel.value == 1, "only accepts size-1 tensor"
  538. return self.cast(value, dtype)
  539. # scalar
  540. return self.scalar_constant(value, dtype)
  541. def full(self, shape: List[int], value, dtype: tl.dtype) -> TensorTy:
  542. return self.splat(self.make_scalar(value, dtype), shape)
  543. # ===----------------------------------------------------------------------===//
  544. # Shape Manipulation
  545. # ===----------------------------------------------------------------------===//
  546. def splat(self, value: TensorTy, shape: List[int]) -> TensorTy:
  547. assert not value.type.is_block(), "Cannot splat a block tensor"
  548. if len(shape) == 0:
  549. return value
  550. ret_ty = tl.block_type(value.dtype, shape)
  551. return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty)
  552. def unsplat(self, value: TensorTy) -> TensorTy:
  553. return self.tensor(self.builder.create_unsplat(value.handle), value.dtype)
  554. def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy:
  555. numel = 1
  556. for s in dst_shape:
  557. numel *= s
  558. if input.type.numel != numel:
  559. raise ValueError("reshape() cannot change total number of elements in tensor")
  560. ret_ty = tl.block_type(input.type.scalar, dst_shape)
  561. return self.tensor(self.builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty)
  562. def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
  563. dst_shape = [tl._unwrap_if_constexpr(x) for x in input.shape]
  564. dst_shape.insert(axis, 1)
  565. if not input.type.is_block():
  566. return self.splat(input, shape=dst_shape)
  567. ret_ty = tl.block_type(input.type.scalar, dst_shape)
  568. return self.tensor(self.builder.create_expand_dims(input.handle, axis), ret_ty)
  569. def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy:
  570. assert can_reorder, "current implementation of `cat` always may reorder elements"
  571. assert len(lhs.shape) == 1
  572. ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
  573. return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type)
  574. def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
  575. a, b = self.broadcast_impl_value(a, b)
  576. # The IR can't handle joining two scalars, so upcast them to 1D tensors,
  577. # then downcast the result.
  578. was_rank_1 = a.shape == []
  579. if was_rank_1:
  580. a = self.expand_dims(a, 0)
  581. b = self.expand_dims(b, 0)
  582. if isinstance(a.shape[-1], tl.constexpr):
  583. two = tl.constexpr(2)
  584. else:
  585. two = 2
  586. new_shape = a.shape + [two]
  587. ret_type = tl.block_type(a.type.scalar, new_shape)
  588. ret = self.tensor(self.builder.create_join(a.handle, b.handle), ret_type)
  589. if was_rank_1:
  590. ret = self.reshape(ret, [2], can_reorder=False)
  591. return ret
  592. def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
  593. assert (len(a.shape) > 0)
  594. assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2)
  595. new_shape = a.shape[:-1]
  596. ret_type = tl.block_type(a.type.scalar, new_shape)
  597. outLHS, outRHS = self.builder.create_split(a.handle)
  598. return (
  599. self.tensor(outLHS, ret_type),
  600. self.tensor(outRHS, ret_type),
  601. )
  602. def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
  603. if len(input.shape) != len(dims):
  604. raise ValueError("permute dims must have the same length as input shape")
  605. if sorted(tl._unwrap_if_constexpr(d) for d in dims) != list(range(len(dims))):
  606. raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}")
  607. ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims])
  608. return self.tensor(self.builder.create_trans(input.handle, dims), ret_type)
  609. def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
  610. if not input.type.is_block():
  611. return self.splat(input, shape)
  612. src_shape = input.type.get_block_shapes()
  613. if len(src_shape) != len(shape):
  614. raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
  615. if shape == src_shape:
  616. return input
  617. for i, item in enumerate(src_shape):
  618. if shape[i] != item and item != 1:
  619. raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
  620. f" must match the existing size ({item}) at non-singleton dimension"
  621. f" {i}: {src_shape}, {shape}")
  622. ret_ty = tl.block_type(input.type.scalar, shape)
  623. return self.tensor(self.builder.create_broadcast(input.handle, shape), ret_ty)
  624. def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
  625. lhs_ty = lhs.type
  626. rhs_ty = rhs.type
  627. # make_shape_compatible(block, scalar)
  628. if lhs_ty.is_block() and not rhs_ty.is_block():
  629. rhs_ty = lhs_ty.with_element_ty(rhs_ty.scalar)
  630. rhs = self.tensor(self.builder.create_splat(rhs_ty.to_ir(self.builder), rhs.handle), rhs_ty)
  631. # make_shape_compatible(scalar, block)
  632. elif not lhs_ty.is_block() and rhs_ty.is_block():
  633. lhs_ty = rhs_ty.with_element_ty(lhs_ty.scalar)
  634. lhs = self.tensor(self.builder.create_splat(lhs_ty.to_ir(self.builder), lhs.handle), lhs_ty)
  635. # make_shape_compatible(block, block)
  636. elif lhs_ty.is_block() and rhs_ty.is_block():
  637. lhs_shape = lhs_ty.get_block_shapes()
  638. rhs_shape = rhs_ty.get_block_shapes()
  639. if len(lhs_shape) < len(rhs_shape):
  640. # Add new axes to lhs
  641. for _ in range(len(lhs_shape), len(rhs_shape)):
  642. lhs = self.tensor(self.builder.create_expand_dims(lhs.handle, 0),
  643. tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values))
  644. lhs_ty = lhs.type
  645. lhs_shape = lhs_ty.get_block_shapes()
  646. elif len(rhs_shape) < len(lhs_shape):
  647. # Add new axes to rhs
  648. for _ in range(len(rhs_shape), len(lhs_shape)):
  649. rhs = self.tensor(self.builder.create_expand_dims(rhs.handle, 0),
  650. tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
  651. rhs_ty = rhs.type
  652. rhs_shape = rhs_ty.get_block_shapes()
  653. assert len(rhs_shape) == len(lhs_shape)
  654. ret_shape = []
  655. for i, left in enumerate(lhs_shape):
  656. right = rhs_shape[i]
  657. if left == 1:
  658. ret_shape.append(right)
  659. elif (right == 1) or (right == left):
  660. ret_shape.append(left)
  661. else:
  662. raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
  663. "at index " + str(i) + ": " + str(left) + " and " + str(right))
  664. if lhs_shape != ret_shape:
  665. ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
  666. lhs = self.tensor(self.builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
  667. if rhs_shape != ret_shape:
  668. ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
  669. rhs = self.tensor(self.builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
  670. # (scalar, scalar) => returns original blocks
  671. return lhs, rhs
  672. #######
  673. # cast
  674. #######
  675. def _str_to_rounding_mode(self, rounding_mode: Optional[str]):
  676. if rounding_mode is None:
  677. return None
  678. if rounding_mode == 'rtne':
  679. return ir.ROUNDING_MODE.RTNE
  680. if rounding_mode == 'rtz':
  681. return ir.ROUNDING_MODE.RTZ
  682. raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.")
  683. def bitcast(self, input: TensorTy, dst_ty: tl.dtype) -> TensorTy:
  684. src_ty = input.type
  685. if src_ty.is_block():
  686. dst_ty = src_ty.with_element_ty(dst_ty.scalar)
  687. if src_ty == dst_ty:
  688. return input
  689. src_sca_ty = src_ty.scalar
  690. dst_sca_ty = dst_ty.scalar
  691. if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
  692. return self.cast(input, dst_ty)
  693. # Bitcast
  694. src_bits = src_sca_ty.primitive_bitwidth
  695. dst_bits = dst_sca_ty.primitive_bitwidth
  696. if src_bits != dst_bits:
  697. raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
  698. "data-type of size " + str(dst_bits))
  699. return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  700. def cast(self, input: TensorTy, dst_ty: tl.dtype, fp_downcast_rounding: Optional[str] = None) -> TensorTy:
  701. src_ty = input.type
  702. src_sca_ty = src_ty.scalar
  703. dst_sca_ty = dst_ty.scalar
  704. if src_sca_ty == dst_sca_ty:
  705. return input
  706. if src_ty.is_block():
  707. dst_ty = src_ty.with_element_ty(dst_sca_ty)
  708. # For fp downcasting default rounding mode should be RTNE, for all other conversions it should
  709. # not be set
  710. fp_downcast_rounding = self._str_to_rounding_mode(fp_downcast_rounding)
  711. use_custom_rounding = False
  712. if dst_sca_ty.is_floating() and src_sca_ty.is_floating(
  713. ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth:
  714. if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE
  715. elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True
  716. else:
  717. if fp_downcast_rounding is not None:
  718. raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. "
  719. "Source scalar type is " + str(src_sca_ty) + " and destination type is " +
  720. str(dst_sca_ty))
  721. if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()):
  722. assert self.builder.codegen_fns.get(
  723. "convert_custom_types") is not None, "target doesn't provide conversion for this type."
  724. return self.builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _semantic=self)
  725. # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
  726. # and non-default rounding modes for downcasting
  727. if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
  728. (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \
  729. use_custom_rounding:
  730. return self.tensor(
  731. self.builder.create_fp_to_fp(input.handle, dst_ty.to_ir(self.builder), fp_downcast_rounding), dst_ty)
  732. # bf16 <=> (not fp32)
  733. if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
  734. (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
  735. return self.cast(self.cast(input, tl.float32), dst_sca_ty)
  736. # Standard floating types' casting: truncation
  737. # fp64 => fp32, fp16, bf16
  738. # fp32 => fp16, bf16
  739. truncate_fp = src_sca_ty.is_floating() and \
  740. dst_sca_ty.is_floating() and \
  741. src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
  742. if truncate_fp:
  743. return self.tensor(self.builder.create_fp_trunc(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  744. # Standard floating types' casting: extension
  745. # fp32 => fp64
  746. # fp16 => fp32, fp64
  747. # bf16 => fp32, fp64
  748. ext_fp = src_sca_ty.is_floating() and \
  749. dst_sca_ty.is_floating() and \
  750. src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
  751. if ext_fp:
  752. return self.tensor(self.builder.create_fp_ext(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  753. # Casting between integer types
  754. if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
  755. (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
  756. sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
  757. if dst_sca_ty.is_bool():
  758. ty = input.dtype.to_ir(self.builder)
  759. _0 = self.tensor(self.builder.get_null_value(ty), input.dtype)
  760. return self.not_equal(input, _0)
  761. else:
  762. return self.tensor(self.builder.create_int_cast(input.handle, dst_ty.to_ir(self.builder), sign_extend),
  763. dst_ty)
  764. # Casting standard floating types to integer types
  765. if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
  766. if dst_sca_ty.is_bool():
  767. ty = input.dtype.to_ir(self.builder)
  768. _0 = self.tensor(self.builder.get_null_value(ty), input.dtype)
  769. return self.not_equal(input, _0)
  770. elif dst_sca_ty.is_int_signed():
  771. return self.tensor(self.builder.create_fp_to_si(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  772. else:
  773. return self.tensor(self.builder.create_fp_to_ui(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  774. # Casting integer types to standard floating types
  775. if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
  776. if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
  777. return self.tensor(self.builder.create_ui_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  778. else:
  779. return self.tensor(self.builder.create_si_to_fp(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  780. # Casting pointer types to integer types
  781. if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
  782. bitwidth = dst_sca_ty.int_bitwidth
  783. if bitwidth == 64:
  784. return self.tensor(self.builder.create_ptr_to_int(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  785. if bitwidth == 1:
  786. return self.not_equal(self.cast(input, tl.int64), self.tensor(self.builder.get_int64(0), tl.int64))
  787. # Casting integer types to pointer types
  788. if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
  789. return self.tensor(self.builder.create_int_to_ptr(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  790. # Casting pointer types to pointer types
  791. if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
  792. return self.tensor(self.builder.create_bitcast(input.handle, dst_ty.to_ir(self.builder)), dst_ty)
  793. assert False, f'cannot cast {input} to {dst_ty}'
  794. # ===----------------------------------------------------------------------===//
  795. # Memory Operators
  796. # ===----------------------------------------------------------------------===//
  797. def _str_to_load_cache_modifier(self, cache_modifier):
  798. cache = ir.CACHE_MODIFIER.NONE # default
  799. if cache_modifier:
  800. if cache_modifier == ".ca":
  801. cache = ir.CACHE_MODIFIER.CA
  802. elif cache_modifier == ".cg":
  803. cache = ir.CACHE_MODIFIER.CG
  804. elif cache_modifier == ".cv":
  805. cache = ir.CACHE_MODIFIER.CV
  806. else:
  807. raise ValueError(f"Cache modifier {cache_modifier} not supported")
  808. return cache
  809. def _str_to_store_cache_modifier(self, cache_modifier):
  810. cache = ir.CACHE_MODIFIER.NONE # default
  811. if cache_modifier:
  812. if cache_modifier == ".wb":
  813. cache = ir.CACHE_MODIFIER.WB
  814. elif cache_modifier == ".cg":
  815. cache = ir.CACHE_MODIFIER.CG
  816. elif cache_modifier == ".cs":
  817. cache = ir.CACHE_MODIFIER.CS
  818. elif cache_modifier == ".wt":
  819. cache = ir.CACHE_MODIFIER.WT
  820. else:
  821. raise ValueError(f"Cache modifier {cache_modifier} not supported")
  822. return cache
  823. def _str_to_eviction_policy(self, eviction_policy):
  824. eviction = ir.EVICTION_POLICY.NORMAL # default
  825. if eviction_policy:
  826. if eviction_policy == "evict_last":
  827. eviction = ir.EVICTION_POLICY.EVICT_LAST
  828. elif eviction_policy == "evict_first":
  829. eviction = ir.EVICTION_POLICY.EVICT_FIRST
  830. else:
  831. raise ValueError(f"Eviction policy {eviction_policy} not supported")
  832. return eviction
  833. def _str_to_padding_option(self, padding_option):
  834. padding = None # default
  835. if padding_option:
  836. if padding_option == "zero":
  837. padding = ir.PADDING_OPTION.PAD_ZERO
  838. elif padding_option == "nan":
  839. padding = ir.PADDING_OPTION.PAD_NAN
  840. else:
  841. raise ValueError(f"Padding option {padding_option} not supported")
  842. return padding
  843. def _str_to_sem(self, sem_option):
  844. sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
  845. if sem_option:
  846. if sem_option == "acquire":
  847. sem = ir.MEM_SEMANTIC.ACQUIRE
  848. elif sem_option == "release":
  849. sem = ir.MEM_SEMANTIC.RELEASE
  850. elif sem_option == "acq_rel":
  851. sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE
  852. elif sem_option == "relaxed":
  853. sem = ir.MEM_SEMANTIC.RELAXED
  854. else:
  855. raise ValueError(f"Memory semantic {sem_option} not supported")
  856. return sem
  857. def _str_to_scope(self, scope_option):
  858. scope = ir.MEM_SYNC_SCOPE.GPU
  859. if scope_option:
  860. if scope_option == "gpu":
  861. scope = ir.MEM_SYNC_SCOPE.GPU
  862. elif scope_option == "cta":
  863. scope = ir.MEM_SYNC_SCOPE.CTA
  864. elif scope_option == "sys":
  865. scope = ir.MEM_SYNC_SCOPE.SYSTEM
  866. else:
  867. raise ValueError(f"Memory semantic {scope_option} not supported")
  868. return scope
  869. def _canonicalize_boundary_check(self, boundary_check, block_shape):
  870. if boundary_check:
  871. if not hasattr(boundary_check, "__iter__"):
  872. boundary_check = [boundary_check]
  873. boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check]
  874. for dim in boundary_check:
  875. assert isinstance(dim, int) and 0 <= dim < len(block_shape)
  876. assert len(boundary_check) > 0
  877. assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`"
  878. return sorted(boundary_check)
  879. return ()
  880. def _load_block_pointer(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile):
  881. # Load by a block pointer: `pointer_type<block_type<>>`
  882. # Block pointer can not have `mask` and `other` arguments
  883. if mask is not None or other is not None:
  884. raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
  885. elt_ty = ptr.type.element_ty.element_ty
  886. assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
  887. if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
  888. raise ValueError("Padding option `nan` is not supported for integer block pointers")
  889. # `dst_ty` is de-referenced type of the pointer type
  890. dst_ty = ptr.type.element_ty
  891. # Check `boundary_check` argument
  892. boundary_check = self._canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes())
  893. # Build IR
  894. return self.tensor(
  895. self.builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile),
  896. dst_ty)
  897. def _load_legacy(self, ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile):
  898. # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
  899. if not ptr.type.scalar.is_ptr():
  900. raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`")
  901. # Check `mask`, `other`, `boundary_check`, and `padding` arguments
  902. if mask is None and other is not None:
  903. raise ValueError("`other` cannot be provided without `mask`")
  904. if padding or boundary_check:
  905. raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of"
  906. "pointers or loading a scalar. Because the compiler does not know the boundary; please "
  907. "use block pointers (defined by `make_block_ptr`) instead")
  908. # For a pointer of scalar, check the type of `mask` and `other`
  909. if not ptr.type.is_block():
  910. if mask and mask.type.is_block():
  911. raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
  912. if other and other.type.is_block():
  913. raise ValueError("Other argument cannot be block type if pointer argument is not a block")
  914. # Make `mask` and `other` into the same shape as `ptr`
  915. if ptr.type.is_block():
  916. if mask is not None:
  917. ptr, mask = self.broadcast_impl_value(ptr, mask)
  918. if other is not None:
  919. ptr, other = self.broadcast_impl_value(ptr, other)
  920. # Get `pointer_type<elt_ty>` and `elt_ty`
  921. ptr_ty = ptr.type.scalar
  922. elt_ty = ptr_ty.element_ty
  923. # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
  924. is_bool = elt_ty == tl.int1
  925. if is_bool:
  926. elt_ty = tl.int8
  927. ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
  928. ptr = self.cast(ptr, ptr_ty)
  929. # Cast `other` into `elt_ty` type
  930. if other is not None:
  931. other = self.cast(other, elt_ty)
  932. # Create loaded result type `dst_ty`
  933. if ptr.type.is_block():
  934. dst_ty = ptr.type.with_element_ty(elt_ty)
  935. else:
  936. # Load by de-referencing the pointer of scalar
  937. dst_ty = elt_ty
  938. # Build IR
  939. if mask is None:
  940. ret = self.tensor(self.builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty)
  941. else:
  942. ret = self.tensor(
  943. self.builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache,
  944. eviction, is_volatile), dst_ty)
  945. if is_bool:
  946. ret = self.cast(ret, tl.int1)
  947. return ret
  948. def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy], boundary_check: Tuple,
  949. padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool) -> TensorTy:
  950. # Cache, eviction and padding options
  951. cache = self._str_to_load_cache_modifier(cache_modifier)
  952. eviction = self._str_to_eviction_policy(eviction_policy)
  953. padding = self._str_to_padding_option(padding_option)
  954. if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
  955. # Load by a block pointer: `pointer_type<block_type<>>`
  956. return self._load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
  957. else:
  958. # Load by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
  959. return self._load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile)
  960. def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str,
  961. eviction_policy: str) -> TensorTy:
  962. assert isinstance(desc, tl.tensor_descriptor_base)
  963. ndim = len(desc.block_shape)
  964. assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
  965. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  966. x = self.builder.create_descriptor_load(desc.handle, offsets, self._str_to_load_cache_modifier(cache_modifier),
  967. self._str_to_eviction_policy(eviction_policy))
  968. return self.tensor(x, desc.block_type)
  969. def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None:
  970. assert isinstance(desc, tl.tensor_descriptor_base)
  971. ndim = len(desc.block_shape)
  972. assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
  973. assert value.shape == desc.block_shape
  974. def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
  975. self.validate_store_like(desc, value, offsets)
  976. # implicitly cast to the descriptor's type
  977. value = self.cast(value, desc.dtype)
  978. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  979. return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
  980. def descriptor_atomic_add(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
  981. self.validate_store_like(desc, value, offsets)
  982. assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.float32, tl.float16, tl.bfloat16}, "Unsupported dtype"
  983. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  984. kind = ir.DESCRIPTOR_REDUCE_KIND.ADD
  985. return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
  986. def _has_native_tma(self, ):
  987. target = driver.active.get_current_target()
  988. return (target.backend == "cuda" and target.arch >= 90)
  989. def _descriptor_atomic_min_max_supported(self, dtype):
  990. assert dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64, tl.float16, tl.bfloat16}, "Unsupported dtype"
  991. if dtype in {tl.float16, tl.bfloat16}:
  992. assert self._has_native_tma(), "16-bit float types require native tma support"
  993. def descriptor_atomic_min(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
  994. self.validate_store_like(desc, value, offsets)
  995. self._descriptor_atomic_min_max_supported(desc.dtype)
  996. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  997. kind = ir.DESCRIPTOR_REDUCE_KIND.MIN
  998. return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
  999. def descriptor_atomic_max(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
  1000. self.validate_store_like(desc, value, offsets)
  1001. self._descriptor_atomic_min_max_supported(desc.dtype)
  1002. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  1003. kind = ir.DESCRIPTOR_REDUCE_KIND.MAX
  1004. return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
  1005. def descriptor_atomic_and(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
  1006. self.validate_store_like(desc, value, offsets)
  1007. assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
  1008. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  1009. kind = ir.DESCRIPTOR_REDUCE_KIND.AND
  1010. return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
  1011. def descriptor_atomic_or(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
  1012. self.validate_store_like(desc, value, offsets)
  1013. assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
  1014. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  1015. kind = ir.DESCRIPTOR_REDUCE_KIND.OR
  1016. return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
  1017. def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
  1018. self.validate_store_like(desc, value, offsets)
  1019. assert desc.dtype in {tl.uint32, tl.int32, tl.uint64, tl.int64}, "Unsupported dtype"
  1020. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  1021. kind = ir.DESCRIPTOR_REDUCE_KIND.XOR
  1022. return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)
  1023. def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy:
  1024. assert isinstance(desc, tl.tensor_descriptor_base)
  1025. assert cache_modifier == "", "cache modifier is not supported yet"
  1026. assert eviction_policy == "", "eviction policy is not supported yet"
  1027. # Validate descriptor.
  1028. assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
  1029. assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
  1030. # Validate offsets.
  1031. assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shape}"
  1032. # Validate minimum block size.
  1033. assert x_offsets.shape[0] >= 8, f"descriptor gather must have at least 8 rows, but got {x_offsets.shape}"
  1034. dtype = desc.dtype
  1035. min_cols = 32 // dtype.primitive_bitwidth * 8
  1036. assert desc.block_shape[
  1037. 1] >= min_cols, f"descriptor gather of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
  1038. type = tl.block_type(desc.dtype, [x_offsets.shape[0], desc.block_shape[1]])
  1039. y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0]
  1040. x = self.builder.create_descriptor_gather(desc.handle, x_offsets.handle, y_offset, type.to_ir(self.builder))
  1041. return self.tensor(x, type)
  1042. def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy:
  1043. assert isinstance(desc, tl.tensor_descriptor_base)
  1044. # Validate descriptor.
  1045. assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
  1046. assert desc.block_shape[0] == 1, f"descriptor block must have 1 row, but got {desc.block_shape}"
  1047. # Validate offsets.
  1048. assert len(x_offsets.shape) == 1, f"x offsets must be 1D, but got {x_offsets.shapae}"
  1049. # Validate minimum block size.
  1050. assert x_offsets.shape[0] >= 8, f"descriptor scatter must have at least 8 rows, but got {x_offsets.shape}"
  1051. dtype = desc.dtype
  1052. min_cols = 32 // dtype.primitive_bitwidth * 8
  1053. assert desc.block_shape[
  1054. 1] >= min_cols, f"descriptor scatter of {dtype} must have at least {min_cols} columns, but got {desc.block_shape[1]}"
  1055. y_offset = self._convert_to_ir_values((y_offset, ), require_i64=False)[0]
  1056. self.builder.create_descriptor_scatter(desc.handle, value.handle, x_offsets.handle, y_offset)
  1057. return self.tensor(None, tl.void)
  1058. def _store_block_pointer(self, ptr, val, mask, boundary_check, cache, eviction):
  1059. # Store by a block pointer: `pointer_type<block_type<>>`
  1060. # Block pointers can not have the `mask` argument
  1061. if mask is not None:
  1062. raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers")
  1063. # Check same shape and element type
  1064. block_shape = ptr.type.element_ty.get_block_shapes()
  1065. if not val.type.is_block():
  1066. val = self.broadcast_impl_shape(val, block_shape)
  1067. assert val.type.is_block(), "Value argument must be block type or a scalar"
  1068. assert block_shape == val.type.get_block_shapes(
  1069. ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch"
  1070. assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch"
  1071. elt_ty = ptr.type.element_ty.element_ty
  1072. assert elt_ty != tl.int1, "`tl.int1` should be rewritten in `tl.make_block_ptr`"
  1073. # Check `boundary_check` argument
  1074. boundary_check = self._canonicalize_boundary_check(boundary_check, block_shape)
  1075. # Cast to target data type
  1076. val = self.cast(val, elt_ty)
  1077. # Build IR
  1078. return self.tensor(
  1079. self.builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), tl.void)
  1080. def _store_legacy(self, ptr, val, mask, boundary_check, cache, eviction):
  1081. # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
  1082. if not ptr.type.scalar.is_ptr():
  1083. raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`")
  1084. # Check `boundary_check` argument
  1085. if boundary_check:
  1086. raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a "
  1087. "scalar. Because the compiler does not know the boundary; please use block pointers "
  1088. "(defined by `make_block_ptr`) instead")
  1089. # For a pointer of scalar, check the type of `val` and `mask`
  1090. if not ptr.type.is_block():
  1091. if val.type.is_block():
  1092. raise ValueError("Value argument cannot be block type if pointer argument is not a block")
  1093. if mask and mask.type.is_block():
  1094. raise ValueError("Mask argument cannot be block type if pointer argument is not a block")
  1095. # Make `mask` and `val` into the same shape as `ptr`
  1096. if ptr.type.is_block():
  1097. ptr_shape = ptr.shape
  1098. if mask is None:
  1099. ptr, val = self.broadcast_tensors(ptr, val)
  1100. else:
  1101. ptr, val, mask = self.broadcast_tensors(ptr, val, mask)
  1102. if ptr_shape != ptr.shape:
  1103. raise ValueError(f"Expected pointer argument to have shape {ptr.shape} but got {ptr_shape}")
  1104. ptr_ty = ptr.type.scalar
  1105. elt_ty = ptr_ty.element_ty
  1106. # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
  1107. if elt_ty == tl.int1:
  1108. elt_ty = tl.int8
  1109. ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
  1110. ptr = self.cast(ptr, ptr_ty)
  1111. # Cast to target data type
  1112. val = self.cast(val, elt_ty)
  1113. # Build IR
  1114. if mask is None:
  1115. return self.tensor(self.builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
  1116. if not mask.type.scalar.is_bool():
  1117. raise ValueError("Mask must have boolean scalar type")
  1118. return self.tensor(self.builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction),
  1119. tl.void)
  1120. def store(self, ptr: TensorTy, val: TensorTy, mask: Optional[TensorTy], boundary_check, cache_modifier: str,
  1121. eviction_policy: str) -> TensorTy:
  1122. # Cache and eviction options
  1123. cache = self._str_to_store_cache_modifier(cache_modifier)
  1124. eviction = self._str_to_eviction_policy(eviction_policy)
  1125. if ptr.type.is_const() or ptr.type.scalar.is_const():
  1126. raise ValueError("Cannot store to a constant pointer")
  1127. if ptr.type.is_ptr() and ptr.type.element_ty.is_block():
  1128. # Store by a block pointer: `pointer_type<block_type<>>`
  1129. return self._store_block_pointer(ptr, val, mask, boundary_check, cache, eviction)
  1130. else:
  1131. # Store by a tensor of pointers or a pointer of scalar: `block_type<pointer_type<>>` or `pointer_type<>`
  1132. return self._store_legacy(ptr, val, mask, boundary_check, cache, eviction)
  1133. #########
  1134. # atomic
  1135. #########
  1136. def atomic_cas(self, ptr: TensorTy, cmp: TensorTy, val: TensorTy, sem: str, scope: str) -> TensorTy:
  1137. sem = self._str_to_sem(sem)
  1138. scope = self._str_to_scope(scope)
  1139. element_ty = ptr.type.scalar.element_ty
  1140. if element_ty.primitive_bitwidth not in [16, 32, 64]:
  1141. raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
  1142. return self.tensor(self.builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
  1143. def atom_red_typechecking_impl(self, ptr: TensorTy, val: TensorTy, mask: TensorTy,
  1144. op: str) -> Tuple[TensorTy, TensorTy, TensorTy]:
  1145. if not ptr.type.scalar.is_ptr():
  1146. raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
  1147. if ptr.type.is_const() or ptr.type.element_ty.is_const():
  1148. raise ValueError("Cannot store to a constant pointer")
  1149. element_ty = ptr.type.scalar.element_ty
  1150. if element_ty is tl.float16 and op != 'add':
  1151. raise ValueError("atomic_" + op + " does not support fp16")
  1152. if element_ty is tl.bfloat16 and op != 'add':
  1153. raise ValueError("atomic_" + op + " does not support bf16")
  1154. if element_ty in [tl.int16, tl.uint16] or element_ty.primitive_bitwidth < 16:
  1155. raise ValueError("atomic_" + op + " does not support " + str(element_ty))
  1156. if ptr.type.is_block():
  1157. if mask is not None:
  1158. mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
  1159. if val is not None:
  1160. val = self.broadcast_impl_shape(val, ptr.type.get_block_shapes())
  1161. val = self.cast(val, ptr.type.scalar.element_ty)
  1162. if mask is None:
  1163. mask_ir = self.builder.get_int1(True)
  1164. mask_ty = tl.int1
  1165. if ptr.type.is_block():
  1166. mask_ty = ptr.type.with_element_ty(tl.int1)
  1167. mask_ir = self.builder.create_splat(mask_ty.to_ir(self.builder), mask_ir)
  1168. mask = self.tensor(mask_ir, mask_ty)
  1169. return ptr, val, mask
  1170. def _signbit(self, x: TensorTy) -> TensorTy:
  1171. bitwidth = x.dtype.primitive_bitwidth
  1172. idtype = tl.get_int_dtype(bitwidth=bitwidth, signed=False)
  1173. ix = self.bitcast(x, idtype)
  1174. signbit = self.lshr(ix, bitwidth - 1)
  1175. return self.cast(signbit, tl.int1)
  1176. def atomic_max(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
  1177. ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'max')
  1178. sem = self._str_to_sem(sem)
  1179. scope = self._str_to_scope(scope)
  1180. sca_ty = val.type.scalar
  1181. # direct call to atomic_max for integers
  1182. if sca_ty.is_int():
  1183. if sca_ty.is_int_signed():
  1184. return self.tensor(
  1185. self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope),
  1186. val.type)
  1187. else:
  1188. return self.tensor(
  1189. self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope),
  1190. val.type)
  1191. # for float
  1192. # return atomic_smax(i_ptr, i_val) if val >= 0
  1193. # return atomic_umin(i_ptr, i_val) if val < 0
  1194. if sca_ty not in {tl.float32, tl.float64}:
  1195. raise TypeError(f"atomic_max not supported for dtype {sca_ty}")
  1196. i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
  1197. i_val = self.bitcast(val, i_type)
  1198. i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1))
  1199. ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
  1200. ui_val = self.bitcast(val, ui_type)
  1201. ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1))
  1202. neg = self._signbit(val)
  1203. pos = self.not_(neg)
  1204. pos_ret = self.tensor(
  1205. self.builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle,
  1206. self.and_(mask, pos).handle, sem, scope), i_val.type)
  1207. neg_ret = self.tensor(
  1208. self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle,
  1209. self.and_(mask, neg).handle, sem, scope), ui_val.type)
  1210. ret = self.where(pos, pos_ret, neg_ret)
  1211. return self.bitcast(ret, sca_ty)
  1212. def atomic_min(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
  1213. ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'min')
  1214. sem = self._str_to_sem(sem)
  1215. scope = self._str_to_scope(scope)
  1216. sca_ty = val.type.scalar
  1217. # direct call to atomic_min for integers
  1218. if sca_ty.is_int():
  1219. if sca_ty.is_int_signed():
  1220. return self.tensor(
  1221. self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope),
  1222. val.type)
  1223. else:
  1224. return self.tensor(
  1225. self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope),
  1226. val.type)
  1227. # for float
  1228. # return atomic_smin(i_ptr, i_val) if val >= 0
  1229. # return atomic_umax(i_ptr, i_val) if val < 0
  1230. if sca_ty not in {tl.float32, tl.float64}:
  1231. raise TypeError(f"atomic_min not supported for dtype {sca_ty}")
  1232. i_type = tl.int32 if sca_ty == tl.float32 else tl.int64
  1233. i_val = self.bitcast(val, i_type)
  1234. i_ptr = self.bitcast(ptr, tl.pointer_type(i_type, 1))
  1235. ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64
  1236. ui_val = self.bitcast(val, ui_type)
  1237. ui_ptr = self.bitcast(ptr, tl.pointer_type(ui_type, 1))
  1238. neg = self._signbit(val)
  1239. pos = self.not_(neg)
  1240. pos_ret = self.tensor(
  1241. self.builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle,
  1242. self.and_(mask, pos).handle, sem, scope), i_val.type)
  1243. neg_ret = self.tensor(
  1244. self.builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle,
  1245. self.and_(mask, neg).handle, sem, scope), ui_ptr.type)
  1246. ret = self.where(pos, pos_ret, neg_ret)
  1247. return self.bitcast(ret, sca_ty)
  1248. def atomic_add(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
  1249. ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'add')
  1250. sem = self._str_to_sem(sem)
  1251. scope = self._str_to_scope(scope)
  1252. sca_ty = val.type.scalar
  1253. op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
  1254. return self.tensor(self.builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope),
  1255. val.type)
  1256. def atomic_and(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
  1257. ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'and')
  1258. sem = self._str_to_sem(sem)
  1259. scope = self._str_to_scope(scope)
  1260. return self.tensor(
  1261. self.builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
  1262. def atomic_or(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
  1263. ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'or')
  1264. sem = self._str_to_sem(sem)
  1265. scope = self._str_to_scope(scope)
  1266. return self.tensor(
  1267. self.builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
  1268. def atomic_xor(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
  1269. ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xor')
  1270. sem = self._str_to_sem(sem)
  1271. scope = self._str_to_scope(scope)
  1272. return self.tensor(
  1273. self.builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
  1274. def atomic_xchg(self, ptr: TensorTy, val: TensorTy, mask: TensorTy, sem: str, scope: str) -> TensorTy:
  1275. ptr, val, mask = self.atom_red_typechecking_impl(ptr, val, mask, 'xchg')
  1276. sem = self._str_to_sem(sem)
  1277. scope = self._str_to_scope(scope)
  1278. return self.tensor(
  1279. self.builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope),
  1280. val.type)
  1281. # ===----------------------------------------------------------------------===//
  1282. # Linear Algebra
  1283. # ===----------------------------------------------------------------------===//
  1284. def _str_to_dot_input_precision(self, input_precision):
  1285. assert input_precision.lower() in self.builder.options.allowed_dot_input_precisions, \
  1286. f"input_precision must be one of {self.builder.options.allowed_dot_input_precisions}. Got {input_precision}"
  1287. input_precision = input_precision.upper()
  1288. if input_precision == "TF32X3":
  1289. input_precision = "TF32x3"
  1290. if input_precision == "BF16X3":
  1291. input_precision = "BF16x3"
  1292. if input_precision == "BF16X6":
  1293. input_precision = "BF16x6"
  1294. return getattr(ir.INPUT_PRECISION, input_precision)
  1295. def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],
  1296. max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy:
  1297. assert lhs.type.is_block() and rhs.type.is_block()
  1298. if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
  1299. # All combinations of supported fp8 x fp8 are permitted
  1300. pass
  1301. else:
  1302. assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
  1303. tl.float64), f"Unsupported lhs dtype {lhs.dtype}"
  1304. assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
  1305. tl.float64), f"Unsupported rhs dtype {rhs.dtype}"
  1306. assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
  1307. if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
  1308. if "fp8e4b15" in self.builder.options.deprecated_fp8_dot_operand_dtypes:
  1309. warnings.warn(
  1310. "the use of fp8e4b15 is deprecated on Hopper and later architectures and can cause significant slow down. It will be removed in a future triton release"
  1311. )
  1312. # We upcast because there's no fp8e4b15 type in MLIR
  1313. lhs = self.cast(lhs, tl.float16)
  1314. rhs = self.cast(rhs, tl.float16)
  1315. uses_fp8e4b8 = lhs.dtype.is_fp8e4b8() or rhs.dtype.is_fp8e4b8()
  1316. uses_fp8e5b16 = lhs.dtype.is_fp8e5b16() or rhs.dtype.is_fp8e5b16()
  1317. if uses_fp8e4b8 or uses_fp8e5b16:
  1318. type_name = "fp8e4b8" if uses_fp8e4b8 else "fp8e5b16"
  1319. if type_name in self.builder.options.deprecated_fp8_dot_operand_dtypes:
  1320. arch = self.builder.options.arch
  1321. warnings.warn(
  1322. f"{type_name} is AMD gfx942 specific and not supported on {arch} so it's upcasted to fp16 and can cause significant slow down. "
  1323. f"Please use OCP fp8 variants on {arch} for performance")
  1324. lhs = self.cast(lhs, tl.float16)
  1325. rhs = self.cast(rhs, tl.float16)
  1326. if input_precision is None:
  1327. input_precision = self.builder.options.default_dot_input_precision
  1328. input_precision = self._str_to_dot_input_precision(input_precision)
  1329. lhs_rank = len(lhs.shape)
  1330. rhs_rank = len(rhs.shape)
  1331. assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
  1332. assert lhs.shape[-1].value == rhs.shape[
  1333. -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})"
  1334. assert self.builder.codegen_fns.get(
  1335. "min_dot_size") is not None, "target doesn't provide lower shape bounds for dot."
  1336. min_dot_size = self.builder.codegen_fns["min_dot_size"](lhs.type, rhs.type)
  1337. assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
  1338. and rhs.shape[-1].value >= min_dot_size[1], \
  1339. f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}"
  1340. if lhs.type.scalar.is_int():
  1341. assert lhs.type.scalar == tl.int8, "only int8 supported!"
  1342. _0 = self.builder.get_int32(0)
  1343. ret_scalar_ty = tl.int32
  1344. elif out_dtype.is_bf16():
  1345. raise ValueError(
  1346. "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`"
  1347. )
  1348. elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
  1349. _0 = self.builder.get_fp32(0)
  1350. ret_scalar_ty = tl.float32
  1351. elif lhs.type.scalar.is_fp64():
  1352. _0 = self.builder.get_fp64(0)
  1353. ret_scalar_ty = tl.float64
  1354. else:
  1355. _0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0)
  1356. ret_scalar_ty = out_dtype
  1357. M = lhs.type.shape[-2]
  1358. N = rhs.type.shape[-1]
  1359. K = lhs.type.shape[-1]
  1360. B = lhs.type.shape[0] if lhs_rank == 3 else None
  1361. ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N])
  1362. if acc is None:
  1363. acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
  1364. else:
  1365. acc_handle = acc.handle
  1366. assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
  1367. # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
  1368. if max_num_imprecise_acc is None:
  1369. if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
  1370. max_num_imprecise_acc = self.builder.options.max_num_imprecise_acc_default
  1371. else:
  1372. max_num_imprecise_acc = 0
  1373. else:
  1374. if lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and max_num_imprecise_acc > K:
  1375. raise ValueError(f"max_num_imprecise_acc ({max_num_imprecise_acc}) must be <= K ({K})")
  1376. return self.tensor(
  1377. self.builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), ret_ty)
  1378. def _str_to_fp_type(self, float_format: str):
  1379. ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
  1380. if ty_enum is None:
  1381. raise ValueError(f"Invalid float format: {float_format}.")
  1382. return ty_enum
  1383. def _bitcast_to_fp_type(self, val: TensorTy, float_format: str):
  1384. """
  1385. If float_format is subbyte, make sure it's packed as uint8 and return it.
  1386. Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
  1387. """
  1388. triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16":
  1389. tl.float16}.get(float_format)
  1390. if triton_ty is None:
  1391. assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
  1392. assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
  1393. return val
  1394. if val.dtype == triton_ty:
  1395. return val
  1396. else:
  1397. unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16, "fp16": tl.uint16}[float_format]
  1398. assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
  1399. return self.bitcast(val, triton_ty)
  1400. def verify_scaled_shape(self, M, N, K, lhs_scale, rhs_scale):
  1401. if lhs_scale is not None:
  1402. scale_factor = 16 if lhs_scale.dtype.is_fp8e4nv() else 32
  1403. lhs_scale_shape = lhs_scale.type.shape
  1404. assert lhs_scale_shape == [
  1405. M, K // scale_factor
  1406. ], f"lhs_scale must be a tensor of shape [{M}, {K // scale_factor}]. Got {lhs_scale_shape}"
  1407. if rhs_scale is not None:
  1408. scale_factor = 16 if rhs_scale.dtype.is_fp8e4nv() else 32
  1409. rhs_scale_shape = rhs_scale.type.shape
  1410. assert rhs_scale_shape == [
  1411. N, K // scale_factor
  1412. ], f"rhs_scale must be a tensor of shape [{N}, {K // scale_factor}]. Got {rhs_scale_shape}"
  1413. def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy,
  1414. rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool,
  1415. lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy:
  1416. assert lhs.type.is_block() and rhs.type.is_block()
  1417. #TODO: validate types.
  1418. lhs_rank = len(lhs.shape)
  1419. rhs_rank = len(rhs.shape)
  1420. assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
  1421. lhs_format: str = lhs_format.value
  1422. rhs_format: str = rhs_format.value
  1423. lhs_format_enum = self._str_to_fp_type(lhs_format)
  1424. rhs_format_enum = self._str_to_fp_type(rhs_format)
  1425. allowed_formats = {"e2m1", "e4m3", "e5m2", "bf16", "fp16"}
  1426. assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}"
  1427. assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}"
  1428. rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None)
  1429. lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None)
  1430. lhs = self._bitcast_to_fp_type(lhs, lhs_format)
  1431. rhs = self._bitcast_to_fp_type(rhs, rhs_format)
  1432. assert lhs_k_pack or lhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K"
  1433. assert rhs_k_pack or rhs_format == "e2m1", "only mxfp4 inputs can be packed along a dimension different than K"
  1434. M, K_LHS = lhs.type.shape[-2:]
  1435. K_RHS, N = rhs.type.shape[-2:]
  1436. PACKED_A = 2 if lhs_format == "e2m1" else 1
  1437. PACKED_B = 2 if rhs_format == "e2m1" else 1
  1438. PACKED_A_DIM = PACKED_A * K_LHS if lhs_k_pack else K_LHS
  1439. PACKED_B_DIM = PACKED_B * K_RHS if rhs_k_pack else K_RHS
  1440. assert PACKED_B_DIM == PACKED_A_DIM, f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
  1441. #assert K * PACKED_B >= 64, f"scaled_dot NYI for K < 64. Got {K=}"
  1442. B = lhs.type.shape[0] if lhs_rank == 3 else None
  1443. K = K_LHS
  1444. if not lhs_k_pack:
  1445. M = M * PACKED_A
  1446. else:
  1447. K = K * PACKED_A
  1448. if not rhs_k_pack:
  1449. N = N * PACKED_B
  1450. ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N])
  1451. _0 = self.builder.get_fp32(0)
  1452. if acc is None:
  1453. acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
  1454. else:
  1455. acc_handle = acc.handle
  1456. assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
  1457. rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
  1458. lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
  1459. self.verify_scaled_shape(M, N, K, None if lhs_scale_is_none else lhs_scale,
  1460. None if rhs_scale_is_none else rhs_scale)
  1461. return self.tensor(
  1462. self.builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle,
  1463. rhs_format_enum, fast_math, lhs_k_pack, rhs_k_pack, acc_handle), ret_ty)
  1464. # ===----------------------------------------------------------------------===//
  1465. # Indexing
  1466. # ===----------------------------------------------------------------------===//
  1467. def where(self, condition: TensorTy, x: TensorTy, y: TensorTy) -> TensorTy:
  1468. if condition.dtype != tl.int1:
  1469. warnings.warn(
  1470. f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}"
  1471. )
  1472. condition = self.cast(condition, tl.int1)
  1473. x, y = self.binary_op_type_checking_impl(x, y, True, True)
  1474. # x, y are broadcasted
  1475. if condition.type.is_block():
  1476. condition, x = self.broadcast_impl_value(condition, x)
  1477. x, y = self.broadcast_impl_value(x, y)
  1478. else:
  1479. condition, _ = self.broadcast_impl_value(condition, x)
  1480. ret_ty = x.type
  1481. return self.tensor(self.builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
  1482. # ===----------------------------------------------------------------------===//
  1483. # Reduction
  1484. # ===----------------------------------------------------------------------===
  1485. def wrap_tensor(self, x, scalar_ty, ret_shape):
  1486. if ret_shape:
  1487. res_ty = tl.block_type(scalar_ty, ret_shape)
  1488. else:
  1489. # 0d-tensor -> scalar
  1490. res_ty = scalar_ty
  1491. return self.tensor(x, res_ty)
  1492. def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) -> Tuple[TensorTy, ...]:
  1493. if axis is None:
  1494. inputs = tuple(self.reshape(t, [t.numel.value], can_reorder=True) for t in inputs)
  1495. axis = 0
  1496. # get result shape
  1497. shape = inputs[0].type.shape
  1498. rank = len(shape)
  1499. assert axis < rank, f"reduction axis must be < inputs rank ({rank})"
  1500. ret_shape = [s for i, s in enumerate(shape) if i != axis]
  1501. assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape"
  1502. reduce_op = self.builder.create_reduce([t.handle for t in inputs], axis)
  1503. region_builder_fn(reduce_op)
  1504. assert reduce_op.verify()
  1505. return tuple(
  1506. self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs)))
  1507. # ===----------------------------------------------------------------------===
  1508. # Associative Scan
  1509. # ===----------------------------------------------------------------------===
  1510. def associative_scan(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn,
  1511. reverse: bool) -> Tuple[TensorTy, ...]:
  1512. shape = inputs[0].type.shape
  1513. rank = len(shape)
  1514. assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})"
  1515. if axis < 0:
  1516. axis += rank
  1517. for t in inputs:
  1518. assert t.type.shape == shape, "all scan inputs must have the same shape"
  1519. scan_op = self.builder.create_scan([t.handle for t in inputs], axis, reverse)
  1520. region_builder_fn(scan_op)
  1521. assert scan_op.verify()
  1522. return tuple(self.wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs)))
  1523. # ===----------------------------------------------------------------------===
  1524. # Gather
  1525. # ===----------------------------------------------------------------------===
  1526. def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy:
  1527. assert index.dtype.is_int(), "index must be an integer tensor"
  1528. rank = len(src.type.shape)
  1529. assert len(index.type.shape) == rank, "source and index tensors must have the same rank"
  1530. assert -rank <= axis < rank, f"gather axis {axis} must be < source rank ({rank})"
  1531. if axis < 0:
  1532. axis += rank
  1533. for d in range(rank):
  1534. if d == axis:
  1535. continue
  1536. assert index.type.shape[d] == src.type.shape[d], f"index dim {axis} must match the corresponding source dim"
  1537. gather = self.builder.create_gather(src.handle, index.handle, axis)
  1538. return self.wrap_tensor(gather, src.type.scalar, index.type.shape)
  1539. # ===----------------------------------------------------------------------===
  1540. # Map Elementwise
  1541. # ===----------------------------------------------------------------------===
  1542. def broadcast_tensors(self, *inputs):
  1543. if not inputs:
  1544. return ()
  1545. head, *tail = inputs
  1546. for i in range(len(tail)):
  1547. head, tail[i] = self.broadcast_impl_value(head, tail[i])
  1548. for i in range(len(tail) - 1):
  1549. head, tail[i] = self.broadcast_impl_value(head, tail[i])
  1550. return (head, *tail)
  1551. def map_elementwise(self, inputs: Sequence[tl.tensor], result_types: Sequence[tl.dtype], pack: int,
  1552. region_builder_fn) -> Tuple[tl.tensor, ...]:
  1553. inputs = self.broadcast_tensors(*inputs)
  1554. assert len(inputs) > 0, "map_elementwise must have at least 1 input tensor"
  1555. result_types = [inputs[0].type.with_element_ty(ty.scalar) for ty in result_types]
  1556. elementwise_op = self.builder.create_map_elementwise(
  1557. [t.handle for t in inputs],
  1558. [ty.to_ir(self.builder) for ty in result_types],
  1559. pack,
  1560. )
  1561. region_builder_fn(elementwise_op)
  1562. assert elementwise_op.verify()
  1563. return tuple(self.tensor(elementwise_op.get_result(i), ty) for i, ty in enumerate(result_types))
  1564. # ===----------------------------------------------------------------------===
  1565. # Histogram
  1566. # ===----------------------------------------------------------------------===
  1567. def histogram(self, input: TensorTy, num_bins: int, mask: Optional[TensorTy]) -> TensorTy:
  1568. assert len(input.shape) == 1, "histogram only supports 1D input"
  1569. assert input.dtype.is_int(), "histogram only supports integer input"
  1570. if mask is not None:
  1571. mask = self.broadcast_impl_shape(mask, input.shape)
  1572. if not mask.type.scalar.is_bool():
  1573. raise ValueError("Mask must have boolean scalar type")
  1574. mask = mask.handle
  1575. return self.tensor(self.builder.create_histogram(input.handle, num_bins, mask),
  1576. tl.block_type(tl.int32, [num_bins]))
  1577. def multiple_of(self, x: TensorTy, values: List[int]) -> TensorTy:
  1578. if max(1, len(x.shape)) != len(values):
  1579. raise ValueError("Shape of input to multiple_of does not match the length of values")
  1580. x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
  1581. return x
  1582. def max_contiguous(self, x: TensorTy, values: List[int]) -> TensorTy:
  1583. if len(x.shape) != len(values):
  1584. raise ValueError("Shape of input to max_contiguous does not match the length of values")
  1585. x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
  1586. return x
  1587. def max_constancy(self, x: TensorTy, values: List[int]) -> TensorTy:
  1588. if len(x.shape) != len(values):
  1589. raise ValueError("Shape of input to max_constancy does not match the length of values")
  1590. x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context()))
  1591. return x
  1592. def debug_barrier(self) -> TensorTy:
  1593. return self.tensor(self.builder.create_barrier(), tl.void)
  1594. def device_print(self, prefix: str, args: List[TensorTy], hex: bool) -> TensorTy:
  1595. # It makes sense visually for prefix to end in ": "; make it so. Also,
  1596. # non-empty prefixes should start with " ".
  1597. if not prefix.endswith(" ") and args:
  1598. prefix += " "
  1599. if not prefix.endswith(": ") and args:
  1600. prefix = prefix[:-1] + ": "
  1601. if len(prefix) > 2 and not prefix.startswith(" "):
  1602. prefix = " " + prefix
  1603. new_args = [arg.handle for arg in args]
  1604. is_signed = [arg.dtype.is_int_signed() for arg in args]
  1605. return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void)
  1606. def device_assert(self, cond: TensorTy, msg: str, mask: Optional[TensorTy]) -> TensorTy:
  1607. if not self.builder.options.debug:
  1608. return
  1609. if mask is not None:
  1610. cond = self.or_(cond, self.not_(mask))
  1611. return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void)
  1612. def assume(self, cond) -> TensorTy:
  1613. return self.tensor(self.builder.create_assume(cond.handle), tl.void)
  1614. def _convert_elem_to_ir_value(self, elem, require_i64):
  1615. if isinstance(elem, int):
  1616. elem = tl.constexpr(elem)
  1617. if isinstance(elem, tl.constexpr):
  1618. if isinstance(elem.value, bool):
  1619. return self.builder.get_int1(elem.value)
  1620. if require_i64:
  1621. assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \
  1622. f"got a value {elem.value} which is out of the range"
  1623. return self.builder.get_int64(elem.value)
  1624. else:
  1625. assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \
  1626. f"got a value {elem.value} which is out of the range"
  1627. return self.builder.get_int32(elem.value)
  1628. elif isinstance(elem, tl.tensor):
  1629. assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets"
  1630. assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets"
  1631. if elem.dtype != tl.int64 and require_i64:
  1632. return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(),
  1633. elem.dtype.is_int_signed())
  1634. elif elem.dtype == tl.int64 and not require_i64:
  1635. assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
  1636. "add a `.to(tl.int32)` or use regular indexing for 64 bit support"
  1637. return elem.handle
  1638. assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}"
  1639. def _convert_to_ir_values(self, list_like, require_i64=True):
  1640. if hasattr(list_like, "__iter__"):
  1641. return [self._convert_elem_to_ir_value(elem, require_i64) for elem in list_like]
  1642. return [self._convert_elem_to_ir_value(list_like, require_i64)]
  1643. def make_block_ptr(self, base: TensorTy, shape, strides, offsets, block_shape, order) -> TensorTy:
  1644. # Convert dynamic arguments to IR values
  1645. # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t`
  1646. shape = self._convert_to_ir_values(shape)
  1647. strides = self._convert_to_ir_values(strides)
  1648. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  1649. # Check `base` type
  1650. if not base.type.is_ptr() or base.type.element_ty.is_block():
  1651. raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)")
  1652. # Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
  1653. if base.type.element_ty == tl.int1:
  1654. base = self.cast(base, tl.pointer_type(tl.int8, base.type.address_space))
  1655. # Check whether `block_shape` is static
  1656. if not hasattr(block_shape, "__iter__"):
  1657. block_shape = [block_shape]
  1658. block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape]
  1659. assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \
  1660. "Expected a list of constant integers (`int32_t` range) in `block_shape`"
  1661. # Check `order`
  1662. if not hasattr(order, "__iter__"):
  1663. order = [order]
  1664. order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order]
  1665. assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order"
  1666. # Must have same length
  1667. assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \
  1668. "Expected shape/strides/offsets/block_shape to have the same length"
  1669. # Build value, the type is:
  1670. # `pointer_type<blocked<shape, element_type>>` in Python
  1671. # `tt.ptr<tensor<shape, element_type>>` in MLIR
  1672. handle = self.builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order)
  1673. return self.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape)))
  1674. def advance(self, base: TensorTy, offsets) -> TensorTy:
  1675. # Convert dynamic offsets to IR values
  1676. offsets = self._convert_to_ir_values(offsets, require_i64=False)
  1677. # Advanced block pointer type is the same as before
  1678. return self.tensor(self.builder.create_advance(base.handle, offsets), base.type)
  1679. def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: List[TensorTy],
  1680. block_shape: List[tl.constexpr], padding_option: str = "zero") -> tl.tensor_descriptor:
  1681. ndim = len(shape)
  1682. if not (1 <= ndim <= 5):
  1683. raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions")
  1684. if len(strides) != ndim:
  1685. raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
  1686. if len(block_shape) != ndim:
  1687. raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
  1688. assert isinstance(base.dtype, tl.pointer_type)
  1689. elem_size = base.dtype.element_ty.primitive_bitwidth // 8
  1690. contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1])
  1691. if contig_dim_size * elem_size < 16:
  1692. raise ValueError(
  1693. f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes"
  1694. )
  1695. last_stride = tl._unwrap_if_constexpr(strides[-1])
  1696. if last_stride != 1:
  1697. raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}")
  1698. shape = [self.make_scalar(x, tl.int32) for x in shape]
  1699. strides = [self.make_scalar(tl._unwrap_if_constexpr(x), tl.int64) for x in strides]
  1700. # Check whether `block_shape` is static
  1701. block_shape = tl._unwrap_shape(block_shape)
  1702. assert isinstance(base.type, tl.pointer_type)
  1703. type = tl.block_type(base.type.element_ty, block_shape)
  1704. base_handle = base.handle
  1705. is_signed_int = base.type.element_ty.is_int_signed()
  1706. padding = self._str_to_padding_option(padding_option)
  1707. if base.type.element_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
  1708. raise ValueError("Padding option `nan` is not supported for integer blocks")
  1709. handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape],
  1710. [s.handle for s in strides], block_shape, is_signed_int,
  1711. padding)
  1712. return tl.tensor_descriptor(handle, shape, strides, type)