modeling_convnext.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ConvNext model."""
  15. import torch
  16. from torch import nn
  17. from ... import initialization as init
  18. from ...activations import ACT2FN
  19. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  20. from ...modeling_outputs import (
  21. BackboneOutput,
  22. BaseModelOutputWithNoAttention,
  23. BaseModelOutputWithPoolingAndNoAttention,
  24. ImageClassifierOutputWithNoAttention,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import TransformersKwargs, auto_docstring, logging
  29. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  30. from ...utils.output_capturing import capture_outputs
  31. from .configuration_convnext import ConvNextConfig
  32. logger = logging.get_logger(__name__)
  33. # Copied from transformers.models.beit.modeling_beit.drop_path
  34. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  35. """
  36. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  37. """
  38. if drop_prob == 0.0 or not training:
  39. return input
  40. keep_prob = 1 - drop_prob
  41. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  42. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  43. random_tensor.floor_() # binarize
  44. output = input.div(keep_prob) * random_tensor
  45. return output
  46. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNext
  47. class ConvNextDropPath(nn.Module):
  48. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  49. def __init__(self, drop_prob: float | None = None) -> None:
  50. super().__init__()
  51. self.drop_prob = drop_prob
  52. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  53. return drop_path(hidden_states, self.drop_prob, self.training)
  54. def extra_repr(self) -> str:
  55. return f"p={self.drop_prob}"
  56. class ConvNextLayerNorm(nn.LayerNorm):
  57. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  58. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  59. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  60. """
  61. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  62. super().__init__(normalized_shape, eps=eps, **kwargs)
  63. if data_format not in ["channels_last", "channels_first"]:
  64. raise NotImplementedError(f"Unsupported data format: {data_format}")
  65. self.data_format = data_format
  66. def forward(self, features: torch.Tensor) -> torch.Tensor:
  67. """
  68. Args:
  69. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  70. """
  71. if self.data_format == "channels_first":
  72. features = features.permute(0, 2, 3, 1)
  73. features = super().forward(features)
  74. features = features.permute(0, 3, 1, 2)
  75. else:
  76. features = super().forward(features)
  77. return features
  78. class ConvNextEmbeddings(nn.Module):
  79. """This class is comparable to (and inspired by) the SwinEmbeddings class
  80. found in src/transformers/models/swin/modeling_swin.py.
  81. """
  82. def __init__(self, config):
  83. super().__init__()
  84. self.patch_embeddings = nn.Conv2d(
  85. config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
  86. )
  87. self.layernorm = ConvNextLayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
  88. self.num_channels = config.num_channels
  89. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  90. num_channels = pixel_values.shape[1]
  91. if num_channels != self.num_channels:
  92. raise ValueError(
  93. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  94. )
  95. embeddings = self.patch_embeddings(pixel_values)
  96. embeddings = self.layernorm(embeddings)
  97. return embeddings
  98. class ConvNextLayer(nn.Module):
  99. """This corresponds to the `Block` class in the original implementation.
  100. There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
  101. H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
  102. The authors used (2) as they find it slightly faster in PyTorch.
  103. Args:
  104. config ([`ConvNextConfig`]): Model configuration class.
  105. dim (`int`): Number of input channels.
  106. drop_path (`float`): Stochastic depth rate. Default: 0.0.
  107. """
  108. def __init__(self, config, dim, drop_path=0):
  109. super().__init__()
  110. self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
  111. self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
  112. self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
  113. self.act = ACT2FN[config.hidden_act]
  114. self.pwconv2 = nn.Linear(4 * dim, dim)
  115. self.layer_scale_parameter = (
  116. nn.Parameter(config.layer_scale_init_value * torch.ones(dim), requires_grad=True)
  117. if config.layer_scale_init_value > 0
  118. else None
  119. )
  120. self.drop_path = ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  121. def forward(self, features: torch.Tensor) -> torch.Tensor:
  122. residual = features
  123. features = self.dwconv(features)
  124. features = features.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  125. features = self.layernorm(features)
  126. features = self.pwconv1(features)
  127. features = self.act(features)
  128. features = self.pwconv2(features)
  129. if self.layer_scale_parameter is not None:
  130. features = self.layer_scale_parameter * features
  131. features = features.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  132. features = residual + self.drop_path(features)
  133. return features
  134. class ConvNextStage(nn.Module):
  135. """ConvNeXT stage, consisting of an optional downsampling layer + multiple residual blocks.
  136. Args:
  137. config ([`ConvNextConfig`]): Model configuration class.
  138. in_channels (`int`): Number of input channels.
  139. out_channels (`int`): Number of output channels.
  140. depth (`int`): Number of residual blocks.
  141. drop_path_rates(`list[float]`): Stochastic depth rates for each layer.
  142. """
  143. def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
  144. super().__init__()
  145. if in_channels != out_channels or stride > 1:
  146. self.downsampling_layer = nn.ModuleList(
  147. [
  148. ConvNextLayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
  149. nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
  150. ]
  151. )
  152. else:
  153. self.downsampling_layer = nn.ModuleList()
  154. drop_path_rates = drop_path_rates or [0.0] * depth
  155. self.layers = nn.ModuleList(
  156. [ConvNextLayer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
  157. )
  158. def forward(self, features: torch.Tensor) -> torch.Tensor:
  159. for layer in self.downsampling_layer:
  160. features = layer(features)
  161. for layer in self.layers:
  162. features = layer(features)
  163. return features
  164. @auto_docstring
  165. class ConvNextPreTrainedModel(PreTrainedModel):
  166. config: ConvNextConfig
  167. base_model_prefix = "convnext"
  168. main_input_name = "pixel_values"
  169. input_modalities = ("image",)
  170. _no_split_modules = ["ConvNextLayer", "ConvNextStage"]
  171. @torch.no_grad()
  172. def _init_weights(self, module):
  173. """Initialize the weights"""
  174. super()._init_weights(module)
  175. if isinstance(module, ConvNextLayer):
  176. if module.layer_scale_parameter is not None:
  177. init.constant_(module.layer_scale_parameter, self.config.layer_scale_init_value)
  178. class ConvNextEncoder(ConvNextPreTrainedModel):
  179. main_input_name = "hidden_states"
  180. _can_record_outputs = {"hidden_states": ConvNextStage}
  181. def __init__(self, config):
  182. super().__init__(config)
  183. self.stages = nn.ModuleList()
  184. drop_path_rates = [
  185. x.tolist()
  186. for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
  187. ]
  188. prev_chs = config.hidden_sizes[0]
  189. for i in range(config.num_stages):
  190. out_chs = config.hidden_sizes[i]
  191. stage = ConvNextStage(
  192. config,
  193. in_channels=prev_chs,
  194. out_channels=out_chs,
  195. stride=2 if i > 0 else 1,
  196. depth=config.depths[i],
  197. drop_path_rates=drop_path_rates[i],
  198. )
  199. self.stages.append(stage)
  200. prev_chs = out_chs
  201. self.post_init()
  202. @merge_with_config_defaults
  203. @capture_outputs(tie_last_hidden_states=False)
  204. def forward(
  205. self,
  206. hidden_states: torch.Tensor,
  207. **kwargs: Unpack[TransformersKwargs],
  208. ) -> BaseModelOutputWithNoAttention:
  209. for layer_module in self.stages:
  210. hidden_states = layer_module(hidden_states)
  211. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states)
  212. @auto_docstring
  213. class ConvNextModel(ConvNextPreTrainedModel):
  214. def __init__(self, config):
  215. super().__init__(config)
  216. self.config = config
  217. self.embeddings = ConvNextEmbeddings(config)
  218. self.encoder = ConvNextEncoder(config)
  219. # final layernorm layer
  220. self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
  221. # Initialize weights and apply final processing
  222. self.post_init()
  223. @can_return_tuple
  224. @auto_docstring
  225. def forward(
  226. self, pixel_values: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs]
  227. ) -> BaseModelOutputWithPoolingAndNoAttention:
  228. if pixel_values is None:
  229. raise ValueError("You have to specify pixel_values")
  230. embedding_output = self.embeddings(pixel_values)
  231. encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(embedding_output, **kwargs)
  232. last_hidden_state = encoder_outputs.last_hidden_state
  233. # global average pooling, (N, C, H, W) -> (N, C)
  234. pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
  235. return BaseModelOutputWithPoolingAndNoAttention(
  236. last_hidden_state=last_hidden_state,
  237. pooler_output=pooled_output,
  238. hidden_states=encoder_outputs.hidden_states,
  239. )
  240. @auto_docstring(
  241. custom_intro="""
  242. ConvNext Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  243. ImageNet.
  244. """
  245. )
  246. class ConvNextForImageClassification(ConvNextPreTrainedModel):
  247. accepts_loss_kwargs = False
  248. def __init__(self, config):
  249. super().__init__(config)
  250. self.num_labels = config.num_labels
  251. self.convnext = ConvNextModel(config)
  252. # Classifier head
  253. if config.num_labels > 0:
  254. self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
  255. else:
  256. self.classifier = nn.Identity()
  257. # Initialize weights and apply final processing
  258. self.post_init()
  259. @can_return_tuple
  260. @auto_docstring
  261. def forward(
  262. self, pixel_values: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, **kwargs
  263. ) -> ImageClassifierOutputWithNoAttention:
  264. r"""
  265. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  266. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  267. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  268. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  269. """
  270. outputs: BaseModelOutputWithPoolingAndNoAttention = self.convnext(pixel_values, **kwargs)
  271. pooled_output = outputs.pooler_output
  272. logits = self.classifier(pooled_output)
  273. loss = None
  274. if labels is not None:
  275. loss = self.loss_function(labels=labels, pooled_logits=logits, config=self.config)
  276. return ImageClassifierOutputWithNoAttention(
  277. loss=loss,
  278. logits=logits,
  279. hidden_states=outputs.hidden_states,
  280. )
  281. @auto_docstring(
  282. custom_intro="""
  283. ConvNeXt backbone, to be used with frameworks like DETR and MaskFormer.
  284. """
  285. )
  286. class ConvNextBackbone(BackboneMixin, ConvNextPreTrainedModel):
  287. has_attentions = False
  288. def __init__(self, config):
  289. super().__init__(config)
  290. self.embeddings = ConvNextEmbeddings(config)
  291. self.encoder = ConvNextEncoder(config)
  292. self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
  293. # Add layer norms to hidden states of out_features
  294. hidden_states_norms = {}
  295. for stage, num_channels in zip(self.out_features, self.channels):
  296. hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
  297. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  298. # initialize weights and apply final processing
  299. self.post_init()
  300. @can_return_tuple
  301. @filter_output_hidden_states
  302. @auto_docstring
  303. def forward(
  304. self,
  305. pixel_values: torch.Tensor,
  306. **kwargs: Unpack[TransformersKwargs],
  307. ) -> BackboneOutput:
  308. r"""
  309. Examples:
  310. ```python
  311. >>> from transformers import AutoImageProcessor, AutoBackbone
  312. >>> import torch
  313. >>> from PIL import Image
  314. >>> import httpx
  315. >>> from io import BytesIO
  316. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  317. >>> with httpx.stream("GET", url) as response:
  318. ... image = Image.open(BytesIO(response.read()))
  319. >>> processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
  320. >>> model = AutoBackbone.from_pretrained("facebook/convnext-tiny-224")
  321. >>> inputs = processor(image, return_tensors="pt")
  322. >>> outputs = model(**inputs)
  323. ```"""
  324. kwargs["output_hidden_states"] = True # required to extract layers for the stages
  325. embedding_output = self.embeddings(pixel_values)
  326. encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(embedding_output, **kwargs)
  327. hidden_states = encoder_outputs.hidden_states
  328. feature_maps = []
  329. for stage, hidden_state in zip(self.stage_names, hidden_states):
  330. if stage in self.out_features:
  331. hidden_state = self.hidden_states_norms[stage](hidden_state)
  332. feature_maps.append(hidden_state)
  333. return BackboneOutput(feature_maps=tuple(feature_maps), hidden_states=hidden_states)
  334. __all__ = ["ConvNextForImageClassification", "ConvNextModel", "ConvNextPreTrainedModel", "ConvNextBackbone"]