modeling_csm.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/csm/modular_csm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_csm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from dataclasses import dataclass
  22. from typing import Optional
  23. import torch
  24. import torch.nn as nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  33. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
  37. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  38. from ...utils.import_utils import is_torchdynamo_compiling
  39. from ...utils.output_capturing import capture_outputs
  40. from ..auto import AutoModel
  41. from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
  42. from .generation_csm import CsmGenerationMixin
  43. logger = logging.get_logger(__name__)
  44. @dataclass
  45. @auto_docstring(
  46. custom_intro="""
  47. Base class for the model autoregressive outputs.
  48. """
  49. )
  50. class CsmOutputWithPast(ModelOutput):
  51. r"""
  52. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  53. Language modeling loss (for next-token prediction).
  54. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  55. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  56. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  57. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  58. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  59. `past_key_values` input) to speed up sequential decoding.
  60. depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  61. Language modeling loss (for next-token prediction) of the depth decoder model.
  62. depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  63. Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
  64. depth_decoder_past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  65. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  66. depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  67. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  68. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  69. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  70. depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  71. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  72. sequence_length)`.
  73. backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  74. Language modeling loss (for next-token prediction) of the backbone model.
  75. """
  76. loss: torch.FloatTensor | None = None
  77. logits: torch.FloatTensor | None = None
  78. past_key_values: Cache | None = None
  79. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  80. attentions: tuple[torch.FloatTensor, ...] | None = None
  81. depth_decoder_loss: torch.FloatTensor | None = None
  82. depth_decoder_logits: torch.FloatTensor | None = None
  83. depth_decoder_past_key_values: Cache | None = None
  84. depth_decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  85. depth_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  86. backbone_loss: torch.FloatTensor | None = None
  87. @use_kernel_forward_from_hub("RMSNorm")
  88. class CsmRMSNorm(nn.Module):
  89. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  90. """
  91. CsmRMSNorm is equivalent to T5LayerNorm
  92. """
  93. super().__init__()
  94. self.weight = nn.Parameter(torch.ones(hidden_size))
  95. self.variance_epsilon = eps
  96. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  97. input_dtype = hidden_states.dtype
  98. hidden_states = hidden_states.to(torch.float32)
  99. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  100. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  101. return self.weight * hidden_states.to(input_dtype)
  102. def extra_repr(self):
  103. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  104. class CsmRotaryEmbedding(nn.Module):
  105. inv_freq: torch.Tensor # fix linting for `register_buffer`
  106. def __init__(self, config: CsmConfig, device=None):
  107. super().__init__()
  108. self.max_seq_len_cached = config.max_position_embeddings
  109. self.original_max_seq_len = config.max_position_embeddings
  110. self.config = config
  111. self.rope_type = self.config.rope_parameters["rope_type"]
  112. rope_init_fn: Callable = self.compute_default_rope_parameters
  113. if self.rope_type != "default":
  114. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  115. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  116. self.register_buffer("inv_freq", inv_freq, persistent=False)
  117. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  118. @staticmethod
  119. def compute_default_rope_parameters(
  120. config: CsmConfig | None = None,
  121. device: Optional["torch.device"] = None,
  122. seq_len: int | None = None,
  123. ) -> tuple["torch.Tensor", float]:
  124. """
  125. Computes the inverse frequencies according to the original RoPE implementation
  126. Args:
  127. config ([`~transformers.PreTrainedConfig`]):
  128. The model configuration.
  129. device (`torch.device`):
  130. The device to use for initialization of the inverse frequencies.
  131. seq_len (`int`, *optional*):
  132. The current sequence length. Unused for this type of RoPE.
  133. Returns:
  134. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  135. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  136. """
  137. base = config.rope_parameters["rope_theta"]
  138. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  139. attention_factor = 1.0 # Unused in this type of RoPE
  140. # Compute the inverse frequencies
  141. inv_freq = 1.0 / (
  142. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  143. )
  144. return inv_freq, attention_factor
  145. @torch.no_grad()
  146. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  147. def forward(self, x, position_ids):
  148. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  149. position_ids_expanded = position_ids[:, None, :].float()
  150. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  151. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  152. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  153. emb = torch.cat((freqs, freqs), dim=-1)
  154. cos = emb.cos() * self.attention_scaling
  155. sin = emb.sin() * self.attention_scaling
  156. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  157. class CsmMLP(nn.Module):
  158. def __init__(self, config):
  159. super().__init__()
  160. self.config = config
  161. self.hidden_size = config.hidden_size
  162. self.intermediate_size = config.intermediate_size
  163. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  164. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  165. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  166. self.act_fn = ACT2FN[config.hidden_act]
  167. def forward(self, x):
  168. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  169. return down_proj
  170. def rotate_half(x):
  171. """Rotates half the hidden dims of the input."""
  172. x1 = x[..., : x.shape[-1] // 2]
  173. x2 = x[..., x.shape[-1] // 2 :]
  174. return torch.cat((-x2, x1), dim=-1)
  175. @use_kernel_func_from_hub("rotary_pos_emb")
  176. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  177. """Applies Rotary Position Embedding to the query and key tensors.
  178. Args:
  179. q (`torch.Tensor`): The query tensor.
  180. k (`torch.Tensor`): The key tensor.
  181. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  182. sin (`torch.Tensor`): The sine part of the rotary embedding.
  183. unsqueeze_dim (`int`, *optional*, defaults to 1):
  184. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  185. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  186. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  187. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  188. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  189. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  190. Returns:
  191. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  192. """
  193. cos = cos.unsqueeze(unsqueeze_dim)
  194. sin = sin.unsqueeze(unsqueeze_dim)
  195. q_embed = (q * cos) + (rotate_half(q) * sin)
  196. k_embed = (k * cos) + (rotate_half(k) * sin)
  197. return q_embed, k_embed
  198. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  199. """
  200. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  201. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  202. """
  203. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  204. if n_rep == 1:
  205. return hidden_states
  206. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  207. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  208. def eager_attention_forward(
  209. module: nn.Module,
  210. query: torch.Tensor,
  211. key: torch.Tensor,
  212. value: torch.Tensor,
  213. attention_mask: torch.Tensor | None,
  214. scaling: float,
  215. dropout: float = 0.0,
  216. **kwargs: Unpack[TransformersKwargs],
  217. ):
  218. key_states = repeat_kv(key, module.num_key_value_groups)
  219. value_states = repeat_kv(value, module.num_key_value_groups)
  220. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  221. if attention_mask is not None:
  222. attn_weights = attn_weights + attention_mask
  223. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  224. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  225. attn_output = torch.matmul(attn_weights, value_states)
  226. attn_output = attn_output.transpose(1, 2).contiguous()
  227. return attn_output, attn_weights
  228. @use_kernelized_func(apply_rotary_pos_emb)
  229. class CsmAttention(nn.Module):
  230. """Multi-headed attention from 'Attention Is All You Need' paper"""
  231. def __init__(self, config: CsmConfig, layer_idx: int):
  232. super().__init__()
  233. self.config = config
  234. self.layer_idx = layer_idx
  235. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  236. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  237. self.scaling = self.head_dim**-0.5
  238. self.attention_dropout = config.attention_dropout
  239. self.is_causal = True
  240. self.q_proj = nn.Linear(
  241. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  242. )
  243. self.k_proj = nn.Linear(
  244. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  245. )
  246. self.v_proj = nn.Linear(
  247. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  248. )
  249. self.o_proj = nn.Linear(
  250. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  251. )
  252. def forward(
  253. self,
  254. hidden_states: torch.Tensor,
  255. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  256. attention_mask: torch.Tensor | None = None,
  257. past_key_values: Cache | None = None,
  258. **kwargs: Unpack[TransformersKwargs],
  259. ) -> tuple[torch.Tensor, torch.Tensor]:
  260. input_shape = hidden_states.shape[:-1]
  261. hidden_shape = (*input_shape, -1, self.head_dim)
  262. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  263. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  264. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  265. cos, sin = position_embeddings
  266. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  267. if past_key_values is not None:
  268. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  269. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  270. self.config._attn_implementation, eager_attention_forward
  271. )
  272. attn_output, attn_weights = attention_interface(
  273. self,
  274. query_states,
  275. key_states,
  276. value_states,
  277. attention_mask,
  278. dropout=0.0 if not self.training else self.attention_dropout,
  279. scaling=self.scaling,
  280. **kwargs,
  281. )
  282. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  283. attn_output = self.o_proj(attn_output)
  284. return attn_output, attn_weights
  285. class CsmDecoderLayer(GradientCheckpointingLayer):
  286. def __init__(self, config: CsmConfig, layer_idx: int):
  287. super().__init__()
  288. self.hidden_size = config.hidden_size
  289. self.self_attn = CsmAttention(config=config, layer_idx=layer_idx)
  290. self.mlp = CsmMLP(config)
  291. self.input_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  292. self.post_attention_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  293. def forward(
  294. self,
  295. hidden_states: torch.Tensor,
  296. attention_mask: torch.Tensor | None = None,
  297. position_ids: torch.LongTensor | None = None,
  298. past_key_values: Cache | None = None,
  299. use_cache: bool | None = False,
  300. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  301. **kwargs: Unpack[TransformersKwargs],
  302. ) -> torch.Tensor:
  303. residual = hidden_states
  304. hidden_states = self.input_layernorm(hidden_states)
  305. # Self Attention
  306. hidden_states, _ = self.self_attn(
  307. hidden_states=hidden_states,
  308. attention_mask=attention_mask,
  309. position_ids=position_ids,
  310. past_key_values=past_key_values,
  311. use_cache=use_cache,
  312. position_embeddings=position_embeddings,
  313. **kwargs,
  314. )
  315. hidden_states = residual + hidden_states
  316. # Fully Connected
  317. residual = hidden_states
  318. hidden_states = self.post_attention_layernorm(hidden_states)
  319. hidden_states = self.mlp(hidden_states)
  320. hidden_states = residual + hidden_states
  321. return hidden_states
  322. @auto_docstring(
  323. custom_intro="""
  324. The bare Csm Model outputting raw hidden-states without any specific head on top.
  325. """
  326. )
  327. @auto_docstring
  328. class CsmPreTrainedModel(PreTrainedModel):
  329. config: CsmConfig
  330. base_model_prefix = "model"
  331. input_modalities = ("audio", "text")
  332. supports_gradient_checkpointing = True
  333. _no_split_modules = ["CsmDecoderLayer"]
  334. _skip_keys_device_placement = ["past_key_values"]
  335. _supports_flash_attn = True
  336. _supports_sdpa = True
  337. # does not because of Mimi codec model
  338. # _supports_flex_attn = True
  339. _can_compile_fullgraph = True
  340. _supports_attention_backend = True
  341. _can_record_outputs = {
  342. "hidden_states": CsmDecoderLayer,
  343. "attentions": CsmAttention,
  344. }
  345. @torch.no_grad()
  346. def _init_weights(self, module):
  347. super()._init_weights(module)
  348. if isinstance(module, CsmCodebooksHead):
  349. num_codebooks = module.num_codebooks
  350. for i in range(num_codebooks - 1):
  351. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  352. elif isinstance(module, CsmBackboneModelEmbeddings):
  353. init.copy_(module.audio_tokens_offsets, torch.arange(self.config.num_codebooks) * self.config.vocab_size)
  354. @auto_docstring
  355. class CsmDepthDecoderModel(CsmPreTrainedModel):
  356. config: CsmDepthDecoderConfig
  357. def __init__(self, config):
  358. super().__init__(config)
  359. self.padding_idx = config.pad_token_id
  360. self.vocab_size = config.vocab_size
  361. self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size)
  362. self.layers = nn.ModuleList(
  363. [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  364. )
  365. self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  366. self.rotary_emb = CsmRotaryEmbedding(config=config)
  367. self.gradient_checkpointing = False
  368. self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False)
  369. # Initialize weights and apply final processing
  370. self.post_init()
  371. @merge_with_config_defaults
  372. @capture_outputs
  373. @auto_docstring
  374. def forward(
  375. self,
  376. input_ids: torch.LongTensor | None = None,
  377. backbone_last_hidden_state: torch.FloatTensor | None = None,
  378. attention_mask: torch.Tensor | None = None,
  379. position_ids: torch.LongTensor | None = None,
  380. past_key_values: Cache | None = None,
  381. inputs_embeds: torch.FloatTensor | None = None,
  382. use_cache: bool | None = None,
  383. **kwargs: Unpack[TransformersKwargs],
  384. ) -> tuple | BaseModelOutputWithPast:
  385. r"""
  386. backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
  387. The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
  388. is provided in the `input_ids` argument.
  389. """
  390. if position_ids is not None and not is_torchdynamo_compiling():
  391. logger.warning_once(
  392. "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids "
  393. "and as it requires them to be identical across the batch, the provided position_ids will be ignored."
  394. )
  395. position_ids = None
  396. if (input_ids is None) ^ (inputs_embeds is not None):
  397. raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
  398. if use_cache and past_key_values is None:
  399. past_key_values = DynamicCache(config=self.config)
  400. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  401. inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
  402. device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
  403. position_ids = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device)
  404. if inputs_embeds is None:
  405. codebook_idxs = torch.clamp(position_ids - 1, min=0)
  406. offset = codebook_idxs * self.vocab_size
  407. inputs_embeds = self.embed_tokens(input_ids + offset)
  408. input_ids_are_first_codebook = position_ids[0] == 0
  409. if backbone_last_hidden_state is not None:
  410. inputs_embeds[:, 0] = backbone_last_hidden_state
  411. else:
  412. if not is_torchdynamo_compiling() and input_ids_are_first_codebook:
  413. logger.warning(
  414. "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference."
  415. )
  416. inputs_embeds = self.inputs_embeds_projector(inputs_embeds)
  417. causal_mask = create_causal_mask(
  418. config=self.config,
  419. inputs_embeds=inputs_embeds,
  420. attention_mask=attention_mask,
  421. past_key_values=past_key_values,
  422. position_ids=position_ids,
  423. )
  424. hidden_states = inputs_embeds
  425. # create position embeddings to be shared across the decoder layers
  426. position_ids = position_ids.unsqueeze(0)
  427. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  428. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  429. hidden_states = decoder_layer(
  430. hidden_states,
  431. attention_mask=causal_mask,
  432. position_ids=position_ids,
  433. past_key_values=past_key_values,
  434. use_cache=use_cache,
  435. position_embeddings=position_embeddings,
  436. **kwargs,
  437. )
  438. hidden_states = self.norm(hidden_states)
  439. return BaseModelOutputWithPast(
  440. last_hidden_state=hidden_states,
  441. past_key_values=past_key_values if use_cache else None,
  442. )
  443. class CsmCodebooksHead(nn.Module):
  444. def __init__(self, hidden_size, num_codebooks, vocab_size):
  445. super().__init__()
  446. self.num_codebooks = num_codebooks
  447. self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size))
  448. def forward(self, hidden_states, codebook_indices=None):
  449. # -1 because of the concatenated backbone last hidden state
  450. codebook_indices = codebook_indices - 1
  451. codebook_weight = self.weight[codebook_indices]
  452. hidden_states = [
  453. nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T)
  454. for codebook_idx in range(codebook_weight.shape[0])
  455. ]
  456. hidden_states = torch.stack(hidden_states, dim=1)
  457. return hidden_states
  458. @auto_docstring(
  459. custom_intro="""
  460. The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
  461. which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
  462. (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
  463. """
  464. )
  465. class CsmDepthDecoderForCausalLM(CsmPreTrainedModel, GenerationMixin):
  466. _tied_weights_keys = None
  467. _tp_plan = None
  468. _pp_plan = None
  469. def __init__(self, config):
  470. super().__init__(config)
  471. self.model = CsmDepthDecoderModel(config)
  472. self.vocab_size = config.vocab_size
  473. self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size)
  474. # Initialize weights and apply final processing
  475. self.post_init()
  476. @can_return_tuple
  477. @auto_docstring
  478. def forward(
  479. self,
  480. input_ids: torch.LongTensor | None = None,
  481. backbone_last_hidden_state: torch.FloatTensor | None = None,
  482. attention_mask: torch.Tensor | None = None,
  483. position_ids: torch.LongTensor | None = None,
  484. past_key_values: Cache | None = None,
  485. inputs_embeds: torch.FloatTensor | None = None,
  486. labels: torch.LongTensor | None = None,
  487. use_cache: bool | None = None,
  488. logits_to_keep: int | torch.Tensor = 0,
  489. **kwargs: Unpack[TransformersKwargs],
  490. ) -> tuple | CausalLMOutputWithPast:
  491. r"""
  492. backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
  493. The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
  494. is provided in the `input_ids` argument.
  495. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  496. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  497. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  498. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  499. """
  500. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  501. seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1]
  502. device = inputs_embeds.device if inputs_embeds is not None else input_ids.device
  503. codebook_indices = torch.arange(seq_len, device=device) + past_seen_tokens
  504. outputs = self.model(
  505. input_ids=input_ids,
  506. backbone_last_hidden_state=backbone_last_hidden_state,
  507. attention_mask=attention_mask,
  508. position_ids=position_ids,
  509. past_key_values=past_key_values,
  510. inputs_embeds=inputs_embeds,
  511. use_cache=use_cache,
  512. **kwargs,
  513. )
  514. hidden_states = outputs[0]
  515. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  516. if isinstance(logits_to_keep, int):
  517. if logits_to_keep == 0:
  518. # skip idx 0 logits since it's for the concatenated backbone last hidden state
  519. slice_indices = slice(1, None)
  520. else:
  521. slice_indices = slice(-logits_to_keep, None)
  522. else:
  523. slice_indices = logits_to_keep
  524. logits = self.codebooks_head(hidden_states[:, slice_indices, :], codebook_indices[slice_indices])
  525. logits = logits.contiguous()
  526. loss = None
  527. if labels is not None:
  528. shift_labels = labels[..., 1:].contiguous()
  529. loss = self.loss_function(
  530. logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
  531. )
  532. return CausalLMOutputWithPast(
  533. loss=loss,
  534. logits=logits,
  535. past_key_values=outputs.past_key_values,
  536. hidden_states=outputs.hidden_states,
  537. attentions=outputs.attentions,
  538. )
  539. def prepare_inputs_for_generation(
  540. self,
  541. input_ids: torch.LongTensor,
  542. next_sequence_length: int | None = None,
  543. past_key_values: Cache | None = None,
  544. attention_mask: torch.LongTensor | None = None,
  545. inputs_embeds: torch.FloatTensor | None = None,
  546. is_first_iteration: bool | None = False,
  547. **kwargs,
  548. ):
  549. model_inputs = super().prepare_inputs_for_generation(
  550. input_ids, next_sequence_length, past_key_values, attention_mask, inputs_embeds, **kwargs
  551. )
  552. if not is_first_iteration:
  553. model_inputs.pop("backbone_last_hidden_state")
  554. # csm depth decoder does not use position_ids
  555. model_inputs.pop("position_ids")
  556. return model_inputs
  557. class CsmBackboneModelEmbeddings(nn.Module):
  558. def __init__(self, config):
  559. super().__init__()
  560. self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.codebook_size), config.hidden_size)
  561. self.register_buffer(
  562. "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.codebook_size, persistent=False
  563. )
  564. def forward(self, input_ids):
  565. inputs_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets)
  566. inputs_embeds = inputs_embeds.sum(dim=2)
  567. return inputs_embeds
  568. @auto_docstring
  569. class CsmBackboneModel(CsmPreTrainedModel):
  570. def __init__(self, config):
  571. super().__init__(config)
  572. self.padding_idx = config.pad_token_id
  573. self.vocab_size = config.vocab_size
  574. self.embed_tokens = CsmBackboneModelEmbeddings(config)
  575. self.layers = nn.ModuleList(
  576. [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  577. )
  578. self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  579. self.rotary_emb = CsmRotaryEmbedding(config=config)
  580. self.gradient_checkpointing = False
  581. # Initialize weights and apply final processing
  582. self.post_init()
  583. @merge_with_config_defaults
  584. @capture_outputs
  585. @auto_docstring
  586. def forward(
  587. self,
  588. input_ids: torch.LongTensor | None = None,
  589. attention_mask: torch.Tensor | None = None,
  590. position_ids: torch.LongTensor | None = None,
  591. past_key_values: Cache | None = None,
  592. inputs_embeds: torch.FloatTensor | None = None,
  593. use_cache: bool | None = None,
  594. **kwargs: Unpack[TransformersKwargs],
  595. ) -> BaseModelOutputWithPast:
  596. r"""
  597. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
  598. 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
  599. requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
  600. 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
  601. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  602. [`PreTrainedTokenizer.__call__`] for details.
  603. [What are input IDs?](../glossary#input-ids)
  604. """
  605. if (input_ids is None) ^ (inputs_embeds is not None):
  606. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  607. if inputs_embeds is None:
  608. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  609. if use_cache and past_key_values is None:
  610. past_key_values = DynamicCache(config=self.config)
  611. if position_ids is None:
  612. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  613. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  614. position_ids = position_ids.unsqueeze(0)
  615. causal_mask = create_causal_mask(
  616. config=self.config,
  617. inputs_embeds=inputs_embeds,
  618. attention_mask=attention_mask,
  619. past_key_values=past_key_values,
  620. position_ids=position_ids,
  621. )
  622. hidden_states = inputs_embeds
  623. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  624. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  625. hidden_states = decoder_layer(
  626. hidden_states,
  627. attention_mask=causal_mask,
  628. position_embeddings=position_embeddings,
  629. position_ids=position_ids,
  630. past_key_values=past_key_values,
  631. use_cache=use_cache,
  632. **kwargs,
  633. )
  634. hidden_states = self.norm(hidden_states)
  635. return BaseModelOutputWithPast(
  636. last_hidden_state=hidden_states,
  637. past_key_values=past_key_values,
  638. )
  639. @auto_docstring(
  640. custom_intro="""
  641. The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
  642. """
  643. )
  644. class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
  645. _tied_weights_keys = {
  646. "backbone_model.embed_tokens.embed_audio_tokens.weight": "depth_decoder.model.embed_tokens.weight"
  647. }
  648. def __init__(self, config):
  649. super().__init__(config)
  650. self.vocab_size = config.vocab_size
  651. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  652. self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size)
  653. self.backbone_model = CsmBackboneModel._from_config(config)
  654. self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config)
  655. self.codec_model = AutoModel.from_config(config.codec_config)
  656. self.post_init()
  657. def get_input_embeddings(self):
  658. return self.backbone_model.embed_tokens
  659. def set_input_embeddings(self, value):
  660. self.backbone_model.embed_tokens = value
  661. @classmethod
  662. def from_pretrained(cls, *args, **kwargs):
  663. if kwargs.get("output_loading_info", False):
  664. model, loading_info = super().from_pretrained(*args, **kwargs)
  665. else:
  666. model = super().from_pretrained(*args, **kwargs)
  667. # copy depth decoder generation conf attr to the depth decoder generation config
  668. prefix = "depth_decoder_"
  669. prefix_len = len(prefix)
  670. depth_decoder_attrs = {
  671. attr[prefix_len:]: value
  672. for attr, value in vars(model.generation_config).items()
  673. if attr.startswith(prefix)
  674. }
  675. vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs})
  676. # remove the depth decoder generation conf attr from the model generation config
  677. for attr in depth_decoder_attrs:
  678. delattr(model.generation_config, prefix + attr)
  679. if "output_loading_info" in kwargs:
  680. return model, loading_info
  681. else:
  682. return model
  683. def save_pretrained(self, *args, **kwargs):
  684. # copy the depth decoder generation config attributes to the model generation config
  685. prefix = "depth_decoder_"
  686. depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict()
  687. depth_decoder_attrs.pop("transformers_version", None)
  688. for attr, value in depth_decoder_attrs.items():
  689. setattr(self.generation_config, prefix + attr, value)
  690. super().save_pretrained(*args, **kwargs)
  691. def _merge_input_ids_with_input_values(
  692. self,
  693. input_ids: torch.Tensor | None = None,
  694. input_values: torch.Tensor | None = None,
  695. input_values_cutoffs: torch.Tensor | None = None,
  696. labels: torch.Tensor | None = None,
  697. ) -> torch.Tensor | None:
  698. """
  699. Merges the input_ids and input_values to produce a single inputs_embeds tensor:
  700. 1 - Infers the codec model on the input_values to retrieve codebook token.
  701. 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
  702. 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.
  703. Args:
  704. input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  705. The input ids to embed.
  706. input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
  707. The audio input values to embed.
  708. input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
  709. The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
  710. """
  711. inputs_embeds = self.embed_text_tokens(input_ids)
  712. if input_values is not None:
  713. # infer input_values_mask
  714. input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0))
  715. audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff()
  716. audio_lengths = audio_lengths[audio_lengths > 0]
  717. input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand(
  718. len(audio_lengths), -1
  719. )
  720. input_values_mask = input_values_mask < audio_lengths.unsqueeze(1)
  721. # =======================================
  722. # TODO: @eustlb, this should be batched !!!
  723. # but requires making sure batched inference of the codec model works as intended
  724. with torch.no_grad():
  725. audio_tokens_list = []
  726. for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
  727. batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
  728. for i in range(batch_input_values_cutoffs.shape[0] - 1):
  729. start_idx = batch_input_values_cutoffs[i]
  730. end_idx = batch_input_values_cutoffs[i + 1]
  731. audio_batch = batch_input_values[..., start_idx:end_idx]
  732. codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
  733. codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
  734. audio_tokens_list.append(codebook_ids[0])
  735. max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
  736. batched_audio_token_ids = torch.stack(
  737. [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
  738. )
  739. audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
  740. # =======================================
  741. audio_token_id = self.config.audio_token_id
  742. audio_token_mask = input_ids == audio_token_id
  743. audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids)
  744. inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask]
  745. # same for the audio eos token
  746. audio_eos_frame_ids = (
  747. torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long)
  748. * self.config.codebook_eos_token_id
  749. )
  750. audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1)
  751. audio_eos_token_mask = input_ids == self.config.audio_eos_token_id
  752. inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1)
  753. # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks)
  754. if labels is not None:
  755. labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
  756. labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
  757. labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
  758. # mask depth decoder
  759. depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
  760. labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100
  761. labels = labels_expanded
  762. return {"inputs_embeds": inputs_embeds, "labels": labels}
  763. def prepare_inputs_for_generation(
  764. self,
  765. input_ids: torch.LongTensor,
  766. next_sequence_length: int | None = None,
  767. past_key_values: Cache | None = None,
  768. attention_mask: torch.LongTensor | None = None,
  769. inputs_embeds: torch.FloatTensor | None = None,
  770. **kwargs,
  771. ):
  772. model_inputs = super().prepare_inputs_for_generation(
  773. input_ids=input_ids,
  774. next_sequence_length=next_sequence_length,
  775. past_key_values=past_key_values,
  776. attention_mask=attention_mask,
  777. inputs_embeds=inputs_embeds,
  778. **kwargs,
  779. )
  780. if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None:
  781. merged_inputs = self._merge_input_ids_with_input_values(
  782. input_ids=input_ids,
  783. input_values=kwargs.get("input_values"),
  784. input_values_cutoffs=kwargs.get("input_values_cutoffs"),
  785. labels=kwargs.get("labels"),
  786. )
  787. model_inputs.update(
  788. {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None}
  789. )
  790. return model_inputs
  791. @can_return_tuple
  792. @auto_docstring
  793. def forward(
  794. self,
  795. input_ids: torch.LongTensor | None = None,
  796. input_values: torch.Tensor | None = None,
  797. attention_mask: torch.Tensor | None = None,
  798. input_values_cutoffs: torch.Tensor | None = None,
  799. position_ids: torch.LongTensor | None = None,
  800. past_key_values: Cache | None = None,
  801. inputs_embeds: torch.FloatTensor | None = None,
  802. labels: torch.LongTensor | None = None,
  803. use_cache: bool | None = None,
  804. logits_to_keep: int | torch.Tensor = 0,
  805. **kwargs: Unpack[TransformersKwargs],
  806. ) -> tuple | CsmOutputWithPast:
  807. r"""
  808. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
  809. 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
  810. requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.
  811. 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.
  812. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  813. [`PreTrainedTokenizer.__call__`] for details.
  814. [What are input IDs?](../glossary#input-ids)
  815. input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
  816. Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
  817. If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
  818. where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
  819. the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
  820. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  821. Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
  822. Requires targeted `input_values` to be provided as audio tokens will be inferred from it using the `codec_model`.
  823. - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
  824. - `-100` will be ignored in the loss computation
  825. - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
  826. Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
  827. logits_to_keep (`int` or `torch.Tensor`, *optional*):
  828. Kept for compatibility. Does not support another value than:
  829. 1. `0`, which is equivalent to keeping all logits, used in the training regime
  830. 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime
  831. Example:
  832. ```python
  833. >>> import torch
  834. >>> from transformers import CsmForConditionalGeneration, AutoProcessor
  835. >>> from datasets import load_dataset, Audio
  836. >>> model_id = "sesame/csm-1b"
  837. >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
  838. >>> processor = AutoProcessor.from_pretrained(model_id)
  839. >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
  840. >>> # ensure the audio is 24kHz
  841. >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
  842. >>> conversation = []
  843. >>> # prepare a conversation with text and corresponding audio
  844. >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
  845. ... conversation.append(
  846. ... {
  847. ... "role": f"{speaker_id}",
  848. ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
  849. ... }
  850. ... )
  851. >>> inputs = processor.apply_chat_template(
  852. ... conversation,
  853. ... tokenize=True,
  854. ... return_dict=True,
  855. ... output_labels=True,
  856. ... ).to(torch_device)
  857. >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
  858. >>> output = model(**inputs)
  859. >>> output.loss.backward()
  860. ```"""
  861. if input_ids is not None and input_ids.ndim == 2:
  862. merged_inputs = self._merge_input_ids_with_input_values(
  863. input_ids, input_values, input_values_cutoffs, labels
  864. )
  865. inputs_embeds = merged_inputs["inputs_embeds"]
  866. labels = merged_inputs["labels"]
  867. input_ids = None
  868. backbone_outputs = self.backbone_model(
  869. input_ids=input_ids,
  870. attention_mask=attention_mask,
  871. position_ids=position_ids,
  872. past_key_values=past_key_values,
  873. inputs_embeds=inputs_embeds,
  874. use_cache=use_cache,
  875. **kwargs,
  876. )
  877. backbone_hidden_states = backbone_outputs[0]
  878. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  879. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  880. backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])
  881. loss = None
  882. backbone_loss = None
  883. depth_decoder_loss = None
  884. depth_decoder_outputs = None
  885. if labels is not None:
  886. # select first codebook as labels for the backbone model
  887. backbone_labels = labels[:, :, 0]
  888. backbone_loss = self.loss_function(
  889. logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
  890. )
  891. # for the depth decoder, we need to select the frames to train on
  892. # those are frames where the label is not uniformly `ignore_index` along the codebook dimension
  893. train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
  894. depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
  895. # add place holder in position 0 that will be replaced by the backbone_last_hidden_state
  896. depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)
  897. train_idxs = train_mask.nonzero(as_tuple=True)
  898. backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
  899. depth_decoder_labels = labels[train_mask]
  900. depth_decoder_outputs = self.depth_decoder(
  901. input_ids=depth_decoder_input_ids,
  902. backbone_last_hidden_state=backbone_last_hidden_states,
  903. use_cache=use_cache,
  904. return_dict=True,
  905. labels=depth_decoder_labels,
  906. **kwargs,
  907. )
  908. depth_decoder_loss = depth_decoder_outputs.loss
  909. loss = backbone_loss + depth_decoder_loss
  910. return CsmOutputWithPast(
  911. loss=loss,
  912. backbone_loss=backbone_loss,
  913. depth_decoder_loss=depth_decoder_loss,
  914. logits=backbone_logits,
  915. past_key_values=backbone_outputs.past_key_values,
  916. hidden_states=backbone_outputs.hidden_states,
  917. attentions=backbone_outputs.attentions,
  918. depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
  919. depth_decoder_past_key_values=depth_decoder_outputs.past_key_values
  920. if depth_decoder_outputs is not None
  921. else None,
  922. depth_decoder_hidden_states=depth_decoder_outputs.hidden_states
  923. if depth_decoder_outputs is not None
  924. else None,
  925. depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
  926. )
  927. __all__ = [
  928. "CsmPreTrainedModel",
  929. "CsmBackboneModel",
  930. "CsmDepthDecoderModel",
  931. "CsmDepthDecoderForCausalLM",
  932. "CsmForConditionalGeneration",
  933. ]