configuration_t5gemma.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/t5gemma/modular_t5gemma.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_t5gemma.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Any
  22. from huggingface_hub.dataclasses import strict
  23. from ...configuration_utils import PreTrainedConfig
  24. from ...modeling_rope_utils import RopeParameters
  25. from ...utils import auto_docstring
  26. @auto_docstring(checkpoint="google/t5_gemma_module-7b")
  27. @strict
  28. class T5GemmaModuleConfig(PreTrainedConfig):
  29. r"""
  30. query_pre_attn_scalar (`float`, *optional*, defaults to 256):
  31. scaling factor used on the attention scores
  32. final_logit_softcapping (`float`, *optional*, defaults to 30.0):
  33. scaling factor when applying tanh softcapping on the logits.
  34. attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
  35. scaling factor when applying tanh softcapping on the attention scores.
  36. ```python
  37. >>> from transformers import T5GemmaModuleModel, T5GemmaModuleConfig
  38. >>> # Initializing a T5GemmaModule t5_gemma_module-7b style configuration
  39. >>> configuration = T5GemmaModuleConfig()
  40. >>> # Initializing a model from the t5_gemma_module-7b style configuration
  41. >>> model = T5GemmaModuleModel(configuration)
  42. >>> # Accessing the model configuration
  43. >>> configuration = model.config
  44. ```"""
  45. model_type = "t5_gemma_module"
  46. keys_to_ignore_at_inference = ["past_key_values"]
  47. base_model_tp_plan = {
  48. "layers.*.self_attn.q_proj": "colwise",
  49. "layers.*.self_attn.k_proj": "colwise",
  50. "layers.*.self_attn.v_proj": "colwise",
  51. "layers.*.self_attn.o_proj": "rowwise",
  52. "layers.*.mlp.gate_proj": "colwise",
  53. "layers.*.mlp.up_proj": "colwise",
  54. "layers.*.mlp.down_proj": "rowwise",
  55. }
  56. base_model_pp_plan = {
  57. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  58. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  59. "norm": (["hidden_states"], ["hidden_states"]),
  60. }
  61. vocab_size: int = 256000
  62. hidden_size: int = 2304
  63. intermediate_size: int = 9216
  64. num_hidden_layers: int = 26
  65. num_attention_heads: int = 8
  66. num_key_value_heads: int = 4
  67. head_dim: int = 256
  68. hidden_activation: str = "gelu_pytorch_tanh"
  69. max_position_embeddings: int = 8192
  70. initializer_range: float = 0.02
  71. rms_norm_eps: float = 1e-6
  72. use_cache: bool = True
  73. pad_token_id: int | None = 0
  74. eos_token_id: int | list[int] | None = 1
  75. bos_token_id: int | None = 2
  76. tie_word_embeddings: bool = True
  77. rope_parameters: RopeParameters | dict | None = None
  78. attention_bias: bool = False
  79. attention_dropout: int | float | None = 0.0
  80. query_pre_attn_scalar: int = 256
  81. sliding_window: int | None = 4096
  82. layer_types: list[str] | None = None
  83. final_logit_softcapping: float | None = 30.0
  84. attn_logit_softcapping: float | None = 50.0
  85. is_decoder: bool = False
  86. def __post_init__(self, **kwargs):
  87. if self.layer_types is None:
  88. self.layer_types = [
  89. "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
  90. ]
  91. super().__post_init__(**kwargs)
  92. def validate_architecture(self):
  93. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  94. if self.hidden_size % self.num_attention_heads != 0:
  95. raise ValueError(
  96. f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
  97. f"heads ({self.num_attention_heads})."
  98. )
  99. @auto_docstring(checkpoint="google/t5_gemma_module-7b")
  100. @strict
  101. class T5GemmaConfig(PreTrainedConfig):
  102. r"""
  103. encoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*):
  104. Configuration for the encoder.
  105. decoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*):
  106. Configuration for the decoder.
  107. Example:
  108. ```python
  109. >>> from transformers import T5GemmaConfig, T5GemmaModel
  110. >>> t5gemma_config = T5GemmaConfig.from_pretrained("google/t5gemma-2b-2b-prefixlm-it")
  111. >>> model = T5GemmaModel(t5gemma_config)
  112. ```"""
  113. model_type = "t5gemma"
  114. keys_to_ignore_at_inference = ["past_key_values"]
  115. sub_configs = {"encoder": T5GemmaModuleConfig, "decoder": T5GemmaModuleConfig}
  116. encoder: T5GemmaModuleConfig | dict[Any, Any] | None = None
  117. decoder: T5GemmaModuleConfig | dict[Any, Any] | None = None
  118. is_encoder_decoder: bool = True
  119. dropout_rate: int | float = 0.0
  120. classifier_dropout_rate: int | float = 0.0
  121. attention_dropout: float | int = 0.0
  122. tie_word_embeddings: bool = True
  123. vocab_size: int = 256000
  124. def __post_init__(self, **kwargs):
  125. if isinstance(self.encoder, dict):
  126. self.encoder = T5GemmaModuleConfig(**self.encoder)
  127. elif self.encoder is None:
  128. self.encoder = T5GemmaModuleConfig()
  129. if isinstance(self.decoder, dict):
  130. self.decoder = T5GemmaModuleConfig(**self.decoder)
  131. elif self.decoder is None:
  132. self.decoder = T5GemmaModuleConfig()
  133. self.encoder.is_decoder = False
  134. self.encoder.dropout_rate = self.dropout_rate
  135. self.encoder.attention_dropout = self.attention_dropout
  136. self.decoder.is_decoder = True
  137. self.decoder.use_cache = True
  138. self.decoder.dropout_rate = self.dropout_rate
  139. self.decoder.attention_dropout = self.attention_dropout
  140. self.decoder.cross_attention_hidden_size = self.encoder.hidden_size
  141. self.initializer_range = kwargs.pop("initializer_range", self.decoder.initializer_range)
  142. for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]:
  143. if special_token_key not in kwargs:
  144. kwargs[special_token_key] = getattr(self.decoder, special_token_key)
  145. super().__post_init__(**kwargs)
  146. __all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"]