ragged_tma.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import triton
  2. import triton.language as tl
  3. from triton.tools.tensor_descriptor import TensorDescriptor
  4. # fmt: off
  5. def create_ragged_descriptor(T, block_shape, ragged_dim=0):
  6. """
  7. Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
  8. which behaves like a concatenation (along the first axis) of subarrays
  9. of potentially unequal size.
  10. The load_ragged and store_ragged device functions can be used to read
  11. and write from subarrays T[batch_offset : batch_offset + batch_size]
  12. with hardware bounds-checking preventing any sort of leakage outside
  13. the subarray.
  14. """
  15. block_shape = list(block_shape)
  16. tensor_shape = list(T.shape)
  17. rank = len(tensor_shape)
  18. if ragged_dim < 0:
  19. ragged_dim += rank
  20. assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
  21. assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
  22. assert len(block_shape) == rank, "block shape must have same length as tensor shape"
  23. max_int = 0x7fff0000
  24. billion = 0x40000000 # == 2**30
  25. assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
  26. tensor_shape[ragged_dim] = billion
  27. ragged_stride = T.stride(ragged_dim)
  28. # we prepend an extra two dimensions and rely on the fact that pointers
  29. # have 64-bit wraparound semantics:
  30. tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
  31. tma_shape = [max_int, max_int] + tensor_shape
  32. box_shape = [1, 1] + block_shape
  33. return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
  34. @triton.jit
  35. def to_ragged_indices(batch_offset, batch_size, row):
  36. """
  37. Helper function for load_ragged and store_ragged.
  38. """
  39. billion = 0x40000000 # == 2**30
  40. x = billion - batch_size + row
  41. y = batch_offset + batch_size
  42. return billion, y, x
  43. @triton.jit
  44. def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
  45. """
  46. Read from a subarray T[batch_offset : batch_offset + batch_size] with
  47. hardware bounds-checking, where reading outside the subarray gives zeros.
  48. Coords should be an appropriately-sized list of integers, just like in
  49. TMA.load().
  50. """
  51. tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
  52. c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
  53. data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
  54. data = tl.reshape(data, data.shape[2:])
  55. return data
  56. @triton.jit
  57. def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
  58. """
  59. Write to a subarray T[batch_offset : batch_offset + batch_size] with
  60. hardware bounds-checking, where writes outside the subarray are masked
  61. correctly.
  62. Coords should be an appropriately-sized list of integers, just like in
  63. TMA.store().
  64. """
  65. c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
  66. data = tl.reshape(data, [1, 1] + data.shape)
  67. TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
  68. @triton.jit
  69. def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
  70. """
  71. Atomic add into a subarray T[batch_offset : batch_offset + batch_size] with
  72. hardware bounds-checking, where adds outside the subarray are masked
  73. correctly.
  74. Coords should be an appropriately-sized list of integers, just like in
  75. TMA.atomic_add().
  76. """
  77. c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
  78. data = tl.reshape(data, [1, 1] + data.shape)
  79. TMA.atomic_add([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)