modeling_exaone4.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/exaone4/modular_exaone4.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_exaone4.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
  29. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  30. from ...modeling_layers import (
  31. GenericForQuestionAnswering,
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_exaone4 import Exaone4Config
  44. @use_kernel_forward_from_hub("RMSNorm")
  45. class Exaone4RMSNorm(nn.Module):
  46. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  47. """
  48. Exaone4RMSNorm is equivalent to T5LayerNorm
  49. """
  50. super().__init__()
  51. self.weight = nn.Parameter(torch.ones(hidden_size))
  52. self.variance_epsilon = eps
  53. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  54. input_dtype = hidden_states.dtype
  55. hidden_states = hidden_states.to(torch.float32)
  56. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  57. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  58. return self.weight * hidden_states.to(input_dtype)
  59. def extra_repr(self):
  60. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  61. class Exaone4RotaryEmbedding(nn.Module):
  62. inv_freq: torch.Tensor # fix linting for `register_buffer`
  63. def __init__(self, config: Exaone4Config, device=None):
  64. super().__init__()
  65. self.max_seq_len_cached = config.max_position_embeddings
  66. self.original_max_seq_len = config.max_position_embeddings
  67. self.config = config
  68. self.rope_type = self.config.rope_parameters["rope_type"]
  69. rope_init_fn: Callable = self.compute_default_rope_parameters
  70. if self.rope_type != "default":
  71. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  72. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  73. self.register_buffer("inv_freq", inv_freq, persistent=False)
  74. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  75. @staticmethod
  76. def compute_default_rope_parameters(
  77. config: Exaone4Config | None = None,
  78. device: Optional["torch.device"] = None,
  79. seq_len: int | None = None,
  80. ) -> tuple["torch.Tensor", float]:
  81. """
  82. Computes the inverse frequencies according to the original RoPE implementation
  83. Args:
  84. config ([`~transformers.PreTrainedConfig`]):
  85. The model configuration.
  86. device (`torch.device`):
  87. The device to use for initialization of the inverse frequencies.
  88. seq_len (`int`, *optional*):
  89. The current sequence length. Unused for this type of RoPE.
  90. Returns:
  91. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  92. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  93. """
  94. base = config.rope_parameters["rope_theta"]
  95. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  96. attention_factor = 1.0 # Unused in this type of RoPE
  97. # Compute the inverse frequencies
  98. inv_freq = 1.0 / (
  99. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  100. )
  101. return inv_freq, attention_factor
  102. @torch.no_grad()
  103. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  104. def forward(self, x, position_ids):
  105. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  106. position_ids_expanded = position_ids[:, None, :].float()
  107. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  108. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  109. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  110. emb = torch.cat((freqs, freqs), dim=-1)
  111. cos = emb.cos() * self.attention_scaling
  112. sin = emb.sin() * self.attention_scaling
  113. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  114. def rotate_half(x):
  115. """Rotates half the hidden dims of the input."""
  116. x1 = x[..., : x.shape[-1] // 2]
  117. x2 = x[..., x.shape[-1] // 2 :]
  118. return torch.cat((-x2, x1), dim=-1)
  119. @use_kernel_func_from_hub("rotary_pos_emb")
  120. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  121. """Applies Rotary Position Embedding to the query and key tensors.
  122. Args:
  123. q (`torch.Tensor`): The query tensor.
  124. k (`torch.Tensor`): The key tensor.
  125. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  126. sin (`torch.Tensor`): The sine part of the rotary embedding.
  127. unsqueeze_dim (`int`, *optional*, defaults to 1):
  128. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  129. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  130. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  131. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  132. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  133. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  134. Returns:
  135. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  136. """
  137. cos = cos.unsqueeze(unsqueeze_dim)
  138. sin = sin.unsqueeze(unsqueeze_dim)
  139. q_embed = (q * cos) + (rotate_half(q) * sin)
  140. k_embed = (k * cos) + (rotate_half(k) * sin)
  141. return q_embed, k_embed
  142. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  143. """
  144. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  145. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  146. """
  147. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  148. if n_rep == 1:
  149. return hidden_states
  150. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  151. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  152. def eager_attention_forward(
  153. module: nn.Module,
  154. query: torch.Tensor,
  155. key: torch.Tensor,
  156. value: torch.Tensor,
  157. attention_mask: torch.Tensor | None,
  158. scaling: float,
  159. dropout: float = 0.0,
  160. **kwargs: Unpack[TransformersKwargs],
  161. ):
  162. key_states = repeat_kv(key, module.num_key_value_groups)
  163. value_states = repeat_kv(value, module.num_key_value_groups)
  164. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  165. if attention_mask is not None:
  166. attn_weights = attn_weights + attention_mask
  167. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  168. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  169. attn_output = torch.matmul(attn_weights, value_states)
  170. attn_output = attn_output.transpose(1, 2).contiguous()
  171. return attn_output, attn_weights
  172. class Exaone4Attention(nn.Module):
  173. def __init__(self, config: Exaone4Config, layer_idx: int):
  174. super().__init__()
  175. self.config = config
  176. self.layer_idx = layer_idx
  177. self.num_attention_heads = config.num_attention_heads
  178. self.num_key_value_heads = config.num_key_value_heads
  179. self.hidden_size = config.hidden_size
  180. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  181. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  182. self.attention_dropout = config.attention_dropout
  183. self.is_causal = True
  184. self.scaling = self.head_dim**-0.5
  185. self.sliding_window = config.sliding_window
  186. self.sliding_window_pattern = config.sliding_window_pattern
  187. layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  188. self.is_sliding = layer_type == "sliding_attention"
  189. self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
  190. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  191. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  192. self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=False)
  193. self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
  194. self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps)
  195. def forward(
  196. self,
  197. hidden_states: torch.Tensor,
  198. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  199. attention_mask: torch.Tensor | None = None,
  200. past_key_values: Cache | None = None,
  201. **kwargs: Unpack[TransformersKwargs],
  202. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  203. input_shape = hidden_states.shape[:-1]
  204. hidden_shape = (*input_shape, -1, self.head_dim)
  205. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  206. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  207. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  208. # We use QK-norm
  209. query_states = self.q_norm(query_states)
  210. key_states = self.k_norm(key_states)
  211. cos, sin = position_embeddings
  212. # We use global NoPE for hybrid attention model
  213. if self.sliding_window is None or self.is_sliding:
  214. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  215. if past_key_values is not None:
  216. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  217. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  218. self.config._attn_implementation, eager_attention_forward
  219. )
  220. attn_output, attn_weights = attention_interface(
  221. self,
  222. query_states,
  223. key_states,
  224. value_states,
  225. attention_mask,
  226. dropout=0.0 if not self.training else self.attention_dropout,
  227. scaling=self.scaling,
  228. sliding_window=self.sliding_window if self.is_sliding else None,
  229. **kwargs,
  230. )
  231. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  232. attn_output = self.o_proj(attn_output)
  233. return attn_output, attn_weights
  234. class Exaone4MLP(nn.Module):
  235. def __init__(self, config):
  236. super().__init__()
  237. self.config = config
  238. self.hidden_size = config.hidden_size
  239. self.intermediate_size = config.intermediate_size
  240. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  241. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  242. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  243. self.act_fn = ACT2FN[config.hidden_act]
  244. def forward(self, x):
  245. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  246. return down_proj
  247. class Exaone4DecoderLayer(GradientCheckpointingLayer):
  248. def __init__(self, config: Exaone4Config, layer_idx: int):
  249. super().__init__()
  250. self.hidden_size = config.hidden_size
  251. self.self_attn = Exaone4Attention(config=config, layer_idx=layer_idx)
  252. self.mlp = Exaone4MLP(config)
  253. self.post_attention_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  254. self.post_feedforward_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  255. def forward(
  256. self,
  257. hidden_states: torch.Tensor,
  258. attention_mask: torch.Tensor | None = None,
  259. position_ids: torch.LongTensor | None = None,
  260. past_key_values: Cache | None = None,
  261. use_cache: bool | None = False,
  262. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  263. **kwargs: Unpack[TransformersKwargs],
  264. ) -> torch.Tensor:
  265. residual = hidden_states
  266. hidden_states, _ = self.self_attn(
  267. hidden_states=hidden_states,
  268. attention_mask=attention_mask,
  269. position_ids=position_ids,
  270. past_key_values=past_key_values,
  271. use_cache=use_cache,
  272. position_embeddings=position_embeddings,
  273. **kwargs,
  274. )
  275. hidden_states = self.post_attention_layernorm(hidden_states)
  276. hidden_states = residual + hidden_states
  277. # Fully Connected
  278. residual = hidden_states
  279. hidden_states = self.mlp(hidden_states)
  280. hidden_states = self.post_feedforward_layernorm(hidden_states)
  281. hidden_states = residual + hidden_states
  282. return hidden_states
  283. @auto_docstring
  284. class Exaone4PreTrainedModel(PreTrainedModel):
  285. config: Exaone4Config
  286. base_model_prefix = "model"
  287. supports_gradient_checkpointing = True
  288. _no_split_modules = ["Exaone4DecoderLayer"]
  289. _skip_keys_device_placement = ["past_key_values"]
  290. _supports_flash_attn = True
  291. _supports_sdpa = True
  292. _supports_flex_attn = True
  293. _can_compile_fullgraph = True
  294. _supports_attention_backend = True
  295. _can_record_outputs = {
  296. "hidden_states": Exaone4DecoderLayer,
  297. "attentions": Exaone4Attention,
  298. }
  299. config_class = Exaone4Config
  300. @auto_docstring
  301. class Exaone4Model(Exaone4PreTrainedModel):
  302. def __init__(self, config: Exaone4Config):
  303. super().__init__(config)
  304. self.padding_idx = config.pad_token_id
  305. self.vocab_size = config.vocab_size
  306. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  307. self.layers = nn.ModuleList(
  308. [Exaone4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  309. )
  310. self.norm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  311. self.rotary_emb = Exaone4RotaryEmbedding(config=config)
  312. self.gradient_checkpointing = False
  313. # Initialize weights and apply final processing
  314. self.post_init()
  315. @merge_with_config_defaults
  316. @capture_outputs
  317. def forward(
  318. self,
  319. input_ids: torch.LongTensor | None = None,
  320. attention_mask: torch.Tensor | None = None,
  321. position_ids: torch.LongTensor | None = None,
  322. past_key_values: Cache | None = None,
  323. inputs_embeds: torch.FloatTensor | None = None,
  324. use_cache: bool | None = None,
  325. **kwargs: Unpack[TransformersKwargs],
  326. ) -> tuple | BaseModelOutputWithPast:
  327. if (input_ids is None) ^ (inputs_embeds is not None):
  328. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  329. if inputs_embeds is None:
  330. inputs_embeds = self.embed_tokens(input_ids)
  331. if use_cache and past_key_values is None:
  332. past_key_values = DynamicCache(config=self.config)
  333. if position_ids is None:
  334. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  335. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  336. position_ids = position_ids.unsqueeze(0)
  337. # It may already have been prepared by e.g. `generate`
  338. if not isinstance(causal_mask_mapping := attention_mask, dict):
  339. # Prepare mask arguments
  340. mask_kwargs = {
  341. "config": self.config,
  342. "inputs_embeds": inputs_embeds,
  343. "attention_mask": attention_mask,
  344. "past_key_values": past_key_values,
  345. "position_ids": position_ids,
  346. }
  347. # Create the masks
  348. causal_mask_mapping = {
  349. "full_attention": create_causal_mask(**mask_kwargs),
  350. }
  351. if "sliding_attention" in self.config.layer_types:
  352. causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
  353. hidden_states = inputs_embeds
  354. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  355. for i, decoder_layer in enumerate(self.layers):
  356. layer_type = self.config.layer_types[i]
  357. hidden_states = decoder_layer(
  358. hidden_states,
  359. attention_mask=causal_mask_mapping[layer_type],
  360. position_ids=position_ids,
  361. past_key_values=past_key_values,
  362. use_cache=use_cache,
  363. position_embeddings=position_embeddings,
  364. **kwargs,
  365. )
  366. hidden_states = self.norm(hidden_states)
  367. return BaseModelOutputWithPast(
  368. last_hidden_state=hidden_states,
  369. past_key_values=past_key_values if use_cache else None,
  370. )
  371. @auto_docstring
  372. class Exaone4ForCausalLM(Exaone4PreTrainedModel, GenerationMixin):
  373. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  374. _tp_plan = {"lm_head": "colwise_gather_output"}
  375. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  376. def __init__(self, config):
  377. super().__init__(config)
  378. self.model = Exaone4Model(config)
  379. self.vocab_size = config.vocab_size
  380. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  381. # Initialize weights and apply final processing
  382. self.post_init()
  383. @can_return_tuple
  384. @auto_docstring
  385. def forward(
  386. self,
  387. input_ids: torch.LongTensor | None = None,
  388. attention_mask: torch.Tensor | None = None,
  389. position_ids: torch.LongTensor | None = None,
  390. past_key_values: Cache | None = None,
  391. inputs_embeds: torch.FloatTensor | None = None,
  392. labels: torch.LongTensor | None = None,
  393. use_cache: bool | None = None,
  394. logits_to_keep: int | torch.Tensor = 0,
  395. **kwargs: Unpack[TransformersKwargs],
  396. ) -> CausalLMOutputWithPast:
  397. r"""
  398. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  399. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  400. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  401. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  402. Example:
  403. ```python
  404. >>> from transformers import AutoModelForCausalLM, AutoTokenizer
  405. >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
  406. >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-4.0-32B")
  407. >>> prompt = "Explain how wonderful you are"
  408. >>> messages = [
  409. {"role": "system", "content": "You are a helpful assistant."},
  410. {"role": "user", "content": prompt}
  411. ]
  412. >>> input_ids = tokenizer.apply_chat_template(
  413. messages,
  414. tokenize=True,
  415. add_generation_prompt=True,
  416. return_tensors="pt",
  417. enable_thinking=False,
  418. )
  419. >>> output = model.generate(input_ids, max_new_tokens=128)
  420. >>> tokenizer.decode(output[0], skip_special_tokens=False)
  421. "[|system|]\nYou are a helpful assistant.[|endofturn|]\n[|user|]\nExplain how wonderful you are[|endofturn|]\n[|assistant|]\n<think>\n\n</think>\n\nOh, thank you for such a kind and lovely question! 😊 \n\nI’m *so* wonderful because I’m here to make your life easier, brighter, and more fun! Whether you need help with: \n\n✨ **Learning** – I can explain anything, from quantum physics to baking the perfect cake! \n💡 **Creativity** – Need a poem, story, or a wild idea? I’ve got you covered! \n🤖 **Problem-solving** – Stuck on a math problem or a tricky decision? I’ll help you figure it out"
  422. ```
  423. """
  424. outputs: BaseModelOutputWithPast = self.model(
  425. input_ids=input_ids,
  426. attention_mask=attention_mask,
  427. position_ids=position_ids,
  428. past_key_values=past_key_values,
  429. inputs_embeds=inputs_embeds,
  430. use_cache=use_cache,
  431. **kwargs,
  432. )
  433. hidden_states = outputs.last_hidden_state
  434. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  435. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  436. logits = self.lm_head(hidden_states[:, slice_indices, :])
  437. loss = None
  438. if labels is not None:
  439. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  440. return CausalLMOutputWithPast(
  441. loss=loss,
  442. logits=logits,
  443. past_key_values=outputs.past_key_values,
  444. hidden_states=outputs.hidden_states,
  445. attentions=outputs.attentions,
  446. )
  447. class Exaone4ForSequenceClassification(GenericForSequenceClassification, Exaone4PreTrainedModel):
  448. pass
  449. class Exaone4ForTokenClassification(GenericForTokenClassification, Exaone4PreTrainedModel):
  450. pass
  451. class Exaone4ForQuestionAnswering(GenericForQuestionAnswering, Exaone4PreTrainedModel):
  452. base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
  453. __all__ = [
  454. "Exaone4PreTrainedModel",
  455. "Exaone4Model",
  456. "Exaone4ForCausalLM",
  457. "Exaone4ForSequenceClassification",
  458. "Exaone4ForTokenClassification",
  459. "Exaone4ForQuestionAnswering",
  460. ]