| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import torch
- import torch.nn as nn
- from torchvision.ops import deform_conv2d
- class DeformableConv2d(nn.Module):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False):
- super(DeformableConv2d, self).__init__()
-
- assert type(kernel_size) == tuple or type(kernel_size) == int
- kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
- self.stride = stride if type(stride) == tuple else (stride, stride)
- self.padding = padding
-
- self.offset_conv = nn.Conv2d(in_channels,
- 2 * kernel_size[0] * kernel_size[1],
- kernel_size=kernel_size,
- stride=stride,
- padding=self.padding,
- bias=True)
- nn.init.constant_(self.offset_conv.weight, 0.)
- nn.init.constant_(self.offset_conv.bias, 0.)
-
- self.modulator_conv = nn.Conv2d(in_channels,
- 1 * kernel_size[0] * kernel_size[1],
- kernel_size=kernel_size,
- stride=stride,
- padding=self.padding,
- bias=True)
- nn.init.constant_(self.modulator_conv.weight, 0.)
- nn.init.constant_(self.modulator_conv.bias, 0.)
- self.regular_conv = nn.Conv2d(in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=self.padding,
- bias=bias)
- def forward(self, x):
- #h, w = x.shape[2:]
- #max_offset = max(h, w)/4.
- offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
- modulator = 2. * torch.sigmoid(self.modulator_conv(x))
-
- x = deform_conv2d(
- input=x,
- offset=offset,
- weight=self.regular_conv.weight,
- bias=self.regular_conv.bias,
- padding=self.padding,
- mask=modulator,
- stride=self.stride,
- )
- return x
|