modeling_speecht5.py 136 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095
  1. # Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SpeechT5 model."""
  15. import math
  16. import numpy as np
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  25. from ...integrations.fsdp import is_fsdp_managed_module
  26. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPastAndCrossAttentions,
  31. Seq2SeqLMOutput,
  32. Seq2SeqModelOutput,
  33. Seq2SeqSpectrogramOutput,
  34. )
  35. from ...modeling_utils import EmbeddingAccessMixin, PreTrainedModel
  36. from ...utils import auto_docstring, logging
  37. from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig
  38. logger = logging.get_logger(__name__)
  39. _HIDDEN_STATES_START_POSITION = 1
  40. # General docstring
  41. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  42. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  43. """
  44. Shift input ids one token to the right.
  45. """
  46. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  47. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  48. shifted_input_ids[:, 0] = decoder_start_token_id
  49. if pad_token_id is None:
  50. raise ValueError("self.model.config.pad_token_id has to be defined.")
  51. # replace possible -100 values in labels by `pad_token_id`
  52. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  53. return shifted_input_ids
  54. def shift_spectrograms_right(
  55. input_values: torch.Tensor, reduction_factor: int = 1, attention_mask: torch.Tensor | None = None
  56. ):
  57. """
  58. Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length.
  59. """
  60. # thin out frames for reduction factor
  61. if reduction_factor > 1:
  62. input_values = input_values[:, reduction_factor - 1 :: reduction_factor]
  63. if attention_mask is not None:
  64. attention_mask = attention_mask[:, reduction_factor - 1 :: reduction_factor]
  65. shifted_input_values = input_values.new_zeros(input_values.shape)
  66. shifted_input_values[:, 1:] = input_values[:, :-1].clone()
  67. # replace possible -100 values in labels by zeros
  68. shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0)
  69. return shifted_input_values, attention_mask
  70. # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
  71. def _compute_mask_indices(
  72. shape: tuple[int, int],
  73. mask_prob: float,
  74. mask_length: int,
  75. attention_mask: torch.LongTensor | None = None,
  76. min_masks: int = 0,
  77. ) -> np.ndarray:
  78. """
  79. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  80. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  81. CPU as part of the preprocessing during training.
  82. Args:
  83. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  84. the first element is the batch size and the second element is the length of the axis to span.
  85. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  86. independently generated mask spans of length `mask_length` is computed by
  87. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  88. actual percentage will be smaller.
  89. mask_length: size of the mask
  90. min_masks: minimum number of masked spans
  91. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  92. each batch dimension.
  93. """
  94. batch_size, sequence_length = shape
  95. if mask_length < 1:
  96. raise ValueError("`mask_length` has to be bigger than 0.")
  97. if mask_length > sequence_length:
  98. raise ValueError(
  99. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  100. f" and `sequence_length`: {sequence_length}`"
  101. )
  102. # epsilon is used for probabilistic rounding
  103. epsilon = np.random.rand(1).item()
  104. def compute_num_masked_span(input_length):
  105. """Given input length, compute how many spans should be masked"""
  106. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  107. num_masked_span = max(num_masked_span, min_masks)
  108. # make sure num masked span <= sequence_length
  109. if num_masked_span * mask_length > sequence_length:
  110. num_masked_span = sequence_length // mask_length
  111. # make sure num_masked span is also <= input_length - (mask_length - 1)
  112. if input_length - (mask_length - 1) < num_masked_span:
  113. num_masked_span = max(input_length - (mask_length - 1), 0)
  114. return num_masked_span
  115. # compute number of masked spans in batch
  116. input_lengths = (
  117. attention_mask.detach().sum(-1).tolist()
  118. if attention_mask is not None
  119. else [sequence_length for _ in range(batch_size)]
  120. )
  121. # SpecAugment mask to fill
  122. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  123. spec_aug_mask_idxs = []
  124. max_num_masked_span = compute_num_masked_span(sequence_length)
  125. if max_num_masked_span == 0:
  126. return spec_aug_mask
  127. for input_length in input_lengths:
  128. # compute num of masked spans for this input
  129. num_masked_span = compute_num_masked_span(input_length)
  130. # get random indices to mask
  131. spec_aug_mask_idx = np.random.choice(
  132. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  133. )
  134. # pick first sampled index that will serve as a dummy index to pad vector
  135. # to ensure same dimension for all batches due to probabilistic rounding
  136. # Picking first sample just pads those vectors twice.
  137. if len(spec_aug_mask_idx) == 0:
  138. # this case can only happen if `input_length` is strictly smaller then
  139. # `sequence_length` in which case the last token has to be a padding
  140. # token which we can use as a dummy mask id
  141. dummy_mask_idx = sequence_length - 1
  142. else:
  143. dummy_mask_idx = spec_aug_mask_idx[0]
  144. spec_aug_mask_idx = np.concatenate(
  145. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  146. )
  147. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  148. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  149. # expand masked indices to masked spans
  150. spec_aug_mask_idxs = np.broadcast_to(
  151. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  152. )
  153. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  154. # add offset to the starting indexes so that indexes now create a span
  155. offsets = np.arange(mask_length)[None, None, :]
  156. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  157. batch_size, max_num_masked_span * mask_length
  158. )
  159. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  160. # ensure that we cannot have indices larger than sequence_length
  161. if spec_aug_mask_idxs.max() > sequence_length - 1:
  162. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  163. # scatter indices to mask
  164. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  165. return spec_aug_mask
  166. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SpeechT5
  167. class SpeechT5NoLayerNormConvLayer(GradientCheckpointingLayer):
  168. def __init__(self, config, layer_id=0):
  169. super().__init__()
  170. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  171. self.out_conv_dim = config.conv_dim[layer_id]
  172. self.conv = nn.Conv1d(
  173. self.in_conv_dim,
  174. self.out_conv_dim,
  175. kernel_size=config.conv_kernel[layer_id],
  176. stride=config.conv_stride[layer_id],
  177. bias=config.conv_bias,
  178. )
  179. self.activation = ACT2FN[config.feat_extract_activation]
  180. def forward(self, hidden_states):
  181. hidden_states = self.conv(hidden_states)
  182. hidden_states = self.activation(hidden_states)
  183. return hidden_states
  184. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SpeechT5
  185. class SpeechT5LayerNormConvLayer(GradientCheckpointingLayer):
  186. def __init__(self, config, layer_id=0):
  187. super().__init__()
  188. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  189. self.out_conv_dim = config.conv_dim[layer_id]
  190. self.conv = nn.Conv1d(
  191. self.in_conv_dim,
  192. self.out_conv_dim,
  193. kernel_size=config.conv_kernel[layer_id],
  194. stride=config.conv_stride[layer_id],
  195. bias=config.conv_bias,
  196. )
  197. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  198. self.activation = ACT2FN[config.feat_extract_activation]
  199. def forward(self, hidden_states):
  200. hidden_states = self.conv(hidden_states)
  201. hidden_states = hidden_states.transpose(-2, -1)
  202. hidden_states = self.layer_norm(hidden_states)
  203. hidden_states = hidden_states.transpose(-2, -1)
  204. hidden_states = self.activation(hidden_states)
  205. return hidden_states
  206. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SpeechT5
  207. class SpeechT5GroupNormConvLayer(GradientCheckpointingLayer):
  208. def __init__(self, config, layer_id=0):
  209. super().__init__()
  210. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  211. self.out_conv_dim = config.conv_dim[layer_id]
  212. self.conv = nn.Conv1d(
  213. self.in_conv_dim,
  214. self.out_conv_dim,
  215. kernel_size=config.conv_kernel[layer_id],
  216. stride=config.conv_stride[layer_id],
  217. bias=config.conv_bias,
  218. )
  219. self.activation = ACT2FN[config.feat_extract_activation]
  220. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  221. def forward(self, hidden_states):
  222. hidden_states = self.conv(hidden_states)
  223. hidden_states = self.layer_norm(hidden_states)
  224. hidden_states = self.activation(hidden_states)
  225. return hidden_states
  226. # Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->SpeechT5
  227. class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
  228. """This module produces sinusoidal positional embeddings of any length."""
  229. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None):
  230. super().__init__()
  231. self.offset = 2
  232. self.num_positions = num_positions
  233. self.embedding_dim = embedding_dim
  234. self.padding_idx = padding_idx
  235. self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
  236. def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
  237. emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
  238. if hasattr(self, "weights"):
  239. # in forward put the weights on the correct dtype and device of the param
  240. emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
  241. self.register_buffer("weights", emb_weights, persistent=False)
  242. @staticmethod
  243. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
  244. """
  245. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
  246. description in Section 3.5 of "Attention Is All You Need".
  247. """
  248. half_dim = embedding_dim // 2
  249. emb = math.log(10000) / (half_dim - 1)
  250. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  251. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  252. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  253. if embedding_dim % 2 == 1:
  254. # zero pad
  255. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  256. if padding_idx is not None:
  257. emb[padding_idx, :] = 0
  258. return emb.to(torch.get_default_dtype())
  259. @torch.no_grad()
  260. def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
  261. bsz, seq_len = input_ids.size()
  262. # Create the position ids from the input token ids. Any padded tokens remain padded.
  263. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
  264. input_ids.device
  265. )
  266. # expand embeddings if needed
  267. max_pos = self.padding_idx + 1 + seq_len
  268. if max_pos > self.weights.size(0):
  269. self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
  270. return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
  271. def create_position_ids_from_input_ids(
  272. self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: int | None = 0
  273. ):
  274. """
  275. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  276. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  277. Args:
  278. x: torch.Tensor x:
  279. Returns: torch.Tensor
  280. """
  281. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  282. mask = input_ids.ne(padding_idx).int()
  283. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  284. return incremental_indices.long() + padding_idx
  285. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SpeechT5
  286. class SpeechT5PositionalConvEmbedding(nn.Module):
  287. def __init__(self, config):
  288. super().__init__()
  289. self.conv = nn.Conv1d(
  290. config.hidden_size,
  291. config.hidden_size,
  292. kernel_size=config.num_conv_pos_embeddings,
  293. padding=config.num_conv_pos_embeddings // 2,
  294. groups=config.num_conv_pos_embedding_groups,
  295. )
  296. weight_norm = nn.utils.weight_norm
  297. if hasattr(nn.utils.parametrizations, "weight_norm"):
  298. weight_norm = nn.utils.parametrizations.weight_norm
  299. if is_deepspeed_zero3_enabled():
  300. import deepspeed
  301. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  302. self.conv = weight_norm(self.conv, name="weight", dim=2)
  303. if hasattr(self.conv, "parametrizations"):
  304. weight_g = self.conv.parametrizations.weight.original0
  305. weight_v = self.conv.parametrizations.weight.original1
  306. else:
  307. weight_g = self.conv.weight_g
  308. weight_v = self.conv.weight_v
  309. deepspeed.zero.register_external_parameter(self, weight_v)
  310. deepspeed.zero.register_external_parameter(self, weight_g)
  311. else:
  312. self.conv = weight_norm(self.conv, name="weight", dim=2)
  313. self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings)
  314. self.activation = ACT2FN[config.feat_extract_activation]
  315. def forward(self, hidden_states):
  316. hidden_states = hidden_states.transpose(1, 2)
  317. hidden_states = self.conv(hidden_states)
  318. hidden_states = self.padding(hidden_states)
  319. hidden_states = self.activation(hidden_states)
  320. hidden_states = hidden_states.transpose(1, 2)
  321. return hidden_states
  322. class SpeechT5ScaledPositionalEncoding(nn.Module):
  323. """
  324. Scaled positional encoding, see §3.2 in https://huggingface.co/papers/1809.08895
  325. """
  326. def __init__(self, dropout, dim, max_len=5000):
  327. pe = torch.zeros(max_len, dim)
  328. position = torch.arange(0, max_len).unsqueeze(1)
  329. div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim))
  330. pe[:, 0::2] = torch.sin(position.float() * div_term)
  331. pe[:, 1::2] = torch.cos(position.float() * div_term)
  332. pe = pe.unsqueeze(0)
  333. super().__init__()
  334. self.register_buffer("pe", pe, persistent=False)
  335. self.dropout = nn.Dropout(p=dropout)
  336. self.dim = dim
  337. self.max_len = max_len
  338. self.alpha = nn.Parameter(torch.tensor(1.0))
  339. def forward(self, emb):
  340. emb = emb + self.alpha * self.pe[:, : emb.size(1)]
  341. emb = self.dropout(emb)
  342. return emb
  343. class SpeechT5RelativePositionalEncoding(torch.nn.Module):
  344. def __init__(self, dim, max_length=1000):
  345. super().__init__()
  346. self.dim = dim
  347. self.max_length = max_length
  348. self.pe_k = torch.nn.Embedding(2 * max_length, dim)
  349. def forward(self, hidden_states):
  350. seq_len = hidden_states.shape[1]
  351. pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
  352. pos_seq = pos_seq[:, None] - pos_seq[None, :]
  353. pos_seq = torch.where(pos_seq < -self.max_length, -self.max_length, pos_seq)
  354. pos_seq = torch.where(pos_seq >= self.max_length, self.max_length - 1, pos_seq)
  355. pos_seq = pos_seq + self.max_length
  356. return self.pe_k(pos_seq)
  357. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SpeechT5
  358. class SpeechT5SamePadLayer(nn.Module):
  359. def __init__(self, num_conv_pos_embeddings):
  360. super().__init__()
  361. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  362. def forward(self, hidden_states):
  363. if self.num_pad_remove > 0:
  364. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  365. return hidden_states
  366. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SpeechT5
  367. class SpeechT5FeatureEncoder(nn.Module):
  368. """Construct the features from raw audio waveform"""
  369. def __init__(self, config):
  370. super().__init__()
  371. if config.feat_extract_norm == "group":
  372. conv_layers = [SpeechT5GroupNormConvLayer(config, layer_id=0)] + [
  373. SpeechT5NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  374. ]
  375. elif config.feat_extract_norm == "layer":
  376. conv_layers = [
  377. SpeechT5LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
  378. ]
  379. else:
  380. raise ValueError(
  381. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  382. )
  383. self.conv_layers = nn.ModuleList(conv_layers)
  384. self.gradient_checkpointing = False
  385. self._requires_grad = True
  386. def _freeze_parameters(self):
  387. for param in self.parameters():
  388. param.requires_grad = False
  389. self._requires_grad = False
  390. def forward(self, input_values):
  391. hidden_states = input_values[:, None]
  392. # make sure hidden_states require grad for gradient_checkpointing
  393. if self._requires_grad and self.training:
  394. hidden_states.requires_grad = True
  395. for conv_layer in self.conv_layers:
  396. hidden_states = conv_layer(hidden_states)
  397. return hidden_states
  398. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->SpeechT5
  399. class SpeechT5FeatureProjection(nn.Module):
  400. def __init__(self, config):
  401. super().__init__()
  402. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  403. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  404. self.dropout = nn.Dropout(config.feat_proj_dropout)
  405. def forward(self, hidden_states):
  406. # non-projected hidden states are needed for quantization
  407. norm_hidden_states = self.layer_norm(hidden_states)
  408. hidden_states = self.projection(norm_hidden_states)
  409. hidden_states = self.dropout(hidden_states)
  410. return hidden_states, norm_hidden_states
  411. class SpeechT5SpeechEncoderPrenet(nn.Module):
  412. def __init__(self, config):
  413. super().__init__()
  414. self.config = config
  415. self.feature_encoder = SpeechT5FeatureEncoder(config)
  416. self.feature_projection = SpeechT5FeatureProjection(config)
  417. # model only needs masking vector if mask prob is > 0.0
  418. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  419. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  420. self.pos_conv_embed = SpeechT5PositionalConvEmbedding(config)
  421. self.pos_sinusoidal_embed = SpeechT5SinusoidalPositionalEmbedding(
  422. config.max_speech_positions + config.pad_token_id + 1,
  423. config.hidden_size,
  424. config.pad_token_id,
  425. )
  426. def freeze_feature_encoder(self):
  427. self.feature_encoder._freeze_parameters()
  428. def forward(
  429. self,
  430. input_values: torch.Tensor,
  431. attention_mask: torch.LongTensor | None = None,
  432. mask_time_indices: torch.FloatTensor | None = None,
  433. ):
  434. extract_features = self.feature_encoder(input_values)
  435. extract_features = extract_features.transpose(1, 2)
  436. if attention_mask is not None:
  437. # compute reduced attention_mask corresponding to feature vectors
  438. attention_mask = self._get_feature_vector_attention_mask(
  439. extract_features.shape[1],
  440. attention_mask,
  441. )
  442. hidden_states, extract_features = self.feature_projection(extract_features)
  443. hidden_states = self._mask_hidden_states(
  444. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  445. )
  446. positional_conv_embedding = self.pos_conv_embed(hidden_states)
  447. hidden_states = hidden_states + positional_conv_embedding
  448. if attention_mask is not None:
  449. padding_mask = attention_mask.ne(1).long()
  450. else:
  451. padding_mask = torch.zeros(hidden_states.shape[:2], dtype=torch.long, device=hidden_states.device)
  452. positional_sinusoidal_embeddings = self.pos_sinusoidal_embed(padding_mask)
  453. hidden_states = hidden_states + positional_sinusoidal_embeddings
  454. return hidden_states, attention_mask
  455. # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feature_vector_attention_mask
  456. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  457. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  458. # on inference mode.
  459. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  460. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
  461. batch_size = attention_mask.shape[0]
  462. attention_mask = torch.zeros(
  463. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  464. )
  465. # these two operations makes sure that all values before the output lengths idxs are attended to
  466. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  467. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  468. return attention_mask
  469. # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feat_extract_output_lengths
  470. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int):
  471. """
  472. Computes the output length of the convolutional layers
  473. """
  474. def _conv_out_length(input_length, kernel_size, stride):
  475. # 1D convolutional layer output length formula taken
  476. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  477. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  478. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  479. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  480. return input_lengths
  481. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
  482. def _mask_hidden_states(
  483. self,
  484. hidden_states: torch.FloatTensor,
  485. mask_time_indices: torch.FloatTensor | None = None,
  486. attention_mask: torch.LongTensor | None = None,
  487. ):
  488. """
  489. Masks extracted features along time axis and/or along feature axis according to
  490. [SpecAugment](https://huggingface.co/papers/1904.08779).
  491. """
  492. # `config.apply_spec_augment` can set masking to False
  493. if not getattr(self.config, "apply_spec_augment", True):
  494. return hidden_states
  495. # generate indices & apply SpecAugment along time axis
  496. batch_size, sequence_length, hidden_size = hidden_states.size()
  497. if mask_time_indices is not None:
  498. # apply SpecAugment along time axis with given mask_time_indices
  499. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  500. elif self.config.mask_time_prob > 0 and self.training:
  501. mask_time_indices = _compute_mask_indices(
  502. (batch_size, sequence_length),
  503. mask_prob=self.config.mask_time_prob,
  504. mask_length=self.config.mask_time_length,
  505. attention_mask=attention_mask,
  506. min_masks=self.config.mask_time_min_masks,
  507. )
  508. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  509. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  510. if self.config.mask_feature_prob > 0 and self.training:
  511. # generate indices & apply SpecAugment along feature axis
  512. mask_feature_indices = _compute_mask_indices(
  513. (batch_size, hidden_size),
  514. mask_prob=self.config.mask_feature_prob,
  515. mask_length=self.config.mask_feature_length,
  516. min_masks=self.config.mask_feature_min_masks,
  517. )
  518. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  519. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  520. hidden_states[mask_feature_indices] = 0
  521. return hidden_states
  522. class SpeechT5SpeechDecoderPrenet(nn.Module):
  523. def __init__(self, config):
  524. super().__init__()
  525. self.config = config
  526. self.layers = nn.ModuleList(
  527. [
  528. nn.Linear(
  529. config.num_mel_bins if i == 0 else config.speech_decoder_prenet_units,
  530. config.speech_decoder_prenet_units,
  531. )
  532. for i in range(config.speech_decoder_prenet_layers)
  533. ]
  534. )
  535. self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size)
  536. self.encode_positions = SpeechT5ScaledPositionalEncoding(
  537. config.positional_dropout,
  538. config.hidden_size,
  539. config.max_speech_positions,
  540. )
  541. self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size)
  542. def _consistent_dropout(self, inputs_embeds, p):
  543. mask = torch.bernoulli(inputs_embeds[0], p=p)
  544. all_masks = mask.unsqueeze(0).repeat(inputs_embeds.size(0), 1, 1)
  545. return torch.where(all_masks == 1, inputs_embeds, 0) * 1 / (1 - p)
  546. def forward(
  547. self,
  548. input_values: torch.Tensor,
  549. speaker_embeddings: torch.Tensor | None = None,
  550. ):
  551. # Dropout is always applied, even when evaluating. See §2.2 in https://huggingface.co/papers/1712.05884.
  552. inputs_embeds = input_values
  553. for layer in self.layers:
  554. inputs_embeds = nn.functional.relu(layer(inputs_embeds))
  555. inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout)
  556. inputs_embeds = self.final_layer(inputs_embeds)
  557. inputs_embeds = self.encode_positions(inputs_embeds)
  558. if speaker_embeddings is not None:
  559. speaker_embeddings = nn.functional.normalize(speaker_embeddings)
  560. speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1)
  561. inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
  562. inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))
  563. return inputs_embeds
  564. class SpeechT5BatchNormConvLayer(nn.Module):
  565. def __init__(self, config, layer_id=0):
  566. super().__init__()
  567. if layer_id == 0:
  568. in_conv_dim = config.num_mel_bins
  569. else:
  570. in_conv_dim = config.speech_decoder_postnet_units
  571. if layer_id == config.speech_decoder_postnet_layers - 1:
  572. out_conv_dim = config.num_mel_bins
  573. else:
  574. out_conv_dim = config.speech_decoder_postnet_units
  575. self.conv = nn.Conv1d(
  576. in_conv_dim,
  577. out_conv_dim,
  578. kernel_size=config.speech_decoder_postnet_kernel,
  579. stride=1,
  580. padding=(config.speech_decoder_postnet_kernel - 1) // 2,
  581. bias=False,
  582. )
  583. self.batch_norm = nn.BatchNorm1d(out_conv_dim)
  584. if layer_id < config.speech_decoder_postnet_layers - 1:
  585. self.activation = nn.Tanh()
  586. else:
  587. self.activation = None
  588. self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout)
  589. def forward(self, hidden_states):
  590. hidden_states = self.conv(hidden_states)
  591. hidden_states = self.batch_norm(hidden_states)
  592. if self.activation is not None:
  593. hidden_states = self.activation(hidden_states)
  594. hidden_states = self.dropout(hidden_states)
  595. return hidden_states
  596. class SpeechT5SpeechDecoderPostnet(nn.Module):
  597. def __init__(self, config):
  598. super().__init__()
  599. self.config = config
  600. self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor)
  601. self.prob_out = nn.Linear(config.hidden_size, config.reduction_factor)
  602. self.layers = nn.ModuleList(
  603. [SpeechT5BatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)]
  604. )
  605. def forward(self, hidden_states: torch.Tensor):
  606. outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins)
  607. outputs_after_postnet = self.postnet(outputs_before_postnet)
  608. logits = self.prob_out(hidden_states).view(hidden_states.size(0), -1)
  609. return outputs_before_postnet, outputs_after_postnet, logits
  610. def postnet(self, hidden_states: torch.Tensor):
  611. layer_output = hidden_states.transpose(1, 2)
  612. for layer in self.layers:
  613. layer_output = layer(layer_output)
  614. return hidden_states + layer_output.transpose(1, 2)
  615. class SpeechT5TextEncoderPrenet(nn.Module, EmbeddingAccessMixin):
  616. def __init__(self, config):
  617. super().__init__()
  618. self.config = config
  619. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
  620. self.encode_positions = SpeechT5ScaledPositionalEncoding(
  621. config.positional_dropout,
  622. config.hidden_size,
  623. config.max_text_positions,
  624. )
  625. def forward(self, input_ids: torch.Tensor):
  626. inputs_embeds = self.embed_tokens(input_ids)
  627. inputs_embeds = self.encode_positions(inputs_embeds)
  628. return inputs_embeds
  629. class SpeechT5TextDecoderPrenet(nn.Module, EmbeddingAccessMixin):
  630. def __init__(self, config):
  631. super().__init__()
  632. self.config = config
  633. self.dropout = nn.Dropout(config.positional_dropout)
  634. self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  635. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
  636. self.embed_positions = SpeechT5SinusoidalPositionalEmbedding(
  637. config.max_text_positions + config.pad_token_id + 1,
  638. config.hidden_size,
  639. config.pad_token_id,
  640. )
  641. def forward(
  642. self,
  643. input_ids: torch.Tensor,
  644. attention_mask: torch.LongTensor | None = None,
  645. past_key_values: Cache | None = None,
  646. ):
  647. if input_ids is not None:
  648. input_shape = input_ids.size()
  649. input_ids = input_ids.view(-1, input_shape[-1])
  650. else:
  651. raise ValueError("You have to specify `decoder_input_ids`")
  652. past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
  653. positions = self.embed_positions(input_ids, past_key_values_length)
  654. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  655. inputs_embeds += positions
  656. inputs_embeds = self.dropout(inputs_embeds)
  657. return inputs_embeds, attention_mask
  658. class SpeechT5TextDecoderPostnet(nn.Module, EmbeddingAccessMixin):
  659. def __init__(self, config):
  660. super().__init__()
  661. self.config = config
  662. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  663. def forward(self, hidden_states: torch.Tensor):
  664. return self.lm_head(hidden_states)
  665. def get_output_embeddings(self):
  666. # Post-net has no token embeddings, but its lm_head must still be
  667. # tied to the decoder weights when `tie_word_embeddings=True`.
  668. return self.lm_head
  669. def set_output_embeddings(self, new_embeddings):
  670. self.lm_head = new_embeddings
  671. class SpeechT5Attention(nn.Module):
  672. """
  673. Multi-headed attention from 'Attention Is All You Need' paper with relative position bias (see
  674. https://aclanthology.org/N18-2074.pdf)
  675. """
  676. def __init__(
  677. self,
  678. embed_dim: int,
  679. num_heads: int,
  680. dropout: float | None = 0.0,
  681. is_decoder: bool | None = False,
  682. bias: bool | None = True,
  683. layer_idx: bool | None = None,
  684. ):
  685. super().__init__()
  686. self.embed_dim = embed_dim
  687. self.num_heads = num_heads
  688. self.dropout = dropout
  689. self.head_dim = embed_dim // num_heads
  690. if (self.head_dim * num_heads) != self.embed_dim:
  691. raise ValueError(
  692. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  693. f" and `num_heads`: {num_heads})."
  694. )
  695. self.scaling = self.head_dim**-0.5
  696. self.is_decoder = is_decoder
  697. self.layer_idx = layer_idx
  698. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  699. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  700. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  701. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  702. def forward(
  703. self,
  704. hidden_states: torch.Tensor,
  705. key_value_states: torch.Tensor | None = None,
  706. past_key_values: Cache | None = None,
  707. attention_mask: torch.Tensor | None = None,
  708. position_bias: torch.Tensor | None = None,
  709. output_attentions: bool = False,
  710. **kwargs,
  711. ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
  712. """Input shape: Batch x Time x Channel"""
  713. # if key_value_states are provided this layer is used as a cross-attention layer
  714. # for the decoder
  715. is_cross_attention = key_value_states is not None
  716. bsz, tgt_len, _ = hidden_states.size()
  717. # get query proj
  718. query_states = self.q_proj(hidden_states) * self.scaling
  719. is_updated = False
  720. if past_key_values is not None:
  721. if isinstance(past_key_values, EncoderDecoderCache):
  722. is_updated = past_key_values.is_updated.get(self.layer_idx)
  723. if is_cross_attention:
  724. # after the first generated id, we can subsequently re-use all key/value_states from cache
  725. curr_past_key_values = past_key_values.cross_attention_cache
  726. else:
  727. curr_past_key_values = past_key_values.self_attention_cache
  728. else:
  729. curr_past_key_values = past_key_values
  730. current_states = key_value_states if is_cross_attention else hidden_states
  731. if is_cross_attention and past_key_values is not None and is_updated:
  732. # reuse k,v, cross_attentions
  733. key_states = curr_past_key_values.layers[self.layer_idx].keys
  734. value_states = curr_past_key_values.layers[self.layer_idx].values
  735. else:
  736. key_states = self.k_proj(current_states)
  737. value_states = self.v_proj(current_states)
  738. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  739. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  740. if past_key_values is not None:
  741. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  742. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  743. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  744. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  745. past_key_values.is_updated[self.layer_idx] = True
  746. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  747. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  748. query_states = query_states.reshape(*proj_shape)
  749. key_states = key_states.reshape(*proj_shape)
  750. value_states = value_states.reshape(*proj_shape)
  751. src_len = key_states.size(1)
  752. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  753. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  754. raise ValueError(
  755. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  756. f" {attn_weights.size()}"
  757. )
  758. # relative attention bias
  759. if position_bias is not None:
  760. reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1)
  761. rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
  762. rel_pos_bias = rel_pos_bias.transpose(0, 1).view(
  763. bsz * self.num_heads, position_bias.size(0), position_bias.size(1)
  764. )
  765. attn_weights += rel_pos_bias
  766. if attention_mask is not None:
  767. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  768. raise ValueError(
  769. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  770. )
  771. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  772. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  773. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  774. if output_attentions:
  775. # this operation is a bit awkward, but it's required to
  776. # make sure that attn_weights keeps its gradient.
  777. # In order to do so, attn_weights have to be reshaped
  778. # twice and have to be reused in the following
  779. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  780. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  781. else:
  782. attn_weights_reshaped = None
  783. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  784. attn_output = torch.bmm(attn_probs, value_states)
  785. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  786. raise ValueError(
  787. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  788. f" {attn_output.size()}"
  789. )
  790. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  791. attn_output = attn_output.transpose(1, 2)
  792. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  793. # partitioned across GPUs when using tensor-parallelism.
  794. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  795. attn_output = self.out_proj(attn_output)
  796. return attn_output, attn_weights_reshaped
  797. class SpeechT5FeedForward(nn.Module):
  798. def __init__(self, config, intermediate_size):
  799. super().__init__()
  800. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  801. self.intermediate_dense = nn.Linear(config.hidden_size, intermediate_size)
  802. if isinstance(config.hidden_act, str):
  803. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  804. else:
  805. self.intermediate_act_fn = config.hidden_act
  806. self.output_dense = nn.Linear(intermediate_size, config.hidden_size)
  807. self.output_dropout = nn.Dropout(config.hidden_dropout)
  808. def forward(self, hidden_states):
  809. hidden_states = self.intermediate_dense(hidden_states)
  810. hidden_states = self.intermediate_act_fn(hidden_states)
  811. hidden_states = self.intermediate_dropout(hidden_states)
  812. hidden_states = self.output_dense(hidden_states)
  813. hidden_states = self.output_dropout(hidden_states)
  814. return hidden_states
  815. class SpeechT5EncoderLayer(GradientCheckpointingLayer):
  816. def __init__(self, config: SpeechT5Config):
  817. super().__init__()
  818. self.attention = SpeechT5Attention(
  819. embed_dim=config.hidden_size,
  820. num_heads=config.encoder_attention_heads,
  821. dropout=config.attention_dropout,
  822. is_decoder=False,
  823. )
  824. self.dropout = nn.Dropout(config.hidden_dropout)
  825. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  826. self.feed_forward = SpeechT5FeedForward(config, config.encoder_ffn_dim)
  827. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  828. def forward(
  829. self,
  830. hidden_states: torch.Tensor,
  831. attention_mask: torch.Tensor | None = None,
  832. position_bias: torch.Tensor | None = None,
  833. output_attentions: bool = False,
  834. ):
  835. """
  836. Args:
  837. hidden_states (`torch.FloatTensor`):
  838. input to the layer of shape `(batch, seq_len, hidden_size)`
  839. attention_mask (`torch.FloatTensor`):
  840. attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very
  841. large negative values.
  842. position_bias (`torch.FloatTensor`):
  843. relative position embeddings of size `(seq_len, seq_len, hidden_size // encoder_attention_heads)`
  844. output_attentions (`bool`, *optional*):
  845. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  846. returned tensors for more detail.
  847. """
  848. residual = hidden_states
  849. hidden_states, attn_weights = self.attention(
  850. hidden_states=hidden_states,
  851. attention_mask=attention_mask,
  852. position_bias=position_bias,
  853. output_attentions=output_attentions,
  854. )
  855. hidden_states = self.dropout(hidden_states)
  856. hidden_states = residual + hidden_states
  857. hidden_states = self.layer_norm(hidden_states)
  858. hidden_states = hidden_states + self.feed_forward(hidden_states)
  859. hidden_states = self.final_layer_norm(hidden_states)
  860. outputs = (hidden_states,)
  861. if output_attentions:
  862. outputs += (attn_weights,)
  863. return outputs
  864. class SpeechT5DecoderLayer(GradientCheckpointingLayer):
  865. def __init__(self, config: SpeechT5Config, layer_idx=None):
  866. super().__init__()
  867. self.self_attn = SpeechT5Attention(
  868. embed_dim=config.hidden_size,
  869. num_heads=config.decoder_attention_heads,
  870. dropout=config.attention_dropout,
  871. is_decoder=True,
  872. layer_idx=layer_idx,
  873. )
  874. self.dropout = nn.Dropout(config.hidden_dropout)
  875. self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  876. self.encoder_attn = SpeechT5Attention(
  877. config.hidden_size,
  878. config.decoder_attention_heads,
  879. dropout=config.attention_dropout,
  880. is_decoder=True,
  881. layer_idx=layer_idx,
  882. )
  883. self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  884. self.feed_forward = SpeechT5FeedForward(config, config.decoder_ffn_dim)
  885. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  886. def forward(
  887. self,
  888. hidden_states: torch.Tensor,
  889. attention_mask: torch.Tensor | None = None,
  890. encoder_hidden_states: torch.Tensor | None = None,
  891. encoder_attention_mask: torch.Tensor | None = None,
  892. past_key_values: Cache | None = None,
  893. output_attentions: bool | None = False,
  894. use_cache: bool | None = True,
  895. **kwargs,
  896. ):
  897. """
  898. Args:
  899. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
  900. attention_mask (`torch.FloatTensor`): attention mask of size
  901. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  902. encoder_hidden_states (`torch.FloatTensor`):
  903. cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
  904. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  905. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  906. past_key_values (`Cache`): cached past key and value projection states
  907. output_attentions (`bool`, *optional*):
  908. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  909. returned tensors for more detail.
  910. """
  911. residual = hidden_states
  912. # Self Attention
  913. hidden_states, self_attn_weights = self.self_attn(
  914. hidden_states=hidden_states,
  915. past_key_values=past_key_values,
  916. attention_mask=attention_mask,
  917. output_attentions=output_attentions,
  918. )
  919. hidden_states = self.dropout(hidden_states)
  920. hidden_states = residual + hidden_states
  921. hidden_states = self.self_attn_layer_norm(hidden_states)
  922. # Cross-Attention Block
  923. cross_attn_weights = None
  924. if encoder_hidden_states is not None:
  925. residual = hidden_states
  926. hidden_states, cross_attn_weights = self.encoder_attn(
  927. hidden_states=hidden_states,
  928. key_value_states=encoder_hidden_states,
  929. attention_mask=encoder_attention_mask,
  930. past_key_values=past_key_values,
  931. output_attentions=output_attentions,
  932. )
  933. hidden_states = self.dropout(hidden_states)
  934. hidden_states = residual + hidden_states
  935. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  936. # Fully Connected
  937. hidden_states = hidden_states + self.feed_forward(hidden_states)
  938. hidden_states = self.final_layer_norm(hidden_states)
  939. outputs = (hidden_states,)
  940. if output_attentions:
  941. outputs += (self_attn_weights, cross_attn_weights)
  942. return outputs
  943. @auto_docstring
  944. class SpeechT5PreTrainedModel(PreTrainedModel):
  945. config: SpeechT5Config
  946. base_model_prefix = "speecht5"
  947. main_input_name = "input_values"
  948. input_modalities = "audio"
  949. supports_gradient_checkpointing = True
  950. @torch.no_grad()
  951. def _init_weights(self, module: nn.Module):
  952. """Initialize the weights"""
  953. std = self.config.initializer_range
  954. if isinstance(module, SpeechT5PositionalConvEmbedding):
  955. init.normal_(
  956. module.conv.weight,
  957. mean=0,
  958. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  959. )
  960. init.constant_(module.conv.bias, 0)
  961. elif isinstance(module, SpeechT5ScaledPositionalEncoding):
  962. init.ones_(module.alpha)
  963. dim, max_len = module.dim, module.max_len
  964. pe = torch.zeros(max_len, dim)
  965. position = torch.arange(0, max_len).unsqueeze(1)
  966. div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim))
  967. pe[:, 0::2] = torch.sin(position.float() * div_term)
  968. pe[:, 1::2] = torch.cos(position.float() * div_term)
  969. pe = pe.unsqueeze(0)
  970. init.copy_(module.pe, pe)
  971. elif isinstance(module, SpeechT5FeatureProjection):
  972. k = math.sqrt(1 / module.projection.in_features)
  973. init.uniform_(module.projection.weight, a=-k, b=k)
  974. init.uniform_(module.projection.bias, a=-k, b=k)
  975. elif isinstance(module, nn.Linear):
  976. init.normal_(module.weight, mean=0.0, std=std)
  977. if module.bias is not None:
  978. init.zeros_(module.bias)
  979. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
  980. init.zeros_(module.bias)
  981. init.ones_(module.weight)
  982. if getattr(module, "running_mean", None) is not None:
  983. init.zeros_(module.running_mean)
  984. init.ones_(module.running_var)
  985. init.zeros_(module.num_batches_tracked)
  986. elif isinstance(module, nn.Conv1d):
  987. init.kaiming_normal_(module.weight)
  988. if module.bias is not None:
  989. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  990. init.uniform_(module.bias, a=-k, b=k)
  991. elif isinstance(module, nn.Embedding):
  992. init.normal_(module.weight, mean=0.0, std=std)
  993. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  994. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  995. init.zeros_(module.weight[module.padding_idx])
  996. elif isinstance(module, SpeechT5SinusoidalPositionalEmbedding):
  997. emb_weights = module.get_embedding(
  998. module.num_positions + module.offset, module.embedding_dim, module.padding_idx
  999. )
  1000. init.copy_(module.weights, emb_weights)
  1001. elif isinstance(module, SpeechT5HifiGan):
  1002. init.zeros_(module.mean)
  1003. init.ones_(module.scale)
  1004. if hasattr(module, "masked_spec_embed"):
  1005. init.uniform_(module.masked_spec_embed)
  1006. class SpeechT5Encoder(SpeechT5PreTrainedModel):
  1007. """
  1008. Transformer encoder consisting of *config.encoder_layers* layers. Each layer is a [`SpeechT5EncoderLayer`].
  1009. """
  1010. def __init__(self, config: SpeechT5Config):
  1011. super().__init__(config)
  1012. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1013. self.dropout = nn.Dropout(config.hidden_dropout)
  1014. self.layerdrop = config.encoder_layerdrop
  1015. self.layers = nn.ModuleList([SpeechT5EncoderLayer(config) for _ in range(config.encoder_layers)])
  1016. self.embed_positions = SpeechT5RelativePositionalEncoding(
  1017. config.hidden_size // config.encoder_attention_heads, config.encoder_max_relative_position
  1018. )
  1019. self.gradient_checkpointing = False
  1020. # Initialize weights and apply final processing
  1021. self.post_init()
  1022. def forward(
  1023. self,
  1024. hidden_states: torch.FloatTensor,
  1025. attention_mask: torch.Tensor | None = None,
  1026. output_attentions: bool | None = None,
  1027. output_hidden_states: bool | None = None,
  1028. return_dict: bool | None = None,
  1029. **kwargs,
  1030. ) -> tuple | BaseModelOutput:
  1031. """
  1032. Args:
  1033. hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
  1034. Features extracted from the speech or text input by the encoder prenet.
  1035. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1036. Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
  1037. `[0, 1]`:
  1038. - 1 for tokens that are **not masked**,
  1039. - 0 for tokens that are **masked**.
  1040. [What are attention masks?](../glossary#attention-mask)
  1041. output_attentions (`bool`, *optional*):
  1042. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1043. returned tensors for more detail.
  1044. output_hidden_states (`bool`, *optional*):
  1045. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1046. for more detail.
  1047. return_dict (`bool`, *optional*):
  1048. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1049. """
  1050. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1051. output_hidden_states = (
  1052. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1053. )
  1054. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1055. attention_mask = create_bidirectional_mask(
  1056. config=self.config,
  1057. inputs_embeds=hidden_states,
  1058. attention_mask=attention_mask,
  1059. )
  1060. hidden_states = self.layer_norm(hidden_states)
  1061. hidden_states = self.dropout(hidden_states)
  1062. position_bias = self.embed_positions(hidden_states)
  1063. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  1064. all_hidden_states = () if output_hidden_states else None
  1065. all_self_attentions = () if output_attentions else None
  1066. for idx, encoder_layer in enumerate(self.layers):
  1067. if output_hidden_states:
  1068. all_hidden_states = all_hidden_states + (hidden_states,)
  1069. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  1070. skip_the_layer = False
  1071. if self.training:
  1072. dropout_probability = torch.rand([])
  1073. skip_the_layer = dropout_probability < self.layerdrop
  1074. if not skip_the_layer or synced_gpus:
  1075. # under fsdp or deepspeed zero3 all gpus must run in sync
  1076. layer_outputs = encoder_layer(
  1077. hidden_states,
  1078. attention_mask=attention_mask,
  1079. position_bias=position_bias,
  1080. output_attentions=output_attentions,
  1081. )
  1082. hidden_states = layer_outputs[0]
  1083. if skip_the_layer:
  1084. layer_outputs = (None, None)
  1085. if output_attentions:
  1086. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  1087. if output_hidden_states:
  1088. all_hidden_states = all_hidden_states + (hidden_states,)
  1089. if not return_dict:
  1090. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  1091. return BaseModelOutput(
  1092. last_hidden_state=hidden_states,
  1093. hidden_states=all_hidden_states,
  1094. attentions=all_self_attentions,
  1095. )
  1096. class SpeechT5EncoderWithSpeechPrenet(SpeechT5PreTrainedModel):
  1097. """
  1098. Wrapper around SpeechT5Encoder that applies SpeechT5SpeechEncoderPrenet to convert the audio waveform data to
  1099. hidden features.
  1100. """
  1101. def __init__(self, config: SpeechT5Config):
  1102. super().__init__(config)
  1103. self.prenet = SpeechT5SpeechEncoderPrenet(config)
  1104. self.wrapped_encoder = SpeechT5Encoder(config)
  1105. # Initialize weights and apply final processing
  1106. self.post_init()
  1107. def forward(
  1108. self,
  1109. input_values: torch.FloatTensor,
  1110. attention_mask: torch.Tensor | None = None,
  1111. output_attentions: bool | None = None,
  1112. output_hidden_states: bool | None = None,
  1113. return_dict: bool | None = None,
  1114. **kwargs,
  1115. ) -> tuple | BaseModelOutput:
  1116. hidden_states, attention_mask = self.prenet(input_values, attention_mask)
  1117. outputs = self.wrapped_encoder(
  1118. hidden_states=hidden_states,
  1119. attention_mask=attention_mask,
  1120. output_attentions=output_attentions,
  1121. output_hidden_states=output_hidden_states,
  1122. return_dict=return_dict,
  1123. )
  1124. return outputs
  1125. class SpeechT5EncoderWithTextPrenet(SpeechT5PreTrainedModel):
  1126. """
  1127. Wrapper around SpeechT5Encoder that applies SpeechT5TextEncoderPrenet to convert the input_ids to hidden features.
  1128. """
  1129. def __init__(self, config: SpeechT5Config):
  1130. super().__init__(config)
  1131. self.prenet = SpeechT5TextEncoderPrenet(config)
  1132. self.wrapped_encoder = SpeechT5Encoder(config)
  1133. # Initialize weights and apply final processing
  1134. self.post_init()
  1135. def get_input_embeddings(self):
  1136. return self.prenet.get_input_embeddings()
  1137. def set_input_embeddings(self, value):
  1138. self.prenet.set_input_embeddings(value)
  1139. def forward(
  1140. self,
  1141. input_values: torch.FloatTensor,
  1142. attention_mask: torch.Tensor | None = None,
  1143. output_attentions: bool | None = None,
  1144. output_hidden_states: bool | None = None,
  1145. return_dict: bool | None = None,
  1146. **kwargs,
  1147. ) -> tuple | BaseModelOutput:
  1148. hidden_states = self.prenet(input_values)
  1149. outputs = self.wrapped_encoder(
  1150. hidden_states=hidden_states,
  1151. attention_mask=attention_mask,
  1152. output_attentions=output_attentions,
  1153. output_hidden_states=output_hidden_states,
  1154. return_dict=return_dict,
  1155. )
  1156. return outputs
  1157. class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel):
  1158. """
  1159. This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with
  1160. [`SpeechT5Model`].
  1161. """
  1162. def __init__(self, config: SpeechT5Config):
  1163. super().__init__(config)
  1164. self.wrapped_encoder = SpeechT5Encoder(config)
  1165. # Initialize weights and apply final processing
  1166. self.post_init()
  1167. def forward(
  1168. self,
  1169. input_values: torch.FloatTensor,
  1170. attention_mask: torch.Tensor | None = None,
  1171. output_attentions: bool | None = None,
  1172. output_hidden_states: bool | None = None,
  1173. return_dict: bool | None = None,
  1174. **kwargs,
  1175. ) -> tuple | BaseModelOutput:
  1176. return self.wrapped_encoder(
  1177. hidden_states=input_values,
  1178. attention_mask=attention_mask,
  1179. output_attentions=output_attentions,
  1180. output_hidden_states=output_hidden_states,
  1181. return_dict=return_dict,
  1182. )
  1183. class SpeechT5Decoder(SpeechT5PreTrainedModel):
  1184. """
  1185. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SpeechT5DecoderLayer`]
  1186. """
  1187. def __init__(self, config: SpeechT5Config):
  1188. super().__init__(config)
  1189. self.layerdrop = config.decoder_layerdrop
  1190. self.layers = nn.ModuleList([SpeechT5DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  1191. self.gradient_checkpointing = False
  1192. # Initialize weights and apply final processing
  1193. self.post_init()
  1194. def forward(
  1195. self,
  1196. hidden_states: torch.FloatTensor | None = None,
  1197. attention_mask: torch.LongTensor | None = None,
  1198. encoder_hidden_states: torch.FloatTensor | None = None,
  1199. encoder_attention_mask: torch.LongTensor | None = None,
  1200. past_key_values: Cache | None = None,
  1201. use_cache: bool | None = None,
  1202. output_attentions: bool | None = None,
  1203. output_hidden_states: bool | None = None,
  1204. return_dict: bool | None = None,
  1205. **kwargs,
  1206. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  1207. r"""
  1208. Args:
  1209. hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
  1210. Features extracted from the speech or text input by the decoder prenet.
  1211. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1212. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1213. - 1 for tokens that are **not masked**,
  1214. - 0 for tokens that are **masked**.
  1215. [What are attention masks?](../glossary#attention-mask)
  1216. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  1217. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1218. of the decoder.
  1219. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  1220. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  1221. selected in `[0, 1]`:
  1222. - 1 for tokens that are **not masked**,
  1223. - 0 for tokens that are **masked**.
  1224. [What are attention masks?](../glossary#attention-mask)
  1225. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1226. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  1227. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  1228. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  1229. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  1230. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  1231. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1232. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1233. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  1234. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  1235. than the model's internal embedding lookup matrix.
  1236. output_attentions (`bool`, *optional*):
  1237. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1238. returned tensors for more detail.
  1239. output_hidden_states (`bool`, *optional*):
  1240. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1241. for more detail.
  1242. return_dict (`bool`, *optional*):
  1243. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1244. """
  1245. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1246. output_hidden_states = (
  1247. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1248. )
  1249. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1250. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1251. if self.gradient_checkpointing and self.training:
  1252. if use_cache:
  1253. logger.warning_once(
  1254. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1255. )
  1256. use_cache = False
  1257. if use_cache and past_key_values is None:
  1258. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  1259. attention_mask = create_causal_mask(
  1260. config=self.config,
  1261. inputs_embeds=hidden_states,
  1262. attention_mask=attention_mask,
  1263. past_key_values=past_key_values,
  1264. )
  1265. # expand encoder attention mask
  1266. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  1267. encoder_attention_mask = create_bidirectional_mask(
  1268. config=self.config,
  1269. inputs_embeds=hidden_states,
  1270. attention_mask=encoder_attention_mask,
  1271. encoder_hidden_states=encoder_hidden_states,
  1272. )
  1273. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  1274. # decoder layers
  1275. all_hidden_states = () if output_hidden_states else None
  1276. all_self_attentions = () if output_attentions else None
  1277. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  1278. for idx, decoder_layer in enumerate(self.layers):
  1279. if output_hidden_states:
  1280. all_hidden_states = all_hidden_states + (hidden_states,)
  1281. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  1282. skip_the_layer = False
  1283. if self.training:
  1284. dropout_probability = torch.rand([])
  1285. skip_the_layer = dropout_probability < self.layerdrop
  1286. if skip_the_layer and not synced_gpus:
  1287. continue
  1288. layer_outputs = decoder_layer(
  1289. hidden_states,
  1290. attention_mask,
  1291. encoder_hidden_states, # as a positional argument for gradient checkpointing
  1292. encoder_attention_mask=encoder_attention_mask,
  1293. past_key_values=past_key_values,
  1294. output_attentions=output_attentions,
  1295. use_cache=use_cache,
  1296. )
  1297. hidden_states = layer_outputs[0]
  1298. if output_attentions:
  1299. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  1300. if encoder_hidden_states is not None:
  1301. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  1302. if output_hidden_states:
  1303. all_hidden_states = all_hidden_states + (hidden_states,)
  1304. if not return_dict:
  1305. return tuple(
  1306. v
  1307. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
  1308. if v is not None
  1309. )
  1310. return BaseModelOutputWithPastAndCrossAttentions(
  1311. last_hidden_state=hidden_states,
  1312. past_key_values=past_key_values,
  1313. hidden_states=all_hidden_states,
  1314. attentions=all_self_attentions,
  1315. cross_attentions=all_cross_attentions,
  1316. )
  1317. class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel):
  1318. """
  1319. Wrapper around SpeechT5Decoder that applies SpeechT5SpeechDecoderPrenet to convert log-mel filterbanks to hidden
  1320. features.
  1321. """
  1322. def __init__(self, config: SpeechT5Config):
  1323. super().__init__(config)
  1324. self.prenet = SpeechT5SpeechDecoderPrenet(config)
  1325. self.wrapped_decoder = SpeechT5Decoder(config)
  1326. # Initialize weights and apply final processing
  1327. self.post_init()
  1328. def forward(
  1329. self,
  1330. input_values: torch.FloatTensor | None = None,
  1331. attention_mask: torch.LongTensor | None = None,
  1332. encoder_hidden_states: torch.FloatTensor | None = None,
  1333. encoder_attention_mask: torch.LongTensor | None = None,
  1334. speaker_embeddings: torch.Tensor | None = None,
  1335. past_key_values: Cache | None = None,
  1336. use_cache: bool | None = None,
  1337. output_attentions: bool | None = None,
  1338. output_hidden_states: bool | None = None,
  1339. return_dict: bool | None = None,
  1340. **kwargs,
  1341. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  1342. decoder_hidden_states = self.prenet(input_values, speaker_embeddings)
  1343. outputs = self.wrapped_decoder(
  1344. hidden_states=decoder_hidden_states,
  1345. attention_mask=attention_mask,
  1346. encoder_hidden_states=encoder_hidden_states,
  1347. encoder_attention_mask=encoder_attention_mask,
  1348. past_key_values=past_key_values,
  1349. use_cache=use_cache,
  1350. output_attentions=output_attentions,
  1351. output_hidden_states=output_hidden_states,
  1352. return_dict=return_dict,
  1353. )
  1354. return outputs
  1355. class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel):
  1356. """
  1357. Wrapper around SpeechT5Decoder that applies SpeechT5TextDecoderPrenet to convert input tokens to hidden features.
  1358. """
  1359. def __init__(self, config: SpeechT5Config):
  1360. super().__init__(config)
  1361. self.prenet = SpeechT5TextDecoderPrenet(config)
  1362. self.wrapped_decoder = SpeechT5Decoder(config)
  1363. # Initialize weights and apply final processing
  1364. self.post_init()
  1365. def get_input_embeddings(self):
  1366. return self.prenet.get_input_embeddings()
  1367. def set_input_embeddings(self, value):
  1368. self.prenet.set_input_embeddings(value)
  1369. def forward(
  1370. self,
  1371. input_values: torch.FloatTensor | None = None,
  1372. attention_mask: torch.LongTensor | None = None,
  1373. encoder_hidden_states: torch.FloatTensor | None = None,
  1374. encoder_attention_mask: torch.LongTensor | None = None,
  1375. past_key_values: Cache | None = None,
  1376. use_cache: bool | None = None,
  1377. output_attentions: bool | None = None,
  1378. output_hidden_states: bool | None = None,
  1379. return_dict: bool | None = None,
  1380. **kwargs,
  1381. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  1382. decoder_hidden_states, attention_mask = self.prenet(input_values, attention_mask, past_key_values)
  1383. outputs = self.wrapped_decoder(
  1384. hidden_states=decoder_hidden_states,
  1385. attention_mask=attention_mask,
  1386. encoder_hidden_states=encoder_hidden_states,
  1387. encoder_attention_mask=encoder_attention_mask,
  1388. past_key_values=past_key_values,
  1389. use_cache=use_cache,
  1390. output_attentions=output_attentions,
  1391. output_hidden_states=output_hidden_states,
  1392. return_dict=return_dict,
  1393. )
  1394. return outputs
  1395. class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel):
  1396. """
  1397. This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with
  1398. [`SpeechT5Model`].
  1399. """
  1400. def __init__(self, config: SpeechT5Config):
  1401. super().__init__(config)
  1402. self.wrapped_decoder = SpeechT5Decoder(config)
  1403. # Initialize weights and apply final processing
  1404. self.post_init()
  1405. def forward(
  1406. self,
  1407. input_values: torch.FloatTensor | None = None,
  1408. attention_mask: torch.LongTensor | None = None,
  1409. encoder_hidden_states: torch.FloatTensor | None = None,
  1410. encoder_attention_mask: torch.LongTensor | None = None,
  1411. past_key_values: Cache | None = None,
  1412. use_cache: bool | None = None,
  1413. output_attentions: bool | None = None,
  1414. output_hidden_states: bool | None = None,
  1415. return_dict: bool | None = None,
  1416. **kwargs,
  1417. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  1418. outputs = self.wrapped_decoder(
  1419. hidden_states=input_values,
  1420. attention_mask=attention_mask,
  1421. encoder_hidden_states=encoder_hidden_states,
  1422. encoder_attention_mask=encoder_attention_mask,
  1423. past_key_values=past_key_values,
  1424. use_cache=use_cache,
  1425. output_attentions=output_attentions,
  1426. output_hidden_states=output_hidden_states,
  1427. return_dict=return_dict,
  1428. )
  1429. return outputs
  1430. class SpeechT5GuidedMultiheadAttentionLoss(nn.Module):
  1431. """
  1432. Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
  1433. Networks with Guided Attention](https://huggingface.co/papers/1710.08969), adapted for multi-head attention.
  1434. """
  1435. def __init__(self, config: SpeechT5Config):
  1436. super().__init__()
  1437. self.sigma = config.guided_attention_loss_sigma
  1438. self.scale = config.guided_attention_loss_scale
  1439. def forward(
  1440. self, attentions: torch.FloatTensor, input_masks: torch.BoolTensor, output_masks: torch.BoolTensor
  1441. ) -> torch.Tensor:
  1442. """
  1443. Compute the attention loss.
  1444. Args:
  1445. attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`):
  1446. Batch of multi-head attention weights
  1447. input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`):
  1448. Input attention mask as booleans.
  1449. output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`):
  1450. Target attention mask as booleans.
  1451. Returns:
  1452. `torch.Tensor` with the loss value
  1453. """
  1454. guided_attn_masks = self._make_guided_attention_masks(input_masks, output_masks, attentions.device)
  1455. masks = output_masks.unsqueeze(-1) & input_masks.unsqueeze(-2)
  1456. masks = masks.to(attentions.device).unsqueeze(1)
  1457. losses = guided_attn_masks * attentions
  1458. loss = torch.mean(losses.masked_select(masks))
  1459. return self.scale * loss
  1460. def _make_guided_attention_masks(self, input_masks, output_masks, device):
  1461. input_lengths = input_masks.sum(-1)
  1462. output_lengths = output_masks.sum(-1)
  1463. guided_attn_masks = torch.zeros((len(input_masks), output_masks.shape[1], input_masks.shape[1]), device=device)
  1464. for idx, (ilen, olen) in enumerate(zip(input_lengths, output_lengths)):
  1465. guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma, device)
  1466. return guided_attn_masks.unsqueeze(1)
  1467. @staticmethod
  1468. def _make_guided_attention_mask(input_length, output_length, sigma, device):
  1469. grid_y, grid_x = torch.meshgrid(
  1470. torch.arange(input_length, device=device),
  1471. torch.arange(output_length, device=device),
  1472. indexing="xy",
  1473. )
  1474. grid_x = grid_x.float() / output_length
  1475. grid_y = grid_y.float() / input_length
  1476. return 1.0 - torch.exp(-((grid_y - grid_x) ** 2) / (2 * (sigma**2)))
  1477. class SpeechT5SpectrogramLoss(nn.Module):
  1478. """
  1479. Loss computation used by SpeechT5ForTextToSpeech.
  1480. """
  1481. def __init__(self, config: SpeechT5Config):
  1482. super().__init__()
  1483. self.use_guided_attention_loss = config.use_guided_attention_loss
  1484. self.guided_attention_loss_num_heads = config.guided_attention_loss_num_heads
  1485. self.reduction_factor = config.reduction_factor
  1486. self.l1_criterion = L1Loss()
  1487. self.bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0))
  1488. if self.use_guided_attention_loss:
  1489. self.attn_criterion = SpeechT5GuidedMultiheadAttentionLoss(config)
  1490. def forward(
  1491. self,
  1492. attention_mask: torch.LongTensor,
  1493. outputs_before_postnet: torch.FloatTensor,
  1494. outputs_after_postnet: torch.FloatTensor,
  1495. logits: torch.FloatTensor,
  1496. labels: torch.FloatTensor,
  1497. cross_attentions: torch.FloatTensor | None = None,
  1498. ) -> torch.Tensor:
  1499. padding_mask = labels != -100.0
  1500. # mask out the padded portions
  1501. labels = labels.masked_select(padding_mask)
  1502. outputs_before_postnet = outputs_before_postnet.masked_select(padding_mask)
  1503. outputs_after_postnet = outputs_after_postnet.masked_select(padding_mask)
  1504. # spectrogram loss
  1505. l1_loss = self.l1_criterion(outputs_after_postnet, labels) + self.l1_criterion(outputs_before_postnet, labels)
  1506. # construct stop labels from the padding mask
  1507. masks = padding_mask[:, :, 0]
  1508. stop_labels = torch.cat([~masks * 1.0, torch.ones(masks.size(0), 1).to(masks.device)], dim=1)
  1509. stop_labels = stop_labels[:, 1:].masked_select(masks)
  1510. logits = logits.masked_select(masks)
  1511. # stop token loss
  1512. bce_loss = self.bce_criterion(logits, stop_labels)
  1513. # combined loss
  1514. loss = l1_loss + bce_loss
  1515. # guided attention loss
  1516. if self.use_guided_attention_loss:
  1517. attn = torch.cat([x[:, : self.guided_attention_loss_num_heads] for x in cross_attentions], dim=1)
  1518. input_masks = attention_mask == 1
  1519. output_masks = padding_mask[:, :, 0]
  1520. if self.reduction_factor > 1:
  1521. output_masks = output_masks[:, self.reduction_factor - 1 :: self.reduction_factor]
  1522. attn_loss = self.attn_criterion(attn, input_masks, output_masks)
  1523. loss += attn_loss
  1524. return loss
  1525. @auto_docstring(
  1526. custom_intro="""
  1527. The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.
  1528. """
  1529. )
  1530. class SpeechT5Model(SpeechT5PreTrainedModel):
  1531. def __init__(
  1532. self,
  1533. config: SpeechT5Config,
  1534. encoder: nn.Module | None = None,
  1535. decoder: nn.Module | None = None,
  1536. ):
  1537. r"""
  1538. encoder (`PreTrainedModel`, *optional*):
  1539. The encoder model to use.
  1540. decoder (`PreTrainedModel`, *optional*):
  1541. The decoder model to use.
  1542. """
  1543. super().__init__(config)
  1544. self.config = config
  1545. self.encoder = SpeechT5EncoderWithoutPrenet(config) if encoder is None else encoder
  1546. self.decoder = SpeechT5DecoderWithoutPrenet(config) if decoder is None else decoder
  1547. # Initialize weights and apply final processing
  1548. self.post_init()
  1549. def get_input_embeddings(self):
  1550. if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet):
  1551. return self.encoder.get_input_embeddings()
  1552. if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet):
  1553. return self.decoder.get_input_embeddings()
  1554. raise NotImplementedError
  1555. def set_input_embeddings(self, value):
  1556. if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet):
  1557. self.encoder.set_input_embeddings(value)
  1558. if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet):
  1559. self.decoder.set_input_embeddings(value)
  1560. def freeze_feature_encoder(self):
  1561. """
  1562. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1563. not be updated during training.
  1564. """
  1565. if isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet):
  1566. self.encoder.prenet.freeze_feature_encoder()
  1567. @auto_docstring
  1568. def forward(
  1569. self,
  1570. input_values: torch.Tensor | None = None,
  1571. attention_mask: torch.LongTensor | None = None,
  1572. decoder_input_values: torch.Tensor | None = None,
  1573. decoder_attention_mask: torch.LongTensor | None = None,
  1574. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  1575. past_key_values: Cache | None = None,
  1576. use_cache: bool | None = None,
  1577. speaker_embeddings: torch.FloatTensor | None = None,
  1578. output_attentions: bool | None = None,
  1579. output_hidden_states: bool | None = None,
  1580. return_dict: bool | None = None,
  1581. **kwargs,
  1582. ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput:
  1583. r"""
  1584. input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  1585. Depending on which encoder is being used, the `input_values` are either: float values of the input raw
  1586. speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states.
  1587. decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1588. Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel
  1589. filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in
  1590. the vocabulary, or hidden states.
  1591. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1592. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
  1593. also be used by default.
  1594. If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
  1595. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1596. information on the default strategy.
  1597. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  1598. Tensor containing the speaker embeddings.
  1599. """
  1600. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1601. output_hidden_states = (
  1602. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1603. )
  1604. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1605. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1606. # Encode if needed (training, first prediction pass)
  1607. if encoder_outputs is None:
  1608. encoder_outputs = self.encoder(
  1609. input_values=input_values,
  1610. attention_mask=attention_mask,
  1611. output_attentions=output_attentions,
  1612. output_hidden_states=output_hidden_states,
  1613. return_dict=return_dict,
  1614. )
  1615. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1616. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1617. encoder_outputs = BaseModelOutput(
  1618. last_hidden_state=encoder_outputs[0],
  1619. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1620. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1621. )
  1622. # downsample encoder attention mask (only for encoders with speech input)
  1623. if attention_mask is not None and isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet):
  1624. encoder_attention_mask = self.encoder.prenet._get_feature_vector_attention_mask(
  1625. encoder_outputs[0].shape[1], attention_mask
  1626. )
  1627. else:
  1628. encoder_attention_mask = attention_mask
  1629. if isinstance(self.decoder, SpeechT5DecoderWithSpeechPrenet):
  1630. decoder_args = {"speaker_embeddings": speaker_embeddings}
  1631. else:
  1632. decoder_args = {}
  1633. decoder_outputs = self.decoder(
  1634. input_values=decoder_input_values,
  1635. attention_mask=decoder_attention_mask,
  1636. encoder_hidden_states=encoder_outputs[0],
  1637. encoder_attention_mask=encoder_attention_mask,
  1638. past_key_values=past_key_values,
  1639. use_cache=use_cache,
  1640. output_attentions=output_attentions,
  1641. output_hidden_states=output_hidden_states,
  1642. return_dict=return_dict,
  1643. **decoder_args,
  1644. )
  1645. if not return_dict:
  1646. return decoder_outputs + encoder_outputs
  1647. return Seq2SeqModelOutput(
  1648. last_hidden_state=decoder_outputs.last_hidden_state,
  1649. past_key_values=decoder_outputs.past_key_values,
  1650. decoder_hidden_states=decoder_outputs.hidden_states,
  1651. decoder_attentions=decoder_outputs.attentions,
  1652. cross_attentions=decoder_outputs.cross_attentions,
  1653. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1654. encoder_hidden_states=encoder_outputs.hidden_states,
  1655. encoder_attentions=encoder_outputs.attentions,
  1656. )
  1657. @auto_docstring(
  1658. custom_intro="""
  1659. SpeechT5 Model with a speech encoder and a text decoder.
  1660. """
  1661. )
  1662. class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
  1663. _tied_weights_keys = {"text_decoder_postnet.lm_head.weight": "speecht5.decoder.prenet.embed_tokens.weight"}
  1664. def __init__(self, config: SpeechT5Config):
  1665. super().__init__(config)
  1666. if config.vocab_size is None:
  1667. raise ValueError(
  1668. f"You are trying to instantiate {self.__class__} with a configuration that does not define the"
  1669. " vocabulary size of the language model head. Please instantiate the model as follows:"
  1670. " `SpeechT5ForSpeechToText.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of"
  1671. " your model's configuration."
  1672. )
  1673. speech_encoder = SpeechT5EncoderWithSpeechPrenet(config)
  1674. text_decoder = SpeechT5DecoderWithTextPrenet(config)
  1675. self.speecht5 = SpeechT5Model(config, speech_encoder, text_decoder)
  1676. self.text_decoder_postnet = SpeechT5TextDecoderPostnet(config)
  1677. # Initialize weights and apply final processing
  1678. self.post_init()
  1679. def freeze_feature_encoder(self):
  1680. """
  1681. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1682. not be updated during training.
  1683. """
  1684. self.get_encoder().prenet.freeze_feature_encoder()
  1685. def get_output_embeddings(self):
  1686. return self.text_decoder_postnet.get_output_embeddings()
  1687. def set_output_embeddings(self, new_embeddings):
  1688. self.text_decoder_postnet.set_output_embeddings(new_embeddings)
  1689. @auto_docstring
  1690. def forward(
  1691. self,
  1692. input_values: torch.FloatTensor | None = None,
  1693. attention_mask: torch.LongTensor | None = None,
  1694. decoder_input_ids: torch.LongTensor | None = None,
  1695. decoder_attention_mask: torch.LongTensor | None = None,
  1696. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  1697. past_key_values: Cache | None = None,
  1698. use_cache: bool | None = None,
  1699. output_attentions: bool | None = None,
  1700. output_hidden_states: bool | None = None,
  1701. return_dict: bool | None = None,
  1702. labels: torch.LongTensor | None = None,
  1703. **kwargs,
  1704. ) -> tuple | Seq2SeqLMOutput:
  1705. r"""
  1706. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1707. Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
  1708. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1709. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1710. To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding
  1711. and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
  1712. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1713. Indices of decoder input sequence tokens in the vocabulary.
  1714. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1715. [`PreTrainedTokenizer.__call__`] for details.
  1716. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1717. SpeechT5 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
  1718. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1719. `past_key_values`).
  1720. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1721. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
  1722. also be used by default.
  1723. If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
  1724. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1725. information on the default strategy.
  1726. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1727. Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
  1728. or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
  1729. only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1730. Label indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1731. [`PreTrainedTokenizer.__call__`] for details.
  1732. Example:
  1733. ```python
  1734. >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToText
  1735. >>> from datasets import load_dataset
  1736. >>> dataset = load_dataset(
  1737. ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
  1738. ... ) # doctest: +IGNORE_RESULT
  1739. >>> dataset = dataset.sort("id")
  1740. >>> sampling_rate = dataset.features["audio"].sampling_rate
  1741. >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")
  1742. >>> model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr")
  1743. >>> # audio file is decoded on the fly
  1744. >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  1745. >>> predicted_ids = model.generate(**inputs, max_length=100)
  1746. >>> # transcribe speech
  1747. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  1748. >>> transcription[0]
  1749. 'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel'
  1750. ```
  1751. ```python
  1752. >>> inputs["labels"] = processor(text_target=dataset[0]["text"], return_tensors="pt").input_ids
  1753. >>> # compute loss
  1754. >>> loss = model(**inputs).loss
  1755. >>> round(loss.item(), 2)
  1756. 19.68
  1757. ```
  1758. """
  1759. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1760. if labels is not None:
  1761. if decoder_input_ids is None:
  1762. decoder_input_ids = shift_tokens_right(
  1763. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1764. )
  1765. outputs = self.speecht5(
  1766. input_values=input_values,
  1767. attention_mask=attention_mask,
  1768. decoder_input_values=decoder_input_ids,
  1769. decoder_attention_mask=decoder_attention_mask,
  1770. encoder_outputs=encoder_outputs,
  1771. past_key_values=past_key_values,
  1772. use_cache=use_cache,
  1773. output_attentions=output_attentions,
  1774. output_hidden_states=output_hidden_states,
  1775. return_dict=True,
  1776. )
  1777. logits = self.text_decoder_postnet(outputs[0])
  1778. loss = None
  1779. if labels is not None:
  1780. loss_fct = CrossEntropyLoss()
  1781. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1782. if not return_dict:
  1783. output = (logits,) + outputs[1:]
  1784. return ((loss,) + output) if loss is not None else output
  1785. return Seq2SeqLMOutput(
  1786. loss=loss,
  1787. logits=logits,
  1788. past_key_values=outputs.past_key_values,
  1789. decoder_hidden_states=outputs.decoder_hidden_states,
  1790. decoder_attentions=outputs.decoder_attentions,
  1791. cross_attentions=outputs.cross_attentions,
  1792. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1793. encoder_hidden_states=outputs.encoder_hidden_states,
  1794. encoder_attentions=outputs.encoder_attentions,
  1795. )
  1796. def _generate_speech(
  1797. model: SpeechT5PreTrainedModel,
  1798. input_values: torch.FloatTensor,
  1799. speaker_embeddings: torch.FloatTensor | None = None,
  1800. attention_mask: torch.LongTensor | None = None,
  1801. threshold: float = 0.5,
  1802. minlenratio: float = 0.0,
  1803. maxlenratio: float = 20.0,
  1804. vocoder: nn.Module | None = None,
  1805. output_cross_attentions: bool = False,
  1806. return_output_lengths: bool = False,
  1807. ) -> torch.FloatTensor | tuple[torch.FloatTensor, torch.FloatTensor]:
  1808. if speaker_embeddings is None:
  1809. raise ValueError(
  1810. """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following
  1811. the code snippet provided in this link:
  1812. https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors
  1813. """
  1814. )
  1815. if attention_mask is None:
  1816. encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int()
  1817. else:
  1818. encoder_attention_mask = attention_mask
  1819. bsz = input_values.size(0)
  1820. encoder_out = model.speecht5.encoder(
  1821. input_values=input_values,
  1822. attention_mask=encoder_attention_mask,
  1823. return_dict=True,
  1824. )
  1825. encoder_last_hidden_state = encoder_out.last_hidden_state
  1826. # downsample encoder attention mask
  1827. if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet):
  1828. encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask(
  1829. encoder_out[0].shape[1], encoder_attention_mask
  1830. )
  1831. maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor)
  1832. minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor)
  1833. # Start the output sequence with a mel spectrum that is all zeros.
  1834. output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins)
  1835. spectrogram = []
  1836. cross_attentions = []
  1837. past_key_values = None
  1838. idx = 0
  1839. result_spectrogram = {}
  1840. while True:
  1841. idx += 1
  1842. # Run the decoder prenet on the entire output sequence.
  1843. decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)
  1844. # Run the decoder layers on the last element of the prenet output.
  1845. decoder_out = model.speecht5.decoder.wrapped_decoder(
  1846. hidden_states=decoder_hidden_states[:, -1:],
  1847. attention_mask=None,
  1848. encoder_hidden_states=encoder_last_hidden_state,
  1849. encoder_attention_mask=encoder_attention_mask,
  1850. past_key_values=past_key_values,
  1851. use_cache=True,
  1852. output_attentions=output_cross_attentions,
  1853. return_dict=True,
  1854. )
  1855. if output_cross_attentions:
  1856. cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))
  1857. last_decoder_output = decoder_out.last_hidden_state.squeeze(1)
  1858. past_key_values = decoder_out.past_key_values
  1859. # Predict the new mel spectrum for this step in the sequence.
  1860. spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
  1861. spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins)
  1862. spectrogram.append(spectrum)
  1863. # Extend the output sequence with the new mel spectrum.
  1864. new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)
  1865. output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1)
  1866. # Predict the probability that this is the stop token.
  1867. prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))
  1868. if idx < minlen:
  1869. continue
  1870. else:
  1871. # If the generation loop is less than maximum length time, check the ones in the batch that have met
  1872. # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch.
  1873. if idx < maxlen:
  1874. meet_thresholds = torch.sum(prob, dim=-1) >= threshold
  1875. meet_indexes = torch.where(meet_thresholds)[0].tolist()
  1876. else:
  1877. meet_indexes = range(len(prob))
  1878. meet_indexes = [i for i in meet_indexes if i not in result_spectrogram]
  1879. if len(meet_indexes) > 0:
  1880. spectrograms = torch.stack(spectrogram)
  1881. spectrograms = spectrograms.transpose(0, 1).flatten(1, 2)
  1882. spectrograms = model.speech_decoder_postnet.postnet(spectrograms)
  1883. for meet_index in meet_indexes:
  1884. result_spectrogram[meet_index] = spectrograms[meet_index]
  1885. if len(result_spectrogram) >= bsz:
  1886. break
  1887. spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))]
  1888. if not return_output_lengths:
  1889. spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
  1890. if vocoder is not None:
  1891. outputs = vocoder(spectrogram)
  1892. else:
  1893. outputs = spectrogram
  1894. if output_cross_attentions:
  1895. cross_attentions = torch.cat(cross_attentions, dim=2)
  1896. if bsz > 1:
  1897. cross_attentions = cross_attentions.view(
  1898. bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
  1899. )
  1900. outputs = (outputs, cross_attentions)
  1901. else:
  1902. # batched return values should also include the spectrogram/waveform lengths
  1903. spectrogram_lengths = []
  1904. for i in range(bsz):
  1905. spectrogram_lengths.append(spectrograms[i].size(0))
  1906. if vocoder is None:
  1907. spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
  1908. outputs = (spectrograms, spectrogram_lengths)
  1909. else:
  1910. waveforms = []
  1911. spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
  1912. waveforms = vocoder(spectrograms)
  1913. waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
  1914. outputs = (waveforms, waveform_lengths)
  1915. if output_cross_attentions:
  1916. cross_attentions = torch.cat(cross_attentions, dim=2)
  1917. cross_attentions = cross_attentions.view(
  1918. bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
  1919. )
  1920. outputs = (*outputs, cross_attentions)
  1921. return outputs
  1922. @auto_docstring(
  1923. custom_intro="""
  1924. SpeechT5 Model with a text encoder and a speech decoder.
  1925. """
  1926. )
  1927. class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
  1928. input_modalities = ("text",)
  1929. main_input_name = "input_ids"
  1930. def __init__(self, config: SpeechT5Config):
  1931. super().__init__(config)
  1932. if config.vocab_size is None:
  1933. raise ValueError(
  1934. f"You are trying to instantiate {self.__class__} with a configuration that does not define the"
  1935. " vocabulary size of the language model head. Please instantiate the model as follows:"
  1936. " `SpeechT5ForTextToSpeech.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of"
  1937. " your model's configuration."
  1938. )
  1939. text_encoder = SpeechT5EncoderWithTextPrenet(config)
  1940. speech_decoder = SpeechT5DecoderWithSpeechPrenet(config)
  1941. self.speecht5 = SpeechT5Model(config, text_encoder, speech_decoder)
  1942. self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config)
  1943. # Initialize weights and apply final processing
  1944. self.post_init()
  1945. @classmethod
  1946. def can_generate(cls) -> bool:
  1947. # Speecht5 has a unique model structure, where the external class (`SpeechT5ForTextToSpeech`) doesn't need to inherit from
  1948. # `GenerationMixin` (it has a non-standard generation method). This means that the base `can_generate()` will return `False`,
  1949. # but we need to override it so as to do `GenerationConfig` handling in multiple parts of the codebase.
  1950. return True
  1951. @auto_docstring
  1952. def forward(
  1953. self,
  1954. input_ids: torch.LongTensor | None = None,
  1955. attention_mask: torch.LongTensor | None = None,
  1956. decoder_input_values: torch.FloatTensor | None = None,
  1957. decoder_attention_mask: torch.LongTensor | None = None,
  1958. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  1959. past_key_values: Cache | None = None,
  1960. use_cache: bool | None = None,
  1961. output_attentions: bool | None = None,
  1962. output_hidden_states: bool | None = None,
  1963. return_dict: bool | None = None,
  1964. speaker_embeddings: torch.FloatTensor | None = None,
  1965. labels: torch.FloatTensor | None = None,
  1966. stop_labels: torch.Tensor | None = None,
  1967. **kwargs,
  1968. ) -> tuple | Seq2SeqSpectrogramOutput:
  1969. r"""
  1970. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1971. Indices of input sequence tokens in the vocabulary.
  1972. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
  1973. [`~PreTrainedTokenizer.__call__`] for details.
  1974. [What are input IDs?](../glossary#input-ids)
  1975. decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`):
  1976. Float values of input mel spectrogram.
  1977. SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If
  1978. `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see
  1979. `past_key_values`).
  1980. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1981. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
  1982. also be used by default.
  1983. If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
  1984. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1985. information on the default strategy.
  1986. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  1987. Tensor containing the speaker embeddings.
  1988. labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
  1989. Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
  1990. computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`]
  1991. for details.
  1992. stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1993. Binary tensor indicating the position of the stop token in the sequence.
  1994. Example:
  1995. ```python
  1996. >>> from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, set_seed
  1997. >>> import torch
  1998. >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
  1999. >>> model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
  2000. >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
  2001. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  2002. >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file
  2003. >>> set_seed(555) # make deterministic
  2004. >>> # generate speech
  2005. >>> speech = model.generate(inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder)
  2006. >>> speech.shape
  2007. torch.Size([15872])
  2008. ```
  2009. """
  2010. return_dict = return_dict if return_dict is not None else self.config.return_dict
  2011. if labels is not None:
  2012. if decoder_input_values is None:
  2013. decoder_input_values, decoder_attention_mask = shift_spectrograms_right(
  2014. labels, self.config.reduction_factor, decoder_attention_mask
  2015. )
  2016. if self.config.use_guided_attention_loss:
  2017. output_attentions = True
  2018. outputs = self.speecht5(
  2019. input_values=input_ids,
  2020. attention_mask=attention_mask,
  2021. decoder_input_values=decoder_input_values,
  2022. decoder_attention_mask=decoder_attention_mask,
  2023. encoder_outputs=encoder_outputs,
  2024. past_key_values=past_key_values,
  2025. use_cache=use_cache,
  2026. speaker_embeddings=speaker_embeddings,
  2027. output_attentions=output_attentions,
  2028. output_hidden_states=output_hidden_states,
  2029. return_dict=True,
  2030. )
  2031. outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0])
  2032. loss = None
  2033. if labels is not None:
  2034. criterion = SpeechT5SpectrogramLoss(self.config)
  2035. loss = criterion(
  2036. attention_mask,
  2037. outputs_before_postnet,
  2038. outputs_after_postnet,
  2039. logits,
  2040. labels,
  2041. outputs.cross_attentions,
  2042. )
  2043. if not return_dict:
  2044. output = (outputs_after_postnet,) + outputs[1:]
  2045. return ((loss,) + output) if loss is not None else output
  2046. return Seq2SeqSpectrogramOutput(
  2047. loss=loss,
  2048. spectrogram=outputs_after_postnet,
  2049. past_key_values=outputs.past_key_values,
  2050. decoder_hidden_states=outputs.decoder_hidden_states,
  2051. decoder_attentions=outputs.decoder_attentions,
  2052. cross_attentions=outputs.cross_attentions,
  2053. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  2054. encoder_hidden_states=outputs.encoder_hidden_states,
  2055. encoder_attentions=outputs.encoder_attentions,
  2056. )
  2057. @torch.no_grad()
  2058. def generate(
  2059. self,
  2060. input_ids: torch.LongTensor,
  2061. attention_mask: torch.LongTensor | None = None,
  2062. speaker_embeddings: torch.FloatTensor | None = None,
  2063. threshold: float = 0.5,
  2064. minlenratio: float = 0.0,
  2065. maxlenratio: float = 20.0,
  2066. vocoder: nn.Module | None = None,
  2067. output_cross_attentions: bool = False,
  2068. return_output_lengths: bool = False,
  2069. **kwargs,
  2070. ) -> torch.FloatTensor | tuple[torch.FloatTensor, torch.FloatTensor]:
  2071. r"""
  2072. Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
  2073. speech waveform using a vocoder.
  2074. Args:
  2075. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  2076. Indices of input sequence tokens in the vocabulary.
  2077. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
  2078. [`~PreTrainedTokenizer.__call__`] for details.
  2079. [What are input IDs?](../glossary#input-ids)
  2080. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  2081. Attention mask from the tokenizer, required for batched inference to signal to the model where to
  2082. ignore padded tokens from the input_ids.
  2083. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2084. Tensor containing the speaker embeddings.
  2085. threshold (`float`, *optional*, defaults to 0.5):
  2086. The generated sequence ends when the predicted stop token probability exceeds this value.
  2087. minlenratio (`float`, *optional*, defaults to 0.0):
  2088. Used to calculate the minimum required length for the output sequence.
  2089. maxlenratio (`float`, *optional*, defaults to 20.0):
  2090. Used to calculate the maximum allowed length for the output sequence.
  2091. vocoder (`nn.Module`, *optional*):
  2092. The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
  2093. spectrogram.
  2094. output_cross_attentions (`bool`, *optional*, defaults to `False`):
  2095. Whether or not to return the attentions tensors of the decoder's cross-attention layers.
  2096. return_output_lengths (`bool`, *optional*, defaults to `False`):
  2097. Whether or not to return the concrete spectrogram/waveform lengths.
  2098. Returns:
  2099. `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
  2100. - when `return_output_lengths` is False
  2101. - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2102. `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
  2103. - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2104. `(num_frames,)` -- The predicted speech waveform.
  2105. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2106. `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
  2107. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2108. - when `return_output_lengths` is True
  2109. - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2110. `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
  2111. are padded to the maximum length.
  2112. - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
  2113. all the concrete lengths for each spectrogram.
  2114. - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2115. `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
  2116. - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
  2117. the concrete lengths for each waveform.
  2118. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2119. `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
  2120. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2121. """
  2122. if speaker_embeddings is not None:
  2123. batch_size = input_ids.size(0)
  2124. if speaker_embeddings.size(0) != batch_size:
  2125. if speaker_embeddings.size(0) == 1:
  2126. speaker_embeddings = speaker_embeddings.repeat(batch_size, 1)
  2127. else:
  2128. raise ValueError(
  2129. "The first dimension of speaker_embeddings must be either 1 or the same as batch_size."
  2130. )
  2131. return _generate_speech(
  2132. self,
  2133. input_ids,
  2134. speaker_embeddings,
  2135. attention_mask,
  2136. threshold,
  2137. minlenratio,
  2138. maxlenratio,
  2139. vocoder,
  2140. output_cross_attentions,
  2141. return_output_lengths,
  2142. )
  2143. @torch.no_grad()
  2144. def generate_speech(
  2145. self,
  2146. input_ids: torch.LongTensor,
  2147. speaker_embeddings: torch.FloatTensor | None = None,
  2148. attention_mask: torch.LongTensor | None = None,
  2149. threshold: float = 0.5,
  2150. minlenratio: float = 0.0,
  2151. maxlenratio: float = 20.0,
  2152. vocoder: nn.Module | None = None,
  2153. output_cross_attentions: bool = False,
  2154. return_output_lengths: bool = False,
  2155. ) -> torch.FloatTensor | tuple[torch.FloatTensor, torch.FloatTensor]:
  2156. r"""
  2157. Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
  2158. speech waveform using a vocoder.
  2159. Args:
  2160. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  2161. Indices of input sequence tokens in the vocabulary.
  2162. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
  2163. [`~PreTrainedTokenizer.__call__`] for details.
  2164. [What are input IDs?](../glossary#input-ids)
  2165. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2166. Tensor containing the speaker embeddings.
  2167. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  2168. Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
  2169. `[0, 1]`:
  2170. - 1 for tokens that are **not masked**,
  2171. - 0 for tokens that are **masked**.
  2172. [What are attention masks?](../glossary#attention-mask)
  2173. threshold (`float`, *optional*, defaults to 0.5):
  2174. The generated sequence ends when the predicted stop token probability exceeds this value.
  2175. minlenratio (`float`, *optional*, defaults to 0.0):
  2176. Used to calculate the minimum required length for the output sequence.
  2177. maxlenratio (`float`, *optional*, defaults to 20.0):
  2178. Used to calculate the maximum allowed length for the output sequence.
  2179. vocoder (`nn.Module`, *optional*, defaults to `None`):
  2180. The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
  2181. spectrogram.
  2182. output_cross_attentions (`bool`, *optional*, defaults to `False`):
  2183. Whether or not to return the attentions tensors of the decoder's cross-attention layers.
  2184. return_output_lengths (`bool`, *optional*, defaults to `False`):
  2185. Whether or not to return the concrete spectrogram/waveform lengths.
  2186. Returns:
  2187. `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
  2188. - when `return_output_lengths` is False
  2189. - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2190. `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
  2191. - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2192. `(num_frames,)` -- The predicted speech waveform.
  2193. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2194. `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
  2195. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2196. - when `return_output_lengths` is True
  2197. - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2198. `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
  2199. are padded to the maximum length.
  2200. - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
  2201. all the concrete lengths for each spectrogram.
  2202. - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2203. `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
  2204. - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
  2205. the concrete lengths for each waveform.
  2206. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2207. `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
  2208. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2209. """
  2210. if speaker_embeddings is not None:
  2211. batch_size = input_ids.size(0)
  2212. if speaker_embeddings.size(0) != batch_size:
  2213. if speaker_embeddings.size(0) == 1:
  2214. speaker_embeddings = speaker_embeddings.repeat(batch_size, 1)
  2215. else:
  2216. raise ValueError(
  2217. "The first dimension of speaker_embeddings must be either 1 or the same as batch size."
  2218. )
  2219. return _generate_speech(
  2220. self,
  2221. input_ids,
  2222. speaker_embeddings,
  2223. attention_mask,
  2224. threshold,
  2225. minlenratio,
  2226. maxlenratio,
  2227. vocoder,
  2228. output_cross_attentions,
  2229. return_output_lengths,
  2230. )
  2231. @auto_docstring(
  2232. custom_intro="""
  2233. SpeechT5 Model with a speech encoder and a speech decoder.
  2234. """
  2235. )
  2236. class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
  2237. def __init__(self, config: SpeechT5Config):
  2238. super().__init__(config)
  2239. speech_encoder = SpeechT5EncoderWithSpeechPrenet(config)
  2240. speech_decoder = SpeechT5DecoderWithSpeechPrenet(config)
  2241. self.speecht5 = SpeechT5Model(config, speech_encoder, speech_decoder)
  2242. self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config)
  2243. # Initialize weights and apply final processing
  2244. self.post_init()
  2245. def freeze_feature_encoder(self):
  2246. """
  2247. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  2248. not be updated during training.
  2249. """
  2250. self.get_encoder().prenet.freeze_feature_encoder()
  2251. @auto_docstring
  2252. def forward(
  2253. self,
  2254. input_values: torch.FloatTensor | None = None,
  2255. attention_mask: torch.LongTensor | None = None,
  2256. decoder_input_values: torch.FloatTensor | None = None,
  2257. decoder_attention_mask: torch.LongTensor | None = None,
  2258. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  2259. past_key_values: Cache | None = None,
  2260. use_cache: bool | None = None,
  2261. output_attentions: bool | None = None,
  2262. output_hidden_states: bool | None = None,
  2263. return_dict: bool | None = None,
  2264. speaker_embeddings: torch.FloatTensor | None = None,
  2265. labels: torch.FloatTensor | None = None,
  2266. stop_labels: torch.Tensor | None = None,
  2267. **kwargs,
  2268. ) -> tuple | Seq2SeqSpectrogramOutput:
  2269. r"""
  2270. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  2271. Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
  2272. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  2273. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  2274. To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and conversion into
  2275. a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
  2276. decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`):
  2277. Float values of input mel spectrogram.
  2278. SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If
  2279. `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see
  2280. `past_key_values`).
  2281. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  2282. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
  2283. also be used by default.
  2284. If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
  2285. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  2286. information on the default strategy.
  2287. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2288. Tensor containing the speaker embeddings.
  2289. labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
  2290. Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See
  2291. [`SpeechT5Processor.__call__`] for details.
  2292. stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  2293. Binary tensor indicating the position of the stop token in the sequence.
  2294. Example:
  2295. ```python
  2296. >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToSpeech, SpeechT5HifiGan, set_seed
  2297. >>> from datasets import load_dataset
  2298. >>> import torch
  2299. >>> dataset = load_dataset(
  2300. ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
  2301. ... ) # doctest: +IGNORE_RESULT
  2302. >>> dataset = dataset.sort("id")
  2303. >>> sampling_rate = dataset.features["audio"].sampling_rate
  2304. >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_vc")
  2305. >>> model = SpeechT5ForSpeechToSpeech.from_pretrained("microsoft/speecht5_vc")
  2306. >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
  2307. >>> # audio file is decoded on the fly
  2308. >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  2309. >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file
  2310. >>> set_seed(555) # make deterministic
  2311. >>> # generate speech
  2312. >>> speech = model.generate_speech(inputs["input_values"], speaker_embeddings, vocoder=vocoder)
  2313. >>> speech.shape
  2314. torch.Size([77824])
  2315. ```
  2316. """
  2317. return_dict = return_dict if return_dict is not None else self.config.return_dict
  2318. if labels is not None:
  2319. if decoder_input_values is None:
  2320. decoder_input_values, decoder_attention_mask = shift_spectrograms_right(
  2321. labels, self.config.reduction_factor, decoder_attention_mask
  2322. )
  2323. outputs = self.speecht5(
  2324. input_values=input_values,
  2325. attention_mask=attention_mask,
  2326. decoder_input_values=decoder_input_values,
  2327. decoder_attention_mask=decoder_attention_mask,
  2328. encoder_outputs=encoder_outputs,
  2329. past_key_values=past_key_values,
  2330. use_cache=use_cache,
  2331. speaker_embeddings=speaker_embeddings,
  2332. output_attentions=output_attentions,
  2333. output_hidden_states=output_hidden_states,
  2334. return_dict=True,
  2335. )
  2336. _, spectrogram, logits = self.speech_decoder_postnet(outputs[0])
  2337. loss = None
  2338. if not return_dict:
  2339. output = (spectrogram,) + outputs[1:]
  2340. return ((loss,) + output) if loss is not None else output
  2341. return Seq2SeqSpectrogramOutput(
  2342. loss=loss,
  2343. spectrogram=spectrogram,
  2344. past_key_values=outputs.past_key_values,
  2345. decoder_hidden_states=outputs.decoder_hidden_states,
  2346. decoder_attentions=outputs.decoder_attentions,
  2347. cross_attentions=outputs.cross_attentions,
  2348. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  2349. encoder_hidden_states=outputs.encoder_hidden_states,
  2350. encoder_attentions=outputs.encoder_attentions,
  2351. )
  2352. @torch.no_grad()
  2353. def generate_speech(
  2354. self,
  2355. input_values: torch.FloatTensor,
  2356. speaker_embeddings: torch.FloatTensor | None = None,
  2357. attention_mask: torch.LongTensor | None = None,
  2358. threshold: float = 0.5,
  2359. minlenratio: float = 0.0,
  2360. maxlenratio: float = 20.0,
  2361. vocoder: nn.Module | None = None,
  2362. output_cross_attentions: bool = False,
  2363. return_output_lengths: bool = False,
  2364. ) -> torch.FloatTensor:
  2365. r"""
  2366. Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a
  2367. speech waveform using a vocoder.
  2368. Args:
  2369. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  2370. Float values of input raw speech waveform.
  2371. Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `list[float]`,
  2372. a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`)
  2373. or the soundfile library (`pip install soundfile`).
  2374. To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and
  2375. conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
  2376. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2377. Tensor containing the speaker embeddings.
  2378. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  2379. Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
  2380. `[0, 1]`:
  2381. - 1 for tokens that are **not masked**,
  2382. - 0 for tokens that are **masked**.
  2383. [What are attention masks?](../glossary#attention-mask)
  2384. threshold (`float`, *optional*, defaults to 0.5):
  2385. The generated sequence ends when the predicted stop token probability exceeds this value.
  2386. minlenratio (`float`, *optional*, defaults to 0.0):
  2387. Used to calculate the minimum required length for the output sequence.
  2388. maxlenratio (`float`, *optional*, defaults to 20.0):
  2389. Used to calculate the maximum allowed length for the output sequence.
  2390. vocoder (`nn.Module`, *optional*, defaults to `None`):
  2391. The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
  2392. spectrogram.
  2393. output_cross_attentions (`bool`, *optional*, defaults to `False`):
  2394. Whether or not to return the attentions tensors of the decoder's cross-attention layers.
  2395. return_output_lengths (`bool`, *optional*, defaults to `False`):
  2396. Whether or not to return the concrete spectrogram/waveform lengths.
  2397. Returns:
  2398. `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
  2399. - when `return_output_lengths` is False
  2400. - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2401. `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
  2402. - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2403. `(num_frames,)` -- The predicted speech waveform.
  2404. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2405. `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
  2406. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2407. - when `return_output_lengths` is True
  2408. - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2409. `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
  2410. are padded to the maximum length.
  2411. - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
  2412. all the concrete lengths for each spectrogram.
  2413. - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2414. `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
  2415. - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
  2416. the concrete lengths for each waveform.
  2417. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2418. `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
  2419. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2420. """
  2421. if speaker_embeddings is None:
  2422. speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
  2423. return _generate_speech(
  2424. self,
  2425. input_values,
  2426. speaker_embeddings,
  2427. attention_mask,
  2428. threshold,
  2429. minlenratio,
  2430. maxlenratio,
  2431. vocoder,
  2432. output_cross_attentions,
  2433. return_output_lengths,
  2434. )
  2435. class HifiGanResidualBlock(nn.Module):
  2436. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
  2437. super().__init__()
  2438. self.leaky_relu_slope = leaky_relu_slope
  2439. self.convs1 = nn.ModuleList(
  2440. [
  2441. nn.Conv1d(
  2442. channels,
  2443. channels,
  2444. kernel_size,
  2445. stride=1,
  2446. dilation=dilation[i],
  2447. padding=self.get_padding(kernel_size, dilation[i]),
  2448. )
  2449. for i in range(len(dilation))
  2450. ]
  2451. )
  2452. self.convs2 = nn.ModuleList(
  2453. [
  2454. nn.Conv1d(
  2455. channels,
  2456. channels,
  2457. kernel_size,
  2458. stride=1,
  2459. dilation=1,
  2460. padding=self.get_padding(kernel_size, 1),
  2461. )
  2462. for _ in range(len(dilation))
  2463. ]
  2464. )
  2465. def get_padding(self, kernel_size, dilation=1):
  2466. return (kernel_size * dilation - dilation) // 2
  2467. def apply_weight_norm(self):
  2468. weight_norm = nn.utils.weight_norm
  2469. if hasattr(nn.utils.parametrizations, "weight_norm"):
  2470. weight_norm = nn.utils.parametrizations.weight_norm
  2471. for layer in self.convs1:
  2472. weight_norm(layer)
  2473. for layer in self.convs2:
  2474. weight_norm(layer)
  2475. def remove_weight_norm(self):
  2476. for layer in self.convs1:
  2477. nn.utils.remove_weight_norm(layer)
  2478. for layer in self.convs2:
  2479. nn.utils.remove_weight_norm(layer)
  2480. def forward(self, hidden_states):
  2481. for conv1, conv2 in zip(self.convs1, self.convs2):
  2482. residual = hidden_states
  2483. hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
  2484. hidden_states = conv1(hidden_states)
  2485. hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
  2486. hidden_states = conv2(hidden_states)
  2487. hidden_states = hidden_states + residual
  2488. return hidden_states
  2489. @auto_docstring(
  2490. custom_intro="""
  2491. HiFi-GAN vocoder.
  2492. """
  2493. )
  2494. class SpeechT5HifiGan(PreTrainedModel):
  2495. config: SpeechT5HifiGanConfig
  2496. main_input_name = "spectrogram"
  2497. def __init__(self, config: SpeechT5HifiGanConfig):
  2498. super().__init__(config)
  2499. self.num_kernels = len(config.resblock_kernel_sizes)
  2500. self.num_upsamples = len(config.upsample_rates)
  2501. self.conv_pre = nn.Conv1d(
  2502. config.model_in_dim,
  2503. config.upsample_initial_channel,
  2504. kernel_size=7,
  2505. stride=1,
  2506. padding=3,
  2507. )
  2508. self.upsampler = nn.ModuleList()
  2509. for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
  2510. self.upsampler.append(
  2511. nn.ConvTranspose1d(
  2512. config.upsample_initial_channel // (2**i),
  2513. config.upsample_initial_channel // (2 ** (i + 1)),
  2514. kernel_size=kernel_size,
  2515. stride=upsample_rate,
  2516. padding=(kernel_size - upsample_rate) // 2,
  2517. )
  2518. )
  2519. self.resblocks = nn.ModuleList()
  2520. for i in range(len(self.upsampler)):
  2521. channels = config.upsample_initial_channel // (2 ** (i + 1))
  2522. for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
  2523. self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
  2524. self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3)
  2525. self.register_buffer("mean", torch.zeros(config.model_in_dim))
  2526. self.register_buffer("scale", torch.ones(config.model_in_dim))
  2527. # Initialize weights and apply final processing
  2528. self.post_init()
  2529. def _init_weights(self, module):
  2530. super()._init_weights(module)
  2531. if isinstance(module, SpeechT5HifiGan):
  2532. init.zeros_(module.mean)
  2533. init.ones_(module.scale)
  2534. def apply_weight_norm(self):
  2535. weight_norm = nn.utils.weight_norm
  2536. if hasattr(nn.utils.parametrizations, "weight_norm"):
  2537. weight_norm = nn.utils.parametrizations.weight_norm
  2538. weight_norm(self.conv_pre)
  2539. for layer in self.upsampler:
  2540. weight_norm(layer)
  2541. for layer in self.resblocks:
  2542. layer.apply_weight_norm()
  2543. weight_norm(self.conv_post)
  2544. def remove_weight_norm(self):
  2545. nn.utils.remove_weight_norm(self.conv_pre)
  2546. for layer in self.upsampler:
  2547. nn.utils.remove_weight_norm(layer)
  2548. for layer in self.resblocks:
  2549. layer.remove_weight_norm()
  2550. nn.utils.remove_weight_norm(self.conv_post)
  2551. @auto_docstring(
  2552. custom_intro="""
  2553. Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
  2554. of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
  2555. waveform.
  2556. """
  2557. )
  2558. def forward(self, spectrogram: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
  2559. r"""
  2560. spectrogram (`torch.FloatTensor`):
  2561. Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
  2562. config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.
  2563. Returns:
  2564. `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
  2565. shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
  2566. """
  2567. if self.config.normalize_before:
  2568. spectrogram = (spectrogram - self.mean) / self.scale
  2569. is_batched = spectrogram.dim() == 3
  2570. if not is_batched:
  2571. spectrogram = spectrogram.unsqueeze(0)
  2572. hidden_states = spectrogram.transpose(2, 1)
  2573. hidden_states = self.conv_pre(hidden_states)
  2574. for i in range(self.num_upsamples):
  2575. hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
  2576. hidden_states = self.upsampler[i](hidden_states)
  2577. res_state = self.resblocks[i * self.num_kernels](hidden_states)
  2578. for j in range(1, self.num_kernels):
  2579. res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
  2580. hidden_states = res_state / self.num_kernels
  2581. hidden_states = nn.functional.leaky_relu(hidden_states)
  2582. hidden_states = self.conv_post(hidden_states)
  2583. hidden_states = torch.tanh(hidden_states)
  2584. if not is_batched:
  2585. # remove batch dim and collapse tensor to 1-d audio waveform
  2586. waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1)
  2587. else:
  2588. # remove seq-len dim since this collapses to 1
  2589. waveform = hidden_states.squeeze(1)
  2590. return waveform
  2591. __all__ = [
  2592. "SpeechT5ForSpeechToText",
  2593. "SpeechT5ForSpeechToSpeech",
  2594. "SpeechT5ForTextToSpeech",
  2595. "SpeechT5Model",
  2596. "SpeechT5PreTrainedModel",
  2597. "SpeechT5HifiGan",
  2598. ]