| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import torch.nn.functional as F
- from torch import Tensor
- from .module import Module
- __all__ = ["ChannelShuffle"]
- class ChannelShuffle(Module):
- r"""Divides and rearranges the channels in a tensor.
- This operation divides the channels in a tensor of shape :math:`(N, C, *)`
- into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them,
- while retaining the original tensor shape in the final output.
- Args:
- groups (int): number of groups to divide channels in.
- Examples::
- >>> channel_shuffle = nn.ChannelShuffle(2)
- >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
- >>> input
- tensor([[[[ 1., 2.],
- [ 3., 4.]],
- [[ 5., 6.],
- [ 7., 8.]],
- [[ 9., 10.],
- [11., 12.]],
- [[13., 14.],
- [15., 16.]]]])
- >>> output = channel_shuffle(input)
- >>> output
- tensor([[[[ 1., 2.],
- [ 3., 4.]],
- [[ 9., 10.],
- [11., 12.]],
- [[ 5., 6.],
- [ 7., 8.]],
- [[13., 14.],
- [15., 16.]]]])
- """
- __constants__ = ["groups"]
- groups: int
- def __init__(self, groups: int) -> None:
- super().__init__()
- self.groups = groups
- def forward(self, input: Tensor) -> Tensor:
- """
- Runs the forward pass.
- """
- return F.channel_shuffle(input, self.groups)
- def extra_repr(self) -> str:
- """
- Return the extra representation of the module.
- """
- return f"groups={self.groups}"
|