modeling_lxmert.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297
  1. # Copyright 2018 Hao Tan, Mohit Bansal, and the HuggingFace team
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LXMERT model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss, SmoothL1Loss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN, gelu
  22. from ...modeling_utils import PreTrainedModel
  23. from ...utils import ModelOutput, auto_docstring, logging
  24. from .configuration_lxmert import LxmertConfig
  25. logger = logging.get_logger(__name__)
  26. class GeLU(nn.Module):
  27. def __init__(self):
  28. super().__init__()
  29. def forward(self, x):
  30. return gelu(x)
  31. @dataclass
  32. @auto_docstring(
  33. custom_intro="""
  34. Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language,
  35. visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship"
  36. encoder")
  37. """
  38. )
  39. class LxmertModelOutput(ModelOutput):
  40. r"""
  41. language_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  42. Sequence of hidden-states at the output of the last layer of the language encoder.
  43. vision_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  44. Sequence of hidden-states at the output of the last layer of the visual encoder.
  45. pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  46. Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed
  47. by a Linear layer and a Tanh activation function. The Linear
  48. language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  49. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  50. shape `(batch_size, sequence_length, hidden_size)`.
  51. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  52. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  53. shape `(batch_size, sequence_length, hidden_size)`.
  54. language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  55. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  56. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  57. the self-attention heads.
  58. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  59. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  60. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  61. the self-attention heads.
  62. cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  63. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  64. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  65. the self-attention heads.
  66. """
  67. language_output: torch.FloatTensor | None = None
  68. vision_output: torch.FloatTensor | None = None
  69. pooled_output: torch.FloatTensor | None = None
  70. language_hidden_states: tuple[torch.FloatTensor] | None = None
  71. vision_hidden_states: tuple[torch.FloatTensor] | None = None
  72. language_attentions: tuple[torch.FloatTensor] | None = None
  73. vision_attentions: tuple[torch.FloatTensor] | None = None
  74. cross_encoder_attentions: tuple[torch.FloatTensor] | None = None
  75. @dataclass
  76. @auto_docstring(
  77. custom_intro="""
  78. Output type of [`LxmertForQuestionAnswering`].
  79. """
  80. )
  81. class LxmertForQuestionAnsweringOutput(ModelOutput):
  82. r"""
  83. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  84. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  85. (classification) loss.k.
  86. question_answering_score (`torch.FloatTensor` of shape `(batch_size, n_qa_answers)`, *optional*):
  87. Prediction scores of question answering objective (classification).
  88. language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  89. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  90. shape `(batch_size, sequence_length, hidden_size)`.
  91. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  92. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  93. shape `(batch_size, sequence_length, hidden_size)`.
  94. language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  95. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  96. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  97. the self-attention heads.
  98. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  99. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  100. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  101. the self-attention heads.
  102. cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  103. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  104. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  105. the self-attention heads.
  106. """
  107. loss: torch.FloatTensor | None = None
  108. question_answering_score: torch.FloatTensor | None = None
  109. language_hidden_states: tuple[torch.FloatTensor] | None = None
  110. vision_hidden_states: tuple[torch.FloatTensor] | None = None
  111. language_attentions: tuple[torch.FloatTensor] | None = None
  112. vision_attentions: tuple[torch.FloatTensor] | None = None
  113. cross_encoder_attentions: tuple[torch.FloatTensor] | None = None
  114. @dataclass
  115. @auto_docstring(
  116. custom_intro="""
  117. Output type of [`LxmertForPreTraining`].
  118. """
  119. )
  120. class LxmertForPreTrainingOutput(ModelOutput):
  121. r"""
  122. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  123. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  124. (classification) loss.
  125. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  126. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  127. cross_relationship_score (`torch.FloatTensor` of shape `(batch_size, 2)`):
  128. Prediction scores of the textual matching objective (classification) head (scores of True/False
  129. continuation before SoftMax).
  130. question_answering_score (`torch.FloatTensor` of shape `(batch_size, n_qa_answers)`):
  131. Prediction scores of question answering objective (classification).
  132. language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  133. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  134. shape `(batch_size, sequence_length, hidden_size)`.
  135. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  136. Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
  137. shape `(batch_size, sequence_length, hidden_size)`.
  138. language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  139. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  140. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  141. the self-attention heads.
  142. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  143. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  144. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  145. the self-attention heads.
  146. cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  147. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  148. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  149. the self-attention heads.
  150. """
  151. loss: torch.FloatTensor | None = None
  152. prediction_logits: torch.FloatTensor | None = None
  153. cross_relationship_score: torch.FloatTensor | None = None
  154. question_answering_score: torch.FloatTensor | None = None
  155. language_hidden_states: tuple[torch.FloatTensor] | None = None
  156. vision_hidden_states: tuple[torch.FloatTensor] | None = None
  157. language_attentions: tuple[torch.FloatTensor] | None = None
  158. vision_attentions: tuple[torch.FloatTensor] | None = None
  159. cross_encoder_attentions: tuple[torch.FloatTensor] | None = None
  160. class LxmertEmbeddings(nn.Module):
  161. """Construct the embeddings from word, position and token_type embeddings."""
  162. def __init__(self, config):
  163. super().__init__()
  164. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
  165. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
  166. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)
  167. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  168. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  169. def forward(self, input_ids, token_type_ids=None, inputs_embeds=None):
  170. if input_ids is not None:
  171. input_shape = input_ids.size()
  172. device = input_ids.device
  173. else:
  174. input_shape = inputs_embeds.size()[:-1]
  175. device = inputs_embeds.device
  176. seq_length = input_shape[1]
  177. position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
  178. position_ids = position_ids.unsqueeze(0).expand(input_shape)
  179. if token_type_ids is None:
  180. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  181. if inputs_embeds is None:
  182. inputs_embeds = self.word_embeddings(input_ids)
  183. position_embeddings = self.position_embeddings(position_ids)
  184. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  185. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  186. embeddings = self.LayerNorm(embeddings)
  187. embeddings = self.dropout(embeddings)
  188. return embeddings
  189. class LxmertAttention(nn.Module):
  190. def __init__(self, config, ctx_dim=None):
  191. super().__init__()
  192. if config.hidden_size % config.num_attention_heads != 0:
  193. raise ValueError(
  194. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  195. f"heads ({config.num_attention_heads})"
  196. )
  197. self.num_attention_heads = config.num_attention_heads
  198. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  199. self.head_size = self.num_attention_heads * self.attention_head_size
  200. # visual_dim = 2048
  201. if ctx_dim is None:
  202. ctx_dim = config.hidden_size
  203. self.query = nn.Linear(config.hidden_size, self.head_size)
  204. self.key = nn.Linear(ctx_dim, self.head_size)
  205. self.value = nn.Linear(ctx_dim, self.head_size)
  206. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  207. def forward(self, hidden_states, context, attention_mask=None, output_attentions=False):
  208. input_shape = hidden_states.shape[:-1]
  209. hidden_shape = (*input_shape, -1, self.attention_head_size)
  210. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  211. kv_shape = (*context.shape[:-1], -1, self.attention_head_size)
  212. key_layer = self.key(context).view(kv_shape).transpose(1, 2)
  213. value_layer = self.value(context).view(kv_shape).transpose(1, 2)
  214. # Take the dot product between "query" and "key" to get the raw attention scores.
  215. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  216. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  217. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  218. if attention_mask is not None:
  219. attention_scores = attention_scores + attention_mask
  220. # Normalize the attention scores to probabilities.
  221. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  222. # This is actually dropping out entire tokens to attend to, which might
  223. # seem a bit unusual, but is taken from the original Transformer paper.
  224. attention_probs = self.dropout(attention_probs)
  225. context_layer = torch.matmul(attention_probs, value_layer)
  226. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  227. new_context_layer_shape = context_layer.size()[:-2] + (self.head_size,)
  228. context_layer = context_layer.view(new_context_layer_shape)
  229. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  230. return outputs
  231. class LxmertAttentionOutput(nn.Module):
  232. def __init__(self, config):
  233. super().__init__()
  234. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  235. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  236. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  237. def forward(self, hidden_states, input_tensor):
  238. hidden_states = self.dense(hidden_states)
  239. hidden_states = self.dropout(hidden_states)
  240. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  241. return hidden_states
  242. class LxmertCrossAttentionLayer(nn.Module):
  243. def __init__(self, config):
  244. super().__init__()
  245. self.att = LxmertAttention(config)
  246. self.output = LxmertAttentionOutput(config)
  247. def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None, output_attentions=False):
  248. output = self.att(input_tensor, ctx_tensor, ctx_att_mask, output_attentions=output_attentions)
  249. if output_attentions:
  250. attention_probs = output[1]
  251. attention_output = self.output(output[0], input_tensor)
  252. outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
  253. return outputs
  254. class LxmertSelfAttentionLayer(nn.Module):
  255. def __init__(self, config):
  256. super().__init__()
  257. self.self = LxmertAttention(config)
  258. self.output = LxmertAttentionOutput(config)
  259. def forward(self, input_tensor, attention_mask, output_attentions=False):
  260. # Self attention attends to itself, thus keys and queries are the same (input_tensor).
  261. output = self.self(
  262. input_tensor,
  263. input_tensor,
  264. attention_mask,
  265. output_attentions=output_attentions,
  266. )
  267. if output_attentions:
  268. attention_probs = output[1]
  269. attention_output = self.output(output[0], input_tensor)
  270. outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
  271. return outputs
  272. class LxmertIntermediate(nn.Module):
  273. def __init__(self, config):
  274. super().__init__()
  275. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  276. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  277. def forward(self, hidden_states):
  278. hidden_states = self.dense(hidden_states)
  279. hidden_states = self.intermediate_act_fn(hidden_states)
  280. return hidden_states
  281. class LxmertOutput(nn.Module):
  282. def __init__(self, config):
  283. super().__init__()
  284. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  285. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  286. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  287. def forward(self, hidden_states, input_tensor):
  288. hidden_states = self.dense(hidden_states)
  289. hidden_states = self.dropout(hidden_states)
  290. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  291. return hidden_states
  292. class LxmertLayer(nn.Module):
  293. def __init__(self, config):
  294. super().__init__()
  295. self.attention = LxmertSelfAttentionLayer(config)
  296. self.intermediate = LxmertIntermediate(config)
  297. self.output = LxmertOutput(config)
  298. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  299. outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
  300. attention_output = outputs[0]
  301. intermediate_output = self.intermediate(attention_output)
  302. layer_output = self.output(intermediate_output, attention_output)
  303. outputs = (layer_output,) + outputs[1:] # add attentions if we output them
  304. return outputs
  305. class LxmertXLayer(nn.Module):
  306. def __init__(self, config):
  307. super().__init__()
  308. # The cross-attention Layer
  309. self.visual_attention = LxmertCrossAttentionLayer(config)
  310. # Self-attention Layers
  311. self.lang_self_att = LxmertSelfAttentionLayer(config)
  312. self.visn_self_att = LxmertSelfAttentionLayer(config)
  313. # Intermediate and Output Layers (FFNs)
  314. self.lang_inter = LxmertIntermediate(config)
  315. self.lang_output = LxmertOutput(config)
  316. self.visn_inter = LxmertIntermediate(config)
  317. self.visn_output = LxmertOutput(config)
  318. def cross_att(
  319. self,
  320. lang_input,
  321. lang_attention_mask,
  322. visual_input,
  323. visual_attention_mask,
  324. output_x_attentions=False,
  325. ):
  326. # Cross Attention
  327. lang_att_output = self.visual_attention(
  328. lang_input,
  329. visual_input,
  330. ctx_att_mask=visual_attention_mask,
  331. output_attentions=output_x_attentions,
  332. )
  333. visual_att_output = self.visual_attention(
  334. visual_input,
  335. lang_input,
  336. ctx_att_mask=lang_attention_mask,
  337. output_attentions=False,
  338. )
  339. return lang_att_output, visual_att_output
  340. def self_att(self, lang_input, lang_attention_mask, visual_input, visual_attention_mask):
  341. # Self Attention
  342. lang_att_output = self.lang_self_att(lang_input, lang_attention_mask, output_attentions=False)
  343. visual_att_output = self.visn_self_att(visual_input, visual_attention_mask, output_attentions=False)
  344. return lang_att_output[0], visual_att_output[0]
  345. def output_fc(self, lang_input, visual_input):
  346. # FC layers
  347. lang_inter_output = self.lang_inter(lang_input)
  348. visual_inter_output = self.visn_inter(visual_input)
  349. # Layer output
  350. lang_output = self.lang_output(lang_inter_output, lang_input)
  351. visual_output = self.visn_output(visual_inter_output, visual_input)
  352. return lang_output, visual_output
  353. def forward(
  354. self,
  355. lang_feats,
  356. lang_attention_mask,
  357. visual_feats,
  358. visual_attention_mask,
  359. output_attentions=False,
  360. ):
  361. lang_att_output, visual_att_output = self.cross_att(
  362. lang_input=lang_feats,
  363. lang_attention_mask=lang_attention_mask,
  364. visual_input=visual_feats,
  365. visual_attention_mask=visual_attention_mask,
  366. output_x_attentions=output_attentions,
  367. )
  368. attention_probs = lang_att_output[1:]
  369. lang_att_output, visual_att_output = self.self_att(
  370. lang_att_output[0],
  371. lang_attention_mask,
  372. visual_att_output[0],
  373. visual_attention_mask,
  374. )
  375. lang_output, visual_output = self.output_fc(lang_att_output, visual_att_output)
  376. return (
  377. (
  378. lang_output,
  379. visual_output,
  380. attention_probs[0],
  381. )
  382. if output_attentions
  383. else (lang_output, visual_output)
  384. )
  385. class LxmertVisualFeatureEncoder(nn.Module):
  386. def __init__(self, config):
  387. super().__init__()
  388. feat_dim = config.visual_feat_dim
  389. pos_dim = config.visual_pos_dim
  390. # Object feature encoding
  391. self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
  392. self.visn_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  393. # Box position encoding
  394. self.box_fc = nn.Linear(pos_dim, config.hidden_size)
  395. self.box_layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  396. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  397. def forward(self, visual_feats, visual_pos):
  398. x = self.visn_fc(visual_feats)
  399. x = self.visn_layer_norm(x)
  400. y = self.box_fc(visual_pos)
  401. y = self.box_layer_norm(y)
  402. output = (x + y) / 2
  403. output = self.dropout(output)
  404. return output
  405. class LxmertEncoder(nn.Module):
  406. def __init__(self, config):
  407. super().__init__()
  408. # Obj-level image embedding layer
  409. self.visn_fc = LxmertVisualFeatureEncoder(config)
  410. self.config = config
  411. # Number of layers
  412. self.num_l_layers = config.l_layers
  413. self.num_x_layers = config.x_layers
  414. self.num_r_layers = config.r_layers
  415. # Layers
  416. # Using self.layer instead of self.l_layer to support loading BERT weights.
  417. self.layer = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_l_layers)])
  418. self.x_layers = nn.ModuleList([LxmertXLayer(config) for _ in range(self.num_x_layers)])
  419. self.r_layers = nn.ModuleList([LxmertLayer(config) for _ in range(self.num_r_layers)])
  420. def forward(
  421. self,
  422. lang_feats,
  423. lang_attention_mask,
  424. visual_feats,
  425. visual_pos,
  426. visual_attention_mask=None,
  427. output_attentions=None,
  428. ):
  429. vision_hidden_states = ()
  430. language_hidden_states = ()
  431. vision_attentions = () if output_attentions or self.config.output_attentions else None
  432. language_attentions = () if output_attentions or self.config.output_attentions else None
  433. cross_encoder_attentions = () if output_attentions or self.config.output_attentions else None
  434. visual_feats = self.visn_fc(visual_feats, visual_pos)
  435. # Run language layers
  436. for layer_module in self.layer:
  437. l_outputs = layer_module(lang_feats, lang_attention_mask, output_attentions=output_attentions)
  438. lang_feats = l_outputs[0]
  439. language_hidden_states = language_hidden_states + (lang_feats,)
  440. if language_attentions is not None:
  441. language_attentions = language_attentions + (l_outputs[1],)
  442. # Run relational layers
  443. for layer_module in self.r_layers:
  444. v_outputs = layer_module(visual_feats, visual_attention_mask, output_attentions=output_attentions)
  445. visual_feats = v_outputs[0]
  446. vision_hidden_states = vision_hidden_states + (visual_feats,)
  447. if vision_attentions is not None:
  448. vision_attentions = vision_attentions + (v_outputs[1],)
  449. # Run cross-modality layers
  450. for layer_module in self.x_layers:
  451. x_outputs = layer_module(
  452. lang_feats,
  453. lang_attention_mask,
  454. visual_feats,
  455. visual_attention_mask,
  456. output_attentions=output_attentions,
  457. )
  458. lang_feats, visual_feats = x_outputs[:2]
  459. vision_hidden_states = vision_hidden_states + (visual_feats,)
  460. language_hidden_states = language_hidden_states + (lang_feats,)
  461. if cross_encoder_attentions is not None:
  462. cross_encoder_attentions = cross_encoder_attentions + (x_outputs[2],)
  463. visual_encoder_outputs = (
  464. vision_hidden_states,
  465. vision_attentions if output_attentions else None,
  466. )
  467. lang_encoder_outputs = (
  468. language_hidden_states,
  469. language_attentions if output_attentions else None,
  470. )
  471. return (
  472. visual_encoder_outputs,
  473. lang_encoder_outputs,
  474. cross_encoder_attentions if output_attentions else None,
  475. )
  476. class LxmertPooler(nn.Module):
  477. def __init__(self, config):
  478. super().__init__()
  479. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  480. self.activation = nn.Tanh()
  481. def forward(self, hidden_states):
  482. # We "pool" the model by simply taking the hidden state corresponding
  483. # to the first token.
  484. first_token_tensor = hidden_states[:, 0]
  485. pooled_output = self.dense(first_token_tensor)
  486. pooled_output = self.activation(pooled_output)
  487. return pooled_output
  488. class LxmertPredictionHeadTransform(nn.Module):
  489. def __init__(self, config):
  490. super().__init__()
  491. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  492. self.transform_act_fn = ACT2FN[config.hidden_act]
  493. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12)
  494. def forward(self, hidden_states):
  495. hidden_states = self.dense(hidden_states)
  496. hidden_states = self.transform_act_fn(hidden_states)
  497. hidden_states = self.LayerNorm(hidden_states)
  498. return hidden_states
  499. class LxmertLMPredictionHead(nn.Module):
  500. def __init__(self, config):
  501. super().__init__()
  502. self.transform = LxmertPredictionHeadTransform(config)
  503. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  504. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  505. def forward(self, hidden_states):
  506. hidden_states = self.transform(hidden_states)
  507. hidden_states = self.decoder(hidden_states) + self.bias
  508. return hidden_states
  509. class LxmertVisualAnswerHead(nn.Module):
  510. def __init__(self, config, num_labels):
  511. super().__init__()
  512. hid_dim = config.hidden_size
  513. self.logit_fc = nn.Sequential(
  514. nn.Linear(hid_dim, hid_dim * 2),
  515. GeLU(),
  516. nn.LayerNorm(hid_dim * 2, eps=1e-12),
  517. nn.Linear(hid_dim * 2, num_labels),
  518. )
  519. def forward(self, hidden_states):
  520. return self.logit_fc(hidden_states)
  521. class LxmertVisualObjHead(nn.Module):
  522. def __init__(self, config):
  523. super().__init__()
  524. self.transform = LxmertPredictionHeadTransform(config)
  525. # Decide the use of visual losses
  526. visual_losses = {}
  527. if config.visual_obj_loss:
  528. visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
  529. if config.visual_attr_loss:
  530. visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
  531. if config.visual_feat_loss:
  532. visual_losses["feat"] = {
  533. "shape": (-1, config.visual_feat_dim),
  534. "num": config.visual_feat_dim,
  535. }
  536. self.visual_losses = visual_losses
  537. # The output weights are the same as the input embeddings, but there is
  538. # an output-only bias for each token.
  539. self.decoder_dict = nn.ModuleDict(
  540. {key: nn.Linear(config.hidden_size, self.visual_losses[key]["num"]) for key in self.visual_losses}
  541. )
  542. def forward(self, hidden_states):
  543. hidden_states = self.transform(hidden_states)
  544. output = {}
  545. for key in self.visual_losses:
  546. output[key] = self.decoder_dict[key](hidden_states)
  547. return output
  548. class LxmertPreTrainingHeads(nn.Module):
  549. def __init__(self, config):
  550. super().__init__()
  551. self.predictions = LxmertLMPredictionHead(config)
  552. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  553. def forward(self, sequence_output, pooled_output):
  554. prediction_scores = self.predictions(sequence_output)
  555. seq_relationship_score = self.seq_relationship(pooled_output)
  556. return prediction_scores, seq_relationship_score
  557. @auto_docstring
  558. class LxmertPreTrainedModel(PreTrainedModel):
  559. config: LxmertConfig
  560. base_model_prefix = "lxmert"
  561. input_modalities = ("image", "text")
  562. @torch.no_grad()
  563. def _init_weights(self, module):
  564. """Initialize the weights"""
  565. super()._init_weights(module)
  566. if isinstance(module, LxmertLMPredictionHead):
  567. init.zeros_(module.bias)
  568. @auto_docstring
  569. class LxmertModel(LxmertPreTrainedModel):
  570. def __init__(self, config):
  571. super().__init__(config)
  572. self.embeddings = LxmertEmbeddings(config)
  573. self.encoder = LxmertEncoder(config)
  574. self.pooler = LxmertPooler(config)
  575. # Initialize weights and apply final processing
  576. self.post_init()
  577. def get_input_embeddings(self):
  578. return self.embeddings.word_embeddings
  579. def set_input_embeddings(self, new_embeddings):
  580. self.embeddings.word_embeddings = new_embeddings
  581. @auto_docstring
  582. def forward(
  583. self,
  584. input_ids: torch.LongTensor | None = None,
  585. visual_feats: torch.FloatTensor | None = None,
  586. visual_pos: torch.FloatTensor | None = None,
  587. attention_mask: torch.FloatTensor | None = None,
  588. visual_attention_mask: torch.FloatTensor | None = None,
  589. token_type_ids: torch.LongTensor | None = None,
  590. inputs_embeds: torch.FloatTensor | None = None,
  591. output_attentions: bool | None = None,
  592. output_hidden_states: bool | None = None,
  593. return_dict: bool | None = None,
  594. **kwargs,
  595. ) -> LxmertModelOutput | tuple[torch.FloatTensor]:
  596. r"""
  597. visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
  598. This input represents visual features. They ROI pooled object features from bounding boxes using a
  599. faster-RCNN model)
  600. These are currently not provided by the transformers library.
  601. visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
  602. This input represents spatial features corresponding to their relative (via index) visual features. The
  603. pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
  604. 1.
  605. These are currently not provided by the transformers library.
  606. visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  607. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  608. - 1 for tokens that are **not masked**,
  609. - 0 for tokens that are **masked**.
  610. [What are attention masks?](../glossary#attention-mask)
  611. """
  612. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  613. output_hidden_states = (
  614. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  615. )
  616. return_dict = return_dict if return_dict is not None else self.config.return_dict
  617. if input_ids is not None and inputs_embeds is not None:
  618. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  619. elif input_ids is not None:
  620. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  621. input_shape = input_ids.size()
  622. elif inputs_embeds is not None:
  623. input_shape = inputs_embeds.size()[:-1]
  624. else:
  625. raise ValueError("You have to specify either input_ids or inputs_embeds")
  626. if visual_feats is None:
  627. raise ValueError("`visual_feats` cannot be `None`")
  628. if visual_pos is None:
  629. raise ValueError("`visual_pos` cannot be `None`")
  630. device = input_ids.device if input_ids is not None else inputs_embeds.device
  631. if attention_mask is None:
  632. attention_mask = torch.ones(input_shape, device=device)
  633. if token_type_ids is None:
  634. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  635. # We create a 3D attention mask from a 2D tensor mask.
  636. # Sizes are [batch_size, 1, 1, to_seq_length]
  637. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  638. # this attention mask is more simple than the triangular masking of causal attention
  639. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  640. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  641. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  642. # masked positions, this operation will create a tensor which is 0.0 for
  643. # positions we want to attend and the dtype's smallest value for masked positions.
  644. # Since we are adding it to the raw scores before the softmax, this is
  645. # effectively the same as removing these entirely.
  646. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  647. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  648. # Process the visual attention mask
  649. if visual_attention_mask is not None:
  650. extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
  651. extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype)
  652. extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * torch.finfo(self.dtype).min
  653. else:
  654. extended_visual_attention_mask = None
  655. # Positional Word Embeddings
  656. embedding_output = self.embeddings(input_ids, token_type_ids, inputs_embeds)
  657. # Run Lxmert encoder
  658. encoder_outputs = self.encoder(
  659. embedding_output,
  660. extended_attention_mask,
  661. visual_feats=visual_feats,
  662. visual_pos=visual_pos,
  663. visual_attention_mask=extended_visual_attention_mask,
  664. output_attentions=output_attentions,
  665. )
  666. visual_encoder_outputs, lang_encoder_outputs = encoder_outputs[:2]
  667. vision_hidden_states = visual_encoder_outputs[0]
  668. language_hidden_states = lang_encoder_outputs[0]
  669. all_attentions = ()
  670. if output_attentions:
  671. language_attentions = lang_encoder_outputs[1]
  672. vision_attentions = visual_encoder_outputs[1]
  673. cross_encoder_attentions = encoder_outputs[2]
  674. all_attentions = (
  675. language_attentions,
  676. vision_attentions,
  677. cross_encoder_attentions,
  678. )
  679. hidden_states = (language_hidden_states, vision_hidden_states) if output_hidden_states else ()
  680. visual_output = vision_hidden_states[-1]
  681. lang_output = language_hidden_states[-1]
  682. pooled_output = self.pooler(lang_output)
  683. if not return_dict:
  684. return (lang_output, visual_output, pooled_output) + hidden_states + all_attentions
  685. return LxmertModelOutput(
  686. pooled_output=pooled_output,
  687. language_output=lang_output,
  688. vision_output=visual_output,
  689. language_hidden_states=language_hidden_states if output_hidden_states else None,
  690. vision_hidden_states=vision_hidden_states if output_hidden_states else None,
  691. language_attentions=language_attentions if output_attentions else None,
  692. vision_attentions=vision_attentions if output_attentions else None,
  693. cross_encoder_attentions=cross_encoder_attentions if output_attentions else None,
  694. )
  695. @auto_docstring
  696. class LxmertForPreTraining(LxmertPreTrainedModel):
  697. # help saving them
  698. _tied_weights_keys = {
  699. "cls.predictions.decoder.weight": "lxmert.embeddings.word_embeddings.weight",
  700. }
  701. def __init__(self, config):
  702. super().__init__(config)
  703. # Configuration
  704. self.config = config
  705. self.num_qa_labels = config.num_qa_labels
  706. self.visual_loss_normalizer = config.visual_loss_normalizer
  707. # Use of pretraining tasks
  708. self.task_mask_lm = config.task_mask_lm
  709. self.task_obj_predict = config.task_obj_predict
  710. self.task_matched = config.task_matched
  711. self.task_qa = config.task_qa
  712. # Lxmert backbone
  713. self.lxmert = LxmertModel(config)
  714. # Pre-training heads
  715. self.cls = LxmertPreTrainingHeads(config)
  716. if self.task_obj_predict:
  717. self.obj_predict_head = LxmertVisualObjHead(config)
  718. if self.task_qa:
  719. self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
  720. # Weight initialization
  721. # Initialize weights and apply final processing
  722. self.post_init()
  723. # Loss functions
  724. self.loss_fcts = {
  725. "l2": SmoothL1Loss(reduction="none"),
  726. "visual_ce": CrossEntropyLoss(reduction="none"),
  727. "ce": CrossEntropyLoss(),
  728. }
  729. visual_losses = {}
  730. if config.visual_obj_loss:
  731. visual_losses["obj"] = {
  732. "shape": (-1,),
  733. "num": config.num_object_labels,
  734. "loss": "visual_ce",
  735. }
  736. if config.visual_attr_loss:
  737. visual_losses["attr"] = {
  738. "shape": (-1,),
  739. "num": config.num_attr_labels,
  740. "loss": "visual_ce",
  741. }
  742. if config.visual_feat_loss:
  743. visual_losses["feat"] = {
  744. "shape": (-1, config.visual_feat_dim),
  745. "num": config.visual_feat_dim,
  746. "loss": "l2",
  747. }
  748. self.visual_losses = visual_losses
  749. def resize_token_embeddings(
  750. self, new_num_tokens: int, pad_to_multiple_of: int | None = None, mean_resizing: bool = True
  751. ) -> nn.Embedding:
  752. # Adding the following steps to resize bias to match the shape of resized embeddings
  753. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  754. self.cls.predictions.bias = self._resize_bias(self.cls.predictions.bias, new_num_tokens)
  755. return new_embeddings
  756. def _resize_bias(self, bias, new_num_tokens: int):
  757. old_num_tokens = bias.shape[0]
  758. if new_num_tokens <= old_num_tokens:
  759. new_bias = bias[:new_num_tokens]
  760. else:
  761. extra_bias = torch.zeros(new_num_tokens - old_num_tokens, device=bias.device)
  762. new_bias = torch.cat([bias, extra_bias])
  763. new_bias = nn.Parameter(new_bias)
  764. return new_bias
  765. def resize_num_qa_labels(self, num_labels):
  766. """
  767. Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
  768. will add newly initialized weights. Reducing the size will remove weights from the end
  769. Args:
  770. num_labels (`int`, *optional*):
  771. New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
  772. weights at the end. Reducing the size will remove weights from the end. If not provided or `None`, just
  773. returns a pointer to the qa labels ``torch.nn.Linear``` module of the model without doing anything.
  774. Return:
  775. `torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
  776. """
  777. cur_qa_logit_layer = self.get_qa_logit_layer()
  778. if num_labels is None or cur_qa_logit_layer is None:
  779. return
  780. new_qa_logit_layer = self._resize_qa_labels(num_labels)
  781. self.config.num_qa_labels = num_labels
  782. self.num_qa_labels = num_labels
  783. return new_qa_logit_layer
  784. def _resize_qa_labels(self, num_labels):
  785. cur_qa_logit_layer = self.get_qa_logit_layer()
  786. new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
  787. self._set_qa_logit_layer(new_qa_logit_layer)
  788. return self.get_qa_logit_layer()
  789. def get_qa_logit_layer(self) -> nn.Module:
  790. """
  791. Returns the linear layer that produces question answering logits.
  792. Returns:
  793. `nn.Module`: A torch module mapping the question answering prediction hidden states or `None` if LXMERT
  794. does not have a visual answering head.
  795. """
  796. if hasattr(self, "answer_head"):
  797. return self.answer_head.logit_fc[-1]
  798. def _set_qa_logit_layer(self, qa_logit_layer):
  799. self.answer_head.logit_fc[-1] = qa_logit_layer
  800. def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
  801. if num_labels is None:
  802. return cur_qa_logit_layer
  803. cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
  804. if cur_qa_labels == num_labels:
  805. return cur_qa_logit_layer
  806. # Build new linear output
  807. if getattr(cur_qa_logit_layer, "bias", None) is not None:
  808. new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
  809. else:
  810. new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
  811. new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
  812. # initialize all new labels
  813. self._init_weights(new_qa_logit_layer)
  814. # Copy labels from the previous weights
  815. num_labels_to_copy = min(cur_qa_labels, num_labels)
  816. new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
  817. if getattr(cur_qa_logit_layer, "bias", None) is not None:
  818. new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
  819. return new_qa_logit_layer
  820. @auto_docstring
  821. def forward(
  822. self,
  823. input_ids: torch.LongTensor | None = None,
  824. visual_feats: torch.FloatTensor | None = None,
  825. visual_pos: torch.FloatTensor | None = None,
  826. attention_mask: torch.FloatTensor | None = None,
  827. visual_attention_mask: torch.FloatTensor | None = None,
  828. token_type_ids: torch.LongTensor | None = None,
  829. inputs_embeds: torch.FloatTensor | None = None,
  830. labels: torch.LongTensor | None = None,
  831. obj_labels: dict[str, tuple[torch.FloatTensor, torch.FloatTensor]] | None = None,
  832. matched_label: torch.LongTensor | None = None,
  833. ans: torch.Tensor | None = None,
  834. output_attentions: bool | None = None,
  835. output_hidden_states: bool | None = None,
  836. return_dict: bool | None = None,
  837. **kwargs,
  838. ) -> LxmertForPreTrainingOutput | tuple[torch.FloatTensor]:
  839. r"""
  840. visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
  841. This input represents visual features. They ROI pooled object features from bounding boxes using a
  842. faster-RCNN model)
  843. These are currently not provided by the transformers library.
  844. visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
  845. This input represents spatial features corresponding to their relative (via index) visual features. The
  846. pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
  847. 1.
  848. These are currently not provided by the transformers library.
  849. visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  850. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  851. - 1 for tokens that are **not masked**,
  852. - 0 for tokens that are **masked**.
  853. [What are attention masks?](../glossary#attention-mask)
  854. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  855. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  856. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  857. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  858. obj_labels (`dict[Str: tuple[Torch.FloatTensor, Torch.FloatTensor]]`, *optional*):
  859. each key is named after each one of the visual losses and each element of the tuple is of the shape
  860. `(batch_size, num_features)` and `(batch_size, num_features, visual_feature_dim)` for each the label id and
  861. the label score respectively
  862. matched_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  863. Labels for computing the whether or not the text input matches the image (classification) loss. Input
  864. should be a sequence pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
  865. - 0 indicates that the sentence does not match the image,
  866. - 1 indicates that the sentence does match the image.
  867. ans (`Torch.Tensor` of shape `(batch_size)`, *optional*):
  868. a one hot representation hof the correct answer *optional*
  869. """
  870. return_dict = return_dict if return_dict is not None else self.config.return_dict
  871. device = input_ids.device if input_ids is not None else inputs_embeds.device
  872. lxmert_output = self.lxmert(
  873. input_ids=input_ids,
  874. visual_feats=visual_feats,
  875. visual_pos=visual_pos,
  876. token_type_ids=token_type_ids,
  877. attention_mask=attention_mask,
  878. visual_attention_mask=visual_attention_mask,
  879. inputs_embeds=inputs_embeds,
  880. output_hidden_states=output_hidden_states,
  881. output_attentions=output_attentions,
  882. return_dict=return_dict,
  883. )
  884. lang_output, visual_output, pooled_output = (
  885. lxmert_output[0],
  886. lxmert_output[1],
  887. lxmert_output[2],
  888. )
  889. lang_prediction_scores, cross_relationship_score = self.cls(lang_output, pooled_output)
  890. if self.task_qa:
  891. answer_score = self.answer_head(pooled_output)
  892. else:
  893. answer_score = pooled_output[0][0]
  894. total_loss = (
  895. None
  896. if (labels is None and matched_label is None and obj_labels is None and ans is None)
  897. else torch.tensor(0.0, device=device)
  898. )
  899. if labels is not None and self.task_mask_lm:
  900. masked_lm_loss = self.loss_fcts["ce"](
  901. lang_prediction_scores.view(-1, self.config.vocab_size),
  902. labels.view(-1),
  903. )
  904. total_loss += masked_lm_loss
  905. if matched_label is not None and self.task_matched:
  906. matched_loss = self.loss_fcts["ce"](cross_relationship_score.view(-1, 2), matched_label.view(-1))
  907. total_loss += matched_loss
  908. if obj_labels is not None and self.task_obj_predict:
  909. total_visual_loss = torch.tensor(0.0, device=input_ids.device)
  910. visual_prediction_scores_dict = self.obj_predict_head(visual_output)
  911. for key, key_info in self.visual_losses.items():
  912. label, mask_conf = obj_labels[key]
  913. output_dim = key_info["num"]
  914. loss_fct_name = key_info["loss"]
  915. label_shape = key_info["shape"]
  916. weight = self.visual_loss_normalizer
  917. visual_loss_fct = self.loss_fcts[loss_fct_name]
  918. visual_prediction_scores = visual_prediction_scores_dict[key]
  919. visual_loss = visual_loss_fct(
  920. visual_prediction_scores.view(-1, output_dim),
  921. label.view(label_shape),
  922. )
  923. if visual_loss.dim() > 1: # Regression Losses
  924. visual_loss = visual_loss.mean(1)
  925. visual_loss = (visual_loss * mask_conf.view(-1)).mean() * weight
  926. total_visual_loss += visual_loss
  927. total_loss += total_visual_loss
  928. if ans is not None and self.task_qa:
  929. answer_loss = self.loss_fcts["ce"](answer_score.view(-1, self.num_qa_labels), ans.view(-1))
  930. total_loss += answer_loss
  931. if not return_dict:
  932. output = (
  933. lang_prediction_scores,
  934. cross_relationship_score,
  935. answer_score,
  936. ) + lxmert_output[3:]
  937. return ((total_loss,) + output) if total_loss is not None else output
  938. return LxmertForPreTrainingOutput(
  939. loss=total_loss,
  940. prediction_logits=lang_prediction_scores,
  941. cross_relationship_score=cross_relationship_score,
  942. question_answering_score=answer_score,
  943. language_hidden_states=lxmert_output.language_hidden_states,
  944. vision_hidden_states=lxmert_output.vision_hidden_states,
  945. language_attentions=lxmert_output.language_attentions,
  946. vision_attentions=lxmert_output.vision_attentions,
  947. cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
  948. )
  949. @auto_docstring(
  950. custom_intro="""
  951. Lxmert Model with a visual-answering head on top for downstream QA tasks
  952. """
  953. )
  954. class LxmertForQuestionAnswering(LxmertPreTrainedModel):
  955. def __init__(self, config):
  956. super().__init__(config)
  957. # Configuration
  958. self.config = config
  959. self.num_qa_labels = config.num_qa_labels
  960. self.visual_loss_normalizer = config.visual_loss_normalizer
  961. # Lxmert backbone
  962. self.lxmert = LxmertModel(config)
  963. self.answer_head = LxmertVisualAnswerHead(config, self.num_qa_labels)
  964. # Weight initialization
  965. # Initialize weights and apply final processing
  966. self.post_init()
  967. # Loss function
  968. self.loss = CrossEntropyLoss()
  969. def resize_num_qa_labels(self, num_labels):
  970. """
  971. Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
  972. will add newly initialized weights. Reducing the size will remove weights from the end
  973. Args:
  974. num_labels (`int`, *optional*):
  975. New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
  976. weights at the end. Reducing the size will remove weights from the end. If not provided or `None`, just
  977. returns a pointer to the qa labels ``torch.nn.Linear``` module of the model without doing anything.
  978. Return:
  979. `torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
  980. """
  981. cur_qa_logit_layer = self.get_qa_logit_layer()
  982. if num_labels is None or cur_qa_logit_layer is None:
  983. return
  984. new_qa_logit_layer = self._resize_qa_labels(num_labels)
  985. self.config.num_qa_labels = num_labels
  986. self.num_qa_labels = num_labels
  987. return new_qa_logit_layer
  988. def _resize_qa_labels(self, num_labels):
  989. cur_qa_logit_layer = self.get_qa_logit_layer()
  990. new_qa_logit_layer = self._get_resized_qa_labels(cur_qa_logit_layer, num_labels)
  991. self._set_qa_logit_layer(new_qa_logit_layer)
  992. return self.get_qa_logit_layer()
  993. def get_qa_logit_layer(self) -> nn.Module:
  994. """
  995. Returns the linear layer that produces question answering logits
  996. Returns:
  997. `nn.Module`: A torch module mapping the question answering prediction hidden states. `None`: A NoneType
  998. object if Lxmert does not have the visual answering head.
  999. """
  1000. if hasattr(self, "answer_head"):
  1001. return self.answer_head.logit_fc[-1]
  1002. def _set_qa_logit_layer(self, qa_logit_layer):
  1003. self.answer_head.logit_fc[-1] = qa_logit_layer
  1004. def _get_resized_qa_labels(self, cur_qa_logit_layer, num_labels):
  1005. if num_labels is None:
  1006. return cur_qa_logit_layer
  1007. cur_qa_labels, hidden_dim = cur_qa_logit_layer.weight.size()
  1008. if cur_qa_labels == num_labels:
  1009. return cur_qa_logit_layer
  1010. # Build new linear output
  1011. if getattr(cur_qa_logit_layer, "bias", None) is not None:
  1012. new_qa_logit_layer = nn.Linear(hidden_dim, num_labels)
  1013. else:
  1014. new_qa_logit_layer = nn.Linear(hidden_dim, num_labels, bias=False)
  1015. new_qa_logit_layer.to(cur_qa_logit_layer.weight.device)
  1016. # initialize all new labels
  1017. self._init_weights(new_qa_logit_layer)
  1018. # Copy labels from the previous weights
  1019. num_labels_to_copy = min(cur_qa_labels, num_labels)
  1020. new_qa_logit_layer.weight.data[:num_labels_to_copy, :] = cur_qa_logit_layer.weight.data[:num_labels_to_copy, :]
  1021. if getattr(cur_qa_logit_layer, "bias", None) is not None:
  1022. new_qa_logit_layer.bias.data[:num_labels_to_copy] = cur_qa_logit_layer.bias.data[:num_labels_to_copy]
  1023. return new_qa_logit_layer
  1024. @auto_docstring
  1025. def forward(
  1026. self,
  1027. input_ids: torch.LongTensor | None = None,
  1028. visual_feats: torch.FloatTensor | None = None,
  1029. visual_pos: torch.FloatTensor | None = None,
  1030. attention_mask: torch.FloatTensor | None = None,
  1031. visual_attention_mask: torch.FloatTensor | None = None,
  1032. token_type_ids: torch.LongTensor | None = None,
  1033. inputs_embeds: torch.FloatTensor | None = None,
  1034. labels: torch.Tensor | None = None,
  1035. output_attentions: bool | None = None,
  1036. output_hidden_states: bool | None = None,
  1037. return_dict: bool | None = None,
  1038. **kwargs,
  1039. ) -> LxmertForQuestionAnsweringOutput | tuple[torch.FloatTensor]:
  1040. r"""
  1041. visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
  1042. This input represents visual features. They ROI pooled object features from bounding boxes using a
  1043. faster-RCNN model)
  1044. These are currently not provided by the transformers library.
  1045. visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
  1046. This input represents spatial features corresponding to their relative (via index) visual features. The
  1047. pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
  1048. 1.
  1049. These are currently not provided by the transformers library.
  1050. visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1051. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1052. - 1 for tokens that are **not masked**,
  1053. - 0 for tokens that are **masked**.
  1054. [What are attention masks?](../glossary#attention-mask)
  1055. labels (`Torch.Tensor` of shape `(batch_size)`, *optional*):
  1056. A one-hot representation of the correct answer
  1057. """
  1058. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1059. lxmert_output = self.lxmert(
  1060. input_ids=input_ids,
  1061. visual_feats=visual_feats,
  1062. visual_pos=visual_pos,
  1063. token_type_ids=token_type_ids,
  1064. attention_mask=attention_mask,
  1065. visual_attention_mask=visual_attention_mask,
  1066. inputs_embeds=inputs_embeds,
  1067. output_hidden_states=output_hidden_states,
  1068. output_attentions=output_attentions,
  1069. return_dict=return_dict,
  1070. )
  1071. pooled_output = lxmert_output[2]
  1072. answer_score = self.answer_head(pooled_output)
  1073. loss = None
  1074. if labels is not None:
  1075. loss = self.loss(answer_score.view(-1, self.num_qa_labels), labels.view(-1))
  1076. if not return_dict:
  1077. output = (answer_score,) + lxmert_output[3:]
  1078. return (loss,) + output if loss is not None else output
  1079. return LxmertForQuestionAnsweringOutput(
  1080. loss=loss,
  1081. question_answering_score=answer_score,
  1082. language_hidden_states=lxmert_output.language_hidden_states,
  1083. vision_hidden_states=lxmert_output.vision_hidden_states,
  1084. language_attentions=lxmert_output.language_attentions,
  1085. vision_attentions=lxmert_output.vision_attentions,
  1086. cross_encoder_attentions=lxmert_output.cross_encoder_attentions,
  1087. )
  1088. __all__ = [
  1089. "LxmertEncoder",
  1090. "LxmertForPreTraining",
  1091. "LxmertForQuestionAnswering",
  1092. "LxmertModel",
  1093. "LxmertPreTrainedModel",
  1094. "LxmertVisualFeatureEncoder",
  1095. "LxmertXLayer",
  1096. ]