modular_qwen3.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Qwen3 model."""
  15. from collections.abc import Callable
  16. import torch
  17. from ...cache_utils import Cache
  18. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  19. from ...modeling_outputs import CausalLMOutputWithPast
  20. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  21. from ...processing_utils import Unpack
  22. from ...utils import TransformersKwargs, logging
  23. from ..gemma.modeling_gemma import GemmaMLP
  24. from ..llama.modeling_llama import (
  25. LlamaAttention,
  26. )
  27. from ..qwen2.modeling_qwen2 import (
  28. Qwen2ForCausalLM,
  29. Qwen2ForQuestionAnswering,
  30. Qwen2ForSequenceClassification,
  31. Qwen2ForTokenClassification,
  32. Qwen2RMSNorm,
  33. Qwen2RotaryEmbedding,
  34. apply_rotary_pos_emb,
  35. eager_attention_forward,
  36. )
  37. from .configuration_qwen3 import Qwen3Config
  38. logger = logging.get_logger(__name__)
  39. _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B"
  40. class Qwen3RMSNorm(Qwen2RMSNorm):
  41. pass
  42. class Qwen3MLP(GemmaMLP):
  43. pass
  44. class Qwen3RotaryEmbedding(Qwen2RotaryEmbedding):
  45. pass
  46. class Qwen3Attention(LlamaAttention):
  47. def __init__(self, config: Qwen3Config, layer_idx: int):
  48. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  49. super().__init__(config, layer_idx)
  50. self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
  51. self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
  52. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  53. def forward(
  54. self,
  55. hidden_states: torch.Tensor,
  56. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  57. attention_mask: torch.Tensor | None,
  58. past_key_values: Cache | None = None,
  59. **kwargs: Unpack[FlashAttentionKwargs],
  60. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  61. input_shape = hidden_states.shape[:-1]
  62. hidden_shape = (*input_shape, -1, self.head_dim)
  63. query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  64. key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  65. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  66. cos, sin = position_embeddings
  67. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  68. if past_key_values is not None:
  69. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  70. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  71. self.config._attn_implementation, eager_attention_forward
  72. )
  73. attn_output, attn_weights = attention_interface(
  74. self,
  75. query_states,
  76. key_states,
  77. value_states,
  78. attention_mask,
  79. dropout=0.0 if not self.training else self.attention_dropout,
  80. scaling=self.scaling,
  81. sliding_window=self.sliding_window, # diff with Llama
  82. **kwargs,
  83. )
  84. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  85. attn_output = self.o_proj(attn_output)
  86. return attn_output, attn_weights
  87. class Qwen3ForCausalLM(Qwen2ForCausalLM):
  88. def forward(
  89. self,
  90. **super_kwargs: Unpack[TransformersKwargs],
  91. ) -> CausalLMOutputWithPast:
  92. r"""
  93. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  94. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  95. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  96. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  97. Example:
  98. ```python
  99. >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
  100. >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
  101. >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
  102. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  103. >>> inputs = tokenizer(prompt, return_tensors="pt")
  104. >>> # Generate
  105. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  106. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  107. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  108. ```"""
  109. return super().forward(**super_kwargs)
  110. class Qwen3ForSequenceClassification(Qwen2ForSequenceClassification):
  111. pass
  112. class Qwen3ForTokenClassification(Qwen2ForTokenClassification):
  113. pass
  114. class Qwen3ForQuestionAnswering(Qwen2ForQuestionAnswering):
  115. pass
  116. __all__ = [
  117. "Qwen3ForCausalLM",
  118. "Qwen3ForQuestionAnswering",
  119. "Qwen3PreTrainedModel", # noqa: F822
  120. "Qwen3Model", # noqa: F822
  121. "Qwen3ForSequenceClassification",
  122. "Qwen3ForTokenClassification",
  123. ]