conv2d_same.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """ Conv2d w/ Same Padding
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from typing import Tuple, Optional, Union
  8. from ._fx import register_notrace_module
  9. from .config import is_exportable, is_scriptable
  10. from .padding import pad_same, pad_same_arg, get_padding_value
  11. _USE_EXPORT_CONV = False
  12. def conv2d_same(
  13. x,
  14. weight: torch.Tensor,
  15. bias: Optional[torch.Tensor] = None,
  16. stride: Tuple[int, int] = (1, 1),
  17. padding: Tuple[int, int] = (0, 0),
  18. dilation: Tuple[int, int] = (1, 1),
  19. groups: int = 1,
  20. ):
  21. x = pad_same(x, weight.shape[-2:], stride, dilation)
  22. return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
  23. @register_notrace_module
  24. class Conv2dSame(nn.Conv2d):
  25. """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
  26. """
  27. def __init__(
  28. self,
  29. in_channels: int,
  30. out_channels: int,
  31. kernel_size: Union[int, Tuple[int, int]],
  32. stride: Union[int, Tuple[int, int]] = 1,
  33. padding: Union[int, Tuple[int, int], str] = 0,
  34. dilation: Union[int, Tuple[int, int]] = 1,
  35. groups: int = 1,
  36. bias: bool = True,
  37. device=None,
  38. dtype=None,
  39. ):
  40. super().__init__(
  41. in_channels,
  42. out_channels,
  43. kernel_size,
  44. stride,
  45. 0, # padding
  46. dilation,
  47. groups,
  48. bias,
  49. device=device,
  50. dtype=dtype,
  51. )
  52. def forward(self, x):
  53. return conv2d_same(
  54. x,
  55. self.weight,
  56. self.bias,
  57. self.stride,
  58. self.padding,
  59. self.dilation,
  60. self.groups,
  61. )
  62. class Conv2dSameExport(nn.Conv2d):
  63. """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
  64. NOTE: This does not currently work with torch.jit.script
  65. """
  66. # pylint: disable=unused-argument
  67. def __init__(
  68. self,
  69. in_channels: int,
  70. out_channels: int,
  71. kernel_size: Union[int, Tuple[int, int]],
  72. stride: Union[int, Tuple[int, int]] = 1,
  73. padding: Union[int, Tuple[int, int], str] = 0,
  74. dilation: Union[int, Tuple[int, int]] = 1,
  75. groups: int = 1,
  76. bias: bool = True,
  77. device=None,
  78. dtype=None,
  79. ):
  80. super().__init__(
  81. in_channels,
  82. out_channels,
  83. kernel_size,
  84. stride,
  85. 0, # padding
  86. dilation,
  87. groups,
  88. bias,
  89. device=device,
  90. dtype=dtype,
  91. )
  92. self.pad = None
  93. self.pad_input_size = (0, 0)
  94. def forward(self, x):
  95. input_size = x.size()[-2:]
  96. if self.pad is None:
  97. pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
  98. self.pad = nn.ZeroPad2d(pad_arg)
  99. self.pad_input_size = input_size
  100. x = self.pad(x)
  101. return F.conv2d(
  102. x,
  103. self.weight,
  104. self.bias,
  105. self.stride,
  106. self.padding,
  107. self.dilation,
  108. self.groups,
  109. )
  110. def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
  111. padding = kwargs.pop('padding', '')
  112. kwargs.setdefault('bias', False)
  113. padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
  114. if is_dynamic:
  115. if _USE_EXPORT_CONV and is_exportable():
  116. # older PyTorch ver needed this to export same padding reasonably
  117. assert not is_scriptable() # Conv2DSameExport does not work with jit
  118. return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
  119. else:
  120. return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
  121. else:
  122. return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)