modeling_bark.py 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521
  1. # Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BARK model."""
  15. import math
  16. import numpy as np
  17. import torch
  18. from torch import nn
  19. from torch.nn import functional as F
  20. from ... import initialization as init
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...generation.logits_process import (
  24. AlternatingCodebooksLogitsProcessor,
  25. BarkEosPrioritizerLogitsProcessor,
  26. SuppressTokensLogitsProcessor,
  27. )
  28. from ...masking_utils import create_bidirectional_mask
  29. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
  32. from ...modeling_utils import PreTrainedModel
  33. from ...utils import (
  34. auto_docstring,
  35. is_accelerate_available,
  36. is_torch_accelerator_available,
  37. logging,
  38. )
  39. from ...utils.deprecation import deprecate_kwarg
  40. from ..auto import AutoModel
  41. from .configuration_bark import (
  42. BarkCoarseConfig,
  43. BarkConfig,
  44. BarkFineConfig,
  45. BarkSemanticConfig,
  46. BarkSubModelConfig,
  47. )
  48. from .generation_configuration_bark import (
  49. BarkCoarseGenerationConfig,
  50. BarkFineGenerationConfig,
  51. BarkSemanticGenerationConfig,
  52. )
  53. if is_flash_attn_available():
  54. from ...integrations.flash_attention import get_target_dtype
  55. from ...modeling_flash_attention_utils import _flash_attention_forward
  56. logger = logging.get_logger(__name__)
  57. class BarkSelfAttention(nn.Module):
  58. # adapted from GPTNeoSelfAttention and Bark code
  59. # BarkSelfAttention can have two attention type, i.e full attention or causal attention
  60. def __init__(self, config, is_causal=False, layer_idx=None):
  61. super().__init__()
  62. # regularization
  63. self.dropout = config.dropout
  64. self.attn_dropout = nn.Dropout(config.dropout)
  65. self.resid_dropout = nn.Dropout(config.dropout)
  66. self.embed_dim = config.hidden_size
  67. self.num_heads = config.num_heads
  68. self.head_dim = self.embed_dim // self.num_heads
  69. self.config = config
  70. if config.hidden_size % config.num_heads != 0:
  71. raise ValueError(
  72. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  73. f" {self.num_heads})."
  74. )
  75. # key, query, value projections for all heads, but in a batch
  76. self.att_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
  77. # output projection
  78. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias)
  79. self.is_causal = is_causal
  80. self.layer_idx = layer_idx
  81. if is_causal:
  82. block_size = config.block_size
  83. bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
  84. self.register_buffer("bias", bias)
  85. # Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads
  86. def _split_heads(self, tensor, num_heads, attn_head_size):
  87. """
  88. Splits hidden_size dim into attn_head_size and num_heads
  89. """
  90. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  91. tensor = tensor.view(new_shape)
  92. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  93. def _merge_heads(self, tensor, num_heads, attn_head_size):
  94. """
  95. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  96. """
  97. # re-assemble all head outputs side by side
  98. # (batch, num_heads, seq_len, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
  99. tensor = tensor.transpose(1, 2).contiguous()
  100. tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
  101. return tensor
  102. def _attn(self, query, key, value, attention_mask=None):
  103. # unlike GPTNeo's SelfAttention, divide by the square root of the dimension of the query and the key
  104. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * (1.0 / math.sqrt(self.head_dim))
  105. if self.is_causal:
  106. query_length, key_length = query.size(-2), key.size(-2)
  107. # fill the upper left part of the attention weights with inf
  108. attn_weights = attn_weights.masked_fill(
  109. self.bias[:, :, key_length - query_length : key_length, :key_length] == 0,
  110. torch.finfo(attn_weights.dtype).min,
  111. )
  112. if attention_mask is not None:
  113. # Apply the attention mask
  114. attn_weights = attn_weights + attention_mask
  115. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  116. attn_weights = attn_weights.to(value.dtype)
  117. attn_weights = self.attn_dropout(attn_weights)
  118. # (batch, num_heads, seq_len, seq_len) x (batch, num_heads, seq_len, attn_head_size)
  119. # -> (batch, num_heads, seq_len, attn_head_size)
  120. attn_output = torch.matmul(attn_weights, value)
  121. return attn_output, attn_weights
  122. def forward(
  123. self,
  124. hidden_states,
  125. attention_mask=None,
  126. past_key_values=None,
  127. use_cache=False,
  128. output_attentions=False,
  129. **kwargs,
  130. ):
  131. # calculate query, key, values for all heads in batch and move head forward to be the batch dim
  132. query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
  133. query = self._split_heads(query, self.num_heads, self.head_dim)
  134. key = self._split_heads(key, self.num_heads, self.head_dim)
  135. value = self._split_heads(value, self.num_heads, self.head_dim)
  136. if past_key_values is not None:
  137. key, value = past_key_values.update(key, value, self.layer_idx)
  138. attn_output, attn_weights = self._attn(query, key, value, attention_mask)
  139. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  140. attn_output = self.out_proj(attn_output)
  141. attn_output = self.resid_dropout(attn_output)
  142. return attn_output, attn_weights
  143. class BarkSelfFlashAttention2(BarkSelfAttention):
  144. """
  145. Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
  146. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  147. flash attention and deal with padding tokens in case the input contains any of them.
  148. """
  149. def __init__(self, *args, **kwargs):
  150. super().__init__(*args, **kwargs)
  151. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  152. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  153. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  154. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  155. def _split_heads(self, tensor, num_heads, attn_head_size):
  156. """
  157. Splits hidden_size dim into attn_head_size and num_heads
  158. """
  159. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  160. tensor = tensor.view(new_shape)
  161. # Flash attention requires the input to have the shape
  162. # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
  163. return tensor
  164. def _merge_heads(self, tensor, num_heads, attn_head_size):
  165. """
  166. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  167. """
  168. # re-assemble all head outputs side by side
  169. # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
  170. tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
  171. return tensor
  172. def forward(
  173. self,
  174. hidden_states,
  175. attention_mask=None,
  176. past_key_values=None,
  177. use_cache=False,
  178. output_attentions=False,
  179. **kwargs,
  180. ):
  181. batch_size, query_len, _ = hidden_states.size()
  182. # calculate query, key, values for all heads in batch and move head forward to be the batch dim
  183. query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
  184. query = self._split_heads(query, self.num_heads, self.head_dim)
  185. key = self._split_heads(key, self.num_heads, self.head_dim)
  186. value = self._split_heads(value, self.num_heads, self.head_dim)
  187. if past_key_values is not None:
  188. key, value = past_key_values.update(key, value, self.layer_idx)
  189. target_dtype = get_target_dtype(query, self) # if the query is in float32, this is the dtype to cast to for FA
  190. attn_output = _flash_attention_forward(
  191. query,
  192. key,
  193. value,
  194. attention_mask,
  195. query_len,
  196. dropout=self.dropout if self.training else 0.0,
  197. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  198. is_causal=self.is_causal,
  199. target_dtype=target_dtype,
  200. )
  201. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  202. attn_output = self.out_proj(attn_output)
  203. attn_output = self.resid_dropout(attn_output)
  204. return attn_output, None
  205. BARK_ATTENTION_CLASSES = {
  206. "eager": BarkSelfAttention,
  207. "flash_attention_2": BarkSelfFlashAttention2,
  208. }
  209. class BarkMLP(nn.Module):
  210. def __init__(self, config):
  211. super().__init__()
  212. self.in_proj = nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=config.bias)
  213. self.out_proj = nn.Linear(4 * config.hidden_size, config.hidden_size, bias=config.bias)
  214. self.dropout = nn.Dropout(config.dropout)
  215. self.gelu = nn.GELU()
  216. def forward(self, hidden_states):
  217. hidden_states = self.in_proj(hidden_states)
  218. hidden_states = self.gelu(hidden_states)
  219. hidden_states = self.out_proj(hidden_states)
  220. hidden_states = self.dropout(hidden_states)
  221. return hidden_states
  222. class BarkBlock(GradientCheckpointingLayer):
  223. def __init__(self, config, is_causal=False, layer_idx=None):
  224. super().__init__()
  225. if is_causal:
  226. # if causal, the layerNorm bias is optional to stick with Bark choice of leaving optional bias
  227. # in AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
  228. self.layernorm_1 = nn.LayerNorm(config.hidden_size, bias=config.bias)
  229. self.layernorm_2 = nn.LayerNorm(config.hidden_size, bias=config.bias)
  230. else:
  231. self.layernorm_1 = nn.LayerNorm(config.hidden_size)
  232. self.layernorm_2 = nn.LayerNorm(config.hidden_size)
  233. self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](
  234. config, is_causal=is_causal, layer_idx=layer_idx
  235. )
  236. self.mlp = BarkMLP(config)
  237. def forward(
  238. self,
  239. hidden_states,
  240. past_key_values=None,
  241. attention_mask=None,
  242. use_cache=False,
  243. output_attentions=False,
  244. **kwargs,
  245. ):
  246. intermediary_hidden_states = self.layernorm_1(hidden_states)
  247. attn_outputs = self.attn(
  248. intermediary_hidden_states,
  249. past_key_values=past_key_values,
  250. attention_mask=attention_mask,
  251. use_cache=use_cache,
  252. output_attentions=output_attentions,
  253. )
  254. attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights)
  255. outputs = attn_outputs[1:]
  256. intermediary_hidden_states = hidden_states + attn_output
  257. intermediary_hidden_states = intermediary_hidden_states + self.mlp(
  258. self.layernorm_2(intermediary_hidden_states)
  259. )
  260. return (intermediary_hidden_states,) + outputs
  261. @auto_docstring
  262. class BarkPreTrainedModel(PreTrainedModel):
  263. config: BarkConfig
  264. supports_gradient_checkpointing = False
  265. _supports_flash_attn = True
  266. @property
  267. def device(self) -> torch.device:
  268. """
  269. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  270. device).
  271. """
  272. # if has _hf_hook, has been offloaded so the device has to be found in the hook
  273. if not hasattr(self, "_hf_hook"):
  274. return super().device
  275. for module in self.modules():
  276. if (
  277. hasattr(module, "_hf_hook")
  278. and hasattr(module._hf_hook, "execution_device")
  279. and module._hf_hook.execution_device is not None
  280. ):
  281. return torch.device(module._hf_hook.execution_device)
  282. return super().device
  283. def _init_weights(self, module):
  284. super()._init_weights(module)
  285. if isinstance(module, BarkSelfAttention):
  286. if module.is_causal:
  287. block_size = module.config.block_size
  288. bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
  289. init.copy_(module.bias, bias)
  290. # GPT2-like autoregressive model
  291. class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
  292. config: BarkSubModelConfig
  293. output_modalities = ("audio",)
  294. def __init__(self, config):
  295. super().__init__(config)
  296. self.config = config
  297. # initialize as an autoregressive GPT-like model
  298. self.input_embeds_layer = nn.Embedding(config.input_vocab_size, config.hidden_size)
  299. self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
  300. self.drop = nn.Dropout(config.dropout)
  301. self.layers = nn.ModuleList([BarkBlock(config, is_causal=True, layer_idx=i) for i in range(config.num_layers)])
  302. self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias)
  303. self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
  304. self.gradient_checkpointing = False
  305. # Initialize weights and apply final processing
  306. self.post_init()
  307. def get_output_embeddings(self):
  308. # NOTE: get_output_embeddings() must return None to prevent accidental weight tying.
  309. # See e.g. https://github.com/huggingface/transformers/pull/39339#discussion_r2219126400
  310. return None
  311. def get_input_embeddings(self):
  312. return self.input_embeds_layer
  313. def set_input_embeddings(self, new_embeddings):
  314. self.input_embeds_layer = new_embeddings
  315. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  316. @auto_docstring
  317. def forward(
  318. self,
  319. input_ids: torch.Tensor | None = None,
  320. past_key_values: Cache | None = None,
  321. attention_mask: torch.Tensor | None = None,
  322. position_ids: torch.Tensor | None = None,
  323. labels: torch.LongTensor | None = None,
  324. inputs_embeds: torch.Tensor | None = None,
  325. use_cache: bool | None = None,
  326. output_attentions: bool | None = None,
  327. output_hidden_states: bool | None = None,
  328. return_dict: bool | None = None,
  329. **kwargs,
  330. ) -> tuple[torch.Tensor] | CausalLMOutputWithPast:
  331. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  332. output_hidden_states = (
  333. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  334. )
  335. use_cache = use_cache if use_cache is not None else self.config.use_cache
  336. return_dict = return_dict if return_dict is not None else self.config.return_dict
  337. loss = None
  338. if labels is not None:
  339. raise NotImplementedError(
  340. "Training is not implemented yet for Bark - ensure you do not pass `labels` to the model."
  341. )
  342. # Verify if inputs_embeds already exists
  343. # then compute embeddings.
  344. if input_ids is not None and inputs_embeds is not None:
  345. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  346. elif inputs_embeds is not None and past_key_values is None:
  347. # we want to return the inputs_embeds in priority so that it is in line with a weird hack
  348. # of Bark which concatenate two bits of the inputs_embeds on the first forward pass of the semantic model
  349. pass
  350. elif input_ids is not None:
  351. inputs_embeds = self.input_embeds_layer(input_ids) # token embeddings of shape (b, t, n_embd)
  352. elif inputs_embeds is not None:
  353. pass
  354. else:
  355. raise ValueError("You have to specify either input_ids or inputs_embeds")
  356. input_shape = inputs_embeds.size()[:-1]
  357. seq_length = input_shape[-1]
  358. if self.gradient_checkpointing and self.training:
  359. if use_cache:
  360. logger.warning_once(
  361. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  362. )
  363. use_cache = False
  364. if use_cache and past_key_values is None:
  365. past_key_values = DynamicCache(config=self.config)
  366. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  367. inputs_embeds = inputs_embeds.to(self.position_embeds_layer.weight.device)
  368. if position_ids is None:
  369. position_ids = torch.arange(
  370. past_length,
  371. seq_length + past_length,
  372. dtype=torch.long,
  373. device=self.position_embeds_layer.weight.device,
  374. )
  375. position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
  376. position_ids = position_ids.to(self.position_embeds_layer.weight.device)
  377. position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
  378. attention_mask = create_bidirectional_mask(
  379. config=self.config,
  380. inputs_embeds=inputs_embeds,
  381. attention_mask=attention_mask,
  382. )
  383. hidden_states = self.drop(inputs_embeds + position_embeds)
  384. output_shape = input_shape + (hidden_states.size(-1),)
  385. all_self_attentions = () if output_attentions else None
  386. all_hidden_states = () if output_hidden_states else None
  387. for i, block in enumerate(self.layers):
  388. if output_hidden_states:
  389. all_hidden_states = all_hidden_states + (hidden_states,)
  390. outputs = block(
  391. hidden_states,
  392. past_key_values=past_key_values,
  393. attention_mask=attention_mask,
  394. use_cache=use_cache,
  395. output_attentions=output_attentions,
  396. )
  397. hidden_states = outputs[0]
  398. if output_attentions:
  399. all_self_attentions = all_self_attentions + (outputs[1],)
  400. hidden_states = self.layernorm_final(hidden_states)
  401. hidden_states = hidden_states.view(output_shape)
  402. # Add last hidden state
  403. if output_hidden_states:
  404. all_hidden_states = all_hidden_states + (hidden_states,)
  405. logits = self.lm_head(hidden_states)
  406. if not return_dict:
  407. return tuple(
  408. v for v in [None, logits, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  409. )
  410. return CausalLMOutputWithPast(
  411. loss=loss,
  412. logits=logits,
  413. past_key_values=past_key_values,
  414. hidden_states=all_hidden_states,
  415. attentions=all_self_attentions,
  416. )
  417. @auto_docstring(
  418. custom_intro="""
  419. Bark semantic (or text) model. It shares the same architecture as the coarse model.
  420. It is a GPT-2 like autoregressive model with a language modeling head on top.
  421. """
  422. )
  423. class BarkSemanticModel(BarkCausalModel):
  424. base_model_prefix = "semantic"
  425. config: BarkSemanticConfig
  426. def generate(
  427. self,
  428. input_ids: torch.Tensor,
  429. semantic_generation_config: BarkSemanticGenerationConfig | None = None,
  430. history_prompt: dict[str, torch.Tensor] | None = None,
  431. attention_mask: torch.Tensor | None = None,
  432. **kwargs,
  433. ) -> torch.LongTensor:
  434. """
  435. Generates text semantic tokens from an input prompt and an additional optional `Bark` speaker prompt.
  436. Args:
  437. input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
  438. Input ids, i.e tokenized input sentences. Will be truncated up to
  439. semantic_generation_config.max_input_semantic_length tokens. Note that the output audios will be as
  440. long as the longest generation among the batch.
  441. semantic_generation_config (`BarkSemanticGenerationConfig`):
  442. Generation config indicating how to generate the semantic tokens.
  443. history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
  444. Optional `Bark` speaker prompt.
  445. attention_mask (`Optional[torch.Tensor]`, *optional*):
  446. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  447. - 1 for tokens that are **not masked**,
  448. - 0 for tokens that are **masked**.
  449. [What are attention masks?](../glossary#attention-mask)
  450. Returns:
  451. torch.LongTensor: Output semantic tokens.
  452. """
  453. if semantic_generation_config is None:
  454. raise ValueError("`semantic_generation_config` has to be provided")
  455. batch_size = input_ids.shape[0]
  456. max_input_semantic_length = semantic_generation_config.max_input_semantic_length
  457. input_ids = input_ids + semantic_generation_config.text_encoding_offset
  458. if attention_mask is not None:
  459. input_ids = input_ids.masked_fill((1 - attention_mask).bool(), semantic_generation_config.text_pad_token)
  460. if history_prompt is not None:
  461. semantic_history = history_prompt["semantic_prompt"][-max_input_semantic_length:]
  462. semantic_history = nn.functional.pad(
  463. semantic_history,
  464. (0, max_input_semantic_length - len(semantic_history)),
  465. value=semantic_generation_config.semantic_pad_token,
  466. mode="constant",
  467. )
  468. else:
  469. semantic_history = torch.full(
  470. (max_input_semantic_length,),
  471. semantic_generation_config.semantic_pad_token,
  472. device=self.device,
  473. dtype=torch.int,
  474. )
  475. semantic_history = torch.repeat_interleave(semantic_history[None], batch_size, dim=0)
  476. infer_array = torch.tensor(
  477. [[semantic_generation_config.semantic_infer_token]] * batch_size, dtype=torch.int
  478. ).to(self.device)
  479. inputs_embeds = torch.cat(
  480. [
  481. self.input_embeds_layer(input_ids[:, :max_input_semantic_length])
  482. + self.input_embeds_layer(semantic_history[:, : max_input_semantic_length + 1]),
  483. self.input_embeds_layer(infer_array),
  484. ],
  485. dim=1,
  486. )
  487. tokens_to_suppress = list(
  488. range(semantic_generation_config.semantic_vocab_size, semantic_generation_config.semantic_pad_token)
  489. )
  490. tokens_to_suppress.extend(
  491. list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size))
  492. )
  493. suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress, device=input_ids.device)
  494. min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
  495. early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
  496. eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p, device=input_ids.device
  497. )
  498. # pass input_ids in order to stay consistent with the transformers generate method even though it is not used
  499. # (except to get the input seq_len - that's why we keep the first 257 tokens)
  500. semantic_output = super().generate(
  501. torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device),
  502. inputs_embeds=inputs_embeds,
  503. logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
  504. generation_config=semantic_generation_config,
  505. **kwargs,
  506. ) # size: 10048
  507. # take the generated semantic tokens
  508. if kwargs.get("return_dict_in_generate", False):
  509. semantic_output = semantic_output.sequences[:, max_input_semantic_length + 1 :]
  510. else:
  511. semantic_output = semantic_output[:, max_input_semantic_length + 1 :]
  512. return semantic_output
  513. @auto_docstring(
  514. custom_intro="""
  515. Bark coarse acoustics model.
  516. It shares the same architecture as the semantic (or text) model. It is a GPT-2 like autoregressive model with a
  517. language modeling head on top.
  518. """
  519. )
  520. class BarkCoarseModel(BarkCausalModel):
  521. base_model_prefix = "coarse_acoustics"
  522. config: BarkCoarseConfig
  523. def preprocess_histories(
  524. self,
  525. max_coarse_history: int,
  526. semantic_to_coarse_ratio: int,
  527. batch_size: int,
  528. semantic_generation_config: int,
  529. codebook_size: int,
  530. history_prompt: dict[str, torch.Tensor] | None = None,
  531. ):
  532. """
  533. Preprocess the optional `Bark` speaker prompts before `self.generate`.
  534. Args:
  535. max_coarse_history (`int`):
  536. Maximum size of coarse tokens used.
  537. semantic_to_coarse_ratio (`int`):
  538. Ratio of semantic to coarse frequency
  539. batch_size (`int`):
  540. Batch size, i.e the number of samples.
  541. semantic_generation_config (`BarkSemanticGenerationConfig`):
  542. Generation config indicating how to generate the semantic tokens.
  543. codebook_size (`int`):
  544. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  545. history_prompt (`Optional[dict[str,torch.Tensor]]`):
  546. Optional `Bark` speaker prompt.
  547. Returns: Returns:
  548. `tuple(torch.FloatTensor)`:
  549. - **x_semantic_history** (`torch.FloatTensor` -- Processed semantic speaker prompt.
  550. - **x_coarse_history** (`torch.FloatTensor`) -- Processed coarse speaker prompt.
  551. """
  552. if history_prompt is not None:
  553. x_semantic_history = torch.repeat_interleave(history_prompt["semantic_prompt"][None], batch_size, dim=0)
  554. # clone to avoid modifying history_prompt.coarse_prompt
  555. x_coarse_history = history_prompt["coarse_prompt"].clone()
  556. # offset x_coarse_history
  557. if codebook_size is not None:
  558. for n in range(1, x_coarse_history.shape[0]):
  559. # offset
  560. x_coarse_history[n, :] += codebook_size * n
  561. # flatten x_coarse_history
  562. x_coarse_history = torch.transpose(x_coarse_history, 0, 1).reshape(-1)
  563. x_coarse_history = x_coarse_history + semantic_generation_config.semantic_vocab_size
  564. x_coarse_history = torch.repeat_interleave(x_coarse_history[None], batch_size, dim=0)
  565. # e.g: after SEMANTIC_VOCAB_SIZE (10000), 1024 tokens dedicated to first codebook, 1024 next tokens
  566. # dedicated to second codebook.
  567. max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
  568. # trim histories correctly
  569. n_semantic_hist_provided = min(
  570. [
  571. max_semantic_history,
  572. x_semantic_history.shape[1] - x_semantic_history.shape[1] % 2,
  573. int(np.floor(x_coarse_history.shape[1] / semantic_to_coarse_ratio)),
  574. ]
  575. )
  576. n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
  577. x_semantic_history = x_semantic_history[:, -n_semantic_hist_provided:].int()
  578. x_coarse_history = x_coarse_history[:, -n_coarse_hist_provided:].int()
  579. # bit of a hack for time alignment (sounds better) - from Bark original implementation
  580. x_coarse_history = x_coarse_history[:, :-2]
  581. else:
  582. # shape: (batch_size, 0)
  583. x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
  584. x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device)
  585. return x_semantic_history, x_coarse_history
  586. def generate(
  587. self,
  588. semantic_output: torch.Tensor,
  589. semantic_generation_config: BarkSemanticGenerationConfig | None = None,
  590. coarse_generation_config: BarkCoarseGenerationConfig | None = None,
  591. codebook_size: int = 1024,
  592. history_prompt: dict[str, torch.Tensor] | None = None,
  593. return_output_lengths: bool | None = None,
  594. **kwargs,
  595. ) -> torch.LongTensor | tuple[torch.LongTensor, torch.LongTensor]:
  596. """
  597. Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker
  598. prompt.
  599. Args:
  600. semantic_output (`torch.Tensor` of shape (batch_size, seq_len), *optional*):
  601. Input text semantic ids, i.e the output of `BarkSemanticModel.generate`.
  602. semantic_generation_config (`BarkSemanticGenerationConfig`):
  603. Generation config indicating how to generate the semantic tokens.
  604. coarse_generation_config (`BarkCoarseGenerationConfig`):
  605. Generation config indicating how to generate the coarse tokens.
  606. codebook_size (`int`, *optional*, defaults to 1024):
  607. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  608. history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
  609. Optional `Bark` speaker prompt.
  610. return_output_lengths (`bool`, *optional*):
  611. Whether or not to return the output lengths. Useful when batching.
  612. Returns:
  613. By default:
  614. torch.LongTensor: Output coarse acoustics tokens.
  615. If `return_output_lengths=True`:
  616. `Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample
  617. of the batch.
  618. """
  619. if semantic_generation_config is None:
  620. raise ValueError("`semantic_generation_config` has to be provided")
  621. if coarse_generation_config is None:
  622. raise ValueError("`coarse_generation_config` has to be provided")
  623. max_coarse_input_length = coarse_generation_config.max_coarse_input_length
  624. max_coarse_history = coarse_generation_config.max_coarse_history
  625. sliding_window_len = coarse_generation_config.sliding_window_len
  626. # replace semantic_pad_token (eos_tok and pad_tok here) with coarse_semantic_pad_token i.e the pad_token
  627. # used in the next model
  628. semantic_output.masked_fill_(
  629. semantic_output == semantic_generation_config.semantic_pad_token,
  630. coarse_generation_config.coarse_semantic_pad_token,
  631. )
  632. semantic_to_coarse_ratio = (
  633. coarse_generation_config.coarse_rate_hz
  634. / semantic_generation_config.semantic_rate_hz
  635. * coarse_generation_config.n_coarse_codebooks
  636. )
  637. max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
  638. output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1)
  639. output_lengths = torch.floor(
  640. output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
  641. )
  642. output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int()
  643. max_generated_len = torch.max(output_lengths).item()
  644. batch_size = semantic_output.shape[0]
  645. x_semantic_history, x_coarse = self.preprocess_histories(
  646. history_prompt=history_prompt,
  647. max_coarse_history=max_coarse_history,
  648. semantic_to_coarse_ratio=semantic_to_coarse_ratio,
  649. batch_size=batch_size,
  650. semantic_generation_config=semantic_generation_config,
  651. codebook_size=codebook_size,
  652. )
  653. base_semantic_idx = x_semantic_history.shape[1]
  654. semantic_output = torch.hstack([x_semantic_history, semantic_output])
  655. n_window_steps = int(np.ceil(max_generated_len / sliding_window_len))
  656. total_generated_len = 0
  657. len_coarse_history = x_coarse.shape[1]
  658. for _ in range(n_window_steps):
  659. semantic_idx = base_semantic_idx + int(round(total_generated_len / semantic_to_coarse_ratio))
  660. # pad from right side
  661. input_coarse = semantic_output[:, np.max([0, semantic_idx - max_semantic_history]) :]
  662. input_coarse = input_coarse[:, :max_coarse_input_length]
  663. input_coarse = F.pad(
  664. input_coarse,
  665. (0, max_coarse_input_length - input_coarse.shape[-1]),
  666. "constant",
  667. coarse_generation_config.coarse_semantic_pad_token,
  668. )
  669. input_coarse = torch.hstack(
  670. [
  671. input_coarse,
  672. torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device),
  673. x_coarse[:, -max_coarse_history:],
  674. ]
  675. )
  676. alternatingLogitsProcessor = AlternatingCodebooksLogitsProcessor(
  677. input_coarse.shape[1],
  678. semantic_generation_config.semantic_vocab_size,
  679. codebook_size,
  680. )
  681. output_coarse = super().generate(
  682. input_coarse,
  683. logits_processor=[alternatingLogitsProcessor],
  684. max_new_tokens=min(sliding_window_len, max_generated_len - total_generated_len),
  685. generation_config=coarse_generation_config,
  686. **kwargs,
  687. )
  688. input_coarse_len = input_coarse.shape[1]
  689. if kwargs.get("return_dict_in_generate", False):
  690. x_coarse = torch.hstack([x_coarse, output_coarse.sequences[:, input_coarse_len:]])
  691. else:
  692. x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]])
  693. total_generated_len = x_coarse.shape[1] - len_coarse_history
  694. del output_coarse
  695. coarse_output = x_coarse[:, len_coarse_history:]
  696. if return_output_lengths:
  697. return coarse_output, output_lengths
  698. return coarse_output
  699. @auto_docstring(
  700. custom_intro="""
  701. Bark fine acoustics model. It is a non-causal GPT-like model with `config.n_codes_total` embedding layers and
  702. language modeling heads, one for each codebook.
  703. """
  704. )
  705. class BarkFineModel(BarkPreTrainedModel):
  706. base_model_prefix = "fine_acoustics"
  707. config: BarkFineConfig
  708. main_input_name = "codebook_idx"
  709. def __init__(self, config):
  710. # non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
  711. super().__init__(config)
  712. self.config = config
  713. self._tied_weights_keys = {}
  714. for i in range(self.config.n_codes_total - self.config.n_codes_given):
  715. self._tied_weights_keys[f"lm_heads.{i}.weight"] = f"input_embeds_layers.{i + 1}.weight"
  716. # initialize a modified non causal GPT-like model
  717. # note that for there is one embedding layer and one lm_head for each codebook of Encodec
  718. self.input_embeds_layers = nn.ModuleList(
  719. [nn.Embedding(config.input_vocab_size, config.hidden_size) for _ in range(config.n_codes_total)]
  720. )
  721. self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
  722. self.drop = nn.Dropout(config.dropout)
  723. self.layers = nn.ModuleList(
  724. [BarkBlock(config, is_causal=False, layer_idx=i) for i in range(config.num_layers)]
  725. )
  726. self.layernorm_final = nn.LayerNorm(config.hidden_size)
  727. self.lm_heads = nn.ModuleList(
  728. [
  729. nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
  730. for _ in range(config.n_codes_given, config.n_codes_total)
  731. ]
  732. )
  733. self.gradient_checkpointing = False
  734. self.n_codes_total = config.n_codes_total
  735. # Initialize weights and apply final processing
  736. self.post_init()
  737. def get_input_embeddings(self):
  738. # one embedding layers for each codebook
  739. return self.input_embeds_layers
  740. def set_input_embeddings(self, new_embeddings):
  741. # one embedding layers for each codebook
  742. self.input_embeds_layers = new_embeddings
  743. def get_output_embeddings(self):
  744. # one lm_head for each codebook
  745. return self.lm_heads
  746. def set_output_embeddings(self, new_output_embeddings):
  747. # one lm_head for each codebook
  748. self.lm_heads = new_output_embeddings
  749. def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
  750. old_embeddings_list = self.get_input_embeddings()
  751. new_embeddings_list = nn.ModuleList(
  752. [
  753. self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing)
  754. for old_embeddings in old_embeddings_list
  755. ]
  756. )
  757. self.set_input_embeddings(new_embeddings_list)
  758. new_num_tokens = new_embeddings_list[0].weight.shape[0]
  759. # if word embeddings are not tied, make sure that lm head is resized as well
  760. if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
  761. old_lm_head_list = self.get_output_embeddings()
  762. new_lm_head_list = nn.ModuleList(
  763. [self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
  764. )
  765. self.set_output_embeddings(new_lm_head_list)
  766. return self.get_input_embeddings()
  767. def resize_token_embeddings(
  768. self,
  769. new_num_tokens: int | None = None,
  770. pad_to_multiple_of: int | None = None,
  771. mean_resizing: bool = True,
  772. ) -> nn.Embedding:
  773. """
  774. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  775. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  776. Arguments:
  777. new_num_tokens (`int`, *optional*):
  778. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  779. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  780. returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
  781. pad_to_multiple_of (`int`, *optional*):
  782. If set will pad the embedding matrix to a multiple of the provided value.
  783. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  784. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  785. details about this, or help on choosing the correct value for resizing, refer to this guide:
  786. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  787. mean_resizing (`bool`):
  788. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  789. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  790. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  791. where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
  792. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  793. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  794. Return:
  795. `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
  796. """
  797. model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  798. if new_num_tokens is None and pad_to_multiple_of is None:
  799. return model_embeds
  800. # Update base model and current model config
  801. self.config.output_vocab_size = model_embeds[0].weight.shape[0]
  802. self.config.vocab_size = model_embeds[0].weight.shape[0]
  803. self.output_vocab_size = model_embeds[0].weight.shape[0]
  804. self.vocab_size = model_embeds[0].weight.shape[0]
  805. # Tie weights again if needed
  806. self.tie_weights()
  807. return model_embeds
  808. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  809. @auto_docstring
  810. def forward(
  811. self,
  812. codebook_idx: int, # an additional idx corresponding to the id of the codebook that will be predicted
  813. input_ids: torch.Tensor | None = None,
  814. attention_mask: torch.Tensor | None = None,
  815. position_ids: torch.Tensor | None = None,
  816. labels: torch.LongTensor | None = None,
  817. inputs_embeds: torch.Tensor | None = None,
  818. output_attentions: bool | None = None,
  819. output_hidden_states: bool | None = None,
  820. return_dict: bool | None = None,
  821. **kwargs,
  822. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  823. r"""
  824. codebook_idx (`int`):
  825. Index of the codebook that will be predicted.
  826. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  827. NOT IMPLEMENTED YET.
  828. """
  829. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  830. output_hidden_states = (
  831. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  832. )
  833. return_dict = return_dict if return_dict is not None else self.config.return_dict
  834. loss = None
  835. if labels is not None:
  836. raise NotImplementedError("Training is not implemented yet")
  837. if codebook_idx == 0:
  838. raise ValueError("Cannot predict 0th codebook - 0th codebook should be predicted by the coarse model")
  839. if input_ids is not None and inputs_embeds is not None:
  840. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  841. if input_ids is None and inputs_embeds is None:
  842. raise ValueError("You have to specify either input_ids or inputs_embeds")
  843. if input_ids is not None:
  844. # the input_embeddings are the sum of the j previous codebooks embeddings before
  845. # the current codebook_idx codebook
  846. # forward the GPT model itself
  847. inputs_embeds = [
  848. input_embeds_layer(input_ids[:, :, i]).unsqueeze(-1)
  849. for i, input_embeds_layer in enumerate(self.input_embeds_layers)
  850. ] # token embeddings of shape (b, t, n_embd)
  851. inputs_embeds = torch.cat(inputs_embeds, dim=-1)
  852. inputs_embeds = inputs_embeds[:, :, :, : codebook_idx + 1].sum(dim=-1)
  853. input_shape = inputs_embeds.size()[:-1]
  854. seq_length = input_shape[1]
  855. inputs_embeds = inputs_embeds.to(self.position_embeds_layer.weight.device)
  856. if position_ids is None:
  857. position_ids = torch.arange(
  858. 0, seq_length, dtype=torch.long, device=self.position_embeds_layer.weight.device
  859. )
  860. position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
  861. position_ids = position_ids.to(self.position_embeds_layer.weight.device)
  862. position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
  863. attention_mask = create_bidirectional_mask(
  864. config=self.config,
  865. inputs_embeds=inputs_embeds,
  866. attention_mask=attention_mask,
  867. )
  868. hidden_states = self.drop(inputs_embeds + position_embeds)
  869. output_shape = input_shape + (hidden_states.size(-1),)
  870. all_self_attentions = () if output_attentions else None
  871. all_hidden_states = () if output_hidden_states else None
  872. for i, block in enumerate(self.layers):
  873. if output_hidden_states:
  874. all_hidden_states = all_hidden_states + (hidden_states,)
  875. outputs = block(
  876. hidden_states,
  877. attention_mask=attention_mask,
  878. output_attentions=output_attentions,
  879. )
  880. hidden_states = outputs[0]
  881. if output_attentions:
  882. all_self_attentions = all_self_attentions + (outputs[1],)
  883. hidden_states = self.layernorm_final(hidden_states)
  884. hidden_states = hidden_states.view(output_shape)
  885. # Add last hidden state
  886. if output_hidden_states:
  887. all_hidden_states = all_hidden_states + (hidden_states,)
  888. logits = self.lm_heads[codebook_idx - self.config.n_codes_given](hidden_states)
  889. if not return_dict:
  890. return tuple(v for v in [None, logits, all_hidden_states, all_self_attentions] if v is not None)
  891. return MaskedLMOutput(
  892. loss=loss,
  893. logits=logits,
  894. hidden_states=all_hidden_states,
  895. attentions=all_self_attentions,
  896. )
  897. @torch.no_grad()
  898. def generate(
  899. self,
  900. coarse_output: torch.Tensor,
  901. semantic_generation_config: BarkSemanticGenerationConfig | None = None,
  902. coarse_generation_config: BarkCoarseGenerationConfig | None = None,
  903. fine_generation_config: BarkFineGenerationConfig = None,
  904. codebook_size: int = 1024,
  905. history_prompt: dict[str, torch.Tensor] | None = None,
  906. **kwargs,
  907. ) -> torch.LongTensor:
  908. """
  909. Generates fine acoustics tokens from input coarse acoustics tokens and an additional optional `Bark` speaker
  910. prompt.
  911. Args:
  912. coarse_output (`torch.Tensor` of shape (batch_size, seq_len)):
  913. Input coarse acoustics ids, i.e the output of `BarkCoarseModel.generate`.
  914. semantic_generation_config (`BarkSemanticGenerationConfig`):
  915. Generation config indicating how to generate the semantic tokens.
  916. coarse_generation_config (`BarkCoarseGenerationConfig`):
  917. Generation config indicating how to generate the coarse tokens.
  918. fine_generation_config (`BarkFineGenerationConfig`):
  919. Generation config indicating how to generate the fine tokens.
  920. codebook_size (`int`, *optional*, defaults to 1024):
  921. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  922. history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
  923. Optional `Bark` speaker prompt.
  924. Returns:
  925. torch.LongTensor: Output fine acoustics tokens.
  926. """
  927. if semantic_generation_config is None:
  928. raise ValueError("`semantic_generation_config` has to be provided")
  929. if coarse_generation_config is None:
  930. raise ValueError("`coarse_generation_config` has to be provided")
  931. if fine_generation_config is None:
  932. raise ValueError("`fine_generation_config` has to be provided")
  933. # since we don't really use GenerationConfig through the fine model (autoencoder)
  934. # and since only temperature is used from the classic GenerationConfig parameters
  935. # manually impose the kwargs priority over the generation config
  936. temperature = kwargs.get("temperature", fine_generation_config.temperature)
  937. max_fine_history_length = fine_generation_config.max_fine_history_length
  938. max_fine_input_length = fine_generation_config.max_fine_input_length
  939. # shape: (batch, n_coarse_codebooks * seq_len)
  940. # new_shape: (batch, seq_len, n_coarse_codebooks)
  941. coarse_output = coarse_output.view(coarse_output.shape[0], -1, coarse_generation_config.n_coarse_codebooks)
  942. # brings ids into the range [0, codebook_size -1]
  943. coarse_output = torch.remainder(coarse_output - semantic_generation_config.semantic_vocab_size, codebook_size)
  944. batch_size = coarse_output.shape[0]
  945. if history_prompt is not None:
  946. x_fine_history = torch.repeat_interleave(history_prompt["fine_prompt"].T[None], batch_size, dim=0)
  947. # transpose to get to shape (seq_len, n_fine_codebooks)
  948. else:
  949. x_fine_history = None
  950. n_coarse = coarse_generation_config.n_coarse_codebooks
  951. # pad the last 6th codebooks
  952. fine_input = F.pad(
  953. coarse_output,
  954. (0, fine_generation_config.n_fine_codebooks - n_coarse),
  955. "constant",
  956. codebook_size,
  957. )
  958. # prepend history if available (max max_fine_history_length)
  959. if x_fine_history is not None:
  960. fine_input = torch.cat([x_fine_history[:, -max_fine_history_length:, :], fine_input], dim=1)
  961. # len of the fine_history that has been added to fine_input
  962. n_history = x_fine_history[:, -max_fine_history_length:, :].shape[1]
  963. else:
  964. n_history = 0
  965. n_remove_from_end = 0
  966. # need to pad if too short (since non-causal model)
  967. if fine_input.shape[1] < max_fine_input_length:
  968. n_remove_from_end = max_fine_input_length - fine_input.shape[1]
  969. fine_input = F.pad(fine_input, (0, 0, 0, n_remove_from_end), mode="constant", value=codebook_size)
  970. # we can be lazy about fractional loop and just keep overwriting codebooks.
  971. # seems that coarse_output.shape[1] - (max_fine_input_length - n_history) is equal to minus n_remove_from_end
  972. # So if we needed to pad because too short, n_loops is always 1 (because n_remove_from_end > 0)
  973. # If not, we loop over at least twice.
  974. n_loops = (coarse_output.shape[1] - (max_fine_input_length - n_history)) / max_fine_history_length
  975. n_loops = int(np.ceil(n_loops))
  976. n_loops = max(0, n_loops) + 1
  977. for n_outer in range(n_loops):
  978. start_idx = min([n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_input_length])
  979. start_fill_idx = min(
  980. [n_history + n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_history_length]
  981. )
  982. rel_start_fill_idx = start_fill_idx - start_idx
  983. input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :]
  984. for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
  985. logits = self.forward(n_inner, input_buffer).logits
  986. if temperature is None or temperature == 1.0:
  987. relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size]
  988. codebook_preds = torch.argmax(relevant_logits, -1)
  989. else:
  990. relevant_logits = logits[:, :, :codebook_size] / temperature
  991. # apply softmax
  992. probs = F.softmax(relevant_logits, dim=-1)[:, rel_start_fill_idx:max_fine_input_length]
  993. # reshape to 2D: (batch_size, seq_len, codebook_size) -> (batch_size*seq_len, codebook_size)
  994. probs = probs.reshape((-1, codebook_size))
  995. # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len)
  996. codebook_preds = torch.multinomial(probs, num_samples=1).view(batch_size, -1)
  997. codebook_preds = codebook_preds.to(torch.int32)
  998. input_buffer[:, rel_start_fill_idx:, n_inner] = codebook_preds
  999. del logits, codebook_preds
  1000. # transfer into fine_input
  1001. for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
  1002. fine_input[
  1003. :, start_fill_idx : start_fill_idx + (max_fine_input_length - rel_start_fill_idx), n_inner
  1004. ] = input_buffer[:, rel_start_fill_idx:, n_inner]
  1005. del input_buffer
  1006. fine_input = fine_input.transpose(1, 2)[:, :, n_history:]
  1007. if n_remove_from_end > 0:
  1008. fine_input = fine_input[:, :, :-n_remove_from_end]
  1009. if fine_input.shape[-1] != coarse_output.shape[-2]:
  1010. raise ValueError("input and output should have the same seq_len")
  1011. return fine_input
  1012. @auto_docstring(
  1013. custom_intro="""
  1014. The full Bark model, a text-to-speech model composed of 4 sub-models:
  1015. - [`BarkSemanticModel`] (also referred to as the 'text' model): a causal auto-regressive transformer model that
  1016. takes
  1017. as input tokenized text, and predicts semantic text tokens that capture the meaning of the text.
  1018. - [`BarkCoarseModel`] (also referred to as the 'coarse acoustics' model), also a causal autoregressive transformer,
  1019. that takes into input the results of the last model. It aims at regressing the first two audio codebooks necessary
  1020. to `encodec`.
  1021. - [`BarkFineModel`] (the 'fine acoustics' model), this time a non-causal autoencoder transformer, which iteratively
  1022. predicts the last codebooks based on the sum of the previous codebooks embeddings.
  1023. - having predicted all the codebook channels from the [`EncodecModel`], Bark uses it to decode the output audio
  1024. array.
  1025. It should be noted that each of the first three modules can support conditional speaker embeddings to condition the
  1026. output sound according to specific predefined voice.
  1027. """
  1028. )
  1029. class BarkModel(BarkPreTrainedModel, GenerationMixin):
  1030. config: BarkConfig
  1031. def __init__(self, config):
  1032. super().__init__(config)
  1033. self.semantic = BarkSemanticModel(config.semantic_config)
  1034. self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
  1035. self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
  1036. self.codec_model = AutoModel.from_config(config.codec_config)
  1037. self.config = config
  1038. self.post_init()
  1039. @classmethod
  1040. def can_generate(cls) -> bool:
  1041. # Bark has a unique model structure, where the external class (`BarkModel`) doesn't need to inherit from
  1042. # `GenerationMixin` (it has a non-standard generation method), but one of the internal models do
  1043. # (`BarkSemanticModel`). This means that the base `can_generate()` will return `False`, but we need to
  1044. # override it so as to do `GenerationConfig` handling in multiple parts of the codebase.
  1045. return True
  1046. @property
  1047. def device(self) -> torch.device:
  1048. """
  1049. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  1050. device).
  1051. """
  1052. # for bark_model, device must be verified on its sub-models
  1053. # if has _hf_hook, has been offloaded so the device has to be found in the hook
  1054. if not hasattr(self.semantic, "_hf_hook"):
  1055. return super().device
  1056. for module in self.semantic.modules():
  1057. if (
  1058. hasattr(module, "_hf_hook")
  1059. and hasattr(module._hf_hook, "execution_device")
  1060. and module._hf_hook.execution_device is not None
  1061. ):
  1062. return torch.device(module._hf_hook.execution_device)
  1063. def enable_cpu_offload(
  1064. self,
  1065. accelerator_id: int | None = 0,
  1066. **kwargs,
  1067. ):
  1068. r"""
  1069. Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
  1070. method moves one whole sub-model at a time to the accelerator when it is used, and the sub-model remains in accelerator until the next sub-model runs.
  1071. Args:
  1072. accelerator_id (`int`, *optional*, defaults to 0):
  1073. accelerator id on which the sub-models will be loaded and offloaded.
  1074. """
  1075. if is_accelerate_available():
  1076. from accelerate import cpu_offload_with_hook
  1077. else:
  1078. raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
  1079. device_type = "cuda"
  1080. if is_torch_accelerator_available():
  1081. device_type = torch.accelerator.current_accelerator().type
  1082. device = torch.device(f"{device_type}:{accelerator_id}")
  1083. torch_accelerator_module = getattr(torch, device_type)
  1084. if self.device.type != "cpu":
  1085. self.to("cpu")
  1086. torch_accelerator_module.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
  1087. # this layer is used outside the first forward pass of semantic so need to be loaded before semantic
  1088. self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
  1089. hook = None
  1090. for cpu_offloaded_model in [
  1091. self.semantic,
  1092. self.coarse_acoustics,
  1093. self.fine_acoustics,
  1094. ]:
  1095. _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
  1096. self.fine_acoustics_hook = hook
  1097. _, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
  1098. # We'll offload the last model manually.
  1099. self.codec_model_hook = hook
  1100. def codec_decode(self, fine_output, output_lengths=None):
  1101. """Turn quantized audio codes into audio array using encodec."""
  1102. fine_output = fine_output.transpose(0, 1)
  1103. emb = self.codec_model.quantizer.decode(fine_output)
  1104. if output_lengths is not None:
  1105. # encodec uses LSTMs which behaves differently with appended padding
  1106. # decoding with encodec takes around 0.1% of the total generation time
  1107. # to keep generation quality, we break batching
  1108. out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)]
  1109. audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out]
  1110. else:
  1111. out = self.codec_model.decoder(emb)
  1112. audio_arr = out.squeeze(1) # squeeze the codebook dimension
  1113. return audio_arr
  1114. @torch.no_grad()
  1115. def generate(
  1116. self,
  1117. input_ids: torch.Tensor | None = None,
  1118. history_prompt: dict[str, torch.Tensor] | None = None,
  1119. return_output_lengths: bool | None = None,
  1120. **kwargs,
  1121. ) -> torch.LongTensor:
  1122. """
  1123. Generates audio from an input prompt and an additional optional `Bark` speaker prompt.
  1124. Args:
  1125. input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
  1126. Input ids. Will be truncated up to 256 tokens. Note that the output audios will be as long as the
  1127. longest generation among the batch.
  1128. history_prompt (`Optional[dict[str,torch.Tensor]]`, *optional*):
  1129. Optional `Bark` speaker prompt. Note that for now, this model takes only one speaker prompt per batch.
  1130. kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments are of two types:
  1131. - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
  1132. - With a *semantic_*, *coarse_*, *fine_* prefix, they will be input for the `generate` method of the
  1133. semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
  1134. This means you can, for example, specify a generation strategy for all sub-models except one.
  1135. return_output_lengths (`bool`, *optional*):
  1136. Whether or not to return the waveform lengths. Useful when batching.
  1137. Returns:
  1138. By default:
  1139. - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
  1140. When `return_output_lengths=True`:
  1141. Returns a tuple made of:
  1142. - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
  1143. - **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch
  1144. Example:
  1145. ```python
  1146. >>> from transformers import AutoProcessor, BarkModel
  1147. >>> processor = AutoProcessor.from_pretrained("suno/bark-small")
  1148. >>> model = BarkModel.from_pretrained("suno/bark-small")
  1149. >>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
  1150. >>> voice_preset = "v2/en_speaker_6"
  1151. >>> inputs = processor("Hello, my dog is cute, I need him in my life", voice_preset=voice_preset)
  1152. >>> audio_array = model.generate(**inputs, semantic_max_new_tokens=100)
  1153. >>> audio_array = audio_array.cpu().numpy().squeeze()
  1154. ```
  1155. """
  1156. # TODO (joao):workaround until nested generation config is compatible with PreTrained Model
  1157. # todo: dict
  1158. semantic_generation_config = BarkSemanticGenerationConfig(**self.generation_config.semantic_config)
  1159. coarse_generation_config = BarkCoarseGenerationConfig(**self.generation_config.coarse_acoustics_config)
  1160. fine_generation_config = BarkFineGenerationConfig(**self.generation_config.fine_acoustics_config)
  1161. kwargs_semantic = {
  1162. # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
  1163. "attention_mask": kwargs.pop("attention_mask", None),
  1164. "min_eos_p": kwargs.pop("min_eos_p", None),
  1165. }
  1166. kwargs_coarse = {}
  1167. kwargs_fine = {}
  1168. for key, value in kwargs.items():
  1169. if key.startswith("semantic_"):
  1170. key = key[len("semantic_") :]
  1171. kwargs_semantic[key] = value
  1172. elif key.startswith("coarse_"):
  1173. key = key[len("coarse_") :]
  1174. kwargs_coarse[key] = value
  1175. elif key.startswith("fine_"):
  1176. key = key[len("fine_") :]
  1177. kwargs_fine[key] = value
  1178. else:
  1179. # If the key is already in a specific config, then it's been set with a
  1180. # submodules specific value and we don't override
  1181. if key not in kwargs_semantic:
  1182. kwargs_semantic[key] = value
  1183. if key not in kwargs_coarse:
  1184. kwargs_coarse[key] = value
  1185. if key not in kwargs_fine:
  1186. kwargs_fine[key] = value
  1187. # 1. Generate from the semantic model
  1188. if "generation_config" in kwargs_semantic:
  1189. kwargs_semantic.pop("generation_config")
  1190. semantic_output = self.semantic.generate(
  1191. input_ids,
  1192. history_prompt=history_prompt,
  1193. semantic_generation_config=semantic_generation_config,
  1194. **kwargs_semantic,
  1195. )
  1196. # 2. Generate from the coarse model
  1197. if "generation_config" in kwargs_coarse:
  1198. kwargs_coarse.pop("generation_config")
  1199. coarse_output = self.coarse_acoustics.generate(
  1200. semantic_output,
  1201. history_prompt=history_prompt,
  1202. semantic_generation_config=semantic_generation_config,
  1203. coarse_generation_config=coarse_generation_config,
  1204. codebook_size=self.generation_config.codebook_size,
  1205. return_output_lengths=return_output_lengths,
  1206. **kwargs_coarse,
  1207. )
  1208. output_lengths = None
  1209. if return_output_lengths:
  1210. coarse_output, output_lengths = coarse_output
  1211. # (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len)
  1212. output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks
  1213. # 3. "generate" from the fine model
  1214. if "generation_config" in kwargs_fine:
  1215. kwargs_fine.pop("generation_config")
  1216. output = self.fine_acoustics.generate(
  1217. coarse_output,
  1218. history_prompt=history_prompt,
  1219. semantic_generation_config=semantic_generation_config,
  1220. coarse_generation_config=coarse_generation_config,
  1221. fine_generation_config=fine_generation_config,
  1222. codebook_size=self.generation_config.codebook_size,
  1223. **kwargs_fine,
  1224. )
  1225. if getattr(self, "fine_acoustics_hook", None) is not None:
  1226. # Manually offload fine_acoustics to CPU
  1227. # and load codec_model to GPU
  1228. # since bark doesn't use codec_model forward pass
  1229. self.fine_acoustics_hook.offload()
  1230. self.codec_model = self.codec_model.to(self.device)
  1231. # 4. Decode the output and generate audio array
  1232. audio = self.codec_decode(output, output_lengths)
  1233. if getattr(self, "codec_model_hook", None) is not None:
  1234. # Offload codec_model to CPU
  1235. self.codec_model_hook.offload()
  1236. if return_output_lengths:
  1237. output_lengths = [len(sample) for sample in audio]
  1238. audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0)
  1239. return audio, output_lengths
  1240. return audio
  1241. __all__ = [
  1242. "BarkFineModel",
  1243. "BarkSemanticModel",
  1244. "BarkCoarseModel",
  1245. "BarkModel",
  1246. "BarkPreTrainedModel",
  1247. "BarkCausalModel",
  1248. ]