flatten.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # mypy: allow-untyped-defs
  2. from torch import Tensor
  3. from torch.types import _size
  4. from .module import Module
  5. __all__ = ["Flatten", "Unflatten"]
  6. class Flatten(Module):
  7. r"""
  8. Flattens a contiguous range of dims into a tensor.
  9. For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details.
  10. Shape:
  11. - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
  12. where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
  13. number of dimensions including none.
  14. - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
  15. Args:
  16. start_dim: first dim to flatten (default = 1).
  17. end_dim: last dim to flatten (default = -1).
  18. Examples::
  19. >>> input = torch.randn(32, 1, 5, 5)
  20. >>> # With default parameters
  21. >>> m = nn.Flatten()
  22. >>> output = m(input)
  23. >>> output.size()
  24. torch.Size([32, 25])
  25. >>> # With non-default parameters
  26. >>> m = nn.Flatten(0, 2)
  27. >>> output = m(input)
  28. >>> output.size()
  29. torch.Size([160, 5])
  30. """
  31. __constants__ = ["start_dim", "end_dim"]
  32. start_dim: int
  33. end_dim: int
  34. def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
  35. super().__init__()
  36. self.start_dim = start_dim
  37. self.end_dim = end_dim
  38. def forward(self, input: Tensor) -> Tensor:
  39. """
  40. Runs the forward pass.
  41. """
  42. return input.flatten(self.start_dim, self.end_dim)
  43. def extra_repr(self) -> str:
  44. """
  45. Returns the extra representation of the module.
  46. """
  47. return f"start_dim={self.start_dim}, end_dim={self.end_dim}"
  48. class Unflatten(Module):
  49. r"""
  50. Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
  51. * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
  52. be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
  53. * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
  54. a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
  55. (tuple of `(name, size)` tuples) for `NamedTensor` input.
  56. Shape:
  57. - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
  58. dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
  59. - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
  60. :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
  61. Args:
  62. dim (Union[int, str]): Dimension to be unflattened
  63. unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
  64. Examples:
  65. >>> input = torch.randn(2, 50)
  66. >>> # With tuple of ints
  67. >>> m = nn.Sequential(
  68. >>> nn.Linear(50, 50),
  69. >>> nn.Unflatten(1, (2, 5, 5))
  70. >>> )
  71. >>> output = m(input)
  72. >>> output.size()
  73. torch.Size([2, 2, 5, 5])
  74. >>> # With torch.Size
  75. >>> m = nn.Sequential(
  76. >>> nn.Linear(50, 50),
  77. >>> nn.Unflatten(1, torch.Size([2, 5, 5]))
  78. >>> )
  79. >>> output = m(input)
  80. >>> output.size()
  81. torch.Size([2, 2, 5, 5])
  82. >>> # With namedshape (tuple of tuples)
  83. >>> input = torch.randn(2, 50, names=("N", "features"))
  84. >>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5)))
  85. >>> output = unflatten(input)
  86. >>> output.size()
  87. torch.Size([2, 2, 5, 5])
  88. """
  89. NamedShape = tuple[tuple[str, int]]
  90. __constants__ = ["dim", "unflattened_size"]
  91. dim: int | str
  92. unflattened_size: _size | NamedShape
  93. def __init__(self, dim: int | str, unflattened_size: _size | NamedShape) -> None:
  94. super().__init__()
  95. if isinstance(dim, int):
  96. self._require_tuple_int(unflattened_size)
  97. elif isinstance(dim, str):
  98. self._require_tuple_tuple(unflattened_size)
  99. else:
  100. raise TypeError("invalid argument type for dim parameter")
  101. self.dim = dim
  102. self.unflattened_size = unflattened_size
  103. def _require_tuple_tuple(self, input) -> None:
  104. if isinstance(input, tuple):
  105. for idx, elem in enumerate(input):
  106. if not isinstance(elem, tuple):
  107. raise TypeError(
  108. "unflattened_size must be tuple of tuples, "
  109. + f"but found element of type {type(elem).__name__} at pos {idx}"
  110. )
  111. return
  112. raise TypeError(
  113. "unflattened_size must be a tuple of tuples, "
  114. + f"but found type {type(input).__name__}"
  115. )
  116. def _require_tuple_int(self, input) -> None:
  117. if isinstance(input, (tuple, list)):
  118. for idx, elem in enumerate(input):
  119. if not isinstance(elem, int):
  120. raise TypeError(
  121. "unflattened_size must be tuple of ints, "
  122. + f"but found element of type {type(elem).__name__} at pos {idx}"
  123. )
  124. return
  125. raise TypeError(
  126. f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}"
  127. )
  128. def forward(self, input: Tensor) -> Tensor:
  129. """
  130. Runs the forward pass.
  131. """
  132. return input.unflatten(self.dim, self.unflattened_size)
  133. def extra_repr(self) -> str:
  134. """
  135. Returns the extra representation of the module.
  136. """
  137. return f"dim={self.dim}, unflattened_size={self.unflattened_size}"