modeling_mimi.py 77 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772
  1. # Copyright 2024 Kyutai, and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Mimi model."""
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. from typing import Optional
  19. import torch
  20. from torch import nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, StaticCache
  24. from ...masking_utils import create_sliding_window_causal_mask
  25. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutputWithPast
  28. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import ModelOutput, auto_docstring, logging
  31. from ...utils.generic import maybe_autocast
  32. from .configuration_mimi import MimiConfig
  33. if is_flash_attn_available():
  34. from ...modeling_flash_attention_utils import _flash_attention_forward
  35. logger = logging.get_logger(__name__)
  36. @dataclass
  37. @auto_docstring
  38. class MimiOutput(ModelOutput):
  39. r"""
  40. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
  41. Discret code embeddings computed using `model.encode`.
  42. audio_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  43. Decoded audio values, obtained using the decoder part of Mimi.
  44. encoder_past_key_values (`Cache`, *optional*):
  45. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
  46. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  47. The model will output the same cache format that is fed as input.
  48. If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
  49. have their past key value states given to this model).
  50. decoder_past_key_values (`Cache`, *optional*):
  51. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
  52. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  53. The model will output the same cache format that is fed as input.
  54. If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
  55. have their past key value states given to this model).
  56. """
  57. audio_codes: torch.LongTensor | None = None
  58. audio_values: torch.FloatTensor | None = None
  59. encoder_past_key_values: Cache | None = None
  60. decoder_past_key_values: Cache | None = None
  61. class MimiConv1dPaddingCache:
  62. """
  63. Padding cache for MimiConv1d causal convolutions in order to support streaming via cache padding.
  64. See: https://huggingface.co/papers/2005.06720 & https://huggingface.co/papers/2204.07064
  65. A padding cache is a list of cached partial hidden states for each convolution layer.
  66. Hidden states are cached from the previous call to the MimiConv1d forward pass, given the padding size.
  67. """
  68. def __init__(
  69. self,
  70. num_layers: int,
  71. per_layer_padding: list[int],
  72. per_layer_padding_mode: list[str],
  73. per_layer_in_channels: list[int],
  74. ):
  75. # ensure correct number of layers for each arg
  76. from_args_num_layers = {len(per_layer_padding), len(per_layer_padding_mode), len(per_layer_in_channels)}
  77. if len(from_args_num_layers) != 1 or from_args_num_layers.pop() != num_layers:
  78. raise ValueError(
  79. f"Expected `num_layers` ({num_layers}) values in `per_layer_padding`, `per_layer_padding_mode` and `per_layer_in_channels`"
  80. )
  81. self.per_layer_padding = per_layer_padding
  82. self.per_layer_padding_mode = per_layer_padding_mode
  83. self.per_layer_in_channels = per_layer_in_channels
  84. self.padding_cache = [None] * num_layers
  85. def _cache_init(self, hidden_states: torch.Tensor, layer_idx: int):
  86. """
  87. Initialize the cache for a specific layer.
  88. Parameters:
  89. hidden_states (`torch.Tensor`):
  90. The hidden states to initialize the cache with.
  91. layer_idx (`int`):
  92. The index of the layer to initialize the cache for.
  93. Returns:
  94. `torch.Tensor`, the initialized cache.
  95. """
  96. batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
  97. padding, padding_mode, in_channels = (
  98. self.per_layer_padding[layer_idx],
  99. self.per_layer_padding_mode[layer_idx],
  100. self.per_layer_in_channels[layer_idx],
  101. )
  102. if padding_mode == "constant":
  103. current_cache = torch.zeros(batch_size, in_channels, padding, device=device, dtype=dtype)
  104. elif padding_mode == "replicate":
  105. current_cache = (
  106. torch.ones(batch_size, in_channels, padding, device=device, dtype=dtype) * hidden_states[..., :1]
  107. )
  108. else:
  109. raise NotImplementedError(f"Padding mode {padding_mode} not supported")
  110. return current_cache
  111. def update(self, hidden_states: torch.Tensor, layer_idx: int):
  112. """
  113. Updates the padding cache with the new padding states for the layer `layer_idx` and returns the current cache.
  114. Parameters:
  115. hidden_states (`torch.Tensor`):
  116. The hidden states to be partially cached.
  117. layer_idx (`int`):
  118. The index of the layer to cache the states for.
  119. Returns:
  120. `torch.Tensor` or `None`, the current padding cache.
  121. """
  122. batch_size, dtype, device = hidden_states.shape[0], hidden_states.dtype, hidden_states.device
  123. padding, in_channels = self.per_layer_padding[layer_idx], self.per_layer_in_channels[layer_idx]
  124. if self.padding_cache[layer_idx] is None:
  125. current_cache = self._cache_init(hidden_states, layer_idx)
  126. else:
  127. current_cache = self.padding_cache[layer_idx]
  128. # update the cache
  129. if padding > 0:
  130. shortfall = max(0, padding - hidden_states.shape[-1])
  131. if shortfall > 0:
  132. padding_states = torch.cat([current_cache[:, :, -shortfall:], hidden_states], dim=-1)
  133. else:
  134. padding_states = hidden_states[:, :, -padding:]
  135. else:
  136. padding_states = torch.empty(batch_size, in_channels, 0, dtype=dtype, device=device)
  137. self.padding_cache[layer_idx] = padding_states
  138. return current_cache
  139. @dataclass
  140. @auto_docstring
  141. class MimiEncoderOutput(ModelOutput):
  142. r"""
  143. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
  144. Discret code embeddings computed using `model.encode`.
  145. encoder_past_key_values (`Cache`, *optional*):
  146. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
  147. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  148. The model will output the same cache format that is fed as input.
  149. If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
  150. have their past key value states given to this model).
  151. padding_cache (`MimiConv1dPaddingCache`, *optional*):
  152. Padding cache for MimiConv1d causal convolutions in order to support streaming via cache padding.
  153. """
  154. audio_codes: torch.LongTensor | None = None
  155. encoder_past_key_values: Cache | None = None
  156. padding_cache: MimiConv1dPaddingCache | None = None
  157. @dataclass
  158. @auto_docstring
  159. class MimiDecoderOutput(ModelOutput):
  160. r"""
  161. audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
  162. Decoded audio values, obtained using the decoder part of Mimi.
  163. decoder_past_key_values (`Cache`, *optional*):
  164. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
  165. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  166. The model will output the same cache format that is fed as input.
  167. If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
  168. have their past key value states given to this model).
  169. """
  170. audio_values: torch.FloatTensor | None = None
  171. decoder_past_key_values: Cache | None = None
  172. class MimiConv1d(nn.Module):
  173. """Conv1d with asymmetric or causal padding and normalization."""
  174. def __init__(
  175. self,
  176. config,
  177. in_channels: int,
  178. out_channels: int,
  179. kernel_size: int,
  180. stride: int = 1,
  181. dilation: int = 1,
  182. groups: int = 1,
  183. pad_mode: str | None = None,
  184. bias: bool = True,
  185. layer_idx: int | None = None,
  186. ):
  187. super().__init__()
  188. self.causal = config.use_causal_conv
  189. self.pad_mode = config.pad_mode if pad_mode is None else pad_mode
  190. self.layer_idx = layer_idx
  191. self.in_channels = in_channels
  192. # warn user on unusual setup between dilation and stride
  193. if stride > 1 and dilation > 1:
  194. logger.warning(
  195. "MimiConv1d has been initialized with stride > 1 and dilation > 1"
  196. f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
  197. )
  198. self.conv = nn.Conv1d(
  199. in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias
  200. )
  201. kernel_size = self.conv.kernel_size[0]
  202. stride = torch.tensor(self.conv.stride[0], dtype=torch.int64)
  203. dilation = self.conv.dilation[0]
  204. # Effective kernel size with dilations.
  205. kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
  206. self.register_buffer("stride", stride, persistent=False)
  207. self.register_buffer("kernel_size", kernel_size, persistent=False)
  208. self.register_buffer("padding_total", kernel_size - stride, persistent=False)
  209. # Asymmetric padding required for odd strides
  210. self.padding_right = self.padding_total // 2
  211. self.padding_left = self.padding_total - self.padding_right
  212. def apply_weight_norm(self):
  213. weight_norm = nn.utils.weight_norm
  214. if hasattr(nn.utils.parametrizations, "weight_norm"):
  215. weight_norm = nn.utils.parametrizations.weight_norm
  216. weight_norm(self.conv)
  217. def remove_weight_norm(self):
  218. nn.utils.remove_weight_norm(self.conv)
  219. # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._get_extra_padding_for_conv1d
  220. def _get_extra_padding_for_conv1d(
  221. self,
  222. hidden_states: torch.Tensor,
  223. ) -> torch.Tensor:
  224. """See `pad_for_conv1d`."""
  225. length = hidden_states.shape[-1]
  226. n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
  227. n_frames = torch.ceil(n_frames).to(torch.int64) - 1
  228. ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
  229. return ideal_length - length
  230. @staticmethod
  231. # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._pad1d
  232. def _pad1d(hidden_states: torch.Tensor, paddings: tuple[int, int], mode: str = "zero", value: float = 0.0):
  233. """Tiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input.
  234. If this is the case, we insert extra 0 padding to the right before the reflection happens.
  235. """
  236. length = hidden_states.shape[-1]
  237. padding_left, padding_right = paddings
  238. if mode != "reflect":
  239. return nn.functional.pad(hidden_states, paddings, mode, value)
  240. max_pad = max(padding_left, padding_right)
  241. extra_pad = 0
  242. if length <= max_pad:
  243. extra_pad = max_pad - length + 1
  244. hidden_states = nn.functional.pad(hidden_states, (0, extra_pad))
  245. padded = nn.functional.pad(hidden_states, paddings, mode, value)
  246. end = padded.shape[-1] - extra_pad
  247. return padded[..., :end]
  248. def _get_output_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
  249. """
  250. Return the length of the output of the MimiConv1d.
  251. """
  252. # padding size
  253. n_frames = (input_length - self.kernel_size + self.padding_total) / self.stride + 1
  254. n_frames = torch.ceil(n_frames).to(torch.int64) - 1
  255. ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
  256. extra_padding = ideal_length - input_length
  257. if self.causal:
  258. padding_left = self.padding_total
  259. padding_right = extra_padding
  260. else:
  261. padding_left = self.padding_left
  262. padding_right = self.padding_right + extra_padding
  263. # padding
  264. input_length = input_length + padding_left + padding_right
  265. # conv
  266. output_length = (
  267. input_length + 2 * self.conv.padding[0] - self.conv.dilation[0] * (self.conv.kernel_size[0] - 1) - 1
  268. ) // self.conv.stride[0] + 1
  269. return output_length
  270. def forward(self, hidden_states, padding_cache=None):
  271. extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
  272. if not self.causal and padding_cache is not None:
  273. raise ValueError("`padding_cache` is not supported for non-causal convolutions.")
  274. if self.causal and padding_cache is not None:
  275. layer_padding_cache = padding_cache.update(hidden_states, self.layer_idx)
  276. hidden_states = torch.cat([layer_padding_cache, hidden_states], dim=2)
  277. elif self.causal:
  278. # Left padding for causal
  279. hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
  280. else:
  281. hidden_states = self._pad1d(
  282. hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode
  283. )
  284. hidden_states = self.conv(hidden_states)
  285. return hidden_states
  286. class MimiConvTranspose1d(nn.Module):
  287. """ConvTranspose1d with asymmetric or causal padding and normalization."""
  288. def __init__(
  289. self,
  290. config,
  291. in_channels: int,
  292. out_channels: int,
  293. kernel_size: int,
  294. stride: int = 1,
  295. groups: int = 1,
  296. bias=True,
  297. ):
  298. super().__init__()
  299. self.causal = config.use_causal_conv
  300. self.trim_right_ratio = config.trim_right_ratio
  301. self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias)
  302. if not (self.causal or self.trim_right_ratio == 1.0):
  303. raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions")
  304. kernel_size = self.conv.kernel_size[0]
  305. stride = self.conv.stride[0]
  306. padding_total = kernel_size - stride
  307. # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
  308. # removed at the very end, when keeping only the right length for the output,
  309. # as removing it here would require also passing the length at the matching layer
  310. # in the encoder.
  311. if self.causal:
  312. # Trim the padding on the right according to the specified ratio
  313. # if trim_right_ratio = 1.0, trim everything from right
  314. self.padding_right = math.ceil(padding_total * self.trim_right_ratio)
  315. else:
  316. # Asymmetric padding required for odd strides
  317. self.padding_right = padding_total // 2
  318. self.padding_left = padding_total - self.padding_right
  319. def apply_weight_norm(self):
  320. weight_norm = nn.utils.weight_norm
  321. if hasattr(nn.utils.parametrizations, "weight_norm"):
  322. weight_norm = nn.utils.parametrizations.weight_norm
  323. weight_norm(self.conv)
  324. def remove_weight_norm(self):
  325. nn.utils.remove_weight_norm(self.conv)
  326. def forward(self, hidden_states):
  327. hidden_states = self.conv(hidden_states)
  328. # unpad
  329. end = hidden_states.shape[-1] - self.padding_right
  330. hidden_states = hidden_states[..., self.padding_left : end]
  331. return hidden_states
  332. class MimiResnetBlock(nn.Module):
  333. """
  334. Residual block from SEANet model as used by Mimi.
  335. """
  336. def __init__(self, config: MimiConfig, dim: int, dilations: list[int]):
  337. super().__init__()
  338. kernel_sizes = (config.residual_kernel_size, 1)
  339. if len(kernel_sizes) != len(dilations):
  340. raise ValueError("Number of kernel sizes should match number of dilations")
  341. hidden = dim // config.compress
  342. block = []
  343. for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
  344. in_chs = dim if i == 0 else hidden
  345. out_chs = dim if i == len(kernel_sizes) - 1 else hidden
  346. block += [nn.ELU()]
  347. block += [MimiConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
  348. self.block = nn.ModuleList(block)
  349. if config.use_conv_shortcut:
  350. self.shortcut = MimiConv1d(config, dim, dim, kernel_size=1)
  351. else:
  352. self.shortcut = nn.Identity()
  353. def forward(self, hidden_states, padding_cache=None):
  354. residual = hidden_states
  355. for layer in self.block:
  356. if isinstance(layer, MimiConv1d):
  357. hidden_states = layer(hidden_states, padding_cache=padding_cache)
  358. else:
  359. hidden_states = layer(hidden_states)
  360. if isinstance(self.shortcut, MimiConv1d):
  361. residual = self.shortcut(residual, padding_cache=padding_cache)
  362. else:
  363. residual = self.shortcut(residual)
  364. return residual + hidden_states
  365. class MimiEncoder(nn.Module):
  366. """SEANet encoder as used by Mimi."""
  367. def __init__(self, config: MimiConfig):
  368. super().__init__()
  369. model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
  370. scaling = 1
  371. # keep track of MimiConv1d submodule layer names for easy encoded length computation
  372. mimiconv1d_layer_names = ["layers.0"]
  373. # Downsample to raw audio scale
  374. for ratio in reversed(config.upsampling_ratios):
  375. current_scale = scaling * config.num_filters
  376. # Add residual layers
  377. for j in range(config.num_residual_layers):
  378. mimiconv1d_layer_names.extend([f"layers.{len(model)}.block.1", f"layers.{len(model)}.block.3"])
  379. model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
  380. # Add downsampling layers
  381. model += [nn.ELU()]
  382. mimiconv1d_layer_names.append(f"layers.{len(model)}")
  383. model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
  384. scaling *= 2
  385. model += [nn.ELU()]
  386. mimiconv1d_layer_names.append(f"layers.{len(model)}")
  387. model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
  388. self.layers = nn.ModuleList(model)
  389. self._mimiconv1d_layer_names = mimiconv1d_layer_names
  390. # initialize layer_idx for MimiConv1d submodules, necessary for padding_cache
  391. for layer_idx, layername in enumerate(self._mimiconv1d_layer_names):
  392. conv_layer = self.get_submodule(layername)
  393. setattr(conv_layer, "layer_idx", layer_idx)
  394. def forward(self, hidden_states, padding_cache=None):
  395. for layer in self.layers:
  396. if isinstance(layer, (MimiConv1d, MimiResnetBlock)):
  397. hidden_states = layer(hidden_states, padding_cache=padding_cache)
  398. else:
  399. hidden_states = layer(hidden_states)
  400. return hidden_states
  401. class MimiLayerScale(nn.Module):
  402. """Layer scale from [Touvron et al 2021] (https://huggingface.co/papers/2103.17239).
  403. This rescales diagonally the residual outputs close to 0, with a learnt scale.
  404. """
  405. def __init__(self, config):
  406. super().__init__()
  407. channels = config.hidden_size
  408. initial_scale = config.layer_scale_initial_scale
  409. self.scale = nn.Parameter(torch.full((channels,), initial_scale, requires_grad=True))
  410. def forward(self, x: torch.Tensor):
  411. return self.scale * x
  412. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mimi
  413. class MimiRotaryEmbedding(nn.Module):
  414. inv_freq: torch.Tensor # fix linting for `register_buffer`
  415. def __init__(self, config: MimiConfig, device=None):
  416. super().__init__()
  417. self.max_seq_len_cached = config.max_position_embeddings
  418. self.original_max_seq_len = config.max_position_embeddings
  419. self.config = config
  420. self.rope_type = self.config.rope_parameters["rope_type"]
  421. rope_init_fn: Callable = self.compute_default_rope_parameters
  422. if self.rope_type != "default":
  423. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  424. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  425. self.register_buffer("inv_freq", inv_freq, persistent=False)
  426. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  427. @staticmethod
  428. def compute_default_rope_parameters(
  429. config: MimiConfig | None = None,
  430. device: Optional["torch.device"] = None,
  431. seq_len: int | None = None,
  432. ) -> tuple["torch.Tensor", float]:
  433. """
  434. Computes the inverse frequencies according to the original RoPE implementation
  435. Args:
  436. config ([`~transformers.PreTrainedConfig`]):
  437. The model configuration.
  438. device (`torch.device`):
  439. The device to use for initialization of the inverse frequencies.
  440. seq_len (`int`, *optional*):
  441. The current sequence length. Unused for this type of RoPE.
  442. Returns:
  443. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  444. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  445. """
  446. base = config.rope_parameters["rope_theta"]
  447. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  448. attention_factor = 1.0 # Unused in this type of RoPE
  449. # Compute the inverse frequencies
  450. inv_freq = 1.0 / (
  451. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  452. )
  453. return inv_freq, attention_factor
  454. @torch.no_grad()
  455. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  456. def forward(self, x, position_ids):
  457. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  458. position_ids_expanded = position_ids[:, None, :].float()
  459. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  460. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  461. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  462. emb = torch.cat((freqs, freqs), dim=-1)
  463. cos = emb.cos() * self.attention_scaling
  464. sin = emb.sin() * self.attention_scaling
  465. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  466. # Copied from transformers.models.llama.modeling_llama.rotate_half
  467. def rotate_half(x):
  468. """Rotates half the hidden dims of the input."""
  469. x1 = x[..., : x.shape[-1] // 2]
  470. x2 = x[..., x.shape[-1] // 2 :]
  471. return torch.cat((-x2, x1), dim=-1)
  472. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  473. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  474. """Applies Rotary Position Embedding to the query and key tensors.
  475. Args:
  476. q (`torch.Tensor`): The query tensor.
  477. k (`torch.Tensor`): The key tensor.
  478. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  479. sin (`torch.Tensor`): The sine part of the rotary embedding.
  480. unsqueeze_dim (`int`, *optional*, defaults to 1):
  481. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  482. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  483. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  484. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  485. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  486. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  487. Returns:
  488. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  489. """
  490. cos = cos.unsqueeze(unsqueeze_dim)
  491. sin = sin.unsqueeze(unsqueeze_dim)
  492. q_embed = (q * cos) + (rotate_half(q) * sin)
  493. k_embed = (k * cos) + (rotate_half(k) * sin)
  494. return q_embed, k_embed
  495. class MimiMLP(nn.Module):
  496. def __init__(self, config):
  497. super().__init__()
  498. self.config = config
  499. self.activation_fn = ACT2FN[config.hidden_act]
  500. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
  501. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  502. # Copied from transformers.models.clip.modeling_clip.CLIPMLP.forward
  503. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  504. hidden_states = self.fc1(hidden_states)
  505. hidden_states = self.activation_fn(hidden_states)
  506. hidden_states = self.fc2(hidden_states)
  507. return hidden_states
  508. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  509. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  510. """
  511. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  512. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  513. """
  514. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  515. if n_rep == 1:
  516. return hidden_states
  517. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  518. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  519. # copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi
  520. # no longer copied after attention refactors
  521. class MimiAttention(nn.Module):
  522. """Multi-headed attention from 'Attention Is All You Need' paper"""
  523. def __init__(self, config: MimiConfig, layer_idx: int | None = None):
  524. super().__init__()
  525. self.config = config
  526. self.layer_idx = layer_idx
  527. if layer_idx is None:
  528. logger.warning_once(
  529. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  530. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  531. "when creating this class."
  532. )
  533. self.attention_dropout = config.attention_dropout
  534. self.hidden_size = config.hidden_size
  535. self.num_heads = config.num_attention_heads
  536. self.head_dim = config.head_dim
  537. self.num_key_value_heads = config.num_key_value_heads
  538. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  539. self.max_position_embeddings = config.max_position_embeddings
  540. self.is_causal = True
  541. self.scaling = 1 / math.sqrt(config.head_dim)
  542. if self.hidden_size % self.num_heads != 0:
  543. raise ValueError(
  544. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  545. f" and `num_heads`: {self.num_heads})."
  546. )
  547. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  548. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  549. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  550. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
  551. self.rotary_emb = MimiRotaryEmbedding(config)
  552. self.sliding_window = config.sliding_window # Ignore copy
  553. def forward(
  554. self,
  555. hidden_states: torch.Tensor,
  556. attention_mask: torch.Tensor | None = None,
  557. position_ids: torch.LongTensor | None = None,
  558. past_key_values: Cache | None = None,
  559. output_attentions: bool = False,
  560. use_cache: bool = False,
  561. **kwargs,
  562. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  563. bsz, q_len, _ = hidden_states.size()
  564. query_states = self.q_proj(hidden_states)
  565. key_states = self.k_proj(hidden_states)
  566. value_states = self.v_proj(hidden_states)
  567. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  568. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  569. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  570. cos, sin = self.rotary_emb(value_states, position_ids)
  571. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  572. if past_key_values is not None:
  573. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  574. key_states = repeat_kv(key_states, self.num_key_value_groups)
  575. value_states = repeat_kv(value_states, self.num_key_value_groups)
  576. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
  577. if attention_mask is not None:
  578. attn_weights = attn_weights + attention_mask
  579. # upcast attention to fp32
  580. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  581. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  582. attn_output = torch.matmul(attn_weights, value_states)
  583. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  584. raise ValueError(
  585. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  586. f" {attn_output.size()}"
  587. )
  588. attn_output = attn_output.transpose(1, 2).contiguous()
  589. attn_output = attn_output.view(bsz, q_len, -1)
  590. attn_output = self.o_proj(attn_output)
  591. if not output_attentions:
  592. attn_weights = None
  593. return attn_output, attn_weights
  594. # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi
  595. # TODO cyril: modular
  596. class MimiFlashAttention2(MimiAttention):
  597. """
  598. Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays
  599. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  600. flash attention and deal with padding tokens in case the input contains any of them.
  601. """
  602. def __init__(self, *args, **kwargs):
  603. super().__init__(*args, **kwargs)
  604. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  605. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  606. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  607. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  608. def forward(
  609. self,
  610. hidden_states: torch.Tensor,
  611. attention_mask: torch.LongTensor | None = None,
  612. position_ids: torch.LongTensor | None = None,
  613. past_key_values: Cache | None = None,
  614. output_attentions: bool = False,
  615. use_cache: bool = False,
  616. **kwargs,
  617. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  618. if isinstance(past_key_values, StaticCache):
  619. raise ValueError(
  620. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  621. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  622. )
  623. output_attentions = False
  624. bsz, q_len, _ = hidden_states.size()
  625. query_states = self.q_proj(hidden_states)
  626. key_states = self.k_proj(hidden_states)
  627. value_states = self.v_proj(hidden_states)
  628. # Flash attention requires the input to have the shape
  629. # batch_size x seq_length x head_dim x hidden_dim
  630. # therefore we just need to keep the original shape
  631. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  632. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  633. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  634. cos, sin = self.rotary_emb(value_states, position_ids)
  635. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  636. if past_key_values is not None:
  637. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  638. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  639. # to be able to avoid many of these transpose/reshape/view.
  640. query_states = query_states.transpose(1, 2)
  641. key_states = key_states.transpose(1, 2)
  642. value_states = value_states.transpose(1, 2)
  643. dropout_rate = self.attention_dropout if self.training else 0.0
  644. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  645. # therefore the input hidden states gets silently casted in float32. Hence, we need
  646. # cast them back in the correct dtype just to be sure everything works as expected.
  647. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  648. # in fp32. (MimiRMSNorm handles it correctly)
  649. input_dtype = query_states.dtype
  650. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  651. if input_dtype == torch.float32:
  652. if torch.is_autocast_enabled(device_type):
  653. target_dtype = torch.get_autocast_dtype(device_type)
  654. # Handle the case where the model is quantized
  655. elif hasattr(self.config, "_is_quantized"):
  656. target_dtype = self.config.dtype
  657. else:
  658. target_dtype = self.q_proj.weight.dtype
  659. logger.warning_once(
  660. f"The input hidden states seems to be silently casted in float32, this might be related to"
  661. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  662. f" {target_dtype}."
  663. )
  664. query_states = query_states.to(target_dtype)
  665. key_states = key_states.to(target_dtype)
  666. value_states = value_states.to(target_dtype)
  667. attn_output = _flash_attention_forward(
  668. query_states,
  669. key_states,
  670. value_states,
  671. attention_mask,
  672. q_len,
  673. position_ids=position_ids,
  674. dropout=dropout_rate,
  675. sliding_window=getattr(self, "sliding_window", None),
  676. is_causal=self.is_causal,
  677. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  678. )
  679. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  680. attn_output = self.o_proj(attn_output)
  681. if not output_attentions:
  682. attn_weights = None
  683. return attn_output, attn_weights
  684. # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi
  685. # TODO cyril: modular
  686. class MimiSdpaAttention(MimiAttention):
  687. """
  688. Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  689. `MimiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  690. SDPA API.
  691. """
  692. # Adapted from MimiAttention.forward
  693. def forward(
  694. self,
  695. hidden_states: torch.Tensor,
  696. attention_mask: torch.Tensor | None = None,
  697. position_ids: torch.LongTensor | None = None,
  698. past_key_values: Cache | None = None,
  699. output_attentions: bool = False,
  700. use_cache: bool = False,
  701. **kwargs,
  702. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  703. if output_attentions:
  704. logger.warning_once(
  705. f"{self.__class__.__name__} does not support `output_attentions=True`. The returned attention weights will "
  706. "be `None`. If you want to get attention weights, please set `attn_implementation='eager'` when loading the model."
  707. )
  708. bsz, q_len, _ = hidden_states.size()
  709. query_states = self.q_proj(hidden_states)
  710. key_states = self.k_proj(hidden_states)
  711. value_states = self.v_proj(hidden_states)
  712. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  713. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  714. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  715. cos, sin = self.rotary_emb(value_states, position_ids)
  716. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  717. if past_key_values is not None:
  718. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  719. key_states = repeat_kv(key_states, self.num_key_value_groups)
  720. value_states = repeat_kv(value_states, self.num_key_value_groups)
  721. causal_mask = attention_mask
  722. if attention_mask is not None:
  723. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  724. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  725. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  726. is_causal = causal_mask is None and q_len > 1
  727. attn_output = torch.nn.functional.scaled_dot_product_attention(
  728. query_states,
  729. key_states,
  730. value_states,
  731. attn_mask=causal_mask,
  732. dropout_p=self.attention_dropout if self.training else 0.0,
  733. is_causal=is_causal,
  734. )
  735. attn_output = attn_output.transpose(1, 2).contiguous()
  736. attn_output = attn_output.view(bsz, q_len, -1)
  737. attn_output = self.o_proj(attn_output)
  738. return attn_output, None
  739. MIMI_ATTENTION_CLASSES = {
  740. "eager": MimiAttention,
  741. "flash_attention_2": MimiFlashAttention2,
  742. "sdpa": MimiSdpaAttention,
  743. }
  744. class MimiTransformerLayer(GradientCheckpointingLayer):
  745. def __init__(self, config: MimiConfig, layer_idx: int):
  746. super().__init__()
  747. self.hidden_size = config.hidden_size
  748. self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  749. self.mlp = MimiMLP(config)
  750. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
  751. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
  752. self.self_attn_layer_scale = MimiLayerScale(config)
  753. self.mlp_layer_scale = MimiLayerScale(config)
  754. def forward(
  755. self,
  756. hidden_states: torch.Tensor,
  757. attention_mask: torch.Tensor | None = None,
  758. position_ids: torch.LongTensor | None = None,
  759. past_key_values: Cache | None = None,
  760. output_attentions: bool | None = False,
  761. use_cache: bool | None = False,
  762. **kwargs,
  763. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  764. """
  765. Args:
  766. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  767. attention_mask (`torch.FloatTensor`, *optional*):
  768. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  769. query_sequence_length, key_sequence_length)` if default attention is used.
  770. output_attentions (`bool`, *optional*):
  771. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  772. returned tensors for more detail.
  773. use_cache (`bool`, *optional*):
  774. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  775. (see `past_key_values`).
  776. past_key_values (`Cache`, *optional*): cached past key and value projection states
  777. kwargs (`dict`, *optional*):
  778. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  779. into the model
  780. """
  781. residual = hidden_states
  782. hidden_states = self.input_layernorm(hidden_states)
  783. # Self Attention
  784. hidden_states, self_attn_weights = self.self_attn(
  785. hidden_states=hidden_states,
  786. attention_mask=attention_mask,
  787. position_ids=position_ids,
  788. past_key_values=past_key_values,
  789. output_attentions=output_attentions,
  790. use_cache=use_cache,
  791. **kwargs,
  792. )
  793. hidden_states = residual + self.self_attn_layer_scale(hidden_states)
  794. # Fully Connected
  795. residual = hidden_states
  796. hidden_states = self.post_attention_layernorm(hidden_states)
  797. hidden_states = self.mlp(hidden_states)
  798. hidden_states = residual + self.mlp_layer_scale(hidden_states)
  799. outputs = (hidden_states,)
  800. if output_attentions:
  801. outputs += (self_attn_weights,)
  802. return outputs
  803. class MimiTransformerModel(nn.Module):
  804. """
  805. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MimiTransformerLayer`]
  806. Args:
  807. config: MimiConfig
  808. """
  809. def __init__(self, config: MimiConfig):
  810. super().__init__()
  811. self.layers = nn.ModuleList(
  812. [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  813. )
  814. self._attn_implementation = config._attn_implementation
  815. self.gradient_checkpointing = False
  816. self.config = config
  817. def forward(
  818. self,
  819. hidden_states: torch.LongTensor | None = None,
  820. attention_mask: torch.Tensor | None = None,
  821. position_ids: torch.LongTensor | None = None,
  822. past_key_values: Cache | None = None,
  823. use_cache: bool | None = None,
  824. output_attentions: bool | None = None,
  825. output_hidden_states: bool | None = None,
  826. return_dict: bool | None = None,
  827. **kwargs,
  828. ) -> tuple | BaseModelOutputWithPast:
  829. """
  830. Args:
  831. hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  832. Embedded representation that will be contextualized by the model
  833. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  834. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  835. - 1 for tokens that are **not masked**,
  836. - 0 for tokens that are **masked**.
  837. [What are attention masks?](../glossary#attention-mask)
  838. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  839. [`PreTrainedTokenizer.__call__`] for details.
  840. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  841. `past_key_values`).
  842. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  843. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  844. information on the default strategy.
  845. - 1 indicates the head is **not masked**,
  846. - 0 indicates the head is **masked**.
  847. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  848. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  849. config.n_positions - 1]`.
  850. [What are position IDs?](../glossary#position-ids)
  851. past_key_values (`Cache`, *optional*):
  852. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  853. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  854. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  855. of shape `(batch_size, sequence_length)`.
  856. use_cache (`bool`, *optional*):
  857. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  858. `past_key_values`).
  859. output_attentions (`bool`, *optional*):
  860. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  861. tensors for more detail.
  862. output_hidden_states (`bool`, *optional*):
  863. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  864. more detail.
  865. return_dict (`bool`, *optional*):
  866. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  867. """
  868. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  869. output_hidden_states = (
  870. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  871. )
  872. use_cache = use_cache if use_cache is not None else self.config.use_cache
  873. return_dict = return_dict if return_dict is not None else self.config.return_dict
  874. if self.gradient_checkpointing and self.training and use_cache:
  875. logger.warning_once(
  876. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  877. )
  878. use_cache = False
  879. if use_cache and past_key_values is None:
  880. past_key_values = DynamicCache(config=self.config)
  881. if position_ids is None:
  882. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  883. position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device) + past_seen_tokens
  884. position_ids = position_ids.unsqueeze(0)
  885. causal_mask = create_sliding_window_causal_mask(
  886. config=self.config,
  887. inputs_embeds=hidden_states,
  888. attention_mask=attention_mask,
  889. past_key_values=past_key_values,
  890. position_ids=position_ids,
  891. )
  892. # decoder layers
  893. all_hidden_states = () if output_hidden_states else None
  894. all_self_attns = () if output_attentions else None
  895. for decoder_layer in self.layers:
  896. if output_hidden_states:
  897. all_hidden_states += (hidden_states,)
  898. layer_outputs = decoder_layer(
  899. hidden_states,
  900. attention_mask=causal_mask,
  901. position_ids=position_ids,
  902. past_key_values=past_key_values,
  903. output_attentions=output_attentions,
  904. use_cache=use_cache,
  905. )
  906. hidden_states = layer_outputs[0]
  907. if output_attentions:
  908. all_self_attns += (layer_outputs[1],)
  909. # add hidden states from the last decoder layer
  910. if output_hidden_states:
  911. all_hidden_states += (hidden_states,)
  912. if not return_dict:
  913. return tuple(
  914. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
  915. )
  916. return BaseModelOutputWithPast(
  917. last_hidden_state=hidden_states,
  918. past_key_values=past_key_values,
  919. hidden_states=all_hidden_states,
  920. attentions=all_self_attns,
  921. )
  922. class MimiDecoder(nn.Module):
  923. """SEANet decoder as used by Mimi."""
  924. def __init__(self, config: MimiConfig):
  925. super().__init__()
  926. scaling = int(2 ** len(config.upsampling_ratios))
  927. model = [MimiConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)]
  928. # Upsample to raw audio scale
  929. for ratio in config.upsampling_ratios:
  930. current_scale = scaling * config.num_filters
  931. # Add upsampling layers
  932. model += [nn.ELU()]
  933. model += [
  934. MimiConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio)
  935. ]
  936. # Add residual layers
  937. for j in range(config.num_residual_layers):
  938. model += [MimiResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))]
  939. scaling //= 2
  940. # Add final layers
  941. model += [nn.ELU()]
  942. model += [MimiConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)]
  943. self.layers = nn.ModuleList(model)
  944. # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoder.forward
  945. def forward(self, hidden_states):
  946. for layer in self.layers:
  947. hidden_states = layer(hidden_states)
  948. return hidden_states
  949. class MimiEuclideanCodebook(nn.Module):
  950. """Codebook with Euclidean distance."""
  951. def __init__(self, config: MimiConfig, epsilon: float = 1e-5):
  952. super().__init__()
  953. embed = torch.zeros(config.codebook_size, config.codebook_dim)
  954. self.codebook_size = config.codebook_size
  955. self.register_buffer("initialized", torch.tensor([True], dtype=torch.float32))
  956. self.register_buffer("cluster_usage", torch.ones(config.codebook_size))
  957. self.register_buffer("embed_sum", embed)
  958. self._embed = None
  959. self.epsilon = epsilon
  960. @property
  961. def embed(self) -> torch.Tensor:
  962. if self._embed is None:
  963. self._embed = self.embed_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
  964. return self._embed
  965. def quantize(self, hidden_states):
  966. # Projects each vector in `hidden_states` over the nearest centroid and return its index.
  967. # `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
  968. dists = torch.cdist(hidden_states[None].float(), self.embed[None].float(), p=2)[0]
  969. embed_ind = dists.argmin(dim=-1)
  970. return embed_ind
  971. # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.encode
  972. def encode(self, hidden_states):
  973. shape = hidden_states.shape
  974. # pre-process
  975. hidden_states = hidden_states.reshape((-1, shape[-1]))
  976. # quantize
  977. embed_ind = self.quantize(hidden_states)
  978. # post-process
  979. embed_ind = embed_ind.view(*shape[:-1])
  980. return embed_ind
  981. # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.decode
  982. def decode(self, embed_ind):
  983. quantize = nn.functional.embedding(embed_ind, self.embed)
  984. return quantize
  985. # Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization with Encodec->Mimi
  986. class MimiVectorQuantization(nn.Module):
  987. """
  988. Vector quantization implementation. Currently supports only euclidean distance.
  989. """
  990. def __init__(self, config: MimiConfig):
  991. super().__init__()
  992. self.codebook = MimiEuclideanCodebook(config)
  993. def encode(self, hidden_states):
  994. hidden_states = hidden_states.permute(0, 2, 1)
  995. embed_in = self.codebook.encode(hidden_states)
  996. return embed_in
  997. def decode(self, embed_ind):
  998. quantize = self.codebook.decode(embed_ind)
  999. quantize = quantize.permute(0, 2, 1)
  1000. return quantize
  1001. class MimiResidualVectorQuantizer(nn.Module):
  1002. """Residual Vector Quantizer."""
  1003. def __init__(self, config: MimiConfig, num_quantizers: int | None = None):
  1004. super().__init__()
  1005. self.codebook_size = config.codebook_size
  1006. self.frame_rate = config.frame_rate
  1007. self.num_quantizers = num_quantizers if num_quantizers is not None else config.num_quantizers
  1008. self.layers = nn.ModuleList([MimiVectorQuantization(config) for _ in range(self.num_quantizers)])
  1009. self.input_proj = None
  1010. self.output_proj = None
  1011. if config.vector_quantization_hidden_dimension != config.hidden_size:
  1012. self.input_proj = torch.nn.Conv1d(
  1013. config.hidden_size, config.vector_quantization_hidden_dimension, 1, bias=False
  1014. )
  1015. self.output_proj = torch.nn.Conv1d(
  1016. config.vector_quantization_hidden_dimension, config.hidden_size, 1, bias=False
  1017. )
  1018. def encode(self, embeddings: torch.Tensor, num_quantizers: int | None = None) -> torch.Tensor:
  1019. """
  1020. Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets
  1021. the appropriate number of quantizers to use and returns indices for each quantizer.
  1022. """
  1023. if self.input_proj is not None:
  1024. embeddings = self.input_proj(embeddings)
  1025. num_quantizers = num_quantizers if num_quantizers is not None else self.num_quantizers
  1026. residual = embeddings
  1027. all_indices = []
  1028. for layer in self.layers[:num_quantizers]:
  1029. indices = layer.encode(residual)
  1030. quantized = layer.decode(indices)
  1031. residual = residual - quantized
  1032. all_indices.append(indices)
  1033. out_indices = torch.stack(all_indices)
  1034. return out_indices
  1035. def decode(self, codes: torch.Tensor) -> torch.Tensor:
  1036. """Decode the given codes of shape [B, K, T] to the quantized representation."""
  1037. quantized_out = torch.tensor(0.0, device=codes.device)
  1038. codes = codes.transpose(0, 1)
  1039. for i, indices in enumerate(codes):
  1040. layer = self.layers[i]
  1041. quantized = layer.decode(indices)
  1042. quantized_out = quantized_out + quantized
  1043. if self.output_proj is not None:
  1044. quantized_out = self.output_proj(quantized_out)
  1045. return quantized_out
  1046. class MimiSplitResidualVectorQuantizer(nn.Module):
  1047. """Split Residual Vector Quantizer."""
  1048. def __init__(self, config: MimiConfig):
  1049. super().__init__()
  1050. self.codebook_size = config.codebook_size
  1051. self.frame_rate = config.frame_rate
  1052. self.max_num_quantizers = config.num_quantizers
  1053. self.num_semantic_quantizers = config.num_semantic_quantizers
  1054. self.num_acoustic_quantizers = config.num_quantizers - config.num_semantic_quantizers
  1055. self.semantic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_semantic_quantizers)
  1056. self.acoustic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_acoustic_quantizers)
  1057. def encode(self, embeddings: torch.Tensor, num_quantizers: float | None = None) -> torch.Tensor:
  1058. """
  1059. Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets
  1060. the appropriate number of quantizers to use and returns indices for each quantizer.
  1061. """
  1062. num_quantizers = self.max_num_quantizers if num_quantizers is None else num_quantizers
  1063. if num_quantizers > self.max_num_quantizers:
  1064. raise ValueError(
  1065. f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.max_num_quantizers}, but is currently {num_quantizers}."
  1066. )
  1067. if num_quantizers < self.num_semantic_quantizers:
  1068. raise ValueError(
  1069. f"The number of quantizers (i.e codebooks) asked should be higher than the number of semantic quantizers {self.num_semantic_quantizers}, but is currently {num_quantizers}."
  1070. )
  1071. # codes is [K, B, T], with T frames, K nb of codebooks.
  1072. codes = self.semantic_residual_vector_quantizer.encode(embeddings)
  1073. if num_quantizers > self.num_semantic_quantizers:
  1074. acoustic_codes = self.acoustic_residual_vector_quantizer.encode(
  1075. embeddings, num_quantizers=num_quantizers - self.num_semantic_quantizers
  1076. )
  1077. codes = torch.cat([codes, acoustic_codes], dim=0)
  1078. return codes
  1079. def decode(self, codes: torch.Tensor) -> torch.Tensor:
  1080. """Decode the given codes to the quantized representation."""
  1081. # The first num_semantic_quantizers codebooks are decoded using the semantic RVQ
  1082. quantized_out = self.semantic_residual_vector_quantizer.decode(codes[:, : self.num_semantic_quantizers])
  1083. # The rest of the codebooks are decoded using the acoustic RVQ
  1084. if codes.shape[1] > self.num_semantic_quantizers:
  1085. quantized_out += self.acoustic_residual_vector_quantizer.decode(codes[:, self.num_semantic_quantizers :])
  1086. return quantized_out
  1087. @auto_docstring
  1088. class MimiPreTrainedModel(PreTrainedModel):
  1089. config: MimiConfig
  1090. base_model_prefix = "mimi"
  1091. main_input_name = "input_values"
  1092. input_modalities = "audio"
  1093. supports_gradient_checkpointing = True
  1094. _no_split_modules = ["MimiResidualVectorQuantizer", "MimiTransformerLayer"]
  1095. _skip_keys_device_placement = "past_key_values"
  1096. _supports_flash_attn = True
  1097. _supports_sdpa = True
  1098. _can_compile_fullgraph = True
  1099. @torch.no_grad()
  1100. def _init_weights(self, module):
  1101. """Initialize the weights"""
  1102. if isinstance(module, nn.Linear):
  1103. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  1104. if module.bias is not None:
  1105. init.zeros_(module.bias)
  1106. elif isinstance(module, nn.LayerNorm):
  1107. init.zeros_(module.bias)
  1108. init.ones_(module.weight)
  1109. elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
  1110. init.kaiming_normal_(module.weight)
  1111. if module.bias is not None:
  1112. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  1113. init.uniform_(module.bias, a=-k, b=k)
  1114. elif isinstance(module, MimiLayerScale):
  1115. init.constant_(module.scale, self.config.layer_scale_initial_scale)
  1116. elif isinstance(module, MimiConv1d):
  1117. kernel_size = module.conv.kernel_size[0]
  1118. stride = module.conv.stride[0]
  1119. dilation = module.conv.dilation[0]
  1120. kernel_size = (kernel_size - 1) * dilation + 1
  1121. init.constant_(module.stride, stride)
  1122. init.constant_(module.kernel_size, kernel_size)
  1123. init.constant_(module.padding_total, kernel_size - stride)
  1124. elif isinstance(module, MimiEuclideanCodebook):
  1125. init.ones_(module.initialized)
  1126. init.ones_(module.cluster_usage)
  1127. init.zeros_(module.embed_sum)
  1128. elif isinstance(module, MimiRotaryEmbedding):
  1129. rope_fn = (
  1130. ROPE_INIT_FUNCTIONS[module.rope_type]
  1131. if module.rope_type != "default"
  1132. else module.compute_default_rope_parameters
  1133. )
  1134. buffer_value, _ = rope_fn(module.config)
  1135. init.copy_(module.inv_freq, buffer_value)
  1136. init.copy_(module.original_inv_freq, buffer_value)
  1137. @auto_docstring(
  1138. custom_intro="""
  1139. The Mimi neural audio codec model.
  1140. """
  1141. )
  1142. class MimiModel(MimiPreTrainedModel):
  1143. def __init__(self, config: MimiConfig):
  1144. super().__init__(config)
  1145. self.config = config
  1146. self.encoder = MimiEncoder(config)
  1147. self.encoder_transformer = MimiTransformerModel(config)
  1148. self.downsample = None
  1149. self.upsample = None
  1150. if config.frame_rate != config.encodec_frame_rate:
  1151. self.downsample = MimiConv1d(
  1152. config,
  1153. config.hidden_size,
  1154. config.hidden_size,
  1155. kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate),
  1156. stride=2,
  1157. bias=False,
  1158. pad_mode="replicate",
  1159. layer_idx=len(self.encoder._mimiconv1d_layer_names),
  1160. )
  1161. self.upsample = MimiConvTranspose1d(
  1162. config,
  1163. config.hidden_size,
  1164. config.hidden_size,
  1165. kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate),
  1166. stride=2,
  1167. bias=False,
  1168. groups=config.upsample_groups,
  1169. )
  1170. self.decoder_transformer = MimiTransformerModel(config)
  1171. self.decoder = MimiDecoder(config)
  1172. self.quantizer = MimiSplitResidualVectorQuantizer(config)
  1173. self.bits_per_codebook = int(math.log2(self.config.codebook_size))
  1174. if 2**self.bits_per_codebook != self.config.codebook_size:
  1175. raise ValueError("The codebook_size must be a power of 2.")
  1176. # Initialize weights and apply final processing
  1177. self.post_init()
  1178. def _encode_frame(
  1179. self,
  1180. input_values: torch.Tensor,
  1181. num_quantizers: int,
  1182. padding_mask: int,
  1183. past_key_values: Cache | None = None,
  1184. padding_cache: MimiConv1dPaddingCache | None = None,
  1185. use_streaming: bool | None = None,
  1186. return_dict: bool | None = None,
  1187. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  1188. """
  1189. Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale.
  1190. """
  1191. # TODO: @eustlb, let's make the encoder support padding_mask so that batched inputs are supported.
  1192. embeddings = self.encoder(input_values, padding_cache=padding_cache)
  1193. # TODO: @eustlb, convert the padding mask to attention mask.
  1194. encoder_outputs = self.encoder_transformer(
  1195. embeddings.transpose(1, 2),
  1196. past_key_values=past_key_values,
  1197. use_cache=use_streaming,
  1198. return_dict=return_dict,
  1199. )
  1200. if return_dict:
  1201. past_key_values = encoder_outputs.get("past_key_values")
  1202. elif len(encoder_outputs) > 1:
  1203. past_key_values = encoder_outputs[1]
  1204. embeddings = encoder_outputs[0].transpose(1, 2)
  1205. embeddings = self.downsample(embeddings, padding_cache=padding_cache)
  1206. codes = self.quantizer.encode(embeddings, num_quantizers)
  1207. codes = codes.transpose(0, 1)
  1208. return codes, past_key_values, padding_cache
  1209. def get_encoded_length(self, input_length: torch.LongTensor) -> torch.LongTensor:
  1210. """
  1211. Return the number of frames of the encoded audio waveform.
  1212. """
  1213. output_length = input_length
  1214. # encoder
  1215. for layer_name in self.encoder._mimiconv1d_layer_names:
  1216. output_length = self.encoder.get_submodule(layer_name)._get_output_length(output_length)
  1217. # downsample
  1218. output_length = self.downsample._get_output_length(output_length)
  1219. return output_length
  1220. def get_audio_codes_mask(self, padding_mask: torch.Tensor, padding_side: str = "right"):
  1221. """
  1222. Get the mask for the audio codes from the original padding mask.
  1223. """
  1224. encoded_lengths = self.get_encoded_length(padding_mask.sum(dim=-1))
  1225. audio_codes_mask = torch.arange(encoded_lengths.max(), device=encoded_lengths.device).expand(
  1226. len(encoded_lengths), -1
  1227. )
  1228. audio_codes_mask = audio_codes_mask < encoded_lengths.unsqueeze(1)
  1229. audio_codes_mask = audio_codes_mask.to(padding_mask.device)
  1230. if padding_side == "right":
  1231. return audio_codes_mask
  1232. else:
  1233. return audio_codes_mask.flip(dims=[-1])
  1234. def encode(
  1235. self,
  1236. input_values: torch.Tensor,
  1237. padding_mask: torch.Tensor | None = None,
  1238. num_quantizers: float | None = None,
  1239. encoder_past_key_values: Cache | None = None,
  1240. padding_cache: MimiConv1dPaddingCache | None = None,
  1241. use_streaming: bool | None = None,
  1242. return_dict: bool | None = None,
  1243. ) -> tuple[torch.Tensor, torch.Tensor | None] | MimiEncoderOutput:
  1244. """
  1245. Encodes the input audio waveform into discrete codes.
  1246. Args:
  1247. input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
  1248. Float values of the input audio waveform.
  1249. padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
  1250. Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
  1251. for *masked*.
  1252. num_quantizers (`int`, *optional*):
  1253. Number of quantizers (i.e codebooks) to use. By default, all quantizers are used.
  1254. encoder_past_key_values (`Cache`, *optional*):
  1255. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
  1256. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  1257. The model will output the same cache format that is fed as input.
  1258. If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
  1259. have their past key value states given to this model).
  1260. return_dict (`bool`, *optional*):
  1261. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1262. Returns:
  1263. `codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform.
  1264. """
  1265. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1266. use_streaming = use_streaming if use_streaming is not None else self.config.use_streaming
  1267. num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers
  1268. if num_quantizers > self.config.num_quantizers:
  1269. raise ValueError(
  1270. f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.config.num_quantizers}, but is currently {num_quantizers}."
  1271. )
  1272. _, channels, input_length = input_values.shape
  1273. if channels < 1 or channels > 2:
  1274. raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
  1275. if padding_mask is None:
  1276. padding_mask = torch.ones_like(input_values).bool()
  1277. if use_streaming and padding_cache is None:
  1278. per_layer_padding, per_layer_padding_mode, per_layer_in_channels = [], [], []
  1279. for layer_name in self.encoder._mimiconv1d_layer_names:
  1280. per_layer_padding.append(self.encoder.get_submodule(layer_name).padding_total)
  1281. per_layer_padding_mode.append(self.encoder.get_submodule(layer_name).pad_mode)
  1282. per_layer_in_channels.append(self.encoder.get_submodule(layer_name).in_channels)
  1283. # downsample layer
  1284. per_layer_padding.append(self.downsample.padding_total)
  1285. per_layer_padding_mode.append(self.downsample.pad_mode)
  1286. per_layer_in_channels.append(self.downsample.in_channels)
  1287. padding_cache = MimiConv1dPaddingCache(
  1288. num_layers=len(self.encoder._mimiconv1d_layer_names) + 1,
  1289. per_layer_padding=per_layer_padding,
  1290. per_layer_padding_mode=per_layer_padding_mode,
  1291. per_layer_in_channels=per_layer_in_channels,
  1292. )
  1293. encoded_frames, encoder_past_key_values, padding_cache = self._encode_frame(
  1294. input_values,
  1295. num_quantizers,
  1296. padding_mask.bool(),
  1297. past_key_values=encoder_past_key_values,
  1298. padding_cache=padding_cache,
  1299. use_streaming=use_streaming,
  1300. return_dict=return_dict,
  1301. )
  1302. if not return_dict:
  1303. return (
  1304. encoded_frames,
  1305. encoder_past_key_values,
  1306. padding_cache,
  1307. )
  1308. return MimiEncoderOutput(encoded_frames, encoder_past_key_values, padding_cache)
  1309. def _decode_frame(
  1310. self,
  1311. codes: torch.Tensor,
  1312. past_key_values: Cache | None = None,
  1313. return_dict: bool | None = None,
  1314. ) -> torch.Tensor:
  1315. embeddings = self.quantizer.decode(codes)
  1316. embeddings = self.upsample(embeddings)
  1317. decoder_outputs = self.decoder_transformer(
  1318. embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict
  1319. )
  1320. if return_dict:
  1321. past_key_values = decoder_outputs.get("past_key_values")
  1322. elif len(decoder_outputs) > 1:
  1323. past_key_values = decoder_outputs[1]
  1324. embeddings = decoder_outputs[0].transpose(1, 2)
  1325. outputs = self.decoder(embeddings)
  1326. return outputs, past_key_values
  1327. def decode(
  1328. self,
  1329. audio_codes: torch.Tensor,
  1330. padding_mask: torch.Tensor | None = None,
  1331. decoder_past_key_values: Cache | None = None,
  1332. return_dict: bool | None = None,
  1333. ) -> tuple[torch.Tensor, torch.Tensor] | MimiDecoderOutput:
  1334. """
  1335. Decodes the given frames into an output audio waveform.
  1336. Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
  1337. trimmed.
  1338. Args:
  1339. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
  1340. Discret code embeddings computed using `model.encode`.
  1341. padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
  1342. Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
  1343. for *masked*.
  1344. decoder_past_key_values (`Cache`, *optional*):
  1345. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
  1346. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  1347. The model will output the same cache format that is fed as input.
  1348. If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
  1349. have their past key value states given to this model).
  1350. return_dict (`bool`, *optional*):
  1351. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1352. """
  1353. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1354. audio_values, decoder_past_key_values = self._decode_frame(
  1355. audio_codes, past_key_values=decoder_past_key_values, return_dict=return_dict
  1356. )
  1357. # truncate based on padding mask
  1358. if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
  1359. audio_values = audio_values[..., : padding_mask.shape[-1]]
  1360. if not return_dict:
  1361. return (
  1362. audio_values,
  1363. decoder_past_key_values,
  1364. )
  1365. return MimiDecoderOutput(audio_values, decoder_past_key_values)
  1366. @auto_docstring
  1367. def forward(
  1368. self,
  1369. input_values: torch.Tensor,
  1370. padding_mask: torch.Tensor | None = None,
  1371. num_quantizers: int | None = None,
  1372. audio_codes: torch.Tensor | None = None,
  1373. encoder_past_key_values: Cache | None = None,
  1374. decoder_past_key_values: Cache | None = None,
  1375. return_dict: bool | None = None,
  1376. **kwargs,
  1377. ) -> tuple[torch.Tensor, torch.Tensor] | MimiOutput:
  1378. r"""
  1379. input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
  1380. Raw audio input converted to Float.
  1381. padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1382. Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
  1383. for *masked*.
  1384. num_quantizers (`int`, *optional*):
  1385. Number of quantizers (i.e codebooks) to use. By default, all quantizers are used.
  1386. audio_codes (`torch.LongTensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
  1387. Discret code embeddings computed using `model.encode`.
  1388. encoder_past_key_values (`Cache`, *optional*):
  1389. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
  1390. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  1391. The model will output the same cache format that is fed as input.
  1392. If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
  1393. have their past key value states given to this model).
  1394. decoder_past_key_values (`Cache`, *optional*):
  1395. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
  1396. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  1397. The model will output the same cache format that is fed as input.
  1398. If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
  1399. have their past key value states given to this model).
  1400. Examples:
  1401. ```python
  1402. >>> from datasets import load_dataset
  1403. >>> from transformers import AutoFeatureExtractor, MimiModel
  1404. >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
  1405. >>> audio_sample = dataset["train"]["audio"][0]["array"]
  1406. >>> model_id = "kyutai/mimi"
  1407. >>> model = MimiModel.from_pretrained(model_id)
  1408. >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
  1409. >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt")
  1410. >>> outputs = model(**inputs)
  1411. >>> audio_codes = outputs.audio_codes
  1412. >>> audio_values = outputs.audio_values
  1413. ```"""
  1414. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1415. if padding_mask is None:
  1416. padding_mask = torch.ones_like(input_values).bool()
  1417. if audio_codes is None:
  1418. encoder_outputs = self.encode(
  1419. input_values, padding_mask, num_quantizers, encoder_past_key_values, return_dict=return_dict
  1420. )
  1421. audio_codes = encoder_outputs[0]
  1422. if return_dict:
  1423. encoder_past_key_values = encoder_outputs.get("past_key_values")
  1424. elif len(encoder_outputs) > 1:
  1425. encoder_past_key_values = encoder_outputs[1]
  1426. decoder_outputs = self.decode(audio_codes, padding_mask, decoder_past_key_values, return_dict=return_dict)
  1427. audio_values = decoder_outputs[0]
  1428. if return_dict:
  1429. decoder_past_key_values = decoder_outputs.get("past_key_values")
  1430. elif len(decoder_outputs) > 1:
  1431. decoder_past_key_values = decoder_outputs[1]
  1432. if not return_dict:
  1433. return (audio_codes, audio_values, encoder_past_key_values, decoder_past_key_values)
  1434. return MimiOutput(
  1435. audio_codes=audio_codes,
  1436. audio_values=audio_values,
  1437. encoder_past_key_values=encoder_past_key_values,
  1438. decoder_past_key_values=decoder_past_key_values,
  1439. )
  1440. __all__ = ["MimiModel", "MimiPreTrainedModel"]