modeling_siglip.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014
  1. # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Siglip model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. from typing import Any
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...masking_utils import create_bidirectional_mask
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  26. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import (
  29. ModelOutput,
  30. TransformersKwargs,
  31. auto_docstring,
  32. can_return_tuple,
  33. torch_int,
  34. )
  35. from ...utils.generic import merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
  38. @dataclass
  39. @auto_docstring(
  40. custom_intro="""
  41. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  42. """
  43. )
  44. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
  45. class SiglipVisionModelOutput(ModelOutput):
  46. r"""
  47. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  48. The image embeddings obtained by applying the projection layer to the pooler_output.
  49. """
  50. image_embeds: torch.FloatTensor | None = None
  51. last_hidden_state: torch.FloatTensor | None = None
  52. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  53. attentions: tuple[torch.FloatTensor, ...] | None = None
  54. @dataclass
  55. @auto_docstring(
  56. custom_intro="""
  57. Base class for text model's outputs that also contains a pooling of the last hidden states.
  58. """
  59. )
  60. # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
  61. class SiglipTextModelOutput(ModelOutput):
  62. r"""
  63. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  64. The text embeddings obtained by applying the projection layer to the pooler_output.
  65. """
  66. text_embeds: torch.FloatTensor | None = None
  67. last_hidden_state: torch.FloatTensor | None = None
  68. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  69. attentions: tuple[torch.FloatTensor, ...] | None = None
  70. @dataclass
  71. @auto_docstring
  72. # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
  73. class SiglipOutput(ModelOutput):
  74. r"""
  75. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  76. Contrastive loss for image-text similarity.
  77. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  78. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  79. similarity scores.
  80. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  81. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  82. similarity scores.
  83. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  84. The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
  85. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  86. The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
  87. text_model_output (`BaseModelOutputWithPooling`):
  88. The output of the [`SiglipTextModel`].
  89. vision_model_output (`BaseModelOutputWithPooling`):
  90. The output of the [`SiglipVisionModel`].
  91. """
  92. loss: torch.FloatTensor | None = None
  93. logits_per_image: torch.FloatTensor | None = None
  94. logits_per_text: torch.FloatTensor | None = None
  95. text_embeds: torch.FloatTensor | None = None
  96. image_embeds: torch.FloatTensor | None = None
  97. text_model_output: BaseModelOutputWithPooling = None
  98. vision_model_output: BaseModelOutputWithPooling = None
  99. def to_tuple(self) -> tuple[Any]:
  100. return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())
  101. class SiglipVisionEmbeddings(nn.Module):
  102. def __init__(self, config: SiglipVisionConfig):
  103. super().__init__()
  104. self.config = config
  105. self.embed_dim = config.hidden_size
  106. self.image_size = config.image_size
  107. self.patch_size = config.patch_size
  108. self.patch_embedding = nn.Conv2d(
  109. in_channels=config.num_channels,
  110. out_channels=self.embed_dim,
  111. kernel_size=self.patch_size,
  112. stride=self.patch_size,
  113. padding="valid",
  114. )
  115. self.num_patches = (self.image_size // self.patch_size) ** 2
  116. self.num_positions = self.num_patches
  117. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  118. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  119. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  120. """
  121. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  122. images. This method is also adapted to support torch.jit tracing and no class embeddings.
  123. Adapted from:
  124. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  125. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  126. """
  127. num_patches = embeddings.shape[1]
  128. num_positions = self.position_embedding.weight.shape[0]
  129. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  130. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  131. return self.position_embedding(self.position_ids)
  132. patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
  133. dim = embeddings.shape[-1]
  134. new_height = height // self.patch_size
  135. new_width = width // self.patch_size
  136. sqrt_num_positions = torch_int(num_positions**0.5)
  137. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  138. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  139. patch_pos_embed = nn.functional.interpolate(
  140. patch_pos_embed,
  141. size=(new_height, new_width),
  142. mode="bicubic",
  143. align_corners=False,
  144. )
  145. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  146. return patch_pos_embed
  147. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  148. _, _, height, width = pixel_values.shape
  149. target_dtype = self.patch_embedding.weight.dtype
  150. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  151. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  152. if interpolate_pos_encoding:
  153. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  154. else:
  155. embeddings = embeddings + self.position_embedding(self.position_ids)
  156. return embeddings
  157. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
  158. class SiglipTextEmbeddings(nn.Module):
  159. def __init__(self, config: SiglipTextConfig):
  160. super().__init__()
  161. embed_dim = config.hidden_size
  162. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  163. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  164. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  165. self.register_buffer(
  166. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  167. )
  168. def forward(
  169. self,
  170. input_ids: torch.LongTensor | None = None,
  171. position_ids: torch.LongTensor | None = None,
  172. inputs_embeds: torch.FloatTensor | None = None,
  173. ) -> torch.Tensor:
  174. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  175. max_position_embedding = self.position_embedding.weight.shape[0]
  176. if seq_length > max_position_embedding:
  177. raise ValueError(
  178. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  179. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  180. )
  181. if position_ids is None:
  182. position_ids = self.position_ids[:, :seq_length]
  183. if inputs_embeds is None:
  184. inputs_embeds = self.token_embedding(input_ids)
  185. position_embeddings = self.position_embedding(position_ids)
  186. embeddings = inputs_embeds + position_embeddings
  187. return embeddings
  188. def eager_attention_forward(
  189. module: nn.Module,
  190. query: torch.Tensor,
  191. key: torch.Tensor,
  192. value: torch.Tensor,
  193. attention_mask: torch.Tensor | None,
  194. scaling: float,
  195. dropout: float = 0.0,
  196. **kwargs,
  197. ):
  198. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  199. if attention_mask is not None:
  200. attn_weights = attn_weights + attention_mask
  201. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  202. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  203. attn_output = torch.matmul(attn_weights, value)
  204. attn_output = attn_output.transpose(1, 2).contiguous()
  205. return attn_output, attn_weights
  206. class SiglipAttention(nn.Module):
  207. """Multi-headed attention from 'Attention Is All You Need' paper"""
  208. def __init__(self, config):
  209. super().__init__()
  210. self.config = config
  211. self.embed_dim = config.hidden_size
  212. self.num_heads = config.num_attention_heads
  213. self.head_dim = self.embed_dim // self.num_heads
  214. if self.head_dim * self.num_heads != self.embed_dim:
  215. raise ValueError(
  216. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  217. f" {self.num_heads})."
  218. )
  219. self.scale = self.head_dim**-0.5
  220. self.dropout = config.attention_dropout
  221. self.is_causal = False
  222. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  223. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  224. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  225. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  226. def forward(
  227. self,
  228. hidden_states: torch.Tensor,
  229. attention_mask: torch.Tensor | None = None,
  230. **kwargs,
  231. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  232. """Input shape: Batch x Time x Channel"""
  233. input_shape = hidden_states.shape[:-1]
  234. hidden_shape = (*input_shape, -1, self.head_dim)
  235. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  236. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  237. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  238. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  239. self.config._attn_implementation, eager_attention_forward
  240. )
  241. attn_output, attn_weights = attention_interface(
  242. self,
  243. queries,
  244. keys,
  245. values,
  246. attention_mask,
  247. is_causal=self.is_causal,
  248. scaling=self.scale,
  249. dropout=0.0 if not self.training else self.dropout,
  250. )
  251. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  252. attn_output = self.out_proj(attn_output)
  253. return attn_output, attn_weights
  254. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
  255. class SiglipMLP(nn.Module):
  256. def __init__(self, config):
  257. super().__init__()
  258. self.config = config
  259. self.activation_fn = ACT2FN[config.hidden_act]
  260. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  261. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  262. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  263. hidden_states = self.fc1(hidden_states)
  264. hidden_states = self.activation_fn(hidden_states)
  265. hidden_states = self.fc2(hidden_states)
  266. return hidden_states
  267. class SiglipEncoderLayer(GradientCheckpointingLayer):
  268. def __init__(self, config: SiglipVisionConfig | SiglipTextConfig):
  269. super().__init__()
  270. self.embed_dim = config.hidden_size
  271. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  272. self.self_attn = SiglipAttention(config)
  273. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  274. self.mlp = SiglipMLP(config)
  275. @auto_docstring
  276. def forward(
  277. self,
  278. hidden_states: torch.Tensor,
  279. attention_mask: torch.Tensor,
  280. **kwargs: Unpack[TransformersKwargs],
  281. ) -> torch.FloatTensor:
  282. residual = hidden_states
  283. hidden_states = self.layer_norm1(hidden_states)
  284. hidden_states, _ = self.self_attn(
  285. hidden_states=hidden_states,
  286. attention_mask=attention_mask,
  287. **kwargs,
  288. )
  289. hidden_states = residual + hidden_states
  290. residual = hidden_states
  291. hidden_states = self.layer_norm2(hidden_states)
  292. hidden_states = self.mlp(hidden_states)
  293. hidden_states = residual + hidden_states
  294. return hidden_states
  295. @auto_docstring
  296. class SiglipPreTrainedModel(PreTrainedModel):
  297. config: SiglipConfig
  298. base_model_prefix = "siglip"
  299. input_modalities = ("image", "text")
  300. supports_gradient_checkpointing = True
  301. _no_split_modules = [
  302. "SiglipTextEmbeddings",
  303. "SiglipVisionEmbeddings",
  304. "SiglipEncoderLayer",
  305. "SiglipMultiheadAttentionPoolingHead",
  306. ]
  307. _supports_flash_attn = True
  308. _supports_sdpa = True
  309. _supports_flex_attn = True
  310. _supports_attention_backend = True
  311. _can_record_outputs = {
  312. "hidden_states": SiglipEncoderLayer,
  313. "attentions": SiglipAttention,
  314. }
  315. @torch.no_grad()
  316. def _init_weights(self, module):
  317. """Initialize the weights"""
  318. if isinstance(module, SiglipVisionEmbeddings):
  319. width = (
  320. self.config.vision_config.hidden_size
  321. if isinstance(self.config, SiglipConfig)
  322. else self.config.hidden_size
  323. )
  324. init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
  325. if hasattr(module, "position_ids"):
  326. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  327. elif isinstance(module, nn.Embedding):
  328. init.default_flax_embed_init_(module.weight)
  329. elif isinstance(module, SiglipAttention):
  330. init.xavier_uniform_(module.q_proj.weight)
  331. init.xavier_uniform_(module.k_proj.weight)
  332. init.xavier_uniform_(module.v_proj.weight)
  333. init.xavier_uniform_(module.out_proj.weight)
  334. init.zeros_(module.q_proj.bias)
  335. init.zeros_(module.k_proj.bias)
  336. init.zeros_(module.v_proj.bias)
  337. init.zeros_(module.out_proj.bias)
  338. elif isinstance(module, SiglipMLP):
  339. init.xavier_uniform_(module.fc1.weight)
  340. init.xavier_uniform_(module.fc2.weight)
  341. init.normal_(module.fc1.bias, std=1e-6)
  342. init.normal_(module.fc2.bias, std=1e-6)
  343. elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
  344. init.xavier_uniform_(module.probe)
  345. init.xavier_uniform_(module.attention.in_proj_weight)
  346. init.zeros_(module.attention.in_proj_bias)
  347. elif isinstance(module, SiglipModel):
  348. init.zeros_(module.logit_scale)
  349. init.zeros_(module.logit_bias)
  350. elif isinstance(module, SiglipForImageClassification):
  351. init.normal_(
  352. module.classifier.weight,
  353. std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
  354. )
  355. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  356. init.lecun_normal_(module.weight)
  357. if module.bias is not None:
  358. init.zeros_(module.bias)
  359. elif isinstance(module, nn.LayerNorm):
  360. init.zeros_(module.bias)
  361. init.ones_(module.weight)
  362. elif isinstance(module, SiglipTextEmbeddings):
  363. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  364. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
  365. class SiglipEncoder(nn.Module):
  366. """
  367. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  368. [`SiglipEncoderLayer`].
  369. Args:
  370. config: SiglipConfig
  371. """
  372. def __init__(self, config: SiglipConfig):
  373. super().__init__()
  374. self.config = config
  375. self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  376. self.gradient_checkpointing = False
  377. # Ignore copy
  378. @auto_docstring
  379. def forward(
  380. self,
  381. inputs_embeds,
  382. attention_mask: torch.Tensor | None = None,
  383. **kwargs: Unpack[TransformersKwargs],
  384. ) -> BaseModelOutput:
  385. hidden_states = inputs_embeds
  386. for encoder_layer in self.layers:
  387. hidden_states = encoder_layer(
  388. hidden_states,
  389. attention_mask,
  390. **kwargs,
  391. )
  392. return BaseModelOutput(last_hidden_state=hidden_states)
  393. class SiglipTextTransformer(SiglipPreTrainedModel):
  394. _input_embed_layer = "token_embedding"
  395. def __init__(self, config: SiglipTextConfig):
  396. super().__init__(config)
  397. self.config = config
  398. embed_dim = config.hidden_size
  399. self.embeddings = SiglipTextEmbeddings(config)
  400. self.encoder = SiglipEncoder(config)
  401. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  402. self.head = nn.Linear(embed_dim, config.projection_size)
  403. self.post_init()
  404. @can_return_tuple
  405. @auto_docstring
  406. def forward(
  407. self,
  408. input_ids: torch.Tensor | None = None,
  409. attention_mask: torch.Tensor | None = None,
  410. position_ids: torch.Tensor | None = None,
  411. **kwargs: Unpack[TransformersKwargs],
  412. ) -> BaseModelOutputWithPooling:
  413. if input_ids is None:
  414. raise ValueError("You have to specify input_ids")
  415. input_shape = input_ids.size()
  416. input_ids = input_ids.view(-1, input_shape[-1])
  417. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  418. # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
  419. attention_mask = create_bidirectional_mask(
  420. config=self.config,
  421. inputs_embeds=hidden_states,
  422. attention_mask=attention_mask,
  423. )
  424. encoder_outputs: BaseModelOutput = self.encoder(
  425. inputs_embeds=hidden_states,
  426. attention_mask=attention_mask,
  427. **kwargs,
  428. )
  429. last_hidden_state = encoder_outputs.last_hidden_state
  430. last_hidden_state = self.final_layer_norm(last_hidden_state)
  431. # The model uses the last token's hidden state, which may be padding.
  432. pooled_output = last_hidden_state[:, -1, :]
  433. pooled_output = self.head(pooled_output)
  434. return BaseModelOutputWithPooling(
  435. last_hidden_state=last_hidden_state,
  436. pooler_output=pooled_output,
  437. )
  438. @auto_docstring(
  439. custom_intro="""
  440. The text model from SigLIP without any head or projection on top.
  441. """
  442. )
  443. class SiglipTextModel(SiglipPreTrainedModel):
  444. config: SiglipTextConfig
  445. input_modalities = ("text",)
  446. def __init__(self, config: SiglipTextConfig):
  447. super().__init__(config)
  448. self.text_model = SiglipTextTransformer(config)
  449. # Initialize weights and apply final processing
  450. self.post_init()
  451. def get_input_embeddings(self) -> nn.Module:
  452. return self.text_model.embeddings.token_embedding
  453. def set_input_embeddings(self, value):
  454. self.text_model.embeddings.token_embedding = value
  455. @merge_with_config_defaults
  456. @capture_outputs(tie_last_hidden_states=False)
  457. @auto_docstring
  458. def forward(
  459. self,
  460. input_ids: torch.Tensor | None = None,
  461. attention_mask: torch.Tensor | None = None,
  462. position_ids: torch.Tensor | None = None,
  463. **kwargs: Unpack[TransformersKwargs],
  464. ) -> BaseModelOutputWithPooling:
  465. r"""
  466. Examples:
  467. ```python
  468. >>> from transformers import AutoTokenizer, SiglipTextModel
  469. >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
  470. >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
  471. >>> # important: make sure to set padding="max_length" as that's how the model was trained
  472. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
  473. >>> outputs = model(**inputs)
  474. >>> last_hidden_state = outputs.last_hidden_state
  475. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  476. ```"""
  477. return self.text_model(
  478. input_ids=input_ids,
  479. attention_mask=attention_mask,
  480. position_ids=position_ids,
  481. **kwargs,
  482. )
  483. class SiglipVisionTransformer(SiglipPreTrainedModel):
  484. _input_embed_layer = "patch_embedding"
  485. def __init__(self, config: SiglipVisionConfig):
  486. super().__init__(config)
  487. self.config = config
  488. embed_dim = config.hidden_size
  489. self.embeddings = SiglipVisionEmbeddings(config)
  490. self.encoder = SiglipEncoder(config)
  491. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  492. self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
  493. if self.use_head:
  494. self.head = SiglipMultiheadAttentionPoolingHead(config)
  495. self.post_init()
  496. @auto_docstring
  497. def forward(
  498. self,
  499. pixel_values,
  500. interpolate_pos_encoding: bool | None = False,
  501. **kwargs: Unpack[TransformersKwargs],
  502. ) -> BaseModelOutputWithPooling:
  503. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  504. encoder_outputs: BaseModelOutput = self.encoder(
  505. inputs_embeds=hidden_states,
  506. **kwargs,
  507. )
  508. last_hidden_state = encoder_outputs.last_hidden_state
  509. last_hidden_state = self.post_layernorm(last_hidden_state)
  510. pooler_output = self.head(last_hidden_state) if self.use_head else None
  511. return BaseModelOutputWithPooling(
  512. last_hidden_state=last_hidden_state,
  513. pooler_output=pooler_output,
  514. )
  515. class SiglipMultiheadAttentionPoolingHead(nn.Module):
  516. """Multihead Attention Pooling."""
  517. def __init__(self, config: SiglipVisionConfig):
  518. super().__init__()
  519. self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  520. self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
  521. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  522. self.mlp = SiglipMLP(config)
  523. def forward(self, hidden_state):
  524. batch_size = hidden_state.shape[0]
  525. probe = self.probe.repeat(batch_size, 1, 1)
  526. hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  527. residual = hidden_state
  528. hidden_state = self.layernorm(hidden_state)
  529. hidden_state = residual + self.mlp(hidden_state)
  530. return hidden_state[:, 0]
  531. @auto_docstring(
  532. custom_intro="""
  533. The vision model from SigLIP without any head or projection on top.
  534. """
  535. )
  536. class SiglipVisionModel(SiglipPreTrainedModel):
  537. config: SiglipVisionConfig
  538. main_input_name = "pixel_values"
  539. input_modalities = ("image",)
  540. def __init__(self, config: SiglipVisionConfig):
  541. super().__init__(config)
  542. self.vision_model = SiglipVisionTransformer(config)
  543. # Initialize weights and apply final processing
  544. self.post_init()
  545. def get_input_embeddings(self) -> nn.Module:
  546. return self.vision_model.embeddings.patch_embedding
  547. @merge_with_config_defaults
  548. @capture_outputs(tie_last_hidden_states=False)
  549. @auto_docstring
  550. def forward(
  551. self,
  552. pixel_values,
  553. interpolate_pos_encoding: bool = False,
  554. **kwargs: Unpack[TransformersKwargs],
  555. ) -> BaseModelOutputWithPooling:
  556. r"""
  557. Examples:
  558. ```python
  559. >>> from PIL import Image
  560. >>> import httpx
  561. >>> from io import BytesIO
  562. >>> from transformers import AutoProcessor, SiglipVisionModel
  563. >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
  564. >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
  565. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  566. >>> with httpx.stream("GET", url) as response:
  567. ... image = Image.open(BytesIO(response.read()))
  568. >>> inputs = processor(images=image, return_tensors="pt")
  569. >>> outputs = model(**inputs)
  570. >>> last_hidden_state = outputs.last_hidden_state
  571. >>> pooled_output = outputs.pooler_output # pooled features
  572. ```"""
  573. return self.vision_model(
  574. pixel_values=pixel_values,
  575. interpolate_pos_encoding=interpolate_pos_encoding,
  576. **kwargs,
  577. )
  578. @auto_docstring
  579. class SiglipModel(SiglipPreTrainedModel):
  580. config: SiglipConfig
  581. def __init__(self, config: SiglipConfig):
  582. super().__init__(config)
  583. if not isinstance(config.text_config, SiglipTextConfig):
  584. raise TypeError(
  585. "config.text_config is expected to be of type SiglipTextConfig but is of type"
  586. f" {type(config.text_config)}."
  587. )
  588. if not isinstance(config.vision_config, SiglipVisionConfig):
  589. raise TypeError(
  590. "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
  591. f" {type(config.vision_config)}."
  592. )
  593. text_config = config.text_config
  594. vision_config = config.vision_config
  595. # First, initialize the text and vision models with proper attention implementation
  596. text_model = SiglipTextModel._from_config(text_config)
  597. vision_model = SiglipVisionModel._from_config(vision_config)
  598. # Second, get the text and vision submodules (for backward compatibility)
  599. self.text_model = text_model.text_model
  600. self.vision_model = vision_model.vision_model
  601. self.logit_scale = nn.Parameter(torch.randn(1))
  602. self.logit_bias = nn.Parameter(torch.randn(1))
  603. # Initialize weights and apply final processing
  604. self.post_init()
  605. def get_input_embeddings(self) -> nn.Module:
  606. return self.text_model.embeddings.token_embedding
  607. def set_input_embeddings(self, value: nn.Module):
  608. self.text_model.embeddings.token_embedding = value
  609. @can_return_tuple
  610. @auto_docstring
  611. def get_text_features(
  612. self,
  613. input_ids: torch.Tensor,
  614. attention_mask: torch.Tensor | None = None,
  615. position_ids: torch.Tensor | None = None,
  616. **kwargs: Unpack[TransformersKwargs],
  617. ) -> tuple | BaseModelOutputWithPooling:
  618. r"""
  619. Examples:
  620. ```python
  621. >>> from transformers import AutoTokenizer, AutoModel
  622. >>> import torch
  623. >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
  624. >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
  625. >>> # important: make sure to set padding="max_length" as that's how the model was trained
  626. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
  627. >>> with torch.no_grad():
  628. ... text_features = model.get_text_features(**inputs)
  629. ```"""
  630. return self.text_model(
  631. input_ids=input_ids,
  632. attention_mask=attention_mask,
  633. position_ids=position_ids,
  634. **kwargs,
  635. )
  636. @can_return_tuple
  637. @auto_docstring
  638. def get_image_features(
  639. self,
  640. pixel_values: torch.FloatTensor,
  641. interpolate_pos_encoding: bool = False,
  642. **kwargs: Unpack[TransformersKwargs],
  643. ) -> tuple | BaseModelOutputWithPooling:
  644. r"""
  645. Examples:
  646. ```python
  647. >>> import torch
  648. >>> from transformers import AutoProcessor, AutoModel
  649. >>> from transformers.image_utils import load_image
  650. >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
  651. >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
  652. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  653. >>> image = load_image(url)
  654. >>> inputs = processor(images=image, return_tensors="pt")
  655. >>> with torch.no_grad():
  656. ... image_features = model.get_image_features(**inputs)
  657. ```"""
  658. return self.vision_model(
  659. pixel_values=pixel_values,
  660. interpolate_pos_encoding=interpolate_pos_encoding,
  661. **kwargs,
  662. )
  663. # NOTE: SiglipModel uses Pretrained backbones, so we don't need to add `capture_outputs` here
  664. @can_return_tuple
  665. @auto_docstring
  666. def forward(
  667. self,
  668. input_ids: torch.LongTensor | None = None,
  669. pixel_values: torch.FloatTensor | None = None,
  670. attention_mask: torch.Tensor | None = None,
  671. position_ids: torch.LongTensor | None = None,
  672. return_loss: bool | None = None,
  673. interpolate_pos_encoding: bool = False,
  674. **kwargs: Unpack[TransformersKwargs],
  675. ) -> SiglipOutput:
  676. r"""
  677. return_loss (`bool`, *optional*):
  678. Whether or not to return the contrastive loss.
  679. Examples:
  680. ```python
  681. >>> from PIL import Image
  682. >>> import httpx
  683. >>> from io import BytesIO
  684. >>> from transformers import AutoProcessor, AutoModel
  685. >>> import torch
  686. >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
  687. >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
  688. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  689. >>> with httpx.stream("GET", url) as response:
  690. ... image = Image.open(BytesIO(response.read()))
  691. >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
  692. >>> # important: we pass `padding=max_length` since the model was trained with this
  693. >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
  694. >>> with torch.no_grad():
  695. ... outputs = model(**inputs)
  696. >>> logits_per_image = outputs.logits_per_image
  697. >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
  698. >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
  699. 31.9% that image 0 is 'a photo of 2 cats'
  700. ```"""
  701. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  702. pixel_values=pixel_values,
  703. interpolate_pos_encoding=interpolate_pos_encoding,
  704. **kwargs,
  705. )
  706. text_outputs: BaseModelOutputWithPooling = self.text_model(
  707. input_ids=input_ids,
  708. attention_mask=attention_mask,
  709. position_ids=position_ids,
  710. **kwargs,
  711. )
  712. image_embeds = vision_outputs.pooler_output
  713. text_embeds = text_outputs.pooler_output
  714. # normalized features
  715. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  716. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  717. # cosine similarity as logits
  718. logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
  719. logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
  720. logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
  721. logits_per_image = logits_per_text.t()
  722. loss = None
  723. if return_loss:
  724. # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
  725. eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
  726. m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
  727. loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
  728. nll = -torch.sum(loglik, dim=-1)
  729. loss = nll.mean()
  730. return SiglipOutput(
  731. loss=loss,
  732. logits_per_image=logits_per_image,
  733. logits_per_text=logits_per_text,
  734. text_embeds=text_embeds,
  735. image_embeds=image_embeds,
  736. text_model_output=text_outputs,
  737. vision_model_output=vision_outputs,
  738. )
  739. @auto_docstring(
  740. custom_intro="""
  741. SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
  742. the patch tokens) e.g. for ImageNet.
  743. """
  744. )
  745. class SiglipForImageClassification(SiglipPreTrainedModel):
  746. main_input_name = "pixel_values"
  747. input_modalities = ("image",)
  748. def __init__(self, config: SiglipConfig) -> None:
  749. super().__init__(config)
  750. self.num_labels = config.num_labels
  751. # Create the vision model with proper attention
  752. # and take only vision_model submodule (for backward compatibility)
  753. vision_model = SiglipVisionModel._from_config(config.vision_config)
  754. self.vision_model = vision_model.vision_model
  755. # Classifier head
  756. self.classifier = (
  757. nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  758. )
  759. # Initialize weights and apply final processing
  760. self.post_init()
  761. def get_input_embeddings(self) -> nn.Module:
  762. return self.vision_model.embeddings.patch_embedding
  763. def set_input_embeddings(self, value: nn.Module):
  764. self.vision_model.embeddings.patch_embedding = value
  765. @merge_with_config_defaults
  766. @capture_outputs
  767. @auto_docstring
  768. def forward(
  769. self,
  770. pixel_values: torch.Tensor | None = None,
  771. labels: torch.Tensor | None = None,
  772. interpolate_pos_encoding: bool = False,
  773. **kwargs: Unpack[TransformersKwargs],
  774. ) -> ImageClassifierOutput:
  775. r"""
  776. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  777. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  778. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  779. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  780. Examples:
  781. ```python
  782. >>> from transformers import AutoImageProcessor, SiglipForImageClassification
  783. >>> import torch
  784. >>> from PIL import Image
  785. >>> import httpx
  786. >>> from io import BytesIO
  787. >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
  788. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  789. >>> with httpx.stream("GET", url) as response:
  790. ... image = Image.open(BytesIO(response.read()))
  791. >>> # note: we are loading a `SiglipModel` from the hub here,
  792. >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
  793. >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
  794. >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
  795. >>> inputs = image_processor(images=image, return_tensors="pt")
  796. >>> outputs = model(**inputs)
  797. >>> logits = outputs.logits
  798. >>> # model predicts one of the two classes
  799. >>> predicted_class_idx = logits.argmax(-1).item()
  800. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  801. Predicted class: LABEL_1
  802. ```"""
  803. outputs: BaseModelOutputWithPooling = self.vision_model(
  804. pixel_values,
  805. interpolate_pos_encoding=interpolate_pos_encoding,
  806. **kwargs,
  807. )
  808. sequence_output = outputs.last_hidden_state
  809. # average pool the patch tokens
  810. sequence_output = torch.mean(sequence_output, dim=1)
  811. # apply classifier
  812. logits = self.classifier(sequence_output)
  813. loss = None
  814. if labels is not None:
  815. loss = self.loss_function(labels, logits, self.config)
  816. return ImageClassifierOutput(
  817. loss=loss,
  818. logits=logits,
  819. )
  820. __all__ = [
  821. "SiglipModel",
  822. "SiglipPreTrainedModel",
  823. "SiglipTextModel",
  824. "SiglipVisionModel",
  825. "SiglipForImageClassification",
  826. ]