tresnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. """
  2. TResNet: High Performance GPU-Dedicated Architecture
  3. https://arxiv.org/pdf/2003.13630.pdf
  4. Original model: https://github.com/mrT23/TResNet
  5. """
  6. from collections import OrderedDict
  7. from functools import partial
  8. from typing import List, Optional, Tuple, Union, Type
  9. import torch
  10. import torch.nn as nn
  11. from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath, calculate_drop_path_rates
  12. from ._builder import build_model_with_cfg
  13. from ._features import feature_take_indices
  14. from ._manipulate import checkpoint, checkpoint_seq
  15. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  16. __all__ = ['TResNet'] # model_registry will add each entrypoint fn to this
  17. class BasicBlock(nn.Module):
  18. expansion = 1
  19. def __init__(
  20. self,
  21. inplanes: int,
  22. planes: int,
  23. stride: int = 1,
  24. downsample: Optional[nn.Module] = None,
  25. use_se: bool = True,
  26. aa_layer: Optional[Type[nn.Module]] = None,
  27. drop_path_rate: float = 0.,
  28. device=None,
  29. dtype=None,
  30. ) -> None:
  31. dd = {'device': device, 'dtype': dtype}
  32. super().__init__()
  33. self.downsample = downsample
  34. self.stride = stride
  35. act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
  36. self.conv1 = ConvNormAct(
  37. inplanes,
  38. planes,
  39. kernel_size=3,
  40. stride=stride,
  41. act_layer=act_layer,
  42. aa_layer=aa_layer,
  43. **dd,
  44. )
  45. self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, **dd)
  46. self.act = nn.ReLU(inplace=True)
  47. rd_chs = max(planes * self.expansion // 4, 64)
  48. self.se = SEModule(planes * self.expansion, rd_channels=rd_chs, **dd) if use_se else None
  49. self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  50. def forward(self, x):
  51. if self.downsample is not None:
  52. shortcut = self.downsample(x)
  53. else:
  54. shortcut = x
  55. out = self.conv1(x)
  56. out = self.conv2(out)
  57. if self.se is not None:
  58. out = self.se(out)
  59. out = self.drop_path(out) + shortcut
  60. out = self.act(out)
  61. return out
  62. class Bottleneck(nn.Module):
  63. expansion = 4
  64. def __init__(
  65. self,
  66. inplanes: int,
  67. planes: int,
  68. stride: int = 1,
  69. downsample: Optional[nn.Module] = None,
  70. use_se: bool = True,
  71. act_layer: Optional[Type[nn.Module]] = None,
  72. aa_layer: Optional[Type[nn.Module]] = None,
  73. drop_path_rate: float = 0.,
  74. device=None,
  75. dtype=None,
  76. ) -> None:
  77. dd = {'device': device, 'dtype': dtype}
  78. super().__init__()
  79. self.downsample = downsample
  80. self.stride = stride
  81. act_layer = act_layer or partial(nn.LeakyReLU, negative_slope=1e-3)
  82. self.conv1 = ConvNormAct(inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, **dd)
  83. self.conv2 = ConvNormAct(
  84. planes,
  85. planes,
  86. kernel_size=3,
  87. stride=stride,
  88. act_layer=act_layer,
  89. aa_layer=aa_layer,
  90. **dd,
  91. )
  92. reduction_chs = max(planes * self.expansion // 8, 64)
  93. self.se = SEModule(planes, rd_channels=reduction_chs, **dd) if use_se else None
  94. self.conv3 = ConvNormAct(planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, **dd)
  95. self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
  96. self.act = nn.ReLU(inplace=True)
  97. def forward(self, x):
  98. if self.downsample is not None:
  99. shortcut = self.downsample(x)
  100. else:
  101. shortcut = x
  102. out = self.conv1(x)
  103. out = self.conv2(out)
  104. if self.se is not None:
  105. out = self.se(out)
  106. out = self.conv3(out)
  107. out = self.drop_path(out) + shortcut
  108. out = self.act(out)
  109. return out
  110. class TResNet(nn.Module):
  111. def __init__(
  112. self,
  113. layers: List[int],
  114. in_chans: int = 3,
  115. num_classes: int = 1000,
  116. width_factor: float = 1.0,
  117. v2: bool = False,
  118. global_pool: str = 'fast',
  119. drop_rate: float = 0.,
  120. drop_path_rate: float = 0.,
  121. device=None,
  122. dtype=None,
  123. ) -> None:
  124. super().__init__()
  125. dd = {'device': device, 'dtype': dtype}
  126. self.num_classes = num_classes
  127. self.in_chans = in_chans
  128. self.drop_rate = drop_rate
  129. self.grad_checkpointing = False
  130. aa_layer = BlurPool2d
  131. act_layer = nn.LeakyReLU
  132. # TResnet stages
  133. self.inplanes = int(64 * width_factor)
  134. self.planes = int(64 * width_factor)
  135. if v2:
  136. self.inplanes = self.inplanes // 8 * 8
  137. self.planes = self.planes // 8 * 8
  138. dpr = calculate_drop_path_rates(drop_path_rate, layers, stagewise=True)
  139. conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer, **dd)
  140. layer1 = self._make_layer(
  141. Bottleneck if v2 else BasicBlock,
  142. self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0], **dd)
  143. layer2 = self._make_layer(
  144. Bottleneck if v2 else BasicBlock,
  145. self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1], **dd)
  146. layer3 = self._make_layer(
  147. Bottleneck,
  148. self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2], **dd)
  149. layer4 = self._make_layer(
  150. Bottleneck,
  151. self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3], **dd)
  152. # body
  153. self.body = nn.Sequential(OrderedDict([
  154. ('s2d', SpaceToDepth()),
  155. ('conv1', conv1),
  156. ('layer1', layer1),
  157. ('layer2', layer2),
  158. ('layer3', layer3),
  159. ('layer4', layer4),
  160. ]))
  161. self.feature_info = [
  162. dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
  163. dict(num_chs=self.planes * (Bottleneck.expansion if v2 else 1), reduction=4, module='body.layer1'),
  164. dict(num_chs=self.planes * 2 * (Bottleneck.expansion if v2 else 1), reduction=8, module='body.layer2'),
  165. dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'),
  166. dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'),
  167. ]
  168. # head
  169. self.num_features = self.head_hidden_size = (self.planes * 8) * Bottleneck.expansion
  170. self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate, **dd)
  171. # model initialization
  172. for m in self.modules():
  173. if isinstance(m, nn.Conv2d):
  174. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
  175. if isinstance(m, nn.Linear):
  176. m.weight.data.normal_(0, 0.01)
  177. # residual connections special initialization
  178. for m in self.modules():
  179. if isinstance(m, BasicBlock):
  180. nn.init.zeros_(m.conv2.bn.weight)
  181. if isinstance(m, Bottleneck):
  182. nn.init.zeros_(m.conv3.bn.weight)
  183. def _make_layer(
  184. self,
  185. block,
  186. planes,
  187. blocks,
  188. stride=1,
  189. use_se=True,
  190. aa_layer=None,
  191. drop_path_rate=0.,
  192. device=None,
  193. dtype=None,
  194. ):
  195. dd = {'device': device, 'dtype': dtype}
  196. downsample = None
  197. if stride != 1 or self.inplanes != planes * block.expansion:
  198. layers = []
  199. if stride == 2:
  200. # avg pooling before 1x1 conv
  201. layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
  202. layers += [ConvNormAct(
  203. self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, **dd)]
  204. downsample = nn.Sequential(*layers)
  205. layers = []
  206. for i in range(blocks):
  207. layers.append(block(
  208. self.inplanes,
  209. planes,
  210. stride=stride if i == 0 else 1,
  211. downsample=downsample if i == 0 else None,
  212. use_se=use_se,
  213. aa_layer=aa_layer,
  214. drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate,
  215. **dd,
  216. ))
  217. self.inplanes = planes * block.expansion
  218. return nn.Sequential(*layers)
  219. @torch.jit.ignore
  220. def group_matcher(self, coarse=False):
  221. matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)')
  222. return matcher
  223. @torch.jit.ignore
  224. def set_grad_checkpointing(self, enable=True):
  225. self.grad_checkpointing = enable
  226. @torch.jit.ignore
  227. def get_classifier(self) -> nn.Module:
  228. return self.head.fc
  229. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  230. self.num_classes = num_classes
  231. self.head.reset(num_classes, pool_type=global_pool)
  232. def forward_intermediates(
  233. self,
  234. x: torch.Tensor,
  235. indices: Optional[Union[int, List[int]]] = None,
  236. norm: bool = False,
  237. stop_early: bool = False,
  238. output_fmt: str = 'NCHW',
  239. intermediates_only: bool = False,
  240. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  241. """ Forward features that returns intermediates.
  242. Args:
  243. x: Input image tensor
  244. indices: Take last n blocks if int, all if None, select matching indices if sequence
  245. norm: Apply norm layer to compatible intermediates
  246. stop_early: Stop iterating over blocks when last desired intermediate hit
  247. output_fmt: Shape of intermediate feature outputs
  248. intermediates_only: Only return intermediate features
  249. Returns:
  250. """
  251. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  252. intermediates = []
  253. stage_ends = [1, 2, 3, 4, 5]
  254. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  255. take_indices = [stage_ends[i] for i in take_indices]
  256. max_index = stage_ends[max_index]
  257. # forward pass
  258. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  259. stages = self.body
  260. else:
  261. stages = self.body[:max_index + 1]
  262. for feat_idx, stage in enumerate(stages):
  263. if self.grad_checkpointing and not torch.jit.is_scripting():
  264. x = checkpoint(stage, x)
  265. else:
  266. x = stage(x)
  267. if feat_idx in take_indices:
  268. intermediates.append(x)
  269. if intermediates_only:
  270. return intermediates
  271. return x, intermediates
  272. def prune_intermediate_layers(
  273. self,
  274. indices: Union[int, List[int]] = 1,
  275. prune_norm: bool = False,
  276. prune_head: bool = True,
  277. ):
  278. """ Prune layers not required for specified intermediates.
  279. """
  280. stage_ends = [1, 2, 3, 4, 5]
  281. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  282. max_index = stage_ends[max_index]
  283. self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0
  284. if prune_head:
  285. self.reset_classifier(0, '')
  286. return take_indices
  287. def forward_features(self, x):
  288. if self.grad_checkpointing and not torch.jit.is_scripting():
  289. x = self.body.s2d(x)
  290. x = self.body.conv1(x)
  291. x = checkpoint_seq([
  292. self.body.layer1,
  293. self.body.layer2,
  294. self.body.layer3,
  295. self.body.layer4],
  296. x, flatten=True)
  297. else:
  298. x = self.body(x)
  299. return x
  300. def forward_head(self, x, pre_logits: bool = False):
  301. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  302. def forward(self, x):
  303. x = self.forward_features(x)
  304. x = self.forward_head(x)
  305. return x
  306. def checkpoint_filter_fn(state_dict, model):
  307. if 'body.conv1.conv.weight' in state_dict:
  308. return state_dict
  309. import re
  310. state_dict = state_dict.get('model', state_dict)
  311. state_dict = state_dict.get('state_dict', state_dict)
  312. out_dict = {}
  313. for k, v in state_dict.items():
  314. k = re.sub(r'conv(\d+)\.0.0', lambda x: f'conv{int(x.group(1))}.conv', k)
  315. k = re.sub(r'conv(\d+)\.0.1', lambda x: f'conv{int(x.group(1))}.bn', k)
  316. k = re.sub(r'conv(\d+)\.0', lambda x: f'conv{int(x.group(1))}.conv', k)
  317. k = re.sub(r'conv(\d+)\.1', lambda x: f'conv{int(x.group(1))}.bn', k)
  318. k = re.sub(r'downsample\.(\d+)\.0', lambda x: f'downsample.{int(x.group(1))}.conv', k)
  319. k = re.sub(r'downsample\.(\d+)\.1', lambda x: f'downsample.{int(x.group(1))}.bn', k)
  320. if k.endswith('bn.weight'):
  321. # convert weight from inplace_abn to batchnorm
  322. v = v.abs().add(1e-5)
  323. out_dict[k] = v
  324. return out_dict
  325. def _create_tresnet(variant, pretrained=False, **kwargs):
  326. return build_model_with_cfg(
  327. TResNet,
  328. variant,
  329. pretrained,
  330. pretrained_filter_fn=checkpoint_filter_fn,
  331. feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
  332. **kwargs,
  333. )
  334. def _cfg(url='', **kwargs):
  335. return {
  336. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  337. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  338. 'mean': (0., 0., 0.), 'std': (1., 1., 1.),
  339. 'first_conv': 'body.conv1.conv', 'classifier': 'head.fc',
  340. 'license': 'apache-2.0',
  341. **kwargs
  342. }
  343. default_cfgs = generate_default_cfgs({
  344. 'tresnet_m.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'),
  345. 'tresnet_m.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221),
  346. 'tresnet_m.miil_in1k': _cfg(hf_hub_id='timm/'),
  347. 'tresnet_l.miil_in1k': _cfg(hf_hub_id='timm/'),
  348. 'tresnet_xl.miil_in1k': _cfg(hf_hub_id='timm/'),
  349. 'tresnet_m.miil_in1k_448': _cfg(
  350. input_size=(3, 448, 448), pool_size=(14, 14),
  351. hf_hub_id='timm/'),
  352. 'tresnet_l.miil_in1k_448': _cfg(
  353. input_size=(3, 448, 448), pool_size=(14, 14),
  354. hf_hub_id='timm/'),
  355. 'tresnet_xl.miil_in1k_448': _cfg(
  356. input_size=(3, 448, 448), pool_size=(14, 14),
  357. hf_hub_id='timm/'),
  358. 'tresnet_v2_l.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'),
  359. 'tresnet_v2_l.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221),
  360. })
  361. @register_model
  362. def tresnet_m(pretrained=False, **kwargs) -> TResNet:
  363. model_args = dict(layers=[3, 4, 11, 3])
  364. return _create_tresnet('tresnet_m', pretrained=pretrained, **dict(model_args, **kwargs))
  365. @register_model
  366. def tresnet_l(pretrained=False, **kwargs) -> TResNet:
  367. model_args = dict(layers=[4, 5, 18, 3], width_factor=1.2)
  368. return _create_tresnet('tresnet_l', pretrained=pretrained, **dict(model_args, **kwargs))
  369. @register_model
  370. def tresnet_xl(pretrained=False, **kwargs) -> TResNet:
  371. model_args = dict(layers=[4, 5, 24, 3], width_factor=1.3)
  372. return _create_tresnet('tresnet_xl', pretrained=pretrained, **dict(model_args, **kwargs))
  373. @register_model
  374. def tresnet_v2_l(pretrained=False, **kwargs) -> TResNet:
  375. model_args = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True)
  376. return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **dict(model_args, **kwargs))
  377. register_model_deprecations(__name__, {
  378. 'tresnet_m_miil_in21k': 'tresnet_m.miil_in21k',
  379. 'tresnet_m_448': 'tresnet_m.miil_in1k_448',
  380. 'tresnet_l_448': 'tresnet_l.miil_in1k_448',
  381. 'tresnet_xl_448': 'tresnet_xl.miil_in1k_448',
  382. })