configuration_mimi.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Mimi model configuration"""
  15. import math
  16. import numpy as np
  17. from huggingface_hub.dataclasses import strict
  18. from ...configuration_utils import PreTrainedConfig
  19. from ...modeling_rope_utils import RopeParameters
  20. from ...utils import auto_docstring
  21. @auto_docstring(checkpoint="kyutai/mimi")
  22. @strict
  23. class MimiConfig(PreTrainedConfig):
  24. r"""
  25. audio_channels (`int`, *optional*, defaults to 1):
  26. Number of channels in the audio data. Either 1 for mono or 2 for stereo.
  27. num_filters (`int`, *optional*, defaults to 64):
  28. Number of convolution kernels of first `MimiConv1d` down sampling layer.
  29. num_residual_layers (`int`, *optional*, defaults to 1):
  30. Number of residual layers.
  31. upsampling_ratios (`Sequence[int]`, *optional*):
  32. Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
  33. will use the ratios in the reverse order to the ones specified here that must match the decoder order.
  34. If not specified, will defaults to `[8, 6, 5, 4]`
  35. last_kernel_size (`int`, *optional*, defaults to 3):
  36. Kernel size for the last convolution layer.
  37. residual_kernel_size (`int`, *optional*, defaults to 3):
  38. Kernel size for the residual layers.
  39. dilation_growth_rate (`int`, *optional*, defaults to 2):
  40. How much to increase the dilation with each layer.
  41. use_causal_conv (`bool`, *optional*, defaults to `True`):
  42. Whether to use fully causal convolution.
  43. pad_mode (`str`, *optional*, defaults to `"constant"`):
  44. Padding mode for the convolutions.
  45. compress (`int`, *optional*, defaults to 2):
  46. Reduced dimensionality in residual branches.
  47. trim_right_ratio (`float`, *optional*, defaults to 1.0):
  48. Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
  49. equal to 1.0, it means that all the trimming is done at the right.
  50. num_quantizers (`int`, *optional*, defaults to 32):
  51. Number of quantizer channels, or codebooks, in the quantizer.
  52. use_conv_shortcut (`bool`, *optional*, defaults to `False`):
  53. Whether to use a convolutional layer as the 'skip' connection in the `MimiResnetBlock` block. If False,
  54. an identity function will be used, giving a generic residual connection.
  55. vector_quantization_hidden_dimension (`int`, *optional*, defaults to 256):
  56. Intermediate representation dimension in the residual vector quantization space.
  57. num_semantic_quantizers (`int`, *optional*, defaults to 1):
  58. Number of semantic quantizer channels, or codebooks, in the semantic quantizer. Must be lower than `num_quantizers`.
  59. upsample_groups (`int`, *optional*, defaults to 512):
  60. If `frame_rate!=encodec_frame_rate`, indicates the number of groups used in the upsampling operation to go from one rate to another.
  61. use_streaming (`bool`, *optional*, defaults to `False`):
  62. Whether to use streaming mode. If `True`, the model encode method will return the padding cache that can be used in a subsequent call to the encode method.
  63. Example:
  64. ```python
  65. >>> from transformers import MimiModel, MimiConfig
  66. >>> # Initializing a "kyutai/mimi" style configuration
  67. >>> configuration = MimiConfig()
  68. >>> # Initializing a model (with random weights) from the "kyutai/mimi" style configuration
  69. >>> model = MimiModel(configuration)
  70. >>> # Accessing the model configuration
  71. >>> configuration = model.config
  72. ```"""
  73. model_type = "mimi"
  74. sampling_rate: int = 24_000
  75. audio_channels: int = 1
  76. hidden_size: int = 512
  77. num_filters: int = 64
  78. num_residual_layers: int = 1
  79. upsampling_ratios: list[int] | None = None
  80. kernel_size: int = 7
  81. last_kernel_size: int = 3
  82. residual_kernel_size: int = 3
  83. dilation_growth_rate: int = 2
  84. use_causal_conv: bool = True
  85. pad_mode: str = "constant"
  86. compress: int = 2
  87. trim_right_ratio: float = 1.0
  88. codebook_size: int = 2048
  89. codebook_dim: int = 256
  90. num_quantizers: int = 32
  91. use_conv_shortcut: bool = False
  92. vector_quantization_hidden_dimension: int = 256
  93. num_semantic_quantizers: int = 1
  94. upsample_groups: int = 512
  95. num_hidden_layers: int = 8
  96. intermediate_size: int = 2048
  97. num_attention_heads: int = 8
  98. num_key_value_heads: int = 8
  99. head_dim: int | None = None
  100. hidden_act: str = "gelu"
  101. max_position_embeddings: int = 8000
  102. initializer_range: float = 0.02
  103. norm_eps: float = 1e-5
  104. use_cache: bool = False
  105. use_streaming: bool = False
  106. rope_parameters: RopeParameters | dict | None = None
  107. sliding_window: int = 250
  108. attention_dropout: float | int = 0.0
  109. layer_scale_initial_scale: float = 0.01
  110. attention_bias: bool = False
  111. tie_word_embeddings: bool = True
  112. def __post_init__(self, **kwargs):
  113. self.upsampling_ratios = self.upsampling_ratios if self.upsampling_ratios else [8, 6, 5, 4]
  114. self.codebook_dim = self.codebook_dim if self.codebook_dim is not None else self.hidden_size
  115. self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads
  116. # Handle backward compatibility for frame_rate:
  117. # If frame_rate is explicitly provided, use it (backward compatibility)
  118. # Otherwise, compute it from other parameters (correctly)
  119. self._frame_rate = kwargs.pop("frame_rate", None)
  120. super().__post_init__(**kwargs)
  121. def validate_architecture(self):
  122. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  123. if self.num_semantic_quantizers >= self.num_quantizers:
  124. raise ValueError(
  125. f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {self.num_semantic_quantizers}."
  126. )
  127. @property
  128. def encodec_frame_rate(self) -> int:
  129. hop_length = np.prod(self.upsampling_ratios)
  130. return math.ceil(self.sampling_rate / hop_length)
  131. @property
  132. def num_codebooks(self) -> int:
  133. # alias to num_quantizers
  134. return self.num_quantizers
  135. @property
  136. def frame_size(self) -> int:
  137. # 1. we need each encoder conv stride
  138. # first conv
  139. strides = [1]
  140. # layer convs
  141. for ratio in reversed(self.upsampling_ratios):
  142. for j in range(self.num_residual_layers):
  143. len_kernel_sizes = len(self.residual_kernel_size) if isinstance(self.residual_kernel_size, list) else 1
  144. strides.extend([1] * (len_kernel_sizes + 1))
  145. if self.use_conv_shortcut: # skip connection
  146. strides.append(1)
  147. strides.append(ratio)
  148. # last conv
  149. strides.append(1)
  150. # downsampling layer
  151. strides.append(2)
  152. return math.prod(strides)
  153. @property
  154. def frame_rate(self) -> float:
  155. # handle backward compatibility
  156. if self._frame_rate is not None:
  157. return self._frame_rate
  158. return self.sampling_rate / self.frame_size
  159. __all__ = ["MimiConfig"]