modeling_siglip2.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/siglip2/modular_siglip2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_siglip2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from dataclasses import dataclass
  22. from typing import Any
  23. import numpy as np
  24. import torch
  25. import torch.nn as nn
  26. import torch.nn.functional as F
  27. from ... import initialization as init
  28. from ...activations import ACT2FN
  29. from ...masking_utils import create_bidirectional_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check
  35. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
  38. @dataclass
  39. @auto_docstring(
  40. custom_intro="""
  41. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  42. """
  43. )
  44. class Siglip2VisionOutput(ModelOutput):
  45. r"""
  46. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  47. The image embeddings obtained by applying the projection layer to the pooler_output.
  48. """
  49. image_embeds: torch.FloatTensor | None = None
  50. last_hidden_state: torch.FloatTensor | None = None
  51. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  52. attentions: tuple[torch.FloatTensor, ...] | None = None
  53. @dataclass
  54. @auto_docstring(
  55. custom_intro="""
  56. Base class for text model's outputs that also contains a pooling of the last hidden states.
  57. """
  58. )
  59. class Siglip2TextOutput(ModelOutput):
  60. r"""
  61. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  62. The text embeddings obtained by applying the projection layer to the pooler_output.
  63. """
  64. text_embeds: torch.FloatTensor | None = None
  65. last_hidden_state: torch.FloatTensor | None = None
  66. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  67. attentions: tuple[torch.FloatTensor, ...] | None = None
  68. @dataclass
  69. @auto_docstring
  70. class Siglip2Output(ModelOutput):
  71. r"""
  72. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  73. Contrastive loss for image-text similarity.
  74. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  75. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  76. similarity scores.
  77. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  78. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  79. similarity scores.
  80. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  81. The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`].
  82. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  83. The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`].
  84. text_model_output (`BaseModelOutputWithPooling`):
  85. The output of the [`Siglip2TextModel`].
  86. vision_model_output (`BaseModelOutputWithPooling`):
  87. The output of the [`Siglip2VisionModel`].
  88. """
  89. loss: torch.FloatTensor | None = None
  90. logits_per_image: torch.FloatTensor | None = None
  91. logits_per_text: torch.FloatTensor | None = None
  92. text_embeds: torch.FloatTensor | None = None
  93. image_embeds: torch.FloatTensor | None = None
  94. text_model_output: BaseModelOutputWithPooling = None
  95. vision_model_output: BaseModelOutputWithPooling = None
  96. def to_tuple(self) -> tuple[Any]:
  97. return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())
  98. class Siglip2VisionEmbeddings(nn.Module):
  99. def __init__(self, config: Siglip2VisionConfig):
  100. super().__init__()
  101. self.config = config
  102. self.embed_dim = config.hidden_size
  103. self.patch_size = config.patch_size
  104. self.patch_embedding = nn.Linear(
  105. in_features=config.num_channels * self.patch_size * self.patch_size,
  106. out_features=self.embed_dim,
  107. )
  108. self.num_patches = config.num_patches
  109. self.position_embedding_size = int(self.num_patches**0.5)
  110. self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
  111. @staticmethod
  112. def resize_positional_embeddings(
  113. positional_embeddings: torch.Tensor,
  114. spatial_shapes: torch.LongTensor,
  115. max_length: int,
  116. ) -> torch.Tensor:
  117. """
  118. Resize positional embeddings to image-specific size and pad to a fixed size.
  119. Args:
  120. positional_embeddings (`torch.Tensor`):
  121. Position embeddings of shape (height, width, embed_dim)
  122. spatial_shapes (`torch.LongTensor`):
  123. Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
  124. max_length (`int`):
  125. Maximum length of the positional embeddings to pad resized positional embeddings to
  126. Returns:
  127. `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
  128. """
  129. batch_size = spatial_shapes.shape[0]
  130. embed_dim = positional_embeddings.shape[-1]
  131. source_dtype = positional_embeddings.dtype
  132. resulted_positional_embeddings = torch.empty(
  133. (batch_size, max_length, embed_dim),
  134. device=positional_embeddings.device,
  135. dtype=source_dtype,
  136. )
  137. # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
  138. positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
  139. # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
  140. if positional_embeddings.device.type == "cpu":
  141. positional_embeddings = positional_embeddings.to(torch.float32)
  142. for i in range(batch_size):
  143. # (1, dim, height, width) -> (1, dim, target_height, target_width)
  144. height, width = spatial_shapes[i].tolist() # will be itemized in F.interpolate either way
  145. torch_compilable_check((width > 0), "Width of resized positional embeddings must be positive.")
  146. torch_compilable_check((height > 0), "Height of resized positional embeddings must be positive.")
  147. torch_compilable_check((height * width) <= max_length, "Resized positional embeddings exceed max_length.")
  148. resized_embeddings = F.interpolate(
  149. positional_embeddings,
  150. size=(height, width),
  151. mode="bilinear",
  152. align_corners=False,
  153. antialias=True,
  154. )
  155. # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
  156. resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
  157. # Cast to original dtype
  158. resized_embeddings = resized_embeddings.to(source_dtype)
  159. resulted_positional_embeddings[i, : height * width] = resized_embeddings
  160. resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
  161. return resulted_positional_embeddings
  162. def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
  163. """
  164. Args:
  165. pixel_values (`torch.FloatTensor`):
  166. Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
  167. spatial_shapes (`list[tuple[int, int]]`):
  168. Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
  169. """
  170. # Apply patch embeddings to already patchified pixel values
  171. target_dtype = self.patch_embedding.weight.dtype
  172. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  173. # Get positional resized and padded positional embeddings
  174. positional_embeddings = self.position_embedding.weight.reshape(
  175. self.position_embedding_size, self.position_embedding_size, -1
  176. )
  177. resized_positional_embeddings = self.resize_positional_embeddings(
  178. positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
  179. )
  180. # Add positional embeddings to patch embeddings
  181. embeddings = patch_embeds + resized_positional_embeddings
  182. return embeddings
  183. class Siglip2TextEmbeddings(nn.Module):
  184. def __init__(self, config: Siglip2TextConfig):
  185. super().__init__()
  186. embed_dim = config.hidden_size
  187. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  188. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  189. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  190. self.register_buffer(
  191. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  192. )
  193. def forward(
  194. self,
  195. input_ids: torch.LongTensor | None = None,
  196. position_ids: torch.LongTensor | None = None,
  197. inputs_embeds: torch.FloatTensor | None = None,
  198. ) -> torch.Tensor:
  199. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  200. max_position_embedding = self.position_embedding.weight.shape[0]
  201. if seq_length > max_position_embedding:
  202. raise ValueError(
  203. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  204. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  205. )
  206. if position_ids is None:
  207. position_ids = self.position_ids[:, :seq_length]
  208. if inputs_embeds is None:
  209. inputs_embeds = self.token_embedding(input_ids)
  210. position_embeddings = self.position_embedding(position_ids)
  211. embeddings = inputs_embeds + position_embeddings
  212. return embeddings
  213. def eager_attention_forward(
  214. module: nn.Module,
  215. query: torch.Tensor,
  216. key: torch.Tensor,
  217. value: torch.Tensor,
  218. attention_mask: torch.Tensor | None,
  219. scaling: float,
  220. dropout: float = 0.0,
  221. **kwargs,
  222. ):
  223. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  224. if attention_mask is not None:
  225. attn_weights = attn_weights + attention_mask
  226. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  227. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  228. attn_output = torch.matmul(attn_weights, value)
  229. attn_output = attn_output.transpose(1, 2).contiguous()
  230. return attn_output, attn_weights
  231. class Siglip2Attention(nn.Module):
  232. """Multi-headed attention from 'Attention Is All You Need' paper"""
  233. def __init__(self, config):
  234. super().__init__()
  235. self.config = config
  236. self.embed_dim = config.hidden_size
  237. self.num_heads = config.num_attention_heads
  238. self.head_dim = self.embed_dim // self.num_heads
  239. if self.head_dim * self.num_heads != self.embed_dim:
  240. raise ValueError(
  241. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  242. f" {self.num_heads})."
  243. )
  244. self.scale = self.head_dim**-0.5
  245. self.dropout = config.attention_dropout
  246. self.is_causal = False
  247. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  248. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  249. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  250. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  251. def forward(
  252. self,
  253. hidden_states: torch.Tensor,
  254. attention_mask: torch.Tensor | None = None,
  255. **kwargs,
  256. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  257. """Input shape: Batch x Time x Channel"""
  258. input_shape = hidden_states.shape[:-1]
  259. hidden_shape = (*input_shape, -1, self.head_dim)
  260. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  261. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  262. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  263. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  264. self.config._attn_implementation, eager_attention_forward
  265. )
  266. attn_output, attn_weights = attention_interface(
  267. self,
  268. queries,
  269. keys,
  270. values,
  271. attention_mask,
  272. is_causal=self.is_causal,
  273. scaling=self.scale,
  274. dropout=0.0 if not self.training else self.dropout,
  275. )
  276. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  277. attn_output = self.out_proj(attn_output)
  278. return attn_output, attn_weights
  279. class Siglip2MLP(nn.Module):
  280. def __init__(self, config):
  281. super().__init__()
  282. self.config = config
  283. self.activation_fn = ACT2FN[config.hidden_act]
  284. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  285. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  286. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  287. hidden_states = self.fc1(hidden_states)
  288. hidden_states = self.activation_fn(hidden_states)
  289. hidden_states = self.fc2(hidden_states)
  290. return hidden_states
  291. class Siglip2EncoderLayer(GradientCheckpointingLayer):
  292. def __init__(self, config: Siglip2VisionConfig | Siglip2TextConfig):
  293. super().__init__()
  294. self.embed_dim = config.hidden_size
  295. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  296. self.self_attn = Siglip2Attention(config)
  297. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  298. self.mlp = Siglip2MLP(config)
  299. @auto_docstring
  300. def forward(
  301. self,
  302. hidden_states: torch.Tensor,
  303. attention_mask: torch.Tensor,
  304. **kwargs: Unpack[TransformersKwargs],
  305. ) -> torch.FloatTensor:
  306. residual = hidden_states
  307. hidden_states = self.layer_norm1(hidden_states)
  308. hidden_states, _ = self.self_attn(
  309. hidden_states=hidden_states,
  310. attention_mask=attention_mask,
  311. **kwargs,
  312. )
  313. hidden_states = residual + hidden_states
  314. residual = hidden_states
  315. hidden_states = self.layer_norm2(hidden_states)
  316. hidden_states = self.mlp(hidden_states)
  317. hidden_states = residual + hidden_states
  318. return hidden_states
  319. @auto_docstring
  320. class Siglip2PreTrainedModel(PreTrainedModel):
  321. config: Siglip2Config
  322. base_model_prefix = "siglip2"
  323. input_modalities = ("image", "text")
  324. supports_gradient_checkpointing = True
  325. _no_split_modules = [
  326. "Siglip2TextEmbeddings",
  327. "Siglip2VisionEmbeddings",
  328. "Siglip2EncoderLayer",
  329. "Siglip2MultiheadAttentionPoolingHead",
  330. ]
  331. _supports_flash_attn = False
  332. _supports_sdpa = True
  333. # nn.MultiHeadAttention mask doesn't allow for non 4d mask
  334. _supports_flex_attn = False
  335. _supports_attention_backend = True
  336. _can_record_outputs = {
  337. "hidden_states": Siglip2EncoderLayer,
  338. "attentions": Siglip2Attention,
  339. }
  340. @torch.no_grad()
  341. def _init_weights(self, module):
  342. """Initialize the weights"""
  343. if isinstance(module, Siglip2VisionEmbeddings):
  344. width = (
  345. self.config.vision_config.hidden_size
  346. if isinstance(self.config, Siglip2Config)
  347. else self.config.hidden_size
  348. )
  349. init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
  350. if hasattr(module, "position_ids"):
  351. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  352. elif isinstance(module, nn.Embedding):
  353. init.default_flax_embed_init_(module.weight)
  354. elif isinstance(module, Siglip2Attention):
  355. init.xavier_uniform_(module.q_proj.weight)
  356. init.xavier_uniform_(module.k_proj.weight)
  357. init.xavier_uniform_(module.v_proj.weight)
  358. init.xavier_uniform_(module.out_proj.weight)
  359. init.zeros_(module.q_proj.bias)
  360. init.zeros_(module.k_proj.bias)
  361. init.zeros_(module.v_proj.bias)
  362. init.zeros_(module.out_proj.bias)
  363. elif isinstance(module, Siglip2MLP):
  364. init.xavier_uniform_(module.fc1.weight)
  365. init.xavier_uniform_(module.fc2.weight)
  366. init.normal_(module.fc1.bias, std=1e-6)
  367. init.normal_(module.fc2.bias, std=1e-6)
  368. elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
  369. init.xavier_uniform_(module.probe)
  370. init.xavier_uniform_(module.attention.in_proj_weight)
  371. init.zeros_(module.attention.in_proj_bias)
  372. elif isinstance(module, Siglip2Model):
  373. init.zeros_(module.logit_scale)
  374. init.zeros_(module.logit_bias)
  375. elif isinstance(module, Siglip2ForImageClassification):
  376. init.normal_(
  377. module.classifier.weight,
  378. std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
  379. )
  380. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  381. init.lecun_normal_(module.weight)
  382. if module.bias is not None:
  383. init.zeros_(module.bias)
  384. elif isinstance(module, nn.LayerNorm):
  385. init.zeros_(module.bias)
  386. init.ones_(module.weight)
  387. elif isinstance(module, Siglip2TextEmbeddings):
  388. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  389. class Siglip2Encoder(nn.Module):
  390. """
  391. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  392. [`Siglip2EncoderLayer`].
  393. Args:
  394. config: Siglip2Config
  395. """
  396. def __init__(self, config: Siglip2Config):
  397. super().__init__()
  398. self.config = config
  399. self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  400. self.gradient_checkpointing = False
  401. # Ignore copy
  402. @auto_docstring
  403. def forward(
  404. self,
  405. inputs_embeds,
  406. attention_mask: torch.Tensor | None = None,
  407. **kwargs: Unpack[TransformersKwargs],
  408. ) -> BaseModelOutput:
  409. hidden_states = inputs_embeds
  410. for encoder_layer in self.layers:
  411. hidden_states = encoder_layer(
  412. hidden_states,
  413. attention_mask,
  414. **kwargs,
  415. )
  416. return BaseModelOutput(last_hidden_state=hidden_states)
  417. class Siglip2VisionTransformer(Siglip2PreTrainedModel):
  418. _input_embed_layer = "patch_embedding"
  419. def __init__(self, config: Siglip2VisionConfig):
  420. super().__init__(config)
  421. self.config = config
  422. embed_dim = config.hidden_size
  423. self.embeddings = Siglip2VisionEmbeddings(config)
  424. self.encoder = Siglip2Encoder(config)
  425. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  426. self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
  427. if self.use_head:
  428. self.head = Siglip2MultiheadAttentionPoolingHead(config)
  429. self.post_init()
  430. @merge_with_config_defaults
  431. @capture_outputs(tie_last_hidden_states=False)
  432. @auto_docstring
  433. def forward(
  434. self,
  435. pixel_values: torch.FloatTensor,
  436. attention_mask: torch.Tensor,
  437. spatial_shapes: torch.LongTensor,
  438. **kwargs: Unpack[TransformersKwargs],
  439. ) -> BaseModelOutputWithPooling:
  440. r"""
  441. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  442. Tensor containing the spatial dimensions (height, width) of the input images.
  443. """
  444. hidden_states = self.embeddings(pixel_values, spatial_shapes)
  445. encoder_attention_mask = create_bidirectional_mask(
  446. config=self.config,
  447. inputs_embeds=hidden_states,
  448. attention_mask=attention_mask,
  449. )
  450. encoder_outputs: BaseModelOutput = self.encoder(
  451. inputs_embeds=hidden_states,
  452. attention_mask=encoder_attention_mask,
  453. **kwargs,
  454. )
  455. last_hidden_state = encoder_outputs.last_hidden_state
  456. last_hidden_state = self.post_layernorm(last_hidden_state)
  457. pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
  458. return BaseModelOutputWithPooling(
  459. last_hidden_state=last_hidden_state,
  460. pooler_output=pooler_output,
  461. )
  462. class Siglip2TextTransformer(Siglip2PreTrainedModel):
  463. _input_embed_layer = "token_embedding"
  464. def __init__(self, config: Siglip2TextConfig):
  465. super().__init__(config)
  466. self.config = config
  467. embed_dim = config.hidden_size
  468. self.embeddings = Siglip2TextEmbeddings(config)
  469. self.encoder = Siglip2Encoder(config)
  470. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  471. self.head = nn.Linear(embed_dim, config.projection_size)
  472. self.post_init()
  473. @can_return_tuple
  474. @auto_docstring
  475. def forward(
  476. self,
  477. input_ids: torch.Tensor | None = None,
  478. attention_mask: torch.Tensor | None = None,
  479. position_ids: torch.Tensor | None = None,
  480. **kwargs: Unpack[TransformersKwargs],
  481. ) -> BaseModelOutputWithPooling:
  482. if input_ids is None:
  483. raise ValueError("You have to specify input_ids")
  484. input_shape = input_ids.size()
  485. input_ids = input_ids.view(-1, input_shape[-1])
  486. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  487. # note: Siglip2's text model does not use a causal mask, unlike the original CLIP model.
  488. attention_mask = create_bidirectional_mask(
  489. config=self.config,
  490. inputs_embeds=hidden_states,
  491. attention_mask=attention_mask,
  492. )
  493. encoder_outputs: BaseModelOutput = self.encoder(
  494. inputs_embeds=hidden_states,
  495. attention_mask=attention_mask,
  496. **kwargs,
  497. )
  498. last_hidden_state = encoder_outputs.last_hidden_state
  499. last_hidden_state = self.final_layer_norm(last_hidden_state)
  500. # The model uses the last token's hidden state, which may be padding.
  501. pooled_output = last_hidden_state[:, -1, :]
  502. pooled_output = self.head(pooled_output)
  503. return BaseModelOutputWithPooling(
  504. last_hidden_state=last_hidden_state,
  505. pooler_output=pooled_output,
  506. )
  507. @auto_docstring(
  508. custom_intro="""
  509. The text model from Siglip2 without any head or projection on top.
  510. """
  511. )
  512. class Siglip2TextModel(Siglip2PreTrainedModel):
  513. config: Siglip2TextConfig
  514. input_modalities = ("text",)
  515. def __init__(self, config: Siglip2TextConfig):
  516. super().__init__(config)
  517. self.text_model = Siglip2TextTransformer(config)
  518. # Initialize weights and apply final processing
  519. self.post_init()
  520. def get_input_embeddings(self) -> nn.Module:
  521. return self.text_model.embeddings.token_embedding
  522. def set_input_embeddings(self, value):
  523. self.text_model.embeddings.token_embedding = value
  524. @merge_with_config_defaults
  525. @capture_outputs(tie_last_hidden_states=False)
  526. @auto_docstring
  527. def forward(
  528. self,
  529. input_ids: torch.Tensor | None = None,
  530. attention_mask: torch.Tensor | None = None,
  531. position_ids: torch.Tensor | None = None,
  532. **kwargs: Unpack[TransformersKwargs],
  533. ) -> BaseModelOutputWithPooling:
  534. r"""
  535. Examples:
  536. ```python
  537. >>> from transformers import AutoTokenizer, Siglip2TextModel
  538. >>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224")
  539. >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")
  540. >>> # important: make sure to set padding="max_length" as that's how the model was trained
  541. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
  542. >>> outputs = model(**inputs)
  543. >>> last_hidden_state = outputs.last_hidden_state
  544. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  545. ```"""
  546. return self.text_model(
  547. input_ids=input_ids,
  548. attention_mask=attention_mask,
  549. position_ids=position_ids,
  550. **kwargs,
  551. )
  552. class Siglip2MultiheadAttentionPoolingHead(nn.Module):
  553. """Multihead Attention Pooling."""
  554. def __init__(self, config: Siglip2VisionConfig):
  555. super().__init__()
  556. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  557. self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
  558. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  559. self.mlp = Siglip2MLP(config)
  560. self.config = config
  561. self.num_heads = config.num_attention_heads
  562. def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
  563. batch_size = hidden_state.shape[0]
  564. probe = self.probe.repeat(batch_size, 1, 1)
  565. if attention_mask is not None:
  566. target_len, source_len = probe.shape[1], hidden_state.shape[1]
  567. attention_mask = create_bidirectional_mask(
  568. config=self.config,
  569. inputs_embeds=probe,
  570. attention_mask=attention_mask,
  571. encoder_hidden_states=hidden_state,
  572. )
  573. if attention_mask is not None:
  574. attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
  575. attention_mask = attention_mask.reshape(-1, target_len, source_len)
  576. # `nn.MultiheadAttention` cannot handle boolean masks (which SDPA can)
  577. if attention_mask.dtype == torch.bool:
  578. attention_mask = torch.where(
  579. attention_mask,
  580. torch.tensor(0.0, device=attention_mask.device, dtype=probe.dtype),
  581. torch.finfo(probe.dtype).min,
  582. )
  583. hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
  584. residual = hidden_state
  585. hidden_state = self.layernorm(hidden_state)
  586. hidden_state = residual + self.mlp(hidden_state)
  587. return hidden_state[:, 0]
  588. @auto_docstring(
  589. custom_intro="""
  590. The vision model from Siglip2 without any head or projection on top.
  591. """
  592. )
  593. class Siglip2VisionModel(Siglip2PreTrainedModel):
  594. config: Siglip2VisionConfig
  595. main_input_name = "pixel_values"
  596. input_modalities = ("image",)
  597. def __init__(self, config: Siglip2VisionConfig):
  598. super().__init__(config)
  599. self.vision_model = Siglip2VisionTransformer(config)
  600. # Initialize weights and apply final processing
  601. self.post_init()
  602. def get_input_embeddings(self) -> nn.Module:
  603. return self.vision_model.embeddings.patch_embedding
  604. @can_return_tuple
  605. @auto_docstring
  606. def forward(
  607. self,
  608. pixel_values: torch.FloatTensor,
  609. pixel_attention_mask: torch.Tensor,
  610. spatial_shapes: torch.LongTensor,
  611. **kwargs: Unpack[TransformersKwargs],
  612. ) -> BaseModelOutputWithPooling:
  613. r"""
  614. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  615. Mask to avoid performing attention on padding pixel indices.
  616. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  617. Tensor containing the spatial dimensions (height, width) of the input images.
  618. Examples:
  619. ```python
  620. >>> from PIL import Image
  621. >>> import httpx
  622. >>> from io import BytesIO
  623. >>> from transformers import AutoProcessor, Siglip2VisionModel
  624. >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
  625. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  626. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  627. >>> with httpx.stream("GET", url) as response:
  628. ... image = Image.open(BytesIO(response.read()))
  629. >>> inputs = processor(images=image, return_tensors="pt")
  630. >>> outputs = model(**inputs)
  631. >>> last_hidden_state = outputs.last_hidden_state
  632. >>> pooled_output = outputs.pooler_output # pooled features
  633. ```"""
  634. return self.vision_model(
  635. pixel_values=pixel_values,
  636. attention_mask=pixel_attention_mask,
  637. spatial_shapes=spatial_shapes,
  638. **kwargs,
  639. )
  640. @auto_docstring
  641. class Siglip2Model(Siglip2PreTrainedModel):
  642. config: Siglip2Config
  643. def __init__(self, config: Siglip2Config):
  644. super().__init__(config)
  645. if not isinstance(config.text_config, Siglip2TextConfig):
  646. raise TypeError(
  647. "config.text_config is expected to be of type Siglip2TextConfig but is of type"
  648. f" {type(config.text_config)}."
  649. )
  650. if not isinstance(config.vision_config, Siglip2VisionConfig):
  651. raise TypeError(
  652. "config.vision_config is expected to be of type Siglip2VisionConfig but is of type"
  653. f" {type(config.vision_config)}."
  654. )
  655. text_config = config.text_config
  656. vision_config = config.vision_config
  657. # First, initialize the text and vision models with proper attention implementation
  658. text_model = Siglip2TextModel._from_config(text_config)
  659. vision_model = Siglip2VisionModel._from_config(vision_config)
  660. # Second, get the text and vision submodules (for backward compatibility)
  661. self.text_model = text_model.text_model
  662. self.vision_model = vision_model.vision_model
  663. self.logit_scale = nn.Parameter(torch.randn(1))
  664. self.logit_bias = nn.Parameter(torch.randn(1))
  665. # Initialize weights and apply final processing
  666. self.post_init()
  667. def get_input_embeddings(self) -> nn.Module:
  668. return self.text_model.embeddings.token_embedding
  669. def set_input_embeddings(self, value: nn.Module):
  670. self.text_model.embeddings.token_embedding = value
  671. @can_return_tuple
  672. @auto_docstring
  673. def get_text_features(
  674. self,
  675. input_ids: torch.Tensor,
  676. attention_mask: torch.Tensor | None = None,
  677. position_ids: torch.Tensor | None = None,
  678. **kwargs: Unpack[TransformersKwargs],
  679. ) -> tuple | BaseModelOutputWithPooling:
  680. r"""
  681. Examples:
  682. ```python
  683. >>> from transformers import AutoTokenizer, AutoModel
  684. >>> import torch
  685. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  686. >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")
  687. >>> # important: make sure to set padding="max_length" as that's how the model was trained
  688. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
  689. >>> with torch.no_grad():
  690. ... text_features = model.get_text_features(**inputs)
  691. ```"""
  692. return self.text_model(
  693. input_ids=input_ids,
  694. attention_mask=attention_mask,
  695. position_ids=position_ids,
  696. **kwargs,
  697. )
  698. @can_return_tuple
  699. @auto_docstring
  700. def get_image_features(
  701. self,
  702. pixel_values: torch.FloatTensor | None = None,
  703. pixel_attention_mask: torch.Tensor | None = None,
  704. spatial_shapes: torch.LongTensor | None = None,
  705. **kwargs: Unpack[TransformersKwargs],
  706. ) -> tuple | BaseModelOutputWithPooling:
  707. r"""
  708. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  709. Mask to avoid performing attention on padding pixel indices.
  710. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  711. Tensor containing the spatial dimensions (height, width) of the input images.
  712. Examples:
  713. ```python
  714. >>> import torch
  715. >>> from transformers import AutoProcessor, AutoModel
  716. >>> from transformers.image_utils import load_image
  717. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  718. >>> image = load_image(url)
  719. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  720. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  721. >>> inputs = processor(images=image, return_tensors="pt")
  722. >>> with torch.no_grad():
  723. ... image_features = model.get_image_features(**inputs)
  724. ```
  725. """
  726. return self.vision_model(
  727. pixel_values=pixel_values,
  728. attention_mask=pixel_attention_mask,
  729. spatial_shapes=spatial_shapes,
  730. **kwargs,
  731. )
  732. # NOTE: Siglip2Model uses Pretrained backbones, so we don't need to add `capture_outputs` here
  733. @can_return_tuple
  734. @auto_docstring
  735. def forward(
  736. self,
  737. input_ids: torch.LongTensor | None = None,
  738. pixel_values: torch.FloatTensor | None = None,
  739. pixel_attention_mask: torch.Tensor | None = None,
  740. spatial_shapes: torch.LongTensor | None = None,
  741. attention_mask: torch.Tensor | None = None,
  742. position_ids: torch.LongTensor | None = None,
  743. return_loss: bool | None = None,
  744. **kwargs: Unpack[TransformersKwargs],
  745. ) -> Siglip2Output:
  746. r"""
  747. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  748. Mask to avoid performing attention on padding pixel indices.
  749. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  750. Tensor containing the spatial dimensions (height, width) of the input images.
  751. return_loss (`bool`, *optional*):
  752. Whether or not to return the contrastive loss.
  753. Examples:
  754. ```python
  755. >>> from PIL import Image
  756. >>> import httpx
  757. >>> from io import BytesIO
  758. >>> from transformers import AutoProcessor, AutoModel
  759. >>> import torch
  760. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  761. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  762. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  763. >>> with httpx.stream("GET", url) as response:
  764. ... image = Image.open(BytesIO(response.read()))
  765. >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
  766. >>> # important: we pass `padding=max_length` since the model was trained with this
  767. >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
  768. >>> with torch.no_grad():
  769. ... outputs = model(**inputs)
  770. >>> logits_per_image = outputs.logits_per_image
  771. >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
  772. >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
  773. 31.9% that image 0 is 'a photo of 2 cats'
  774. ```
  775. """
  776. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  777. pixel_values=pixel_values,
  778. attention_mask=pixel_attention_mask,
  779. spatial_shapes=spatial_shapes,
  780. **kwargs,
  781. )
  782. text_outputs: BaseModelOutputWithPooling = self.text_model(
  783. input_ids=input_ids,
  784. attention_mask=attention_mask,
  785. position_ids=position_ids,
  786. **kwargs,
  787. )
  788. image_embeds = vision_outputs.pooler_output
  789. text_embeds = text_outputs.pooler_output
  790. # normalized features
  791. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  792. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  793. # cosine similarity as logits
  794. logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
  795. logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
  796. logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
  797. logits_per_image = logits_per_text.t()
  798. loss = None
  799. if return_loss:
  800. # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
  801. eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
  802. m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
  803. loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
  804. nll = -torch.sum(loglik, dim=-1)
  805. loss = nll.mean()
  806. return Siglip2Output(
  807. loss=loss,
  808. logits_per_image=logits_per_image,
  809. logits_per_text=logits_per_text,
  810. text_embeds=text_embeds,
  811. image_embeds=image_embeds,
  812. text_model_output=text_outputs,
  813. vision_model_output=vision_outputs,
  814. )
  815. @auto_docstring(
  816. custom_intro="""
  817. Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
  818. the patch tokens) e.g. for ImageNet.
  819. """
  820. )
  821. class Siglip2ForImageClassification(Siglip2PreTrainedModel):
  822. main_input_name = "pixel_values"
  823. input_modalities = ("image",)
  824. def __init__(self, config: Siglip2Config) -> None:
  825. super().__init__(config)
  826. self.num_labels = config.num_labels
  827. # Create the vision model with proper attention
  828. # and take only vision_model submodule (for backward compatibility)
  829. vision_model = Siglip2VisionModel._from_config(config.vision_config)
  830. self.vision_model = vision_model.vision_model
  831. # Classifier head
  832. self.classifier = (
  833. nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  834. )
  835. # Initialize weights and apply final processing
  836. self.post_init()
  837. def get_input_embeddings(self) -> nn.Module:
  838. return self.vision_model.embeddings.patch_embedding
  839. def set_input_embeddings(self, value: nn.Module):
  840. self.vision_model.embeddings.patch_embedding = value
  841. @can_return_tuple
  842. @auto_docstring
  843. def forward(
  844. self,
  845. pixel_values: torch.Tensor | None = None,
  846. pixel_attention_mask: torch.Tensor | None = None,
  847. spatial_shapes: torch.LongTensor | None = None,
  848. labels: torch.Tensor | None = None,
  849. **kwargs: Unpack[TransformersKwargs],
  850. ) -> ImageClassifierOutput:
  851. r"""
  852. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  853. Mask to avoid performing attention on padding pixel indices.
  854. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  855. Tensor containing the spatial dimensions (height, width) of the input images.
  856. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  857. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  858. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  859. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  860. Examples:
  861. ```python
  862. >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
  863. >>> import torch
  864. >>> from PIL import Image
  865. >>> import httpx
  866. >>> from io import BytesIO
  867. >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
  868. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  869. >>> with httpx.stream("GET", url) as response:
  870. ... image = Image.open(BytesIO(response.read()))
  871. >>> # note: we are loading a `Siglip2Model` from the hub here,
  872. >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
  873. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
  874. >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")
  875. >>> inputs = image_processor(images=image, return_tensors="pt")
  876. >>> outputs = model(**inputs)
  877. >>> logits = outputs.logits
  878. >>> # model predicts one of the two classes
  879. >>> predicted_class_idx = logits.argmax(-1).item()
  880. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  881. Predicted class: LABEL_1
  882. ```
  883. """
  884. outputs: BaseModelOutputWithPooling = self.vision_model(
  885. pixel_values,
  886. attention_mask=pixel_attention_mask,
  887. spatial_shapes=spatial_shapes,
  888. **kwargs,
  889. )
  890. sequence_output = outputs.last_hidden_state
  891. # average pool the patch tokens
  892. if pixel_attention_mask is not None:
  893. pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
  894. sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
  895. else:
  896. sequence_output = torch.mean(sequence_output, dim=1)
  897. # apply classifier
  898. logits = self.classifier(sequence_output)
  899. loss = None
  900. if labels is not None:
  901. loss = self.loss_function(labels, logits, self.config)
  902. return ImageClassifierOutput(
  903. loss=loss,
  904. logits=logits,
  905. hidden_states=outputs.hidden_states,
  906. attentions=outputs.attentions,
  907. )
  908. __all__ = [
  909. "Siglip2Model",
  910. "Siglip2PreTrainedModel",
  911. "Siglip2TextModel",
  912. "Siglip2VisionModel",
  913. "Siglip2ForImageClassification",
  914. ]