modeling_glm.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glm/modular_glm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 The GLM & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. import torch.nn as nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import (
  31. GenericForSequenceClassification,
  32. GenericForTokenClassification,
  33. GradientCheckpointingLayer,
  34. )
  35. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  36. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  37. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  38. from ...processing_utils import Unpack
  39. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  40. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  41. from ...utils.output_capturing import capture_outputs
  42. from .configuration_glm import GlmConfig
  43. class GlmMLP(nn.Module):
  44. def __init__(self, config):
  45. super().__init__()
  46. self.config = config
  47. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  48. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  49. self.activation_fn = ACT2FN[config.hidden_act]
  50. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  51. up_states = self.gate_up_proj(hidden_states)
  52. gate, up_states = up_states.chunk(2, dim=-1)
  53. up_states = up_states * self.activation_fn(gate)
  54. return self.down_proj(up_states)
  55. class GlmRotaryEmbedding(nn.Module):
  56. inv_freq: torch.Tensor # fix linting for `register_buffer`
  57. def __init__(self, config: GlmConfig, device=None):
  58. super().__init__()
  59. self.max_seq_len_cached = config.max_position_embeddings
  60. self.original_max_seq_len = config.max_position_embeddings
  61. self.config = config
  62. self.rope_type = self.config.rope_parameters["rope_type"]
  63. rope_init_fn: Callable = self.compute_default_rope_parameters
  64. if self.rope_type != "default":
  65. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  66. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  67. self.register_buffer("inv_freq", inv_freq, persistent=False)
  68. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  69. @staticmethod
  70. def compute_default_rope_parameters(
  71. config: GlmConfig | None = None,
  72. device: Optional["torch.device"] = None,
  73. seq_len: int | None = None,
  74. ) -> tuple["torch.Tensor", float]:
  75. """
  76. Computes the inverse frequencies according to the original RoPE implementation
  77. Args:
  78. config ([`~transformers.PreTrainedConfig`]):
  79. The model configuration.
  80. device (`torch.device`):
  81. The device to use for initialization of the inverse frequencies.
  82. seq_len (`int`, *optional*):
  83. The current sequence length. Unused for this type of RoPE.
  84. Returns:
  85. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  86. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  87. """
  88. base = config.rope_parameters["rope_theta"]
  89. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  90. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  91. dim = int(head_dim * partial_rotary_factor)
  92. attention_factor = 1.0 # Unused in this type of RoPE
  93. # Compute the inverse frequencies
  94. inv_freq = 1.0 / (
  95. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  96. )
  97. return inv_freq, attention_factor
  98. @torch.no_grad()
  99. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  100. def forward(self, x, position_ids):
  101. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  102. position_ids_expanded = position_ids[:, None, :].float()
  103. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  104. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  105. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  106. emb = torch.cat((freqs, freqs), dim=-1)
  107. cos = emb.cos() * self.attention_scaling
  108. sin = emb.sin() * self.attention_scaling
  109. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  110. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  111. """
  112. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  113. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  114. """
  115. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  116. if n_rep == 1:
  117. return hidden_states
  118. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  119. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  120. def eager_attention_forward(
  121. module: nn.Module,
  122. query: torch.Tensor,
  123. key: torch.Tensor,
  124. value: torch.Tensor,
  125. attention_mask: torch.Tensor | None,
  126. scaling: float,
  127. dropout: float = 0.0,
  128. **kwargs: Unpack[TransformersKwargs],
  129. ):
  130. key_states = repeat_kv(key, module.num_key_value_groups)
  131. value_states = repeat_kv(value, module.num_key_value_groups)
  132. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  133. if attention_mask is not None:
  134. attn_weights = attn_weights + attention_mask
  135. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  136. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  137. attn_output = torch.matmul(attn_weights, value_states)
  138. attn_output = attn_output.transpose(1, 2).contiguous()
  139. return attn_output, attn_weights
  140. def rotate_half(x):
  141. """Rotates half the hidden dims of the input."""
  142. x1 = x[..., 0::2]
  143. x2 = x[..., 1::2]
  144. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  145. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  146. """Applies Rotary Position Embedding to the query and key tensors.
  147. Args:
  148. q (`torch.Tensor`): The query tensor.
  149. k (`torch.Tensor`): The key tensor.
  150. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  151. sin (`torch.Tensor`): The sine part of the rotary embedding.
  152. unsqueeze_dim (`int`, *optional*, defaults to 1):
  153. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  154. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  155. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  156. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  157. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  158. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  159. Returns:
  160. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  161. """
  162. cos = cos.unsqueeze(unsqueeze_dim)
  163. sin = sin.unsqueeze(unsqueeze_dim)
  164. # Interleave them instead of usual shape
  165. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  166. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  167. # Keep half or full tensor for later concatenation
  168. rotary_dim = cos.shape[-1]
  169. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  170. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  171. # Apply rotary embeddings on the first half or full tensor
  172. q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
  173. k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
  174. # Concatenate back to full shape
  175. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  176. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  177. return q_embed, k_embed
  178. @use_kernelized_func(apply_rotary_pos_emb)
  179. class GlmAttention(nn.Module):
  180. """Multi-headed attention from 'Attention Is All You Need' paper"""
  181. def __init__(self, config: GlmConfig, layer_idx: int | None = None):
  182. super().__init__()
  183. self.config = config
  184. self.layer_idx = layer_idx
  185. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  186. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  187. self.scaling = self.head_dim**-0.5
  188. self.attention_dropout = config.attention_dropout
  189. self.is_causal = True
  190. self.q_proj = nn.Linear(
  191. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  192. )
  193. self.k_proj = nn.Linear(
  194. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  195. )
  196. self.v_proj = nn.Linear(
  197. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  198. )
  199. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  200. def forward(
  201. self,
  202. hidden_states: torch.Tensor,
  203. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  204. attention_mask: torch.Tensor | None = None,
  205. past_key_values: Cache | None = None,
  206. **kwargs: Unpack[TransformersKwargs],
  207. ) -> tuple[torch.Tensor, torch.Tensor]:
  208. input_shape = hidden_states.shape[:-1]
  209. hidden_shape = (*input_shape, -1, self.head_dim)
  210. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  211. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  212. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  213. cos, sin = position_embeddings
  214. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  215. if past_key_values is not None:
  216. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  217. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  218. self.config._attn_implementation, eager_attention_forward
  219. )
  220. attn_output, attn_weights = attention_interface(
  221. self,
  222. query_states,
  223. key_states,
  224. value_states,
  225. attention_mask,
  226. dropout=0.0 if not self.training else self.attention_dropout,
  227. scaling=self.scaling,
  228. **kwargs,
  229. )
  230. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  231. attn_output = self.o_proj(attn_output)
  232. return attn_output, attn_weights
  233. @use_kernel_forward_from_hub("RMSNorm")
  234. class GlmRMSNorm(nn.Module):
  235. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  236. """
  237. GlmRMSNorm is equivalent to T5LayerNorm
  238. """
  239. super().__init__()
  240. self.weight = nn.Parameter(torch.ones(hidden_size))
  241. self.variance_epsilon = eps
  242. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  243. input_dtype = hidden_states.dtype
  244. hidden_states = hidden_states.to(torch.float32)
  245. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  246. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  247. return self.weight * hidden_states.to(input_dtype)
  248. def extra_repr(self):
  249. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  250. class GlmDecoderLayer(GradientCheckpointingLayer):
  251. def __init__(self, config: GlmConfig, layer_idx: int):
  252. super().__init__()
  253. self.hidden_size = config.hidden_size
  254. self.self_attn = GlmAttention(config=config, layer_idx=layer_idx)
  255. self.mlp = GlmMLP(config)
  256. self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  257. self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  258. def forward(
  259. self,
  260. hidden_states: torch.Tensor,
  261. attention_mask: torch.Tensor | None = None,
  262. position_ids: torch.LongTensor | None = None,
  263. past_key_values: Cache | None = None,
  264. use_cache: bool | None = False,
  265. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  266. **kwargs: Unpack[TransformersKwargs],
  267. ) -> torch.Tensor:
  268. residual = hidden_states
  269. hidden_states = self.input_layernorm(hidden_states)
  270. # Self Attention
  271. hidden_states, _ = self.self_attn(
  272. hidden_states=hidden_states,
  273. attention_mask=attention_mask,
  274. position_ids=position_ids,
  275. past_key_values=past_key_values,
  276. use_cache=use_cache,
  277. position_embeddings=position_embeddings,
  278. **kwargs,
  279. )
  280. hidden_states = residual + hidden_states
  281. # Fully Connected
  282. residual = hidden_states
  283. hidden_states = self.post_attention_layernorm(hidden_states)
  284. hidden_states = self.mlp(hidden_states)
  285. hidden_states = residual + hidden_states
  286. return hidden_states
  287. @auto_docstring
  288. class GlmPreTrainedModel(PreTrainedModel):
  289. config: GlmConfig
  290. base_model_prefix = "model"
  291. supports_gradient_checkpointing = True
  292. _no_split_modules = ["GlmDecoderLayer"]
  293. _skip_keys_device_placement = ["past_key_values"]
  294. _supports_flash_attn = True
  295. _supports_sdpa = True
  296. _supports_flex_attn = True
  297. _can_compile_fullgraph = True
  298. _supports_attention_backend = True
  299. _can_record_outputs = {
  300. "hidden_states": GlmDecoderLayer,
  301. "attentions": GlmAttention,
  302. }
  303. @auto_docstring
  304. class GlmModel(GlmPreTrainedModel):
  305. def __init__(self, config: GlmConfig):
  306. super().__init__(config)
  307. self.padding_idx = config.pad_token_id
  308. self.vocab_size = config.vocab_size
  309. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  310. self.layers = nn.ModuleList(
  311. [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  312. )
  313. self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  314. self.rotary_emb = GlmRotaryEmbedding(config=config)
  315. self.gradient_checkpointing = False
  316. # Initialize weights and apply final processing
  317. self.post_init()
  318. @merge_with_config_defaults
  319. @capture_outputs
  320. @auto_docstring
  321. def forward(
  322. self,
  323. input_ids: torch.LongTensor | None = None,
  324. attention_mask: torch.Tensor | None = None,
  325. position_ids: torch.LongTensor | None = None,
  326. past_key_values: Cache | None = None,
  327. inputs_embeds: torch.FloatTensor | None = None,
  328. use_cache: bool | None = None,
  329. **kwargs: Unpack[TransformersKwargs],
  330. ) -> BaseModelOutputWithPast:
  331. if (input_ids is None) ^ (inputs_embeds is not None):
  332. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  333. if inputs_embeds is None:
  334. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  335. if use_cache and past_key_values is None:
  336. past_key_values = DynamicCache(config=self.config)
  337. if position_ids is None:
  338. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  339. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  340. position_ids = position_ids.unsqueeze(0)
  341. causal_mask = create_causal_mask(
  342. config=self.config,
  343. inputs_embeds=inputs_embeds,
  344. attention_mask=attention_mask,
  345. past_key_values=past_key_values,
  346. position_ids=position_ids,
  347. )
  348. hidden_states = inputs_embeds
  349. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  350. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  351. hidden_states = decoder_layer(
  352. hidden_states,
  353. attention_mask=causal_mask,
  354. position_embeddings=position_embeddings,
  355. position_ids=position_ids,
  356. past_key_values=past_key_values,
  357. use_cache=use_cache,
  358. **kwargs,
  359. )
  360. hidden_states = self.norm(hidden_states)
  361. return BaseModelOutputWithPast(
  362. last_hidden_state=hidden_states,
  363. past_key_values=past_key_values,
  364. )
  365. @auto_docstring
  366. class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
  367. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  368. _tp_plan = {"lm_head": "colwise_gather_output"}
  369. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  370. def __init__(self, config):
  371. super().__init__(config)
  372. self.model = GlmModel(config)
  373. self.vocab_size = config.vocab_size
  374. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  375. # Initialize weights and apply final processing
  376. self.post_init()
  377. @can_return_tuple
  378. @auto_docstring
  379. def forward(
  380. self,
  381. input_ids: torch.LongTensor | None = None,
  382. attention_mask: torch.Tensor | None = None,
  383. position_ids: torch.LongTensor | None = None,
  384. past_key_values: Cache | None = None,
  385. inputs_embeds: torch.FloatTensor | None = None,
  386. labels: torch.LongTensor | None = None,
  387. use_cache: bool | None = None,
  388. logits_to_keep: int | torch.Tensor = 0,
  389. **kwargs: Unpack[TransformersKwargs],
  390. ) -> CausalLMOutputWithPast:
  391. r"""
  392. Example:
  393. ```python
  394. >>> from transformers import AutoTokenizer, GlmForCausalLM
  395. >>> model = GlmForCausalLM.from_pretrained("meta-glm/Glm-2-7b-hf")
  396. >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm/Glm-2-7b-hf")
  397. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  398. >>> inputs = tokenizer(prompt, return_tensors="pt")
  399. >>> # Generate
  400. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  401. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  402. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  403. ```"""
  404. outputs: BaseModelOutputWithPast = self.model(
  405. input_ids=input_ids,
  406. attention_mask=attention_mask,
  407. position_ids=position_ids,
  408. past_key_values=past_key_values,
  409. inputs_embeds=inputs_embeds,
  410. use_cache=use_cache,
  411. **kwargs,
  412. )
  413. hidden_states = outputs.last_hidden_state
  414. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  415. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  416. logits = self.lm_head(hidden_states[:, slice_indices, :])
  417. loss = None
  418. if labels is not None:
  419. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  420. return CausalLMOutputWithPast(
  421. loss=loss,
  422. logits=logits,
  423. past_key_values=outputs.past_key_values,
  424. hidden_states=outputs.hidden_states,
  425. attentions=outputs.attentions,
  426. )
  427. class GlmForSequenceClassification(GenericForSequenceClassification, GlmPreTrainedModel):
  428. pass
  429. class GlmForTokenClassification(GenericForTokenClassification, GlmPreTrainedModel):
  430. pass
  431. __all__ = [
  432. "GlmPreTrainedModel",
  433. "GlmModel",
  434. "GlmForCausalLM",
  435. "GlmForSequenceClassification",
  436. "GlmForTokenClassification",
  437. ]