modular_siglip2.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from huggingface_hub.dataclasses import strict
  18. from tokenizers import normalizers
  19. from transformers.models.gemma.tokenization_gemma import GemmaTokenizer
  20. from transformers.models.siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
  21. from transformers.models.siglip.modeling_siglip import (
  22. BaseModelOutput,
  23. BaseModelOutputWithPooling,
  24. ImageClassifierOutput,
  25. SiglipForImageClassification,
  26. SiglipModel,
  27. SiglipMultiheadAttentionPoolingHead,
  28. SiglipOutput,
  29. SiglipPreTrainedModel,
  30. SiglipTextModel,
  31. SiglipTextModelOutput,
  32. SiglipVisionModel,
  33. SiglipVisionModelOutput,
  34. SiglipVisionTransformer,
  35. )
  36. from ...masking_utils import create_bidirectional_mask
  37. from ...processing_utils import Unpack
  38. from ...utils import (
  39. TransformersKwargs,
  40. auto_docstring,
  41. torch_compilable_check,
  42. )
  43. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  44. from ...utils.output_capturing import capture_outputs
  45. class Siglip2Tokenizer(GemmaTokenizer):
  46. """
  47. Gemma tokenizer + SigLIP2 training default: lowercase normalization.
  48. """
  49. def __init__(
  50. self,
  51. vocab: str | dict[str, int] | None = None,
  52. merges: str | list[str] | None = None,
  53. unk_token: str = "<unk>",
  54. bos_token: str = "<bos>",
  55. eos_token: str = "<eos>",
  56. pad_token: str = "<pad>",
  57. mask_token: str = "<mask>",
  58. **kwargs,
  59. ):
  60. super().__init__(
  61. vocab=vocab,
  62. merges=merges,
  63. unk_token=unk_token,
  64. bos_token=bos_token,
  65. eos_token=eos_token,
  66. pad_token=pad_token,
  67. mask_token=mask_token,
  68. **kwargs,
  69. )
  70. # Persist for save/load + push_to_hub dynamic tokenizer test
  71. if hasattr(self, "init_kwargs") and isinstance(self.init_kwargs, dict):
  72. self.init_kwargs.setdefault("tokenizer_class", self.__class__.__name__)
  73. backend = getattr(self, "_tokenizer", None)
  74. if backend is not None and backend.normalizer is not None:
  75. backend.normalizer = normalizers.Sequence([normalizers.Lowercase(), backend.normalizer])
  76. @auto_docstring(checkpoint="google/siglip2-base-patch16-naflex")
  77. @strict
  78. class Siglip2TextConfig(SiglipTextConfig):
  79. pass
  80. @auto_docstring(checkpoint="google/siglip2-base-patch16-naflex")
  81. @strict
  82. class Siglip2VisionConfig(SiglipVisionConfig):
  83. r"""
  84. num_patches (`int`, *optional*, defaults to 256):
  85. The number of patches in the image with the size of (`patch_size`, `patch_size`).
  86. The image is resized to fill maximum of this number of patches, and to preserve
  87. the aspect ratio. In case the resulted number of patches is lower, the image is
  88. padded in "patch" dimension.
  89. Example:
  90. ```python
  91. >>> from transformers import Siglip2VisionConfig, Siglip2VisionModel
  92. >>> # Initializing a Siglip2VisionConfig with google/siglip2-base-patch16-naflex style configuration
  93. >>> configuration = Siglip2VisionConfig()
  94. >>> # Initializing a Siglip2VisionModel (with random weights) from the google/siglip2-base-patch16-naflex style configuration
  95. >>> model = Siglip2VisionModel(configuration)
  96. >>> # Accessing the model configuration
  97. >>> configuration = model.config
  98. ```"""
  99. num_patches: int = 256
  100. image_size = AttributeError()
  101. @auto_docstring(checkpoint="google/siglip2-base-patch16-naflex")
  102. @strict
  103. class Siglip2Config(SiglipConfig):
  104. pass
  105. class Siglip2VisionOutput(SiglipVisionModelOutput):
  106. pass
  107. class Siglip2TextOutput(SiglipTextModelOutput):
  108. pass
  109. class Siglip2Output(SiglipOutput):
  110. pass
  111. class Siglip2VisionEmbeddings(nn.Module):
  112. def __init__(self, config: Siglip2VisionConfig):
  113. super().__init__()
  114. self.config = config
  115. self.embed_dim = config.hidden_size
  116. self.patch_size = config.patch_size
  117. self.patch_embedding = nn.Linear(
  118. in_features=config.num_channels * self.patch_size * self.patch_size,
  119. out_features=self.embed_dim,
  120. )
  121. self.num_patches = config.num_patches
  122. self.position_embedding_size = int(self.num_patches**0.5)
  123. self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
  124. @staticmethod
  125. def resize_positional_embeddings(
  126. positional_embeddings: torch.Tensor,
  127. spatial_shapes: torch.LongTensor,
  128. max_length: int,
  129. ) -> torch.Tensor:
  130. """
  131. Resize positional embeddings to image-specific size and pad to a fixed size.
  132. Args:
  133. positional_embeddings (`torch.Tensor`):
  134. Position embeddings of shape (height, width, embed_dim)
  135. spatial_shapes (`torch.LongTensor`):
  136. Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
  137. max_length (`int`):
  138. Maximum length of the positional embeddings to pad resized positional embeddings to
  139. Returns:
  140. `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
  141. """
  142. batch_size = spatial_shapes.shape[0]
  143. embed_dim = positional_embeddings.shape[-1]
  144. source_dtype = positional_embeddings.dtype
  145. resulted_positional_embeddings = torch.empty(
  146. (batch_size, max_length, embed_dim),
  147. device=positional_embeddings.device,
  148. dtype=source_dtype,
  149. )
  150. # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
  151. positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
  152. # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
  153. if positional_embeddings.device.type == "cpu":
  154. positional_embeddings = positional_embeddings.to(torch.float32)
  155. for i in range(batch_size):
  156. # (1, dim, height, width) -> (1, dim, target_height, target_width)
  157. height, width = spatial_shapes[i].tolist() # will be itemized in F.interpolate either way
  158. torch_compilable_check((width > 0), "Width of resized positional embeddings must be positive.")
  159. torch_compilable_check((height > 0), "Height of resized positional embeddings must be positive.")
  160. torch_compilable_check((height * width) <= max_length, "Resized positional embeddings exceed max_length.")
  161. resized_embeddings = F.interpolate(
  162. positional_embeddings,
  163. size=(height, width),
  164. mode="bilinear",
  165. align_corners=False,
  166. antialias=True,
  167. )
  168. # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
  169. resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
  170. # Cast to original dtype
  171. resized_embeddings = resized_embeddings.to(source_dtype)
  172. resulted_positional_embeddings[i, : height * width] = resized_embeddings
  173. resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
  174. return resulted_positional_embeddings
  175. def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
  176. """
  177. Args:
  178. pixel_values (`torch.FloatTensor`):
  179. Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
  180. spatial_shapes (`list[tuple[int, int]]`):
  181. Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
  182. """
  183. # Apply patch embeddings to already patchified pixel values
  184. target_dtype = self.patch_embedding.weight.dtype
  185. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  186. # Get positional resized and padded positional embeddings
  187. positional_embeddings = self.position_embedding.weight.reshape(
  188. self.position_embedding_size, self.position_embedding_size, -1
  189. )
  190. resized_positional_embeddings = self.resize_positional_embeddings(
  191. positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
  192. )
  193. # Add positional embeddings to patch embeddings
  194. embeddings = patch_embeds + resized_positional_embeddings
  195. return embeddings
  196. class Siglip2PreTrainedModel(SiglipPreTrainedModel):
  197. # nn.MultiHeadAttention mask doesn't allow for non 4d mask
  198. _supports_flex_attn = False
  199. _supports_flash_attn = False
  200. class Siglip2VisionTransformer(SiglipVisionTransformer):
  201. def __init__(self, config: Siglip2VisionConfig):
  202. super().__init__(config)
  203. @merge_with_config_defaults
  204. @capture_outputs(tie_last_hidden_states=False)
  205. @auto_docstring
  206. def forward(
  207. self,
  208. pixel_values: torch.FloatTensor,
  209. attention_mask: torch.Tensor,
  210. spatial_shapes: torch.LongTensor,
  211. **kwargs: Unpack[TransformersKwargs],
  212. ) -> BaseModelOutputWithPooling:
  213. r"""
  214. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  215. Tensor containing the spatial dimensions (height, width) of the input images.
  216. """
  217. hidden_states = self.embeddings(pixel_values, spatial_shapes)
  218. encoder_attention_mask = create_bidirectional_mask(
  219. config=self.config,
  220. inputs_embeds=hidden_states,
  221. attention_mask=attention_mask,
  222. )
  223. encoder_outputs: BaseModelOutput = self.encoder(
  224. inputs_embeds=hidden_states,
  225. attention_mask=encoder_attention_mask,
  226. **kwargs,
  227. )
  228. last_hidden_state = encoder_outputs.last_hidden_state
  229. last_hidden_state = self.post_layernorm(last_hidden_state)
  230. pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
  231. return BaseModelOutputWithPooling(
  232. last_hidden_state=last_hidden_state,
  233. pooler_output=pooler_output,
  234. )
  235. class Siglip2TextModel(SiglipTextModel):
  236. pass
  237. class Siglip2MultiheadAttentionPoolingHead(SiglipMultiheadAttentionPoolingHead):
  238. def __init__(self, config: Siglip2VisionConfig):
  239. super().__init__(config)
  240. self.config = config
  241. self.num_heads = config.num_attention_heads
  242. def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
  243. batch_size = hidden_state.shape[0]
  244. probe = self.probe.repeat(batch_size, 1, 1)
  245. if attention_mask is not None:
  246. target_len, source_len = probe.shape[1], hidden_state.shape[1]
  247. attention_mask = create_bidirectional_mask(
  248. config=self.config,
  249. inputs_embeds=probe,
  250. attention_mask=attention_mask,
  251. encoder_hidden_states=hidden_state,
  252. )
  253. if attention_mask is not None:
  254. attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
  255. attention_mask = attention_mask.reshape(-1, target_len, source_len)
  256. # `nn.MultiheadAttention` cannot handle boolean masks (which SDPA can)
  257. if attention_mask.dtype == torch.bool:
  258. attention_mask = torch.where(
  259. attention_mask,
  260. torch.tensor(0.0, device=attention_mask.device, dtype=probe.dtype),
  261. torch.finfo(probe.dtype).min,
  262. )
  263. hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
  264. residual = hidden_state
  265. hidden_state = self.layernorm(hidden_state)
  266. hidden_state = residual + self.mlp(hidden_state)
  267. return hidden_state[:, 0]
  268. class Siglip2VisionModel(SiglipVisionModel):
  269. @can_return_tuple
  270. @auto_docstring
  271. def forward(
  272. self,
  273. pixel_values: torch.FloatTensor,
  274. pixel_attention_mask: torch.Tensor,
  275. spatial_shapes: torch.LongTensor,
  276. **kwargs: Unpack[TransformersKwargs],
  277. ) -> BaseModelOutputWithPooling:
  278. r"""
  279. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  280. Mask to avoid performing attention on padding pixel indices.
  281. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  282. Tensor containing the spatial dimensions (height, width) of the input images.
  283. Examples:
  284. ```python
  285. >>> from PIL import Image
  286. >>> import httpx
  287. >>> from io import BytesIO
  288. >>> from transformers import AutoProcessor, Siglip2VisionModel
  289. >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
  290. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  291. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  292. >>> with httpx.stream("GET", url) as response:
  293. ... image = Image.open(BytesIO(response.read()))
  294. >>> inputs = processor(images=image, return_tensors="pt")
  295. >>> outputs = model(**inputs)
  296. >>> last_hidden_state = outputs.last_hidden_state
  297. >>> pooled_output = outputs.pooler_output # pooled features
  298. ```"""
  299. return self.vision_model(
  300. pixel_values=pixel_values,
  301. attention_mask=pixel_attention_mask,
  302. spatial_shapes=spatial_shapes,
  303. **kwargs,
  304. )
  305. class Siglip2Model(SiglipModel):
  306. # Update: add `spatial_shapes` and `pixel_attention_mask`
  307. @can_return_tuple
  308. @auto_docstring
  309. def get_image_features(
  310. self,
  311. pixel_values: torch.FloatTensor | None = None,
  312. pixel_attention_mask: torch.Tensor | None = None,
  313. spatial_shapes: torch.LongTensor | None = None,
  314. **kwargs: Unpack[TransformersKwargs],
  315. ) -> tuple | BaseModelOutputWithPooling:
  316. r"""
  317. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  318. Mask to avoid performing attention on padding pixel indices.
  319. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  320. Tensor containing the spatial dimensions (height, width) of the input images.
  321. Examples:
  322. ```python
  323. >>> import torch
  324. >>> from transformers import AutoProcessor, AutoModel
  325. >>> from transformers.image_utils import load_image
  326. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  327. >>> image = load_image(url)
  328. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  329. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  330. >>> inputs = processor(images=image, return_tensors="pt")
  331. >>> with torch.no_grad():
  332. ... image_features = model.get_image_features(**inputs)
  333. ```
  334. """
  335. return self.vision_model(
  336. pixel_values=pixel_values,
  337. attention_mask=pixel_attention_mask,
  338. spatial_shapes=spatial_shapes,
  339. **kwargs,
  340. )
  341. # Update: add `spatial_shapes` and `pixel_attention_mask`
  342. @can_return_tuple
  343. @auto_docstring
  344. def forward(
  345. self,
  346. input_ids: torch.LongTensor | None = None,
  347. pixel_values: torch.FloatTensor | None = None,
  348. pixel_attention_mask: torch.Tensor | None = None,
  349. spatial_shapes: torch.LongTensor | None = None,
  350. attention_mask: torch.Tensor | None = None,
  351. position_ids: torch.LongTensor | None = None,
  352. return_loss: bool | None = None,
  353. **kwargs: Unpack[TransformersKwargs],
  354. ) -> Siglip2Output:
  355. r"""
  356. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  357. Mask to avoid performing attention on padding pixel indices.
  358. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  359. Tensor containing the spatial dimensions (height, width) of the input images.
  360. return_loss (`bool`, *optional*):
  361. Whether or not to return the contrastive loss.
  362. Examples:
  363. ```python
  364. >>> from PIL import Image
  365. >>> import httpx
  366. >>> from io import BytesIO
  367. >>> from transformers import AutoProcessor, AutoModel
  368. >>> import torch
  369. >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
  370. >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")
  371. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  372. >>> with httpx.stream("GET", url) as response:
  373. ... image = Image.open(BytesIO(response.read()))
  374. >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
  375. >>> # important: we pass `padding=max_length` since the model was trained with this
  376. >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
  377. >>> with torch.no_grad():
  378. ... outputs = model(**inputs)
  379. >>> logits_per_image = outputs.logits_per_image
  380. >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
  381. >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
  382. 31.9% that image 0 is 'a photo of 2 cats'
  383. ```
  384. """
  385. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  386. pixel_values=pixel_values,
  387. attention_mask=pixel_attention_mask,
  388. spatial_shapes=spatial_shapes,
  389. **kwargs,
  390. )
  391. text_outputs: BaseModelOutputWithPooling = self.text_model(
  392. input_ids=input_ids,
  393. attention_mask=attention_mask,
  394. position_ids=position_ids,
  395. **kwargs,
  396. )
  397. image_embeds = vision_outputs.pooler_output
  398. text_embeds = text_outputs.pooler_output
  399. # normalized features
  400. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  401. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  402. # cosine similarity as logits
  403. logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
  404. logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
  405. logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
  406. logits_per_image = logits_per_text.t()
  407. loss = None
  408. if return_loss:
  409. # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip2.py#L287
  410. eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
  411. m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
  412. loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
  413. nll = -torch.sum(loglik, dim=-1)
  414. loss = nll.mean()
  415. return Siglip2Output(
  416. loss=loss,
  417. logits_per_image=logits_per_image,
  418. logits_per_text=logits_per_text,
  419. text_embeds=text_embeds,
  420. image_embeds=image_embeds,
  421. text_model_output=text_outputs,
  422. vision_model_output=vision_outputs,
  423. )
  424. class Siglip2ForImageClassification(SiglipForImageClassification):
  425. # Update: add `spatial_shapes` and `pixel_attention_mask`
  426. @can_return_tuple
  427. @auto_docstring
  428. def forward(
  429. self,
  430. pixel_values: torch.Tensor | None = None,
  431. pixel_attention_mask: torch.Tensor | None = None,
  432. spatial_shapes: torch.LongTensor | None = None,
  433. labels: torch.Tensor | None = None,
  434. **kwargs: Unpack[TransformersKwargs],
  435. ) -> ImageClassifierOutput:
  436. r"""
  437. pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
  438. Mask to avoid performing attention on padding pixel indices.
  439. spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
  440. Tensor containing the spatial dimensions (height, width) of the input images.
  441. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  442. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  443. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  444. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  445. Examples:
  446. ```python
  447. >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
  448. >>> import torch
  449. >>> from PIL import Image
  450. >>> import httpx
  451. >>> from io import BytesIO
  452. >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
  453. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  454. >>> with httpx.stream("GET", url) as response:
  455. ... image = Image.open(BytesIO(response.read()))
  456. >>> # note: we are loading a `Siglip2Model` from the hub here,
  457. >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
  458. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
  459. >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")
  460. >>> inputs = image_processor(images=image, return_tensors="pt")
  461. >>> outputs = model(**inputs)
  462. >>> logits = outputs.logits
  463. >>> # model predicts one of the two classes
  464. >>> predicted_class_idx = logits.argmax(-1).item()
  465. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  466. Predicted class: LABEL_1
  467. ```
  468. """
  469. outputs: BaseModelOutputWithPooling = self.vision_model(
  470. pixel_values,
  471. attention_mask=pixel_attention_mask,
  472. spatial_shapes=spatial_shapes,
  473. **kwargs,
  474. )
  475. sequence_output = outputs.last_hidden_state
  476. # average pool the patch tokens
  477. if pixel_attention_mask is not None:
  478. pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
  479. sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
  480. else:
  481. sequence_output = torch.mean(sequence_output, dim=1)
  482. # apply classifier
  483. logits = self.classifier(sequence_output)
  484. loss = None
  485. if labels is not None:
  486. loss = self.loss_function(labels, logits, self.config)
  487. return ImageClassifierOutput(
  488. loss=loss,
  489. logits=logits,
  490. hidden_states=outputs.hidden_states,
  491. attentions=outputs.attentions,
  492. )
  493. __all__ = [
  494. "Siglip2Config",
  495. "Siglip2TextConfig",
  496. "Siglip2VisionConfig",
  497. "Siglip2Model",
  498. "Siglip2PreTrainedModel",
  499. "Siglip2TextModel",
  500. "Siglip2VisionModel",
  501. "Siglip2ForImageClassification",
  502. "Siglip2Tokenizer",
  503. ]