dpn.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. """ PyTorch implementation of DualPathNetworks
  2. Based on original MXNet implementation https://github.com/cypw/DPNs with
  3. many ideas from another PyTorch implementation https://github.com/oyam/pytorch-DPNs.
  4. This implementation is compatible with the pretrained weights from cypw's MXNet implementation.
  5. Hacked together by / Copyright 2020 Ross Wightman
  6. """
  7. from collections import OrderedDict
  8. from functools import partial
  9. from typing import Tuple, Type, Optional
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
  15. from ._builder import build_model_with_cfg
  16. from ._registry import register_model, generate_default_cfgs
  17. __all__ = ['DPN']
  18. class CatBnAct(nn.Module):
  19. def __init__(
  20. self,
  21. in_chs: int,
  22. norm_layer: Type[nn.Module] = BatchNormAct2d,
  23. device=None,
  24. dtype=None,
  25. ):
  26. dd = {'device': device, 'dtype': dtype}
  27. super().__init__()
  28. self.bn = norm_layer(in_chs, eps=0.001, **dd)
  29. def forward(self, x):
  30. if isinstance(x, tuple):
  31. x = torch.cat(x, dim=1)
  32. return self.bn(x)
  33. class BnActConv2d(nn.Module):
  34. def __init__(
  35. self,
  36. in_chs: int,
  37. out_chs: int,
  38. kernel_size: int,
  39. stride: int,
  40. groups: int = 1,
  41. norm_layer: Type[nn.Module] = BatchNormAct2d,
  42. device=None,
  43. dtype=None,
  44. ):
  45. dd = {'device': device, 'dtype': dtype}
  46. super().__init__()
  47. self.bn = norm_layer(in_chs, eps=0.001, **dd)
  48. self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups, **dd)
  49. def forward(self, x):
  50. return self.conv(self.bn(x))
  51. class DualPathBlock(nn.Module):
  52. def __init__(
  53. self,
  54. in_chs: int,
  55. num_1x1_a: int,
  56. num_3x3_b: int,
  57. num_1x1_c: int,
  58. inc: int,
  59. groups: int,
  60. block_type: str = 'normal',
  61. b: bool = False,
  62. device=None,
  63. dtype=None,
  64. ):
  65. dd = {'device': device, 'dtype': dtype}
  66. super().__init__()
  67. self.num_1x1_c = num_1x1_c
  68. self.inc = inc
  69. self.b = b
  70. if block_type == 'proj':
  71. self.key_stride = 1
  72. self.has_proj = True
  73. elif block_type == 'down':
  74. self.key_stride = 2
  75. self.has_proj = True
  76. else:
  77. assert block_type == 'normal'
  78. self.key_stride = 1
  79. self.has_proj = False
  80. self.c1x1_w_s1 = None
  81. self.c1x1_w_s2 = None
  82. if self.has_proj:
  83. # Using different member names here to allow easier parameter key matching for conversion
  84. if self.key_stride == 2:
  85. self.c1x1_w_s2 = BnActConv2d(
  86. in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2, **dd)
  87. else:
  88. self.c1x1_w_s1 = BnActConv2d(
  89. in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1, **dd)
  90. self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1, **dd)
  91. self.c3x3_b = BnActConv2d(
  92. in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups, **dd)
  93. if b:
  94. self.c1x1_c = CatBnAct(in_chs=num_3x3_b, **dd)
  95. self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1, **dd)
  96. self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1, **dd)
  97. else:
  98. self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1, **dd)
  99. self.c1x1_c1 = None
  100. self.c1x1_c2 = None
  101. def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
  102. if isinstance(x, tuple):
  103. x_in = torch.cat(x, dim=1)
  104. else:
  105. x_in = x
  106. if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None:
  107. # self.has_proj == False, torchscript requires condition on module == None
  108. x_s1 = x[0]
  109. x_s2 = x[1]
  110. else:
  111. # self.has_proj == True
  112. if self.c1x1_w_s1 is not None:
  113. # self.key_stride = 1
  114. x_s = self.c1x1_w_s1(x_in)
  115. else:
  116. # self.key_stride = 2
  117. x_s = self.c1x1_w_s2(x_in)
  118. x_s1 = x_s[:, :self.num_1x1_c, :, :]
  119. x_s2 = x_s[:, self.num_1x1_c:, :, :]
  120. x_in = self.c1x1_a(x_in)
  121. x_in = self.c3x3_b(x_in)
  122. x_in = self.c1x1_c(x_in)
  123. if self.c1x1_c1 is not None:
  124. # self.b == True, using None check for torchscript compat
  125. out1 = self.c1x1_c1(x_in)
  126. out2 = self.c1x1_c2(x_in)
  127. else:
  128. out1 = x_in[:, :self.num_1x1_c, :, :]
  129. out2 = x_in[:, self.num_1x1_c:, :, :]
  130. resid = x_s1 + out1
  131. dense = torch.cat([x_s2, out2], dim=1)
  132. return resid, dense
  133. class DPN(nn.Module):
  134. def __init__(
  135. self,
  136. k_sec: Tuple[int, ...] = (3, 4, 20, 3),
  137. inc_sec: Tuple[int, ...] = (16, 32, 24, 128),
  138. k_r: int = 96,
  139. groups: int = 32,
  140. num_classes: int = 1000,
  141. in_chans: int = 3,
  142. output_stride: int = 32,
  143. global_pool: str = 'avg',
  144. small: bool = False,
  145. num_init_features: int = 64,
  146. b: bool = False,
  147. drop_rate: float = 0.,
  148. norm_layer: str = 'batchnorm2d',
  149. act_layer: str = 'relu',
  150. fc_act_layer: str = 'elu',
  151. device=None,
  152. dtype=None,
  153. ):
  154. super().__init__()
  155. dd = {'device': device, 'dtype': dtype}
  156. self.num_classes = num_classes
  157. self.in_chans = in_chans
  158. self.drop_rate = drop_rate
  159. self.b = b
  160. assert output_stride == 32 # FIXME look into dilation support
  161. norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=act_layer), eps=.001)
  162. fc_norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=fc_act_layer), eps=.001, inplace=False)
  163. bw_factor = 1 if small else 4
  164. blocks = OrderedDict()
  165. # conv1
  166. blocks['conv1_1'] = ConvNormAct(
  167. in_chans,
  168. num_init_features,
  169. kernel_size=3 if small else 7,
  170. stride=2,
  171. norm_layer=norm_layer,
  172. **dd,
  173. )
  174. blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  175. self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]
  176. # conv2
  177. bw = 64 * bw_factor
  178. inc = inc_sec[0]
  179. r = (k_r * bw) // (64 * bw_factor)
  180. blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b, **dd)
  181. in_chs = bw + 3 * inc
  182. for i in range(2, k_sec[0] + 1):
  183. blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd)
  184. in_chs += inc
  185. self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')]
  186. # conv3
  187. bw = 128 * bw_factor
  188. inc = inc_sec[1]
  189. r = (k_r * bw) // (64 * bw_factor)
  190. blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd)
  191. in_chs = bw + 3 * inc
  192. for i in range(2, k_sec[1] + 1):
  193. blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd)
  194. in_chs += inc
  195. self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')]
  196. # conv4
  197. bw = 256 * bw_factor
  198. inc = inc_sec[2]
  199. r = (k_r * bw) // (64 * bw_factor)
  200. blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd)
  201. in_chs = bw + 3 * inc
  202. for i in range(2, k_sec[2] + 1):
  203. blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd)
  204. in_chs += inc
  205. self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')]
  206. # conv5
  207. bw = 512 * bw_factor
  208. inc = inc_sec[3]
  209. r = (k_r * bw) // (64 * bw_factor)
  210. blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b, **dd)
  211. in_chs = bw + 3 * inc
  212. for i in range(2, k_sec[3] + 1):
  213. blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b, **dd)
  214. in_chs += inc
  215. self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')]
  216. blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer, **dd)
  217. self.num_features = self.head_hidden_size = in_chs
  218. self.features = nn.Sequential(blocks)
  219. # Using 1x1 conv for the FC layer to allow the extra pooling scheme
  220. self.global_pool, self.classifier = create_classifier(
  221. self.num_features,
  222. self.num_classes,
  223. pool_type=global_pool,
  224. use_conv=True,
  225. **dd,
  226. )
  227. self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  228. @torch.jit.ignore
  229. def group_matcher(self, coarse=False):
  230. matcher = dict(
  231. stem=r'^features\.conv1',
  232. blocks=[
  233. (r'^features\.conv(\d+)' if coarse else r'^features\.conv(\d+)_(\d+)', None),
  234. (r'^features\.conv5_bn_ac', (99999,))
  235. ]
  236. )
  237. return matcher
  238. @torch.jit.ignore
  239. def set_grad_checkpointing(self, enable=True):
  240. assert not enable, 'gradient checkpointing not supported'
  241. @torch.jit.ignore
  242. def get_classifier(self) -> nn.Module:
  243. return self.classifier
  244. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  245. self.num_classes = num_classes
  246. self.global_pool, self.classifier = create_classifier(
  247. self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
  248. self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  249. def forward_features(self, x):
  250. return self.features(x)
  251. def forward_head(self, x, pre_logits: bool = False):
  252. x = self.global_pool(x)
  253. if self.drop_rate > 0.:
  254. x = F.dropout(x, p=self.drop_rate, training=self.training)
  255. if pre_logits:
  256. return self.flatten(x)
  257. x = self.classifier(x)
  258. return self.flatten(x)
  259. def forward(self, x):
  260. x = self.forward_features(x)
  261. x = self.forward_head(x)
  262. return x
  263. def _create_dpn(variant, pretrained=False, **kwargs):
  264. return build_model_with_cfg(
  265. DPN,
  266. variant,
  267. pretrained,
  268. feature_cfg=dict(feature_concat=True, flatten_sequential=True),
  269. **kwargs,
  270. )
  271. def _cfg(url='', **kwargs):
  272. return {
  273. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  274. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  275. 'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
  276. 'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier', 'license': 'apache-2.0',
  277. **kwargs
  278. }
  279. default_cfgs = generate_default_cfgs({
  280. 'dpn48b.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
  281. 'dpn68.mx_in1k': _cfg(hf_hub_id='timm/'),
  282. 'dpn68b.ra_in1k': _cfg(
  283. hf_hub_id='timm/',
  284. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  285. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  286. 'dpn68b.mx_in1k': _cfg(hf_hub_id='timm/'),
  287. 'dpn92.mx_in1k': _cfg(hf_hub_id='timm/'),
  288. 'dpn98.mx_in1k': _cfg(hf_hub_id='timm/'),
  289. 'dpn131.mx_in1k': _cfg(hf_hub_id='timm/'),
  290. 'dpn107.mx_in1k': _cfg(hf_hub_id='timm/')
  291. })
  292. @register_model
  293. def dpn48b(pretrained=False, **kwargs) -> DPN:
  294. model_args = dict(
  295. small=True, num_init_features=10, k_r=128, groups=32,
  296. b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu')
  297. return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_args, **kwargs))
  298. @register_model
  299. def dpn68(pretrained=False, **kwargs) -> DPN:
  300. model_args = dict(
  301. small=True, num_init_features=10, k_r=128, groups=32,
  302. k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
  303. return _create_dpn('dpn68', pretrained=pretrained, **dict(model_args, **kwargs))
  304. @register_model
  305. def dpn68b(pretrained=False, **kwargs) -> DPN:
  306. model_args = dict(
  307. small=True, num_init_features=10, k_r=128, groups=32,
  308. b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
  309. return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_args, **kwargs))
  310. @register_model
  311. def dpn92(pretrained=False, **kwargs) -> DPN:
  312. model_args = dict(
  313. num_init_features=64, k_r=96, groups=32,
  314. k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128))
  315. return _create_dpn('dpn92', pretrained=pretrained, **dict(model_args, **kwargs))
  316. @register_model
  317. def dpn98(pretrained=False, **kwargs) -> DPN:
  318. model_args = dict(
  319. num_init_features=96, k_r=160, groups=40,
  320. k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128))
  321. return _create_dpn('dpn98', pretrained=pretrained, **dict(model_args, **kwargs))
  322. @register_model
  323. def dpn131(pretrained=False, **kwargs) -> DPN:
  324. model_args = dict(
  325. num_init_features=128, k_r=160, groups=40,
  326. k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128))
  327. return _create_dpn('dpn131', pretrained=pretrained, **dict(model_args, **kwargs))
  328. @register_model
  329. def dpn107(pretrained=False, **kwargs) -> DPN:
  330. model_args = dict(
  331. num_init_features=128, k_r=200, groups=50,
  332. k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128))
  333. return _create_dpn('dpn107', pretrained=pretrained, **dict(model_args, **kwargs))