| 123456789101112131415161718192021222324252627282930313233343536 |
- from dataclasses import dataclass
- from typing import List, Any
- from triton._utils import validate_block_shape
- @dataclass
- class TensorDescriptor:
- base: Any
- shape: List[int]
- strides: List[int]
- block_shape: List[int]
- padding: str = "zero"
- def __post_init__(self):
- rank = len(self.shape)
- assert len(self.strides) == rank, f"rank mismatch: {self}"
- assert len(self.block_shape) == rank, f"rank mismatch: {self}"
- assert rank > 0, "rank must not be zero"
- assert rank <= 5, "rank cannot be more than 5"
- ty = type(self.base)
- if ty.__name__ not in ("FakeTensor", "FunctionalTensor"):
- assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
- validate_block_shape(self.block_shape)
- elem_bytes = self.base.dtype.itemsize
- for stride in self.strides[:-1]:
- assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
- for shape_dim in self.shape:
- assert shape_dim > 0, "shape must be positive"
- assert self.strides[-1] == 1, "Last dimension must be contiguous"
- assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
- if self.padding == "nan":
- assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
- @staticmethod
- def from_tensor(tensor: Any, block_shape: List[int], padding="zero"):
- return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding)
|