# Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """LG AI Research EXAONE Lab""" from collections.abc import Callable import torch from huggingface_hub.dataclasses import strict from torch import nn from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PreTrainedConfig from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from ...modeling_rope_utils import RopeParameters from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding from ..llama.modeling_llama import ( LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm, apply_rotary_pos_emb, eager_attention_forward, ) from ..olmo2.modeling_olmo2 import Olmo2DecoderLayer, Olmo2MLP logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "LGAI-EXAONE/EXAONE-4.0-32B" _CONFIG_FOR_DOC = "Exaone4Config" @auto_docstring(checkpoint="LGAI-EXAONE/EXAONE-4.0-32B") @strict class Exaone4Config(PreTrainedConfig): r""" sliding_window_pattern (`str`, *optional*): The pattern to use for sliding window attention. Can be one of: - `None`: No sliding window attention is used - `int`: Every `sliding_window` layers, use global attention, else use local attention. - `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The final layer always uses global attention regardless of the pattern. For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means: - Layer 0, 1, 2: local attention, - Layer 3: global attention, ...(repeated) Example: ```python >>> from transformers import Exaone4Model, Exaone4Config >>> # Initializing a EXAONE configuration >>> configuration = Exaone4Config() >>> # Initializing a model from configuration >>> model = Exaone4Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "exaone4" keys_to_ignore_at_inference = ["past_key_values"] # Default tensor parallel plan for base model `LlamaModel` base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.q_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.k_norm": "replicated_with_grad_allreduce", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } vocab_size: int = 102400 hidden_size: int = 4096 intermediate_size: int = 16384 num_hidden_layers: int = 32 num_attention_heads: int = 32 num_key_value_heads: int = 32 hidden_act: str = "silu" max_position_embeddings: int = 2048 initializer_range: float = 0.02 rms_norm_eps: float = 1e-5 use_cache: bool = True bos_token_id: int | None = 0 eos_token_id: int | list[int] | None = 2 pad_token_id: int | None = None tie_word_embeddings: bool = False rope_parameters: RopeParameters | dict | None = None attention_dropout: float | int = 0.0 sliding_window: int | None = 4096 sliding_window_pattern: str | int | None = 4 layer_types: list[str] | None = None def __post_init__(self, **kwargs): if self.sliding_window is None: self.sliding_window_pattern = 0 if self.layer_types is None: self.layer_types = [ "sliding_attention" if ((i + 1) % (self.sliding_window_pattern) != 0 and i < self.num_hidden_layers) else "full_attention" for i in range(self.num_hidden_layers) ] super().__post_init__(**kwargs) class Exaone4RMSNorm(LlamaRMSNorm): pass class Exaone4RotaryEmbedding(Gemma2RotaryEmbedding): pass class Exaone4Attention(nn.Module): def __init__(self, config: Exaone4Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.hidden_size = config.hidden_size self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout self.is_causal = True self.scaling = self.head_dim**-0.5 self.sliding_window = config.sliding_window self.sliding_window_pattern = config.sliding_window_pattern layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.is_sliding = layer_type == "sliding_attention" self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False) self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) # We use QK-norm query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) cos, sin = position_embeddings # We use global NoPE for hybrid attention model if self.sliding_window is None or self.is_sliding: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window if self.is_sliding else None, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Exaone4MLP(Olmo2MLP): pass class Exaone4DecoderLayer(Olmo2DecoderLayer): pass class Exaone4PreTrainedModel(LlamaPreTrainedModel): config_class = Exaone4Config _no_split_modules = ["Exaone4DecoderLayer"] class Exaone4Model(Exaone4PreTrainedModel, LlamaModel): def __init__(self, config: Exaone4Config): super().__init__(config) self.layers = nn.ModuleList( [Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Initialize weights and apply final processing self.post_init() @merge_with_config_defaults @capture_outputs def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens position_ids = position_ids.unsqueeze(0) # It may already have been prepared by e.g. `generate` if not isinstance(causal_mask_mapping := attention_mask, dict): # Prepare mask arguments mask_kwargs = { "config": self.config, "inputs_embeds": inputs_embeds, "attention_mask": attention_mask, "past_key_values": past_key_values, "position_ids": position_ids, } # Create the masks causal_mask_mapping = { "full_attention": create_causal_mask(**mask_kwargs), } if "sliding_attention" in self.config.layer_types: causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) for i, decoder_layer in enumerate(self.layers): layer_type = self.config.layer_types[i] hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask_mapping[layer_type], position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) class Exaone4ForCausalLM(LlamaForCausalLM): def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Example: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B") >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B") >>> prompt = "Explain how wonderful you are" >>> messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] >>> input_ids = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", enable_thinking=False, ) >>> output = model.generate(input_ids, max_new_tokens=128) >>> tokenizer.decode(output[0], skip_special_tokens=False) "[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n\n\n\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out" ``` """ super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, logits_to_keep=logits_to_keep, **kwargs, ) class Exaone4ForSequenceClassification(LlamaForSequenceClassification): pass class Exaone4ForTokenClassification(LlamaForTokenClassification): pass class Exaone4ForQuestionAnswering(LlamaForQuestionAnswering): pass __all__ = [ "Exaone4Config", "Exaone4PreTrainedModel", "Exaone4Model", "Exaone4ForCausalLM", "Exaone4ForSequenceClassification", "Exaone4ForTokenClassification", "Exaone4ForQuestionAnswering", ]