deform_conv.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import torch
  2. import torch.nn as nn
  3. from torchvision.ops import deform_conv2d
  4. class DeformableConv2d(nn.Module):
  5. def __init__(self,
  6. in_channels,
  7. out_channels,
  8. kernel_size=3,
  9. stride=1,
  10. padding=1,
  11. bias=False):
  12. super(DeformableConv2d, self).__init__()
  13. assert type(kernel_size) == tuple or type(kernel_size) == int
  14. kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
  15. self.stride = stride if type(stride) == tuple else (stride, stride)
  16. self.padding = padding
  17. self.offset_conv = nn.Conv2d(in_channels,
  18. 2 * kernel_size[0] * kernel_size[1],
  19. kernel_size=kernel_size,
  20. stride=stride,
  21. padding=self.padding,
  22. bias=True)
  23. nn.init.constant_(self.offset_conv.weight, 0.)
  24. nn.init.constant_(self.offset_conv.bias, 0.)
  25. self.modulator_conv = nn.Conv2d(in_channels,
  26. 1 * kernel_size[0] * kernel_size[1],
  27. kernel_size=kernel_size,
  28. stride=stride,
  29. padding=self.padding,
  30. bias=True)
  31. nn.init.constant_(self.modulator_conv.weight, 0.)
  32. nn.init.constant_(self.modulator_conv.bias, 0.)
  33. self.regular_conv = nn.Conv2d(in_channels,
  34. out_channels=out_channels,
  35. kernel_size=kernel_size,
  36. stride=stride,
  37. padding=self.padding,
  38. bias=bias)
  39. def forward(self, x):
  40. #h, w = x.shape[2:]
  41. #max_offset = max(h, w)/4.
  42. offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
  43. modulator = 2. * torch.sigmoid(self.modulator_conv(x))
  44. x = deform_conv2d(
  45. input=x,
  46. offset=offset,
  47. weight=self.regular_conv.weight,
  48. bias=self.regular_conv.bias,
  49. padding=self.padding,
  50. mask=modulator,
  51. stride=self.stride,
  52. )
  53. return x