pool2d_same.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """ AvgPool2d w/ Same Padding
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from typing import List, Tuple, Optional, Union
  8. from ._fx import register_notrace_module
  9. from .helpers import to_2tuple
  10. from .padding import pad_same, get_padding_value
  11. def avg_pool2d_same(
  12. x: torch.Tensor,
  13. kernel_size: List[int],
  14. stride: List[int],
  15. padding: List[int] = (0, 0),
  16. ceil_mode: bool = False,
  17. count_include_pad: bool = True,
  18. ):
  19. # FIXME how to deal with count_include_pad vs not for external padding?
  20. x = pad_same(x, kernel_size, stride)
  21. return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
  22. @register_notrace_module
  23. class AvgPool2dSame(nn.AvgPool2d):
  24. """ Tensorflow like 'SAME' wrapper for 2D average pooling
  25. """
  26. def __init__(
  27. self,
  28. kernel_size: Union[int, Tuple[int, int]],
  29. stride: Optional[Union[int, Tuple[int, int]]] = None,
  30. padding: Union[int, Tuple[int, int], str] = 0,
  31. ceil_mode: bool = False,
  32. count_include_pad: bool = True,
  33. ):
  34. kernel_size = to_2tuple(kernel_size)
  35. stride = to_2tuple(stride)
  36. super().__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
  37. def forward(self, x):
  38. x = pad_same(x, self.kernel_size, self.stride)
  39. return F.avg_pool2d(
  40. x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
  41. def max_pool2d_same(
  42. x: torch.Tensor,
  43. kernel_size: List[int],
  44. stride: List[int],
  45. padding: List[int] = (0, 0),
  46. dilation: List[int] = (1, 1),
  47. ceil_mode: bool = False,
  48. ):
  49. x = pad_same(x, kernel_size, stride, value=-float('inf'))
  50. return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
  51. @register_notrace_module
  52. class MaxPool2dSame(nn.MaxPool2d):
  53. """ Tensorflow like 'SAME' wrapper for 2D max pooling
  54. """
  55. def __init__(
  56. self,
  57. kernel_size: Union[int, Tuple[int, int]],
  58. stride: Optional[Union[int, Tuple[int, int]]] = None,
  59. padding: Union[int, Tuple[int, int], str] = 0,
  60. dilation: Union[int, Tuple[int, int]] = 1,
  61. ceil_mode: bool = False,
  62. ):
  63. kernel_size = to_2tuple(kernel_size)
  64. stride = to_2tuple(stride)
  65. dilation = to_2tuple(dilation)
  66. super().__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
  67. def forward(self, x):
  68. x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
  69. return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
  70. def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
  71. stride = stride or kernel_size
  72. padding = kwargs.pop('padding', '')
  73. padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
  74. if is_dynamic:
  75. if pool_type == 'avg':
  76. return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
  77. elif pool_type == 'max':
  78. return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
  79. else:
  80. assert False, f'Unsupported pool type {pool_type}'
  81. else:
  82. if pool_type == 'avg':
  83. return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
  84. elif pool_type == 'max':
  85. return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
  86. else:
  87. assert False, f'Unsupported pool type {pool_type}'