core.py 120 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490
  1. from __future__ import annotations
  2. import math
  3. from warnings import warn
  4. from contextlib import contextmanager
  5. from enum import Enum
  6. from functools import partial, wraps
  7. import typing
  8. from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
  9. from dataclasses import dataclass
  10. import builtins
  11. from .. import knobs
  12. from ..runtime.jit import JITCallable
  13. import inspect
  14. from .._C.libtriton import ir
  15. from .._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth
  16. T = TypeVar('T')
  17. TRITON_BUILTIN = "__triton_builtin__"
  18. PropagateNan = ir.PROPAGATE_NAN
  19. def must_use_result(x, s=True):
  20. """If the result of this function is unused, throw an error."""
  21. if isinstance(x, str):
  22. return (lambda fn: must_use_result(fn, x))
  23. x._must_use_result = s
  24. return x
  25. def builtin(fn: T) -> T:
  26. """Mark a function as a builtin."""
  27. assert callable(fn)
  28. @wraps(fn)
  29. def wrapper(*args, **kwargs):
  30. if "_semantic" not in kwargs or kwargs["_semantic"] is None:
  31. raise ValueError("Did you forget to add @triton.jit ? "
  32. "(`_semantic` argument must be provided outside of JIT functions.)")
  33. return fn(*args, **kwargs)
  34. setattr(wrapper, TRITON_BUILTIN, True)
  35. return wrapper
  36. def _tensor_member_fn(fn: T) -> T:
  37. """Decorator that adds this free function as a member fn on class tensor.
  38. When called as a member function on class tensor, the first argument to `fn`
  39. is `self`, i.e. the tensor object.
  40. If there are multiple decorators on a function, you probably want this one
  41. to be the highest one (i.e. furthest from the function's `def`), so it's
  42. applied last.
  43. Unfortunately you still need to add a type stub to the body of class tensor
  44. in order for pytype to know about it.
  45. """
  46. assert callable(fn)
  47. orig_sig = inspect.signature(fn)
  48. # Does fn take args other than _semantic, _generator, and the tensor itself?
  49. has_args = len(orig_sig.parameters.keys() - {"_semantic", "_generator"}) > 1
  50. if not fn.__doc__:
  51. fn.__doc__ = ""
  52. fn.__doc__ += f"""
  53. This function can also be called as a member function on :py:class:`tensor`,
  54. as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of
  55. :code:`{fn.__name__}(x{", ..." if has_args else ""})`.
  56. """
  57. def wrapper(*args, **kwargs):
  58. return fn(*args, **kwargs)
  59. # Match the signature of `fn`, but change the first arg to `self` so the
  60. # docs are a little less weird.
  61. new_params = list(orig_sig.parameters.values())
  62. new_params[0] = new_params[0].replace(name='self')
  63. new_sig = orig_sig.replace(parameters=new_params)
  64. wrapper.__signature__ = new_sig
  65. wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function"
  66. # If fn is a builtin, mark the wrapper as a builtin too.
  67. if is_builtin(fn):
  68. setattr(wrapper, TRITON_BUILTIN, True)
  69. setattr(tensor, fn.__name__, fn if isinstance(fn, JITCallable) else wrapper)
  70. return fn
  71. def _unwrap_iterable(x):
  72. """Returns x[0] if x has one element and x[0] is iterable."""
  73. if len(x) == 1:
  74. # Determine whether x[0] is iterable.
  75. #
  76. # You might want to use collections.abc.Iterable instead of this
  77. # try/except block. Unfortunately, this doesn't work with constexpr.
  78. #
  79. # The problem is that abc.Iterable checks for __iter__ on the *class*.
  80. # But we want constexpr to expose an __iter__ method if and only if the
  81. # wrapped *object* (i.e. self.value) is iterable. Therefore there's no
  82. # right answer for whether the class constexpr defines __iter__, and
  83. # abc.Iterable doesn't work (at least not without some metaclass magic).
  84. try:
  85. iter(x[0])
  86. return x[0]
  87. except TypeError:
  88. pass
  89. return x
  90. def is_builtin(fn) -> bool:
  91. """Is this a registered triton builtin function?"""
  92. return getattr(fn, TRITON_BUILTIN, False)
  93. @builtin
  94. def to_tensor(x, _semantic=None):
  95. return _semantic.to_tensor(x)
  96. # -----------------------
  97. # constexpr
  98. # -----------------------
  99. class const:
  100. """
  101. This class is used as a type annotation to mark pointers to constant data.
  102. The `store` function cannot be called with a pointer to const. Constness
  103. is part of the pointer type and the usual Triton type consistency rules
  104. apply. For example you cannot have a function that returns constant pointer
  105. in one return statement and non-constant pointer in another.
  106. """
  107. pass
  108. class base_value:
  109. """Base class of values that exist in the triton IR (i.e. not constexprs).
  110. """
  111. type: base_type
  112. def _flatten_ir(self, handles: List[ir.value]) -> None:
  113. """Flatten frontend value into a sequence of mlir handles, which are appended
  114. to the output list
  115. """
  116. raise NotImplementedError
  117. class base_type:
  118. def __eq__(self, other) -> bool:
  119. raise NotImplementedError("Types must implement __eq__")
  120. def __ne__(self, other) -> bool:
  121. return not (self == other)
  122. def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
  123. """Build a frontend value with the current dtype, wrapping a list of existing handles.
  124. cursor is the index of the first handle relevant to this value, and the function
  125. should return the updated cursor position after any handles consumed by the created value.
  126. """
  127. raise NotImplementedError
  128. def mangle(self) -> str:
  129. raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}")
  130. def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
  131. raise NotImplementedError
  132. class constexpr_type(base_type):
  133. def __init__(self, value):
  134. self.value = value
  135. def __eq__(self, other):
  136. return isinstance(other, constexpr_type) and self.value == other.value
  137. def __repr__(self) -> str:
  138. return f"constexpr_type[{self.value}]"
  139. def __hash__(self):
  140. return hash(self.value)
  141. def mangle(self) -> str:
  142. return repr(self)
  143. def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
  144. return
  145. def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
  146. return constexpr(self.value), cursor
  147. class constexpr(base_value):
  148. """
  149. This class is used to store a value that is known at compile-time.
  150. """
  151. def __init__(self, value):
  152. while isinstance(value, constexpr):
  153. value = value.value
  154. self.value = value
  155. self.type = constexpr_type(value)
  156. def __repr__(self) -> str:
  157. return f"constexpr[{self.value}]"
  158. def __hash__(self):
  159. return hash((self.value, self.type))
  160. def _flatten_ir(self, handles: List[ir.value]) -> None:
  161. return
  162. def __index__(self):
  163. return self.value
  164. # In interpreter mode, constant values are not wrapped in constexpr,
  165. # and therefore do not have a .value attribute.
  166. # As a result, from here and below, we need to call the _unwrap_if_constexpr
  167. # function to obtain either constexpr.value or the value itself.
  168. def __add__(self, other):
  169. return constexpr(self.value + _unwrap_if_constexpr(other))
  170. def __radd__(self, other):
  171. return constexpr(_unwrap_if_constexpr(other) + self.value)
  172. def __sub__(self, other):
  173. return constexpr(self.value - _unwrap_if_constexpr(other))
  174. def __rsub__(self, other):
  175. return constexpr(_unwrap_if_constexpr(other) - self.value)
  176. def __mul__(self, other):
  177. return constexpr(self.value * _unwrap_if_constexpr(other))
  178. def __mod__(self, other):
  179. return constexpr(self.value % _unwrap_if_constexpr(other))
  180. def __rmul__(self, other):
  181. return constexpr(_unwrap_if_constexpr(other) * self.value)
  182. def __truediv__(self, other):
  183. return constexpr(self.value / _unwrap_if_constexpr(other))
  184. def __rtruediv__(self, other):
  185. return constexpr(_unwrap_if_constexpr(other) / self.value)
  186. def __floordiv__(self, other):
  187. return constexpr(self.value // _unwrap_if_constexpr(other))
  188. def __rfloordiv__(self, other):
  189. return constexpr(_unwrap_if_constexpr(other) // self.value)
  190. def __gt__(self, other):
  191. return constexpr(self.value > _unwrap_if_constexpr(other))
  192. def __rgt__(self, other):
  193. return constexpr(_unwrap_if_constexpr(other) > self.value)
  194. def __ge__(self, other):
  195. return constexpr(self.value >= _unwrap_if_constexpr(other))
  196. def __rge__(self, other):
  197. return constexpr(_unwrap_if_constexpr(other) >= self.value)
  198. def __lt__(self, other):
  199. return constexpr(self.value < _unwrap_if_constexpr(other))
  200. def __rlt__(self, other):
  201. return constexpr(_unwrap_if_constexpr(other) < self.value)
  202. def __le__(self, other):
  203. return constexpr(self.value <= _unwrap_if_constexpr(other))
  204. def __rle__(self, other):
  205. return constexpr(_unwrap_if_constexpr(other) <= self.value)
  206. def __eq__(self, other):
  207. return constexpr(self.value == _unwrap_if_constexpr(other))
  208. def __ne__(self, other):
  209. return constexpr(self.value != _unwrap_if_constexpr(other))
  210. def __bool__(self):
  211. return bool(self.value)
  212. def __neg__(self):
  213. return constexpr(-self.value)
  214. def __and__(self, other):
  215. return constexpr(self.value & _unwrap_if_constexpr(other))
  216. def logical_and(self, other):
  217. return constexpr(self.value and _unwrap_if_constexpr(other))
  218. def __or__(self, other):
  219. return constexpr(self.value | _unwrap_if_constexpr(other))
  220. def __xor__(self, other):
  221. return constexpr(self.value ^ _unwrap_if_constexpr(other))
  222. def logical_or(self, other):
  223. return constexpr(self.value or _unwrap_if_constexpr(other))
  224. def __pos__(self):
  225. return constexpr(+self.value)
  226. def __invert__(self):
  227. return constexpr(~self.value)
  228. def __pow__(self, other):
  229. return constexpr(self.value**_unwrap_if_constexpr(other))
  230. def __rpow__(self, other):
  231. return constexpr(_unwrap_if_constexpr(other)**self.value)
  232. def __rshift__(self, other):
  233. return constexpr(self.value >> _unwrap_if_constexpr(other))
  234. def __lshift__(self, other):
  235. return constexpr(self.value << _unwrap_if_constexpr(other))
  236. def __not__(self):
  237. return constexpr(not self.value)
  238. def __iter__(self):
  239. return iter(self.value)
  240. def __call__(self, *args, **kwds):
  241. return self.value(*args, **kwds)
  242. def __getitem__(self, *args):
  243. args = (_unwrap_if_constexpr(x) for x in _normalize_tuple(args))
  244. return self.value.__getitem__(*args)
  245. CONSTEXPR_0 = constexpr(0)
  246. def _unwrap_if_constexpr(o):
  247. if isinstance(o, list):
  248. return [_unwrap_if_constexpr(x) for x in o]
  249. if isinstance(o, builtins.tuple):
  250. return builtins.tuple(_unwrap_if_constexpr(x) for x in o)
  251. if isinstance(o, tuple):
  252. return tuple(_unwrap_if_constexpr(x) for x in o)
  253. return o.value if isinstance(o, constexpr) else o
  254. def _normalize_tuple(t):
  255. normalized_tuple = _unwrap_if_constexpr(t)
  256. if isinstance(normalized_tuple, (list, builtins.tuple)):
  257. normalized_tuple = tuple(normalized_tuple)
  258. return normalized_tuple
  259. def check_bit_width(value, shift_value):
  260. if isinstance(value, tensor) and isinstance(shift_value, constexpr):
  261. bitwidth = value.type.scalar.primitive_bitwidth
  262. if shift_value.value >= bitwidth:
  263. warn(
  264. f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior."
  265. )
  266. # -----------------------
  267. # dtype
  268. # -----------------------
  269. class dtype(base_type):
  270. SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
  271. UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
  272. FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
  273. STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
  274. OTHER_TYPES = ['void']
  275. class SIGNEDNESS(Enum):
  276. SIGNED = 0
  277. UNSIGNED = 1
  278. class KIND(Enum):
  279. BOOLEAN = 0
  280. INTEGRAL = 1
  281. FLOATING = 2
  282. def __init__(self, name):
  283. name = _unwrap_if_constexpr(name)
  284. self.name = name
  285. assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
  286. self.primitive_bitwidth = get_primitive_bitwidth(name)
  287. self.itemsize = self.primitive_bitwidth // 8
  288. if name in dtype.SINT_TYPES:
  289. self.int_signedness = dtype.SIGNEDNESS.SIGNED
  290. self.int_bitwidth = self.primitive_bitwidth
  291. elif name in dtype.UINT_TYPES:
  292. self.int_signedness = dtype.SIGNEDNESS.UNSIGNED
  293. self.int_bitwidth = self.primitive_bitwidth
  294. elif name in dtype.FP_TYPES:
  295. if name == 'fp8e4b15':
  296. self.fp_mantissa_width = 3
  297. self.exponent_bias = 15
  298. elif name == 'fp8e4nv':
  299. self.fp_mantissa_width = 3
  300. self.exponent_bias = 7
  301. elif name == 'fp8e4b8':
  302. self.fp_mantissa_width = 3
  303. self.exponent_bias = 8
  304. elif name == 'fp8e5':
  305. self.fp_mantissa_width = 2
  306. self.exponent_bias = 15
  307. elif name == 'fp8e5b16':
  308. self.fp_mantissa_width = 2
  309. self.exponent_bias = 16
  310. elif name == 'fp16':
  311. self.fp_mantissa_width = 10
  312. self.exponent_bias = 15
  313. elif name == 'bf16':
  314. self.fp_mantissa_width = 7
  315. self.exponent_bias = 127
  316. elif name == 'fp32':
  317. self.fp_mantissa_width = 23
  318. self.exponent_bias = 127
  319. elif name == 'fp64':
  320. self.fp_mantissa_width = 52
  321. self.exponent_bias = 1023
  322. else:
  323. raise RuntimeError(f'Unsupported floating-point type {name}')
  324. def is_fp8(self):
  325. return 'fp8' in self.name
  326. def is_fp8e4nv(self):
  327. return self.name == 'fp8e4nv'
  328. def is_fp8e4b8(self):
  329. return self.name == 'fp8e4b8'
  330. def is_fp8e4b15(self):
  331. return self.name == 'fp8e4b15'
  332. def is_fp8e5(self):
  333. return self.name == 'fp8e5'
  334. def is_fp8e5b16(self):
  335. return self.name == 'fp8e5b16'
  336. def is_fp16(self):
  337. return self.name == 'fp16'
  338. def is_bf16(self):
  339. return self.name == 'bf16'
  340. def is_fp32(self):
  341. return self.name == 'fp32'
  342. def is_fp64(self):
  343. return self.name == 'fp64'
  344. def is_int1(self):
  345. return self.name == 'int1'
  346. def is_int8(self):
  347. return self.name == 'int8'
  348. def is_int16(self):
  349. return self.name == 'int16'
  350. def is_int32(self):
  351. return self.name == 'int32'
  352. def is_int64(self):
  353. return self.name == 'int64'
  354. def is_uint8(self):
  355. return self.name == 'uint8'
  356. def is_uint16(self):
  357. return self.name == 'uint16'
  358. def is_uint32(self):
  359. return self.name == 'uint32'
  360. def is_uint64(self):
  361. return self.name == 'uint64'
  362. def is_floating(self):
  363. return self.name in dtype.FP_TYPES
  364. def is_standard_floating(self):
  365. return self.name in dtype.STANDARD_FP_TYPES
  366. def is_int_signed(self):
  367. return self.name in dtype.SINT_TYPES
  368. def is_int_unsigned(self):
  369. return self.name in dtype.UINT_TYPES
  370. def is_int(self):
  371. return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES
  372. def is_bool(self):
  373. return self.is_int1()
  374. def kind(self):
  375. # Return int value following the type ordering bool < integer < fp
  376. if self.is_bool():
  377. return dtype.KIND.BOOLEAN
  378. elif self.is_int():
  379. return dtype.KIND.INTEGRAL
  380. else:
  381. assert self.is_floating()
  382. return dtype.KIND.FLOATING
  383. def get_int_max_value(self):
  384. if self.is_int_signed():
  385. return 2**(self.int_bitwidth - 1) - 1
  386. if self.is_int_unsigned():
  387. return 2**self.int_bitwidth - 1
  388. assert False
  389. def get_int_min_value(self):
  390. if self.is_int_signed():
  391. return -2**(self.int_bitwidth - 1)
  392. if self.is_int_unsigned():
  393. return 0
  394. assert False
  395. @staticmethod
  396. def is_dtype(type_str):
  397. return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES
  398. @staticmethod
  399. def is_void():
  400. raise RuntimeError("Not implemented")
  401. @staticmethod
  402. def is_block():
  403. return False
  404. @staticmethod
  405. def is_ptr():
  406. return False
  407. @staticmethod
  408. def is_const():
  409. return False
  410. def __eq__(self, other) -> bool:
  411. other = _unwrap_if_constexpr(other)
  412. if not isinstance(other, dtype):
  413. return False
  414. return self.name == other.name
  415. def __hash__(self):
  416. return hash((self.name, ))
  417. @property
  418. def scalar(self):
  419. return self
  420. def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
  421. out.append(self.to_ir(builder))
  422. def to_ir(self, builder: ir.builder) -> ir.type:
  423. if self.name.startswith("fp8"):
  424. if hasattr(builder, "options") and self.name not in builder.options.supported_fp8_dtypes:
  425. raise ValueError(f'type {self} not supported in this architecture. '
  426. f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}')
  427. if self.name == 'void':
  428. return builder.get_void_ty()
  429. elif self.name == 'int1':
  430. return builder.get_int1_ty()
  431. elif self.name in ('int8', 'uint8'):
  432. return builder.get_int8_ty()
  433. elif self.name in ('int16', 'uint16'):
  434. return builder.get_int16_ty()
  435. elif self.name in ('int32', 'uint32'):
  436. return builder.get_int32_ty()
  437. elif self.name in ('int64', 'uint64'):
  438. return builder.get_int64_ty()
  439. elif self.name == 'fp8e5':
  440. return builder.get_fp8e5_ty()
  441. elif self.name == 'fp8e5b16':
  442. return builder.get_fp8e5b16_ty()
  443. elif self.name == 'fp8e4nv':
  444. return builder.get_fp8e4nv_ty()
  445. elif self.name == 'fp8e4b8':
  446. return builder.get_fp8e4b8_ty()
  447. elif self.name == 'fp8e4b15':
  448. return builder.get_fp8e4b15_ty()
  449. elif self.name == 'fp16':
  450. return builder.get_half_ty()
  451. elif self.name == 'bf16':
  452. return builder.get_bf16_ty()
  453. elif self.name == 'fp32':
  454. return builder.get_float_ty()
  455. elif self.name == 'fp64':
  456. return builder.get_double_ty()
  457. raise ValueError(f'fail to convert {self} to ir type')
  458. def __str__(self):
  459. return self.name
  460. def codegen_name(self):
  461. if self.name.startswith("fp"):
  462. return "float" + self.name[2:]
  463. elif self.name.startswith("bf"):
  464. return "bfloat" + self.name[2:]
  465. else:
  466. return self.name
  467. @property
  468. def cache_key_part(self) -> str:
  469. """See cache_key_part() in triton.cc."""
  470. return self.name
  471. def __repr__(self):
  472. """Output of repr needs to be an evaluatable expression"""
  473. return f'triton.language.{self.codegen_name()}'
  474. def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
  475. return tensor(handles[cursor], self), cursor + 1
  476. def mangle(self) -> str:
  477. if self.is_int():
  478. SIGNED = dtype.SIGNEDNESS.SIGNED
  479. prefix = 'i' if self.int_signedness == SIGNED else 'u'
  480. return prefix + str(self.int_bitwidth)
  481. if self.is_floating():
  482. return str(self)
  483. if self.is_void():
  484. return 'V'
  485. return super().mangle()
  486. def with_element_ty(self, element_ty: dtype):
  487. assert not self.is_block()
  488. return element_ty
  489. # Some functions have a param named `dtype`, which shadows the `dtype` class.
  490. # We can't change the param name because it is part of function's public API.
  491. # Declare an alias so those functions can still reference the dtype class.
  492. _DtypeClass = dtype
  493. class pointer_type(dtype):
  494. def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False):
  495. element_ty = _unwrap_if_constexpr(element_ty)
  496. if not isinstance(element_ty, dtype):
  497. raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.')
  498. self.element_ty = element_ty
  499. self.address_space = address_space
  500. self.const = const
  501. self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>'
  502. def to_ir(self, builder: ir.builder) -> ir.pointer_type:
  503. return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space)
  504. def __str__(self):
  505. return self.name
  506. def __repr__(self):
  507. return self.__str__()
  508. def is_ptr(self):
  509. return True
  510. def is_const(self):
  511. return self.const
  512. def __eq__(self, other) -> bool:
  513. other = _unwrap_if_constexpr(other)
  514. if not isinstance(other, pointer_type):
  515. return False
  516. return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
  517. @property
  518. def scalar(self):
  519. return self
  520. def mangle(self) -> str:
  521. return f"P{self.element_ty.mangle()}"
  522. class block_type(dtype):
  523. def __init__(self, element_ty: dtype, shape: List):
  524. self.element_ty = element_ty
  525. # Note that block_type's shape is a list of int
  526. # while tensor's shape is a list of constexpr.
  527. assert (isinstance(shape, (list, tuple)))
  528. # shape can be empty ([]) when an input is a 0D tensor.
  529. self.shape = tuple(_unwrap_shape(shape))
  530. if not self.shape:
  531. raise TypeError('0d block_type is forbidden')
  532. self.numel = validate_block_shape(self.shape)
  533. self.name = f'<{self.shape}, {self.element_ty}>'
  534. def to_ir(self, builder: ir.builder) -> ir.block_type:
  535. return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape)
  536. def __str__(self):
  537. return self.name
  538. def __repr__(self):
  539. return self.__str__()
  540. def is_block(self):
  541. return True
  542. def get_block_shapes(self) -> Tuple[int]:
  543. return self.shape
  544. def with_element_ty(self, scalar_ty: dtype) -> block_type:
  545. return block_type(scalar_ty, self.shape)
  546. def __eq__(self, other) -> bool:
  547. if not isinstance(other, block_type):
  548. return False
  549. return self.element_ty == other.element_ty and self.shape == other.shape
  550. @property
  551. def scalar(self):
  552. return self.element_ty
  553. @property
  554. def nbytes(self):
  555. return self.numel * (self.element_ty.primitive_bitwidth // 8)
  556. def mangle(self) -> str:
  557. elt = self.scalar.mangle()
  558. shape = '_'.join(map(str, self.shape))
  559. return f'{elt}S{shape}S'
  560. class tuple_type(base_type):
  561. def __init__(self, types, fields=None):
  562. self.types = types
  563. self.fields = fields or [''] * len(types)
  564. self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']'
  565. def __str__(self):
  566. return self.name
  567. def __iter__(self):
  568. return iter(self.types)
  569. def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]):
  570. for ty in self.types:
  571. if not isinstance(ty, constexpr):
  572. ty._flatten_ir_types(builder, out)
  573. def __getitem__(self, index: int) -> dtype:
  574. return self.types[index]
  575. def __eq__(self, other):
  576. return type(self) is type(other) and self.types == other.types and self.fields == other.fields
  577. def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]:
  578. values = []
  579. for ty in self.types:
  580. value, cursor = ty._unflatten_ir(handles, cursor)
  581. values.append(value)
  582. return tuple(values, self), cursor
  583. def mangle(self):
  584. return 'T' + '_'.join(ty.mangle() for ty in self.types) + 'T'
  585. class slice_type(dtype):
  586. def __init__(self):
  587. self.name = 'slice_type'
  588. # scalar types
  589. void = dtype('void')
  590. int1 = dtype('int1')
  591. int8 = dtype('int8')
  592. int16 = dtype('int16')
  593. int32 = dtype('int32')
  594. int64 = dtype('int64')
  595. uint8 = dtype('uint8')
  596. uint16 = dtype('uint16')
  597. uint32 = dtype('uint32')
  598. uint64 = dtype('uint64')
  599. float8e5 = dtype('fp8e5')
  600. float8e5b16 = dtype('fp8e5b16')
  601. float8e4nv = dtype('fp8e4nv')
  602. float8e4b8 = dtype('fp8e4b8')
  603. float8e4b15 = dtype('fp8e4b15')
  604. float16 = dtype('fp16')
  605. bfloat16 = dtype('bf16')
  606. float32 = dtype('fp32')
  607. float64 = dtype('fp64')
  608. # pointer types
  609. pi32_t = pointer_type(int32)
  610. def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
  611. if bitwidth == 1:
  612. return int1
  613. elif bitwidth == 8 and signed:
  614. return int8
  615. elif bitwidth == 8 and not signed:
  616. return uint8
  617. elif bitwidth == 16 and signed:
  618. return int16
  619. elif bitwidth == 16 and not signed:
  620. return uint16
  621. elif bitwidth == 32 and signed:
  622. return int32
  623. elif bitwidth == 32 and not signed:
  624. return uint32
  625. elif bitwidth == 64 and signed:
  626. return int64
  627. elif bitwidth == 64 and not signed:
  628. return uint64
  629. else:
  630. raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')
  631. # -----------------------
  632. # tensor
  633. # -----------------------
  634. class tensor(base_value):
  635. """Represents an N-dimensional array of values or pointers.
  636. :code:`tensor` is the fundamental data structure in Triton programs. Most
  637. functions in :py:mod:`triton.language` operate on and return tensors.
  638. Most of the named member functions here are duplicates of the free functions
  639. in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is
  640. equivalent to :code:`x.sqrt()`.
  641. :code:`tensor` also defines most of the magic/dunder methods, so you can
  642. write :code:`x+y`, :code:`x << 2`, etc.
  643. .. rubric:: Constructors
  644. ..
  645. For some reason Sphinx includes __init__ before printing the full table
  646. of methods. Not what I want, but I can't figure out how to fix it. Give
  647. it its own section so it looks intentional. :)
  648. """
  649. def __init__(self, handle, type: dtype):
  650. """Not called by user code."""
  651. super().__init__()
  652. # IR handle
  653. self.handle = handle
  654. # Block shape
  655. self.shape = type.shape if type.is_block() else ()
  656. self.numel = constexpr(math.prod(self.shape))
  657. self.type = type # Tensor type (can be block_type)
  658. # Following the practice in pytorch, dtype is scalar type
  659. self.dtype = type.scalar
  660. self.shape = tuple([constexpr(s) for s in self.shape])
  661. def _flatten_ir(self, handles: List[ir.value]) -> None:
  662. handles.append(self.handle)
  663. def __str__(self) -> str:
  664. # ex. "float32[16, 32]"
  665. return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
  666. @builtin
  667. def __add__(self, other, _semantic=None):
  668. return add(self, other, sanitize_overflow=True, _semantic=_semantic)
  669. @builtin
  670. def __radd__(self, other, _semantic=None):
  671. return add(other, self, sanitize_overflow=True, _semantic=_semantic)
  672. @builtin
  673. def __sub__(self, other, _semantic=None):
  674. return sub(self, other, sanitize_overflow=True, _semantic=_semantic)
  675. @builtin
  676. def __rsub__(self, other, _semantic=None):
  677. return sub(other, self, sanitize_overflow=True, _semantic=_semantic)
  678. @builtin
  679. def __mul__(self, other, _semantic=None):
  680. return mul(self, other, sanitize_overflow=True, _semantic=_semantic)
  681. @builtin
  682. def __rmul__(self, other, _semantic=None):
  683. return mul(other, self, sanitize_overflow=True, _semantic=_semantic)
  684. @builtin
  685. def __truediv__(self, other, _semantic=None):
  686. other = _unwrap_if_constexpr(other)
  687. return _semantic.truediv(self, other)
  688. @builtin
  689. def __rtruediv__(self, other, _semantic=None):
  690. other = _unwrap_if_constexpr(other)
  691. return _semantic.truediv(other, self)
  692. @builtin
  693. def __floordiv__(self, other, _semantic=None):
  694. other = _unwrap_if_constexpr(other)
  695. return _semantic.floordiv(self, other)
  696. @builtin
  697. def __rfloordiv__(self, other, _semantic=None):
  698. other = _unwrap_if_constexpr(other)
  699. return _semantic.floordiv(other, self)
  700. @builtin
  701. def __mod__(self, other, _semantic=None):
  702. other = _unwrap_if_constexpr(other)
  703. return _semantic.mod(self, other)
  704. @builtin
  705. def __rmod__(self, other, _semantic=None):
  706. other = _unwrap_if_constexpr(other)
  707. return _semantic.mod(other, self)
  708. # unary operators
  709. @builtin
  710. def __neg__(self, _semantic=None):
  711. return _semantic.minus(self)
  712. @builtin
  713. def __invert__(self, _semantic=None):
  714. return _semantic.invert(self)
  715. # bitwise operators
  716. @builtin
  717. def __and__(self, other, _semantic=None):
  718. other = _unwrap_if_constexpr(other)
  719. return _semantic.and_(self, other)
  720. @builtin
  721. def __rand__(self, other, _semantic=None):
  722. other = _unwrap_if_constexpr(other)
  723. return _semantic.and_(other, self)
  724. @builtin
  725. def __or__(self, other, _semantic=None):
  726. other = _unwrap_if_constexpr(other)
  727. return _semantic.or_(self, other)
  728. @builtin
  729. def __ror__(self, other, _semantic=None):
  730. other = _unwrap_if_constexpr(other)
  731. return _semantic.or_(other, self)
  732. @builtin
  733. def __xor__(self, other, _semantic=None):
  734. other = _unwrap_if_constexpr(other)
  735. return _semantic.xor_(self, other)
  736. @builtin
  737. def __rxor__(self, other, _semantic=None):
  738. other = _unwrap_if_constexpr(other)
  739. return _semantic.xor_(other, self)
  740. @builtin
  741. def __lshift__(self, other, _semantic=None):
  742. check_bit_width(self, other)
  743. other = _unwrap_if_constexpr(other)
  744. return _semantic.shl(self, other)
  745. @builtin
  746. def __rlshift__(self, other, _semantic=None):
  747. check_bit_width(other, self)
  748. other = _unwrap_if_constexpr(other)
  749. return _semantic.shl(other, self)
  750. @builtin
  751. def __rshift__(self, other, _semantic=None):
  752. check_bit_width(self, other)
  753. other = _unwrap_if_constexpr(other)
  754. if self.dtype.is_int_signed():
  755. return _semantic.ashr(self, other)
  756. else:
  757. return _semantic.lshr(self, other)
  758. @builtin
  759. def __rrshift__(self, other, _semantic=None):
  760. check_bit_width(other, self)
  761. other = _unwrap_if_constexpr(other)
  762. if self.dtype.is_int_signed():
  763. return _semantic.ashr(other, self)
  764. else:
  765. return _semantic.lshr(other, self)
  766. # >
  767. @builtin
  768. def __gt__(self, other, _semantic=None):
  769. other = _semantic.to_tensor(other)
  770. return _semantic.greater_than(self, other)
  771. @builtin
  772. def __rgt__(self, other, _semantic=None):
  773. other = _semantic.to_tensor(other)
  774. return _semantic.greater_than(other, self)
  775. # >=
  776. @builtin
  777. def __ge__(self, other, _semantic=None):
  778. other = _semantic.to_tensor(other)
  779. return _semantic.greater_equal(self, other)
  780. @builtin
  781. def __rge__(self, other, _semantic=None):
  782. other = _semantic.to_tensor(other)
  783. return _semantic.greater_equal(other, self)
  784. # <
  785. @builtin
  786. def __lt__(self, other, _semantic=None):
  787. other = _semantic.to_tensor(other)
  788. return _semantic.less_than(self, other)
  789. @builtin
  790. def __rlt__(self, other, _semantic=None):
  791. other = _semantic.to_tensor(other)
  792. return _semantic.less_than(other, self)
  793. # <=
  794. @builtin
  795. def __le__(self, other, _semantic=None):
  796. other = _semantic.to_tensor(other)
  797. return _semantic.less_equal(self, other)
  798. @builtin
  799. def __rle__(self, other, _semantic=None):
  800. other = _semantic.to_tensor(other)
  801. return _semantic.less_equal(other, self)
  802. # ==
  803. @builtin
  804. def __eq__(self, other, _semantic=None):
  805. other = _semantic.to_tensor(other)
  806. return _semantic.equal(self, other)
  807. @builtin
  808. def __req__(self, other, _semantic=None):
  809. other = _semantic.to_tensor(other)
  810. return _semantic.equal(other, self)
  811. @builtin
  812. def __ne__(self, other, _semantic=None):
  813. other = _semantic.to_tensor(other)
  814. return _semantic.not_equal(self, other)
  815. @builtin
  816. def __rne__(self, other, _semantic=None):
  817. other = _semantic.to_tensor(other)
  818. return _semantic.not_equal(other, self)
  819. @builtin
  820. def logical_and(self, other, _semantic=None):
  821. other = _semantic.to_tensor(other)
  822. return _semantic.logical_and(self, other)
  823. @builtin
  824. def logical_or(self, other, _semantic=None):
  825. other = _semantic.to_tensor(other)
  826. return _semantic.logical_or(self, other)
  827. # note: __not__ isn't actually a magic method in python
  828. # but it's ok because our ASTVisitor handles it
  829. @builtin
  830. def __not__(self, _semantic=None):
  831. return _semantic.not_(self)
  832. @builtin
  833. def __getitem__(self, slices, _semantic=None):
  834. if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
  835. slices = [slices]
  836. if isinstance(slices, tuple):
  837. slices = slices.values
  838. ret = self
  839. for dim, sl in enumerate(slices):
  840. if _unwrap_if_constexpr(sl) is None:
  841. ret = _semantic.expand_dims(ret, dim)
  842. elif isinstance(sl, (builtins.slice, slice)) and all(
  843. _unwrap_if_constexpr(arg) is None for arg in (sl.start, sl.stop, sl.step)):
  844. pass # an unsqueeze
  845. else:
  846. raise ValueError(f"unsupported tensor index: {sl}")
  847. return ret
  848. @property
  849. def T(self):
  850. """Transposes a 2D tensor."""
  851. assert False, "Transposition must be created by the AST Visitor"
  852. @builtin
  853. def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
  854. """
  855. Alias for :py:func:`tensor.cast`.
  856. """
  857. return cast(self, dtype, fp_downcast_rounding, bitcast, _semantic=_semantic)
  858. # Type stubs for functions added by the _tensor_member_fn decorator.
  859. # (Unfortunately these can't be created automatically.)
  860. #
  861. # We couldn't write these definitions out even if we wanted to, because some
  862. # of these functions are defined in standard.py.
  863. def broadcast_to(self, *shape) -> tensor:
  864. ...
  865. def trans(self, *dims) -> tensor:
  866. ...
  867. def permute(self, *dims) -> tensor:
  868. ...
  869. def split(self) -> tuple[tensor, tensor]:
  870. ...
  871. def view(self, *shape) -> tensor:
  872. ...
  873. def reshape(self, *shape) -> tensor:
  874. ...
  875. def expand_dims(self, axis) -> tensor:
  876. ...
  877. def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor:
  878. ...
  879. def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor:
  880. ...
  881. def advance(self, offsets) -> tensor:
  882. ...
  883. def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor:
  884. ...
  885. def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor:
  886. ...
  887. def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor:
  888. ...
  889. def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor:
  890. ...
  891. def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor:
  892. ...
  893. def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor:
  894. ...
  895. def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor:
  896. ...
  897. def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor:
  898. ...
  899. def exp(self) -> tensor:
  900. ...
  901. def log(self) -> tensor:
  902. ...
  903. def cos(self) -> tensor:
  904. ...
  905. def sin(self) -> tensor:
  906. ...
  907. def sqrt(self) -> tensor:
  908. ...
  909. def rsqrt(self) -> tensor:
  910. ...
  911. def abs(self) -> tensor:
  912. ...
  913. def reduce(self, axis, combine_fn, keep_dims=False) -> tensor:
  914. ...
  915. def associative_scan(self, axis, combine_fn, reverse=False) -> tensor:
  916. ...
  917. def gather(self, indices, axis) -> tensor:
  918. ...
  919. def histogram(self, num_bins) -> tensor:
  920. ...
  921. def cdiv(self, div) -> tensor:
  922. ...
  923. def sigmoid(self) -> tensor:
  924. ...
  925. def softmax(self, dim=None, keep_dims=False, ieee_rounding=False) -> tensor:
  926. ...
  927. def ravel(self) -> tensor:
  928. ...
  929. def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor:
  930. ...
  931. def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
  932. ...
  933. def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor:
  934. ...
  935. def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
  936. ...
  937. def sum(self, axis=None, keep_dims=False, dtype=None) -> tensor:
  938. ...
  939. def xor_sum(self, axis=None, keep_dims=False) -> tensor:
  940. ...
  941. def reduce_or(self, axis=None, keep_dims=False) -> tensor:
  942. ...
  943. def cumsum(self, axis=0, reverse=False) -> tensor:
  944. ...
  945. def cumprod(self, axis=0, reverse=False) -> tensor:
  946. ...
  947. def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor:
  948. ...
  949. def flip(self, dim=None) -> tensor:
  950. ...
  951. def _type_for_tuple_values(values, fields=None):
  952. return tuple_type([constexpr_type(x) if isinstance(x, (int, float, dtype)) else x.type for x in values], fields)
  953. class tuple(base_value):
  954. def __init__(self, args: Sequence, type: Optional[tuple_type] = None):
  955. self.values = [i for i in args]
  956. if isinstance(type, tuple_type):
  957. self.type = type
  958. elif type is not None: # make_template in ASTFunction.deserialize may pass us a list/tuple
  959. self.type = tuple_type(type)
  960. else:
  961. self.type = _type_for_tuple_values(self.values)
  962. def __getitem__(self, idx: constexpr):
  963. if isinstance(idx, int):
  964. idx = constexpr(idx)
  965. if isinstance(idx, constexpr):
  966. return self.values[idx]
  967. else:
  968. assert isinstance(idx, (slice, builtins.slice))
  969. return tuple(self.values[idx.start:idx.stop:idx.step])
  970. def __getattr__(self, name):
  971. return self.values[self.type.fields.index(name)]
  972. # TODO: remove
  973. def _setitem(self, idx, value):
  974. idx = _unwrap_if_constexpr(idx)
  975. assert isinstance(idx, int)
  976. self.values[idx] = value
  977. self.type = _type_for_tuple_values(self.values, self.type.fields)
  978. def __add__(self, other):
  979. other = _normalize_tuple(other)
  980. return tuple(self.values + other.values)
  981. # return tuple(a + b for a, b in zip(self.values, other.values))
  982. def __mul__(self, other):
  983. assert isinstance(other, constexpr)
  984. return tuple(self.values * other.value)
  985. def __eq__(self, other):
  986. other = _normalize_tuple(other)
  987. return constexpr(self.values == other.values)
  988. def __hash__(self):
  989. return hash(builtins.tuple(self.values))
  990. def __str__(self):
  991. return str([str(x) for x in self.values])
  992. def __iter__(self):
  993. return iter(self.values)
  994. def __len__(self):
  995. return len(self.values)
  996. def _flatten_ir(self, handles: List[ir.value]):
  997. for v in self.values:
  998. v._flatten_ir(handles)
  999. def __repr__(self):
  1000. return f"({', '.join(repr(x) for x in self.values)})"
  1001. class slice:
  1002. def __init__(self, start, stop, step):
  1003. self.start = start
  1004. self.stop = stop
  1005. self.step = step
  1006. self.type = slice_type()
  1007. class tensor_descriptor_base_type(base_type):
  1008. def __init__(self, block_type: block_type):
  1009. self.block_type = block_type
  1010. def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
  1011. value = tensor_descriptor_base(handles[cursor], self.block_type)
  1012. return value, cursor + 1
  1013. def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
  1014. is_signed = self.block_type.element_ty.is_int_signed()
  1015. out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed))
  1016. def __str__(self) -> str:
  1017. # ex. "tensor_descriptor<float32[16, 32]>"
  1018. return f"tensor_descriptor<{self.block_type}>"
  1019. def __eq__(self, other) -> bool:
  1020. if type(other) is not type(self):
  1021. return False
  1022. return self.block_type == other.block_type
  1023. def __neq__(self, other) -> bool:
  1024. return not (self == other)
  1025. def mangle(self) -> str:
  1026. return f"TD{self.block_type.mangle()}"
  1027. class tensor_descriptor_base(base_value):
  1028. """"
  1029. A tensor descriptor with unknown shape and strides
  1030. """
  1031. def __init__(self, handle, block_type: block_type):
  1032. """Not called by user code."""
  1033. super().__init__()
  1034. self.handle = handle # IR handle
  1035. self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type)
  1036. def _flatten_ir(self, handles: List[ir.value]) -> None:
  1037. handles.append(self.handle)
  1038. @property
  1039. def block_type(self):
  1040. return self.type.block_type
  1041. @property
  1042. def block_shape(self):
  1043. return self.type.block_type.shape
  1044. @property
  1045. def dtype(self):
  1046. return self.type.block_type.element_ty
  1047. def __str__(self) -> str:
  1048. return str(self.type)
  1049. @builtin
  1050. def load(self, offsets: Sequence[constexpr | tensor], _semantic=None) -> tensor:
  1051. """Load a block from the descriptor starting at the given element offsets.
  1052. Values outside of the tensor bounds will be filled with zeros.
  1053. :note: Offset must be a multiple of 16-bytes
  1054. """
  1055. return _semantic.descriptor_load(self, offsets, "", "")
  1056. @builtin
  1057. def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
  1058. """Store a block from the descriptor starting at the given element offsets.
  1059. Values outside of the tensor bounds will be ignored.
  1060. :note: Offset must be a multiple of 16-bytes
  1061. """
  1062. return _semantic.descriptor_store(self, value, offsets)
  1063. @builtin
  1064. def atomic_add(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
  1065. return _semantic.descriptor_atomic_add(self, value, offsets)
  1066. @builtin
  1067. def atomic_min(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
  1068. return _semantic.descriptor_atomic_min(self, value, offsets)
  1069. @builtin
  1070. def atomic_max(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
  1071. return _semantic.descriptor_atomic_max(self, value, offsets)
  1072. @builtin
  1073. def atomic_and(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
  1074. return _semantic.descriptor_atomic_and(self, value, offsets)
  1075. @builtin
  1076. def atomic_or(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
  1077. return _semantic.descriptor_atomic_or(self, value, offsets)
  1078. @builtin
  1079. def atomic_xor(self, offsets: Sequence[constexpr | tensor], value: tensor, _semantic=None) -> tensor:
  1080. return _semantic.descriptor_atomic_xor(self, value, offsets)
  1081. @builtin
  1082. def gather(self, *args, _semantic=None) -> tensor:
  1083. """Gather multiple descriptors worth of data"""
  1084. assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}"
  1085. x_offsets = args[0]
  1086. y_offset = args[1]
  1087. return _semantic.descriptor_gather(self, x_offsets, y_offset, "", "")
  1088. @builtin
  1089. def scatter(self, value, *args, _semantic=None) -> tensor:
  1090. """Scatter multiple descriptors worth of data"""
  1091. assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}"
  1092. x_offsets = args[0]
  1093. y_offset = args[1]
  1094. return _semantic.descriptor_scatter(self, value, x_offsets, y_offset)
  1095. class tensor_descriptor_type(tensor_descriptor_base_type):
  1096. def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type):
  1097. self.block_type = block_type
  1098. self.shape_type = shape_type
  1099. self.strides_type = strides_type
  1100. def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]:
  1101. handle = handles[cursor]
  1102. cursor += 1
  1103. shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
  1104. strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
  1105. shape = shape.values
  1106. strides = strides.values
  1107. value = tensor_descriptor(handle, shape, strides, self.block_type)
  1108. return value, cursor
  1109. def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
  1110. super()._flatten_ir_types(builder, out)
  1111. self.shape_type._flatten_ir_types(builder, out)
  1112. self.strides_type._flatten_ir_types(builder, out)
  1113. def __eq__(self, other):
  1114. return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type
  1115. == other.strides_type)
  1116. class tensor_descriptor(tensor_descriptor_base):
  1117. """A descriptor representing a tensor in global memory.
  1118. """
  1119. def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type):
  1120. """Not called by user code."""
  1121. # IR handle
  1122. super().__init__(handle, block_type)
  1123. # Global shape
  1124. self.shape = tuple(shape)
  1125. self.strides = tuple(strides)
  1126. self.type = tensor_descriptor_type(
  1127. block_type,
  1128. shape_type=self.shape.type,
  1129. strides_type=self.strides.type,
  1130. )
  1131. def _flatten_ir(self, handles: List[ir.value]) -> None:
  1132. handles.append(self.handle)
  1133. self.shape._flatten_ir(handles)
  1134. self.strides._flatten_ir(handles)
  1135. # -----------------------
  1136. # aggregate
  1137. # -----------------------
  1138. @dataclass(frozen=True)
  1139. class _aggregate_type(base_type):
  1140. """A generic base type for all Triton aggregate types.
  1141. This class contains a reference to the original user-defined Python class
  1142. and a list of class fields with their Triton types.
  1143. """
  1144. base_cls: type
  1145. fields: List[Tuple[str, base_type]]
  1146. def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
  1147. instance = self.base_cls._get_instance()
  1148. for name, ty in self.fields:
  1149. value, cursor = ty._unflatten_ir(handles, cursor)
  1150. setattr(instance, name, value)
  1151. return instance, cursor
  1152. def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
  1153. for name, ty in self.fields:
  1154. ty._flatten_ir_types(builder, out)
  1155. def mangle(self) -> str:
  1156. name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}"
  1157. fields = [ty.mangle() for (name, ty) in self.fields]
  1158. return f"{name}<{', '.join(fields)}>"
  1159. def _aggregate(cls):
  1160. # Define the wrapped Triton value type.
  1161. class aggregate_value(base_value):
  1162. __triton_builtin__ = True
  1163. __triton_aggregate__ = True
  1164. @classmethod
  1165. def _get_instance(this_cls):
  1166. return super().__new__(this_cls)
  1167. def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
  1168. # Call into the user-defined constructor.
  1169. instance = this_cls._get_instance()
  1170. extra_kwargs = {}
  1171. if isinstance(cls.__init__, JITCallable):
  1172. # raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
  1173. pass
  1174. else:
  1175. if "_semantic" in inspect.signature(cls.__init__).parameters:
  1176. extra_kwargs["_semantic"] = _semantic
  1177. if "_generator" in inspect.signature(cls.__init__).parameters:
  1178. extra_kwargs["_generator"] = _generator
  1179. cls.__init__(instance, *args, **extra_kwargs, **kwargs)
  1180. # Require that the user-defined constructor initialized all fields.
  1181. for name in cls.__annotations__.keys():
  1182. if not hasattr(instance, name):
  1183. raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'")
  1184. return instance
  1185. # Only allow setting attributes defined in the class annotations.
  1186. def __setattr__(self, name, value):
  1187. if name not in cls.__annotations__:
  1188. raise AttributeError(f"{cls.__name__} has no attribute '{name}'")
  1189. if not isinstance(value, cls.__annotations__[name]):
  1190. raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}")
  1191. super().__setattr__(name, value)
  1192. def _flatten_ir(self, handles: List[ir.value]) -> None:
  1193. for name in cls.__annotations__.keys():
  1194. getattr(self, name)._flatten_ir(handles)
  1195. @property
  1196. def type(self):
  1197. return _aggregate_type(aggregate_value,
  1198. [(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
  1199. hash_attrs = [cls.__init__]
  1200. for (name, member) in inspect.getmembers(cls):
  1201. if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable):
  1202. if name != "__init__":
  1203. setattr(aggregate_value, name, member)
  1204. hash_attrs.append(member)
  1205. aggregate_value.hash_attrs = hash_attrs
  1206. aggregate_value.__name__ = cls.__name__
  1207. aggregate_value.__module__ = cls.__module__
  1208. aggregate_value.__qualname__ = cls.__qualname__
  1209. aggregate_value.__doc__ = cls.__doc__
  1210. return aggregate_value
  1211. # -----------------------
  1212. # SPMD Programming Model
  1213. # -----------------------
  1214. @builtin
  1215. def program_id(axis, _semantic=None):
  1216. """
  1217. Returns the id of the current program instance along the given :code:`axis`.
  1218. :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
  1219. :type axis: int
  1220. """
  1221. # if axis == -1:
  1222. # pid0 = _semantic.program_id(0)
  1223. # pid1 = _semantic.program_id(1)
  1224. # pid2 = _semantic.program_id(2)
  1225. # npg0 = _semantic.num_programs(0)
  1226. # npg1 = _semantic.num_programs(1)
  1227. # return pid0 + pid1*npg0 + pid2*npg0*npg1
  1228. axis = _unwrap_if_constexpr(axis)
  1229. return _semantic.program_id(axis)
  1230. @builtin
  1231. def num_programs(axis, _semantic=None):
  1232. """
  1233. Returns the number of program instances launched along the given :code:`axis`.
  1234. :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2.
  1235. :type axis: int
  1236. """
  1237. axis = _unwrap_if_constexpr(axis)
  1238. return _semantic.num_programs(axis)
  1239. # -----------------------
  1240. # Block Initialization
  1241. # -----------------------
  1242. @builtin
  1243. def arange(start, end, _semantic=None):
  1244. start = _unwrap_if_constexpr(start)
  1245. end = _unwrap_if_constexpr(end)
  1246. return _semantic.arange(start, end)
  1247. arange.__doc__ = f"""
  1248. Returns contiguous values within the half-open interval :code:`[start,
  1249. end)`. :code:`end - start` must be less than or equal to
  1250. :code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}`
  1251. :param start: Start of the interval. Must be a power of two.
  1252. :type start: int32
  1253. :param end: End of the interval. Must be a power of two greater than
  1254. :code:`start`.
  1255. :type end: int32
  1256. """
  1257. def _unwrap_shape(shape):
  1258. shape = _unwrap_if_constexpr(shape)
  1259. return [_unwrap_if_constexpr(s) for s in shape]
  1260. def _shape_check_impl(shape):
  1261. shape = _unwrap_shape(shape)
  1262. validate_block_shape(shape)
  1263. return shape
  1264. @builtin
  1265. def full(shape, value, dtype, _semantic=None):
  1266. """
  1267. Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`.
  1268. :param shape: Shape of the new array, e.g., (8, 16) or (8, )
  1269. :type shape: tuple of ints
  1270. :param value: A scalar value to fill the array with
  1271. :type value: scalar
  1272. :param dtype: Data type of the new array, e.g., :code:`tl.float16`
  1273. :type dtype: tl.dtype
  1274. """
  1275. shape = _shape_check_impl(shape)
  1276. value = _unwrap_if_constexpr(value)
  1277. dtype = _unwrap_if_constexpr(dtype)
  1278. return _semantic.full(shape, value, dtype)
  1279. # -----------------------
  1280. # Shape Manipulation
  1281. # -----------------------
  1282. @builtin
  1283. def broadcast(input, other, _semantic=None):
  1284. """
  1285. Tries to broadcast the two given blocks to a common compatible shape.
  1286. :param input: The first input tensor.
  1287. :type input: Block
  1288. :param other: The second input tensor.
  1289. :type other: Block
  1290. """
  1291. return _semantic.broadcast_impl_value(input, other)
  1292. @_tensor_member_fn
  1293. @builtin
  1294. def broadcast_to(input, *shape, _semantic=None):
  1295. """
  1296. Tries to broadcast the given tensor to a new :code:`shape`.
  1297. :param input: The input tensor.
  1298. :type input: Block
  1299. :param shape: The desired shape.
  1300. :type shape:
  1301. :code:`shape` can be passed as a tuple or as individual parameters: ::
  1302. # These are equivalent
  1303. broadcast_to(x, (32, 32))
  1304. broadcast_to(x, 32, 32)
  1305. """
  1306. shape = _shape_check_impl(_unwrap_iterable(shape))
  1307. return _semantic.broadcast_impl_shape(input, shape)
  1308. @_tensor_member_fn
  1309. @builtin
  1310. def trans(input: tensor, *dims, _semantic=None):
  1311. """
  1312. Permutes the dimensions of a tensor.
  1313. If the parameter :code:`dims` is not specified, the function defaults to
  1314. swapping the last two axes, thereby performing an (optionally batched)
  1315. 2D transpose.
  1316. :param input: The input tensor.
  1317. :param dims: The desired ordering of dimensions. For example,
  1318. :code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
  1319. :code:`dims` can be passed as a tuple or as individual parameters: ::
  1320. # These are equivalent
  1321. trans(x, (2, 1, 0))
  1322. trans(x, 2, 1, 0)
  1323. :py:func:`permute` is equivalent to this function, except it doesn't
  1324. have the special case when no permutation is specified.
  1325. """
  1326. dims = _unwrap_iterable(dims)
  1327. if not dims:
  1328. n = len(input.shape)
  1329. if n < 2:
  1330. raise ValueError("tl.trans invoked with a 0- or 1-dimensional tensor")
  1331. dims = list(builtins.range(n - 2)) + [n - 1, n - 2]
  1332. return _semantic.permute(input, dims)
  1333. @_tensor_member_fn
  1334. @builtin
  1335. def permute(input, *dims, _semantic=None):
  1336. """
  1337. Permutes the dimensions of a tensor.
  1338. :param input: The input tensor.
  1339. :type input: Block
  1340. :param dims: The desired ordering of dimensions. For example,
  1341. :code:`(2, 1, 0)` reverses the order dims in a 3D tensor.
  1342. :code:`dims` can be passed as a tuple or as individual parameters: ::
  1343. # These are equivalent
  1344. permute(x, (2, 1, 0))
  1345. permute(x, 2, 1, 0)
  1346. :py:func:`trans` is equivalent to this function, except when
  1347. :code:`dims` is empty, it tries to swap the last two axes.
  1348. """
  1349. dims = _unwrap_iterable(dims)
  1350. return _semantic.permute(input, dims)
  1351. @builtin
  1352. def cat(input, other, can_reorder=False, _semantic=None):
  1353. """
  1354. Concatenate the given blocks
  1355. :param input: The first input tensor.
  1356. :type input: Tensor
  1357. :param other: The second input tensor.
  1358. :type other: Tensor
  1359. :param reorder: Compiler hint. If true, the compiler is
  1360. allowed to reorder elements while concatenating inputs. Only use if the
  1361. order does not matter (e.g., result is only used in reduction ops).
  1362. Current implementation of `cat` supports only can_reorder=True.
  1363. """
  1364. return _semantic.cat(input, other, can_reorder)
  1365. @builtin
  1366. def join(a, b, _semantic=None):
  1367. """
  1368. Join the given tensors in a new, minor dimension.
  1369. For example, given two tensors of shape (4,8), produces a new tensor of
  1370. shape (4,8,2). Given two scalars, returns a tensor of shape (2).
  1371. The two inputs are broadcasted to be the same shape.
  1372. If you want to join more than two elements, you can use multiple calls to
  1373. this function. This reflects the constraint in Triton that tensors must
  1374. have power-of-two sizes.
  1375. join is the inverse of split.
  1376. :param a: The first input tensor.
  1377. :type a: Tensor
  1378. :param b: The second input tensor.
  1379. :type b: Tensor
  1380. """
  1381. return _semantic.join(a, b)
  1382. def _unsplat(x, _semantic=None, _generator=None):
  1383. """
  1384. Convert a single-element tensor to a scalar.
  1385. """
  1386. if len(x.shape) == 0:
  1387. return x
  1388. numel = 1
  1389. for d in x.shape:
  1390. numel *= d
  1391. assert numel == 1, "can only unsplat single-element tensors"
  1392. return _semantic.unsplat(x)
  1393. @_tensor_member_fn
  1394. @builtin
  1395. def split(a, _semantic=None, _generator=None) -> tuple[tensor, tensor]:
  1396. """
  1397. Split a tensor in two along its last dim, which must have size 2.
  1398. For example, given a tensor of shape (4,8,2), produces two tensors of shape
  1399. (4,8). Given a tensor of shape (2), returns two scalars.
  1400. If you want to split into more than two pieces, you can use multiple calls
  1401. to this function (probably plus calling reshape). This reflects the
  1402. constraint in Triton that tensors must have power-of-two sizes.
  1403. split is the inverse of join.
  1404. :param a: The tensor to split.
  1405. :type a: Tensor
  1406. """
  1407. # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars.
  1408. # But _semantic.split can only handle returning tensors. Work around this by
  1409. # expanding the input to shape [1,2] and then reducing the result.
  1410. was_rank_1 = len(a.shape) == 1
  1411. if was_rank_1:
  1412. a = _semantic.expand_dims(a, 0)
  1413. out_lhs, out_rhs = _semantic.split(a)
  1414. if was_rank_1:
  1415. # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
  1416. out_lhs = _unsplat(out_lhs, _semantic=_semantic, _generator=_generator)
  1417. out_rhs = _unsplat(out_rhs, _semantic=_semantic, _generator=_generator)
  1418. return out_lhs, out_rhs
  1419. @_tensor_member_fn
  1420. @builtin
  1421. def view(input, *shape, _semantic=None):
  1422. """
  1423. Returns a tensor with the same elements as `input` but a different shape.
  1424. The order of the elements may not be preserved.
  1425. :param input: The input tensor.
  1426. :type input: Block
  1427. :param shape: The desired shape.
  1428. :code:`shape` can be passed as a tuple or as individual parameters: ::
  1429. # These are equivalent
  1430. view(x, (32, 32))
  1431. view(x, 32, 32)
  1432. """
  1433. warn("view is deprecated, please use reshape with can_reorder being true.")
  1434. shape = _shape_check_impl(_unwrap_iterable(shape))
  1435. return _semantic.reshape(input, shape, can_reorder=True)
  1436. @_tensor_member_fn
  1437. @builtin
  1438. def item(input, _semantic=None, _generator=None):
  1439. """
  1440. Converts a single-element tensor into a scalar.
  1441. """
  1442. return _unsplat(input, _semantic=_semantic, _generator=_generator)
  1443. @_tensor_member_fn
  1444. @builtin
  1445. def reshape(input, *shape, can_reorder=False, _semantic=None, _generator=None):
  1446. """
  1447. Returns a tensor with the same number of elements as input but with the
  1448. provided shape.
  1449. :param input: The input tensor.
  1450. :type input: Block
  1451. :param shape: The new shape.
  1452. :code:`shape` can be passed as a tuple or as individual parameters: ::
  1453. # These are equivalent
  1454. reshape(x, (32, 32))
  1455. reshape(x, 32, 32)
  1456. """
  1457. shape = _shape_check_impl(_unwrap_iterable(shape))
  1458. if len(shape) == 0:
  1459. return _unsplat(input, _semantic=_semantic, _generator=_generator)
  1460. return _semantic.reshape(input, shape, can_reorder)
  1461. def _wrap_axis(axis, ndim):
  1462. if not (-ndim <= axis < ndim):
  1463. raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}")
  1464. return axis if axis >= 0 else axis + ndim
  1465. @_tensor_member_fn
  1466. @builtin
  1467. def expand_dims(input, axis, _semantic=None):
  1468. """
  1469. Expand the shape of a tensor, by inserting new length-1 dimensions.
  1470. Axis indices are with respect to the resulting tensor, so
  1471. ``result.shape[axis]`` will be 1 for each axis.
  1472. :param input: The input tensor.
  1473. :type input: tl.tensor
  1474. :param axis: The indices to add new axes
  1475. :type axis: int | Sequence[int]
  1476. """
  1477. input = _semantic.to_tensor(input)
  1478. axis = _unwrap_if_constexpr(axis)
  1479. axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis]
  1480. new_ndim = len(input.shape) + len(axes)
  1481. axes = [_wrap_axis(_unwrap_if_constexpr(d), new_ndim) for d in axes]
  1482. if len(set(axes)) != len(axes):
  1483. raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}")
  1484. ret = input
  1485. for a in sorted(axes):
  1486. ret = _semantic.expand_dims(ret, a)
  1487. return ret
  1488. @_tensor_member_fn
  1489. @builtin
  1490. def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _semantic=None):
  1491. """
  1492. Casts a tensor to the given :code:`dtype`.
  1493. :param dtype: The target data type.
  1494. :type dtype: tl.dtype
  1495. :param fp_downcast_rounding: The rounding mode for downcasting
  1496. floating-point values. This parameter is only used when self is a
  1497. floating-point tensor and dtype is a floating-point type with a
  1498. smaller bitwidth. Supported values are :code:`"rtne"` (round to
  1499. nearest, ties to even) and :code:`"rtz"` (round towards zero).
  1500. :type fp_downcast_rounding: str, optional
  1501. :param bitcast: If true, the tensor is bitcasted to the given
  1502. :code:`dtype`, instead of being numerically casted.
  1503. :type bitcast: bool, optional
  1504. """
  1505. input = _semantic.to_tensor(input)
  1506. dtype = _unwrap_if_constexpr(dtype)
  1507. fp_downcast_rounding = _unwrap_if_constexpr(fp_downcast_rounding)
  1508. bitcast = _unwrap_if_constexpr(bitcast)
  1509. if bitcast:
  1510. return _semantic.bitcast(input, dtype)
  1511. return _semantic.cast(input, dtype, fp_downcast_rounding)
  1512. # -----------------------
  1513. # Linear Algebra
  1514. # -----------------------
  1515. @builtin
  1516. def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
  1517. _semantic=None):
  1518. """
  1519. Returns the matrix product of two blocks.
  1520. The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions.
  1521. For three-dimensional blocks, `tl.dot` performs the batched matrix product,
  1522. where the first dimension of each block represents the batch dimension.
  1523. :param input: The first tensor to be multiplied.
  1524. :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
  1525. :param other: The second tensor to be multiplied.
  1526. :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
  1527. :param acc: The accumulator tensor. If not None, the result is added to this tensor.
  1528. :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
  1529. :param input_precision: How to exercise the Tensor Cores for f32 x f32. If
  1530. the device does not have Tensor Cores or the inputs are not of dtype f32,
  1531. this option is ignored. For devices that do have tensor cores, the
  1532. default precision is tf32.
  1533. :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
  1534. :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
  1535. Only one of :code:`input_precision` and :code:`allow_tf32` can be
  1536. specified (i.e. at least one must be :code:`None`).
  1537. """
  1538. assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
  1539. if input_precision is None:
  1540. supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions
  1541. input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and
  1542. (allow_tf32 or allow_tf32 is None)) else "ieee")
  1543. input_precision = _unwrap_if_constexpr(input_precision)
  1544. out_dtype = _unwrap_if_constexpr(out_dtype)
  1545. max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc)
  1546. acc = _unwrap_if_constexpr(acc)
  1547. # check shapes make sense:
  1548. a_shape = list(input.shape)
  1549. b_shape = list(other.shape)
  1550. assert len(a_shape) == len(b_shape) >= 2, "input and other must have equal ranks >= 2"
  1551. assert a_shape[:-2] == b_shape[:-2], "input and other must have equal batch shapes"
  1552. assert a_shape[-1] == b_shape[-2], "input and other must have equal reduction dimensions"
  1553. # compute shape of accumulator:
  1554. c_shape = a_shape[:-1] + [b_shape[-1]]
  1555. if acc is not None:
  1556. assert list(acc.shape) == c_shape, "accumulator shape is incompatible"
  1557. rank = len(c_shape)
  1558. if rank >= 4:
  1559. batch_size = 1
  1560. for i in builtins.range(rank - 2):
  1561. batch_size *= c_shape[i]
  1562. input = _semantic.reshape(input, [batch_size] + a_shape[-2:], can_reorder=False)
  1563. other = _semantic.reshape(other, [batch_size] + b_shape[-2:], can_reorder=False)
  1564. if acc is not None:
  1565. acc = _semantic.reshape(acc, [batch_size] + c_shape[-2:], can_reorder=False)
  1566. res = _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype)
  1567. if rank >= 4:
  1568. res = _semantic.reshape(res, c_shape, can_reorder=False)
  1569. assert list(res.shape) == c_shape, "output shape is unexpected"
  1570. return res
  1571. @builtin
  1572. def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True,
  1573. rhs_k_pack=True, out_dtype=float32, _semantic=None):
  1574. """
  1575. Returns the matrix product of two blocks in microscaling format.
  1576. lhs and rhs use microscaling formats described here:
  1577. https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
  1578. Software emulation enables targeting hardware architectures without native microscaling
  1579. operation support. Right now for such case, microscaled lhs/rhs are upcasted to
  1580. :code:`bf16` element type beforehand for dot computation, with one exception:
  1581. for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type,
  1582. the other input is also upcasted to :code:`fp16` element type instead.
  1583. This behavior is experimental and may be subject to change in the future.
  1584. :param lhs: The first tensor to be multiplied.
  1585. :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
  1586. :param lhs_scale: Scale factor for lhs tensor. Shape should be [M, K//group_size] when lhs is [M, K], where group_size is 32 if scales type are `e8m0`.
  1587. :type lhs_scale: e8m0 type represented as an uint8 tensor, or None.
  1588. :param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
  1589. :type lhs_format: str
  1590. :param rhs: The second tensor to be multiplied.
  1591. :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
  1592. :param rhs_scale: Scale factor for rhs tensor. Shape should be [N, K//group_size] where rhs is [K, N].
  1593. Important: Do NOT transpose rhs_scale
  1594. :type rhs_scale: e8m0 type represented as an uint8 tensor, or None.
  1595. :param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
  1596. :type rhs_format: str
  1597. :param acc: The accumulator tensor. If not None, the result is added to this tensor.
  1598. :param lhs_k_pack: If false, the lhs tensor is packed into uint8 along M dimension.
  1599. :type lhs_k_pack: bool, optional
  1600. :param rhs_k_pack: If false, the rhs tensor is packed into uint8 along N dimension.
  1601. :type rhs_k_pack: bool, optional
  1602. """
  1603. out_dtype = _unwrap_if_constexpr(out_dtype)
  1604. acc = _unwrap_if_constexpr(acc)
  1605. assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
  1606. return _semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, lhs_k_pack,
  1607. rhs_k_pack, out_dtype)
  1608. # -----------------------
  1609. # Non-Atomic Memory Operations
  1610. # -----------------------
  1611. @builtin
  1612. def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="",
  1613. volatile=False, _semantic=None):
  1614. """
  1615. Return a tensor of data whose values are loaded from memory at location defined by `pointer`:
  1616. (1) If `pointer` is a single element pointer, a scalar is be loaded. In
  1617. this case:
  1618. - `mask` and `other` must also be scalars,
  1619. - `other` is implicitly typecast to `pointer.dtype.element_ty`, and
  1620. - `boundary_check` and `padding_option` must be empty.
  1621. (2) If `pointer` is an N-dimensional tensor of pointers, an
  1622. N-dimensional tensor is loaded. In this case:
  1623. - `mask` and `other` are implicitly broadcast to `pointer.shape`,
  1624. - `other` is implicitly typecast to `pointer.dtype.element_ty`, and
  1625. - `boundary_check` and `padding_option` must be empty.
  1626. (3) If `pointer` is a block pointer defined by `make_block_ptr`, a
  1627. tensor is loaded. In this case:
  1628. - `mask` and `other` must be `None`, and
  1629. - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access.
  1630. :param pointer: Pointer to the data to be loaded
  1631. :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
  1632. :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
  1633. (must be `None` with block pointers)
  1634. :type mask: Block of `triton.int1`, optional
  1635. :param other: if `mask[idx]` is false, return `other[idx]`
  1636. :type other: Block, optional
  1637. :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
  1638. :type boundary_check: tuple of ints, optional
  1639. :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value.
  1640. :param cache_modifier: changes cache option in NVIDIA PTX
  1641. :type cache_modifier: str, optional, should be one of {"", ".ca", ".cg", ".cv"}, where ".ca" stands for
  1642. cache at all levels, ".cg" stands for cache at global level (cache in L2 and below, not L1),
  1643. and ".cv" means don’t cache and fetch again. see
  1644. `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
  1645. :param eviction_policy: changes eviction policy in NVIDIA PTX
  1646. :type eviction_policy: str, optional
  1647. :param volatile: changes volatile option in NVIDIA PTX
  1648. :type volatile: bool, optional
  1649. """
  1650. # `mask` and `other` can be constexpr
  1651. mask = _unwrap_if_constexpr(mask)
  1652. other = _unwrap_if_constexpr(other)
  1653. if mask is not None:
  1654. mask = _semantic.to_tensor(mask)
  1655. if other is not None:
  1656. other = _semantic.to_tensor(other)
  1657. padding_option = _unwrap_if_constexpr(padding_option)
  1658. cache_modifier = _unwrap_if_constexpr(cache_modifier)
  1659. eviction_policy = _unwrap_if_constexpr(eviction_policy)
  1660. volatile = _unwrap_if_constexpr(volatile)
  1661. return _semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy,
  1662. volatile)
  1663. @builtin
  1664. def load_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor],
  1665. _semantic=None) -> tensor:
  1666. """Load a block of data from a tensor descriptor."""
  1667. return desc.load(offsets, _semantic=_semantic)
  1668. @builtin
  1669. def store_tensor_descriptor(desc: tensor_descriptor_base, offsets: Sequence[constexpr | tensor], value: tensor,
  1670. _semantic=None) -> tensor:
  1671. """Store a block of data to a tensor descriptor."""
  1672. return desc.store(offsets, value, _semantic=_semantic)
  1673. @_tensor_member_fn
  1674. @builtin
  1675. def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _semantic=None):
  1676. """
  1677. Store a tensor of data into memory locations defined by `pointer`.
  1678. (1) If `pointer` is a single element pointer, a scalar is stored. In
  1679. this case:
  1680. - `mask` must also be scalar, and
  1681. - `boundary_check` and `padding_option` must be empty.
  1682. (2) If `pointer` is an N-dimensional tensor of pointers, an
  1683. N-dimensional block is stored. In this case:
  1684. - `mask` is implicitly broadcast to `pointer.shape`, and
  1685. - `boundary_check` must be empty.
  1686. (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block
  1687. of data is stored. In this case:
  1688. - `mask` must be None, and
  1689. - `boundary_check` can be specified to control the behavior of out-of-bound access.
  1690. `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`.
  1691. :param pointer: The memory location where the elements of `value` are stored
  1692. :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType`
  1693. :param value: The tensor of elements to be stored
  1694. :type value: Block
  1695. :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]`
  1696. :type mask: Block of triton.int1, optional
  1697. :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check
  1698. :type boundary_check: tuple of ints, optional
  1699. :param cache_modifier: changes cache option in NVIDIA PTX
  1700. :type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for
  1701. cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt"
  1702. stands for cache write-through, see `cache operator <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cache-operators>`_ for more details.
  1703. :param eviction_policy: changes eviction policy in NVIDIA PTX
  1704. :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"}
  1705. """
  1706. # `value` can be constexpr
  1707. value = _semantic.to_tensor(value)
  1708. mask = _unwrap_if_constexpr(mask)
  1709. if mask is not None:
  1710. mask = _semantic.to_tensor(mask)
  1711. cache_modifier = _unwrap_if_constexpr(cache_modifier)
  1712. eviction_policy = _unwrap_if_constexpr(eviction_policy)
  1713. return _semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy)
  1714. @builtin
  1715. def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _semantic=None):
  1716. """
  1717. Returns a pointer to a block in a parent tensor
  1718. :param base: The base pointer to the parent tensor
  1719. :param shape: The shape of the parent tensor
  1720. :param strides: The strides of the parent tensor
  1721. :param offsets: The offsets to the block
  1722. :param block_shape: The shape of the block
  1723. :param order: The order of the original data format
  1724. """
  1725. return _semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order)
  1726. @must_use_result(
  1727. "Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable."
  1728. )
  1729. @_tensor_member_fn
  1730. @builtin
  1731. def advance(base, offsets, _semantic=None):
  1732. """
  1733. Advance a block pointer
  1734. :param base: the block pointer to advance
  1735. :param offsets: the offsets to advance, a tuple by dimension
  1736. """
  1737. return _semantic.advance(base, offsets)
  1738. @builtin
  1739. def make_tensor_descriptor(
  1740. base: tensor,
  1741. shape: List[tensor],
  1742. strides: List[tensor],
  1743. block_shape: List[constexpr],
  1744. padding_option="zero",
  1745. _semantic=None,
  1746. ) -> tensor_descriptor:
  1747. """Make a tensor descriptor object
  1748. :param base: the base pointer of the tensor, must be 16-byte aligned
  1749. :param shape: A list of non-negative integers representing the tensor shape
  1750. :param strides: A list of tensor strides. Leading dimensions must be multiples
  1751. of 16-byte strides and the last dimension must be contiguous.
  1752. :param block_shape: The shape of block to be loaded/stored from global memory
  1753. Notes
  1754. *****
  1755. On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object
  1756. and loads and stores from the descriptor will be backed by the TMA hardware.
  1757. Currently only 2-5 dimensional tensors are supported.
  1758. Example
  1759. *******
  1760. .. code-block:: python
  1761. @triton.jit
  1762. def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
  1763. desc = tl.make_tensor_descriptor(
  1764. in_out_ptr,
  1765. shape=[M, N],
  1766. strides=[N, 1],
  1767. block_shape=[M_BLOCK, N_BLOCK],
  1768. )
  1769. moffset = tl.program_id(0) * M_BLOCK
  1770. noffset = tl.program_id(1) * N_BLOCK
  1771. value = desc.load([moffset, noffset])
  1772. desc.store([moffset, noffset], tl.abs(value))
  1773. # TMA descriptors require a global memory allocation
  1774. def alloc_fn(size: int, alignment: int, stream: Optional[int]):
  1775. return torch.empty(size, device="cuda", dtype=torch.int8)
  1776. triton.set_allocator(alloc_fn)
  1777. M, N = 256, 256
  1778. x = torch.randn(M, N, device="cuda")
  1779. M_BLOCK, N_BLOCK = 32, 32
  1780. grid = (M / M_BLOCK, N / N_BLOCK)
  1781. inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
  1782. """
  1783. padding_option = _unwrap_if_constexpr(padding_option)
  1784. return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option)
  1785. # -----------------------
  1786. # Atomic Memory Operations
  1787. # -----------------------
  1788. def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
  1789. def _decorator(func: T) -> T:
  1790. docstr = f"""
  1791. Performs an atomic {name} at the memory location specified by :code:`pointer`.
  1792. Return the data stored at :code:`pointer` before the atomic operation.
  1793. :param pointer: The memory locations to operate on
  1794. :type pointer: Block of dtype=triton.PointerDType"""
  1795. if has_cmp:
  1796. docstr += """
  1797. :param cmp: The values expected to be found in the atomic object
  1798. :type cmp: Block of dtype=pointer.dtype.element_ty"""
  1799. docstr += """
  1800. :param val: The values with which to perform the atomic operation
  1801. :type val: Block of dtype=pointer.dtype.element_ty
  1802. :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire",
  1803. "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided,
  1804. the function defaults to using "acq_rel" semantics.
  1805. :type sem: str, optional
  1806. :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation.
  1807. Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu".
  1808. :type scope: str, optional
  1809. """
  1810. func.__doc__ = docstr
  1811. return func
  1812. return _decorator
  1813. @_tensor_member_fn
  1814. @builtin
  1815. @_add_atomic_docstr("compare-and-swap", has_cmp=True)
  1816. def atomic_cas(pointer, cmp, val, sem=None, scope=None, _semantic=None):
  1817. cmp = _semantic.to_tensor(cmp)
  1818. val = _semantic.to_tensor(val)
  1819. sem = _unwrap_if_constexpr(sem)
  1820. scope = _unwrap_if_constexpr(scope)
  1821. return _semantic.atomic_cas(pointer, cmp, val, sem, scope)
  1822. @_tensor_member_fn
  1823. @builtin
  1824. @_add_atomic_docstr("exchange")
  1825. def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
  1826. val = _semantic.to_tensor(val)
  1827. sem = _unwrap_if_constexpr(sem)
  1828. scope = _unwrap_if_constexpr(scope)
  1829. mask = _unwrap_if_constexpr(mask)
  1830. return _semantic.atomic_xchg(pointer, val, mask, sem, scope)
  1831. @_tensor_member_fn
  1832. @builtin
  1833. @_add_atomic_docstr("add")
  1834. def atomic_add(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
  1835. val = _semantic.to_tensor(val)
  1836. sem = _unwrap_if_constexpr(sem)
  1837. scope = _unwrap_if_constexpr(scope)
  1838. mask = _unwrap_if_constexpr(mask)
  1839. return _semantic.atomic_add(pointer, val, mask, sem, scope)
  1840. @_tensor_member_fn
  1841. @builtin
  1842. @_add_atomic_docstr("max")
  1843. def atomic_max(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
  1844. val = _semantic.to_tensor(val)
  1845. sem = _unwrap_if_constexpr(sem)
  1846. scope = _unwrap_if_constexpr(scope)
  1847. mask = _unwrap_if_constexpr(mask)
  1848. return _semantic.atomic_max(pointer, val, mask, sem, scope)
  1849. @_tensor_member_fn
  1850. @builtin
  1851. @_add_atomic_docstr("min")
  1852. def atomic_min(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
  1853. val = _semantic.to_tensor(val)
  1854. sem = _unwrap_if_constexpr(sem)
  1855. scope = _unwrap_if_constexpr(scope)
  1856. mask = _unwrap_if_constexpr(mask)
  1857. return _semantic.atomic_min(pointer, val, mask, sem, scope)
  1858. @_tensor_member_fn
  1859. @builtin
  1860. @_add_atomic_docstr("logical and")
  1861. def atomic_and(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
  1862. val = _semantic.to_tensor(val)
  1863. sem = _unwrap_if_constexpr(sem)
  1864. scope = _unwrap_if_constexpr(scope)
  1865. mask = _unwrap_if_constexpr(mask)
  1866. return _semantic.atomic_and(pointer, val, mask, sem, scope)
  1867. @_tensor_member_fn
  1868. @builtin
  1869. @_add_atomic_docstr("logical or")
  1870. def atomic_or(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
  1871. val = _semantic.to_tensor(val)
  1872. sem = _unwrap_if_constexpr(sem)
  1873. scope = _unwrap_if_constexpr(scope)
  1874. mask = _unwrap_if_constexpr(mask)
  1875. return _semantic.atomic_or(pointer, val, mask, sem, scope)
  1876. @_tensor_member_fn
  1877. @builtin
  1878. @_add_atomic_docstr("logical xor")
  1879. def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _semantic=None):
  1880. val = _semantic.to_tensor(val)
  1881. sem = _unwrap_if_constexpr(sem)
  1882. scope = _unwrap_if_constexpr(scope)
  1883. mask = _unwrap_if_constexpr(mask)
  1884. return _semantic.atomic_xor(pointer, val, mask, sem, scope)
  1885. # -----------------------
  1886. # Conditioning
  1887. # -----------------------
  1888. @builtin
  1889. def where(condition, x, y, _semantic=None):
  1890. """
  1891. Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`.
  1892. Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`.
  1893. If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
  1894. The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
  1895. :code:`x` and :code:`y` must have the same data type.
  1896. :param condition: When True (nonzero), yield x, otherwise yield y.
  1897. :type condition: Block of triton.bool
  1898. :param x: values selected at indices where condition is True.
  1899. :param y: values selected at indices where condition is False.
  1900. """
  1901. condition = _semantic.to_tensor(condition)
  1902. x = _unwrap_if_constexpr(x)
  1903. y = _unwrap_if_constexpr(y)
  1904. return _semantic.where(condition, x, y)
  1905. # -----------------------
  1906. # Math
  1907. # -----------------------
  1908. @builtin
  1909. def add(x, y, sanitize_overflow: constexpr = True, _semantic=None):
  1910. x = _unwrap_if_constexpr(x)
  1911. y = _unwrap_if_constexpr(y)
  1912. return _semantic.add(x, y, sanitize_overflow)
  1913. @builtin
  1914. def sub(x, y, sanitize_overflow: constexpr = True, _semantic=None):
  1915. x = _unwrap_if_constexpr(x)
  1916. y = _unwrap_if_constexpr(y)
  1917. return _semantic.sub(x, y, sanitize_overflow)
  1918. @builtin
  1919. def mul(x, y, sanitize_overflow: constexpr = True, _semantic=None):
  1920. x = _unwrap_if_constexpr(x)
  1921. y = _unwrap_if_constexpr(y)
  1922. return _semantic.mul(x, y, sanitize_overflow)
  1923. @builtin
  1924. def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
  1925. """
  1926. Computes the element-wise minimum of :code:`x` and :code:`y`.
  1927. :param x: the first input tensor
  1928. :type x: Block
  1929. :param y: the second input tensor
  1930. :type y: Block
  1931. :param propagate_nan: whether to propagate NaN values.
  1932. :type propagate_nan: tl.PropagateNan
  1933. .. seealso:: :class:`tl.PropagateNan`
  1934. """
  1935. x = _semantic.to_tensor(x)
  1936. y = _semantic.to_tensor(y)
  1937. x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
  1938. y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
  1939. propagate_nan = _unwrap_if_constexpr(propagate_nan)
  1940. return _semantic.minimum(x, y, propagate_nan)
  1941. @builtin
  1942. def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
  1943. """
  1944. Computes the element-wise maximum of :code:`x` and :code:`y`.
  1945. :param x: the first input tensor
  1946. :type x: Block
  1947. :param y: the second input tensor
  1948. :type y: Block
  1949. :param propagate_nan: whether to propagate NaN values.
  1950. :type propagate_nan: tl.PropagateNan
  1951. .. seealso:: :class:`tl.PropagateNan`
  1952. """
  1953. x = _semantic.to_tensor(x)
  1954. y = _semantic.to_tensor(y)
  1955. x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
  1956. y = _promote_bfloat16_to_float32(y, _semantic=_semantic)
  1957. propagate_nan = _unwrap_if_constexpr(propagate_nan)
  1958. return _semantic.maximum(x, y, propagate_nan)
  1959. @builtin
  1960. def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _semantic=None):
  1961. """
  1962. Clamps the input tensor :code:`x` within the range [min, max].
  1963. Behavior when :code:`min` > :code:`max` is undefined.
  1964. :param x: the input tensor
  1965. :type x: Block
  1966. :param min: the lower bound for clamping
  1967. :type min: Block
  1968. :param max: the upper bound for clamping
  1969. :type max: Block
  1970. :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor.
  1971. If either :code:`min` or :code:`max` is NaN, the result is undefined.
  1972. :type propagate_nan: tl.PropagateNan
  1973. .. seealso:: :class:`tl.PropagateNan`
  1974. """
  1975. x = _semantic.to_tensor(x)
  1976. min = _semantic.to_tensor(min)
  1977. max = _semantic.to_tensor(max)
  1978. x = _promote_bfloat16_to_float32(x, _semantic=_semantic)
  1979. min = _promote_bfloat16_to_float32(min, _semantic=_semantic)
  1980. max = _promote_bfloat16_to_float32(max, _semantic=_semantic)
  1981. propagate_nan = _unwrap_if_constexpr(propagate_nan)
  1982. return _semantic.clamp(x, min, max, propagate_nan)
  1983. # -----------------------
  1984. # Reductions
  1985. # -----------------------
  1986. def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None,
  1987. dtype_arg: str = None) -> Callable[[T], T]:
  1988. def _decorator(func: T) -> T:
  1989. docstr = """
  1990. Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
  1991. :param input: the input values
  1992. :type input: Tensor
  1993. :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions
  1994. :type axis: int
  1995. :param keep_dims: if true, keep the reduced dimensions with length 1
  1996. :type keep_dims: bool"""
  1997. if return_indices_arg is not None:
  1998. docstr += f"""
  1999. :param {return_indices_arg}: if true, return index corresponding to the {name} value
  2000. :type {return_indices_arg}: bool"""
  2001. if tie_break_arg is not None:
  2002. docstr += f"""
  2003. :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN
  2004. :type {tie_break_arg}: bool"""
  2005. if dtype_arg is not None:
  2006. docstr += f"""
  2007. :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. This is useful for preventing data overflows. If not specified, integer and bool dtypes are upcasted to :code:`tl.int32` and float dtypes are upcasted to at least :code:`tl.float32`.
  2008. :type {dtype_arg}: tl.dtype"""
  2009. func.__doc__ = docstr.format(name=name)
  2010. return func
  2011. return _decorator
  2012. @contextmanager
  2013. def _insertion_guard(builder):
  2014. ip = builder.get_insertion_point()
  2015. yield
  2016. builder.restore_insertion_point(ip)
  2017. @_tensor_member_fn
  2018. @builtin
  2019. def reduce(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
  2020. """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis`
  2021. :param input: the input tensor, or tuple of tensors
  2022. :type input: Tensor
  2023. :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions
  2024. :type axis: int | None
  2025. :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
  2026. :type combine_fn: Callable
  2027. :param keep_dims: if true, keep the reduced dimensions with length 1
  2028. :type keep_dims: bool
  2029. """
  2030. if isinstance(input, tensor):
  2031. return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic, _generator=_generator)[0]
  2032. def make_combine_region(reduce_op):
  2033. param_types = [t.type.scalar for t in input] * 2
  2034. region = reduce_op.get_region(0)
  2035. builder = _semantic.builder
  2036. with _insertion_guard(builder):
  2037. to_ir = lambda T: T.to_ir(builder)
  2038. block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
  2039. args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
  2040. results = _generator.call_JitFunction(combine_fn, args, kwargs={})
  2041. if isinstance(results, tensor):
  2042. handles = [results.handle]
  2043. else:
  2044. handles = [r.handle for r in results]
  2045. builder.create_reduce_ret(*handles)
  2046. def expand_ndims(t, ndims):
  2047. for _ in builtins.range(ndims):
  2048. t = expand_dims(t, 0, _semantic=_semantic)
  2049. return t
  2050. axis = _unwrap_if_constexpr(axis)
  2051. keep_dims = _unwrap_if_constexpr(keep_dims)
  2052. if axis is not None:
  2053. axis = _wrap_axis(axis, len(input[0].shape))
  2054. ret = _semantic.reduction(input, axis, make_combine_region)
  2055. if keep_dims:
  2056. if axis is not None:
  2057. ret = tuple(expand_dims(t, axis, _semantic=_semantic) for t in ret)
  2058. else:
  2059. ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret)
  2060. return ret
  2061. @builtin
  2062. def _promote_bfloat16_to_float32(t, _semantic=None):
  2063. scalar_ty = t.type.scalar
  2064. # hardware doesn't support FMAX, FMIN, CMP for bfloat16
  2065. if scalar_ty is bfloat16:
  2066. return t.to(float32, _semantic=_semantic)
  2067. return t
  2068. @builtin
  2069. def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _semantic=None, _generator=None):
  2070. axis = _unwrap_if_constexpr(axis)
  2071. n = input.shape[axis]
  2072. index = arange(0, n, _semantic=_semantic)
  2073. if len(input.shape) > 1:
  2074. # Broadcast index across the non-reduced axes
  2075. axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))]
  2076. del axes_to_expand[axis]
  2077. index = expand_dims(index, axes_to_expand, _semantic=_semantic)
  2078. index = broadcast_to(index, input.shape, _semantic=_semantic)
  2079. rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _semantic=_semantic,
  2080. _generator=_generator)
  2081. return rvalue, rindices
  2082. # -----------------------
  2083. # Scans
  2084. # -----------------------
  2085. def _add_scan_docstr(name: str, dtype_arg: str = None) -> Callable[[T], T]:
  2086. def _decorator(func: T) -> T:
  2087. docstr = """
  2088. Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis`
  2089. :param input: the input values
  2090. :type input: Tensor
  2091. :param axis: the dimension along which the scan should be done
  2092. :type axis: int
  2093. :param reverse: if true, the scan is performed in the reverse direction
  2094. :type reverse: bool"""
  2095. if dtype_arg is not None:
  2096. docstr += f"""
  2097. :param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that :code:`tl.bfloat16` inputs are automatically promoted to :code:`tl.float32`.
  2098. :type {dtype_arg}: tl.dtype"""
  2099. func.__doc__ = docstr.format(name=name)
  2100. return func
  2101. return _decorator
  2102. @_tensor_member_fn
  2103. @builtin
  2104. def associative_scan(input, axis, combine_fn, reverse=False, _semantic=None, _generator=None):
  2105. """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry
  2106. :param input: the input tensor, or tuple of tensors
  2107. :type input: Tensor
  2108. :param axis: the dimension along which the reduction should be done
  2109. :type axis: int
  2110. :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit)
  2111. :type combine_fn: Callable
  2112. :param reverse: whether to apply the associative scan in the reverse direction along axis
  2113. :type reverse: bool
  2114. """
  2115. if isinstance(input, tensor):
  2116. return associative_scan((input, ), axis, combine_fn, reverse, _semantic=_semantic, _generator=_generator)[0]
  2117. def make_combine_region(scan_op):
  2118. param_types = [t.type.scalar for t in input] * 2
  2119. region = scan_op.get_region(0)
  2120. builder = _semantic.builder
  2121. with _insertion_guard(builder):
  2122. to_ir = lambda T: T.to_ir(builder)
  2123. block = builder.create_block_with_parent(region, list(map(to_ir, param_types)))
  2124. args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
  2125. results = _generator.call_JitFunction(combine_fn, args, kwargs={})
  2126. if isinstance(results, tensor):
  2127. handles = [results.handle]
  2128. else:
  2129. handles = [r.handle for r in results]
  2130. builder.create_scan_ret(*handles)
  2131. axis = _unwrap_if_constexpr(axis)
  2132. if axis is not None:
  2133. axis = _wrap_axis(axis, len(input[0].shape))
  2134. return _semantic.associative_scan(input, axis, make_combine_region, reverse)
  2135. @_tensor_member_fn
  2136. @builtin
  2137. def histogram(input, num_bins, mask=None, _semantic=None, _generator=None):
  2138. """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.
  2139. :param input: the input tensor
  2140. :type input: Tensor
  2141. :param num_bins: number of histogram bins
  2142. :type num_bins: int
  2143. :param mask: if `mask[idx]` is false, exclude `input[idx]` from histogram
  2144. :type mask: Block of `triton.int1`, optional
  2145. """
  2146. num_bins = _unwrap_if_constexpr(num_bins)
  2147. mask = _unwrap_if_constexpr(mask)
  2148. if mask is not None:
  2149. mask = _semantic.to_tensor(mask)
  2150. return _semantic.histogram(input, num_bins, mask)
  2151. @_tensor_member_fn
  2152. @builtin
  2153. def gather(src, index, axis, _semantic=None):
  2154. """Gather from a tensor along a given dimension.
  2155. :param src: the source tensor
  2156. :type src: Tensor
  2157. :param index: the index tensor
  2158. :type index: Tensor
  2159. :param axis: the dimension to gather along
  2160. :type axis: int
  2161. """
  2162. src = _unwrap_if_constexpr(src)
  2163. index = _unwrap_if_constexpr(index)
  2164. axis = _unwrap_if_constexpr(axis)
  2165. return _semantic.gather(src, index, axis)
  2166. @builtin
  2167. def map_elementwise(
  2168. scalar_fn: Callable[..., Tuple[tensor, ...]],
  2169. *args: tensor,
  2170. pack=1,
  2171. _semantic=None,
  2172. _generator=None,
  2173. ):
  2174. '''
  2175. Map a scalar function over a tensor.
  2176. The input tensors :code:`args` are implicitly broadcasted to the same shape.
  2177. This may be useful in allowing control flow over single elements in a tensor,
  2178. for example a multi-branch function where one branch is more expensive. With
  2179. :code:`tl.where` you are forced to calculate both sides of the branch, but
  2180. with an if we only execute one side.
  2181. .. highlight:: python
  2182. .. code-block:: python
  2183. @triton.jit
  2184. def selu_scalar(x, alpha):
  2185. if x > 0:
  2186. return a
  2187. else:
  2188. return alpha * (tl.exp(x) - 1)
  2189. @triton.jit
  2190. def selu(x, alpha):
  2191. return tl.map_elementwise(selu_scalar, x, alpha)
  2192. :param scalar_fn: the function to map over.
  2193. :param pack: the number of elements to be processed by one function call.
  2194. :return: one tensor or a tuple of tensors, depending on the mapped function.
  2195. '''
  2196. # Build the block for the nested region first to discover the return types
  2197. assert pack >= 1
  2198. in_scalar_tys = [t.type.scalar for t in args]
  2199. builder = _semantic.builder
  2200. block = builder.new_block()
  2201. scalar_args = []
  2202. original_loc = builder.get_loc()
  2203. for i, ty in enumerate(in_scalar_tys):
  2204. for j in builtins.range(pack):
  2205. block.add_argument_at(ty.to_ir(builder), original_loc)
  2206. scalar_args.append(tensor(block.arg(i * pack + j), ty))
  2207. with _insertion_guard(builder):
  2208. builder.set_insertion_point_to_start(block)
  2209. scalar_results = _generator.call_JitFunction(scalar_fn, scalar_args, kwargs={})
  2210. is_single = isinstance(scalar_results, tensor)
  2211. if is_single:
  2212. scalar_results = scalar_results,
  2213. handles = [r.handle for r in scalar_results]
  2214. builder.set_loc(original_loc)
  2215. builder.create_map_elementwise_ret(handles)
  2216. fn_result_types = [x.type for x in scalar_results]
  2217. scalar_result_types = fn_result_types
  2218. if pack > 1:
  2219. scalar_result_types = fn_result_types[::pack]
  2220. for offset in builtins.range(1, pack):
  2221. assert scalar_result_types == fn_result_types[offset::pack], "type mismatch in unpacked results"
  2222. def make_elementwise_region(elementwise_op):
  2223. region = elementwise_op.get_region(0)
  2224. region.push_back(block)
  2225. builder.set_loc(original_loc)
  2226. result = _semantic.map_elementwise(args, scalar_result_types, pack, make_elementwise_region)
  2227. return result[0] if is_single else result
  2228. # -----------------------
  2229. # Compiler Hint Ops
  2230. # -----------------------
  2231. @builtin
  2232. def debug_barrier(_semantic=None):
  2233. '''
  2234. Insert a barrier to synchronize all threads in a block.
  2235. '''
  2236. return _semantic.debug_barrier()
  2237. @builtin
  2238. def multiple_of(input, values, _semantic=None):
  2239. """
  2240. Let the compiler know that the values in :code:`input` are all multiples of :code:`value`.
  2241. """
  2242. if isinstance(values, constexpr):
  2243. values = [values]
  2244. for i, d in enumerate(values):
  2245. if not isinstance(d, constexpr):
  2246. raise TypeError(f"values element {i} must have type `constexpr`")
  2247. if not isinstance(d.value, int):
  2248. raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
  2249. values = [x.value for x in values]
  2250. return _semantic.multiple_of(input, values)
  2251. @builtin
  2252. def max_contiguous(input, values, _semantic=None):
  2253. """
  2254. Let the compiler know that the `value` first values in :code:`input` are contiguous.
  2255. """
  2256. if isinstance(values, constexpr):
  2257. values = [values]
  2258. for i, d in enumerate(values):
  2259. if not isinstance(d, constexpr):
  2260. raise TypeError(f"values element {i} must have type `constexpr`")
  2261. if not isinstance(d.value, int):
  2262. raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
  2263. values = [x.value for x in values]
  2264. return _semantic.max_contiguous(input, values)
  2265. @builtin
  2266. def max_constancy(input, values, _semantic=None):
  2267. """
  2268. Let the compiler know that the `value` first values in :code:`input` are constant.
  2269. e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal,
  2270. for example [0, 0, 0, 0, 1, 1, 1, 1].
  2271. """
  2272. if isinstance(values, constexpr):
  2273. values = [values]
  2274. for i, d in enumerate(values):
  2275. if not isinstance(d, constexpr):
  2276. raise TypeError(f"values element {i} must have type `constexpr`")
  2277. if not isinstance(d.value, int):
  2278. raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
  2279. values = [x.value for x in values]
  2280. return _semantic.max_constancy(input, values)
  2281. @builtin
  2282. def assume(cond, _semantic=None):
  2283. '''
  2284. Allow compiler to assume the :code:`cond` is True.
  2285. '''
  2286. return _semantic.assume(_semantic.to_tensor(cond))
  2287. # -----------------------
  2288. # Debugging functions
  2289. # -----------------------
  2290. @builtin
  2291. def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _semantic=None):
  2292. '''
  2293. Print the values at compile time. The parameters are the same as the builtin :code:`print`.
  2294. NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`,
  2295. which has special requirements for the arguments.
  2296. .. highlight:: python
  2297. .. code-block:: python
  2298. tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}")
  2299. '''
  2300. pass
  2301. @builtin
  2302. def static_assert(cond, msg="", _semantic=None):
  2303. '''
  2304. Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable
  2305. is set.
  2306. .. highlight:: python
  2307. .. code-block:: python
  2308. tl.static_assert(BLOCK_SIZE == 1024)
  2309. '''
  2310. pass
  2311. @builtin
  2312. def device_print(prefix, *args, hex=False, _semantic=None):
  2313. '''
  2314. Print the values at runtime from the device. String formatting does not work for runtime values, so you should
  2315. provide the values you want to print as arguments. The first value must be a string, all following values must
  2316. be scalars or tensors.
  2317. Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match
  2318. this function (not the normal requirements for :code:`print`).
  2319. .. highlight:: python
  2320. .. code-block:: python
  2321. tl.device_print("pid", pid)
  2322. print("pid", pid)
  2323. On CUDA, printfs are streamed through a buffer of limited size (on one host,
  2324. we measured the default as 6912 KiB, but this may not be consistent across
  2325. GPUs and CUDA versions). If you notice some printfs are being dropped, you
  2326. can increase the buffer size by calling
  2327. .. highlight:: python
  2328. .. code-block:: python
  2329. triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes)
  2330. CUDA may raise an error if you try to change this value after running a
  2331. kernel that uses printfs. The value set here may only affect the current
  2332. device (so if you have multiple GPUs, you'd need to call it multiple times).
  2333. :param prefix: a prefix to print before the values. This is required to be a string literal.
  2334. :param args: the values to print. They can be any tensor or scalar.
  2335. :param hex: print all values as hex instead of decimal
  2336. '''
  2337. import string
  2338. prefix = _unwrap_if_constexpr(prefix)
  2339. assert isinstance(prefix, str), f"{prefix} is not string"
  2340. b_ascii = True
  2341. for ch in prefix:
  2342. if ch not in string.printable:
  2343. b_ascii = False
  2344. break
  2345. assert b_ascii, f"{prefix} is not an ascii string"
  2346. new_args = []
  2347. for arg in args:
  2348. new_args.append(_semantic.to_tensor(arg))
  2349. return _semantic.device_print(prefix, new_args, hex)
  2350. @builtin
  2351. def device_assert(cond, msg="", mask=None, _semantic=None):
  2352. '''
  2353. Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
  2354. is set to a value besides :code:`0` in order for this to have any effect.
  2355. Using the Python :code:`assert` statement is the same as calling this function, except that the second argument
  2356. must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must
  2357. be set for this :code:`assert` statement to have any effect.
  2358. .. highlight:: python
  2359. .. code-block:: python
  2360. tl.device_assert(pid == 0)
  2361. assert pid == 0, f"pid != 0"
  2362. :param cond: the condition to assert. This is required to be a boolean tensor.
  2363. :param msg: the message to print if the assertion fails. This is required to be a string literal.
  2364. '''
  2365. msg = _unwrap_if_constexpr(msg)
  2366. mask = _unwrap_if_constexpr(mask)
  2367. if mask is not None:
  2368. mask = _semantic.to_tensor(mask)
  2369. return _semantic.device_assert(_semantic.to_tensor(cond), msg, mask)
  2370. @builtin
  2371. def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]],
  2372. is_pure: bool, pack: int, _semantic=None):
  2373. '''
  2374. Execute inline assembly over a tensor. Essentially, this is :code:`map`
  2375. where the function is inline assembly.
  2376. The input tensors :code:`args` are implicitly broadcasted to the same shape.
  2377. :code:`dtype` can be a tuple of types, in which case the output is a
  2378. tuple of tensors.
  2379. Each invocation of the inline asm processes :code:`pack` elements at a
  2380. time. Exactly which set of inputs a block receives is unspecified.
  2381. Input elements of size less than 4 bytes are packed into 4-byte
  2382. registers.
  2383. This op does not support empty :code:`dtype` -- the inline asm must
  2384. return at least one tensor, even if you don't need it. You can work
  2385. around this by returning a dummy tensor of arbitrary type; it shouldn't
  2386. cost you anything if you don't use it.
  2387. Example using
  2388. `PTX <https://docs.nvidia.com/cuda/parallel-thread-execution/index.html>`_
  2389. assembly:
  2390. .. highlight:: python
  2391. .. code-block:: python
  2392. @triton.jit
  2393. def kernel(A, B, C, D, BLOCK: tl.constexpr):
  2394. a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor
  2395. b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor
  2396. # For each (a,b) in zip(a,b), perform the following:
  2397. # - Let ai be `a` converted to int32.
  2398. # - Let af be `a` converted to float.
  2399. # - Let m be the max of ai and b.
  2400. # - Return ai and mi.
  2401. # Do the above 4 elements at a time.
  2402. (c, d) = tl.inline_asm_elementwise(
  2403. asm="""
  2404. {
  2405. // Unpack `a` into `ai`.
  2406. .reg .b8 tmp<4>;
  2407. mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8;
  2408. cvt.u32.u8 $0, tmp0;
  2409. cvt.u32.u8 $1, tmp1;
  2410. cvt.u32.u8 $2, tmp2;
  2411. cvt.u32.u8 $3, tmp3;
  2412. }
  2413. // Convert `ai` to float.
  2414. cvt.rn.f32.s32 $4, $0;
  2415. cvt.rn.f32.s32 $5, $1;
  2416. cvt.rn.f32.s32 $6, $2;
  2417. cvt.rn.f32.s32 $7, $3;
  2418. // Take max of `ai` and `b`.
  2419. max.f32 $4, $4, $9;
  2420. max.f32 $5, $5, $10;
  2421. max.f32 $6, $6, $11;
  2422. max.f32 $7, $7, $12;
  2423. """,
  2424. constraints=(
  2425. # 8 output registers, namely
  2426. # $0=ai0, $1=ai1, $2=ai2, $3=ai3,
  2427. # $4=m0, $5=m1, $6=m2, $7=m3.
  2428. "=r,=r,=r,=r,=r,=r,=r,=r,"
  2429. # 5 input registers, namely
  2430. # $8=ai,
  2431. # $9=b0, $10=b1, $11=b2, $12=b3.
  2432. # The four elements from `a` are all packed into one register.
  2433. "r,r,r,r,r"),
  2434. args=[a, b],
  2435. dtype=(tl.int32, tl.float32),
  2436. is_pure=True,
  2437. pack=4,
  2438. )
  2439. tl.store(C + tl.arange(0, BLOCK), c)
  2440. tl.store(D + tl.arange(0, BLOCK), d)
  2441. :param asm: assembly to run. Must match target's assembly format.
  2442. :param constraints: asm constraints in
  2443. `LLVM format <https://llvm.org/docs/LangRef.html#inline-asm-constraint-string>`_
  2444. :param args: the input tensors, whose values are passed to the asm block
  2445. :param dtype: the element type(s) of the returned tensor(s)
  2446. :param is_pure: if true, the compiler assumes the asm block has no side-effects
  2447. :param pack: the number of elements to be processed by one instance of inline assembly
  2448. :return: one tensor or a tuple of tensors of the given dtypes
  2449. '''
  2450. asm = _unwrap_if_constexpr(asm)
  2451. constraints = _unwrap_if_constexpr(constraints)
  2452. pack = _unwrap_if_constexpr(pack)
  2453. is_pure = _unwrap_if_constexpr(is_pure)
  2454. # Wrap `dtype` in a tuple if it's not already.
  2455. try:
  2456. iter(dtype) # type: ignore
  2457. has_multiple_outputs = True
  2458. except TypeError:
  2459. has_multiple_outputs = False
  2460. dtype = (dtype, ) # type: ignore
  2461. dtype = typing.cast(Sequence[_DtypeClass], dtype)
  2462. res_tys = dtype
  2463. if dispatch_args := [_semantic.to_tensor(arg) for arg in args]:
  2464. bin_op_type_checking = partial(
  2465. _semantic.binary_op_type_checking_impl,
  2466. arithmetic_check=False,
  2467. allow_lhs_ptr=True,
  2468. allow_rhs_ptr=True,
  2469. )
  2470. broadcast_arg = dispatch_args[0]
  2471. # Get the broadcast shape over all the arguments
  2472. for item in dispatch_args:
  2473. _, broadcast_arg = bin_op_type_checking(item, broadcast_arg)
  2474. if broadcast_arg.shape:
  2475. # Change the shape of each argument based on the broadcast shape
  2476. for i, item in enumerate(dispatch_args):
  2477. dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg)
  2478. res_tys = [broadcast_arg.type.with_element_ty(dt) for dt in dtype]
  2479. handles = [t.handle for t in dispatch_args]
  2480. builder = _semantic.builder
  2481. call = builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(builder) for ty in res_tys], is_pure, pack)
  2482. if not has_multiple_outputs:
  2483. return tensor(call.get_result(0), res_tys[0])
  2484. return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys))
  2485. # -----------------------
  2486. # Iterators
  2487. # -----------------------
  2488. class static_range(base_value):
  2489. """
  2490. Iterator that counts upward forever.
  2491. .. highlight:: python
  2492. .. code-block:: python
  2493. @triton.jit
  2494. def kernel(...):
  2495. for i in tl.static_range(10):
  2496. ...
  2497. :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
  2498. :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively.
  2499. :param arg1: the start value.
  2500. :param arg2: the end value.
  2501. :param step: the step value.
  2502. """
  2503. def __init__(self, arg1, arg2=None, step=None):
  2504. assert isinstance(arg1, constexpr), f"{arg1} used as tl.static_range start value is not a constexpr"
  2505. if step is None:
  2506. self.step = constexpr(1)
  2507. else:
  2508. assert isinstance(step, constexpr), f"{step} used as tl.static_range step value is not a constexpr"
  2509. self.step = step
  2510. if arg2 is None:
  2511. self.start = constexpr(0)
  2512. self.end = arg1
  2513. else:
  2514. assert isinstance(arg2, constexpr), f"{arg2} used as tl.static_range end value is not a constexpr"
  2515. self.start = arg1
  2516. self.end = arg2
  2517. def __iter__(self):
  2518. raise RuntimeError("static_range can only be used in @triton.jit'd functions")
  2519. def __next__(self):
  2520. raise RuntimeError("static_range can only be used in @triton.jit'd functions")
  2521. class range(base_value):
  2522. """
  2523. Iterator that counts upward forever.
  2524. .. highlight:: python
  2525. .. code-block:: python
  2526. @triton.jit
  2527. def kernel(...):
  2528. for i in tl.range(10, num_stages=3):
  2529. ...
  2530. :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of
  2531. :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler.
  2532. :param arg1: the start value.
  2533. :param arg2: the end value.
  2534. :param step: the step value.
  2535. :param num_stages: pipeline the loop into this many stages (so there are
  2536. :code:`num_stages` iterations of the loop in flight at once).
  2537. Note this is subtly different than passing :code:`num_stages` as a
  2538. kernel argument. The kernel argument only pipelines loads that feed
  2539. into :code:`dot` operations, while this attribute tries to pipeline most
  2540. (though not all) loads in this loop.
  2541. :param loop_unroll_factor: Tells the Triton IR level loop unroller how many
  2542. times to unroll a for loop that this range is used with. Less than 2 for
  2543. this value implies no unrolling.
  2544. :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot
  2545. operation in the loop to be multi-buffered, if applicable.
  2546. :param flatten: automatically flatten the loop nest starting at this loop to
  2547. create a single flattened loop. The compiler will try to pipeline the
  2548. flattened loop which can avoid stage stalling.
  2549. :param warp_specialize: Enable automatic warp specialization on the loop.
  2550. The compiler will attempt to partition memory, MMA, and vector
  2551. operations in the loop into separate async partitions. This will
  2552. increase the total number of warps required by the kernel.
  2553. :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
  2554. code outside the loop. This is often useful to avoid creating long liveranges
  2555. within a loop.
  2556. Note that warp specialization is only supported on Blackwell GPUs and
  2557. only works on simple matmul loops. Support for arbitrary loops will be
  2558. expanded over time.
  2559. """
  2560. def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
  2561. disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False):
  2562. if step is None:
  2563. self.step = constexpr(1)
  2564. else:
  2565. self.step = step
  2566. if arg2 is None:
  2567. self.start = constexpr(0)
  2568. self.end = arg1
  2569. else:
  2570. self.start = arg1
  2571. self.end = arg2
  2572. self.num_stages = num_stages
  2573. self.loop_unroll_factor = loop_unroll_factor
  2574. self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
  2575. self.flatten = flatten
  2576. self.warp_specialize = warp_specialize
  2577. self.disable_licm = disable_licm
  2578. def __iter__(self):
  2579. raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
  2580. def __next__(self):
  2581. raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
  2582. class condition(base_value):
  2583. """
  2584. While loop condition wrapper.
  2585. .. highlight:: python
  2586. .. code-block:: python
  2587. @triton.jit
  2588. def kernel(...):
  2589. while tl.condition(c, disable_licm)
  2590. ...
  2591. :note: This is a special wrapper used to annotate while loops in the context of
  2592. :code:`triton.jit` functions. It allows user to pass extra attributes to the compiler.
  2593. :param disable_licm: Tells the compiler it shouldn't hoist loop invariant
  2594. code outside the loop. This is often useful to avoid creating long liveranges
  2595. within a loop.
  2596. """
  2597. def __init__(self, arg1, disable_licm=False):
  2598. self.condition = arg1
  2599. self.disable_licm = disable_licm
  2600. # -----------------------
  2601. # Extern functions
  2602. # -----------------------
  2603. def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_type: dtype, is_pure: bool,
  2604. _semantic):
  2605. '''
  2606. Dispatch a function to a library
  2607. :param func: the function to dispatch
  2608. :param lib_name: the name of the library
  2609. :param lib_path: the path of the library
  2610. :param args: the arguments of the function
  2611. :param arg_type_symbol_dict: the type of the arguments
  2612. :param ret_type: the type of the return value
  2613. :return: the return value of the function
  2614. '''
  2615. if len(arg_type_symbol_dict) == 0:
  2616. raise ValueError("arg_type_symbol_dict is empty")
  2617. num_args = len(list(arg_type_symbol_dict.keys())[0])
  2618. if len(args) != num_args:
  2619. raise ValueError(f"length of input args does not match."
  2620. f"Expect {len(args)}, got {num_args}")
  2621. arg_types = []
  2622. arg_list = []
  2623. for arg in args:
  2624. if isinstance(arg, tensor):
  2625. arg_types.append(arg.dtype)
  2626. arg_list.append(arg.handle)
  2627. else:
  2628. arg_types.append(type(arg))
  2629. arg_list.append(arg)
  2630. arg_types = tuple(arg_types)
  2631. if arg_types not in arg_type_symbol_dict:
  2632. raise ValueError(f"input arg type does not match."
  2633. f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
  2634. else:
  2635. symbol = arg_type_symbol_dict[arg_types][0]
  2636. builder = _semantic.builder
  2637. return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type)
  2638. @builtin
  2639. def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
  2640. _semantic=None):
  2641. '''
  2642. Dispatch an elementwise function to a library
  2643. :param lib_name: the name of the library
  2644. :param lib_path: the path of the library
  2645. :param args: the arguments of the function
  2646. :param arg_type_symbol_dict: the type of the arguments
  2647. :param is_pure: whether the function is pure
  2648. :return: the return value of the function
  2649. '''
  2650. dispatch_args = args.copy()
  2651. all_scalar = True
  2652. arg_types = []
  2653. for i in builtins.range(len(dispatch_args)):
  2654. dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
  2655. arg_types.append(dispatch_args[i].dtype)
  2656. if dispatch_args[i].type.is_block():
  2657. all_scalar = False
  2658. arg_types = tuple(arg_types)
  2659. ret_type = arg_type_symbol_dict[arg_types][1]
  2660. if len(arg_types) > 0:
  2661. arithmetic_check = True
  2662. # If there's a type tuple that is not supported by the library, we will do arithmetic check
  2663. if arg_types in arg_type_symbol_dict:
  2664. arithmetic_check = False
  2665. broadcast_arg = dispatch_args[0]
  2666. # Get the broadcast shape over all the arguments
  2667. for item in dispatch_args:
  2668. _, broadcast_arg = _semantic.binary_op_type_checking_impl(item, broadcast_arg,
  2669. arithmetic_check=arithmetic_check)
  2670. # Change the shape of each argument based on the broadcast shape
  2671. for i in builtins.range(len(dispatch_args)):
  2672. dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
  2673. arithmetic_check=arithmetic_check)
  2674. if not all_scalar:
  2675. ret_type = broadcast_arg.type.with_element_ty(ret_type)
  2676. func = _semantic.builder.create_extern_elementwise
  2677. return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic)
  2678. def binary_op_type_legalization(lhs, rhs, semantic):
  2679. '''
  2680. Convert both operands to a single common type
  2681. :param lhs: the left operand
  2682. :param rhs: the right operand
  2683. :param builder: the builder
  2684. '''
  2685. return semantic.binary_op_type_checking_impl(lhs, rhs)
  2686. def extern(fn):
  2687. """A decorator for external functions."""
  2688. return builtin(fn)
  2689. _NOTHING = object()
  2690. def is_negative_zero(x):
  2691. return x == 0.0 and math.copysign(1.0, x) < 0
  2692. @builtin
  2693. def builtin_max(*args, propagate_nan=_NOTHING, _semantic=None):
  2694. args = _unwrap_if_constexpr(args)
  2695. is_constexpr = all(not isinstance(x, base_value) for x in args)
  2696. if is_constexpr:
  2697. assert propagate_nan is _NOTHING, "propagate_nan is not supported on builtin max"
  2698. assert not any(math.isnan(x) for x in args)
  2699. assert not any(is_negative_zero(x) for x in args)
  2700. return constexpr(builtins.max(_unwrap_if_constexpr(args)))
  2701. if propagate_nan is _NOTHING:
  2702. propagate_nan = PropagateNan.NONE
  2703. else:
  2704. warn("passing propagate_nan to builtin max is deprecated, use tl.minimum instead", DeprecationWarning)
  2705. assert len(args) >= 2, "min requires at least 2 values"
  2706. max_val = args[0]
  2707. for arg in args[1:]:
  2708. max_val = maximum(max_val, arg, propagate_nan=propagate_nan, _semantic=_semantic)
  2709. if max_val.type.is_block():
  2710. warn("builtin max on non-scalar tensor values is deprecated, use tl.maximum instead", DeprecationWarning)
  2711. return max_val
  2712. @builtin
  2713. def builtin_min(*args, propagate_nan=_NOTHING, _semantic=None):
  2714. args = _unwrap_if_constexpr(args)
  2715. is_constexpr = all(not isinstance(x, base_value) for x in args)
  2716. if is_constexpr:
  2717. assert propagate_nan is _NOTHING, "propagate_nan is not supported on builtin min"
  2718. assert not any(math.isnan(x) for x in args)
  2719. assert not any(is_negative_zero(x) for x in args)
  2720. return constexpr(builtins.min(_unwrap_if_constexpr(args)))
  2721. if propagate_nan is _NOTHING:
  2722. propagate_nan = PropagateNan.NONE
  2723. else:
  2724. warn("passing propagate_nan to builtin min is deprecated, use tl.minimum instead", DeprecationWarning)
  2725. assert len(args) >= 2, "min requires at least 2 values"
  2726. min_val = args[0]
  2727. for arg in args[1:]:
  2728. min_val = minimum(min_val, arg, propagate_nan=propagate_nan, _semantic=_semantic)
  2729. if min_val.type.is_block():
  2730. warn("builtin min on non-scalar tensor values is deprecated, use tl.minimum instead", DeprecationWarning)
  2731. return min_val