| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724 |
- # Copyright 2023 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch CLVP model."""
- import copy
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- import torch
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...activations import ACT2FN, get_activation
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationConfig, GenerationMixin
- from ...masking_utils import create_bidirectional_mask, create_causal_mask
- from ...modeling_outputs import (
- BaseModelOutputWithPastAndCrossAttentions,
- BaseModelOutputWithPooling,
- CausalLMOutputWithCrossAttentions,
- )
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import Conv1D
- from ...utils import (
- ModelOutput,
- TransformersKwargs,
- auto_docstring,
- can_return_tuple,
- logging,
- )
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_clvp import (
- ClvpConfig,
- ClvpDecoderConfig,
- ClvpEncoderConfig,
- )
- logger = logging.get_logger(__name__)
- # Copied from transformers.models.clip.modeling_clip.contrastive_loss
- def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
- return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
- # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clvp, image_loss->speech_loss
- def clvp_loss(similarity: torch.Tensor) -> torch.Tensor:
- caption_loss = contrastive_loss(similarity)
- speech_loss = contrastive_loss(similarity.t())
- return (caption_loss + speech_loss) / 2.0
- # Copied from transformers.models.llama.modeling_llama.rotate_half
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- def apply_rotary_pos_emb(q, k, v, cos, sin, position_ids, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`):
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
- used to pass offsetted position ids when working with a KV-cache.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- v_embed = (v * cos) + (rotate_half(v) * sin)
- return q_embed, k_embed, v_embed
- def _pad_extra_bos_eos_tokens(
- input_ids,
- attention_mask=None,
- pad_token_id=0,
- bos_token_id=255,
- eos_token_id=0,
- add_bos_token=True,
- add_eos_token=True,
- ):
- """
- This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in
- `ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`.
- """
- # add the bos token at the beginning
- if add_bos_token:
- input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=bos_token_id)
- attention_mask = (
- torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
- )
- modified_input_ids = input_ids
- if add_eos_token:
- modified_input_ids = torch.zeros(
- (input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype, device=input_ids.device
- )
- for i, each_input_id in enumerate(input_ids):
- # locate where the valid tokens end and then add the eos token
- if torch.isin(each_input_id, pad_token_id).sum():
- pos = torch.where(each_input_id == pad_token_id)[0].min()
- modified_input_ids[i] = torch.concatenate(
- [each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
- )
- else:
- # if there are no pad tokens present, then add eos to the end
- modified_input_ids[i] = torch.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id)
- attention_mask = (
- torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
- )
- return modified_input_ids, attention_mask
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for CLVP encoder's outputs that contains a pooling of the last hidden states as well as a projection
- output (a linear layer on top of the pooled output).
- """
- )
- class ClvpEncoderOutput(ModelOutput):
- r"""
- embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when model is initialized with `with_projection=True`):
- The embeddings obtained by applying the projection layer to the pooler_output.
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- The hidden state of the last layer of the model.
- pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
- Pooled output of the `last_hidden_state`.
- """
- embeds: torch.FloatTensor | None = None
- last_hidden_state: torch.FloatTensor | None = None
- pooler_output: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring
- class ClvpOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
- Contrastive loss for speech-text similarity.
- speech_ids (`torch.LongTensor`, *optional*):
- speech_ids (or speech candidates) generated by the `ClvpForCausalLM` model.
- logits_per_speech (`torch.FloatTensor` of shape `(speech_batch_size, text_batch_size)`):
- The scaled dot product scores between `speech_embeds` and `text_embeds`. This represents the speech-text
- similarity scores.
- logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, speech_batch_size)`):
- The scaled dot product scores between `text_embeds` and `speech_embeds`. This represents the text-speech
- similarity scores.
- text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
- The text embeddings obtained by applying the projection layer to the pooled output of the text encoder
- model.
- speech_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
- The speech embeddings obtained by applying the projection layer to the pooled output of the speech encoder
- model.
- text_model_output (`BaseModelOutputWithPooling`):
- The pooled output of the `last_hidden_state` of the text encoder Model.
- speech_model_output (`BaseModelOutputWithPooling`):
- The pooled output of the `last_hidden_state` of the speech encoder Model.
- decoder_hidden_states (`torch.FloatTensor`, *optional*):
- The hidden states of the decoder model.
- text_encoder_hidden_states (`torch.FloatTensor`, *optional*):
- The hidden states of the text encoder model.
- speech_encoder_hidden_states (`torch.FloatTensor`, *optional*):
- The hidden states of the speech encoder model.
- """
- loss: torch.FloatTensor | None = None
- speech_ids: torch.LongTensor | None = None
- logits_per_speech: torch.FloatTensor | None = None
- logits_per_text: torch.FloatTensor | None = None
- text_embeds: torch.FloatTensor | None = None
- speech_embeds: torch.FloatTensor | None = None
- text_model_output: BaseModelOutputWithPooling = None
- speech_model_output: BaseModelOutputWithPooling = None
- decoder_hidden_states: torch.FloatTensor | None = None
- text_encoder_hidden_states: torch.FloatTensor | None = None
- speech_encoder_hidden_states: torch.FloatTensor | None = None
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Clvp
- class ClvpRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
- """
- ClvpRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- class ClvpRotaryPositionalEmbedding(nn.Module):
- """
- Rotary Position Embedding Class for CLVP. It was proposed in the paper 'ROFORMER: ENHANCED TRANSFORMER WITH ROTARY
- POSITION EMBEDDING', Please see https://huggingface.co/papers/2104.09864.
- """
- def __init__(self, config):
- super().__init__()
- dim = max(config.projection_dim // (config.num_attention_heads * 2), 32)
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
- self.register_buffer("inv_freq", inv_freq)
- self.cached_sequence_length = None
- self.cached_rotary_positional_embedding = None
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
- sequence_length = hidden_states.shape[1]
- if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
- return self.cached_rotary_positional_embedding
- self.cached_sequence_length = sequence_length
- time_stamps = torch.arange(sequence_length, device=hidden_states.device).type_as(self.inv_freq)
- freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
- embeddings = torch.cat((freqs, freqs), dim=-1)
- self.cached_rotary_positional_embedding = embeddings.unsqueeze(0)
- return self.cached_rotary_positional_embedding
- class ClvpSelfAttention(nn.Module):
- """
- Multi-headed attention to combine Absolute and Rotary Positional Embeddings into a single Attention module.
- """
- def __init__(self, config, layer_idx=None):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.layer_idx = layer_idx
- if hasattr(config, "max_position_embeddings"):
- max_positions = config.max_position_embeddings
- bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
- bias = bias.view(1, 1, max_positions, max_positions)
- self.register_buffer("bias", bias, persistent=False)
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_attention_bias)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- rotary_pos_emb: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]:
- # Raise error when position_ids is None but rotary_pos_emb is provided, because we need that when applying
- # rotary_pos_emb to query and key states.
- if rotary_pos_emb is not None and position_ids is None:
- raise ValueError("`position_ids` must be provided when `rotary_pos_emb` is not None.")
- bsz, _, embed_dim = hidden_states.size()
- # get query proj
- query_states = self._shape(self.q_proj(hidden_states), -1, bsz) * self.scale
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- if rotary_pos_emb is not None:
- rotary_emb_dim = rotary_pos_emb.shape[-1]
- # Partial rotary embedding
- query_rot, query_pass = (
- query_states[..., :rotary_emb_dim],
- query_states[..., rotary_emb_dim:],
- )
- key_rot, key_pass = (
- key_states[..., :rotary_emb_dim],
- key_states[..., rotary_emb_dim:],
- )
- value_rot, value_pass = (
- value_states[..., :rotary_emb_dim],
- value_states[..., rotary_emb_dim:],
- )
- cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0)
- query_rot, key_rot, value_rot = apply_rotary_pos_emb(query_rot, key_rot, value_rot, cos, sin, position_ids)
- # [batch_size, num_heads, seq_length, head_dim]
- query_states = torch.cat((query_rot, query_pass), dim=-1)
- key_states = torch.cat((key_rot, key_pass), dim=-1)
- value_states = torch.cat((value_rot, value_pass), dim=-1)
- tgt_len = query_states.shape[2]
- src_len = key_states.shape[2]
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = torch.matmul(attn_probs, value_states)
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- class ClvpGatedLinearUnit(nn.Module):
- """
- `ClvpGatedLinearUnit` uses the second half of the `hidden_states` to act as a gate for the first half of the
- `hidden_states` which controls the flow of data from the first of the tensor.
- """
- def __init__(self, config):
- super().__init__()
- self.activation_fn = ACT2FN[config.hidden_act]
- self.proj = nn.Linear(config.hidden_size, config.intermediate_size * 2)
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
- hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
- return hidden_states * self.activation_fn(gate)
- class ClvpEncoderMLP(nn.Module):
- """
- This MLP is used in CLVP speech or text encoder models.
- """
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.fc1 = ClvpGatedLinearUnit(config)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
- self.dropout_layer = nn.Dropout(config.dropout)
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.dropout_layer(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- class ClvpEncoderLayer(nn.Module):
- def __init__(self, config: ClvpConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.self_attn = ClvpSelfAttention(config)
- self.mlp = ClvpEncoderMLP(config)
- self.input_rmsnorm = ClvpRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.post_attention_rmsnorm = ClvpRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- rotary_pos_emb: torch.FloatTensor,
- attention_mask: torch.LongTensor,
- position_ids: torch.LongTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.FloatTensor:
- residual = hidden_states
- hidden_states = self.input_rmsnorm(hidden_states)
- hidden_states, _ = self.self_attn(
- hidden_states,
- rotary_pos_emb=rotary_pos_emb,
- attention_mask=attention_mask,
- position_ids=position_ids,
- **kwargs,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.post_attention_rmsnorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
- # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Clvp
- class ClvpSequenceSummary(nn.Module):
- r"""
- Compute a single vector summary of a sequence hidden states.
- Args:
- config ([`ClvpConfig`]):
- The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
- config class of your model for the default values it uses):
- - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
- - `"last"` -- Take the last token hidden state (like XLNet)
- - `"first"` -- Take the first token hidden state (like Bert)
- - `"mean"` -- Take the mean of all tokens hidden states
- - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- - `"attn"` -- Not implemented now, use multi-head attention
- - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
- - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
- (otherwise to `config.hidden_size`).
- - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
- another string or `None` will add no activation.
- - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
- - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
- """
- def __init__(self, config: ClvpConfig):
- super().__init__()
- self.summary_type = getattr(config, "summary_type", "last")
- if self.summary_type == "attn":
- # We should use a standard multi-head attention module with absolute positional embedding for that.
- # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
- # We can probably just use the multi-head attention module of PyTorch >=1.1.0
- raise NotImplementedError
- self.summary = nn.Identity()
- if hasattr(config, "summary_use_proj") and config.summary_use_proj:
- if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
- num_classes = config.num_labels
- else:
- num_classes = config.hidden_size
- self.summary = nn.Linear(config.hidden_size, num_classes)
- activation_string = getattr(config, "summary_activation", None)
- self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
- self.first_dropout = nn.Identity()
- if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
- self.first_dropout = nn.Dropout(config.summary_first_dropout)
- self.last_dropout = nn.Identity()
- if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
- self.last_dropout = nn.Dropout(config.summary_last_dropout)
- def forward(
- self, hidden_states: torch.FloatTensor, cls_index: torch.LongTensor | None = None
- ) -> torch.FloatTensor:
- """
- Compute a single vector summary of a sequence hidden states.
- Args:
- hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
- The hidden states of the last layer.
- cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
- Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
- Returns:
- `torch.FloatTensor`: The summary of the sequence hidden states.
- """
- if self.summary_type == "last":
- output = hidden_states[:, -1]
- elif self.summary_type == "first":
- output = hidden_states[:, 0]
- elif self.summary_type == "mean":
- output = hidden_states.mean(dim=1)
- elif self.summary_type == "cls_index":
- if cls_index is None:
- cls_index = torch.full_like(
- hidden_states[..., :1, :],
- hidden_states.shape[-2] - 1,
- dtype=torch.long,
- )
- else:
- cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
- cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
- # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
- output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
- elif self.summary_type == "attn":
- raise NotImplementedError
- output = self.first_dropout(output)
- output = self.summary(output)
- output = self.activation(output)
- output = self.last_dropout(output)
- return output
- # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->ClvpDecoderMLP
- class ClvpDecoderMLP(nn.Module):
- def __init__(self, intermediate_size, config):
- super().__init__()
- embed_dim = config.hidden_size
- self.c_fc = Conv1D(intermediate_size, embed_dim)
- self.c_proj = Conv1D(embed_dim, intermediate_size)
- self.act = ACT2FN[config.activation_function]
- self.dropout = nn.Dropout(config.resid_pdrop)
- def forward(self, hidden_states: tuple[torch.FloatTensor] | None) -> torch.FloatTensor:
- hidden_states = self.c_fc(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.c_proj(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states
- class ClvpDecoderLayer(nn.Module):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- hidden_size = config.hidden_size
- inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
- self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = ClvpSelfAttention(config, layer_idx=layer_idx)
- self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.mlp = ClvpDecoderMLP(inner_dim, config)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- past_key_values: Cache | None = None,
- attention_mask: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- use_cache: bool | None = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states, _ = self.attn(
- hidden_states,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=use_cache,
- **kwargs,
- )
- # residual connection
- hidden_states = hidden_states + residual
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- feed_forward_hidden_states = self.mlp(hidden_states)
- # residual connection
- hidden_states = residual + feed_forward_hidden_states
- return hidden_states
- class ClvpConditioningEncoder(nn.Module):
- """
- This class processes the log-mel spectrograms(extracted by the Feature Extractor) and text tokens(produced by the
- tokenizer) as inputs for the decoder model.
- First each log-mel spectrogram is processed into a single vector which captures valuable characteristics from each
- of them, then the text tokens are converted into token embeddings and position embeddings are added afterwards.
- Both of these vectors are concatenated and then passed to the decoder model.
- The text tokens helps to incorporate the "text information" and the log-mel spectrogram is used to specify the
- "voice characteristics" into the generated mel tokens.
- """
- def __init__(self, config: ClvpConfig):
- super().__init__()
- self.text_config = config.text_config
- self.decoder_config = config.decoder_config
- self.text_token_embedding = nn.Embedding(self.text_config.vocab_size, self.decoder_config.hidden_size)
- self.text_position_embedding = nn.Embedding(
- self.decoder_config.max_text_tokens, self.decoder_config.hidden_size
- )
- self.mel_conv = nn.Conv1d(self.decoder_config.feature_size, self.decoder_config.hidden_size, kernel_size=1)
- # define group norms to be used before each attention layer
- num_groups = self.compute_groupnorm_groups(self.decoder_config.hidden_size)
- self.group_norms = nn.ModuleList(
- [
- nn.GroupNorm(num_groups, self.decoder_config.hidden_size, eps=1e-5, affine=True)
- for _ in range(self.decoder_config.num_mel_attn_blocks)
- ]
- )
- # define the attention layers
- self.mel_attn_blocks = nn.ModuleList(
- [ClvpSelfAttention(self.decoder_config) for _ in range(self.decoder_config.num_mel_attn_blocks)]
- )
- self.gradient_checkpointing = False
- def compute_groupnorm_groups(self, channels: int, groups: int = 32):
- """
- Calculates the value of `num_groups` for nn.GroupNorm. This logic is taken from the official tortoise
- repository. link :
- https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/models/arch_util.py#L26
- """
- if channels <= 16:
- groups = 8
- elif channels <= 64:
- groups = 16
- while channels % groups != 0:
- groups = int(groups / 2)
- if groups <= 2:
- raise ValueError(
- f"Number of groups for the GroupNorm must be greater than 2, but it is {groups}."
- f"Please consider using a different `hidden_size`"
- )
- return groups
- def forward(
- self,
- input_features: torch.FloatTensor,
- input_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- ):
- # process text
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- batch_size, seq_length = input_ids.size()
- elif inputs_embeds is not None:
- batch_size, seq_length = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- # construct attention mask if not given
- if attention_mask is None:
- attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=input_ids.device)
- # We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
- # This logic is specific to ClvpConditioningEncoder and not used by other modules.
- input_ids, attention_mask = _pad_extra_bos_eos_tokens(
- input_ids,
- attention_mask,
- bos_token_id=self.text_config.bos_token_id,
- eos_token_id=self.text_config.eos_token_id,
- )
- inputs_embeds = self.text_token_embedding(input_ids)
- position_ids = attention_mask.cumsum(-1) - 1
- position_embeds = self.text_position_embedding(position_ids)
- text_embeds = inputs_embeds + position_embeds
- if self.gradient_checkpointing and self.training:
- # process each log-mel spectrogram into a single vector
- mel_spec = torch.utils.checkpoint.checkpoint(self.mel_conv, input_features)
- for i, mel_attn_block in enumerate(self.mel_attn_blocks):
- residual_mel_spec = mel_spec.transpose(1, 2)
- mel_spec = torch.utils.checkpoint.checkpoint(self.group_norms[i], mel_spec).transpose(1, 2)
- mel_spec = torch.utils.checkpoint.checkpoint(mel_attn_block, mel_spec)[0] + residual_mel_spec
- mel_spec = mel_spec.transpose(1, 2)
- else:
- # process each log-mel spectrogram into a single vector
- mel_spec = self.mel_conv(input_features)
- for i, mel_attn_block in enumerate(self.mel_attn_blocks):
- residual_mel_spec = mel_spec.transpose(1, 2)
- mel_spec = self.group_norms[i](mel_spec).transpose(1, 2)
- mel_spec = mel_attn_block(mel_spec)[0] + residual_mel_spec
- mel_spec = mel_spec.transpose(1, 2)
- mel_spec = mel_spec[:, :, 0]
- mel_spec = mel_spec.unsqueeze(1)
- # repeat if there is either (1 text vs N audios) or (N texts vs 1 audio)
- if text_embeds.shape[0] == 1 and mel_spec.shape[0] != 1:
- text_embeds = text_embeds.repeat(mel_spec.shape[0], 1, 1)
- elif text_embeds.shape[0] != 1 and mel_spec.shape[0] == 1:
- mel_spec = mel_spec.repeat(text_embeds.shape[0], 1, 1)
- # If there is N texts and M audios we will raise error since the number of text and audio must be same.
- elif text_embeds.shape[0] != mel_spec.shape[0]:
- raise ValueError(
- f"The number of texts and number of audios must be same. "
- f"Found {text_embeds.shape[0]} texts vs {mel_spec.shape[0]} audios"
- )
- return torch.concat([mel_spec, text_embeds], dim=1)
- @auto_docstring
- class ClvpPreTrainedModel(PreTrainedModel):
- config: ClvpConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _skip_keys_device_placement = "past_key_values"
- _can_record_outputs = {
- "hidden_states": (ClvpEncoderLayer, ClvpDecoderLayer),
- "attentions": ClvpSelfAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module: nn.Module):
- """Initialize the weights"""
- factor = self.config.initializer_factor
- if isinstance(module, nn.Embedding):
- init.normal_(module.weight, mean=0.0, std=factor * 0.02)
- elif isinstance(module, (nn.Linear, Conv1D, nn.Conv1d)):
- init.normal_(module.weight, mean=0.0, std=factor * 0.02)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, ClvpRMSNorm):
- init.ones_(module.weight)
- elif isinstance(module, ClvpEncoderMLP):
- in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
- fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
- init.normal_(module.fc1.proj.weight if getattr(module.fc1, "proj") else module.fc1.weight, std=fc_std)
- init.normal_(module.fc2.weight, std=in_proj_std)
- elif isinstance(module, ClvpEncoder):
- config = self.config.get_text_config()
- factor = config.initializer_factor
- init.normal_(module.projection.weight, mean=0.0, std=factor * (config.hidden_size**-0.5))
- elif isinstance(module, ClvpConditioningEncoder):
- init.normal_(module.mel_conv.weight, mean=0.0, std=factor)
- init.zeros_(module.mel_conv.bias)
- elif isinstance(module, ClvpForCausalLM):
- for name, p in module.named_parameters():
- if name == "c_proj.weight":
- init.normal_(
- p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers)
- )
- elif isinstance(module, ClvpModelForConditionalGeneration):
- init.constant_(module.logit_scale, self.config.logit_scale_init_value)
- elif isinstance(module, ClvpSelfAttention):
- if hasattr(module.config, "max_position_embeddings"):
- max_positions = module.config.max_position_embeddings
- bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
- bias = bias.view(1, 1, max_positions, max_positions)
- init.copy_(module.bias, bias)
- elif isinstance(module, ClvpRotaryPositionalEmbedding):
- dim = max(self.config.projection_dim // (self.config.num_attention_heads * 2), 32)
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
- init.copy_(module.inv_freq, inv_freq)
- if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- class ClvpEncoder(ClvpPreTrainedModel):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`ClvpEncoderLayer`].
- Args:
- config: ClvpConfig
- """
- config: ClvpEncoderConfig
- def __init__(self, config: ClvpConfig):
- super().__init__(config)
- self.config = config
- self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
- self.rotary_pos_emb = ClvpRotaryPositionalEmbedding(config) if config.use_rotary_embedding else None
- self.layers = nn.ModuleList([ClvpEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.sequence_summary = ClvpSequenceSummary(config)
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
- self.gradient_checkpointing = False
- self.post_init()
- def get_input_embeddings(self):
- return self.token_embedding
- def set_input_embeddings(self, value):
- self.token_embedding = value
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.LongTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> ClvpEncoderOutput:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.token_embedding(input_ids)
- # expand attention_mask and create position_ids if needed
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- )
- if position_ids is None:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- position_ids = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=device)
- position_ids = position_ids.unsqueeze(0)
- rotary_pos_emb = self.rotary_pos_emb(inputs_embeds) if self.rotary_pos_emb is not None else None
- hidden_states = inputs_embeds
- for encoder_layer in self.layers:
- hidden_states = encoder_layer(
- hidden_states,
- rotary_pos_emb,
- attention_mask,
- position_ids,
- **kwargs,
- )
- last_hidden_state = self.final_layer_norm(hidden_states)
- # take the mean over axis 1 and get pooled output
- pooled_output = self.sequence_summary(last_hidden_state)
- # apply the projection layer
- embeds = self.projection(pooled_output)
- return ClvpEncoderOutput(
- embeds=embeds,
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- )
- class ClvpDecoder(ClvpPreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ClvpDecoderLayer`]
- """
- config: ClvpDecoderConfig
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- self.input_embeds_layer = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
- self.position_embeds_layer = nn.Embedding(self.config.max_position_embeddings, self.config.hidden_size)
- self.drop = nn.Dropout(self.config.embd_pdrop)
- self.layers = nn.ModuleList(
- [ClvpDecoderLayer(self.config, layer_idx=i) for i in range(self.config.num_hidden_layers)]
- )
- self.layer_norm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_epsilon)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.input_embeds_layer
- def set_input_embeddings(self, new_embeddings):
- self.input_embeds_layer = new_embeddings
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPastAndCrossAttentions:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.input_embeds_layer(input_ids)
- seq_len = inputs_embeds.shape[1]
- if token_type_ids is not None:
- token_type_ids = token_type_ids.view(-1, seq_len)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- position_embeds = self.position_embeds_layer(position_ids)
- inputs_embeds = inputs_embeds + position_embeds
- attention_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- )
- hidden_states = inputs_embeds
- if token_type_ids is not None:
- token_type_embeds = self.input_embeds_layer(token_type_ids)
- hidden_states = hidden_states + token_type_embeds
- hidden_states = self.drop(hidden_states)
- output_shape = (
- -1,
- seq_len,
- ) + (hidden_states.size(-1),)
- for block in self.layers:
- hidden_states = block(
- hidden_states,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- position_ids=position_ids,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = self.layer_norm(hidden_states)
- hidden_states = hidden_states.view(output_shape)
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- @auto_docstring
- class ClvpModel(ClvpPreTrainedModel):
- config: ClvpDecoderConfig
- def __init__(self, config: ClvpDecoderConfig):
- super().__init__(config)
- self.config = config
- self.decoder = ClvpDecoder(self.config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.decoder.input_embeds_layer
- def set_input_embeddings(self, value):
- self.decoder.input_embeds_layer = value
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutputWithPastAndCrossAttentions:
- # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
- decoder_outputs: BaseModelOutputWithPastAndCrossAttentions = self.decoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- hidden_states=decoder_outputs.hidden_states,
- attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- The CLVP decoder model with a language modelling head on top.
- """
- )
- class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
- config: ClvpDecoderConfig
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- self.model = ClvpModel(self.config)
- self.final_norm = nn.LayerNorm(self.config.hidden_size)
- self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=True)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return None
- def get_input_embeddings(self):
- return self.model.decoder.input_embeds_layer
- def set_input_embeddings(self, new_embeddings):
- self.model.decoder.input_embeds_layer = new_embeddings
- def _prepare_model_inputs(
- self,
- inputs: torch.Tensor | None,
- bos_token_id: int | None,
- model_kwargs: dict[str, torch.Tensor],
- ) -> tuple[torch.Tensor, str | None, dict[str, torch.Tensor]]:
- """
- This function extracts the model-specific `inputs` for generation.
- """
- input_name = self.main_input_name
- model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
- inputs_kwarg = model_kwargs.pop(input_name, None)
- if inputs_kwarg is not None and inputs is not None:
- raise ValueError(
- f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
- f"Make sure to either pass {inputs} or {input_name}=..."
- )
- elif inputs_kwarg is not None:
- inputs = inputs_kwarg
- if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
- model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
- inputs, bos_token_id, model_kwargs=model_kwargs
- )
- inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
- # Check if conditioning_embeds are provided or not, if yes then concatenate the bos_token_id at the end of the conditioning_embeds.
- # Then we must subtract the positional_ids because during the forward pass it will be added anyways, so we must cancel them out here.
- conditioning_embeds = model_kwargs.get("conditioning_embeds")
- if conditioning_embeds is not None:
- mel_start_token_embedding = self.model.decoder.input_embeds_layer(
- torch.full(
- (conditioning_embeds.shape[0], 1),
- fill_value=self.config.bos_token_id,
- device=conditioning_embeds.device,
- )
- )
- mel_start_token_embedding += self.model.decoder.position_embeds_layer(
- torch.full((conditioning_embeds.shape[0], 1), fill_value=0, device=conditioning_embeds.device)
- )
- conditioning_embeds = torch.concat([conditioning_embeds, mel_start_token_embedding], dim=1)
- # subtract the positional_ids here
- if hasattr(model_kwargs, "attention_mask"):
- position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1
- else:
- position_ids = torch.arange(
- 0, conditioning_embeds.shape[1], dtype=torch.long, device=conditioning_embeds.device
- )
- position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1)
- model_kwargs["inputs_embeds"] = conditioning_embeds - self.model.decoder.position_embeds_layer(
- position_ids
- )
- model_kwargs["input_ids"] = (
- torch.ones((model_kwargs["inputs_embeds"].shape[0], 1), dtype=torch.long, device=self.device)
- * self.config.bos_token_id
- )
- return model_kwargs["inputs_embeds"], "inputs_embeds", model_kwargs
- inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
- return inputs, input_name, model_kwargs
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- inputs_embeds=None,
- conditioning_embeds=None,
- is_first_iteration=False,
- **kwargs,
- ):
- # Overwritten: has `conditioning_embeds`-related logic
- input_ids_length = input_ids.shape[-1]
- model_inputs = super().prepare_inputs_for_generation(
- input_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- is_first_iteration=is_first_iteration,
- **kwargs,
- )
- if conditioning_embeds is not None and not is_first_iteration:
- model_inputs["position_ids"] = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device)
- return model_inputs
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | CausalLMOutputWithCrossAttentions:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
- """
- outputs: BaseModelOutputWithPastAndCrossAttentions = self.model(
- input_ids=input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- lm_logits = self.final_norm(hidden_states)
- lm_logits = self.lm_head(lm_logits)
- loss = None
- if labels is not None:
- labels = labels.to(lm_logits.device)
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
- return CausalLMOutputWithCrossAttentions(
- loss=loss,
- logits=lm_logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @auto_docstring(
- custom_intro="""
- The composite CLVP model with a text encoder, speech encoder and speech decoder model.
- """
- )
- class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin):
- def __init__(self, config: ClvpConfig):
- super().__init__(config)
- if not isinstance(config.text_config, ClvpEncoderConfig):
- raise TypeError(
- "config.text_config is expected to be of type `ClvpEncoderConfig` but is of type"
- f" {type(config.text_config)}."
- )
- if not isinstance(config.speech_config, ClvpEncoderConfig):
- raise TypeError(
- "config.speech_config is expected to be of type `ClvpEncoderConfig` but is of type"
- f" {type(config.speech_config)}."
- )
- if not isinstance(config.decoder_config, ClvpDecoderConfig):
- raise TypeError(
- "config.decoder_config is expected to be of type `ClvpDecoderConfig` but is of type"
- f" {type(config.decoder_config)}."
- )
- self.conditioning_encoder = ClvpConditioningEncoder(config)
- self.speech_decoder_model = ClvpForCausalLM(config.decoder_config)
- self.text_encoder_model = ClvpEncoder(config.text_config)
- self.speech_encoder_model = ClvpEncoder(config.speech_config)
- self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
- # Initialize weights and apply final processing
- self.post_init()
- # taken from the original repo,
- # link : https://github.com/neonbjb/tortoise-tts/blob/4003544b6ff4b68c09856e04d3eff9da26d023c2/tortoise/api.py#L117
- def fix_speech_decoder_output(self, speech_ids: torch.LongTensor) -> torch.LongTensor:
- """
- This method modifies the output of the decoder model, such as replacing the `eos_token_id` and changing the
- last few tokens of each sequence.
- Args:
- speech_ids (`torch.LongTensor`):
- This refers to the output of the decoder model.
- """
- decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes
- speech_ids = speech_ids[:, 1:]
- stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0)
- speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0])
- for i, each_seq_stop_token_index in enumerate(stop_token_indices):
- # This means that no stop tokens were found so the sentence was still being generated, in that case we don't need
- # to apply any padding so just skip to the next sequence of tokens.
- if each_seq_stop_token_index.sum() == 0:
- continue
- stm = each_seq_stop_token_index.argmax()
- speech_ids[i, stm:] = decoder_fixing_codes[0]
- if stm - 3 < speech_ids.shape[1]:
- speech_ids[i, -3:] = torch.tensor(
- [decoder_fixing_codes[1:]], device=speech_ids.device, dtype=torch.long
- )
- return speech_ids
- @can_return_tuple
- @auto_docstring(
- custom_intro="""
- This method can be used to extract text_embeds from a text. The text embeddings obtained by applying the
- projection layer to the pooled output of the CLVP text encoder model.
- """
- )
- def get_text_features(
- self,
- input_ids: torch.LongTensor | None = None,
- text_encoder_inputs_embeds: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | ClvpEncoderOutput:
- r"""
- text_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
- inputs_embeds for the text encoder model passed in place of `input_ids`.
- Examples:
- ```python
- >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
- >>> # Define the Text
- >>> text = "This is an example text."
- >>> # Define processor and model
- >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
- >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
- >>> # Generate processor output and text embeds
- >>> processor_output = processor(text=text, return_tensors="pt")
- >>> text_embeds = model.get_text_features(input_ids=processor_output["input_ids"])
- ```
- """
- return self.text_encoder_model(
- input_ids=input_ids,
- inputs_embeds=text_encoder_inputs_embeds,
- attention_mask=attention_mask,
- **kwargs,
- )
- def get_speech_features(
- self,
- speech_ids: torch.LongTensor | None = None,
- input_ids: torch.LongTensor | None = None,
- input_features: torch.FloatTensor | None = None,
- conditioning_encoder_inputs_embeds: torch.FloatTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- generation_config: GenerationConfig | None = None,
- **kwargs,
- ) -> torch.FloatTensor:
- r"""
- This method can be used to extract speech_embeds. The speech embeddings are obtained by applying the speech
- model on speech_ids. If speech_ids is not present but both input_ids and input_features are given then the
- decoder model will be used to first generate the speech_ids and then applying the speech model.
- Args:
- speech_ids (`torch.LongTensor` of shape `(batch_size, num_speech_ids)`, *optional*):
- Speech Tokens. Padding will be ignored by default should you provide it. If speech_ids are provided
- then input_ids and input_features will be automatically ignored.
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Input text Tokens. Processed from the [`ClvpTokenizer`]. If speech_ids is not provided, then input_ids
- and input_features will be used.
- conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
- inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding speech token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- generation_config (`GenerationConfig`, *optional*):
- generation config to control the generation of speech_ids if they are not provided.
- Returns:
- `torch.FloatTensor` of shape `(batch_size, output_dim)`:
- The speech embeddings obtained by applying the projection layer to the pooled output of the CLVP Speech
- Model.
- Examples:
- ```python
- >>> import datasets
- >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
- >>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library)
- >>> text = "This is an example text."
- >>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
- >>> audio = ds.sort("id")["audio"][0]
- >>> audio_sample, sr = audio["array"], audio["sampling_rate"]
- >>> # Define processor and model
- >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
- >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
- >>> # Generate processor output and model output
- >>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt")
- >>> speech_embeds = model.get_speech_features(
- ... input_ids=processor_output["input_ids"], input_features=processor_output["input_features"]
- ... )
- ```
- """
- if speech_ids is None:
- if (input_ids is None and conditioning_encoder_inputs_embeds is None) or input_features is None:
- raise ValueError(
- "Either speech_ids or input_ids/conditioning_encoder_inputs_embeds and input_features must be provided."
- )
- if generation_config is None:
- generation_config = self.generation_config
- generation_config.update(**kwargs)
- conditioning_embeds = self.conditioning_encoder(
- input_features=input_features,
- input_ids=input_ids,
- inputs_embeds=conditioning_encoder_inputs_embeds,
- attention_mask=attention_mask,
- )
- speech_ids = self.speech_decoder_model.generate(
- conditioning_embeds=conditioning_embeds,
- generation_config=generation_config,
- )
- speech_ids = self.fix_speech_decoder_output(speech_ids[0])
- outputs = self.speech_encoder_model(
- input_ids=speech_ids,
- attention_mask=attention_mask,
- )
- return outputs[0]
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- input_features: torch.FloatTensor | None = None,
- conditioning_encoder_inputs_embeds: torch.FloatTensor | None = None,
- text_encoder_inputs_embeds: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- return_loss: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | ClvpOutput:
- r"""
- conditioning_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
- inputs_embeds for `ClvpConditioningEncoder`. Can be used in place of `input_ids`.
- text_encoder_inputs_embeds (`torch.FloatTensor`, *optional*):
- inputs_embeds for the text encoder model passed in place of `input_ids`.
- return_loss (`bool`, *optional*):
- Whether or not to return the contrastive loss.
- Examples:
- ```python
- >>> import datasets
- >>> from transformers import ClvpProcessor, ClvpModelForConditionalGeneration
- >>> # Define the Text and Load the Audio (We are taking an audio example from HuggingFace Hub using `datasets` library)
- >>> text = "This is an example text."
- >>> ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
- >>> ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
- >>> audio = ds.sort("id")["audio"][0]
- >>> audio_sample, sr = audio["array"], audio["sampling_rate"]
- >>> # Define processor and model
- >>> processor = ClvpProcessor.from_pretrained("susnato/clvp_dev")
- >>> model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev")
- >>> # processor outputs and model outputs
- >>> processor_output = processor(raw_speech=audio_sample, sampling_rate=sr, text=text, return_tensors="pt")
- >>> outputs = model(
- ... input_ids=processor_output["input_ids"],
- ... input_features=processor_output["input_features"],
- ... return_dict=True,
- ... )
- ```
- """
- conditioning_embeds = self.conditioning_encoder(
- input_features=input_features,
- input_ids=input_ids,
- inputs_embeds=conditioning_encoder_inputs_embeds,
- attention_mask=attention_mask,
- )
- decoder_outputs: CausalLMOutputWithCrossAttentions = self.speech_decoder_model(
- inputs_embeds=conditioning_embeds,
- **kwargs,
- )
- speech_ids = decoder_outputs.logits
- # since we will get the embeds of shape `(batch_size, seq_len, embedding_dim)` during the forward pass
- # we must convert it to tokens, to make it compaitable with speech_transformer
- if speech_ids.ndim == 3:
- speech_ids = speech_ids.argmax(2)
- speech_ids = self.fix_speech_decoder_output(speech_ids)
- speech_outputs: ClvpEncoderOutput = self.speech_encoder_model(
- input_ids=speech_ids,
- **kwargs,
- )
- text_outputs: ClvpEncoderOutput = self.text_encoder_model(
- input_ids=input_ids,
- inputs_embeds=text_encoder_inputs_embeds,
- attention_mask=attention_mask,
- **kwargs,
- )
- speech_embeds = speech_outputs.embeds
- text_embeds = text_outputs.embeds
- # normalized features
- speech_embeds = speech_embeds / speech_embeds.norm(p=2, dim=-1, keepdim=True)
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
- # cosine similarity as logits
- logit_scale = self.logit_scale.exp()
- logits_per_text = torch.matmul(text_embeds, speech_embeds.t()) * logit_scale
- logits_per_speech = logits_per_text.t()
- loss = None
- if return_loss:
- loss = clvp_loss(logits_per_text)
- return ClvpOutput(
- loss=loss,
- logits_per_speech=logits_per_speech,
- logits_per_text=logits_per_text,
- text_embeds=text_embeds,
- speech_embeds=speech_embeds,
- text_model_output=text_outputs.pooler_output,
- speech_model_output=speech_outputs.pooler_output,
- decoder_hidden_states=decoder_outputs.hidden_states,
- text_encoder_hidden_states=text_outputs.hidden_states,
- speech_encoder_hidden_states=speech_outputs.hidden_states,
- )
- @torch.no_grad()
- def generate(
- self,
- input_ids: torch.LongTensor | None = None,
- input_features: torch.FloatTensor | None = None,
- attention_mask: torch.LongTensor | None = None,
- generation_config: GenerationConfig | None = None,
- pad_to_max_mel_tokens: int | None = None,
- output_hidden_states: bool | None = None,
- **kwargs,
- ):
- """
- Generate method for `ClvpModelForConditionalGeneration`, this method calls the `generate` method of
- `ClvpForCausalLM` and then uses those generated `speech_ids` to process `text_embeds` and `speech_embeds` using
- `ClvpEncoder`.
- Args:
- input_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Input text Tokens. Processed from the [`ClvpTokenizer`].
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding text token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- generation_config (`~generation.GenerationConfig`, *optional*):
- The generation configuration to be used as base parametrization for the generation call. `**kwargs`
- passed to generate matching the attributes of `generation_config` will override them. If
- `generation_config` is not provided, the default will be used, which had the following loading
- priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
- configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
- default values, whose documentation should be checked to parameterize generation.
- pad_to_max_mel_tokens (`int`, *optional*):
- Pads generated speech_ids to the specified value. This is to implement the same logic from the official
- repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
- and to make sure the logits are same.
- This does not affect generation quality so please don't consider using it since it is less efficient.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
- Returns:
- `ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
- `config.return_dict_in_generate=True`) or a tuple.
- """
- # If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error,
- # because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to
- # properly sample
- sequence_length = input_ids.shape[-1]
- if sequence_length > (self.config.decoder_config.max_text_tokens - 3):
- raise ValueError(
- f"Maximum sequence length reached! Found input_ids of length {sequence_length}."
- f"Please make sure that the maximum length of input_ids is {self.config.decoder_config.max_text_tokens - 3}"
- )
- if generation_config is None:
- generation_config = self.generation_config
- generation_config = copy.deepcopy(generation_config)
- model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
- generation_config.validate()
- self._validate_model_kwargs(model_kwargs.copy())
- # pad input_ids as specified in the original repo
- # link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380
- input_ids, attention_mask = _pad_extra_bos_eos_tokens(
- input_ids,
- attention_mask,
- add_bos_token=False,
- bos_token_id=self.config.text_config.bos_token_id,
- eos_token_id=self.config.text_config.eos_token_id,
- )
- conditioning_embeds = self.conditioning_encoder(
- input_features=input_features,
- input_ids=input_ids,
- attention_mask=attention_mask,
- )
- decoder_outputs = self.speech_decoder_model.generate(
- conditioning_embeds=conditioning_embeds,
- generation_config=generation_config,
- output_hidden_states=output_hidden_states,
- )
- if isinstance(decoder_outputs, ModelOutput):
- speech_ids = decoder_outputs.sequences
- # pad to pad_to_max_mel_tokens if given, to replicate the original repo logic
- # link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
- if pad_to_max_mel_tokens is not None:
- padding_needed = pad_to_max_mel_tokens - speech_ids.shape[-1]
- speech_ids = torch.nn.functional.pad(
- speech_ids, (0, padding_needed), value=self.generation_config.eos_token_id
- )
- speech_ids = self.fix_speech_decoder_output(speech_ids)
- speech_outputs: ClvpEncoderOutput = self.speech_encoder_model(
- input_ids=speech_ids,
- output_hidden_states=output_hidden_states,
- return_dict=generation_config.return_dict_in_generate,
- )
- text_outputs: ClvpEncoderOutput = self.text_encoder_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- output_hidden_states=output_hidden_states,
- return_dict=generation_config.return_dict_in_generate,
- )
- speech_embeds = speech_outputs.embeds
- text_embeds = text_outputs.embeds
- # normalized features
- speech_embeds = speech_embeds / speech_embeds.norm(p=2, dim=-1, keepdim=True)
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
- # cosine similarity as logits
- logit_scale = self.logit_scale.exp()
- logits_per_text = torch.matmul(text_embeds, speech_embeds.t()) * logit_scale
- logits_per_speech = logits_per_text.t()
- if not generation_config.return_dict_in_generate:
- output = (
- speech_ids,
- logits_per_speech,
- logits_per_text,
- text_embeds,
- speech_embeds,
- text_outputs.pooler_output,
- speech_outputs.pooler_output,
- )
- if output_hidden_states:
- output += (
- decoder_outputs[-1],
- text_outputs.hidden_states,
- speech_outputs.hidden_states,
- )
- return output
- return ClvpOutput(
- speech_ids=speech_ids,
- logits_per_speech=logits_per_speech,
- logits_per_text=logits_per_text,
- text_embeds=text_embeds,
- speech_embeds=speech_embeds,
- text_model_output=text_outputs.pooler_output,
- speech_model_output=speech_outputs.pooler_output,
- decoder_hidden_states=decoder_outputs.hidden_states,
- text_encoder_hidden_states=text_outputs.hidden_states,
- speech_encoder_hidden_states=speech_outputs.hidden_states,
- )
- __all__ = [
- "ClvpModelForConditionalGeneration",
- "ClvpForCausalLM",
- "ClvpModel",
- "ClvpPreTrainedModel",
- "ClvpEncoder",
- "ClvpDecoder",
- ]
|