| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/pixio/modular_pixio.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_pixio.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 Meta AI and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import collections.abc
- from collections.abc import Callable
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...backbone_utils import BackboneMixin, filter_output_hidden_states
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, is_tracing
- from ...utils.generic import can_return_tuple, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_pixio import PixioConfig
- class PixioPatchEmbeddings(nn.Module):
- """
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
- Transformer.
- """
- def __init__(self, config: PixioConfig):
- super().__init__()
- image_size, patch_size = config.image_size, config.patch_size
- num_channels, hidden_size = config.num_channels, config.hidden_size
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- self.image_size = image_size
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.num_patches = num_patches
- self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
- def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
- batch_size, num_channels, height, width = pixel_values.shape
- if num_channels != self.num_channels:
- raise ValueError(
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
- f" Expected {self.num_channels} but got {num_channels}."
- )
- if not interpolate_pos_encoding:
- if height != self.image_size[0] or width != self.image_size[1]:
- raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model"
- f" ({self.image_size[0]}*{self.image_size[1]})."
- )
- embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
- return embeddings
- class PixioEmbeddings(nn.Module):
- """
- Construct the CLS tokens, position and patch embeddings.
- """
- def __init__(self, config: PixioConfig) -> None:
- super().__init__()
- self.cls_token = nn.Parameter(torch.randn(1, config.n_cls_tokens, config.hidden_size))
- self.mask_token = None
- self.patch_embeddings = PixioPatchEmbeddings(config)
- num_patches = self.patch_embeddings.num_patches
- self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + config.n_cls_tokens, config.hidden_size))
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.n_cls_tokens = config.n_cls_tokens
- self.patch_size = config.patch_size
- self.config = config
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
- images. This method is also adapted to support tracing and interpolation at torch.float32 precision.
- Adapted from:
- - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
- """
- num_patches = embeddings.shape[1] - self.n_cls_tokens
- num_positions = self.position_embeddings.shape[1] - self.n_cls_tokens
- if not is_tracing() and num_patches == num_positions and height == width:
- return self.position_embeddings
- class_pos_embed = self.position_embeddings[:, : self.n_cls_tokens]
- patch_pos_embed = self.position_embeddings[:, self.n_cls_tokens :]
- dim = embeddings.shape[-1]
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- target_dtype = patch_pos_embed.dtype
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed.to(torch.float32),
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- ).to(dtype=target_dtype)
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
- batch_size, _, height, width = pixel_values.shape
- target_dtype = self.patch_embeddings.projection.weight.dtype
- embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
- embeddings = torch.cat((cls_tokens, embeddings), dim=1)
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
- embeddings = self.dropout(embeddings)
- return embeddings
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class PixioSelfAttention(nn.Module):
- def __init__(self, config: PixioConfig):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
- f"heads {config.num_attention_heads}."
- )
- self.config = config
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.dropout_prob = config.attention_probs_dropout_prob
- self.scaling = self.attention_head_size**-0.5
- self.is_causal = False
- self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
- self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
- self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
- def forward(
- self,
- hidden_states: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- batch_size = hidden_states.shape[0]
- new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
- key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
- value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
- query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- context_layer, attention_probs = attention_interface(
- self,
- query_layer,
- key_layer,
- value_layer,
- None,
- is_causal=self.is_causal,
- scaling=self.scaling,
- dropout=0.0 if not self.training else self.dropout_prob,
- **kwargs,
- )
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.reshape(new_context_layer_shape)
- return context_layer, attention_probs
- class PixioSelfOutput(nn.Module):
- """
- The residual connection is defined in PixioLayer instead of here (as is the case with other models), due to the
- layernorm applied before each block.
- """
- def __init__(self, config: PixioConfig):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states
- class PixioAttention(nn.Module):
- def __init__(self, config: PixioConfig):
- super().__init__()
- self.attention = PixioSelfAttention(config)
- self.output = PixioSelfOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- self_attn_output, _ = self.attention(hidden_states, **kwargs)
- output = self.output(self_attn_output, hidden_states)
- return output
- def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
- if drop_prob == 0.0 or not training:
- return input
- keep_prob = 1 - drop_prob
- shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
- random_tensor.floor_() # binarize
- output = input.div(keep_prob) * random_tensor
- return output
- class PixioDropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob: float | None = None) -> None:
- super().__init__()
- self.drop_prob = drop_prob
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return drop_path(hidden_states, self.drop_prob, self.training)
- def extra_repr(self) -> str:
- return f"p={self.drop_prob}"
- class PixioMLP(nn.Module):
- def __init__(self, config) -> None:
- super().__init__()
- in_features = out_features = config.hidden_size
- hidden_features = int(config.hidden_size * config.mlp_ratio)
- self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
- if isinstance(config.hidden_act, str):
- self.activation = ACT2FN[config.hidden_act]
- else:
- self.activation = config.hidden_act
- self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- hidden_state = self.fc1(hidden_state)
- hidden_state = self.activation(hidden_state)
- hidden_state = self.fc2(hidden_state)
- return hidden_state
- class PixioLayer(GradientCheckpointingLayer):
- def __init__(self, config: PixioConfig) -> None:
- super().__init__()
- self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.attention = PixioAttention(config)
- self.drop_path = PixioDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
- self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.mlp = PixioMLP(config)
- def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> torch.Tensor:
- hidden_states_norm = self.norm1(hidden_states)
- self_attention_output = self.attention(hidden_states_norm, **kwargs)
- hidden_states = self.drop_path(self_attention_output) + hidden_states
- layer_output = self.norm2(hidden_states)
- layer_output = self.mlp(layer_output)
- layer_output = self.drop_path(layer_output) + hidden_states
- return layer_output
- @auto_docstring
- class PixioPreTrainedModel(PreTrainedModel):
- config: PixioConfig
- base_model_prefix = "pixio"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- supports_gradient_checkpointing = True
- _no_split_modules = ["PixioEmbeddings", "PixioLayer"]
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": PixioLayer,
- "attentions": PixioSelfAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm):
- """Initialize the weights"""
- if isinstance(module, nn.Linear | nn.Conv2d):
- init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, nn.LayerNorm):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- elif isinstance(module, PixioEmbeddings):
- init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
- init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
- if module.mask_token is not None:
- init.zeros_(module.mask_token)
- class PixioEncoder(PixioPreTrainedModel):
- def __init__(self, config: PixioConfig):
- super().__init__(config)
- self.layer = nn.ModuleList([PixioLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs(tie_last_hidden_states=False)
- def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BaseModelOutput:
- for layer_module in self.layer:
- hidden_states = layer_module(hidden_states, **kwargs)
- return BaseModelOutput(last_hidden_state=hidden_states)
- @auto_docstring
- class PixioModel(PixioPreTrainedModel):
- def __init__(self, config: PixioConfig):
- super().__init__(config)
- self.config = config
- self.embeddings = PixioEmbeddings(config)
- self.encoder = PixioEncoder(config)
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.post_init()
- def get_input_embeddings(self) -> PixioPatchEmbeddings:
- return self.embeddings.patch_embeddings
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPooling:
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- embedding_output = self.embeddings(pixel_values)
- encoder_outputs: BaseModelOutput = self.encoder(embedding_output, **kwargs)
- sequence_output = encoder_outputs.last_hidden_state
- sequence_output = self.layernorm(sequence_output)
- pooled_output = sequence_output[:, : self.embeddings.n_cls_tokens, :].mean(dim=1)
- return BaseModelOutputWithPooling(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Pixio backbone, to be used with frameworks like DETR and MaskFormer.
- """
- )
- class PixioBackbone(BackboneMixin, PixioPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
- self.embeddings = PixioEmbeddings(config)
- self.encoder = PixioEncoder(config)
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self) -> PixioPatchEmbeddings:
- return self.embeddings.patch_embeddings
- @can_return_tuple
- @filter_output_hidden_states
- @auto_docstring
- def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BackboneOutput:
- r"""
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, AutoBackbone
- >>> import torch
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> processor = AutoImageProcessor.from_pretrained("facebook/pixio-huge")
- >>> model = AutoBackbone.from_pretrained(
- ... "facebook/pixio-huge", out_features=["stage7", "stage15", "stage23", "stage31"]
- ... )
- >>> inputs = processor(image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> feature_maps = outputs.feature_maps
- >>> list(feature_maps[-1].shape)
- [1, 1280, 16, 16]
- ```"""
- kwargs["output_hidden_states"] = True # required to extract layers for the stages
- embedding_output = self.embeddings(pixel_values)
- output: BaseModelOutput = self.encoder(embedding_output, **kwargs)
- hidden_states = output.hidden_states
- feature_maps = []
- for stage, hidden_state in zip(self.stage_names, hidden_states):
- if stage in self.out_features:
- if self.config.apply_layernorm:
- hidden_state = self.layernorm(hidden_state)
- if self.config.reshape_hidden_states:
- hidden_state = hidden_state[:, self.embeddings.n_cls_tokens :]
- batch_size, _, height, width = pixel_values.shape
- patch_size = self.config.patch_size
- hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
- hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
- feature_maps.append(hidden_state)
- return BackboneOutput(
- feature_maps=tuple(feature_maps),
- hidden_states=hidden_states,
- attentions=output.attentions,
- )
- __all__ = ["PixioModel", "PixioPreTrainedModel", "PixioBackbone"]
|