vgg.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, List, Optional, Union, cast
  18. import torch
  19. from torch import nn
  20. from kornia.core import Module
  21. class VGG(nn.Module):
  22. def __init__(
  23. self, features: Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
  24. ) -> None:
  25. super().__init__()
  26. self.features = features
  27. self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
  28. self.classifier = nn.Sequential(
  29. nn.Linear(512 * 7 * 7, 4096),
  30. nn.ReLU(True),
  31. nn.Dropout(p=dropout),
  32. nn.Linear(4096, 4096),
  33. nn.ReLU(True),
  34. nn.Dropout(p=dropout),
  35. nn.Linear(4096, num_classes),
  36. )
  37. if init_weights:
  38. for m in self.modules():
  39. if isinstance(m, nn.Conv2d):
  40. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  41. if m.bias is not None:
  42. nn.init.constant_(m.bias, 0)
  43. elif isinstance(m, nn.BatchNorm2d):
  44. nn.init.constant_(m.weight, 1)
  45. nn.init.constant_(m.bias, 0)
  46. elif isinstance(m, nn.Linear):
  47. nn.init.normal_(m.weight, 0, 0.01)
  48. nn.init.constant_(m.bias, 0)
  49. def forward(self, x: torch.Tensor) -> torch.Tensor:
  50. x = self.features(x)
  51. x = self.avgpool(x)
  52. x = torch.flatten(x, 1)
  53. x = self.classifier(x)
  54. return x
  55. def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
  56. """Make model layers."""
  57. layers: List[nn.Module] = []
  58. in_channels = 3
  59. for v in cfg:
  60. if v == "M":
  61. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  62. else:
  63. v = cast(int, v)
  64. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  65. if batch_norm:
  66. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  67. else:
  68. layers += [conv2d, nn.ReLU(inplace=True)]
  69. in_channels = v
  70. return nn.Sequential(*layers)
  71. cfgs: Dict[str, List[Union[str, int]]] = {
  72. "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
  73. "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
  74. "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
  75. "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
  76. }
  77. def _vgg(cfg: str, batch_norm: bool, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  78. model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
  79. return model
  80. def vgg11(*, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  81. """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  82. Args:
  83. weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
  84. pretrained weights to use. See
  85. :class:`~torchvision.models.VGG11_Weights` below for
  86. more details, and possible values. By default, no pre-trained
  87. weights are used.
  88. progress (bool, optional): If True, displays a progress bar of the
  89. download to stderr. Default is True.
  90. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  91. base class. Please refer to the `source code
  92. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  93. for more details about this class.
  94. .. autoclass:: torchvision.models.VGG11_Weights
  95. :members:
  96. """
  97. return _vgg("A", False, weights, **kwargs)
  98. def vgg11_bn(*, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  99. """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  100. Args:
  101. weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
  102. pretrained weights to use. See
  103. :class:`~torchvision.models.VGG11_BN_Weights` below for
  104. more details, and possible values. By default, no pre-trained
  105. weights are used.
  106. progress (bool, optional): If True, displays a progress bar of the
  107. download to stderr. Default is True.
  108. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  109. base class. Please refer to the `source code
  110. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  111. for more details about this class.
  112. .. autoclass:: torchvision.models.VGG11_BN_Weights
  113. :members:
  114. """
  115. return _vgg("A", True, weights, **kwargs)
  116. def vgg13(*, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  117. """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  118. Args:
  119. weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
  120. pretrained weights to use. See
  121. :class:`~torchvision.models.VGG13_Weights` below for
  122. more details, and possible values. By default, no pre-trained
  123. weights are used.
  124. progress (bool, optional): If True, displays a progress bar of the
  125. download to stderr. Default is True.
  126. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  127. base class. Please refer to the `source code
  128. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  129. for more details about this class.
  130. .. autoclass:: torchvision.models.VGG13_Weights
  131. :members:
  132. """
  133. return _vgg("B", False, weights, **kwargs)
  134. def vgg13_bn(*, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  135. """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  136. Args:
  137. weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
  138. pretrained weights to use. See
  139. :class:`~torchvision.models.VGG13_BN_Weights` below for
  140. more details, and possible values. By default, no pre-trained
  141. weights are used.
  142. progress (bool, optional): If True, displays a progress bar of the
  143. download to stderr. Default is True.
  144. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  145. base class. Please refer to the `source code
  146. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  147. for more details about this class.
  148. .. autoclass:: torchvision.models.VGG13_BN_Weights
  149. :members:
  150. """
  151. return _vgg("B", True, weights, **kwargs)
  152. def vgg16(*, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  153. """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  154. Args:
  155. weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
  156. pretrained weights to use. See
  157. :class:`~torchvision.models.VGG16_Weights` below for
  158. more details, and possible values. By default, no pre-trained
  159. weights are used.
  160. progress (bool, optional): If True, displays a progress bar of the
  161. download to stderr. Default is True.
  162. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  163. base class. Please refer to the `source code
  164. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  165. for more details about this class.
  166. .. autoclass:: torchvision.models.VGG16_Weights
  167. :members:
  168. """
  169. return _vgg("D", False, weights, **kwargs)
  170. def vgg16_bn(*, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  171. """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  172. Args:
  173. weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
  174. pretrained weights to use. See
  175. :class:`~torchvision.models.VGG16_BN_Weights` below for
  176. more details, and possible values. By default, no pre-trained
  177. weights are used.
  178. progress (bool, optional): If True, displays a progress bar of the
  179. download to stderr. Default is True.
  180. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  181. base class. Please refer to the `source code
  182. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  183. for more details about this class.
  184. .. autoclass:: torchvision.models.VGG16_BN_Weights
  185. :members:
  186. """
  187. return _vgg("D", True, weights, **kwargs)
  188. def vgg19(*, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  189. """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  190. Args:
  191. weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
  192. pretrained weights to use. See
  193. :class:`~torchvision.models.VGG19_Weights` below for
  194. more details, and possible values. By default, no pre-trained
  195. weights are used.
  196. progress (bool, optional): If True, displays a progress bar of the
  197. download to stderr. Default is True.
  198. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  199. base class. Please refer to the `source code
  200. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  201. for more details about this class.
  202. .. autoclass:: torchvision.models.VGG19_Weights
  203. :members:
  204. """
  205. return _vgg("E", False, weights, **kwargs)
  206. def vgg19_bn(*, weights: Optional[Any] = None, **kwargs: Any) -> VGG:
  207. """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
  208. Args:
  209. weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
  210. pretrained weights to use. See
  211. :class:`~torchvision.models.VGG19_BN_Weights` below for
  212. more details, and possible values. By default, no pre-trained
  213. weights are used.
  214. progress (bool, optional): If True, displays a progress bar of the
  215. download to stderr. Default is True.
  216. **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
  217. base class. Please refer to the `source code
  218. <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
  219. for more details about this class.
  220. .. autoclass:: torchvision.models.VGG19_BN_Weights
  221. :members:
  222. """
  223. return _vgg("E", True, weights, **kwargs)