modeling_youtu.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/youtu/modular_youtu.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_youtu.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 the Tencent and HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. import math
  26. from collections.abc import Callable
  27. from typing import Optional
  28. import torch
  29. import torch.nn.functional as F
  30. from torch import nn
  31. from ... import initialization as init
  32. from ...activations import ACT2FN
  33. from ...cache_utils import Cache, DynamicCache
  34. from ...generation import GenerationMixin
  35. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
  36. from ...masking_utils import create_causal_mask
  37. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  38. from ...modeling_layers import GradientCheckpointingLayer
  39. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  40. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  44. from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
  45. from ...utils.output_capturing import capture_outputs
  46. from .configuration_youtu import YoutuConfig
  47. @use_kernel_forward_from_hub("RMSNorm")
  48. class YoutuRMSNorm(nn.Module):
  49. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  50. """
  51. YoutuRMSNorm is equivalent to T5LayerNorm
  52. """
  53. super().__init__()
  54. self.weight = nn.Parameter(torch.ones(hidden_size))
  55. self.variance_epsilon = eps
  56. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  57. input_dtype = hidden_states.dtype
  58. hidden_states = hidden_states.to(torch.float32)
  59. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  60. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  61. return self.weight * hidden_states.to(input_dtype)
  62. def extra_repr(self):
  63. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  64. class YoutuRotaryEmbedding(nn.Module):
  65. inv_freq: torch.Tensor # fix linting for `register_buffer`
  66. def __init__(self, config: YoutuConfig, device=None):
  67. super().__init__()
  68. self.max_seq_len_cached = config.max_position_embeddings
  69. self.original_max_seq_len = config.max_position_embeddings
  70. self.config = config
  71. self.rope_type = self.config.rope_parameters["rope_type"]
  72. rope_init_fn: Callable = self.compute_default_rope_parameters
  73. if self.rope_type != "default":
  74. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  75. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  76. self.register_buffer("inv_freq", inv_freq, persistent=False)
  77. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  78. @staticmethod
  79. def compute_default_rope_parameters(
  80. config: YoutuConfig | None = None,
  81. device: Optional["torch.device"] = None,
  82. seq_len: int | None = None,
  83. ) -> tuple["torch.Tensor", float]:
  84. """
  85. Computes the inverse frequencies according to the original RoPE implementation
  86. Args:
  87. config ([`~transformers.PreTrainedConfig`]):
  88. The model configuration.
  89. device (`torch.device`):
  90. The device to use for initialization of the inverse frequencies.
  91. seq_len (`int`, *optional*):
  92. The current sequence length. Unused for this type of RoPE.
  93. Returns:
  94. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  95. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  96. """
  97. base = config.rope_parameters["rope_theta"]
  98. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  99. attention_factor = 1.0 # Unused in this type of RoPE
  100. # Compute the inverse frequencies
  101. inv_freq = 1.0 / (
  102. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  103. )
  104. return inv_freq, attention_factor
  105. @torch.no_grad()
  106. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  107. def forward(self, x, position_ids):
  108. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  109. position_ids_expanded = position_ids[:, None, :].float()
  110. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  111. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  112. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  113. emb = torch.cat((freqs, freqs), dim=-1)
  114. cos = emb.cos() * self.attention_scaling
  115. sin = emb.sin() * self.attention_scaling
  116. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  117. class YoutuMLP(nn.Module):
  118. def __init__(self, config):
  119. super().__init__()
  120. self.config = config
  121. self.hidden_size = config.hidden_size
  122. self.intermediate_size = config.intermediate_size
  123. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  124. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  125. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  126. self.act_fn = ACT2FN[config.hidden_act]
  127. def forward(self, x):
  128. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  129. return down_proj
  130. def rotate_half(x):
  131. """Rotates half the hidden dims of the input."""
  132. x1 = x[..., : x.shape[-1] // 2]
  133. x2 = x[..., x.shape[-1] // 2 :]
  134. return torch.cat((-x2, x1), dim=-1)
  135. @use_kernel_func_from_hub("rotary_pos_emb")
  136. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  137. """Applies Rotary Position Embedding to the query and key tensors.
  138. Args:
  139. q (`torch.Tensor`): The query tensor.
  140. k (`torch.Tensor`): The key tensor.
  141. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  142. sin (`torch.Tensor`): The sine part of the rotary embedding.
  143. unsqueeze_dim (`int`, *optional*, defaults to 1):
  144. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  145. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  146. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  147. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  148. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  149. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  150. Returns:
  151. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  152. """
  153. cos = cos.unsqueeze(unsqueeze_dim)
  154. sin = sin.unsqueeze(unsqueeze_dim)
  155. q_embed = (q * cos) + (rotate_half(q) * sin)
  156. k_embed = (k * cos) + (rotate_half(k) * sin)
  157. return q_embed, k_embed
  158. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  159. """
  160. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  161. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  162. """
  163. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  164. if n_rep == 1:
  165. return hidden_states
  166. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  167. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  168. def eager_attention_forward(
  169. module: nn.Module,
  170. query: torch.Tensor,
  171. key: torch.Tensor,
  172. value: torch.Tensor,
  173. attention_mask: torch.Tensor | None,
  174. scaling: float,
  175. dropout: float = 0.0,
  176. **kwargs: Unpack[TransformersKwargs],
  177. ):
  178. key_states = repeat_kv(key, module.num_key_value_groups)
  179. value_states = repeat_kv(value, module.num_key_value_groups)
  180. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  181. if attention_mask is not None:
  182. attn_weights = attn_weights + attention_mask
  183. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  184. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  185. attn_output = torch.matmul(attn_weights, value_states)
  186. attn_output = attn_output.transpose(1, 2).contiguous()
  187. return attn_output, attn_weights
  188. def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  189. r"""
  190. TODO let's just use the original freqcis computation to not have the view
  191. transpose + reshape! This is not optimized!
  192. Applies Rotary Position Embedding to the query and key tensors.
  193. Args:
  194. q (`torch.Tensor`): The query tensor.
  195. k (`torch.Tensor`): The key tensor.
  196. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  197. sin (`torch.Tensor`): The sine part of the rotary embedding.
  198. position_ids (`torch.Tensor`):
  199. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  200. used to pass offsetted position ids when working with a KV-cache.
  201. unsqueeze_dim (`int`, *optional*, defaults to 1):
  202. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  203. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  204. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  205. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  206. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  207. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  208. Returns:
  209. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  210. """
  211. cos = cos.unsqueeze(unsqueeze_dim)
  212. sin = sin.unsqueeze(unsqueeze_dim)
  213. b, h, s, d = q.shape
  214. q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
  215. b, h, s, d = k.shape
  216. k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
  217. q_embed = (q * cos) + (rotate_half(q) * sin)
  218. k_embed = (k * cos) + (rotate_half(k) * sin)
  219. return q_embed, k_embed
  220. def yarn_get_mscale(scale=1, mscale=1):
  221. if scale <= 1:
  222. return 1.0
  223. return 0.1 * mscale * math.log(scale) + 1.0
  224. class YoutuAttention(nn.Module):
  225. """Multi-headed attention from 'Attention Is All You Need' paper"""
  226. def __init__(self, config: YoutuConfig, layer_idx: int):
  227. super().__init__()
  228. self.config = config
  229. self.layer_idx = layer_idx
  230. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  231. self.attention_dropout = config.attention_dropout
  232. self.num_heads = config.num_attention_heads
  233. self.q_lora_rank = config.q_lora_rank
  234. self.qk_rope_head_dim = config.qk_rope_head_dim
  235. self.kv_lora_rank = config.kv_lora_rank
  236. self.v_head_dim = config.v_head_dim
  237. self.qk_nope_head_dim = config.qk_nope_head_dim
  238. self.qk_head_dim = config.qk_head_dim
  239. self.is_causal = True
  240. if self.q_lora_rank is None:
  241. self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
  242. else:
  243. self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
  244. self.q_a_layernorm = YoutuRMSNorm(config.q_lora_rank)
  245. self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
  246. self.kv_a_proj_with_mqa = nn.Linear(
  247. config.hidden_size,
  248. self.kv_lora_rank + self.qk_rope_head_dim,
  249. bias=config.attention_bias,
  250. )
  251. self.kv_a_layernorm = YoutuRMSNorm(self.kv_lora_rank)
  252. self.kv_b_proj = nn.Linear(
  253. self.kv_lora_rank,
  254. self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
  255. bias=False,
  256. )
  257. self.o_proj = nn.Linear(
  258. self.num_heads * self.v_head_dim,
  259. config.hidden_size,
  260. bias=config.attention_bias,
  261. )
  262. self.scaling = self.qk_head_dim ** (-0.5)
  263. if self.config.rope_parameters.get("rope_type", "default") != "default":
  264. mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0)
  265. scaling_factor = self.config.rope_parameters["factor"]
  266. if mscale_all_dim:
  267. mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
  268. self.scaling = self.scaling * mscale * mscale
  269. def forward(
  270. self,
  271. hidden_states: torch.Tensor,
  272. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  273. attention_mask: torch.Tensor | None,
  274. past_key_values: Cache | None = None,
  275. **kwargs: Unpack[FlashAttentionKwargs],
  276. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  277. batch_size, seq_length = hidden_states.shape[:-1]
  278. query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
  279. key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
  280. if self.q_lora_rank is None:
  281. q_states = self.q_proj(hidden_states)
  282. else:
  283. q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
  284. q_states = q_states.view(query_shape).transpose(1, 2)
  285. q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
  286. compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
  287. k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  288. k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
  289. k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  290. k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
  291. cos, sin = position_embeddings
  292. if self.config.rope_interleave: # support using interleaved weights for efficiency
  293. q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
  294. else:
  295. q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
  296. k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
  297. query_states = torch.cat((q_pass, q_rot), dim=-1)
  298. key_states = torch.cat((k_pass, k_rot), dim=-1)
  299. if past_key_values is not None:
  300. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  301. if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
  302. value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
  303. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  304. self.config._attn_implementation, eager_attention_forward
  305. )
  306. attn_output, attn_weights = attention_interface(
  307. self,
  308. query_states,
  309. key_states,
  310. value_states,
  311. attention_mask,
  312. dropout=0.0 if not self.training else self.attention_dropout,
  313. scaling=self.scaling,
  314. **kwargs,
  315. )
  316. if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim:
  317. attn_output = attn_output[:, :, :, : self.v_head_dim]
  318. attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
  319. attn_output = self.o_proj(attn_output)
  320. return attn_output, attn_weights
  321. class YoutuDecoderLayer(GradientCheckpointingLayer):
  322. def __init__(self, config: YoutuConfig, layer_idx: int):
  323. super().__init__()
  324. self.hidden_size = config.hidden_size
  325. self.self_attn = YoutuAttention(config=config, layer_idx=layer_idx)
  326. self.mlp = YoutuMLP(config)
  327. self.input_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  328. self.post_attention_layernorm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  329. def forward(
  330. self,
  331. hidden_states: torch.Tensor,
  332. attention_mask: torch.Tensor | None = None,
  333. position_ids: torch.LongTensor | None = None,
  334. past_key_values: Cache | None = None,
  335. use_cache: bool | None = False,
  336. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  337. **kwargs: Unpack[TransformersKwargs],
  338. ) -> torch.Tensor:
  339. residual = hidden_states
  340. hidden_states = self.input_layernorm(hidden_states)
  341. # Self Attention
  342. hidden_states, _ = self.self_attn(
  343. hidden_states=hidden_states,
  344. attention_mask=attention_mask,
  345. position_ids=position_ids,
  346. past_key_values=past_key_values,
  347. use_cache=use_cache,
  348. position_embeddings=position_embeddings,
  349. **kwargs,
  350. )
  351. hidden_states = residual + hidden_states
  352. # Fully Connected
  353. residual = hidden_states
  354. hidden_states = self.post_attention_layernorm(hidden_states)
  355. hidden_states = self.mlp(hidden_states)
  356. hidden_states = residual + hidden_states
  357. return hidden_states
  358. @auto_docstring
  359. class YoutuPreTrainedModel(PreTrainedModel):
  360. config: YoutuConfig
  361. base_model_prefix = "model"
  362. supports_gradient_checkpointing = True
  363. _no_split_modules = ["YoutuDecoderLayer"]
  364. _skip_keys_device_placement = ["past_key_values"]
  365. _supports_flash_attn = True
  366. _supports_sdpa = True
  367. _supports_flex_attn = True
  368. _can_compile_fullgraph = True
  369. _supports_attention_backend = True
  370. _can_record_outputs = {
  371. "hidden_states": YoutuDecoderLayer,
  372. "attentions": YoutuAttention,
  373. }
  374. @torch.no_grad()
  375. def _init_weights(self, module):
  376. super()._init_weights(module)
  377. std = getattr(self.config, "initializer_range", 0.02)
  378. embed_std = getattr(self.config, "embedding_initializer_range", 2 * std)
  379. if isinstance(module, nn.Embedding):
  380. init.normal_(module.weight, mean=0.0, std=embed_std)
  381. if module.padding_idx is not None:
  382. init.zeros_(module.weight.data[module.padding_idx])
  383. @auto_docstring
  384. class YoutuModel(YoutuPreTrainedModel):
  385. def __init__(self, config: YoutuConfig):
  386. super().__init__(config)
  387. self.padding_idx = config.pad_token_id
  388. self.vocab_size = config.vocab_size
  389. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  390. self.layers = nn.ModuleList(
  391. [YoutuDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  392. )
  393. self.norm = YoutuRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  394. self.rotary_emb = YoutuRotaryEmbedding(config=config)
  395. self.gradient_checkpointing = False
  396. # Initialize weights and apply final processing
  397. self.post_init()
  398. @merge_with_config_defaults
  399. @capture_outputs
  400. @auto_docstring
  401. def forward(
  402. self,
  403. input_ids: torch.LongTensor | None = None,
  404. attention_mask: torch.Tensor | None = None,
  405. position_ids: torch.LongTensor | None = None,
  406. past_key_values: Cache | None = None,
  407. inputs_embeds: torch.FloatTensor | None = None,
  408. use_cache: bool | None = None,
  409. **kwargs: Unpack[TransformersKwargs],
  410. ) -> BaseModelOutputWithPast:
  411. if (input_ids is None) ^ (inputs_embeds is not None):
  412. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  413. if inputs_embeds is None:
  414. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  415. if use_cache and past_key_values is None:
  416. past_key_values = DynamicCache(config=self.config)
  417. if position_ids is None:
  418. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  419. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  420. position_ids = position_ids.unsqueeze(0)
  421. causal_mask = create_causal_mask(
  422. config=self.config,
  423. inputs_embeds=inputs_embeds,
  424. attention_mask=attention_mask,
  425. past_key_values=past_key_values,
  426. position_ids=position_ids,
  427. )
  428. hidden_states = inputs_embeds
  429. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  430. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  431. hidden_states = decoder_layer(
  432. hidden_states,
  433. attention_mask=causal_mask,
  434. position_embeddings=position_embeddings,
  435. position_ids=position_ids,
  436. past_key_values=past_key_values,
  437. use_cache=use_cache,
  438. **kwargs,
  439. )
  440. hidden_states = self.norm(hidden_states)
  441. return BaseModelOutputWithPast(
  442. last_hidden_state=hidden_states,
  443. past_key_values=past_key_values,
  444. )
  445. @auto_docstring
  446. class YoutuForCausalLM(YoutuPreTrainedModel, GenerationMixin):
  447. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  448. _tp_plan = {"lm_head": "colwise_gather_output"}
  449. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  450. def __init__(self, config):
  451. super().__init__(config)
  452. self.model = YoutuModel(config)
  453. self.vocab_size = config.vocab_size
  454. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  455. # Initialize weights and apply final processing
  456. self.post_init()
  457. @can_return_tuple
  458. @auto_docstring
  459. def forward(
  460. self,
  461. input_ids: torch.LongTensor | None = None,
  462. attention_mask: torch.Tensor | None = None,
  463. position_ids: torch.LongTensor | None = None,
  464. past_key_values: Cache | None = None,
  465. inputs_embeds: torch.FloatTensor | None = None,
  466. labels: torch.LongTensor | None = None,
  467. use_cache: bool | None = None,
  468. logits_to_keep: int | torch.Tensor = 0,
  469. **kwargs: Unpack[TransformersKwargs],
  470. ) -> CausalLMOutputWithPast:
  471. r"""
  472. Example:
  473. ```python
  474. >>> from transformers import AutoTokenizer, YoutuForCausalLM
  475. >>> model = YoutuForCausalLM.from_pretrained("meta-youtu/Youtu-2-7b-hf")
  476. >>> tokenizer = AutoTokenizer.from_pretrained("meta-youtu/Youtu-2-7b-hf")
  477. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  478. >>> inputs = tokenizer(prompt, return_tensors="pt")
  479. >>> # Generate
  480. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  481. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  482. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  483. ```"""
  484. outputs: BaseModelOutputWithPast = self.model(
  485. input_ids=input_ids,
  486. attention_mask=attention_mask,
  487. position_ids=position_ids,
  488. past_key_values=past_key_values,
  489. inputs_embeds=inputs_embeds,
  490. use_cache=use_cache,
  491. **kwargs,
  492. )
  493. hidden_states = outputs.last_hidden_state
  494. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  495. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  496. logits = self.lm_head(hidden_states[:, slice_indices, :])
  497. loss = None
  498. if labels is not None:
  499. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  500. return CausalLMOutputWithPast(
  501. loss=loss,
  502. logits=logits,
  503. past_key_values=outputs.past_key_values,
  504. hidden_states=outputs.hidden_states,
  505. attentions=outputs.attentions,
  506. )
  507. __all__ = ["YoutuPreTrainedModel", "YoutuModel", "YoutuForCausalLM"]