semi_structured.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from collections import namedtuple
  4. from collections.abc import Callable
  5. from typing import Any
  6. import torch
  7. from torch.sparse._semi_structured_conversions import (
  8. sparse_semi_structured_from_dense_cutlass,
  9. sparse_semi_structured_to_dense_cutlass,
  10. )
  11. from torch.sparse._semi_structured_ops import (
  12. fallback_dispatcher,
  13. semi_sparse_addmm,
  14. semi_sparse_detach,
  15. semi_sparse_indices,
  16. semi_sparse_linear,
  17. semi_sparse_mm,
  18. semi_sparse_scaled_mm,
  19. semi_sparse_t,
  20. semi_sparse_values,
  21. semi_sparse_view,
  22. )
  23. __all__ = [
  24. "SparseSemiStructuredTensor",
  25. "SparseSemiStructuredTensorCUTLASS",
  26. "SparseSemiStructuredTensorCUSPARSELT",
  27. "to_sparse_semi_structured",
  28. ]
  29. _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
  30. "_SEMI_STRUCTURED_SPARSE_CONFIG",
  31. "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
  32. )
  33. class SparseSemiStructuredTensor(torch.Tensor):
  34. """
  35. This class implements semi-structured sparsity as a Tensor subclass.
  36. Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
  37. depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
  38. structured sparsity.
  39. There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
  40. This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
  41. and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
  42. Note that as such, this class cannot be instantiated directly.
  43. -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
  44. - `def from_dense()` - backend specific compression routines
  45. - `def _mm()` - backend specific mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
  46. """
  47. _DEFAULT_ALG_ID: int = 0
  48. _DTYPE_SHAPE_CONSTRAINTS: dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
  49. _FORCE_CUTLASS: bool = False
  50. _FUSE_TRANSPOSE: bool = False
  51. _PROTOTYPE_WARNING_SHOWN: bool = False
  52. BACKEND: str
  53. SPARSE_DISPATCH: dict[Callable, Callable]
  54. packed: torch.Tensor | None
  55. meta: torch.Tensor | None
  56. packed_t: torch.Tensor | None
  57. meta_t: torch.Tensor | None
  58. compressed_swizzled_bitmask: torch.Tensor | None
  59. fuse_transpose_cusparselt: bool
  60. alg_id_cusparselt: int
  61. __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
  62. @staticmethod
  63. def __new__( # noqa: PYI034
  64. cls,
  65. shape: torch.Size,
  66. packed: torch.Tensor | None,
  67. meta: torch.Tensor | None,
  68. packed_t: torch.Tensor | None,
  69. meta_t: torch.Tensor | None,
  70. compressed_swizzled_bitmask: torch.Tensor | None,
  71. fuse_transpose_cusparselt: bool = False,
  72. alg_id_cusparselt: int = 0,
  73. requires_grad: bool = False,
  74. ):
  75. """
  76. Create a new instance of the tensor subclass from the compressed sparse representation.
  77. We have the option to create the subclass with the compressed representations of both X and X', for training.
  78. For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
  79. Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
  80. Args:
  81. shape: The shape of the original dense tensor
  82. packed: The compressed representation of the original dense tensor
  83. meta: The metadata of the original dense tensor, if it is stored separately
  84. packed_t: The compressed representation of the transposed original dense tensor
  85. meta_t: The metadata of the transposed original dense tensor, if it is stored separately
  86. compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
  87. participate in the computation. Used for pointwise ops.
  88. fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
  89. with a matmul, which is useful in the case of 2:4 sparse training.
  90. alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
  91. Returns:
  92. torch.Tensor: A torch.Tensor wrapper subclass.
  93. Raises:
  94. ValueError: If all of the tensor arguments are None.
  95. """
  96. if not cls._PROTOTYPE_WARNING_SHOWN:
  97. warnings.warn(
  98. (
  99. "The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
  100. "and will change in the near future. Please open a Github issue "
  101. "for features requests and see our documentation on the torch.sparse "
  102. "module for further information about the project."
  103. ),
  104. UserWarning,
  105. stacklevel=2,
  106. )
  107. cls._PROTOTYPE_WARNING_SHOWN = True
  108. # Because this only runs once, we also load the dispatch table here as well.
  109. # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
  110. # But this is useful since it allows users to overload the dispatch table for debugging / testing.
  111. cls._load_dispatch_table()
  112. # we can also register the classes with dynamo when the warning is shown.
  113. torch._dynamo.allow_in_graph(cls)
  114. if packed is not None:
  115. previous_tensor = packed
  116. elif packed_t is not None:
  117. previous_tensor = packed_t
  118. else:
  119. raise ValueError("At least one of packed or packed_t must be provided")
  120. tensor = torch.Tensor._make_wrapper_subclass(
  121. cls,
  122. shape,
  123. device=previous_tensor.device,
  124. dtype=previous_tensor.dtype,
  125. layout=previous_tensor.layout,
  126. requires_grad=requires_grad,
  127. )
  128. tensor.packed = packed
  129. tensor.meta = meta
  130. tensor.packed_t = packed_t
  131. tensor.meta_t = meta_t
  132. tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
  133. tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
  134. tensor.alg_id_cusparselt = alg_id_cusparselt
  135. return tensor
  136. def __repr__(self) -> str: # type: ignore[override]
  137. if not hasattr(self, "shape"):
  138. raise AssertionError("tensor has no shape attribute")
  139. return f"{self.__class__.__name__}(shape={self.shape})"
  140. def __tensor_flatten__(
  141. self,
  142. ) -> tuple[list[str], tuple[torch.Size, bool, int, bool]]:
  143. inner_tensors = list(
  144. filter(lambda x: getattr(self, x) is not None, self.__slots__)
  145. )
  146. tensor_meta = (
  147. self.shape,
  148. self.fuse_transpose_cusparselt,
  149. self.alg_id_cusparselt,
  150. self.requires_grad,
  151. )
  152. return inner_tensors, tensor_meta
  153. @classmethod
  154. def __tensor_unflatten__(
  155. cls,
  156. inner_tensors,
  157. tensor_meta: tuple[torch.Size, bool, int, bool],
  158. outer_size,
  159. outer_stride,
  160. ) -> torch.Tensor:
  161. shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
  162. # pyrefly: ignore [no-matching-overload]
  163. return cls(
  164. shape=shape,
  165. packed=inner_tensors.get("packed", None),
  166. meta=inner_tensors.get("meta", None),
  167. packed_t=inner_tensors.get("packed_t", None),
  168. meta_t=inner_tensors.get("meta_t", None),
  169. compressed_swizzled_bitmask=inner_tensors.get(
  170. "compressed_swizzled_bitmask", None
  171. ),
  172. fuse_transpose_cusparselt=fuse_transpose_cusparselt,
  173. alg_id_cusparselt=alg_id_cusparselt,
  174. requires_grad=requires_grad,
  175. )
  176. __torch_function__ = torch._C._disabled_torch_function_impl # type: ignore[assignment]
  177. @classmethod
  178. def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: # type: ignore[override]
  179. if func._overloadpacket not in cls.SPARSE_DISPATCH:
  180. raise NotImplementedError(
  181. f"{cls.__name__} only supports a specific set of operations, "
  182. f"can't perform requested op ({func.__name__})"
  183. )
  184. return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
  185. @classmethod
  186. def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
  187. """
  188. Loads the op overload sparse dispatch table for the current class.
  189. """
  190. if getattr(cls, "SPARSE_DISPATCH", None) is None:
  191. cls.SPARSE_DISPATCH = {
  192. torch.ops.aten.values: semi_sparse_values,
  193. torch.ops.aten.indices: semi_sparse_indices,
  194. torch.ops.aten.is_same_size: fallback_dispatcher,
  195. torch.ops.aten.detach_: fallback_dispatcher,
  196. torch.ops.aten.detach: semi_sparse_detach,
  197. torch.ops.aten.t: semi_sparse_t,
  198. torch.ops.aten.view: semi_sparse_view,
  199. torch.ops.aten.mm: semi_sparse_mm,
  200. torch.ops.aten.matmul: semi_sparse_mm,
  201. torch.ops.aten.addmm: semi_sparse_addmm,
  202. torch.ops.aten.linear: semi_sparse_linear,
  203. torch.ops.aten._to_copy: fallback_dispatcher,
  204. torch.ops.aten._scaled_mm: semi_sparse_scaled_mm,
  205. }
  206. if custom_dispatch_table is not None:
  207. cls.SPARSE_DISPATCH.update(custom_dispatch_table)
  208. @classmethod
  209. def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
  210. """
  211. Assert that the given tensor is valid for semi-structured sparse compression.
  212. """
  213. # check device
  214. if not original_tensor.is_cuda:
  215. raise RuntimeError(
  216. f"Error original_tensor.device= {original_tensor.device} is not supported! "
  217. "Only CUDA tensors are currently supported."
  218. )
  219. # check dim
  220. if original_tensor.dim() != 2:
  221. raise RuntimeError(
  222. f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
  223. "Only 2d tensors are currently supported."
  224. )
  225. # check contiguous
  226. if not original_tensor.is_contiguous():
  227. raise RuntimeError(
  228. "Error original_tensor is not contiguous!"
  229. "Only contiguous tensors are currently supported."
  230. )
  231. # check dtype
  232. if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
  233. raise RuntimeError(
  234. f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype for {cls}!"
  235. )
  236. # check shape
  237. m, n = original_tensor.shape
  238. min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
  239. min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
  240. if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
  241. # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
  242. raise RuntimeError(
  243. f"Error original_tensor.shape {original_tensor.shape} is not supported! "
  244. f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
  245. )
  246. @classmethod
  247. def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
  248. """
  249. Calculates padding for dense tensor and pads tensor if necessary.
  250. If padding is not required, this function returns the original tensor.
  251. """
  252. # only 2d matmul
  253. if dense_input.dim() != 2:
  254. raise AssertionError(f"dense_input must be 2D, got {dense_input.dim()}D")
  255. # check shape
  256. m, n = dense_input.shape
  257. min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
  258. min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
  259. # calculate padding
  260. to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
  261. to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
  262. if to_pad_m or to_pad_n:
  263. return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
  264. else:
  265. return dense_input
  266. def to_dense(self): # type:ignore[override]
  267. col = self.shape[-1]
  268. return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
  269. @classmethod
  270. def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
  271. raise NotImplementedError
  272. def _mm(
  273. self,
  274. B: torch.Tensor,
  275. *,
  276. bias: torch.Tensor | None = None,
  277. **kwargs,
  278. ) -> torch.Tensor:
  279. raise NotImplementedError
  280. def to_sparse_semi_structured(
  281. original_tensor: torch.Tensor,
  282. transposed: bool = False,
  283. ) -> SparseSemiStructuredTensor:
  284. """
  285. This function converts a dense tensor into a sparse semi-structured tensor.
  286. It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
  287. This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
  288. We currently only support semi-structured sparse tensors for 2d CUDA tensors.
  289. Additionally, your tensor must be a positive multiple of the minimum sparse block size, given in
  290. `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
  291. Args:
  292. original_tensor (Tensor): the dense tensor to convert
  293. transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
  294. Returns:
  295. SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
  296. Raises:
  297. None
  298. Example:
  299. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  300. >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
  301. tensor([[0., 0., 1., ..., 0., 1., 1.],
  302. [0., 0., 1., ..., 0., 1., 1.],
  303. [0., 0., 1., ..., 0., 1., 1.],
  304. ...,
  305. [0., 0., 1., ..., 0., 1., 1.],
  306. [0., 0., 1., ..., 0., 1., 1.],
  307. [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
  308. >>> A_sparse = to_sparse_semi_structured(A)
  309. SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
  310. >>> A_sparse.values()
  311. tensor([[1., 1., 1., ..., 1., 1., 1.],
  312. [1., 1., 1., ..., 1., 1., 1.],
  313. [1., 1., 1., ..., 1., 1., 1.],
  314. ...,
  315. [1., 1., 1., ..., 1., 1., 1.],
  316. [1., 1., 1., ..., 1., 1., 1.],
  317. [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
  318. >>> A_sparse.indices()
  319. tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
  320. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  321. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  322. ...,
  323. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  324. [-4370, -4370, -4370, ..., -4370, -4370, -4370],
  325. [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
  326. """
  327. if transposed:
  328. warnings.warn(
  329. "Setting transpose from `to_sparse_semi_structured` is deprecated "
  330. "and will be removed in a future release. "
  331. "`SparseSemiStructuredTensor` only support contiguous input tensors.",
  332. FutureWarning,
  333. stacklevel=2,
  334. )
  335. # set from _FORCE_CUTLASS flag
  336. SPARSE_SUBCLASS = (
  337. torch.sparse.SparseSemiStructuredTensorCUTLASS
  338. if SparseSemiStructuredTensor._FORCE_CUTLASS
  339. else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
  340. )
  341. return SPARSE_SUBCLASS.from_dense(original_tensor)
  342. class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
  343. """
  344. This class implements semi-structured sparsity for the CUTLASS backend.
  345. In this implementation, the specified elements and metadata are stored separately,
  346. in packed and meta respectively.
  347. When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
  348. sparse_semi_structured_from_dense for conversion to the compressed format.
  349. """
  350. BACKEND = "cutlass"
  351. _DTYPE_SHAPE_CONSTRAINTS = {
  352. torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
  353. torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
  354. torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
  355. torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
  356. }
  357. @classmethod
  358. def from_dense(
  359. cls, original_tensor: torch.Tensor
  360. ) -> "SparseSemiStructuredTensorCUTLASS":
  361. cls._validate_device_dim_dtype_shape(original_tensor)
  362. (
  363. sparse_tensor_cutlass,
  364. meta_tensor_cutlass,
  365. ) = sparse_semi_structured_from_dense_cutlass(original_tensor)
  366. # pyrefly: ignore [no-matching-overload]
  367. return cls(
  368. original_tensor.shape,
  369. packed=sparse_tensor_cutlass,
  370. meta=meta_tensor_cutlass,
  371. packed_t=None,
  372. meta_t=None,
  373. compressed_swizzled_bitmask=None,
  374. requires_grad=original_tensor.requires_grad,
  375. )
  376. def to_dense(self): # type: ignore[override]
  377. if self.meta is None or self.packed is None:
  378. raise AssertionError("meta and packed must not be None")
  379. return (
  380. sparse_semi_structured_to_dense_cutlass(
  381. self.packed,
  382. self.meta,
  383. )
  384. if self.meta.ndim == 2
  385. else super().to_dense()
  386. )
  387. @classmethod
  388. def prune_dense_static_sort(
  389. cls, original_tensor: torch.Tensor, algorithm=""
  390. ) -> "SparseSemiStructuredTensor":
  391. """
  392. This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
  393. It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
  394. The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
  395. Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
  396. It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
  397. pruned dense tensor.
  398. Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
  399. Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
  400. This can be used in the backward pass to mask the gradients.
  401. [9 1 7 4] [9 0 7 0]
  402. [1 2 3 0] [0 2 0 0]
  403. [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
  404. [1 2 6 2] [0 0 6 2] -> metadata
  405. -> pack to transposed CUTLASS -> packed_t
  406. semi-structured representation -> metadata_t
  407. -> compute swizzled bitmask -> compressed_swizzled_bitmask
  408. The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
  409. ```
  410. from torch.sparse import SparseSemiStructuredTensorCUTLASS
  411. from torch.sparse._semi_structured_conversions import (
  412. _sparse_semi_structured_tile,
  413. _compute_compressed_swizzled_bitmask,
  414. )
  415. pruned = _sparse_semi_structured_tile(dense)
  416. packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
  417. packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(
  418. pruned.t().contiguous()
  419. )
  420. bitmask = _compute_compressed_swizzled_bitmask(pruned)
  421. SparseSemiStructuredTensorCUTLASS(
  422. dense.shape,
  423. packed_cutlass,
  424. meta_cutlass,
  425. packed_t_cutlass,
  426. meta_t_cutlass,
  427. bitmask,
  428. )
  429. ```
  430. """
  431. # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
  432. (
  433. packed,
  434. meta,
  435. packed_t,
  436. meta_t,
  437. compressed_swizzled_bitmask,
  438. ) = torch._sparse_semi_structured_tile(
  439. original_tensor, algorithm=algorithm, use_cutlass=True
  440. )
  441. # pyrefly: ignore [no-matching-overload]
  442. return cls(
  443. original_tensor.shape,
  444. packed=packed,
  445. meta=meta,
  446. packed_t=packed_t,
  447. meta_t=meta_t,
  448. compressed_swizzled_bitmask=compressed_swizzled_bitmask,
  449. requires_grad=False,
  450. )
  451. def _mm(
  452. self, B: torch.Tensor, *, bias: torch.Tensor | None = None, **kwargs
  453. ) -> torch.Tensor:
  454. if isinstance(B, SparseSemiStructuredTensor):
  455. raise ValueError(
  456. "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
  457. )
  458. cls_name = self.__class__.__name__
  459. if self.ndim != 2 or B.ndim != 2:
  460. raise NotImplementedError(
  461. f"`{cls_name}` matmul: Broadcasting is not implemented"
  462. )
  463. if self.packed is None or self.meta is None:
  464. raise NotImplementedError(
  465. f"`{cls_name}` matmul: operation is not supported"
  466. )
  467. else:
  468. if bias is None:
  469. res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
  470. else:
  471. res = torch._sparse_semi_structured_addmm(
  472. bias, self.packed, self.meta, B
  473. )
  474. return res[: self.shape[0]]
  475. class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
  476. """
  477. The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
  478. packed = [ specified elements of original tensor | metadata ]
  479. For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
  480. The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
  481. attributes respectively.
  482. cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
  483. as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
  484. """
  485. BACKEND = "cusparselt"
  486. _DTYPE_SHAPE_CONSTRAINTS = {
  487. torch.float8_e4m3fn: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
  488. torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
  489. torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
  490. torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
  491. }
  492. @classmethod
  493. def from_dense(
  494. cls, original_tensor: torch.Tensor
  495. ) -> "SparseSemiStructuredTensorCUSPARSELT":
  496. cls._validate_device_dim_dtype_shape(original_tensor)
  497. # pyrefly: ignore [no-matching-overload]
  498. return cls(
  499. shape=original_tensor.shape,
  500. packed=torch._cslt_compress(original_tensor),
  501. meta=None,
  502. packed_t=None,
  503. meta_t=None,
  504. compressed_swizzled_bitmask=None,
  505. fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
  506. alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
  507. requires_grad=original_tensor.requires_grad,
  508. )
  509. @classmethod
  510. def prune_dense_static_sort(
  511. cls, original_tensor: torch.Tensor, algorithm=""
  512. ) -> "SparseSemiStructuredTensor":
  513. """
  514. This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPARSELt metadata
  515. layout and sparse matmul.
  516. The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
  517. [9 1 7 4] [9 0 7 0]
  518. [1 2 3 0] [0 2 0 0]
  519. [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
  520. [1 2 6 2] [0 0 6 2]
  521. -> pack to transposed cuSPARSELt -> packed_t
  522. semi-structured representation
  523. -> compute swizzled bitmask -> compressed_swizzled_bitmask
  524. The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
  525. ```
  526. from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
  527. from torch.sparse._semi_structured_conversions import (
  528. _sparse_semi_structured_tile,
  529. _compute_compressed_swizzled_bitmask,
  530. )
  531. pruned = _sparse_semi_structured_tile(dense)
  532. packed_cusparselt = torch._cslt_compress(pruned)
  533. packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
  534. bitmask = _compute_compressed_swizzled_bitmask(pruned)
  535. SparseSemiStructuredTensorCUSPARSELT(
  536. dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask
  537. )
  538. ```
  539. """
  540. (
  541. packed,
  542. meta,
  543. packed_t,
  544. meta_t,
  545. compressed_swizzled_bitmask,
  546. ) = torch._sparse_semi_structured_tile(
  547. original_tensor, algorithm=algorithm, use_cutlass=False
  548. )
  549. # Map this two 2-dim view of packed data.
  550. # TODO: is this proper cuSPARSELt metadata?
  551. packed = packed.view(original_tensor.shape[0], -1)
  552. packed_t = packed_t.view(original_tensor.shape[1], -1)
  553. # pyrefly: ignore [no-matching-overload]
  554. return cls(
  555. original_tensor.shape,
  556. packed=packed,
  557. meta=meta,
  558. packed_t=packed_t,
  559. meta_t=meta_t,
  560. compressed_swizzled_bitmask=compressed_swizzled_bitmask,
  561. requires_grad=False,
  562. )
  563. def _mm(
  564. self, B: torch.Tensor, *, bias: torch.Tensor | None = None, **kwargs
  565. ) -> torch.Tensor:
  566. if isinstance(B, SparseSemiStructuredTensor):
  567. raise ValueError(
  568. "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
  569. )
  570. if self.ndim != 2 or B.ndim != 2:
  571. raise NotImplementedError(
  572. f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
  573. )
  574. if B.dtype != self.dtype:
  575. raise NotImplementedError(
  576. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
  577. f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
  578. "This operation is only supported when A and B have the same data type."
  579. )
  580. if bias is not None and bias.dtype != self.dtype:
  581. raise NotImplementedError(
  582. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
  583. f"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
  584. "This operation is only supported when A, B and C have the same data type."
  585. )
  586. # Force fp8 mm to error to be consistent with torch
  587. if self.dtype == torch.float8_e4m3fn:
  588. raise NotImplementedError(
  589. f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
  590. f"with A.dtype=B.dtype={self.dtype}. "
  591. "mm is not supported for float8_e4m3fn, please use `torch._scaled_mm` instead."
  592. )
  593. if self.packed is None:
  594. raise NotImplementedError(
  595. f"`{self.__class__.__name__}` matmul: operation is not supported"
  596. )
  597. else:
  598. res = torch._cslt_sparse_mm(
  599. self.packed,
  600. B,
  601. bias=bias,
  602. transpose_result=self.fuse_transpose_cusparselt,
  603. alg_id=self.alg_id_cusparselt,
  604. )
  605. return res.t() if self.fuse_transpose_cusparselt else res