__init__.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730
  1. # mypy: allow-untyped-defs
  2. # The Tensor classes are added to this module by python_tensor.cpp
  3. # A workaround to support both TorchScript and MyPy:
  4. from typing import Any, Optional, TYPE_CHECKING, Union
  5. import torch
  6. from torch import Tensor
  7. from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
  8. # Semi structured sparsity support
  9. from .semi_structured import (
  10. SparseSemiStructuredTensor,
  11. SparseSemiStructuredTensorCUSPARSELT,
  12. SparseSemiStructuredTensorCUTLASS,
  13. to_sparse_semi_structured,
  14. )
  15. if TYPE_CHECKING:
  16. from torch.types import _dtype as DType
  17. DimOrDims = Optional[int | tuple[int, ...] | list[int]]
  18. else:
  19. # The JIT doesn't understand Union, nor torch.dtype here
  20. DType = int
  21. DimOrDims = Optional[tuple[int]]
  22. __all__ = [
  23. "addmm",
  24. "check_sparse_tensor_invariants",
  25. "mm",
  26. "sum",
  27. "softmax",
  28. # pyrefly: ignore [bad-dunder-all]
  29. "solve",
  30. "log_softmax",
  31. "SparseSemiStructuredTensor",
  32. "SparseSemiStructuredTensorCUTLASS",
  33. "SparseSemiStructuredTensorCUSPARSELT",
  34. "to_sparse_semi_structured",
  35. "as_sparse_gradcheck",
  36. ]
  37. addmm = _add_docstr(
  38. _sparse._sparse_addmm,
  39. r"""
  40. sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
  41. This function does exact same thing as :func:`torch.addmm` in the forward,
  42. except that it supports backward for sparse COO and CSR matrix :attr:`mat1`.
  43. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
  44. Supports both CSR and COO storage formats.
  45. .. note::
  46. **Gradient support:**
  47. - **COO @ Dense**: Backward is supported for both inputs. The gradient for the
  48. sparse input is returned as a sparse COO tensor.
  49. - **CSR @ Dense**: Backward is supported for both inputs. The gradient for the
  50. sparse input is returned as a sparse CSR tensor.
  51. - **CSC/BSR/BSC @ Dense**: Not supported.
  52. - **Sparse @ Sparse** (COO @ COO, CSR @ CSR): Forward works, but backward is
  53. not supported.
  54. Args:
  55. mat (Tensor): a dense matrix to be added
  56. mat1 (Tensor): a sparse matrix to be multiplied
  57. mat2 (Tensor): a dense matrix to be multiplied
  58. beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
  59. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  60. """,
  61. )
  62. mm = _add_docstr(
  63. _sparse._sparse_mm,
  64. r"""
  65. Performs a matrix multiplication of the sparse matrix :attr:`mat1`
  66. and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a
  67. :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
  68. :math:`(n \times p)` tensor.
  69. When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
  70. Supports both CSR and COO storage formats.
  71. .. note::
  72. **Gradient support:**
  73. - **COO @ Dense**: Backward is supported for both inputs. The gradient for the
  74. sparse input is returned as a sparse COO tensor.
  75. - **CSR @ Dense**: Backward is supported for both inputs. The gradient for the
  76. sparse input is returned as a sparse CSR tensor.
  77. - **CSC/BSR/BSC @ Dense**: Not supported.
  78. - **Sparse @ Sparse** (COO @ COO, CSR @ CSR): Forward works, but backward is
  79. not supported.
  80. - **Mixed formats** (COO @ CSR, CSR @ COO): Not supported.
  81. This function also additionally accepts an optional :attr:`reduce` argument that allows
  82. specification of an optional reduction operation, mathematically performs the following operation:
  83. .. math::
  84. z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
  85. where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for
  86. CSR storage format on CPU device.
  87. Args:
  88. mat1 (Tensor): the first sparse matrix to be multiplied
  89. mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
  90. reduce (str, optional): the reduction operation to apply for non-unique indices
  91. (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`.
  92. Shape:
  93. The format of the output tensor of this function follows:
  94. - sparse x sparse -> sparse
  95. - sparse x dense -> dense
  96. Example::
  97. >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
  98. >>> a
  99. tensor(indices=tensor([[0, 0, 1],
  100. [0, 2, 1]]),
  101. values=tensor([1., 2., 3.]),
  102. size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
  103. >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
  104. >>> b
  105. tensor([[0., 1.],
  106. [2., 0.],
  107. [0., 0.]], requires_grad=True)
  108. >>> y = torch.sparse.mm(a, b)
  109. >>> y
  110. tensor([[0., 1.],
  111. [6., 0.]], grad_fn=<SparseAddmmBackward0>)
  112. >>> y.sum().backward()
  113. >>> a.grad
  114. tensor(indices=tensor([[0, 0, 1],
  115. [0, 2, 1]]),
  116. values=tensor([1., 0., 2.]),
  117. size=(2, 3), nnz=3, layout=torch.sparse_coo)
  118. >>> c = a.detach().to_sparse_csr()
  119. >>> c
  120. tensor(crow_indices=tensor([0, 2, 3]),
  121. col_indices=tensor([0, 2, 1]),
  122. values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
  123. layout=torch.sparse_csr)
  124. >>> y1 = torch.sparse.mm(c, b, 'sum')
  125. >>> y1
  126. tensor([[0., 1.],
  127. [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
  128. >>> y2 = torch.sparse.mm(c, b, 'max')
  129. >>> y2
  130. tensor([[0., 1.],
  131. [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
  132. """,
  133. )
  134. sampled_addmm = _add_docstr(
  135. _sparse.sparse_sampled_addmm,
  136. r"""
  137. sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
  138. Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations
  139. specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.
  140. Mathematically this performs the following operation:
  141. .. math::
  142. \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
  143. where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`
  144. and :attr:`beta` are the scaling factors.
  145. :math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere.
  146. .. note::
  147. :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors.
  148. Args:
  149. input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute
  150. the sampled matrix multiplication
  151. mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied
  152. mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied
  153. Keyword args:
  154. beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`)
  155. alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
  156. out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
  157. Examples::
  158. >>> input = torch.eye(3, device='cuda').to_sparse_csr()
  159. >>> mat1 = torch.randn(3, 5, device='cuda')
  160. >>> mat2 = torch.randn(5, 3, device='cuda')
  161. >>> torch.sparse.sampled_addmm(input, mat1, mat2)
  162. tensor(crow_indices=tensor([0, 1, 2, 3]),
  163. col_indices=tensor([0, 1, 2]),
  164. values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0',
  165. size=(3, 3), nnz=3, layout=torch.sparse_csr)
  166. >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense()
  167. tensor([[ 0.2847, 0.0000, 0.0000],
  168. [ 0.0000, -0.7805, 0.0000],
  169. [ 0.0000, 0.0000, -0.1900]], device='cuda:0')
  170. >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5)
  171. tensor(crow_indices=tensor([0, 1, 2, 3]),
  172. col_indices=tensor([0, 1, 2]),
  173. values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0',
  174. size=(3, 3), nnz=3, layout=torch.sparse_csr)
  175. """,
  176. )
  177. def sum(input: Tensor, dim: DimOrDims = None, dtype: DType | None = None) -> Tensor:
  178. r"""Return the sum of each row of the given sparse tensor.
  179. Returns the sum of each row of the sparse tensor :attr:`input` in the given
  180. dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
  181. reduce over all of them. When sum over all ``sparse_dim``, this method
  182. returns a dense tensor instead of a sparse tensor.
  183. All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
  184. tensor having :attr:`dim` fewer dimensions than :attr:`input`.
  185. During backward, only gradients at ``nnz`` locations of :attr:`input`
  186. will propagate back. Note that the gradients of :attr:`input` is coalesced.
  187. Args:
  188. input (Tensor): the input sparse tensor
  189. dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
  190. over all dims.
  191. dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
  192. Default: dtype of :attr:`input`.
  193. Example::
  194. >>> nnz = 3
  195. >>> dims = [5, 5, 2, 3]
  196. >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
  197. torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
  198. >>> V = torch.randn(nnz, dims[2], dims[3])
  199. >>> size = torch.Size(dims)
  200. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  201. >>> S = torch.sparse_coo_tensor(I, V, size)
  202. >>> S
  203. tensor(indices=tensor([[2, 0, 3],
  204. [2, 4, 1]]),
  205. values=tensor([[[-0.6438, -1.6467, 1.4004],
  206. [ 0.3411, 0.0918, -0.2312]],
  207. [[ 0.5348, 0.0634, -2.0494],
  208. [-0.7125, -1.0646, 2.1844]],
  209. [[ 0.1276, 0.1874, -0.6334],
  210. [-1.9682, -0.5340, 0.7483]]]),
  211. size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
  212. # when sum over only part of sparse_dims, return a sparse tensor
  213. >>> torch.sparse.sum(S, [1, 3])
  214. tensor(indices=tensor([[0, 2, 3]]),
  215. values=tensor([[-1.4512, 0.4073],
  216. [-0.8901, 0.2017],
  217. [-0.3183, -1.7539]]),
  218. size=(5, 2), nnz=3, layout=torch.sparse_coo)
  219. # when sum over all sparse dim, return a dense tensor
  220. # with summed dims squeezed
  221. >>> torch.sparse.sum(S, [0, 1, 3])
  222. tensor([-2.6596, -1.1450])
  223. """
  224. if dtype is None:
  225. if dim is not None:
  226. return torch._sparse_sum(input, dim)
  227. else:
  228. return torch._sparse_sum(input)
  229. else:
  230. if dim is not None:
  231. return torch._sparse_sum(input, dim, dtype=dtype)
  232. else:
  233. return torch._sparse_sum(input, dtype=dtype)
  234. softmax = _add_docstr(
  235. _sparse._sparse_softmax,
  236. r"""
  237. sparse.softmax(input, dim, *, dtype=None) -> Tensor
  238. Applies a softmax function.
  239. Softmax is defined as:
  240. :math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
  241. where :math:`i, j` run over sparse tensor indices and unspecified
  242. entries are ignores. This is equivalent to defining unspecified
  243. entries as negative infinity so that :math:`exp(x_k) = 0` when the
  244. entry with index :math:`k` has not specified.
  245. It is applied to all slices along `dim`, and will re-scale them so
  246. that the elements lie in the range `[0, 1]` and sum to 1.
  247. Args:
  248. input (Tensor): input
  249. dim (int): A dimension along which softmax will be computed.
  250. dtype (:class:`torch.dtype`, optional): the desired data type
  251. of returned tensor. If specified, the input tensor is
  252. casted to :attr:`dtype` before the operation is
  253. performed. This is useful for preventing data type
  254. overflows. Default: None
  255. """,
  256. )
  257. spsolve = _add_docstr(
  258. _sparse._spsolve,
  259. r"""
  260. sparse.spsolve(input, other, *, left=True) -> Tensor
  261. Computes the solution of a square system of linear equations with
  262. a unique solution. Its purpose is similar to :func:`torch.linalg.solve`,
  263. except that the system is defined by a sparse CSR matrix with layout
  264. `sparse_csr`.
  265. Args:
  266. input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the
  267. coefficients of the linear system.
  268. other (Tensor): a dense matrix of shape `(n, )` representing the right-hand
  269. side of the linear system.
  270. left (bool, optional): whether to solve the system for `input @ out = other`
  271. (default) or `out @ input = other`. Only `left=True` is supported.
  272. """,
  273. )
  274. log_softmax = _add_docstr(
  275. _sparse._sparse_log_softmax,
  276. r"""
  277. sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
  278. Applies a softmax function followed by logarithm.
  279. See :class:`~torch.sparse.softmax` for more details.
  280. Args:
  281. input (Tensor): input
  282. dim (int): A dimension along which softmax will be computed.
  283. dtype (:class:`torch.dtype`, optional): the desired data type
  284. of returned tensor. If specified, the input tensor is
  285. casted to :attr:`dtype` before the operation is
  286. performed. This is useful for preventing data type
  287. overflows. Default: None
  288. """,
  289. )
  290. spdiags = _add_docstr(
  291. _sparse._spdiags,
  292. r"""
  293. sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
  294. Creates a sparse 2D tensor by placing the values from rows of
  295. :attr:`diagonals` along specified diagonals of the output
  296. The :attr:`offsets` tensor controls which diagonals are set.
  297. - If :attr:`offsets[i]` = 0, it is the main diagonal
  298. - If :attr:`offsets[i]` < 0, it is below the main diagonal
  299. - If :attr:`offsets[i]` > 0, it is above the main diagonal
  300. The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
  301. and an offset may not be repeated.
  302. Args:
  303. diagonals (Tensor): Matrix storing diagonals row-wise
  304. offsets (Tensor): The diagonals to be set, stored as a vector
  305. shape (2-tuple of ints): The desired shape of the result
  306. Keyword args:
  307. layout (:class:`torch.layout`, optional): The desired layout of the
  308. returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
  309. are supported. Default: ``torch.sparse_coo``
  310. Examples:
  311. Set the main and first two lower diagonals of a matrix::
  312. >>> diags = torch.arange(9).reshape(3, 3)
  313. >>> diags
  314. tensor([[0, 1, 2],
  315. [3, 4, 5],
  316. [6, 7, 8]])
  317. >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
  318. >>> s
  319. tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
  320. [0, 1, 2, 0, 1, 0]]),
  321. values=tensor([0, 1, 2, 3, 4, 6]),
  322. size=(3, 3), nnz=6, layout=torch.sparse_coo)
  323. >>> s.to_dense()
  324. tensor([[0, 0, 0],
  325. [3, 1, 0],
  326. [6, 4, 2]])
  327. Change the output layout::
  328. >>> diags = torch.arange(9).reshape(3, 3)
  329. >>> diags
  330. tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
  331. >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
  332. >>> s
  333. tensor(crow_indices=tensor([0, 1, 3, 6]),
  334. col_indices=tensor([0, 0, 1, 0, 1, 2]),
  335. values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
  336. layout=torch.sparse_csr)
  337. >>> s.to_dense()
  338. tensor([[0, 0, 0],
  339. [3, 1, 0],
  340. [6, 4, 2]])
  341. Set partial diagonals of a large output::
  342. >>> diags = torch.tensor([[1, 2], [3, 4]])
  343. >>> offsets = torch.tensor([0, -1])
  344. >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
  345. tensor([[1, 0, 0, 0, 0],
  346. [3, 2, 0, 0, 0],
  347. [0, 4, 0, 0, 0],
  348. [0, 0, 0, 0, 0],
  349. [0, 0, 0, 0, 0]])
  350. .. note::
  351. When setting the values along a given diagonal the index into the diagonal
  352. and the index into the row of :attr:`diagonals` is taken as the
  353. column index in the output. This has the effect that when setting a diagonal
  354. with a positive offset `k` the first value along that diagonal will be
  355. the value in position `k` of the row of :attr:`diagonals`
  356. Specifying a positive offset::
  357. >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
  358. >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
  359. tensor([[1, 2, 3, 0, 0],
  360. [0, 2, 3, 0, 0],
  361. [0, 0, 3, 0, 0],
  362. [0, 0, 0, 0, 0],
  363. [0, 0, 0, 0, 0]])
  364. """,
  365. )
  366. class check_sparse_tensor_invariants:
  367. """A tool to control checking sparse tensor invariants.
  368. The following options exists to manage sparsr tensor invariants
  369. checking in sparse tensor construction:
  370. 1. Using a context manager:
  371. .. code:: python
  372. with torch.sparse.check_sparse_tensor_invariants():
  373. run_my_model()
  374. 2. Using a procedural approach:
  375. .. code:: python
  376. prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
  377. torch.sparse.check_sparse_tensor_invariants.enable()
  378. run_my_model()
  379. if not prev_checks_enabled:
  380. torch.sparse.check_sparse_tensor_invariants.disable()
  381. 3. Using function decoration:
  382. .. code:: python
  383. @torch.sparse.check_sparse_tensor_invariants()
  384. def run_my_model():
  385. ...
  386. run_my_model()
  387. 4. Using ``check_invariants`` keyword argument in sparse tensor constructor call.
  388. For example:
  389. >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True)
  390. Traceback (most recent call last):
  391. File "<stdin>", line 1, in <module>
  392. RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied.
  393. """
  394. @staticmethod
  395. def is_enabled():
  396. r"""Return True if the sparse tensor invariants checking is enabled.
  397. .. note::
  398. Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or
  399. :func:`torch.sparse.check_sparse_tensor_invariants.disable` to
  400. manage the state of the sparse tensor invariants checks.
  401. """
  402. return torch._C._check_sparse_tensor_invariants()
  403. @staticmethod
  404. def enable():
  405. r"""Enable sparse tensor invariants checking in sparse tensor constructors.
  406. .. note::
  407. By default, the sparse tensor invariants checks are disabled. Use
  408. :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to
  409. retrieve the current state of sparse tensor invariants checking.
  410. .. note::
  411. The sparse tensor invariants check flag is effective to all sparse
  412. tensor constructors, both in Python and ATen.
  413. The flag can be locally overridden by the ``check_invariants``
  414. optional argument of the sparse tensor constructor functions.
  415. """
  416. torch._C._set_check_sparse_tensor_invariants(True)
  417. @staticmethod
  418. def disable():
  419. r"""Disable sparse tensor invariants checking in sparse tensor constructors.
  420. See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information.
  421. """
  422. torch._C._set_check_sparse_tensor_invariants(False)
  423. # context manager support
  424. def __init__(self, enable=True):
  425. self.state = enable
  426. self.saved_state: bool | None = None
  427. def __enter__(self):
  428. if self.saved_state is not None:
  429. raise RuntimeError(
  430. "This context manager instance is already activated."
  431. " Use a different context manager instance for context nesting."
  432. )
  433. self.saved_state = self.is_enabled()
  434. torch._C._set_check_sparse_tensor_invariants(self.state)
  435. def __exit__(self, type, value, traceback):
  436. if self.saved_state is None:
  437. raise AssertionError("saved_state should not be None on exit")
  438. torch._C._set_check_sparse_tensor_invariants(self.saved_state)
  439. self.saved_state = None
  440. # decorator support
  441. def __call__(self, mth):
  442. def test_mth(*args, **kwargs):
  443. with type(self)(self.state):
  444. return mth(*args, **kwargs)
  445. return test_mth
  446. def as_sparse_gradcheck(gradcheck):
  447. """Decorate function, to extend gradcheck for sparse tensors.
  448. Decorator for torch.autograd.gradcheck or its functools.partial
  449. variants that extends the gradcheck function with support to input
  450. functions that operate on or/and return sparse tensors.
  451. The specified gradcheck function itself is guaranteed to operate
  452. on strided tensors only.
  453. For example:
  454. >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
  455. >>> x = (
  456. ... torch.tensor([[0, 1], [2, 3]], dtype=torch.float64)
  457. ... .to_sparse_coo()
  458. ... .requires_grad_(True)
  459. ... )
  460. >>> gradcheck(lambda x: x.to_sparse_csr(), x)
  461. True
  462. """
  463. def gradcheck_with_sparse_support(func, inputs, **kwargs):
  464. """
  465. Create gradcheck with support for sparse tensors.
  466. Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support.
  467. """
  468. masked = kwargs.pop("masked", False)
  469. sparse_layouts = {
  470. torch.sparse_coo,
  471. torch.sparse_csr,
  472. torch.sparse_csc,
  473. torch.sparse_bsr,
  474. torch.sparse_bsc,
  475. }
  476. sparse_compressed_layouts = {
  477. torch.sparse_csr,
  478. torch.sparse_csc,
  479. torch.sparse_bsr,
  480. torch.sparse_bsc,
  481. }
  482. sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc}
  483. STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__"
  484. def convert_to_strided_representation(args):
  485. """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
  486. if not isinstance(args, (list, tuple)):
  487. args = (args,)
  488. new_args: list[Any] = []
  489. for obj in args:
  490. if (
  491. isinstance(obj, torch.Tensor)
  492. and obj.requires_grad
  493. and obj.layout in sparse_layouts
  494. ):
  495. d = {
  496. "layout": obj.layout,
  497. "shape": obj.shape,
  498. }
  499. if not masked:
  500. # Materialize unspecified elements with zero values
  501. batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
  502. blocksize = (
  503. obj.values().shape[batch_dim + 1 : batch_dim + 3]
  504. if obj.layout in sparse_block_layouts
  505. else None
  506. )
  507. full_mask = torch.ones(
  508. obj.shape, device=obj.device, dtype=torch.bool
  509. ).to_sparse(
  510. layout=obj.layout,
  511. blocksize=blocksize,
  512. dense_dim=obj.dense_dim(),
  513. )
  514. obj = obj.to_dense().sparse_mask(full_mask)
  515. if obj.layout is torch.sparse_coo:
  516. # pyrefly: ignore [no-matching-overload]
  517. d.update(
  518. indices=obj._indices(), is_coalesced=obj.is_coalesced()
  519. )
  520. values = obj._values()
  521. elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
  522. # pyrefly: ignore [no-matching-overload]
  523. d.update(
  524. compressed_indices=obj.crow_indices(),
  525. plain_indices=obj.col_indices(),
  526. )
  527. values = obj.values()
  528. else:
  529. # pyrefly: ignore [no-matching-overload]
  530. d.update(
  531. compressed_indices=obj.ccol_indices(),
  532. plain_indices=obj.row_indices(),
  533. )
  534. values = obj.values()
  535. new_args.extend(
  536. (STRIDED_REPRESENTATION, d, values.requires_grad_(True))
  537. )
  538. else:
  539. new_args.append(obj)
  540. return tuple(new_args)
  541. def restore_from_strided_representation(args):
  542. """Restore non-strided differentiable tensosr from their strided representations."""
  543. new_args = []
  544. args = list(args)
  545. while args:
  546. a = args.pop(0)
  547. if a == STRIDED_REPRESENTATION:
  548. d, values = args.pop(0), args.pop(0)
  549. if d["layout"] is torch.sparse_coo:
  550. a = torch.sparse_coo_tensor(
  551. d["indices"],
  552. values,
  553. size=d["shape"],
  554. is_coalesced=d["is_coalesced"],
  555. )
  556. elif d["layout"] in sparse_compressed_layouts:
  557. a = torch.sparse_compressed_tensor(
  558. d["compressed_indices"],
  559. d["plain_indices"],
  560. values,
  561. size=d["shape"],
  562. layout=d["layout"],
  563. )
  564. else:
  565. raise NotImplementedError(
  566. f"conversion of {d['layout']} strided representation to tensor"
  567. )
  568. new_args.append(a)
  569. return tuple(new_args)
  570. def func_wrapper(*args, **kwargs):
  571. restored_args = restore_from_strided_representation(args)
  572. # convert differentiable output sparse tensors to strided
  573. # tensors:
  574. outputs = func(*restored_args, **kwargs)
  575. strided_outputs = (
  576. tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
  577. )
  578. strided_outputs = tuple(
  579. (
  580. o.to_dense(masked_grad=masked)
  581. if isinstance(o, torch.Tensor)
  582. and o.requires_grad
  583. and o.layout in sparse_layouts
  584. else o
  585. )
  586. for o in strided_outputs
  587. )
  588. return (
  589. strided_outputs
  590. if isinstance(outputs, (list, tuple))
  591. else strided_outputs[0]
  592. )
  593. args = (func_wrapper, convert_to_strided_representation(inputs))
  594. return gradcheck(*args, **kwargs)
  595. return gradcheck_with_sparse_support