space_to_depth.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import torch
  2. import torch.nn as nn
  3. class SpaceToDepth(nn.Module):
  4. """Rearrange spatial dimensions into channel dimension.
  5. Divides spatial dimensions by block_size and multiplies channels by block_size^2.
  6. Used in TResNet as an efficient stem operation.
  7. Args:
  8. block_size: Spatial reduction factor.
  9. """
  10. bs: torch.jit.Final[int]
  11. def __init__(self, block_size: int = 4):
  12. super().__init__()
  13. assert block_size == 4
  14. self.bs = block_size
  15. def forward(self, x):
  16. N, C, H, W = x.size()
  17. x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
  18. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
  19. x = x.view(N, C * self.bs * self.bs, H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
  20. return x
  21. class DepthToSpace(nn.Module):
  22. """Rearrange channel dimension into spatial dimensions.
  23. Inverse of SpaceToDepth. Divides channels by block_size^2 and multiplies
  24. spatial dimensions by block_size.
  25. Args:
  26. block_size: Spatial expansion factor.
  27. """
  28. def __init__(self, block_size):
  29. super().__init__()
  30. self.bs = block_size
  31. def forward(self, x):
  32. N, C, H, W = x.size()
  33. x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
  34. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
  35. x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
  36. return x