| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 |
- # Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch ResNet model."""
- import math
- import torch
- from torch import Tensor, nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...backbone_utils import BackboneMixin, filter_output_hidden_states
- from ...modeling_outputs import (
- BackboneOutput,
- BaseModelOutputWithNoAttention,
- BaseModelOutputWithPoolingAndNoAttention,
- ImageClassifierOutputWithNoAttention,
- )
- from ...modeling_utils import PreTrainedModel
- from ...utils import auto_docstring, logging
- from ...utils.generic import can_return_tuple
- from .configuration_resnet import ResNetConfig
- logger = logging.get_logger(__name__)
- class ResNetConvLayer(nn.Module):
- def __init__(
- self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
- ):
- super().__init__()
- self.convolution = nn.Conv2d(
- in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
- )
- self.normalization = nn.BatchNorm2d(out_channels)
- self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
- def forward(self, input: Tensor) -> Tensor:
- hidden_state = self.convolution(input)
- hidden_state = self.normalization(hidden_state)
- hidden_state = self.activation(hidden_state)
- return hidden_state
- class ResNetEmbeddings(nn.Module):
- """
- ResNet Embeddings (stem) composed of a single aggressive convolution.
- """
- def __init__(self, config: ResNetConfig):
- super().__init__()
- self.embedder = ResNetConvLayer(
- config.num_channels, config.embedding_size, kernel_size=7, stride=2, activation=config.hidden_act
- )
- self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.num_channels = config.num_channels
- def forward(self, pixel_values: Tensor) -> Tensor:
- num_channels = pixel_values.shape[1]
- if num_channels != self.num_channels:
- raise ValueError(
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
- )
- embedding = self.embedder(pixel_values)
- embedding = self.pooler(embedding)
- return embedding
- class ResNetShortCut(nn.Module):
- """
- ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
- downsample the input using `stride=2`.
- """
- def __init__(self, in_channels: int, out_channels: int, stride: int = 2):
- super().__init__()
- self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
- self.normalization = nn.BatchNorm2d(out_channels)
- def forward(self, input: Tensor) -> Tensor:
- hidden_state = self.convolution(input)
- hidden_state = self.normalization(hidden_state)
- return hidden_state
- class ResNetBasicLayer(nn.Module):
- """
- A classic ResNet's residual layer composed by two `3x3` convolutions.
- """
- def __init__(self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu"):
- super().__init__()
- should_apply_shortcut = in_channels != out_channels or stride != 1
- self.shortcut = (
- ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
- )
- self.layer = nn.Sequential(
- ResNetConvLayer(in_channels, out_channels, stride=stride),
- ResNetConvLayer(out_channels, out_channels, activation=None),
- )
- self.activation = ACT2FN[activation]
- def forward(self, hidden_state):
- residual = hidden_state
- hidden_state = self.layer(hidden_state)
- residual = self.shortcut(residual)
- hidden_state += residual
- hidden_state = self.activation(hidden_state)
- return hidden_state
- class ResNetBottleNeckLayer(nn.Module):
- """
- A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
- The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
- convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. If
- `downsample_in_bottleneck` is true, downsample will be in the first layer instead of the second layer.
- """
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- stride: int = 1,
- activation: str = "relu",
- reduction: int = 4,
- downsample_in_bottleneck: bool = False,
- ):
- super().__init__()
- should_apply_shortcut = in_channels != out_channels or stride != 1
- reduces_channels = out_channels // reduction
- self.shortcut = (
- ResNetShortCut(in_channels, out_channels, stride=stride) if should_apply_shortcut else nn.Identity()
- )
- self.layer = nn.Sequential(
- ResNetConvLayer(
- in_channels, reduces_channels, kernel_size=1, stride=stride if downsample_in_bottleneck else 1
- ),
- ResNetConvLayer(reduces_channels, reduces_channels, stride=stride if not downsample_in_bottleneck else 1),
- ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=None),
- )
- self.activation = ACT2FN[activation]
- def forward(self, hidden_state):
- residual = hidden_state
- hidden_state = self.layer(hidden_state)
- residual = self.shortcut(residual)
- hidden_state += residual
- hidden_state = self.activation(hidden_state)
- return hidden_state
- class ResNetStage(nn.Module):
- """
- A ResNet stage composed by stacked layers.
- """
- def __init__(
- self,
- config: ResNetConfig,
- in_channels: int,
- out_channels: int,
- stride: int = 2,
- depth: int = 2,
- ):
- super().__init__()
- layer = ResNetBottleNeckLayer if config.layer_type == "bottleneck" else ResNetBasicLayer
- if config.layer_type == "bottleneck":
- first_layer = layer(
- in_channels,
- out_channels,
- stride=stride,
- activation=config.hidden_act,
- downsample_in_bottleneck=config.downsample_in_bottleneck,
- )
- else:
- first_layer = layer(in_channels, out_channels, stride=stride, activation=config.hidden_act)
- self.layers = nn.Sequential(
- first_layer, *[layer(out_channels, out_channels, activation=config.hidden_act) for _ in range(depth - 1)]
- )
- def forward(self, input: Tensor) -> Tensor:
- hidden_state = input
- for layer in self.layers:
- hidden_state = layer(hidden_state)
- return hidden_state
- class ResNetEncoder(nn.Module):
- def __init__(self, config: ResNetConfig):
- super().__init__()
- self.stages = nn.ModuleList([])
- # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
- self.stages.append(
- ResNetStage(
- config,
- config.embedding_size,
- config.hidden_sizes[0],
- stride=2 if config.downsample_in_first_stage else 1,
- depth=config.depths[0],
- )
- )
- in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
- for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]):
- self.stages.append(ResNetStage(config, in_channels, out_channels, depth=depth))
- def forward(
- self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
- ) -> BaseModelOutputWithNoAttention:
- hidden_states = () if output_hidden_states else None
- for stage_module in self.stages:
- if output_hidden_states:
- hidden_states = hidden_states + (hidden_state,)
- hidden_state = stage_module(hidden_state)
- if output_hidden_states:
- hidden_states = hidden_states + (hidden_state,)
- if not return_dict:
- return tuple(v for v in [hidden_state, hidden_states] if v is not None)
- return BaseModelOutputWithNoAttention(
- last_hidden_state=hidden_state,
- hidden_states=hidden_states,
- )
- @auto_docstring
- class ResNetPreTrainedModel(PreTrainedModel):
- config: ResNetConfig
- base_model_prefix = "resnet"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- _no_split_modules = ["ResNetConvLayer", "ResNetShortCut"]
- @torch.no_grad()
- def _init_weights(self, module):
- if isinstance(module, nn.Conv2d):
- init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
- # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
- elif isinstance(module, nn.Linear):
- init.kaiming_uniform_(module.weight, a=math.sqrt(5))
- if module.bias is not None:
- fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
- init.uniform_(module.bias, -bound, bound)
- # We need to check it like that as some Detr models replace the BatchNorm2d by their own
- elif "BatchNorm" in module.__class__.__name__:
- init.ones_(module.weight)
- init.zeros_(module.bias)
- init.zeros_(module.running_mean)
- init.ones_(module.running_var)
- if getattr(module, "num_batches_tracked", None) is not None:
- init.zeros_(module.num_batches_tracked)
- @auto_docstring
- class ResNetModel(ResNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- self.embedder = ResNetEmbeddings(config)
- self.encoder = ResNetEncoder(config)
- self.pooler = nn.AdaptiveAvgPool2d((1, 1))
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- pixel_values: Tensor,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> BaseModelOutputWithPoolingAndNoAttention:
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- embedding_output = self.embedder(pixel_values)
- encoder_outputs = self.encoder(
- embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict
- )
- last_hidden_state = encoder_outputs[0]
- pooled_output = self.pooler(last_hidden_state)
- if not return_dict:
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
- return BaseModelOutputWithPoolingAndNoAttention(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
- ImageNet.
- """
- )
- class ResNetForImageClassification(ResNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.resnet = ResNetModel(config)
- # classification head
- self.classifier = nn.Sequential(
- nn.Flatten(),
- nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity(),
- )
- # initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> ImageClassifierOutputWithNoAttention:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.resnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
- pooled_output = outputs.pooler_output if return_dict else outputs[1]
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- loss = self.loss_function(labels, logits, self.config)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return (loss,) + output if loss is not None else output
- return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
- @auto_docstring(
- custom_intro="""
- ResNet backbone, to be used with frameworks like DETR and MaskFormer.
- """
- )
- class ResNetBackbone(BackboneMixin, ResNetPreTrainedModel):
- has_attentions = False
- def __init__(self, config):
- super().__init__(config)
- self.num_features = [config.embedding_size] + config.hidden_sizes
- self.embedder = ResNetEmbeddings(config)
- self.encoder = ResNetEncoder(config)
- # initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @filter_output_hidden_states
- @auto_docstring
- def forward(
- self,
- pixel_values: Tensor,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> BackboneOutput:
- r"""
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, AutoBackbone
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
- >>> model = AutoBackbone.from_pretrained(
- ... "microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]
- ... )
- >>> inputs = processor(image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> feature_maps = outputs.feature_maps
- >>> list(feature_maps[-1].shape)
- [1, 2048, 7, 7]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- embedding_output = self.embedder(pixel_values)
- outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
- hidden_states = outputs.hidden_states
- feature_maps = ()
- for idx, stage in enumerate(self.stage_names):
- if stage in self.out_features:
- feature_maps += (hidden_states[idx],)
- if not return_dict:
- output = (feature_maps,)
- if output_hidden_states:
- output += (outputs.hidden_states,)
- return output
- return BackboneOutput(
- feature_maps=feature_maps,
- hidden_states=outputs.hidden_states if output_hidden_states else None,
- attentions=None,
- )
- __all__ = ["ResNetForImageClassification", "ResNetModel", "ResNetPreTrainedModel", "ResNetBackbone"]
|