| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- import triton
- import triton.language as tl
- from triton.tools.tensor_descriptor import TensorDescriptor
- # fmt: off
- def create_ragged_descriptor(T, block_shape, ragged_dim=0):
- """
- Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
- which behaves like a concatenation (along the first axis) of subarrays
- of potentially unequal size.
- The load_ragged and store_ragged device functions can be used to read
- and write from subarrays T[batch_offset : batch_offset + batch_size]
- with hardware bounds-checking preventing any sort of leakage outside
- the subarray.
- """
- block_shape = list(block_shape)
- tensor_shape = list(T.shape)
- rank = len(tensor_shape)
- if ragged_dim < 0:
- ragged_dim += rank
- assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
- assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
- assert len(block_shape) == rank, "block shape must have same length as tensor shape"
- max_int = 0x7fff0000
- billion = 0x40000000 # == 2**30
- assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
- tensor_shape[ragged_dim] = billion
- ragged_stride = T.stride(ragged_dim)
- # we prepend an extra two dimensions and rely on the fact that pointers
- # have 64-bit wraparound semantics:
- tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
- tma_shape = [max_int, max_int] + tensor_shape
- box_shape = [1, 1] + block_shape
- return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
- @triton.jit
- def to_ragged_indices(batch_offset, batch_size, row):
- """
- Helper function for load_ragged and store_ragged.
- """
- billion = 0x40000000 # == 2**30
- x = billion - batch_size + row
- y = batch_offset + batch_size
- return billion, y, x
- @triton.jit
- def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
- """
- Read from a subarray T[batch_offset : batch_offset + batch_size] with
- hardware bounds-checking, where reading outside the subarray gives zeros.
- Coords should be an appropriately-sized list of integers, just like in
- TMA.load().
- """
- tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
- c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
- data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
- data = tl.reshape(data, data.shape[2:])
- return data
- @triton.jit
- def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
- """
- Write to a subarray T[batch_offset : batch_offset + batch_size] with
- hardware bounds-checking, where writes outside the subarray are masked
- correctly.
- Coords should be an appropriately-sized list of integers, just like in
- TMA.store().
- """
- c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
- data = tl.reshape(data, [1, 1] + data.shape)
- TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
- @triton.jit
- def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
- """
- Atomic add into a subarray T[batch_offset : batch_offset + batch_size] with
- hardware bounds-checking, where adds outside the subarray are masked
- correctly.
- Coords should be an appropriately-sized list of integers, just like in
- TMA.atomic_add().
- """
- c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
- data = tl.reshape(data, [1, 1] + data.shape)
- TMA.atomic_add([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
|