channelshuffle.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import torch.nn.functional as F
  2. from torch import Tensor
  3. from .module import Module
  4. __all__ = ["ChannelShuffle"]
  5. class ChannelShuffle(Module):
  6. r"""Divides and rearranges the channels in a tensor.
  7. This operation divides the channels in a tensor of shape :math:`(N, C, *)`
  8. into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them,
  9. while retaining the original tensor shape in the final output.
  10. Args:
  11. groups (int): number of groups to divide channels in.
  12. Examples::
  13. >>> channel_shuffle = nn.ChannelShuffle(2)
  14. >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
  15. >>> input
  16. tensor([[[[ 1., 2.],
  17. [ 3., 4.]],
  18. [[ 5., 6.],
  19. [ 7., 8.]],
  20. [[ 9., 10.],
  21. [11., 12.]],
  22. [[13., 14.],
  23. [15., 16.]]]])
  24. >>> output = channel_shuffle(input)
  25. >>> output
  26. tensor([[[[ 1., 2.],
  27. [ 3., 4.]],
  28. [[ 9., 10.],
  29. [11., 12.]],
  30. [[ 5., 6.],
  31. [ 7., 8.]],
  32. [[13., 14.],
  33. [15., 16.]]]])
  34. """
  35. __constants__ = ["groups"]
  36. groups: int
  37. def __init__(self, groups: int) -> None:
  38. super().__init__()
  39. self.groups = groups
  40. def forward(self, input: Tensor) -> Tensor:
  41. """
  42. Runs the forward pass.
  43. """
  44. return F.channel_shuffle(input, self.groups)
  45. def extra_repr(self) -> str:
  46. """
  47. Return the extra representation of the module.
  48. """
  49. return f"groups={self.groups}"