__init__.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. """isort:skip_file"""
  2. # Import order is significant here.
  3. from . import math
  4. from . import extra
  5. from .standard import (
  6. argmax,
  7. argmin,
  8. bitonic_merge,
  9. cdiv,
  10. cumprod,
  11. cumsum,
  12. flip,
  13. interleave,
  14. max,
  15. min,
  16. ravel,
  17. reduce_or,
  18. sigmoid,
  19. softmax,
  20. sort,
  21. sum,
  22. swizzle2d,
  23. topk,
  24. xor_sum,
  25. zeros,
  26. zeros_like,
  27. )
  28. from .core import (
  29. PropagateNan,
  30. TRITON_MAX_TENSOR_NUMEL,
  31. load_tensor_descriptor,
  32. store_tensor_descriptor,
  33. make_tensor_descriptor,
  34. tensor_descriptor,
  35. tensor_descriptor_type,
  36. add,
  37. advance,
  38. arange,
  39. associative_scan,
  40. assume,
  41. atomic_add,
  42. atomic_and,
  43. atomic_cas,
  44. atomic_max,
  45. atomic_min,
  46. atomic_or,
  47. atomic_xchg,
  48. atomic_xor,
  49. bfloat16,
  50. block_type,
  51. broadcast,
  52. broadcast_to,
  53. cat,
  54. cast,
  55. clamp,
  56. condition,
  57. const,
  58. constexpr,
  59. constexpr_type,
  60. debug_barrier,
  61. device_assert,
  62. device_print,
  63. dot,
  64. dot_scaled,
  65. dtype,
  66. expand_dims,
  67. float16,
  68. float32,
  69. float64,
  70. float8e4b15,
  71. float8e4nv,
  72. float8e4b8,
  73. float8e5,
  74. float8e5b16,
  75. full,
  76. gather,
  77. histogram,
  78. inline_asm_elementwise,
  79. int1,
  80. int16,
  81. int32,
  82. int64,
  83. int8,
  84. join,
  85. load,
  86. make_block_ptr,
  87. map_elementwise,
  88. max_constancy,
  89. max_contiguous,
  90. maximum,
  91. minimum,
  92. mul,
  93. multiple_of,
  94. num_programs,
  95. permute,
  96. pi32_t,
  97. pointer_type,
  98. program_id,
  99. range,
  100. reduce,
  101. reshape,
  102. slice,
  103. split,
  104. static_assert,
  105. static_print,
  106. static_range,
  107. store,
  108. sub,
  109. tensor,
  110. trans,
  111. tuple,
  112. tuple_type,
  113. uint16,
  114. uint32,
  115. uint64,
  116. uint8,
  117. view,
  118. void,
  119. where,
  120. )
  121. from .math import (umulhi, exp, exp2, fma, log, log2, cos, rsqrt, sin, sqrt, sqrt_rn, abs, fdiv, div_rn, erf, floor,
  122. ceil)
  123. from .random import (
  124. pair_uniform_to_normal,
  125. philox,
  126. philox_impl,
  127. rand,
  128. rand4x,
  129. randint,
  130. randint4x,
  131. randn,
  132. randn4x,
  133. uint_to_uniform_float,
  134. )
  135. from . import target_info
  136. __all__ = [
  137. "PropagateNan",
  138. "TRITON_MAX_TENSOR_NUMEL",
  139. "load_tensor_descriptor",
  140. "store_tensor_descriptor",
  141. "make_tensor_descriptor",
  142. "tensor_descriptor",
  143. "abs",
  144. "add",
  145. "advance",
  146. "arange",
  147. "argmax",
  148. "argmin",
  149. "associative_scan",
  150. "assume",
  151. "atomic_add",
  152. "atomic_and",
  153. "atomic_cas",
  154. "atomic_max",
  155. "atomic_min",
  156. "atomic_or",
  157. "atomic_xchg",
  158. "atomic_xor",
  159. "bfloat16",
  160. "bitonic_merge",
  161. "block_type",
  162. "broadcast",
  163. "broadcast_to",
  164. "cat",
  165. "cast",
  166. "cdiv",
  167. "ceil",
  168. "clamp",
  169. "condition",
  170. "const",
  171. "constexpr",
  172. "constexpr_type",
  173. "cos",
  174. "cumprod",
  175. "cumsum",
  176. "debug_barrier",
  177. "device_assert",
  178. "device_print",
  179. "div_rn",
  180. "dot",
  181. "dot_scaled",
  182. "dtype",
  183. "erf",
  184. "exp",
  185. "exp2",
  186. "expand_dims",
  187. "extra",
  188. "fdiv",
  189. "flip",
  190. "float16",
  191. "float32",
  192. "float64",
  193. "float8e4b15",
  194. "float8e4nv",
  195. "float8e4b8",
  196. "float8e5",
  197. "float8e5b16",
  198. "floor",
  199. "fma",
  200. "full",
  201. "gather",
  202. "histogram",
  203. "inline_asm_elementwise",
  204. "interleave",
  205. "int1",
  206. "int16",
  207. "int32",
  208. "int64",
  209. "int8",
  210. "join",
  211. "load",
  212. "log",
  213. "log2",
  214. "make_block_ptr",
  215. "map_elementwise",
  216. "math",
  217. "max",
  218. "max_constancy",
  219. "max_contiguous",
  220. "maximum",
  221. "min",
  222. "minimum",
  223. "mul",
  224. "multiple_of",
  225. "num_programs",
  226. "pair_uniform_to_normal",
  227. "permute",
  228. "philox",
  229. "philox_impl",
  230. "pi32_t",
  231. "pointer_type",
  232. "program_id",
  233. "rand",
  234. "rand4x",
  235. "randint",
  236. "randint4x",
  237. "randn",
  238. "randn4x",
  239. "range",
  240. "ravel",
  241. "reduce",
  242. "reduce_or",
  243. "reshape",
  244. "rsqrt",
  245. "slice",
  246. "sigmoid",
  247. "sin",
  248. "softmax",
  249. "sort",
  250. "split",
  251. "sqrt",
  252. "sqrt_rn",
  253. "static_assert",
  254. "static_print",
  255. "static_range",
  256. "store",
  257. "sub",
  258. "sum",
  259. "swizzle2d",
  260. "target_info",
  261. "tensor",
  262. "topk",
  263. "trans",
  264. "tuple",
  265. "uint16",
  266. "uint32",
  267. "uint64",
  268. "uint8",
  269. "uint_to_uniform_float",
  270. "umulhi",
  271. "view",
  272. "void",
  273. "where",
  274. "xor_sum",
  275. "zeros",
  276. "zeros_like",
  277. ]
  278. def str_to_ty(name, c):
  279. from builtins import tuple
  280. if isinstance(name, tuple):
  281. fields = type(name).__dict__.get("_fields", None)
  282. return tuple_type([str_to_ty(x, c) for x in name], fields)
  283. if name[0] == "*":
  284. name = name[1:]
  285. const = False
  286. if name[0] == "k":
  287. name = name[1:]
  288. const = True
  289. ty = str_to_ty(name, c)
  290. return pointer_type(element_ty=ty, const=const)
  291. if name.startswith("tensordesc"):
  292. inner = name.split("<")[1].rstrip(">")
  293. dtype, rest = inner.split("[", maxsplit=1)
  294. block_shape, rest = rest.split("]", maxsplit=1)
  295. block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
  296. layout = rest.lstrip(",")
  297. is_gluon = len(layout)
  298. dtype = str_to_ty(dtype, None)
  299. ndim = len(block_shape)
  300. shape_type = tuple_type([int32] * ndim)
  301. # FIXME: Last dim stride should be constexpr(1)
  302. stride_type = tuple_type(([int64] * ndim))
  303. block = block_type(dtype, block_shape)
  304. if is_gluon:
  305. from triton.experimental.gluon.language._layouts import NVMMASharedLayout, PaddedSharedLayout, SwizzledSharedLayout
  306. from triton.experimental.gluon.language.nvidia.hopper.tma import tensor_descriptor_type as nvidia_tensor_descriptor_type
  307. from triton.experimental.gluon.language.amd.gfx1250.tdm import tensor_descriptor_type as amd_tensor_descriptor_type
  308. layout = eval(
  309. layout,
  310. dict(NVMMASharedLayout=NVMMASharedLayout, PaddedSharedLayout=PaddedSharedLayout,
  311. SwizzledSharedLayout=SwizzledSharedLayout))
  312. if isinstance(layout, NVMMASharedLayout):
  313. return nvidia_tensor_descriptor_type(block, shape_type, stride_type, layout)
  314. else:
  315. return amd_tensor_descriptor_type(block, shape_type, stride_type, layout)
  316. return tensor_descriptor_type(block, shape_type, stride_type)
  317. if name.startswith("constexpr"):
  318. return constexpr_type(c)
  319. tys = {
  320. "fp8e4nv": float8e4nv,
  321. "fp8e4b8": float8e4b8,
  322. "fp8e5": float8e5,
  323. "fp8e5b16": float8e5b16,
  324. "fp8e4b15": float8e4b15,
  325. "fp16": float16,
  326. "bf16": bfloat16,
  327. "fp32": float32,
  328. "fp64": float64,
  329. "i1": int1,
  330. "i8": int8,
  331. "i16": int16,
  332. "i32": int32,
  333. "i64": int64,
  334. "u1": int1,
  335. "u8": uint8,
  336. "u16": uint16,
  337. "u32": uint32,
  338. "u64": uint64,
  339. "B": int1,
  340. }
  341. return tys[name]