xception_aligned.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. """Pytorch impl of Aligned Xception 41, 65, 71
  2. This is a correct, from scratch impl of Aligned Xception (Deeplab) models compatible with TF weights at
  3. https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
  4. Hacked together by / Copyright 2020 Ross Wightman
  5. """
  6. from functools import partial
  7. from typing import List, Dict, Type, Optional
  8. import torch
  9. import torch.nn as nn
  10. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  11. from timm.layers import ClassifierHead, ConvNormAct, DropPath, PadType, create_conv2d, get_norm_act_layer
  12. from timm.layers.helpers import to_3tuple
  13. from ._builder import build_model_with_cfg
  14. from ._manipulate import checkpoint_seq
  15. from ._registry import register_model, generate_default_cfgs
  16. __all__ = ['XceptionAligned']
  17. class SeparableConv2d(nn.Module):
  18. def __init__(
  19. self,
  20. in_chs: int,
  21. out_chs: int,
  22. kernel_size: int = 3,
  23. stride: int = 1,
  24. dilation: int = 1,
  25. padding: PadType = '',
  26. act_layer: Type[nn.Module] = nn.ReLU,
  27. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  28. device=None,
  29. dtype=None,
  30. ):
  31. dd = {'device': device, 'dtype': dtype}
  32. super().__init__()
  33. self.kernel_size = kernel_size
  34. self.dilation = dilation
  35. # depthwise convolution
  36. self.conv_dw = create_conv2d(
  37. in_chs,
  38. in_chs,
  39. kernel_size,
  40. stride=stride,
  41. padding=padding,
  42. dilation=dilation,
  43. depthwise=True,
  44. **dd,
  45. )
  46. self.bn_dw = norm_layer(in_chs, **dd)
  47. self.act_dw = act_layer(inplace=True) if act_layer is not None else nn.Identity()
  48. # pointwise convolution
  49. self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1, **dd)
  50. self.bn_pw = norm_layer(out_chs, **dd)
  51. self.act_pw = act_layer(inplace=True) if act_layer is not None else nn.Identity()
  52. def forward(self, x):
  53. x = self.conv_dw(x)
  54. x = self.bn_dw(x)
  55. x = self.act_dw(x)
  56. x = self.conv_pw(x)
  57. x = self.bn_pw(x)
  58. x = self.act_pw(x)
  59. return x
  60. class PreSeparableConv2d(nn.Module):
  61. def __init__(
  62. self,
  63. in_chs: int,
  64. out_chs: int,
  65. kernel_size: int = 3,
  66. stride: int = 1,
  67. dilation: int = 1,
  68. padding: PadType = '',
  69. act_layer: Type[nn.Module] = nn.ReLU,
  70. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  71. first_act: bool = True,
  72. device=None,
  73. dtype=None,
  74. ):
  75. dd = {'device': device, 'dtype': dtype}
  76. super().__init__()
  77. norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
  78. self.kernel_size = kernel_size
  79. self.dilation = dilation
  80. self.norm = norm_act_layer(in_chs, inplace=True, **dd) if first_act else nn.Identity()
  81. # depthwise convolution
  82. self.conv_dw = create_conv2d(
  83. in_chs,
  84. in_chs,
  85. kernel_size,
  86. stride=stride,
  87. padding=padding,
  88. dilation=dilation,
  89. depthwise=True,
  90. **dd,
  91. )
  92. # pointwise convolution
  93. self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1, **dd)
  94. def forward(self, x):
  95. x = self.norm(x)
  96. x = self.conv_dw(x)
  97. x = self.conv_pw(x)
  98. return x
  99. class XceptionModule(nn.Module):
  100. def __init__(
  101. self,
  102. in_chs: int,
  103. out_chs: int,
  104. stride: int = 1,
  105. dilation: int = 1,
  106. pad_type: PadType = '',
  107. start_with_relu: bool = True,
  108. no_skip: bool = False,
  109. act_layer: Type[nn.Module] = nn.ReLU,
  110. norm_layer: Optional[Type[nn.Module]] = None,
  111. drop_path: Optional[nn.Module] = None,
  112. device=None,
  113. dtype=None,
  114. ):
  115. dd = {'device': device, 'dtype': dtype}
  116. super().__init__()
  117. out_chs = to_3tuple(out_chs)
  118. self.in_channels = in_chs
  119. self.out_channels = out_chs[-1]
  120. self.no_skip = no_skip
  121. if not no_skip and (self.out_channels != self.in_channels or stride != 1):
  122. self.shortcut = ConvNormAct(
  123. in_chs,
  124. self.out_channels,
  125. 1,
  126. stride=stride,
  127. norm_layer=norm_layer,
  128. apply_act=False,
  129. **dd,
  130. )
  131. else:
  132. self.shortcut = None
  133. separable_act_layer = None if start_with_relu else act_layer
  134. self.stack = nn.Sequential()
  135. for i in range(3):
  136. if start_with_relu:
  137. self.stack.add_module(f'act{i + 1}', act_layer(inplace=i > 0))
  138. self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
  139. in_chs,
  140. out_chs[i],
  141. 3,
  142. stride=stride if i == 2 else 1,
  143. dilation=dilation,
  144. padding=pad_type,
  145. act_layer=separable_act_layer,
  146. norm_layer=norm_layer,
  147. **dd,
  148. ))
  149. in_chs = out_chs[i]
  150. self.drop_path = drop_path
  151. def forward(self, x):
  152. skip = x
  153. x = self.stack(x)
  154. if self.shortcut is not None:
  155. skip = self.shortcut(skip)
  156. if not self.no_skip:
  157. if self.drop_path is not None:
  158. x = self.drop_path(x)
  159. x = x + skip
  160. return x
  161. class PreXceptionModule(nn.Module):
  162. def __init__(
  163. self,
  164. in_chs: int,
  165. out_chs: int,
  166. stride: int = 1,
  167. dilation: int = 1,
  168. pad_type: PadType = '',
  169. no_skip: bool = False,
  170. act_layer: Type[nn.Module] = nn.ReLU,
  171. norm_layer: Optional[Type[nn.Module]] = None,
  172. drop_path: Optional[nn.Module] = None,
  173. device=None,
  174. dtype=None,
  175. ):
  176. dd = {'device': device, 'dtype': dtype}
  177. super().__init__()
  178. out_chs = to_3tuple(out_chs)
  179. self.in_channels = in_chs
  180. self.out_channels = out_chs[-1]
  181. self.no_skip = no_skip
  182. if not no_skip and (self.out_channels != self.in_channels or stride != 1):
  183. self.shortcut = create_conv2d(in_chs, self.out_channels, 1, stride=stride, **dd)
  184. else:
  185. self.shortcut = nn.Identity()
  186. self.norm = get_norm_act_layer(norm_layer, act_layer=act_layer)(in_chs, inplace=True, **dd)
  187. self.stack = nn.Sequential()
  188. for i in range(3):
  189. self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d(
  190. in_chs,
  191. out_chs[i],
  192. 3,
  193. stride=stride if i == 2 else 1,
  194. dilation=dilation,
  195. padding=pad_type,
  196. act_layer=act_layer,
  197. norm_layer=norm_layer,
  198. first_act=i > 0,
  199. **dd,
  200. ))
  201. in_chs = out_chs[i]
  202. self.drop_path = drop_path
  203. def forward(self, x):
  204. x = self.norm(x)
  205. skip = x
  206. x = self.stack(x)
  207. if not self.no_skip:
  208. if self.drop_path is not None:
  209. x = self.drop_path(x)
  210. x = x + self.shortcut(skip)
  211. return x
  212. class XceptionAligned(nn.Module):
  213. """Modified Aligned Xception
  214. """
  215. def __init__(
  216. self,
  217. block_cfg: List[Dict],
  218. num_classes: int = 1000,
  219. in_chans: int = 3,
  220. output_stride: int = 32,
  221. preact: bool = False,
  222. act_layer: Type[nn.Module] = nn.ReLU,
  223. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  224. drop_rate: float = 0.,
  225. drop_path_rate: float = 0.,
  226. global_pool: str = 'avg',
  227. device=None,
  228. dtype=None,
  229. ):
  230. super().__init__()
  231. dd = {'device': device, 'dtype': dtype}
  232. assert output_stride in (8, 16, 32)
  233. self.num_classes = num_classes
  234. self.in_chans = in_chans
  235. self.drop_rate = drop_rate
  236. self.grad_checkpointing = False
  237. layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, **dd)
  238. self.stem = nn.Sequential(*[
  239. ConvNormAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
  240. create_conv2d(32, 64, kernel_size=3, stride=1, **dd) if preact else
  241. ConvNormAct(32, 64, kernel_size=3, stride=1, **layer_args)
  242. ])
  243. curr_dilation = 1
  244. curr_stride = 2
  245. self.feature_info = []
  246. self.blocks = nn.Sequential()
  247. module_fn = PreXceptionModule if preact else XceptionModule
  248. net_num_blocks = len(block_cfg)
  249. net_block_idx = 0
  250. for i, b in enumerate(block_cfg):
  251. block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
  252. b['drop_path'] = DropPath(block_dpr) if block_dpr > 0. else None
  253. b['dilation'] = curr_dilation
  254. if b['stride'] > 1:
  255. name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
  256. self.feature_info += [dict(num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=name)]
  257. next_stride = curr_stride * b['stride']
  258. if next_stride > output_stride:
  259. curr_dilation *= b['stride']
  260. b['stride'] = 1
  261. else:
  262. curr_stride = next_stride
  263. self.blocks.add_module(str(i), module_fn(**b, **layer_args))
  264. self.num_features = self.blocks[-1].out_channels
  265. net_block_idx += 1
  266. self.feature_info += [dict(
  267. num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
  268. self.act = act_layer(inplace=True) if preact else nn.Identity()
  269. self.head_hidden_size = self.num_features
  270. self.head = ClassifierHead(
  271. in_features=self.num_features,
  272. num_classes=num_classes,
  273. pool_type=global_pool,
  274. drop_rate=drop_rate,
  275. **dd,
  276. )
  277. @torch.jit.ignore
  278. def group_matcher(self, coarse=False):
  279. return dict(
  280. stem=r'^stem',
  281. blocks=r'^blocks\.(\d+)',
  282. )
  283. @torch.jit.ignore
  284. def set_grad_checkpointing(self, enable=True):
  285. self.grad_checkpointing = enable
  286. @torch.jit.ignore
  287. def get_classifier(self) -> nn.Module:
  288. return self.head.fc
  289. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  290. self.num_classes = num_classes
  291. self.head.reset(num_classes, pool_type=global_pool)
  292. def forward_features(self, x):
  293. x = self.stem(x)
  294. if self.grad_checkpointing and not torch.jit.is_scripting():
  295. x = checkpoint_seq(self.blocks, x)
  296. else:
  297. x = self.blocks(x)
  298. x = self.act(x)
  299. return x
  300. def forward_head(self, x, pre_logits: bool = False):
  301. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  302. def forward(self, x):
  303. x = self.forward_features(x)
  304. x = self.forward_head(x)
  305. return x
  306. def _xception(variant, pretrained=False, **kwargs):
  307. return build_model_with_cfg(
  308. XceptionAligned,
  309. variant,
  310. pretrained,
  311. feature_cfg=dict(flatten_sequential=True, feature_cls='hook'),
  312. **kwargs,
  313. )
  314. def _cfg(url='', **kwargs):
  315. return {
  316. 'url': url,
  317. 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
  318. 'crop_pct': 0.903, 'interpolation': 'bicubic',
  319. 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
  320. 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', 'license': 'apache-2.0',
  321. **kwargs
  322. }
  323. default_cfgs = generate_default_cfgs({
  324. 'xception65.ra3_in1k': _cfg(
  325. hf_hub_id='timm/',
  326. crop_pct=0.94,
  327. ),
  328. 'xception41.tf_in1k': _cfg(hf_hub_id='timm/'),
  329. 'xception65.tf_in1k': _cfg(hf_hub_id='timm/'),
  330. 'xception71.tf_in1k': _cfg(hf_hub_id='timm/'),
  331. 'xception41p.ra3_in1k': _cfg(
  332. hf_hub_id='timm/',
  333. crop_pct=0.94,
  334. ),
  335. 'xception65p.ra3_in1k': _cfg(
  336. hf_hub_id='timm/',
  337. crop_pct=0.94,
  338. ),
  339. })
  340. @register_model
  341. def xception41(pretrained=False, **kwargs) -> XceptionAligned:
  342. """ Modified Aligned Xception-41
  343. """
  344. block_cfg = [
  345. # entry flow
  346. dict(in_chs=64, out_chs=128, stride=2),
  347. dict(in_chs=128, out_chs=256, stride=2),
  348. dict(in_chs=256, out_chs=728, stride=2),
  349. # middle flow
  350. *([dict(in_chs=728, out_chs=728, stride=1)] * 8),
  351. # exit flow
  352. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  353. dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
  354. ]
  355. model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1))
  356. return _xception('xception41', pretrained=pretrained, **dict(model_args, **kwargs))
  357. @register_model
  358. def xception65(pretrained=False, **kwargs) -> XceptionAligned:
  359. """ Modified Aligned Xception-65
  360. """
  361. block_cfg = [
  362. # entry flow
  363. dict(in_chs=64, out_chs=128, stride=2),
  364. dict(in_chs=128, out_chs=256, stride=2),
  365. dict(in_chs=256, out_chs=728, stride=2),
  366. # middle flow
  367. *([dict(in_chs=728, out_chs=728, stride=1)] * 16),
  368. # exit flow
  369. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  370. dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
  371. ]
  372. model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1))
  373. return _xception('xception65', pretrained=pretrained, **dict(model_args, **kwargs))
  374. @register_model
  375. def xception71(pretrained=False, **kwargs) -> XceptionAligned:
  376. """ Modified Aligned Xception-71
  377. """
  378. block_cfg = [
  379. # entry flow
  380. dict(in_chs=64, out_chs=128, stride=2),
  381. dict(in_chs=128, out_chs=256, stride=1),
  382. dict(in_chs=256, out_chs=256, stride=2),
  383. dict(in_chs=256, out_chs=728, stride=1),
  384. dict(in_chs=728, out_chs=728, stride=2),
  385. # middle flow
  386. *([dict(in_chs=728, out_chs=728, stride=1)] * 16),
  387. # exit flow
  388. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  389. dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
  390. ]
  391. model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1))
  392. return _xception('xception71', pretrained=pretrained, **dict(model_args, **kwargs))
  393. @register_model
  394. def xception41p(pretrained=False, **kwargs) -> XceptionAligned:
  395. """ Modified Aligned Xception-41 w/ Pre-Act
  396. """
  397. block_cfg = [
  398. # entry flow
  399. dict(in_chs=64, out_chs=128, stride=2),
  400. dict(in_chs=128, out_chs=256, stride=2),
  401. dict(in_chs=256, out_chs=728, stride=2),
  402. # middle flow
  403. *([dict(in_chs=728, out_chs=728, stride=1)] * 8),
  404. # exit flow
  405. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  406. dict(in_chs=1024, out_chs=(1536, 1536, 2048), no_skip=True, stride=1),
  407. ]
  408. model_args = dict(block_cfg=block_cfg, preact=True, norm_layer=nn.BatchNorm2d)
  409. return _xception('xception41p', pretrained=pretrained, **dict(model_args, **kwargs))
  410. @register_model
  411. def xception65p(pretrained=False, **kwargs) -> XceptionAligned:
  412. """ Modified Aligned Xception-65 w/ Pre-Act
  413. """
  414. block_cfg = [
  415. # entry flow
  416. dict(in_chs=64, out_chs=128, stride=2),
  417. dict(in_chs=128, out_chs=256, stride=2),
  418. dict(in_chs=256, out_chs=728, stride=2),
  419. # middle flow
  420. *([dict(in_chs=728, out_chs=728, stride=1)] * 16),
  421. # exit flow
  422. dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
  423. dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True),
  424. ]
  425. model_args = dict(
  426. block_cfg=block_cfg, preact=True, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1))
  427. return _xception('xception65p', pretrained=pretrained, **dict(model_args, **kwargs))