densenet.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. """Pytorch Densenet implementation w/ tweaks
  2. This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
  3. fixed kwargs passthrough and addition of dynamic global avg/max pool.
  4. """
  5. import re
  6. from collections import OrderedDict
  7. from typing import Any, Dict, Optional, Tuple, Type, Union
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from torch.jit.annotations import List
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
  14. from ._builder import build_model_with_cfg
  15. from ._manipulate import MATCH_PREV_GROUP, checkpoint
  16. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  17. __all__ = ['DenseNet']
  18. class DenseLayer(nn.Module):
  19. """Dense layer for DenseNet.
  20. Implements the bottleneck layer with 1x1 and 3x3 convolutions.
  21. """
  22. def __init__(
  23. self,
  24. num_input_features: int,
  25. growth_rate: int,
  26. bn_size: int,
  27. norm_layer: Type[nn.Module] = BatchNormAct2d,
  28. drop_rate: float = 0.,
  29. grad_checkpointing: bool = False,
  30. device=None,
  31. dtype=None,
  32. ) -> None:
  33. """Initialize DenseLayer.
  34. Args:
  35. num_input_features: Number of input features.
  36. growth_rate: Growth rate (k) of the layer.
  37. bn_size: Bottleneck size multiplier.
  38. norm_layer: Normalization layer class.
  39. drop_rate: Dropout rate.
  40. grad_checkpointing: Use gradient checkpointing.
  41. """
  42. dd = {'device': device, 'dtype': dtype}
  43. super().__init__()
  44. self.add_module('norm1', norm_layer(num_input_features, **dd)),
  45. self.add_module('conv1', nn.Conv2d(
  46. num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False, **dd)),
  47. self.add_module('norm2', norm_layer(bn_size * growth_rate, **dd)),
  48. self.add_module('conv2', nn.Conv2d(
  49. bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False, **dd)),
  50. self.drop_rate = float(drop_rate)
  51. self.grad_checkpointing = grad_checkpointing
  52. def bottleneck_fn(self, xs: List[torch.Tensor]) -> torch.Tensor:
  53. """Bottleneck function for concatenated features."""
  54. concated_features = torch.cat(xs, 1)
  55. bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484
  56. return bottleneck_output
  57. # todo: rewrite when torchscript supports any
  58. def any_requires_grad(self, x: List[torch.Tensor]) -> bool:
  59. """Check if any tensor in list requires gradient."""
  60. for tensor in x:
  61. if tensor.requires_grad:
  62. return True
  63. return False
  64. def call_checkpoint_bottleneck(self, x: List[torch.Tensor]) -> torch.Tensor:
  65. """Call bottleneck function with gradient checkpointing."""
  66. def closure(*xs):
  67. return self.bottleneck_fn(xs)
  68. return checkpoint(closure, *x)
  69. # torchscript does not yet support *args, so we overload method
  70. # allowing it to take either a List[Tensor] or single Tensor
  71. def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: # noqa: F811
  72. """Forward pass.
  73. Args:
  74. x: Input features (single tensor or list of tensors).
  75. Returns:
  76. New features to be concatenated.
  77. """
  78. if isinstance(x, torch.Tensor):
  79. prev_features = [x]
  80. else:
  81. prev_features = x
  82. if self.grad_checkpointing and self.any_requires_grad(prev_features):
  83. if torch.jit.is_scripting():
  84. raise Exception("Memory Efficient not supported in JIT")
  85. bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
  86. else:
  87. bottleneck_output = self.bottleneck_fn(prev_features)
  88. new_features = self.conv2(self.norm2(bottleneck_output))
  89. if self.drop_rate > 0:
  90. new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
  91. return new_features
  92. class DenseBlock(nn.ModuleDict):
  93. """DenseNet Block.
  94. Contains multiple dense layers with concatenated features.
  95. """
  96. _version = 2
  97. def __init__(
  98. self,
  99. num_layers: int,
  100. num_input_features: int,
  101. bn_size: int,
  102. growth_rate: int,
  103. norm_layer: Type[nn.Module] = BatchNormAct2d,
  104. drop_rate: float = 0.,
  105. grad_checkpointing: bool = False,
  106. device=None,
  107. dtype=None,
  108. ) -> None:
  109. """Initialize DenseBlock.
  110. Args:
  111. num_layers: Number of layers in the block.
  112. num_input_features: Number of input features.
  113. bn_size: Bottleneck size multiplier.
  114. growth_rate: Growth rate (k) for each layer.
  115. norm_layer: Normalization layer class.
  116. drop_rate: Dropout rate.
  117. grad_checkpointing: Use gradient checkpointing.
  118. """
  119. dd = {'device': device, 'dtype': dtype}
  120. super().__init__()
  121. for i in range(num_layers):
  122. layer = DenseLayer(
  123. num_input_features + i * growth_rate,
  124. growth_rate=growth_rate,
  125. bn_size=bn_size,
  126. norm_layer=norm_layer,
  127. drop_rate=drop_rate,
  128. grad_checkpointing=grad_checkpointing,
  129. **dd,
  130. )
  131. self.add_module('denselayer%d' % (i + 1), layer)
  132. def forward(self, init_features: torch.Tensor) -> torch.Tensor:
  133. """Forward pass through all layers in the block.
  134. Args:
  135. init_features: Initial features from previous layer.
  136. Returns:
  137. Concatenated features from all layers.
  138. """
  139. features = [init_features]
  140. for name, layer in self.items():
  141. new_features = layer(features)
  142. features.append(new_features)
  143. return torch.cat(features, 1)
  144. class DenseTransition(nn.Sequential):
  145. """Transition layer between DenseNet blocks.
  146. Reduces feature dimensions and spatial resolution.
  147. """
  148. def __init__(
  149. self,
  150. num_input_features: int,
  151. num_output_features: int,
  152. norm_layer: Type[nn.Module] = BatchNormAct2d,
  153. aa_layer: Optional[Type[nn.Module]] = None,
  154. device=None,
  155. dtype=None,
  156. ) -> None:
  157. """Initialize DenseTransition.
  158. Args:
  159. num_input_features: Number of input features.
  160. num_output_features: Number of output features.
  161. norm_layer: Normalization layer class.
  162. aa_layer: Anti-aliasing layer class.
  163. """
  164. dd = {'device': device, 'dtype': dtype}
  165. super().__init__()
  166. self.add_module('norm', norm_layer(num_input_features, **dd))
  167. self.add_module('conv', nn.Conv2d(
  168. num_input_features, num_output_features, kernel_size=1, stride=1, bias=False, **dd))
  169. if aa_layer is not None:
  170. self.add_module('pool', aa_layer(num_output_features, stride=2, **dd))
  171. else:
  172. self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
  173. class DenseNet(nn.Module):
  174. """Densenet-BC model class.
  175. Based on `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
  176. Args:
  177. growth_rate: How many filters to add each layer (`k` in paper).
  178. block_config: How many layers in each pooling block.
  179. bn_size: Multiplicative factor for number of bottle neck layers
  180. (i.e. bn_size * k features in the bottleneck layer).
  181. drop_rate: Dropout rate before classifier layer.
  182. proj_drop_rate: Dropout rate after each dense layer.
  183. num_classes: Number of classification classes.
  184. memory_efficient: If True, uses checkpointing. Much more memory efficient,
  185. but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
  186. """
  187. def __init__(
  188. self,
  189. growth_rate: int = 32,
  190. block_config: Tuple[int, ...] = (6, 12, 24, 16),
  191. num_classes: int = 1000,
  192. in_chans: int = 3,
  193. global_pool: str = 'avg',
  194. bn_size: int = 4,
  195. stem_type: str = '',
  196. act_layer: str = 'relu',
  197. norm_layer: str = 'batchnorm2d',
  198. aa_layer: Optional[Type[nn.Module]] = None,
  199. drop_rate: float = 0.,
  200. proj_drop_rate: float = 0.,
  201. memory_efficient: bool = False,
  202. aa_stem_only: bool = True,
  203. device=None,
  204. dtype=None,
  205. ) -> None:
  206. """Initialize DenseNet.
  207. Args:
  208. growth_rate: How many filters to add each layer (k in paper).
  209. block_config: How many layers in each pooling block.
  210. num_classes: Number of classification classes.
  211. in_chans: Number of input channels.
  212. global_pool: Global pooling type.
  213. bn_size: Multiplicative factor for number of bottle neck layers.
  214. stem_type: Type of stem ('', 'deep', 'deep_tiered').
  215. act_layer: Activation layer.
  216. norm_layer: Normalization layer.
  217. aa_layer: Anti-aliasing layer.
  218. drop_rate: Dropout rate before classifier layer.
  219. proj_drop_rate: Dropout rate after each dense layer.
  220. memory_efficient: If True, uses checkpointing for memory efficiency.
  221. aa_stem_only: Apply anti-aliasing only to stem.
  222. """
  223. dd = {'device': device, 'dtype': dtype}
  224. self.num_classes = num_classes
  225. self.in_chans = in_chans
  226. super().__init__()
  227. norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
  228. # Stem
  229. deep_stem = 'deep' in stem_type # 3x3 deep stem
  230. num_init_features = growth_rate * 2
  231. if aa_layer is None:
  232. stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  233. else:
  234. stem_pool = nn.Sequential(*[
  235. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  236. aa_layer(channels=num_init_features, stride=2, **dd)])
  237. if deep_stem:
  238. stem_chs_1 = stem_chs_2 = growth_rate
  239. if 'tiered' in stem_type:
  240. stem_chs_1 = 3 * (growth_rate // 4)
  241. stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4)
  242. self.features = nn.Sequential(OrderedDict([
  243. ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False, **dd)),
  244. ('norm0', norm_layer(stem_chs_1, **dd)),
  245. ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False, **dd)),
  246. ('norm1', norm_layer(stem_chs_2, **dd)),
  247. ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False, **dd)),
  248. ('norm2', norm_layer(num_init_features, **dd)),
  249. ('pool0', stem_pool),
  250. ]))
  251. else:
  252. self.features = nn.Sequential(OrderedDict([
  253. ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False, **dd)),
  254. ('norm0', norm_layer(num_init_features, **dd)),
  255. ('pool0', stem_pool),
  256. ]))
  257. self.feature_info = [
  258. dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')]
  259. current_stride = 4
  260. # DenseBlocks
  261. num_features = num_init_features
  262. for i, num_layers in enumerate(block_config):
  263. block = DenseBlock(
  264. num_layers=num_layers,
  265. num_input_features=num_features,
  266. bn_size=bn_size,
  267. growth_rate=growth_rate,
  268. norm_layer=norm_layer,
  269. drop_rate=proj_drop_rate,
  270. grad_checkpointing=memory_efficient,
  271. **dd,
  272. )
  273. module_name = f'denseblock{(i + 1)}'
  274. self.features.add_module(module_name, block)
  275. num_features = num_features + num_layers * growth_rate
  276. transition_aa_layer = None if aa_stem_only else aa_layer
  277. if i != len(block_config) - 1:
  278. self.feature_info += [
  279. dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
  280. current_stride *= 2
  281. trans = DenseTransition(
  282. num_input_features=num_features,
  283. num_output_features=num_features // 2,
  284. norm_layer=norm_layer,
  285. aa_layer=transition_aa_layer,
  286. **dd,
  287. )
  288. self.features.add_module(f'transition{i + 1}', trans)
  289. num_features = num_features // 2
  290. # Final batch norm
  291. self.features.add_module('norm5', norm_layer(num_features, **dd))
  292. self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')]
  293. self.num_features = self.head_hidden_size = num_features
  294. # Linear layer
  295. global_pool, classifier = create_classifier(
  296. self.num_features,
  297. self.num_classes,
  298. pool_type=global_pool,
  299. **dd,
  300. )
  301. self.global_pool = global_pool
  302. self.head_drop = nn.Dropout(drop_rate)
  303. self.classifier = classifier
  304. # Official init from torch repo.
  305. for m in self.modules():
  306. if isinstance(m, nn.Conv2d):
  307. nn.init.kaiming_normal_(m.weight)
  308. elif isinstance(m, nn.BatchNorm2d):
  309. nn.init.constant_(m.weight, 1)
  310. nn.init.constant_(m.bias, 0)
  311. elif isinstance(m, nn.Linear):
  312. nn.init.constant_(m.bias, 0)
  313. @torch.jit.ignore
  314. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  315. """Group parameters for optimization."""
  316. matcher = dict(
  317. stem=r'^features\.conv[012]|features\.norm[012]|features\.pool[012]',
  318. blocks=r'^features\.(?:denseblock|transition)(\d+)' if coarse else [
  319. (r'^features\.denseblock(\d+)\.denselayer(\d+)', None),
  320. (r'^features\.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer
  321. ]
  322. )
  323. return matcher
  324. @torch.jit.ignore
  325. def set_grad_checkpointing(self, enable: bool = True) -> None:
  326. """Enable or disable gradient checkpointing."""
  327. for b in self.features.modules():
  328. if isinstance(b, DenseLayer):
  329. b.grad_checkpointing = enable
  330. @torch.jit.ignore
  331. def get_classifier(self) -> nn.Module:
  332. """Get the classifier head."""
  333. return self.classifier
  334. def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None:
  335. """Reset the classifier head.
  336. Args:
  337. num_classes: Number of classes for new classifier.
  338. global_pool: Global pooling type.
  339. """
  340. self.num_classes = num_classes
  341. self.global_pool, self.classifier = create_classifier(
  342. self.num_features, self.num_classes, pool_type=global_pool)
  343. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  344. """Forward pass through feature extraction layers."""
  345. return self.features(x)
  346. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  347. """Forward pass through classifier head.
  348. Args:
  349. x: Feature tensor.
  350. pre_logits: Return features before final classifier.
  351. Returns:
  352. Output tensor.
  353. """
  354. x = self.global_pool(x)
  355. x = self.head_drop(x)
  356. return x if pre_logits else self.classifier(x)
  357. def forward(self, x: torch.Tensor) -> torch.Tensor:
  358. """Forward pass.
  359. Args:
  360. x: Input tensor.
  361. Returns:
  362. Output logits.
  363. """
  364. x = self.forward_features(x)
  365. x = self.forward_head(x)
  366. return x
  367. def _filter_torchvision_pretrained(state_dict: dict) -> Dict[str, torch.Tensor]:
  368. """Filter torchvision pretrained state dict for compatibility.
  369. Args:
  370. state_dict: State dictionary from torchvision checkpoint.
  371. Returns:
  372. Filtered state dictionary.
  373. """
  374. pattern = re.compile(
  375. r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
  376. for key in list(state_dict.keys()):
  377. res = pattern.match(key)
  378. if res:
  379. new_key = res.group(1) + res.group(2)
  380. state_dict[new_key] = state_dict[key]
  381. del state_dict[key]
  382. return state_dict
  383. def _create_densenet(
  384. variant: str,
  385. growth_rate: int,
  386. block_config: Tuple[int, ...],
  387. pretrained: bool,
  388. **kwargs,
  389. ) -> DenseNet:
  390. """Create a DenseNet model.
  391. Args:
  392. variant: Model variant name.
  393. growth_rate: Growth rate parameter.
  394. block_config: Block configuration.
  395. pretrained: Load pretrained weights.
  396. **kwargs: Additional model arguments.
  397. Returns:
  398. DenseNet model instance.
  399. """
  400. kwargs['growth_rate'] = growth_rate
  401. kwargs['block_config'] = block_config
  402. return build_model_with_cfg(
  403. DenseNet,
  404. variant,
  405. pretrained,
  406. feature_cfg=dict(flatten_sequential=True),
  407. pretrained_filter_fn=_filter_torchvision_pretrained,
  408. **kwargs,
  409. )
  410. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  411. """Create default configuration for DenseNet models."""
  412. return {
  413. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  414. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  415. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  416. 'first_conv': 'features.conv0', 'classifier': 'classifier', 'license': 'apache-2.0',
  417. **kwargs,
  418. }
  419. default_cfgs = generate_default_cfgs({
  420. 'densenet121.ra_in1k': _cfg(
  421. hf_hub_id='timm/',
  422. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  423. 'densenetblur121d.ra_in1k': _cfg(
  424. hf_hub_id='timm/',
  425. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  426. 'densenet264d.untrained': _cfg(),
  427. 'densenet121.tv_in1k': _cfg(hf_hub_id='timm/'),
  428. 'densenet169.tv_in1k': _cfg(hf_hub_id='timm/'),
  429. 'densenet201.tv_in1k': _cfg(hf_hub_id='timm/'),
  430. 'densenet161.tv_in1k': _cfg(hf_hub_id='timm/'),
  431. })
  432. @register_model
  433. def densenet121(pretrained=False, **kwargs) -> DenseNet:
  434. r"""Densenet-121 model from
  435. `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
  436. """
  437. model_args = dict(growth_rate=32, block_config=(6, 12, 24, 16))
  438. model = _create_densenet('densenet121', pretrained=pretrained, **dict(model_args, **kwargs))
  439. return model
  440. @register_model
  441. def densenetblur121d(pretrained=False, **kwargs) -> DenseNet:
  442. r"""Densenet-121 w/ blur-pooling & 3-layer 3x3 stem
  443. `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
  444. """
  445. model_args = dict(growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', aa_layer=BlurPool2d)
  446. model = _create_densenet('densenetblur121d', pretrained=pretrained, **dict(model_args, **kwargs))
  447. return model
  448. @register_model
  449. def densenet169(pretrained=False, **kwargs) -> DenseNet:
  450. r"""Densenet-169 model from
  451. `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
  452. """
  453. model_args = dict(growth_rate=32, block_config=(6, 12, 32, 32))
  454. model = _create_densenet('densenet169', pretrained=pretrained, **dict(model_args, **kwargs))
  455. return model
  456. @register_model
  457. def densenet201(pretrained=False, **kwargs) -> DenseNet:
  458. r"""Densenet-201 model from
  459. `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
  460. """
  461. model_args = dict(growth_rate=32, block_config=(6, 12, 48, 32))
  462. model = _create_densenet('densenet201', pretrained=pretrained, **dict(model_args, **kwargs))
  463. return model
  464. @register_model
  465. def densenet161(pretrained=False, **kwargs) -> DenseNet:
  466. r"""Densenet-161 model from
  467. `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
  468. """
  469. model_args = dict(growth_rate=48, block_config=(6, 12, 36, 24))
  470. model = _create_densenet('densenet161', pretrained=pretrained, **dict(model_args, **kwargs))
  471. return model
  472. @register_model
  473. def densenet264d(pretrained=False, **kwargs) -> DenseNet:
  474. r"""Densenet-264 model from
  475. `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
  476. """
  477. model_args = dict(growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep')
  478. model = _create_densenet('densenet264d', pretrained=pretrained, **dict(model_args, **kwargs))
  479. return model
  480. register_model_deprecations(__name__, {
  481. 'tv_densenet121': 'densenet121.tv_in1k',
  482. })