modeling_aimv2.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/aimv2/modular_aimv2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_aimv2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. from typing import Any
  24. import torch
  25. import torch.nn.functional as F
  26. from torch import nn
  27. from ... import initialization as init
  28. from ...activations import ACT2FN
  29. from ...integrations import use_kernel_forward_from_hub
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.generic import merge_with_config_defaults
  37. from ...utils.output_capturing import capture_outputs
  38. from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig
  39. @dataclass
  40. @auto_docstring
  41. class Aimv2Output(ModelOutput):
  42. r"""
  43. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  44. Contrastive loss for image-text similarity.
  45. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  46. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  47. similarity scores.
  48. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  49. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  50. similarity scores.
  51. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  52. The text embeddings obtained by applying the projection layer to the pooled output of [`Aimv2TextModel`].
  53. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  54. The image embeddings obtained by applying the projection layer to the pooled output of [`Aimv2VisionModel`].
  55. text_model_output (`BaseModelOutputWithPooling`):
  56. The output of the [`Aimv2TextModel`].
  57. vision_model_output (`BaseModelOutputWithPooling`):
  58. The output of the [`Aimv2VisionModel`].
  59. """
  60. loss: torch.FloatTensor | None = None
  61. logits_per_image: torch.FloatTensor | None = None
  62. logits_per_text: torch.FloatTensor | None = None
  63. text_embeds: torch.FloatTensor | None = None
  64. image_embeds: torch.FloatTensor | None = None
  65. text_model_output: BaseModelOutputWithPooling = None
  66. vision_model_output: BaseModelOutputWithPooling = None
  67. def to_tuple(self) -> tuple[Any]:
  68. return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())
  69. @use_kernel_forward_from_hub("RMSNorm")
  70. class Aimv2RMSNorm(nn.Module):
  71. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  72. """
  73. Aimv2RMSNorm is equivalent to T5LayerNorm
  74. """
  75. super().__init__()
  76. self.weight = nn.Parameter(torch.ones(hidden_size))
  77. self.variance_epsilon = eps
  78. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  79. input_dtype = hidden_states.dtype
  80. hidden_states = hidden_states.to(torch.float32)
  81. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  82. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  83. return self.weight * hidden_states.to(input_dtype)
  84. def extra_repr(self):
  85. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  86. class Aimv2MLP(nn.Module):
  87. def __init__(self, config):
  88. super().__init__()
  89. self.config = config
  90. self.hidden_size = config.hidden_size
  91. self.intermediate_size = config.intermediate_size
  92. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  93. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  94. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  95. self.act_fn = ACT2FN[config.hidden_act]
  96. def forward(self, x):
  97. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  98. return down_proj
  99. class Aimv2VisionEmbeddings(nn.Module):
  100. def __init__(self, config: Aimv2VisionConfig):
  101. super().__init__()
  102. self.config = config
  103. self.patch_size = config.patch_size
  104. self.patch_embed = nn.Conv2d(
  105. config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
  106. )
  107. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  108. num_patches = (config.image_size // config.patch_size) ** 2
  109. if not self.config.is_native:
  110. self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
  111. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  112. @staticmethod
  113. def build_2d_sincos_position_embedding(
  114. height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
  115. ) -> torch.Tensor:
  116. grid_w = torch.arange(int(width), dtype=dtype, device=device)
  117. grid_h = torch.arange(int(height), dtype=dtype, device=device)
  118. grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy")
  119. pos_dim = embed_dim // 4
  120. omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
  121. omega = 1.0 / (temperature**omega)
  122. out_h = grid_h.flatten()[..., None] @ omega[None, :]
  123. out_w = grid_w.flatten()[..., None] @ omega[None, :]
  124. return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
  125. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  126. _, _, height, width = pixel_values.size()
  127. hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
  128. hidden_states = self.rms_norm(hidden_states)
  129. if self.config.is_native:
  130. pos_embed = self.build_2d_sincos_position_embedding(
  131. height // self.patch_size,
  132. width // self.patch_size,
  133. embed_dim=self.config.hidden_size,
  134. device=hidden_states.device,
  135. dtype=hidden_states.dtype,
  136. )
  137. else:
  138. pos_embed = self.position_embedding(self.position_ids)
  139. hidden_states = hidden_states + pos_embed
  140. return hidden_states
  141. class Aimv2TextEmbeddings(nn.Module):
  142. def __init__(self, config: Aimv2TextConfig):
  143. super().__init__()
  144. embed_dim = config.hidden_size
  145. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  146. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  147. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  148. self.register_buffer(
  149. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  150. )
  151. def forward(
  152. self,
  153. input_ids: torch.LongTensor | None = None,
  154. position_ids: torch.LongTensor | None = None,
  155. inputs_embeds: torch.FloatTensor | None = None,
  156. ) -> torch.Tensor:
  157. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  158. max_position_embedding = self.position_embedding.weight.shape[0]
  159. if seq_length > max_position_embedding:
  160. raise ValueError(
  161. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  162. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  163. )
  164. if position_ids is None:
  165. position_ids = self.position_ids[:, :seq_length]
  166. if inputs_embeds is None:
  167. inputs_embeds = self.token_embedding(input_ids)
  168. position_embeddings = self.position_embedding(position_ids)
  169. embeddings = inputs_embeds + position_embeddings
  170. return embeddings
  171. def eager_attention_forward(
  172. module: nn.Module,
  173. query: torch.Tensor,
  174. key: torch.Tensor,
  175. value: torch.Tensor,
  176. attention_mask: torch.Tensor | None,
  177. scaling: float,
  178. dropout: float = 0.0,
  179. **kwargs,
  180. ):
  181. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  182. if attention_mask is not None:
  183. attn_weights = attn_weights + attention_mask
  184. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  185. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  186. attn_output = torch.matmul(attn_weights, value)
  187. attn_output = attn_output.transpose(1, 2).contiguous()
  188. return attn_output, attn_weights
  189. class Aimv2Attention(nn.Module):
  190. """Multi-headed attention from 'Attention Is All You Need' paper"""
  191. def __init__(self, config):
  192. super().__init__()
  193. self.config = config
  194. self.embed_dim = config.hidden_size
  195. self.num_heads = config.num_attention_heads
  196. self.head_dim = self.embed_dim // self.num_heads
  197. if self.head_dim * self.num_heads != self.embed_dim:
  198. raise ValueError(
  199. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  200. f" {self.num_heads})."
  201. )
  202. self.scale = self.head_dim**-0.5
  203. self.dropout = config.attention_dropout
  204. self.is_causal = False
  205. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  206. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  207. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  208. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  209. def forward(
  210. self,
  211. hidden_states: torch.Tensor,
  212. attention_mask: torch.Tensor | None = None,
  213. **kwargs,
  214. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  215. """Input shape: Batch x Time x Channel"""
  216. input_shape = hidden_states.shape[:-1]
  217. hidden_shape = (*input_shape, -1, self.head_dim)
  218. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  219. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  220. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  221. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  222. self.config._attn_implementation, eager_attention_forward
  223. )
  224. attn_output, attn_weights = attention_interface(
  225. self,
  226. queries,
  227. keys,
  228. values,
  229. attention_mask,
  230. is_causal=self.is_causal,
  231. scaling=self.scale,
  232. dropout=0.0 if not self.training else self.dropout,
  233. )
  234. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  235. attn_output = self.out_proj(attn_output)
  236. return attn_output, attn_weights
  237. class Aimv2EncoderLayer(GradientCheckpointingLayer):
  238. def __init__(self, config: Aimv2VisionConfig):
  239. super().__init__()
  240. self.attention = Aimv2Attention(config)
  241. self.ffn = Aimv2MLP(config)
  242. self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  243. self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  244. def forward(
  245. self,
  246. hidden_states: torch.Tensor,
  247. attention_mask: torch.Tensor | None = None,
  248. **kwargs: Unpack[TransformersKwargs],
  249. ) -> torch.Tensor:
  250. norm_hidden_states = self.rms_norm1(hidden_states)
  251. attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
  252. hidden_states = hidden_states + attn_output
  253. norm_hidden_states = self.rms_norm2(hidden_states)
  254. mlp_output = self.ffn(norm_hidden_states)
  255. hidden_states = hidden_states + mlp_output
  256. return hidden_states
  257. class Aimv2Encoder(nn.Module):
  258. """
  259. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  260. [`Aimv2EncoderLayer`].
  261. Args:
  262. config: Aimv2Config
  263. """
  264. def __init__(self, config: Aimv2Config):
  265. super().__init__()
  266. self.config = config
  267. self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  268. self.gradient_checkpointing = False
  269. # Ignore copy
  270. @auto_docstring
  271. def forward(
  272. self,
  273. inputs_embeds,
  274. attention_mask: torch.Tensor | None = None,
  275. **kwargs: Unpack[TransformersKwargs],
  276. ) -> BaseModelOutput:
  277. hidden_states = inputs_embeds
  278. for encoder_layer in self.layers:
  279. hidden_states = encoder_layer(
  280. hidden_states,
  281. attention_mask,
  282. **kwargs,
  283. )
  284. return BaseModelOutput(last_hidden_state=hidden_states)
  285. class Aimv2AttentionPoolingHead(nn.Module):
  286. def __init__(self, config: Aimv2VisionConfig):
  287. super().__init__()
  288. self.hidden_size = config.hidden_size
  289. self.num_heads = config.num_attention_heads
  290. self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
  291. self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
  292. self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
  293. self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
  294. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  295. batch_size, seq_len, hidden_dim = hidden_states.shape
  296. cls_token = self.cls_token.expand(batch_size, -1, -1)
  297. key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
  298. value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
  299. query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads)
  300. key = key.permute(0, 2, 1, 3)
  301. value = value.permute(0, 2, 1, 3)
  302. query = query.permute(0, 2, 1, 3)
  303. attn_output = F.scaled_dot_product_attention(query, key, value)
  304. attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim)
  305. attn_output = attn_output.mean(dim=1)
  306. output = self.output_proj(attn_output)
  307. return output
  308. @auto_docstring
  309. class Aimv2PreTrainedModel(PreTrainedModel):
  310. """
  311. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  312. models. The model is only intended for inference and doesn't support finetuning.
  313. """
  314. config: Aimv2Config
  315. base_model_prefix = "aimv2"
  316. input_modalities = ("image",)
  317. supports_gradient_checkpointing = True
  318. _no_split_modules = [
  319. "Aimv2EncoderLayer",
  320. "Aimv2AttentionPoolingHead",
  321. "Aimv2VisionEmbeddings",
  322. "Aimv2TextEmbeddings",
  323. ]
  324. _supports_sdpa = True
  325. _supports_flash_attn = True
  326. _supports_flex_attn = True
  327. @torch.no_grad()
  328. def _init_weights(self, module):
  329. super()._init_weights(module)
  330. if hasattr(module, "logit_scale"):
  331. if isinstance(module.logit_scale, nn.Parameter):
  332. init.constant_(module.logit_scale, math.log(1 / 0.07))
  333. elif isinstance(module, Aimv2AttentionPoolingHead):
  334. init.normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
  335. elif isinstance(module, Aimv2VisionEmbeddings):
  336. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  337. elif isinstance(module, Aimv2TextEmbeddings):
  338. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  339. @auto_docstring(
  340. custom_intro="""
  341. The Vision model from AIMv2 without any head or projection on top.
  342. """
  343. )
  344. class Aimv2VisionModel(Aimv2PreTrainedModel):
  345. config: Aimv2VisionConfig
  346. main_input_name = "pixel_values"
  347. _can_record_outputs = {
  348. "hidden_states": Aimv2EncoderLayer,
  349. "attentions": Aimv2Attention,
  350. }
  351. def __init__(self, config: Aimv2VisionConfig):
  352. super().__init__(config)
  353. self.config = config
  354. self.embeddings = Aimv2VisionEmbeddings(config)
  355. self.encoder = Aimv2Encoder(config)
  356. # The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
  357. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  358. self.use_head = config.use_head
  359. if self.use_head:
  360. self.head = Aimv2AttentionPoolingHead(config)
  361. self.post_init()
  362. def get_input_embeddings(self) -> nn.Module:
  363. return self.embeddings.patch_embed
  364. @merge_with_config_defaults
  365. @capture_outputs(tie_last_hidden_states=False)
  366. @auto_docstring
  367. def forward(
  368. self,
  369. pixel_values,
  370. **kwargs: Unpack[TransformersKwargs],
  371. ) -> BaseModelOutputWithPooling:
  372. r"""
  373. Examples:
  374. ```python
  375. >>> from PIL import Image
  376. >>> import httpx
  377. >>> from io import BytesIO
  378. >>> from transformers import AutoProcessor, Siglip2VisionModel
  379. >>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native")
  380. >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native")
  381. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  382. >>> with httpx.stream("GET", url) as response:
  383. ... image = Image.open(BytesIO(response.read()))
  384. >>> inputs = processor(images=image, return_tensors="pt")
  385. >>> outputs = model(**inputs)
  386. >>> last_hidden_state = outputs.last_hidden_state
  387. >>> pooled_output = outputs.pooler_output # pooled features
  388. ```"""
  389. hidden_states = self.embeddings(pixel_values)
  390. encoder_outputs: BaseModelOutput = self.encoder(
  391. inputs_embeds=hidden_states,
  392. **kwargs,
  393. )
  394. last_hidden_state = encoder_outputs.last_hidden_state
  395. last_hidden_state = self.rms_norm(last_hidden_state)
  396. pooler_output = self.head(last_hidden_state) if self.use_head else None
  397. return BaseModelOutputWithPooling(
  398. last_hidden_state=last_hidden_state,
  399. pooler_output=pooler_output,
  400. )
  401. @auto_docstring(
  402. custom_intro="""
  403. The text model from AIMv2 without any head or projection on top.
  404. """
  405. )
  406. class Aimv2TextModel(Aimv2PreTrainedModel):
  407. main_input_name = "input_ids"
  408. _can_record_outputs = {
  409. "hidden_states": Aimv2EncoderLayer,
  410. "attentions": Aimv2Attention,
  411. }
  412. def __init__(self, config: Aimv2TextConfig):
  413. super().__init__(config)
  414. self.config = config
  415. self.embeddings = Aimv2TextEmbeddings(config)
  416. self.encoder = Aimv2Encoder(config)
  417. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  418. self.eos_token_id = config.eos_token_id
  419. self.post_init()
  420. def get_input_embeddings(self) -> nn.Module:
  421. return self.embeddings.token_embedding
  422. def set_input_embeddings(self, value):
  423. self.embeddings.token_embedding = value
  424. @merge_with_config_defaults
  425. @capture_outputs(tie_last_hidden_states=False)
  426. @auto_docstring
  427. def forward(
  428. self,
  429. input_ids,
  430. attention_mask: torch.Tensor | None = None,
  431. **kwargs: Unpack[TransformersKwargs],
  432. ) -> BaseModelOutputWithPooling:
  433. hidden_states = self.embeddings(input_ids)
  434. batch_size, seq_len, _ = hidden_states.shape
  435. position_ids = torch.arange(seq_len, dtype=torch.long, device=hidden_states.device)
  436. position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
  437. if attention_mask is not None:
  438. attention_mask = create_causal_mask(
  439. config=self.config,
  440. inputs_embeds=hidden_states,
  441. position_ids=position_ids,
  442. attention_mask=attention_mask,
  443. past_key_values=None,
  444. )
  445. encoder_outputs = self.encoder(
  446. inputs_embeds=hidden_states,
  447. attention_mask=attention_mask,
  448. **kwargs,
  449. )
  450. last_hidden_state = encoder_outputs.last_hidden_state
  451. last_hidden_state = self.rms_norm(last_hidden_state)
  452. # Get pooled output
  453. pooled_output = last_hidden_state[
  454. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  455. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1),
  456. ]
  457. return BaseModelOutputWithPooling(
  458. last_hidden_state=last_hidden_state,
  459. pooler_output=pooled_output,
  460. )
  461. def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
  462. """
  463. This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
  464. model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
  465. """
  466. square_tensor = torch.pow(tensor, 2)
  467. sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
  468. normed_tensor = torch.pow(sum_tensor, 0.5)
  469. return normed_tensor
  470. @auto_docstring
  471. class Aimv2Model(Aimv2PreTrainedModel):
  472. config: Aimv2Config
  473. _no_split_modules = ["Aimv2TextEmbeddings", "Aimv2EncoderLayer", "Aimv2VisionEmbeddings"]
  474. _supports_flash_attn = True
  475. def __init__(self, config: Aimv2Config):
  476. super().__init__(config)
  477. self.projection_dim = config.projection_dim
  478. self.vision_embed_dim = config.vision_config.hidden_size
  479. self.text_embed_dim = config.text_config.hidden_size
  480. self.vision_model = Aimv2VisionModel._from_config(config.vision_config)
  481. self.text_model = Aimv2TextModel._from_config(config.text_config)
  482. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  483. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  484. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  485. self.max_log_logit_scale = math.log(config.max_logit_scale)
  486. self.post_init()
  487. @can_return_tuple
  488. @auto_docstring
  489. def get_text_features(
  490. self,
  491. input_ids: torch.Tensor,
  492. attention_mask: torch.Tensor | None = None,
  493. position_ids: torch.Tensor | None = None,
  494. **kwargs: Unpack[TransformersKwargs],
  495. ) -> tuple | BaseModelOutputWithPooling:
  496. r"""
  497. Examples:
  498. ```python
  499. >>> import torch
  500. >>> from transformers import AutoTokenizer, Aimv2Model
  501. >>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32")
  502. >>> tokenizer = AutoTokenizer.from_pretrained("openai/aimv2-vit-base-patch32")
  503. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  504. >>> with torch.inference_mode():
  505. ... text_features = model.get_text_features(**inputs)
  506. ```"""
  507. text_outputs: BaseModelOutputWithPooling = self.text_model(
  508. input_ids=input_ids,
  509. attention_mask=attention_mask,
  510. position_ids=position_ids,
  511. return_dict=True,
  512. **kwargs,
  513. )
  514. pooled_output = text_outputs.pooler_output
  515. text_outputs.pooler_output = self.text_projection(pooled_output)
  516. return text_outputs
  517. @can_return_tuple
  518. @auto_docstring
  519. def get_image_features(
  520. self,
  521. pixel_values: torch.FloatTensor,
  522. interpolate_pos_encoding: bool = False,
  523. **kwargs: Unpack[TransformersKwargs],
  524. ) -> tuple | BaseModelOutputWithPooling:
  525. r"""
  526. Examples:
  527. ```python
  528. >>> import torch
  529. >>> from transformers import AutoProcessor, Aimv2Model
  530. >>> from transformers.image_utils import load_image
  531. >>> model = Aimv2Model.from_pretrained("openai/aimv2-vit-base-patch32")
  532. >>> processor = AutoProcessor.from_pretrained("openai/aimv2-vit-base-patch32")
  533. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  534. >>> image = load_image(url)
  535. >>> inputs = processor(images=image, return_tensors="pt")
  536. >>> with torch.inference_mode():
  537. ... image_features = model.get_image_features(**inputs)
  538. ```"""
  539. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  540. pixel_values=pixel_values,
  541. interpolate_pos_encoding=interpolate_pos_encoding,
  542. return_dict=True,
  543. **kwargs,
  544. )
  545. pooled_output = vision_outputs.pooler_output
  546. vision_outputs.pooler_output = self.visual_projection(pooled_output)
  547. return vision_outputs
  548. @auto_docstring
  549. @can_return_tuple
  550. def forward(
  551. self,
  552. input_ids: torch.LongTensor | None = None,
  553. pixel_values: torch.FloatTensor | None = None,
  554. attention_mask: torch.Tensor | None = None,
  555. **kwargs: Unpack[TransformersKwargs],
  556. ) -> Aimv2Output:
  557. r"""
  558. Examples:
  559. ```python
  560. >>> from PIL import Image
  561. >>> import httpx
  562. >>> from io import BytesIO
  563. >>> from transformers import AutoProcessor, Aimv2Model
  564. >>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit")
  565. >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
  566. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  567. >>> with httpx.stream("GET", url) as response:
  568. ... image = Image.open(BytesIO(response.read()))
  569. >>> inputs = processor(
  570. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  571. ... )
  572. >>> outputs = model(**inputs)
  573. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  574. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  575. ```"""
  576. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  577. pixel_values=pixel_values,
  578. **kwargs,
  579. )
  580. text_outputs: BaseModelOutputWithPooling = self.text_model(
  581. input_ids=input_ids,
  582. attention_mask=attention_mask,
  583. **kwargs,
  584. )
  585. image_embeds = vision_outputs.pooler_output
  586. image_embeds = self.visual_projection(image_embeds)
  587. text_embeds = text_outputs.pooler_output
  588. text_embeds = self.text_projection(text_embeds)
  589. # normalized features
  590. image_embeds = image_embeds / _get_vector_norm(image_embeds)
  591. text_embeds = text_embeds / _get_vector_norm(text_embeds)
  592. logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp().to(text_embeds.device)
  593. logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
  594. logits_per_image = logits_per_text.t()
  595. return Aimv2Output(
  596. logits_per_image=logits_per_image,
  597. logits_per_text=logits_per_text,
  598. text_embeds=text_embeds,
  599. image_embeds=image_embeds,
  600. text_model_output=text_outputs,
  601. vision_model_output=vision_outputs,
  602. )
  603. __all__ = ["Aimv2VisionModel", "Aimv2Model", "Aimv2PreTrainedModel", "Aimv2TextModel"]