configuration_bamba.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Bamba model configuration"""
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...modeling_rope_utils import RopeParameters
  18. from ...utils import auto_docstring
  19. @strict
  20. @auto_docstring(
  21. custom_intro="""
  22. The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
  23. The checkpoints are jointly trained by IBM, Princeton, and UIUC.
  24. """,
  25. checkpoint="ibm-fms/Bamba-9.8b-2.2T-hf",
  26. )
  27. class BambaConfig(PreTrainedConfig):
  28. r"""
  29. num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
  30. Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
  31. integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
  32. logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
  33. sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
  34. significantly.
  35. attn_layer_indices (`list`, *optional*):
  36. Specifies the layer indices that will have full attention. Must contain values at most num_hidden_layers.
  37. z_loss_coefficient (`float`, *optional*, defaults to 0.0):
  38. Coefficient for auxiliary z-loss used to control logit growth during training
  39. """
  40. model_type = "bamba"
  41. attribute_map = {"layer_types": "layers_block_type"}
  42. keys_to_ignore_at_inference = ["past_key_values"]
  43. vocab_size: int = 128000
  44. tie_word_embeddings: bool = False
  45. hidden_size: int = 4096
  46. intermediate_size: int = 14336
  47. num_hidden_layers: int = 32
  48. num_attention_heads: int = 32
  49. num_key_value_heads: int | None = 8
  50. hidden_act: str = "silu"
  51. initializer_range: float = 0.02
  52. rms_norm_eps: float = 1e-5
  53. use_cache: bool = True
  54. num_logits_to_keep: int | None = 1
  55. pad_token_id: int | None = 0
  56. bos_token_id: int | None = 1
  57. eos_token_id: int | list[int] | None = 2
  58. max_position_embeddings: int = 262144
  59. attention_dropout: float | int | None = 0.0
  60. attn_layer_indices: list[int] | None = None
  61. mamba_n_heads: int | None = 128
  62. mamba_d_head: str | int | None = "auto"
  63. mamba_n_groups: int | None = 1
  64. mamba_d_state: int | None = 256
  65. mamba_d_conv: int | None = 4
  66. mamba_expand: int | None = 2
  67. mamba_chunk_size: int | None = 256
  68. mamba_conv_bias: bool | None = True
  69. mamba_proj_bias: bool | None = False
  70. time_step_min: float | None = 0.001
  71. time_step_max: float | None = 0.1
  72. time_step_limit: list[float] | tuple[float, float] | None = (0.0, float("inf"))
  73. z_loss_coefficient: float | None = 0.0
  74. rope_parameters: RopeParameters | dict | None = None
  75. attention_bias: bool = False
  76. mlp_bias: bool = False
  77. def __post_init__(self, **kwargs):
  78. # for backward compatibility
  79. if self.num_key_value_heads is None:
  80. self.num_key_value_heads = self.num_attention_heads
  81. # for the mamba_v2, must satisfy the following
  82. if self.mamba_d_head == "auto":
  83. self.mamba_d_head = self.mamba_expand * self.hidden_size // self.mamba_n_heads
  84. self.time_step_limit = tuple(self.time_step_limit) if self.time_step_limit is not None else None
  85. kwargs["partial_rotary_factor"] = 0.5 # hardcode for BC
  86. super().__post_init__(**kwargs)
  87. @property
  88. def layers_block_type(self):
  89. return [
  90. "attention" if (self.attn_layer_indices and i in self.attn_layer_indices) else "mamba"
  91. for i in range(self.num_hidden_layers)
  92. ]
  93. def validate_architecture(self):
  94. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  95. mamba_intermediate = self.mamba_expand * self.hidden_size
  96. if mamba_intermediate % self.mamba_n_heads != 0:
  97. raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
  98. if self.mamba_d_head * self.mamba_n_heads != mamba_intermediate:
  99. raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
  100. __all__ = ["BambaConfig"]