convnext.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. from collections.abc import Sequence
  2. from functools import partial
  3. from typing import Any, Callable, Optional
  4. import torch
  5. from torch import nn, Tensor
  6. from torch.nn import functional as F
  7. from ..ops.misc import Conv2dNormActivation, Permute
  8. from ..ops.stochastic_depth import StochasticDepth
  9. from ..transforms._presets import ImageClassification
  10. from ..utils import _log_api_usage_once
  11. from ._api import register_model, Weights, WeightsEnum
  12. from ._meta import _IMAGENET_CATEGORIES
  13. from ._utils import _ovewrite_named_param, handle_legacy_interface
  14. __all__ = [
  15. "ConvNeXt",
  16. "ConvNeXt_Tiny_Weights",
  17. "ConvNeXt_Small_Weights",
  18. "ConvNeXt_Base_Weights",
  19. "ConvNeXt_Large_Weights",
  20. "convnext_tiny",
  21. "convnext_small",
  22. "convnext_base",
  23. "convnext_large",
  24. ]
  25. class LayerNorm2d(nn.LayerNorm):
  26. def forward(self, x: Tensor) -> Tensor:
  27. x = x.permute(0, 2, 3, 1)
  28. x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  29. x = x.permute(0, 3, 1, 2)
  30. return x
  31. class CNBlock(nn.Module):
  32. def __init__(
  33. self,
  34. dim,
  35. layer_scale: float,
  36. stochastic_depth_prob: float,
  37. norm_layer: Optional[Callable[..., nn.Module]] = None,
  38. ) -> None:
  39. super().__init__()
  40. if norm_layer is None:
  41. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  42. self.block = nn.Sequential(
  43. nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
  44. Permute([0, 2, 3, 1]),
  45. norm_layer(dim),
  46. nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
  47. nn.GELU(),
  48. nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
  49. Permute([0, 3, 1, 2]),
  50. )
  51. self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
  52. self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
  53. def forward(self, input: Tensor) -> Tensor:
  54. result = self.layer_scale * self.block(input)
  55. result = self.stochastic_depth(result)
  56. result += input
  57. return result
  58. class CNBlockConfig:
  59. # Stores information listed at Section 3 of the ConvNeXt paper
  60. def __init__(
  61. self,
  62. input_channels: int,
  63. out_channels: Optional[int],
  64. num_layers: int,
  65. ) -> None:
  66. self.input_channels = input_channels
  67. self.out_channels = out_channels
  68. self.num_layers = num_layers
  69. def __repr__(self) -> str:
  70. s = self.__class__.__name__ + "("
  71. s += "input_channels={input_channels}"
  72. s += ", out_channels={out_channels}"
  73. s += ", num_layers={num_layers}"
  74. s += ")"
  75. return s.format(**self.__dict__)
  76. class ConvNeXt(nn.Module):
  77. def __init__(
  78. self,
  79. block_setting: list[CNBlockConfig],
  80. stochastic_depth_prob: float = 0.0,
  81. layer_scale: float = 1e-6,
  82. num_classes: int = 1000,
  83. block: Optional[Callable[..., nn.Module]] = None,
  84. norm_layer: Optional[Callable[..., nn.Module]] = None,
  85. **kwargs: Any,
  86. ) -> None:
  87. super().__init__()
  88. _log_api_usage_once(self)
  89. if not block_setting:
  90. raise ValueError("The block_setting should not be empty")
  91. elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
  92. raise TypeError("The block_setting should be List[CNBlockConfig]")
  93. if block is None:
  94. block = CNBlock
  95. if norm_layer is None:
  96. norm_layer = partial(LayerNorm2d, eps=1e-6)
  97. layers: list[nn.Module] = []
  98. # Stem
  99. firstconv_output_channels = block_setting[0].input_channels
  100. layers.append(
  101. Conv2dNormActivation(
  102. 3,
  103. firstconv_output_channels,
  104. kernel_size=4,
  105. stride=4,
  106. padding=0,
  107. norm_layer=norm_layer,
  108. activation_layer=None,
  109. bias=True,
  110. )
  111. )
  112. total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
  113. stage_block_id = 0
  114. for cnf in block_setting:
  115. # Bottlenecks
  116. stage: list[nn.Module] = []
  117. for _ in range(cnf.num_layers):
  118. # adjust stochastic depth probability based on the depth of the stage block
  119. sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
  120. stage.append(block(cnf.input_channels, layer_scale, sd_prob))
  121. stage_block_id += 1
  122. layers.append(nn.Sequential(*stage))
  123. if cnf.out_channels is not None:
  124. # Downsampling
  125. layers.append(
  126. nn.Sequential(
  127. norm_layer(cnf.input_channels),
  128. nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
  129. )
  130. )
  131. self.features = nn.Sequential(*layers)
  132. self.avgpool = nn.AdaptiveAvgPool2d(1)
  133. lastblock = block_setting[-1]
  134. lastconv_output_channels = (
  135. lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
  136. )
  137. self.classifier = nn.Sequential(
  138. norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
  139. )
  140. for m in self.modules():
  141. if isinstance(m, (nn.Conv2d, nn.Linear)):
  142. nn.init.trunc_normal_(m.weight, std=0.02)
  143. if m.bias is not None:
  144. nn.init.zeros_(m.bias)
  145. def _forward_impl(self, x: Tensor) -> Tensor:
  146. x = self.features(x)
  147. x = self.avgpool(x)
  148. x = self.classifier(x)
  149. return x
  150. def forward(self, x: Tensor) -> Tensor:
  151. return self._forward_impl(x)
  152. def _convnext(
  153. block_setting: list[CNBlockConfig],
  154. stochastic_depth_prob: float,
  155. weights: Optional[WeightsEnum],
  156. progress: bool,
  157. **kwargs: Any,
  158. ) -> ConvNeXt:
  159. if weights is not None:
  160. _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
  161. model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
  162. if weights is not None:
  163. model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
  164. return model
  165. _COMMON_META = {
  166. "min_size": (32, 32),
  167. "categories": _IMAGENET_CATEGORIES,
  168. "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
  169. "_docs": """
  170. These weights improve upon the results of the original paper by using a modified version of TorchVision's
  171. `new training recipe
  172. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
  173. """,
  174. }
  175. class ConvNeXt_Tiny_Weights(WeightsEnum):
  176. IMAGENET1K_V1 = Weights(
  177. url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
  178. transforms=partial(ImageClassification, crop_size=224, resize_size=236),
  179. meta={
  180. **_COMMON_META,
  181. "num_params": 28589128,
  182. "_metrics": {
  183. "ImageNet-1K": {
  184. "acc@1": 82.520,
  185. "acc@5": 96.146,
  186. }
  187. },
  188. "_ops": 4.456,
  189. "_file_size": 109.119,
  190. },
  191. )
  192. DEFAULT = IMAGENET1K_V1
  193. class ConvNeXt_Small_Weights(WeightsEnum):
  194. IMAGENET1K_V1 = Weights(
  195. url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
  196. transforms=partial(ImageClassification, crop_size=224, resize_size=230),
  197. meta={
  198. **_COMMON_META,
  199. "num_params": 50223688,
  200. "_metrics": {
  201. "ImageNet-1K": {
  202. "acc@1": 83.616,
  203. "acc@5": 96.650,
  204. }
  205. },
  206. "_ops": 8.684,
  207. "_file_size": 191.703,
  208. },
  209. )
  210. DEFAULT = IMAGENET1K_V1
  211. class ConvNeXt_Base_Weights(WeightsEnum):
  212. IMAGENET1K_V1 = Weights(
  213. url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
  214. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  215. meta={
  216. **_COMMON_META,
  217. "num_params": 88591464,
  218. "_metrics": {
  219. "ImageNet-1K": {
  220. "acc@1": 84.062,
  221. "acc@5": 96.870,
  222. }
  223. },
  224. "_ops": 15.355,
  225. "_file_size": 338.064,
  226. },
  227. )
  228. DEFAULT = IMAGENET1K_V1
  229. class ConvNeXt_Large_Weights(WeightsEnum):
  230. IMAGENET1K_V1 = Weights(
  231. url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
  232. transforms=partial(ImageClassification, crop_size=224, resize_size=232),
  233. meta={
  234. **_COMMON_META,
  235. "num_params": 197767336,
  236. "_metrics": {
  237. "ImageNet-1K": {
  238. "acc@1": 84.414,
  239. "acc@5": 96.976,
  240. }
  241. },
  242. "_ops": 34.361,
  243. "_file_size": 754.537,
  244. },
  245. )
  246. DEFAULT = IMAGENET1K_V1
  247. @register_model()
  248. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
  249. def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
  250. """ConvNeXt Tiny model architecture from the
  251. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  252. Args:
  253. weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
  254. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
  255. below for more details and possible values. By default, no pre-trained weights are used.
  256. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  257. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  258. base class. Please refer to the `source code
  259. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  260. for more details about this class.
  261. .. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
  262. :members:
  263. """
  264. weights = ConvNeXt_Tiny_Weights.verify(weights)
  265. block_setting = [
  266. CNBlockConfig(96, 192, 3),
  267. CNBlockConfig(192, 384, 3),
  268. CNBlockConfig(384, 768, 9),
  269. CNBlockConfig(768, None, 3),
  270. ]
  271. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
  272. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  273. @register_model()
  274. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
  275. def convnext_small(
  276. *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
  277. ) -> ConvNeXt:
  278. """ConvNeXt Small model architecture from the
  279. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  280. Args:
  281. weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
  282. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
  283. below for more details and possible values. By default, no pre-trained weights are used.
  284. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  285. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  286. base class. Please refer to the `source code
  287. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  288. for more details about this class.
  289. .. autoclass:: torchvision.models.ConvNeXt_Small_Weights
  290. :members:
  291. """
  292. weights = ConvNeXt_Small_Weights.verify(weights)
  293. block_setting = [
  294. CNBlockConfig(96, 192, 3),
  295. CNBlockConfig(192, 384, 3),
  296. CNBlockConfig(384, 768, 27),
  297. CNBlockConfig(768, None, 3),
  298. ]
  299. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
  300. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  301. @register_model()
  302. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
  303. def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
  304. """ConvNeXt Base model architecture from the
  305. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  306. Args:
  307. weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
  308. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
  309. below for more details and possible values. By default, no pre-trained weights are used.
  310. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  311. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  312. base class. Please refer to the `source code
  313. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  314. for more details about this class.
  315. .. autoclass:: torchvision.models.ConvNeXt_Base_Weights
  316. :members:
  317. """
  318. weights = ConvNeXt_Base_Weights.verify(weights)
  319. block_setting = [
  320. CNBlockConfig(128, 256, 3),
  321. CNBlockConfig(256, 512, 3),
  322. CNBlockConfig(512, 1024, 27),
  323. CNBlockConfig(1024, None, 3),
  324. ]
  325. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
  326. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
  327. @register_model()
  328. @handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
  329. def convnext_large(
  330. *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
  331. ) -> ConvNeXt:
  332. """ConvNeXt Large model architecture from the
  333. `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
  334. Args:
  335. weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
  336. weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
  337. below for more details and possible values. By default, no pre-trained weights are used.
  338. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
  339. **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
  340. base class. Please refer to the `source code
  341. <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
  342. for more details about this class.
  343. .. autoclass:: torchvision.models.ConvNeXt_Large_Weights
  344. :members:
  345. """
  346. weights = ConvNeXt_Large_Weights.verify(weights)
  347. block_setting = [
  348. CNBlockConfig(192, 384, 3),
  349. CNBlockConfig(384, 768, 3),
  350. CNBlockConfig(768, 1536, 27),
  351. CNBlockConfig(1536, None, 3),
  352. ]
  353. stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
  354. return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)