modeling_convnextv2.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # Copyright 2023 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ConvNextV2 model."""
  15. import torch
  16. from torch import nn
  17. from ... import initialization as init
  18. from ...activations import ACT2FN
  19. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  20. from ...modeling_outputs import (
  21. BackboneOutput,
  22. BaseModelOutputWithNoAttention,
  23. BaseModelOutputWithPoolingAndNoAttention,
  24. ImageClassifierOutputWithNoAttention,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import TransformersKwargs, auto_docstring, logging
  29. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  30. from ...utils.output_capturing import capture_outputs
  31. from .configuration_convnextv2 import ConvNextV2Config
  32. logger = logging.get_logger(__name__)
  33. # Copied from transformers.models.beit.modeling_beit.drop_path
  34. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  35. """
  36. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  37. """
  38. if drop_prob == 0.0 or not training:
  39. return input
  40. keep_prob = 1 - drop_prob
  41. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  42. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  43. random_tensor.floor_() # binarize
  44. output = input.div(keep_prob) * random_tensor
  45. return output
  46. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->ConvNextV2
  47. class ConvNextV2DropPath(nn.Module):
  48. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  49. def __init__(self, drop_prob: float | None = None) -> None:
  50. super().__init__()
  51. self.drop_prob = drop_prob
  52. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  53. return drop_path(hidden_states, self.drop_prob, self.training)
  54. def extra_repr(self) -> str:
  55. return f"p={self.drop_prob}"
  56. class ConvNextV2GRN(nn.Module):
  57. """GRN (Global Response Normalization) layer"""
  58. def __init__(self, dim: int):
  59. super().__init__()
  60. self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
  61. self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
  62. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  63. # Compute and normalize global spatial feature maps
  64. global_features = torch.linalg.vector_norm(hidden_states, ord=2, dim=(1, 2), keepdim=True)
  65. norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6)
  66. hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
  67. return hidden_states
  68. # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->ConvNextV2
  69. class ConvNextV2LayerNorm(nn.LayerNorm):
  70. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  71. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  72. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  73. """
  74. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  75. super().__init__(normalized_shape, eps=eps, **kwargs)
  76. if data_format not in ["channels_last", "channels_first"]:
  77. raise NotImplementedError(f"Unsupported data format: {data_format}")
  78. self.data_format = data_format
  79. def forward(self, features: torch.Tensor) -> torch.Tensor:
  80. """
  81. Args:
  82. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  83. """
  84. if self.data_format == "channels_first":
  85. features = features.permute(0, 2, 3, 1)
  86. features = super().forward(features)
  87. features = features.permute(0, 3, 1, 2)
  88. else:
  89. features = super().forward(features)
  90. return features
  91. # Copied from transformers.models.convnext.modeling_convnext.ConvNextEmbeddings with ConvNext->ConvNextV2
  92. class ConvNextV2Embeddings(nn.Module):
  93. """This class is comparable to (and inspired by) the SwinEmbeddings class
  94. found in src/transformers/models/swin/modeling_swin.py.
  95. """
  96. def __init__(self, config):
  97. super().__init__()
  98. self.patch_embeddings = nn.Conv2d(
  99. config.num_channels, config.hidden_sizes[0], kernel_size=config.patch_size, stride=config.patch_size
  100. )
  101. self.layernorm = ConvNextV2LayerNorm(config.hidden_sizes[0], eps=1e-6, data_format="channels_first")
  102. self.num_channels = config.num_channels
  103. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  104. num_channels = pixel_values.shape[1]
  105. if num_channels != self.num_channels:
  106. raise ValueError(
  107. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  108. )
  109. embeddings = self.patch_embeddings(pixel_values)
  110. embeddings = self.layernorm(embeddings)
  111. return embeddings
  112. class ConvNextV2Layer(nn.Module):
  113. """This corresponds to the `Block` class in the original implementation.
  114. There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
  115. H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
  116. The authors used (2) as they find it slightly faster in PyTorch.
  117. Args:
  118. config ([`ConvNextV2Config`]): Model configuration class.
  119. dim (`int`): Number of input channels.
  120. drop_path (`float`): Stochastic depth rate. Default: 0.0.
  121. """
  122. def __init__(self, config, dim, drop_path=0):
  123. super().__init__()
  124. # depthwise conv
  125. self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
  126. self.layernorm = ConvNextV2LayerNorm(dim, eps=1e-6)
  127. # pointwise/1x1 convs, implemented with linear layers
  128. self.pwconv1 = nn.Linear(dim, 4 * dim)
  129. self.act = ACT2FN[config.hidden_act]
  130. self.grn = ConvNextV2GRN(4 * dim)
  131. self.pwconv2 = nn.Linear(4 * dim, dim)
  132. self.drop_path = ConvNextV2DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  133. def forward(self, features: torch.Tensor) -> torch.Tensor:
  134. residual = features
  135. features = self.dwconv(features)
  136. # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
  137. features = features.permute(0, 2, 3, 1)
  138. features = self.layernorm(features)
  139. features = self.pwconv1(features)
  140. features = self.act(features)
  141. features = self.grn(features)
  142. features = self.pwconv2(features)
  143. # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
  144. features = features.permute(0, 3, 1, 2)
  145. features = residual + self.drop_path(features)
  146. return features
  147. # Copied from transformers.models.convnext.modeling_convnext.ConvNextStage with ConvNeXT->ConvNeXTV2, ConvNext->ConvNextV2
  148. class ConvNextV2Stage(nn.Module):
  149. """ConvNeXTV2 stage, consisting of an optional downsampling layer + multiple residual blocks.
  150. Args:
  151. config ([`ConvNextV2Config`]): Model configuration class.
  152. in_channels (`int`): Number of input channels.
  153. out_channels (`int`): Number of output channels.
  154. depth (`int`): Number of residual blocks.
  155. drop_path_rates(`list[float]`): Stochastic depth rates for each layer.
  156. """
  157. def __init__(self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None):
  158. super().__init__()
  159. if in_channels != out_channels or stride > 1:
  160. self.downsampling_layer = nn.ModuleList(
  161. [
  162. ConvNextV2LayerNorm(in_channels, eps=1e-6, data_format="channels_first"),
  163. nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride),
  164. ]
  165. )
  166. else:
  167. self.downsampling_layer = nn.ModuleList()
  168. drop_path_rates = drop_path_rates or [0.0] * depth
  169. self.layers = nn.ModuleList(
  170. [ConvNextV2Layer(config, dim=out_channels, drop_path=drop_path_rates[j]) for j in range(depth)]
  171. )
  172. def forward(self, features: torch.Tensor) -> torch.Tensor:
  173. for layer in self.downsampling_layer:
  174. features = layer(features)
  175. for layer in self.layers:
  176. features = layer(features)
  177. return features
  178. @auto_docstring
  179. class ConvNextV2PreTrainedModel(PreTrainedModel):
  180. config: ConvNextV2Config
  181. base_model_prefix = "convnextv2"
  182. main_input_name = "pixel_values"
  183. input_modalities = ("image",)
  184. _no_split_modules = ["ConvNextV2Layer"]
  185. @torch.no_grad()
  186. def _init_weights(self, module):
  187. """Initialize the weights"""
  188. super()._init_weights(module)
  189. if isinstance(module, ConvNextV2GRN):
  190. init.zeros_(module.weight)
  191. init.zeros_(module.bias)
  192. # Copied from transformers.models.convnext.modeling_convnext.ConvNextEncoder with CONVNEXT->CONVNEXTV2, ConvNext->ConvNextV2
  193. class ConvNextV2Encoder(ConvNextV2PreTrainedModel):
  194. main_input_name = "hidden_states"
  195. _can_record_outputs = {"hidden_states": ConvNextV2Stage}
  196. def __init__(self, config):
  197. super().__init__(config)
  198. self.stages = nn.ModuleList()
  199. drop_path_rates = [
  200. x.tolist()
  201. for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
  202. ]
  203. prev_chs = config.hidden_sizes[0]
  204. for i in range(config.num_stages):
  205. out_chs = config.hidden_sizes[i]
  206. stage = ConvNextV2Stage(
  207. config,
  208. in_channels=prev_chs,
  209. out_channels=out_chs,
  210. stride=2 if i > 0 else 1,
  211. depth=config.depths[i],
  212. drop_path_rates=drop_path_rates[i],
  213. )
  214. self.stages.append(stage)
  215. prev_chs = out_chs
  216. self.post_init()
  217. @merge_with_config_defaults
  218. @capture_outputs(tie_last_hidden_states=False)
  219. def forward(
  220. self,
  221. hidden_states: torch.Tensor,
  222. **kwargs: Unpack[TransformersKwargs],
  223. ) -> BaseModelOutputWithNoAttention:
  224. for layer_module in self.stages:
  225. hidden_states = layer_module(hidden_states)
  226. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states)
  227. @auto_docstring
  228. # Copied from transformers.models.convnext.modeling_convnext.ConvNextModel with CONVNEXT->CONVNEXTV2, ConvNext->ConvNextV2
  229. class ConvNextV2Model(ConvNextV2PreTrainedModel):
  230. def __init__(self, config):
  231. super().__init__(config)
  232. self.config = config
  233. self.embeddings = ConvNextV2Embeddings(config)
  234. self.encoder = ConvNextV2Encoder(config)
  235. # final layernorm layer
  236. self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
  237. # Initialize weights and apply final processing
  238. self.post_init()
  239. @can_return_tuple
  240. @auto_docstring
  241. def forward(
  242. self, pixel_values: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs]
  243. ) -> BaseModelOutputWithPoolingAndNoAttention:
  244. if pixel_values is None:
  245. raise ValueError("You have to specify pixel_values")
  246. embedding_output = self.embeddings(pixel_values)
  247. encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(embedding_output, **kwargs)
  248. last_hidden_state = encoder_outputs.last_hidden_state
  249. # global average pooling, (N, C, H, W) -> (N, C)
  250. pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
  251. return BaseModelOutputWithPoolingAndNoAttention(
  252. last_hidden_state=last_hidden_state,
  253. pooler_output=pooled_output,
  254. hidden_states=encoder_outputs.hidden_states,
  255. )
  256. @auto_docstring(
  257. custom_intro="""
  258. ConvNextV2 Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  259. ImageNet.
  260. """
  261. )
  262. # Copied from transformers.models.convnext.modeling_convnext.ConvNextForImageClassification with CONVNEXT->CONVNEXTV2,ConvNext->ConvNextV2,convnext->convnextv2
  263. class ConvNextV2ForImageClassification(ConvNextV2PreTrainedModel):
  264. accepts_loss_kwargs = False
  265. def __init__(self, config):
  266. super().__init__(config)
  267. self.num_labels = config.num_labels
  268. self.convnextv2 = ConvNextV2Model(config)
  269. # Classifier head
  270. if config.num_labels > 0:
  271. self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
  272. else:
  273. self.classifier = nn.Identity()
  274. # Initialize weights and apply final processing
  275. self.post_init()
  276. @can_return_tuple
  277. @auto_docstring
  278. def forward(
  279. self, pixel_values: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, **kwargs
  280. ) -> ImageClassifierOutputWithNoAttention:
  281. r"""
  282. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  283. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  284. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  285. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  286. """
  287. outputs: BaseModelOutputWithPoolingAndNoAttention = self.convnextv2(pixel_values, **kwargs)
  288. pooled_output = outputs.pooler_output
  289. logits = self.classifier(pooled_output)
  290. loss = None
  291. if labels is not None:
  292. loss = self.loss_function(labels=labels, pooled_logits=logits, config=self.config)
  293. return ImageClassifierOutputWithNoAttention(
  294. loss=loss,
  295. logits=logits,
  296. hidden_states=outputs.hidden_states,
  297. )
  298. @auto_docstring(
  299. custom_intro="""
  300. ConvNeXT V2 backbone, to be used with frameworks like DETR and MaskFormer.
  301. """
  302. )
  303. # Copied from transformers.models.convnext.modeling_convnext.ConvNextBackbone with CONVNEXT->CONVNEXTV2,ConvNext->ConvNextV2,facebook/convnext-tiny-224->facebook/convnextv2-tiny-1k-224
  304. class ConvNextV2Backbone(BackboneMixin, ConvNextV2PreTrainedModel):
  305. has_attentions = False
  306. def __init__(self, config):
  307. super().__init__(config)
  308. self.embeddings = ConvNextV2Embeddings(config)
  309. self.encoder = ConvNextV2Encoder(config)
  310. self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
  311. # Add layer norms to hidden states of out_features
  312. hidden_states_norms = {}
  313. for stage, num_channels in zip(self.out_features, self.channels):
  314. hidden_states_norms[stage] = ConvNextV2LayerNorm(num_channels, data_format="channels_first")
  315. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  316. # initialize weights and apply final processing
  317. self.post_init()
  318. @can_return_tuple
  319. @filter_output_hidden_states
  320. @auto_docstring
  321. def forward(
  322. self,
  323. pixel_values: torch.Tensor,
  324. **kwargs: Unpack[TransformersKwargs],
  325. ) -> BackboneOutput:
  326. r"""
  327. Examples:
  328. ```python
  329. >>> from transformers import AutoImageProcessor, AutoBackbone
  330. >>> import torch
  331. >>> from PIL import Image
  332. >>> import httpx
  333. >>> from io import BytesIO
  334. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  335. >>> with httpx.stream("GET", url) as response:
  336. ... image = Image.open(BytesIO(response.read()))
  337. >>> processor = AutoImageProcessor.from_pretrained("facebook/convnextv2-tiny-1k-224")
  338. >>> model = AutoBackbone.from_pretrained("facebook/convnextv2-tiny-1k-224")
  339. >>> inputs = processor(image, return_tensors="pt")
  340. >>> outputs = model(**inputs)
  341. ```"""
  342. kwargs["output_hidden_states"] = True # required to extract layers for the stages
  343. embedding_output = self.embeddings(pixel_values)
  344. encoder_outputs: BaseModelOutputWithNoAttention = self.encoder(embedding_output, **kwargs)
  345. hidden_states = encoder_outputs.hidden_states
  346. feature_maps = []
  347. for stage, hidden_state in zip(self.stage_names, hidden_states):
  348. if stage in self.out_features:
  349. hidden_state = self.hidden_states_norms[stage](hidden_state)
  350. feature_maps.append(hidden_state)
  351. return BackboneOutput(feature_maps=tuple(feature_maps), hidden_states=hidden_states)
  352. __all__ = ["ConvNextV2ForImageClassification", "ConvNextV2Model", "ConvNextV2PreTrainedModel", "ConvNextV2Backbone"]