modeling_align.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180
  1. # Copyright 2023 The Google Research Team Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ALIGN model."""
  15. import math
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. from typing import Any
  19. import torch
  20. from torch import nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. BaseModelOutputWithNoAttention,
  27. BaseModelOutputWithPooling,
  28. BaseModelOutputWithPoolingAndNoAttention,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...pytorch_utils import apply_chunking_to_forward
  33. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  34. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  35. from ...utils.output_capturing import capture_outputs
  36. from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig
  37. logger = logging.get_logger(__name__)
  38. @dataclass
  39. @auto_docstring(
  40. custom_intro="""
  41. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  42. """
  43. )
  44. class AlignVisionModelOutput(ModelOutput):
  45. r"""
  46. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  47. The image embeddings obtained by applying the projection layer to the pooler_output.
  48. """
  49. image_embeds: torch.FloatTensor | None = None
  50. last_hidden_state: torch.FloatTensor | None = None
  51. hidden_states: tuple[torch.FloatTensor] | None = None
  52. @dataclass
  53. @auto_docstring(
  54. custom_intro="""
  55. Base class for text model's outputs that also contains a pooling of the last hidden states.
  56. """
  57. )
  58. class AlignTextModelOutput(ModelOutput):
  59. r"""
  60. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  61. The text embeddings obtained by applying the projection layer to the pooler_output.
  62. """
  63. text_embeds: torch.FloatTensor | None = None
  64. last_hidden_state: torch.FloatTensor | None = None
  65. hidden_states: tuple[torch.FloatTensor] | None = None
  66. attentions: tuple[torch.FloatTensor] | None = None
  67. @dataclass
  68. @auto_docstring
  69. class AlignOutput(ModelOutput):
  70. r"""
  71. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  72. Contrastive loss for image-text similarity.
  73. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  74. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  75. similarity scores.
  76. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  77. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  78. similarity scores.
  79. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  80. The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
  81. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  82. The output of [`AlignVisionModel`].
  83. text_model_output (`BaseModelOutputWithPooling`):
  84. The output of the [`AlignTextModel`].
  85. vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`):
  86. The output of the [`AlignVisionModel`].
  87. """
  88. loss: torch.FloatTensor | None = None
  89. logits_per_image: torch.FloatTensor | None = None
  90. logits_per_text: torch.FloatTensor | None = None
  91. text_embeds: torch.FloatTensor | None = None
  92. image_embeds: torch.FloatTensor | None = None
  93. text_model_output: BaseModelOutputWithPooling = None
  94. vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None
  95. def to_tuple(self) -> tuple[Any]:
  96. return tuple(
  97. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  98. for k in self.keys()
  99. )
  100. # contrastive loss function, adapted from
  101. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  102. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  103. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device), label_smoothing=0.1)
  104. def align_loss(similarity: torch.Tensor) -> torch.Tensor:
  105. caption_loss = contrastive_loss(similarity)
  106. image_loss = contrastive_loss(similarity.t())
  107. return (caption_loss + image_loss) / 2.0
  108. # Copied from transformers.models.efficientnet.modeling_efficientnet.round_filters with EfficientNet->AlignVision
  109. def round_filters(config: AlignVisionConfig, num_channels: int):
  110. r"""
  111. Round number of filters based on depth multiplier.
  112. """
  113. divisor = config.depth_divisor
  114. num_channels *= config.width_coefficient
  115. new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor)
  116. # Make sure that round down does not go down by more than 10%.
  117. if new_dim < 0.9 * num_channels:
  118. new_dim += divisor
  119. return int(new_dim)
  120. # Copied from transformers.models.efficientnet.modeling_efficientnet.correct_pad
  121. def correct_pad(kernel_size: int | tuple, adjust: bool = True):
  122. r"""
  123. Utility function to get the tuple padding value for the depthwise convolution.
  124. Args:
  125. kernel_size (`int` or `tuple`):
  126. Kernel size of the convolution layers.
  127. adjust (`bool`, *optional*, defaults to `True`):
  128. Adjusts padding value to apply to right and bottom sides of the input.
  129. """
  130. if isinstance(kernel_size, int):
  131. kernel_size = (kernel_size, kernel_size)
  132. correct = (kernel_size[0] // 2, kernel_size[1] // 2)
  133. if adjust:
  134. return (correct[1] - 1, correct[1], correct[0] - 1, correct[0])
  135. else:
  136. return (correct[1], correct[1], correct[0], correct[0])
  137. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetEmbeddings with EfficientNet->AlignVision
  138. class AlignVisionEmbeddings(nn.Module):
  139. r"""
  140. A module that corresponds to the stem module of the original work.
  141. """
  142. def __init__(self, config: AlignVisionConfig):
  143. super().__init__()
  144. self.out_dim = round_filters(config, 32)
  145. self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1))
  146. self.convolution = nn.Conv2d(
  147. config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False
  148. )
  149. self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum)
  150. self.activation = ACT2FN[config.hidden_act]
  151. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  152. features = self.padding(pixel_values)
  153. features = self.convolution(features)
  154. features = self.batchnorm(features)
  155. features = self.activation(features)
  156. return features
  157. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetDepthwiseConv2d with EfficientNet->AlignVision
  158. class AlignVisionDepthwiseConv2d(nn.Conv2d):
  159. def __init__(
  160. self,
  161. in_channels,
  162. depth_multiplier=1,
  163. kernel_size=3,
  164. stride=1,
  165. padding=0,
  166. dilation=1,
  167. bias=True,
  168. padding_mode="zeros",
  169. ):
  170. out_channels = in_channels * depth_multiplier
  171. super().__init__(
  172. in_channels=in_channels,
  173. out_channels=out_channels,
  174. kernel_size=kernel_size,
  175. stride=stride,
  176. padding=padding,
  177. dilation=dilation,
  178. groups=in_channels,
  179. bias=bias,
  180. padding_mode=padding_mode,
  181. )
  182. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetExpansionLayer with EfficientNet->AlignVision
  183. class AlignVisionExpansionLayer(nn.Module):
  184. r"""
  185. This corresponds to the expansion phase of each block in the original implementation.
  186. """
  187. def __init__(self, config: AlignVisionConfig, in_dim: int, out_dim: int, stride: int):
  188. super().__init__()
  189. self.expand_conv = nn.Conv2d(
  190. in_channels=in_dim,
  191. out_channels=out_dim,
  192. kernel_size=1,
  193. padding="same",
  194. bias=False,
  195. )
  196. self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps)
  197. self.expand_act = ACT2FN[config.hidden_act]
  198. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  199. # Expand phase
  200. hidden_states = self.expand_conv(hidden_states)
  201. hidden_states = self.expand_bn(hidden_states)
  202. hidden_states = self.expand_act(hidden_states)
  203. return hidden_states
  204. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetDepthwiseLayer with EfficientNet->AlignVision
  205. class AlignVisionDepthwiseLayer(nn.Module):
  206. r"""
  207. This corresponds to the depthwise convolution phase of each block in the original implementation.
  208. """
  209. def __init__(
  210. self,
  211. config: AlignVisionConfig,
  212. in_dim: int,
  213. stride: int,
  214. kernel_size: int,
  215. adjust_padding: bool,
  216. ):
  217. super().__init__()
  218. self.stride = stride
  219. conv_pad = "valid" if self.stride == 2 else "same"
  220. padding = correct_pad(kernel_size, adjust=adjust_padding)
  221. self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding)
  222. self.depthwise_conv = AlignVisionDepthwiseConv2d(
  223. in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False
  224. )
  225. self.depthwise_norm = nn.BatchNorm2d(
  226. num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  227. )
  228. self.depthwise_act = ACT2FN[config.hidden_act]
  229. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  230. # Depthwise convolution
  231. if self.stride == 2:
  232. hidden_states = self.depthwise_conv_pad(hidden_states)
  233. hidden_states = self.depthwise_conv(hidden_states)
  234. hidden_states = self.depthwise_norm(hidden_states)
  235. hidden_states = self.depthwise_act(hidden_states)
  236. return hidden_states
  237. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetSqueezeExciteLayer with EfficientNet->AlignVision
  238. class AlignVisionSqueezeExciteLayer(nn.Module):
  239. r"""
  240. This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
  241. """
  242. def __init__(self, config: AlignVisionConfig, in_dim: int, expand_dim: int, expand: bool = False):
  243. super().__init__()
  244. self.dim = expand_dim if expand else in_dim
  245. self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio))
  246. self.squeeze = nn.AdaptiveAvgPool2d(output_size=1)
  247. self.reduce = nn.Conv2d(
  248. in_channels=self.dim,
  249. out_channels=self.dim_se,
  250. kernel_size=1,
  251. padding="same",
  252. )
  253. self.expand = nn.Conv2d(
  254. in_channels=self.dim_se,
  255. out_channels=self.dim,
  256. kernel_size=1,
  257. padding="same",
  258. )
  259. self.act_reduce = ACT2FN[config.hidden_act]
  260. self.act_expand = nn.Sigmoid()
  261. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  262. inputs = hidden_states
  263. hidden_states = self.squeeze(hidden_states)
  264. hidden_states = self.reduce(hidden_states)
  265. hidden_states = self.act_reduce(hidden_states)
  266. hidden_states = self.expand(hidden_states)
  267. hidden_states = self.act_expand(hidden_states)
  268. hidden_states = torch.mul(inputs, hidden_states)
  269. return hidden_states
  270. class AlignVisionFinalBlockLayer(nn.Module):
  271. r"""
  272. This corresponds to the final phase of each block in the original implementation.
  273. """
  274. def __init__(
  275. self, config: AlignVisionConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool
  276. ):
  277. super().__init__()
  278. self.apply_dropout = stride == 1 and not id_skip
  279. self.project_conv = nn.Conv2d(
  280. in_channels=in_dim,
  281. out_channels=out_dim,
  282. kernel_size=1,
  283. padding="same",
  284. bias=False,
  285. )
  286. self.project_bn = nn.BatchNorm2d(
  287. num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  288. )
  289. self.dropout = nn.Dropout(p=drop_rate)
  290. def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor:
  291. hidden_states = self.project_conv(hidden_states)
  292. hidden_states = self.project_bn(hidden_states)
  293. if self.apply_dropout:
  294. hidden_states = self.dropout(hidden_states)
  295. hidden_states = hidden_states + embeddings
  296. return hidden_states
  297. class AlignVisionBlock(nn.Module):
  298. r"""
  299. This corresponds to the block module of original the EfficientNet vision encoder implementation.
  300. Args:
  301. config ([`AlignVisionConfig`]):
  302. Model configuration class.
  303. in_dim (`int`):
  304. Number of input channels.
  305. out_dim (`int`):
  306. Number of output channels.
  307. stride (`int`):
  308. Stride size to be used in convolution layers.
  309. expand_ratio (`int`):
  310. Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
  311. kernel_size (`int`):
  312. Kernel size for the depthwise convolution layer.
  313. drop_rate (`float`):
  314. Dropout rate to be used in the final phase of each block.
  315. id_skip (`bool`):
  316. Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
  317. of each block. Set to `True` for the first block of each stage.
  318. adjust_padding (`bool`):
  319. Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
  320. operation, set to `True` for inputs with odd input sizes.
  321. """
  322. def __init__(
  323. self,
  324. config: AlignVisionConfig,
  325. in_dim: int,
  326. out_dim: int,
  327. stride: int,
  328. expand_ratio: int,
  329. kernel_size: int,
  330. drop_rate: float,
  331. id_skip: bool,
  332. adjust_padding: bool,
  333. ):
  334. super().__init__()
  335. self.expand_ratio = expand_ratio
  336. self.expand = self.expand_ratio != 1
  337. expand_in_dim = in_dim * expand_ratio
  338. if self.expand:
  339. self.expansion = AlignVisionExpansionLayer(
  340. config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride
  341. )
  342. self.depthwise_conv = AlignVisionDepthwiseLayer(
  343. config=config,
  344. in_dim=expand_in_dim if self.expand else in_dim,
  345. stride=stride,
  346. kernel_size=kernel_size,
  347. adjust_padding=adjust_padding,
  348. )
  349. self.squeeze_excite = AlignVisionSqueezeExciteLayer(
  350. config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand
  351. )
  352. self.projection = AlignVisionFinalBlockLayer(
  353. config=config,
  354. in_dim=expand_in_dim if self.expand else in_dim,
  355. out_dim=out_dim,
  356. stride=stride,
  357. drop_rate=drop_rate,
  358. id_skip=id_skip,
  359. )
  360. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  361. embeddings = hidden_states
  362. # Expansion and depthwise convolution phase
  363. if self.expand_ratio != 1:
  364. hidden_states = self.expansion(hidden_states)
  365. hidden_states = self.depthwise_conv(hidden_states)
  366. # Squeeze and excite phase
  367. hidden_states = self.squeeze_excite(hidden_states)
  368. hidden_states = self.projection(embeddings, hidden_states)
  369. return hidden_states
  370. class AlignVisionEncoder(nn.Module):
  371. r"""
  372. Forward propagates the embeddings through each vision encoder (EfficientNet) block.
  373. Args:
  374. config ([`AlignVisionConfig`]):
  375. Model configuration class.
  376. """
  377. def __init__(self, config: AlignVisionConfig):
  378. super().__init__()
  379. self.depth_coefficient = config.depth_coefficient
  380. def round_repeats(repeats):
  381. # Round number of block repeats based on depth multiplier.
  382. return int(math.ceil(self.depth_coefficient * repeats))
  383. num_base_blocks = len(config.in_channels)
  384. num_blocks = sum(round_repeats(n) for n in config.num_block_repeats)
  385. curr_block_num = 0
  386. blocks = []
  387. for i in range(num_base_blocks):
  388. in_dim = round_filters(config, config.in_channels[i])
  389. out_dim = round_filters(config, config.out_channels[i])
  390. stride = config.strides[i]
  391. kernel_size = config.kernel_sizes[i]
  392. expand_ratio = config.expand_ratios[i]
  393. for j in range(round_repeats(config.num_block_repeats[i])):
  394. id_skip = j == 0
  395. stride = 1 if j > 0 else stride
  396. in_dim = out_dim if j > 0 else in_dim
  397. adjust_padding = curr_block_num not in config.depthwise_padding
  398. drop_rate = config.drop_connect_rate * curr_block_num / num_blocks
  399. block = AlignVisionBlock(
  400. config=config,
  401. in_dim=in_dim,
  402. out_dim=out_dim,
  403. stride=stride,
  404. kernel_size=kernel_size,
  405. expand_ratio=expand_ratio,
  406. drop_rate=drop_rate,
  407. id_skip=id_skip,
  408. adjust_padding=adjust_padding,
  409. )
  410. blocks.append(block)
  411. curr_block_num += 1
  412. self.blocks = nn.ModuleList(blocks)
  413. def forward(
  414. self,
  415. hidden_states: torch.FloatTensor,
  416. **kwargs: Unpack[TransformersKwargs],
  417. ) -> BaseModelOutputWithNoAttention:
  418. for block in self.blocks:
  419. hidden_states = block(hidden_states)
  420. return BaseModelOutputWithNoAttention(
  421. last_hidden_state=hidden_states,
  422. )
  423. class AlignTextEmbeddings(nn.Module):
  424. """Construct the embeddings from word, position and token_type embeddings."""
  425. def __init__(self, config):
  426. super().__init__()
  427. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  428. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  429. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  430. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  431. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  432. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  433. self.register_buffer(
  434. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  435. )
  436. self.register_buffer(
  437. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  438. )
  439. def forward(
  440. self,
  441. input_ids: torch.LongTensor | None = None,
  442. token_type_ids: torch.LongTensor | None = None,
  443. position_ids: torch.LongTensor | None = None,
  444. inputs_embeds: torch.FloatTensor | None = None,
  445. ) -> torch.Tensor:
  446. if input_ids is not None:
  447. input_shape = input_ids.size()
  448. else:
  449. input_shape = inputs_embeds.size()[:-1]
  450. seq_length = input_shape[1]
  451. if position_ids is None:
  452. position_ids = self.position_ids[:, :seq_length]
  453. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  454. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  455. # issue #5664
  456. if token_type_ids is None:
  457. if hasattr(self, "token_type_ids"):
  458. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  459. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  460. token_type_ids = buffered_token_type_ids_expanded
  461. else:
  462. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  463. if inputs_embeds is None:
  464. inputs_embeds = self.word_embeddings(input_ids)
  465. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  466. embeddings = inputs_embeds + token_type_embeddings
  467. position_embeddings = self.position_embeddings(position_ids)
  468. embeddings += position_embeddings
  469. embeddings = self.LayerNorm(embeddings)
  470. embeddings = self.dropout(embeddings)
  471. return embeddings
  472. def eager_attention_forward(
  473. module: nn.Module,
  474. query: torch.Tensor,
  475. key: torch.Tensor,
  476. value: torch.Tensor,
  477. attention_mask: torch.Tensor | None,
  478. scaling: float,
  479. dropout: float = 0.0,
  480. **kwargs,
  481. ):
  482. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  483. if attention_mask is not None:
  484. attn_weights = attn_weights + attention_mask
  485. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  486. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  487. attn_output = torch.matmul(attn_weights, value)
  488. attn_output = attn_output.transpose(1, 2).contiguous()
  489. return attn_output, attn_weights
  490. class AlignTextSelfAttention(nn.Module):
  491. def __init__(self, config):
  492. super().__init__()
  493. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  494. raise ValueError(
  495. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  496. f"heads ({config.num_attention_heads})"
  497. )
  498. self.config = config
  499. self.num_attention_heads = config.num_attention_heads
  500. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  501. self.all_head_size = self.num_attention_heads * self.attention_head_size
  502. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  503. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  504. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  505. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  506. self.attention_dropout = config.attention_probs_dropout_prob
  507. self.scaling = self.attention_head_size**-0.5
  508. def forward(
  509. self,
  510. hidden_states: torch.Tensor,
  511. attention_mask: torch.FloatTensor | None = None,
  512. **kwargs: Unpack[TransformersKwargs],
  513. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  514. input_shape = hidden_states.shape[:-1]
  515. hidden_shape = (*input_shape, -1, self.attention_head_size)
  516. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  517. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  518. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  519. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  520. self.config._attn_implementation, eager_attention_forward
  521. )
  522. attn_output, attn_weights = attention_interface(
  523. self,
  524. query_states,
  525. key_states,
  526. value_states,
  527. attention_mask,
  528. dropout=0.0 if not self.training else self.attention_dropout,
  529. scaling=self.scaling,
  530. **kwargs,
  531. )
  532. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  533. return attn_output, attn_weights
  534. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->AlignText
  535. class AlignTextSelfOutput(nn.Module):
  536. def __init__(self, config):
  537. super().__init__()
  538. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  539. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  540. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  541. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  542. hidden_states = self.dense(hidden_states)
  543. hidden_states = self.dropout(hidden_states)
  544. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  545. return hidden_states
  546. class AlignTextAttention(nn.Module):
  547. def __init__(self, config):
  548. super().__init__()
  549. self.self = AlignTextSelfAttention(config)
  550. self.output = AlignTextSelfOutput(config)
  551. def forward(
  552. self,
  553. hidden_states: torch.Tensor,
  554. attention_mask: torch.FloatTensor | None = None,
  555. **kwargs: Unpack[TransformersKwargs],
  556. ) -> torch.Tensor:
  557. residual = hidden_states
  558. hidden_states, _ = self.self(
  559. hidden_states,
  560. attention_mask=attention_mask,
  561. **kwargs,
  562. )
  563. hidden_states = self.output(hidden_states, residual)
  564. return hidden_states
  565. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->AlignText
  566. class AlignTextIntermediate(nn.Module):
  567. def __init__(self, config):
  568. super().__init__()
  569. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  570. if isinstance(config.hidden_act, str):
  571. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  572. else:
  573. self.intermediate_act_fn = config.hidden_act
  574. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  575. hidden_states = self.dense(hidden_states)
  576. hidden_states = self.intermediate_act_fn(hidden_states)
  577. return hidden_states
  578. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->AlignText
  579. class AlignTextOutput(nn.Module):
  580. def __init__(self, config):
  581. super().__init__()
  582. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  583. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  584. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  585. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  586. hidden_states = self.dense(hidden_states)
  587. hidden_states = self.dropout(hidden_states)
  588. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  589. return hidden_states
  590. class AlignTextLayer(GradientCheckpointingLayer):
  591. def __init__(self, config):
  592. super().__init__()
  593. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  594. self.seq_len_dim = 1
  595. self.attention = AlignTextAttention(config)
  596. self.intermediate = AlignTextIntermediate(config)
  597. self.output = AlignTextOutput(config)
  598. def forward(
  599. self,
  600. hidden_states: torch.Tensor,
  601. attention_mask: torch.FloatTensor | None = None,
  602. **kwargs: Unpack[TransformersKwargs],
  603. ) -> torch.Tensor:
  604. hidden_states = self.attention(
  605. hidden_states,
  606. attention_mask=attention_mask,
  607. **kwargs,
  608. )
  609. hidden_states = apply_chunking_to_forward(
  610. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, hidden_states
  611. )
  612. return hidden_states
  613. def feed_forward_chunk(self, attention_output):
  614. intermediate_output = self.intermediate(attention_output)
  615. layer_output = self.output(intermediate_output, attention_output)
  616. return layer_output
  617. class AlignTextEncoder(nn.Module):
  618. def __init__(self, config):
  619. super().__init__()
  620. self.config = config
  621. self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)])
  622. self.gradient_checkpointing = False
  623. def forward(
  624. self,
  625. hidden_states: torch.Tensor,
  626. attention_mask: torch.FloatTensor | None = None,
  627. **kwargs: Unpack[TransformersKwargs],
  628. ) -> BaseModelOutput:
  629. for layer_module in self.layer:
  630. hidden_states = layer_module(
  631. hidden_states,
  632. attention_mask,
  633. **kwargs,
  634. )
  635. return BaseModelOutput(
  636. last_hidden_state=hidden_states,
  637. )
  638. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert -> AlignText
  639. class AlignTextPooler(nn.Module):
  640. def __init__(self, config):
  641. super().__init__()
  642. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  643. self.activation = nn.Tanh()
  644. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  645. # We "pool" the model by simply taking the hidden state corresponding
  646. # to the first token.
  647. first_token_tensor = hidden_states[:, 0]
  648. pooled_output = self.dense(first_token_tensor)
  649. pooled_output = self.activation(pooled_output)
  650. return pooled_output
  651. @auto_docstring
  652. class AlignPreTrainedModel(PreTrainedModel):
  653. config: AlignConfig
  654. base_model_prefix = "align"
  655. input_modalities = ("image", "text")
  656. supports_gradient_checkpointing = True
  657. @torch.no_grad()
  658. def _init_weights(self, module: nn.Module):
  659. """Initialize the weights"""
  660. std = self.config.initializer_range
  661. if isinstance(module, (nn.Linear, nn.Conv2d)):
  662. init.normal_(module.weight, mean=0.0, std=std)
  663. if module.bias is not None:
  664. init.zeros_(module.bias)
  665. elif isinstance(module, AlignModel):
  666. init.xavier_uniform_(module.text_projection.weight)
  667. init.zeros_(module.text_projection.bias)
  668. init.constant_(module.temperature, self.config.temperature_init_value)
  669. elif isinstance(module, nn.Embedding):
  670. init.normal_(module.weight, mean=0.0, std=std)
  671. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  672. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  673. init.zeros_(module.weight[module.padding_idx])
  674. if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
  675. init.zeros_(module.bias)
  676. init.ones_(module.weight)
  677. if getattr(module, "running_mean", None) is not None:
  678. init.zeros_(module.running_mean)
  679. init.ones_(module.running_var)
  680. init.zeros_(module.num_batches_tracked)
  681. elif isinstance(module, AlignTextEmbeddings):
  682. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  683. init.zeros_(module.token_type_ids)
  684. @auto_docstring(
  685. custom_intro="""
  686. The text model from ALIGN without any head or projection on top.
  687. """
  688. )
  689. class AlignTextModel(AlignPreTrainedModel):
  690. config: AlignTextConfig
  691. input_modalities = ("text",)
  692. _no_split_modules = ["AlignTextEmbeddings"]
  693. _can_record_outputs = {
  694. "hidden_states": AlignTextLayer,
  695. "attentions": AlignTextSelfAttention,
  696. }
  697. def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True):
  698. r"""
  699. add_pooling_layer (bool, *optional*, defaults to `True`):
  700. Whether to add a pooling layer
  701. """
  702. super().__init__(config)
  703. self.config = config
  704. self.embeddings = AlignTextEmbeddings(config)
  705. self.encoder = AlignTextEncoder(config)
  706. self.pooler = AlignTextPooler(config) if add_pooling_layer else None
  707. # Initialize weights and apply final processing
  708. self.post_init()
  709. def get_input_embeddings(self):
  710. return self.embeddings.word_embeddings
  711. def set_input_embeddings(self, value):
  712. self.embeddings.word_embeddings = value
  713. @merge_with_config_defaults
  714. @capture_outputs
  715. @auto_docstring
  716. def forward(
  717. self,
  718. input_ids: torch.Tensor | None = None,
  719. attention_mask: torch.Tensor | None = None,
  720. token_type_ids: torch.Tensor | None = None,
  721. position_ids: torch.Tensor | None = None,
  722. inputs_embeds: torch.Tensor | None = None,
  723. **kwargs: Unpack[TransformersKwargs],
  724. ) -> tuple | BaseModelOutputWithPooling:
  725. r"""
  726. Examples:
  727. ```python
  728. >>> from transformers import AutoTokenizer, AlignTextModel
  729. >>> model = AlignTextModel.from_pretrained("kakaobrain/align-base")
  730. >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")
  731. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  732. >>> outputs = model(**inputs)
  733. >>> last_hidden_state = outputs.last_hidden_state
  734. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  735. ```"""
  736. if input_ids is not None and inputs_embeds is not None:
  737. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  738. elif input_ids is not None:
  739. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  740. input_shape = input_ids.size()
  741. elif inputs_embeds is not None:
  742. input_shape = inputs_embeds.size()[:-1]
  743. else:
  744. raise ValueError("You have to specify either input_ids or inputs_embeds")
  745. batch_size, seq_length = input_shape
  746. device = input_ids.device if input_ids is not None else inputs_embeds.device
  747. if attention_mask is None:
  748. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  749. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  750. # ourselves in which case we just need to make it broadcastable to all heads.
  751. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  752. embedding_output = self.embeddings(
  753. input_ids=input_ids,
  754. position_ids=position_ids,
  755. token_type_ids=token_type_ids,
  756. inputs_embeds=inputs_embeds,
  757. )
  758. encoder_outputs = self.encoder(
  759. embedding_output,
  760. attention_mask=extended_attention_mask,
  761. **kwargs,
  762. )
  763. sequence_output = encoder_outputs[0]
  764. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  765. return BaseModelOutputWithPooling(
  766. last_hidden_state=sequence_output,
  767. pooler_output=pooled_output,
  768. )
  769. @auto_docstring(
  770. custom_intro="""
  771. The vision model from ALIGN without any head or projection on top.
  772. """
  773. )
  774. class AlignVisionModel(AlignPreTrainedModel):
  775. config: AlignVisionConfig
  776. main_input_name = "pixel_values"
  777. input_modalities = ("image",)
  778. supports_gradient_checkpointing = False
  779. _input_embed_layer = "convolution"
  780. _no_split_modules = ["AlignVisionBlock"]
  781. _can_record_outputs = {
  782. "hidden_states": AlignVisionBlock,
  783. }
  784. def __init__(self, config: AlignVisionConfig):
  785. super().__init__(config)
  786. self.config = config
  787. self.embeddings = AlignVisionEmbeddings(config)
  788. self.encoder = AlignVisionEncoder(config)
  789. # Final pooling layer
  790. if config.pooling_type == "mean":
  791. self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True)
  792. elif config.pooling_type == "max":
  793. self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True)
  794. else:
  795. raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}")
  796. # Initialize weights and apply final processing
  797. self.post_init()
  798. @merge_with_config_defaults
  799. @capture_outputs
  800. @auto_docstring
  801. def forward(
  802. self,
  803. pixel_values: torch.FloatTensor | None = None,
  804. **kwargs: Unpack[TransformersKwargs],
  805. ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
  806. r"""
  807. Examples:
  808. ```python
  809. >>> from PIL import Image
  810. >>> import httpx
  811. >>> from io import BytesIO
  812. >>> from transformers import AutoProcessor, AlignVisionModel
  813. >>> model = AlignVisionModel.from_pretrained("kakaobrain/align-base")
  814. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  815. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  816. >>> with httpx.stream("GET", url) as response:
  817. ... image = Image.open(BytesIO(response.read()))
  818. >>> inputs = processor(images=image, return_tensors="pt")
  819. >>> outputs = model(**inputs)
  820. >>> last_hidden_state = outputs.last_hidden_state
  821. >>> pooled_output = outputs.pooler_output # pooled CLS states
  822. ```"""
  823. if pixel_values is None:
  824. raise ValueError("You have to specify pixel_values")
  825. embedding_output = self.embeddings(pixel_values)
  826. encoder_outputs = self.encoder(
  827. embedding_output,
  828. **kwargs,
  829. )
  830. last_hidden_state = encoder_outputs[0]
  831. pooled_output = self.pooler(last_hidden_state)
  832. pooled_output = pooled_output.reshape(pooled_output.shape[:2])
  833. return BaseModelOutputWithPoolingAndNoAttention(
  834. last_hidden_state=last_hidden_state,
  835. pooler_output=pooled_output,
  836. )
  837. @auto_docstring
  838. class AlignModel(AlignPreTrainedModel):
  839. config: AlignConfig
  840. def __init__(self, config: AlignConfig):
  841. super().__init__(config)
  842. if not isinstance(config.text_config, AlignTextConfig):
  843. raise TypeError(
  844. "config.text_config is expected to be of type AlignTextConfig but is of type"
  845. f" {type(config.text_config)}."
  846. )
  847. if not isinstance(config.vision_config, AlignVisionConfig):
  848. raise TypeError(
  849. "config.vision_config is expected to be of type AlignVisionConfig but is of type"
  850. f" {type(config.vision_config)}."
  851. )
  852. text_config = config.text_config
  853. vision_config = config.vision_config
  854. self.projection_dim = config.projection_dim
  855. self.text_embed_dim = text_config.hidden_size
  856. self.text_model = AlignTextModel(text_config)
  857. self.vision_model = AlignVisionModel(vision_config)
  858. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim)
  859. self.temperature = nn.Parameter(torch.tensor(self.config.temperature_init_value))
  860. # Initialize weights and apply final processing
  861. self.post_init()
  862. @can_return_tuple
  863. @auto_docstring
  864. def get_text_features(
  865. self,
  866. input_ids: torch.Tensor | None = None,
  867. attention_mask: torch.Tensor | None = None,
  868. token_type_ids: torch.Tensor | None = None,
  869. position_ids: torch.Tensor | None = None,
  870. inputs_embeds: torch.Tensor | None = None,
  871. **kwargs: Unpack[TransformersKwargs],
  872. ) -> tuple | BaseModelOutputWithPooling:
  873. r"""
  874. Examples:
  875. ```python
  876. >>> import torch
  877. >>> from transformers import AutoTokenizer, AlignModel
  878. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  879. >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")
  880. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  881. >>> with torch.inference_mode():
  882. ... text_features = model.get_text_features(**inputs)
  883. ```"""
  884. text_outputs: BaseModelOutputWithPooling = self.text_model(
  885. input_ids=input_ids,
  886. attention_mask=attention_mask,
  887. token_type_ids=token_type_ids,
  888. position_ids=position_ids,
  889. inputs_embeds=inputs_embeds,
  890. **kwargs,
  891. )
  892. last_hidden_state = text_outputs[0][:, 0, :]
  893. text_outputs.pooler_output = self.text_projection(last_hidden_state)
  894. return text_outputs
  895. @can_return_tuple
  896. @auto_docstring
  897. def get_image_features(
  898. self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
  899. ) -> tuple | BaseModelOutputWithPooling:
  900. r"""
  901. Examples:
  902. ```python
  903. >>> import torch
  904. >>> from transformers import AutoProcessor, AlignModel
  905. >>> from transformers.image_utils import load_image
  906. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  907. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  908. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  909. >>> image = load_image(url)
  910. >>> inputs = processor(images=image, return_tensors="pt")
  911. >>> with torch.inference_mode():
  912. ... image_features = model.get_image_features(**inputs)
  913. ```"""
  914. return self.vision_model(pixel_values=pixel_values, **kwargs)
  915. @can_return_tuple
  916. @auto_docstring
  917. def forward(
  918. self,
  919. input_ids: torch.LongTensor | None = None,
  920. pixel_values: torch.FloatTensor | None = None,
  921. attention_mask: torch.Tensor | None = None,
  922. token_type_ids: torch.Tensor | None = None,
  923. position_ids: torch.Tensor | None = None,
  924. inputs_embeds: torch.Tensor | None = None,
  925. return_loss: bool | None = None,
  926. **kwargs: Unpack[TransformersKwargs],
  927. ) -> tuple | AlignOutput:
  928. r"""
  929. return_loss (`bool`, *optional*):
  930. Whether or not to return the contrastive loss.
  931. Examples:
  932. ```python
  933. >>> import torch
  934. >>> from transformers import AutoProcessor, AlignModel
  935. >>> from transformers.image_utils import load_image
  936. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  937. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  938. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  939. >>> image = load_image(url)
  940. >>> inputs = processor(
  941. ... images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True
  942. ... )
  943. >>> with torch.inference_mode():
  944. ... outputs = model(**inputs)
  945. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  946. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  947. ```"""
  948. vision_outputs = self.vision_model(
  949. pixel_values=pixel_values,
  950. **kwargs,
  951. )
  952. text_outputs = self.text_model(
  953. input_ids=input_ids,
  954. attention_mask=attention_mask,
  955. token_type_ids=token_type_ids,
  956. position_ids=position_ids,
  957. inputs_embeds=inputs_embeds,
  958. **kwargs,
  959. )
  960. image_embeds = vision_outputs[1]
  961. text_embeds = text_outputs[0][:, 0, :]
  962. text_embeds = self.text_projection(text_embeds)
  963. # normalized features
  964. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  965. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  966. # cosine similarity as logits
  967. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) / self.temperature
  968. logits_per_image = logits_per_text.t()
  969. loss = None
  970. if return_loss:
  971. loss = align_loss(logits_per_text)
  972. return AlignOutput(
  973. loss=loss,
  974. logits_per_image=logits_per_image,
  975. logits_per_text=logits_per_text,
  976. text_embeds=text_embeds,
  977. image_embeds=image_embeds,
  978. text_model_output=text_outputs,
  979. vision_model_output=vision_outputs,
  980. )
  981. __all__ = ["AlignPreTrainedModel", "AlignTextModel", "AlignVisionModel", "AlignModel"]