modular_ministral3.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from collections.abc import Callable
  2. import torch
  3. from ...cache_utils import Cache
  4. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  5. from ...modeling_layers import (
  6. GenericForQuestionAnswering,
  7. GenericForSequenceClassification,
  8. GenericForTokenClassification,
  9. )
  10. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  11. from ...processing_utils import Unpack
  12. from ...utils import auto_docstring, logging
  13. from ..mistral.modeling_mistral import (
  14. MistralAttention,
  15. MistralDecoderLayer,
  16. MistralForCausalLM,
  17. MistralModel,
  18. MistralPreTrainedModel,
  19. apply_rotary_pos_emb,
  20. eager_attention_forward,
  21. )
  22. logger = logging.get_logger(__name__)
  23. def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
  24. scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
  25. return scaling[:, None, :, None]
  26. class Ministral3Attention(MistralAttention):
  27. def forward(
  28. self,
  29. hidden_states: torch.Tensor,
  30. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  31. attention_mask: torch.Tensor | None,
  32. position_ids: torch.Tensor,
  33. past_key_values: Cache | None = None,
  34. **kwargs: Unpack[FlashAttentionKwargs],
  35. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  36. input_shape = hidden_states.shape[:-1]
  37. hidden_shape = (*input_shape, -1, self.head_dim)
  38. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  39. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  40. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  41. cos, sin = position_embeddings
  42. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  43. query_states = query_states * get_llama_4_attn_scale(
  44. position_ids,
  45. self.config.rope_parameters.get("llama_4_scaling_beta"),
  46. self.config.rope_parameters.get("original_max_position_embeddings"),
  47. ).to(query_states.dtype)
  48. if past_key_values is not None:
  49. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  50. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  51. self.config._attn_implementation, eager_attention_forward
  52. )
  53. attn_output, attn_weights = attention_interface(
  54. self,
  55. query_states,
  56. key_states,
  57. value_states,
  58. attention_mask,
  59. dropout=0.0 if not self.training else self.attention_dropout,
  60. scaling=self.scaling,
  61. sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
  62. **kwargs,
  63. )
  64. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  65. attn_output = self.o_proj(attn_output)
  66. return attn_output, attn_weights
  67. class Ministral3DecoderLayer(MistralDecoderLayer):
  68. pass
  69. @auto_docstring
  70. class Ministral3PreTrainedModel(MistralPreTrainedModel):
  71. pass
  72. @auto_docstring
  73. class Ministral3Model(MistralModel):
  74. pass
  75. @auto_docstring
  76. class Ministral3ForCausalLM(MistralForCausalLM):
  77. pass
  78. class Ministral3ForTokenClassification(GenericForTokenClassification, Ministral3PreTrainedModel):
  79. pass
  80. class Ministral3ForSequenceClassification(GenericForSequenceClassification, Ministral3PreTrainedModel):
  81. pass
  82. class Ministral3ForQuestionAnswering(GenericForQuestionAnswering, Ministral3PreTrainedModel):
  83. pass
  84. __all__ = [
  85. "Ministral3ForCausalLM",
  86. "Ministral3ForQuestionAnswering",
  87. "Ministral3Model",
  88. "Ministral3PreTrainedModel",
  89. "Ministral3ForSequenceClassification",
  90. "Ministral3ForTokenClassification",
  91. ]