standard.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. from __future__ import annotations
  2. from ..runtime.jit import jit, constexpr_function
  3. from . import core
  4. from . import math
  5. # constexpr utilities
  6. @constexpr_function
  7. def _log2(i):
  8. log2 = 0
  9. n = i
  10. while n > 1:
  11. n >>= 1
  12. log2 += 1
  13. return log2
  14. @constexpr_function
  15. def _is_power_of_two(i):
  16. return (i & (i - 1)) == 0 and i != 0
  17. _get_int_dtype = constexpr_function(core.get_int_dtype)
  18. # -----------------------
  19. # Standard library
  20. # -----------------------
  21. @core._tensor_member_fn
  22. @jit
  23. def cdiv(x, div):
  24. """
  25. Computes the ceiling division of :code:`x` by :code:`div`
  26. :param x: the input number
  27. :type x: Block
  28. :param div: the divisor
  29. :type div: Block
  30. """
  31. return (x + (div - 1)) // div
  32. @core._tensor_member_fn
  33. @jit
  34. @math._add_math_1arg_docstr("sigmoid")
  35. def sigmoid(x):
  36. return 1 / (1 + math.exp(-x))
  37. @core._tensor_member_fn
  38. @jit
  39. @math._add_math_1arg_docstr("softmax")
  40. def softmax(x, dim=None, keep_dims=False, ieee_rounding=False):
  41. if dim is None:
  42. _dim: core.constexpr = 0
  43. else:
  44. _dim: core.constexpr = dim
  45. z = x - max(x, _dim, keep_dims=keep_dims)
  46. num = math.exp(z)
  47. den = sum(num, _dim, keep_dims=keep_dims)
  48. return math.fdiv(num, den, ieee_rounding)
  49. @core._tensor_member_fn
  50. @jit
  51. def ravel(x, can_reorder=False):
  52. """
  53. Returns a contiguous flattened view of :code:`x`.
  54. :param x: the input tensor
  55. :type x: Block
  56. """
  57. return core.reshape(x, [x.numel], can_reorder=can_reorder)
  58. @jit
  59. def swizzle2d(i, j, size_i, size_j, size_g):
  60. """
  61. Transforms the indices of a row-major `size_i * size_j` matrix into
  62. the indices of a column-major matrix for each group of `size_g` rows.
  63. For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will
  64. transform ::
  65. [[0 , 1 , 2 , 3 ],
  66. [4 , 5 , 6 , 7 ],
  67. [8 , 9 , 10, 11],
  68. [12, 13, 14, 15]]
  69. into ::
  70. [[0, 2, 4 , 6 ],
  71. [1, 3, 5 , 7 ],
  72. [8, 10, 12, 14],
  73. [9, 11, 13, 15]]
  74. """
  75. # "unrolled index in array"
  76. ij = i * size_j + j
  77. # number of elements in `size_g` groups
  78. # of `size_j` columns
  79. size_gj = size_g * size_j
  80. # index of the group in which (i,j) is
  81. group_id = ij // size_gj
  82. # row-index of the first element of this group
  83. off_i = group_id * size_g
  84. # last group may have fewer rows
  85. size_g = core.minimum(size_i - off_i, size_g)
  86. # linear index with respect to the first element in this group
  87. ij = ij % size_gj
  88. # new row and column indices
  89. new_i = off_i + ij % size_g
  90. new_j = ij // size_g
  91. return new_i, new_j
  92. @jit
  93. def zeros(shape, dtype):
  94. """
  95. Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
  96. :param shape: Shape of the new array, e.g., (8, 16) or (8, )
  97. :type shape: tuple of ints
  98. :param dtype: Data-type of the new array, e.g., :code:`tl.float16`
  99. :type dtype: DType
  100. """
  101. return core.full(shape, 0, dtype)
  102. @jit
  103. def zeros_like(input):
  104. """
  105. Returns a tensor of zeros with the same shape and type as a given tensor.
  106. :param input: input tensor
  107. :type input: Tensor
  108. """
  109. return zeros(input.shape, input.dtype)
  110. # max and argmax
  111. @jit
  112. def _argmax_combine(value1, index1, value2, index2, tie_break_left):
  113. if tie_break_left:
  114. tie = value1 == value2 and index1 < index2
  115. else:
  116. tie = False
  117. gt = value1 > value2 or tie
  118. v_ret = core.where(gt, value1, value2)
  119. i_ret = core.where(gt, index1, index2)
  120. return v_ret, i_ret
  121. @jit
  122. def _argmax_combine_tie_break_left(value1, index1, value2, index2):
  123. return _argmax_combine(value1, index1, value2, index2, True)
  124. @jit
  125. def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
  126. return _argmax_combine(value1, index1, value2, index2, False)
  127. @jit
  128. def _elementwise_max(a, b):
  129. return core.maximum(a, b)
  130. @core._tensor_member_fn
  131. @jit
  132. @core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
  133. tie_break_arg="return_indices_tie_break_left")
  134. def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
  135. input = core._promote_bfloat16_to_float32(input)
  136. if return_indices:
  137. if return_indices_tie_break_left:
  138. return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims)
  139. else:
  140. return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims)
  141. else:
  142. if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
  143. if core.constexpr(input.dtype.is_floating()):
  144. input = input.to(core.float32)
  145. else:
  146. assert input.dtype.is_int(), "Expecting input to be integer type"
  147. input = input.to(core.int32)
  148. return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims)
  149. @core._tensor_member_fn
  150. @jit
  151. @core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
  152. def argmax(input, axis, tie_break_left=True, keep_dims=False):
  153. (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
  154. return ret
  155. # min and argmin
  156. @jit
  157. def _argmin_combine(value1, index1, value2, index2, tie_break_left):
  158. if tie_break_left:
  159. tie = value1 == value2 and index1 < index2
  160. else:
  161. tie = False
  162. lt = value1 < value2 or tie
  163. value_ret = core.where(lt, value1, value2)
  164. index_ret = core.where(lt, index1, index2)
  165. return value_ret, index_ret
  166. @jit
  167. def _argmin_combine_tie_break_left(value1, index1, value2, index2):
  168. return _argmin_combine(value1, index1, value2, index2, True)
  169. @jit
  170. def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
  171. return _argmin_combine(value1, index1, value2, index2, False)
  172. @jit
  173. def _elementwise_min(a, b):
  174. return core.minimum(a, b)
  175. @core._tensor_member_fn
  176. @jit
  177. @core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
  178. tie_break_arg="return_indices_tie_break_left")
  179. def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
  180. input = core._promote_bfloat16_to_float32(input)
  181. if return_indices:
  182. if return_indices_tie_break_left:
  183. return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims)
  184. else:
  185. return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims)
  186. else:
  187. if core.constexpr(input.dtype.primitive_bitwidth) < 32:
  188. if core.constexpr(input.dtype.is_floating()):
  189. input = input.to(core.float32)
  190. else:
  191. assert input.dtype.is_int(), "Expecting input to be integer type"
  192. input = input.to(core.int32)
  193. return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims)
  194. @core._tensor_member_fn
  195. @jit
  196. @core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
  197. def argmin(input, axis, tie_break_left=True, keep_dims=False):
  198. _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
  199. return ret
  200. @jit
  201. def _sum_combine(a, b):
  202. return a + b
  203. # sum
  204. @constexpr_function
  205. def _pick_sum_dtype(in_dtype, dtype):
  206. if dtype is not None:
  207. return dtype
  208. # For integer bitwidths less than 32, pick int32 with the same sign to
  209. # avoid overflow.
  210. out_dtype = None
  211. if in_dtype.is_int_signed():
  212. out_dtype = core.int32 if in_dtype.int_bitwidth < 32 else None
  213. elif in_dtype.is_int_unsigned():
  214. out_dtype = core.uint32 if in_dtype.int_bitwidth < 32 else None
  215. return out_dtype
  216. @core._tensor_member_fn
  217. @jit
  218. @core._add_reduction_docstr("sum", dtype_arg="dtype")
  219. def sum(input, axis=None, keep_dims=False, dtype: core.constexpr = None):
  220. # Pick a default dtype for the reduction if one was not specified.
  221. out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
  222. if out_dtype is not None:
  223. input = input.to(out_dtype)
  224. return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
  225. @jit
  226. def _xor_combine(a, b):
  227. return a ^ b
  228. # xor sum
  229. @core._tensor_member_fn
  230. @jit
  231. @core._add_reduction_docstr("xor sum")
  232. def xor_sum(input, axis=None, keep_dims=False):
  233. core.static_assert(input.type.scalar.is_int(), "xor_sum only supported for integers")
  234. return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
  235. # or reduction
  236. @jit
  237. def _or_combine(x, y):
  238. return x | y
  239. @core._tensor_member_fn
  240. @jit
  241. @core._add_reduction_docstr("reduce_or")
  242. def reduce_or(input, axis, keep_dims=False):
  243. core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
  244. return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
  245. # cumsum
  246. @core._tensor_member_fn
  247. @jit
  248. @core._add_scan_docstr("cumsum", dtype_arg="dtype")
  249. def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None):
  250. # todo rename this to a generic function name
  251. input = core._promote_bfloat16_to_float32(input)
  252. out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
  253. if out_dtype is not None:
  254. input = input.to(out_dtype)
  255. return core.associative_scan(input, axis, _sum_combine, reverse)
  256. # cumprod
  257. @jit
  258. def _prod_combine(a, b):
  259. return a * b
  260. @core._tensor_member_fn
  261. @jit
  262. @core._add_scan_docstr("cumprod")
  263. def cumprod(input, axis=0, reverse=False):
  264. # todo rename this to a generic function name
  265. input = core._promote_bfloat16_to_float32(input)
  266. return core.associative_scan(input, axis, _prod_combine, reverse)
  267. # sort
  268. @jit
  269. def _indicator(n_dims: core.constexpr, j: core.constexpr):
  270. ar = core.arange(0, 2)
  271. ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
  272. return ar
  273. @jit
  274. def _compare_and_swap(x, flip, i: core.constexpr):
  275. # compare-and-swap on the ith *innermost* dimension
  276. n_dims: core.constexpr = _log2(x.numel)
  277. # flip along middle dimension (the bitwise XORs will be optimised away):
  278. idtype = _get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
  279. ix = x.to(idtype, bitcast=True)
  280. iy = ix ^ xor_sum(ix, n_dims - 1 - i, True)
  281. y = iy.to(x.dtype, bitcast=True)
  282. # determines whether we are in the right (rather than left) position along the axis:
  283. is_right = _indicator(n_dims, i)
  284. # conditional swap:
  285. ret = core.where((x > y) != (flip ^ is_right), y, x)
  286. return ret
  287. @jit
  288. def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr):
  289. '''
  290. order_type 0 == ascending
  291. order_type 1 == descending
  292. order_type 2 == alternating
  293. '''
  294. # flip denotes whether to re-arrange sub-sequences of elements in ascending or
  295. # descending order.
  296. # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
  297. # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
  298. # a stride of 2) at this stage
  299. if order == 2:
  300. flip = _indicator(_log2(x.numel), stage)
  301. else:
  302. flip = order
  303. # perform `stage` rounds of `compare-and-swap`
  304. for i in core.static_range(stage):
  305. x = _compare_and_swap(x, flip, stage - 1 - i)
  306. return x
  307. @jit
  308. def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
  309. h = core.reshape(x, [2] * _log2(x.numel))
  310. h = _bitonic_merge_hypercube(h, stage, order)
  311. x = core.reshape(h, x.shape)
  312. return x
  313. @jit
  314. def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
  315. """
  316. Sorts a tensor along a specified dimension.
  317. :param x: The input tensor to be sorted.
  318. :type x: Tensor
  319. :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
  320. :type dim: int, optional
  321. :param k: the number of top elements to select. If none, assume k = x.shape[dim]
  322. :type k: int, optional
  323. :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
  324. :type descending: bool, optional
  325. """
  326. # handle default dimension or check that it is the most minor dim
  327. _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
  328. core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
  329. log_n: core.constexpr = _log2(x.shape[_dim])
  330. log_k: core.constexpr = log_n if k is None else _log2(k)
  331. n_dims: core.constexpr = _log2(x.numel)
  332. # reshape to hypercube:
  333. h = core.reshape(x, [2] * n_dims if n_dims else [1])
  334. # run first log_k bitonic sort iterations:
  335. for i in core.static_range(1, log_k + 1):
  336. h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending)
  337. # select top k elements using bitonic top-k
  338. # https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
  339. for i in core.static_range(log_k + 1, log_n + 1):
  340. h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k))
  341. h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending)
  342. # reshape back:
  343. x = core.reshape(h, x.shape[:-1] + [2**log_k])
  344. return x
  345. @jit
  346. def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
  347. return sort_impl(x, dim=dim, descending=descending)
  348. @jit
  349. def topk(x, k: core.constexpr, dim: core.constexpr = None):
  350. return sort_impl(x, k=k, dim=dim, descending=True)
  351. @jit
  352. def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
  353. # handle default dimension or check that it is the most minor dim
  354. _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
  355. core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
  356. n_dims: core.constexpr = _log2(x.shape[-1])
  357. return _bitonic_merge(x, n_dims, descending, n_dims)
  358. @constexpr_function
  359. def _get_flip_dim(dim, shape):
  360. if dim is None:
  361. dim = len(shape) - 1
  362. if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
  363. dim += len(shape)
  364. return dim
  365. @core._tensor_member_fn
  366. @jit
  367. def flip(x, dim=None):
  368. """
  369. Flips a tensor `x` along the dimension `dim`.
  370. :param x: the first input tensor
  371. :type x: Block
  372. :param dim: the dimension to flip along
  373. :type dim: int
  374. """
  375. core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
  376. _dim: core.constexpr = _get_flip_dim(dim, x.shape)
  377. core.static_assert(_is_power_of_two(x.shape[_dim]))
  378. steps: core.constexpr = _log2(x.shape[_dim])
  379. # reshape the swap dimension to (2, 2, ..., 2)
  380. idtype = _get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
  381. y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
  382. for i in core.static_range(steps):
  383. y = y ^ xor_sum(y, _dim + i, True)
  384. x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
  385. return x
  386. @jit
  387. def interleave(a, b):
  388. """
  389. Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape.
  390. Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])`
  391. :param a: The first input tensor.
  392. :type a: Tensor
  393. :param b: The second input tensor.
  394. :type b: Tensor
  395. """
  396. c = core.join(a, b)
  397. if len(c.shape) == 1:
  398. # We must have interleaved two scalars.
  399. return c
  400. else:
  401. # This `else` is necessary because Triton's AST parser doesn't
  402. # understand that if we take the `if` above we definitely don't run this
  403. # `else`.
  404. return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]])