| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188 |
- from collections.abc import Callable
- import torch
- from torch import nn
- from ...cache_utils import Cache, DynamicCache
- from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
- from ...modeling_layers import (
- GenericForQuestionAnswering,
- )
- from ...modeling_outputs import BaseModelOutputWithPast
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, logging
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from ..llama.modeling_llama import (
- LlamaAttention,
- LlamaDecoderLayer,
- LlamaForCausalLM,
- LlamaForSequenceClassification,
- LlamaForTokenClassification,
- LlamaMLP,
- LlamaModel,
- LlamaPreTrainedModel,
- apply_rotary_pos_emb,
- eager_attention_forward,
- )
- from .configuration_mistral import MistralConfig
- logger = logging.get_logger(__name__)
- class MistralMLP(LlamaMLP):
- def __init__(self, config):
- super().__init__(config)
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
- class MistralAttention(LlamaAttention):
- def __init__(self, config: MistralConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
- self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.Tensor | None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[FlashAttentionKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class MistralDecoderLayer(LlamaDecoderLayer):
- def __init__(self, config: MistralConfig, layer_idx: int):
- super().__init__(config, layer_idx)
- self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
- self.mlp = MistralMLP(config)
- class MistralPreTrainedModel(LlamaPreTrainedModel):
- _can_record_outputs = {
- "hidden_states": MistralDecoderLayer,
- "attentions": MistralAttention,
- }
- class MistralModel(LlamaModel):
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
- causal_mask = mask_function(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- position_ids=position_ids,
- )
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values if use_cache else None,
- )
- class MistralForCausalLM(LlamaForCausalLM):
- pass
- class MistralForTokenClassification(LlamaForTokenClassification):
- pass
- class MistralForSequenceClassification(LlamaForSequenceClassification):
- pass
- class MistralForQuestionAnswering(GenericForQuestionAnswering, MistralPreTrainedModel): ...
- __all__ = [
- "MistralForCausalLM",
- "MistralForQuestionAnswering",
- "MistralModel",
- "MistralPreTrainedModel",
- "MistralForSequenceClassification",
- "MistralForTokenClassification",
- ]
|