tensor_descriptor.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from dataclasses import dataclass
  2. from typing import List, Any
  3. from triton._utils import validate_block_shape
  4. @dataclass
  5. class TensorDescriptor:
  6. base: Any
  7. shape: List[int]
  8. strides: List[int]
  9. block_shape: List[int]
  10. padding: str = "zero"
  11. def __post_init__(self):
  12. rank = len(self.shape)
  13. assert len(self.strides) == rank, f"rank mismatch: {self}"
  14. assert len(self.block_shape) == rank, f"rank mismatch: {self}"
  15. assert rank > 0, "rank must not be zero"
  16. assert rank <= 5, "rank cannot be more than 5"
  17. ty = type(self.base)
  18. if ty.__name__ not in ("FakeTensor", "FunctionalTensor"):
  19. assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
  20. validate_block_shape(self.block_shape)
  21. elem_bytes = self.base.dtype.itemsize
  22. for stride in self.strides[:-1]:
  23. assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
  24. for shape_dim in self.shape:
  25. assert shape_dim > 0, "shape must be positive"
  26. assert self.strides[-1] == 1, "Last dimension must be contiguous"
  27. assert self.padding == "zero" or self.padding == "nan", "Illegal value for padding"
  28. if self.padding == "nan":
  29. assert self.base.dtype.is_floating_point, "Padding option `nan` is only supported for floating point tensors"
  30. @staticmethod
  31. def from_tensor(tensor: Any, block_shape: List[int], padding="zero"):
  32. return TensorDescriptor(tensor, tensor.shape, tensor.stride(), block_shape, padding)