configuration_llama4.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...modeling_rope_utils import RopeParameters
  18. from ...utils import auto_docstring, logging
  19. logger = logging.get_logger(__name__)
  20. @auto_docstring(checkpoint="meta-llama/Llama-4-Scout-17B-16E")
  21. @strict
  22. class Llama4VisionConfig(PreTrainedConfig):
  23. r"""
  24. vision_output_dim (`int`, *optional*, defaults to 7680):
  25. Dimensionality of the vision model output. Includes output of transformer
  26. encoder with intermediate layers and global transformer encoder.
  27. pixel_shuffle_ratio (`float`, *optional*, defaults to 0.5):
  28. Pixel-shuffle ratio for downsampling patch tokens. Smaller values produce fewer tokens (more downsampling).
  29. projector_input_dim (`int`, *optional*, defaults to 4096):
  30. Width of the vision adapter MLP before pixel shuffle. Larger value increases capacity and compute.
  31. projector_output_dim (`int`, *optional*, defaults to 4096):
  32. Output width of the vision adapter. Larger value yields higher-dimensional image features.
  33. projector_dropout (`float`, *optional*, defaults to 0.0):
  34. Dropout rate inside the vision adapter MLP. Higher value adds more regularization.
  35. """
  36. base_model_tp_plan = {
  37. "model.layers.*.self_attn.q_proj": "colwise",
  38. "model.layers.*.self_attn.k_proj": "colwise",
  39. "model.layers.*.self_attn.v_proj": "colwise",
  40. "model.layers.*.self_attn.o_proj": "rowwise",
  41. "vision_adapter.mlp.fc1": "colwise",
  42. "vision_adapter.mlp.fc2": "rowwise",
  43. "patch_embedding.linear": "colwise_gather_output",
  44. }
  45. model_type = "llama4_vision_model"
  46. base_config_key = "vision_config"
  47. hidden_size: int = 768
  48. hidden_act: str = "gelu"
  49. num_hidden_layers: int = 34
  50. num_attention_heads: int = 16
  51. num_channels: int = 3
  52. intermediate_size: int = 5632
  53. vision_output_dim: int = 7680
  54. image_size: int | list[int] | tuple[int, int] = 448
  55. patch_size: int | list[int] | tuple[int, int] = 14
  56. norm_eps: float = 1e-5
  57. vision_feature_select_strategy: str = "default"
  58. initializer_range: float = 0.02
  59. pixel_shuffle_ratio: float = 0.5
  60. projector_input_dim: int = 4096
  61. projector_output_dim: int = 4096
  62. multi_modal_projector_bias: bool = False
  63. projector_dropout: float | int = 0.0
  64. attention_dropout: float | int = 0.0
  65. rope_parameters: RopeParameters | dict | None = None
  66. @auto_docstring(checkpoint="meta-llama/Llama-4-Scout-17B-16E")
  67. @strict
  68. class Llama4TextConfig(PreTrainedConfig):
  69. r"""
  70. intermediate_size_mlp (`int`, *optional*, defaults to 16384):
  71. Intermediate size of dense MLP layers. Larger value increases FFN capacity and compute.
  72. moe_layers (`list[int]`, *optional*):
  73. List of layer indices that use MoE. Overrides `interleave_moe_layer_step` when set.
  74. interleave_moe_layer_step (`int`, *optional*, defaults to 1):
  75. Spacing between MoE layers when `moe_layers` is `None`. Larger value means fewer MoE layers.
  76. use_qk_norm (`bool`, *optional*, defaults to `True`):
  77. Whether to L2-normalize queries/keys on RoPE layers. Can stabilize attention when enabled.
  78. no_rope_layers (`list[int]`, *optional*):
  79. List with at least the same length as the number of layers in the model.
  80. A `1` at an index position indicates that the corresponding layer will use RoPE,
  81. while a `0` indicates that it's a NoPE layer.
  82. no_rope_layer_interval (`int`, *optional*, defaults to 4):
  83. If `no_rope_layers` is `None`, it will be created using a NoPE layer every
  84. `no_rope_layer_interval` layers.
  85. attention_chunk_size (`int`, *optional*, defaults to 8192):
  86. Chunk size for the attention computation. Smaller value enforces more local attention and lowers memory.
  87. attn_temperature_tuning (`bool`, *optional*, defaults to `True`):
  88. Whether to dynamically scale the attention temperature for each query token based on sequence length.
  89. Recommended for long sequences (e.g., >32k tokens) to maintain stable output results.
  90. floor_scale (`int`, *optional*, defaults to 8192):
  91. Base scale (in tokens) for attention temperature tuning. Larger value delays scaling to longer positions.
  92. attn_scale (`float`, *optional*, defaults to 0.1):
  93. Strength of attention temperature tuning. Larger value increases scaling at long positions.
  94. Example:
  95. """
  96. model_type = "llama4_text"
  97. keys_to_ignore_at_inference = ["past_key_values"]
  98. default_theta = 500000.0
  99. base_model_tp_plan = {
  100. "layers.*.self_attn.q_proj": "colwise",
  101. "layers.*.self_attn.k_proj": "colwise",
  102. "layers.*.self_attn.v_proj": "colwise",
  103. "layers.*.self_attn.o_proj": "rowwise",
  104. "layers.*.feed_forward.shared_expert.gate_proj": "colwise",
  105. "layers.*.feed_forward.shared_expert.up_proj": "colwise",
  106. "layers.*.feed_forward.shared_expert.down_proj": "rowwise",
  107. "layers.*.feed_forward.experts.gate_up_proj": "packed_rowwise", # row because not linear
  108. "layers.*.feed_forward.experts.down_proj": "colwise", # col because not linear
  109. "layers.*.feed_forward.gate_proj": "colwise",
  110. "layers.*.feed_forward.up_proj": "colwise",
  111. "layers.*.feed_forward.down_proj": "rowwise",
  112. }
  113. base_model_ep_plan = {
  114. "layers.*.self_attn.q_proj": "colwise",
  115. "layers.*.self_attn.k_proj": "colwise",
  116. "layers.*.self_attn.v_proj": "colwise",
  117. "layers.*.self_attn.o_proj": "rowwise",
  118. "layers.*.feed_forward.experts.gate_up_proj": "grouped_gemm", # row because not linear
  119. "layers.*.feed_forward.experts.down_proj": "grouped_gemm", # col because not linear
  120. "layers.*.feed_forward.gate_proj": "colwise",
  121. "layers.*.feed_forward.up_proj": "colwise",
  122. "layers.*.feed_forward.down_proj": "rowwise",
  123. "layers.*.feed_forward.router": "ep_router",
  124. }
  125. vocab_size: int = 202048
  126. hidden_size: int = 5120
  127. intermediate_size: int = 8192
  128. intermediate_size_mlp: int = 16384
  129. num_hidden_layers: int = 48
  130. num_attention_heads: int = 40
  131. num_key_value_heads: int = 8
  132. head_dim: int = 128
  133. hidden_act: str = "silu"
  134. max_position_embeddings: int = 4096 * 32
  135. initializer_range: float = 0.02
  136. rms_norm_eps: float = 1e-5
  137. use_cache: bool = True
  138. pad_token_id: int | None = None
  139. bos_token_id: int | None = 1
  140. eos_token_id: int | list[int] | None = 2
  141. tie_word_embeddings: bool = False
  142. attention_dropout: float | int = 0.0
  143. num_experts_per_tok: int = 1
  144. num_local_experts: int = 16
  145. moe_layers: list[int] | None = None
  146. interleave_moe_layer_step: int = 1
  147. use_qk_norm: bool = True
  148. output_router_logits: bool = False
  149. router_aux_loss_coef: float = 0.001
  150. router_jitter_noise: float = 0.0
  151. rope_parameters: RopeParameters | dict | None = None
  152. no_rope_layers: list[int] | None = None
  153. no_rope_layer_interval: int = 4
  154. attention_chunk_size: int | None = 8192
  155. layer_types: list[str] | None = None
  156. attn_temperature_tuning: bool = True
  157. floor_scale: int = 8192
  158. attn_scale: float = 0.1
  159. attention_bias: bool = False
  160. def __post_init__(self, **kwargs):
  161. if self.num_key_value_heads is None:
  162. self.num_key_value_heads = self.num_attention_heads
  163. default_no_rope_layers = [
  164. int((layer_idx + 1) % self.no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
  165. ]
  166. self.no_rope_layers = self.no_rope_layers if self.no_rope_layers else default_no_rope_layers
  167. self.head_dim = self.head_dim if self.head_dim is not None else self.hidden_size // self.num_attention_heads
  168. self.moe_layers = (
  169. self.moe_layers
  170. if self.moe_layers is not None
  171. else list(
  172. range(
  173. self.interleave_moe_layer_step - 1,
  174. self.num_hidden_layers,
  175. self.interleave_moe_layer_step,
  176. )
  177. )
  178. )
  179. if self.layer_types is None:
  180. self.layer_types = [
  181. "chunked_attention" if no_rope else "full_attention" for no_rope in self.no_rope_layers
  182. ]
  183. super().__post_init__(**kwargs)
  184. @auto_docstring(checkpoint="meta-llama/Llama-4-Scout-17B-16E")
  185. @strict
  186. class Llama4Config(PreTrainedConfig):
  187. r"""
  188. boi_token_index (`int`, *optional*, defaults to 200080):
  189. The begin-of-image token index to wrap the image prompt.
  190. eoi_token_index (`int`, *optional*, defaults to 200081):
  191. The end-of-image token index to wrap the image prompt.
  192. ```python
  193. >>> from transformers import Llama4Model, Llama4Config
  194. >>> # Initializing a Llama4 7B style configuration
  195. >>> configuration = Llama4Config()
  196. >>> # Initializing a model from the Llama4 7B style configuration
  197. >>> model = Llama4Model(configuration)
  198. >>> # Accessing the model configuration
  199. >>> configuration = model.config
  200. ```
  201. """
  202. model_type = "llama4"
  203. attribute_map = {
  204. "image_token_id": "image_token_index",
  205. "boi_token_id": "boi_token_index",
  206. "eoi_token_id": "eoi_token_index",
  207. }
  208. sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig}
  209. base_model_tp_plan = {
  210. "multi_modal_projector.linear_1": "colwise_rep",
  211. }
  212. vision_config: dict | PreTrainedConfig | None = None
  213. text_config: dict | PreTrainedConfig | None = None
  214. boi_token_index: int = 200080
  215. eoi_token_index: int = 200081
  216. image_token_index: int = 200092
  217. tie_word_embeddings: bool = False
  218. def __post_init__(self, **kwargs):
  219. if self.vision_config is None:
  220. self.vision_config = Llama4VisionConfig()
  221. logger.info("vision_config is None, using default llama4 vision config")
  222. elif isinstance(self.vision_config, dict):
  223. self.vision_config = Llama4VisionConfig(**self.vision_config)
  224. if self.text_config is None:
  225. self.text_config = Llama4TextConfig()
  226. logger.info("text_config is None, using default llama4 text config")
  227. elif isinstance(self.text_config, dict):
  228. self.text_config = Llama4TextConfig(**self.text_config)
  229. super().__post_init__(**kwargs)
  230. __all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"]