configuration_chameleon.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """chameleon model configuration"""
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...modeling_rope_utils import RopeParameters
  18. from ...utils import auto_docstring, logging
  19. logger = logging.get_logger(__name__)
  20. @auto_docstring(checkpoint="facebook/chameleon-7b")
  21. @strict
  22. class ChameleonVQVAEConfig(PreTrainedConfig):
  23. r"""
  24. resolution (`int`, *optional*, defaults to 512):
  25. Resolution of the input images.
  26. base_channels (`int`, *optional*, defaults to 128):
  27. Base channel count.
  28. channel_multiplier (`list[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
  29. Channel multipliers for each resolution.
  30. num_res_blocks (`int`, *optional*, defaults to 2):
  31. Number of residual blocks.
  32. attn_resolutions (`list[int]`, *optional*):
  33. Resolutions to apply attention.
  34. dropout (`float`, *optional*, defaults to 0.0):
  35. Dropout rate.
  36. attn_type (`str`, *optional*, defaults to `"vanilla"`):
  37. Attention type used in VQ-GAN encoder. Can be "vanilla" or None
  38. """
  39. model_type = "chameleon_vqgan"
  40. base_config_key = "vq_config"
  41. embed_dim: int = 256
  42. num_embeddings: int = 8192
  43. double_latent: bool = False
  44. latent_channels: int = 256
  45. resolution: int = 512
  46. in_channels: int = 3
  47. base_channels: int = 128
  48. channel_multiplier: list[int] | tuple[int, ...] = (1, 1, 2, 2, 4)
  49. num_res_blocks: int = 2
  50. attn_resolutions: list[int] | None = None
  51. dropout: float | int = 0.0
  52. attn_type: str = "vanilla"
  53. initializer_range = 0.02
  54. @auto_docstring(checkpoint="facebook/chameleon-7b")
  55. @strict
  56. class ChameleonConfig(PreTrainedConfig):
  57. r"""
  58. model_parallel_size (`int`, *optional*, defaults to 1):
  59. Number of shards used when training the model. This will be used in qk layernorm because the original Chameleon inference
  60. doesn't do reduction in those layers and each rank has its own biases.
  61. swin_norm (`bool`, *optional*, defaults to `False`):
  62. Use Swin Transformer normalization.
  63. vocabulary_map (`dict`, *optional*):
  64. A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs.
  65. ```python
  66. >>> from transformers import ChameleonModel, ChameleonConfig
  67. >>> # Initializing a chameleon chameleon-7b style configuration
  68. >>> configuration = ChameleonConfig()
  69. >>> # Initializing a model from the chameleon-7b style configuration
  70. >>> model = ChameleonModel(configuration)
  71. >>> # Accessing the model configuration
  72. >>> configuration = model.config
  73. ```
  74. """
  75. model_type = "chameleon"
  76. sub_configs = {"vq_config": ChameleonVQVAEConfig}
  77. keys_to_ignore_at_inference = ["past_key_values"]
  78. vocab_size: int = 65536
  79. hidden_size: int = 4096
  80. intermediate_size: int = 11008
  81. num_hidden_layers: int = 32
  82. num_attention_heads: int = 32
  83. num_key_value_heads: int | None = 32
  84. hidden_act: str = "silu"
  85. max_position_embeddings: int = 4096
  86. initializer_range: float = 0.02
  87. rms_norm_eps: float = 1e-05
  88. use_cache: bool = True
  89. pad_token_id: int | None = None
  90. bos_token_id: int | None = 1
  91. eos_token_id: int | list[int] | None = 2
  92. tie_word_embeddings: bool = False
  93. rope_parameters: RopeParameters | dict | None = None
  94. attention_bias: bool | None = False
  95. attention_dropout: float | int | None = 0.0
  96. model_parallel_size: int | None = 1
  97. swin_norm: bool | None = False
  98. vq_config: dict | PreTrainedConfig | None = None
  99. vocabulary_map: dict | None = None
  100. mlp_bias: bool = False
  101. def __post_init__(self, **kwargs):
  102. if self.vq_config is None:
  103. logger.info("vq_config is None. initializing the ChameleonVQConfig with default values.")
  104. self.vq_config = ChameleonVQVAEConfig()
  105. elif isinstance(self.vq_config, dict):
  106. self.vq_config = ChameleonVQVAEConfig(**self.vq_config)
  107. self.image_token_id = self.vocabulary_map.get("<image>") if self.vocabulary_map is not None else None
  108. super().__post_init__(**kwargs)
  109. __all__ = ["ChameleonConfig", "ChameleonVQVAEConfig"]