modeling_resnet.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ResNet model."""
  15. import math
  16. import torch
  17. from torch import Tensor, nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  21. from ...modeling_outputs import (
  22. BackboneOutput,
  23. BaseModelOutputWithNoAttention,
  24. BaseModelOutputWithPoolingAndNoAttention,
  25. ImageClassifierOutputWithNoAttention,
  26. )
  27. from ...modeling_utils import PreTrainedModel
  28. from ...utils import auto_docstring, logging
  29. from ...utils.generic import can_return_tuple
  30. from .configuration_resnet import ResNetConfig
  31. logger = logging.get_logger(__name__)
  32. class ResNetConvLayer(nn.Module):
  33. def __init__(
  34. self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
  35. ):
  36. super().__init__()
  37. self.convolution = nn.Conv2d(
  38. in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
  39. )
  40. self.normalization = nn.BatchNorm2d(out_channels)
  41. self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
  42. def forward(self, input: Tensor) -> Tensor:
  43. hidden_state = self.convolution(input)
  44. hidden_state = self.normalization(hidden_state)
  45. hidden_state = self.activation(hidden_state)
  46. return hidden_state
  47. class ResNetEmbeddings(nn.Module):
  48. """
  49. ResNet Embeddings (stem) composed of a single aggressive convolution.
  50. """
  51. def __init__(self, config: ResNetConfig):
  52. super().__init__()
  53. self.embedder = ResNetConvLayer(
  54. config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act
  55. )
  56. self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  57. self.num_channels = config.num_channels
  58. def forward(self, pixel_values: Tensor) -> Tensor:
  59. num_channels = pixel_values.shape[1]
  60. if num_channels != self.num_channels:
  61. raise ValueError(
  62. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  63. )
  64. embedding = self.embedder(pixel_values)
  65. embedding = self.pooler(embedding)
  66. return embedding
  67. class ResNetShortCut(nn.Module):
  68. """
  69. ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
  70. downsample the input using `stride=2`.
  71. """
  72. def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
  73. super().__init__()
  74. self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
  75. self.normalization = nn.BatchNorm2d(out_channels)
  76. def forward(self, input: Tensor) -> Tensor:
  77. hidden_state = self.convolution(input)
  78. hidden_state = self.normalization(hidden_state)
  79. return hidden_state
  80. class ResNetBasicLayer(nn.Module):
  81. """
  82. A classic ResNet's residual layer composed by two `3x3` convolutions.
  83. """
  84. def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"):
  85. super().__init__()
  86. should_apply_shortcut = in_channels != out_channels or stride != 1
  87. self.shortcut = (
  88. ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
  89. )
  90. self.layer = nn.Sequential(
  91. ResNetConvLayer(in_channels, out_channels, stride=stride),
  92. ResNetConvLayer(out_channels, out_channels, activation=None),
  93. )
  94. self.activation = ACT2FN[activation]
  95. def forward(self, hidden_state):
  96. residual = hidden_state
  97. hidden_state = self.layer(hidden_state)
  98. residual = self.shortcut(residual)
  99. hidden_state += residual
  100. hidden_state = self.activation(hidden_state)
  101. return hidden_state
  102. class ResNetBottleNeckLayer(nn.Module):
  103. """
  104. A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
  105. The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
  106. convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
  107. `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
  108. """
  109. def __init__(
  110. self,
  111. in_channels: int,
  112. out_channels: int,
  113. stride: int = 1,
  114. activation: str = "relu",
  115. reduction: int = 4,
  116. downsample_in_bottleneck: bool = False,
  117. ):
  118. super().__init__()
  119. should_apply_shortcut = in_channels != out_channels or stride != 1
  120. reduces_channels = out_channels // reduction
  121. self.shortcut = (
  122. ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
  123. )
  124. self.layer = nn.Sequential(
  125. ResNetConvLayer(
  126. in_channels, reduces_channels, kernel_size=1, stride=stride if downsample_in_bottleneck else 1
  127. ),
  128. ResNetConvLayer(reduces_channels, reduces_channels, stride=stride if not downsample_in_bottleneck else 1),
  129. ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
  130. )
  131. self.activation = ACT2FN[activation]
  132. def forward(self, hidden_state):
  133. residual = hidden_state
  134. hidden_state = self.layer(hidden_state)
  135. residual = self.shortcut(residual)
  136. hidden_state += residual
  137. hidden_state = self.activation(hidden_state)
  138. return hidden_state
  139. class ResNetStage(nn.Module):
  140. """
  141. A ResNet stage composed by stacked layers.
  142. """
  143. def __init__(
  144. self,
  145. config: ResNetConfig,
  146. in_channels: int,
  147. out_channels: int,
  148. stride: int = 2,
  149. depth: int = 2,
  150. ):
  151. super().__init__()
  152. layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
  153. if config.layer_type == "bottleneck":
  154. first_layer = layer(
  155. in_channels,
  156. out_channels,
  157. stride=stride,
  158. activation=config.hidden_act,
  159. downsample_in_bottleneck=config.downsample_in_bottleneck,
  160. )
  161. else:
  162. first_layer = layer(in_channels, out_channels, stride=stride, activation=config.hidden_act)
  163. self.layers = nn.Sequential(
  164. first_layer, *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)]
  165. )
  166. def forward(self, input: Tensor) -> Tensor:
  167. hidden_state = input
  168. for layer in self.layers:
  169. hidden_state = layer(hidden_state)
  170. return hidden_state
  171. class ResNetEncoder(nn.Module):
  172. def __init__(self, config: ResNetConfig):
  173. super().__init__()
  174. self.stages = nn.ModuleList([])
  175. # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
  176. self.stages.append(
  177. ResNetStage(
  178. config,
  179. config.embedding_size,
  180. config.hidden_sizes[0],
  181. stride=2 if config.downsample_in_first_stage else 1,
  182. depth=config.depths[0],
  183. )
  184. )
  185. in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
  186. for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
  187. self.stages.append(ResNetStage(config, in_channels, out_channels, depth=depth))
  188. def forward(
  189. self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
  190. ) -> BaseModelOutputWithNoAttention:
  191. hidden_states = () if output_hidden_states else None
  192. for stage_module in self.stages:
  193. if output_hidden_states:
  194. hidden_states = hidden_states + (hidden_state,)
  195. hidden_state = stage_module(hidden_state)
  196. if output_hidden_states:
  197. hidden_states = hidden_states + (hidden_state,)
  198. if not return_dict:
  199. return tuple(v for v in [hidden_state, hidden_states] if v is not None)
  200. return BaseModelOutputWithNoAttention(
  201. last_hidden_state=hidden_state,
  202. hidden_states=hidden_states,
  203. )
  204. @auto_docstring
  205. class ResNetPreTrainedModel(PreTrainedModel):
  206. config: ResNetConfig
  207. base_model_prefix = "resnet"
  208. main_input_name = "pixel_values"
  209. input_modalities = ("image",)
  210. _no_split_modules = ["ResNetConvLayer", "ResNetShortCut"]
  211. @torch.no_grad()
  212. def _init_weights(self, module):
  213. if isinstance(module, nn.Conv2d):
  214. init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  215. # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
  216. elif isinstance(module, nn.Linear):
  217. init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  218. if module.bias is not None:
  219. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
  220. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  221. init.uniform_(module.bias, -bound, bound)
  222. # We need to check it like that as some Detr models replace the BatchNorm2d by their own
  223. elif "BatchNorm" in module.__class__.__name__:
  224. init.ones_(module.weight)
  225. init.zeros_(module.bias)
  226. init.zeros_(module.running_mean)
  227. init.ones_(module.running_var)
  228. if getattr(module, "num_batches_tracked", None) is not None:
  229. init.zeros_(module.num_batches_tracked)
  230. @auto_docstring
  231. class ResNetModel(ResNetPreTrainedModel):
  232. def __init__(self, config):
  233. super().__init__(config)
  234. self.config = config
  235. self.embedder = ResNetEmbeddings(config)
  236. self.encoder = ResNetEncoder(config)
  237. self.pooler = nn.AdaptiveAvgPool2d((1, 1))
  238. # Initialize weights and apply final processing
  239. self.post_init()
  240. @auto_docstring
  241. def forward(
  242. self,
  243. pixel_values: Tensor,
  244. output_hidden_states: bool | None = None,
  245. return_dict: bool | None = None,
  246. **kwargs,
  247. ) -> BaseModelOutputWithPoolingAndNoAttention:
  248. output_hidden_states = (
  249. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  250. )
  251. return_dict = return_dict if return_dict is not None else self.config.return_dict
  252. embedding_output = self.embedder(pixel_values)
  253. encoder_outputs = self.encoder(
  254. embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
  255. )
  256. last_hidden_state = encoder_outputs[0]
  257. pooled_output = self.pooler(last_hidden_state)
  258. if not return_dict:
  259. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  260. return BaseModelOutputWithPoolingAndNoAttention(
  261. last_hidden_state=last_hidden_state,
  262. pooler_output=pooled_output,
  263. hidden_states=encoder_outputs.hidden_states,
  264. )
  265. @auto_docstring(
  266. custom_intro="""
  267. ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  268. ImageNet.
  269. """
  270. )
  271. class ResNetForImageClassification(ResNetPreTrainedModel):
  272. def __init__(self, config):
  273. super().__init__(config)
  274. self.num_labels = config.num_labels
  275. self.resnet = ResNetModel(config)
  276. # classification head
  277. self.classifier = nn.Sequential(
  278. nn.Flatten(),
  279. nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
  280. )
  281. # initialize weights and apply final processing
  282. self.post_init()
  283. @auto_docstring
  284. def forward(
  285. self,
  286. pixel_values: torch.FloatTensor | None = None,
  287. labels: torch.LongTensor | None = None,
  288. output_hidden_states: bool | None = None,
  289. return_dict: bool | None = None,
  290. **kwargs,
  291. ) -> ImageClassifierOutputWithNoAttention:
  292. r"""
  293. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  294. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  295. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  296. """
  297. return_dict = return_dict if return_dict is not None else self.config.return_dict
  298. outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  299. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  300. logits = self.classifier(pooled_output)
  301. loss = None
  302. if labels is not None:
  303. loss = self.loss_function(labels, logits, self.config)
  304. if not return_dict:
  305. output = (logits,) + outputs[2:]
  306. return (loss,) + output if loss is not None else output
  307. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  308. @auto_docstring(
  309. custom_intro="""
  310. ResNet backbone, to be used with frameworks like DETR and MaskFormer.
  311. """
  312. )
  313. class ResNetBackbone(BackboneMixin, ResNetPreTrainedModel):
  314. has_attentions = False
  315. def __init__(self, config):
  316. super().__init__(config)
  317. self.num_features = [config.embedding_size] + config.hidden_sizes
  318. self.embedder = ResNetEmbeddings(config)
  319. self.encoder = ResNetEncoder(config)
  320. # initialize weights and apply final processing
  321. self.post_init()
  322. @can_return_tuple
  323. @filter_output_hidden_states
  324. @auto_docstring
  325. def forward(
  326. self,
  327. pixel_values: Tensor,
  328. output_hidden_states: bool | None = None,
  329. return_dict: bool | None = None,
  330. **kwargs,
  331. ) -> BackboneOutput:
  332. r"""
  333. Examples:
  334. ```python
  335. >>> from transformers import AutoImageProcessor, AutoBackbone
  336. >>> import torch
  337. >>> from PIL import Image
  338. >>> import httpx
  339. >>> from io import BytesIO
  340. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  341. >>> with httpx.stream("GET", url) as response:
  342. ... image = Image.open(BytesIO(response.read()))
  343. >>> processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
  344. >>> model = AutoBackbone.from_pretrained(
  345. ... "microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]
  346. ... )
  347. >>> inputs = processor(image, return_tensors="pt")
  348. >>> outputs = model(**inputs)
  349. >>> feature_maps = outputs.feature_maps
  350. >>> list(feature_maps[-1].shape)
  351. [1, 2048, 7, 7]
  352. ```"""
  353. return_dict = return_dict if return_dict is not None else self.config.return_dict
  354. output_hidden_states = (
  355. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  356. )
  357. embedding_output = self.embedder(pixel_values)
  358. outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
  359. hidden_states = outputs.hidden_states
  360. feature_maps = ()
  361. for idx, stage in enumerate(self.stage_names):
  362. if stage in self.out_features:
  363. feature_maps += (hidden_states[idx],)
  364. if not return_dict:
  365. output = (feature_maps,)
  366. if output_hidden_states:
  367. output += (outputs.hidden_states,)
  368. return output
  369. return BackboneOutput(
  370. feature_maps=feature_maps,
  371. hidden_states=outputs.hidden_states if output_hidden_states else None,
  372. attentions=None,
  373. )
  374. __all__ = ["ResNetForImageClassification", "ResNetModel", "ResNetPreTrainedModel", "ResNetBackbone"]