| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- from dataclasses import dataclass
- from typing import List, Any
- from triton._utils import validate_block_shape
- from triton.experimental.gluon.language._layouts import PaddedSharedLayout, SwizzledSharedLayout
- __all__ = ["TensorDescriptor"]
- @dataclass
- class TensorDescriptor:
- base: Any
- shape: List[int]
- strides: List[int]
- block_shape: List[int]
- layout: PaddedSharedLayout | SwizzledSharedLayout
- padding: str = "zero"
- def __post_init__(self):
- ndim = len(self.shape)
- # TODO: support 1D-5D tensor descriptors
- assert ndim == 2, f"Expected 2 dimensions but got {ndim} dimensions"
- assert len(self.strides) == ndim, f"Expected {ndim} strides but got {len(self.strides)}"
- assert len(self.block_shape) == ndim, \
- f"Expected block_shape to have {ndim} dimensions but got {len(self.strides)}"
- validate_block_shape(self.block_shape)
- assert self.strides[-1] == 1, "Last dimension must be contiguous"
- assert isinstance(self.layout, (PaddedSharedLayout, SwizzledSharedLayout)), \
- "Expected layout to be a PaddedSharedLayout or SwizzledSharedLayout"
- if isinstance(self.layout, SwizzledSharedLayout):
- assert self.layout.max_phase == 1, "Expected max_phase to be 1 for SwizzledSharedLayout"
- assert self.padding == "zero", "Only 'zero' padding is supported"
- @staticmethod
- def from_tensor(tensor: Any, block_shape: List[int], layout: PaddedSharedLayout | SwizzledSharedLayout):
- """ Create a TensorDescriptor object from a tensor.
- Args:
- tensor (torch.Tensor): The input tensor.
- block_shape (List[int]): The block shape of the tensor.
- layout (PaddedSharedLayout | SwizzledSharedLayout): The layout of the tensor in shared memory.
- Returns:
- tensor_descriptor: the created TensorDescriptor object
- """
- return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, layout)
|