modular_pixio.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright 2025 Meta AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Pixio model."""
  15. import torch
  16. from huggingface_hub.dataclasses import strict
  17. from torch import nn
  18. from ...modeling_layers import GradientCheckpointingLayer
  19. from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling
  20. from ...processing_utils import Unpack
  21. from ...utils import TransformersKwargs, auto_docstring, is_tracing
  22. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  23. from ...utils.output_capturing import capture_outputs
  24. from ..dinov2.configuration_dinov2 import Dinov2Config
  25. from ..dinov2.modeling_dinov2 import (
  26. Dinov2Backbone,
  27. Dinov2DropPath,
  28. Dinov2MLP,
  29. )
  30. from ..vit.modeling_vit import ViTAttention, ViTPatchEmbeddings, ViTPreTrainedModel, ViTSelfAttention
  31. @auto_docstring(checkpoint="facebook/pixio-huge")
  32. @strict
  33. class PixioConfig(Dinov2Config):
  34. r"""
  35. apply_layernorm (`bool`, *optional*, defaults to `True`):
  36. Whether to apply layer normalization to the feature maps in case the model is used as backbone.
  37. reshape_hidden_states (`bool`, *optional*, defaults to `True`):
  38. Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
  39. case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
  40. seq_len, hidden_size)`.
  41. n_cls_tokens (`int`, *optional*, defaults to 8):
  42. Number of class tokens in the Transformer encoder.
  43. Example:
  44. ```python
  45. >>> from transformers import PixioConfig, PixioModel
  46. >>> # Initializing a Pixio pixio-huge style configuration
  47. >>> configuration = PixioConfig()
  48. >>> # Initializing a model (with random weights) from the pixio-huge style configuration
  49. >>> model = PixioModel(configuration)
  50. >>> # Accessing the model configuration
  51. >>> configuration = model.config
  52. ```"""
  53. model_type = "pixio"
  54. hidden_size: int = 1280
  55. num_hidden_layers: int = 32
  56. num_attention_heads: int = 16
  57. n_cls_tokens: int = 8
  58. image_size: int | list[int] | tuple[int, int] = 256
  59. patch_size: int | list[int] | tuple[int, int] = 16
  60. layerscale_value = AttributeError()
  61. use_swiglu_ffn = AttributeError()
  62. use_mask_token = AttributeError()
  63. class PixioPatchEmbeddings(ViTPatchEmbeddings):
  64. pass
  65. class PixioEmbeddings(nn.Module):
  66. """
  67. Construct the CLS tokens, position and patch embeddings.
  68. """
  69. def __init__(self, config: PixioConfig) -> None:
  70. super().__init__()
  71. self.cls_token = nn.Parameter(torch.randn(1, config.n_cls_tokens, config.hidden_size))
  72. self.mask_token = None
  73. self.patch_embeddings = PixioPatchEmbeddings(config)
  74. num_patches = self.patch_embeddings.num_patches
  75. self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + config.n_cls_tokens, config.hidden_size))
  76. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  77. self.n_cls_tokens = config.n_cls_tokens
  78. self.patch_size = config.patch_size
  79. self.config = config
  80. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  81. """
  82. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  83. images. This method is also adapted to support tracing and interpolation at torch.float32 precision.
  84. Adapted from:
  85. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  86. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  87. """
  88. num_patches = embeddings.shape[1] - self.n_cls_tokens
  89. num_positions = self.position_embeddings.shape[1] - self.n_cls_tokens
  90. if not is_tracing() and num_patches == num_positions and height == width:
  91. return self.position_embeddings
  92. class_pos_embed = self.position_embeddings[:, : self.n_cls_tokens]
  93. patch_pos_embed = self.position_embeddings[:, self.n_cls_tokens :]
  94. dim = embeddings.shape[-1]
  95. new_height = height // self.patch_size
  96. new_width = width // self.patch_size
  97. sqrt_num_positions = int(num_positions**0.5)
  98. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  99. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  100. target_dtype = patch_pos_embed.dtype
  101. patch_pos_embed = nn.functional.interpolate(
  102. patch_pos_embed.to(torch.float32),
  103. size=(new_height, new_width),
  104. mode="bicubic",
  105. align_corners=False,
  106. ).to(dtype=target_dtype)
  107. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  108. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  109. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  110. batch_size, _, height, width = pixel_values.shape
  111. target_dtype = self.patch_embeddings.projection.weight.dtype
  112. embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
  113. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  114. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  115. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  116. embeddings = self.dropout(embeddings)
  117. return embeddings
  118. class PixioSelfAttention(ViTSelfAttention):
  119. pass
  120. class PixioAttention(ViTAttention):
  121. def __init__(self, config: PixioConfig):
  122. super().__init__(config)
  123. self.attention = PixioSelfAttention(config)
  124. class PixioDropPath(Dinov2DropPath):
  125. pass
  126. class PixioMLP(Dinov2MLP):
  127. pass
  128. class PixioLayer(GradientCheckpointingLayer):
  129. def __init__(self, config: PixioConfig) -> None:
  130. super().__init__()
  131. self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  132. self.attention = PixioAttention(config)
  133. self.drop_path = PixioDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  134. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  135. self.mlp = PixioMLP(config)
  136. def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> torch.Tensor:
  137. hidden_states_norm = self.norm1(hidden_states)
  138. self_attention_output = self.attention(hidden_states_norm, **kwargs)
  139. hidden_states = self.drop_path(self_attention_output) + hidden_states
  140. layer_output = self.norm2(hidden_states)
  141. layer_output = self.mlp(layer_output)
  142. layer_output = self.drop_path(layer_output) + hidden_states
  143. return layer_output
  144. class PixioPreTrainedModel(ViTPreTrainedModel):
  145. _can_record_outputs = {
  146. "hidden_states": PixioLayer,
  147. "attentions": PixioSelfAttention,
  148. }
  149. class PixioEncoder(PixioPreTrainedModel):
  150. def __init__(self, config: PixioConfig):
  151. super().__init__(config)
  152. self.layer = nn.ModuleList([PixioLayer(config) for _ in range(config.num_hidden_layers)])
  153. self.gradient_checkpointing = False
  154. self.post_init()
  155. @merge_with_config_defaults
  156. @capture_outputs(tie_last_hidden_states=False)
  157. def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BaseModelOutput:
  158. for layer_module in self.layer:
  159. hidden_states = layer_module(hidden_states, **kwargs)
  160. return BaseModelOutput(last_hidden_state=hidden_states)
  161. @auto_docstring
  162. class PixioModel(PixioPreTrainedModel):
  163. def __init__(self, config: PixioConfig):
  164. super().__init__(config)
  165. self.config = config
  166. self.embeddings = PixioEmbeddings(config)
  167. self.encoder = PixioEncoder(config)
  168. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  169. self.post_init()
  170. def get_input_embeddings(self) -> PixioPatchEmbeddings:
  171. return self.embeddings.patch_embeddings
  172. @can_return_tuple
  173. @auto_docstring
  174. def forward(
  175. self,
  176. pixel_values: torch.Tensor | None = None,
  177. **kwargs: Unpack[TransformersKwargs],
  178. ) -> BaseModelOutputWithPooling:
  179. if pixel_values is None:
  180. raise ValueError("You have to specify pixel_values")
  181. embedding_output = self.embeddings(pixel_values)
  182. encoder_outputs: BaseModelOutput = self.encoder(embedding_output, **kwargs)
  183. sequence_output = encoder_outputs.last_hidden_state
  184. sequence_output = self.layernorm(sequence_output)
  185. pooled_output = sequence_output[:, : self.embeddings.n_cls_tokens, :].mean(dim=1)
  186. return BaseModelOutputWithPooling(
  187. last_hidden_state=sequence_output,
  188. pooler_output=pooled_output,
  189. hidden_states=encoder_outputs.hidden_states,
  190. attentions=encoder_outputs.attentions,
  191. )
  192. @auto_docstring(
  193. custom_intro="""
  194. Pixio backbone, to be used with frameworks like DETR and MaskFormer.
  195. """
  196. )
  197. class PixioBackbone(Dinov2Backbone):
  198. def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BackboneOutput:
  199. r"""
  200. Examples:
  201. ```python
  202. >>> from transformers import AutoImageProcessor, AutoBackbone
  203. >>> import torch
  204. >>> from PIL import Image
  205. >>> import httpx
  206. >>> from io import BytesIO
  207. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  208. >>> with httpx.stream("GET", url) as response:
  209. ... image = Image.open(BytesIO(response.read()))
  210. >>> processor = AutoImageProcessor.from_pretrained("facebook/pixio-huge")
  211. >>> model = AutoBackbone.from_pretrained(
  212. ... "facebook/pixio-huge", out_features=["stage7", "stage15", "stage23", "stage31"]
  213. ... )
  214. >>> inputs = processor(image, return_tensors="pt")
  215. >>> outputs = model(**inputs)
  216. >>> feature_maps = outputs.feature_maps
  217. >>> list(feature_maps[-1].shape)
  218. [1, 1280, 16, 16]
  219. ```"""
  220. kwargs["output_hidden_states"] = True # required to extract layers for the stages
  221. embedding_output = self.embeddings(pixel_values)
  222. output: BaseModelOutput = self.encoder(embedding_output, **kwargs)
  223. hidden_states = output.hidden_states
  224. feature_maps = []
  225. for stage, hidden_state in zip(self.stage_names, hidden_states):
  226. if stage in self.out_features:
  227. if self.config.apply_layernorm:
  228. hidden_state = self.layernorm(hidden_state)
  229. if self.config.reshape_hidden_states:
  230. hidden_state = hidden_state[:, self.embeddings.n_cls_tokens :]
  231. batch_size, _, height, width = pixel_values.shape
  232. patch_size = self.config.patch_size
  233. hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
  234. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  235. feature_maps.append(hidden_state)
  236. return BackboneOutput(
  237. feature_maps=tuple(feature_maps),
  238. hidden_states=hidden_states,
  239. attentions=output.attentions,
  240. )
  241. __all__ = ["PixioConfig", "PixioModel", "PixioPreTrainedModel", "PixioBackbone"]