| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417 |
- # Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch TextNet model."""
- from typing import Any
- import torch
- import torch.nn as nn
- from torch import Tensor
- from ...activations import ACT2CLS
- from ...backbone_utils import BackboneMixin, filter_output_hidden_states
- from ...modeling_outputs import (
- BackboneOutput,
- BaseModelOutputWithNoAttention,
- BaseModelOutputWithPoolingAndNoAttention,
- ImageClassifierOutputWithNoAttention,
- )
- from ...modeling_utils import PreTrainedModel
- from ...utils import auto_docstring, logging
- from ...utils.generic import can_return_tuple
- from .configuration_textnet import TextNetConfig
- logger = logging.get_logger(__name__)
- class TextNetConvLayer(nn.Module):
- def __init__(self, config: TextNetConfig):
- super().__init__()
- self.kernel_size = config.stem_kernel_size
- self.stride = config.stem_stride
- self.activation_function = config.stem_act_func
- padding = (
- (config.kernel_size[0] // 2, config.kernel_size[1] // 2)
- if isinstance(config.stem_kernel_size, tuple)
- else config.stem_kernel_size // 2
- )
- self.conv = nn.Conv2d(
- config.stem_num_channels,
- config.stem_out_channels,
- kernel_size=config.stem_kernel_size,
- stride=config.stem_stride,
- padding=padding,
- bias=False,
- )
- self.batch_norm = nn.BatchNorm2d(config.stem_out_channels, config.batch_norm_eps)
- self.activation = nn.Identity()
- if self.activation_function is not None:
- self.activation = ACT2CLS[self.activation_function]()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.conv(hidden_states)
- hidden_states = self.batch_norm(hidden_states)
- return self.activation(hidden_states)
- class TextNetRepConvLayer(nn.Module):
- r"""
- This layer supports re-parameterization by combining multiple convolutional branches
- (e.g., main convolution, vertical, horizontal, and identity branches) during training.
- At inference time, these branches can be collapsed into a single convolution for
- efficiency, as per the re-parameterization paradigm.
- The "Rep" in the name stands for "re-parameterization" (introduced by RepVGG).
- """
- def __init__(self, config: TextNetConfig, in_channels: int, out_channels: int, kernel_size: int, stride: int):
- super().__init__()
- self.num_channels = in_channels
- self.out_channels = out_channels
- self.kernel_size = kernel_size
- self.stride = stride
- padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
- self.activation_function = nn.ReLU()
- self.main_conv = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- bias=False,
- )
- self.main_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
- vertical_padding = ((kernel_size[0] - 1) // 2, 0)
- horizontal_padding = (0, (kernel_size[1] - 1) // 2)
- if kernel_size[1] != 1:
- self.vertical_conv = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=(kernel_size[0], 1),
- stride=stride,
- padding=vertical_padding,
- bias=False,
- )
- self.vertical_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
- else:
- self.vertical_conv, self.vertical_batch_norm = None, None
- if kernel_size[0] != 1:
- self.horizontal_conv = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=(1, kernel_size[1]),
- stride=stride,
- padding=horizontal_padding,
- bias=False,
- )
- self.horizontal_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
- else:
- self.horizontal_conv, self.horizontal_batch_norm = None, None
- self.rbr_identity = (
- nn.BatchNorm2d(num_features=in_channels, eps=config.batch_norm_eps)
- if out_channels == in_channels and stride == 1
- else None
- )
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- main_outputs = self.main_conv(hidden_states)
- main_outputs = self.main_batch_norm(main_outputs)
- # applies a convolution with a vertical kernel
- if self.vertical_conv is not None:
- vertical_outputs = self.vertical_conv(hidden_states)
- vertical_outputs = self.vertical_batch_norm(vertical_outputs)
- main_outputs = main_outputs + vertical_outputs
- # applies a convolution with a horizontal kernel
- if self.horizontal_conv is not None:
- horizontal_outputs = self.horizontal_conv(hidden_states)
- horizontal_outputs = self.horizontal_batch_norm(horizontal_outputs)
- main_outputs = main_outputs + horizontal_outputs
- if self.rbr_identity is not None:
- id_out = self.rbr_identity(hidden_states)
- main_outputs = main_outputs + id_out
- return self.activation_function(main_outputs)
- class TextNetStage(nn.Module):
- def __init__(self, config: TextNetConfig, depth: int):
- super().__init__()
- kernel_size = config.conv_layer_kernel_sizes[depth]
- stride = config.conv_layer_strides[depth]
- num_layers = len(kernel_size)
- stage_in_channel_size = config.hidden_sizes[depth]
- stage_out_channel_size = config.hidden_sizes[depth + 1]
- in_channels = [stage_in_channel_size] + [stage_out_channel_size] * (num_layers - 1)
- out_channels = [stage_out_channel_size] * num_layers
- stage = []
- for stage_config in zip(in_channels, out_channels, kernel_size, stride):
- stage.append(TextNetRepConvLayer(config, *stage_config))
- self.stage = nn.ModuleList(stage)
- def forward(self, hidden_state):
- for block in self.stage:
- hidden_state = block(hidden_state)
- return hidden_state
- class TextNetEncoder(nn.Module):
- def __init__(self, config: TextNetConfig):
- super().__init__()
- stages = []
- num_stages = len(config.conv_layer_kernel_sizes)
- for stage_ix in range(num_stages):
- stages.append(TextNetStage(config, stage_ix))
- self.stages = nn.ModuleList(stages)
- def forward(
- self,
- hidden_state: torch.Tensor,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- ) -> BaseModelOutputWithNoAttention:
- hidden_states = [hidden_state]
- for stage in self.stages:
- hidden_state = stage(hidden_state)
- hidden_states.append(hidden_state)
- if not return_dict:
- output = (hidden_state,)
- return output + (hidden_states,) if output_hidden_states else output
- return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
- @auto_docstring
- class TextNetPreTrainedModel(PreTrainedModel):
- config: TextNetConfig
- base_model_prefix = "textnet"
- main_input_name = "pixel_values"
- @auto_docstring
- class TextNetModel(TextNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.stem = TextNetConvLayer(config)
- self.encoder = TextNetEncoder(config)
- self.pooler = nn.AdaptiveAvgPool2d((2, 2))
- self.post_init()
- @auto_docstring
- def forward(
- self,
- pixel_values: Tensor,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[Any, list[Any]] | tuple[Any] | BaseModelOutputWithPoolingAndNoAttention:
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- hidden_state = self.stem(pixel_values)
- encoder_outputs = self.encoder(
- hidden_state, output_hidden_states=output_hidden_states, return_dict=return_dict
- )
- last_hidden_state = encoder_outputs[0]
- pooled_output = self.pooler(last_hidden_state)
- if not return_dict:
- output = (last_hidden_state, pooled_output)
- return output + (encoder_outputs[1],) if output_hidden_states else output
- return BaseModelOutputWithPoolingAndNoAttention(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs[1] if output_hidden_states else None,
- )
- @auto_docstring(
- custom_intro="""
- TextNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
- ImageNet.
- """
- )
- class TextNetForImageClassification(TextNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.textnet = TextNetModel(config)
- self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
- self.flatten = nn.Flatten()
- self.fc = nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
- # classification head
- self.classifier = nn.ModuleList([self.avg_pool, self.flatten])
- # initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> ImageClassifierOutputWithNoAttention:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> import torch
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import TextNetForImageClassification, TextNetImageProcessor
- >>> from PIL import Image
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> processor = TextNetImageProcessor.from_pretrained("czczup/textnet-base")
- >>> model = TextNetForImageClassification.from_pretrained("czczup/textnet-base")
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> with torch.no_grad():
- ... outputs = model(**inputs)
- >>> outputs.logits.shape
- torch.Size([1, 2])
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.textnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
- last_hidden_state = outputs[0]
- for layer in self.classifier:
- last_hidden_state = layer(last_hidden_state)
- logits = self.fc(last_hidden_state)
- loss = None
- if labels is not None:
- loss = self.loss_function(labels, logits, self.config)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return (loss,) + output if loss is not None else output
- return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
- @auto_docstring(
- custom_intro="""
- TextNet backbone, to be used with frameworks like DETR and MaskFormer.
- """
- )
- class TextNetBackbone(BackboneMixin, TextNetPreTrainedModel):
- has_attentions = False
- def __init__(self, config):
- super().__init__(config)
- self.textnet = TextNetModel(config)
- self.num_features = config.hidden_sizes
- # initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @filter_output_hidden_states
- @auto_docstring
- def forward(
- self,
- pixel_values: Tensor,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[tuple] | BackboneOutput:
- r"""
- Examples:
- ```python
- >>> import torch
- >>> import httpx
- >>> from io import BytesIO
- >>> from PIL import Image
- >>> from transformers import AutoImageProcessor, AutoBackbone
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> processor = AutoImageProcessor.from_pretrained("czczup/textnet-base")
- >>> model = AutoBackbone.from_pretrained("czczup/textnet-base")
- >>> inputs = processor(image, return_tensors="pt")
- >>> with torch.no_grad():
- >>> outputs = model(**inputs)
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- outputs = self.textnet(pixel_values, output_hidden_states=True, return_dict=return_dict)
- hidden_states = outputs.hidden_states if return_dict else outputs[2]
- feature_maps = ()
- for idx, stage in enumerate(self.stage_names):
- if stage in self.out_features:
- feature_maps += (hidden_states[idx],)
- if not return_dict:
- output = (feature_maps,)
- if output_hidden_states:
- hidden_states = outputs.hidden_states if return_dict else outputs[2]
- output += (hidden_states,)
- return output
- return BackboneOutput(
- feature_maps=feature_maps,
- hidden_states=outputs.hidden_states if output_hidden_states else None,
- attentions=None,
- )
- __all__ = ["TextNetBackbone", "TextNetModel", "TextNetPreTrainedModel", "TextNetForImageClassification"]
|