configuration_musicgen.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """MusicGen model configuration"""
  15. from typing import ClassVar
  16. from huggingface_hub.dataclasses import strict
  17. from ...configuration_utils import PreTrainedConfig
  18. from ...utils import auto_docstring
  19. from ..auto.configuration_auto import AutoConfig
  20. @auto_docstring(checkpoint="facebook/musicgen-small")
  21. @strict
  22. class MusicgenDecoderConfig(PreTrainedConfig):
  23. model_type = "musicgen_decoder"
  24. base_config_key = "decoder_config"
  25. keys_to_ignore_at_inference = ["past_key_values"]
  26. vocab_size: int = 2048
  27. max_position_embeddings: int = 2048
  28. num_hidden_layers: int = 24
  29. ffn_dim: int = 4096
  30. num_attention_heads: int = 16
  31. layerdrop: float | int = 0.0
  32. use_cache: bool = True
  33. activation_function: str = "gelu"
  34. hidden_size: int = 1024
  35. dropout: float | int = 0.1
  36. attention_dropout: float | int = 0.0
  37. activation_dropout: float | int = 0.0
  38. initializer_factor: float = 0.02
  39. scale_embedding: bool = False
  40. num_codebooks: int = 4
  41. audio_channels: int = 1
  42. pad_token_id: int | None = 2048
  43. bos_token_id: int | None = 2048
  44. eos_token_id: int | list[int] | None = None
  45. tie_word_embeddings: bool = False
  46. is_decoder: bool = False
  47. add_cross_attention: bool = False
  48. cross_attention_hidden_size: int | None = None
  49. def validate_architecture(self):
  50. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  51. if self.audio_channels not in [1, 2]:
  52. raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {self.audio_channels} channels.")
  53. @auto_docstring(checkpoint="facebook/musicgen-small")
  54. @strict
  55. class MusicgenConfig(PreTrainedConfig):
  56. r"""
  57. text_encoder (`Union[dict, `PretrainedConfig`]`):
  58. An instance of a configuration object that defines the text encoder config.
  59. audio_encoder (`Union[dict, `PretrainedConfig`]`):
  60. An instance of a configuration object that defines the audio encoder config.
  61. decoder (`Union[dict, `PretrainedConfig`]`):
  62. An instance of a configuration object that defines the decoder config.
  63. Example:
  64. ```python
  65. >>> from transformers import (
  66. ... MusicgenConfig,
  67. ... MusicgenDecoderConfig,
  68. ... T5Config,
  69. ... EncodecConfig,
  70. ... MusicgenForConditionalGeneration,
  71. ... )
  72. >>> # Initializing text encoder, audio encoder, and decoder model configurations
  73. >>> text_encoder_config = T5Config()
  74. >>> audio_encoder_config = EncodecConfig()
  75. >>> decoder_config = MusicgenDecoderConfig()
  76. >>> configuration = MusicgenConfig(
  77. ... text_encoder=text_encoder_config,
  78. ... audio_encoder=audio_encoder_config,
  79. ... decoder=decoder_config,
  80. ... )
  81. >>> # Initializing a MusicgenForConditionalGeneration (with random weights) from the facebook/musicgen-small style configuration
  82. >>> model = MusicgenForConditionalGeneration(configuration)
  83. >>> # Accessing the model configuration
  84. >>> configuration = model.config
  85. >>> config_text_encoder = model.config.text_encoder
  86. >>> config_audio_encoder = model.config.audio_encoder
  87. >>> config_decoder = model.config.decoder
  88. >>> # Saving the model, including its configuration
  89. >>> model.save_pretrained("musicgen-model")
  90. >>> # loading model and config from pretrained folder
  91. >>> musicgen_config = MusicgenConfig.from_pretrained("musicgen-model")
  92. >>> model = MusicgenForConditionalGeneration.from_pretrained("musicgen-model", config=musicgen_config)
  93. ```"""
  94. model_type: ClassVar[str] = "musicgen"
  95. sub_configs: ClassVar[dict[str, type[PreTrainedConfig]]] = {
  96. "text_encoder": AutoConfig,
  97. "audio_encoder": AutoConfig,
  98. "decoder": MusicgenDecoderConfig,
  99. }
  100. has_no_defaults_at_init: ClassVar[bool] = True
  101. text_encoder: dict | PreTrainedConfig = None
  102. audio_encoder: dict | PreTrainedConfig = None
  103. decoder: dict | PreTrainedConfig = None
  104. initializer_factor: float = 0.02
  105. def __post_init__(self, **kwargs):
  106. if isinstance(self.text_encoder, dict):
  107. text_encoder_model_type = self.text_encoder.pop("model_type")
  108. self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **self.text_encoder)
  109. elif self.text_encoder is None:
  110. raise ValueError(
  111. f"A configuration of type {self.model_type} cannot be instantiated because text_encoder is not passed"
  112. )
  113. if isinstance(self.audio_encoder, dict):
  114. audio_encoder_model_type = self.audio_encoder.pop("model_type")
  115. self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **self.audio_encoder)
  116. elif self.audio_encoder is None:
  117. raise ValueError(
  118. f"A configuration of type {self.model_type} cannot be instantiated because audio_encoder is not passed"
  119. )
  120. if isinstance(self.decoder, dict):
  121. self.decoder = MusicgenDecoderConfig(**self.decoder)
  122. elif self.decoder is None:
  123. self.decoder = MusicgenDecoderConfig()
  124. self.is_encoder_decoder = True
  125. super().__post_init__(**kwargs)
  126. @property
  127. # This is a property because you might want to change the codec model on the fly
  128. def sampling_rate(self):
  129. return self.audio_encoder.sampling_rate
  130. __all__ = ["MusicgenConfig", "MusicgenDecoderConfig"]