gfx1250.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from dataclasses import dataclass
  2. from typing import List, Any
  3. from triton._utils import validate_block_shape
  4. from triton.experimental.gluon.language._layouts import PaddedSharedLayout, SwizzledSharedLayout
  5. __all__ = ["TensorDescriptor"]
  6. @dataclass
  7. class TensorDescriptor:
  8. base: Any
  9. shape: List[int]
  10. strides: List[int]
  11. block_shape: List[int]
  12. layout: PaddedSharedLayout | SwizzledSharedLayout
  13. padding: str = "zero"
  14. def __post_init__(self):
  15. ndim = len(self.shape)
  16. # TODO: support 1D-5D tensor descriptors
  17. assert ndim == 2, f"Expected 2 dimensions but got {ndim} dimensions"
  18. assert len(self.strides) == ndim, f"Expected {ndim} strides but got {len(self.strides)}"
  19. assert len(self.block_shape) == ndim, \
  20. f"Expected block_shape to have {ndim} dimensions but got {len(self.strides)}"
  21. validate_block_shape(self.block_shape)
  22. assert self.strides[-1] == 1, "Last dimension must be contiguous"
  23. assert isinstance(self.layout, (PaddedSharedLayout, SwizzledSharedLayout)), \
  24. "Expected layout to be a PaddedSharedLayout or SwizzledSharedLayout"
  25. if isinstance(self.layout, SwizzledSharedLayout):
  26. assert self.layout.max_phase == 1, "Expected max_phase to be 1 for SwizzledSharedLayout"
  27. assert self.padding == "zero", "Only 'zero' padding is supported"
  28. @staticmethod
  29. def from_tensor(tensor: Any, block_shape: List[int], layout: PaddedSharedLayout | SwizzledSharedLayout):
  30. """ Create a TensorDescriptor object from a tensor.
  31. Args:
  32. tensor (torch.Tensor): The input tensor.
  33. block_shape (List[int]): The block shape of the tensor.
  34. layout (PaddedSharedLayout | SwizzledSharedLayout): The layout of the tensor in shared memory.
  35. Returns:
  36. tensor_descriptor: the created TensorDescriptor object
  37. """
  38. return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, layout)