upsampling.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # mypy: allow-untyped-defs
  2. import torch.nn.functional as F
  3. from torch import Tensor
  4. from torch.nn.common_types import _ratio_2_t, _ratio_any_t, _size_2_t, _size_any_t
  5. from .module import Module
  6. __all__ = ["Upsample", "UpsamplingNearest2d", "UpsamplingBilinear2d"]
  7. class Upsample(Module):
  8. r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
  9. The input data is assumed to be of the form
  10. `minibatch x channels x [optional depth] x [optional height] x width`.
  11. Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
  12. The algorithms available for upsampling are nearest neighbor and linear,
  13. bilinear, bicubic and trilinear for 3D, 4D and 5D input Tensor,
  14. respectively.
  15. One can either give a :attr:`scale_factor` or the target output :attr:`size` to
  16. calculate the output size. (You cannot give both, as it is ambiguous)
  17. Args:
  18. size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int], optional):
  19. output spatial sizes
  20. scale_factor (float or Tuple[float] or Tuple[float, float] or Tuple[float, float, float], optional):
  21. multiplier for spatial size. Has to match input size if it is a tuple.
  22. mode (str, optional): the upsampling algorithm: one of ``'nearest'``,
  23. ``'linear'``, ``'bilinear'``, ``'bicubic'`` and ``'trilinear'``.
  24. Default: ``'nearest'``
  25. align_corners (bool, optional): if ``True``, the corner pixels of the input
  26. and output tensors are aligned, and thus preserving the values at
  27. those pixels. This only has effect when :attr:`mode` is
  28. ``'linear'``, ``'bilinear'``, ``'bicubic'``, or ``'trilinear'``.
  29. Default: ``False``
  30. recompute_scale_factor (bool, optional): recompute the scale_factor for use in the
  31. interpolation calculation. If `recompute_scale_factor` is ``True``, then
  32. `scale_factor` must be passed in and `scale_factor` is used to compute the
  33. output `size`. The computed output `size` will be used to infer new scales for
  34. the interpolation. Note that when `scale_factor` is floating-point, it may differ
  35. from the recomputed `scale_factor` due to rounding and precision issues.
  36. If `recompute_scale_factor` is ``False``, then `size` or `scale_factor` will
  37. be used directly for interpolation.
  38. Shape:
  39. - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
  40. - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`
  41. or :math:`(N, C, D_{out}, H_{out}, W_{out})`, where
  42. .. math::
  43. D_{out} = \left\lfloor D_{in} \times \text{scale\_factor} \right\rfloor
  44. .. math::
  45. H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
  46. .. math::
  47. W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
  48. .. warning::
  49. With ``align_corners = True``, the linearly interpolating modes
  50. (`linear`, `bilinear`, `bicubic`, and `trilinear`) don't proportionally
  51. align the output and input pixels, and thus the output values can depend
  52. on the input size. This was the default behavior for these modes up to
  53. version 0.3.1. Since then, the default behavior is
  54. ``align_corners = False``. See below for concrete examples on how this
  55. affects the outputs.
  56. .. note::
  57. If you want downsampling/general resizing, you should use :func:`~nn.functional.interpolate`.
  58. Examples::
  59. >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
  60. >>> input
  61. tensor([[[[1., 2.],
  62. [3., 4.]]]])
  63. >>> m = nn.Upsample(scale_factor=2, mode='nearest')
  64. >>> m(input)
  65. tensor([[[[1., 1., 2., 2.],
  66. [1., 1., 2., 2.],
  67. [3., 3., 4., 4.],
  68. [3., 3., 4., 4.]]]])
  69. >>> # xdoctest: +IGNORE_WANT("other tests seem to modify printing styles")
  70. >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
  71. >>> m(input)
  72. tensor([[[[1.0000, 1.2500, 1.7500, 2.0000],
  73. [1.5000, 1.7500, 2.2500, 2.5000],
  74. [2.5000, 2.7500, 3.2500, 3.5000],
  75. [3.0000, 3.2500, 3.7500, 4.0000]]]])
  76. >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  77. >>> m(input)
  78. tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
  79. [1.6667, 2.0000, 2.3333, 2.6667],
  80. [2.3333, 2.6667, 3.0000, 3.3333],
  81. [3.0000, 3.3333, 3.6667, 4.0000]]]])
  82. >>> # Try scaling the same data in a larger tensor
  83. >>> input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3)
  84. >>> input_3x3[:, :, :2, :2].copy_(input)
  85. tensor([[[[1., 2.],
  86. [3., 4.]]]])
  87. >>> input_3x3
  88. tensor([[[[1., 2., 0.],
  89. [3., 4., 0.],
  90. [0., 0., 0.]]]])
  91. >>> # xdoctest: +IGNORE_WANT("seems to fail when other tests are run in the same session")
  92. >>> m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
  93. >>> # Notice that values in top left corner are the same with the small input (except at boundary)
  94. >>> m(input_3x3)
  95. tensor([[[[1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000],
  96. [1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000],
  97. [2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000],
  98. [2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000],
  99. [0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000],
  100. [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
  101. >>> m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  102. >>> # Notice that values in top left corner are now changed
  103. >>> m(input_3x3)
  104. tensor([[[[1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000],
  105. [1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000],
  106. [2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000],
  107. [2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000],
  108. [1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000],
  109. [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
  110. """
  111. __constants__ = [
  112. "size",
  113. "scale_factor",
  114. "mode",
  115. "align_corners",
  116. "name",
  117. "recompute_scale_factor",
  118. ]
  119. name: str
  120. size: _size_any_t | None
  121. scale_factor: _ratio_any_t | None
  122. mode: str
  123. align_corners: bool | None
  124. recompute_scale_factor: bool | None
  125. def __init__(
  126. self,
  127. size: _size_any_t | None = None,
  128. scale_factor: _ratio_any_t | None = None,
  129. mode: str = "nearest",
  130. align_corners: bool | None = None,
  131. recompute_scale_factor: bool | None = None,
  132. ) -> None:
  133. super().__init__()
  134. self.name = type(self).__name__
  135. self.size = size
  136. if isinstance(scale_factor, tuple):
  137. self.scale_factor = tuple(float(factor) for factor in scale_factor)
  138. else:
  139. self.scale_factor = float(scale_factor) if scale_factor else None
  140. self.mode = mode
  141. self.align_corners = align_corners
  142. self.recompute_scale_factor = recompute_scale_factor
  143. def forward(self, input: Tensor) -> Tensor:
  144. """
  145. Runs the forward pass.
  146. """
  147. return F.interpolate(
  148. input,
  149. self.size,
  150. self.scale_factor,
  151. self.mode,
  152. self.align_corners,
  153. recompute_scale_factor=self.recompute_scale_factor,
  154. )
  155. def __setstate__(self, state):
  156. if "recompute_scale_factor" not in state:
  157. state["recompute_scale_factor"] = True
  158. super().__setstate__(state)
  159. def extra_repr(self) -> str:
  160. """
  161. Return the extra representation of the module.
  162. """
  163. if self.scale_factor is not None:
  164. info = "scale_factor=" + repr(self.scale_factor)
  165. else:
  166. info = "size=" + repr(self.size)
  167. info += ", mode=" + repr(self.mode)
  168. return info
  169. class UpsamplingNearest2d(Upsample):
  170. r"""Applies a 2D nearest neighbor upsampling to an input signal composed of several input channels.
  171. To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
  172. as it's constructor argument.
  173. When :attr:`size` is given, it is the output size of the image `(h, w)`.
  174. Args:
  175. size (int or Tuple[int, int], optional): output spatial sizes
  176. scale_factor (float or Tuple[float, float], optional): multiplier for
  177. spatial size.
  178. .. warning::
  179. This class is deprecated in favor of :func:`~nn.functional.interpolate`.
  180. Shape:
  181. - Input: :math:`(N, C, H_{in}, W_{in})`
  182. - Output: :math:`(N, C, H_{out}, W_{out})` where
  183. .. math::
  184. H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
  185. .. math::
  186. W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
  187. Examples::
  188. >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
  189. >>> input
  190. tensor([[[[1., 2.],
  191. [3., 4.]]]])
  192. >>> m = nn.UpsamplingNearest2d(scale_factor=2)
  193. >>> m(input)
  194. tensor([[[[1., 1., 2., 2.],
  195. [1., 1., 2., 2.],
  196. [3., 3., 4., 4.],
  197. [3., 3., 4., 4.]]]])
  198. """
  199. def __init__(
  200. self,
  201. size: _size_2_t | None = None,
  202. scale_factor: _ratio_2_t | None = None,
  203. ) -> None:
  204. super().__init__(size, scale_factor, mode="nearest")
  205. class UpsamplingBilinear2d(Upsample):
  206. r"""Applies a 2D bilinear upsampling to an input signal composed of several input channels.
  207. To specify the scale, it takes either the :attr:`size` or the :attr:`scale_factor`
  208. as it's constructor argument.
  209. When :attr:`size` is given, it is the output size of the image `(h, w)`.
  210. Args:
  211. size (int or Tuple[int, int], optional): output spatial sizes
  212. scale_factor (float or Tuple[float, float], optional): multiplier for
  213. spatial size.
  214. .. warning::
  215. This class is deprecated in favor of :func:`~nn.functional.interpolate`. It is
  216. equivalent to ``nn.functional.interpolate(..., mode='bilinear', align_corners=True)``.
  217. Shape:
  218. - Input: :math:`(N, C, H_{in}, W_{in})`
  219. - Output: :math:`(N, C, H_{out}, W_{out})` where
  220. .. math::
  221. H_{out} = \left\lfloor H_{in} \times \text{scale\_factor} \right\rfloor
  222. .. math::
  223. W_{out} = \left\lfloor W_{in} \times \text{scale\_factor} \right\rfloor
  224. Examples::
  225. >>> input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
  226. >>> input
  227. tensor([[[[1., 2.],
  228. [3., 4.]]]])
  229. >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
  230. >>> m = nn.UpsamplingBilinear2d(scale_factor=2)
  231. >>> m(input)
  232. tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
  233. [1.6667, 2.0000, 2.3333, 2.6667],
  234. [2.3333, 2.6667, 3.0000, 3.3333],
  235. [3.0000, 3.3333, 3.6667, 4.0000]]]])
  236. """
  237. def __init__(
  238. self,
  239. size: _size_2_t | None = None,
  240. scale_factor: _ratio_2_t | None = None,
  241. ) -> None:
  242. super().__init__(size, scale_factor, mode="bilinear", align_corners=True)