modeling_regnet.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch RegNet model."""
  15. import math
  16. import torch
  17. from torch import Tensor, nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...modeling_outputs import (
  21. BaseModelOutputWithNoAttention,
  22. BaseModelOutputWithPoolingAndNoAttention,
  23. ImageClassifierOutputWithNoAttention,
  24. )
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import auto_docstring, logging
  27. from .configuration_regnet import RegNetConfig
  28. logger = logging.get_logger(__name__)
  29. class RegNetConvLayer(nn.Module):
  30. def __init__(
  31. self,
  32. in_channels: int,
  33. out_channels: int,
  34. kernel_size: int = 3,
  35. stride: int = 1,
  36. groups: int = 1,
  37. activation: str | None = "relu",
  38. ):
  39. super().__init__()
  40. self.convolution = nn.Conv2d(
  41. in_channels,
  42. out_channels,
  43. kernel_size=kernel_size,
  44. stride=stride,
  45. padding=kernel_size // 2,
  46. groups=groups,
  47. bias=False,
  48. )
  49. self.normalization = nn.BatchNorm2d(out_channels)
  50. self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
  51. def forward(self, hidden_state):
  52. hidden_state = self.convolution(hidden_state)
  53. hidden_state = self.normalization(hidden_state)
  54. hidden_state = self.activation(hidden_state)
  55. return hidden_state
  56. class RegNetEmbeddings(nn.Module):
  57. """
  58. RegNet Embeddings (stem) composed of a single aggressive convolution.
  59. """
  60. def __init__(self, config: RegNetConfig):
  61. super().__init__()
  62. self.embedder = RegNetConvLayer(
  63. config.num_channels, config.embedding_size, kernel_size=3, stride=2, activation=config.hidden_act
  64. )
  65. self.num_channels = config.num_channels
  66. def forward(self, pixel_values):
  67. num_channels = pixel_values.shape[1]
  68. if num_channels != self.num_channels:
  69. raise ValueError(
  70. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  71. )
  72. hidden_state = self.embedder(pixel_values)
  73. return hidden_state
  74. # Copied from transformers.models.resnet.modeling_resnet.ResNetShortCut with ResNet->RegNet
  75. class RegNetShortCut(nn.Module):
  76. """
  77. RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
  78. downsample the input using `stride=2`.
  79. """
  80. def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
  81. super().__init__()
  82. self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
  83. self.normalization = nn.BatchNorm2d(out_channels)
  84. def forward(self, input: Tensor) -> Tensor:
  85. hidden_state = self.convolution(input)
  86. hidden_state = self.normalization(hidden_state)
  87. return hidden_state
  88. class RegNetSELayer(nn.Module):
  89. """
  90. Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://huggingface.co/papers/1709.01507).
  91. """
  92. def __init__(self, in_channels: int, reduced_channels: int):
  93. super().__init__()
  94. self.pooler = nn.AdaptiveAvgPool2d((1, 1))
  95. self.attention = nn.Sequential(
  96. nn.Conv2d(in_channels, reduced_channels, kernel_size=1),
  97. nn.ReLU(),
  98. nn.Conv2d(reduced_channels, in_channels, kernel_size=1),
  99. nn.Sigmoid(),
  100. )
  101. def forward(self, hidden_state):
  102. # b c h w -> b c 1 1
  103. pooled = self.pooler(hidden_state)
  104. attention = self.attention(pooled)
  105. hidden_state = hidden_state * attention
  106. return hidden_state
  107. class RegNetXLayer(nn.Module):
  108. """
  109. RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
  110. """
  111. def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1):
  112. super().__init__()
  113. should_apply_shortcut = in_channels != out_channels or stride != 1
  114. groups = max(1, out_channels // config.groups_width)
  115. self.shortcut = (
  116. RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
  117. )
  118. self.layer = nn.Sequential(
  119. RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act),
  120. RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act),
  121. RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None),
  122. )
  123. self.activation = ACT2FN[config.hidden_act]
  124. def forward(self, hidden_state):
  125. residual = hidden_state
  126. hidden_state = self.layer(hidden_state)
  127. residual = self.shortcut(residual)
  128. hidden_state += residual
  129. hidden_state = self.activation(hidden_state)
  130. return hidden_state
  131. class RegNetYLayer(nn.Module):
  132. """
  133. RegNet's Y layer: an X layer with Squeeze and Excitation.
  134. """
  135. def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1):
  136. super().__init__()
  137. should_apply_shortcut = in_channels != out_channels or stride != 1
  138. groups = max(1, out_channels // config.groups_width)
  139. self.shortcut = (
  140. RegNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
  141. )
  142. self.layer = nn.Sequential(
  143. RegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act),
  144. RegNetConvLayer(out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act),
  145. RegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4))),
  146. RegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None),
  147. )
  148. self.activation = ACT2FN[config.hidden_act]
  149. def forward(self, hidden_state):
  150. residual = hidden_state
  151. hidden_state = self.layer(hidden_state)
  152. residual = self.shortcut(residual)
  153. hidden_state += residual
  154. hidden_state = self.activation(hidden_state)
  155. return hidden_state
  156. class RegNetStage(nn.Module):
  157. """
  158. A RegNet stage composed by stacked layers.
  159. """
  160. def __init__(
  161. self,
  162. config: RegNetConfig,
  163. in_channels: int,
  164. out_channels: int,
  165. stride: int = 2,
  166. depth: int = 2,
  167. ):
  168. super().__init__()
  169. layer = RegNetXLayer if config.layer_type == "x" else RegNetYLayer
  170. self.layers = nn.Sequential(
  171. # downsampling is done in the first layer with stride of 2
  172. layer(
  173. config,
  174. in_channels,
  175. out_channels,
  176. stride=stride,
  177. ),
  178. *[layer(config, out_channels, out_channels) for _ in range(depth - 1)],
  179. )
  180. def forward(self, hidden_state):
  181. hidden_state = self.layers(hidden_state)
  182. return hidden_state
  183. class RegNetEncoder(nn.Module):
  184. def __init__(self, config: RegNetConfig):
  185. super().__init__()
  186. self.stages = nn.ModuleList([])
  187. # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input
  188. self.stages.append(
  189. RegNetStage(
  190. config,
  191. config.embedding_size,
  192. config.hidden_sizes[0],
  193. stride=2 if config.downsample_in_first_stage else 1,
  194. depth=config.depths[0],
  195. )
  196. )
  197. in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
  198. for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
  199. self.stages.append(RegNetStage(config, in_channels, out_channels, depth=depth))
  200. def forward(
  201. self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
  202. ) -> BaseModelOutputWithNoAttention:
  203. hidden_states = () if output_hidden_states else None
  204. for stage_module in self.stages:
  205. if output_hidden_states:
  206. hidden_states = hidden_states + (hidden_state,)
  207. hidden_state = stage_module(hidden_state)
  208. if output_hidden_states:
  209. hidden_states = hidden_states + (hidden_state,)
  210. if not return_dict:
  211. return tuple(v for v in [hidden_state, hidden_states] if v is not None)
  212. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
  213. @auto_docstring
  214. class RegNetPreTrainedModel(PreTrainedModel):
  215. config: RegNetConfig
  216. base_model_prefix = "regnet"
  217. main_input_name = "pixel_values"
  218. _no_split_modules = ["RegNetYLayer"]
  219. @torch.no_grad()
  220. def _init_weights(self, module):
  221. if isinstance(module, nn.Conv2d):
  222. init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  223. # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
  224. elif isinstance(module, nn.Linear):
  225. init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  226. if module.bias is not None:
  227. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
  228. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  229. init.uniform_(module.bias, -bound, bound)
  230. elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
  231. init.constant_(module.weight, 1)
  232. init.constant_(module.bias, 0)
  233. if getattr(module, "running_mean", None) is not None:
  234. init.zeros_(module.running_mean)
  235. init.ones_(module.running_var)
  236. init.zeros_(module.num_batches_tracked)
  237. @auto_docstring
  238. # Copied from transformers.models.resnet.modeling_resnet.ResNetModel with RESNET->REGNET,ResNet->RegNet
  239. class RegNetModel(RegNetPreTrainedModel):
  240. def __init__(self, config):
  241. super().__init__(config)
  242. self.config = config
  243. self.embedder = RegNetEmbeddings(config)
  244. self.encoder = RegNetEncoder(config)
  245. self.pooler = nn.AdaptiveAvgPool2d((1, 1))
  246. # Initialize weights and apply final processing
  247. self.post_init()
  248. @auto_docstring
  249. def forward(
  250. self,
  251. pixel_values: Tensor,
  252. output_hidden_states: bool | None = None,
  253. return_dict: bool | None = None,
  254. **kwargs,
  255. ) -> BaseModelOutputWithPoolingAndNoAttention:
  256. output_hidden_states = (
  257. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  258. )
  259. return_dict = return_dict if return_dict is not None else self.config.return_dict
  260. embedding_output = self.embedder(pixel_values)
  261. encoder_outputs = self.encoder(
  262. embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
  263. )
  264. last_hidden_state = encoder_outputs[0]
  265. pooled_output = self.pooler(last_hidden_state)
  266. if not return_dict:
  267. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  268. return BaseModelOutputWithPoolingAndNoAttention(
  269. last_hidden_state=last_hidden_state,
  270. pooler_output=pooled_output,
  271. hidden_states=encoder_outputs.hidden_states,
  272. )
  273. @auto_docstring(
  274. custom_intro="""
  275. RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  276. ImageNet.
  277. """
  278. )
  279. # Copied from transformers.models.resnet.modeling_resnet.ResNetForImageClassification with RESNET->REGNET,ResNet->RegNet,resnet->regnet
  280. class RegNetForImageClassification(RegNetPreTrainedModel):
  281. def __init__(self, config):
  282. super().__init__(config)
  283. self.num_labels = config.num_labels
  284. self.regnet = RegNetModel(config)
  285. # classification head
  286. self.classifier = nn.Sequential(
  287. nn.Flatten(),
  288. nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
  289. )
  290. # initialize weights and apply final processing
  291. self.post_init()
  292. @auto_docstring
  293. def forward(
  294. self,
  295. pixel_values: torch.FloatTensor | None = None,
  296. labels: torch.LongTensor | None = None,
  297. output_hidden_states: bool | None = None,
  298. return_dict: bool | None = None,
  299. **kwargs,
  300. ) -> ImageClassifierOutputWithNoAttention:
  301. r"""
  302. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  303. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  304. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  305. """
  306. return_dict = return_dict if return_dict is not None else self.config.return_dict
  307. outputs = self.regnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  308. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  309. logits = self.classifier(pooled_output)
  310. loss = None
  311. if labels is not None:
  312. loss = self.loss_function(labels, logits, self.config)
  313. if not return_dict:
  314. output = (logits,) + outputs[2:]
  315. return (loss,) + output if loss is not None else output
  316. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  317. __all__ = ["RegNetForImageClassification", "RegNetModel", "RegNetPreTrainedModel"]