modular_aimv2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. # Copyright 2025 Apple Inc. and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Pytorch implementation of AIMv2 Model"""
  15. import math
  16. import torch
  17. import torch.nn.functional as F
  18. from huggingface_hub.dataclasses import strict
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  25. from ...modeling_utils import PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  28. from ...utils.generic import merge_with_config_defaults
  29. from ...utils.output_capturing import capture_outputs
  30. from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm
  31. from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm
  32. from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
  33. from ..siglip.modeling_siglip import SiglipAttention, SiglipEncoder, SiglipOutput
  34. @auto_docstring(checkpoint="apple/aimv2-large-patch14-224-lit")
  35. @strict
  36. class Aimv2VisionConfig(SiglipVisionConfig):
  37. r"""
  38. use_head (`str`, *optional*, defaults to `True`):
  39. Whether to use Attention Pooling Head or Not.
  40. is_native (`str`, *optional*, defaults to `False`):
  41. Whether to use ckpt trained for image native resolution or not.
  42. Example:
  43. ```python
  44. >>> from transformers import SiglipVisionConfig, SiglipVisionModel
  45. >>> # Initializing a Aimv2VisionConfig with apple/aimv2-large-patch14-224 style configuration
  46. >>> configuration = Aimv2VisionConfig()
  47. >>> # Initializing a Aimv2VisionModel (with random weights) from the apple/aimv2-large-patch14-224 style configuration
  48. >>> model = Aimv2VisionModel(configuration)
  49. >>> # Accessing the model configuration
  50. >>> configuration = model.config
  51. ```"""
  52. hidden_size: int = 1024
  53. intermediate_size: int = 2816
  54. num_hidden_layers: int = 24
  55. num_attention_heads: int = 8
  56. patch_size: int | list[int] | tuple[int, int] = 14
  57. rms_norm_eps: float = 1e-5
  58. attention_dropout: float | int = 0.0
  59. qkv_bias: bool = False
  60. mlp_bias: bool = False
  61. hidden_act: str = "silu"
  62. initializer_range: float = 0.02
  63. use_head: bool = True
  64. is_native: bool = False
  65. layer_norm_eps = AttributeError()
  66. @auto_docstring(checkpoint="apple/aimv2-large-patch14-224-lit")
  67. @strict
  68. class Aimv2TextConfig(SiglipTextConfig):
  69. vocab_size: int = 49408
  70. hidden_size: int = 768
  71. intermediate_size: int = 2048
  72. num_hidden_layers: int = 12
  73. num_attention_heads: int = 6
  74. max_position_embeddings: int = 77
  75. hidden_act: str = "silu"
  76. rms_norm_eps: float = 1e-5
  77. qkv_bias: bool = False
  78. mlp_bias: bool = False
  79. initializer_range: float = 0.02
  80. bos_token_id = AttributeError()
  81. pad_token_id = AttributeError()
  82. layer_norm_eps = AttributeError()
  83. projection_size = AttributeError()
  84. def __post_init__(self, **kwargs):
  85. PreTrainedConfig.__post_init__(**kwargs)
  86. @auto_docstring(checkpoint="apple/aimv2-large-patch14-224-lit")
  87. @strict
  88. class Aimv2Config(SiglipConfig):
  89. r"""
  90. max_logit_scale (`float`, *optional*, defaults to `100.0`):
  91. The maximum logit scale to use
  92. Example:
  93. ```python
  94. >>> from transformers import Aimv2Config, Aimv2Model
  95. >>> # Initializing a Aimv2Config with apple/aimv2-large-patch14-224-lit style configuration
  96. >>> configuration = Aimv2Config()
  97. >>> # Initializing a Aimv2Model (with random weights) from the apple/aimv2-large-patch14-224-lit style configuration
  98. >>> model = Aimv2Model(configuration)
  99. >>> # Accessing the model configuration
  100. >>> configuration = model.config
  101. >>> # We can also initialize a Aimv2Config from a Aimv2TextConfig and a Aimv2VisionConfig
  102. >>> from transformers import Aimv2TextConfig, Aimv2VisionConfig
  103. >>> # Initializing a AIMv2Text and AIMv2Vision configuration
  104. >>> config_text = Aimv2TextConfig()
  105. >>> config_vision = Aimv2VisionConfig()
  106. >>> config = Aimv2Config(text_config=config_text, vision_config=config_vision)
  107. ```"""
  108. projection_dim: int = 512
  109. logit_scale_init_value: float = 2.6592
  110. max_logit_scale: float = 100.0
  111. class Aimv2Output(SiglipOutput):
  112. pass
  113. class Aimv2RMSNorm(LlamaRMSNorm):
  114. pass
  115. class Aimv2MLP(LlamaMLP):
  116. pass
  117. class Aimv2VisionEmbeddings(nn.Module):
  118. def __init__(self, config: Aimv2VisionConfig):
  119. super().__init__()
  120. self.config = config
  121. self.patch_size = config.patch_size
  122. self.patch_embed = nn.Conv2d(
  123. config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
  124. )
  125. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  126. num_patches = (config.image_size // config.patch_size) ** 2
  127. if not self.config.is_native:
  128. self.position_embedding = nn.Embedding(num_patches, config.hidden_size)
  129. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  130. @staticmethod
  131. def build_2d_sincos_position_embedding(
  132. height, width, embed_dim=256, temperature=10000.0, device="cpu", dtype=torch.float32
  133. ) -> torch.Tensor:
  134. grid_w = torch.arange(int(width), dtype=dtype, device=device)
  135. grid_h = torch.arange(int(height), dtype=dtype, device=device)
  136. grid_h, grid_w = torch.meshgrid(grid_w, grid_h, indexing="xy")
  137. pos_dim = embed_dim // 4
  138. omega = torch.arange(pos_dim, dtype=dtype, device=device) / pos_dim
  139. omega = 1.0 / (temperature**omega)
  140. out_h = grid_h.flatten()[..., None] @ omega[None, :]
  141. out_w = grid_w.flatten()[..., None] @ omega[None, :]
  142. return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
  143. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  144. _, _, height, width = pixel_values.size()
  145. hidden_states = self.patch_embed(pixel_values).flatten(2).transpose(1, 2)
  146. hidden_states = self.rms_norm(hidden_states)
  147. if self.config.is_native:
  148. pos_embed = self.build_2d_sincos_position_embedding(
  149. height // self.patch_size,
  150. width // self.patch_size,
  151. embed_dim=self.config.hidden_size,
  152. device=hidden_states.device,
  153. dtype=hidden_states.dtype,
  154. )
  155. else:
  156. pos_embed = self.position_embedding(self.position_ids)
  157. hidden_states = hidden_states + pos_embed
  158. return hidden_states
  159. class Aimv2TextEmbeddings(CLIPTextEmbeddings):
  160. pass
  161. class Aimv2Attention(SiglipAttention):
  162. def __init__(self, config):
  163. super().__init__(config)
  164. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  165. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  166. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  167. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_bias)
  168. class Aimv2EncoderLayer(GradientCheckpointingLayer):
  169. def __init__(self, config: Aimv2VisionConfig):
  170. super().__init__()
  171. self.attention = Aimv2Attention(config)
  172. self.ffn = Aimv2MLP(config)
  173. self.rms_norm1 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  174. self.rms_norm2 = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  175. def forward(
  176. self,
  177. hidden_states: torch.Tensor,
  178. attention_mask: torch.Tensor | None = None,
  179. **kwargs: Unpack[TransformersKwargs],
  180. ) -> torch.Tensor:
  181. norm_hidden_states = self.rms_norm1(hidden_states)
  182. attn_output, _ = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask, **kwargs)
  183. hidden_states = hidden_states + attn_output
  184. norm_hidden_states = self.rms_norm2(hidden_states)
  185. mlp_output = self.ffn(norm_hidden_states)
  186. hidden_states = hidden_states + mlp_output
  187. return hidden_states
  188. class Aimv2Encoder(SiglipEncoder):
  189. pass
  190. class Aimv2AttentionPoolingHead(nn.Module):
  191. def __init__(self, config: Aimv2VisionConfig):
  192. super().__init__()
  193. self.hidden_size = config.hidden_size
  194. self.num_heads = config.num_attention_heads
  195. self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
  196. self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.qkv_bias)
  197. self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
  198. self.output_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
  199. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  200. batch_size, seq_len, hidden_dim = hidden_states.shape
  201. cls_token = self.cls_token.expand(batch_size, -1, -1)
  202. key = self.k_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
  203. value = self.v_proj(hidden_states).reshape(batch_size, seq_len, self.num_heads, hidden_dim // self.num_heads)
  204. query = cls_token.reshape(batch_size, 1, self.num_heads, hidden_dim // self.num_heads)
  205. key = key.permute(0, 2, 1, 3)
  206. value = value.permute(0, 2, 1, 3)
  207. query = query.permute(0, 2, 1, 3)
  208. attn_output = F.scaled_dot_product_attention(query, key, value)
  209. attn_output = attn_output.transpose(1, 2).reshape(batch_size, 1, hidden_dim)
  210. attn_output = attn_output.mean(dim=1)
  211. output = self.output_proj(attn_output)
  212. return output
  213. @auto_docstring
  214. class Aimv2PreTrainedModel(PreTrainedModel):
  215. """
  216. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  217. models. The model is only intended for inference and doesn't support finetuning.
  218. """
  219. config: Aimv2Config
  220. base_model_prefix = "aimv2"
  221. input_modalities = ("image",)
  222. supports_gradient_checkpointing = True
  223. _no_split_modules = [
  224. "Aimv2EncoderLayer",
  225. "Aimv2AttentionPoolingHead",
  226. "Aimv2VisionEmbeddings",
  227. "Aimv2TextEmbeddings",
  228. ]
  229. _supports_sdpa = True
  230. _supports_flash_attn = True
  231. _supports_flex_attn = True
  232. @torch.no_grad()
  233. def _init_weights(self, module):
  234. super()._init_weights(module)
  235. if hasattr(module, "logit_scale"):
  236. if isinstance(module.logit_scale, nn.Parameter):
  237. init.constant_(module.logit_scale, math.log(1 / 0.07))
  238. elif isinstance(module, Aimv2AttentionPoolingHead):
  239. init.normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
  240. elif isinstance(module, Aimv2VisionEmbeddings):
  241. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  242. elif isinstance(module, Aimv2TextEmbeddings):
  243. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  244. @auto_docstring(
  245. custom_intro="""
  246. The Vision model from AIMv2 without any head or projection on top.
  247. """
  248. )
  249. class Aimv2VisionModel(Aimv2PreTrainedModel):
  250. config: Aimv2VisionConfig
  251. main_input_name = "pixel_values"
  252. _can_record_outputs = {
  253. "hidden_states": Aimv2EncoderLayer,
  254. "attentions": Aimv2Attention,
  255. }
  256. def __init__(self, config: Aimv2VisionConfig):
  257. super().__init__(config)
  258. self.config = config
  259. self.embeddings = Aimv2VisionEmbeddings(config)
  260. self.encoder = Aimv2Encoder(config)
  261. # The only change from SiglipVisionTransformer is, layernorm -> rms_norm.
  262. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  263. self.use_head = config.use_head
  264. if self.use_head:
  265. self.head = Aimv2AttentionPoolingHead(config)
  266. self.post_init()
  267. def get_input_embeddings(self) -> nn.Module:
  268. return self.embeddings.patch_embed
  269. @merge_with_config_defaults
  270. @capture_outputs(tie_last_hidden_states=False)
  271. @auto_docstring
  272. def forward(
  273. self,
  274. pixel_values,
  275. **kwargs: Unpack[TransformersKwargs],
  276. ) -> BaseModelOutputWithPooling:
  277. r"""
  278. Examples:
  279. ```python
  280. >>> from PIL import Image
  281. >>> import httpx
  282. >>> from io import BytesIO
  283. >>> from transformers import AutoProcessor, Siglip2VisionModel
  284. >>> model = Aimv2VisionModel.from_pretrained("apple/aimv2-large-patch14-native")
  285. >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-native")
  286. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  287. >>> with httpx.stream("GET", url) as response:
  288. ... image = Image.open(BytesIO(response.read()))
  289. >>> inputs = processor(images=image, return_tensors="pt")
  290. >>> outputs = model(**inputs)
  291. >>> last_hidden_state = outputs.last_hidden_state
  292. >>> pooled_output = outputs.pooler_output # pooled features
  293. ```"""
  294. hidden_states = self.embeddings(pixel_values)
  295. encoder_outputs: BaseModelOutput = self.encoder(
  296. inputs_embeds=hidden_states,
  297. **kwargs,
  298. )
  299. last_hidden_state = encoder_outputs.last_hidden_state
  300. last_hidden_state = self.rms_norm(last_hidden_state)
  301. pooler_output = self.head(last_hidden_state) if self.use_head else None
  302. return BaseModelOutputWithPooling(
  303. last_hidden_state=last_hidden_state,
  304. pooler_output=pooler_output,
  305. )
  306. @auto_docstring(
  307. custom_intro="""
  308. The text model from AIMv2 without any head or projection on top.
  309. """
  310. )
  311. class Aimv2TextModel(Aimv2PreTrainedModel):
  312. main_input_name = "input_ids"
  313. _can_record_outputs = {
  314. "hidden_states": Aimv2EncoderLayer,
  315. "attentions": Aimv2Attention,
  316. }
  317. def __init__(self, config: Aimv2TextConfig):
  318. super().__init__(config)
  319. self.config = config
  320. self.embeddings = Aimv2TextEmbeddings(config)
  321. self.encoder = Aimv2Encoder(config)
  322. self.rms_norm = Aimv2RMSNorm(config.hidden_size, config.rms_norm_eps)
  323. self.eos_token_id = config.eos_token_id
  324. self.post_init()
  325. def get_input_embeddings(self) -> nn.Module:
  326. return self.embeddings.token_embedding
  327. def set_input_embeddings(self, value):
  328. self.embeddings.token_embedding = value
  329. @merge_with_config_defaults
  330. @capture_outputs(tie_last_hidden_states=False)
  331. @auto_docstring
  332. def forward(
  333. self,
  334. input_ids,
  335. attention_mask: torch.Tensor | None = None,
  336. **kwargs: Unpack[TransformersKwargs],
  337. ) -> BaseModelOutputWithPooling:
  338. hidden_states = self.embeddings(input_ids)
  339. batch_size, seq_len, _ = hidden_states.shape
  340. position_ids = torch.arange(seq_len, dtype=torch.long, device=hidden_states.device)
  341. position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
  342. if attention_mask is not None:
  343. attention_mask = create_causal_mask(
  344. config=self.config,
  345. inputs_embeds=hidden_states,
  346. position_ids=position_ids,
  347. attention_mask=attention_mask,
  348. past_key_values=None,
  349. )
  350. encoder_outputs = self.encoder(
  351. inputs_embeds=hidden_states,
  352. attention_mask=attention_mask,
  353. **kwargs,
  354. )
  355. last_hidden_state = encoder_outputs.last_hidden_state
  356. last_hidden_state = self.rms_norm(last_hidden_state)
  357. # Get pooled output
  358. pooled_output = last_hidden_state[
  359. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  360. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id).int().argmax(dim=-1),
  361. ]
  362. return BaseModelOutputWithPooling(
  363. last_hidden_state=last_hidden_state,
  364. pooler_output=pooled_output,
  365. )
  366. @auto_docstring
  367. class Aimv2Model(CLIPModel):
  368. _supports_flash_attn = True
  369. def __init__(self, config: Aimv2Config):
  370. PreTrainedModel.__init__(self, config)
  371. self.projection_dim = config.projection_dim
  372. self.vision_embed_dim = config.vision_config.hidden_size
  373. self.text_embed_dim = config.text_config.hidden_size
  374. self.vision_model = Aimv2VisionModel._from_config(config.vision_config)
  375. self.text_model = Aimv2TextModel._from_config(config.text_config)
  376. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  377. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  378. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  379. self.max_log_logit_scale = math.log(config.max_logit_scale)
  380. self.post_init()
  381. @auto_docstring
  382. @can_return_tuple
  383. def forward(
  384. self,
  385. input_ids: torch.LongTensor | None = None,
  386. pixel_values: torch.FloatTensor | None = None,
  387. attention_mask: torch.Tensor | None = None,
  388. **kwargs: Unpack[TransformersKwargs],
  389. ) -> Aimv2Output:
  390. r"""
  391. Examples:
  392. ```python
  393. >>> from PIL import Image
  394. >>> import httpx
  395. >>> from io import BytesIO
  396. >>> from transformers import AutoProcessor, Aimv2Model
  397. >>> model = Aimv2Model.from_pretrained("apple/aimv2-large-patch14-224-lit")
  398. >>> processor = AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit")
  399. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  400. >>> with httpx.stream("GET", url) as response:
  401. ... image = Image.open(BytesIO(response.read()))
  402. >>> inputs = processor(
  403. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  404. ... )
  405. >>> outputs = model(**inputs)
  406. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  407. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  408. ```"""
  409. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  410. pixel_values=pixel_values,
  411. **kwargs,
  412. )
  413. text_outputs: BaseModelOutputWithPooling = self.text_model(
  414. input_ids=input_ids,
  415. attention_mask=attention_mask,
  416. **kwargs,
  417. )
  418. image_embeds = vision_outputs.pooler_output
  419. image_embeds = self.visual_projection(image_embeds)
  420. text_embeds = text_outputs.pooler_output
  421. text_embeds = self.text_projection(text_embeds)
  422. # normalized features
  423. image_embeds = image_embeds / _get_vector_norm(image_embeds)
  424. text_embeds = text_embeds / _get_vector_norm(text_embeds)
  425. logit_scale = self.logit_scale.clamp(0.0, self.max_log_logit_scale).exp().to(text_embeds.device)
  426. logits_per_text = (logit_scale * text_embeds) @ image_embeds.t()
  427. logits_per_image = logits_per_text.t()
  428. return Aimv2Output(
  429. logits_per_image=logits_per_image,
  430. logits_per_text=logits_per_text,
  431. text_embeds=text_embeds,
  432. image_embeds=image_embeds,
  433. text_model_output=text_outputs,
  434. vision_model_output=vision_outputs,
  435. )
  436. __all__ = [
  437. "Aimv2Config",
  438. "Aimv2VisionConfig",
  439. "Aimv2TextConfig",
  440. "Aimv2VisionModel",
  441. "Aimv2Model",
  442. "Aimv2PreTrainedModel",
  443. "Aimv2TextModel",
  444. ]