# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/csm/modular_csm.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_csm.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.import_utils import is_torchdynamo_compiling from ...utils.output_capturing import capture_outputs from ..auto import AutoModel from .configuration_csm import CsmConfig, CsmDepthDecoderConfig from .generation_csm import CsmGenerationMixin logger = logging.get_logger(__name__) @dataclass @auto_docstring( custom_intro=""" Base class for the model autoregressive outputs. """ ) class CsmOutputWithPast(ModelOutput): r""" loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction) of the depth decoder model. depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax). depth_decoder_past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss (for next-token prediction) of the backbone model. """ loss: torch.FloatTensor | None = None logits: torch.FloatTensor | None = None past_key_values: Cache | None = None hidden_states: tuple[torch.FloatTensor, ...] | None = None attentions: tuple[torch.FloatTensor, ...] | None = None depth_decoder_loss: torch.FloatTensor | None = None depth_decoder_logits: torch.FloatTensor | None = None depth_decoder_past_key_values: Cache | None = None depth_decoder_hidden_states: tuple[torch.FloatTensor, ...] | None = None depth_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None backbone_loss: torch.FloatTensor | None = None @use_kernel_forward_from_hub("RMSNorm") class CsmRMSNorm(nn.Module): def __init__(self, hidden_size, eps: float = 1e-6) -> None: """ CsmRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class CsmRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, config: CsmConfig, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_type = self.config.rope_parameters["rope_type"] rope_init_fn: Callable = self.compute_default_rope_parameters if self.rope_type != "default": rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) @staticmethod def compute_default_rope_parameters( config: CsmConfig | None = None, device: Optional["torch.device"] = None, seq_len: int | None = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation Args: config ([`~transformers.PreTrainedConfig`]): The model configuration. device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ base = config.rope_parameters["rope_theta"] dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads attention_factor = 1.0 # Unused in this type of RoPE # Compute the inverse frequencies inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) ) return inv_freq, attention_factor @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class CsmMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor | None, scaling: float, dropout: float = 0.0, **kwargs: Unpack[TransformersKwargs], ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @use_kernelized_func(apply_rotary_pos_emb) class CsmAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: CsmConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( self.config._attn_implementation, eager_attention_forward ) attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class CsmDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: CsmConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = CsmAttention(config=config, layer_idx=layer_idx) self.mlp = CsmMLP(config) self.input_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, use_cache: bool | None = False, position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @auto_docstring( custom_intro=""" The bare Csm Model outputting raw hidden-states without any specific head on top. """ ) @auto_docstring class CsmPreTrainedModel(PreTrainedModel): config: CsmConfig base_model_prefix = "model" input_modalities = ("audio", "text") supports_gradient_checkpointing = True _no_split_modules = ["CsmDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True # does not because of Mimi codec model # _supports_flex_attn = True _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { "hidden_states": CsmDecoderLayer, "attentions": CsmAttention, } @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, CsmCodebooksHead): num_codebooks = module.num_codebooks for i in range(num_codebooks - 1): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, CsmBackboneModelEmbeddings): init.copy_(module.audio_tokens_offsets, torch.arange(self.config.num_codebooks) * self.config.vocab_size) @auto_docstring class CsmDepthDecoderModel(CsmPreTrainedModel): config: CsmDepthDecoderConfig def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding((config.num_codebooks * config.vocab_size), config.backbone_hidden_size) self.layers = nn.ModuleList( [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = CsmRotaryEmbedding(config=config) self.gradient_checkpointing = False self.inputs_embeds_projector = nn.Linear(config.backbone_hidden_size, config.hidden_size, bias=False) # Initialize weights and apply final processing self.post_init() @merge_with_config_defaults @capture_outputs @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, backbone_last_hidden_state: torch.FloatTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: r""" backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*): The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model) is provided in the `input_ids` argument. """ if position_ids is not None and not is_torchdynamo_compiling(): logger.warning_once( "Custom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids " "and as it requires them to be identical across the batch, the provided position_ids will be ignored." ) position_ids = None if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds.") if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 inputs_seq_length = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] device = inputs_embeds.device if inputs_embeds is not None else input_ids.device position_ids = torch.arange(past_seen_tokens, past_seen_tokens + inputs_seq_length, device=device) if inputs_embeds is None: codebook_idxs = torch.clamp(position_ids - 1, min=0) offset = codebook_idxs * self.vocab_size inputs_embeds = self.embed_tokens(input_ids + offset) input_ids_are_first_codebook = position_ids[0] == 0 if backbone_last_hidden_state is not None: inputs_embeds[:, 0] = backbone_last_hidden_state else: if not is_torchdynamo_compiling() and input_ids_are_first_codebook: logger.warning( "When the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference." ) inputs_embeds = self.inputs_embeds_projector(inputs_embeds) causal_mask = create_causal_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_ids = position_ids.unsqueeze(0) position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, ) class CsmCodebooksHead(nn.Module): def __init__(self, hidden_size, num_codebooks, vocab_size): super().__init__() self.num_codebooks = num_codebooks self.weight = nn.Parameter(torch.empty(self.num_codebooks - 1, hidden_size, vocab_size)) def forward(self, hidden_states, codebook_indices=None): # -1 because of the concatenated backbone last hidden state codebook_indices = codebook_indices - 1 codebook_weight = self.weight[codebook_indices] hidden_states = [ nn.functional.linear(hidden_states[:, codebook_idx, :], codebook_weight[codebook_idx].T) for codebook_idx in range(codebook_weight.shape[0]) ] hidden_states = torch.stack(hidden_states, dim=1) return hidden_states @auto_docstring( custom_intro=""" The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top, which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook (e.g. position 0 is the first codebook and uses the first codebook head, etc.) """ ) class CsmDepthDecoderForCausalLM(CsmPreTrainedModel, GenerationMixin): _tied_weights_keys = None _tp_plan = None _pp_plan = None def __init__(self, config): super().__init__(config) self.model = CsmDepthDecoderModel(config) self.vocab_size = config.vocab_size self.codebooks_head = CsmCodebooksHead(config.hidden_size, config.num_codebooks, config.vocab_size) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, backbone_last_hidden_state: torch.FloatTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: r""" backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*): The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model) is provided in the `input_ids` argument. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] device = inputs_embeds.device if inputs_embeds is not None else input_ids.device codebook_indices = torch.arange(seq_len, device=device) + past_seen_tokens outputs = self.model( input_ids=input_ids, backbone_last_hidden_state=backbone_last_hidden_state, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss if isinstance(logits_to_keep, int): if logits_to_keep == 0: # skip idx 0 logits since it's for the concatenated backbone last hidden state slice_indices = slice(1, None) else: slice_indices = slice(-logits_to_keep, None) else: slice_indices = logits_to_keep logits = self.codebooks_head(hidden_states[:, slice_indices, :], codebook_indices[slice_indices]) logits = logits.contiguous() loss = None if labels is not None: shift_labels = labels[..., 1:].contiguous() loss = self.loss_function( logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, next_sequence_length: int | None = None, past_key_values: Cache | None = None, attention_mask: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, is_first_iteration: bool | None = False, **kwargs, ): model_inputs = super().prepare_inputs_for_generation( input_ids, next_sequence_length, past_key_values, attention_mask, inputs_embeds, **kwargs ) if not is_first_iteration: model_inputs.pop("backbone_last_hidden_state") # csm depth decoder does not use position_ids model_inputs.pop("position_ids") return model_inputs class CsmBackboneModelEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.embed_audio_tokens = nn.Embedding((config.num_codebooks * config.codebook_size), config.hidden_size) self.register_buffer( "audio_tokens_offsets", torch.arange(config.num_codebooks) * config.codebook_size, persistent=False ) def forward(self, input_ids): inputs_embeds = self.embed_audio_tokens(input_ids + self.audio_tokens_offsets) inputs_embeds = inputs_embeds.sum(dim=2) return inputs_embeds @auto_docstring class CsmBackboneModel(CsmPreTrainedModel): def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = CsmBackboneModelEmbeddings(config) self.layers = nn.ModuleList( [CsmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = CsmRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @merge_with_config_defaults @capture_outputs @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutputWithPast: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`): 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens. 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) if position_ids is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens position_ids = position_ids.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) for decoder_layer in self.layers[: self.config.num_hidden_layers]: hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, position_embeddings=position_embeddings, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, **kwargs, ) hidden_states = self.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, ) @auto_docstring( custom_intro=""" The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens. """ ) class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): _tied_weights_keys = { "backbone_model.embed_tokens.embed_audio_tokens.weight": "depth_decoder.model.embed_tokens.weight" } def __init__(self, config): super().__init__(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.embed_text_tokens = nn.Embedding(config.text_vocab_size, config.hidden_size) self.backbone_model = CsmBackboneModel._from_config(config) self.depth_decoder = CsmDepthDecoderForCausalLM._from_config(config.depth_decoder_config) self.codec_model = AutoModel.from_config(config.codec_config) self.post_init() def get_input_embeddings(self): return self.backbone_model.embed_tokens def set_input_embeddings(self, value): self.backbone_model.embed_tokens = value @classmethod def from_pretrained(cls, *args, **kwargs): if kwargs.get("output_loading_info", False): model, loading_info = super().from_pretrained(*args, **kwargs) else: model = super().from_pretrained(*args, **kwargs) # copy depth decoder generation conf attr to the depth decoder generation config prefix = "depth_decoder_" prefix_len = len(prefix) depth_decoder_attrs = { attr[prefix_len:]: value for attr, value in vars(model.generation_config).items() if attr.startswith(prefix) } vars(model.depth_decoder.generation_config).update({"_from_model_config": False, **depth_decoder_attrs}) # remove the depth decoder generation conf attr from the model generation config for attr in depth_decoder_attrs: delattr(model.generation_config, prefix + attr) if "output_loading_info" in kwargs: return model, loading_info else: return model def save_pretrained(self, *args, **kwargs): # copy the depth decoder generation config attributes to the model generation config prefix = "depth_decoder_" depth_decoder_attrs = self.depth_decoder.generation_config.to_diff_dict() depth_decoder_attrs.pop("transformers_version", None) for attr, value in depth_decoder_attrs.items(): setattr(self.generation_config, prefix + attr, value) super().save_pretrained(*args, **kwargs) def _merge_input_ids_with_input_values( self, input_ids: torch.Tensor | None = None, input_values: torch.Tensor | None = None, input_values_cutoffs: torch.Tensor | None = None, labels: torch.Tensor | None = None, ) -> torch.Tensor | None: """ Merges the input_ids and input_values to produce a single inputs_embeds tensor: 1 - Infers the codec model on the input_values to retrieve codebook token. 2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor. 3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor. Args: input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`): The input ids to embed. input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`): The audio input values to embed. input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`): The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio. """ inputs_embeds = self.embed_text_tokens(input_ids) if input_values is not None: # infer input_values_mask input_values_cutoffs = nn.functional.pad(input_values_cutoffs, (1, 0)) audio_lengths = input_values_cutoffs[input_values_cutoffs >= 0].diff() audio_lengths = audio_lengths[audio_lengths > 0] input_values_mask = torch.arange(input_values_cutoffs.max(), device=input_values.device).expand( len(audio_lengths), -1 ) input_values_mask = input_values_mask < audio_lengths.unsqueeze(1) # ======================================= # TODO: @eustlb, this should be batched !!! # but requires making sure batched inference of the codec model works as intended with torch.no_grad(): audio_tokens_list = [] for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs): batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0] for i in range(batch_input_values_cutoffs.shape[0] - 1): start_idx = batch_input_values_cutoffs[i] end_idx = batch_input_values_cutoffs[i + 1] audio_batch = batch_input_values[..., start_idx:end_idx] codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0)) codebook_ids = codec_outputs.audio_codes.transpose(1, -1) audio_tokens_list.append(codebook_ids[0]) max_audio_frames = max(el.shape[0] for el in audio_tokens_list) batched_audio_token_ids = torch.stack( [nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list] ) audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask) # ======================================= audio_token_id = self.config.audio_token_id audio_token_mask = input_ids == audio_token_id audio_embeds = self.backbone_model.embed_tokens(batched_audio_token_ids) inputs_embeds[audio_token_mask] = audio_embeds[audio_codes_mask] # same for the audio eos token audio_eos_frame_ids = ( torch.ones((1, 1, self.config.num_codebooks), device=input_ids.device, dtype=torch.long) * self.config.codebook_eos_token_id ) audio_eos_embeds = self.backbone_model.embed_tokens(audio_eos_frame_ids).squeeze(1) audio_eos_token_mask = input_ids == self.config.audio_eos_token_id inputs_embeds[audio_eos_token_mask] = audio_eos_embeds.repeat(audio_eos_token_mask.sum(), 1) # if the labels are provided, we need to expand the labels to (batch_size, seq_length, num_codebooks) if labels is not None: labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks) labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask] labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids # mask depth decoder depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True) labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100 labels = labels_expanded return {"inputs_embeds": inputs_embeds, "labels": labels} def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, next_sequence_length: int | None = None, past_key_values: Cache | None = None, attention_mask: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, **kwargs, ): model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, next_sequence_length=next_sequence_length, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs, ) if input_ids is not None and input_ids.ndim == 2 and model_inputs.get("inputs_embeds") is None: merged_inputs = self._merge_input_ids_with_input_values( input_ids=input_ids, input_values=kwargs.get("input_values"), input_values_cutoffs=kwargs.get("input_values_cutoffs"), labels=kwargs.get("labels"), ) model_inputs.update( {"inputs_embeds": merged_inputs["inputs_embeds"], "labels": merged_inputs["labels"], "input_ids": None} ) return model_inputs @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, input_values: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, input_values_cutoffs: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CsmOutputWithPast: r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`): 1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens. 2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*): Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input. If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2, the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]]. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`. Requires targeted `input_values` to be provided as audio tokens will be inferred from it using the `codec_model`. - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames) - `-100` will be ignored in the loss computation - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels) Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`]. logits_to_keep (`int` or `torch.Tensor`, *optional*): Kept for compatibility. Does not support another value than: 1. `0`, which is equivalent to keeping all logits, used in the training regime 2. `1`, which is equivalent to keeping only the last logit, used in the generation regime Example: ```python >>> import torch >>> from transformers import CsmForConditionalGeneration, AutoProcessor >>> from datasets import load_dataset, Audio >>> model_id = "sesame/csm-1b" >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu" >>> processor = AutoProcessor.from_pretrained(model_id) >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train") >>> # ensure the audio is 24kHz >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000)) >>> conversation = [] >>> # prepare a conversation with text and corresponding audio >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]): ... conversation.append( ... { ... "role": f"{speaker_id}", ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}], ... } ... ) >>> inputs = processor.apply_chat_template( ... conversation, ... tokenize=True, ... return_dict=True, ... output_labels=True, ... ).to(torch_device) >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device) >>> output = model(**inputs) >>> output.loss.backward() ```""" if input_ids is not None and input_ids.ndim == 2: merged_inputs = self._merge_input_ids_with_input_values( input_ids, input_values, input_values_cutoffs, labels ) inputs_embeds = merged_inputs["inputs_embeds"] labels = merged_inputs["labels"] input_ids = None backbone_outputs = self.backbone_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs, ) backbone_hidden_states = backbone_outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :]) loss = None backbone_loss = None depth_decoder_loss = None depth_decoder_outputs = None if labels is not None: # select first codebook as labels for the backbone model backbone_labels = labels[:, :, 0] backbone_loss = self.loss_function( logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs ) # for the depth decoder, we need to select the frames to train on # those are frames where the label is not uniformly `ignore_index` along the codebook dimension train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1) depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1] # add place holder in position 0 that will be replaced by the backbone_last_hidden_state depth_decoder_input_ids = nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0) train_idxs = train_mask.nonzero(as_tuple=True) backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :] depth_decoder_labels = labels[train_mask] depth_decoder_outputs = self.depth_decoder( input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_states, use_cache=use_cache, return_dict=True, labels=depth_decoder_labels, **kwargs, ) depth_decoder_loss = depth_decoder_outputs.loss loss = backbone_loss + depth_decoder_loss return CsmOutputWithPast( loss=loss, backbone_loss=backbone_loss, depth_decoder_loss=depth_decoder_loss, logits=backbone_logits, past_key_values=backbone_outputs.past_key_values, hidden_states=backbone_outputs.hidden_states, attentions=backbone_outputs.attentions, depth_decoder_logits=depth_decoder_outputs.logits if depth_decoder_outputs is not None else None, depth_decoder_past_key_values=depth_decoder_outputs.past_key_values if depth_decoder_outputs is not None else None, depth_decoder_hidden_states=depth_decoder_outputs.hidden_states if depth_decoder_outputs is not None else None, depth_decoder_attentions=depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, ) __all__ = [ "CsmPreTrainedModel", "CsmBackboneModel", "CsmDepthDecoderModel", "CsmDepthDecoderForCausalLM", "CsmForConditionalGeneration", ]