median_pool.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. """ Median Pool
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from .helpers import to_2tuple, to_4tuple
  7. class MedianPool2d(nn.Module):
  8. """ Median pool (usable as median filter when stride=1) module.
  9. Args:
  10. kernel_size: size of pooling kernel, int or 2-tuple
  11. stride: pool stride, int or 2-tuple
  12. padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
  13. same: override padding and enforce same padding, boolean
  14. """
  15. def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
  16. super().__init__()
  17. self.k = to_2tuple(kernel_size)
  18. self.stride = to_2tuple(stride)
  19. self.padding = to_4tuple(padding) # convert to l, r, t, b
  20. self.same = same
  21. def _padding(self, x):
  22. if self.same:
  23. ih, iw = x.size()[2:]
  24. if ih % self.stride[0] == 0:
  25. ph = max(self.k[0] - self.stride[0], 0)
  26. else:
  27. ph = max(self.k[0] - (ih % self.stride[0]), 0)
  28. if iw % self.stride[1] == 0:
  29. pw = max(self.k[1] - self.stride[1], 0)
  30. else:
  31. pw = max(self.k[1] - (iw % self.stride[1]), 0)
  32. pl = pw // 2
  33. pr = pw - pl
  34. pt = ph // 2
  35. pb = ph - pt
  36. padding = (pl, pr, pt, pb)
  37. else:
  38. padding = self.padding
  39. return padding
  40. def forward(self, x):
  41. x = F.pad(x, self._padding(x), mode='reflect')
  42. x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
  43. x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
  44. return x