modeling_pixio.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/pixio/modular_pixio.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_pixio.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Meta AI and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import collections.abc
  21. from collections.abc import Callable
  22. import torch
  23. from torch import nn
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, is_tracing
  32. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  33. from ...utils.output_capturing import capture_outputs
  34. from .configuration_pixio import PixioConfig
  35. class PixioPatchEmbeddings(nn.Module):
  36. """
  37. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  38. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  39. Transformer.
  40. """
  41. def __init__(self, config: PixioConfig):
  42. super().__init__()
  43. image_size, patch_size = config.image_size, config.patch_size
  44. num_channels, hidden_size = config.num_channels, config.hidden_size
  45. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  46. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  47. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  48. self.image_size = image_size
  49. self.patch_size = patch_size
  50. self.num_channels = num_channels
  51. self.num_patches = num_patches
  52. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  53. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  54. batch_size, num_channels, height, width = pixel_values.shape
  55. if num_channels != self.num_channels:
  56. raise ValueError(
  57. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  58. f" Expected {self.num_channels} but got {num_channels}."
  59. )
  60. if not interpolate_pos_encoding:
  61. if height != self.image_size[0] or width != self.image_size[1]:
  62. raise ValueError(
  63. f"Input image size ({height}*{width}) doesn't match model"
  64. f" ({self.image_size[0]}*{self.image_size[1]})."
  65. )
  66. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  67. return embeddings
  68. class PixioEmbeddings(nn.Module):
  69. """
  70. Construct the CLS tokens, position and patch embeddings.
  71. """
  72. def __init__(self, config: PixioConfig) -> None:
  73. super().__init__()
  74. self.cls_token = nn.Parameter(torch.randn(1, config.n_cls_tokens, config.hidden_size))
  75. self.mask_token = None
  76. self.patch_embeddings = PixioPatchEmbeddings(config)
  77. num_patches = self.patch_embeddings.num_patches
  78. self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + config.n_cls_tokens, config.hidden_size))
  79. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  80. self.n_cls_tokens = config.n_cls_tokens
  81. self.patch_size = config.patch_size
  82. self.config = config
  83. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  84. """
  85. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  86. images. This method is also adapted to support tracing and interpolation at torch.float32 precision.
  87. Adapted from:
  88. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  89. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  90. """
  91. num_patches = embeddings.shape[1] - self.n_cls_tokens
  92. num_positions = self.position_embeddings.shape[1] - self.n_cls_tokens
  93. if not is_tracing() and num_patches == num_positions and height == width:
  94. return self.position_embeddings
  95. class_pos_embed = self.position_embeddings[:, : self.n_cls_tokens]
  96. patch_pos_embed = self.position_embeddings[:, self.n_cls_tokens :]
  97. dim = embeddings.shape[-1]
  98. new_height = height // self.patch_size
  99. new_width = width // self.patch_size
  100. sqrt_num_positions = int(num_positions**0.5)
  101. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  102. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  103. target_dtype = patch_pos_embed.dtype
  104. patch_pos_embed = nn.functional.interpolate(
  105. patch_pos_embed.to(torch.float32),
  106. size=(new_height, new_width),
  107. mode="bicubic",
  108. align_corners=False,
  109. ).to(dtype=target_dtype)
  110. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  111. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  112. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  113. batch_size, _, height, width = pixel_values.shape
  114. target_dtype = self.patch_embeddings.projection.weight.dtype
  115. embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
  116. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  117. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  118. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  119. embeddings = self.dropout(embeddings)
  120. return embeddings
  121. def eager_attention_forward(
  122. module: nn.Module,
  123. query: torch.Tensor,
  124. key: torch.Tensor,
  125. value: torch.Tensor,
  126. attention_mask: torch.Tensor | None,
  127. scaling: float | None = None,
  128. dropout: float = 0.0,
  129. **kwargs: Unpack[TransformersKwargs],
  130. ):
  131. if scaling is None:
  132. scaling = query.size(-1) ** -0.5
  133. # Take the dot product between "query" and "key" to get the raw attention scores.
  134. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  135. if attention_mask is not None:
  136. attn_weights = attn_weights + attention_mask
  137. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  138. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  139. attn_output = torch.matmul(attn_weights, value)
  140. attn_output = attn_output.transpose(1, 2).contiguous()
  141. return attn_output, attn_weights
  142. class PixioSelfAttention(nn.Module):
  143. def __init__(self, config: PixioConfig):
  144. super().__init__()
  145. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  146. raise ValueError(
  147. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  148. f"heads {config.num_attention_heads}."
  149. )
  150. self.config = config
  151. self.num_attention_heads = config.num_attention_heads
  152. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  153. self.all_head_size = self.num_attention_heads * self.attention_head_size
  154. self.dropout_prob = config.attention_probs_dropout_prob
  155. self.scaling = self.attention_head_size**-0.5
  156. self.is_causal = False
  157. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  158. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  159. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  160. def forward(
  161. self,
  162. hidden_states: torch.Tensor,
  163. **kwargs: Unpack[TransformersKwargs],
  164. ) -> tuple[torch.Tensor, torch.Tensor]:
  165. batch_size = hidden_states.shape[0]
  166. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  167. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  168. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  169. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  170. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  171. self.config._attn_implementation, eager_attention_forward
  172. )
  173. context_layer, attention_probs = attention_interface(
  174. self,
  175. query_layer,
  176. key_layer,
  177. value_layer,
  178. None,
  179. is_causal=self.is_causal,
  180. scaling=self.scaling,
  181. dropout=0.0 if not self.training else self.dropout_prob,
  182. **kwargs,
  183. )
  184. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  185. context_layer = context_layer.reshape(new_context_layer_shape)
  186. return context_layer, attention_probs
  187. class PixioSelfOutput(nn.Module):
  188. """
  189. The residual connection is defined in PixioLayer instead of here (as is the case with other models), due to the
  190. layernorm applied before each block.
  191. """
  192. def __init__(self, config: PixioConfig):
  193. super().__init__()
  194. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  195. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  196. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  197. hidden_states = self.dense(hidden_states)
  198. hidden_states = self.dropout(hidden_states)
  199. return hidden_states
  200. class PixioAttention(nn.Module):
  201. def __init__(self, config: PixioConfig):
  202. super().__init__()
  203. self.attention = PixioSelfAttention(config)
  204. self.output = PixioSelfOutput(config)
  205. def forward(
  206. self,
  207. hidden_states: torch.Tensor,
  208. **kwargs: Unpack[TransformersKwargs],
  209. ) -> torch.Tensor:
  210. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  211. output = self.output(self_attn_output, hidden_states)
  212. return output
  213. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  214. """
  215. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  216. """
  217. if drop_prob == 0.0 or not training:
  218. return input
  219. keep_prob = 1 - drop_prob
  220. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  221. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  222. random_tensor.floor_() # binarize
  223. output = input.div(keep_prob) * random_tensor
  224. return output
  225. class PixioDropPath(nn.Module):
  226. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  227. def __init__(self, drop_prob: float | None = None) -> None:
  228. super().__init__()
  229. self.drop_prob = drop_prob
  230. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  231. return drop_path(hidden_states, self.drop_prob, self.training)
  232. def extra_repr(self) -> str:
  233. return f"p={self.drop_prob}"
  234. class PixioMLP(nn.Module):
  235. def __init__(self, config) -> None:
  236. super().__init__()
  237. in_features = out_features = config.hidden_size
  238. hidden_features = int(config.hidden_size * config.mlp_ratio)
  239. self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
  240. if isinstance(config.hidden_act, str):
  241. self.activation = ACT2FN[config.hidden_act]
  242. else:
  243. self.activation = config.hidden_act
  244. self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
  245. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  246. hidden_state = self.fc1(hidden_state)
  247. hidden_state = self.activation(hidden_state)
  248. hidden_state = self.fc2(hidden_state)
  249. return hidden_state
  250. class PixioLayer(GradientCheckpointingLayer):
  251. def __init__(self, config: PixioConfig) -> None:
  252. super().__init__()
  253. self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  254. self.attention = PixioAttention(config)
  255. self.drop_path = PixioDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  256. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  257. self.mlp = PixioMLP(config)
  258. def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> torch.Tensor:
  259. hidden_states_norm = self.norm1(hidden_states)
  260. self_attention_output = self.attention(hidden_states_norm, **kwargs)
  261. hidden_states = self.drop_path(self_attention_output) + hidden_states
  262. layer_output = self.norm2(hidden_states)
  263. layer_output = self.mlp(layer_output)
  264. layer_output = self.drop_path(layer_output) + hidden_states
  265. return layer_output
  266. @auto_docstring
  267. class PixioPreTrainedModel(PreTrainedModel):
  268. config: PixioConfig
  269. base_model_prefix = "pixio"
  270. main_input_name = "pixel_values"
  271. input_modalities = ("image",)
  272. supports_gradient_checkpointing = True
  273. _no_split_modules = ["PixioEmbeddings", "PixioLayer"]
  274. _supports_sdpa = True
  275. _supports_flash_attn = True
  276. _supports_flex_attn = True
  277. _supports_attention_backend = True
  278. _can_record_outputs = {
  279. "hidden_states": PixioLayer,
  280. "attentions": PixioSelfAttention,
  281. }
  282. @torch.no_grad()
  283. def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm):
  284. """Initialize the weights"""
  285. if isinstance(module, nn.Linear | nn.Conv2d):
  286. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  287. if module.bias is not None:
  288. init.zeros_(module.bias)
  289. elif isinstance(module, nn.LayerNorm):
  290. init.zeros_(module.bias)
  291. init.ones_(module.weight)
  292. elif isinstance(module, PixioEmbeddings):
  293. init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
  294. init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
  295. if module.mask_token is not None:
  296. init.zeros_(module.mask_token)
  297. class PixioEncoder(PixioPreTrainedModel):
  298. def __init__(self, config: PixioConfig):
  299. super().__init__(config)
  300. self.layer = nn.ModuleList([PixioLayer(config) for _ in range(config.num_hidden_layers)])
  301. self.gradient_checkpointing = False
  302. self.post_init()
  303. @merge_with_config_defaults
  304. @capture_outputs(tie_last_hidden_states=False)
  305. def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BaseModelOutput:
  306. for layer_module in self.layer:
  307. hidden_states = layer_module(hidden_states, **kwargs)
  308. return BaseModelOutput(last_hidden_state=hidden_states)
  309. @auto_docstring
  310. class PixioModel(PixioPreTrainedModel):
  311. def __init__(self, config: PixioConfig):
  312. super().__init__(config)
  313. self.config = config
  314. self.embeddings = PixioEmbeddings(config)
  315. self.encoder = PixioEncoder(config)
  316. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  317. self.post_init()
  318. def get_input_embeddings(self) -> PixioPatchEmbeddings:
  319. return self.embeddings.patch_embeddings
  320. @can_return_tuple
  321. @auto_docstring
  322. def forward(
  323. self,
  324. pixel_values: torch.Tensor | None = None,
  325. **kwargs: Unpack[TransformersKwargs],
  326. ) -> BaseModelOutputWithPooling:
  327. if pixel_values is None:
  328. raise ValueError("You have to specify pixel_values")
  329. embedding_output = self.embeddings(pixel_values)
  330. encoder_outputs: BaseModelOutput = self.encoder(embedding_output, **kwargs)
  331. sequence_output = encoder_outputs.last_hidden_state
  332. sequence_output = self.layernorm(sequence_output)
  333. pooled_output = sequence_output[:, : self.embeddings.n_cls_tokens, :].mean(dim=1)
  334. return BaseModelOutputWithPooling(
  335. last_hidden_state=sequence_output,
  336. pooler_output=pooled_output,
  337. hidden_states=encoder_outputs.hidden_states,
  338. attentions=encoder_outputs.attentions,
  339. )
  340. @auto_docstring(
  341. custom_intro="""
  342. Pixio backbone, to be used with frameworks like DETR and MaskFormer.
  343. """
  344. )
  345. class PixioBackbone(BackboneMixin, PixioPreTrainedModel):
  346. def __init__(self, config):
  347. super().__init__(config)
  348. self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
  349. self.embeddings = PixioEmbeddings(config)
  350. self.encoder = PixioEncoder(config)
  351. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  352. # Initialize weights and apply final processing
  353. self.post_init()
  354. def get_input_embeddings(self) -> PixioPatchEmbeddings:
  355. return self.embeddings.patch_embeddings
  356. @can_return_tuple
  357. @filter_output_hidden_states
  358. @auto_docstring
  359. def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BackboneOutput:
  360. r"""
  361. Examples:
  362. ```python
  363. >>> from transformers import AutoImageProcessor, AutoBackbone
  364. >>> import torch
  365. >>> from PIL import Image
  366. >>> import httpx
  367. >>> from io import BytesIO
  368. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  369. >>> with httpx.stream("GET", url) as response:
  370. ... image = Image.open(BytesIO(response.read()))
  371. >>> processor = AutoImageProcessor.from_pretrained("facebook/pixio-huge")
  372. >>> model = AutoBackbone.from_pretrained(
  373. ... "facebook/pixio-huge", out_features=["stage7", "stage15", "stage23", "stage31"]
  374. ... )
  375. >>> inputs = processor(image, return_tensors="pt")
  376. >>> outputs = model(**inputs)
  377. >>> feature_maps = outputs.feature_maps
  378. >>> list(feature_maps[-1].shape)
  379. [1, 1280, 16, 16]
  380. ```"""
  381. kwargs["output_hidden_states"] = True # required to extract layers for the stages
  382. embedding_output = self.embeddings(pixel_values)
  383. output: BaseModelOutput = self.encoder(embedding_output, **kwargs)
  384. hidden_states = output.hidden_states
  385. feature_maps = []
  386. for stage, hidden_state in zip(self.stage_names, hidden_states):
  387. if stage in self.out_features:
  388. if self.config.apply_layernorm:
  389. hidden_state = self.layernorm(hidden_state)
  390. if self.config.reshape_hidden_states:
  391. hidden_state = hidden_state[:, self.embeddings.n_cls_tokens :]
  392. batch_size, _, height, width = pixel_values.shape
  393. patch_size = self.config.patch_size
  394. hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
  395. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  396. feature_maps.append(hidden_state)
  397. return BackboneOutput(
  398. feature_maps=tuple(feature_maps),
  399. hidden_states=hidden_states,
  400. attentions=output.attentions,
  401. )
  402. __all__ = ["PixioModel", "PixioPreTrainedModel", "PixioBackbone"]