| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- """ Conv2d w/ Same Padding
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from typing import Tuple, Optional, Union
- from ._fx import register_notrace_module
- from .config import is_exportable, is_scriptable
- from .padding import pad_same, pad_same_arg, get_padding_value
- _USE_EXPORT_CONV = False
- def conv2d_same(
- x,
- weight: torch.Tensor,
- bias: Optional[torch.Tensor] = None,
- stride: Tuple[int, int] = (1, 1),
- padding: Tuple[int, int] = (0, 0),
- dilation: Tuple[int, int] = (1, 1),
- groups: int = 1,
- ):
- x = pad_same(x, weight.shape[-2:], stride, dilation)
- return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
- @register_notrace_module
- class Conv2dSame(nn.Conv2d):
- """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
- """
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int]],
- stride: Union[int, Tuple[int, int]] = 1,
- padding: Union[int, Tuple[int, int], str] = 0,
- dilation: Union[int, Tuple[int, int]] = 1,
- groups: int = 1,
- bias: bool = True,
- device=None,
- dtype=None,
- ):
- super().__init__(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- 0, # padding
- dilation,
- groups,
- bias,
- device=device,
- dtype=dtype,
- )
- def forward(self, x):
- return conv2d_same(
- x,
- self.weight,
- self.bias,
- self.stride,
- self.padding,
- self.dilation,
- self.groups,
- )
- class Conv2dSameExport(nn.Conv2d):
- """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
- NOTE: This does not currently work with torch.jit.script
- """
- # pylint: disable=unused-argument
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int]],
- stride: Union[int, Tuple[int, int]] = 1,
- padding: Union[int, Tuple[int, int], str] = 0,
- dilation: Union[int, Tuple[int, int]] = 1,
- groups: int = 1,
- bias: bool = True,
- device=None,
- dtype=None,
- ):
- super().__init__(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- 0, # padding
- dilation,
- groups,
- bias,
- device=device,
- dtype=dtype,
- )
- self.pad = None
- self.pad_input_size = (0, 0)
- def forward(self, x):
- input_size = x.size()[-2:]
- if self.pad is None:
- pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
- self.pad = nn.ZeroPad2d(pad_arg)
- self.pad_input_size = input_size
- x = self.pad(x)
- return F.conv2d(
- x,
- self.weight,
- self.bias,
- self.stride,
- self.padding,
- self.dilation,
- self.groups,
- )
- def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
- padding = kwargs.pop('padding', '')
- kwargs.setdefault('bias', False)
- padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
- if is_dynamic:
- if _USE_EXPORT_CONV and is_exportable():
- # older PyTorch ver needed this to export same padding reasonably
- assert not is_scriptable() # Conv2DSameExport does not work with jit
- return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
- else:
- return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
- else:
- return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|