res2net.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. """ Res2Net and Res2NeXt
  2. Adapted from Official Pytorch impl at: https://github.com/gasvn/Res2Net/
  3. Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
  4. """
  5. import math
  6. from typing import Optional, Type
  7. import torch
  8. import torch.nn as nn
  9. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  10. from ._builder import build_model_with_cfg
  11. from ._registry import register_model, generate_default_cfgs
  12. from .resnet import ResNet
  13. __all__ = []
  14. class Bottle2neck(nn.Module):
  15. """ Res2Net/Res2NeXT Bottleneck
  16. Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py
  17. """
  18. expansion = 4
  19. def __init__(
  20. self,
  21. inplanes: int,
  22. planes: int,
  23. stride: int = 1,
  24. downsample: Optional[nn.Module] = None,
  25. cardinality: int = 1,
  26. base_width: int = 26,
  27. scale: int = 4,
  28. dilation: int = 1,
  29. first_dilation: Optional[int] = None,
  30. act_layer: Type[nn.Module] = nn.ReLU,
  31. norm_layer: Optional[Type[nn.Module]] = None,
  32. attn_layer: Optional[Type[nn.Module]] = None,
  33. device=None,
  34. dtype=None,
  35. **_,
  36. ):
  37. dd = {'device': device, 'dtype': dtype}
  38. super().__init__()
  39. self.scale = scale
  40. self.is_first = stride > 1 or downsample is not None
  41. self.num_scales = max(1, scale - 1)
  42. width = int(math.floor(planes * (base_width / 64.0))) * cardinality
  43. self.width = width
  44. outplanes = planes * self.expansion
  45. first_dilation = first_dilation or dilation
  46. self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False, **dd)
  47. self.bn1 = norm_layer(width * scale, **dd)
  48. convs = []
  49. bns = []
  50. for i in range(self.num_scales):
  51. convs.append(nn.Conv2d(
  52. width,
  53. width,
  54. kernel_size=3,
  55. stride=stride,
  56. padding=first_dilation,
  57. dilation=first_dilation,
  58. groups=cardinality,
  59. bias=False,
  60. **dd,
  61. ))
  62. bns.append(norm_layer(width, **dd))
  63. self.convs = nn.ModuleList(convs)
  64. self.bns = nn.ModuleList(bns)
  65. if self.is_first:
  66. # FIXME this should probably have count_include_pad=False, but hurts original weights
  67. self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
  68. else:
  69. self.pool = None
  70. self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False, **dd)
  71. self.bn3 = norm_layer(outplanes, **dd)
  72. self.se = attn_layer(outplanes, **dd) if attn_layer is not None else None
  73. self.relu = act_layer(inplace=True)
  74. self.downsample = downsample
  75. def zero_init_last(self):
  76. if getattr(self.bn3, 'weight', None) is not None:
  77. nn.init.zeros_(self.bn3.weight)
  78. def forward(self, x):
  79. shortcut = x
  80. out = self.conv1(x)
  81. out = self.bn1(out)
  82. out = self.relu(out)
  83. spx = torch.split(out, self.width, 1)
  84. spo = []
  85. sp = spx[0] # redundant, for torchscript
  86. for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
  87. if i == 0 or self.is_first:
  88. sp = spx[i]
  89. else:
  90. sp = sp + spx[i]
  91. sp = conv(sp)
  92. sp = bn(sp)
  93. sp = self.relu(sp)
  94. spo.append(sp)
  95. if self.scale > 1:
  96. if self.pool is not None: # self.is_first == True, None check for torchscript
  97. spo.append(self.pool(spx[-1]))
  98. else:
  99. spo.append(spx[-1])
  100. out = torch.cat(spo, 1)
  101. out = self.conv3(out)
  102. out = self.bn3(out)
  103. if self.se is not None:
  104. out = self.se(out)
  105. if self.downsample is not None:
  106. shortcut = self.downsample(x)
  107. out += shortcut
  108. out = self.relu(out)
  109. return out
  110. def _create_res2net(variant, pretrained=False, **kwargs):
  111. return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
  112. def _cfg(url='', **kwargs):
  113. return {
  114. 'url': url,
  115. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  116. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  117. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  118. 'first_conv': 'conv1', 'classifier': 'fc',
  119. 'license': 'unknown',
  120. **kwargs
  121. }
  122. default_cfgs = generate_default_cfgs({
  123. 'res2net50_26w_4s.in1k': _cfg(hf_hub_id='timm/'),
  124. 'res2net50_48w_2s.in1k': _cfg(hf_hub_id='timm/'),
  125. 'res2net50_14w_8s.in1k': _cfg(hf_hub_id='timm/'),
  126. 'res2net50_26w_6s.in1k': _cfg(hf_hub_id='timm/'),
  127. 'res2net50_26w_8s.in1k': _cfg(hf_hub_id='timm/'),
  128. 'res2net101_26w_4s.in1k': _cfg(hf_hub_id='timm/'),
  129. 'res2next50.in1k': _cfg(hf_hub_id='timm/'),
  130. 'res2net50d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'),
  131. 'res2net101d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'),
  132. })
  133. @register_model
  134. def res2net50_26w_4s(pretrained=False, **kwargs) -> ResNet:
  135. """Constructs a Res2Net-50 26w4s model.
  136. """
  137. model_args = dict(
  138. block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4))
  139. return _create_res2net('res2net50_26w_4s', pretrained, **dict(model_args, **kwargs))
  140. @register_model
  141. def res2net101_26w_4s(pretrained=False, **kwargs) -> ResNet:
  142. """Constructs a Res2Net-101 26w4s model.
  143. """
  144. model_args = dict(
  145. block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4))
  146. return _create_res2net('res2net101_26w_4s', pretrained, **dict(model_args, **kwargs))
  147. @register_model
  148. def res2net50_26w_6s(pretrained=False, **kwargs) -> ResNet:
  149. """Constructs a Res2Net-50 26w6s model.
  150. """
  151. model_args = dict(
  152. block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6))
  153. return _create_res2net('res2net50_26w_6s', pretrained, **dict(model_args, **kwargs))
  154. @register_model
  155. def res2net50_26w_8s(pretrained=False, **kwargs) -> ResNet:
  156. """Constructs a Res2Net-50 26w8s model.
  157. """
  158. model_args = dict(
  159. block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8))
  160. return _create_res2net('res2net50_26w_8s', pretrained, **dict(model_args, **kwargs))
  161. @register_model
  162. def res2net50_48w_2s(pretrained=False, **kwargs) -> ResNet:
  163. """Constructs a Res2Net-50 48w2s model.
  164. """
  165. model_args = dict(
  166. block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2))
  167. return _create_res2net('res2net50_48w_2s', pretrained, **dict(model_args, **kwargs))
  168. @register_model
  169. def res2net50_14w_8s(pretrained=False, **kwargs) -> ResNet:
  170. """Constructs a Res2Net-50 14w8s model.
  171. """
  172. model_args = dict(
  173. block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8))
  174. return _create_res2net('res2net50_14w_8s', pretrained, **dict(model_args, **kwargs))
  175. @register_model
  176. def res2next50(pretrained=False, **kwargs) -> ResNet:
  177. """Construct Res2NeXt-50 4s
  178. """
  179. model_args = dict(
  180. block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4))
  181. return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs))
  182. @register_model
  183. def res2net50d(pretrained=False, **kwargs) -> ResNet:
  184. """Construct Res2Net-50
  185. """
  186. model_args = dict(
  187. block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, stem_type='deep',
  188. avg_down=True, stem_width=32, block_args=dict(scale=4))
  189. return _create_res2net('res2net50d', pretrained, **dict(model_args, **kwargs))
  190. @register_model
  191. def res2net101d(pretrained=False, **kwargs) -> ResNet:
  192. """Construct Res2Net-50
  193. """
  194. model_args = dict(
  195. block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, stem_type='deep',
  196. avg_down=True, stem_width=32, block_args=dict(scale=4))
  197. return _create_res2net('res2net101d', pretrained, **dict(model_args, **kwargs))