__init__.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import SymInt, Tensor
  6. from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
  7. from torch.types import _device as Device, _dtype as DType
  8. __all__ = [
  9. "to_padded_tensor",
  10. "as_nested_tensor",
  11. "nested_tensor",
  12. "nested_tensor_from_jagged",
  13. "narrow",
  14. "masked_select",
  15. ]
  16. # Allowlist these for weights_only load of NJT
  17. from ._internal.nested_tensor import _rebuild_njt, NestedTensor as _NestedTensor
  18. torch.serialization.add_safe_globals([_NestedTensor, _rebuild_njt])
  19. def as_nested_tensor(
  20. ts: Tensor | list[Tensor] | tuple[Tensor, ...],
  21. dtype: DType | None = None,
  22. device: Device | None = None,
  23. layout=None,
  24. ) -> Tensor:
  25. r"""
  26. Constructs a nested tensor preserving autograd history from a tensor or a list / tuple of
  27. tensors.
  28. If a nested tensor is passed, it will be returned directly unless the device / dtype / layout
  29. differ. Note that converting device / dtype will result in a copy, while converting layout
  30. is not currently supported by this function.
  31. If a non-nested tensor is passed, it is treated as a batch of constituents of consistent size.
  32. A copy will be incurred if the passed device / dtype differ from those of the input OR if
  33. the input is non-contiguous. Otherwise, the input's storage will be used directly.
  34. If a tensor list is provided, tensors in the list are always copied during construction of
  35. the nested tensor.
  36. Args:
  37. ts (Tensor or List[Tensor] or Tuple[Tensor]): a tensor to treat as a nested tensor OR a
  38. list / tuple of tensors with the same ndim
  39. Keyword arguments:
  40. dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
  41. Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
  42. device (:class:`torch.device`, optional): the desired device of returned nested tensor.
  43. Default: if None, same :class:`torch.device` as leftmost tensor in the list
  44. layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
  45. Only strided and jagged layouts are supported. Default: if None, the strided layout.
  46. Example::
  47. >>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
  48. >>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
  49. >>> nt = torch.nested.as_nested_tensor([a, b])
  50. >>> nt.is_leaf
  51. False
  52. >>> fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)])
  53. >>> nt.backward(fake_grad)
  54. >>> a.grad
  55. tensor([1., 1., 1.])
  56. >>> b.grad
  57. tensor([0., 0., 0., 0., 0.])
  58. >>> c = torch.randn(3, 5, requires_grad=True)
  59. >>> nt2 = torch.nested.as_nested_tensor(c)
  60. """
  61. is_tensor_list = isinstance(ts, (list, tuple)) and all(
  62. isinstance(t, Tensor) for t in ts
  63. )
  64. if not isinstance(ts, Tensor) and not is_tensor_list:
  65. raise TypeError(
  66. "as_nested_tensor(): Expected first argument to be a tensor or a list / tuple of tensors "
  67. )
  68. # convert tuple -> list if needed
  69. if is_tensor_list and not isinstance(ts, list):
  70. ts = list(ts)
  71. if isinstance(ts, Tensor) and ts.dim() < 2:
  72. raise RuntimeError(
  73. "as_nested_tensor(): Expected tensor argument to have dim() > 1"
  74. )
  75. if isinstance(ts, Tensor) and ts.is_nested:
  76. if layout == ts.layout:
  77. # return input directly or input copied to device / dtype
  78. return ts.to(device=device, dtype=dtype)
  79. else:
  80. # TODO: Just use nt.to(layout=layout) when it exists.
  81. raise RuntimeError(
  82. "as_nested_tensor(): Converting between nested tensor layouts is not supported"
  83. )
  84. if layout is None:
  85. layout = torch.strided
  86. if layout == torch.strided:
  87. if isinstance(ts, Tensor):
  88. # contiguous() might be necessary to get flattened view.
  89. # we could probably be more precise about when to do this as an optimization
  90. buffer = ts.contiguous().view(-1).to(device=device, dtype=dtype)
  91. nested_sizes = torch.tensor([t.shape for t in ts])
  92. return torch._nested_view_from_buffer(
  93. buffer,
  94. nested_sizes,
  95. *torch._nested_compute_contiguous_strides_offsets(nested_sizes),
  96. )
  97. else:
  98. if not isinstance(ts, list):
  99. raise AssertionError(
  100. f"Expected ts to be a list, but got {type(ts).__name__}"
  101. )
  102. return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None)
  103. elif layout == torch.jagged:
  104. if isinstance(ts, Tensor):
  105. if device is None:
  106. device = ts.device
  107. # contiguous() might be necessary to get flattened view.
  108. # we could probably be more precise about when to do this as an optimization
  109. values = ts.contiguous().flatten(0, 1).to(device=device, dtype=dtype)
  110. batch_size = ts.shape[0]
  111. seq_len = ts.shape[1]
  112. offsets = torch.arange(
  113. 0, batch_size * seq_len + 1, seq_len, device=device, dtype=torch.int64
  114. )
  115. from torch.nested._internal.nested_tensor import (
  116. nested_view_from_values_offsets,
  117. )
  118. return nested_view_from_values_offsets(
  119. values, offsets, min_seqlen=seq_len, max_seqlen=seq_len
  120. )
  121. else:
  122. from torch.nested._internal.nested_tensor import jagged_from_list
  123. if not isinstance(ts, list):
  124. raise AssertionError(
  125. f"Expected ts to be a list, but got {type(ts).__name__}"
  126. )
  127. nt, _ = jagged_from_list(ts, offsets=None, device=device, dtype=dtype)
  128. return nt
  129. else:
  130. raise RuntimeError(
  131. f"Specified layout is unsupported for nested tensors: {layout}"
  132. )
  133. # Note: This not only adds doc strings for the nested ops, but
  134. # also connects the torch.nested Python namespace to the torch._C._nested builtins.
  135. to_padded_tensor = _add_docstr(
  136. _nested.nested_to_padded_tensor,
  137. r"""
  138. to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor
  139. Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor.
  140. The leading entries will be filled with the nested data,
  141. while the trailing entries will be padded.
  142. .. warning::
  143. :func:`to_padded_tensor` always copies the underlying data,
  144. since the nested and the non-nested tensors differ in memory layout.
  145. Args:
  146. padding (float): The padding value for the trailing entries.
  147. Keyword args:
  148. output_size (Tuple[int]): The size of the output tensor.
  149. If given, it must be large enough to contain all nested data;
  150. else, will infer by taking the max size of each nested sub-tensor along each dimension.
  151. out (Tensor, optional): the output tensor.
  152. Example::
  153. >>> nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))])
  154. nested_tensor([
  155. tensor([[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276],
  156. [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995]]),
  157. tensor([[-1.8546, -0.7194, -0.2918, -0.1846],
  158. [ 0.2773, 0.8793, -0.5183, -0.6447],
  159. [ 1.8009, 1.8468, -0.9832, -1.5272]])
  160. ])
  161. >>> pt_infer = torch.nested.to_padded_tensor(nt, 0.0)
  162. tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276],
  163. [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995],
  164. [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
  165. [[-1.8546, -0.7194, -0.2918, -0.1846, 0.0000],
  166. [ 0.2773, 0.8793, -0.5183, -0.6447, 0.0000],
  167. [ 1.8009, 1.8468, -0.9832, -1.5272, 0.0000]]])
  168. >>> pt_large = torch.nested.to_padded_tensor(nt, 1.0, (2, 4, 6))
  169. tensor([[[ 1.6862, -1.1282, 1.1031, 0.0464, -1.3276, 1.0000],
  170. [-1.9967, -1.0054, 1.8972, 0.9174, -1.4995, 1.0000],
  171. [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
  172. [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
  173. [[-1.8546, -0.7194, -0.2918, -0.1846, 1.0000, 1.0000],
  174. [ 0.2773, 0.8793, -0.5183, -0.6447, 1.0000, 1.0000],
  175. [ 1.8009, 1.8468, -0.9832, -1.5272, 1.0000, 1.0000],
  176. [ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]])
  177. >>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2))
  178. RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
  179. """,
  180. )
  181. def nested_tensor(
  182. tensor_list,
  183. *,
  184. dtype=None,
  185. layout=None,
  186. device=None,
  187. requires_grad=False,
  188. pin_memory=False,
  189. ) -> Tensor:
  190. r"""
  191. Constructs a nested tensor with no autograd history (also known as a "leaf tensor", see
  192. :ref:`Autograd mechanics <autograd-mechanics>`) from :attr:`tensor_list` a list of tensors.
  193. Args:
  194. tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor,
  195. where each element of the list has the same dimensionality.
  196. Keyword arguments:
  197. dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
  198. Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
  199. layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
  200. Only strided and jagged layouts are supported. Default: if None, the strided layout.
  201. device (:class:`torch.device`, optional): the desired device of returned nested tensor.
  202. Default: if None, same :class:`torch.device` as leftmost tensor in the list
  203. requires_grad (bool, optional): If autograd should record operations on the
  204. returned nested tensor. Default: ``False``.
  205. pin_memory (bool, optional): If set, returned nested tensor would be allocated in
  206. the pinned memory. Works only for CPU tensors. Default: ``False``.
  207. Example::
  208. >>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
  209. >>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
  210. >>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
  211. >>> nt.is_leaf
  212. True
  213. """
  214. if layout is None:
  215. layout = torch.strided
  216. if layout == torch.strided:
  217. return _nested.nested_tensor(
  218. tensor_list,
  219. dtype=dtype,
  220. device=device,
  221. requires_grad=requires_grad,
  222. pin_memory=pin_memory,
  223. )
  224. elif layout == torch.jagged:
  225. # Need to wrap lists of scalars as tensors
  226. list_of_tensors = [
  227. t if isinstance(t, Tensor) else torch.as_tensor(t) for t in tensor_list
  228. ]
  229. from torch.nested._internal.nested_tensor import jagged_from_list
  230. with torch.no_grad():
  231. nt, _ = jagged_from_list(
  232. list_of_tensors, offsets=None, device=device, dtype=dtype
  233. )
  234. nt.requires_grad_(requires_grad)
  235. if pin_memory:
  236. nt = nt.pin_memory() # type: ignore[assignment]
  237. return nt
  238. else:
  239. raise RuntimeError(
  240. f"Specified layout is unsupported for nested tensors: {layout}"
  241. )
  242. def narrow(
  243. tensor: Tensor,
  244. dim: int,
  245. start: int | Tensor,
  246. length: int | Tensor,
  247. layout=torch.strided,
  248. ) -> Tensor:
  249. r"""
  250. Constructs a nested tensor (which might be a view) from :attr:`tensor`, a strided tensor. This follows
  251. similar semantics to torch.Tensor.narrow, where in the :attr:`dim`-th dimension the new nested tensor
  252. shows only the elements in the interval `[start, start+length)`. As nested representations
  253. allow for a different `start` and `length` at each 'row' of that dimension, :attr:`start` and :attr:`length`
  254. can also be tensors of shape `tensor.shape[0]`.
  255. There's some differences depending on the layout you use for the nested tensor. If using strided layout,
  256. torch.narrow will do a copy of the narrowed data into a contiguous NT with strided layout, while
  257. jagged layout narrow() will create a non-contiguous view of your original strided tensor. This particular
  258. representation is really useful for representing kv-caches in Transformer models, as specialized
  259. SDPA kernels can deal with format easily, resulting in performance improvements.
  260. Args:
  261. tensor (:class:`torch.Tensor`): a strided tensor, which will be used as the underlying data
  262. for the nested tensor if using the jagged layout or will be copied for the strided layout.
  263. dim (int): the dimension where narrow will be applied. Only `dim=1` is supported for the
  264. jagged layout, while strided supports all dim
  265. start (Union[int, :class:`torch.Tensor`]): starting element for the narrow operation
  266. length (Union[int, :class:`torch.Tensor`]): number of elements taken during the narrow op
  267. Keyword arguments:
  268. layout (:class:`torch.layout`, optional): the desired layout of returned nested tensor.
  269. Only strided and jagged layouts are supported. Default: if None, the strided layout.
  270. Example::
  271. >>> starts = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
  272. >>> lengths = torch.tensor([3, 2, 2, 1, 5], dtype=torch.int64)
  273. >>> narrow_base = torch.randn(5, 10, 20)
  274. >>> nt_narrowed = torch.nested.narrow(narrow_base, 1, starts, lengths, layout=torch.jagged)
  275. >>> nt_narrowed.is_contiguous()
  276. False
  277. """
  278. if not isinstance(start, (int, SymInt, Tensor)):
  279. raise RuntimeError("start must be an integer or a tensor")
  280. if not isinstance(length, (int, SymInt, Tensor)):
  281. raise RuntimeError("length must be an integer or a tensor")
  282. if layout == torch.strided:
  283. if isinstance(start, Tensor) or isinstance(length, Tensor):
  284. raise RuntimeError(
  285. "start and length must be integers for the strided layout NT impl"
  286. )
  287. # TODO: switch to as_nested_tensor(tensor) when it is available
  288. nt = as_nested_tensor(torch.unbind(tensor), layout=torch.strided).narrow(
  289. dim, start, length
  290. )
  291. elif layout == torch.jagged:
  292. if dim != 1:
  293. raise RuntimeError("jagged layout only supports dim=1")
  294. from torch.nested._internal.nested_tensor import jagged_from_tensor_and_lengths
  295. if isinstance(start, (int, SymInt)):
  296. start = torch.tensor([start], device=tensor.device, dtype=torch.int64)
  297. if isinstance(length, (int, SymInt)):
  298. length = torch.tensor([length], device=tensor.device, dtype=torch.int64)
  299. nt, _, _ = jagged_from_tensor_and_lengths(tensor, start, length)
  300. else:
  301. raise RuntimeError(
  302. f"Specified layout is unsupported for nested narrow: {layout}"
  303. )
  304. return nt
  305. def nested_tensor_from_jagged(
  306. values: Tensor,
  307. offsets: Tensor | None = None,
  308. lengths: Tensor | None = None,
  309. jagged_dim: int | None = None,
  310. min_seqlen: int | None = None,
  311. max_seqlen: int | None = None,
  312. ) -> Tensor:
  313. r"""
  314. Constructs a jagged layout nested tensor from the given jagged components. The jagged layout
  315. consists of a required values buffer with the jagged dimension packed into a single dimension.
  316. The offsets / lengths metadata determines how this dimension is split into batch elements
  317. and are expected to be allocated on the same device as the values buffer.
  318. Expected metadata formats:
  319. * offsets: Indices within the packed dimension splitting it into heterogeneously-sized
  320. batch elements. Example: [0, 2, 3, 6] indicates that a packed jagged dim of size 6
  321. should be conceptually split into batch elements of length [2, 1, 3]. Note that both the
  322. beginning and ending offsets are required for kernel convenience (i.e. shape batch_size + 1).
  323. * lengths: Lengths of the individual batch elements; shape == batch_size. Example: [2, 1, 3]
  324. indicates that a packed jagged dim of size 6 should be conceptually split into batch
  325. elements of length [2, 1, 3].
  326. Note that it can be useful to provide both offsets and lengths. This describes a nested tensor
  327. with "holes", where the offsets indicate the start position of each batch item and the length
  328. specifies the total number of elements (see example below).
  329. The returned jagged layout nested tensor will be a view of the input values tensor.
  330. Args:
  331. values (:class:`torch.Tensor`): The underlying buffer in the shape of
  332. (sum_B(*), D_1, ..., D_N). The jagged dimension is packed into a single dimension,
  333. with the offsets / lengths metadata used to distinguish batch elements.
  334. offsets (optional :class:`torch.Tensor`): Offsets into the jagged dimension of shape B + 1.
  335. lengths (optional :class:`torch.Tensor`): Lengths of the batch elements of shape B.
  336. jagged_dim (optional int): Indicates which dimension in values is the packed jagged
  337. dimension. Must be >= 1 as the batch dimension (dim=0) cannot be ragged.
  338. If None, this is set to dim=1 (i.e. the dimension immediately following the batch dimension). Default: None
  339. min_seqlen (optional int): If set, uses the specified value as the cached minimum sequence
  340. length for the returned nested tensor. This can be a useful alternative to computing
  341. this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None
  342. max_seqlen (optional int): If set, uses the specified value as the cached maximum sequence
  343. length for the returned nested tensor. This can be a useful alternative to computing
  344. this value on-demand, possibly avoiding a GPU -> CPU sync. Default: None
  345. Example::
  346. >>> values = torch.randn(12, 5)
  347. >>> offsets = torch.tensor([0, 3, 5, 6, 10, 12])
  348. >>> nt = nested_tensor_from_jagged(values, offsets)
  349. >>> # 3D shape with the middle dimension jagged
  350. >>> nt.shape
  351. torch.Size([5, j2, 5])
  352. >>> # Length of each item in the batch:
  353. >>> offsets.diff()
  354. tensor([3, 2, 1, 4, 2])
  355. >>> values = torch.randn(6, 5)
  356. >>> offsets = torch.tensor([0, 2, 3, 6])
  357. >>> lengths = torch.tensor([1, 1, 2])
  358. >>> # NT with holes
  359. >>> nt = nested_tensor_from_jagged(values, offsets, lengths)
  360. >>> a, b, c = nt.unbind()
  361. >>> # Batch item 1 consists of indices [0, 1)
  362. >>> torch.equal(a, values[0:1, :])
  363. True
  364. >>> # Batch item 2 consists of indices [2, 3)
  365. >>> torch.equal(b, values[2:3, :])
  366. True
  367. >>> # Batch item 3 consists of indices [3, 5)
  368. >>> torch.equal(c, values[3:5, :])
  369. True
  370. """
  371. from torch.fx._symbolic_trace import is_fx_tracing
  372. if is_fx_tracing():
  373. raise RuntimeError(
  374. "torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace. "
  375. "Use fx.wrap to wrap the function that calls nested_tensor_from_jagged."
  376. )
  377. if offsets is None:
  378. if lengths is None:
  379. raise RuntimeError(
  380. "nested_tensor_from_jagged(): At least one of offsets or lengths is required."
  381. )
  382. else:
  383. # TODO: Truly support offsets=None at some point?
  384. # For now, just convert lengths -> offsets for kernel convenience
  385. offsets = F.pad(lengths.cumsum(0), (1, 0))
  386. lengths = None
  387. if jagged_dim is None:
  388. jagged_dim = 1
  389. elif jagged_dim < 1:
  390. raise ValueError(f"Expected jagged_dim >=1, but got {jagged_dim}.")
  391. from torch.nested._internal.nested_tensor import (
  392. nested_view_from_values_offsets_lengths,
  393. )
  394. return nested_view_from_values_offsets_lengths(
  395. values,
  396. offsets,
  397. lengths,
  398. ragged_idx=jagged_dim,
  399. min_seqlen=min_seqlen,
  400. max_seqlen=max_seqlen,
  401. )
  402. def masked_select(tensor: Tensor, mask: Tensor) -> Tensor:
  403. r"""
  404. Constructs a nested tensor given a strided tensor input and a strided mask, the resulting jagged layout nested tensor
  405. will have values retain values where the mask is equal to True. The dimensionality of the mask is preserved and is
  406. represented with the offsets, this is unlike :func:`masked_select` where the output is collapsed to a 1D tensor.
  407. Args:
  408. tensor (:class:`torch.Tensor`): a strided tensor from which the jagged layout nested tensor is constructed from.
  409. mask (:class:`torch.Tensor`): a strided mask tensor which is applied to the tensor input
  410. Example::
  411. >>> tensor = torch.randn(3, 3)
  412. >>> mask = torch.tensor([[False, False, True], [True, False, True], [False, False, True]])
  413. >>> nt = torch.nested.masked_select(tensor, mask)
  414. >>> nt.shape
  415. torch.Size([3, j4])
  416. >>> # Length of each item in the batch:
  417. >>> nt.offsets().diff()
  418. tensor([1, 2, 1])
  419. >>> tensor = torch.randn(6, 5)
  420. >>> mask = torch.tensor([False])
  421. >>> nt = torch.nested.masked_select(tensor, mask)
  422. >>> nt.shape
  423. torch.Size([6, j5])
  424. >>> # Length of each item in the batch:
  425. >>> nt.offsets().diff()
  426. tensor([0, 0, 0, 0, 0, 0])
  427. """
  428. if tensor.layout != torch.strided:
  429. raise RuntimeError(
  430. f"torch.nested.masked_select requires a strided tensor, given {tensor.layout}"
  431. )
  432. if mask.layout != torch.strided:
  433. raise RuntimeError(
  434. f"torch.nested.masked_select requires a strided mask, given: {mask.layout}"
  435. )
  436. res_values = tensor.masked_select(mask)
  437. expanded_mask = mask.expand(tensor.shape)
  438. res_lengths = expanded_mask.sum(dim=tensor.ndim - 1).view(-1)
  439. from torch.nested._internal.nested_tensor import nested_view_from_values_offsets
  440. return nested_view_from_values_offsets(
  441. values=res_values,
  442. offsets=F.pad(res_lengths.cumsum(dim=0), (1, 0)),
  443. )