selective_kernel.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. """ Selective Kernel Convolution/Attention
  2. Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
  3. Hacked together by / Copyright 2020 Ross Wightman
  4. """
  5. from typing import List, Optional, Tuple, Type, Union
  6. import torch
  7. from torch import nn as nn
  8. from .conv_bn_act import ConvNormAct
  9. from .helpers import make_divisible
  10. from .trace_utils import _assert
  11. def _kernel_valid(k):
  12. if isinstance(k, (list, tuple)):
  13. for ki in k:
  14. return _kernel_valid(ki)
  15. assert k >= 3 and k % 2
  16. class SelectiveKernelAttn(nn.Module):
  17. def __init__(
  18. self,
  19. channels: int,
  20. num_paths: int = 2,
  21. attn_channels: int = 32,
  22. act_layer: Type[nn.Module] = nn.ReLU,
  23. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  24. device=None,
  25. dtype=None,
  26. ):
  27. """ Selective Kernel Attention Module
  28. Selective Kernel attention mechanism factored out into its own module.
  29. """
  30. dd = {'device': device, 'dtype': dtype}
  31. super().__init__()
  32. self.num_paths = num_paths
  33. self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False, **dd)
  34. self.bn = norm_layer(attn_channels, **dd)
  35. self.act = act_layer(inplace=True)
  36. self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False, **dd)
  37. def forward(self, x):
  38. _assert(x.shape[1] == self.num_paths, '')
  39. x = x.sum(1).mean((2, 3), keepdim=True)
  40. x = self.fc_reduce(x)
  41. x = self.bn(x)
  42. x = self.act(x)
  43. x = self.fc_select(x)
  44. B, C, H, W = x.shape
  45. x = x.view(B, self.num_paths, C // self.num_paths, H, W)
  46. x = torch.softmax(x, dim=1)
  47. return x
  48. class SelectiveKernel(nn.Module):
  49. def __init__(
  50. self,
  51. in_channels: int,
  52. out_channels: Optional[int] = None,
  53. kernel_size: Optional[Union[int, List[int]]] = None,
  54. stride: int = 1,
  55. dilation: int = 1,
  56. groups: int = 1,
  57. rd_ratio: float = 1./16,
  58. rd_channels: Optional[int] = None,
  59. rd_divisor: int = 8,
  60. keep_3x3: bool = True,
  61. split_input: bool = True,
  62. act_layer: Type[nn.Module] = nn.ReLU,
  63. norm_layer: Type[nn.Module]= nn.BatchNorm2d,
  64. aa_layer: Optional[Type[nn.Module]] = None,
  65. drop_layer: Optional[Type[nn.Module]] = None,
  66. device=None,
  67. dtype=None,
  68. ):
  69. """ Selective Kernel Convolution Module
  70. As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
  71. Largest change is the input split, which divides the input channels across each convolution path, this can
  72. be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
  73. the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
  74. a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
  75. Args:
  76. in_channels: module input (feature) channel count
  77. out_channels: module output (feature) channel count
  78. kernel_size: kernel size for each convolution branch
  79. stride: stride for convolutions
  80. dilation: dilation for module as a whole, impacts dilation of each branch
  81. groups: number of groups for each branch
  82. rd_ratio: reduction factor for attention features
  83. keep_3x3: keep all branch convolution kernels as 3x3, changing larger kernels for dilations
  84. split_input: split input channels evenly across each convolution branch, keeps param count lower,
  85. can be viewed as grouping by path, output expands to module out_channels count
  86. act_layer: activation layer to use
  87. norm_layer: batchnorm/norm layer to use
  88. aa_layer: anti-aliasing module
  89. drop_layer: spatial drop module in convs (drop block, etc)
  90. """
  91. dd = {'device': device, 'dtype': dtype}
  92. super().__init__()
  93. out_channels = out_channels or in_channels
  94. kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
  95. _kernel_valid(kernel_size)
  96. if not isinstance(kernel_size, list):
  97. kernel_size = [kernel_size] * 2
  98. if keep_3x3:
  99. dilation = [dilation * (k - 1) // 2 for k in kernel_size]
  100. kernel_size = [3] * len(kernel_size)
  101. else:
  102. dilation = [dilation] * len(kernel_size)
  103. self.num_paths = len(kernel_size)
  104. self.in_channels = in_channels
  105. self.out_channels = out_channels
  106. self.split_input = split_input
  107. if self.split_input:
  108. assert in_channels % self.num_paths == 0
  109. in_channels = in_channels // self.num_paths
  110. groups = min(out_channels, groups)
  111. conv_kwargs = dict(
  112. stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
  113. aa_layer=aa_layer, drop_layer=drop_layer, **dd)
  114. self.paths = nn.ModuleList([
  115. ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
  116. for k, d in zip(kernel_size, dilation)])
  117. attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
  118. self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels, **dd)
  119. def forward(self, x):
  120. if self.split_input:
  121. x_split = torch.split(x, self.in_channels // self.num_paths, 1)
  122. x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
  123. else:
  124. x_paths = [op(x) for op in self.paths]
  125. x = torch.stack(x_paths, dim=1)
  126. x_attn = self.attn(x)
  127. x = x * x_attn
  128. x = torch.sum(x, dim=1)
  129. return x