| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550 |
- import torch.nn.functional as F
- from torch import Tensor
- from torch.nn.common_types import (
- _ratio_2_t,
- _ratio_3_t,
- _size_1_t,
- _size_2_opt_t,
- _size_2_t,
- _size_3_opt_t,
- _size_3_t,
- _size_any_opt_t,
- _size_any_t,
- )
- from .module import Module
- from .utils import _pair, _single, _triple
- __all__ = [
- "MaxPool1d",
- "MaxPool2d",
- "MaxPool3d",
- "MaxUnpool1d",
- "MaxUnpool2d",
- "MaxUnpool3d",
- "AvgPool1d",
- "AvgPool2d",
- "AvgPool3d",
- "FractionalMaxPool2d",
- "FractionalMaxPool3d",
- "LPPool1d",
- "LPPool2d",
- "LPPool3d",
- "AdaptiveMaxPool1d",
- "AdaptiveMaxPool2d",
- "AdaptiveMaxPool3d",
- "AdaptiveAvgPool1d",
- "AdaptiveAvgPool2d",
- "AdaptiveAvgPool3d",
- ]
- class _MaxPoolNd(Module):
- __constants__ = [
- "kernel_size",
- "stride",
- "padding",
- "dilation",
- "return_indices",
- "ceil_mode",
- ]
- return_indices: bool
- ceil_mode: bool
- def __init__(
- self,
- kernel_size: _size_any_t,
- stride: _size_any_t | None = None,
- padding: _size_any_t = 0,
- dilation: _size_any_t = 1,
- return_indices: bool = False,
- ceil_mode: bool = False,
- ) -> None:
- super().__init__()
- self.kernel_size = kernel_size
- self.stride = stride if (stride is not None) else kernel_size
- self.padding = padding
- self.dilation = dilation
- self.return_indices = return_indices
- self.ceil_mode = ceil_mode
- def extra_repr(self) -> str:
- return (
- "kernel_size={kernel_size}, stride={stride}, padding={padding}"
- ", dilation={dilation}, ceil_mode={ceil_mode}".format(**self.__dict__)
- )
- class MaxPool1d(_MaxPoolNd):
- r"""Applies a 1D max pooling over an input signal composed of several input planes.
- In the simplest case, the output value of the layer with input size :math:`(N, C, L)`
- and output :math:`(N, C, L_{out})` can be precisely described as:
- .. math::
- out(N_i, C_j, k) = \max_{m=0, \ldots, \text{kernel\_size} - 1}
- input(N_i, C_j, stride \times k + m)
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
- for :attr:`padding` number of points. :attr:`dilation` is the stride between the elements within the
- sliding window. This `link`_ has a nice visualization of the pooling parameters.
- Note:
- When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
- or the input. Sliding windows that would start in the right padded region are ignored.
- Args:
- kernel_size: The size of the sliding window, must be > 0.
- stride: The stride of the sliding window, must be > 0. Default value is :attr:`kernel_size`.
- padding: Implicit negative infinity padding to be added on both sides, must be >= 0 and <= kernel_size / 2.
- dilation: The stride between elements within a sliding window, must be > 0.
- return_indices: If ``True``, will return the argmax along with the max values.
- Useful for :class:`torch.nn.MaxUnpool1d` later
- ceil_mode: If ``True``, will use `ceil` instead of `floor` to compute the output shape. This
- ensures that every element in the input tensor is covered by a sliding window.
- Shape:
- - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
- - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`,
- where ``ceil_mode = False``
- .. math::
- L_{out} = \left\lfloor \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
- \times (\text{kernel\_size} - 1) - 1}{\text{stride}}\right\rfloor + 1
- where ``ceil_mode = True``
- .. math::
- L_{out} = \left\lceil \frac{L_{in} + 2 \times \text{padding} - \text{dilation}
- \times (\text{kernel\_size} - 1) - 1 + (stride - 1)}{\text{stride}}\right\rceil + 1
- - Ensure that the last pooling starts inside the image, make :math:`L_{out} = L_{out} - 1`
- when :math:`(L_{out} - 1) * \text{stride} >= L_{in} + \text{padding}`.
- Examples::
- >>> # pool of size=3, stride=2
- >>> m = nn.MaxPool1d(3, stride=2)
- >>> input = torch.randn(20, 16, 50)
- >>> output = m(input)
- .. _link:
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
- """
- kernel_size: _size_1_t
- stride: _size_1_t
- padding: _size_1_t
- dilation: _size_1_t
- def forward(self, input: Tensor):
- """Runs the forward pass."""
- return F.max_pool1d(
- input,
- self.kernel_size,
- self.stride,
- self.padding,
- self.dilation,
- ceil_mode=self.ceil_mode,
- return_indices=self.return_indices,
- )
- class MaxPool2d(_MaxPoolNd):
- r"""Applies a 2D max pooling over an input signal composed of several input planes.
- In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
- output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
- can be precisely described as:
- .. math::
- \begin{aligned}
- out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
- & \text{input}(N_i, C_j, \text{stride[0]} \times h + m,
- \text{stride[1]} \times w + n)
- \end{aligned}
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
- Note:
- When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
- or the input. Sliding windows that would start in the right padded region are ignored.
- The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
- - a single ``int`` -- in which case the same value is used for the height and width dimension
- - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
- and the second `int` for the width dimension
- Args:
- kernel_size: the size of the window to take a max over
- stride: the stride of the window. Default value is :attr:`kernel_size`
- padding: Implicit negative infinity padding to be added on both sides
- dilation: a parameter that controls the stride of elements in the window
- return_indices: if ``True``, will return the max indices along with the outputs.
- Useful for :class:`torch.nn.MaxUnpool2d` later
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
- Shape:
- - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`
- - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
- .. math::
- H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
- \times (\text{kernel\_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
- .. math::
- W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
- \times (\text{kernel\_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
- Examples::
- >>> # pool of square window of size=3, stride=2
- >>> m = nn.MaxPool2d(3, stride=2)
- >>> # pool of non-square window
- >>> m = nn.MaxPool2d((3, 2), stride=(2, 1))
- >>> input = torch.randn(20, 16, 50, 32)
- >>> output = m(input)
- .. _link:
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
- """
- kernel_size: _size_2_t
- stride: _size_2_t
- padding: _size_2_t
- dilation: _size_2_t
- def forward(self, input: Tensor):
- """Runs the forward pass."""
- return F.max_pool2d(
- input,
- self.kernel_size,
- self.stride,
- self.padding,
- self.dilation,
- ceil_mode=self.ceil_mode,
- return_indices=self.return_indices,
- )
- class MaxPool3d(_MaxPoolNd):
- r"""Applies a 3D max pooling over an input signal composed of several input planes.
- In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
- output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
- can be precisely described as:
- .. math::
- \begin{aligned}
- \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, kD-1} \max_{m=0, \ldots, kH-1} \max_{n=0, \ldots, kW-1} \\
- & \text{input}(N_i, C_j, \text{stride[0]} \times d + k,
- \text{stride[1]} \times h + m, \text{stride[2]} \times w + n)
- \end{aligned}
- If :attr:`padding` is non-zero, then the input is implicitly padded with negative infinity on both sides
- for :attr:`padding` number of points. :attr:`dilation` controls the spacing between the kernel points.
- It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
- Note:
- When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
- or the input. Sliding windows that would start in the right padded region are ignored.
- The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:
- - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
- - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
- the second `int` for the height dimension and the third `int` for the width dimension
- Args:
- kernel_size: the size of the window to take a max over
- stride: the stride of the window. Default value is :attr:`kernel_size`
- padding: Implicit negative infinity padding to be added on all three sides
- dilation: a parameter that controls the stride of elements in the window
- return_indices: if ``True``, will return the max indices along with the outputs.
- Useful for :class:`torch.nn.MaxUnpool3d` later
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
- Shape:
- - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
- .. math::
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times
- (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
- .. math::
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times
- (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
- .. math::
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] - \text{dilation}[2] \times
- (\text{kernel\_size}[2] - 1) - 1}{\text{stride}[2]} + 1\right\rfloor
- Examples::
- >>> # pool of square window of size=3, stride=2
- >>> m = nn.MaxPool3d(3, stride=2)
- >>> # pool of non-square window
- >>> m = nn.MaxPool3d((3, 2, 2), stride=(2, 1, 2))
- >>> input = torch.randn(20, 16, 50, 44, 31)
- >>> output = m(input)
- .. _link:
- https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
- """
- kernel_size: _size_3_t
- stride: _size_3_t
- padding: _size_3_t
- dilation: _size_3_t
- def forward(self, input: Tensor):
- """Runs the forward pass."""
- return F.max_pool3d(
- input,
- self.kernel_size,
- self.stride,
- self.padding,
- self.dilation,
- ceil_mode=self.ceil_mode,
- return_indices=self.return_indices,
- )
- class _MaxUnpoolNd(Module):
- def extra_repr(self) -> str:
- return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}"
- class MaxUnpool1d(_MaxUnpoolNd):
- r"""Computes a partial inverse of :class:`MaxPool1d`.
- :class:`MaxPool1d` is not fully invertible, since the non-maximal values are lost.
- :class:`MaxUnpool1d` takes in as input the output of :class:`MaxPool1d`
- including the indices of the maximal values and computes a partial inverse
- in which all non-maximal values are set to zero.
- Note:
- This operation may behave nondeterministically when the input indices has repeat values.
- See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
- .. note:: :class:`MaxPool1d` can map several input sizes to the same output
- sizes. Hence, the inversion process can get ambiguous.
- To accommodate this, you can provide the needed output size
- as an additional argument :attr:`output_size` in the forward call.
- See the Inputs and Example below.
- Args:
- kernel_size (int or tuple): Size of the max pooling window.
- stride (int or tuple): Stride of the max pooling window.
- It is set to :attr:`kernel_size` by default.
- padding (int or tuple): Padding that was added to the input
- Inputs:
- - `input`: the input Tensor to invert
- - `indices`: the indices given out by :class:`~torch.nn.MaxPool1d`
- - `output_size` (optional): the targeted output size
- Shape:
- - Input: :math:`(N, C, H_{in})` or :math:`(C, H_{in})`.
- - Output: :math:`(N, C, H_{out})` or :math:`(C, H_{out})`, where
- .. math::
- H_{out} = (H_{in} - 1) \times \text{stride}[0] - 2 \times \text{padding}[0] + \text{kernel\_size}[0]
- or as given by :attr:`output_size` in the call operator
- Example::
- >>> # xdoctest: +IGNORE_WANT("do other tests modify the global state?")
- >>> pool = nn.MaxPool1d(2, stride=2, return_indices=True)
- >>> unpool = nn.MaxUnpool1d(2, stride=2)
- >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8]]])
- >>> output, indices = pool(input)
- >>> unpool(output, indices)
- tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
- >>> # Example showcasing the use of output_size
- >>> input = torch.tensor([[[1., 2, 3, 4, 5, 6, 7, 8, 9]]])
- >>> output, indices = pool(input)
- >>> unpool(output, indices, output_size=input.size())
- tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8., 0.]]])
- >>> unpool(output, indices)
- tensor([[[ 0., 2., 0., 4., 0., 6., 0., 8.]]])
- """
- kernel_size: _size_1_t
- stride: _size_1_t
- padding: _size_1_t
- def __init__(
- self,
- kernel_size: _size_1_t,
- stride: _size_1_t | None = None,
- padding: _size_1_t = 0,
- ) -> None:
- super().__init__()
- self.kernel_size = _single(kernel_size)
- self.stride = _single(stride if (stride is not None) else kernel_size)
- self.padding = _single(padding)
- def forward(
- self, input: Tensor, indices: Tensor, output_size: list[int] | None = None
- ) -> Tensor:
- """Runs the forward pass."""
- return F.max_unpool1d(
- input, indices, self.kernel_size, self.stride, self.padding, output_size
- )
- class MaxUnpool2d(_MaxUnpoolNd):
- r"""Computes a partial inverse of :class:`MaxPool2d`.
- :class:`MaxPool2d` is not fully invertible, since the non-maximal values are lost.
- :class:`MaxUnpool2d` takes in as input the output of :class:`MaxPool2d`
- including the indices of the maximal values and computes a partial inverse
- in which all non-maximal values are set to zero.
- Note:
- This operation may behave nondeterministically when the input indices has repeat values.
- See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
- .. note:: :class:`MaxPool2d` can map several input sizes to the same output
- sizes. Hence, the inversion process can get ambiguous.
- To accommodate this, you can provide the needed output size
- as an additional argument :attr:`output_size` in the forward call.
- See the Inputs and Example below.
- Args:
- kernel_size (int or tuple): Size of the max pooling window.
- stride (int or tuple): Stride of the max pooling window.
- It is set to :attr:`kernel_size` by default.
- padding (int or tuple): Padding that was added to the input
- Inputs:
- - `input`: the input Tensor to invert
- - `indices`: the indices given out by :class:`~torch.nn.MaxPool2d`
- - `output_size` (optional): the targeted output size
- Shape:
- - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
- .. math::
- H_{out} = (H_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
- .. math::
- W_{out} = (W_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
- or as given by :attr:`output_size` in the call operator
- Example::
- >>> pool = nn.MaxPool2d(2, stride=2, return_indices=True)
- >>> unpool = nn.MaxUnpool2d(2, stride=2)
- >>> input = torch.tensor([[[[ 1., 2., 3., 4.],
- [ 5., 6., 7., 8.],
- [ 9., 10., 11., 12.],
- [13., 14., 15., 16.]]]])
- >>> output, indices = pool(input)
- >>> unpool(output, indices)
- tensor([[[[ 0., 0., 0., 0.],
- [ 0., 6., 0., 8.],
- [ 0., 0., 0., 0.],
- [ 0., 14., 0., 16.]]]])
- >>> # Now using output_size to resolve an ambiguous size for the inverse
- >>> input = torch.tensor([[[[ 1., 2., 3., 4., 5.],
- [ 6., 7., 8., 9., 10.],
- [11., 12., 13., 14., 15.],
- [16., 17., 18., 19., 20.]]]])
- >>> output, indices = pool(input)
- >>> # This call will not work without specifying output_size
- >>> unpool(output, indices, output_size=input.size())
- tensor([[[[ 0., 0., 0., 0., 0.],
- [ 0., 7., 0., 9., 0.],
- [ 0., 0., 0., 0., 0.],
- [ 0., 17., 0., 19., 0.]]]])
- """
- kernel_size: _size_2_t
- stride: _size_2_t
- padding: _size_2_t
- def __init__(
- self,
- kernel_size: _size_2_t,
- stride: _size_2_t | None = None,
- padding: _size_2_t = 0,
- ) -> None:
- super().__init__()
- self.kernel_size = _pair(kernel_size)
- self.stride = _pair(stride if (stride is not None) else kernel_size)
- self.padding = _pair(padding)
- def forward(
- self, input: Tensor, indices: Tensor, output_size: list[int] | None = None
- ) -> Tensor:
- """Runs the forward pass."""
- return F.max_unpool2d(
- input, indices, self.kernel_size, self.stride, self.padding, output_size
- )
- class MaxUnpool3d(_MaxUnpoolNd):
- r"""Computes a partial inverse of :class:`MaxPool3d`.
- :class:`MaxPool3d` is not fully invertible, since the non-maximal values are lost.
- :class:`MaxUnpool3d` takes in as input the output of :class:`MaxPool3d`
- including the indices of the maximal values and computes a partial inverse
- in which all non-maximal values are set to zero.
- Note:
- This operation may behave nondeterministically when the input indices has repeat values.
- See https://github.com/pytorch/pytorch/issues/80827 and :doc:`/notes/randomness` for more information.
- .. note:: :class:`MaxPool3d` can map several input sizes to the same output
- sizes. Hence, the inversion process can get ambiguous.
- To accommodate this, you can provide the needed output size
- as an additional argument :attr:`output_size` in the forward call.
- See the Inputs section below.
- Args:
- kernel_size (int or tuple): Size of the max pooling window.
- stride (int or tuple): Stride of the max pooling window.
- It is set to :attr:`kernel_size` by default.
- padding (int or tuple): Padding that was added to the input
- Inputs:
- - `input`: the input Tensor to invert
- - `indices`: the indices given out by :class:`~torch.nn.MaxPool3d`
- - `output_size` (optional): the targeted output size
- Shape:
- - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`, where
- .. math::
- D_{out} = (D_{in} - 1) \times \text{stride[0]} - 2 \times \text{padding[0]} + \text{kernel\_size[0]}
- .. math::
- H_{out} = (H_{in} - 1) \times \text{stride[1]} - 2 \times \text{padding[1]} + \text{kernel\_size[1]}
- .. math::
- W_{out} = (W_{in} - 1) \times \text{stride[2]} - 2 \times \text{padding[2]} + \text{kernel\_size[2]}
- or as given by :attr:`output_size` in the call operator
- Example::
- >>> # pool of square window of size=3, stride=2
- >>> pool = nn.MaxPool3d(3, stride=2, return_indices=True)
- >>> unpool = nn.MaxUnpool3d(3, stride=2)
- >>> output, indices = pool(torch.randn(20, 16, 51, 33, 15))
- >>> unpooled_output = unpool(output, indices)
- >>> unpooled_output.size()
- torch.Size([20, 16, 51, 33, 15])
- """
- kernel_size: _size_3_t
- stride: _size_3_t
- padding: _size_3_t
- def __init__(
- self,
- kernel_size: _size_3_t,
- stride: _size_3_t | None = None,
- padding: _size_3_t = 0,
- ) -> None:
- super().__init__()
- self.kernel_size = _triple(kernel_size)
- self.stride = _triple(stride if (stride is not None) else kernel_size)
- self.padding = _triple(padding)
- def forward(
- self, input: Tensor, indices: Tensor, output_size: list[int] | None = None
- ) -> Tensor:
- """Runs the forward pass."""
- return F.max_unpool3d(
- input, indices, self.kernel_size, self.stride, self.padding, output_size
- )
- class _AvgPoolNd(Module):
- __constants__ = [
- "kernel_size",
- "stride",
- "padding",
- "ceil_mode",
- "count_include_pad",
- ]
- def extra_repr(self) -> str:
- return f"kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}"
- class AvgPool1d(_AvgPoolNd):
- r"""Applies a 1D average pooling over an input signal composed of several input planes.
- In the simplest case, the output value of the layer with input size :math:`(N, C, L)`,
- output :math:`(N, C, L_{out})` and :attr:`kernel_size` :math:`k`
- can be precisely described as:
- .. math::
- \text{out}(N_i, C_j, l) = \frac{1}{k} \sum_{m=0}^{k-1}
- \text{input}(N_i, C_j, \text{stride} \times l + m)
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
- for :attr:`padding` number of points.
- Note:
- When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
- or the input. Sliding windows that would start in the right padded region are ignored.
- .. note::
- pad should be at most half of effective kernel size.
- The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can each be
- an ``int`` or a one-element tuple.
- Args:
- kernel_size: the size of the window
- stride: the stride of the window. Default value is :attr:`kernel_size`
- padding: implicit zero padding to be added on both sides
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
- count_include_pad: when True, will include the zero-padding in the averaging calculation
- Shape:
- - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
- - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
- .. math::
- L_{out} = \left\lfloor \frac{L_{in} +
- 2 \times \text{padding} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
- Per the note above, if ``ceil_mode`` is True and :math:`(L_{out} - 1) \times \text{stride} \geq L_{in}
- + \text{padding}`, we skip the last window as it would start in the right padded region, resulting in
- :math:`L_{out}` being reduced by one.
- Examples::
- >>> # pool with window of size=3, stride=2
- >>> m = nn.AvgPool1d(3, stride=2)
- >>> m(torch.tensor([[[1., 2, 3, 4, 5, 6, 7]]]))
- tensor([[[2., 4., 6.]]])
- """
- kernel_size: _size_1_t
- stride: _size_1_t
- padding: _size_1_t
- ceil_mode: bool
- count_include_pad: bool
- def __init__(
- self,
- kernel_size: _size_1_t,
- stride: _size_1_t = None,
- padding: _size_1_t = 0,
- ceil_mode: bool = False,
- count_include_pad: bool = True,
- ) -> None:
- super().__init__()
- self.kernel_size = _single(kernel_size)
- self.stride = _single(stride if stride is not None else kernel_size)
- self.padding = _single(padding)
- self.ceil_mode = ceil_mode
- self.count_include_pad = count_include_pad
- def forward(self, input: Tensor) -> Tensor:
- """Runs the forward pass."""
- return F.avg_pool1d(
- input,
- self.kernel_size,
- self.stride,
- self.padding,
- self.ceil_mode,
- self.count_include_pad,
- )
- class AvgPool2d(_AvgPoolNd):
- r"""Applies a 2D average pooling over an input signal composed of several input planes.
- In the simplest case, the output value of the layer with input size :math:`(N, C, H, W)`,
- output :math:`(N, C, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kH, kW)`
- can be precisely described as:
- .. math::
- out(N_i, C_j, h, w) = \frac{1}{kH * kW} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1}
- input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n)
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on both sides
- for :attr:`padding` number of points.
- Note:
- When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
- or the input. Sliding windows that would start in the right padded region are ignored.
- .. note::
- pad should be at most half of effective kernel size.
- The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding` can either be:
- - a single ``int`` or a single-element tuple -- in which case the same value is used for the height and width dimension
- - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
- and the second `int` for the width dimension
- Args:
- kernel_size: the size of the window
- stride: the stride of the window. Default value is :attr:`kernel_size`
- padding: implicit zero padding to be added on both sides
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
- count_include_pad: when True, will include the zero-padding in the averaging calculation
- divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used.
- Shape:
- - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
- .. math::
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] -
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
- .. math::
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] -
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
- Per the note above, if ``ceil_mode`` is True and :math:`(H_{out} - 1)\times \text{stride}[0]\geq H_{in}
- + \text{padding}[0]`, we skip the last window as it would start in the bottom padded region,
- resulting in :math:`H_{out}` being reduced by one.
- The same applies for :math:`W_{out}`.
- Examples::
- >>> # pool of square window of size=3, stride=2
- >>> m = nn.AvgPool2d(3, stride=2)
- >>> # pool of non-square window
- >>> m = nn.AvgPool2d((3, 2), stride=(2, 1))
- >>> input = torch.randn(20, 16, 50, 32)
- >>> output = m(input)
- """
- __constants__ = [
- "kernel_size",
- "stride",
- "padding",
- "ceil_mode",
- "count_include_pad",
- "divisor_override",
- ]
- kernel_size: _size_2_t
- stride: _size_2_t
- padding: _size_2_t
- ceil_mode: bool
- count_include_pad: bool
- def __init__(
- self,
- kernel_size: _size_2_t,
- stride: _size_2_t | None = None,
- padding: _size_2_t = 0,
- ceil_mode: bool = False,
- count_include_pad: bool = True,
- divisor_override: int | None = None,
- ) -> None:
- super().__init__()
- self.kernel_size = kernel_size
- self.stride = stride if (stride is not None) else kernel_size
- self.padding = padding
- self.ceil_mode = ceil_mode
- self.count_include_pad = count_include_pad
- self.divisor_override = divisor_override
- def forward(self, input: Tensor) -> Tensor:
- """Runs the forward pass."""
- return F.avg_pool2d(
- input,
- self.kernel_size,
- self.stride,
- self.padding,
- self.ceil_mode,
- self.count_include_pad,
- self.divisor_override,
- )
- class AvgPool3d(_AvgPoolNd):
- r"""Applies a 3D average pooling over an input signal composed of several input planes.
- In the simplest case, the output value of the layer with input size :math:`(N, C, D, H, W)`,
- output :math:`(N, C, D_{out}, H_{out}, W_{out})` and :attr:`kernel_size` :math:`(kD, kH, kW)`
- can be precisely described as:
- .. math::
- \begin{aligned}
- \text{out}(N_i, C_j, d, h, w) ={} & \sum_{k=0}^{kD-1} \sum_{m=0}^{kH-1} \sum_{n=0}^{kW-1} \\
- & \frac{\text{input}(N_i, C_j, \text{stride}[0] \times d + k,
- \text{stride}[1] \times h + m, \text{stride}[2] \times w + n)}
- {kD \times kH \times kW}
- \end{aligned}
- If :attr:`padding` is non-zero, then the input is implicitly zero-padded on all three sides
- for :attr:`padding` number of points.
- Note:
- When ceil_mode=True, sliding windows are allowed to go off-bounds if they start within the left padding
- or the input. Sliding windows that would start in the right padded region are ignored.
- .. note::
- pad should be at most half of effective kernel size.
- The parameters :attr:`kernel_size`, :attr:`stride` can either be:
- - a single ``int`` -- in which case the same value is used for the depth, height and width dimension
- - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
- the second `int` for the height dimension and the third `int` for the width dimension
- Args:
- kernel_size: the size of the window
- stride: the stride of the window. Default value is :attr:`kernel_size`
- padding: implicit zero padding to be added on all three sides
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
- count_include_pad: when True, will include the zero-padding in the averaging calculation
- divisor_override: if specified, it will be used as divisor, otherwise :attr:`kernel_size` will be used
- Shape:
- - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
- :math:`(C, D_{out}, H_{out}, W_{out})`, where
- .. math::
- D_{out} = \left\lfloor\frac{D_{in} + 2 \times \text{padding}[0] -
- \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
- .. math::
- H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[1] -
- \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
- .. math::
- W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[2] -
- \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
- Per the note above, if ``ceil_mode`` is True and :math:`(D_{out} - 1)\times \text{stride}[0]\geq D_{in}
- + \text{padding}[0]`, we skip the last window as it would start in the padded region,
- resulting in :math:`D_{out}` being reduced by one.
- The same applies for :math:`W_{out}` and :math:`H_{out}`.
- Examples::
- >>> # pool of square window of size=3, stride=2
- >>> m = nn.AvgPool3d(3, stride=2)
- >>> # pool of non-square window
- >>> m = nn.AvgPool3d((3, 2, 2), stride=(2, 1, 2))
- >>> input = torch.randn(20, 16, 50, 44, 31)
- >>> output = m(input)
- """
- __constants__ = [
- "kernel_size",
- "stride",
- "padding",
- "ceil_mode",
- "count_include_pad",
- "divisor_override",
- ]
- kernel_size: _size_3_t
- stride: _size_3_t
- padding: _size_3_t
- ceil_mode: bool
- count_include_pad: bool
- def __init__(
- self,
- kernel_size: _size_3_t,
- stride: _size_3_t | None = None,
- padding: _size_3_t = 0,
- ceil_mode: bool = False,
- count_include_pad: bool = True,
- divisor_override: int | None = None,
- ) -> None:
- super().__init__()
- self.kernel_size = kernel_size
- self.stride = stride if (stride is not None) else kernel_size
- self.padding = padding
- self.ceil_mode = ceil_mode
- self.count_include_pad = count_include_pad
- self.divisor_override = divisor_override
- def forward(self, input: Tensor) -> Tensor:
- """Runs the forward pass."""
- return F.avg_pool3d(
- input,
- self.kernel_size,
- self.stride,
- self.padding,
- self.ceil_mode,
- self.count_include_pad,
- self.divisor_override,
- )
- def __setstate__(self, d):
- super().__setstate__(d)
- self.__dict__.setdefault("padding", 0)
- self.__dict__.setdefault("ceil_mode", False)
- self.__dict__.setdefault("count_include_pad", True)
- class FractionalMaxPool2d(Module):
- r"""Applies a 2D fractional max pooling over an input signal composed of several input planes.
- Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
- The max-pooling operation is applied in :math:`kH \times kW` regions by a stochastic
- step size determined by the target output size.
- The number of output features is equal to the number of input planes.
- .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined.
- Args:
- kernel_size: the size of the window to take a max over.
- Can be a single number k (for a square kernel of k x k) or a tuple `(kh, kw)`
- output_size: the target output size of the image of the form `oH x oW`.
- Can be a tuple `(oH, oW)` or a single number oH for a square image `oH x oH`.
- Note that we must have :math:`kH + oH - 1 <= H_{in}` and :math:`kW + oW - 1 <= W_{in}`
- output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
- This has to be a number or tuple in the range (0, 1).
- Note that we must have :math:`kH + (output\_ratio\_H * H_{in}) - 1 <= H_{in}`
- and :math:`kW + (output\_ratio\_W * W_{in}) - 1 <= W_{in}`
- return_indices: if ``True``, will return the indices along with the outputs.
- Useful to pass to :meth:`nn.MaxUnpool2d`. Default: ``False``
- Shape:
- - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
- :math:`(H_{out}, W_{out})=\text{output\_size}` or
- :math:`(H_{out}, W_{out})=\text{output\_ratio} \times (H_{in}, W_{in})`.
- Examples:
- >>> # pool of square window of size=3, and target output size 13x12
- >>> m = nn.FractionalMaxPool2d(3, output_size=(13, 12))
- >>> # pool of square window and target output size being half of input image size
- >>> m = nn.FractionalMaxPool2d(3, output_ratio=(0.5, 0.5))
- >>> input = torch.randn(20, 16, 50, 32)
- >>> output = m(input)
- .. _Fractional MaxPooling:
- https://arxiv.org/abs/1412.6071
- """
- __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"]
- kernel_size: _size_2_t
- return_indices: bool
- output_size: _size_2_t
- output_ratio: _ratio_2_t
- def __init__(
- self,
- kernel_size: _size_2_t,
- output_size: _size_2_t | None = None,
- output_ratio: _ratio_2_t | None = None,
- return_indices: bool = False,
- _random_samples=None,
- ) -> None:
- super().__init__()
- self.kernel_size = _pair(kernel_size)
- self.return_indices = return_indices
- self.register_buffer("_random_samples", _random_samples)
- self.output_size = _pair(output_size) if output_size is not None else None
- self.output_ratio = _pair(output_ratio) if output_ratio is not None else None
- if output_size is None and output_ratio is None:
- raise ValueError(
- "FractionalMaxPool2d requires specifying either "
- "an output size, or a pooling ratio"
- )
- if output_size is not None and output_ratio is not None:
- raise ValueError(
- "only one of output_size and output_ratio may be specified"
- )
- if self.output_ratio is not None:
- if not (0 < self.output_ratio[0] < 1 and 0 < self.output_ratio[1] < 1):
- raise ValueError(
- f"output_ratio must be between 0 and 1 (got {output_ratio})"
- )
- def forward(self, input: Tensor):
- return F.fractional_max_pool2d(
- input,
- self.kernel_size,
- self.output_size,
- self.output_ratio,
- self.return_indices,
- _random_samples=self._random_samples,
- )
- class FractionalMaxPool3d(Module):
- r"""Applies a 3D fractional max pooling over an input signal composed of several input planes.
- Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
- The max-pooling operation is applied in :math:`kT \times kH \times kW` regions by a stochastic
- step size determined by the target output size.
- The number of output features is equal to the number of input planes.
- .. note:: Exactly one of ``output_size`` or ``output_ratio`` must be defined.
- Args:
- kernel_size: the size of the window to take a max over.
- Can be a single number `k` (for a square kernel of `k x k x k`) or a tuple `(kt x kh x kw)`,
- `k` must greater than 0.
- output_size: the target output size of the image of the form `oT x oH x oW`.
- Can be a tuple `(oT, oH, oW)` or a single number oH for a square image `oH x oH x oH`
- output_ratio: If one wants to have an output size as a ratio of the input size, this option can be given.
- This has to be a number or tuple in the range (0, 1)
- return_indices: if ``True``, will return the indices along with the outputs.
- Useful to pass to :meth:`nn.MaxUnpool3d`. Default: ``False``
- Shape:
- - Input: :math:`(N, C, T_{in}, H_{in}, W_{in})` or :math:`(C, T_{in}, H_{in}, W_{in})`.
- - Output: :math:`(N, C, T_{out}, H_{out}, W_{out})` or :math:`(C, T_{out}, H_{out}, W_{out})`, where
- :math:`(T_{out}, H_{out}, W_{out})=\text{output\_size}` or
- :math:`(T_{out}, H_{out}, W_{out})=\text{output\_ratio} \times (T_{in}, H_{in}, W_{in})`
- Examples:
- >>> # pool of cubic window of size=3, and target output size 13x12x11
- >>> m = nn.FractionalMaxPool3d(3, output_size=(13, 12, 11))
- >>> # pool of cubic window and target output size being half of input size
- >>> m = nn.FractionalMaxPool3d(3, output_ratio=(0.5, 0.5, 0.5))
- >>> input = torch.randn(20, 16, 50, 32, 16)
- >>> output = m(input)
- .. _Fractional MaxPooling:
- https://arxiv.org/abs/1412.6071
- """
- __constants__ = ["kernel_size", "return_indices", "output_size", "output_ratio"]
- kernel_size: _size_3_t
- return_indices: bool
- output_size: _size_3_t
- output_ratio: _ratio_3_t
- def __init__(
- self,
- kernel_size: _size_3_t,
- output_size: _size_3_t | None = None,
- output_ratio: _ratio_3_t | None = None,
- return_indices: bool = False,
- _random_samples=None,
- ) -> None:
- super().__init__()
- if (isinstance(kernel_size, int) and kernel_size <= 0) or (
- isinstance(kernel_size, (tuple, list))
- and not all(k > 0 for k in kernel_size)
- ):
- raise ValueError(f"kernel_size must greater than 0, but got {kernel_size}")
- self.kernel_size = _triple(kernel_size)
- self.return_indices = return_indices
- self.register_buffer("_random_samples", _random_samples)
- self.output_size = _triple(output_size) if output_size is not None else None
- self.output_ratio = _triple(output_ratio) if output_ratio is not None else None
- if output_size is None and output_ratio is None:
- raise ValueError(
- "FractionalMaxPool3d requires specifying either "
- "an output size, or a pooling ratio"
- )
- if output_size is not None and output_ratio is not None:
- raise ValueError(
- "only one of output_size and output_ratio may be specified"
- )
- if self.output_ratio is not None:
- if not (
- 0 < self.output_ratio[0] < 1
- and 0 < self.output_ratio[1] < 1
- and 0 < self.output_ratio[2] < 1
- ):
- raise ValueError(
- f"output_ratio must be between 0 and 1 (got {output_ratio})"
- )
- def forward(self, input: Tensor):
- return F.fractional_max_pool3d(
- input,
- self.kernel_size,
- self.output_size,
- self.output_ratio,
- self.return_indices,
- _random_samples=self._random_samples,
- )
- class _LPPoolNd(Module):
- __constants__ = ["norm_type", "kernel_size", "stride", "ceil_mode"]
- norm_type: float
- ceil_mode: bool
- def __init__(
- self,
- norm_type: float,
- kernel_size: _size_any_t,
- stride: _size_any_t | None = None,
- ceil_mode: bool = False,
- ) -> None:
- super().__init__()
- self.norm_type = norm_type
- self.kernel_size = kernel_size
- self.stride = stride
- self.ceil_mode = ceil_mode
- def extra_repr(self) -> str:
- return (
- "norm_type={norm_type}, kernel_size={kernel_size}, stride={stride}, "
- "ceil_mode={ceil_mode}".format(**self.__dict__)
- )
- class LPPool1d(_LPPoolNd):
- r"""Applies a 1D power-average pooling over an input signal composed of several input planes.
- On each window, the function computed is:
- .. math::
- f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
- - At p = :math:`\infty`, one gets Max Pooling
- - At p = 1, one gets Sum Pooling (which is proportional to Average Pooling)
- .. note:: If the sum to the power of `p` is zero, the gradient of this function is
- not defined. This implementation will set the gradient to zero in this case.
- Args:
- kernel_size: a single int, the size of the window
- stride: a single int, the stride of the window. Default value is :attr:`kernel_size`
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
- Note:
- When :attr:`ceil_mode` is ``True``, sliding windows may go off-bounds if they start within the
- left padding or the input. Sliding windows that would start in the right padded region are ignored.
- Shape:
- - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
- - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
- .. math::
- L_{out} = \left\lfloor\frac{L_{in} - \text{kernel\_size}}{\text{stride}} + 1\right\rfloor
- Examples::
- >>> # power-2 pool of window of length 3, with stride 2.
- >>> m = nn.LPPool1d(2, 3, stride=2)
- >>> input = torch.randn(20, 16, 50)
- >>> output = m(input)
- """
- kernel_size: _size_1_t
- stride: _size_1_t
- def forward(self, input: Tensor) -> Tensor:
- """Runs the forward pass."""
- return F.lp_pool1d(
- input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode
- )
- class LPPool2d(_LPPoolNd):
- r"""Applies a 2D power-average pooling over an input signal composed of several input planes.
- On each window, the function computed is:
- .. math::
- f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
- - At p = :math:`\infty`, one gets Max Pooling
- - At p = 1, one gets Sum Pooling (which is proportional to average pooling)
- The parameters :attr:`kernel_size`, :attr:`stride` can either be:
- - a single ``int`` -- in which case the same value is used for the height and width dimension
- - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
- and the second `int` for the width dimension
- .. note:: If the sum to the power of `p` is zero, the gradient of this function is
- not defined. This implementation will set the gradient to zero in this case.
- Args:
- kernel_size: the size of the window
- stride: the stride of the window. Default value is :attr:`kernel_size`
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
- Note:
- When :attr:`ceil_mode` is ``True``, sliding windows may go off-bounds if they start within the
- left padding or the input. Sliding windows that would start in the right padded region are ignored.
- Shape:
- - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
- .. math::
- H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
- .. math::
- W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
- Examples::
- >>> # power-2 pool of square window of size=3, stride=2
- >>> m = nn.LPPool2d(2, 3, stride=2)
- >>> # pool of non-square window of power 1.2
- >>> m = nn.LPPool2d(1.2, (3, 2), stride=(2, 1))
- >>> input = torch.randn(20, 16, 50, 32)
- >>> output = m(input)
- """
- kernel_size: _size_2_t
- stride: _size_2_t
- def forward(self, input: Tensor) -> Tensor:
- """Runs the forward pass."""
- return F.lp_pool2d(
- input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode
- )
- class LPPool3d(_LPPoolNd):
- r"""Applies a 3D power-average pooling over an input signal composed of several input planes.
- On each window, the function computed is:
- .. math::
- f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
- - At p = :math:`\infty`, one gets Max Pooling
- - At p = 1, one gets Sum Pooling (which is proportional to average pooling)
- The parameters :attr:`kernel_size`, :attr:`stride` can either be:
- - a single ``int`` -- in which case the same value is used for the height, width and depth dimension
- - a ``tuple`` of three ints -- in which case, the first `int` is used for the depth dimension,
- the second `int` for the height dimension and the third `int` for the width dimension
- .. note:: If the sum to the power of `p` is zero, the gradient of this function is
- not defined. This implementation will set the gradient to zero in this case.
- Args:
- kernel_size: the size of the window
- stride: the stride of the window. Default value is :attr:`kernel_size`
- ceil_mode: when True, will use `ceil` instead of `floor` to compute the output shape
- Note:
- When :attr:`ceil_mode` is ``True``, sliding windows may go off-bounds if they start within the
- left padding or the input. Sliding windows that would start in the right padded region are ignored.
- Shape:
- - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or
- :math:`(C, D_{out}, H_{out}, W_{out})`, where
- .. math::
- D_{out} = \left\lfloor\frac{D_{in} - \text{kernel\_size}[0]}{\text{stride}[0]} + 1\right\rfloor
- .. math::
- H_{out} = \left\lfloor\frac{H_{in} - \text{kernel\_size}[1]}{\text{stride}[1]} + 1\right\rfloor
- .. math::
- W_{out} = \left\lfloor\frac{W_{in} - \text{kernel\_size}[2]}{\text{stride}[2]} + 1\right\rfloor
- Examples::
- >>> # power-2 pool of square window of size=3, stride=2
- >>> m = nn.LPPool3d(2, 3, stride=2)
- >>> # pool of non-square window of power 1.2
- >>> m = nn.LPPool3d(1.2, (3, 2, 2), stride=(2, 1, 2))
- >>> input = torch.randn(20, 16, 50, 44, 31)
- >>> output = m(input)
- """
- kernel_size: _size_3_t
- stride: _size_3_t
- def forward(self, input: Tensor) -> Tensor:
- """Runs the forward pass."""
- return F.lp_pool3d(
- input, float(self.norm_type), self.kernel_size, self.stride, self.ceil_mode
- )
- class _AdaptiveMaxPoolNd(Module):
- __constants__ = ["output_size", "return_indices"]
- return_indices: bool
- def __init__(
- self, output_size: _size_any_opt_t, return_indices: bool = False
- ) -> None:
- super().__init__()
- self.output_size = output_size
- self.return_indices = return_indices
- def extra_repr(self) -> str:
- return f"output_size={self.output_size}"
- # FIXME (by @ssnl): Improve adaptive pooling docs: specify what the input and
- # output shapes are, and how the operation computes output.
- class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):
- r"""Applies a 1D adaptive max pooling over an input signal composed of several input planes.
- The output size is :math:`L_{out}`, for any input size.
- The number of output features is equal to the number of input planes.
- Args:
- output_size: the target output size :math:`L_{out}`.
- return_indices: if ``True``, will return the indices along with the outputs.
- Useful to pass to nn.MaxUnpool1d. Default: ``False``
- Shape:
- - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
- - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
- :math:`L_{out}=\text{output\_size}`.
- Examples:
- >>> # target output size of 5
- >>> m = nn.AdaptiveMaxPool1d(5)
- >>> input = torch.randn(1, 64, 8)
- >>> output = m(input)
- """
- output_size: _size_1_t
- def forward(self, input: Tensor):
- """Runs the forward pass."""
- return F.adaptive_max_pool1d(input, self.output_size, self.return_indices)
- class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
- r"""Applies a 2D adaptive max pooling over an input signal composed of several input planes.
- The output is of size :math:`H_{out} \times W_{out}`, for any input size.
- The number of output features is equal to the number of input planes.
- Args:
- output_size: the target output size of the image of the form :math:`H_{out} \times W_{out}`.
- Can be a tuple :math:`(H_{out}, W_{out})` or a single :math:`H_{out}` for a
- square image :math:`H_{out} \times H_{out}`. :math:`H_{out}` and :math:`W_{out}`
- can be either a ``int``, or ``None`` which means the size will be the same as that
- of the input.
- return_indices: if ``True``, will return the indices along with the outputs.
- Useful to pass to nn.MaxUnpool2d. Default: ``False``
- Shape:
- - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(C, H_{out}, W_{out})`, where
- :math:`(H_{out}, W_{out})=\text{output\_size}`.
- Examples:
- >>> # target output size of 5x7
- >>> m = nn.AdaptiveMaxPool2d((5, 7))
- >>> input = torch.randn(1, 64, 8, 9)
- >>> output = m(input)
- >>> # target output size of 7x7 (square)
- >>> m = nn.AdaptiveMaxPool2d(7)
- >>> input = torch.randn(1, 64, 10, 9)
- >>> output = m(input)
- >>> # target output size of 10x7
- >>> m = nn.AdaptiveMaxPool2d((None, 7))
- >>> input = torch.randn(1, 64, 10, 9)
- >>> output = m(input)
- """
- output_size: _size_2_opt_t
- def forward(self, input: Tensor):
- """Runs the forward pass."""
- return F.adaptive_max_pool2d(input, self.output_size, self.return_indices)
- class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
- r"""Applies a 3D adaptive max pooling over an input signal composed of several input planes.
- The output is of size :math:`D_{out} \times H_{out} \times W_{out}`, for any input size.
- The number of output features is equal to the number of input planes.
- Args:
- output_size: the target output size of the image of the form :math:`D_{out} \times H_{out} \times W_{out}`.
- Can be a tuple :math:`(D_{out}, H_{out}, W_{out})` or a single
- :math:`D_{out}` for a cube :math:`D_{out} \times D_{out} \times D_{out}`.
- :math:`D_{out}`, :math:`H_{out}` and :math:`W_{out}` can be either a
- ``int``, or ``None`` which means the size will be the same as that of the input.
- return_indices: if ``True``, will return the indices along with the outputs.
- Useful to pass to nn.MaxUnpool3d. Default: ``False``
- Shape:
- - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` or :math:`(C, D_{out}, H_{out}, W_{out})`,
- where :math:`(D_{out}, H_{out}, W_{out})=\text{output\_size}`.
- Examples:
- >>> # target output size of 5x7x9
- >>> m = nn.AdaptiveMaxPool3d((5, 7, 9))
- >>> input = torch.randn(1, 64, 8, 9, 10)
- >>> output = m(input)
- >>> # target output size of 7x7x7 (cube)
- >>> m = nn.AdaptiveMaxPool3d(7)
- >>> input = torch.randn(1, 64, 10, 9, 8)
- >>> output = m(input)
- >>> # target output size of 7x9x8
- >>> m = nn.AdaptiveMaxPool3d((7, None, None))
- >>> input = torch.randn(1, 64, 10, 9, 8)
- >>> output = m(input)
- """
- output_size: _size_3_opt_t
- def forward(self, input: Tensor):
- """Runs the forward pass."""
- return F.adaptive_max_pool3d(input, self.output_size, self.return_indices)
- class _AdaptiveAvgPoolNd(Module):
- __constants__ = ["output_size"]
- def __init__(self, output_size: _size_any_opt_t) -> None:
- super().__init__()
- self.output_size = output_size
- def extra_repr(self) -> str:
- return f"output_size={self.output_size}"
- class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
- r"""Applies a 1D adaptive average pooling over an input signal composed of several input planes.
- The output size is :math:`L_{out}`, for any input size.
- The number of output features is equal to the number of input planes.
- Args:
- output_size: the target output size :math:`L_{out}`.
- Shape:
- - Input: :math:`(N, C, L_{in})` or :math:`(C, L_{in})`.
- - Output: :math:`(N, C, L_{out})` or :math:`(C, L_{out})`, where
- :math:`L_{out}=\text{output\_size}`.
- Examples:
- >>> # target output size of 5
- >>> m = nn.AdaptiveAvgPool1d(5)
- >>> input = torch.randn(1, 64, 8)
- >>> output = m(input)
- """
- output_size: _size_1_t
- def forward(self, input: Tensor) -> Tensor:
- """
- Runs the forward pass.
- """
- return F.adaptive_avg_pool1d(input, self.output_size)
- class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
- r"""Applies a 2D adaptive average pooling over an input signal composed of several input planes.
- The output is of size H x W, for any input size.
- The number of output features is equal to the number of input planes.
- Args:
- output_size: the target output size of the image of the form H x W.
- Can be a tuple (H, W) or a single H for a square image H x H.
- H and W can be either a ``int``, or ``None`` which means the size will
- be the same as that of the input.
- Shape:
- - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(C, H_{in}, W_{in})`.
- - Output: :math:`(N, C, S_{0}, S_{1})` or :math:`(C, S_{0}, S_{1})`, where
- :math:`S=\text{output\_size}`.
- Examples:
- >>> # target output size of 5x7
- >>> m = nn.AdaptiveAvgPool2d((5, 7))
- >>> input = torch.randn(1, 64, 8, 9)
- >>> output = m(input)
- >>> # target output size of 7x7 (square)
- >>> m = nn.AdaptiveAvgPool2d(7)
- >>> input = torch.randn(1, 64, 10, 9)
- >>> output = m(input)
- >>> # target output size of 10x7
- >>> m = nn.AdaptiveAvgPool2d((None, 7))
- >>> input = torch.randn(1, 64, 10, 9)
- >>> output = m(input)
- """
- output_size: _size_2_opt_t
- def forward(self, input: Tensor) -> Tensor:
- """Runs the forward pass."""
- return F.adaptive_avg_pool2d(input, self.output_size)
- class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
- r"""Applies a 3D adaptive average pooling over an input signal composed of several input planes.
- The output is of size D x H x W, for any input size.
- The number of output features is equal to the number of input planes.
- Args:
- output_size: the target output size of the form D x H x W.
- Can be a tuple (D, H, W) or a single number D for a cube D x D x D.
- D, H and W can be either a ``int``, or ``None`` which means the size will
- be the same as that of the input.
- Shape:
- - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})` or :math:`(C, D_{in}, H_{in}, W_{in})`.
- - Output: :math:`(N, C, S_{0}, S_{1}, S_{2})` or :math:`(C, S_{0}, S_{1}, S_{2})`,
- where :math:`S=\text{output\_size}`.
- Examples:
- >>> # target output size of 5x7x9
- >>> m = nn.AdaptiveAvgPool3d((5, 7, 9))
- >>> input = torch.randn(1, 64, 8, 9, 10)
- >>> output = m(input)
- >>> # target output size of 7x7x7 (cube)
- >>> m = nn.AdaptiveAvgPool3d(7)
- >>> input = torch.randn(1, 64, 10, 9, 8)
- >>> output = m(input)
- >>> # target output size of 7x9x8
- >>> m = nn.AdaptiveAvgPool3d((7, None, None))
- >>> input = torch.randn(1, 64, 10, 9, 8)
- >>> output = m(input)
- """
- output_size: _size_3_opt_t
- def forward(self, input: Tensor) -> Tensor:
- """Runs the forward pass."""
- return F.adaptive_avg_pool3d(input, self.output_size)
|