modeling_glm4.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glm4/modular_glm4.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glm4.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. import torch.nn as nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import (
  32. GenericForSequenceClassification,
  33. GenericForTokenClassification,
  34. GradientCheckpointingLayer,
  35. )
  36. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_glm4 import Glm4Config
  44. class Glm4MLP(nn.Module):
  45. def __init__(self, config):
  46. super().__init__()
  47. self.config = config
  48. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  49. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  50. self.activation_fn = ACT2FN[config.hidden_act]
  51. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  52. up_states = self.gate_up_proj(hidden_states)
  53. gate, up_states = up_states.chunk(2, dim=-1)
  54. up_states = up_states * self.activation_fn(gate)
  55. return self.down_proj(up_states)
  56. class Glm4DecoderLayer(GradientCheckpointingLayer):
  57. def __init__(self, config: Glm4Config, layer_idx: int):
  58. super().__init__()
  59. self.hidden_size = config.hidden_size
  60. self.self_attn = Glm4Attention(config=config, layer_idx=layer_idx)
  61. self.mlp = Glm4MLP(config)
  62. self.input_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  63. self.post_attention_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  64. self.post_self_attn_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  65. self.post_mlp_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  66. def forward(
  67. self,
  68. hidden_states: torch.Tensor,
  69. attention_mask: torch.Tensor | None = None,
  70. position_ids: torch.LongTensor | None = None,
  71. past_key_values: Cache | None = None,
  72. use_cache: bool | None = False,
  73. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  74. **kwargs: Unpack[FlashAttentionKwargs],
  75. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  76. residual = hidden_states
  77. hidden_states = self.input_layernorm(hidden_states)
  78. hidden_states, _ = self.self_attn(
  79. hidden_states=hidden_states,
  80. attention_mask=attention_mask,
  81. position_ids=position_ids,
  82. past_key_values=past_key_values,
  83. use_cache=use_cache,
  84. position_embeddings=position_embeddings,
  85. **kwargs,
  86. )
  87. hidden_states = self.post_self_attn_layernorm(hidden_states)
  88. hidden_states = residual + hidden_states
  89. residual = hidden_states
  90. hidden_states = self.post_attention_layernorm(hidden_states)
  91. hidden_states = self.mlp(hidden_states)
  92. hidden_states = self.post_mlp_layernorm(hidden_states)
  93. hidden_states = residual + hidden_states
  94. return hidden_states
  95. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  96. """
  97. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  98. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  99. """
  100. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  101. if n_rep == 1:
  102. return hidden_states
  103. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  104. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  105. def eager_attention_forward(
  106. module: nn.Module,
  107. query: torch.Tensor,
  108. key: torch.Tensor,
  109. value: torch.Tensor,
  110. attention_mask: torch.Tensor | None,
  111. scaling: float,
  112. dropout: float = 0.0,
  113. **kwargs: Unpack[TransformersKwargs],
  114. ):
  115. key_states = repeat_kv(key, module.num_key_value_groups)
  116. value_states = repeat_kv(value, module.num_key_value_groups)
  117. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  118. if attention_mask is not None:
  119. attn_weights = attn_weights + attention_mask
  120. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  121. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  122. attn_output = torch.matmul(attn_weights, value_states)
  123. attn_output = attn_output.transpose(1, 2).contiguous()
  124. return attn_output, attn_weights
  125. def rotate_half(x):
  126. """Rotates half the hidden dims of the input."""
  127. x1 = x[..., 0::2]
  128. x2 = x[..., 1::2]
  129. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  130. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  131. """Applies Rotary Position Embedding to the query and key tensors.
  132. Args:
  133. q (`torch.Tensor`): The query tensor.
  134. k (`torch.Tensor`): The key tensor.
  135. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  136. sin (`torch.Tensor`): The sine part of the rotary embedding.
  137. unsqueeze_dim (`int`, *optional*, defaults to 1):
  138. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  139. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  140. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  141. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  142. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  143. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  144. Returns:
  145. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  146. """
  147. cos = cos.unsqueeze(unsqueeze_dim)
  148. sin = sin.unsqueeze(unsqueeze_dim)
  149. # Interleave them instead of usual shape
  150. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  151. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  152. # Keep half or full tensor for later concatenation
  153. rotary_dim = cos.shape[-1]
  154. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  155. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  156. # Apply rotary embeddings on the first half or full tensor
  157. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  158. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  159. # Concatenate back to full shape
  160. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  161. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  162. return q_embed, k_embed
  163. @use_kernelized_func(apply_rotary_pos_emb)
  164. class Glm4Attention(nn.Module):
  165. """Multi-headed attention from 'Attention Is All You Need' paper"""
  166. def __init__(self, config: Glm4Config, layer_idx: int | None = None):
  167. super().__init__()
  168. self.config = config
  169. self.layer_idx = layer_idx
  170. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  171. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  172. self.scaling = self.head_dim**-0.5
  173. self.attention_dropout = config.attention_dropout
  174. self.is_causal = True
  175. self.q_proj = nn.Linear(
  176. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  177. )
  178. self.k_proj = nn.Linear(
  179. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  180. )
  181. self.v_proj = nn.Linear(
  182. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  183. )
  184. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  185. def forward(
  186. self,
  187. hidden_states: torch.Tensor,
  188. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  189. attention_mask: torch.Tensor | None = None,
  190. past_key_values: Cache | None = None,
  191. **kwargs: Unpack[TransformersKwargs],
  192. ) -> tuple[torch.Tensor, torch.Tensor]:
  193. input_shape = hidden_states.shape[:-1]
  194. hidden_shape = (*input_shape, -1, self.head_dim)
  195. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  196. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  197. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  198. cos, sin = position_embeddings
  199. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  200. if past_key_values is not None:
  201. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  202. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  203. self.config._attn_implementation, eager_attention_forward
  204. )
  205. attn_output, attn_weights = attention_interface(
  206. self,
  207. query_states,
  208. key_states,
  209. value_states,
  210. attention_mask,
  211. dropout=0.0 if not self.training else self.attention_dropout,
  212. scaling=self.scaling,
  213. **kwargs,
  214. )
  215. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  216. attn_output = self.o_proj(attn_output)
  217. return attn_output, attn_weights
  218. class Glm4RotaryEmbedding(nn.Module):
  219. inv_freq: torch.Tensor # fix linting for `register_buffer`
  220. def __init__(self, config: Glm4Config, device=None):
  221. super().__init__()
  222. self.max_seq_len_cached = config.max_position_embeddings
  223. self.original_max_seq_len = config.max_position_embeddings
  224. self.config = config
  225. self.rope_type = self.config.rope_parameters["rope_type"]
  226. rope_init_fn: Callable = self.compute_default_rope_parameters
  227. if self.rope_type != "default":
  228. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  229. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  230. self.register_buffer("inv_freq", inv_freq, persistent=False)
  231. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  232. @staticmethod
  233. def compute_default_rope_parameters(
  234. config: Glm4Config | None = None,
  235. device: Optional["torch.device"] = None,
  236. seq_len: int | None = None,
  237. ) -> tuple["torch.Tensor", float]:
  238. """
  239. Computes the inverse frequencies according to the original RoPE implementation
  240. Args:
  241. config ([`~transformers.PreTrainedConfig`]):
  242. The model configuration.
  243. device (`torch.device`):
  244. The device to use for initialization of the inverse frequencies.
  245. seq_len (`int`, *optional*):
  246. The current sequence length. Unused for this type of RoPE.
  247. Returns:
  248. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  249. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  250. """
  251. base = config.rope_parameters["rope_theta"]
  252. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  253. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  254. dim = int(head_dim * partial_rotary_factor)
  255. attention_factor = 1.0 # Unused in this type of RoPE
  256. # Compute the inverse frequencies
  257. inv_freq = 1.0 / (
  258. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  259. )
  260. return inv_freq, attention_factor
  261. @torch.no_grad()
  262. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  263. def forward(self, x, position_ids):
  264. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  265. position_ids_expanded = position_ids[:, None, :].float()
  266. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  267. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  268. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  269. emb = torch.cat((freqs, freqs), dim=-1)
  270. cos = emb.cos() * self.attention_scaling
  271. sin = emb.sin() * self.attention_scaling
  272. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  273. @use_kernel_forward_from_hub("RMSNorm")
  274. class Glm4RMSNorm(nn.Module):
  275. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  276. """
  277. Glm4RMSNorm is equivalent to T5LayerNorm
  278. """
  279. super().__init__()
  280. self.weight = nn.Parameter(torch.ones(hidden_size))
  281. self.variance_epsilon = eps
  282. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  283. input_dtype = hidden_states.dtype
  284. hidden_states = hidden_states.to(torch.float32)
  285. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  286. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  287. return self.weight * hidden_states.to(input_dtype)
  288. def extra_repr(self):
  289. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  290. @auto_docstring
  291. class Glm4PreTrainedModel(PreTrainedModel):
  292. config: Glm4Config
  293. base_model_prefix = "model"
  294. supports_gradient_checkpointing = True
  295. _no_split_modules = ["Glm4DecoderLayer"]
  296. _skip_keys_device_placement = ["past_key_values"]
  297. _supports_flash_attn = True
  298. _supports_sdpa = True
  299. _supports_flex_attn = True
  300. _can_compile_fullgraph = True
  301. _supports_attention_backend = True
  302. _can_record_outputs = {
  303. "hidden_states": Glm4DecoderLayer,
  304. "attentions": Glm4Attention,
  305. }
  306. @auto_docstring
  307. class Glm4Model(Glm4PreTrainedModel):
  308. def __init__(self, config: Glm4Config):
  309. super().__init__(config)
  310. self.padding_idx = config.pad_token_id
  311. self.vocab_size = config.vocab_size
  312. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  313. self.layers = nn.ModuleList(
  314. [Glm4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  315. )
  316. self.norm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  317. self.rotary_emb = Glm4RotaryEmbedding(config=config)
  318. self.gradient_checkpointing = False
  319. # Initialize weights and apply final processing
  320. self.post_init()
  321. @merge_with_config_defaults
  322. @capture_outputs
  323. @auto_docstring
  324. def forward(
  325. self,
  326. input_ids: torch.LongTensor | None = None,
  327. attention_mask: torch.Tensor | None = None,
  328. position_ids: torch.LongTensor | None = None,
  329. past_key_values: Cache | None = None,
  330. inputs_embeds: torch.FloatTensor | None = None,
  331. use_cache: bool | None = None,
  332. **kwargs: Unpack[TransformersKwargs],
  333. ) -> BaseModelOutputWithPast:
  334. if (input_ids is None) ^ (inputs_embeds is not None):
  335. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  336. if inputs_embeds is None:
  337. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  338. if use_cache and past_key_values is None:
  339. past_key_values = DynamicCache(config=self.config)
  340. if position_ids is None:
  341. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  342. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  343. position_ids = position_ids.unsqueeze(0)
  344. causal_mask = create_causal_mask(
  345. config=self.config,
  346. inputs_embeds=inputs_embeds,
  347. attention_mask=attention_mask,
  348. past_key_values=past_key_values,
  349. position_ids=position_ids,
  350. )
  351. hidden_states = inputs_embeds
  352. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  353. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  354. hidden_states = decoder_layer(
  355. hidden_states,
  356. attention_mask=causal_mask,
  357. position_embeddings=position_embeddings,
  358. position_ids=position_ids,
  359. past_key_values=past_key_values,
  360. use_cache=use_cache,
  361. **kwargs,
  362. )
  363. hidden_states = self.norm(hidden_states)
  364. return BaseModelOutputWithPast(
  365. last_hidden_state=hidden_states,
  366. past_key_values=past_key_values,
  367. )
  368. @auto_docstring
  369. class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin):
  370. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  371. _tp_plan = {"lm_head": "colwise_gather_output"}
  372. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  373. def __init__(self, config):
  374. super().__init__(config)
  375. self.model = Glm4Model(config)
  376. self.vocab_size = config.vocab_size
  377. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  378. # Initialize weights and apply final processing
  379. self.post_init()
  380. @can_return_tuple
  381. @auto_docstring
  382. def forward(
  383. self,
  384. input_ids: torch.LongTensor | None = None,
  385. attention_mask: torch.Tensor | None = None,
  386. position_ids: torch.LongTensor | None = None,
  387. past_key_values: Cache | None = None,
  388. inputs_embeds: torch.FloatTensor | None = None,
  389. labels: torch.LongTensor | None = None,
  390. use_cache: bool | None = None,
  391. logits_to_keep: int | torch.Tensor = 0,
  392. **kwargs: Unpack[TransformersKwargs],
  393. ) -> tuple | CausalLMOutputWithPast:
  394. r"""
  395. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  396. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  397. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  398. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  399. Example:
  400. ```python
  401. >>> from transformers import AutoTokenizer, Glm4ForCausalLM
  402. >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
  403. >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
  404. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  405. >>> inputs = tokenizer(prompt, return_tensors="pt")
  406. >>> # Generate
  407. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  408. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  409. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  410. ```"""
  411. outputs: BaseModelOutputWithPast = self.model(
  412. input_ids=input_ids,
  413. attention_mask=attention_mask,
  414. position_ids=position_ids,
  415. past_key_values=past_key_values,
  416. inputs_embeds=inputs_embeds,
  417. use_cache=use_cache,
  418. **kwargs,
  419. )
  420. hidden_states = outputs.last_hidden_state
  421. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  422. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  423. logits = self.lm_head(hidden_states[:, slice_indices, :])
  424. loss = None
  425. if labels is not None:
  426. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  427. return CausalLMOutputWithPast(
  428. loss=loss,
  429. logits=logits,
  430. past_key_values=outputs.past_key_values,
  431. hidden_states=outputs.hidden_states,
  432. attentions=outputs.attentions,
  433. )
  434. class Glm4ForSequenceClassification(GenericForSequenceClassification, Glm4PreTrainedModel):
  435. pass
  436. class Glm4ForTokenClassification(GenericForTokenClassification, Glm4PreTrainedModel):
  437. pass
  438. __all__ = [
  439. "Glm4PreTrainedModel",
  440. "Glm4Model",
  441. "Glm4ForCausalLM",
  442. "Glm4ForSequenceClassification",
  443. "Glm4ForTokenClassification",
  444. ]