configuration_falcon.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Falcon configuration"""
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...modeling_rope_utils import RopeParameters
  18. from ...utils import auto_docstring
  19. @auto_docstring(checkpoint="tiiuae/falcon-7b")
  20. @strict
  21. class FalconConfig(PreTrainedConfig):
  22. r"""
  23. num_ln_in_parallel_attn (`int`, *optional*):
  24. Set to 2 if separate layer norms are to be used for the MLP and the attention output when using parallel
  25. attention, otherwise, 1.
  26. alibi (`bool`, *optional*, defaults to `False`):
  27. Whether to use ALiBi positional biases during self-attention.
  28. new_decoder_architecture (`bool`, *optional*, defaults to `False`):
  29. Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
  30. arguments are ignored, as the new decoder always uses parallel attention.
  31. multi_query (`bool`, *optional*, defaults to `True`):
  32. Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
  33. parallel_attn (`bool`, *optional*, defaults to `True`):
  34. Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
  35. instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
  36. bias (`bool`, *optional*, defaults to `False`):
  37. Whether to use bias on Linear layers.
  38. ffn_hidden_size (`int`, *optional*):
  39. The hidden size of the feedforward layer in the Transformer decoder.
  40. defaults to 4x hidden dim
  41. activation (`str`, *optional*, defaults to `"gelu"`):
  42. The activation function used in the feedforward layer.
  43. Example:
  44. ```python
  45. >>> from transformers import FalconModel, FalconConfig
  46. >>> # Initializing a small (2-layer) Falcon configuration
  47. >>> configuration = FalconConfig(num_hidden_layers=2)
  48. >>> # Initializing a model from the small configuration
  49. >>> model = FalconModel(configuration)
  50. >>> # Accessing the model configuration
  51. >>> configuration = model.config
  52. ```"""
  53. model_type = "falcon"
  54. keys_to_ignore_at_inference = ["past_key_values"]
  55. vocab_size: int = 65024
  56. hidden_size: int = 4544
  57. num_hidden_layers: int = 32
  58. num_attention_heads: int = 71
  59. num_ln_in_parallel_attn: int | None = None
  60. layer_norm_epsilon: float | None = 1e-5
  61. initializer_range: float = 0.02
  62. use_cache: bool = True
  63. hidden_dropout: float | int | None = 0.0
  64. attention_dropout: float | int | None = 0.0
  65. num_kv_heads: int | None = None
  66. alibi: bool | None = False
  67. new_decoder_architecture: bool | None = False
  68. multi_query: bool | None = True
  69. parallel_attn: bool | None = True
  70. bias: bool | None = False
  71. max_position_embeddings: int = 2048
  72. rope_parameters: RopeParameters | dict | None = None
  73. bos_token_id: int | None = 11
  74. eos_token_id: int | list[int] | None = 11
  75. pad_token_id: int | None = None
  76. ffn_hidden_size: int | None = None
  77. activation: str | None = "gelu"
  78. tie_word_embeddings: bool = True
  79. def __post_init__(self, **kwargs):
  80. # Backward compatibility with n_embed kwarg
  81. n_embed = kwargs.pop("n_embed", None)
  82. self.hidden_size = self.hidden_size if n_embed is None else n_embed
  83. self.num_kv_heads = self.num_attention_heads if self.num_kv_heads is None else self.num_kv_heads
  84. if self.ffn_hidden_size is None:
  85. self.ffn_hidden_size = self.hidden_size * 4
  86. super().__post_init__(**kwargs)
  87. @property
  88. def head_dim(self):
  89. return self.hidden_size // self.num_attention_heads
  90. @property
  91. def rotary(self):
  92. return not self.alibi
  93. __all__ = ["FalconConfig"]