vision.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object"""
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  23. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  24. from ...processing_utils import Unpack
  25. from ...utils import (
  26. ModelOutput,
  27. TransformersKwargs,
  28. logging,
  29. )
  30. from .configuration_idefics import IdeficsVisionConfig
  31. logger = logging.get_logger(__name__)
  32. @dataclass
  33. class IdeficsVisionModelOutput(ModelOutput):
  34. """
  35. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  36. Args:
  37. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  38. The image embeddings obtained by applying the projection layer to the pooler_output.
  39. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  40. Sequence of hidden-states at the output of the last layer of the model.
  41. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  42. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  43. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  44. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  45. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  46. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  47. sequence_length)`.
  48. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  49. heads.
  50. """
  51. image_embeds: torch.FloatTensor | None = None
  52. last_hidden_state: torch.FloatTensor | None = None
  53. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  54. attentions: tuple[torch.FloatTensor, ...] | None = None
  55. # Adapted from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings
  56. class IdeficsVisionEmbeddings(nn.Module):
  57. def __init__(self, config: IdeficsVisionConfig):
  58. super().__init__()
  59. self.config = config
  60. self.embed_dim = config.hidden_size
  61. self.image_size = config.image_size
  62. self.patch_size = config.patch_size
  63. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  64. self.patch_embedding = nn.Conv2d(
  65. in_channels=config.num_channels,
  66. out_channels=self.embed_dim,
  67. kernel_size=self.patch_size,
  68. stride=self.patch_size,
  69. bias=False,
  70. )
  71. self.num_patches = (self.image_size // self.patch_size) ** 2
  72. self.num_positions = self.num_patches + 1
  73. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  74. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  75. # Heavily inspired from https://github.com/huggingface/transformers/blob/v4.33.0/src/transformers/models/vit/modeling_vit.py#L82
  76. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  77. """
  78. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
  79. resolution images.
  80. Source:
  81. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  82. """
  83. num_patches = embeddings.shape[1] - 1
  84. pos_embed = self.position_embedding(self.position_ids)
  85. num_positions = pos_embed.shape[1] - 1
  86. if num_patches == num_positions and height == width:
  87. return pos_embed
  88. class_pos_embed = pos_embed[:, 0]
  89. patch_pos_embed = pos_embed[:, 1:]
  90. embed_dim = embeddings.shape[-1]
  91. num_h_patches = height // self.config.patch_size
  92. num_w_patches = width // self.config.patch_size
  93. # we add a small number to avoid floating point error in the interpolation
  94. # see discussion at https://github.com/facebookresearch/dino/issues/8
  95. num_h_patches, num_w_patches = num_h_patches + 0.1, num_w_patches + 0.1
  96. sqrt_num_positions = math.sqrt(num_positions)
  97. patch_pos_embed = patch_pos_embed.reshape(1, int(sqrt_num_positions), int(sqrt_num_positions), embed_dim)
  98. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  99. fp32_upcasting = patch_pos_embed.dtype == torch.bfloat16
  100. if fp32_upcasting:
  101. logger.warning_once(
  102. "Upcasting patch_pos_embed to fp32 for interpolation since `upsample_bicubic2d_out_frame` in nn.functional.interpolate "
  103. "is not implemented for 'torch.bfloat16' dtype. This will result in a slight overhead."
  104. )
  105. patch_pos_embed = patch_pos_embed.to(torch.float)
  106. patch_pos_embed = nn.functional.interpolate(
  107. patch_pos_embed,
  108. scale_factor=(num_h_patches / sqrt_num_positions, num_w_patches / sqrt_num_positions),
  109. mode="bicubic",
  110. align_corners=False,
  111. )
  112. if fp32_upcasting:
  113. patch_pos_embed = patch_pos_embed.to(torch.bfloat16)
  114. if int(num_h_patches) != patch_pos_embed.shape[-2] or int(num_w_patches) != patch_pos_embed.shape[-1]:
  115. raise ValueError(
  116. f"Number of patches for images ({int(num_h_patches), int(num_w_patches)}) don't match the "
  117. f"shape of position embedding ({patch_pos_embed.shape[-2], patch_pos_embed.shape[-1]})"
  118. )
  119. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
  120. return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
  121. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  122. batch_size, num_channels, height, width = pixel_values.shape
  123. if not interpolate_pos_encoding:
  124. if height != self.image_size or width != self.image_size:
  125. raise ValueError(
  126. f"Input image size ({height}*{width}) doesn't match model"
  127. f" ({self.image_size}*{self.image_size}). You should try to set `interpolate_pos_encoding=True`"
  128. )
  129. target_dtype = self.patch_embedding.weight.dtype
  130. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  131. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  132. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  133. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  134. # add positional encoding to each token
  135. if interpolate_pos_encoding:
  136. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  137. else:
  138. embeddings = embeddings + self.position_embedding(self.position_ids)
  139. return embeddings
  140. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  141. def eager_attention_forward(
  142. module: nn.Module,
  143. query: torch.Tensor,
  144. key: torch.Tensor,
  145. value: torch.Tensor,
  146. attention_mask: torch.Tensor | None,
  147. scaling: float,
  148. dropout: float = 0.0,
  149. **kwargs,
  150. ):
  151. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  152. if attention_mask is not None:
  153. attn_weights = attn_weights + attention_mask
  154. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  155. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  156. attn_output = torch.matmul(attn_weights, value)
  157. attn_output = attn_output.transpose(1, 2).contiguous()
  158. return attn_output, attn_weights
  159. class IdeficsVisionAttention(nn.Module):
  160. """Multi-headed attention from 'Attention Is All You Need' paper"""
  161. def __init__(self, config: IdeficsVisionConfig):
  162. super().__init__()
  163. self.config = config
  164. self.embed_dim = config.hidden_size
  165. self.num_heads = config.num_attention_heads
  166. self.head_dim = self.embed_dim // self.num_heads
  167. if self.head_dim * self.num_heads != self.embed_dim:
  168. raise ValueError(
  169. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  170. f" {self.num_heads})."
  171. )
  172. self.scale = self.head_dim**-0.5
  173. self.dropout = config.attention_dropout
  174. self.is_causal = False
  175. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  176. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  177. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  178. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  179. def forward(
  180. self,
  181. hidden_states: torch.Tensor,
  182. attention_mask: torch.Tensor | None = None,
  183. **kwargs: Unpack[TransformersKwargs],
  184. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  185. """Input shape: Batch x Time x Channel"""
  186. input_shape = hidden_states.shape[:-1]
  187. hidden_shape = (*input_shape, -1, self.head_dim)
  188. queries = self.q_proj(hidden_states)
  189. keys = self.k_proj(hidden_states)
  190. values = self.v_proj(hidden_states)
  191. queries = queries.view(hidden_shape).transpose(1, 2)
  192. keys = keys.view(hidden_shape).transpose(1, 2)
  193. values = values.view(hidden_shape).transpose(1, 2)
  194. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  195. self.config._attn_implementation, eager_attention_forward
  196. )
  197. attn_output, attn_weights = attention_interface(
  198. self,
  199. queries,
  200. keys,
  201. values,
  202. attention_mask,
  203. is_causal=self.is_causal,
  204. scaling=self.scale,
  205. dropout=0.0 if not self.training else self.dropout,
  206. **kwargs,
  207. )
  208. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  209. attn_output = self.out_proj(attn_output)
  210. return attn_output, attn_weights
  211. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision
  212. class IdeficsVisionMLP(nn.Module):
  213. def __init__(self, config):
  214. super().__init__()
  215. self.config = config
  216. self.activation_fn = ACT2FN[config.hidden_act]
  217. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  218. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  219. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  220. hidden_states = self.fc1(hidden_states)
  221. hidden_states = self.activation_fn(hidden_states)
  222. hidden_states = self.fc2(hidden_states)
  223. return hidden_states
  224. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision
  225. class IdeficsVisionEncoderLayer(GradientCheckpointingLayer):
  226. def __init__(self, config: IdeficsVisionConfig):
  227. super().__init__()
  228. self.embed_dim = config.hidden_size
  229. self.self_attn = IdeficsVisionAttention(config)
  230. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  231. self.mlp = IdeficsVisionMLP(config)
  232. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  233. def forward(
  234. self,
  235. hidden_states: torch.Tensor,
  236. attention_mask: torch.Tensor,
  237. **kwargs: Unpack[TransformersKwargs],
  238. ) -> tuple[torch.FloatTensor, torch.Tensor | None]:
  239. residual = hidden_states
  240. hidden_states = self.layer_norm1(hidden_states)
  241. hidden_states, _ = self.self_attn(
  242. hidden_states=hidden_states,
  243. attention_mask=attention_mask,
  244. **kwargs,
  245. )
  246. hidden_states = residual + hidden_states
  247. residual = hidden_states
  248. hidden_states = self.layer_norm2(hidden_states)
  249. hidden_states = self.mlp(hidden_states)
  250. hidden_states = residual + hidden_states
  251. return hidden_states
  252. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->IdeficsVision
  253. class IdeficsVisionEncoder(nn.Module):
  254. """
  255. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  256. [`IdeficsVisionEncoderLayer`].
  257. Args:
  258. config: IdeficsVisionConfig
  259. """
  260. def __init__(self, config: IdeficsVisionConfig):
  261. super().__init__()
  262. self.config = config
  263. self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  264. self.gradient_checkpointing = False
  265. def forward(
  266. self,
  267. inputs_embeds,
  268. attention_mask: torch.Tensor | None = None,
  269. **kwargs: Unpack[TransformersKwargs],
  270. ) -> tuple | BaseModelOutput:
  271. r"""
  272. Args:
  273. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  274. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  275. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  276. than the model's internal embedding lookup matrix.
  277. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  278. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  279. - 1 for tokens that are **not masked**,
  280. - 0 for tokens that are **masked**.
  281. [What are attention masks?](../glossary#attention-mask)
  282. """
  283. hidden_states = inputs_embeds
  284. for encoder_layer in self.layers:
  285. hidden_states = encoder_layer(
  286. hidden_states,
  287. attention_mask,
  288. **kwargs,
  289. )
  290. return BaseModelOutput(
  291. last_hidden_state=hidden_states,
  292. )
  293. # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer
  294. class IdeficsVisionTransformer(nn.Module):
  295. def __init__(self, config: IdeficsVisionConfig):
  296. super().__init__()
  297. self.config = config
  298. embed_dim = config.hidden_size
  299. self.embeddings = IdeficsVisionEmbeddings(config)
  300. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  301. self.encoder = IdeficsVisionEncoder(config)
  302. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  303. # Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
  304. def forward(
  305. self,
  306. pixel_values: torch.FloatTensor | None = None,
  307. interpolate_pos_encoding: bool | None = False,
  308. **kwargs,
  309. ) -> tuple | BaseModelOutputWithPooling:
  310. r"""
  311. Returns:
  312. """
  313. if pixel_values is None:
  314. raise ValueError("You have to specify pixel_values")
  315. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  316. hidden_states = self.pre_layrnorm(hidden_states)
  317. encoder_outputs: BaseModelOutput = self.encoder(
  318. inputs_embeds=hidden_states,
  319. **kwargs,
  320. )
  321. last_hidden_state = encoder_outputs.last_hidden_state
  322. pooled_output = last_hidden_state[:, 0, :]
  323. pooled_output = self.post_layernorm(pooled_output)
  324. return BaseModelOutputWithPooling(
  325. last_hidden_state=last_hidden_state,
  326. pooler_output=pooled_output,
  327. )