mobilenetv3.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. from collections.abc import Sequence
  2. from functools import partial
  3. from typing import Any, Callable, Optional
  4. import torch
  5. from torch import nn, Tensor
  6. from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
  7. from ..transforms._presets import ImageClassification
  8. from ..utils import _log_api_usage_once
  9. from ._api import register_model, Weights, WeightsEnum
  10. from ._meta import _IMAGENET_CATEGORIES
  11. from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
  12. __all__ = [
  13. "MobileNetV3",
  14. "MobileNet_V3_Large_Weights",
  15. "MobileNet_V3_Small_Weights",
  16. "mobilenet_v3_large",
  17. "mobilenet_v3_small",
  18. ]
  19. class InvertedResidualConfig:
  20. # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
  21. def __init__(
  22. self,
  23. input_channels: int,
  24. kernel: int,
  25. expanded_channels: int,
  26. out_channels: int,
  27. use_se: bool,
  28. activation: str,
  29. stride: int,
  30. dilation: int,
  31. width_mult: float,
  32. ):
  33. self.input_channels = self.adjust_channels(input_channels, width_mult)
  34. self.kernel = kernel
  35. self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
  36. self.out_channels = self.adjust_channels(out_channels, width_mult)
  37. self.use_se = use_se
  38. self.use_hs = activation == "HS"
  39. self.stride = stride
  40. self.dilation = dilation
  41. @staticmethod
  42. def adjust_channels(channels: int, width_mult: float):
  43. return _make_divisible(channels * width_mult, 8)
  44. class InvertedResidual(nn.Module):
  45. # Implemented as described at section 5 of MobileNetV3 paper
  46. def __init__(
  47. self,
  48. cnf: InvertedResidualConfig,
  49. norm_layer: Callable[..., nn.Module],
  50. se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid),
  51. ):
  52. super().__init__()
  53. if not (1 <= cnf.stride <= 2):
  54. raise ValueError("illegal stride value")
  55. self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
  56. layers: list[nn.Module] = []
  57. activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
  58. # expand
  59. if cnf.expanded_channels != cnf.input_channels:
  60. layers.append(
  61. Conv2dNormActivation(
  62. cnf.input_channels,
  63. cnf.expanded_channels,
  64. kernel_size=1,
  65. norm_layer=norm_layer,
  66. activation_layer=activation_layer,
  67. )
  68. )
  69. # depthwise
  70. stride = 1 if cnf.dilation > 1 else cnf.stride
  71. layers.append(
  72. Conv2dNormActivation(
  73. cnf.expanded_channels,
  74. cnf.expanded_channels,
  75. kernel_size=cnf.kernel,
  76. stride=stride,
  77. dilation=cnf.dilation,
  78. groups=cnf.expanded_channels,
  79. norm_layer=norm_layer,
  80. activation_layer=activation_layer,
  81. )
  82. )
  83. if cnf.use_se:
  84. squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
  85. layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
  86. # project
  87. layers.append(
  88. Conv2dNormActivation(
  89. cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
  90. )
  91. )
  92. self.block = nn.Sequential(*layers)
  93. self.out_channels = cnf.out_channels
  94. self._is_cn = cnf.stride > 1
  95. def forward(self, input: Tensor) -> Tensor:
  96. result = self.block(input)
  97. if self.use_res_connect:
  98. result += input
  99. return result
  100. class MobileNetV3(nn.Module):
  101. def __init__(
  102. self,
  103. inverted_residual_setting: list[InvertedResidualConfig],
  104. last_channel: int,
  105. num_classes: int = 1000,
  106. block: Optional[Callable[..., nn.Module]] = None,
  107. norm_layer: Optional[Callable[..., nn.Module]] = None,
  108. dropout: float = 0.2,
  109. **kwargs: Any,
  110. ) -> None:
  111. """
  112. MobileNet V3 main class
  113. Args:
  114. inverted_residual_setting (List[InvertedResidualConfig]): Network structure
  115. last_channel (int): The number of channels on the penultimate layer
  116. num_classes (int): Number of classes
  117. block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
  118. norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
  119. dropout (float): The droupout probability
  120. """
  121. super().__init__()
  122. _log_api_usage_once(self)
  123. if not inverted_residual_setting:
  124. raise ValueError("The inverted_residual_setting should not be empty")
  125. elif not (
  126. isinstance(inverted_residual_setting, Sequence)
  127. and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
  128. ):
  129. raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
  130. if block is None:
  131. block = InvertedResidual
  132. if norm_layer is None:
  133. norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
  134. layers: list[nn.Module] = []
  135. # building first layer
  136. firstconv_output_channels = inverted_residual_setting[0].input_channels
  137. layers.append(
  138. Conv2dNormActivation(
  139. 3,
  140. firstconv_output_channels,
  141. kernel_size=3,
  142. stride=2,
  143. norm_layer=norm_layer,
  144. activation_layer=nn.Hardswish,
  145. )
  146. )
  147. # building inverted residual blocks
  148. for cnf in inverted_residual_setting:
  149. layers.append(block(cnf, norm_layer))
  150. # building last several layers
  151. lastconv_input_channels = inverted_residual_setting[-1].out_channels
  152. lastconv_output_channels = 6 * lastconv_input_channels
  153. layers.append(
  154. Conv2dNormActivation(
  155. lastconv_input_channels,
  156. lastconv_output_channels,
  157. kernel_size=1,
  158. norm_layer=norm_layer,
  159. activation_layer=nn.Hardswish,
  160. )
  161. )
  162. self.features = nn.Sequential(*layers)
  163. self.avgpool = nn.AdaptiveAvgPool2d(1)
  164. self.classifier = nn.Sequential(
  165. nn.Linear(lastconv_output_channels, last_channel),
  166. nn.Hardswish(inplace=True),
  167. nn.Dropout(p=dropout, inplace=True),
  168. nn.Linear(last_channel, num_classes),
  169. )
  170. for m in self.modules():
  171. if isinstance(m, nn.Conv2d):
  172. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  173. if m.bias is not None:
  174. nn.init.zeros_(m.bias)
  175. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  176. nn.init.ones_(m.weight)
  177. nn.init.zeros_(m.bias)
  178. elif isinstance(m, nn.Linear):
  179. nn.init.normal_(m.weight, 0, 0.01)
  180. nn.init.zeros_(m.bias)
  181. def _forward_impl(self, x: Tensor) -> Tensor:
  182. x = self.features(x)
  183. x = self.avgpool(x)
  184. x = torch.flatten(x, 1)
  185. x = self.classifier(x)
  186. return x
  187. def forward(self, x: Tensor) -> Tensor:
  188. return self._forward_impl(x)
  189. def _mobilenet_v3_conf(
  190. arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any
  191. ):
  192. reduce_divider = 2 if reduced_tail else 1
  193. dilation = 2 if dilated else 1
  194. bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
  195. adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
  196. if arch == "mobilenet_v3_large":
  197. inverted_residual_setting = [
  198. bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
  199. bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
  200. bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
  201. bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
  202. bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
  203. bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
  204. bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
  205. bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
  206. bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
  207. bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
  208. bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
  209. bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
  210. bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
  211. bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
  212. bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
  213. ]
  214. last_channel = adjust_channels(1280 // reduce_divider) # C5
  215. elif arch == "mobilenet_v3_small":
  216. inverted_residual_setting = [
  217. bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
  218. bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
  219. bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
  220. bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
  221. bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
  222. bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
  223. bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
  224. bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
  225. bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
  226. bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
  227. bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
  228. ]
  229. last_channel = adjust_channels(1024 // reduce_divider) # C5
  230. else:
  231. raise ValueError(f"Unsupported model type {arch}")
  232. return inverted_residual_setting, last_channel
  233. def _mobilenet_v3(
  234. inverted_residual_setting: list[InvertedResidualConfig],
  235. last_channel: int,
  236. weights: Optional[WeightsEnum],
  237. progress: bool,
  238. **kwargs: Any,
  239. ) -> MobileNetV3:
  240. if weights is not None:
  241. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  242. model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
  243. if weights is not None:
  244. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  245. return model
  246. _COMMON_META = {
  247. "min_size": (1, 1),
  248. "categories": _IMAGENET_CATEGORIES,
  249. }
  250. class MobileNet_V3_Large_Weights(WeightsEnum):
  251. IMAGENET1K_V1 = Weights(
  252. url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
  253. transforms=partial(ImageClassification, crop_size=224),
  254. meta={
  255. **_COMMON_META,
  256. "num_params": 5483032,
  257. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
  258. "_metrics": {
  259. "ImageNet-1K": {
  260. "acc@1": 74.042,
  261. "acc@5": 91.340,
  262. }
  263. },
  264. "_ops": 0.217,
  265. "_file_size": 21.114,
  266. "_docs": """These weights were trained from scratch by using a simple training recipe.""",
  267. },
  268. )
  269. IMAGENET1K_V2 = Weights(
  270. url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
  271. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  272. meta={
  273. **_COMMON_META,
  274. "num_params": 5483032,
  275. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
  276. "_metrics": {
  277. "ImageNet-1K": {
  278. "acc@1": 75.274,
  279. "acc@5": 92.566,
  280. }
  281. },
  282. "_ops": 0.217,
  283. "_file_size": 21.107,
  284. "_docs": """
  285. These weights improve marginally upon the results of the original paper by using a modified version of
  286. TorchVision's `new training recipe
  287. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  288. """,
  289. },
  290. )
  291. DEFAULT = IMAGENET1K_V2
  292. class MobileNet_V3_Small_Weights(WeightsEnum):
  293. IMAGENET1K_V1 = Weights(
  294. url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
  295. transforms=partial(ImageClassification, crop_size=224),
  296. meta={
  297. **_COMMON_META,
  298. "num_params": 2542856,
  299. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
  300. "_metrics": {
  301. "ImageNet-1K": {
  302. "acc@1": 67.668,
  303. "acc@5": 87.402,
  304. }
  305. },
  306. "_ops": 0.057,
  307. "_file_size": 9.829,
  308. "_docs": """
  309. These weights improve upon the results of the original paper by using a simple training recipe.
  310. """,
  311. },
  312. )
  313. DEFAULT = IMAGENET1K_V1
  314. @register_model()
  315. @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
  316. def mobilenet_v3_large(
  317. *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
  318. ) -> MobileNetV3:
  319. """
  320. Constructs a large MobileNetV3 architecture from
  321. `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__.
  322. Args:
  323. weights (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
  324. pretrained weights to use. See
  325. :class:`~torchvision.models.MobileNet_V3_Large_Weights` below for
  326. more details, and possible values. By default, no pre-trained
  327. weights are used.
  328. progress (bool, optional): If True, displays a progress bar of the
  329. download to stderr. Default is True.
  330. **kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3``
  331. base class. Please refer to the `source code
  332. <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_
  333. for more details about this class.
  334. .. autoclass:: torchvision.models.MobileNet_V3_Large_Weights
  335. :members:
  336. """
  337. weights = MobileNet_V3_Large_Weights.verify(weights)
  338. inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
  339. return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
  340. @register_model()
  341. @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
  342. def mobilenet_v3_small(
  343. *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
  344. ) -> MobileNetV3:
  345. """
  346. Constructs a small MobileNetV3 architecture from
  347. `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__.
  348. Args:
  349. weights (:class:`~torchvision.models.MobileNet_V3_Small_Weights`, optional): The
  350. pretrained weights to use. See
  351. :class:`~torchvision.models.MobileNet_V3_Small_Weights` below for
  352. more details, and possible values. By default, no pre-trained
  353. weights are used.
  354. progress (bool, optional): If True, displays a progress bar of the
  355. download to stderr. Default is True.
  356. **kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3``
  357. base class. Please refer to the `source code
  358. <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_
  359. for more details about this class.
  360. .. autoclass:: torchvision.models.MobileNet_V3_Small_Weights
  361. :members:
  362. """
  363. weights = MobileNet_V3_Small_Weights.verify(weights)
  364. inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
  365. return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)