modeling_qwen2.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_qwen2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. from collections.abc import Callable
  8. from typing import Optional
  9. import torch
  10. from torch import nn
  11. from ...activations import ACT2FN
  12. from ...cache_utils import Cache, DynamicCache
  13. from ...generation import GenerationMixin
  14. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  15. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  16. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  17. from ...modeling_layers import (
  18. GenericForQuestionAnswering,
  19. GenericForSequenceClassification,
  20. GenericForTokenClassification,
  21. GradientCheckpointingLayer,
  22. )
  23. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  24. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  28. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  29. from ...utils.output_capturing import capture_outputs
  30. from .configuration_qwen2 import Qwen2Config
  31. class Qwen2MLP(nn.Module):
  32. def __init__(self, config):
  33. super().__init__()
  34. self.config = config
  35. self.hidden_size = config.hidden_size
  36. self.intermediate_size = config.intermediate_size
  37. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  38. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  39. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  40. self.act_fn = ACT2FN[config.hidden_act]
  41. def forward(self, x):
  42. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  43. return down_proj
  44. class Qwen2RotaryEmbedding(nn.Module):
  45. inv_freq: torch.Tensor # fix linting for `register_buffer`
  46. def __init__(self, config: Qwen2Config, device=None):
  47. super().__init__()
  48. self.max_seq_len_cached = config.max_position_embeddings
  49. self.original_max_seq_len = config.max_position_embeddings
  50. self.config = config
  51. self.rope_type = self.config.rope_parameters["rope_type"]
  52. rope_init_fn: Callable = self.compute_default_rope_parameters
  53. if self.rope_type != "default":
  54. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  55. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  56. self.register_buffer("inv_freq", inv_freq, persistent=False)
  57. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  58. @staticmethod
  59. def compute_default_rope_parameters(
  60. config: Qwen2Config | None = None,
  61. device: Optional["torch.device"] = None,
  62. seq_len: int | None = None,
  63. ) -> tuple["torch.Tensor", float]:
  64. """
  65. Computes the inverse frequencies according to the original RoPE implementation
  66. Args:
  67. config ([`~transformers.PreTrainedConfig`]):
  68. The model configuration.
  69. device (`torch.device`):
  70. The device to use for initialization of the inverse frequencies.
  71. seq_len (`int`, *optional*):
  72. The current sequence length. Unused for this type of RoPE.
  73. Returns:
  74. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  75. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  76. """
  77. base = config.rope_parameters["rope_theta"]
  78. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  79. attention_factor = 1.0 # Unused in this type of RoPE
  80. # Compute the inverse frequencies
  81. inv_freq = 1.0 / (
  82. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  83. )
  84. return inv_freq, attention_factor
  85. @torch.no_grad()
  86. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  87. def forward(self, x, position_ids):
  88. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  89. position_ids_expanded = position_ids[:, None, :].float()
  90. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  91. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  92. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  93. emb = torch.cat((freqs, freqs), dim=-1)
  94. cos = emb.cos() * self.attention_scaling
  95. sin = emb.sin() * self.attention_scaling
  96. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  97. def rotate_half(x):
  98. """Rotates half the hidden dims of the input."""
  99. x1 = x[..., : x.shape[-1] // 2]
  100. x2 = x[..., x.shape[-1] // 2 :]
  101. return torch.cat((-x2, x1), dim=-1)
  102. @use_kernel_func_from_hub("rotary_pos_emb")
  103. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  104. """Applies Rotary Position Embedding to the query and key tensors.
  105. Args:
  106. q (`torch.Tensor`): The query tensor.
  107. k (`torch.Tensor`): The key tensor.
  108. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  109. sin (`torch.Tensor`): The sine part of the rotary embedding.
  110. unsqueeze_dim (`int`, *optional*, defaults to 1):
  111. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  112. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  113. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  114. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  115. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  116. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  117. Returns:
  118. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  119. """
  120. cos = cos.unsqueeze(unsqueeze_dim)
  121. sin = sin.unsqueeze(unsqueeze_dim)
  122. q_embed = (q * cos) + (rotate_half(q) * sin)
  123. k_embed = (k * cos) + (rotate_half(k) * sin)
  124. return q_embed, k_embed
  125. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  126. """
  127. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  128. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  129. """
  130. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  131. if n_rep == 1:
  132. return hidden_states
  133. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  134. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  135. def eager_attention_forward(
  136. module: nn.Module,
  137. query: torch.Tensor,
  138. key: torch.Tensor,
  139. value: torch.Tensor,
  140. attention_mask: torch.Tensor | None,
  141. scaling: float,
  142. dropout: float = 0.0,
  143. **kwargs: Unpack[TransformersKwargs],
  144. ):
  145. key_states = repeat_kv(key, module.num_key_value_groups)
  146. value_states = repeat_kv(value, module.num_key_value_groups)
  147. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  148. if attention_mask is not None:
  149. attn_weights = attn_weights + attention_mask
  150. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  151. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  152. attn_output = torch.matmul(attn_weights, value_states)
  153. attn_output = attn_output.transpose(1, 2).contiguous()
  154. return attn_output, attn_weights
  155. @use_kernelized_func(apply_rotary_pos_emb)
  156. class Qwen2Attention(nn.Module):
  157. """Multi-headed attention from 'Attention Is All You Need' paper"""
  158. def __init__(self, config: Qwen2Config, layer_idx: int):
  159. super().__init__()
  160. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  161. self.config = config
  162. self.layer_idx = layer_idx
  163. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  164. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  165. self.scaling = self.head_dim**-0.5
  166. self.attention_dropout = config.attention_dropout
  167. self.is_causal = True
  168. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
  169. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  170. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
  171. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  172. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  173. def forward(
  174. self,
  175. hidden_states: torch.Tensor,
  176. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  177. attention_mask: torch.Tensor | None,
  178. past_key_values: Cache | None = None,
  179. **kwargs: Unpack[FlashAttentionKwargs],
  180. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  181. input_shape = hidden_states.shape[:-1]
  182. hidden_shape = (*input_shape, -1, self.head_dim)
  183. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  184. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  185. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  186. cos, sin = position_embeddings
  187. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  188. if past_key_values is not None:
  189. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  190. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  191. self.config._attn_implementation, eager_attention_forward
  192. )
  193. attn_output, attn_weights = attention_interface(
  194. self,
  195. query_states,
  196. key_states,
  197. value_states,
  198. attention_mask,
  199. dropout=0.0 if not self.training else self.attention_dropout,
  200. scaling=self.scaling,
  201. sliding_window=self.sliding_window, # main diff with Llama
  202. **kwargs,
  203. )
  204. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  205. attn_output = self.o_proj(attn_output)
  206. return attn_output, attn_weights
  207. @use_kernel_forward_from_hub("RMSNorm")
  208. class Qwen2RMSNorm(nn.Module):
  209. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  210. """
  211. Qwen2RMSNorm is equivalent to T5LayerNorm
  212. """
  213. super().__init__()
  214. self.weight = nn.Parameter(torch.ones(hidden_size))
  215. self.variance_epsilon = eps
  216. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  217. input_dtype = hidden_states.dtype
  218. hidden_states = hidden_states.to(torch.float32)
  219. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  220. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  221. return self.weight * hidden_states.to(input_dtype)
  222. def extra_repr(self):
  223. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  224. class Qwen2DecoderLayer(GradientCheckpointingLayer):
  225. def __init__(self, config: Qwen2Config, layer_idx: int):
  226. super().__init__()
  227. self.hidden_size = config.hidden_size
  228. self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
  229. self.mlp = Qwen2MLP(config)
  230. self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  231. self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  232. def forward(
  233. self,
  234. hidden_states: torch.Tensor,
  235. attention_mask: torch.Tensor | None = None,
  236. position_ids: torch.LongTensor | None = None,
  237. past_key_values: Cache | None = None,
  238. use_cache: bool | None = False,
  239. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  240. **kwargs: Unpack[TransformersKwargs],
  241. ) -> torch.Tensor:
  242. residual = hidden_states
  243. hidden_states = self.input_layernorm(hidden_states)
  244. # Self Attention
  245. hidden_states, _ = self.self_attn(
  246. hidden_states=hidden_states,
  247. attention_mask=attention_mask,
  248. position_ids=position_ids,
  249. past_key_values=past_key_values,
  250. use_cache=use_cache,
  251. position_embeddings=position_embeddings,
  252. **kwargs,
  253. )
  254. hidden_states = residual + hidden_states
  255. # Fully Connected
  256. residual = hidden_states
  257. hidden_states = self.post_attention_layernorm(hidden_states)
  258. hidden_states = self.mlp(hidden_states)
  259. hidden_states = residual + hidden_states
  260. return hidden_states
  261. @auto_docstring
  262. class Qwen2PreTrainedModel(PreTrainedModel):
  263. config: Qwen2Config
  264. base_model_prefix = "model"
  265. supports_gradient_checkpointing = True
  266. _no_split_modules = ["Qwen2DecoderLayer"]
  267. _skip_keys_device_placement = ["past_key_values"]
  268. _supports_flash_attn = True
  269. _supports_sdpa = True
  270. _supports_flex_attn = True
  271. _can_compile_fullgraph = True
  272. _supports_attention_backend = True
  273. _can_record_outputs = {
  274. "hidden_states": Qwen2DecoderLayer,
  275. "attentions": Qwen2Attention,
  276. }
  277. @auto_docstring
  278. class Qwen2Model(Qwen2PreTrainedModel):
  279. def __init__(self, config: Qwen2Config):
  280. super().__init__(config)
  281. self.padding_idx = config.pad_token_id
  282. self.vocab_size = config.vocab_size
  283. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  284. self.layers = nn.ModuleList(
  285. [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  286. )
  287. self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  288. self.rotary_emb = Qwen2RotaryEmbedding(config=config)
  289. self.gradient_checkpointing = False
  290. self.has_sliding_layers = "sliding_attention" in self.config.layer_types
  291. # Initialize weights and apply final processing
  292. self.post_init()
  293. @merge_with_config_defaults
  294. @capture_outputs
  295. @auto_docstring
  296. def forward(
  297. self,
  298. input_ids: torch.LongTensor | None = None,
  299. attention_mask: torch.Tensor | None = None,
  300. position_ids: torch.LongTensor | None = None,
  301. past_key_values: Cache | None = None,
  302. inputs_embeds: torch.FloatTensor | None = None,
  303. use_cache: bool | None = None,
  304. **kwargs: Unpack[TransformersKwargs],
  305. ) -> BaseModelOutputWithPast:
  306. if (input_ids is None) ^ (inputs_embeds is not None):
  307. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  308. if inputs_embeds is None:
  309. inputs_embeds = self.embed_tokens(input_ids)
  310. if use_cache and past_key_values is None:
  311. past_key_values = DynamicCache(config=self.config)
  312. if position_ids is None:
  313. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  314. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  315. position_ids = position_ids.unsqueeze(0)
  316. # It may already have been prepared by e.g. `generate`
  317. if not isinstance(causal_mask_mapping := attention_mask, dict):
  318. # Prepare mask arguments
  319. mask_kwargs = {
  320. "config": self.config,
  321. "inputs_embeds": inputs_embeds,
  322. "attention_mask": attention_mask,
  323. "past_key_values": past_key_values,
  324. "position_ids": position_ids,
  325. }
  326. # Create the masks
  327. causal_mask_mapping = {
  328. "full_attention": create_causal_mask(**mask_kwargs),
  329. }
  330. # The sliding window alternating layers are not always activated depending on the config
  331. if self.has_sliding_layers:
  332. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  333. hidden_states = inputs_embeds
  334. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  335. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  336. hidden_states = decoder_layer(
  337. hidden_states,
  338. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  339. position_embeddings=position_embeddings,
  340. position_ids=position_ids,
  341. past_key_values=past_key_values,
  342. use_cache=use_cache,
  343. **kwargs,
  344. )
  345. hidden_states = self.norm(hidden_states)
  346. return BaseModelOutputWithPast(
  347. last_hidden_state=hidden_states,
  348. past_key_values=past_key_values if use_cache else None,
  349. )
  350. @auto_docstring
  351. class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
  352. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  353. _tp_plan = {"lm_head": "colwise_gather_output"}
  354. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  355. def __init__(self, config):
  356. super().__init__(config)
  357. self.model = Qwen2Model(config)
  358. self.vocab_size = config.vocab_size
  359. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  360. # Initialize weights and apply final processing
  361. self.post_init()
  362. @can_return_tuple
  363. @auto_docstring
  364. def forward(
  365. self,
  366. input_ids: torch.LongTensor | None = None,
  367. attention_mask: torch.Tensor | None = None,
  368. position_ids: torch.LongTensor | None = None,
  369. past_key_values: Cache | None = None,
  370. inputs_embeds: torch.FloatTensor | None = None,
  371. labels: torch.LongTensor | None = None,
  372. use_cache: bool | None = None,
  373. logits_to_keep: int | torch.Tensor = 0,
  374. **kwargs: Unpack[TransformersKwargs],
  375. ) -> CausalLMOutputWithPast:
  376. r"""
  377. Example:
  378. ```python
  379. >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
  380. >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
  381. >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
  382. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  383. >>> inputs = tokenizer(prompt, return_tensors="pt")
  384. >>> # Generate
  385. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  386. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  387. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  388. ```"""
  389. outputs: BaseModelOutputWithPast = self.model(
  390. input_ids=input_ids,
  391. attention_mask=attention_mask,
  392. position_ids=position_ids,
  393. past_key_values=past_key_values,
  394. inputs_embeds=inputs_embeds,
  395. use_cache=use_cache,
  396. **kwargs,
  397. )
  398. hidden_states = outputs.last_hidden_state
  399. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  400. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  401. logits = self.lm_head(hidden_states[:, slice_indices, :])
  402. loss = None
  403. if labels is not None:
  404. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  405. return CausalLMOutputWithPast(
  406. loss=loss,
  407. logits=logits,
  408. past_key_values=outputs.past_key_values,
  409. hidden_states=outputs.hidden_states,
  410. attentions=outputs.attentions,
  411. )
  412. class Qwen2ForSequenceClassification(GenericForSequenceClassification, Qwen2PreTrainedModel):
  413. pass
  414. class Qwen2ForTokenClassification(GenericForTokenClassification, Qwen2PreTrainedModel):
  415. pass
  416. class Qwen2ForQuestionAnswering(GenericForQuestionAnswering, Qwen2PreTrainedModel):
  417. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  418. __all__ = [
  419. "Qwen2PreTrainedModel",
  420. "Qwen2Model",
  421. "Qwen2ForCausalLM",
  422. "Qwen2RMSNorm",
  423. "Qwen2ForSequenceClassification",
  424. "Qwen2ForTokenClassification",
  425. "Qwen2ForQuestionAnswering",
  426. ]