modeling_qwen3.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_qwen3.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  28. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import (
  31. GenericForQuestionAnswering,
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_qwen3 import Qwen3Config
  44. @use_kernel_forward_from_hub("RMSNorm")
  45. class Qwen3RMSNorm(nn.Module):
  46. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  47. """
  48. Qwen3RMSNorm is equivalent to T5LayerNorm
  49. """
  50. super().__init__()
  51. self.weight = nn.Parameter(torch.ones(hidden_size))
  52. self.variance_epsilon = eps
  53. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  54. input_dtype = hidden_states.dtype
  55. hidden_states = hidden_states.to(torch.float32)
  56. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  57. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  58. return self.weight * hidden_states.to(input_dtype)
  59. def extra_repr(self):
  60. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  61. class Qwen3MLP(nn.Module):
  62. def __init__(self, config):
  63. super().__init__()
  64. self.config = config
  65. self.hidden_size = config.hidden_size
  66. self.intermediate_size = config.intermediate_size
  67. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  68. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  69. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  70. self.act_fn = ACT2FN[config.hidden_act]
  71. def forward(self, x):
  72. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  73. return down_proj
  74. class Qwen3RotaryEmbedding(nn.Module):
  75. inv_freq: torch.Tensor # fix linting for `register_buffer`
  76. def __init__(self, config: Qwen3Config, device=None):
  77. super().__init__()
  78. self.max_seq_len_cached = config.max_position_embeddings
  79. self.original_max_seq_len = config.max_position_embeddings
  80. self.config = config
  81. self.rope_type = self.config.rope_parameters["rope_type"]
  82. rope_init_fn: Callable = self.compute_default_rope_parameters
  83. if self.rope_type != "default":
  84. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  85. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  86. self.register_buffer("inv_freq", inv_freq, persistent=False)
  87. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  88. @staticmethod
  89. def compute_default_rope_parameters(
  90. config: Qwen3Config | None = None,
  91. device: Optional["torch.device"] = None,
  92. seq_len: int | None = None,
  93. ) -> tuple["torch.Tensor", float]:
  94. """
  95. Computes the inverse frequencies according to the original RoPE implementation
  96. Args:
  97. config ([`~transformers.PreTrainedConfig`]):
  98. The model configuration.
  99. device (`torch.device`):
  100. The device to use for initialization of the inverse frequencies.
  101. seq_len (`int`, *optional*):
  102. The current sequence length. Unused for this type of RoPE.
  103. Returns:
  104. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  105. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  106. """
  107. base = config.rope_parameters["rope_theta"]
  108. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  109. attention_factor = 1.0 # Unused in this type of RoPE
  110. # Compute the inverse frequencies
  111. inv_freq = 1.0 / (
  112. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  113. )
  114. return inv_freq, attention_factor
  115. @torch.no_grad()
  116. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  117. def forward(self, x, position_ids):
  118. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  119. position_ids_expanded = position_ids[:, None, :].float()
  120. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  121. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  122. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  123. emb = torch.cat((freqs, freqs), dim=-1)
  124. cos = emb.cos() * self.attention_scaling
  125. sin = emb.sin() * self.attention_scaling
  126. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  127. def rotate_half(x):
  128. """Rotates half the hidden dims of the input."""
  129. x1 = x[..., : x.shape[-1] // 2]
  130. x2 = x[..., x.shape[-1] // 2 :]
  131. return torch.cat((-x2, x1), dim=-1)
  132. @use_kernel_func_from_hub("rotary_pos_emb")
  133. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  134. """Applies Rotary Position Embedding to the query and key tensors.
  135. Args:
  136. q (`torch.Tensor`): The query tensor.
  137. k (`torch.Tensor`): The key tensor.
  138. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  139. sin (`torch.Tensor`): The sine part of the rotary embedding.
  140. unsqueeze_dim (`int`, *optional*, defaults to 1):
  141. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  142. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  143. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  144. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  145. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  146. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  147. Returns:
  148. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  149. """
  150. cos = cos.unsqueeze(unsqueeze_dim)
  151. sin = sin.unsqueeze(unsqueeze_dim)
  152. q_embed = (q * cos) + (rotate_half(q) * sin)
  153. k_embed = (k * cos) + (rotate_half(k) * sin)
  154. return q_embed, k_embed
  155. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  156. """
  157. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  158. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  159. """
  160. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  161. if n_rep == 1:
  162. return hidden_states
  163. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  164. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  165. def eager_attention_forward(
  166. module: nn.Module,
  167. query: torch.Tensor,
  168. key: torch.Tensor,
  169. value: torch.Tensor,
  170. attention_mask: torch.Tensor | None,
  171. scaling: float,
  172. dropout: float = 0.0,
  173. **kwargs: Unpack[TransformersKwargs],
  174. ):
  175. key_states = repeat_kv(key, module.num_key_value_groups)
  176. value_states = repeat_kv(value, module.num_key_value_groups)
  177. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  178. if attention_mask is not None:
  179. attn_weights = attn_weights + attention_mask
  180. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  181. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  182. attn_output = torch.matmul(attn_weights, value_states)
  183. attn_output = attn_output.transpose(1, 2).contiguous()
  184. return attn_output, attn_weights
  185. @use_kernelized_func(apply_rotary_pos_emb)
  186. class Qwen3Attention(nn.Module):
  187. """Multi-headed attention from 'Attention Is All You Need' paper"""
  188. def __init__(self, config: Qwen3Config, layer_idx: int):
  189. super().__init__()
  190. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  191. self.config = config
  192. self.layer_idx = layer_idx
  193. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  194. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  195. self.scaling = self.head_dim**-0.5
  196. self.attention_dropout = config.attention_dropout
  197. self.is_causal = True
  198. self.q_proj = nn.Linear(
  199. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  200. )
  201. self.k_proj = nn.Linear(
  202. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  203. )
  204. self.v_proj = nn.Linear(
  205. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  206. )
  207. self.o_proj = nn.Linear(
  208. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  209. )
  210. self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
  211. self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
  212. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  213. def forward(
  214. self,
  215. hidden_states: torch.Tensor,
  216. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  217. attention_mask: torch.Tensor | None,
  218. past_key_values: Cache | None = None,
  219. **kwargs: Unpack[FlashAttentionKwargs],
  220. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  221. input_shape = hidden_states.shape[:-1]
  222. hidden_shape = (*input_shape, -1, self.head_dim)
  223. query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  224. key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
  225. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  226. cos, sin = position_embeddings
  227. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  228. if past_key_values is not None:
  229. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  230. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  231. self.config._attn_implementation, eager_attention_forward
  232. )
  233. attn_output, attn_weights = attention_interface(
  234. self,
  235. query_states,
  236. key_states,
  237. value_states,
  238. attention_mask,
  239. dropout=0.0 if not self.training else self.attention_dropout,
  240. scaling=self.scaling,
  241. sliding_window=self.sliding_window, # diff with Llama
  242. **kwargs,
  243. )
  244. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  245. attn_output = self.o_proj(attn_output)
  246. return attn_output, attn_weights
  247. class Qwen3DecoderLayer(GradientCheckpointingLayer):
  248. def __init__(self, config: Qwen3Config, layer_idx: int):
  249. super().__init__()
  250. self.hidden_size = config.hidden_size
  251. self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
  252. self.mlp = Qwen3MLP(config)
  253. self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  254. self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  255. def forward(
  256. self,
  257. hidden_states: torch.Tensor,
  258. attention_mask: torch.Tensor | None = None,
  259. position_ids: torch.LongTensor | None = None,
  260. past_key_values: Cache | None = None,
  261. use_cache: bool | None = False,
  262. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  263. **kwargs: Unpack[TransformersKwargs],
  264. ) -> torch.Tensor:
  265. residual = hidden_states
  266. hidden_states = self.input_layernorm(hidden_states)
  267. # Self Attention
  268. hidden_states, _ = self.self_attn(
  269. hidden_states=hidden_states,
  270. attention_mask=attention_mask,
  271. position_ids=position_ids,
  272. past_key_values=past_key_values,
  273. use_cache=use_cache,
  274. position_embeddings=position_embeddings,
  275. **kwargs,
  276. )
  277. hidden_states = residual + hidden_states
  278. # Fully Connected
  279. residual = hidden_states
  280. hidden_states = self.post_attention_layernorm(hidden_states)
  281. hidden_states = self.mlp(hidden_states)
  282. hidden_states = residual + hidden_states
  283. return hidden_states
  284. @auto_docstring
  285. class Qwen3PreTrainedModel(PreTrainedModel):
  286. config: Qwen3Config
  287. base_model_prefix = "model"
  288. supports_gradient_checkpointing = True
  289. _no_split_modules = ["Qwen3DecoderLayer"]
  290. _skip_keys_device_placement = ["past_key_values"]
  291. _supports_flash_attn = True
  292. _supports_sdpa = True
  293. _supports_flex_attn = True
  294. _can_compile_fullgraph = True
  295. _supports_attention_backend = True
  296. _can_record_outputs = {
  297. "hidden_states": Qwen3DecoderLayer,
  298. "attentions": Qwen3Attention,
  299. }
  300. @auto_docstring
  301. class Qwen3Model(Qwen3PreTrainedModel):
  302. def __init__(self, config: Qwen3Config):
  303. super().__init__(config)
  304. self.padding_idx = config.pad_token_id
  305. self.vocab_size = config.vocab_size
  306. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  307. self.layers = nn.ModuleList(
  308. [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  309. )
  310. self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  311. self.rotary_emb = Qwen3RotaryEmbedding(config=config)
  312. self.gradient_checkpointing = False
  313. self.has_sliding_layers = "sliding_attention" in self.config.layer_types
  314. # Initialize weights and apply final processing
  315. self.post_init()
  316. @merge_with_config_defaults
  317. @capture_outputs
  318. @auto_docstring
  319. def forward(
  320. self,
  321. input_ids: torch.LongTensor | None = None,
  322. attention_mask: torch.Tensor | None = None,
  323. position_ids: torch.LongTensor | None = None,
  324. past_key_values: Cache | None = None,
  325. inputs_embeds: torch.FloatTensor | None = None,
  326. use_cache: bool | None = None,
  327. **kwargs: Unpack[TransformersKwargs],
  328. ) -> BaseModelOutputWithPast:
  329. if (input_ids is None) ^ (inputs_embeds is not None):
  330. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  331. if inputs_embeds is None:
  332. inputs_embeds = self.embed_tokens(input_ids)
  333. if use_cache and past_key_values is None:
  334. past_key_values = DynamicCache(config=self.config)
  335. if position_ids is None:
  336. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  337. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  338. position_ids = position_ids.unsqueeze(0)
  339. # It may already have been prepared by e.g. `generate`
  340. if not isinstance(causal_mask_mapping := attention_mask, dict):
  341. # Prepare mask arguments
  342. mask_kwargs = {
  343. "config": self.config,
  344. "inputs_embeds": inputs_embeds,
  345. "attention_mask": attention_mask,
  346. "past_key_values": past_key_values,
  347. "position_ids": position_ids,
  348. }
  349. # Create the masks
  350. causal_mask_mapping = {
  351. "full_attention": create_causal_mask(**mask_kwargs),
  352. }
  353. # The sliding window alternating layers are not always activated depending on the config
  354. if self.has_sliding_layers:
  355. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  356. hidden_states = inputs_embeds
  357. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  358. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  359. hidden_states = decoder_layer(
  360. hidden_states,
  361. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  362. position_embeddings=position_embeddings,
  363. position_ids=position_ids,
  364. past_key_values=past_key_values,
  365. use_cache=use_cache,
  366. **kwargs,
  367. )
  368. hidden_states = self.norm(hidden_states)
  369. return BaseModelOutputWithPast(
  370. last_hidden_state=hidden_states,
  371. past_key_values=past_key_values if use_cache else None,
  372. )
  373. @auto_docstring
  374. class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
  375. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  376. _tp_plan = {"lm_head": "colwise_gather_output"}
  377. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  378. def __init__(self, config):
  379. super().__init__(config)
  380. self.model = Qwen3Model(config)
  381. self.vocab_size = config.vocab_size
  382. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  383. # Initialize weights and apply final processing
  384. self.post_init()
  385. @can_return_tuple
  386. @auto_docstring
  387. def forward(
  388. self,
  389. input_ids: torch.LongTensor | None = None,
  390. attention_mask: torch.Tensor | None = None,
  391. position_ids: torch.LongTensor | None = None,
  392. past_key_values: Cache | None = None,
  393. inputs_embeds: torch.FloatTensor | None = None,
  394. labels: torch.LongTensor | None = None,
  395. use_cache: bool | None = None,
  396. logits_to_keep: int | torch.Tensor = 0,
  397. **kwargs: Unpack[TransformersKwargs],
  398. ) -> CausalLMOutputWithPast:
  399. r"""
  400. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  401. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  402. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  403. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  404. Example:
  405. ```python
  406. >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
  407. >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
  408. >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
  409. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  410. >>> inputs = tokenizer(prompt, return_tensors="pt")
  411. >>> # Generate
  412. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  413. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  414. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  415. ```"""
  416. outputs: BaseModelOutputWithPast = self.model(
  417. input_ids=input_ids,
  418. attention_mask=attention_mask,
  419. position_ids=position_ids,
  420. past_key_values=past_key_values,
  421. inputs_embeds=inputs_embeds,
  422. use_cache=use_cache,
  423. **kwargs,
  424. )
  425. hidden_states = outputs.last_hidden_state
  426. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  427. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  428. logits = self.lm_head(hidden_states[:, slice_indices, :])
  429. loss = None
  430. if labels is not None:
  431. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  432. return CausalLMOutputWithPast(
  433. loss=loss,
  434. logits=logits,
  435. past_key_values=outputs.past_key_values,
  436. hidden_states=outputs.hidden_states,
  437. attentions=outputs.attentions,
  438. )
  439. class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
  440. pass
  441. class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel):
  442. pass
  443. class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel):
  444. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  445. __all__ = [
  446. "Qwen3ForCausalLM",
  447. "Qwen3ForQuestionAnswering",
  448. "Qwen3PreTrainedModel",
  449. "Qwen3Model",
  450. "Qwen3ForSequenceClassification",
  451. "Qwen3ForTokenClassification",
  452. ]