modeling_pixtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Pixtral model."""
  15. from collections.abc import Callable
  16. from typing import Optional
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...modeling_layers import GradientCheckpointingLayer
  21. from ...modeling_outputs import BaseModelOutput
  22. from ...modeling_rope_utils import dynamic_rope_update
  23. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  24. from ...processing_utils import Unpack
  25. from ...utils import TransformersKwargs, auto_docstring, logging
  26. from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
  27. from ...utils.output_capturing import capture_outputs
  28. from .configuration_pixtral import PixtralVisionConfig
  29. logger = logging.get_logger(__name__)
  30. def position_ids_in_meshgrid(patch_embeds_list, max_width):
  31. positions = []
  32. for patch in patch_embeds_list:
  33. height, width = patch.shape[-2:]
  34. mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
  35. h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
  36. ids = h_grid * max_width + v_grid
  37. positions.append(ids[:, 0])
  38. return torch.cat(positions)
  39. class PixtralRotaryEmbedding(nn.Module):
  40. """
  41. The key with pixtral embedding is just that you have a frequency for each pixel positions.
  42. If you have height x width pixels (or embedding pixels), then the frequency used for ROPE
  43. is given by indexing the pre_computed frequency on the width and height.
  44. What you output is of dimension (batch, height * width, dim) with dim the embed dim.
  45. This simply means that for each image hidden state, you are going to add
  46. a corresponding positional embedding, based on its index in the grid.
  47. """
  48. inv_freq: torch.Tensor # fix linting for `register_buffer`
  49. def __init__(self, config: PixtralVisionConfig, device=None, layer_type=None):
  50. super().__init__()
  51. self.config = config
  52. self.rope_type = self.config.rope_parameters["rope_type"]
  53. rope_init_fn: Callable = self.compute_default_rope_parameters
  54. if self.rope_type != "default":
  55. raise ValueError(
  56. f"{self.__class__.__name__} does not support non-default RoPE, but got `rope_type={self.rope_type}`"
  57. )
  58. inv_freq, attention_scaling = rope_init_fn(self.config, device)
  59. self.register_buffer("inv_freq", inv_freq, persistent=False)
  60. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  61. @staticmethod
  62. def compute_default_rope_parameters(
  63. config: PixtralVisionConfig | None = None,
  64. device: Optional["torch.device"] = None,
  65. seq_len: int | None = None,
  66. ) -> tuple["torch.Tensor", float]:
  67. """
  68. Computes the inverse frequencies according to the original RoPE implementation
  69. Args:
  70. config ([`~transformers.PreTrainedConfig`]):
  71. The model configuration.
  72. device (`torch.device`):
  73. The device to use for initialization of the inverse frequencies.
  74. seq_len (`int`, *optional*):
  75. The current sequence length. Unused for this type of RoPE.
  76. Returns:
  77. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  78. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  79. """
  80. base = config.rope_parameters["rope_theta"]
  81. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  82. attention_factor = 1.0 # Unused in this type of RoPE
  83. # Here is the diff from Llama RoPE
  84. max_patches_per_side = config.image_size // config.patch_size
  85. h = torch.arange(max_patches_per_side)
  86. w = torch.arange(max_patches_per_side)
  87. freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
  88. freqs_h = torch.outer(h, freqs[::2]).float()
  89. freqs_w = torch.outer(w, freqs[1::2]).float()
  90. inv_freq = torch.cat(
  91. [
  92. freqs_h[:, None, :].repeat(1, max_patches_per_side, 1),
  93. freqs_w[None, :, :].repeat(max_patches_per_side, 1, 1),
  94. ],
  95. dim=-1,
  96. ).reshape(-1, dim // 2) # we reshape to only index on the position indexes, not tuple of indexes
  97. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  98. # TODO maybe make it torch compatible later on. We can also just slice
  99. inv_freq = torch.cat((inv_freq, inv_freq), dim=-1)
  100. return inv_freq, attention_factor
  101. @torch.no_grad()
  102. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  103. def forward(self, x, position_ids):
  104. freqs = self.inv_freq[position_ids]
  105. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  106. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  107. emb = freqs
  108. cos = emb.cos()
  109. sin = emb.sin()
  110. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  111. # Copied from transformers.models.llama.modeling_llama.rotate_half
  112. def rotate_half(x):
  113. """Rotates half the hidden dims of the input."""
  114. x1 = x[..., : x.shape[-1] // 2]
  115. x2 = x[..., x.shape[-1] // 2 :]
  116. return torch.cat((-x2, x1), dim=-1)
  117. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  118. """Applies Rotary Position Embedding to the query and key tensors.
  119. Args:
  120. q (`torch.Tensor`): The query tensor.
  121. k (`torch.Tensor`): The key tensor.
  122. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  123. sin (`torch.Tensor`): The sine part of the rotary embedding.
  124. unsqueeze_dim (`int`, *optional*, defaults to 1):
  125. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  126. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  127. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  128. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  129. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  130. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  131. Returns:
  132. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  133. """
  134. cos = cos.unsqueeze(unsqueeze_dim)
  135. sin = sin.unsqueeze(unsqueeze_dim)
  136. q_embed = (q * cos) + (rotate_half(q) * sin)
  137. k_embed = (k * cos) + (rotate_half(k) * sin)
  138. return q_embed, k_embed
  139. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  140. def eager_attention_forward(
  141. module: nn.Module,
  142. query: torch.Tensor,
  143. key: torch.Tensor,
  144. value: torch.Tensor,
  145. attention_mask: torch.Tensor | None,
  146. scaling: float,
  147. dropout: float = 0.0,
  148. **kwargs,
  149. ):
  150. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  151. if attention_mask is not None:
  152. attn_weights = attn_weights + attention_mask
  153. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  154. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  155. attn_output = torch.matmul(attn_weights, value)
  156. attn_output = attn_output.transpose(1, 2).contiguous()
  157. return attn_output, attn_weights
  158. class PixtralAttention(nn.Module):
  159. """
  160. Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS.
  161. """
  162. def __init__(self, config):
  163. super().__init__()
  164. self.config = config
  165. self.embed_dim = config.hidden_size
  166. self.num_heads = config.num_attention_heads
  167. self.head_dim = self.embed_dim // self.num_heads
  168. self.is_causal = False
  169. self.scaling = self.head_dim**-0.5
  170. self.is_causal = False
  171. self.dropout = config.attention_dropout
  172. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  173. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  174. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  175. self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  176. def forward(
  177. self,
  178. hidden_states: torch.Tensor,
  179. attention_mask: torch.Tensor | None = None,
  180. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  181. **kwargs: Unpack[TransformersKwargs],
  182. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  183. """Input shape: Batch x Time x Channel"""
  184. batch_size, patches, _ = hidden_states.size()
  185. query_states = self.q_proj(hidden_states)
  186. key_states = self.k_proj(hidden_states)
  187. value_states = self.v_proj(hidden_states)
  188. query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
  189. key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
  190. value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
  191. cos, sin = position_embeddings
  192. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=0)
  193. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  194. self.config._attn_implementation, eager_attention_forward
  195. )
  196. attn_output, attn_weights = attention_interface(
  197. self,
  198. query_states,
  199. key_states,
  200. value_states,
  201. attention_mask,
  202. dropout=0.0 if not self.training else self.dropout,
  203. scaling=self.scaling,
  204. **kwargs,
  205. )
  206. attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
  207. attn_output = self.o_proj(attn_output)
  208. return attn_output, attn_weights
  209. # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Pixtral
  210. class PixtralMLP(nn.Module):
  211. def __init__(self, config):
  212. super().__init__()
  213. self.config = config
  214. self.hidden_size = config.hidden_size
  215. self.intermediate_size = config.intermediate_size
  216. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  217. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  218. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  219. self.act_fn = ACT2FN[config.hidden_act]
  220. def forward(self, x):
  221. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  222. return down_proj
  223. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral
  224. class PixtralRMSNorm(nn.Module):
  225. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  226. """
  227. PixtralRMSNorm is equivalent to T5LayerNorm
  228. """
  229. super().__init__()
  230. self.weight = nn.Parameter(torch.ones(hidden_size))
  231. self.variance_epsilon = eps
  232. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  233. input_dtype = hidden_states.dtype
  234. hidden_states = hidden_states.to(torch.float32)
  235. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  236. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  237. return self.weight * hidden_states.to(input_dtype)
  238. def extra_repr(self):
  239. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  240. class PixtralAttentionLayer(GradientCheckpointingLayer):
  241. def __init__(self, config):
  242. super().__init__()
  243. self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
  244. self.feed_forward = PixtralMLP(config)
  245. self.attention = PixtralAttention(config)
  246. self.ffn_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5)
  247. def forward(
  248. self,
  249. hidden_states: torch.Tensor,
  250. attention_mask: torch.Tensor,
  251. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  252. **kwargs: Unpack[TransformersKwargs],
  253. ) -> torch.Tensor:
  254. """
  255. Args:
  256. hidden_states (`torch.FloatTensor`):
  257. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  258. attention_mask (`torch.FloatTensor`):
  259. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  260. """
  261. residual = hidden_states
  262. hidden_states = self.attention_norm(hidden_states)
  263. hidden_states, _ = self.attention(
  264. hidden_states=hidden_states,
  265. attention_mask=attention_mask,
  266. position_embeddings=position_embeddings,
  267. **kwargs,
  268. )
  269. hidden_states = residual + hidden_states
  270. residual = hidden_states
  271. hidden_states = self.ffn_norm(hidden_states)
  272. hidden_states = self.feed_forward(hidden_states)
  273. hidden_states = residual + hidden_states
  274. return hidden_states
  275. class PixtralTransformer(nn.Module):
  276. def __init__(self, config):
  277. super().__init__()
  278. self.config = config
  279. self.layers = torch.nn.ModuleList()
  280. for _ in range(config.num_hidden_layers):
  281. self.layers.append(PixtralAttentionLayer(config))
  282. self.gradient_checkpointing = False
  283. def forward(
  284. self,
  285. inputs_embeds,
  286. attention_mask: torch.Tensor | None = None,
  287. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  288. **kwargs: Unpack[TransformersKwargs],
  289. ) -> tuple | BaseModelOutput:
  290. r"""
  291. Args:
  292. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  293. Embeddings which serve as input to the Transformer.
  294. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  295. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  296. - 1 for tokens that are **not masked**,
  297. - 0 for tokens that are **masked**.
  298. [What are attention masks?](../glossary#attention-mask)
  299. """
  300. hidden_states = inputs_embeds
  301. for encoder_layer in self.layers:
  302. hidden_states = encoder_layer(
  303. hidden_states,
  304. attention_mask,
  305. position_embeddings=position_embeddings,
  306. **kwargs,
  307. )
  308. return BaseModelOutput(last_hidden_state=hidden_states)
  309. @auto_docstring
  310. class PixtralPreTrainedModel(PreTrainedModel):
  311. config: PixtralVisionConfig
  312. base_model_prefix = "model"
  313. main_input_name = "pixel_values"
  314. input_modalities = ("image",)
  315. supports_gradient_checkpointing = True
  316. _supports_attention_backend = True
  317. _supports_flash_attn = True
  318. _supports_sdpa = True
  319. _supports_flex_attn = True
  320. _no_split_modules = ["PixtralAttentionLayer"]
  321. _can_record_outputs = {
  322. "hidden_states": PixtralAttentionLayer,
  323. "attentions": PixtralAttention,
  324. }
  325. def generate_block_attention_mask(patch_embeds_list, tensor):
  326. dtype = tensor.dtype
  327. device = tensor.device
  328. seq_len = tensor.shape[1]
  329. d_min = torch.finfo(dtype).min
  330. causal_mask = torch.full((seq_len, seq_len), fill_value=d_min, dtype=dtype, device=device)
  331. block_end_idx = torch.tensor(patch_embeds_list).cumsum(-1)
  332. block_start_idx = torch.tensor([0] + patch_embeds_list[:-1]).cumsum(-1)
  333. for start, end in zip(block_start_idx, block_end_idx):
  334. causal_mask[start:end, start:end] = 0
  335. causal_mask = causal_mask[None, None, :, :].expand(tensor.shape[0], 1, -1, -1)
  336. return causal_mask
  337. @auto_docstring
  338. class PixtralVisionModel(PixtralPreTrainedModel):
  339. base_model_prefix = "vision_encoder"
  340. def __init__(self, config):
  341. super().__init__(config)
  342. self.config = config
  343. self.patch_conv = nn.Conv2d(
  344. in_channels=config.num_channels,
  345. out_channels=config.hidden_size,
  346. kernel_size=config.patch_size,
  347. stride=config.patch_size,
  348. bias=False,
  349. )
  350. self.patch_size = config.patch_size
  351. self.ln_pre = PixtralRMSNorm(config.hidden_size, eps=1e-5)
  352. self.transformer = PixtralTransformer(config)
  353. self.patch_positional_embedding = PixtralRotaryEmbedding(config)
  354. self.post_init()
  355. def get_input_embeddings(self):
  356. return self.patch_conv
  357. @merge_with_config_defaults
  358. @capture_outputs
  359. @auto_docstring
  360. def forward(
  361. self,
  362. pixel_values: torch.Tensor,
  363. image_sizes: torch.Tensor | None = None,
  364. **kwargs: Unpack[TransformersKwargs],
  365. ) -> tuple | BaseModelOutput:
  366. if image_sizes is None:
  367. batch_size, _, height, width = pixel_values.shape
  368. image_sizes = [(height, width)] * batch_size
  369. # pass images through initial convolution independently
  370. target_dtype = self.patch_conv.weight.dtype
  371. patch_embeds = self.patch_conv(pixel_values.to(dtype=target_dtype))
  372. patch_embeds_list = [
  373. embed[..., : (size[0] // self.patch_size), : (size[1] // self.patch_size)]
  374. for embed, size in zip(patch_embeds, image_sizes)
  375. ]
  376. # flatten to a single sequence
  377. patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
  378. patch_embeds = self.ln_pre(patch_embeds)
  379. # positional embeddings
  380. position_ids = position_ids_in_meshgrid(
  381. patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
  382. )
  383. kwargs["position_ids"] = position_ids.unsqueeze(0).to(patch_embeds.device, non_blocking=True)
  384. position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
  385. if is_flash_attention_requested(self.config):
  386. # We only rely on position_ids when using flash attention
  387. attention_mask = None
  388. else:
  389. attention_mask = generate_block_attention_mask(
  390. [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
  391. )
  392. return self.transformer(
  393. patch_embeds,
  394. attention_mask=attention_mask,
  395. position_embeddings=position_embeddings,
  396. **kwargs,
  397. )
  398. __all__ = ["PixtralVisionModel", "PixtralPreTrainedModel"]