rnn.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. import warnings
  2. from collections.abc import Callable, Iterable
  3. from typing import Any, NamedTuple, TypeVar
  4. from typing_extensions import Self
  5. import torch
  6. from torch import _VF, Tensor
  7. from torch.utils._typing_utils import copy_method_params
  8. __all__ = [
  9. "PackedSequence",
  10. "invert_permutation",
  11. "pack_padded_sequence",
  12. "pad_packed_sequence",
  13. "pad_sequence",
  14. "unpad_sequence",
  15. "pack_sequence",
  16. "unpack_sequence",
  17. ]
  18. _T = TypeVar("_T")
  19. _R = TypeVar("_R")
  20. class PackedSequence_(NamedTuple):
  21. data: torch.Tensor
  22. batch_sizes: torch.Tensor
  23. sorted_indices: torch.Tensor | None
  24. unsorted_indices: torch.Tensor | None
  25. def bind(optional: _T | None, fn: Callable[[_T], _R]) -> _R | None:
  26. if optional is None:
  27. return None
  28. return fn(optional)
  29. class PackedSequence(PackedSequence_):
  30. r"""Holds the data and list of :attr:`batch_sizes` of a packed sequence.
  31. All RNN modules accept packed sequences as inputs.
  32. Note:
  33. Instances of this class should never be created manually. They are meant
  34. to be instantiated by functions like :func:`pack_padded_sequence`.
  35. Batch sizes represent the number elements at each sequence step in
  36. the batch, not the varying sequence lengths passed to
  37. :func:`pack_padded_sequence`. For instance, given data ``abc`` and ``x``
  38. the :class:`PackedSequence` would contain data ``axbc`` with
  39. ``batch_sizes=[2,1,1]``.
  40. Attributes:
  41. data (Tensor): Tensor containing packed sequence
  42. batch_sizes (Tensor): Tensor of integers holding
  43. information about the batch size at each sequence step
  44. sorted_indices (Tensor, optional): Tensor of integers holding how this
  45. :class:`PackedSequence` is constructed from sequences.
  46. unsorted_indices (Tensor, optional): Tensor of integers holding how this
  47. to recover the original sequences with correct order.
  48. .. note::
  49. :attr:`data` can be on arbitrary device and of arbitrary dtype.
  50. :attr:`sorted_indices` and :attr:`unsorted_indices` must be ``torch.int64``
  51. tensors on the same device as :attr:`data`.
  52. However, :attr:`batch_sizes` should always be a CPU ``torch.int64`` tensor.
  53. This invariant is maintained throughout :class:`PackedSequence` class,
  54. and all functions that construct a :class:`PackedSequence` in PyTorch
  55. (i.e., they only pass in tensors conforming to this constraint).
  56. """
  57. def __new__(
  58. cls,
  59. data: Tensor,
  60. batch_sizes: Tensor | None = None,
  61. sorted_indices: Tensor | None = None,
  62. unsorted_indices: Tensor | None = None,
  63. ) -> Self:
  64. return super().__new__(
  65. cls,
  66. *_packed_sequence_init_args(
  67. data, batch_sizes, sorted_indices, unsorted_indices
  68. ),
  69. )
  70. # NOTE [ device and dtype of a PackedSequence ]
  71. #
  72. # See the note above in doc string (starting with ":attr:`data` can be on
  73. # arbitrary device...").
  74. def pin_memory(self) -> Self:
  75. # Why not convert `batch_sizes`?
  76. # See NOTE [ device and dtype of a PackedSequence ]
  77. return type(self)(
  78. self.data.pin_memory(),
  79. self.batch_sizes,
  80. bind(self.sorted_indices, lambda t: t.pin_memory()),
  81. bind(self.unsorted_indices, lambda t: t.pin_memory()),
  82. )
  83. @copy_method_params(torch.Tensor.to)
  84. def to(self, *args: Any, **kwargs: Any) -> Self:
  85. r"""Perform dtype and/or device conversion on `self.data`.
  86. It has similar signature as :meth:`torch.Tensor.to`
  87. .. note::
  88. If the ``self.data`` Tensor already has the correct :class:`torch.dtype`
  89. and :class:`torch.device`, then ``self`` is returned.
  90. Otherwise, returns a copy with the desired configuration.
  91. """
  92. # Why not convert `batch_sizes`?
  93. # See NOTE [ device and dtype of a PackedSequence ]
  94. data = self.data.to(*args, **kwargs)
  95. if data is self.data:
  96. return self
  97. else:
  98. _device, _dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
  99. *args, **kwargs
  100. )
  101. # Does not forward device or dtype arg/kwargs, device is set from data.device
  102. def call_to(t: torch.Tensor) -> torch.Tensor:
  103. return t.to(
  104. data.device,
  105. non_blocking=non_blocking,
  106. memory_format=convert_to_format,
  107. )
  108. sorted_indices = bind(self.sorted_indices, call_to)
  109. unsorted_indices = bind(self.unsorted_indices, call_to)
  110. return type(self)(data, self.batch_sizes, sorted_indices, unsorted_indices)
  111. @copy_method_params(torch.Tensor.cuda)
  112. def cuda(self, *args: Any, **kwargs: Any) -> Self:
  113. # Tests to see if 'cuda' should be added to kwargs
  114. ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
  115. *args, **kwargs
  116. )
  117. if ex.is_cuda:
  118. return self.to(*args, **kwargs)
  119. kwargs["device"] = "cuda"
  120. return self.to(*args, **kwargs)
  121. @copy_method_params(torch.Tensor.cpu)
  122. def cpu(self, *args: Any, **kwargs: Any) -> Self:
  123. ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
  124. *args, **kwargs
  125. )
  126. if ex.device.type == "cpu":
  127. return self.to(*args, **kwargs)
  128. kwargs["device"] = "cpu"
  129. return self.to(*args, **kwargs)
  130. def double(self) -> Self:
  131. return self.to(dtype=torch.double)
  132. def float(self) -> Self:
  133. return self.to(dtype=torch.float)
  134. def half(self) -> Self:
  135. return self.to(dtype=torch.half)
  136. def long(self) -> Self:
  137. return self.to(dtype=torch.long)
  138. def int(self) -> Self:
  139. return self.to(dtype=torch.int)
  140. def short(self) -> Self:
  141. return self.to(dtype=torch.short)
  142. def char(self) -> Self:
  143. return self.to(dtype=torch.int8)
  144. def byte(self) -> Self:
  145. return self.to(dtype=torch.uint8)
  146. @property
  147. def is_cuda(self) -> bool:
  148. r"""Return true if `self.data` stored on a gpu."""
  149. return self.data.is_cuda
  150. def is_pinned(self) -> bool:
  151. r"""Return true if `self.data` stored on in pinned memory."""
  152. return self.data.is_pinned()
  153. # TorchScript doesn't support constructors on named tuples, so we use this helper
  154. # method to construct PackedSequence
  155. def _packed_sequence_init_args(
  156. data: Tensor,
  157. batch_sizes: Tensor | None = None,
  158. sorted_indices: Tensor | None = None,
  159. unsorted_indices: Tensor | None = None,
  160. ) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]:
  161. # NB: if unsorted_indices is provided, it should be the inverse permutation
  162. # to sorted_indices. Don't assert it here because the PackedSequence ctor
  163. # should only be used internally.
  164. if unsorted_indices is None:
  165. unsorted_indices = invert_permutation(sorted_indices)
  166. # support being called as `PackedSequence(data, batch_sizes, sorted_indices)`
  167. if batch_sizes is not None:
  168. # TODO: Re-enable this check (.type isn't supported in TorchScript)
  169. if batch_sizes.device.type != "cpu":
  170. raise ValueError(
  171. "batch_sizes should always be on CPU. "
  172. "Instances of PackedSequence should never be created manually. "
  173. "They should be instantiated by functions like pack_sequence "
  174. "and pack_padded_sequences in nn.utils.rnn. "
  175. "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence"
  176. )
  177. return data, batch_sizes, sorted_indices, unsorted_indices
  178. # support being called as `PackedSequence((data, batch_sizes), *, sorted_indices)`
  179. else:
  180. if not (isinstance(data, (list, tuple)) and len(data) == 2):
  181. raise AssertionError("Expected data to be a list or tuple of length 2")
  182. return data[0], data[1], sorted_indices, unsorted_indices
  183. def _packed_sequence_init(
  184. data: Tensor,
  185. batch_sizes: Tensor | None = None,
  186. sorted_indices: Tensor | None = None,
  187. unsorted_indices: Tensor | None = None,
  188. ) -> PackedSequence:
  189. data, batch_sizes, sorted_indices, unsorted_indices = _packed_sequence_init_args(
  190. data, batch_sizes, sorted_indices, unsorted_indices
  191. )
  192. return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)
  193. def invert_permutation(permutation: Tensor | None) -> Tensor | None:
  194. """Returns the inverse of ``permutation``.
  195. This is useful for converting between sorted and unsorted indices in
  196. a :class:`~nn.utils.rnn.PackedSequence`.
  197. Args:
  198. permutation (Tensor, optional): a 1-D tensor of indices to invert
  199. """
  200. if permutation is None:
  201. return None
  202. output = torch.empty_like(permutation, memory_format=torch.legacy_contiguous_format)
  203. output.scatter_(
  204. 0, permutation, torch.arange(0, permutation.numel(), device=permutation.device)
  205. )
  206. return output
  207. def pack_padded_sequence(
  208. input: Tensor,
  209. lengths: Tensor | list[int],
  210. batch_first: bool = False,
  211. enforce_sorted: bool = True,
  212. ) -> PackedSequence:
  213. r"""Packs a Tensor containing padded sequences of variable length.
  214. :attr:`input` can be of size ``T x B x *`` (if :attr:`batch_first` is ``False``)
  215. or ``B x T x *`` (if :attr:`batch_first` is ``True``) where ``T`` is the length
  216. of the longest sequence, ``B`` is the batch size, and ``*`` is any number of dimensions
  217. (including 0).
  218. For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
  219. ``True``, the sequences should be sorted by length in a decreasing order, i.e.
  220. ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
  221. one. `enforce_sorted = True` is only necessary for ONNX export.
  222. It is an inverse operation to :func:`pad_packed_sequence`, and hence :func:`pad_packed_sequence`
  223. can be used to recover the underlying tensor packed in :class:`PackedSequence`.
  224. Note:
  225. This function accepts any input that has at least two dimensions. You
  226. can apply it to pack the labels, and use the output of the RNN with
  227. them to compute the loss directly. A Tensor can be retrieved from
  228. a :class:`PackedSequence` object by accessing its ``.data`` attribute.
  229. Args:
  230. input (Tensor): padded batch of variable length sequences.
  231. lengths (Tensor or list(int)): list of sequence lengths of each batch
  232. element (must be on the CPU if provided as a tensor).
  233. batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
  234. format, ``T x B x *`` otherwise. Default: ``False``.
  235. enforce_sorted (bool, optional): if ``True``, the input is expected to
  236. contain sequences sorted by length in a decreasing order. If
  237. ``False``, the input will get sorted unconditionally. Default: ``True``.
  238. .. warning::
  239. The dim of ``input`` tensor will be truncated if its length larger than
  240. correspond value in ``length``.
  241. Returns:
  242. a :class:`PackedSequence` object
  243. """
  244. if not isinstance(lengths, torch.Tensor):
  245. if torch._C._get_tracing_state():
  246. warnings.warn(
  247. "pack_padded_sequence has been called with a Python list of "
  248. "sequence lengths. The tracer cannot track the data flow of Python "
  249. "values, and it will treat them as constants, likely rendering "
  250. "the trace incorrect for any other combination of lengths.",
  251. stacklevel=2,
  252. )
  253. lengths = torch.as_tensor(lengths, dtype=torch.int64, device="cpu")
  254. else:
  255. lengths = lengths.to(dtype=torch.int64)
  256. if enforce_sorted:
  257. sorted_indices = None
  258. else:
  259. lengths, sorted_indices = torch.sort(lengths, descending=True)
  260. sorted_indices = sorted_indices.to(input.device)
  261. batch_dim = 0 if batch_first else 1
  262. input = input.index_select(batch_dim, sorted_indices)
  263. data, batch_sizes = _VF._pack_padded_sequence(input, lengths, batch_first)
  264. return _packed_sequence_init(data, batch_sizes, sorted_indices, None)
  265. def pad_packed_sequence(
  266. sequence: PackedSequence,
  267. batch_first: bool = False,
  268. padding_value: float = 0.0,
  269. total_length: int | None = None,
  270. ) -> tuple[Tensor, Tensor]:
  271. r"""Pad a packed batch of variable length sequences.
  272. It is an inverse operation to :func:`pack_padded_sequence`.
  273. The returned Tensor's data will be of size ``T x B x *`` (if :attr:`batch_first` is ``False``)
  274. or ``B x T x *`` (if :attr:`batch_first` is ``True``) , where ``T`` is the length of the longest
  275. sequence and ``B`` is the batch size.
  276. Example:
  277. >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
  278. >>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
  279. >>> lens = [2, 1, 3]
  280. >>> packed = pack_padded_sequence(
  281. ... seq, lens, batch_first=True, enforce_sorted=False
  282. ... )
  283. >>> packed
  284. PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
  285. sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
  286. >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
  287. >>> seq_unpacked
  288. tensor([[1, 2, 0],
  289. [3, 0, 0],
  290. [4, 5, 6]])
  291. >>> lens_unpacked
  292. tensor([2, 1, 3])
  293. .. note::
  294. :attr:`total_length` is useful to implement the
  295. ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
  296. :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
  297. See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
  298. details.
  299. Args:
  300. sequence (PackedSequence): batch to pad
  301. batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
  302. format, ``T x B x *`` otherwise.
  303. padding_value (float, optional): values for padded elements.
  304. total_length (int, optional): if not ``None``, the output will be padded to
  305. have length :attr:`total_length`. This method will throw :class:`ValueError`
  306. if :attr:`total_length` is less than the max sequence length in
  307. :attr:`sequence`.
  308. Returns:
  309. Tuple of Tensor containing the padded sequence, and a Tensor
  310. containing the list of lengths of each sequence in the batch.
  311. Batch elements will be re-ordered as they were ordered originally when
  312. the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.
  313. """
  314. max_seq_length = sequence.batch_sizes.size(0)
  315. if total_length is not None:
  316. if total_length < max_seq_length:
  317. raise ValueError(
  318. "Expected total_length to be at least the length "
  319. "of the longest sequence in input, but got "
  320. f"total_length={total_length} and max sequence length being {max_seq_length}"
  321. )
  322. max_seq_length = total_length
  323. padded_output, lengths = _VF._pad_packed_sequence(
  324. sequence.data, sequence.batch_sizes, batch_first, padding_value, max_seq_length
  325. )
  326. unsorted_indices = sequence.unsorted_indices
  327. if unsorted_indices is not None:
  328. batch_dim = 0 if batch_first else 1
  329. return (
  330. padded_output.index_select(batch_dim, unsorted_indices),
  331. lengths[unsorted_indices.cpu()],
  332. )
  333. return padded_output, lengths
  334. # NOTE: for JIT-compatibility, we need to be more restrictive here and use specific types instead of Iterable.
  335. def pad_sequence(
  336. sequences: Tensor | list[Tensor],
  337. batch_first: bool = False,
  338. padding_value: float = 0.0,
  339. padding_side: str = "right",
  340. ) -> Tensor:
  341. r"""Pad a list of variable length Tensors with :attr:`padding_value`.
  342. ``pad_sequence`` stacks a list of Tensors along a new dimension, and pads them
  343. to equal length. :attr:`sequences` can be list of sequences with size ``L x *``,
  344. where `L` is length of the sequence and ``*`` is any number of dimensions
  345. (including ``0``). If :attr:`batch_first` is ``False``, the output is of size
  346. ``T x B x *``, and ``B x T x *`` otherwise, where ``B`` is the batch size
  347. (the number of elements in :attr:`sequences`), ``T`` is the length of the longest
  348. sequence.
  349. Example:
  350. >>> from torch.nn.utils.rnn import pad_sequence
  351. >>> a = torch.ones(25, 300)
  352. >>> b = torch.ones(22, 300)
  353. >>> c = torch.ones(15, 300)
  354. >>> pad_sequence([a, b, c]).size()
  355. torch.Size([25, 3, 300])
  356. Note:
  357. This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
  358. where `T` is the length of the longest sequence. This function assumes
  359. trailing dimensions and type of all the Tensors in sequences are same.
  360. Args:
  361. sequences (list[Tensor]): list of variable length sequences.
  362. batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
  363. format, ``T x B x *`` otherwise.
  364. padding_value (float, optional): value for padded elements. Default: ``0``.
  365. padding_side (str, optional): the side to pad the sequences on.
  366. Default: ``'right'``.
  367. Returns:
  368. Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
  369. Tensor of size ``B x T x *`` otherwise
  370. """
  371. if not (torch.jit.is_tracing() or torch.jit.is_scripting()):
  372. # JIT doesn't support `Iterable`
  373. if not isinstance(sequences, Iterable):
  374. msg = (
  375. "pad_sequence: Expected iterable for input sequences, but got arg of type: "
  376. f"{type(sequences)}"
  377. )
  378. raise RuntimeError(msg)
  379. # In JIT context this leads to,
  380. # RuntimeError: cannot statically infer the expected size of a list in this context
  381. sequences = tuple(sequences) # type: ignore[assignment]
  382. else:
  383. # For JIT, we only support Union[Tensor, Tuple[Tensor]]
  384. if isinstance(sequences, torch.Tensor):
  385. sequences = sequences.unbind(0) # type: ignore[assignment]
  386. # assuming trailing dimensions and type of all the Tensors
  387. # in sequences are same and fetching those from sequences[0]
  388. return torch._C._nn.pad_sequence(
  389. sequences, # type: ignore[arg-type]
  390. batch_first,
  391. padding_value,
  392. padding_side, # type: ignore[arg-type]
  393. )
  394. def unpad_sequence(
  395. padded_sequences: Tensor,
  396. lengths: Tensor,
  397. batch_first: bool = False,
  398. ) -> list[Tensor]:
  399. r"""Unpad padded Tensor into a list of variable length Tensors.
  400. ``unpad_sequence`` unstacks padded Tensor into a list of variable length Tensors.
  401. Example:
  402. >>> from torch.nn.utils.rnn import pad_sequence, unpad_sequence
  403. >>> a = torch.ones(25, 300)
  404. >>> b = torch.ones(22, 300)
  405. >>> c = torch.ones(15, 300)
  406. >>> sequences = [a, b, c]
  407. >>> padded_sequences = pad_sequence(sequences)
  408. >>> lengths = torch.as_tensor([v.size(0) for v in sequences])
  409. >>> unpadded_sequences = unpad_sequence(padded_sequences, lengths)
  410. >>> torch.allclose(sequences[0], unpadded_sequences[0])
  411. True
  412. >>> torch.allclose(sequences[1], unpadded_sequences[1])
  413. True
  414. >>> torch.allclose(sequences[2], unpadded_sequences[2])
  415. True
  416. Args:
  417. padded_sequences (Tensor): padded sequences.
  418. lengths (Tensor): length of original (unpadded) sequences.
  419. batch_first (bool, optional): whether batch dimension first or not. Default: ``False``.
  420. Returns:
  421. a list of :class:`Tensor` objects
  422. """
  423. unpadded_sequences = []
  424. if not batch_first:
  425. padded_sequences.transpose_(0, 1)
  426. max_length = padded_sequences.shape[1]
  427. idx = torch.arange(max_length, device=lengths.device)
  428. for seq, length in zip(padded_sequences, lengths, strict=True):
  429. mask = idx < length
  430. unpacked_seq = seq[mask]
  431. unpadded_sequences.append(unpacked_seq)
  432. return unpadded_sequences
  433. def pack_sequence(
  434. sequences: list[Tensor],
  435. enforce_sorted: bool = True,
  436. ) -> PackedSequence:
  437. r"""Packs a list of variable length Tensors.
  438. Consecutive call of the next functions: ``pad_sequence``, ``pack_padded_sequence``.
  439. ``sequences`` should be a list of Tensors of size ``L x *``, where `L` is
  440. the length of a sequence and `*` is any number of trailing dimensions,
  441. including ``0``.
  442. For unsorted sequences, use `enforce_sorted = False`. If ``enforce_sorted``
  443. is ``True``, the sequences should be sorted in the order of decreasing length.
  444. ``enforce_sorted = True`` is only necessary for ONNX export.
  445. Example:
  446. >>> from torch.nn.utils.rnn import pack_sequence
  447. >>> a = torch.tensor([1, 2, 3])
  448. >>> b = torch.tensor([4, 5])
  449. >>> c = torch.tensor([6])
  450. >>> pack_sequence([a, b, c])
  451. PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
  452. Args:
  453. sequences (list[Tensor]): A list of sequences of decreasing length.
  454. enforce_sorted (bool, optional): if ``True``, checks that the input
  455. contains sequences sorted by length in a decreasing order. If
  456. ``False``, this condition is not checked. Default: ``True``.
  457. Returns:
  458. a :class:`PackedSequence` object
  459. """
  460. lengths = torch.as_tensor([v.size(0) for v in sequences])
  461. return pack_padded_sequence(
  462. pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted
  463. )
  464. def unpack_sequence(packed_sequences: PackedSequence) -> list[Tensor]:
  465. r"""Unpack PackedSequence into a list of variable length Tensors.
  466. ``packed_sequences`` should be a PackedSequence object.
  467. Example:
  468. >>> from torch.nn.utils.rnn import pack_sequence, unpack_sequence
  469. >>> a = torch.tensor([1, 2, 3])
  470. >>> b = torch.tensor([4, 5])
  471. >>> c = torch.tensor([6])
  472. >>> sequences = [a, b, c]
  473. >>> print(sequences)
  474. [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
  475. >>> packed_sequences = pack_sequence(sequences)
  476. >>> print(packed_sequences)
  477. PackedSequence(data=tensor([1, 4, 6, 2, 5, 3]), batch_sizes=tensor([3, 2, 1]), sorted_indices=None, unsorted_indices=None)
  478. >>> unpacked_sequences = unpack_sequence(packed_sequences)
  479. >>> print(unpacked_sequences)
  480. [tensor([1, 2, 3]), tensor([4, 5]), tensor([6])]
  481. Args:
  482. packed_sequences (PackedSequence): A PackedSequence object.
  483. Returns:
  484. a list of :class:`Tensor` objects
  485. """
  486. padded_sequences, lengths = pad_packed_sequence(packed_sequences, batch_first=True)
  487. unpacked_sequences = unpad_sequence(padded_sequences, lengths, batch_first=True)
  488. return unpacked_sequences