modular_mistral.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from collections.abc import Callable
  2. import torch
  3. from torch import nn
  4. from ...cache_utils import Cache, DynamicCache
  5. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  6. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  7. from ...modeling_layers import (
  8. GenericForQuestionAnswering,
  9. )
  10. from ...modeling_outputs import BaseModelOutputWithPast
  11. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  12. from ...processing_utils import Unpack
  13. from ...utils import TransformersKwargs, auto_docstring, logging
  14. from ...utils.generic import merge_with_config_defaults
  15. from ...utils.output_capturing import capture_outputs
  16. from ..llama.modeling_llama import (
  17. LlamaAttention,
  18. LlamaDecoderLayer,
  19. LlamaForCausalLM,
  20. LlamaForSequenceClassification,
  21. LlamaForTokenClassification,
  22. LlamaMLP,
  23. LlamaModel,
  24. LlamaPreTrainedModel,
  25. apply_rotary_pos_emb,
  26. eager_attention_forward,
  27. )
  28. from .configuration_mistral import MistralConfig
  29. logger = logging.get_logger(__name__)
  30. class MistralMLP(LlamaMLP):
  31. def __init__(self, config):
  32. super().__init__(config)
  33. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  34. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  35. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  36. class MistralAttention(LlamaAttention):
  37. def __init__(self, config: MistralConfig, layer_idx: int):
  38. super().__init__(config, layer_idx)
  39. self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  40. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  41. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  42. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  43. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  44. def forward(
  45. self,
  46. hidden_states: torch.Tensor,
  47. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  48. attention_mask: torch.Tensor | None,
  49. past_key_values: Cache | None = None,
  50. **kwargs: Unpack[FlashAttentionKwargs],
  51. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  52. input_shape = hidden_states.shape[:-1]
  53. hidden_shape = (*input_shape, -1, self.head_dim)
  54. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  55. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  56. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  57. cos, sin = position_embeddings
  58. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  59. if past_key_values is not None:
  60. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  61. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  62. self.config._attn_implementation, eager_attention_forward
  63. )
  64. attn_output, attn_weights = attention_interface(
  65. self,
  66. query_states,
  67. key_states,
  68. value_states,
  69. attention_mask,
  70. dropout=0.0 if not self.training else self.attention_dropout,
  71. scaling=self.scaling,
  72. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  73. **kwargs,
  74. )
  75. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  76. attn_output = self.o_proj(attn_output)
  77. return attn_output, attn_weights
  78. class MistralDecoderLayer(LlamaDecoderLayer):
  79. def __init__(self, config: MistralConfig, layer_idx: int):
  80. super().__init__(config, layer_idx)
  81. self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
  82. self.mlp = MistralMLP(config)
  83. class MistralPreTrainedModel(LlamaPreTrainedModel):
  84. _can_record_outputs = {
  85. "hidden_states": MistralDecoderLayer,
  86. "attentions": MistralAttention,
  87. }
  88. class MistralModel(LlamaModel):
  89. @merge_with_config_defaults
  90. @capture_outputs
  91. @auto_docstring
  92. def forward(
  93. self,
  94. input_ids: torch.LongTensor | None = None,
  95. attention_mask: torch.Tensor | None = None,
  96. position_ids: torch.LongTensor | None = None,
  97. past_key_values: Cache | None = None,
  98. inputs_embeds: torch.FloatTensor | None = None,
  99. use_cache: bool | None = None,
  100. **kwargs: Unpack[TransformersKwargs],
  101. ) -> BaseModelOutputWithPast:
  102. if (input_ids is None) ^ (inputs_embeds is not None):
  103. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  104. if inputs_embeds is None:
  105. inputs_embeds = self.embed_tokens(input_ids)
  106. if use_cache and past_key_values is None:
  107. past_key_values = DynamicCache(config=self.config)
  108. if position_ids is None:
  109. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  110. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  111. position_ids = position_ids.unsqueeze(0)
  112. mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
  113. causal_mask = mask_function(
  114. config=self.config,
  115. inputs_embeds=inputs_embeds,
  116. attention_mask=attention_mask,
  117. past_key_values=past_key_values,
  118. position_ids=position_ids,
  119. )
  120. hidden_states = inputs_embeds
  121. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  122. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  123. hidden_states = decoder_layer(
  124. hidden_states,
  125. attention_mask=causal_mask,
  126. position_ids=position_ids,
  127. past_key_values=past_key_values,
  128. use_cache=use_cache,
  129. position_embeddings=position_embeddings,
  130. **kwargs,
  131. )
  132. hidden_states = self.norm(hidden_states)
  133. return BaseModelOutputWithPast(
  134. last_hidden_state=hidden_states,
  135. past_key_values=past_key_values if use_cache else None,
  136. )
  137. class MistralForCausalLM(LlamaForCausalLM):
  138. pass
  139. class MistralForTokenClassification(LlamaForTokenClassification):
  140. pass
  141. class MistralForSequenceClassification(LlamaForSequenceClassification):
  142. pass
  143. class MistralForQuestionAnswering(GenericForQuestionAnswering, MistralPreTrainedModel): ...
  144. __all__ = [
  145. "MistralForCausalLM",
  146. "MistralForQuestionAnswering",
  147. "MistralModel",
  148. "MistralPreTrainedModel",
  149. "MistralForSequenceClassification",
  150. "MistralForTokenClassification",
  151. ]