vgg.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. """VGG
  2. Adapted from https://github.com/pytorch/vision 'vgg.py' (BSD-3-Clause) with a few changes for
  3. timm functionality.
  4. Copyright 2021 Ross Wightman
  5. """
  6. from typing import Any, Dict, List, Optional, Type, Union, cast
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  11. from timm.layers import ClassifierHead
  12. from ._builder import build_model_with_cfg
  13. from ._features_fx import register_notrace_module
  14. from ._registry import register_model, generate_default_cfgs
  15. __all__ = ['VGG']
  16. cfgs: Dict[str, List[Union[str, int]]] = {
  17. 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  18. 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  19. 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  20. 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
  21. }
  22. @register_notrace_module # reason: FX can't symbolically trace control flow in forward method
  23. class ConvMlp(nn.Module):
  24. """Convolutional MLP block for VGG head.
  25. Replaces traditional Linear layers with Conv2d layers in the classifier.
  26. """
  27. def __init__(
  28. self,
  29. in_features: int = 512,
  30. out_features: int = 4096,
  31. kernel_size: int = 7,
  32. mlp_ratio: float = 1.0,
  33. drop_rate: float = 0.2,
  34. act_layer: Type[nn.Module] = nn.ReLU,
  35. conv_layer: Type[nn.Module] = nn.Conv2d,
  36. device=None,
  37. dtype=None,
  38. ) -> None:
  39. """Initialize ConvMlp.
  40. Args:
  41. in_features: Number of input features.
  42. out_features: Number of output features.
  43. kernel_size: Kernel size for first conv layer.
  44. mlp_ratio: Ratio for hidden layer size.
  45. drop_rate: Dropout rate.
  46. act_layer: Activation layer type.
  47. conv_layer: Convolution layer type.
  48. """
  49. dd = {'device': device, 'dtype': dtype}
  50. super().__init__()
  51. self.input_kernel_size = kernel_size
  52. mid_features = int(out_features * mlp_ratio)
  53. self.fc1 = conv_layer(in_features, mid_features, kernel_size, bias=True, **dd)
  54. self.act1 = act_layer(True)
  55. self.drop = nn.Dropout(drop_rate)
  56. self.fc2 = conv_layer(mid_features, out_features, 1, bias=True, **dd)
  57. self.act2 = act_layer(True)
  58. def forward(self, x: torch.Tensor) -> torch.Tensor:
  59. """Forward pass.
  60. Args:
  61. x: Input tensor.
  62. Returns:
  63. Output tensor.
  64. """
  65. if x.shape[-2] < self.input_kernel_size or x.shape[-1] < self.input_kernel_size:
  66. # keep the input size >= 7x7
  67. output_size = (max(self.input_kernel_size, x.shape[-2]), max(self.input_kernel_size, x.shape[-1]))
  68. x = F.adaptive_avg_pool2d(x, output_size)
  69. x = self.fc1(x)
  70. x = self.act1(x)
  71. x = self.drop(x)
  72. x = self.fc2(x)
  73. x = self.act2(x)
  74. return x
  75. class VGG(nn.Module):
  76. """VGG model architecture.
  77. Based on `Very Deep Convolutional Networks for Large-Scale Image Recognition`
  78. - https://arxiv.org/abs/1409.1556
  79. """
  80. def __init__(
  81. self,
  82. cfg: List[Any],
  83. num_classes: int = 1000,
  84. in_chans: int = 3,
  85. output_stride: int = 32,
  86. mlp_ratio: float = 1.0,
  87. act_layer: Type[nn.Module] = nn.ReLU,
  88. conv_layer: Type[nn.Module] = nn.Conv2d,
  89. norm_layer: Optional[Type[nn.Module]] = None,
  90. global_pool: str = 'avg',
  91. drop_rate: float = 0.,
  92. device=None,
  93. dtype=None,
  94. ) -> None:
  95. """Initialize VGG model.
  96. Args:
  97. cfg: Configuration list defining network architecture.
  98. num_classes: Number of classes for classification.
  99. in_chans: Number of input channels.
  100. output_stride: Output stride of network.
  101. mlp_ratio: Ratio for MLP hidden layer size.
  102. act_layer: Activation layer type.
  103. conv_layer: Convolution layer type.
  104. norm_layer: Normalization layer type.
  105. global_pool: Global pooling type.
  106. drop_rate: Dropout rate.
  107. """
  108. super().__init__()
  109. dd = {'device': device, 'dtype': dtype}
  110. assert output_stride == 32
  111. self.num_classes = num_classes
  112. self.in_chans = in_chans
  113. self.drop_rate = drop_rate
  114. self.grad_checkpointing = False
  115. self.use_norm = norm_layer is not None
  116. self.feature_info = []
  117. prev_chs = in_chans
  118. net_stride = 1
  119. pool_layer = nn.MaxPool2d
  120. layers: List[nn.Module] = []
  121. for v in cfg:
  122. last_idx = len(layers) - 1
  123. if v == 'M':
  124. self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{last_idx}'))
  125. layers += [pool_layer(kernel_size=2, stride=2)]
  126. net_stride *= 2
  127. else:
  128. v = cast(int, v)
  129. conv2d = conv_layer(prev_chs, v, kernel_size=3, padding=1, **dd)
  130. if norm_layer is not None:
  131. layers += [conv2d, norm_layer(v, **dd), act_layer(inplace=True)]
  132. else:
  133. layers += [conv2d, act_layer(inplace=True)]
  134. prev_chs = v
  135. self.features = nn.Sequential(*layers)
  136. self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}'))
  137. self.num_features = prev_chs
  138. self.head_hidden_size = 4096
  139. self.pre_logits = ConvMlp(
  140. prev_chs,
  141. self.head_hidden_size,
  142. 7,
  143. mlp_ratio=mlp_ratio,
  144. drop_rate=drop_rate,
  145. act_layer=act_layer,
  146. conv_layer=conv_layer,
  147. **dd,
  148. )
  149. self.head = ClassifierHead(
  150. self.head_hidden_size,
  151. num_classes,
  152. pool_type=global_pool,
  153. drop_rate=drop_rate,
  154. **dd,
  155. )
  156. self._initialize_weights()
  157. @torch.jit.ignore
  158. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  159. """Group matcher for parameter groups.
  160. Args:
  161. coarse: Whether to use coarse grouping.
  162. Returns:
  163. Dictionary of grouped parameters.
  164. """
  165. # this treats BN layers as separate groups for bn variants, a lot of effort to fix that
  166. return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)')
  167. @torch.jit.ignore
  168. def set_grad_checkpointing(self, enable: bool = True) -> None:
  169. """Enable or disable gradient checkpointing.
  170. Args:
  171. enable: Whether to enable gradient checkpointing.
  172. """
  173. assert not enable, 'gradient checkpointing not supported'
  174. @torch.jit.ignore
  175. def get_classifier(self) -> nn.Module:
  176. """Get the classifier module.
  177. Returns:
  178. Classifier module.
  179. """
  180. return self.head.fc
  181. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  182. """Reset the classifier.
  183. Args:
  184. num_classes: Number of classes for new classifier.
  185. global_pool: Global pooling type.
  186. """
  187. self.num_classes = num_classes
  188. self.head.reset(num_classes, global_pool)
  189. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  190. """Forward pass through feature extraction layers.
  191. Args:
  192. x: Input tensor.
  193. Returns:
  194. Feature tensor.
  195. """
  196. x = self.features(x)
  197. return x
  198. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  199. """Forward pass through head.
  200. Args:
  201. x: Input features.
  202. pre_logits: Return features before final linear layer.
  203. Returns:
  204. Classification logits or features.
  205. """
  206. x = self.pre_logits(x)
  207. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  208. def forward(self, x: torch.Tensor) -> torch.Tensor:
  209. """Forward pass.
  210. Args:
  211. x: Input tensor.
  212. Returns:
  213. Output logits.
  214. """
  215. x = self.forward_features(x)
  216. x = self.forward_head(x)
  217. return x
  218. def _initialize_weights(self) -> None:
  219. """Initialize model weights."""
  220. for m in self.modules():
  221. if isinstance(m, nn.Conv2d):
  222. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  223. if m.bias is not None:
  224. nn.init.constant_(m.bias, 0)
  225. elif isinstance(m, nn.BatchNorm2d):
  226. nn.init.constant_(m.weight, 1)
  227. nn.init.constant_(m.bias, 0)
  228. elif isinstance(m, nn.Linear):
  229. nn.init.normal_(m.weight, 0, 0.01)
  230. nn.init.constant_(m.bias, 0)
  231. def _filter_fn(state_dict: dict) -> Dict[str, torch.Tensor]:
  232. """Convert patch embedding weight from manual patchify + linear proj to conv.
  233. Args:
  234. state_dict: State dictionary to filter.
  235. Returns:
  236. Filtered state dictionary.
  237. """
  238. out_dict = {}
  239. for k, v in state_dict.items():
  240. k_r = k
  241. k_r = k_r.replace('classifier.0', 'pre_logits.fc1')
  242. k_r = k_r.replace('classifier.3', 'pre_logits.fc2')
  243. k_r = k_r.replace('classifier.6', 'head.fc')
  244. if 'classifier.0.weight' in k:
  245. v = v.reshape(-1, 512, 7, 7)
  246. if 'classifier.3.weight' in k:
  247. v = v.reshape(-1, 4096, 1, 1)
  248. out_dict[k_r] = v
  249. return out_dict
  250. def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
  251. """Create a VGG model.
  252. Args:
  253. variant: Model variant name.
  254. pretrained: Load pretrained weights.
  255. **kwargs: Additional model arguments.
  256. Returns:
  257. VGG model instance.
  258. """
  259. cfg = variant.split('_')[0]
  260. # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
  261. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5))
  262. model = build_model_with_cfg(
  263. VGG,
  264. variant,
  265. pretrained,
  266. model_cfg=cfgs[cfg],
  267. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  268. pretrained_filter_fn=_filter_fn,
  269. **kwargs,
  270. )
  271. return model
  272. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  273. """Create default configuration dictionary.
  274. Args:
  275. url: Model weight URL.
  276. **kwargs: Additional configuration options.
  277. Returns:
  278. Configuration dictionary.
  279. """
  280. return {
  281. 'url': url,
  282. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  283. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  284. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  285. 'first_conv': 'features.0', 'classifier': 'head.fc',
  286. 'license': 'bsd-3-clause',
  287. **kwargs
  288. }
  289. default_cfgs = generate_default_cfgs({
  290. 'vgg11.tv_in1k': _cfg(hf_hub_id='timm/'),
  291. 'vgg13.tv_in1k': _cfg(hf_hub_id='timm/'),
  292. 'vgg16.tv_in1k': _cfg(hf_hub_id='timm/'),
  293. 'vgg19.tv_in1k': _cfg(hf_hub_id='timm/'),
  294. 'vgg11_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
  295. 'vgg13_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
  296. 'vgg16_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
  297. 'vgg19_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
  298. })
  299. @register_model
  300. def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG:
  301. r"""VGG 11-layer model (configuration "A") from
  302. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  303. """
  304. model_args = dict(**kwargs)
  305. return _create_vgg('vgg11', pretrained=pretrained, **model_args)
  306. @register_model
  307. def vgg11_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
  308. r"""VGG 11-layer model (configuration "A") with batch normalization
  309. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  310. """
  311. model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
  312. return _create_vgg('vgg11_bn', pretrained=pretrained, **model_args)
  313. @register_model
  314. def vgg13(pretrained: bool = False, **kwargs: Any) -> VGG:
  315. r"""VGG 13-layer model (configuration "B")
  316. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  317. """
  318. model_args = dict(**kwargs)
  319. return _create_vgg('vgg13', pretrained=pretrained, **model_args)
  320. @register_model
  321. def vgg13_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
  322. r"""VGG 13-layer model (configuration "B") with batch normalization
  323. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  324. """
  325. model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
  326. return _create_vgg('vgg13_bn', pretrained=pretrained, **model_args)
  327. @register_model
  328. def vgg16(pretrained: bool = False, **kwargs: Any) -> VGG:
  329. r"""VGG 16-layer model (configuration "D")
  330. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  331. """
  332. model_args = dict(**kwargs)
  333. return _create_vgg('vgg16', pretrained=pretrained, **model_args)
  334. @register_model
  335. def vgg16_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
  336. r"""VGG 16-layer model (configuration "D") with batch normalization
  337. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  338. """
  339. model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
  340. return _create_vgg('vgg16_bn', pretrained=pretrained, **model_args)
  341. @register_model
  342. def vgg19(pretrained: bool = False, **kwargs: Any) -> VGG:
  343. r"""VGG 19-layer model (configuration "E")
  344. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  345. """
  346. model_args = dict(**kwargs)
  347. return _create_vgg('vgg19', pretrained=pretrained, **model_args)
  348. @register_model
  349. def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
  350. r"""VGG 19-layer model (configuration 'E') with batch normalization
  351. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  352. """
  353. model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
  354. return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)