hopper.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from dataclasses import dataclass
  2. from typing import List, Any
  3. from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth
  4. from triton.experimental.gluon.language._layouts import NVMMASharedLayout
  5. __all__ = ["TensorDescriptor"]
  6. @dataclass
  7. class TensorDescriptor:
  8. base: Any
  9. shape: List[int]
  10. strides: List[int]
  11. block_shape: List[int]
  12. layout: NVMMASharedLayout
  13. padding: str = "zero"
  14. def __post_init__(self):
  15. rank = len(self.shape)
  16. assert len(self.strides) == rank, f"rank mismatch: {self}"
  17. assert len(self.block_shape) == rank, f"rank mismatch: {self}"
  18. assert rank > 0, "rank must not be zero"
  19. assert rank <= 5, "rank cannot be more than 5"
  20. assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
  21. validate_block_shape(self.block_shape)
  22. dtype_str = canonicalize_dtype(self.base.dtype)
  23. elem_bytes = get_primitive_bitwidth(dtype_str) // 8
  24. for stride in self.strides[:-1]:
  25. assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
  26. for shape_dim in self.shape:
  27. assert shape_dim > 0, "shape must be positive"
  28. assert self.strides[-1] == 1, "Last dimension must be contiguous"
  29. assert isinstance(self.layout, NVMMASharedLayout), "Layout must be NVMMASharedLayout"
  30. assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
  31. if self.padding == "nan":
  32. assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
  33. @staticmethod
  34. def from_tensor(tensor: Any, block_shape: List[int], layout: NVMMASharedLayout, padding="zero"):
  35. return TensorDescriptor(
  36. tensor,
  37. tensor.shape,
  38. tensor.stride(),
  39. block_shape,
  40. layout,
  41. padding,
  42. )