modular_csm.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. # Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from dataclasses import dataclass
  15. import torch
  16. import torch.nn as nn
  17. from ... import initialization as init
  18. from ...cache_utils import Cache, DynamicCache
  19. from ...generation import GenerationMixin
  20. from ...masking_utils import create_causal_mask
  21. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  22. from ...modeling_utils import PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
  25. from ...utils.generic import merge_with_config_defaults
  26. from ...utils.import_utils import is_torchdynamo_compiling
  27. from ...utils.output_capturing import capture_outputs
  28. from ..auto import AutoModel
  29. from ..llama.modeling_llama import (
  30. LlamaAttention,
  31. LlamaDecoderLayer,
  32. LlamaForCausalLM,
  33. LlamaMLP,
  34. LlamaModel,
  35. LlamaRMSNorm,
  36. LlamaRotaryEmbedding,
  37. TransformersKwargs,
  38. )
  39. from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
  40. from .generation_csm import CsmGenerationMixin
  41. logger = logging.get_logger(__name__)
  42. @dataclass
  43. @auto_docstring(
  44. custom_intro="""
  45. Base class for the model autoregressive outputs.
  46. """
  47. )
  48. class CsmOutputWithPast(ModelOutput):
  49. r"""
  50. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  51. Language modeling loss (for next-token prediction).
  52. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  53. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  54. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  55. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  56. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  57. `past_key_values` input) to speed up sequential decoding.
  58. depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  59. Language modeling loss (for next-token prediction) of the depth decoder model.
  60. depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  61. Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
  62. depth_decoder_past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  63. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  64. depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  65. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  66. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  67. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  68. depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  69. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  70. sequence_length)`.
  71. backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  72. Language modeling loss (for next-token prediction) of the backbone model.
  73. """
  74. loss: torch.FloatTensor | None = None
  75. logits: torch.FloatTensor | None = None
  76. past_key_values: Cache | None = None
  77. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  78. attentions: tuple[torch.FloatTensor, ...] | None = None
  79. depth_decoder_loss: torch.FloatTensor | None = None
  80. depth_decoder_logits: torch.FloatTensor | None = None
  81. depth_decoder_past_key_values: Cache | None = None
  82. depth_decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  83. depth_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  84. backbone_loss: torch.FloatTensor | None = None
  85. # manually specify names for correct naming when converting from modular
  86. class CsmRMSNorm(LlamaRMSNorm):
  87. pass
  88. class CsmRotaryEmbedding(LlamaRotaryEmbedding):
  89. pass
  90. class CsmMLP(LlamaMLP):
  91. pass
  92. class CsmAttention(LlamaAttention):
  93. pass
  94. class CsmDecoderLayer(LlamaDecoderLayer):
  95. pass
  96. @auto_docstring(
  97. custom_intro="""
  98. The bare Csm Model outputting raw hidden-states without any specific head on top.
  99. """
  100. )
  101. @auto_docstring
  102. class CsmPreTrainedModel(PreTrainedModel):
  103. config: CsmConfig
  104. base_model_prefix = "model"
  105. input_modalities = ("audio", "text")
  106. supports_gradient_checkpointing = True
  107. _no_split_modules = ["CsmDecoderLayer"]
  108. _skip_keys_device_placement = ["past_key_values"]
  109. _supports_flash_attn = True
  110. _supports_sdpa = True
  111. # does not because of Mimi codec model
  112. # _supports_flex_attn = True
  113. _can_compile_fullgraph = True
  114. _supports_attention_backend = True
  115. _can_record_outputs = {
  116. "hidden_states": CsmDecoderLayer,
  117. "attentions": CsmAttention,
  118. }
  119. @torch.no_grad()
  120. def _init_weights(self, module):
  121. super()._init_weights(module)
  122. if isinstance(module, CsmCodebooksHead):
  123. num_codebooks = module.num_codebooks
  124. for i in range(num_codebooks - 1):
  125. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  126. elif isinstance(module, CsmBackboneModelEmbeddings):
  127. init.copy_(module.audio_tokens_offsets, torch.arange(self.config.num_codebooks) * self.config.vocab_size)
  128. @auto_docstring
  129. class CsmDepthDecoderModel(LlamaModel, CsmPreTrainedModel):
  130. config: CsmDepthDecoderConfig
  131. def __init__(self, config):
  132. super().__init__(config)
  133. self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
  134. self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
  135. @merge_with_config_defaults
  136. @capture_outputs
  137. @auto_docstring
  138. def forward(
  139. self,
  140. input_ids: torch.LongTensor | None = None,
  141. backbone_last_hidden_state: torch.FloatTensor | None = None,
  142. attention_mask: torch.Tensor | None = None,
  143. position_ids: torch.LongTensor | None = None,
  144. past_key_values: Cache | None = None,
  145. inputs_embeds: torch.FloatTensor | None = None,
  146. use_cache: bool | None = None,
  147. **kwargs: Unpack[TransformersKwargs],
  148. ) -> tuple | BaseModelOutputWithPast:
  149. r"""
  150. backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
  151. The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
  152. is provided in the `input_ids` argument.
  153. """
  154. if position_ids is not None and not is_torchdynamo_compiling():
  155. logger.warning_once(
  156. "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
  157. "and as it requires them to be identical across the batch, the provided position_ids will be ignored."
  158. )
  159. position_ids = None
  160. if (input_ids is None) ^ (inputs_embeds is not None):
  161. raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
  162. if use_cache and past_key_values is None:
  163. past_key_values = DynamicCache(config=self.config)
  164. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  165. inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
  166. device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
  167. position_ids = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
  168. if inputs_embeds is None:
  169. codebook_idxs = torch.clamp(position_ids - 1, min=0)
  170. offset = codebook_idxs * self.vocab_size
  171. inputs_embeds = self.embed_tokens(input_ids + offset)
  172. input_ids_are_first_codebook = position_ids[0] == 0
  173. if backbone_last_hidden_state is not None:
  174. inputs_embeds[:, 0] = backbone_last_hidden_state
  175. else:
  176. if not is_torchdynamo_compiling() and input_ids_are_first_codebook:
  177. logger.warning(
  178. "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
  179. )
  180. inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
  181. causal_mask = create_causal_mask(
  182. config=self.config,
  183. inputs_embeds=inputs_embeds,
  184. attention_mask=attention_mask,
  185. past_key_values=past_key_values,
  186. position_ids=position_ids,
  187. )
  188. hidden_states = inputs_embeds
  189. # create position embeddings to be shared across the decoder layers
  190. position_ids = position_ids.unsqueeze(0)
  191. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  192. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  193. hidden_states = decoder_layer(
  194. hidden_states,
  195. attention_mask=causal_mask,
  196. position_ids=position_ids,
  197. past_key_values=past_key_values,
  198. use_cache=use_cache,
  199. position_embeddings=position_embeddings,
  200. **kwargs,
  201. )
  202. hidden_states = self.norm(hidden_states)
  203. return BaseModelOutputWithPast(
  204. last_hidden_state=hidden_states,
  205. past_key_values=past_key_values if use_cache else None,
  206. )
  207. class CsmCodebooksHead(nn.Module):
  208. def __init__(self, hidden_size, num_codebooks, vocab_size):
  209. super().__init__()
  210. self.num_codebooks = num_codebooks
  211. self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
  212. def forward(self, hidden_states, codebook_indices=None):
  213. # -1 because of the concatenated backbone last hidden state
  214. codebook_indices = codebook_indices - 1
  215. codebook_weight = self.weight[codebook_indices]
  216. hidden_states = [
  217. nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
  218. for codebook_idx in range(codebook_weight.shape[0])
  219. ]
  220. hidden_states = torch.stack(hidden_states, dim=1)
  221. return hidden_states
  222. @auto_docstring(
  223. custom_intro="""
  224. The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
  225. which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
  226. (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
  227. """
  228. )
  229. class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin):
  230. _tied_weights_keys = None
  231. _tp_plan = None
  232. _pp_plan = None
  233. def __init__(self, config):
  234. super().__init__(config)
  235. del self.lm_head
  236. self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
  237. self.model = CsmDepthDecoderModel(config)
  238. def prepare_inputs_for_generation(
  239. self,
  240. input_ids: torch.LongTensor,
  241. next_sequence_length: int | None = None,
  242. past_key_values: Cache | None = None,
  243. attention_mask: torch.LongTensor | None = None,
  244. inputs_embeds: torch.FloatTensor | None = None,
  245. is_first_iteration: bool | None = False,
  246. **kwargs,
  247. ):
  248. model_inputs = super().prepare_inputs_for_generation(
  249. input_ids, next_sequence_length, past_key_values, attention_mask, inputs_embeds, **kwargs
  250. )
  251. if not is_first_iteration:
  252. model_inputs.pop("backbone_last_hidden_state")
  253. # csm depth decoder does not use position_ids
  254. model_inputs.pop("position_ids")
  255. return model_inputs
  256. @can_return_tuple
  257. @auto_docstring
  258. def forward(
  259. self,
  260. input_ids: torch.LongTensor | None = None,
  261. backbone_last_hidden_state: torch.FloatTensor | None = None,
  262. attention_mask: torch.Tensor | None = None,
  263. position_ids: torch.LongTensor | None = None,
  264. past_key_values: Cache | None = None,
  265. inputs_embeds: torch.FloatTensor | None = None,
  266. labels: torch.LongTensor | None = None,
  267. use_cache: bool | None = None,
  268. logits_to_keep: int | torch.Tensor = 0,
  269. **kwargs: Unpack[TransformersKwargs],
  270. ) -> tuple | CausalLMOutputWithPast:
  271. r"""
  272. backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
  273. The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
  274. is provided in the `input_ids` argument.
  275. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  276. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  277. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  278. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  279. """
  280. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  281. seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
  282. device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
  283. codebook_indices = torch.arange(seq_len, device=device) + past_seen_tokens
  284. outputs = self.model(
  285. input_ids=input_ids,
  286. backbone_last_hidden_state=backbone_last_hidden_state,
  287. attention_mask=attention_mask,
  288. position_ids=position_ids,
  289. past_key_values=past_key_values,
  290. inputs_embeds=inputs_embeds,
  291. use_cache=use_cache,
  292. **kwargs,
  293. )
  294. hidden_states = outputs[0]
  295. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  296. if isinstance(logits_to_keep, int):
  297. if logits_to_keep == 0:
  298. # skip idx 0 logits since it's for the concatenated backbone last hidden state
  299. slice_indices = slice(1, None)
  300. else:
  301. slice_indices = slice(-logits_to_keep, None)
  302. else:
  303. slice_indices = logits_to_keep
  304. logits = self.codebooks_head(hidden_states[:, slice_indices, :], codebook_indices[slice_indices])
  305. logits = logits.contiguous()
  306. loss = None
  307. if labels is not None:
  308. shift_labels = labels[..., 1:].contiguous()
  309. loss = self.loss_function(
  310. logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
  311. )
  312. return CausalLMOutputWithPast(
  313. loss=loss,
  314. logits=logits,
  315. past_key_values=outputs.past_key_values,
  316. hidden_states=outputs.hidden_states,
  317. attentions=outputs.attentions,
  318. )
  319. class CsmBackboneModelEmbeddings(nn.Module):
  320. def __init__(self, config):
  321. super().__init__()
  322. self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.codebook_size), config.hidden_size)
  323. self.register_buffer(
  324. "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.codebook_size, persistent=False
  325. )
  326. def forward(self, input_ids):
  327. inputs_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
  328. inputs_embeds = inputs_embeds.sum(dim=2)
  329. return inputs_embeds
  330. @auto_docstring
  331. class CsmBackboneModel(LlamaModel):
  332. def __init__(self, config):
  333. super().__init__(config)
  334. self.embed_tokens = CsmBackboneModelEmbeddings(config)
  335. @merge_with_config_defaults
  336. @capture_outputs
  337. @auto_docstring
  338. def forward(self, **super_kwargs):
  339. r"""
  340. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
  341. 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
  342. requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
  343. 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
  344. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  345. [`PreTrainedTokenizer.__call__`] for details.
  346. [What are input IDs?](../glossary#input-ids)
  347. """
  348. return super().forward(**super_kwargs)
  349. @auto_docstring(
  350. custom_intro="""
  351. The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
  352. """
  353. )
  354. class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
  355. _tied_weights_keys = {
  356. "backbone_model.embed_tokens.embed_audio_tokens.weight": "depth_decoder.model.embed_tokens.weight"
  357. }
  358. def __init__(self, config):
  359. super().__init__(config)
  360. self.vocab_size = config.vocab_size
  361. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  362. self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
  363. self.backbone_model = CsmBackboneModel._from_config(config)
  364. self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
  365. self.codec_model = AutoModel.from_config(config.codec_config)
  366. self.post_init()
  367. def get_input_embeddings(self):
  368. return self.backbone_model.embed_tokens
  369. def set_input_embeddings(self, value):
  370. self.backbone_model.embed_tokens = value
  371. @classmethod
  372. def from_pretrained(cls, *args, **kwargs):
  373. if kwargs.get("output_loading_info", False):
  374. model, loading_info = super().from_pretrained(*args, **kwargs)
  375. else:
  376. model = super().from_pretrained(*args, **kwargs)
  377. # copy depth decoder generation conf attr to the depth decoder generation config
  378. prefix = "depth_decoder_"
  379. prefix_len = len(prefix)
  380. depth_decoder_attrs = {
  381. attr[prefix_len:]: value
  382. for attr, value in vars(model.generation_config).items()
  383. if attr.startswith(prefix)
  384. }
  385. vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
  386. # remove the depth decoder generation conf attr from the model generation config
  387. for attr in depth_decoder_attrs:
  388. delattr(model.generation_config, prefix + attr)
  389. if "output_loading_info" in kwargs:
  390. return model, loading_info
  391. else:
  392. return model
  393. def save_pretrained(self, *args, **kwargs):
  394. # copy the depth decoder generation config attributes to the model generation config
  395. prefix = "depth_decoder_"
  396. depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
  397. depth_decoder_attrs.pop("transformers_version", None)
  398. for attr, value in depth_decoder_attrs.items():
  399. setattr(self.generation_config, prefix + attr, value)
  400. super().save_pretrained(*args, **kwargs)
  401. def _merge_input_ids_with_input_values(
  402. self,
  403. input_ids: torch.Tensor | None = None,
  404. input_values: torch.Tensor | None = None,
  405. input_values_cutoffs: torch.Tensor | None = None,
  406. labels: torch.Tensor | None = None,
  407. ) -> torch.Tensor | None:
  408. """
  409. Merges the input_ids and input_values to produce a single inputs_embeds tensor:
  410. 1 - Infers the codec model on the input_values to retrieve codebook token.
  411. 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
  412. 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
  413. Args:
  414. input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  415. The input ids to embed.
  416. input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
  417. The audio input values to embed.
  418. input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
  419. The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
  420. """
  421. inputs_embeds = self.embed_text_tokens(input_ids)
  422. if input_values is not None:
  423. # infer input_values_mask
  424. input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
  425. audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
  426. audio_lengths = audio_lengths[audio_lengths > 0]
  427. input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
  428. len(audio_lengths), -1
  429. )
  430. input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
  431. # =======================================
  432. # TODO: @eustlb, this should be batched !!!
  433. # but requires making sure batched inference of the codec model works as intended
  434. with torch.no_grad():
  435. audio_tokens_list = []
  436. for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
  437. batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
  438. for i in range(batch_input_values_cutoffs.shape[0] - 1):
  439. start_idx = batch_input_values_cutoffs[i]
  440. end_idx = batch_input_values_cutoffs[i + 1]
  441. audio_batch = batch_input_values[..., start_idx:end_idx]
  442. codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
  443. codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
  444. audio_tokens_list.append(codebook_ids[0])
  445. max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
  446. batched_audio_token_ids = torch.stack(
  447. [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
  448. )
  449. audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
  450. # =======================================
  451. audio_token_id = self.config.audio_token_id
  452. audio_token_mask = input_ids == audio_token_id
  453. audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
  454. inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
  455. # same for the audio eos token
  456. audio_eos_frame_ids = (
  457. torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
  458. * self.config.codebook_eos_token_id
  459. )
  460. audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
  461. audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
  462. inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
  463. # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
  464. if labels is not None:
  465. labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
  466. labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
  467. labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
  468. # mask depth decoder
  469. depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
  470. labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
  471. labels = labels_expanded
  472. return {"inputs_embeds": inputs_embeds, "labels": labels}
  473. def prepare_inputs_for_generation(
  474. self,
  475. input_ids: torch.LongTensor,
  476. next_sequence_length: int | None = None,
  477. past_key_values: Cache | None = None,
  478. attention_mask: torch.LongTensor | None = None,
  479. inputs_embeds: torch.FloatTensor | None = None,
  480. **kwargs,
  481. ):
  482. model_inputs = super().prepare_inputs_for_generation(
  483. input_ids=input_ids,
  484. next_sequence_length=next_sequence_length,
  485. past_key_values=past_key_values,
  486. attention_mask=attention_mask,
  487. inputs_embeds=inputs_embeds,
  488. **kwargs,
  489. )
  490. if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
  491. merged_inputs = self._merge_input_ids_with_input_values(
  492. input_ids=input_ids,
  493. input_values=kwargs.get("input_values"),
  494. input_values_cutoffs=kwargs.get("input_values_cutoffs"),
  495. labels=kwargs.get("labels"),
  496. )
  497. model_inputs.update(
  498. {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
  499. )
  500. return model_inputs
  501. @can_return_tuple
  502. @auto_docstring
  503. def forward(
  504. self,
  505. input_ids: torch.LongTensor | None = None,
  506. input_values: torch.Tensor | None = None,
  507. attention_mask: torch.Tensor | None = None,
  508. input_values_cutoffs: torch.Tensor | None = None,
  509. position_ids: torch.LongTensor | None = None,
  510. past_key_values: Cache | None = None,
  511. inputs_embeds: torch.FloatTensor | None = None,
  512. labels: torch.LongTensor | None = None,
  513. use_cache: bool | None = None,
  514. logits_to_keep: int | torch.Tensor = 0,
  515. **kwargs: Unpack[TransformersKwargs],
  516. ) -> tuple | CsmOutputWithPast:
  517. r"""
  518. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
  519. 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
  520. requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
  521. 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
  522. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  523. [`PreTrainedTokenizer.__call__`] for details.
  524. [What are input IDs?](../glossary#input-ids)
  525. input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
  526. Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
  527. If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
  528. where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
  529. the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
  530. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  531. Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
  532. Requires targeted `input_values` to be provided as audio tokens will be inferred from it using the `codec_model`.
  533. - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
  534. - `-100` will be ignored in the loss computation
  535. - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
  536. Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
  537. logits_to_keep (`int` or `torch.Tensor`, *optional*):
  538. Kept for compatibility. Does not support another value than:
  539. 1. `0`, which is equivalent to keeping all logits, used in the training regime
  540. 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
  541. Example:
  542. ```python
  543. >>> import torch
  544. >>> from transformers import CsmForConditionalGeneration, AutoProcessor
  545. >>> from datasets import load_dataset, Audio
  546. >>> model_id = "sesame/csm-1b"
  547. >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
  548. >>> processor = AutoProcessor.from_pretrained(model_id)
  549. >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
  550. >>> # ensure the audio is 24kHz
  551. >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
  552. >>> conversation = []
  553. >>> # prepare a conversation with text and corresponding audio
  554. >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
  555. ... conversation.append(
  556. ... {
  557. ... "role": f"{speaker_id}",
  558. ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
  559. ... }
  560. ... )
  561. >>> inputs = processor.apply_chat_template(
  562. ... conversation,
  563. ... tokenize=True,
  564. ... return_dict=True,
  565. ... output_labels=True,
  566. ... ).to(torch_device)
  567. >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
  568. >>> output = model(**inputs)
  569. >>> output.loss.backward()
  570. ```"""
  571. if input_ids is not None and input_ids.ndim == 2:
  572. merged_inputs = self._merge_input_ids_with_input_values(
  573. input_ids, input_values, input_values_cutoffs, labels
  574. )
  575. inputs_embeds = merged_inputs["inputs_embeds"]
  576. labels = merged_inputs["labels"]
  577. input_ids = None
  578. backbone_outputs = self.backbone_model(
  579. input_ids=input_ids,
  580. attention_mask=attention_mask,
  581. position_ids=position_ids,
  582. past_key_values=past_key_values,
  583. inputs_embeds=inputs_embeds,
  584. use_cache=use_cache,
  585. **kwargs,
  586. )
  587. backbone_hidden_states = backbone_outputs[0]
  588. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  589. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  590. backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
  591. loss = None
  592. backbone_loss = None
  593. depth_decoder_loss = None
  594. depth_decoder_outputs = None
  595. if labels is not None:
  596. # select first codebook as labels for the backbone model
  597. backbone_labels = labels[:, :, 0]
  598. backbone_loss = self.loss_function(
  599. logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
  600. )
  601. # for the depth decoder, we need to select the frames to train on
  602. # those are frames where the label is not uniformly `ignore_index` along the codebook dimension
  603. train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
  604. depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
  605. # add place holder in position 0 that will be replaced by the backbone_last_hidden_state
  606. depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
  607. train_idxs = train_mask.nonzero(as_tuple=True)
  608. backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
  609. depth_decoder_labels = labels[train_mask]
  610. depth_decoder_outputs = self.depth_decoder(
  611. input_ids=depth_decoder_input_ids,
  612. backbone_last_hidden_state=backbone_last_hidden_states,
  613. use_cache=use_cache,
  614. return_dict=True,
  615. labels=depth_decoder_labels,
  616. **kwargs,
  617. )
  618. depth_decoder_loss = depth_decoder_outputs.loss
  619. loss = backbone_loss + depth_decoder_loss
  620. return CsmOutputWithPast(
  621. loss=loss,
  622. backbone_loss=backbone_loss,
  623. depth_decoder_loss=depth_decoder_loss,
  624. logits=backbone_logits,
  625. past_key_values=backbone_outputs.past_key_values,
  626. hidden_states=backbone_outputs.hidden_states,
  627. attentions=backbone_outputs.attentions,
  628. depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
  629. depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
  630. if depth_decoder_outputs is not None
  631. else None,
  632. depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
  633. if depth_decoder_outputs is not None
  634. else None,
  635. depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
  636. )
  637. __all__ = [
  638. "CsmPreTrainedModel",
  639. "CsmBackboneModel",
  640. "CsmDepthDecoderModel",
  641. "CsmDepthDecoderForCausalLM",
  642. "CsmForConditionalGeneration",
  643. ]