modular_glm4.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from ...cache_utils import Cache
  17. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  18. from ...modeling_layers import GradientCheckpointingLayer
  19. from ...modeling_outputs import CausalLMOutputWithPast
  20. from ...processing_utils import Unpack
  21. from ...utils import TransformersKwargs, logging
  22. from ..glm.modeling_glm import GlmAttention, GlmForCausalLM, GlmForSequenceClassification, GlmForTokenClassification
  23. from ..phi3.modeling_phi3 import Phi3MLP
  24. from .configuration_glm4 import Glm4Config
  25. from .modeling_glm4 import Glm4RMSNorm
  26. logger = logging.get_logger(__name__)
  27. _CHECKPOINT_FOR_DOC = "THUDM/GLM-4-9B-0414"
  28. class Glm4MLP(Phi3MLP):
  29. pass
  30. class Glm4DecoderLayer(GradientCheckpointingLayer):
  31. def __init__(self, config: Glm4Config, layer_idx: int):
  32. super().__init__()
  33. self.hidden_size = config.hidden_size
  34. self.self_attn = Glm4Attention(config=config, layer_idx=layer_idx)
  35. self.mlp = Glm4MLP(config)
  36. self.input_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  37. self.post_attention_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  38. self.post_self_attn_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  39. self.post_mlp_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  40. def forward(
  41. self,
  42. hidden_states: torch.Tensor,
  43. attention_mask: torch.Tensor | None = None,
  44. position_ids: torch.LongTensor | None = None,
  45. past_key_values: Cache | None = None,
  46. use_cache: bool | None = False,
  47. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  48. **kwargs: Unpack[FlashAttentionKwargs],
  49. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  50. residual = hidden_states
  51. hidden_states = self.input_layernorm(hidden_states)
  52. hidden_states, _ = self.self_attn(
  53. hidden_states=hidden_states,
  54. attention_mask=attention_mask,
  55. position_ids=position_ids,
  56. past_key_values=past_key_values,
  57. use_cache=use_cache,
  58. position_embeddings=position_embeddings,
  59. **kwargs,
  60. )
  61. hidden_states = self.post_self_attn_layernorm(hidden_states)
  62. hidden_states = residual + hidden_states
  63. residual = hidden_states
  64. hidden_states = self.post_attention_layernorm(hidden_states)
  65. hidden_states = self.mlp(hidden_states)
  66. hidden_states = self.post_mlp_layernorm(hidden_states)
  67. hidden_states = residual + hidden_states
  68. return hidden_states
  69. class Glm4Attention(GlmAttention):
  70. pass
  71. class Glm4ForCausalLM(GlmForCausalLM):
  72. def forward(
  73. self,
  74. **super_kwargs: Unpack[TransformersKwargs],
  75. ) -> tuple | CausalLMOutputWithPast:
  76. r"""
  77. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  78. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  79. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  80. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  81. Example:
  82. ```python
  83. >>> from transformers import AutoTokenizer, Glm4ForCausalLM
  84. >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
  85. >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
  86. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  87. >>> inputs = tokenizer(prompt, return_tensors="pt")
  88. >>> # Generate
  89. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  90. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  91. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  92. ```"""
  93. return super().forward(**super_kwargs)
  94. class Glm4ForSequenceClassification(GlmForSequenceClassification):
  95. pass
  96. class Glm4ForTokenClassification(GlmForTokenClassification):
  97. pass
  98. __all__ = [
  99. "Glm4PreTrainedModel", # noqa: F822
  100. "Glm4Model", # noqa: F822
  101. "Glm4ForCausalLM",
  102. "Glm4ForSequenceClassification",
  103. "Glm4ForTokenClassification",
  104. ]