configuration_jamba.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Jamba model configuration"""
  15. import math
  16. from huggingface_hub.dataclasses import strict
  17. from ...configuration_utils import PreTrainedConfig
  18. from ...utils import auto_docstring
  19. @auto_docstring(checkpoint="ai21labs/Jamba-v0.1")
  20. @strict
  21. class JambaConfig(PreTrainedConfig):
  22. r"""
  23. expert_layer_period (`int`, *optional*, defaults to 2):
  24. Once in this many layers, we will have an expert layer
  25. expert_layer_offset (`int`, *optional*, defaults to 1):
  26. The first layer index that contains an expert mlp layer
  27. attn_layer_period (`int`, *optional*, defaults to 8):
  28. Once in this many layers, we will have a vanilla attention layer
  29. attn_layer_offset (`int`, *optional*, defaults to 4):
  30. The first layer index that contains a vanilla attention mlp layer
  31. use_mamba_kernels (`bool`, *optional*, defaults to `True`):
  32. Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
  33. `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
  34. `True` and kernels are not available
  35. mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
  36. Rank of the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
  37. """
  38. model_type = "jamba"
  39. keys_to_ignore_at_inference = ["past_key_values"]
  40. attribute_map = {
  41. "num_local_experts": "num_experts",
  42. }
  43. vocab_size: int = 65536
  44. tie_word_embeddings: bool = False
  45. hidden_size: int = 4096
  46. intermediate_size: int = 14336
  47. num_hidden_layers: int = 32
  48. num_attention_heads: int = 32
  49. num_key_value_heads: int = 8
  50. hidden_act: str = "silu"
  51. initializer_range: float = 0.02
  52. rms_norm_eps: float = 1e-6
  53. use_cache: bool = True
  54. output_router_logits: bool = False
  55. router_aux_loss_coef: float = 0.001
  56. pad_token_id: int | None = 0
  57. bos_token_id: int | None = 1
  58. eos_token_id: int | list[int] | None = 2
  59. max_position_embeddings: int = 262144
  60. attention_dropout: float | int = 0.0
  61. num_experts_per_tok: int = 2
  62. num_experts: int = 16
  63. expert_layer_period: int = 2
  64. expert_layer_offset: int = 1
  65. attn_layer_period: int = 8
  66. attn_layer_offset: int = 4
  67. use_mamba_kernels: bool = True
  68. mamba_d_state: int = 16
  69. mamba_d_conv: int = 4
  70. mamba_expand: int = 2
  71. mamba_dt_rank: int | str = "auto"
  72. mamba_conv_bias: bool = True
  73. mamba_proj_bias: bool = False
  74. def __post_init__(self, **kwargs):
  75. if self.num_key_value_heads is None:
  76. self.num_key_value_heads = self.num_attention_heads
  77. self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if self.mamba_dt_rank == "auto" else self.mamba_dt_rank
  78. super().__post_init__(**kwargs)
  79. @property
  80. def layers_block_type(self):
  81. return [
  82. "attention" if i % self.attn_layer_period == self.attn_layer_offset else "mamba"
  83. for i in range(self.num_hidden_layers)
  84. ]
  85. @property
  86. def layer_types(self):
  87. # Follow the `layer_types` conventions
  88. layer_types = self.layers_block_type
  89. return ["full_attention" if x == "attention" else x for x in layer_types]
  90. @property
  91. def layers_num_experts(self):
  92. return [
  93. self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
  94. for i in range(self.num_hidden_layers)
  95. ]
  96. def validate_architecture(self):
  97. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  98. if self.attn_layer_offset >= self.attn_layer_period:
  99. raise ValueError(
  100. f"attention layer offset ({self.attn_layer_offset}) must be smaller than attention layer period ({self.attn_layer_period})"
  101. )
  102. if self.expert_layer_offset >= self.expert_layer_period:
  103. raise ValueError(
  104. f"expert layer offset ({self.expert_layer_offset}) must be smaller than expert layer period ({self.expert_layer_period})"
  105. )
  106. __all__ = ["JambaConfig"]