sknet.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. """ Selective Kernel Networks (ResNet base)
  2. Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
  3. This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268)
  4. and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer
  5. to the original paper with some modifications of my own to better balance param count vs accuracy.
  6. Hacked together by / Copyright 2020 Ross Wightman
  7. """
  8. import math
  9. from typing import Optional, Type
  10. from torch import nn as nn
  11. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  12. from timm.layers import SelectiveKernel, ConvNormAct, create_attn
  13. from ._builder import build_model_with_cfg
  14. from ._registry import register_model, generate_default_cfgs
  15. from .resnet import ResNet
  16. class SelectiveKernelBasic(nn.Module):
  17. expansion = 1
  18. def __init__(
  19. self,
  20. inplanes: int,
  21. planes: int,
  22. stride: int = 1,
  23. downsample: Optional[nn.Module] = None,
  24. cardinality: int = 1,
  25. base_width: int = 64,
  26. sk_kwargs: Optional[dict] = None,
  27. reduce_first: int = 1,
  28. dilation: int = 1,
  29. first_dilation: Optional[int] = None,
  30. act_layer: Type[nn.Module] = nn.ReLU,
  31. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  32. attn_layer: Optional[Type[nn.Module]] = None,
  33. aa_layer: Optional[Type[nn.Module]] = None,
  34. drop_block: Optional[nn.Module] = None,
  35. drop_path: Optional[nn.Module] = None,
  36. device=None,
  37. dtype=None,
  38. ):
  39. dd = {'device': device, 'dtype': dtype}
  40. super().__init__()
  41. sk_kwargs = sk_kwargs or {}
  42. conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer, **dd)
  43. assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
  44. assert base_width == 64, 'BasicBlock doest not support changing base width'
  45. first_planes = planes // reduce_first
  46. outplanes = planes * self.expansion
  47. first_dilation = first_dilation or dilation
  48. self.conv1 = SelectiveKernel(
  49. inplanes,
  50. first_planes,
  51. stride=stride,
  52. dilation=first_dilation,
  53. aa_layer=aa_layer,
  54. drop_layer=drop_block,
  55. **conv_kwargs,
  56. **sk_kwargs,
  57. )
  58. self.conv2 = ConvNormAct(
  59. first_planes,
  60. outplanes,
  61. kernel_size=3,
  62. dilation=dilation,
  63. apply_act=False,
  64. **conv_kwargs,
  65. )
  66. self.se = create_attn(attn_layer, outplanes, **dd)
  67. self.act = act_layer(inplace=True)
  68. self.downsample = downsample
  69. self.drop_path = drop_path
  70. def zero_init_last(self):
  71. if getattr(self.conv2.bn, 'weight', None) is not None:
  72. nn.init.zeros_(self.conv2.bn.weight)
  73. def forward(self, x):
  74. shortcut = x
  75. x = self.conv1(x)
  76. x = self.conv2(x)
  77. if self.se is not None:
  78. x = self.se(x)
  79. if self.drop_path is not None:
  80. x = self.drop_path(x)
  81. if self.downsample is not None:
  82. shortcut = self.downsample(shortcut)
  83. x += shortcut
  84. x = self.act(x)
  85. return x
  86. class SelectiveKernelBottleneck(nn.Module):
  87. expansion = 4
  88. def __init__(
  89. self,
  90. inplanes: int,
  91. planes: int,
  92. stride: int = 1,
  93. downsample: Optional[nn.Module] = None,
  94. cardinality: int = 1,
  95. base_width: int = 64,
  96. sk_kwargs: Optional[dict] = None,
  97. reduce_first: int = 1,
  98. dilation: int = 1,
  99. first_dilation: Optional[int] = None,
  100. act_layer: Type[nn.Module] = nn.ReLU,
  101. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  102. attn_layer: Optional[Type[nn.Module]] = None,
  103. aa_layer: Optional[Type[nn.Module]] = None,
  104. drop_block: Optional[nn.Module] = None,
  105. drop_path: Optional[nn.Module] = None,
  106. device=None,
  107. dtype=None,
  108. ):
  109. dd = {'device': device, 'dtype': dtype}
  110. super().__init__()
  111. sk_kwargs = sk_kwargs or {}
  112. conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer, **dd)
  113. width = int(math.floor(planes * (base_width / 64)) * cardinality)
  114. first_planes = width // reduce_first
  115. outplanes = planes * self.expansion
  116. first_dilation = first_dilation or dilation
  117. self.conv1 = ConvNormAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
  118. self.conv2 = SelectiveKernel(
  119. first_planes,
  120. width,
  121. stride=stride,
  122. dilation=first_dilation,
  123. groups=cardinality,
  124. aa_layer=aa_layer,
  125. drop_layer=drop_block,
  126. **conv_kwargs,
  127. **sk_kwargs,
  128. )
  129. self.conv3 = ConvNormAct(width, outplanes, kernel_size=1, apply_act=False, **conv_kwargs)
  130. self.se = create_attn(attn_layer, outplanes, **dd)
  131. self.act = act_layer(inplace=True)
  132. self.downsample = downsample
  133. self.drop_path = drop_path
  134. def zero_init_last(self):
  135. if getattr(self.conv3.bn, 'weight', None) is not None:
  136. nn.init.zeros_(self.conv3.bn.weight)
  137. def forward(self, x):
  138. shortcut = x
  139. x = self.conv1(x)
  140. x = self.conv2(x)
  141. x = self.conv3(x)
  142. if self.se is not None:
  143. x = self.se(x)
  144. if self.drop_path is not None:
  145. x = self.drop_path(x)
  146. if self.downsample is not None:
  147. shortcut = self.downsample(shortcut)
  148. x += shortcut
  149. x = self.act(x)
  150. return x
  151. def _create_skresnet(variant, pretrained=False, **kwargs):
  152. return build_model_with_cfg(
  153. ResNet,
  154. variant,
  155. pretrained,
  156. **kwargs,
  157. )
  158. def _cfg(url='', **kwargs):
  159. return {
  160. 'url': url,
  161. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  162. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  163. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  164. 'first_conv': 'conv1', 'classifier': 'fc',
  165. 'license': 'apache-2.0',
  166. **kwargs
  167. }
  168. default_cfgs = generate_default_cfgs({
  169. 'skresnet18.ra_in1k': _cfg(hf_hub_id='timm/'),
  170. 'skresnet34.ra_in1k': _cfg(hf_hub_id='timm/'),
  171. 'skresnet50.untrained': _cfg(),
  172. 'skresnet50d.untrained': _cfg(
  173. first_conv='conv1.0'),
  174. 'skresnext50_32x4d.ra_in1k': _cfg(hf_hub_id='timm/'),
  175. })
  176. @register_model
  177. def skresnet18(pretrained=False, **kwargs) -> ResNet:
  178. """Constructs a Selective Kernel ResNet-18 model.
  179. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
  180. variation splits the input channels to the selective convolutions to keep param count down.
  181. """
  182. sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
  183. model_args = dict(
  184. block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
  185. zero_init_last=False)
  186. return _create_skresnet('skresnet18', pretrained, **dict(model_args, **kwargs))
  187. @register_model
  188. def skresnet34(pretrained=False, **kwargs) -> ResNet:
  189. """Constructs a Selective Kernel ResNet-34 model.
  190. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
  191. variation splits the input channels to the selective convolutions to keep param count down.
  192. """
  193. sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
  194. model_args = dict(
  195. block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
  196. zero_init_last=False)
  197. return _create_skresnet('skresnet34', pretrained, **dict(model_args, **kwargs))
  198. @register_model
  199. def skresnet50(pretrained=False, **kwargs) -> ResNet:
  200. """Constructs a Select Kernel ResNet-50 model.
  201. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
  202. variation splits the input channels to the selective convolutions to keep param count down.
  203. """
  204. sk_kwargs = dict(split_input=True)
  205. model_args = dict(
  206. block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
  207. zero_init_last=False)
  208. return _create_skresnet('skresnet50', pretrained, **dict(model_args, **kwargs))
  209. @register_model
  210. def skresnet50d(pretrained=False, **kwargs) -> ResNet:
  211. """Constructs a Select Kernel ResNet-50-D model.
  212. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
  213. variation splits the input channels to the selective convolutions to keep param count down.
  214. """
  215. sk_kwargs = dict(split_input=True)
  216. model_args = dict(
  217. block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
  218. block_args=dict(sk_kwargs=sk_kwargs), zero_init_last=False)
  219. return _create_skresnet('skresnet50d', pretrained, **dict(model_args, **kwargs))
  220. @register_model
  221. def skresnext50_32x4d(pretrained=False, **kwargs) -> ResNet:
  222. """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
  223. the SKNet-50 model in the Select Kernel Paper
  224. """
  225. sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
  226. model_args = dict(
  227. block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
  228. block_args=dict(sk_kwargs=sk_kwargs), zero_init_last=False)
  229. return _create_skresnet('skresnext50_32x4d', pretrained, **dict(model_args, **kwargs))