configuration_emu3.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...modeling_rope_utils import RopeParameters
  18. from ...utils import auto_docstring
  19. @auto_docstring(checkpoint="Emu3-community/Emu3-Chat-hf")
  20. @strict
  21. class Emu3VQVAEConfig(PreTrainedConfig):
  22. r"""
  23. embed_dim (`int`, *optional*, defaults to 4):
  24. Dimension of the quantized vector in codebook.
  25. out_channels (`int`, *optional*, defaults to 3):
  26. Output channel of decoder.
  27. temporal_downsample_factor (`int`, *optional*, defaults to 4):
  28. Temporal downsample factor.
  29. base_channels (`int`, *optional*, defaults to 256):
  30. Basic channel number of the intermediate blocks.
  31. channel_multiplier (`list[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
  32. Channel scaling factor of the intermediate blocks.
  33. num_res_blocks (`int`, *optional*, defaults to 2):
  34. Residual block number in each stage.
  35. attn_resolutions (`list[int]`, *optional*, defaults to `[3]`):
  36. Stage indices to apply attention.
  37. ```python
  38. >>> from transformers import Emu3VQVAE, Emu3VQVAEConfig
  39. >>> # Initializing a video VQ model of Emu3 configuration
  40. >>> configuration = Emu3VQVAEConfig()
  41. >>> # Initializing a model from the Emu3 VQ model style configuration
  42. >>> model = Emu3VQVAE(configuration)
  43. >>> # Accessing the model configuration
  44. >>> configuration = model.config
  45. ```
  46. """
  47. model_type = "emu3_vqgan"
  48. base_config_key = "vq_config"
  49. codebook_size: int = 32768
  50. embed_dim: int = 4
  51. latent_channels: int = 4
  52. double_latent: bool = False
  53. in_channels: int = 3
  54. out_channels: int = 3
  55. temporal_downsample_factor: int = 4
  56. base_channels: int = 256
  57. channel_multiplier: list[int] | tuple[int, ...] = (1, 2, 2, 4)
  58. num_res_blocks: int = 2
  59. attn_resolutions: list[int] | tuple[int, ...] = (3,)
  60. hidden_size: int = 1024
  61. num_attention_heads: int = 1
  62. attention_dropout: float | int = 0.0
  63. @auto_docstring(checkpoint="Emu3-community/Emu3-Chat-hf")
  64. @strict
  65. class Emu3TextConfig(PreTrainedConfig):
  66. r"""
  67. Example:
  68. ```python
  69. >>> from transformers import Emu3Model, Emu3Config
  70. >>> # Initializing a Emu3-community/Emu3-Chat-hf style configuration
  71. >>> configuration = Emu3Config()
  72. >>> # Initializing a model from the Emu3-community/Emu3-Chat-hf style configuration
  73. >>> model = Emu3Model(configuration)
  74. >>> # Accessing the model configuration
  75. >>> configuration = model.config
  76. ```"""
  77. model_type = "emu3_text_model"
  78. base_config_key = "text_config"
  79. keys_to_ignore_at_inference = ["past_key_values"]
  80. default_theta = 1000000.0
  81. vocab_size: int = 184622
  82. hidden_size: int = 4096
  83. intermediate_size: int = 14336
  84. num_hidden_layers: int = 32
  85. num_attention_heads: int = 32
  86. num_key_value_heads: int | None = 8
  87. hidden_act: str = "silu"
  88. max_position_embeddings: int = 9216
  89. rms_norm_eps: float = 1e-5
  90. use_cache: bool = True
  91. pad_token_id: int = 151643
  92. bos_token_id: int = 151849
  93. eos_token_id: int | list[int] | None = 151850
  94. rope_parameters: RopeParameters | dict | None = None
  95. mlp_bias = False
  96. attention_bias = False
  97. attention_dropout: float | int = 0.1
  98. initializer_range: float = 0.02
  99. tie_word_embeddings: bool = False
  100. @auto_docstring(checkpoint="Emu3-community/Emu3-Chat-hf")
  101. @strict
  102. class Emu3Config(PreTrainedConfig):
  103. r"""
  104. vocabulary_map (`dict`, *optional*):
  105. A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs.
  106. """
  107. model_type = "emu3"
  108. keys_to_ignore_at_inference = ["past_key_values"]
  109. sub_configs = {"text_config": Emu3TextConfig, "vq_config": Emu3VQVAEConfig}
  110. vq_config: dict | Emu3VQVAEConfig | None = None
  111. text_config: dict | Emu3TextConfig | None = None
  112. vocabulary_map: dict[str, int] | None = None
  113. tie_word_embeddings: bool = False
  114. def __post_init__(self, **kwargs):
  115. if self.vq_config is None:
  116. self.vq_config = Emu3VQVAEConfig()
  117. elif isinstance(self.vq_config, dict):
  118. self.vq_config = Emu3VQVAEConfig(**self.vq_config)
  119. if self.text_config is None:
  120. self.text_config = Emu3TextConfig()
  121. elif isinstance(self.text_config, dict):
  122. self.text_config = Emu3TextConfig(**self.text_config)
  123. self.image_token_id = self.vocabulary_map.get("<image>") if self.vocabulary_map is not None else None
  124. super().__post_init__(**kwargs)
  125. __all__ = ["Emu3Config", "Emu3TextConfig", "Emu3VQVAEConfig"]