mobilenetv2.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. from functools import partial
  2. from typing import Any, Callable, Optional
  3. import torch
  4. from torch import nn, Tensor
  5. from ..ops.misc import Conv2dNormActivation
  6. from ..transforms._presets import ImageClassification
  7. from ..utils import _log_api_usage_once
  8. from ._api import register_model, Weights, WeightsEnum
  9. from ._meta import _IMAGENET_CATEGORIES
  10. from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
  11. __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
  12. # necessary for backwards compatibility
  13. class InvertedResidual(nn.Module):
  14. def __init__(
  15. self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
  16. ) -> None:
  17. super().__init__()
  18. self.stride = stride
  19. if stride not in [1, 2]:
  20. raise ValueError(f"stride should be 1 or 2 instead of {stride}")
  21. if norm_layer is None:
  22. norm_layer = nn.BatchNorm2d
  23. hidden_dim = int(round(inp * expand_ratio))
  24. self.use_res_connect = self.stride == 1 and inp == oup
  25. layers: list[nn.Module] = []
  26. if expand_ratio != 1:
  27. # pw
  28. layers.append(
  29. Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
  30. )
  31. layers.extend(
  32. [
  33. # dw
  34. Conv2dNormActivation(
  35. hidden_dim,
  36. hidden_dim,
  37. stride=stride,
  38. groups=hidden_dim,
  39. norm_layer=norm_layer,
  40. activation_layer=nn.ReLU6,
  41. ),
  42. # pw-linear
  43. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  44. norm_layer(oup),
  45. ]
  46. )
  47. self.conv = nn.Sequential(*layers)
  48. self.out_channels = oup
  49. self._is_cn = stride > 1
  50. def forward(self, x: Tensor) -> Tensor:
  51. if self.use_res_connect:
  52. return x + self.conv(x)
  53. else:
  54. return self.conv(x)
  55. class MobileNetV2(nn.Module):
  56. def __init__(
  57. self,
  58. num_classes: int = 1000,
  59. width_mult: float = 1.0,
  60. inverted_residual_setting: Optional[list[list[int]]] = None,
  61. round_nearest: int = 8,
  62. block: Optional[Callable[..., nn.Module]] = None,
  63. norm_layer: Optional[Callable[..., nn.Module]] = None,
  64. dropout: float = 0.2,
  65. ) -> None:
  66. """
  67. MobileNet V2 main class
  68. Args:
  69. num_classes (int): Number of classes
  70. width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
  71. inverted_residual_setting: Network structure
  72. round_nearest (int): Round the number of channels in each layer to be a multiple of this number
  73. Set to 1 to turn off rounding
  74. block: Module specifying inverted residual building block for mobilenet
  75. norm_layer: Module specifying the normalization layer to use
  76. dropout (float): The droupout probability
  77. """
  78. super().__init__()
  79. _log_api_usage_once(self)
  80. if block is None:
  81. block = InvertedResidual
  82. if norm_layer is None:
  83. norm_layer = nn.BatchNorm2d
  84. input_channel = 32
  85. last_channel = 1280
  86. if inverted_residual_setting is None:
  87. inverted_residual_setting = [
  88. # t, c, n, s
  89. [1, 16, 1, 1],
  90. [6, 24, 2, 2],
  91. [6, 32, 3, 2],
  92. [6, 64, 4, 2],
  93. [6, 96, 3, 1],
  94. [6, 160, 3, 2],
  95. [6, 320, 1, 1],
  96. ]
  97. # only check the first element, assuming user knows t,c,n,s are required
  98. if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
  99. raise ValueError(
  100. f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
  101. )
  102. # building first layer
  103. input_channel = _make_divisible(input_channel * width_mult, round_nearest)
  104. self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
  105. features: list[nn.Module] = [
  106. Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
  107. ]
  108. # building inverted residual blocks
  109. for t, c, n, s in inverted_residual_setting:
  110. output_channel = _make_divisible(c * width_mult, round_nearest)
  111. for i in range(n):
  112. stride = s if i == 0 else 1
  113. features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
  114. input_channel = output_channel
  115. # building last several layers
  116. features.append(
  117. Conv2dNormActivation(
  118. input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
  119. )
  120. )
  121. # make it nn.Sequential
  122. self.features = nn.Sequential(*features)
  123. # building classifier
  124. self.classifier = nn.Sequential(
  125. nn.Dropout(p=dropout),
  126. nn.Linear(self.last_channel, num_classes),
  127. )
  128. # weight initialization
  129. for m in self.modules():
  130. if isinstance(m, nn.Conv2d):
  131. nn.init.kaiming_normal_(m.weight, mode="fan_out")
  132. if m.bias is not None:
  133. nn.init.zeros_(m.bias)
  134. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  135. nn.init.ones_(m.weight)
  136. nn.init.zeros_(m.bias)
  137. elif isinstance(m, nn.Linear):
  138. nn.init.normal_(m.weight, 0, 0.01)
  139. nn.init.zeros_(m.bias)
  140. def _forward_impl(self, x: Tensor) -> Tensor:
  141. # This exists since TorchScript doesn't support inheritance, so the superclass method
  142. # (this one) needs to have a name other than `forward` that can be accessed in a subclass
  143. x = self.features(x)
  144. # Cannot use "squeeze" as batch-size can be 1
  145. x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
  146. x = torch.flatten(x, 1)
  147. x = self.classifier(x)
  148. return x
  149. def forward(self, x: Tensor) -> Tensor:
  150. return self._forward_impl(x)
  151. _COMMON_META = {
  152. "num_params": 3504872,
  153. "min_size": (1, 1),
  154. "categories": _IMAGENET_CATEGORIES,
  155. }
  156. class MobileNet_V2_Weights(WeightsEnum):
  157. IMAGENET1K_V1 = Weights(
  158. url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
  159. transforms=partial(ImageClassification, crop_size=224),
  160. meta={
  161. **_COMMON_META,
  162. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
  163. "_metrics": {
  164. "ImageNet-1K": {
  165. "acc@1": 71.878,
  166. "acc@5": 90.286,
  167. }
  168. },
  169. "_ops": 0.301,
  170. "_file_size": 13.555,
  171. "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
  172. },
  173. )
  174. IMAGENET1K_V2 = Weights(
  175. url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
  176. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  177. meta={
  178. **_COMMON_META,
  179. "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
  180. "_metrics": {
  181. "ImageNet-1K": {
  182. "acc@1": 72.154,
  183. "acc@5": 90.822,
  184. }
  185. },
  186. "_ops": 0.301,
  187. "_file_size": 13.598,
  188. "_docs": """
  189. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  190. `new training recipe
  191. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  192. """,
  193. },
  194. )
  195. DEFAULT = IMAGENET1K_V2
  196. @register_model()
  197. @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
  198. def mobilenet_v2(
  199. *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
  200. ) -> MobileNetV2:
  201. """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
  202. Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
  203. Args:
  204. weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
  205. pretrained weights to use. See
  206. :class:`~torchvision.models.MobileNet_V2_Weights` below for
  207. more details, and possible values. By default, no pre-trained
  208. weights are used.
  209. progress (bool, optional): If True, displays a progress bar of the
  210. download to stderr. Default is True.
  211. **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
  212. base class. Please refer to the `source code
  213. <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
  214. for more details about this class.
  215. .. autoclass:: torchvision.models.MobileNet_V2_Weights
  216. :members:
  217. """
  218. weights = MobileNet_V2_Weights.verify(weights)
  219. if weights is not None:
  220. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  221. model = MobileNetV2(**kwargs)
  222. if weights is not None:
  223. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  224. return model