modeling_clap.py 73 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775
  1. # Copyright 2023 The LAION-AI Team and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch CLAP model."""
  15. import collections
  16. import math
  17. from collections.abc import Callable
  18. from dataclasses import dataclass
  19. from typing import Any
  20. import torch
  21. import torch.nn.functional as F
  22. from torch import nn
  23. from ... import initialization as init
  24. from ...activations import ACT2FN
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPooling,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...pytorch_utils import apply_chunking_to_forward
  33. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  34. from ...utils.generic import merge_with_config_defaults
  35. from ...utils.output_capturing import capture_outputs
  36. from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
  37. logger = logging.get_logger(__name__)
  38. # Adapted from: https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/utils.py#L191
  39. def interpolate(hidden_states, ratio):
  40. """
  41. Interpolate data in time domain. This is used to compensate the resolution reduction in downsampling of a CNN.
  42. Args:
  43. hidden_states (`torch.FloatTensor` of shape (batch_size, time_length, classes_num)):
  44. Input hidden states
  45. ratio (`int`):
  46. The ratio of the length of the output to the length of the input.
  47. """
  48. (batch_size, time_length, classes_num) = hidden_states.shape
  49. upsampled = hidden_states[:, :, None, :].repeat(1, 1, ratio, 1)
  50. upsampled = upsampled.reshape(batch_size, time_length * ratio, classes_num)
  51. return upsampled
  52. # Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/htsat.py#L249
  53. def window_partition(hidden_states, window_size):
  54. """
  55. Returns the resized hidden states. The output shape should be `(batch_size * num_windows, window_size, window_size,
  56. num_channels)`
  57. Args:
  58. hidden_states (`torch.FloatTensor` of shape `(batch_size, height, width, num_channels)`):
  59. Input hidden states
  60. window_size (`int`):
  61. Window size
  62. """
  63. batch_size, height, width, num_channels = hidden_states.shape
  64. hidden_states = hidden_states.view(
  65. batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
  66. )
  67. windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  68. return windows
  69. # Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/htsat.py#L263
  70. def window_reverse(windows, window_size, height, width):
  71. """
  72. Merges windows to produce higher resolution features.
  73. Args:
  74. windows (`torch.FloatTensor` of shape `(num_windows * batch_size, window_size, window_size, num_channels)`):
  75. Input windows
  76. window_size (`int`):
  77. Window size
  78. height (`int`):
  79. Height of the resized audio
  80. width (`int`):
  81. Width of the resized audio
  82. """
  83. num_channels = windows.shape[-1]
  84. windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
  85. windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
  86. return windows
  87. # contrastive loss function, adapted from
  88. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html#CLIP-loss-function
  89. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  90. labels = torch.arange(len(logits), device=logits.device)
  91. return nn.functional.cross_entropy(logits, labels)
  92. @dataclass
  93. @auto_docstring(
  94. custom_intro="""
  95. Base class for text model's outputs that also contains a pooling of the last hidden states.
  96. """
  97. )
  98. # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Clap
  99. class ClapTextModelOutput(ModelOutput):
  100. r"""
  101. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  102. The text embeddings obtained by applying the projection layer to the pooler_output.
  103. """
  104. text_embeds: torch.FloatTensor | None = None
  105. last_hidden_state: torch.FloatTensor | None = None
  106. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  107. attentions: tuple[torch.FloatTensor, ...] | None = None
  108. @dataclass
  109. @auto_docstring(
  110. custom_intro="""
  111. ClapAudio model output to mimic the output of the original implementation.
  112. """
  113. )
  114. class ClapAudioModelOutput(ModelOutput):
  115. r"""
  116. audio_embeds (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  117. The Audio embeddings obtained by applying the projection layer to the pooler_output.
  118. """
  119. audio_embeds: torch.FloatTensor | None = None
  120. last_hidden_state: torch.FloatTensor | None = None
  121. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  122. attentions: tuple[torch.FloatTensor, ...] | None = None
  123. @dataclass
  124. @auto_docstring
  125. # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Clap, vision->audio, Vision->Audio, image->audio
  126. class ClapOutput(ModelOutput):
  127. r"""
  128. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  129. Contrastive loss for audio-text similarity.
  130. logits_per_audio (`torch.FloatTensor` of shape `(audio_batch_size, text_batch_size)`):
  131. The scaled dot product scores between `audio_embeds` and `text_embeds`. This represents the audio-text
  132. similarity scores.
  133. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, audio_batch_size)`):
  134. The scaled dot product scores between `text_embeds` and `audio_embeds`. This represents the text-audio
  135. similarity scores.
  136. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  137. The text embeddings obtained by applying the projection layer to the pooled output of [`ClapTextModel`].
  138. audio_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  139. The audio embeddings obtained by applying the projection layer to the pooled output of [`ClapAudioModel`].
  140. text_model_output (`BaseModelOutputWithPooling`):
  141. The output of the [`ClapTextModel`].
  142. audio_model_output (`BaseModelOutputWithPooling`):
  143. The output of the [`ClapAudioModel`].
  144. """
  145. loss: torch.FloatTensor | None = None
  146. logits_per_audio: torch.FloatTensor | None = None
  147. logits_per_text: torch.FloatTensor | None = None
  148. text_embeds: torch.FloatTensor | None = None
  149. audio_embeds: torch.FloatTensor | None = None
  150. text_model_output: BaseModelOutputWithPooling = None
  151. audio_model_output: BaseModelOutputWithPooling = None
  152. def to_tuple(self) -> tuple[Any]:
  153. return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())
  154. # Adapted from transformers.models.swin.modeling_swin.SwinDropPath
  155. class ClapDropPath(nn.Module):
  156. """
  157. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is a slightly
  158. refactored version of the `SwinDropPath` implementation.
  159. """
  160. def __init__(self, drop_prob=None):
  161. super().__init__()
  162. self.drop_prob = drop_prob
  163. def forward(self, hidden_states):
  164. if self.drop_prob == 0.0 or not self.training:
  165. return hidden_states
  166. keep_prob = 1 - self.drop_prob
  167. # work with diff dim tensors, not just 2D ConvNets
  168. shape = (hidden_states.shape[0],) + (1,) * (hidden_states.ndim - 1)
  169. random_tensor = keep_prob + torch.rand(shape, dtype=hidden_states.dtype, device=hidden_states.device)
  170. random_tensor.floor_() # binarize
  171. output = hidden_states.div(keep_prob) * random_tensor
  172. return output
  173. # Adapted from https://github.com/LAION-AI/CLAP/blob/6ad05a971ba0622f6acee8c41993e0d02bbed639/src/open_clip/feature_fusion.py#L133
  174. class ClapAudioAFFBlock(nn.Module):
  175. r"""
  176. ATTENTIONAL FEATURE FUSION Block from CLAP, since in CLAP we are always in 2D mode, it is not needed to implement
  177. the 1D version.
  178. """
  179. def __init__(self, config: ClapAudioConfig):
  180. super().__init__()
  181. channels = config.patch_embeds_hidden_size
  182. downsize_ratio = config.aff_block_r
  183. inter_channels = int(channels // downsize_ratio)
  184. self.local_att = nn.Sequential(
  185. nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
  186. nn.BatchNorm2d(inter_channels),
  187. nn.ReLU(inplace=True),
  188. nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
  189. nn.BatchNorm2d(channels),
  190. )
  191. self.global_att = nn.Sequential(
  192. nn.AdaptiveAvgPool2d(1),
  193. nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
  194. nn.BatchNorm2d(inter_channels),
  195. nn.ReLU(inplace=True),
  196. nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
  197. nn.BatchNorm2d(channels),
  198. )
  199. self.sigmoid = nn.Sigmoid()
  200. def forward(self, hidden_states, residual):
  201. attention_input = hidden_states + residual
  202. fused_layer_output = self.local_att(attention_input) + self.global_att(attention_input)
  203. fused_layer_output = self.sigmoid(fused_layer_output)
  204. output = 2 * hidden_states * fused_layer_output + 2 * residual * (1 - fused_layer_output)
  205. return output
  206. class ClapAudioPatchEmbed(nn.Module):
  207. """
  208. This module converts the hidden states reshaped as an image to patch embeddings ready to be passed to the
  209. Transformer block.
  210. """
  211. def __init__(self, config: ClapAudioConfig):
  212. super().__init__()
  213. img_size = (config.spec_size, config.spec_size) if isinstance(config.spec_size, int) else config.spec_size
  214. patch_size = (
  215. (config.patch_size, config.patch_size) if isinstance(config.patch_size, int) else config.patch_size
  216. )
  217. patch_stride = (
  218. (config.patch_stride, config.patch_stride) if isinstance(config.patch_stride, int) else config.patch_stride
  219. )
  220. self.img_size = img_size
  221. self.patch_stride = patch_stride
  222. self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
  223. self.num_patches = self.grid_size[0] * self.grid_size[1]
  224. self.flatten = config.flatten_patch_embeds
  225. self.enable_fusion = config.enable_fusion
  226. padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
  227. scale_factor = 4 if self.enable_fusion and config.fusion_type == "channel_map" else 1
  228. self.proj = nn.Conv2d(
  229. config.patch_embed_input_channels * scale_factor,
  230. config.patch_embeds_hidden_size,
  231. kernel_size=patch_size,
  232. stride=patch_stride,
  233. padding=padding,
  234. )
  235. self.norm = nn.LayerNorm(config.patch_embeds_hidden_size) if config.enable_patch_layer_norm else nn.Identity()
  236. if self.enable_fusion:
  237. self.fusion_model = ClapAudioAFFBlock(config)
  238. self.mel_conv2d = nn.Conv2d(
  239. config.patch_embed_input_channels,
  240. config.patch_embeds_hidden_size,
  241. kernel_size=(patch_size[0], patch_size[1] * 3),
  242. stride=(patch_stride[0], patch_stride[1] * 3),
  243. padding=padding,
  244. )
  245. def forward(self, hidden_states, is_longer_idx=None):
  246. if self.enable_fusion:
  247. # retrieve the last mel as we have transposed the input
  248. global_hidden_states = hidden_states[:, 0:1, :, :]
  249. # global processing
  250. batch_size, num_channels, height, width = global_hidden_states.shape
  251. if height != self.img_size[0] or width != self.img_size[1]:
  252. raise ValueError(
  253. f"Input audio size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  254. )
  255. global_hidden_states = self.proj(global_hidden_states)
  256. output_width = global_hidden_states.size(-1)
  257. if len(is_longer_idx) > 0:
  258. # local processing
  259. local_hidden_states = hidden_states[is_longer_idx, 1:, :, :].contiguous()
  260. batch_size, num_channels, height, width = local_hidden_states.shape
  261. local_hidden_states = local_hidden_states.view(batch_size * num_channels, 1, height, width)
  262. local_hidden_states = self.mel_conv2d(local_hidden_states)
  263. _, features, height, width = local_hidden_states.shape
  264. local_hidden_states = local_hidden_states.view(batch_size, num_channels, features, height, width)
  265. local_hidden_states = local_hidden_states.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
  266. local_width = local_hidden_states.size(-1)
  267. local_hidden_states = torch.nn.functional.pad(
  268. local_hidden_states, (0, output_width - local_width), "constant", 0
  269. )
  270. global_hidden_states[is_longer_idx] = self.fusion_model(
  271. global_hidden_states[is_longer_idx], local_hidden_states
  272. )
  273. hidden_states = global_hidden_states
  274. else:
  275. _, _, height, width = hidden_states.shape
  276. if height != self.img_size[0] or width != self.img_size[1]:
  277. raise ValueError(
  278. f"Input audio size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  279. )
  280. hidden_states = self.proj(hidden_states)
  281. if self.flatten:
  282. hidden_states = hidden_states.flatten(2).transpose(1, 2)
  283. hidden_states = self.norm(hidden_states)
  284. return hidden_states
  285. # Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->ClapAudio
  286. class ClapAudioSelfAttention(nn.Module):
  287. def __init__(self, config, dim, num_heads, window_size):
  288. super().__init__()
  289. if dim % num_heads != 0:
  290. raise ValueError(
  291. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  292. )
  293. self.num_attention_heads = num_heads
  294. self.attention_head_size = int(dim / num_heads)
  295. self.all_head_size = self.num_attention_heads * self.attention_head_size
  296. self.window_size = (
  297. window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
  298. )
  299. self.relative_position_bias_table = nn.Parameter(
  300. torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
  301. )
  302. self.register_buffer("relative_position_index", self.create_relative_position_index())
  303. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  304. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  305. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  306. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  307. def forward(
  308. self,
  309. hidden_states: torch.Tensor,
  310. attention_mask: torch.FloatTensor | None = None,
  311. output_attentions: bool | None = False,
  312. ) -> tuple[torch.Tensor]:
  313. batch_size, dim, num_channels = hidden_states.shape
  314. hidden_shape = (batch_size, dim, -1, self.attention_head_size)
  315. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  316. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  317. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  318. # Take the dot product between "query" and "key" to get the raw attention scores.
  319. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  320. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  321. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
  322. relative_position_bias = relative_position_bias.view(
  323. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
  324. )
  325. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  326. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  327. if attention_mask is not None:
  328. # Apply the attention mask is (precomputed for all layers in ClapAudioModel forward() function)
  329. mask_shape = attention_mask.shape[0]
  330. attention_scores = attention_scores.view(
  331. batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
  332. )
  333. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
  334. attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
  335. # Normalize the attention scores to probabilities.
  336. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  337. # This is actually dropping out entire tokens to attend to, which might
  338. # seem a bit unusual, but is taken from the original Transformer paper.
  339. attention_probs = self.dropout(attention_probs)
  340. context_layer = torch.matmul(attention_probs, value_layer)
  341. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  342. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  343. context_layer = context_layer.view(new_context_layer_shape)
  344. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  345. return outputs
  346. def create_relative_position_index(self):
  347. # get pair-wise relative position index for each token inside the window
  348. coords_h = torch.arange(self.window_size[0])
  349. coords_w = torch.arange(self.window_size[1])
  350. coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
  351. coords_flatten = torch.flatten(coords, 1)
  352. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  353. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  354. relative_coords[:, :, 0] += self.window_size[0] - 1
  355. relative_coords[:, :, 1] += self.window_size[1] - 1
  356. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  357. relative_position_index = relative_coords.sum(-1)
  358. return relative_position_index
  359. # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->ClapAudio
  360. class ClapAudioSelfOutput(nn.Module):
  361. def __init__(self, config, dim):
  362. super().__init__()
  363. self.dense = nn.Linear(dim, dim)
  364. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  365. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  366. hidden_states = self.dense(hidden_states)
  367. hidden_states = self.dropout(hidden_states)
  368. return hidden_states
  369. # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->ClapAudio
  370. class ClapAudioAttention(nn.Module):
  371. def __init__(self, config, dim, num_heads, window_size):
  372. super().__init__()
  373. self.self = ClapAudioSelfAttention(config, dim, num_heads, window_size)
  374. self.output = ClapAudioSelfOutput(config, dim)
  375. def forward(
  376. self,
  377. hidden_states: torch.Tensor,
  378. attention_mask: torch.FloatTensor | None = None,
  379. output_attentions: bool | None = False,
  380. ) -> tuple[torch.Tensor]:
  381. self_outputs = self.self(hidden_states, attention_mask, output_attentions)
  382. attention_output = self.output(self_outputs[0], hidden_states)
  383. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  384. return outputs
  385. # Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->ClapAudio
  386. class ClapAudioIntermediate(nn.Module):
  387. def __init__(self, config, dim):
  388. super().__init__()
  389. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  390. if isinstance(config.hidden_act, str):
  391. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  392. else:
  393. self.intermediate_act_fn = config.hidden_act
  394. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  395. hidden_states = self.dense(hidden_states)
  396. hidden_states = self.intermediate_act_fn(hidden_states)
  397. return hidden_states
  398. # Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->ClapAudio
  399. class ClapAudioOutput(nn.Module):
  400. def __init__(self, config, dim):
  401. super().__init__()
  402. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  403. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  404. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  405. hidden_states = self.dense(hidden_states)
  406. hidden_states = self.dropout(hidden_states)
  407. return hidden_states
  408. # Copied from transformers.models.swin.modeling_swin.SwinLayer with SwinDropPath->ClapDropPath, Swin->ClapAudio
  409. class ClapAudioLayer(nn.Module):
  410. def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
  411. super().__init__()
  412. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  413. self.shift_size = shift_size
  414. self.window_size = config.window_size
  415. self.input_resolution = input_resolution
  416. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  417. self.attention = ClapAudioAttention(config, dim, num_heads, window_size=self.window_size)
  418. self.drop_path = ClapDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  419. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  420. self.intermediate = ClapAudioIntermediate(config, dim)
  421. self.output = ClapAudioOutput(config, dim)
  422. def set_shift_and_window_size(self, input_resolution):
  423. if min(input_resolution) <= self.window_size:
  424. # if window size is larger than input resolution, we don't partition windows
  425. self.shift_size = torch_int(0)
  426. self.window_size = (
  427. torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
  428. )
  429. def get_attn_mask(self, height, width, dtype, device):
  430. if self.shift_size > 0:
  431. # calculate attention mask for SW-MSA
  432. img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
  433. height_slices = (
  434. slice(0, -self.window_size),
  435. slice(-self.window_size, -self.shift_size),
  436. slice(-self.shift_size, None),
  437. )
  438. width_slices = (
  439. slice(0, -self.window_size),
  440. slice(-self.window_size, -self.shift_size),
  441. slice(-self.shift_size, None),
  442. )
  443. count = 0
  444. for height_slice in height_slices:
  445. for width_slice in width_slices:
  446. img_mask[:, height_slice, width_slice, :] = count
  447. count += 1
  448. mask_windows = window_partition(img_mask, self.window_size)
  449. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  450. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  451. attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
  452. else:
  453. attn_mask = None
  454. return attn_mask
  455. def maybe_pad(self, hidden_states, height, width):
  456. pad_right = (self.window_size - width % self.window_size) % self.window_size
  457. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  458. pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
  459. hidden_states = nn.functional.pad(hidden_states, pad_values)
  460. return hidden_states, pad_values
  461. def forward(
  462. self,
  463. hidden_states: torch.Tensor,
  464. input_dimensions: tuple[int, int],
  465. output_attentions: bool | None = False,
  466. always_partition: bool | None = False,
  467. ) -> tuple[torch.Tensor, torch.Tensor]:
  468. if not always_partition:
  469. self.set_shift_and_window_size(input_dimensions)
  470. else:
  471. pass
  472. height, width = input_dimensions
  473. batch_size, _, channels = hidden_states.size()
  474. shortcut = hidden_states
  475. hidden_states = self.layernorm_before(hidden_states)
  476. hidden_states = hidden_states.view(batch_size, height, width, channels)
  477. # pad hidden_states to multiples of window size
  478. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  479. _, height_pad, width_pad, _ = hidden_states.shape
  480. # cyclic shift
  481. if self.shift_size > 0:
  482. shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  483. else:
  484. shifted_hidden_states = hidden_states
  485. # partition windows
  486. hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
  487. hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
  488. attn_mask = self.get_attn_mask(
  489. height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
  490. )
  491. attention_outputs = self.attention(hidden_states_windows, attn_mask, output_attentions=output_attentions)
  492. attention_output = attention_outputs[0]
  493. attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
  494. shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
  495. # reverse cyclic shift
  496. if self.shift_size > 0:
  497. attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  498. else:
  499. attention_windows = shifted_windows
  500. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  501. if was_padded:
  502. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  503. attention_windows = attention_windows.view(batch_size, height * width, channels)
  504. hidden_states = shortcut + self.drop_path(attention_windows)
  505. layer_output = self.layernorm_after(hidden_states)
  506. layer_output = self.intermediate(layer_output)
  507. layer_output = hidden_states + self.output(layer_output)
  508. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  509. return layer_outputs
  510. # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio
  511. class ClapAudioStage(GradientCheckpointingLayer):
  512. def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
  513. super().__init__()
  514. self.config = config
  515. self.dim = dim
  516. self.blocks = nn.ModuleList(
  517. [
  518. ClapAudioLayer(
  519. config=config,
  520. dim=dim,
  521. input_resolution=input_resolution,
  522. num_heads=num_heads,
  523. drop_path_rate=drop_path[i],
  524. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  525. )
  526. for i in range(depth)
  527. ]
  528. )
  529. # patch merging layer
  530. if downsample is not None:
  531. self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
  532. else:
  533. self.downsample = None
  534. self.pointing = False
  535. def forward(
  536. self,
  537. hidden_states: torch.Tensor,
  538. input_dimensions: tuple[int, int],
  539. output_attentions: bool | None = False,
  540. always_partition: bool | None = False,
  541. ) -> tuple[torch.Tensor]:
  542. height, width = input_dimensions
  543. for i, layer_module in enumerate(self.blocks):
  544. layer_outputs = layer_module(hidden_states, input_dimensions, output_attentions, always_partition)
  545. hidden_states = layer_outputs[0]
  546. hidden_states_before_downsampling = hidden_states
  547. if self.downsample is not None:
  548. height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
  549. output_dimensions = (height, width, height_downsampled, width_downsampled)
  550. hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
  551. else:
  552. output_dimensions = (height, width, height, width)
  553. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  554. if output_attentions:
  555. stage_outputs += layer_outputs[1:]
  556. return stage_outputs
  557. # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging with Swin->ClapAudio
  558. class ClapAudioPatchMerging(nn.Module):
  559. """
  560. Patch Merging Layer.
  561. Args:
  562. input_resolution (`tuple[int]`):
  563. Resolution of input feature.
  564. dim (`int`):
  565. Number of input channels.
  566. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  567. Normalization layer class.
  568. """
  569. def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  570. super().__init__()
  571. self.input_resolution = input_resolution
  572. self.dim = dim
  573. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  574. self.norm = norm_layer(4 * dim)
  575. def maybe_pad(self, input_feature, height, width):
  576. should_pad = (height % 2 == 1) or (width % 2 == 1)
  577. if should_pad:
  578. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  579. input_feature = nn.functional.pad(input_feature, pad_values)
  580. return input_feature
  581. def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
  582. height, width = input_dimensions
  583. # `dim` is height * width
  584. batch_size, dim, num_channels = input_feature.shape
  585. input_feature = input_feature.view(batch_size, height, width, num_channels)
  586. # pad input to be divisible by width and height, if needed
  587. input_feature = self.maybe_pad(input_feature, height, width)
  588. # [batch_size, height/2, width/2, num_channels]
  589. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  590. # [batch_size, height/2, width/2, num_channels]
  591. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  592. # [batch_size, height/2, width/2, num_channels]
  593. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  594. # [batch_size, height/2, width/2, num_channels]
  595. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  596. # batch_size height/2 width/2 4*num_channels
  597. input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
  598. input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
  599. input_feature = self.norm(input_feature)
  600. input_feature = self.reduction(input_feature)
  601. return input_feature
  602. class ClapAudioEncoder(nn.Module):
  603. def __init__(self, config):
  604. super().__init__()
  605. self.num_layers = len(config.depths)
  606. self.config = config
  607. self.patch_embed = ClapAudioPatchEmbed(config)
  608. self.enable_fusion = config.enable_fusion
  609. self.patch_stride = self.patch_embed.patch_stride
  610. self.spec_size = config.spec_size
  611. self.freq_ratio = config.spec_size // config.num_mel_bins
  612. self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1))
  613. drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  614. grid_size = self.patch_embed.grid_size
  615. self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)]
  616. self.layers = nn.ModuleList(
  617. [
  618. ClapAudioStage(
  619. config=config,
  620. dim=int(config.patch_embeds_hidden_size * 2**i_layer),
  621. input_resolution=self.input_resolutions[i_layer],
  622. depth=config.depths[i_layer],
  623. num_heads=config.num_attention_heads[i_layer],
  624. drop_path=drop_path_rate[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
  625. downsample=ClapAudioPatchMerging if (i_layer < self.num_layers - 1) else None,
  626. )
  627. for i_layer in range(self.num_layers)
  628. ]
  629. )
  630. self.gradient_checkpointing = False
  631. self.batch_norm = nn.BatchNorm2d(config.num_mel_bins)
  632. self.norm = nn.LayerNorm(self.num_features)
  633. self.depths = config.depths
  634. self.avgpool = nn.AdaptiveAvgPool1d(1)
  635. def reshape_mel2img(self, normalized_input_features):
  636. """
  637. The input is 4 normalized log mel spectrograms. It is reshape to the common shape of images. Each channel
  638. should represent 1 of the 4 crops of the spectrogram. For more details, refer to the [`ClapFeatureExtractor`].
  639. """
  640. _, _, time_length, freq_length = normalized_input_features.shape
  641. spec_width = int(self.spec_size * self.freq_ratio)
  642. spec_height = self.spec_size // self.freq_ratio
  643. if time_length > spec_width or freq_length > spec_height:
  644. raise ValueError("the wav size should be less than or equal to the swin input size")
  645. # to avoid bicubic zero error
  646. if time_length < spec_width:
  647. normalized_input_features = nn.functional.interpolate(
  648. normalized_input_features, (spec_width, freq_length), mode="bicubic", align_corners=True
  649. )
  650. if freq_length < spec_height:
  651. normalized_input_features = nn.functional.interpolate(
  652. normalized_input_features, (time_length, spec_height), mode="bicubic", align_corners=True
  653. )
  654. batch, channels, time, freq = normalized_input_features.shape
  655. # batch_size, channels, spec_width, spec_height --> batch_size, channels, spec_height * freq_ratio, spec_width // freq_ratio
  656. normalized_input_features = normalized_input_features.reshape(
  657. batch, channels * self.freq_ratio, time // self.freq_ratio, freq
  658. )
  659. normalized_input_features = normalized_input_features.permute(0, 1, 3, 2).contiguous()
  660. normalized_input_features = normalized_input_features.reshape(
  661. batch, channels, freq * self.freq_ratio, time // self.freq_ratio
  662. )
  663. return normalized_input_features
  664. @can_return_tuple
  665. def forward(
  666. self,
  667. input_features,
  668. is_longer: torch.FloatTensor | None = None,
  669. output_attentions: bool | None = False,
  670. output_hidden_states: bool | None = False,
  671. output_hidden_states_before_downsampling: bool | None = False,
  672. always_partition: bool | None = False,
  673. return_dict: bool | None = True,
  674. ) -> tuple | ClapAudioModelOutput:
  675. # Unique logic so no refactor here yet
  676. output_hidden_states = output_hidden_states or self.config.output_hidden_states
  677. output_attentions = output_attentions or self.config.output_attentions
  678. input_features = input_features.transpose(1, 3)
  679. normalized_input_features = self.batch_norm(input_features)
  680. normalized_input_features = normalized_input_features.transpose(1, 3)
  681. is_longer_list_idx = None
  682. if self.enable_fusion:
  683. is_longer_list = is_longer.to(input_features.device)
  684. is_longer_list_idx = torch.where(is_longer_list == 1)[0]
  685. hidden_states = self.reshape_mel2img(normalized_input_features)
  686. frames_num = hidden_states.shape[2]
  687. hidden_states = self.patch_embed(hidden_states, is_longer_list_idx)
  688. all_hidden_states = () if output_hidden_states else None
  689. all_reshaped_hidden_states = () if output_hidden_states else None
  690. all_self_attentions = () if output_attentions else None
  691. input_dimensions = self.input_resolutions[0]
  692. if output_hidden_states:
  693. batch_size, _, hidden_size = hidden_states.shape
  694. # rearrange batch_size (height width) channels -> batch_size channel height width
  695. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  696. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  697. all_hidden_states += (hidden_states,)
  698. all_reshaped_hidden_states += (reshaped_hidden_state,)
  699. for i, layer_module in enumerate(self.layers):
  700. input_dimensions = self.input_resolutions[i]
  701. layer_outputs = layer_module(hidden_states, input_dimensions, output_attentions, always_partition)
  702. hidden_states = layer_outputs[0]
  703. hidden_states_before_downsampling = layer_outputs[1]
  704. output_dimensions = layer_outputs[2]
  705. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  706. if output_hidden_states and output_hidden_states_before_downsampling:
  707. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  708. # rearrange batch_size (height width) channels -> batch_size channel height width
  709. # here we use the original (not downsampled) height and width
  710. reshaped_hidden_state = hidden_states_before_downsampling.view(
  711. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  712. )
  713. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  714. all_hidden_states += (hidden_states_before_downsampling,)
  715. all_reshaped_hidden_states += (reshaped_hidden_state,)
  716. elif output_hidden_states and not output_hidden_states_before_downsampling:
  717. batch_size, _, hidden_size = hidden_states.shape
  718. # rearrange batch_size (height width) channels -> batch_size channel height width
  719. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  720. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  721. all_hidden_states += (hidden_states,)
  722. all_reshaped_hidden_states += (reshaped_hidden_state,)
  723. if output_attentions:
  724. all_self_attentions += layer_outputs[3:]
  725. last_hidden_state = self.norm(hidden_states)
  726. batch_size, _, n_channels = last_hidden_state.shape
  727. freq_shape = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
  728. temporal_shape = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
  729. last_hidden_state = (
  730. last_hidden_state.permute(0, 2, 1).contiguous().reshape(batch_size, n_channels, freq_shape, temporal_shape)
  731. )
  732. batch_size, n_channels, n_frequencies, n_temp = last_hidden_state.shape
  733. # group 2D CNN
  734. c_freq_bin = n_frequencies // self.freq_ratio
  735. last_hidden_state = last_hidden_state.reshape(
  736. batch_size, n_channels, n_frequencies // c_freq_bin, c_freq_bin, n_temp
  737. )
  738. last_hidden_state = (
  739. last_hidden_state.permute(0, 1, 3, 2, 4).contiguous().reshape(batch_size, n_channels, c_freq_bin, -1)
  740. )
  741. latent_output = self.avgpool(torch.flatten(last_hidden_state, 2))
  742. latent_output = torch.flatten(latent_output, 1)
  743. return BaseModelOutputWithPooling(
  744. last_hidden_state=last_hidden_state,
  745. pooler_output=latent_output,
  746. hidden_states=all_reshaped_hidden_states,
  747. attentions=all_self_attentions,
  748. )
  749. class ClapProjectionLayer(nn.Module):
  750. def __init__(self, config: ClapAudioConfig | ClapTextConfig):
  751. super().__init__()
  752. self.config = config
  753. hidden_size = config.hidden_size
  754. projection_dim = config.projection_dim
  755. self.linear1 = nn.Linear(hidden_size, projection_dim)
  756. self.activation = ACT2FN[config.projection_hidden_act]
  757. self.linear2 = nn.Linear(projection_dim, projection_dim)
  758. def forward(self, hidden_states):
  759. hidden_states = self.linear1(hidden_states)
  760. hidden_states = self.activation(hidden_states)
  761. hidden_states = self.linear2(hidden_states)
  762. return hidden_states
  763. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->ClapText, persistent=False->persistent=True
  764. class ClapTextEmbeddings(nn.Module):
  765. """Construct the embeddings from word, position and token_type embeddings."""
  766. def __init__(self, config):
  767. super().__init__()
  768. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  769. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  770. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  771. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  772. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  773. self.register_buffer(
  774. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=True
  775. )
  776. self.register_buffer(
  777. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=True
  778. )
  779. self.padding_idx = config.pad_token_id
  780. self.position_embeddings = nn.Embedding(
  781. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  782. )
  783. def forward(
  784. self,
  785. input_ids: torch.LongTensor | None = None,
  786. token_type_ids: torch.LongTensor | None = None,
  787. position_ids: torch.LongTensor | None = None,
  788. inputs_embeds: torch.FloatTensor | None = None,
  789. past_key_values_length: int = 0,
  790. ) -> torch.Tensor:
  791. if position_ids is None:
  792. if input_ids is not None:
  793. # Create the position ids from the input token ids. Any padded tokens remain padded.
  794. position_ids = self.create_position_ids_from_input_ids(
  795. input_ids, self.padding_idx, past_key_values_length
  796. )
  797. else:
  798. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
  799. if input_ids is not None:
  800. input_shape = input_ids.size()
  801. else:
  802. input_shape = inputs_embeds.size()[:-1]
  803. batch_size, seq_length = input_shape
  804. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  805. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  806. # issue #5664
  807. if token_type_ids is None:
  808. if hasattr(self, "token_type_ids"):
  809. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  810. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  811. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  812. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  813. else:
  814. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  815. if inputs_embeds is None:
  816. inputs_embeds = self.word_embeddings(input_ids)
  817. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  818. embeddings = inputs_embeds + token_type_embeddings
  819. position_embeddings = self.position_embeddings(position_ids)
  820. embeddings = embeddings + position_embeddings
  821. embeddings = self.LayerNorm(embeddings)
  822. embeddings = self.dropout(embeddings)
  823. return embeddings
  824. @staticmethod
  825. def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
  826. """
  827. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  828. Args:
  829. inputs_embeds: torch.Tensor
  830. Returns: torch.Tensor
  831. """
  832. input_shape = inputs_embeds.size()[:-1]
  833. sequence_length = input_shape[1]
  834. position_ids = torch.arange(
  835. padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  836. )
  837. return position_ids.unsqueeze(0).expand(input_shape)
  838. @staticmethod
  839. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  840. """
  841. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  842. are ignored. This is modified from fairseq's `utils.make_positions`.
  843. Args:
  844. x: torch.Tensor x:
  845. Returns: torch.Tensor
  846. """
  847. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  848. mask = input_ids.ne(padding_idx).int()
  849. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  850. return incremental_indices.long() + padding_idx
  851. # Copied from transformers.models.align.modeling_align.eager_attention_forward
  852. def eager_attention_forward(
  853. module: nn.Module,
  854. query: torch.Tensor,
  855. key: torch.Tensor,
  856. value: torch.Tensor,
  857. attention_mask: torch.Tensor | None,
  858. scaling: float,
  859. dropout: float = 0.0,
  860. **kwargs,
  861. ):
  862. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  863. if attention_mask is not None:
  864. attn_weights = attn_weights + attention_mask
  865. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  866. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  867. attn_output = torch.matmul(attn_weights, value)
  868. attn_output = attn_output.transpose(1, 2).contiguous()
  869. return attn_output, attn_weights
  870. # Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap
  871. class ClapTextSelfAttention(nn.Module):
  872. def __init__(self, config):
  873. super().__init__()
  874. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  875. raise ValueError(
  876. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  877. f"heads ({config.num_attention_heads})"
  878. )
  879. self.config = config
  880. self.num_attention_heads = config.num_attention_heads
  881. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  882. self.all_head_size = self.num_attention_heads * self.attention_head_size
  883. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  884. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  885. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  886. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  887. self.attention_dropout = config.attention_probs_dropout_prob
  888. self.scaling = self.attention_head_size**-0.5
  889. def forward(
  890. self,
  891. hidden_states: torch.Tensor,
  892. attention_mask: torch.FloatTensor | None = None,
  893. **kwargs: Unpack[TransformersKwargs],
  894. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  895. input_shape = hidden_states.shape[:-1]
  896. hidden_shape = (*input_shape, -1, self.attention_head_size)
  897. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  898. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  899. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  900. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  901. self.config._attn_implementation, eager_attention_forward
  902. )
  903. attn_output, attn_weights = attention_interface(
  904. self,
  905. query_states,
  906. key_states,
  907. value_states,
  908. attention_mask,
  909. dropout=0.0 if not self.training else self.attention_dropout,
  910. scaling=self.scaling,
  911. **kwargs,
  912. )
  913. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  914. return attn_output, attn_weights
  915. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  916. class ClapTextSelfOutput(nn.Module):
  917. def __init__(self, config):
  918. super().__init__()
  919. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  920. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  921. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  922. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  923. hidden_states = self.dense(hidden_states)
  924. hidden_states = self.dropout(hidden_states)
  925. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  926. return hidden_states
  927. # Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap
  928. class ClapTextAttention(nn.Module):
  929. def __init__(self, config):
  930. super().__init__()
  931. self.self = ClapTextSelfAttention(config)
  932. self.output = ClapTextSelfOutput(config)
  933. def forward(
  934. self,
  935. hidden_states: torch.Tensor,
  936. attention_mask: torch.FloatTensor | None = None,
  937. **kwargs: Unpack[TransformersKwargs],
  938. ) -> torch.Tensor:
  939. residual = hidden_states
  940. hidden_states, _ = self.self(
  941. hidden_states,
  942. attention_mask=attention_mask,
  943. **kwargs,
  944. )
  945. hidden_states = self.output(hidden_states, residual)
  946. return hidden_states
  947. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  948. class ClapTextIntermediate(nn.Module):
  949. def __init__(self, config):
  950. super().__init__()
  951. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  952. if isinstance(config.hidden_act, str):
  953. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  954. else:
  955. self.intermediate_act_fn = config.hidden_act
  956. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  957. hidden_states = self.dense(hidden_states)
  958. hidden_states = self.intermediate_act_fn(hidden_states)
  959. return hidden_states
  960. # Copied from transformers.models.bert.modeling_bert.BertOutput
  961. class ClapTextOutput(nn.Module):
  962. def __init__(self, config):
  963. super().__init__()
  964. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  965. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  966. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  967. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  968. hidden_states = self.dense(hidden_states)
  969. hidden_states = self.dropout(hidden_states)
  970. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  971. return hidden_states
  972. # Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap
  973. class ClapTextLayer(GradientCheckpointingLayer):
  974. def __init__(self, config):
  975. super().__init__()
  976. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  977. self.seq_len_dim = 1
  978. self.attention = ClapTextAttention(config)
  979. self.intermediate = ClapTextIntermediate(config)
  980. self.output = ClapTextOutput(config)
  981. def forward(
  982. self,
  983. hidden_states: torch.Tensor,
  984. attention_mask: torch.FloatTensor | None = None,
  985. **kwargs: Unpack[TransformersKwargs],
  986. ) -> torch.Tensor:
  987. hidden_states = self.attention(
  988. hidden_states,
  989. attention_mask=attention_mask,
  990. **kwargs,
  991. )
  992. hidden_states = apply_chunking_to_forward(
  993. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, hidden_states
  994. )
  995. return hidden_states
  996. def feed_forward_chunk(self, attention_output):
  997. intermediate_output = self.intermediate(attention_output)
  998. layer_output = self.output(intermediate_output, attention_output)
  999. return layer_output
  1000. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap
  1001. class ClapTextEncoder(nn.Module):
  1002. def __init__(self, config):
  1003. super().__init__()
  1004. self.config = config
  1005. self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)])
  1006. self.gradient_checkpointing = False
  1007. def forward(
  1008. self,
  1009. hidden_states: torch.Tensor,
  1010. attention_mask: torch.FloatTensor | None = None,
  1011. **kwargs: Unpack[TransformersKwargs],
  1012. ) -> BaseModelOutput:
  1013. for layer_module in self.layer:
  1014. hidden_states = layer_module(
  1015. hidden_states,
  1016. attention_mask,
  1017. **kwargs,
  1018. )
  1019. return BaseModelOutput(
  1020. last_hidden_state=hidden_states,
  1021. )
  1022. # Copied from transformers.models.bert.modeling_bert.BertPooler
  1023. class ClapTextPooler(nn.Module):
  1024. def __init__(self, config):
  1025. super().__init__()
  1026. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1027. self.activation = nn.Tanh()
  1028. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1029. # We "pool" the model by simply taking the hidden state corresponding
  1030. # to the first token.
  1031. first_token_tensor = hidden_states[:, 0]
  1032. pooled_output = self.dense(first_token_tensor)
  1033. pooled_output = self.activation(pooled_output)
  1034. return pooled_output
  1035. @auto_docstring
  1036. class ClapPreTrainedModel(PreTrainedModel):
  1037. config: ClapConfig
  1038. base_model_prefix = "clap"
  1039. input_modalities = ("audio", "text")
  1040. supports_gradient_checkpointing = False
  1041. @torch.no_grad()
  1042. def _init_weights(self, module: nn.Module):
  1043. """Initialize the weights"""
  1044. factor = self.config.initializer_factor
  1045. if isinstance(module, ClapTextEmbeddings):
  1046. init.normal_(module.position_embeddings.weight, mean=0.0, std=factor * 0.02)
  1047. init.normal_(module.token_type_embeddings.weight, mean=0.0, std=factor * 0.02)
  1048. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  1049. init.zeros_(module.token_type_ids)
  1050. elif isinstance(module, ClapModel):
  1051. init.constant_(module.logit_scale_a, math.log(self.config.logit_scale_init_value))
  1052. init.constant_(module.logit_scale_t, math.log(self.config.logit_scale_init_value))
  1053. elif isinstance(module, nn.Embedding):
  1054. init.normal_(module.weight, mean=0.0, std=factor * 0.02)
  1055. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
  1056. init.zeros_(module.bias)
  1057. init.ones_(module.weight)
  1058. if getattr(module, "running_mean", None) is not None:
  1059. init.zeros_(module.running_mean)
  1060. init.ones_(module.running_var)
  1061. init.zeros_(module.num_batches_tracked)
  1062. elif isinstance(module, (nn.Conv2d, nn.Linear)):
  1063. in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor
  1064. init.normal_(module.weight, std=in_proj_std)
  1065. if module.bias is not None:
  1066. init.zeros_(module.bias)
  1067. elif isinstance(module, ClapAudioSelfAttention):
  1068. init.zeros_(module.relative_position_bias_table)
  1069. init.copy_(module.relative_position_index, module.create_relative_position_index())
  1070. class ClapAudioModel(ClapPreTrainedModel):
  1071. config: ClapAudioConfig
  1072. main_input_name = "input_features"
  1073. input_modalities = "audio"
  1074. def __init__(self, config: ClapAudioConfig):
  1075. super().__init__(config)
  1076. self.audio_encoder = ClapAudioEncoder(config)
  1077. # Initialize weights and apply final processing
  1078. self.post_init()
  1079. def get_input_embeddings(self) -> nn.Module:
  1080. return self.audio_encoder.patch_embed.proj
  1081. @auto_docstring
  1082. def forward(
  1083. self,
  1084. input_features: torch.FloatTensor | None = None,
  1085. is_longer: torch.BoolTensor | None = None,
  1086. **kwargs: Unpack[TransformersKwargs],
  1087. ) -> tuple | BaseModelOutputWithPooling:
  1088. r"""
  1089. is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
  1090. Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
  1091. the features.
  1092. Examples:
  1093. ```python
  1094. >>> from datasets import load_dataset
  1095. >>> from transformers import AutoProcessor, ClapAudioModel
  1096. >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
  1097. >>> audio_sample = dataset["train"]["audio"][0]["array"]
  1098. >>> model = ClapAudioModel.from_pretrained("laion/clap-htsat-fused")
  1099. >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-fused")
  1100. >>> inputs = processor(audio=audio_sample, return_tensors="pt")
  1101. >>> outputs = model(**inputs)
  1102. >>> last_hidden_state = outputs.last_hidden_state
  1103. ```"""
  1104. return self.audio_encoder(
  1105. input_features=input_features,
  1106. is_longer=is_longer,
  1107. **kwargs,
  1108. )
  1109. @auto_docstring(
  1110. custom_intro="""
  1111. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  1112. cross-attention is added between the self-attention layers, following the architecture described in *Attention is
  1113. all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
  1114. Kaiser and Illia Polosukhin.
  1115. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  1116. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  1117. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  1118. .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
  1119. """
  1120. )
  1121. class ClapTextModel(ClapPreTrainedModel):
  1122. config: ClapTextConfig
  1123. input_modalities = ("text",)
  1124. _can_record_outputs = {
  1125. "hidden_states": ClapTextLayer,
  1126. "attentions": ClapTextSelfAttention,
  1127. }
  1128. def __init__(self, config, add_pooling_layer=True):
  1129. r"""
  1130. add_pooling_layer (bool, *optional*, defaults to `True`):
  1131. Whether to add a pooling layer
  1132. """
  1133. super().__init__(config)
  1134. self.config = config
  1135. self.embeddings = ClapTextEmbeddings(config)
  1136. self.encoder = ClapTextEncoder(config)
  1137. self.pooler = ClapTextPooler(config) if add_pooling_layer else None
  1138. # Initialize weights and apply final processing
  1139. self.post_init()
  1140. def get_input_embeddings(self):
  1141. return self.embeddings.word_embeddings
  1142. def set_input_embeddings(self, value):
  1143. self.embeddings.word_embeddings = value
  1144. @merge_with_config_defaults
  1145. @capture_outputs
  1146. @auto_docstring
  1147. def forward(
  1148. self,
  1149. input_ids: torch.Tensor | None = None,
  1150. attention_mask: torch.Tensor | None = None,
  1151. token_type_ids: torch.Tensor | None = None,
  1152. position_ids: torch.Tensor | None = None,
  1153. inputs_embeds: torch.Tensor | None = None,
  1154. **kwargs: Unpack[TransformersKwargs],
  1155. ) -> BaseModelOutputWithPooling:
  1156. if input_ids is not None and inputs_embeds is not None:
  1157. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1158. elif input_ids is not None:
  1159. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  1160. input_shape = input_ids.size()
  1161. elif inputs_embeds is not None:
  1162. input_shape = inputs_embeds.size()[:-1]
  1163. else:
  1164. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1165. batch_size, seq_length = input_shape
  1166. device = input_ids.device if input_ids is not None else inputs_embeds.device
  1167. if attention_mask is None:
  1168. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  1169. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  1170. # ourselves in which case we just need to make it broadcastable to all heads.
  1171. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  1172. embedding_output = self.embeddings(
  1173. input_ids=input_ids,
  1174. position_ids=position_ids,
  1175. token_type_ids=token_type_ids,
  1176. inputs_embeds=inputs_embeds,
  1177. )
  1178. encoder_outputs = self.encoder(
  1179. embedding_output,
  1180. attention_mask=extended_attention_mask,
  1181. **kwargs,
  1182. )
  1183. sequence_output = encoder_outputs[0]
  1184. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  1185. return BaseModelOutputWithPooling(
  1186. last_hidden_state=sequence_output,
  1187. pooler_output=pooled_output,
  1188. )
  1189. @auto_docstring
  1190. class ClapModel(ClapPreTrainedModel):
  1191. config: ClapConfig
  1192. def __init__(self, config: ClapConfig):
  1193. super().__init__(config)
  1194. if not isinstance(config.text_config, ClapTextConfig):
  1195. raise TypeError(
  1196. "config.text_config is expected to be of type ClapTextConfig but is of type"
  1197. f" {type(config.text_config)}."
  1198. )
  1199. if not isinstance(config.audio_config, ClapAudioConfig):
  1200. raise TypeError(
  1201. "config.audio_config is expected to be of type ClapAudioConfig but is of type"
  1202. f" {type(config.audio_config)}."
  1203. )
  1204. text_config = config.text_config
  1205. audio_config = config.audio_config
  1206. self.logit_scale_a = nn.Parameter(torch.tensor(math.log(config.logit_scale_init_value)))
  1207. self.logit_scale_t = nn.Parameter(torch.tensor(math.log(config.logit_scale_init_value)))
  1208. self.projection_dim = config.projection_dim
  1209. self.text_model = ClapTextModel(text_config)
  1210. self.text_projection = ClapProjectionLayer(text_config)
  1211. self.audio_model = ClapAudioModel(audio_config)
  1212. self.audio_projection = ClapProjectionLayer(audio_config)
  1213. # Initialize weights and apply final processing
  1214. self.post_init()
  1215. @can_return_tuple
  1216. @auto_docstring
  1217. def get_text_features(
  1218. self,
  1219. input_ids: torch.Tensor,
  1220. attention_mask: torch.Tensor | None = None,
  1221. position_ids: torch.Tensor | None = None,
  1222. **kwargs: Unpack[TransformersKwargs],
  1223. ) -> tuple | BaseModelOutputWithPooling:
  1224. r"""
  1225. Examples:
  1226. ```python
  1227. >>> import torch
  1228. >>> from transformers import AutoTokenizer, ClapModel
  1229. >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
  1230. >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
  1231. >>> inputs = tokenizer(["the sound of a cat", "the sound of a dog"], padding=True, return_tensors="pt")
  1232. >>> with torch.inference_mode():
  1233. ... text_features = model.get_text_features(**inputs)
  1234. ```"""
  1235. text_outputs: BaseModelOutputWithPooling = self.text_model(
  1236. input_ids=input_ids,
  1237. attention_mask=attention_mask,
  1238. position_ids=position_ids,
  1239. **kwargs,
  1240. )
  1241. text_features = self.text_projection(text_outputs.pooler_output)
  1242. text_outputs.pooler_output = F.normalize(text_features, dim=-1)
  1243. return text_outputs
  1244. @can_return_tuple
  1245. @auto_docstring
  1246. def get_audio_features(
  1247. self,
  1248. input_features: torch.Tensor,
  1249. is_longer: torch.Tensor | None = None,
  1250. attention_mask: torch.Tensor | None = None,
  1251. **kwargs: Unpack[TransformersKwargs],
  1252. ) -> tuple | BaseModelOutputWithPooling:
  1253. r"""
  1254. is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
  1255. Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
  1256. the features.
  1257. Examples:
  1258. ```python
  1259. >>> import torch
  1260. >>> from transformers import AutoFeatureExtractor, ClapModel
  1261. >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
  1262. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")
  1263. >>> random_audio = torch.rand((16_000))
  1264. >>> inputs = feature_extractor(random_audio, return_tensors="pt")
  1265. >>> with torch.inference_mode():
  1266. ... audio_features = model.get_audio_features(**inputs)
  1267. ```"""
  1268. audio_outputs: BaseModelOutputWithPooling = self.audio_model(
  1269. input_features=input_features, is_longer=is_longer, **kwargs
  1270. )
  1271. audio_features = self.audio_projection(audio_outputs.pooler_output)
  1272. audio_outputs.pooler_output = F.normalize(audio_features, dim=-1)
  1273. return audio_outputs
  1274. @can_return_tuple
  1275. @auto_docstring
  1276. def forward(
  1277. self,
  1278. input_ids: torch.LongTensor | None = None,
  1279. input_features: torch.FloatTensor | None = None,
  1280. is_longer: torch.BoolTensor | None = None,
  1281. attention_mask: torch.Tensor | None = None,
  1282. position_ids: torch.LongTensor | None = None,
  1283. return_loss: bool | None = None,
  1284. **kwargs: Unpack[TransformersKwargs],
  1285. ) -> tuple | ClapOutput:
  1286. r"""
  1287. is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
  1288. Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
  1289. the features.
  1290. return_loss (`bool`, *optional*):
  1291. Whether or not to return the contrastive loss.
  1292. Examples:
  1293. ```python
  1294. >>> from datasets import load_dataset
  1295. >>> from transformers import AutoProcessor, ClapModel
  1296. >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
  1297. >>> audio_sample = dataset["train"]["audio"][0]["array"]
  1298. >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
  1299. >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")
  1300. >>> input_text = ["Sound of a dog", "Sound of vacuum cleaner"]
  1301. >>> inputs = processor(text=input_text, audio=audio_sample, return_tensors="pt", padding=True)
  1302. >>> outputs = model(**inputs)
  1303. >>> logits_per_audio = outputs.logits_per_audio # this is the audio-text similarity score
  1304. >>> probs = logits_per_audio.softmax(dim=-1) # we can take the softmax to get the label probabilities
  1305. ```"""
  1306. audio_outputs = self.audio_model(
  1307. input_features=input_features,
  1308. is_longer=is_longer,
  1309. **kwargs,
  1310. )
  1311. text_outputs = self.text_model(
  1312. input_ids=input_ids,
  1313. attention_mask=attention_mask,
  1314. position_ids=position_ids,
  1315. **kwargs,
  1316. )
  1317. audio_embeds = audio_outputs.pooler_output
  1318. audio_embeds = self.audio_projection(audio_embeds)
  1319. text_embeds = text_outputs.pooler_output
  1320. text_embeds = self.text_projection(text_embeds)
  1321. # normalized features
  1322. audio_embeds = audio_embeds / audio_embeds.norm(p=2, dim=-1, keepdim=True)
  1323. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  1324. # cosine similarity as logits
  1325. logit_scale_text = self.logit_scale_t.exp()
  1326. logit_scale_audio = self.logit_scale_a.exp()
  1327. logits_per_text = torch.matmul(text_embeds, audio_embeds.t()) * logit_scale_text
  1328. logits_per_audio = torch.matmul(audio_embeds, text_embeds.t()) * logit_scale_audio
  1329. loss = None
  1330. if return_loss:
  1331. caption_loss = contrastive_loss(logits_per_text)
  1332. audio_loss = contrastive_loss(logits_per_audio.t())
  1333. loss = (caption_loss + audio_loss) / 2.0
  1334. return ClapOutput(
  1335. loss=loss,
  1336. logits_per_audio=logits_per_audio,
  1337. logits_per_text=logits_per_text,
  1338. text_embeds=text_embeds,
  1339. audio_embeds=audio_embeds,
  1340. text_model_output=text_outputs,
  1341. audio_model_output=audio_outputs,
  1342. )
  1343. @auto_docstring
  1344. class ClapTextModelWithProjection(ClapPreTrainedModel):
  1345. config: ClapTextConfig
  1346. input_modalities = ("text",)
  1347. _can_record_outputs = {
  1348. "hidden_states": ClapTextLayer,
  1349. "attentions": ClapTextSelfAttention,
  1350. }
  1351. def __init__(self, config: ClapTextConfig):
  1352. super().__init__(config)
  1353. self.text_model = ClapTextModel(config)
  1354. self.text_projection = ClapProjectionLayer(config)
  1355. # Initialize weights and apply final processing
  1356. self.post_init()
  1357. def get_input_embeddings(self) -> nn.Module:
  1358. return self.text_model.embeddings.word_embeddings
  1359. def set_input_embeddings(self, value):
  1360. self.text_model.embeddings.word_embeddings = value
  1361. @can_return_tuple
  1362. @auto_docstring
  1363. def forward(
  1364. self,
  1365. input_ids: torch.Tensor | None = None,
  1366. attention_mask: torch.Tensor | None = None,
  1367. position_ids: torch.Tensor | None = None,
  1368. **kwargs: Unpack[TransformersKwargs],
  1369. ) -> tuple | ClapTextModelOutput:
  1370. r"""
  1371. Examples:
  1372. ```python
  1373. >>> from transformers import AutoTokenizer, ClapTextModelWithProjection
  1374. >>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
  1375. >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
  1376. >>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt")
  1377. >>> outputs = model(**inputs)
  1378. >>> text_embeds = outputs.text_embeds
  1379. ```"""
  1380. text_outputs: BaseModelOutputWithPooling = self.text_model(
  1381. input_ids=input_ids,
  1382. attention_mask=attention_mask,
  1383. position_ids=position_ids,
  1384. **kwargs,
  1385. )
  1386. pooled_output = text_outputs.pooler_output
  1387. text_embeds = self.text_projection(pooled_output)
  1388. return ClapTextModelOutput(
  1389. text_embeds=text_embeds,
  1390. last_hidden_state=text_outputs.last_hidden_state,
  1391. hidden_states=text_outputs.hidden_states,
  1392. attentions=text_outputs.attentions,
  1393. )
  1394. @auto_docstring
  1395. class ClapAudioModelWithProjection(ClapPreTrainedModel):
  1396. config: ClapAudioConfig
  1397. main_input_name = "input_features"
  1398. input_modalities = "audio"
  1399. def __init__(self, config: ClapAudioConfig):
  1400. super().__init__(config)
  1401. self.audio_model = ClapAudioModel(config)
  1402. self.audio_projection = ClapProjectionLayer(config)
  1403. # Initialize weights and apply final processing
  1404. self.post_init()
  1405. def get_input_embeddings(self) -> nn.Module:
  1406. return self.audio_model.audio_encoder.patch_embed.proj
  1407. @can_return_tuple
  1408. @auto_docstring
  1409. def forward(
  1410. self,
  1411. input_features: torch.FloatTensor | None = None,
  1412. is_longer: torch.BoolTensor | None = None,
  1413. **kwargs: Unpack[TransformersKwargs],
  1414. ) -> tuple | ClapAudioModelOutput:
  1415. r"""
  1416. is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
  1417. Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
  1418. the features.
  1419. Examples:
  1420. ```python
  1421. >>> from datasets import load_dataset
  1422. >>> from transformers import ClapAudioModelWithProjection, ClapProcessor
  1423. >>> model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused")
  1424. >>> processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
  1425. >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
  1426. >>> audio_sample = dataset["train"]["audio"][0]["array"]
  1427. >>> inputs = processor(audio=audio_sample, return_tensors="pt")
  1428. >>> outputs = model(**inputs)
  1429. >>> audio_embeds = outputs.audio_embeds
  1430. ```"""
  1431. audio_outputs: BaseModelOutputWithPooling = self.audio_model(
  1432. input_features=input_features,
  1433. is_longer=is_longer,
  1434. **kwargs,
  1435. )
  1436. audio_embeds = self.audio_projection(audio_outputs.pooler_output)
  1437. return ClapAudioModelOutput(
  1438. audio_embeds=audio_embeds,
  1439. last_hidden_state=audio_outputs.last_hidden_state,
  1440. attentions=audio_outputs.attentions,
  1441. hidden_states=audio_outputs.hidden_states,
  1442. )
  1443. __all__ = [
  1444. "ClapModel",
  1445. "ClapPreTrainedModel",
  1446. "ClapTextModel",
  1447. "ClapTextModelWithProjection",
  1448. "ClapAudioModel",
  1449. "ClapAudioModelWithProjection",
  1450. ]