modeling_bridgetower.py 74 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761
  1. # Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch BridgeTower Model"""
  15. from collections import OrderedDict
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN, QuickGELUActivation
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. BaseModelOutputWithPoolingAndCrossAttentions,
  29. MaskedLMOutput,
  30. ModelOutput,
  31. SequenceClassifierOutput,
  32. )
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...pytorch_utils import apply_chunking_to_forward
  36. from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
  37. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  38. from ...utils.output_capturing import capture_outputs
  39. from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
  40. logger = logging.get_logger(__name__)
  41. _TOKENIZER_FOR_DOC = "RobertaTokenizer"
  42. @dataclass
  43. @auto_docstring(
  44. custom_intro="""
  45. Output type of [`BridgeTowerModel`].
  46. """
  47. )
  48. class BridgeTowerModelOutput(ModelOutput):
  49. r"""
  50. text_features (`torch.FloatTensor` of shape `(batch_size, text_sequence_length, hidden_size)`):
  51. Sequence of hidden-states at the text output of the last layer of the model.
  52. image_features (`torch.FloatTensor` of shape `(batch_size, image_sequence_length, hidden_size)`):
  53. Sequence of hidden-states at the image output of the last layer of the model.
  54. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size x 2)`):
  55. Concatenation of last layer hidden-state of the first token of the text and image sequence (classification
  56. token), respectively, after further processing through layers used for auxiliary pretraining tasks.
  57. """
  58. text_features: torch.FloatTensor | None = None
  59. image_features: torch.FloatTensor | None = None
  60. pooler_output: torch.FloatTensor | None = None
  61. hidden_states: tuple[torch.FloatTensor] | None = None
  62. attentions: tuple[torch.FloatTensor] | None = None
  63. @dataclass
  64. @auto_docstring(
  65. custom_intro="""
  66. Output type of ['BridgeTowerForContrastiveLearning']
  67. """
  68. )
  69. class BridgeTowerContrastiveOutput(ModelOutput):
  70. r"""
  71. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  72. Image-text contrastive loss.
  73. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  74. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  75. text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
  76. The text embeddings obtained by applying the projection layer to the pooler_output.
  77. image_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
  78. The image embeddings obtained by applying the projection layer to the pooler_output.
  79. cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
  80. The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
  81. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  82. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  83. sequence_length)`.
  84. """
  85. loss: torch.FloatTensor | None = None
  86. logits: torch.FloatTensor | None = None
  87. text_embeds: tuple[torch.FloatTensor] | None = None
  88. image_embeds: tuple[torch.FloatTensor] | None = None
  89. cross_embeds: tuple[torch.FloatTensor] | None = None
  90. hidden_states: tuple[torch.FloatTensor] | None = None
  91. attentions: tuple[torch.FloatTensor] | None = None
  92. class BridgeTowerResidualAttention(nn.Module):
  93. def __init__(self, config):
  94. super().__init__()
  95. self.attn = nn.MultiheadAttention(config.hidden_size, config.hidden_size // 64)
  96. self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  97. self.mlp = nn.ModuleDict(
  98. OrderedDict(
  99. [
  100. ("c_fc", nn.Linear(config.hidden_size, config.hidden_size * 4)),
  101. ("gelu", QuickGELUActivation()),
  102. ("c_proj", nn.Linear(config.hidden_size * 4, config.hidden_size)),
  103. ]
  104. )
  105. )
  106. self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  107. self.attn_mask = None
  108. def attention(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor):
  109. if attention_mask is not None:
  110. attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_state.device)
  111. self.attn_mask = (
  112. self.attn_mask.to(dtype=hidden_state.dtype, device=hidden_state.device)
  113. if self.attn_mask is not None
  114. else None
  115. )
  116. return self.attn(
  117. hidden_state,
  118. hidden_state,
  119. hidden_state,
  120. need_weights=False,
  121. attn_mask=self.attn_mask,
  122. key_padding_mask=attention_mask,
  123. )[0]
  124. def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None):
  125. residual_state = hidden_state + self.attention(self.ln_1(hidden_state), attention_mask)
  126. hidden_state = self.ln_2(residual_state)
  127. for layer in self.mlp.values():
  128. hidden_state = layer(hidden_state)
  129. hidden_state = residual_state + hidden_state
  130. return hidden_state
  131. class BridgeTowerTransformer(nn.Module):
  132. def __init__(self, config):
  133. super().__init__()
  134. self.hidden_size = config.hidden_size
  135. self.num_hidden_layers = config.num_hidden_layers
  136. if config.remove_last_layer:
  137. self.resblocks = nn.ModuleList(
  138. [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers - 1)]
  139. )
  140. else:
  141. self.resblocks = nn.ModuleList(
  142. [BridgeTowerResidualAttention(config) for _ in range(self.num_hidden_layers)]
  143. )
  144. self.stop_gradient = config.stop_gradient
  145. def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None):
  146. hidden_states = []
  147. for block in self.resblocks:
  148. hidden_state = block(hidden_state, attention_mask)
  149. if self.stop_gradient:
  150. hidden_states.append(hidden_state.detach())
  151. else:
  152. hidden_states.append(hidden_state)
  153. return hidden_states
  154. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->BridgeTower
  155. class BridgeTowerVisionEmbeddings(nn.Module):
  156. def __init__(self, config: BridgeTowerVisionConfig):
  157. super().__init__()
  158. self.config = config
  159. self.embed_dim = config.hidden_size
  160. self.image_size = config.image_size
  161. self.patch_size = config.patch_size
  162. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  163. self.patch_embedding = nn.Conv2d(
  164. in_channels=config.num_channels,
  165. out_channels=self.embed_dim,
  166. kernel_size=self.patch_size,
  167. stride=self.patch_size,
  168. bias=False,
  169. )
  170. self.num_patches = (self.image_size // self.patch_size) ** 2
  171. self.num_positions = self.num_patches + 1
  172. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  173. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  174. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  175. """
  176. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  177. images. This method is also adapted to support torch.jit tracing.
  178. Adapted from:
  179. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  180. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  181. """
  182. num_patches = embeddings.shape[1] - 1
  183. position_embedding = self.position_embedding.weight.unsqueeze(0)
  184. num_positions = position_embedding.shape[1] - 1
  185. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  186. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  187. return self.position_embedding(self.position_ids)
  188. class_pos_embed = position_embedding[:, :1]
  189. patch_pos_embed = position_embedding[:, 1:]
  190. dim = embeddings.shape[-1]
  191. new_height = height // self.patch_size
  192. new_width = width // self.patch_size
  193. sqrt_num_positions = torch_int(num_positions**0.5)
  194. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  195. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  196. patch_pos_embed = nn.functional.interpolate(
  197. patch_pos_embed,
  198. size=(new_height, new_width),
  199. mode="bicubic",
  200. align_corners=False,
  201. )
  202. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  203. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  204. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  205. batch_size, _, height, width = pixel_values.shape
  206. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  207. raise ValueError(
  208. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  209. )
  210. target_dtype = self.patch_embedding.weight.dtype
  211. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  212. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  213. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  214. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  215. if interpolate_pos_encoding:
  216. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  217. else:
  218. embeddings = embeddings + self.position_embedding(self.position_ids)
  219. return embeddings
  220. class BridgeTowerVisionTransformer(nn.Module):
  221. def __init__(self, config):
  222. super().__init__()
  223. self.embeddings = BridgeTowerVisionEmbeddings(config)
  224. self.ln_pre = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  225. self.transformer = BridgeTowerTransformer(config)
  226. self.ln_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  227. self.share_layernorm = config.share_layernorm
  228. if not config.share_layernorm:
  229. self.ln_separate = nn.ModuleList(
  230. [nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
  231. )
  232. def forward(
  233. self,
  234. pixel_values: torch.Tensor,
  235. attention_mask,
  236. interpolate_pos_encoding: bool = False,
  237. ):
  238. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding)
  239. hidden_states = self.ln_pre(hidden_states)
  240. # NLD -> LND
  241. hidden_states = hidden_states.permute(1, 0, 2)
  242. hidden_states = self.transformer(hidden_states, attention_mask)
  243. # shape = [num_hidden_layers, hidden_size, *, grid ** 2]
  244. hidden_states = torch.stack(hidden_states, dim=0)
  245. # shape = [num_hidden_layers, *, hidden_size, grid ** 2]
  246. hidden_states = hidden_states.permute(0, 2, 1, 3)
  247. if self.share_layernorm:
  248. hidden_states = self.ln_post(hidden_states)
  249. else:
  250. hidden_states_stack = []
  251. for hidden_states, ln in zip(hidden_states, self.ln_separate):
  252. hidden_states = ln(hidden_states)
  253. hidden_states_stack.append(hidden_states)
  254. # shape = [num_hidden_layers, *, hidden_size, grid ** 2]
  255. hidden_states = torch.stack(hidden_states_stack, dim=0)
  256. return hidden_states
  257. def forward_pre(
  258. self,
  259. pixel_values: torch.Tensor,
  260. interpolate_pos_encoding: bool = False,
  261. ):
  262. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  263. hidden_states = self.ln_pre(hidden_states)
  264. # NLD -> LND
  265. hidden_states = hidden_states.permute(1, 0, 2)
  266. return hidden_states
  267. def forward_post(self, hidden_state: torch.Tensor):
  268. visual_output_post = hidden_state.permute(1, 0, 2)
  269. visual_output_post = self.ln_post(visual_output_post)
  270. return visual_output_post
  271. class BridgeTowerLinkTower(nn.Module):
  272. def __init__(self, config):
  273. super().__init__()
  274. self.link_tower_type = config.link_tower_type
  275. self.hidden_size = config.hidden_size
  276. if config.link_tower_type in ["add", "scaled_add", "interpolate"]:
  277. if config.link_tower_type == "scaled_add":
  278. self.scaled_factor = nn.Parameter(torch.tensor(1.0))
  279. elif config.link_tower_type == "interpolate":
  280. self.beta = nn.Parameter(torch.tensor(0.5))
  281. self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
  282. else:
  283. raise NotImplementedError(f"link_tower_type {config.link_tower_type} is not implemented")
  284. def forward(self, hidden_states, cross_modal_hidden_states, attention_mask):
  285. if self.link_tower_type == "add":
  286. return self.LayerNorm(hidden_states + cross_modal_hidden_states)
  287. elif self.link_tower_type == "scaled_add":
  288. return self.LayerNorm(hidden_states * self.scaled_factor + cross_modal_hidden_states)
  289. elif self.link_tower_type == "interpolate":
  290. return self.LayerNorm(hidden_states * (1 - self.beta) + cross_modal_hidden_states * self.beta)
  291. else:
  292. raise NotImplementedError(f"link_tower_type {self.link_tower_type} is not implemented")
  293. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->BridgeTower
  294. class BridgeTowerSelfOutput(nn.Module):
  295. def __init__(self, config):
  296. super().__init__()
  297. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  298. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  299. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  300. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  301. hidden_states = self.dense(hidden_states)
  302. hidden_states = self.dropout(hidden_states)
  303. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  304. return hidden_states
  305. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->BridgeTower
  306. class BridgeTowerIntermediate(nn.Module):
  307. def __init__(self, config):
  308. super().__init__()
  309. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  310. if isinstance(config.hidden_act, str):
  311. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  312. else:
  313. self.intermediate_act_fn = config.hidden_act
  314. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  315. hidden_states = self.dense(hidden_states)
  316. hidden_states = self.intermediate_act_fn(hidden_states)
  317. return hidden_states
  318. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->BridgeTower
  319. class BridgeTowerOutput(nn.Module):
  320. def __init__(self, config):
  321. super().__init__()
  322. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  323. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  324. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  325. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  326. hidden_states = self.dense(hidden_states)
  327. hidden_states = self.dropout(hidden_states)
  328. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  329. return hidden_states
  330. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->BridgeTower
  331. class BridgeTowerPooler(nn.Module):
  332. def __init__(self, config):
  333. super().__init__()
  334. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  335. self.activation = nn.Tanh()
  336. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  337. # We "pool" the model by simply taking the hidden state corresponding
  338. # to the first token.
  339. first_token_tensor = hidden_states[:, 0]
  340. pooled_output = self.dense(first_token_tensor)
  341. pooled_output = self.activation(pooled_output)
  342. return pooled_output
  343. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  344. def eager_attention_forward(
  345. module: nn.Module,
  346. query: torch.Tensor,
  347. key: torch.Tensor,
  348. value: torch.Tensor,
  349. attention_mask: torch.Tensor | None,
  350. scaling: float | None = None,
  351. dropout: float = 0.0,
  352. **kwargs: Unpack[TransformersKwargs],
  353. ):
  354. if scaling is None:
  355. scaling = query.size(-1) ** -0.5
  356. # Take the dot product between "query" and "key" to get the raw attention scores.
  357. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  358. if attention_mask is not None:
  359. attn_weights = attn_weights + attention_mask
  360. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  361. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  362. attn_output = torch.matmul(attn_weights, value)
  363. attn_output = attn_output.transpose(1, 2).contiguous()
  364. return attn_output, attn_weights
  365. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->BridgeTower
  366. class BridgeTowerSelfAttention(nn.Module):
  367. def __init__(self, config, is_causal=False, layer_idx=None):
  368. super().__init__()
  369. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  370. raise ValueError(
  371. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  372. f"heads ({config.num_attention_heads})"
  373. )
  374. self.config = config
  375. self.num_attention_heads = config.num_attention_heads
  376. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  377. self.all_head_size = self.num_attention_heads * self.attention_head_size
  378. self.scaling = self.attention_head_size**-0.5
  379. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  380. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  381. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  382. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  383. self.is_decoder = config.is_decoder
  384. self.is_causal = is_causal
  385. self.layer_idx = layer_idx
  386. def forward(
  387. self,
  388. hidden_states: torch.Tensor,
  389. attention_mask: torch.FloatTensor | None = None,
  390. past_key_values: Cache | None = None,
  391. **kwargs: Unpack[TransformersKwargs],
  392. ) -> tuple[torch.Tensor]:
  393. input_shape = hidden_states.shape[:-1]
  394. hidden_shape = (*input_shape, -1, self.attention_head_size)
  395. # get all proj
  396. query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
  397. key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
  398. value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
  399. if past_key_values is not None:
  400. # decoder-only roberta can have a simple dynamic cache for example
  401. current_past_key_values = past_key_values
  402. if isinstance(past_key_values, EncoderDecoderCache):
  403. current_past_key_values = past_key_values.self_attention_cache
  404. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  405. key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
  406. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  407. self.config._attn_implementation, eager_attention_forward
  408. )
  409. attn_output, attn_weights = attention_interface(
  410. self,
  411. query_layer,
  412. key_layer,
  413. value_layer,
  414. attention_mask,
  415. dropout=0.0 if not self.training else self.dropout.p,
  416. scaling=self.scaling,
  417. **kwargs,
  418. )
  419. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  420. return attn_output, attn_weights
  421. # Copied from transformers.models.roberta.modeling_roberta.RobertaCrossAttention with Roberta->BridgeTower
  422. class BridgeTowerCrossAttention(nn.Module):
  423. def __init__(self, config, is_causal=False, layer_idx=None):
  424. super().__init__()
  425. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  426. raise ValueError(
  427. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  428. f"heads ({config.num_attention_heads})"
  429. )
  430. self.config = config
  431. self.num_attention_heads = config.num_attention_heads
  432. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  433. self.all_head_size = self.num_attention_heads * self.attention_head_size
  434. self.scaling = self.attention_head_size**-0.5
  435. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  436. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  437. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  438. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  439. self.is_causal = is_causal
  440. self.layer_idx = layer_idx
  441. def forward(
  442. self,
  443. hidden_states: torch.Tensor,
  444. encoder_hidden_states: torch.FloatTensor | None = None,
  445. attention_mask: torch.FloatTensor | None = None,
  446. past_key_values: EncoderDecoderCache | None = None,
  447. **kwargs: Unpack[TransformersKwargs],
  448. ) -> tuple[torch.Tensor]:
  449. # determine input shapes
  450. input_shape = hidden_states.shape[:-1]
  451. hidden_shape = (*input_shape, -1, self.attention_head_size)
  452. # get query proj
  453. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  454. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  455. if past_key_values is not None and is_updated:
  456. # reuse k,v, cross_attentions
  457. key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  458. value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  459. else:
  460. kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
  461. key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  462. value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
  463. if past_key_values is not None:
  464. # save all states to the cache
  465. key_layer, value_layer = past_key_values.cross_attention_cache.update(
  466. key_layer, value_layer, self.layer_idx
  467. )
  468. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  469. past_key_values.is_updated[self.layer_idx] = True
  470. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  471. self.config._attn_implementation, eager_attention_forward
  472. )
  473. attn_output, attn_weights = attention_interface(
  474. self,
  475. query_layer,
  476. key_layer,
  477. value_layer,
  478. attention_mask,
  479. dropout=0.0 if not self.training else self.dropout.p,
  480. scaling=self.scaling,
  481. **kwargs,
  482. )
  483. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  484. return attn_output, attn_weights
  485. # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->BridgeTower,BERT->BRIDGE_TOWER
  486. class BridgeTowerAttention(nn.Module):
  487. def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
  488. super().__init__()
  489. self.is_cross_attention = is_cross_attention
  490. attention_class = BridgeTowerCrossAttention if is_cross_attention else BridgeTowerSelfAttention
  491. self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
  492. self.output = BridgeTowerSelfOutput(config)
  493. def forward(
  494. self,
  495. hidden_states: torch.Tensor,
  496. attention_mask: torch.FloatTensor | None = None,
  497. encoder_hidden_states: torch.FloatTensor | None = None,
  498. encoder_attention_mask: torch.FloatTensor | None = None,
  499. past_key_values: Cache | None = None,
  500. **kwargs: Unpack[TransformersKwargs],
  501. ) -> tuple[torch.Tensor]:
  502. attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
  503. attention_output, attn_weights = self.self(
  504. hidden_states,
  505. encoder_hidden_states=encoder_hidden_states,
  506. attention_mask=attention_mask,
  507. past_key_values=past_key_values,
  508. **kwargs,
  509. )
  510. attention_output = self.output(attention_output, hidden_states)
  511. return attention_output, attn_weights
  512. class BridgeTowerBertCrossLayer(nn.Module):
  513. def __init__(self, config, layer_idx=None):
  514. super().__init__()
  515. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  516. self.seq_len_dim = 1
  517. self.attention = BridgeTowerAttention(config, is_causal=True, layer_idx=layer_idx)
  518. self.is_decoder = config.is_decoder
  519. self.add_cross_attention = config.add_cross_attention
  520. self.crossattention = BridgeTowerAttention(
  521. config,
  522. is_causal=False,
  523. layer_idx=layer_idx,
  524. is_cross_attention=True,
  525. )
  526. self.intermediate = BridgeTowerIntermediate(config)
  527. self.output = BridgeTowerOutput(config)
  528. def forward(
  529. self,
  530. hidden_states,
  531. encoder_hidden_states,
  532. attention_mask=None,
  533. encoder_attention_mask=None,
  534. past_key_values=None,
  535. **kwargs: Unpack[TransformersKwargs],
  536. ):
  537. self_attention_output, self_attn_weights = self.attention(
  538. hidden_states,
  539. attention_mask=attention_mask,
  540. past_key_values=None,
  541. **kwargs,
  542. )
  543. attention_output = self_attention_output
  544. cross_attention_output, cross_attn_weights = self.crossattention(
  545. attention_output,
  546. attention_mask=attention_mask,
  547. encoder_hidden_states=encoder_hidden_states,
  548. encoder_attention_mask=encoder_attention_mask,
  549. past_key_values=past_key_values,
  550. **kwargs,
  551. )
  552. attention_output = cross_attention_output
  553. layer_output = apply_chunking_to_forward(
  554. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  555. )
  556. return (
  557. layer_output,
  558. self_attn_weights,
  559. cross_attn_weights,
  560. )
  561. def feed_forward_chunk(self, attention_output):
  562. intermediate_output = self.intermediate(attention_output)
  563. layer_output = self.output(intermediate_output, attention_output)
  564. return layer_output
  565. class BridgeTowerTextLayer(GradientCheckpointingLayer):
  566. def __init__(self, config, layer_idx=None):
  567. super().__init__()
  568. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  569. self.seq_len_dim = 1
  570. self.attention = BridgeTowerAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
  571. self.is_decoder = config.is_decoder
  572. self.add_cross_attention = config.add_cross_attention
  573. if self.add_cross_attention:
  574. if not self.is_decoder:
  575. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  576. self.crossattention = BridgeTowerAttention(
  577. config,
  578. is_causal=False,
  579. layer_idx=layer_idx,
  580. is_cross_attention=True,
  581. )
  582. self.intermediate = BridgeTowerIntermediate(config)
  583. self.output = BridgeTowerOutput(config)
  584. # copied from transformers.models.bert.modeling_bert.BertLayer.forward
  585. def forward(
  586. self,
  587. hidden_states: torch.Tensor,
  588. attention_mask: torch.FloatTensor | None = None,
  589. encoder_hidden_states: torch.FloatTensor | None = None,
  590. encoder_attention_mask: torch.FloatTensor | None = None,
  591. past_key_values: Cache | None = None,
  592. **kwargs: Unpack[TransformersKwargs],
  593. ) -> torch.Tensor:
  594. self_attention_output, _ = self.attention(
  595. hidden_states,
  596. attention_mask,
  597. past_key_values=past_key_values,
  598. **kwargs,
  599. )
  600. attention_output = self_attention_output
  601. if self.is_decoder and encoder_hidden_states is not None:
  602. if not hasattr(self, "crossattention"):
  603. raise ValueError(
  604. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  605. " by setting `config.add_cross_attention=True`"
  606. )
  607. cross_attention_output, _ = self.crossattention(
  608. self_attention_output,
  609. None, # attention_mask
  610. encoder_hidden_states,
  611. encoder_attention_mask,
  612. past_key_values=past_key_values,
  613. **kwargs,
  614. )
  615. attention_output = cross_attention_output
  616. layer_output = apply_chunking_to_forward(
  617. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  618. )
  619. return layer_output
  620. def feed_forward_chunk(self, attention_output):
  621. intermediate_output = self.intermediate(attention_output)
  622. layer_output = self.output(intermediate_output, attention_output)
  623. return layer_output
  624. # copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->BridgeTowerText
  625. class BridgeTowerTextEncoder(nn.Module):
  626. def __init__(self, config):
  627. super().__init__()
  628. self.config = config
  629. self.layer = nn.ModuleList(
  630. [BridgeTowerTextLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
  631. )
  632. def forward(
  633. self,
  634. hidden_states: torch.Tensor,
  635. attention_mask: torch.FloatTensor | None = None,
  636. encoder_hidden_states: torch.FloatTensor | None = None,
  637. encoder_attention_mask: torch.FloatTensor | None = None,
  638. past_key_values: Cache | None = None,
  639. use_cache: bool | None = None,
  640. **kwargs: Unpack[TransformersKwargs],
  641. ) -> BaseModelOutputWithPastAndCrossAttentions:
  642. for layer_module in self.layer:
  643. hidden_states = layer_module(
  644. hidden_states,
  645. attention_mask,
  646. encoder_hidden_states, # as a positional argument for gradient checkpointing
  647. encoder_attention_mask=encoder_attention_mask,
  648. past_key_values=past_key_values,
  649. **kwargs,
  650. )
  651. return BaseModelOutputWithPastAndCrossAttentions(
  652. last_hidden_state=hidden_states,
  653. past_key_values=past_key_values if use_cache else None,
  654. )
  655. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->BridgeTowerText
  656. class BridgeTowerTextEmbeddings(nn.Module):
  657. """Construct the embeddings from word, position and token_type embeddings."""
  658. def __init__(self, config):
  659. super().__init__()
  660. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  661. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  662. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  663. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  664. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  665. self.register_buffer(
  666. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  667. )
  668. self.register_buffer(
  669. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  670. )
  671. self.padding_idx = config.pad_token_id
  672. self.position_embeddings = nn.Embedding(
  673. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  674. )
  675. def forward(
  676. self,
  677. input_ids: torch.LongTensor | None = None,
  678. token_type_ids: torch.LongTensor | None = None,
  679. position_ids: torch.LongTensor | None = None,
  680. inputs_embeds: torch.FloatTensor | None = None,
  681. past_key_values_length: int = 0,
  682. ) -> torch.Tensor:
  683. if position_ids is None:
  684. if input_ids is not None:
  685. # Create the position ids from the input token ids. Any padded tokens remain padded.
  686. position_ids = self.create_position_ids_from_input_ids(
  687. input_ids, self.padding_idx, past_key_values_length
  688. )
  689. else:
  690. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
  691. if input_ids is not None:
  692. input_shape = input_ids.size()
  693. else:
  694. input_shape = inputs_embeds.size()[:-1]
  695. batch_size, seq_length = input_shape
  696. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  697. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  698. # issue #5664
  699. if token_type_ids is None:
  700. if hasattr(self, "token_type_ids"):
  701. # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
  702. buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
  703. buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
  704. token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
  705. else:
  706. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  707. if inputs_embeds is None:
  708. inputs_embeds = self.word_embeddings(input_ids)
  709. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  710. embeddings = inputs_embeds + token_type_embeddings
  711. position_embeddings = self.position_embeddings(position_ids)
  712. embeddings = embeddings + position_embeddings
  713. embeddings = self.LayerNorm(embeddings)
  714. embeddings = self.dropout(embeddings)
  715. return embeddings
  716. @staticmethod
  717. def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
  718. """
  719. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  720. Args:
  721. inputs_embeds: torch.Tensor
  722. Returns: torch.Tensor
  723. """
  724. input_shape = inputs_embeds.size()[:-1]
  725. sequence_length = input_shape[1]
  726. position_ids = torch.arange(
  727. padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  728. )
  729. return position_ids.unsqueeze(0).expand(input_shape)
  730. @staticmethod
  731. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  732. """
  733. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  734. are ignored. This is modified from fairseq's `utils.make_positions`.
  735. Args:
  736. x: torch.Tensor x:
  737. Returns: torch.Tensor
  738. """
  739. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  740. mask = input_ids.ne(padding_idx).int()
  741. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  742. return incremental_indices.long() + padding_idx
  743. @auto_docstring
  744. class BridgeTowerPreTrainedModel(PreTrainedModel):
  745. config: BridgeTowerConfig
  746. base_model_prefix = "bridgetower"
  747. input_modalities = ("image", "text")
  748. supports_gradient_checkpointing = False
  749. _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
  750. _skip_keys_device_placement = "past_key_values"
  751. _can_record_outputs = {
  752. "hidden_states": BridgeTowerTextLayer,
  753. "attentions": BridgeTowerSelfAttention,
  754. "cross_attentions": BridgeTowerCrossAttention,
  755. }
  756. @torch.no_grad()
  757. def _init_weights(self, module: nn.Module):
  758. std = self.config.initializer_factor
  759. if isinstance(module, BridgeTowerVisionTransformer):
  760. proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5)
  761. attn_std = self.config.hidden_size**-0.5
  762. fc_std = (2 * self.config.hidden_size) ** -0.5
  763. for block in module.transformer.resblocks:
  764. init.normal_(block.attn.in_proj_weight, std=attn_std * std)
  765. init.zeros_(block.attn.in_proj_bias)
  766. init.normal_(block.attn.out_proj.weight, std=proj_std * std)
  767. init.normal_(block.mlp.c_fc.weight, std=fc_std * std)
  768. init.normal_(block.mlp.c_proj.weight, std=proj_std * std)
  769. init.normal_(module.embeddings.class_embedding, std=attn_std * std)
  770. init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std)
  771. elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)):
  772. init.normal_(module.weight, mean=0.0, std=0.05 * std)
  773. elif isinstance(module, nn.LayerNorm):
  774. init.zeros_(module.bias)
  775. init.ones_(module.weight)
  776. elif isinstance(module, BridgeTowerForContrastiveLearning):
  777. init.constant_(module.logit_scale, self.config.logit_scale_init_value)
  778. elif isinstance(module, BridgeTowerVisionEmbeddings):
  779. init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
  780. elif isinstance(module, BridgeTowerTextEmbeddings):
  781. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  782. init.zeros_(module.token_type_ids)
  783. if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None:
  784. init.zeros_(module.bias)
  785. class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
  786. config: BridgeTowerVisionConfig
  787. input_modalities = ("image",)
  788. def __init__(self, config):
  789. super().__init__(config)
  790. self.visual = BridgeTowerVisionTransformer(config)
  791. self.post_init()
  792. @property
  793. def dtype(self):
  794. return self.visual.embeddings.patch_embedding.weight.dtype
  795. def forward(self, image, image_mask=None, interpolate_pos_encoding=False, **kwargs):
  796. return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding)
  797. @auto_docstring(
  798. custom_intro="""
  799. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  800. cross-attention is added between the self-attention layers, following the architecture described in *Attention is
  801. all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
  802. Kaiser and Illia Polosukhin.
  803. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  804. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  805. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  806. .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
  807. """
  808. )
  809. class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
  810. config: BridgeTowerTextConfig
  811. input_modalities = ("text",)
  812. def __init__(self, config, add_pooling_layer=True):
  813. r"""
  814. add_pooling_layer (bool, *optional*, defaults to `True`):
  815. Whether to add a pooling layer
  816. """
  817. super().__init__(config)
  818. self.config = config
  819. self.gradient_checkpointing = False
  820. self.embeddings = BridgeTowerTextEmbeddings(config)
  821. self.encoder = BridgeTowerTextEncoder(config)
  822. self.pooler = BridgeTowerPooler(config) if add_pooling_layer else None
  823. # Initialize weights and apply final processing
  824. self.post_init()
  825. def get_input_embeddings(self):
  826. return self.embeddings.word_embeddings
  827. def set_input_embeddings(self, value):
  828. self.embeddings.word_embeddings = value
  829. @merge_with_config_defaults
  830. @capture_outputs
  831. @auto_docstring
  832. # NOTE: bridgetower with its multimodality has a more complicated scheme making records harder
  833. # for now we skip the copies from bert but stay close to the original
  834. # copied from transformers.models.bert.modeling_bert.BertModel.forward
  835. def forward(
  836. self,
  837. input_ids: torch.Tensor | None = None,
  838. attention_mask: torch.Tensor | None = None,
  839. token_type_ids: torch.Tensor | None = None,
  840. position_ids: torch.Tensor | None = None,
  841. inputs_embeds: torch.Tensor | None = None,
  842. encoder_hidden_states: torch.Tensor | None = None,
  843. encoder_attention_mask: torch.Tensor | None = None,
  844. past_key_values: Cache | None = None,
  845. use_cache: bool | None = None,
  846. **kwargs: Unpack[TransformersKwargs],
  847. ) -> BaseModelOutputWithPoolingAndCrossAttentions:
  848. if (input_ids is None) ^ (inputs_embeds is not None):
  849. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  850. if not self.config.is_decoder:
  851. use_cache = False
  852. if use_cache and past_key_values is None:
  853. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  854. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  855. embedding_output = self.embeddings(
  856. input_ids=input_ids,
  857. position_ids=position_ids,
  858. token_type_ids=token_type_ids,
  859. inputs_embeds=inputs_embeds,
  860. past_key_values_length=past_key_values_length,
  861. )
  862. attention_mask, encoder_attention_mask = self._create_attention_masks(
  863. attention_mask=attention_mask,
  864. encoder_attention_mask=encoder_attention_mask,
  865. embedding_output=embedding_output,
  866. encoder_hidden_states=encoder_hidden_states,
  867. past_key_values=past_key_values,
  868. )
  869. encoder_outputs = self.encoder(
  870. embedding_output,
  871. attention_mask=attention_mask,
  872. encoder_hidden_states=encoder_hidden_states,
  873. encoder_attention_mask=encoder_attention_mask,
  874. past_key_values=past_key_values,
  875. use_cache=use_cache,
  876. position_ids=position_ids,
  877. **kwargs,
  878. )
  879. sequence_output = encoder_outputs[0]
  880. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  881. return BaseModelOutputWithPoolingAndCrossAttentions(
  882. last_hidden_state=sequence_output,
  883. pooler_output=pooled_output,
  884. past_key_values=encoder_outputs.past_key_values,
  885. )
  886. # Copied from transformers.models.bert.modeling_bert.BertModel._create_attention_masks
  887. def _create_attention_masks(
  888. self,
  889. attention_mask,
  890. encoder_attention_mask,
  891. embedding_output,
  892. encoder_hidden_states,
  893. past_key_values,
  894. ):
  895. if self.config.is_decoder:
  896. attention_mask = create_causal_mask(
  897. config=self.config,
  898. inputs_embeds=embedding_output,
  899. attention_mask=attention_mask,
  900. past_key_values=past_key_values,
  901. )
  902. else:
  903. attention_mask = create_bidirectional_mask(
  904. config=self.config,
  905. inputs_embeds=embedding_output,
  906. attention_mask=attention_mask,
  907. )
  908. if encoder_attention_mask is not None:
  909. encoder_attention_mask = create_bidirectional_mask(
  910. config=self.config,
  911. inputs_embeds=embedding_output,
  912. attention_mask=encoder_attention_mask,
  913. encoder_hidden_states=encoder_hidden_states,
  914. )
  915. return attention_mask, encoder_attention_mask
  916. @auto_docstring(
  917. custom_intro="""
  918. The bare BridgeTower Model transformer outputting BridgeTowerModelOutput object without any specific head on
  919. """
  920. )
  921. class BridgeTowerModel(BridgeTowerPreTrainedModel):
  922. def __init__(self, config):
  923. super().__init__(config)
  924. self.config = config
  925. vision_config = config.vision_config
  926. text_config = config.text_config
  927. if config.share_cross_modal_transformer_layers:
  928. self.cross_modal_text_transform = nn.Linear(text_config.hidden_size, config.hidden_size)
  929. self.cross_modal_image_transform = nn.Linear(vision_config.hidden_size, config.hidden_size)
  930. else:
  931. self.cross_modal_text_transform = nn.ModuleList(
  932. [nn.Linear(text_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
  933. )
  934. self.cross_modal_image_transform = nn.ModuleList(
  935. [nn.Linear(vision_config.hidden_size, config.hidden_size) for _ in range(config.num_hidden_layers)]
  936. )
  937. self.token_type_embeddings = nn.Embedding(2, config.hidden_size)
  938. self.vision_model = BridgeTowerVisionModel(vision_config)
  939. self.text_model = BridgeTowerTextModel(text_config)
  940. if not vision_config.share_layernorm and config.init_layernorm_from_vision_encoder:
  941. for ln in self.vision_model.visual.cross_modal_ln_separate:
  942. ln.weight.data = self.vision_model.visual.ln_post.weight.data
  943. ln.bias.data = self.vision_model.visual.ln_post.bias.data
  944. self.cross_modal_image_layers = nn.ModuleList(
  945. [BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)]
  946. )
  947. self.cross_modal_text_layers = nn.ModuleList(
  948. [BridgeTowerBertCrossLayer(text_config) for _ in range(config.num_hidden_layers)]
  949. )
  950. # Class token => Linear => Tanh
  951. self.cross_modal_image_pooler = BridgeTowerPooler(config)
  952. self.cross_modal_text_pooler = BridgeTowerPooler(config)
  953. # Initialize BridgeTower Components
  954. self.cross_modal_text_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  955. self.cross_modal_image_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  956. if config.share_link_tower_layers:
  957. self.cross_modal_text_link_tower = BridgeTowerLinkTower(config)
  958. self.cross_modal_image_link_tower = BridgeTowerLinkTower(config)
  959. else:
  960. self.cross_modal_text_link_tower = nn.ModuleList(
  961. [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
  962. )
  963. self.cross_modal_image_link_tower = nn.ModuleList(
  964. [BridgeTowerLinkTower(config) for _ in range(config.num_hidden_layers - 1)]
  965. )
  966. self.post_init()
  967. def get_input_embeddings(self):
  968. return self.text_model.get_input_embeddings()
  969. def set_input_embeddings(self, value):
  970. self.text_model.set_input_embeddings(value)
  971. def _apply_text_transform(self, hidden_states: torch.Tensor, layer_idx: int) -> torch.Tensor:
  972. if self.config.share_cross_modal_transformer_layers:
  973. return self.cross_modal_text_transform(hidden_states)
  974. return self.cross_modal_text_transform[layer_idx](hidden_states)
  975. def _apply_image_transform(self, hidden_states: torch.Tensor, layer_idx: int) -> torch.Tensor:
  976. if self.config.share_cross_modal_transformer_layers:
  977. return self.cross_modal_image_transform(hidden_states)
  978. return self.cross_modal_image_transform[layer_idx](hidden_states)
  979. @can_return_tuple
  980. @auto_docstring
  981. def forward(
  982. self,
  983. input_ids: torch.LongTensor | None = None,
  984. attention_mask: torch.FloatTensor | None = None,
  985. token_type_ids: torch.LongTensor | None = None,
  986. pixel_values: torch.FloatTensor | None = None,
  987. pixel_mask: torch.LongTensor | None = None,
  988. inputs_embeds: torch.FloatTensor | None = None,
  989. image_embeds: torch.FloatTensor | None = None,
  990. image_token_type_idx: int | None = None,
  991. labels: torch.LongTensor | None = None,
  992. interpolate_pos_encoding: bool = False,
  993. **kwargs: Unpack[TransformersKwargs],
  994. ) -> tuple[torch.Tensor] | BridgeTowerModelOutput:
  995. r"""
  996. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  997. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  998. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  999. image_token_type_idx (`int`, *optional*):
  1000. - The token type ids for images.
  1001. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1002. Labels are currently not supported.
  1003. Examples:
  1004. ```python
  1005. >>> from transformers import BridgeTowerProcessor, BridgeTowerModel
  1006. >>> from PIL import Image
  1007. >>> import httpx
  1008. >>> from io import BytesIO
  1009. >>> # prepare image and text
  1010. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1011. >>> with httpx.stream("GET", url) as response:
  1012. ... image = Image.open(BytesIO(response.read()))
  1013. >>> text = "hello world"
  1014. >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base")
  1015. >>> model = BridgeTowerModel.from_pretrained("BridgeTower/bridgetower-base")
  1016. >>> inputs = processor(image, text, return_tensors="pt")
  1017. >>> outputs = model(**inputs)
  1018. >>> outputs.keys()
  1019. odict_keys(['text_features', 'image_features', 'pooler_output'])
  1020. ```"""
  1021. all_hidden_states_text = []
  1022. all_hidden_states_image = []
  1023. all_hidden_states_cross = []
  1024. all_self_attentions = []
  1025. if inputs_embeds is not None and input_ids is None:
  1026. raise NotImplementedError(
  1027. "BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
  1028. )
  1029. image_token_type_idx = image_token_type_idx or 1
  1030. input_shape = input_ids.size()
  1031. text_embeds = self.text_model.embeddings(input_ids=input_ids)
  1032. all_hidden_states_text.append(text_embeds)
  1033. if attention_mask is None:
  1034. attention_mask = torch.ones(input_shape, dtype=torch.long, device=input_ids.device)
  1035. extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, input_shape).to(
  1036. input_ids.device
  1037. )
  1038. # The split_index determines how many layers of the uni-modal encoder are applied before the cross-modal encoder
  1039. split_index = len(self.text_model.encoder.layer) - self.config.num_hidden_layers + 1
  1040. # Run the first 'split_index' layers of the textual encoder
  1041. for layer in self.text_model.encoder.layer[:split_index]:
  1042. text_embeds = layer(text_embeds, extend_text_masks)
  1043. all_hidden_states_text.append(text_embeds)
  1044. if image_embeds is None:
  1045. image_embeds = self.vision_model.visual.forward_pre(
  1046. pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding
  1047. )
  1048. else:
  1049. # Permute as BridgeTowerResidualAttention has batch_first=True
  1050. image_embeds = image_embeds.permute(1, 0, 2)
  1051. all_hidden_states_image.append(image_embeds)
  1052. # Run the first 'split_index' layers of the visual encoder
  1053. for block in self.vision_model.visual.transformer.resblocks[:split_index]:
  1054. image_embeds = block(image_embeds)
  1055. all_hidden_states_image.append(image_embeds)
  1056. image_embeds_with_ln = self.vision_model.visual.forward_post(image_embeds.type(self.vision_model.dtype))
  1057. # first layer is a special case because we don't have the output from the cross-encoder yet
  1058. cross_modal_text = self._apply_text_transform(text_embeds, layer_idx=0)
  1059. text_token_type_embeddings = self.token_type_embeddings(
  1060. torch.zeros(1, dtype=torch.long, device=input_ids.device)
  1061. ).expand_as(cross_modal_text)
  1062. cross_modal_text = self.cross_modal_text_layernorm(cross_modal_text + text_token_type_embeddings)
  1063. image_embeds_with_ln = self._apply_image_transform(image_embeds_with_ln, layer_idx=0)
  1064. image_token_type_embeddings = self.token_type_embeddings(
  1065. torch.full((1,), image_token_type_idx, dtype=torch.long, device=input_ids.device)
  1066. ).expand_as(image_embeds_with_ln)
  1067. image_embeds_with_ln = image_embeds_with_ln + image_token_type_embeddings
  1068. cross_modal_image = self.cross_modal_image_layernorm(image_embeds_with_ln)
  1069. pixel_mask = torch.ones(
  1070. (cross_modal_image.size(0), cross_modal_image.size(1)),
  1071. dtype=torch.long,
  1072. device=input_ids.device,
  1073. )
  1074. extend_image_masks = self.text_model.get_extended_attention_mask(pixel_mask, pixel_mask.size()).to(
  1075. input_ids.device
  1076. )
  1077. layer_outputs_text = self.cross_modal_text_layers[0](
  1078. cross_modal_text,
  1079. cross_modal_image,
  1080. attention_mask=extend_text_masks,
  1081. encoder_attention_mask=extend_image_masks,
  1082. )
  1083. cross_text_features = layer_outputs_text[0]
  1084. layer_outputs_image = self.cross_modal_image_layers[0](
  1085. cross_modal_image,
  1086. cross_modal_text,
  1087. attention_mask=extend_image_masks,
  1088. encoder_attention_mask=extend_text_masks,
  1089. )
  1090. cross_image_features = layer_outputs_image[0]
  1091. all_hidden_states_cross.append((cross_text_features, cross_image_features))
  1092. all_self_attentions.append((layer_outputs_text[1], layer_outputs_image[1]))
  1093. link_layer_index = 0
  1094. # Each of the top 6 layers of the visual and textual encoders ([split_index:]) is connected to each layer of
  1095. # the cross-modal encoder via bridge layers, which brings bottom-up alignment and fusion to the cross-modal encoder.
  1096. for i in range(split_index, len(self.text_model.encoder.layer)):
  1097. text_embeds = self.text_model.encoder.layer[i](text_embeds, extend_text_masks)
  1098. image_embeds = self.vision_model.visual.transformer.resblocks[i](image_embeds).type(
  1099. self.vision_model.dtype
  1100. )
  1101. image_embeds_with_ln = (
  1102. self._apply_image_transform(self.vision_model.visual.forward_post(image_embeds), link_layer_index + 1)
  1103. + image_token_type_embeddings
  1104. )
  1105. text_link_tower = self.cross_modal_text_link_tower[link_layer_index]
  1106. image_link_tower = self.cross_modal_image_link_tower[link_layer_index]
  1107. # Bridge layers for textual and visual encoders
  1108. transformed_text_embeds = self._apply_text_transform(text_embeds, link_layer_index + 1)
  1109. cross_text_features_ = text_link_tower(
  1110. transformed_text_embeds + text_token_type_embeddings,
  1111. cross_text_features,
  1112. extend_text_masks,
  1113. )
  1114. cross_image_features_ = image_link_tower(image_embeds_with_ln, cross_image_features, extend_image_masks)
  1115. # Cross-modal encoder via bridge layers of textual and visual encoders
  1116. layer_outputs_text = self.cross_modal_text_layers[link_layer_index + 1](
  1117. cross_text_features_,
  1118. cross_image_features_,
  1119. attention_mask=extend_text_masks,
  1120. encoder_attention_mask=extend_image_masks,
  1121. )
  1122. cross_text_features = layer_outputs_text[0]
  1123. layer_outputs_image = self.cross_modal_image_layers[link_layer_index + 1](
  1124. cross_image_features_,
  1125. cross_text_features_,
  1126. attention_mask=extend_image_masks,
  1127. encoder_attention_mask=extend_text_masks,
  1128. )
  1129. cross_image_features = layer_outputs_image[0]
  1130. link_layer_index += 1
  1131. all_hidden_states_text.append(text_embeds)
  1132. all_hidden_states_image.append(image_embeds)
  1133. all_hidden_states_cross.append((cross_text_features, cross_image_features))
  1134. all_self_attentions.append((layer_outputs_text[1], layer_outputs_image[1]))
  1135. # Concatenate the cls token of the text and image features to get the final represtation
  1136. text_features, image_features = cross_text_features, cross_image_features
  1137. cls_features = self.get_cls_features(text_features, image_features)
  1138. return BridgeTowerModelOutput(
  1139. text_features=text_features,
  1140. image_features=image_features,
  1141. pooler_output=cls_features,
  1142. hidden_states=(
  1143. tuple(all_hidden_states_text),
  1144. tuple(all_hidden_states_image),
  1145. tuple(all_hidden_states_cross),
  1146. ),
  1147. attentions=tuple(all_self_attentions),
  1148. )
  1149. def get_cls_features(self, text_features, image_features):
  1150. cls_features_text = self.cross_modal_text_pooler(text_features)
  1151. cls_features_image = self.cross_modal_image_pooler(image_features)
  1152. return torch.cat([cls_features_text, cls_features_image], dim=-1)
  1153. # Copied from transformers.models.vilt.modeling_vilt.ViltPredictionHeadTransform with Vilt->BridgeTower
  1154. class BridgeTowerPredictionHeadTransform(nn.Module):
  1155. def __init__(self, config):
  1156. super().__init__()
  1157. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1158. if isinstance(config.hidden_act, str):
  1159. self.transform_act_fn = ACT2FN[config.hidden_act]
  1160. else:
  1161. self.transform_act_fn = config.hidden_act
  1162. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1163. def forward(self, hidden_states):
  1164. hidden_states = self.dense(hidden_states)
  1165. hidden_states = self.transform_act_fn(hidden_states)
  1166. hidden_states = self.LayerNorm(hidden_states)
  1167. return hidden_states
  1168. class BridgeTowerMLMHead(nn.Module):
  1169. def __init__(self, config, weight=None):
  1170. super().__init__()
  1171. self.config = config
  1172. self.transform = BridgeTowerPredictionHeadTransform(config)
  1173. self.decoder = nn.Linear(config.hidden_size, config.text_config.vocab_size, bias=False)
  1174. self.bias = nn.Parameter(torch.zeros(config.text_config.vocab_size))
  1175. if weight is not None:
  1176. self.decoder.weight = weight
  1177. def forward(self, x):
  1178. mlm_score = self.transform(x)
  1179. mlm_score = self.decoder(mlm_score) + self.bias
  1180. return mlm_score
  1181. class BridgeTowerITMHead(nn.Module):
  1182. def __init__(self, hidden_size):
  1183. super().__init__()
  1184. self.fc = nn.Linear(hidden_size, 2)
  1185. def forward(self, x):
  1186. itm_score = self.fc(x)
  1187. return itm_score
  1188. @auto_docstring(
  1189. custom_intro="""
  1190. BridgeTower Model with a language modeling head on top as done during pretraining.
  1191. """
  1192. )
  1193. class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
  1194. _tied_weights_keys = {"mlm_score.decoder.weight": "bridgetower.text_model.embeddings.word_embeddings.weight"}
  1195. def __init__(self, config):
  1196. super().__init__(config)
  1197. self.bridgetower = BridgeTowerModel(config)
  1198. self.mlm_score = BridgeTowerMLMHead(config)
  1199. # Initialize weights and apply final processing
  1200. self.post_init()
  1201. def get_output_embeddings(self):
  1202. return self.mlm_score.decoder
  1203. def set_output_embeddings(self, new_embeddings):
  1204. self.mlm_score.decoder = new_embeddings
  1205. @can_return_tuple
  1206. @auto_docstring
  1207. def forward(
  1208. self,
  1209. input_ids: torch.LongTensor | None = None,
  1210. attention_mask: torch.FloatTensor | None = None,
  1211. token_type_ids: torch.LongTensor | None = None,
  1212. pixel_values: torch.FloatTensor | None = None,
  1213. pixel_mask: torch.LongTensor | None = None,
  1214. inputs_embeds: torch.FloatTensor | None = None,
  1215. image_embeds: torch.FloatTensor | None = None,
  1216. labels: torch.LongTensor | None = None,
  1217. **kwargs: Unpack[TransformersKwargs],
  1218. ) -> MaskedLMOutput:
  1219. r"""
  1220. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1221. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1222. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1223. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1224. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1225. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  1226. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1227. Examples:
  1228. ```python
  1229. >>> from transformers import BridgeTowerProcessor, BridgeTowerForMaskedLM
  1230. >>> from PIL import Image
  1231. >>> import httpx
  1232. >>> from io import BytesIO
  1233. >>> url = "http://images.cocodataset.org/val2017/000000360943.jpg"
  1234. >>> with httpx.stream("GET", url) as response:
  1235. ... image = Image.open(BytesIO(response.read())).convert("RGB")
  1236. >>> text = "a <mask> looking out of the window"
  1237. >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
  1238. >>> model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
  1239. >>> # prepare inputs
  1240. >>> encoding = processor(image, text, return_tensors="pt")
  1241. >>> # forward pass
  1242. >>> outputs = model(**encoding)
  1243. >>> results = processor.decode(outputs.logits.argmax(dim=-1).squeeze(0).tolist())
  1244. >>> print(results)
  1245. .a cat looking out of the window.
  1246. ```"""
  1247. outputs = self.bridgetower(
  1248. input_ids=input_ids,
  1249. attention_mask=attention_mask,
  1250. token_type_ids=token_type_ids,
  1251. pixel_values=pixel_values,
  1252. pixel_mask=pixel_mask,
  1253. inputs_embeds=inputs_embeds,
  1254. image_embeds=image_embeds,
  1255. **kwargs,
  1256. )
  1257. mlm_logits = self.mlm_score(outputs.text_features)
  1258. masked_lm_loss = None
  1259. if labels is not None:
  1260. loss_fct = CrossEntropyLoss() # -100 index = padding token
  1261. labels = labels.to(mlm_logits.device)
  1262. masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
  1263. return MaskedLMOutput(
  1264. loss=masked_lm_loss,
  1265. logits=mlm_logits,
  1266. hidden_states=outputs.hidden_states,
  1267. attentions=outputs.attentions,
  1268. )
  1269. @auto_docstring(
  1270. custom_intro="""
  1271. BridgeTower Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the
  1272. [CLS] token) for image-to-text matching.
  1273. """
  1274. )
  1275. class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
  1276. def __init__(self, config):
  1277. super().__init__(config)
  1278. self.bridgetower = BridgeTowerModel(config)
  1279. self.itm_score = BridgeTowerITMHead(config.hidden_size * 2)
  1280. # Initialize weights and apply final processing
  1281. self.post_init()
  1282. @can_return_tuple
  1283. @auto_docstring
  1284. def forward(
  1285. self,
  1286. input_ids: torch.LongTensor | None = None,
  1287. attention_mask: torch.FloatTensor | None = None,
  1288. token_type_ids: torch.LongTensor | None = None,
  1289. pixel_values: torch.FloatTensor | None = None,
  1290. pixel_mask: torch.LongTensor | None = None,
  1291. inputs_embeds: torch.FloatTensor | None = None,
  1292. image_embeds: torch.FloatTensor | None = None,
  1293. labels: torch.LongTensor | None = None,
  1294. **kwargs: Unpack[TransformersKwargs],
  1295. ) -> SequenceClassifierOutput:
  1296. r"""
  1297. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1298. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1299. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1300. labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
  1301. Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
  1302. The pairs with 0 will be skipped for calculation.
  1303. Examples:
  1304. ```python
  1305. >>> from transformers import BridgeTowerProcessor, BridgeTowerForImageAndTextRetrieval
  1306. >>> import httpx
  1307. >>> from io import BytesIO
  1308. >>> from PIL import Image
  1309. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1310. >>> with httpx.stream("GET", url) as response:
  1311. ... image = Image.open(BytesIO(response.read()))
  1312. >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
  1313. >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
  1314. >>> model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
  1315. >>> # forward pass
  1316. >>> scores = dict()
  1317. >>> for text in texts:
  1318. ... # prepare inputs
  1319. ... encoding = processor(image, text, return_tensors="pt")
  1320. ... outputs = model(**encoding)
  1321. ... scores[text] = outputs.logits[0, 1].item()
  1322. ```"""
  1323. outputs = self.bridgetower(
  1324. input_ids=input_ids,
  1325. attention_mask=attention_mask,
  1326. token_type_ids=token_type_ids,
  1327. pixel_values=pixel_values,
  1328. pixel_mask=pixel_mask,
  1329. inputs_embeds=inputs_embeds,
  1330. image_embeds=image_embeds,
  1331. **kwargs,
  1332. )
  1333. pooler_output = outputs.pooler_output
  1334. logits = self.itm_score(pooler_output)
  1335. itm_loss = None
  1336. if labels is not None:
  1337. loss_fct = CrossEntropyLoss()
  1338. labels = labels.to(logits.device)
  1339. itm_loss = loss_fct(logits, labels)
  1340. return SequenceClassifierOutput(
  1341. loss=itm_loss,
  1342. logits=logits,
  1343. hidden_states=outputs.hidden_states,
  1344. attentions=outputs.attentions,
  1345. )
  1346. class BridgeTowerContrastiveHead(nn.Module):
  1347. def __init__(self, hidden_size, embed_size):
  1348. super().__init__()
  1349. self.fc = nn.Linear(hidden_size, embed_size)
  1350. def forward(self, x):
  1351. x = self.fc(x)
  1352. return x
  1353. @auto_docstring(
  1354. custom_intro="""
  1355. BridgeTower Model with a image-text contrastive head on top computing image-text contrastive loss.
  1356. """
  1357. )
  1358. class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
  1359. def __init__(self, config):
  1360. super().__init__(config)
  1361. self.bridgetower = BridgeTowerModel(config)
  1362. self.itc_text_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
  1363. self.itc_image_head = BridgeTowerContrastiveHead(config.hidden_size, config.contrastive_hidden_size)
  1364. self.itc_cross_modal_head = BridgeTowerContrastiveHead(config.hidden_size * 2, config.contrastive_hidden_size)
  1365. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  1366. # Initialize weights and apply final processing
  1367. self.post_init()
  1368. @can_return_tuple
  1369. @auto_docstring
  1370. def forward(
  1371. self,
  1372. input_ids: torch.LongTensor | None = None,
  1373. attention_mask: torch.FloatTensor | None = None,
  1374. token_type_ids: torch.LongTensor | None = None,
  1375. pixel_values: torch.FloatTensor | None = None,
  1376. pixel_mask: torch.LongTensor | None = None,
  1377. inputs_embeds: torch.FloatTensor | None = None,
  1378. image_embeds: torch.FloatTensor | None = None,
  1379. return_loss: bool | None = None,
  1380. **kwargs: Unpack[TransformersKwargs],
  1381. ) -> BridgeTowerContrastiveOutput:
  1382. r"""
  1383. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1384. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1385. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1386. return_loss (`bool`, *optional*):
  1387. Whether or not to return the contrastive loss.
  1388. Examples:
  1389. ```python
  1390. >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
  1391. >>> import httpx
  1392. >>> from io import BytesIO
  1393. >>> from PIL import Image
  1394. >>> import torch
  1395. >>> image_urls = [
  1396. ... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
  1397. ... "http://images.cocodataset.org/val2017/000000039769.jpg",
  1398. ... ]
  1399. >>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
  1400. >>> with httpx.stream("GET", urls[0]) as response:
  1401. ... image1 = Image.open(BytesIO(response.read()))
  1402. >>> with httpx.stream("GET", urls[1]) as response:
  1403. ... image2 = Image.open(BytesIO(response.read()))
  1404. >>> images = [image1, image2]
  1405. >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
  1406. >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
  1407. >>> inputs = processor(images, texts, padding=True, return_tensors="pt")
  1408. >>> loss = model(**inputs, return_loss=True).loss
  1409. >>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
  1410. >>> loss_swapped = model(**inputs, return_loss=True).loss
  1411. >>> print("Loss", round(loss.item(), 4))
  1412. Loss 0.0019
  1413. >>> print("Loss with swapped images", round(loss_swapped.item(), 4))
  1414. Loss with swapped images 2.126
  1415. ```"""
  1416. kwargs.setdefault("output_hidden_states", True)
  1417. outputs = self.bridgetower(
  1418. input_ids=input_ids,
  1419. attention_mask=attention_mask,
  1420. token_type_ids=token_type_ids,
  1421. pixel_values=pixel_values,
  1422. pixel_mask=pixel_mask,
  1423. inputs_embeds=inputs_embeds,
  1424. image_embeds=image_embeds,
  1425. **kwargs,
  1426. )
  1427. pooler_output = outputs.pooler_output
  1428. hidden_states_txt, hidden_states_img, hidden_states_cross_modal = outputs.hidden_states
  1429. text_embeds = hidden_states_txt[-1]
  1430. image_embeds = hidden_states_img[-1]
  1431. image_embeds_with_ln = self.bridgetower.vision_model.visual.forward_post(image_embeds)
  1432. image_token_type_embeddings = self.bridgetower.token_type_embeddings(
  1433. torch.full((1,), 1, dtype=torch.long, device=self.bridgetower.token_type_embeddings.weight.device)
  1434. ).expand_as(image_embeds_with_ln)
  1435. image_embeds = self.bridgetower.cross_modal_image_transform(image_embeds_with_ln) + image_token_type_embeddings
  1436. # normalized features
  1437. text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
  1438. image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
  1439. device=text_embeds.device
  1440. )
  1441. cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
  1442. device=text_embeds.device
  1443. )
  1444. logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
  1445. logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
  1446. logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  1447. logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
  1448. logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale
  1449. itc_loss = None
  1450. if return_loss:
  1451. labels = torch.arange(len(logits), device=logits.device)
  1452. text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
  1453. text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
  1454. image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
  1455. itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
  1456. return BridgeTowerContrastiveOutput(
  1457. loss=itc_loss,
  1458. logits=logits,
  1459. text_embeds=text_embeds,
  1460. image_embeds=image_embeds,
  1461. cross_embeds=cross_embeds,
  1462. hidden_states=outputs.hidden_states,
  1463. attentions=outputs.attentions,
  1464. )
  1465. __all__ = [
  1466. "BridgeTowerForContrastiveLearning",
  1467. "BridgeTowerForImageAndTextRetrieval",
  1468. "BridgeTowerForMaskedLM",
  1469. "BridgeTowerModel",
  1470. "BridgeTowerPreTrainedModel",
  1471. ]