resnest.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. """ ResNeSt Models
  2. Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955
  3. Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang
  4. Modified for torchscript compat, and consistency with timm by Ross Wightman
  5. """
  6. from typing import Optional, Type
  7. from torch import nn
  8. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  9. from timm.layers import SplitAttn
  10. from ._builder import build_model_with_cfg
  11. from ._registry import register_model, generate_default_cfgs
  12. from .resnet import ResNet
  13. class ResNestBottleneck(nn.Module):
  14. """ResNet Bottleneck
  15. """
  16. # pylint: disable=unused-argument
  17. expansion = 4
  18. def __init__(
  19. self,
  20. inplanes: int,
  21. planes: int,
  22. stride: int = 1,
  23. downsample: Optional[nn.Module] = None,
  24. radix: int = 1,
  25. cardinality: int = 1,
  26. base_width: int = 64,
  27. avd: bool = False,
  28. avd_first: bool = False,
  29. is_first: bool = False,
  30. reduce_first: int = 1,
  31. dilation: int = 1,
  32. first_dilation: Optional[int] = None,
  33. act_layer: Type[nn.Module] = nn.ReLU,
  34. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  35. attn_layer: Optional[Type[nn.Module]] = None,
  36. aa_layer: Optional[Type[nn.Module]] = None,
  37. drop_block: Optional[Type[nn.Module]] = None,
  38. drop_path: Optional[nn.Module] = None,
  39. device=None,
  40. dtype=None,
  41. ):
  42. dd = {'device': device, 'dtype': dtype}
  43. super().__init__()
  44. assert reduce_first == 1 # not supported
  45. assert attn_layer is None, 'attn_layer is not supported' # not supported
  46. assert aa_layer is None, 'aa_layer is not supported' # TODO not yet supported
  47. group_width = int(planes * (base_width / 64.)) * cardinality
  48. first_dilation = first_dilation or dilation
  49. if avd and (stride > 1 or is_first):
  50. avd_stride = stride
  51. stride = 1
  52. else:
  53. avd_stride = 0
  54. self.radix = radix
  55. self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False, **dd)
  56. self.bn1 = norm_layer(group_width, **dd)
  57. self.act1 = act_layer(inplace=True)
  58. self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
  59. if self.radix >= 1:
  60. self.conv2 = SplitAttn(
  61. group_width,
  62. group_width,
  63. kernel_size=3,
  64. stride=stride,
  65. padding=first_dilation,
  66. dilation=first_dilation,
  67. groups=cardinality,
  68. radix=radix,
  69. norm_layer=norm_layer,
  70. drop_layer=drop_block,
  71. **dd,
  72. )
  73. self.bn2 = nn.Identity()
  74. self.drop_block = nn.Identity()
  75. self.act2 = nn.Identity()
  76. else:
  77. self.conv2 = nn.Conv2d(
  78. group_width,
  79. group_width,
  80. kernel_size=3,
  81. stride=stride,
  82. padding=first_dilation,
  83. dilation=first_dilation,
  84. groups=cardinality,
  85. bias=False,
  86. **dd,
  87. )
  88. self.bn2 = norm_layer(group_width, **dd)
  89. self.drop_block = drop_block() if drop_block is not None else nn.Identity()
  90. self.act2 = act_layer(inplace=True)
  91. self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
  92. self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False, **dd)
  93. self.bn3 = norm_layer(planes * 4, **dd)
  94. self.act3 = act_layer(inplace=True)
  95. self.downsample = downsample
  96. self.drop_path = drop_path
  97. def zero_init_last(self):
  98. if getattr(self.bn3, 'weight', None) is not None:
  99. nn.init.zeros_(self.bn3.weight)
  100. def forward(self, x):
  101. shortcut = x
  102. out = self.conv1(x)
  103. out = self.bn1(out)
  104. out = self.act1(out)
  105. if self.avd_first is not None:
  106. out = self.avd_first(out)
  107. out = self.conv2(out)
  108. out = self.bn2(out)
  109. out = self.drop_block(out)
  110. out = self.act2(out)
  111. if self.avd_last is not None:
  112. out = self.avd_last(out)
  113. out = self.conv3(out)
  114. out = self.bn3(out)
  115. if self.drop_path is not None:
  116. x = self.drop_path(x)
  117. if self.downsample is not None:
  118. shortcut = self.downsample(x)
  119. out += shortcut
  120. out = self.act3(out)
  121. return out
  122. def _create_resnest(variant, pretrained=False, **kwargs):
  123. return build_model_with_cfg(
  124. ResNet,
  125. variant,
  126. pretrained,
  127. **kwargs,
  128. )
  129. def _cfg(url='', **kwargs):
  130. return {
  131. 'url': url,
  132. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  133. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  134. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  135. 'first_conv': 'conv1.0', 'classifier': 'fc',
  136. 'license': 'apache-2.0',
  137. **kwargs
  138. }
  139. default_cfgs = generate_default_cfgs({
  140. 'resnest14d.gluon_in1k': _cfg(hf_hub_id='timm/'),
  141. 'resnest26d.gluon_in1k': _cfg(hf_hub_id='timm/'),
  142. 'resnest50d.in1k': _cfg(hf_hub_id='timm/'),
  143. 'resnest101e.in1k': _cfg(
  144. hf_hub_id='timm/',
  145. input_size=(3, 256, 256), pool_size=(8, 8)),
  146. 'resnest200e.in1k': _cfg(
  147. hf_hub_id='timm/',
  148. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'),
  149. 'resnest269e.in1k': _cfg(
  150. hf_hub_id='timm/',
  151. input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'),
  152. 'resnest50d_4s2x40d.in1k': _cfg(
  153. hf_hub_id='timm/',
  154. interpolation='bicubic'),
  155. 'resnest50d_1s4x24d.in1k': _cfg(
  156. hf_hub_id='timm/',
  157. interpolation='bicubic')
  158. })
  159. @register_model
  160. def resnest14d(pretrained=False, **kwargs) -> ResNet:
  161. """ ResNeSt-14d model. Weights ported from GluonCV.
  162. """
  163. model_kwargs = dict(
  164. block=ResNestBottleneck, layers=[1, 1, 1, 1],
  165. stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
  166. block_args=dict(radix=2, avd=True, avd_first=False))
  167. return _create_resnest('resnest14d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
  168. @register_model
  169. def resnest26d(pretrained=False, **kwargs) -> ResNet:
  170. """ ResNeSt-26d model. Weights ported from GluonCV.
  171. """
  172. model_kwargs = dict(
  173. block=ResNestBottleneck, layers=[2, 2, 2, 2],
  174. stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
  175. block_args=dict(radix=2, avd=True, avd_first=False))
  176. return _create_resnest('resnest26d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
  177. @register_model
  178. def resnest50d(pretrained=False, **kwargs) -> ResNet:
  179. """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
  180. Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
  181. """
  182. model_kwargs = dict(
  183. block=ResNestBottleneck, layers=[3, 4, 6, 3],
  184. stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
  185. block_args=dict(radix=2, avd=True, avd_first=False))
  186. return _create_resnest('resnest50d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
  187. @register_model
  188. def resnest101e(pretrained=False, **kwargs) -> ResNet:
  189. """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
  190. Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
  191. """
  192. model_kwargs = dict(
  193. block=ResNestBottleneck, layers=[3, 4, 23, 3],
  194. stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
  195. block_args=dict(radix=2, avd=True, avd_first=False))
  196. return _create_resnest('resnest101e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
  197. @register_model
  198. def resnest200e(pretrained=False, **kwargs) -> ResNet:
  199. """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
  200. Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
  201. """
  202. model_kwargs = dict(
  203. block=ResNestBottleneck, layers=[3, 24, 36, 3],
  204. stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
  205. block_args=dict(radix=2, avd=True, avd_first=False))
  206. return _create_resnest('resnest200e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
  207. @register_model
  208. def resnest269e(pretrained=False, **kwargs) -> ResNet:
  209. """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
  210. Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
  211. """
  212. model_kwargs = dict(
  213. block=ResNestBottleneck, layers=[3, 30, 48, 8],
  214. stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
  215. block_args=dict(radix=2, avd=True, avd_first=False))
  216. return _create_resnest('resnest269e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
  217. @register_model
  218. def resnest50d_4s2x40d(pretrained=False, **kwargs) -> ResNet:
  219. """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
  220. """
  221. model_kwargs = dict(
  222. block=ResNestBottleneck, layers=[3, 4, 6, 3],
  223. stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
  224. block_args=dict(radix=4, avd=True, avd_first=True))
  225. return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
  226. @register_model
  227. def resnest50d_1s4x24d(pretrained=False, **kwargs) -> ResNet:
  228. """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
  229. """
  230. model_kwargs = dict(
  231. block=ResNestBottleneck, layers=[3, 4, 6, 3],
  232. stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
  233. block_args=dict(radix=1, avd=True, avd_first=True))
  234. return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **dict(model_kwargs, **kwargs))