modeling_mistral.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mistral/modular_mistral.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mistral.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. from collections.abc import Callable
  8. from typing import Optional
  9. import torch
  10. from torch import nn
  11. from ...activations import ACT2FN
  12. from ...cache_utils import Cache, DynamicCache
  13. from ...generation import GenerationMixin
  14. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  15. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  16. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  17. from ...modeling_layers import (
  18. GenericForQuestionAnswering,
  19. GenericForSequenceClassification,
  20. GenericForTokenClassification,
  21. GradientCheckpointingLayer,
  22. )
  23. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  24. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  28. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  29. from ...utils.output_capturing import capture_outputs
  30. from .configuration_mistral import MistralConfig
  31. class MistralMLP(nn.Module):
  32. def __init__(self, config):
  33. super().__init__()
  34. self.config = config
  35. self.hidden_size = config.hidden_size
  36. self.intermediate_size = config.intermediate_size
  37. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  38. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  39. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  40. self.act_fn = ACT2FN[config.hidden_act]
  41. def forward(self, x):
  42. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  43. return down_proj
  44. def rotate_half(x):
  45. """Rotates half the hidden dims of the input."""
  46. x1 = x[..., : x.shape[-1] // 2]
  47. x2 = x[..., x.shape[-1] // 2 :]
  48. return torch.cat((-x2, x1), dim=-1)
  49. @use_kernel_func_from_hub("rotary_pos_emb")
  50. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  51. """Applies Rotary Position Embedding to the query and key tensors.
  52. Args:
  53. q (`torch.Tensor`): The query tensor.
  54. k (`torch.Tensor`): The key tensor.
  55. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  56. sin (`torch.Tensor`): The sine part of the rotary embedding.
  57. unsqueeze_dim (`int`, *optional*, defaults to 1):
  58. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  59. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  60. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  61. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  62. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  63. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  64. Returns:
  65. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  66. """
  67. cos = cos.unsqueeze(unsqueeze_dim)
  68. sin = sin.unsqueeze(unsqueeze_dim)
  69. q_embed = (q * cos) + (rotate_half(q) * sin)
  70. k_embed = (k * cos) + (rotate_half(k) * sin)
  71. return q_embed, k_embed
  72. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  73. """
  74. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  75. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  76. """
  77. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  78. if n_rep == 1:
  79. return hidden_states
  80. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  81. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  82. def eager_attention_forward(
  83. module: nn.Module,
  84. query: torch.Tensor,
  85. key: torch.Tensor,
  86. value: torch.Tensor,
  87. attention_mask: torch.Tensor | None,
  88. scaling: float,
  89. dropout: float = 0.0,
  90. **kwargs: Unpack[TransformersKwargs],
  91. ):
  92. key_states = repeat_kv(key, module.num_key_value_groups)
  93. value_states = repeat_kv(value, module.num_key_value_groups)
  94. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  95. if attention_mask is not None:
  96. attn_weights = attn_weights + attention_mask
  97. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  98. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  99. attn_output = torch.matmul(attn_weights, value_states)
  100. attn_output = attn_output.transpose(1, 2).contiguous()
  101. return attn_output, attn_weights
  102. @use_kernelized_func(apply_rotary_pos_emb)
  103. class MistralAttention(nn.Module):
  104. """Multi-headed attention from 'Attention Is All You Need' paper"""
  105. def __init__(self, config: MistralConfig, layer_idx: int):
  106. super().__init__()
  107. self.config = config
  108. self.layer_idx = layer_idx
  109. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  110. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  111. self.scaling = self.head_dim**-0.5
  112. self.attention_dropout = config.attention_dropout
  113. self.is_causal = True
  114. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  115. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  116. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  117. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  118. def forward(
  119. self,
  120. hidden_states: torch.Tensor,
  121. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  122. attention_mask: torch.Tensor | None,
  123. past_key_values: Cache | None = None,
  124. **kwargs: Unpack[FlashAttentionKwargs],
  125. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  126. input_shape = hidden_states.shape[:-1]
  127. hidden_shape = (*input_shape, -1, self.head_dim)
  128. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  129. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  130. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  131. cos, sin = position_embeddings
  132. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  133. if past_key_values is not None:
  134. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  135. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  136. self.config._attn_implementation, eager_attention_forward
  137. )
  138. attn_output, attn_weights = attention_interface(
  139. self,
  140. query_states,
  141. key_states,
  142. value_states,
  143. attention_mask,
  144. dropout=0.0 if not self.training else self.attention_dropout,
  145. scaling=self.scaling,
  146. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  147. **kwargs,
  148. )
  149. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  150. attn_output = self.o_proj(attn_output)
  151. return attn_output, attn_weights
  152. @use_kernel_forward_from_hub("RMSNorm")
  153. class MistralRMSNorm(nn.Module):
  154. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  155. """
  156. MistralRMSNorm is equivalent to T5LayerNorm
  157. """
  158. super().__init__()
  159. self.weight = nn.Parameter(torch.ones(hidden_size))
  160. self.variance_epsilon = eps
  161. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  162. input_dtype = hidden_states.dtype
  163. hidden_states = hidden_states.to(torch.float32)
  164. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  165. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  166. return self.weight * hidden_states.to(input_dtype)
  167. def extra_repr(self):
  168. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  169. class MistralDecoderLayer(GradientCheckpointingLayer):
  170. def __init__(self, config: MistralConfig, layer_idx: int):
  171. super().__init__()
  172. self.hidden_size = config.hidden_size
  173. self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
  174. self.mlp = MistralMLP(config)
  175. self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  176. self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  177. def forward(
  178. self,
  179. hidden_states: torch.Tensor,
  180. attention_mask: torch.Tensor | None = None,
  181. position_ids: torch.LongTensor | None = None,
  182. past_key_values: Cache | None = None,
  183. use_cache: bool | None = False,
  184. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  185. **kwargs: Unpack[TransformersKwargs],
  186. ) -> torch.Tensor:
  187. residual = hidden_states
  188. hidden_states = self.input_layernorm(hidden_states)
  189. # Self Attention
  190. hidden_states, _ = self.self_attn(
  191. hidden_states=hidden_states,
  192. attention_mask=attention_mask,
  193. position_ids=position_ids,
  194. past_key_values=past_key_values,
  195. use_cache=use_cache,
  196. position_embeddings=position_embeddings,
  197. **kwargs,
  198. )
  199. hidden_states = residual + hidden_states
  200. # Fully Connected
  201. residual = hidden_states
  202. hidden_states = self.post_attention_layernorm(hidden_states)
  203. hidden_states = self.mlp(hidden_states)
  204. hidden_states = residual + hidden_states
  205. return hidden_states
  206. @auto_docstring
  207. class MistralPreTrainedModel(PreTrainedModel):
  208. config: MistralConfig
  209. base_model_prefix = "model"
  210. supports_gradient_checkpointing = True
  211. _no_split_modules = ["MistralDecoderLayer"]
  212. _skip_keys_device_placement = ["past_key_values"]
  213. _supports_flash_attn = True
  214. _supports_sdpa = True
  215. _supports_flex_attn = True
  216. _can_compile_fullgraph = True
  217. _supports_attention_backend = True
  218. _can_record_outputs = {
  219. "hidden_states": MistralDecoderLayer,
  220. "attentions": MistralAttention,
  221. }
  222. class MistralRotaryEmbedding(nn.Module):
  223. inv_freq: torch.Tensor # fix linting for `register_buffer`
  224. def __init__(self, config: MistralConfig, device=None):
  225. super().__init__()
  226. self.max_seq_len_cached = config.max_position_embeddings
  227. self.original_max_seq_len = config.max_position_embeddings
  228. self.config = config
  229. self.rope_type = self.config.rope_parameters["rope_type"]
  230. rope_init_fn: Callable = self.compute_default_rope_parameters
  231. if self.rope_type != "default":
  232. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  233. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  234. self.register_buffer("inv_freq", inv_freq, persistent=False)
  235. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  236. @staticmethod
  237. def compute_default_rope_parameters(
  238. config: MistralConfig | None = None,
  239. device: Optional["torch.device"] = None,
  240. seq_len: int | None = None,
  241. ) -> tuple["torch.Tensor", float]:
  242. """
  243. Computes the inverse frequencies according to the original RoPE implementation
  244. Args:
  245. config ([`~transformers.PreTrainedConfig`]):
  246. The model configuration.
  247. device (`torch.device`):
  248. The device to use for initialization of the inverse frequencies.
  249. seq_len (`int`, *optional*):
  250. The current sequence length. Unused for this type of RoPE.
  251. Returns:
  252. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  253. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  254. """
  255. base = config.rope_parameters["rope_theta"]
  256. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  257. attention_factor = 1.0 # Unused in this type of RoPE
  258. # Compute the inverse frequencies
  259. inv_freq = 1.0 / (
  260. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  261. )
  262. return inv_freq, attention_factor
  263. @torch.no_grad()
  264. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  265. def forward(self, x, position_ids):
  266. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  267. position_ids_expanded = position_ids[:, None, :].float()
  268. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  269. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  270. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  271. emb = torch.cat((freqs, freqs), dim=-1)
  272. cos = emb.cos() * self.attention_scaling
  273. sin = emb.sin() * self.attention_scaling
  274. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  275. @auto_docstring
  276. class MistralModel(MistralPreTrainedModel):
  277. def __init__(self, config: MistralConfig):
  278. super().__init__(config)
  279. self.padding_idx = config.pad_token_id
  280. self.vocab_size = config.vocab_size
  281. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  282. self.layers = nn.ModuleList(
  283. [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  284. )
  285. self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  286. self.rotary_emb = MistralRotaryEmbedding(config=config)
  287. self.gradient_checkpointing = False
  288. # Initialize weights and apply final processing
  289. self.post_init()
  290. @merge_with_config_defaults
  291. @capture_outputs
  292. @auto_docstring
  293. def forward(
  294. self,
  295. input_ids: torch.LongTensor | None = None,
  296. attention_mask: torch.Tensor | None = None,
  297. position_ids: torch.LongTensor | None = None,
  298. past_key_values: Cache | None = None,
  299. inputs_embeds: torch.FloatTensor | None = None,
  300. use_cache: bool | None = None,
  301. **kwargs: Unpack[TransformersKwargs],
  302. ) -> BaseModelOutputWithPast:
  303. if (input_ids is None) ^ (inputs_embeds is not None):
  304. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  305. if inputs_embeds is None:
  306. inputs_embeds = self.embed_tokens(input_ids)
  307. if use_cache and past_key_values is None:
  308. past_key_values = DynamicCache(config=self.config)
  309. if position_ids is None:
  310. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  311. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  312. position_ids = position_ids.unsqueeze(0)
  313. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  314. causal_mask = mask_function(
  315. config=self.config,
  316. inputs_embeds=inputs_embeds,
  317. attention_mask=attention_mask,
  318. past_key_values=past_key_values,
  319. position_ids=position_ids,
  320. )
  321. hidden_states = inputs_embeds
  322. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  323. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  324. hidden_states = decoder_layer(
  325. hidden_states,
  326. attention_mask=causal_mask,
  327. position_ids=position_ids,
  328. past_key_values=past_key_values,
  329. use_cache=use_cache,
  330. position_embeddings=position_embeddings,
  331. **kwargs,
  332. )
  333. hidden_states = self.norm(hidden_states)
  334. return BaseModelOutputWithPast(
  335. last_hidden_state=hidden_states,
  336. past_key_values=past_key_values if use_cache else None,
  337. )
  338. @auto_docstring
  339. class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
  340. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  341. _tp_plan = {"lm_head": "colwise_gather_output"}
  342. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  343. def __init__(self, config):
  344. super().__init__(config)
  345. self.model = MistralModel(config)
  346. self.vocab_size = config.vocab_size
  347. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  348. # Initialize weights and apply final processing
  349. self.post_init()
  350. @can_return_tuple
  351. @auto_docstring
  352. def forward(
  353. self,
  354. input_ids: torch.LongTensor | None = None,
  355. attention_mask: torch.Tensor | None = None,
  356. position_ids: torch.LongTensor | None = None,
  357. past_key_values: Cache | None = None,
  358. inputs_embeds: torch.FloatTensor | None = None,
  359. labels: torch.LongTensor | None = None,
  360. use_cache: bool | None = None,
  361. logits_to_keep: int | torch.Tensor = 0,
  362. **kwargs: Unpack[TransformersKwargs],
  363. ) -> CausalLMOutputWithPast:
  364. r"""
  365. Example:
  366. ```python
  367. >>> from transformers import AutoTokenizer, MistralForCausalLM
  368. >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
  369. >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
  370. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  371. >>> inputs = tokenizer(prompt, return_tensors="pt")
  372. >>> # Generate
  373. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  374. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  375. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  376. ```"""
  377. outputs: BaseModelOutputWithPast = self.model(
  378. input_ids=input_ids,
  379. attention_mask=attention_mask,
  380. position_ids=position_ids,
  381. past_key_values=past_key_values,
  382. inputs_embeds=inputs_embeds,
  383. use_cache=use_cache,
  384. **kwargs,
  385. )
  386. hidden_states = outputs.last_hidden_state
  387. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  388. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  389. logits = self.lm_head(hidden_states[:, slice_indices, :])
  390. loss = None
  391. if labels is not None:
  392. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  393. return CausalLMOutputWithPast(
  394. loss=loss,
  395. logits=logits,
  396. past_key_values=outputs.past_key_values,
  397. hidden_states=outputs.hidden_states,
  398. attentions=outputs.attentions,
  399. )
  400. class MistralForTokenClassification(GenericForTokenClassification, MistralPreTrainedModel):
  401. pass
  402. class MistralForSequenceClassification(GenericForSequenceClassification, MistralPreTrainedModel):
  403. pass
  404. class MistralForQuestionAnswering(GenericForQuestionAnswering, MistralPreTrainedModel): ...
  405. __all__ = [
  406. "MistralForCausalLM",
  407. "MistralForQuestionAnswering",
  408. "MistralModel",
  409. "MistralPreTrainedModel",
  410. "MistralForSequenceClassification",
  411. "MistralForTokenClassification",
  412. ]