modeling_jais2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/jais2/modular_jais2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_jais2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 the HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. import torch.nn as nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  28. from ...masking_utils import create_causal_mask
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  31. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  35. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_jais2 import Jais2Config
  38. class Jais2MLP(nn.Module):
  39. def __init__(self, config):
  40. super().__init__()
  41. self.config = config
  42. self.hidden_size = config.hidden_size
  43. self.intermediate_size = config.intermediate_size
  44. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  45. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  46. self.act_fn = ACT2FN[config.hidden_act]
  47. def forward(self, x):
  48. return self.down_proj(self.act_fn(self.up_proj(x)))
  49. def rotate_half(x):
  50. """Rotates half the hidden dims of the input."""
  51. x1 = x[..., : x.shape[-1] // 2]
  52. x2 = x[..., x.shape[-1] // 2 :]
  53. return torch.cat((-x2, x1), dim=-1)
  54. @use_kernel_func_from_hub("rotary_pos_emb")
  55. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  56. """Applies Rotary Position Embedding to the query and key tensors.
  57. Args:
  58. q (`torch.Tensor`): The query tensor.
  59. k (`torch.Tensor`): The key tensor.
  60. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  61. sin (`torch.Tensor`): The sine part of the rotary embedding.
  62. unsqueeze_dim (`int`, *optional*, defaults to 1):
  63. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  64. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  65. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  66. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  67. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  68. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  69. Returns:
  70. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  71. """
  72. cos = cos.unsqueeze(unsqueeze_dim)
  73. sin = sin.unsqueeze(unsqueeze_dim)
  74. q_embed = (q * cos) + (rotate_half(q) * sin)
  75. k_embed = (k * cos) + (rotate_half(k) * sin)
  76. return q_embed, k_embed
  77. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  78. """
  79. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  80. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  81. """
  82. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  83. if n_rep == 1:
  84. return hidden_states
  85. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  86. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  87. def eager_attention_forward(
  88. module: nn.Module,
  89. query: torch.Tensor,
  90. key: torch.Tensor,
  91. value: torch.Tensor,
  92. attention_mask: torch.Tensor | None,
  93. scaling: float,
  94. dropout: float = 0.0,
  95. **kwargs: Unpack[TransformersKwargs],
  96. ):
  97. key_states = repeat_kv(key, module.num_key_value_groups)
  98. value_states = repeat_kv(value, module.num_key_value_groups)
  99. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  100. if attention_mask is not None:
  101. attn_weights = attn_weights + attention_mask
  102. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  103. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  104. attn_output = torch.matmul(attn_weights, value_states)
  105. attn_output = attn_output.transpose(1, 2).contiguous()
  106. return attn_output, attn_weights
  107. @use_kernelized_func(apply_rotary_pos_emb)
  108. class Jais2Attention(nn.Module):
  109. """Multi-headed attention from 'Attention Is All You Need' paper"""
  110. def __init__(self, config: Jais2Config, layer_idx: int):
  111. super().__init__()
  112. self.config = config
  113. self.layer_idx = layer_idx
  114. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  115. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  116. self.scaling = self.head_dim**-0.5
  117. self.attention_dropout = config.attention_dropout
  118. self.is_causal = True
  119. self.q_proj = nn.Linear(
  120. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  121. )
  122. self.k_proj = nn.Linear(
  123. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  124. )
  125. self.v_proj = nn.Linear(
  126. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  127. )
  128. self.o_proj = nn.Linear(
  129. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  130. )
  131. def forward(
  132. self,
  133. hidden_states: torch.Tensor,
  134. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  135. attention_mask: torch.Tensor | None = None,
  136. past_key_values: Cache | None = None,
  137. **kwargs: Unpack[TransformersKwargs],
  138. ) -> tuple[torch.Tensor, torch.Tensor]:
  139. input_shape = hidden_states.shape[:-1]
  140. hidden_shape = (*input_shape, -1, self.head_dim)
  141. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  142. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  143. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  144. cos, sin = position_embeddings
  145. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  146. if past_key_values is not None:
  147. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  148. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  149. self.config._attn_implementation, eager_attention_forward
  150. )
  151. attn_output, attn_weights = attention_interface(
  152. self,
  153. query_states,
  154. key_states,
  155. value_states,
  156. attention_mask,
  157. dropout=0.0 if not self.training else self.attention_dropout,
  158. scaling=self.scaling,
  159. **kwargs,
  160. )
  161. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  162. attn_output = self.o_proj(attn_output)
  163. return attn_output, attn_weights
  164. class Jais2DecoderLayer(GradientCheckpointingLayer):
  165. def __init__(self, config: Jais2Config, layer_idx: int):
  166. super().__init__()
  167. self.hidden_size = config.hidden_size
  168. self.self_attn = Jais2Attention(config=config, layer_idx=layer_idx)
  169. self.mlp = Jais2MLP(config)
  170. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  171. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  172. def forward(
  173. self,
  174. hidden_states: torch.Tensor,
  175. attention_mask: torch.Tensor | None = None,
  176. position_ids: torch.LongTensor | None = None,
  177. past_key_values: Cache | None = None,
  178. use_cache: bool | None = False,
  179. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  180. **kwargs: Unpack[TransformersKwargs],
  181. ) -> torch.Tensor:
  182. residual = hidden_states
  183. hidden_states = self.input_layernorm(hidden_states)
  184. # Self Attention
  185. hidden_states, _ = self.self_attn(
  186. hidden_states=hidden_states,
  187. attention_mask=attention_mask,
  188. position_ids=position_ids,
  189. past_key_values=past_key_values,
  190. use_cache=use_cache,
  191. position_embeddings=position_embeddings,
  192. **kwargs,
  193. )
  194. hidden_states = residual + hidden_states
  195. # Fully Connected
  196. residual = hidden_states
  197. hidden_states = self.post_attention_layernorm(hidden_states)
  198. hidden_states = self.mlp(hidden_states)
  199. hidden_states = residual + hidden_states
  200. return hidden_states
  201. @auto_docstring
  202. class Jais2PreTrainedModel(PreTrainedModel):
  203. config: Jais2Config
  204. base_model_prefix = "model"
  205. supports_gradient_checkpointing = True
  206. _no_split_modules = ["Jais2DecoderLayer"]
  207. _skip_keys_device_placement = ["past_key_values"]
  208. _supports_flash_attn = True
  209. _supports_sdpa = True
  210. _supports_flex_attn = True
  211. _can_compile_fullgraph = True
  212. _supports_attention_backend = True
  213. _can_record_outputs = {
  214. "hidden_states": Jais2DecoderLayer,
  215. "attentions": Jais2Attention,
  216. }
  217. class Jais2RotaryEmbedding(nn.Module):
  218. inv_freq: torch.Tensor # fix linting for `register_buffer`
  219. def __init__(self, config: Jais2Config, device=None):
  220. super().__init__()
  221. self.max_seq_len_cached = config.max_position_embeddings
  222. self.original_max_seq_len = config.max_position_embeddings
  223. self.config = config
  224. self.rope_type = self.config.rope_parameters["rope_type"]
  225. rope_init_fn: Callable = self.compute_default_rope_parameters
  226. if self.rope_type != "default":
  227. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  228. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  229. self.register_buffer("inv_freq", inv_freq, persistent=False)
  230. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  231. @staticmethod
  232. def compute_default_rope_parameters(
  233. config: Jais2Config | None = None,
  234. device: Optional["torch.device"] = None,
  235. seq_len: int | None = None,
  236. ) -> tuple["torch.Tensor", float]:
  237. """
  238. Computes the inverse frequencies according to the original RoPE implementation
  239. Args:
  240. config ([`~transformers.PreTrainedConfig`]):
  241. The model configuration.
  242. device (`torch.device`):
  243. The device to use for initialization of the inverse frequencies.
  244. seq_len (`int`, *optional*):
  245. The current sequence length. Unused for this type of RoPE.
  246. Returns:
  247. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  248. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  249. """
  250. base = config.rope_parameters["rope_theta"]
  251. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  252. attention_factor = 1.0 # Unused in this type of RoPE
  253. # Compute the inverse frequencies
  254. inv_freq = 1.0 / (
  255. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  256. )
  257. return inv_freq, attention_factor
  258. @torch.no_grad()
  259. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  260. def forward(self, x, position_ids):
  261. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  262. position_ids_expanded = position_ids[:, None, :].float()
  263. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  264. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  265. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  266. emb = torch.cat((freqs, freqs), dim=-1)
  267. cos = emb.cos() * self.attention_scaling
  268. sin = emb.sin() * self.attention_scaling
  269. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  270. @auto_docstring
  271. class Jais2Model(Jais2PreTrainedModel):
  272. def __init__(self, config: Jais2Config):
  273. super().__init__(config)
  274. self.padding_idx = config.pad_token_id
  275. self.vocab_size = config.vocab_size
  276. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  277. self.layers = nn.ModuleList(
  278. [Jais2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  279. )
  280. self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  281. self.rotary_emb = Jais2RotaryEmbedding(config=config)
  282. self.gradient_checkpointing = False
  283. # Initialize weights and apply final processing
  284. self.post_init()
  285. @merge_with_config_defaults
  286. @capture_outputs
  287. @auto_docstring
  288. def forward(
  289. self,
  290. input_ids: torch.LongTensor | None = None,
  291. attention_mask: torch.Tensor | None = None,
  292. position_ids: torch.LongTensor | None = None,
  293. past_key_values: Cache | None = None,
  294. inputs_embeds: torch.FloatTensor | None = None,
  295. use_cache: bool | None = None,
  296. **kwargs: Unpack[TransformersKwargs],
  297. ) -> BaseModelOutputWithPast:
  298. if (input_ids is None) ^ (inputs_embeds is not None):
  299. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  300. if inputs_embeds is None:
  301. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  302. if use_cache and past_key_values is None:
  303. past_key_values = DynamicCache(config=self.config)
  304. if position_ids is None:
  305. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  306. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  307. position_ids = position_ids.unsqueeze(0)
  308. causal_mask = create_causal_mask(
  309. config=self.config,
  310. inputs_embeds=inputs_embeds,
  311. attention_mask=attention_mask,
  312. past_key_values=past_key_values,
  313. position_ids=position_ids,
  314. )
  315. hidden_states = inputs_embeds
  316. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  317. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  318. hidden_states = decoder_layer(
  319. hidden_states,
  320. attention_mask=causal_mask,
  321. position_embeddings=position_embeddings,
  322. position_ids=position_ids,
  323. past_key_values=past_key_values,
  324. use_cache=use_cache,
  325. **kwargs,
  326. )
  327. hidden_states = self.norm(hidden_states)
  328. return BaseModelOutputWithPast(
  329. last_hidden_state=hidden_states,
  330. past_key_values=past_key_values,
  331. )
  332. @auto_docstring
  333. class Jais2ForCausalLM(Jais2PreTrainedModel, GenerationMixin):
  334. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  335. _tp_plan = {"lm_head": "colwise_gather_output"}
  336. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  337. def __init__(self, config):
  338. super().__init__(config)
  339. self.model = Jais2Model(config)
  340. self.vocab_size = config.vocab_size
  341. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  342. # Initialize weights and apply final processing
  343. self.post_init()
  344. @can_return_tuple
  345. @auto_docstring
  346. def forward(
  347. self,
  348. input_ids: torch.LongTensor | None = None,
  349. attention_mask: torch.Tensor | None = None,
  350. position_ids: torch.LongTensor | None = None,
  351. past_key_values: Cache | None = None,
  352. inputs_embeds: torch.FloatTensor | None = None,
  353. labels: torch.LongTensor | None = None,
  354. use_cache: bool | None = None,
  355. logits_to_keep: int | torch.Tensor = 0,
  356. **kwargs: Unpack[TransformersKwargs],
  357. ) -> CausalLMOutputWithPast:
  358. r"""
  359. Example:
  360. ```python
  361. >>> from transformers import AutoTokenizer, Jais2ForCausalLM
  362. >>> model = Jais2ForCausalLM.from_pretrained("inceptionai/Jais-2-8B-Chat")
  363. >>> tokenizer = AutoTokenizer.from_pretrained("inceptionai/Jais-2-8B-Chat")
  364. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  365. >>> inputs = tokenizer(prompt, return_tensors="pt")
  366. >>> # Generate
  367. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  368. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  369. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  370. ```"""
  371. outputs: BaseModelOutputWithPast = self.model(
  372. input_ids=input_ids,
  373. attention_mask=attention_mask,
  374. position_ids=position_ids,
  375. past_key_values=past_key_values,
  376. inputs_embeds=inputs_embeds,
  377. use_cache=use_cache,
  378. **kwargs,
  379. )
  380. hidden_states = outputs.last_hidden_state
  381. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  382. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  383. logits = self.lm_head(hidden_states[:, slice_indices, :])
  384. loss = None
  385. if labels is not None:
  386. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  387. return CausalLMOutputWithPast(
  388. loss=loss,
  389. logits=logits,
  390. past_key_values=outputs.past_key_values,
  391. hidden_states=outputs.hidden_states,
  392. attentions=outputs.attentions,
  393. )
  394. __all__ = ["Jais2Model", "Jais2ForCausalLM", "Jais2PreTrainedModel"]