modeling_musicgen.py 102 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173
  1. # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Musicgen model."""
  15. import copy
  16. import inspect
  17. import math
  18. import random
  19. from collections.abc import Callable
  20. from dataclasses import dataclass
  21. from typing import TYPE_CHECKING, Any, Optional
  22. import torch
  23. import torch.nn as nn
  24. from torch.nn import CrossEntropyLoss
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  28. from ...generation import (
  29. ClassifierFreeGuidanceLogitsProcessor,
  30. GenerationConfig,
  31. GenerationMixin,
  32. GenerationMode,
  33. LogitsProcessorList,
  34. StoppingCriteriaList,
  35. )
  36. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  37. from ...modeling_flash_attention_utils import (
  38. FlashAttentionKwargs,
  39. )
  40. from ...modeling_layers import GradientCheckpointingLayer
  41. from ...modeling_outputs import (
  42. BaseModelOutput,
  43. BaseModelOutputWithPastAndCrossAttentions,
  44. CausalLMOutputWithCrossAttentions,
  45. ModelOutput,
  46. Seq2SeqLMOutput,
  47. )
  48. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  49. from ...processing_utils import Unpack
  50. from ...utils import TransformersKwargs, auto_docstring, logging
  51. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  52. from ...utils.output_capturing import OutputRecorder, capture_outputs
  53. from ..auto.configuration_auto import AutoConfig
  54. from ..auto.modeling_auto import AutoModel
  55. from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
  56. if TYPE_CHECKING:
  57. from ...generation.streamers import BaseStreamer
  58. logger = logging.get_logger(__name__)
  59. @dataclass
  60. @auto_docstring
  61. class MusicgenUnconditionalInput(ModelOutput):
  62. r"""
  63. encoder_outputs (`tuple[torch.FloatTensor]` of length 1, with tensor shape `(batch_size, sequence_length, hidden_size)`):
  64. Sequence of hidden-states at the output of the last layer of the text encoder model.
  65. attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
  66. Encoder attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
  67. 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
  68. guidance_scale (`float`, *optional*):
  69. Guidance scale for classifier free guidance, setting the balance between the conditional logits (predicted
  70. from the prompts) and the unconditional logits (predicted without prompts).
  71. """
  72. encoder_outputs: tuple[torch.FloatTensor] | None = None
  73. attention_mask: torch.LongTensor | None = None
  74. guidance_scale: float | None = None
  75. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  76. """
  77. Shift input ids one token to the right.
  78. """
  79. # transpose to get (bsz, num_codebooks, seq_len)
  80. input_ids = input_ids.transpose(1, 2)
  81. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  82. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  83. if decoder_start_token_id is None:
  84. raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
  85. shifted_input_ids[..., 0] = decoder_start_token_id
  86. if pad_token_id is None:
  87. raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
  88. # replace possible -100 values in labels by `pad_token_id`
  89. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  90. return shifted_input_ids
  91. class MusicgenSinusoidalPositionalEmbedding(nn.Module):
  92. """This module produces sinusoidal positional embeddings of any length."""
  93. def __init__(self, num_positions: int, embedding_dim: int):
  94. super().__init__()
  95. self.embedding_dim = embedding_dim
  96. self.num_positions = num_positions
  97. self.make_weights(num_positions, embedding_dim)
  98. def make_weights(self, num_embeddings: int, embedding_dim: int):
  99. emb_weights = self.get_embedding(num_embeddings, embedding_dim)
  100. if hasattr(self, "weights"):
  101. # in forward put the weights on the correct dtype and device of the param
  102. emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
  103. self.register_buffer("weights", emb_weights, persistent=False)
  104. @staticmethod
  105. def get_embedding(num_embeddings: int, embedding_dim: int):
  106. """
  107. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
  108. description in Section 3.5 of "Attention Is All You Need".
  109. """
  110. half_dim = embedding_dim // 2
  111. emb = math.log(10000) / (half_dim - 1)
  112. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  113. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  114. emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1)
  115. if embedding_dim % 2 == 1:
  116. # zero pad
  117. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  118. return emb.to(torch.get_default_dtype())
  119. @torch.no_grad()
  120. def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
  121. bsz, codebooks, seq_len = input_ids.size()
  122. # Create the position ids from the input token ids.
  123. position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
  124. # expand embeddings if needed
  125. if seq_len > self.weights.size(0):
  126. self.make_weights(seq_len, self.embedding_dim)
  127. return self.weights.index_select(0, position_ids.view(-1)).detach()
  128. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  129. def eager_attention_forward(
  130. module: nn.Module,
  131. query: torch.Tensor,
  132. key: torch.Tensor,
  133. value: torch.Tensor,
  134. attention_mask: torch.Tensor | None,
  135. scaling: float | None = None,
  136. dropout: float = 0.0,
  137. **kwargs: Unpack[TransformersKwargs],
  138. ):
  139. if scaling is None:
  140. scaling = query.size(-1) ** -0.5
  141. # Take the dot product between "query" and "key" to get the raw attention scores.
  142. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  143. if attention_mask is not None:
  144. attn_weights = attn_weights + attention_mask
  145. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  146. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  147. attn_output = torch.matmul(attn_weights, value)
  148. attn_output = attn_output.transpose(1, 2).contiguous()
  149. return attn_output, attn_weights
  150. class MusicgenAttention(nn.Module):
  151. """Multi-headed attention from 'Attention Is All You Need' paper"""
  152. def __init__(
  153. self,
  154. embed_dim: int,
  155. num_heads: int,
  156. dropout: float | None = 0.0,
  157. is_decoder: bool | None = False,
  158. bias: bool | None = True,
  159. is_causal: bool | None = False,
  160. config: MusicgenConfig | None = None,
  161. layer_idx: int | None = None,
  162. ):
  163. super().__init__()
  164. self.embed_dim = embed_dim
  165. self.num_heads = num_heads
  166. self.dropout = dropout
  167. self.head_dim = embed_dim // num_heads
  168. self.config = config
  169. if (self.head_dim * num_heads) != self.embed_dim:
  170. raise ValueError(
  171. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  172. f" and `num_heads`: {num_heads})."
  173. )
  174. self.scaling = self.head_dim**-0.5
  175. self.is_decoder = is_decoder
  176. self.is_causal = is_causal
  177. self.layer_idx = layer_idx
  178. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  179. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  180. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  181. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  182. def forward(
  183. self,
  184. hidden_states: torch.Tensor,
  185. key_value_states: torch.Tensor | None = None,
  186. past_key_values: Cache | None = None,
  187. attention_mask: torch.Tensor | None = None,
  188. output_attentions: bool | None = False,
  189. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  190. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  191. **kwargs: Unpack[FlashAttentionKwargs],
  192. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  193. """Input shape: Batch x Time x Channel"""
  194. # if key_value_states are provided this layer is used as a cross-attention layer
  195. # for the decoder
  196. is_cross_attention = key_value_states is not None
  197. # determine input shapes
  198. input_shape = hidden_states.shape[:-1]
  199. hidden_shape = (*input_shape, -1, self.head_dim)
  200. # get query proj
  201. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  202. is_updated = False
  203. if past_key_values is not None:
  204. if isinstance(past_key_values, EncoderDecoderCache):
  205. is_updated = past_key_values.is_updated.get(self.layer_idx)
  206. if is_cross_attention:
  207. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  208. curr_past_key_values = past_key_values.cross_attention_cache
  209. else:
  210. curr_past_key_values = past_key_values.self_attention_cache
  211. else:
  212. curr_past_key_values = past_key_values
  213. current_states = key_value_states if is_cross_attention else hidden_states
  214. if is_cross_attention and past_key_values is not None and is_updated:
  215. # reuse k,v, cross_attentions
  216. key_states = curr_past_key_values.layers[self.layer_idx].keys
  217. value_states = curr_past_key_values.layers[self.layer_idx].values
  218. else:
  219. kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
  220. key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2)
  221. value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2)
  222. if past_key_values is not None:
  223. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  224. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  225. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  226. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  227. past_key_values.is_updated[self.layer_idx] = True
  228. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  229. self.config._attn_implementation, eager_attention_forward
  230. )
  231. attn_output, attn_weights = attention_interface(
  232. self,
  233. query_states,
  234. key_states,
  235. value_states,
  236. attention_mask,
  237. dropout=0.0 if not self.training else self.dropout,
  238. scaling=self.scaling,
  239. output_attentions=output_attentions,
  240. **kwargs,
  241. )
  242. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  243. attn_output = self.out_proj(attn_output)
  244. return attn_output, attn_weights
  245. class MusicgenDecoderLayer(GradientCheckpointingLayer):
  246. def __init__(self, config: MusicgenDecoderConfig, layer_idx=None):
  247. super().__init__()
  248. self.embed_dim = config.hidden_size
  249. self.self_attn = MusicgenAttention(
  250. embed_dim=self.embed_dim,
  251. num_heads=config.num_attention_heads,
  252. dropout=config.attention_dropout,
  253. is_decoder=True,
  254. bias=False,
  255. is_causal=True,
  256. config=config,
  257. layer_idx=layer_idx,
  258. )
  259. self.dropout = config.dropout
  260. self.activation_fn = ACT2FN[config.activation_function]
  261. self.activation_dropout = config.activation_dropout
  262. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  263. self.encoder_attn = MusicgenAttention(
  264. self.embed_dim,
  265. config.num_attention_heads,
  266. dropout=config.attention_dropout,
  267. is_decoder=True,
  268. bias=False,
  269. config=config,
  270. layer_idx=layer_idx,
  271. )
  272. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  273. self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
  274. self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
  275. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  276. def forward(
  277. self,
  278. hidden_states: torch.Tensor,
  279. attention_mask: torch.Tensor | None = None,
  280. encoder_hidden_states: torch.Tensor | None = None,
  281. encoder_attention_mask: torch.Tensor | None = None,
  282. past_key_values: Cache | None = None,
  283. use_cache: bool | None = True,
  284. **kwargs: Unpack[TransformersKwargs],
  285. ) -> torch.Tensor:
  286. """
  287. Args:
  288. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  289. attention_mask (`torch.FloatTensor`): attention mask of size
  290. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  291. encoder_hidden_states (`torch.FloatTensor`):
  292. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  293. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  294. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  295. past_key_values (`Cache`): cached past key and value projection states
  296. """
  297. residual = hidden_states
  298. hidden_states = self.self_attn_layer_norm(hidden_states)
  299. # Self Attention
  300. hidden_states, _ = self.self_attn(
  301. hidden_states,
  302. past_key_values=past_key_values,
  303. attention_mask=attention_mask,
  304. **kwargs,
  305. )
  306. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  307. hidden_states = residual + hidden_states
  308. # Cross-Attention Block
  309. if encoder_hidden_states is not None:
  310. residual = hidden_states
  311. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  312. hidden_states, _ = self.encoder_attn(
  313. hidden_states,
  314. key_value_states=encoder_hidden_states,
  315. attention_mask=encoder_attention_mask,
  316. past_key_values=past_key_values,
  317. **kwargs,
  318. )
  319. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  320. hidden_states = residual + hidden_states
  321. # Fully Connected
  322. residual = hidden_states
  323. hidden_states = self.final_layer_norm(hidden_states)
  324. hidden_states = self.activation_fn(self.fc1(hidden_states))
  325. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  326. hidden_states = self.fc2(hidden_states)
  327. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  328. hidden_states = residual + hidden_states
  329. return hidden_states
  330. @auto_docstring
  331. class MusicgenPreTrainedModel(PreTrainedModel):
  332. config: MusicgenDecoderConfig
  333. base_model_prefix = "model"
  334. supports_gradient_checkpointing = True
  335. _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
  336. _supports_flash_attn = True
  337. _supports_sdpa = True
  338. _supports_flex_attn = True
  339. @torch.no_grad()
  340. def _init_weights(self, module):
  341. std = self.config.initializer_factor
  342. if isinstance(module, nn.Linear):
  343. init.normal_(module.weight, mean=0.0, std=std)
  344. if module.bias is not None:
  345. init.zeros_(module.bias)
  346. elif isinstance(module, nn.LayerNorm):
  347. init.ones_(module.weight)
  348. init.zeros_(module.bias)
  349. elif isinstance(module, nn.Embedding):
  350. init.normal_(module.weight, mean=0.0, std=std)
  351. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  352. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  353. init.zeros_(module.weight[module.padding_idx])
  354. elif isinstance(module, MusicgenSinusoidalPositionalEmbedding):
  355. emb_weights = module.get_embedding(module.num_positions, module.embedding_dim)
  356. init.copy_(module.weights, emb_weights)
  357. class MusicgenDecoder(MusicgenPreTrainedModel):
  358. """
  359. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MusicgenDecoderLayer`]
  360. """
  361. _can_record_outputs = {
  362. "hidden_states": MusicgenDecoderLayer,
  363. "attentions": OutputRecorder(MusicgenAttention, index=1, layer_name="self_attn"),
  364. "cross_attentions": OutputRecorder(MusicgenAttention, index=1, layer_name="encoder_attn"),
  365. }
  366. def __init__(self, config: MusicgenDecoderConfig):
  367. super().__init__(config)
  368. self.dropout = config.dropout
  369. self.layerdrop = config.layerdrop
  370. self.max_target_positions = config.max_position_embeddings
  371. self.d_model = config.hidden_size
  372. self.num_codebooks = config.num_codebooks
  373. self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  374. embed_dim = config.vocab_size + 1
  375. self.embed_tokens = nn.ModuleList(
  376. [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
  377. )
  378. self.embed_positions = MusicgenSinusoidalPositionalEmbedding(
  379. config.max_position_embeddings,
  380. config.hidden_size,
  381. )
  382. self.layers = nn.ModuleList(
  383. [MusicgenDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
  384. )
  385. self.layer_norm = nn.LayerNorm(config.hidden_size)
  386. self.attn_implementation = config._attn_implementation
  387. self.gradient_checkpointing = False
  388. # Initialize weights and apply final processing
  389. self.post_init()
  390. @merge_with_config_defaults
  391. @capture_outputs
  392. @auto_docstring
  393. def forward(
  394. self,
  395. input_ids: torch.LongTensor | None = None,
  396. attention_mask: torch.Tensor | None = None,
  397. encoder_hidden_states: torch.FloatTensor | None = None,
  398. encoder_attention_mask: torch.LongTensor | None = None,
  399. past_key_values: Cache | None = None,
  400. inputs_embeds: torch.FloatTensor | None = None,
  401. use_cache: bool | None = None,
  402. **kwargs: Unpack[TransformersKwargs],
  403. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  404. r"""
  405. input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
  406. Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
  407. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
  408. such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
  409. [What are input IDs?](../glossary#input-ids)
  410. <Tip warning={true}>
  411. The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
  412. target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
  413. you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
  414. frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
  415. target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
  416. `input_ids`.
  417. </Tip>
  418. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  419. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
  420. the decoder.
  421. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  422. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  423. selected in `[0, 1]`:
  424. - 1 for tokens that are **not masked**,
  425. - 0 for tokens that are **masked**.
  426. [What are attention masks?](../glossary#attention-mask)
  427. """
  428. if input_ids is not None and inputs_embeds is not None:
  429. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  430. elif input_ids is not None:
  431. # (bsz * codebooks, seq_len) -> (bsz, codebooks, seq_len)
  432. input = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
  433. bsz, num_codebooks, seq_len = input.shape
  434. elif inputs_embeds is not None:
  435. input = inputs_embeds[:, :, -1:]
  436. else:
  437. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  438. if use_cache and past_key_values is None:
  439. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  440. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  441. if inputs_embeds is None:
  442. inputs_embeds = sum(self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks))
  443. attention_mask = create_causal_mask(
  444. config=self.config,
  445. inputs_embeds=inputs_embeds,
  446. attention_mask=attention_mask,
  447. past_key_values=past_key_values,
  448. )
  449. encoder_attention_mask = create_bidirectional_mask(
  450. config=self.config,
  451. inputs_embeds=inputs_embeds,
  452. attention_mask=encoder_attention_mask,
  453. encoder_hidden_states=encoder_hidden_states,
  454. )
  455. # embed positions
  456. positions = self.embed_positions(input, past_key_values_length)
  457. hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
  458. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  459. for idx, decoder_layer in enumerate(self.layers):
  460. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  461. dropout_probability = random.uniform(0, 1)
  462. if self.training and (dropout_probability < self.layerdrop):
  463. continue
  464. hidden_states = decoder_layer(
  465. hidden_states,
  466. attention_mask,
  467. encoder_hidden_states, # as a positional argument for gradient checkpointing
  468. encoder_attention_mask=encoder_attention_mask,
  469. past_key_values=past_key_values,
  470. use_cache=use_cache,
  471. **kwargs,
  472. )
  473. hidden_states = self.layer_norm(hidden_states)
  474. return BaseModelOutputWithPastAndCrossAttentions(
  475. last_hidden_state=hidden_states,
  476. past_key_values=past_key_values,
  477. )
  478. @auto_docstring
  479. class MusicgenModel(MusicgenPreTrainedModel):
  480. def __init__(self, config: MusicgenDecoderConfig):
  481. super().__init__(config)
  482. self.decoder = MusicgenDecoder(config)
  483. # Initialize weights and apply final processing
  484. self.post_init()
  485. def get_input_embeddings(self):
  486. return self.decoder.embed_tokens
  487. def set_input_embeddings(self, value):
  488. self.decoder.embed_tokens = value
  489. @merge_with_config_defaults
  490. @capture_outputs
  491. @auto_docstring
  492. def forward(
  493. self,
  494. input_ids: torch.LongTensor | None = None,
  495. attention_mask: torch.Tensor | None = None,
  496. encoder_hidden_states: torch.FloatTensor | None = None,
  497. encoder_attention_mask: torch.LongTensor | None = None,
  498. past_key_values: Cache | None = None,
  499. inputs_embeds: torch.FloatTensor | None = None,
  500. use_cache: bool | None = None,
  501. **kwargs: Unpack[TransformersKwargs],
  502. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  503. r"""
  504. input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
  505. Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
  506. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
  507. such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
  508. [What are input IDs?](../glossary#input-ids)
  509. <Tip warning={true}>
  510. The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
  511. target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
  512. you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
  513. frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
  514. target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
  515. `input_ids`.
  516. </Tip>
  517. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  518. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
  519. the decoder.
  520. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  521. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  522. selected in `[0, 1]`:
  523. - 1 for tokens that are **not masked**,
  524. - 0 for tokens that are **masked**.
  525. [What are attention masks?](../glossary#attention-mask)
  526. """
  527. decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
  528. input_ids=input_ids,
  529. attention_mask=attention_mask,
  530. encoder_attention_mask=encoder_attention_mask,
  531. encoder_hidden_states=encoder_hidden_states,
  532. past_key_values=past_key_values,
  533. inputs_embeds=inputs_embeds,
  534. use_cache=use_cache,
  535. **kwargs,
  536. )
  537. return decoder_outputs
  538. @auto_docstring(
  539. custom_intro="""
  540. The MusicGen decoder model with a language modelling head on top.
  541. """
  542. )
  543. class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
  544. output_modalities = ("audio",)
  545. def __init__(self, config: MusicgenDecoderConfig):
  546. super().__init__(config)
  547. self.model = MusicgenModel(config)
  548. self.num_codebooks = config.num_codebooks
  549. self.lm_heads = nn.ModuleList(
  550. [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)]
  551. )
  552. # Initialize weights and apply final processing
  553. self.post_init()
  554. def get_input_embeddings(self):
  555. return self.model.decoder.embed_tokens
  556. def set_input_embeddings(self, value):
  557. self.model.decoder.embed_tokens = value
  558. def get_output_embeddings(self):
  559. return self.lm_heads
  560. def set_output_embeddings(self, new_embeddings):
  561. self.lm_heads = new_embeddings
  562. @merge_with_config_defaults
  563. @capture_outputs
  564. @auto_docstring
  565. def forward(
  566. self,
  567. input_ids: torch.LongTensor | None = None,
  568. attention_mask: torch.Tensor | None = None,
  569. encoder_hidden_states: torch.FloatTensor | None = None,
  570. encoder_attention_mask: torch.LongTensor | None = None,
  571. past_key_values: Cache | None = None,
  572. inputs_embeds: torch.FloatTensor | None = None,
  573. labels: torch.LongTensor | None = None,
  574. use_cache: bool | None = None,
  575. **kwargs: Unpack[TransformersKwargs],
  576. ) -> tuple | CausalLMOutputWithCrossAttentions:
  577. r"""
  578. input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, sequence_length)`):
  579. Indices of input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
  580. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
  581. such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
  582. [What are input IDs?](../glossary#input-ids)
  583. <Tip warning={true}>
  584. The `input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
  585. target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
  586. you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
  587. frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
  588. target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
  589. `input_ids`.
  590. </Tip>
  591. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  592. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
  593. the decoder.
  594. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  595. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  596. selected in `[0, 1]`:
  597. - 1 for tokens that are **not masked**,
  598. - 0 for tokens that are **masked**.
  599. [What are attention masks?](../glossary#attention-mask)
  600. labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
  601. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  602. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  603. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  604. """
  605. if (labels is not None) and (input_ids is None and inputs_embeds is None):
  606. input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.bos_token_id)
  607. outputs: BaseModelOutputWithPastAndCrossAttentions = self.model(
  608. input_ids,
  609. attention_mask=attention_mask,
  610. encoder_hidden_states=encoder_hidden_states,
  611. encoder_attention_mask=encoder_attention_mask,
  612. past_key_values=past_key_values,
  613. inputs_embeds=inputs_embeds,
  614. use_cache=use_cache,
  615. **kwargs,
  616. )
  617. hidden_states = outputs.last_hidden_state
  618. lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)
  619. loss = None
  620. if labels is not None:
  621. # since encoder hidden states have been concatenated to the decoder hidden states,
  622. # we take the last timestamps corresponding to labels
  623. logits = lm_logits[:, :, -labels.shape[1] :]
  624. loss_fct = CrossEntropyLoss()
  625. loss = torch.zeros([], device=self.device)
  626. # per codebook cross-entropy
  627. # -100 labels are ignored
  628. labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
  629. # per codebook cross-entropy
  630. # ref: https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/solvers/musicgen.py#L242-L243
  631. for codebook in range(self.config.num_codebooks):
  632. codebook_logits = logits[:, codebook].contiguous().view(-1, logits.shape[-1])
  633. codebook_labels = labels[..., codebook].contiguous().view(-1)
  634. loss += loss_fct(codebook_logits, codebook_labels)
  635. loss = loss / self.config.num_codebooks
  636. # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
  637. lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
  638. return CausalLMOutputWithCrossAttentions(
  639. loss=loss,
  640. logits=lm_logits,
  641. past_key_values=outputs.past_key_values,
  642. hidden_states=outputs.hidden_states,
  643. attentions=outputs.attentions,
  644. cross_attentions=outputs.cross_attentions,
  645. )
  646. def prepare_inputs_for_generation(
  647. self,
  648. input_ids,
  649. attention_mask=None,
  650. encoder_hidden_states=None,
  651. encoder_attention_mask=None,
  652. past_key_values=None,
  653. use_cache=True,
  654. delay_pattern_mask=None,
  655. guidance_scale=None,
  656. **kwargs,
  657. ):
  658. # Overwritten -- MusicGen has custom processing
  659. if delay_pattern_mask is None:
  660. input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
  661. input_ids,
  662. pad_token_id=self.generation_config.pad_token_id,
  663. max_length=self.generation_config.max_length,
  664. )
  665. # apply the delay pattern mask
  666. input_ids = self.apply_delay_pattern_mask(input_ids, delay_pattern_mask)
  667. if guidance_scale is not None and guidance_scale > 1:
  668. # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
  669. # before sampling)
  670. input_ids = input_ids.repeat((2, 1))
  671. if attention_mask is not None:
  672. attention_mask = attention_mask.repeat((2, 1))
  673. if past_key_values is not None:
  674. input_ids = input_ids[:, -1:]
  675. return {
  676. "input_ids": input_ids,
  677. "attention_mask": attention_mask,
  678. "encoder_hidden_states": encoder_hidden_states,
  679. "encoder_attention_mask": encoder_attention_mask,
  680. "past_key_values": past_key_values,
  681. "use_cache": use_cache,
  682. }
  683. def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: int, max_length: int | None = None):
  684. """Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
  685. one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
  686. are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
  687. seq_len)`:
  688. - [P, -1, -1, -1, -1, P, P, P]
  689. - [P, P, -1, -1, -1, -1, P, P]
  690. - [P, P, P, -1, -1, -1, -1, P]
  691. - [P, P, P, P, -1, -1, -1, -1]
  692. where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
  693. a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
  694. mask is set to the value in the prompt:
  695. - [P, a, b, -1, -1, P, P, P]
  696. - [P, P, c, d, -1, -1, P, P]
  697. - [P, P, P, e, f, -1, -1, P]
  698. - [P, P, P, P, g, h, -1, -1]
  699. where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
  700. tokens in our prediction.
  701. """
  702. # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
  703. input_ids = input_ids.reshape(-1, self.num_codebooks, input_ids.shape[-1])
  704. bsz, num_codebooks, seq_len = input_ids.shape
  705. max_length = max_length if max_length is not None else self.generation_config.max_length
  706. input_ids_shifted = (
  707. torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
  708. )
  709. channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
  710. # we only apply the mask if we have a large enough seq len - otherwise we return as is
  711. if max_length < 2 * channel_codebooks - 1:
  712. return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
  713. # fill the shifted ids with the prompt entries, offset by the codebook idx
  714. for codebook in range(channel_codebooks):
  715. if self.config.audio_channels == 1:
  716. # mono channel - loop over the codebooks one-by-one
  717. input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
  718. else:
  719. # left/right channels are interleaved in the generated codebooks, so handle one then the other
  720. input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
  721. input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
  722. # construct a pattern mask that indicates the positions of padding tokens for each codebook
  723. # first fill the upper triangular part (the EOS padding)
  724. delay_pattern = torch.triu(
  725. torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
  726. )
  727. # then fill the lower triangular part (the BOS padding)
  728. delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
  729. if self.config.audio_channels == 2:
  730. # for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
  731. delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
  732. mask = ~delay_pattern.to(input_ids.device)
  733. input_ids = mask * input_ids_shifted + ~mask * pad_token_id
  734. # find the first position to start generating - this is the first place we have the -1 token
  735. # and will always be in the first codebook (since it has no codebook offset)
  736. first_codebook_ids = input_ids[:, 0, :]
  737. start_ids = (first_codebook_ids == -1).nonzero()[:, 1]
  738. if len(start_ids) > 0:
  739. first_start_id = min(start_ids)
  740. else:
  741. # we have no tokens that need to be filled - return entire matrix of input ids
  742. first_start_id = seq_len
  743. # (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
  744. pattern_mask = input_ids.reshape(bsz * num_codebooks, -1)
  745. input_ids = input_ids[..., :first_start_id].reshape(bsz * num_codebooks, -1)
  746. return input_ids, pattern_mask
  747. @staticmethod
  748. def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
  749. """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
  750. the mask is set to -1, and otherwise setting to the value detailed in the mask."""
  751. seq_len = input_ids.shape[-1]
  752. decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
  753. input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
  754. return input_ids
  755. @torch.no_grad()
  756. def generate(
  757. self,
  758. inputs: torch.Tensor | None = None,
  759. generation_config: GenerationConfig | None = None,
  760. logits_processor: LogitsProcessorList | None = None,
  761. stopping_criteria: StoppingCriteriaList | None = None,
  762. synced_gpus: bool | None = None,
  763. streamer: Optional["BaseStreamer"] = None,
  764. **kwargs,
  765. ):
  766. """
  767. Generates sequences of token ids for models with a language modeling head.
  768. <Tip warning={true}>
  769. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
  770. model's default generation configuration. You can override any `generation_config` by passing the corresponding
  771. parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
  772. For an overview of generation strategies and code examples, check out the [following
  773. guide](./generation_strategies).
  774. </Tip>
  775. Parameters:
  776. inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
  777. The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
  778. method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
  779. should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
  780. `input_ids`, `input_values`, `input_features`, or `pixel_values`.
  781. generation_config (`~generation.GenerationConfig`, *optional*):
  782. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  783. passed to generate matching the attributes of `generation_config` will override them. If
  784. `generation_config` is not provided, the default will be used, which had the following loading
  785. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  786. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  787. default values, whose documentation should be checked to parameterize generation.
  788. logits_processor (`LogitsProcessorList`, *optional*):
  789. Custom logits processors that complement the default logits processors built from arguments and
  790. generation config. If a logit processor is passed that is already created with the arguments or a
  791. generation config an error is thrown. This feature is intended for advanced users.
  792. stopping_criteria (`StoppingCriteriaList`, *optional*):
  793. Custom stopping criteria that complement the default stopping criteria built from arguments and a
  794. generation config. If a stopping criteria is passed that is already created with the arguments or a
  795. generation config an error is thrown. This feature is intended for advanced users.
  796. synced_gpus (`bool`, *optional*, defaults to `False`):
  797. Whether to continue running the while loop until max_length (needed to avoid deadlocking with
  798. `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
  799. streamer (`BaseStreamer`, *optional*):
  800. Streamer object that will be used to stream the generated sequences. Generated tokens are passed
  801. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
  802. kwargs (`dict[str, Any]`, *optional*):
  803. Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
  804. forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
  805. specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
  806. Return:
  807. [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
  808. or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
  809. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
  810. [`~utils.ModelOutput`] types are:
  811. - [`~generation.GenerateDecoderOnlyOutput`],
  812. - [`~generation.GenerateBeamDecoderOnlyOutput`]
  813. If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
  814. [`~utils.ModelOutput`] types are:
  815. - [`~generation.GenerateEncoderDecoderOutput`],
  816. - [`~generation.GenerateBeamEncoderDecoderOutput`]
  817. """
  818. # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
  819. if generation_config is None:
  820. generation_config = self.generation_config
  821. generation_config = copy.deepcopy(generation_config)
  822. model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
  823. generation_config.validate()
  824. self._validate_model_kwargs(model_kwargs.copy())
  825. # 2. Set generation parameters if not already defined
  826. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  827. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
  828. requires_attention_mask = "encoder_outputs" not in model_kwargs
  829. kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  830. # 3. Define model inputs`
  831. input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
  832. inputs, generation_config.bos_token_id, model_kwargs
  833. )
  834. batch_size = input_ids.shape[0] // self.num_codebooks
  835. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
  836. # 4. Define other model kwargs
  837. model_kwargs["use_cache"] = generation_config.use_cache
  838. model_kwargs["guidance_scale"] = generation_config.guidance_scale
  839. if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
  840. model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
  841. input_ids, generation_config, model_kwargs
  842. )
  843. # 5. Prepare `max_length` depending on other stopping criteria.
  844. input_ids_length = input_ids.shape[-1]
  845. has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
  846. has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
  847. generation_config = self._prepare_generated_length(
  848. generation_config=generation_config,
  849. has_default_max_length=has_default_max_length,
  850. has_default_min_length=has_default_min_length,
  851. model_input_name=model_input_name,
  852. inputs_tensor=input_ids,
  853. input_ids_length=input_ids_length,
  854. )
  855. self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
  856. # 6. Prepare the cache.
  857. # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
  858. # - different models have a different cache name expected by the model (default = "past_key_values")
  859. # - `max_length`, prepared above, is used to determine the maximum cache length
  860. max_cache_length = generation_config.max_length - 1
  861. if (
  862. input_ids_length.shape[1] != input_ids_length
  863. and model_input_name == "inputs_embeds"
  864. and not self.config.is_encoder_decoder
  865. ):
  866. max_cache_length += input_ids_length.shape[1]
  867. self._prepare_cache_for_generation(
  868. generation_config,
  869. model_kwargs,
  870. generation_mode=None,
  871. batch_size=batch_size,
  872. max_cache_length=max_cache_length,
  873. )
  874. # 7. Prepare `input_ids` which will be used for auto-regressive generation
  875. # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
  876. input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
  877. input_ids,
  878. pad_token_id=generation_config._decoder_start_token_tensor,
  879. max_length=generation_config.max_length,
  880. )
  881. if streamer is not None:
  882. streamer.put(input_ids.cpu())
  883. # stash the delay mask so that we don't have to recompute it in each forward pass
  884. model_kwargs["delay_pattern_mask"] = delay_pattern_mask
  885. # 8. determine generation mode
  886. generation_mode = generation_config.get_generation_mode()
  887. # 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
  888. if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
  889. logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
  890. generation_config.guidance_scale = None
  891. # 10. prepare distribution pre_processing samplers
  892. logits_processor = self._get_logits_processor(
  893. generation_config=generation_config,
  894. input_ids_seq_length=input_ids_length,
  895. encoder_input_ids=input_ids,
  896. prefix_allowed_tokens_fn=None,
  897. logits_processor=logits_processor,
  898. device=input_ids.device,
  899. )
  900. # 10. prepare stopping criteria
  901. stopping_criteria = self._get_stopping_criteria(
  902. generation_config=generation_config, stopping_criteria=stopping_criteria
  903. )
  904. if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
  905. # expand input_ids with `num_return_sequences` additional sequences per batch
  906. input_ids, model_kwargs = self._expand_inputs_for_generation(
  907. input_ids=input_ids,
  908. expand_size=generation_config.num_return_sequences,
  909. **model_kwargs,
  910. )
  911. # 11. run sample
  912. outputs = self._sample(
  913. input_ids,
  914. logits_processor=logits_processor,
  915. stopping_criteria=stopping_criteria,
  916. generation_config=generation_config,
  917. synced_gpus=synced_gpus,
  918. streamer=streamer,
  919. **model_kwargs,
  920. )
  921. else:
  922. raise ValueError(
  923. "Got incompatible mode for generation, should be one of greedy or sampling. "
  924. "Ensure that beam search is de-activated by setting `num_beams=1`."
  925. )
  926. if generation_config.return_dict_in_generate:
  927. output_ids = outputs.sequences
  928. else:
  929. output_ids = outputs
  930. # apply the pattern mask to the final ids
  931. output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
  932. # revert the pattern delay mask by filtering the pad token id
  933. output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
  934. batch_size, self.num_codebooks, -1
  935. )
  936. if generation_config.return_dict_in_generate:
  937. outputs.sequences = output_ids
  938. return outputs
  939. else:
  940. return output_ids
  941. @auto_docstring(
  942. custom_intro="""
  943. The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder,
  944. """
  945. )
  946. class MusicgenForConditionalGeneration(MusicgenPreTrainedModel, GenerationMixin):
  947. config: MusicgenConfig
  948. output_modalities = ("audio",)
  949. base_model_prefix = "encoder_decoder"
  950. main_input_name = "input_ids"
  951. supports_gradient_checkpointing = True
  952. def __init__(
  953. self,
  954. config: MusicgenConfig | None = None,
  955. text_encoder: PreTrainedModel | None = None,
  956. audio_encoder: PreTrainedModel | None = None,
  957. decoder: MusicgenForCausalLM | None = None,
  958. ):
  959. r"""
  960. text_encoder (`PreTrainedModel`, *optional*):
  961. The text encoder model that encodes text into hidden states for conditioning.
  962. audio_encoder (`PreTrainedModel`, *optional*):
  963. The audio encoder model that encodes audio into hidden states for conditioning.
  964. decoder (`MusicgenForCausalLM`, *optional*):
  965. The decoder model that generates audio tokens based on conditioning signals.
  966. """
  967. if config is None and (text_encoder is None or audio_encoder is None or decoder is None):
  968. raise ValueError(
  969. "Either a configuration has to be provided, or all three of text encoder, audio encoder and MusicGen decoder."
  970. )
  971. if config is None:
  972. config = MusicgenConfig(
  973. text_encoder=text_encoder.config, audio_encoder=audio_encoder.config, decoder=decoder.config
  974. )
  975. else:
  976. if not isinstance(config, self.config_class):
  977. raise ValueError(f"Config: {config} has to be of type {self.config_class}")
  978. if config.decoder.cross_attention_hidden_size is not None:
  979. if config.decoder.cross_attention_hidden_size != config.text_encoder.hidden_size:
  980. raise ValueError(
  981. "If `cross_attention_hidden_size` is specified in the MusicGen decoder's configuration, it has to be equal"
  982. f" to the text encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
  983. f" `config.decoder.cross_attention_hidden_size` and {config.text_encoder.hidden_size} for"
  984. " `config.text_encoder.hidden_size`."
  985. )
  986. # initialize with config
  987. super().__init__(config)
  988. if text_encoder is None:
  989. from ..auto.modeling_auto import AutoModelForTextEncoding
  990. text_encoder = AutoModelForTextEncoding.from_config(config.text_encoder)
  991. if audio_encoder is None:
  992. from ..auto.modeling_auto import AutoModel
  993. audio_encoder = AutoModel.from_config(config.audio_encoder)
  994. if decoder is None:
  995. decoder = MusicgenForCausalLM._from_config(config.decoder)
  996. self.text_encoder = text_encoder
  997. self.audio_encoder = audio_encoder
  998. self.decoder = decoder
  999. if self.text_encoder.config.to_dict() != self.config.text_encoder.to_dict():
  1000. logger.warning(
  1001. f"Config of the text_encoder: {self.text_encoder.__class__} is overwritten by shared text_encoder config:"
  1002. f" {self.config.text_encoder}"
  1003. )
  1004. if self.audio_encoder.config.to_dict() != self.config.audio_encoder.to_dict():
  1005. logger.warning(
  1006. f"Config of the audio_encoder: {self.audio_encoder.__class__} is overwritten by shared audio_encoder config:"
  1007. f" {self.config.audio_encoder}"
  1008. )
  1009. if self.decoder.config.to_dict() != self.config.decoder.to_dict():
  1010. logger.warning(
  1011. f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
  1012. f" {self.config.decoder}"
  1013. )
  1014. # make sure that the individual model's config refers to the shared config
  1015. # so that the updates to the config will be synced
  1016. self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation
  1017. self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation
  1018. self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
  1019. self.text_encoder.config = self.config.text_encoder
  1020. self.audio_encoder.config = self.config.audio_encoder
  1021. self.decoder.config = self.config.decoder
  1022. # text encoder outputs might need to be projected to different dimension for decoder
  1023. if (
  1024. self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
  1025. and self.decoder.config.cross_attention_hidden_size is None
  1026. ):
  1027. self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
  1028. if self.text_encoder.get_output_embeddings() is not None:
  1029. raise ValueError(
  1030. f"The encoder {self.text_encoder} should not have a LM Head. Please use a model without and LM Head"
  1031. )
  1032. decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
  1033. if "encoder_hidden_states" not in decoder_signature:
  1034. raise ValueError(
  1035. "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
  1036. "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
  1037. )
  1038. # tie text encoder, decoder weights if config set accordingly
  1039. self.post_init()
  1040. def get_input_embeddings(self):
  1041. return self.text_encoder.get_input_embeddings()
  1042. def get_output_embeddings(self):
  1043. return self.decoder.get_output_embeddings()
  1044. def set_output_embeddings(self, new_embeddings):
  1045. return self.decoder.set_output_embeddings(new_embeddings)
  1046. @classmethod
  1047. def from_sub_models_pretrained(
  1048. cls,
  1049. text_encoder_pretrained_model_name_or_path: str | None = None,
  1050. audio_encoder_pretrained_model_name_or_path: str | None = None,
  1051. decoder_pretrained_model_name_or_path: str | None = None,
  1052. *model_args,
  1053. **kwargs,
  1054. ) -> PreTrainedModel:
  1055. r"""
  1056. Instantiate a text encoder, an audio encoder, and a MusicGen decoder from one, two or three base classes of the
  1057. library from pretrained model checkpoints.
  1058. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
  1059. the model, you need to first set it back in training mode with `model.train()`.
  1060. Params:
  1061. text_encoder_pretrained_model_name_or_path (`str`, *optional*):
  1062. Information necessary to initiate the text encoder. Can be either:
  1063. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  1064. - A path to a *directory* containing model weights saved using
  1065. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  1066. audio_encoder_pretrained_model_name_or_path (`str`, *optional*):
  1067. Information necessary to initiate the audio encoder. Can be either:
  1068. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  1069. - A path to a *directory* containing model weights saved using
  1070. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  1071. decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
  1072. Information necessary to initiate the decoder. Can be either:
  1073. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  1074. - A path to a *directory* containing model weights saved using
  1075. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  1076. model_args (remaining positional arguments, *optional*):
  1077. All remaining positional arguments will be passed to the underlying model's `__init__` method.
  1078. kwargs (remaining dictionary of keyword arguments, *optional*):
  1079. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  1080. `output_attentions=True`).
  1081. - To update the text encoder configuration, use the prefix *text_encoder_* for each configuration
  1082. parameter.
  1083. - To update the audio encoder configuration, use the prefix *audio_encoder_* for each configuration
  1084. parameter.
  1085. - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
  1086. - To update the parent model configuration, do not use a prefix for each configuration parameter.
  1087. Behaves differently depending on whether a `config` is provided or automatically loaded.
  1088. Example:
  1089. ```python
  1090. >>> from transformers import MusicgenForConditionalGeneration
  1091. >>> # initialize a musicgen model from a t5 text encoder, encodec audio encoder, and musicgen decoder
  1092. >>> model = MusicgenForConditionalGeneration.from_sub_models_pretrained(
  1093. ... text_encoder_pretrained_model_name_or_path="google-t5/t5-base",
  1094. ... audio_encoder_pretrained_model_name_or_path="facebook/encodec_24khz",
  1095. ... decoder_pretrained_model_name_or_path="facebook/musicgen-small",
  1096. ... )
  1097. >>> # saving model after fine-tuning
  1098. >>> model.save_pretrained("./musicgen-ft")
  1099. >>> # load fine-tuned model
  1100. >>> model = MusicgenForConditionalGeneration.from_pretrained("./musicgen-ft")
  1101. ```"""
  1102. kwargs_text_encoder = {
  1103. argument[len("text_encoder_") :]: value
  1104. for argument, value in kwargs.items()
  1105. if argument.startswith("text_encoder_")
  1106. }
  1107. kwargs_audio_encoder = {
  1108. argument[len("audio_encoder_") :]: value
  1109. for argument, value in kwargs.items()
  1110. if argument.startswith("audio_encoder_")
  1111. }
  1112. kwargs_decoder = {
  1113. argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
  1114. }
  1115. # remove text encoder, audio encoder and decoder kwargs from kwargs
  1116. for key in kwargs_text_encoder:
  1117. del kwargs["text_encoder_" + key]
  1118. for key in kwargs_audio_encoder:
  1119. del kwargs["audio_encoder_" + key]
  1120. for key in kwargs_decoder:
  1121. del kwargs["decoder_" + key]
  1122. # Load and initialize the encoder and decoder
  1123. # The distinction between encoder and decoder at the model level is made
  1124. # by the value of the flag `is_decoder` that we need to set correctly.
  1125. text_encoder = kwargs_text_encoder.pop("model", None)
  1126. if text_encoder is None:
  1127. if text_encoder_pretrained_model_name_or_path is None:
  1128. raise ValueError(
  1129. "If `text_encoder_model` is not defined as an argument, a `text_encoder_pretrained_model_name_or_path` has "
  1130. "to be defined."
  1131. )
  1132. if "config" not in kwargs_text_encoder:
  1133. encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained(
  1134. text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True
  1135. )
  1136. if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
  1137. logger.info(
  1138. f"Initializing {text_encoder_pretrained_model_name_or_path} as a text_encoder model "
  1139. "from a decoder model. Cross-attention and causal mask are disabled."
  1140. )
  1141. encoder_config.is_decoder = False
  1142. encoder_config.add_cross_attention = False
  1143. kwargs_text_encoder["config"] = encoder_config
  1144. text_encoder = AutoModel.from_pretrained(
  1145. text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder
  1146. )
  1147. audio_encoder = kwargs_audio_encoder.pop("model", None)
  1148. if audio_encoder is None:
  1149. if audio_encoder_pretrained_model_name_or_path is None:
  1150. raise ValueError(
  1151. "If `audio_encoder_model` is not defined as an argument, an `audio_encoder_pretrained_model_name_or_path` has "
  1152. "to be defined."
  1153. )
  1154. if "config" not in kwargs_audio_encoder:
  1155. encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained(
  1156. audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True
  1157. )
  1158. if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
  1159. logger.info(
  1160. f"Initializing {audio_encoder_pretrained_model_name_or_path} as an audio_encoder model "
  1161. "from a decoder model. Cross-attention and causal mask are disabled."
  1162. )
  1163. encoder_config.is_decoder = False
  1164. encoder_config.add_cross_attention = False
  1165. kwargs_audio_encoder["config"] = encoder_config
  1166. audio_encoder = AutoModel.from_pretrained(
  1167. audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder
  1168. )
  1169. decoder = kwargs_decoder.pop("model", None)
  1170. if decoder is None:
  1171. if decoder_pretrained_model_name_or_path is None:
  1172. raise ValueError(
  1173. "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
  1174. "to be defined."
  1175. )
  1176. if "config" not in kwargs_decoder:
  1177. decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
  1178. decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
  1179. )
  1180. if isinstance(decoder_config, MusicgenConfig):
  1181. decoder_config = decoder_config.decoder
  1182. if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
  1183. logger.info(
  1184. f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
  1185. f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
  1186. f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
  1187. )
  1188. decoder_config.is_decoder = True
  1189. decoder_config.add_cross_attention = True
  1190. kwargs_decoder["config"] = decoder_config
  1191. if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
  1192. logger.warning(
  1193. f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
  1194. f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
  1195. "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
  1196. "passed to `.from_sub_models_pretrained(...)` are set to `True` or do not pass a "
  1197. "`decoder_config` to `.from_sub_models_pretrained(...)`"
  1198. )
  1199. decoder = MusicgenForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
  1200. # instantiate config with corresponding kwargs
  1201. config = MusicgenConfig(
  1202. text_encoder=text_encoder.config, audio_encoder=audio_encoder.config, decoder=decoder.config, **kwargs
  1203. )
  1204. return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config)
  1205. @can_return_tuple
  1206. @auto_docstring
  1207. def forward(
  1208. self,
  1209. input_ids: torch.LongTensor | None = None,
  1210. attention_mask: torch.BoolTensor | None = None,
  1211. input_values: torch.FloatTensor | None = None,
  1212. padding_mask: torch.BoolTensor | None = None,
  1213. decoder_input_ids: torch.LongTensor | None = None,
  1214. decoder_attention_mask: torch.BoolTensor | None = None,
  1215. encoder_outputs: tuple[torch.FloatTensor] | None = None,
  1216. past_key_values: Cache | None = None,
  1217. inputs_embeds: torch.FloatTensor | None = None,
  1218. decoder_inputs_embeds: torch.FloatTensor | None = None,
  1219. labels: torch.LongTensor | None = None,
  1220. use_cache: bool | None = None,
  1221. **kwargs: Unpack[TransformersKwargs],
  1222. ) -> tuple | Seq2SeqLMOutput:
  1223. r"""
  1224. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1225. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1226. - 1 for tokens that are **not masked**,
  1227. - 0 for tokens that are **masked**.
  1228. [What are attention masks?](../glossary#attention-mask)
  1229. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)`, *optional*):
  1230. Indices of decoder input sequence tokens in the vocabulary, corresponding to the sequence of audio codes.
  1231. Indices can be obtained by encoding an audio prompt with an audio encoder model to predict audio codes,
  1232. such as with the [`EncodecModel`]. See [`EncodecModel.encode`] for details.
  1233. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1234. <Tip warning={true}>
  1235. The `decoder_input_ids` will automatically be converted from shape `(batch_size * num_codebooks,
  1236. target_sequence_length)` to `(batch_size, num_codebooks, target_sequence_length)` in the forward pass. If
  1237. you obtain audio codes from an audio encoding model, such as [`EncodecModel`], ensure that the number of
  1238. frames is equal to 1, and that you reshape the audio codes from `(frames, batch_size, num_codebooks,
  1239. target_sequence_length)` to `(batch_size * num_codebooks, target_sequence_length)` prior to passing them as
  1240. `decoder_input_ids`.
  1241. </Tip>
  1242. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1243. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1244. be used by default.
  1245. labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
  1246. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  1247. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  1248. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  1249. Examples:
  1250. ```python
  1251. >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration
  1252. >>> import torch
  1253. >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
  1254. >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
  1255. >>> inputs = processor(
  1256. ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"],
  1257. ... padding=True,
  1258. ... return_tensors="pt",
  1259. ... )
  1260. >>> pad_token_id = model.generation_config.pad_token_id
  1261. >>> decoder_input_ids = (
  1262. ... torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)
  1263. ... * pad_token_id
  1264. ... )
  1265. >>> logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits
  1266. >>> logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)
  1267. torch.Size([8, 1, 2048])
  1268. ```"""
  1269. kwargs_text_encoder = {}
  1270. kwargs_audio_encoder = {}
  1271. kwargs_decoder = {}
  1272. common_kwargs = {}
  1273. for key, value in kwargs.items():
  1274. if key.startswith("text_encoder_"):
  1275. kwargs_text_encoder[key[len("text_encoder_") :]] = value
  1276. elif key.startswith("audio_encoder_"):
  1277. kwargs_audio_encoder[key[len("audio_encoder_") :]] = value
  1278. elif key.startswith("decoder_"):
  1279. kwargs_decoder[key[len("decoder_") :]] = value
  1280. else:
  1281. common_kwargs[key] = value
  1282. if encoder_outputs is None:
  1283. encoder_outputs = self.text_encoder(
  1284. input_ids=input_ids,
  1285. attention_mask=attention_mask,
  1286. inputs_embeds=inputs_embeds,
  1287. **kwargs_text_encoder,
  1288. **common_kwargs,
  1289. )
  1290. elif isinstance(encoder_outputs, tuple):
  1291. encoder_outputs = BaseModelOutput(*encoder_outputs)
  1292. encoder_hidden_states = encoder_outputs[0]
  1293. # optionally project encoder_hidden_states
  1294. if (
  1295. self.text_encoder.config.hidden_size != self.decoder.config.hidden_size
  1296. and self.decoder.config.cross_attention_hidden_size is None
  1297. ):
  1298. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  1299. if attention_mask is not None:
  1300. encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
  1301. if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
  1302. decoder_input_ids = shift_tokens_right(
  1303. labels, self.config.decoder.pad_token_id, self.config.decoder.decoder_start_token_id
  1304. )
  1305. elif decoder_input_ids is None and decoder_inputs_embeds is None:
  1306. audio_encoder_outputs = self.audio_encoder(
  1307. input_values=input_values,
  1308. padding_mask=padding_mask,
  1309. **kwargs_audio_encoder,
  1310. )
  1311. audio_codes = audio_encoder_outputs.audio_codes
  1312. frames, bsz, codebooks, seq_len = audio_codes.shape
  1313. if frames != 1:
  1314. raise ValueError(
  1315. f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
  1316. "disabled by setting `chunk_length=None` in the audio encoder."
  1317. )
  1318. if self.config.decoder.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
  1319. # mono input through encodec that we convert to stereo
  1320. audio_codes = audio_codes.repeat_interleave(2, dim=2)
  1321. decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
  1322. # Decode
  1323. decoder_outputs: CausalLMOutputWithCrossAttentions = self.decoder(
  1324. input_ids=decoder_input_ids,
  1325. attention_mask=decoder_attention_mask,
  1326. encoder_hidden_states=encoder_hidden_states,
  1327. encoder_attention_mask=attention_mask,
  1328. inputs_embeds=decoder_inputs_embeds,
  1329. use_cache=use_cache,
  1330. past_key_values=past_key_values,
  1331. labels=labels,
  1332. **kwargs_decoder,
  1333. **common_kwargs,
  1334. )
  1335. return Seq2SeqLMOutput(
  1336. loss=decoder_outputs.loss,
  1337. logits=decoder_outputs.logits,
  1338. past_key_values=decoder_outputs.past_key_values,
  1339. decoder_hidden_states=decoder_outputs.hidden_states,
  1340. decoder_attentions=decoder_outputs.attentions,
  1341. cross_attentions=decoder_outputs.cross_attentions,
  1342. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1343. encoder_hidden_states=encoder_outputs.hidden_states,
  1344. encoder_attentions=encoder_outputs.attentions,
  1345. )
  1346. def prepare_inputs_for_generation(
  1347. self,
  1348. decoder_input_ids,
  1349. next_sequence_length: int | None = None,
  1350. past_key_values=None,
  1351. attention_mask=None,
  1352. decoder_attention_mask=None,
  1353. use_cache=None,
  1354. encoder_outputs=None,
  1355. decoder_delay_pattern_mask=None,
  1356. guidance_scale=None,
  1357. **kwargs,
  1358. ):
  1359. # Overwritten -- MusicGen has custom processing
  1360. if decoder_delay_pattern_mask is None:
  1361. decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
  1362. decoder_input_ids,
  1363. self.generation_config.pad_token_id,
  1364. max_length=self.generation_config.max_length,
  1365. )
  1366. # apply the delay pattern mask
  1367. decoder_input_ids = self.decoder.apply_delay_pattern_mask(decoder_input_ids, decoder_delay_pattern_mask)
  1368. if guidance_scale is not None and guidance_scale > 1:
  1369. # for classifier free guidance we need to replicate the decoder args across the batch dim (we'll split these
  1370. # before sampling)
  1371. decoder_input_ids = decoder_input_ids.repeat((2, 1))
  1372. if decoder_attention_mask is not None:
  1373. decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
  1374. if past_key_values is not None:
  1375. decoder_input_ids = (
  1376. decoder_input_ids[:, -next_sequence_length:] if next_sequence_length is not None else decoder_input_ids
  1377. )
  1378. return {
  1379. "input_ids": None, # encoder_outputs is defined. input_ids not needed
  1380. "encoder_outputs": encoder_outputs,
  1381. "past_key_values": past_key_values,
  1382. "decoder_input_ids": decoder_input_ids,
  1383. "attention_mask": attention_mask,
  1384. "decoder_attention_mask": decoder_attention_mask,
  1385. "use_cache": use_cache,
  1386. }
  1387. def _prepare_decoder_input_ids_for_generation(
  1388. self,
  1389. batch_size: int,
  1390. model_input_name: str,
  1391. model_kwargs: dict[str, torch.Tensor],
  1392. decoder_start_token_id: int | None = None,
  1393. bos_token_id: int | None = None,
  1394. device: torch.device | None = None,
  1395. ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
  1396. """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
  1397. # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
  1398. # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
  1399. if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
  1400. decoder_input_ids = model_kwargs.pop("decoder_input_ids")
  1401. elif "input_ids" in model_kwargs and model_input_name != "input_ids":
  1402. decoder_input_ids = model_kwargs.pop("input_ids")
  1403. else:
  1404. decoder_input_ids = None
  1405. # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
  1406. decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
  1407. if device is None:
  1408. device = self.device
  1409. decoder_input_ids_start = (
  1410. torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device)
  1411. * decoder_start_token_id
  1412. )
  1413. # no user input -> use decoder_start_token_id as decoder_input_ids
  1414. if decoder_input_ids is None:
  1415. decoder_input_ids = decoder_input_ids_start
  1416. # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
  1417. # decoder_attention_mask if provided)
  1418. elif (decoder_input_ids[..., 0] != decoder_start_token_id).all().item():
  1419. decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
  1420. if "decoder_attention_mask" in model_kwargs:
  1421. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  1422. decoder_attention_mask = torch.cat(
  1423. (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
  1424. dim=-1,
  1425. )
  1426. model_kwargs["decoder_attention_mask"] = decoder_attention_mask
  1427. return decoder_input_ids, model_kwargs
  1428. def _prepare_text_encoder_kwargs_for_generation(
  1429. self,
  1430. inputs_tensor: torch.Tensor,
  1431. model_kwargs,
  1432. model_input_name: str | None,
  1433. generation_config: GenerationConfig,
  1434. ) -> dict[str, Any]:
  1435. # 1. get text encoder
  1436. encoder = self.get_encoder()
  1437. # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
  1438. # as the inputs.
  1439. if hasattr(encoder, "_hf_hook"):
  1440. encoder._hf_hook.io_same_device = True
  1441. # 2. Prepare encoder args and encoder kwargs from model kwargs.
  1442. irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
  1443. encoder_kwargs = {
  1444. argument: value
  1445. for argument, value in model_kwargs.items()
  1446. if not any(argument.startswith(p) for p in irrelevant_prefix)
  1447. }
  1448. encoder_signature = set(inspect.signature(encoder.forward).parameters)
  1449. encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
  1450. if not encoder_accepts_wildcard:
  1451. encoder_kwargs = {
  1452. argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
  1453. }
  1454. encoder_kwargs["output_attentions"] = generation_config.output_attentions
  1455. encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
  1456. guidance_scale = generation_config.guidance_scale
  1457. # 3. make sure that encoder returns `ModelOutput`
  1458. model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
  1459. encoder_kwargs["return_dict"] = True
  1460. encoder_kwargs[model_input_name] = inputs_tensor
  1461. last_hidden_state = encoder(**encoder_kwargs).last_hidden_state
  1462. # for classifier free guidance we need to add a 'null' input to our encoder hidden states
  1463. if guidance_scale is not None and guidance_scale > 1:
  1464. last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0)
  1465. if "attention_mask" in model_kwargs:
  1466. model_kwargs["attention_mask"] = torch.concatenate(
  1467. [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0
  1468. )
  1469. model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
  1470. return model_kwargs
  1471. def _prepare_audio_encoder_kwargs_for_generation(
  1472. self, input_values, model_kwargs, model_input_name: str | None = None
  1473. ):
  1474. # 1. get audio encoder
  1475. encoder = self.get_encoder(modality="audio")
  1476. # Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
  1477. # as the inputs.
  1478. if hasattr(encoder, "_hf_hook"):
  1479. encoder._hf_hook.io_same_device = True
  1480. # 2. Prepare encoder args and encoder kwargs from model kwargs.
  1481. irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
  1482. encoder_kwargs = {
  1483. argument: value
  1484. for argument, value in model_kwargs.items()
  1485. if not any(argument.startswith(p) for p in irrelevant_prefix)
  1486. }
  1487. encoder_signature = set(inspect.signature(encoder.forward).parameters)
  1488. encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
  1489. if not encoder_accepts_wildcard:
  1490. encoder_kwargs = {
  1491. argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
  1492. }
  1493. # 3. make sure that encoder returns `ModelOutput`
  1494. model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
  1495. encoder_kwargs["return_dict"] = True
  1496. if self.decoder.config.audio_channels == 1:
  1497. encoder_kwargs[model_input_name] = input_values
  1498. audio_encoder_outputs = encoder.encode(**encoder_kwargs)
  1499. audio_codes = audio_encoder_outputs.audio_codes
  1500. audio_scales = audio_encoder_outputs.audio_scales
  1501. frames, bsz, codebooks, seq_len = audio_codes.shape
  1502. else:
  1503. if input_values.shape[1] != 2:
  1504. raise ValueError(
  1505. f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel."
  1506. )
  1507. encoder_kwargs[model_input_name] = input_values[:, :1, :]
  1508. audio_encoder_outputs_left = encoder.encode(**encoder_kwargs)
  1509. audio_codes_left = audio_encoder_outputs_left.audio_codes
  1510. audio_scales_left = audio_encoder_outputs_left.audio_scales
  1511. encoder_kwargs[model_input_name] = input_values[:, 1:, :]
  1512. audio_encoder_outputs_right = encoder.encode(**encoder_kwargs)
  1513. audio_codes_right = audio_encoder_outputs_right.audio_codes
  1514. audio_scales_right = audio_encoder_outputs_right.audio_scales
  1515. frames, bsz, codebooks, seq_len = audio_codes_left.shape
  1516. # copy alternating left/right channel codes into stereo codebook
  1517. audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len))
  1518. audio_codes[:, :, ::2, :] = audio_codes_left
  1519. audio_codes[:, :, 1::2, :] = audio_codes_right
  1520. if audio_scales_left != [None] or audio_scales_right != [None]:
  1521. audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1)
  1522. else:
  1523. audio_scales = [None] * bsz
  1524. if frames != 1:
  1525. raise ValueError(
  1526. f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
  1527. "disabled by setting `chunk_length=None` in the audio encoder."
  1528. )
  1529. decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
  1530. model_kwargs["decoder_input_ids"] = decoder_input_ids
  1531. model_kwargs["audio_scales"] = audio_scales
  1532. return model_kwargs
  1533. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1534. return shift_tokens_right(labels, self.config.decoder.pad_token_id, self.config.decoder.bos_token_id)
  1535. def resize_token_embeddings(self, *args, **kwargs):
  1536. raise NotImplementedError(
  1537. "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
  1538. " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
  1539. " model.decoder.resize_token_embeddings(...))"
  1540. )
  1541. def freeze_audio_encoder(self):
  1542. """
  1543. Freeze the audio encoder weights.
  1544. """
  1545. for param in self.audio_encoder.parameters():
  1546. param.requires_grad = False
  1547. self.audio_encoder._requires_grad = False
  1548. def freeze_text_encoder(self):
  1549. """
  1550. Freeze the text encoder weights.
  1551. """
  1552. for param in self.text_encoder.parameters():
  1553. param.requires_grad = False
  1554. self.text_encoder._requires_grad = False
  1555. def _maybe_initialize_input_ids_for_generation(
  1556. self,
  1557. inputs: torch.Tensor | None,
  1558. bos_token_id: int | None,
  1559. model_kwargs: dict[str, torch.Tensor],
  1560. ) -> torch.LongTensor:
  1561. """Initializes input ids for generation, if necessary."""
  1562. if inputs is not None:
  1563. return inputs
  1564. encoder_outputs = model_kwargs.get("encoder_outputs")
  1565. if encoder_outputs is not None:
  1566. # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
  1567. shape = encoder_outputs[0].size()[:-1]
  1568. return torch.ones(shape, dtype=torch.long, device=self.device) * -100
  1569. if bos_token_id is None:
  1570. raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
  1571. # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
  1572. # soft-prompting or in multimodal implementations built on top of decoder-only language models.
  1573. batch_size = 1
  1574. for value in model_kwargs.values():
  1575. if isinstance(value, torch.Tensor):
  1576. batch_size = value.shape[0]
  1577. break
  1578. return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
  1579. def _get_decoder_start_token_id(
  1580. self, decoder_start_token_id: int | list[int] | None = None, bos_token_id: int | None = None
  1581. ) -> int:
  1582. decoder_start_token_id = (
  1583. decoder_start_token_id
  1584. if decoder_start_token_id is not None
  1585. else self.generation_config.decoder_start_token_id
  1586. )
  1587. bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
  1588. if decoder_start_token_id is not None:
  1589. return decoder_start_token_id
  1590. elif bos_token_id is not None:
  1591. return bos_token_id
  1592. raise ValueError(
  1593. "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
  1594. )
  1595. @torch.no_grad()
  1596. def generate(
  1597. self,
  1598. inputs: torch.Tensor | None = None,
  1599. generation_config: GenerationConfig | None = None,
  1600. logits_processor: LogitsProcessorList | None = None,
  1601. stopping_criteria: StoppingCriteriaList | None = None,
  1602. synced_gpus: bool | None = None,
  1603. streamer: Optional["BaseStreamer"] = None,
  1604. **kwargs,
  1605. ):
  1606. """
  1607. Generates sequences of token ids for models with a language modeling head.
  1608. <Tip warning={true}>
  1609. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
  1610. model's default generation configuration. You can override any `generation_config` by passing the corresponding
  1611. parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
  1612. For an overview of generation strategies and code examples, check out the [following
  1613. guide](./generation_strategies).
  1614. </Tip>
  1615. Parameters:
  1616. inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
  1617. The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
  1618. method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
  1619. should be in the format `input_ids`. For encoder-decoder models *inputs* can represent any of
  1620. `input_ids`, `input_values`, `input_features`, or `pixel_values`.
  1621. generation_config (`~generation.GenerationConfig`, *optional*):
  1622. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  1623. passed to generate matching the attributes of `generation_config` will override them. If
  1624. `generation_config` is not provided, the default will be used, which had the following loading
  1625. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  1626. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  1627. default values, whose documentation should be checked to parameterize generation.
  1628. logits_processor (`LogitsProcessorList`, *optional*):
  1629. Custom logits processors that complement the default logits processors built from arguments and
  1630. generation config. If a logit processor is passed that is already created with the arguments or a
  1631. generation config an error is thrown. This feature is intended for advanced users.
  1632. stopping_criteria (`StoppingCriteriaList`, *optional*):
  1633. Custom stopping criteria that complement the default stopping criteria built from arguments and a
  1634. generation config. If a stopping criteria is passed that is already created with the arguments or a
  1635. generation config an error is thrown. This feature is intended for advanced users.
  1636. synced_gpus (`bool`, *optional*, defaults to `False`):
  1637. Whether to continue running the while loop until max_length (needed to avoid deadlocking with
  1638. `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
  1639. streamer (`BaseStreamer`, *optional*):
  1640. Streamer object that will be used to stream the generated sequences. Generated tokens are passed
  1641. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
  1642. kwargs (`dict[str, Any]`, *optional*):
  1643. Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
  1644. forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
  1645. specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
  1646. Return:
  1647. [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
  1648. or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
  1649. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
  1650. [`~utils.ModelOutput`] types are:
  1651. - [`~generation.GenerateDecoderOnlyOutput`],
  1652. - [`~generation.GenerateBeamDecoderOnlyOutput`]
  1653. If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
  1654. [`~utils.ModelOutput`] types are:
  1655. - [`~generation.GenerateEncoderDecoderOutput`],
  1656. - [`~generation.GenerateBeamEncoderDecoderOutput`]
  1657. """
  1658. # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
  1659. generation_mode_kwargs = self._extract_generation_mode_kwargs(None, kwargs, False, None, None)
  1660. generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
  1661. generation_mode = generation_config.get_generation_mode()
  1662. if generation_mode not in [GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH]:
  1663. raise ValueError(
  1664. "Got incompatible mode for generation, should be one of greedy or sampling. "
  1665. "Ensure that beam search is de-activated by setting `num_beams=1`."
  1666. )
  1667. self._validate_model_kwargs(model_kwargs.copy())
  1668. self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
  1669. if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple:
  1670. # wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
  1671. model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])
  1672. # 2. Set generation parameters if not already defined
  1673. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  1674. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
  1675. requires_attention_mask = "encoder_outputs" not in model_kwargs
  1676. kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  1677. # 3. Define model inputs
  1678. inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
  1679. inputs, generation_config.bos_token_id, model_kwargs
  1680. )
  1681. batch_size = inputs_tensor.shape[0]
  1682. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device)
  1683. # 4. Define other model kwargs
  1684. model_kwargs["use_cache"] = generation_config.use_cache
  1685. model_kwargs["guidance_scale"] = generation_config.guidance_scale
  1686. if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
  1687. model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
  1688. inputs_tensor, generation_config, model_kwargs
  1689. )
  1690. if "encoder_outputs" not in model_kwargs:
  1691. # encoder_outputs are created and added to `model_kwargs`
  1692. model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
  1693. inputs_tensor, model_kwargs, model_input_name, generation_config
  1694. )
  1695. if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
  1696. model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
  1697. model_kwargs["input_values"],
  1698. model_kwargs,
  1699. )
  1700. # 5. Prepare `input_ids` which will be used for auto-regressive generation
  1701. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  1702. batch_size=batch_size,
  1703. model_input_name=model_input_name,
  1704. model_kwargs=model_kwargs,
  1705. decoder_start_token_id=generation_config._decoder_start_token_tensor,
  1706. bos_token_id=generation_config._bos_token_tensor,
  1707. device=inputs_tensor.device,
  1708. )
  1709. # 6. Prepare `max_length` depending on other stopping criteria.
  1710. input_ids_length = input_ids.shape[-1]
  1711. has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
  1712. has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
  1713. generation_config = self._prepare_generated_length(
  1714. generation_config=generation_config,
  1715. has_default_max_length=has_default_max_length,
  1716. has_default_min_length=has_default_min_length,
  1717. model_input_name=model_input_name,
  1718. inputs_tensor=inputs_tensor,
  1719. input_ids_length=input_ids_length,
  1720. )
  1721. # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
  1722. input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
  1723. input_ids,
  1724. pad_token_id=generation_config._decoder_start_token_tensor,
  1725. max_length=generation_config.max_length,
  1726. )
  1727. # stash the delay mask so that we don't have to recompute in each forward pass
  1728. model_kwargs["decoder_delay_pattern_mask"] = decoder_delay_pattern_mask
  1729. # input_ids are ready to be placed on the streamer (if used)
  1730. if streamer is not None:
  1731. streamer.put(input_ids.cpu())
  1732. # 7. determine generation mode
  1733. generation_mode = generation_config.get_generation_mode()
  1734. # 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
  1735. if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
  1736. logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
  1737. generation_config.guidance_scale = None
  1738. # 9. prepare distribution pre_processing samplers
  1739. logits_processor = self._get_logits_processor(
  1740. generation_config=generation_config,
  1741. input_ids_seq_length=input_ids_length,
  1742. encoder_input_ids=inputs_tensor,
  1743. prefix_allowed_tokens_fn=None,
  1744. logits_processor=logits_processor,
  1745. device=input_ids.device,
  1746. )
  1747. # 10. prepare stopping criteria
  1748. stopping_criteria = self._get_stopping_criteria(
  1749. generation_config=generation_config, stopping_criteria=stopping_criteria
  1750. )
  1751. # expand input_ids with `num_return_sequences` additional sequences per batch
  1752. input_ids, model_kwargs = self._expand_inputs_for_generation(
  1753. input_ids=input_ids,
  1754. expand_size=generation_config.num_return_sequences,
  1755. is_encoder_decoder=self.config.is_encoder_decoder,
  1756. **model_kwargs,
  1757. )
  1758. # 10b. prepare prefill outputs
  1759. generation_mode_kwargs["prefill_outputs"] = self._prefill(input_ids, generation_config, model_kwargs)
  1760. # 11. run sample
  1761. outputs = self._sample(
  1762. input_ids,
  1763. logits_processor=logits_processor,
  1764. stopping_criteria=stopping_criteria,
  1765. generation_config=generation_config,
  1766. **generation_mode_kwargs,
  1767. **model_kwargs,
  1768. )
  1769. if generation_config.return_dict_in_generate:
  1770. output_ids = outputs.sequences
  1771. else:
  1772. output_ids = outputs
  1773. # apply the pattern mask to the final ids
  1774. output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"])
  1775. # revert the pattern delay mask by filtering the pad token id
  1776. output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape(
  1777. batch_size, self.decoder.num_codebooks, -1
  1778. )
  1779. # append the frame dimension back to the audio codes
  1780. output_ids = output_ids[None, ...]
  1781. audio_scales = model_kwargs.get("audio_scales")
  1782. if audio_scales is None:
  1783. audio_scales = [None] * batch_size
  1784. if self.decoder.config.audio_channels == 1:
  1785. output_values = self.audio_encoder.decode(
  1786. output_ids,
  1787. audio_scales=audio_scales,
  1788. ).audio_values
  1789. else:
  1790. codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
  1791. output_values_left = codec_outputs_left.audio_values
  1792. codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
  1793. output_values_right = codec_outputs_right.audio_values
  1794. output_values = torch.cat([output_values_left, output_values_right], dim=1)
  1795. if generation_config.return_dict_in_generate:
  1796. outputs.sequences = output_values
  1797. return outputs
  1798. else:
  1799. return output_values
  1800. def get_unconditional_inputs(self, num_samples=1):
  1801. """
  1802. Helper function to get null inputs for unconditional generation, enabling the model to be used without the
  1803. feature extractor or tokenizer.
  1804. Args:
  1805. num_samples (int, *optional*):
  1806. Number of audio samples to unconditionally generate.
  1807. max_new_tokens (int, *optional*):
  1808. Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
  1809. longer inference (since more audio tokens need to be generated per sample).
  1810. Example:
  1811. ```python
  1812. >>> from transformers import MusicgenForConditionalGeneration
  1813. >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
  1814. >>> # get the unconditional (or 'null') inputs for the model
  1815. >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
  1816. >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
  1817. ```"""
  1818. last_hidden_state = torch.zeros(
  1819. (num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
  1820. )
  1821. attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
  1822. return MusicgenUnconditionalInput(
  1823. encoder_outputs=(last_hidden_state,),
  1824. attention_mask=attention_mask,
  1825. guidance_scale=1.0,
  1826. )
  1827. __all__ = ["MusicgenForConditionalGeneration", "MusicgenForCausalLM", "MusicgenModel", "MusicgenPreTrainedModel"]