modular_exaone4.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """LG AI Research EXAONE Lab"""
  16. from collections.abc import Callable
  17. import torch
  18. from huggingface_hub.dataclasses import strict
  19. from torch import nn
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithPast,
  25. CausalLMOutputWithPast,
  26. )
  27. from ...modeling_rope_utils import RopeParameters
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  29. from ...processing_utils import Unpack
  30. from ...utils import TransformersKwargs, auto_docstring, logging
  31. from ...utils.generic import merge_with_config_defaults
  32. from ...utils.output_capturing import capture_outputs
  33. from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding
  34. from ..llama.modeling_llama import (
  35. LlamaForCausalLM,
  36. LlamaForQuestionAnswering,
  37. LlamaForSequenceClassification,
  38. LlamaForTokenClassification,
  39. LlamaModel,
  40. LlamaPreTrainedModel,
  41. LlamaRMSNorm,
  42. apply_rotary_pos_emb,
  43. eager_attention_forward,
  44. )
  45. from ..olmo2.modeling_olmo2 import Olmo2DecoderLayer, Olmo2MLP
  46. logger = logging.get_logger(__name__)
  47. _CHECKPOINT_FOR_DOC = "LGAI-EXAONE/EXAONE-4.0-32B"
  48. _CONFIG_FOR_DOC = "Exaone4Config"
  49. @auto_docstring(checkpoint="LGAI-EXAONE/EXAONE-4.0-32B")
  50. @strict
  51. class Exaone4Config(PreTrainedConfig):
  52. r"""
  53. sliding_window_pattern (`str`, *optional*):
  54. The pattern to use for sliding window attention. Can be one of:
  55. - `None`: No sliding window attention is used
  56. - `int`: Every `sliding_window` layers, use global attention, else use local attention.
  57. - `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the
  58. attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The
  59. final layer always uses global attention regardless of the pattern.
  60. For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means:
  61. - Layer 0, 1, 2: local attention,
  62. - Layer 3: global attention,
  63. ...(repeated)
  64. Example:
  65. ```python
  66. >>> from transformers import Exaone4Model, Exaone4Config
  67. >>> # Initializing a EXAONE configuration
  68. >>> configuration = Exaone4Config()
  69. >>> # Initializing a model from configuration
  70. >>> model = Exaone4Model(configuration)
  71. >>> # Accessing the model configuration
  72. >>> configuration = model.config
  73. ```"""
  74. model_type = "exaone4"
  75. keys_to_ignore_at_inference = ["past_key_values"]
  76. # Default tensor parallel plan for base model `LlamaModel`
  77. base_model_tp_plan = {
  78. "layers.*.self_attn.q_proj": "colwise",
  79. "layers.*.self_attn.k_proj": "colwise",
  80. "layers.*.self_attn.v_proj": "colwise",
  81. "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
  82. "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
  83. "layers.*.self_attn.o_proj": "rowwise",
  84. "layers.*.mlp.gate_proj": "colwise",
  85. "layers.*.mlp.up_proj": "colwise",
  86. "layers.*.mlp.down_proj": "rowwise",
  87. }
  88. base_model_pp_plan = {
  89. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  90. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  91. "norm": (["hidden_states"], ["hidden_states"]),
  92. }
  93. vocab_size: int = 102400
  94. hidden_size: int = 4096
  95. intermediate_size: int = 16384
  96. num_hidden_layers: int = 32
  97. num_attention_heads: int = 32
  98. num_key_value_heads: int = 32
  99. hidden_act: str = "silu"
  100. max_position_embeddings: int = 2048
  101. initializer_range: float = 0.02
  102. rms_norm_eps: float = 1e-5
  103. use_cache: bool = True
  104. bos_token_id: int | None = 0
  105. eos_token_id: int | list[int] | None = 2
  106. pad_token_id: int | None = None
  107. tie_word_embeddings: bool = False
  108. rope_parameters: RopeParameters | dict | None = None
  109. attention_dropout: float | int = 0.0
  110. sliding_window: int | None = 4096
  111. sliding_window_pattern: str | int | None = 4
  112. layer_types: list[str] | None = None
  113. def __post_init__(self, **kwargs):
  114. if self.sliding_window is None:
  115. self.sliding_window_pattern = 0
  116. if self.layer_types is None:
  117. self.layer_types = [
  118. "sliding_attention"
  119. if ((i + 1) % (self.sliding_window_pattern) != 0 and i < self.num_hidden_layers)
  120. else "full_attention"
  121. for i in range(self.num_hidden_layers)
  122. ]
  123. super().__post_init__(**kwargs)
  124. class Exaone4RMSNorm(LlamaRMSNorm):
  125. pass
  126. class Exaone4RotaryEmbedding(Gemma2RotaryEmbedding):
  127. pass
  128. class Exaone4Attention(nn.Module):
  129. def __init__(self, config: Exaone4Config, layer_idx: int):
  130. super().__init__()
  131. self.config = config
  132. self.layer_idx = layer_idx
  133. self.num_attention_heads = config.num_attention_heads
  134. self.num_key_value_heads = config.num_key_value_heads
  135. self.hidden_size = config.hidden_size
  136. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  137. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  138. self.attention_dropout = config.attention_dropout
  139. self.is_causal = True
  140. self.scaling = self.head_dim**-0.5
  141. self.sliding_window = config.sliding_window
  142. self.sliding_window_pattern = config.sliding_window_pattern
  143. layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  144. self.is_sliding = layer_type == "sliding_attention"
  145. self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
  146. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  147. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  148. self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
  149. self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
  150. self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
  151. def forward(
  152. self,
  153. hidden_states: torch.Tensor,
  154. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  155. attention_mask: torch.Tensor | None = None,
  156. past_key_values: Cache | None = None,
  157. **kwargs: Unpack[TransformersKwargs],
  158. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  159. input_shape = hidden_states.shape[:-1]
  160. hidden_shape = (*input_shape, -1, self.head_dim)
  161. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  162. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  163. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  164. # We use QK-norm
  165. query_states = self.q_norm(query_states)
  166. key_states = self.k_norm(key_states)
  167. cos, sin = position_embeddings
  168. # We use global NoPE for hybrid attention model
  169. if self.sliding_window is None or self.is_sliding:
  170. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  171. if past_key_values is not None:
  172. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  173. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  174. self.config._attn_implementation, eager_attention_forward
  175. )
  176. attn_output, attn_weights = attention_interface(
  177. self,
  178. query_states,
  179. key_states,
  180. value_states,
  181. attention_mask,
  182. dropout=0.0 if not self.training else self.attention_dropout,
  183. scaling=self.scaling,
  184. sliding_window=self.sliding_window if self.is_sliding else None,
  185. **kwargs,
  186. )
  187. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  188. attn_output = self.o_proj(attn_output)
  189. return attn_output, attn_weights
  190. class Exaone4MLP(Olmo2MLP):
  191. pass
  192. class Exaone4DecoderLayer(Olmo2DecoderLayer):
  193. pass
  194. class Exaone4PreTrainedModel(LlamaPreTrainedModel):
  195. config_class = Exaone4Config
  196. _no_split_modules = ["Exaone4DecoderLayer"]
  197. class Exaone4Model(Exaone4PreTrainedModel, LlamaModel):
  198. def __init__(self, config: Exaone4Config):
  199. super().__init__(config)
  200. self.layers = nn.ModuleList(
  201. [Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  202. )
  203. self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  204. # Initialize weights and apply final processing
  205. self.post_init()
  206. @merge_with_config_defaults
  207. @capture_outputs
  208. def forward(
  209. self,
  210. input_ids: torch.LongTensor | None = None,
  211. attention_mask: torch.Tensor | None = None,
  212. position_ids: torch.LongTensor | None = None,
  213. past_key_values: Cache | None = None,
  214. inputs_embeds: torch.FloatTensor | None = None,
  215. use_cache: bool | None = None,
  216. **kwargs: Unpack[TransformersKwargs],
  217. ) -> tuple | BaseModelOutputWithPast:
  218. if (input_ids is None) ^ (inputs_embeds is not None):
  219. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  220. if inputs_embeds is None:
  221. inputs_embeds = self.embed_tokens(input_ids)
  222. if use_cache and past_key_values is None:
  223. past_key_values = DynamicCache(config=self.config)
  224. if position_ids is None:
  225. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  226. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  227. position_ids = position_ids.unsqueeze(0)
  228. # It may already have been prepared by e.g. `generate`
  229. if not isinstance(causal_mask_mapping := attention_mask, dict):
  230. # Prepare mask arguments
  231. mask_kwargs = {
  232. "config": self.config,
  233. "inputs_embeds": inputs_embeds,
  234. "attention_mask": attention_mask,
  235. "past_key_values": past_key_values,
  236. "position_ids": position_ids,
  237. }
  238. # Create the masks
  239. causal_mask_mapping = {
  240. "full_attention": create_causal_mask(**mask_kwargs),
  241. }
  242. if "sliding_attention" in self.config.layer_types:
  243. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  244. hidden_states = inputs_embeds
  245. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  246. for i, decoder_layer in enumerate(self.layers):
  247. layer_type = self.config.layer_types[i]
  248. hidden_states = decoder_layer(
  249. hidden_states,
  250. attention_mask=causal_mask_mapping[layer_type],
  251. position_ids=position_ids,
  252. past_key_values=past_key_values,
  253. use_cache=use_cache,
  254. position_embeddings=position_embeddings,
  255. **kwargs,
  256. )
  257. hidden_states = self.norm(hidden_states)
  258. return BaseModelOutputWithPast(
  259. last_hidden_state=hidden_states,
  260. past_key_values=past_key_values if use_cache else None,
  261. )
  262. class Exaone4ForCausalLM(LlamaForCausalLM):
  263. def forward(
  264. self,
  265. input_ids: torch.LongTensor | None = None,
  266. attention_mask: torch.Tensor | None = None,
  267. position_ids: torch.LongTensor | None = None,
  268. past_key_values: Cache | None = None,
  269. inputs_embeds: torch.FloatTensor | None = None,
  270. labels: torch.LongTensor | None = None,
  271. use_cache: bool | None = None,
  272. logits_to_keep: int | torch.Tensor = 0,
  273. **kwargs: Unpack[TransformersKwargs],
  274. ) -> CausalLMOutputWithPast:
  275. r"""
  276. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  277. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  278. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  279. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  280. Example:
  281. ```python
  282. >>> from transformers import AutoModelForCausalLM, AutoTokenizer
  283. >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
  284. >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
  285. >>> prompt = "Explain how wonderful you are"
  286. >>> messages = [
  287. {"role": "system", "content": "You are a helpful assistant."},
  288. {"role": "user", "content": prompt}
  289. ]
  290. >>> input_ids = tokenizer.apply_chat_template(
  291. messages,
  292. tokenize=True,
  293. add_generation_prompt=True,
  294. return_tensors="pt",
  295. enable_thinking=False,
  296. )
  297. >>> output = model.generate(input_ids, max_new_tokens=128)
  298. >>> tokenizer.decode(output[0], skip_special_tokens=False)
  299. "[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n<think>\n\n</think>\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out"
  300. ```
  301. """
  302. super().forward(
  303. input_ids=input_ids,
  304. attention_mask=attention_mask,
  305. position_ids=position_ids,
  306. past_key_values=past_key_values,
  307. inputs_embeds=inputs_embeds,
  308. labels=labels,
  309. use_cache=use_cache,
  310. logits_to_keep=logits_to_keep,
  311. **kwargs,
  312. )
  313. class Exaone4ForSequenceClassification(LlamaForSequenceClassification):
  314. pass
  315. class Exaone4ForTokenClassification(LlamaForTokenClassification):
  316. pass
  317. class Exaone4ForQuestionAnswering(LlamaForQuestionAnswering):
  318. pass
  319. __all__ = [
  320. "Exaone4Config",
  321. "Exaone4PreTrainedModel",
  322. "Exaone4Model",
  323. "Exaone4ForCausalLM",
  324. "Exaone4ForSequenceClassification",
  325. "Exaone4ForTokenClassification",
  326. "Exaone4ForQuestionAnswering",
  327. ]