modular_qwen2.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from collections.abc import Callable
  2. import torch
  3. from torch import nn
  4. from ...cache_utils import Cache, DynamicCache
  5. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  6. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  7. from ...modeling_outputs import (
  8. BaseModelOutputWithPast,
  9. )
  10. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  11. from ...processing_utils import Unpack
  12. from ...utils import TransformersKwargs, auto_docstring, logging
  13. from ...utils.generic import merge_with_config_defaults
  14. from ...utils.output_capturing import capture_outputs
  15. from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding
  16. from ..llama.modeling_llama import (
  17. LlamaAttention,
  18. LlamaDecoderLayer,
  19. LlamaForCausalLM,
  20. LlamaForQuestionAnswering,
  21. LlamaForSequenceClassification,
  22. LlamaForTokenClassification,
  23. LlamaMLP,
  24. LlamaPreTrainedModel,
  25. apply_rotary_pos_emb,
  26. eager_attention_forward,
  27. )
  28. from ..mistral.modeling_mistral import MistralModel
  29. from .configuration_qwen2 import Qwen2Config
  30. logger = logging.get_logger(__name__)
  31. class Qwen2MLP(LlamaMLP):
  32. def __init__(self, config):
  33. super().__init__(config)
  34. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  35. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  36. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  37. class Qwen2RotaryEmbedding(Gemma2RotaryEmbedding):
  38. pass
  39. class Qwen2Attention(LlamaAttention):
  40. def __init__(self, config: Qwen2Config, layer_idx: int):
  41. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  42. super().__init__(config, layer_idx)
  43. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
  44. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  45. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  46. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  47. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  48. def forward(
  49. self,
  50. hidden_states: torch.Tensor,
  51. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  52. attention_mask: torch.Tensor | None,
  53. past_key_values: Cache | None = None,
  54. **kwargs: Unpack[FlashAttentionKwargs],
  55. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  56. input_shape = hidden_states.shape[:-1]
  57. hidden_shape = (*input_shape, -1, self.head_dim)
  58. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  59. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  60. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  61. cos, sin = position_embeddings
  62. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  63. if past_key_values is not None:
  64. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  65. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  66. self.config._attn_implementation, eager_attention_forward
  67. )
  68. attn_output, attn_weights = attention_interface(
  69. self,
  70. query_states,
  71. key_states,
  72. value_states,
  73. attention_mask,
  74. dropout=0.0 if not self.training else self.attention_dropout,
  75. scaling=self.scaling,
  76. sliding_window=self.sliding_window, # main diff with Llama
  77. **kwargs,
  78. )
  79. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  80. attn_output = self.o_proj(attn_output)
  81. return attn_output, attn_weights
  82. class Qwen2DecoderLayer(LlamaDecoderLayer):
  83. pass
  84. class Qwen2PreTrainedModel(LlamaPreTrainedModel):
  85. pass
  86. class Qwen2Model(MistralModel):
  87. def __init__(self, config: Qwen2Config):
  88. super().__init__(config)
  89. self.has_sliding_layers = "sliding_attention" in self.config.layer_types
  90. @merge_with_config_defaults
  91. @capture_outputs
  92. @auto_docstring
  93. def forward(
  94. self,
  95. input_ids: torch.LongTensor | None = None,
  96. attention_mask: torch.Tensor | None = None,
  97. position_ids: torch.LongTensor | None = None,
  98. past_key_values: Cache | None = None,
  99. inputs_embeds: torch.FloatTensor | None = None,
  100. use_cache: bool | None = None,
  101. **kwargs: Unpack[TransformersKwargs],
  102. ) -> BaseModelOutputWithPast:
  103. if (input_ids is None) ^ (inputs_embeds is not None):
  104. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  105. if inputs_embeds is None:
  106. inputs_embeds = self.embed_tokens(input_ids)
  107. if use_cache and past_key_values is None:
  108. past_key_values = DynamicCache(config=self.config)
  109. if position_ids is None:
  110. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  111. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  112. position_ids = position_ids.unsqueeze(0)
  113. # It may already have been prepared by e.g. `generate`
  114. if not isinstance(causal_mask_mapping := attention_mask, dict):
  115. # Prepare mask arguments
  116. mask_kwargs = {
  117. "config": self.config,
  118. "inputs_embeds": inputs_embeds,
  119. "attention_mask": attention_mask,
  120. "past_key_values": past_key_values,
  121. "position_ids": position_ids,
  122. }
  123. # Create the masks
  124. causal_mask_mapping = {
  125. "full_attention": create_causal_mask(**mask_kwargs),
  126. }
  127. # The sliding window alternating layers are not always activated depending on the config
  128. if self.has_sliding_layers:
  129. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  130. hidden_states = inputs_embeds
  131. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  132. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  133. hidden_states = decoder_layer(
  134. hidden_states,
  135. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  136. position_embeddings=position_embeddings,
  137. position_ids=position_ids,
  138. past_key_values=past_key_values,
  139. use_cache=use_cache,
  140. **kwargs,
  141. )
  142. hidden_states = self.norm(hidden_states)
  143. return BaseModelOutputWithPast(
  144. last_hidden_state=hidden_states,
  145. past_key_values=past_key_values if use_cache else None,
  146. )
  147. class Qwen2ForCausalLM(LlamaForCausalLM):
  148. pass
  149. class Qwen2ForSequenceClassification(LlamaForSequenceClassification):
  150. pass
  151. class Qwen2ForTokenClassification(LlamaForTokenClassification):
  152. pass
  153. class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering):
  154. pass
  155. __all__ = [
  156. "Qwen2PreTrainedModel",
  157. "Qwen2Model",
  158. "Qwen2ForCausalLM",
  159. "Qwen2RMSNorm", # noqa: F822
  160. "Qwen2ForSequenceClassification",
  161. "Qwen2ForTokenClassification",
  162. "Qwen2ForQuestionAnswering",
  163. ]