modeling_vilt.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286
  1. # Copyright 2022 NAVER AI Labs and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ViLT model."""
  15. import collections.abc
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. BaseModelOutputWithPooling,
  27. MaskedLMOutput,
  28. ModelOutput,
  29. SequenceClassifierOutput,
  30. TokenClassifierOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...utils import auto_docstring, logging
  34. from .configuration_vilt import ViltConfig
  35. logger = logging.get_logger(__name__)
  36. @dataclass
  37. @auto_docstring(
  38. custom_intro="""
  39. Class for outputs of [`ViltForImagesAndTextClassification`].
  40. """
  41. )
  42. class ViltForImagesAndTextClassificationOutput(ModelOutput):
  43. r"""
  44. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  45. Classification (or regression if config.num_labels==1) loss.
  46. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  47. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  48. hidden_states (`list[tuple(torch.FloatTensor)]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  49. List of tuples of `torch.FloatTensor` (one for each image-text pair, each tuple containing the output of
  50. the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  51. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  52. """
  53. loss: torch.FloatTensor | None = None
  54. logits: torch.FloatTensor | None = None
  55. hidden_states: list[tuple[torch.FloatTensor]] | None = None
  56. attentions: list[tuple[torch.FloatTensor]] | None = None
  57. class ViltEmbeddings(nn.Module):
  58. """
  59. Construct the text and patch embeddings.
  60. Text embeddings are equivalent to BERT embeddings.
  61. Patch embeddings are equivalent to ViT embeddings.
  62. """
  63. def __init__(self, config):
  64. super().__init__()
  65. # text embeddings
  66. self.text_embeddings = TextEmbeddings(config)
  67. # patch embeddings
  68. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  69. self.patch_embeddings = ViltPatchEmbeddings(config)
  70. num_patches = self.patch_embeddings.num_patches
  71. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  72. # modality type (text/patch) embeddings
  73. self.token_type_embeddings = nn.Embedding(config.modality_type_vocab_size, config.hidden_size)
  74. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  75. self.config = config
  76. def visual_embed(self, pixel_values, pixel_mask, max_image_length=200):
  77. _, _, ph, pw = self.patch_embeddings.projection.weight.shape
  78. x = self.patch_embeddings(pixel_values)
  79. x_mask = pixel_mask[:, None, :, :].float()
  80. x_mask = nn.functional.interpolate(x_mask, size=(x.shape[2], x.shape[3])).long()
  81. x_h = x_mask[:, 0].sum(dim=1)[:, 0]
  82. x_w = x_mask[:, 0].sum(dim=2)[:, 0]
  83. batch_size, num_channels, height, width = x.shape
  84. patch_dim = self.config.image_size // self.config.patch_size
  85. spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view(1, num_channels, patch_dim, patch_dim)
  86. pos_embed = torch.cat(
  87. [
  88. nn.functional.pad(
  89. nn.functional.interpolate(
  90. spatial_pos,
  91. size=(h, w),
  92. mode="bilinear",
  93. align_corners=True,
  94. ),
  95. (0, width - w, 0, height - h),
  96. )
  97. for h, w in zip(x_h, x_w)
  98. ],
  99. dim=0,
  100. )
  101. pos_embed = pos_embed.flatten(2).transpose(1, 2)
  102. x = x.flatten(2).transpose(1, 2)
  103. # Set `device` here, otherwise `patch_index` will always be on `CPU` and will fail near the end for torch>=1.13
  104. patch_index = torch.stack(
  105. torch.meshgrid(torch.arange(x_mask.shape[-2]), torch.arange(x_mask.shape[-1]), indexing="ij"), dim=-1
  106. ).to(device=x_mask.device)
  107. patch_index = patch_index[None, None, :, :, :]
  108. patch_index = patch_index.expand(x_mask.shape[0], x_mask.shape[1], -1, -1, -1)
  109. patch_index = patch_index.flatten(1, 3)
  110. x_mask = x_mask.flatten(1)
  111. if max_image_length < 0 or max_image_length is None or not isinstance(max_image_length, int):
  112. # suppose aug is 800 x 1333, then, maximum effective res is 800 x 1333 (if one side gets bigger, the other will be constrained and be shrunk)
  113. # (800 // self.patch_size) * (1333 // self.patch_size) is the maximum number of patches that single image can get.
  114. # if self.patch_size = 32, 25 * 41 = 1025
  115. # if res is 384 x 640, 12 * 20 = 240
  116. effective_resolution = x_h * x_w
  117. max_image_length = effective_resolution.max()
  118. else:
  119. effective_resolution = x_h * x_w
  120. max_image_length = min(effective_resolution.max(), max_image_length)
  121. valid_idx = x_mask.nonzero(as_tuple=False)
  122. non_valid_idx = (1 - x_mask).nonzero(as_tuple=False)
  123. unique_rows = valid_idx[:, 0].unique()
  124. valid_row_idx = [valid_idx[valid_idx[:, 0] == u] for u in unique_rows]
  125. non_valid_row_idx = [non_valid_idx[non_valid_idx[:, 0] == u] for u in unique_rows]
  126. valid_nums = [v.size(0) for v in valid_row_idx]
  127. non_valid_nums = [v.size(0) for v in non_valid_row_idx]
  128. pad_nums = [max_image_length - v for v in valid_nums]
  129. select = []
  130. for i, (v, nv, p) in enumerate(zip(valid_nums, non_valid_nums, pad_nums)):
  131. if p <= 0:
  132. valid_choice = torch.multinomial(torch.ones(v).float(), max_image_length)
  133. select.append(valid_row_idx[i][valid_choice])
  134. else:
  135. pad_choice = torch.multinomial(torch.ones(nv).float(), p, replacement=True)
  136. select.append(torch.cat([valid_row_idx[i], non_valid_row_idx[i][pad_choice]], dim=0))
  137. select = torch.cat(select, dim=0)
  138. x = x[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
  139. x_mask = x_mask[select[:, 0], select[:, 1]].view(batch_size, -1)
  140. # `patch_index` should be on the same device as `select`, which is ensured at definition time.
  141. patch_index = patch_index[select[:, 0], select[:, 1]].view(batch_size, -1, 2)
  142. pos_embed = pos_embed[select[:, 0], select[:, 1]].view(batch_size, -1, num_channels)
  143. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  144. x = torch.cat((cls_tokens, x), dim=1)
  145. pos_embed = torch.cat(
  146. (self.position_embeddings[:, 0, :][:, None, :].expand(batch_size, -1, -1), pos_embed), dim=1
  147. )
  148. x = x + pos_embed
  149. x = self.dropout(x)
  150. x_mask = torch.cat([torch.ones(x_mask.shape[0], 1).to(x_mask), x_mask], dim=1)
  151. return x, x_mask, (patch_index, (height, width))
  152. def forward(
  153. self,
  154. input_ids,
  155. attention_mask,
  156. token_type_ids,
  157. pixel_values,
  158. pixel_mask,
  159. inputs_embeds,
  160. image_embeds,
  161. image_token_type_idx=1,
  162. ):
  163. # PART 1: text embeddings
  164. text_embeds = self.text_embeddings(
  165. input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  166. )
  167. # PART 2: patch embeddings (with interpolated position encodings)
  168. if image_embeds is None:
  169. image_embeds, image_masks, patch_index = self.visual_embed(
  170. pixel_values, pixel_mask, max_image_length=self.config.max_image_length
  171. )
  172. else:
  173. image_masks = pixel_mask.flatten(1)
  174. # PART 3: add modality type embeddings
  175. # 0 indicates text, 1 indicates image, 2 is optionally used when a second image is provided (NLVR2)
  176. if image_token_type_idx is None:
  177. image_token_type_idx = 1
  178. text_embeds = text_embeds + self.token_type_embeddings(
  179. torch.zeros_like(attention_mask, dtype=torch.long, device=text_embeds.device)
  180. )
  181. image_embeds = image_embeds + self.token_type_embeddings(
  182. torch.full_like(image_masks, image_token_type_idx, dtype=torch.long, device=text_embeds.device)
  183. )
  184. # PART 4: concatenate
  185. embeddings = torch.cat([text_embeds, image_embeds], dim=1)
  186. masks = torch.cat([attention_mask, image_masks], dim=1)
  187. return embeddings, masks
  188. class TextEmbeddings(nn.Module):
  189. """Construct the embeddings from word, position and token_type embeddings."""
  190. def __init__(self, config):
  191. super().__init__()
  192. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  193. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  194. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  195. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  196. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  197. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  198. self.register_buffer(
  199. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  200. )
  201. self.register_buffer(
  202. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  203. )
  204. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  205. if input_ids is not None:
  206. input_shape = input_ids.size()
  207. else:
  208. input_shape = inputs_embeds.size()[:-1]
  209. seq_length = input_shape[1]
  210. if position_ids is None:
  211. position_ids = self.position_ids[:, :seq_length]
  212. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  213. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  214. # issue #5664
  215. if token_type_ids is None:
  216. if hasattr(self, "token_type_ids"):
  217. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  218. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  219. token_type_ids = buffered_token_type_ids_expanded
  220. else:
  221. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  222. if inputs_embeds is None:
  223. inputs_embeds = self.word_embeddings(input_ids)
  224. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  225. embeddings = inputs_embeds + token_type_embeddings
  226. position_embeddings = self.position_embeddings(position_ids)
  227. embeddings += position_embeddings
  228. embeddings = self.LayerNorm(embeddings)
  229. embeddings = self.dropout(embeddings)
  230. return embeddings
  231. class ViltPatchEmbeddings(nn.Module):
  232. """
  233. Image to Patch Embedding.
  234. """
  235. def __init__(self, config):
  236. super().__init__()
  237. image_size, patch_size = config.image_size, config.patch_size
  238. num_channels, hidden_size = config.num_channels, config.hidden_size
  239. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  240. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  241. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  242. self.image_size = image_size
  243. self.patch_size = patch_size
  244. self.num_channels = num_channels
  245. self.num_patches = num_patches
  246. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  247. def forward(self, pixel_values):
  248. batch_size, num_channels, height, width = pixel_values.shape
  249. if num_channels != self.num_channels:
  250. raise ValueError(
  251. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  252. )
  253. target_dtype = self.projection.weight.dtype
  254. x = self.projection(pixel_values.to(dtype=target_dtype))
  255. return x
  256. class ViltSelfAttention(nn.Module):
  257. def __init__(self, config):
  258. super().__init__()
  259. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  260. raise ValueError(
  261. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  262. f"heads {config.num_attention_heads}."
  263. )
  264. self.num_attention_heads = config.num_attention_heads
  265. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  266. self.all_head_size = self.num_attention_heads * self.attention_head_size
  267. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  268. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  269. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  270. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  271. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  272. input_shape = hidden_states.shape[:-1]
  273. hidden_shape = (*input_shape, -1, self.attention_head_size)
  274. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  275. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  276. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  277. # Take the dot product between "query" and "key" to get the raw attention scores.
  278. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  279. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  280. if attention_mask is not None:
  281. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  282. attention_scores = attention_scores + attention_mask
  283. # Normalize the attention scores to probabilities.
  284. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  285. # This is actually dropping out entire tokens to attend to, which might
  286. # seem a bit unusual, but is taken from the original Transformer paper.
  287. attention_probs = self.dropout(attention_probs)
  288. context_layer = torch.matmul(attention_probs, value_layer)
  289. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  290. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  291. context_layer = context_layer.view(*new_context_layer_shape)
  292. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  293. return outputs
  294. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vilt
  295. class ViltSelfOutput(nn.Module):
  296. """
  297. The residual connection is defined in ViltLayer instead of here (as is the case with other models), due to the
  298. layernorm applied before each block.
  299. """
  300. def __init__(self, config: ViltConfig):
  301. super().__init__()
  302. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  303. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  304. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  305. hidden_states = self.dense(hidden_states)
  306. hidden_states = self.dropout(hidden_states)
  307. return hidden_states
  308. class ViltAttention(nn.Module):
  309. def __init__(self, config):
  310. super().__init__()
  311. self.attention = ViltSelfAttention(config)
  312. self.output = ViltSelfOutput(config)
  313. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  314. self_outputs = self.attention(hidden_states, attention_mask, output_attentions)
  315. attention_output = self.output(self_outputs[0], hidden_states)
  316. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  317. return outputs
  318. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->Vilt
  319. class ViltIntermediate(nn.Module):
  320. def __init__(self, config: ViltConfig):
  321. super().__init__()
  322. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  323. if isinstance(config.hidden_act, str):
  324. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  325. else:
  326. self.intermediate_act_fn = config.hidden_act
  327. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  328. hidden_states = self.dense(hidden_states)
  329. hidden_states = self.intermediate_act_fn(hidden_states)
  330. return hidden_states
  331. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->Vilt
  332. class ViltOutput(nn.Module):
  333. def __init__(self, config: ViltConfig):
  334. super().__init__()
  335. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  336. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  337. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  338. hidden_states = self.dense(hidden_states)
  339. hidden_states = self.dropout(hidden_states)
  340. hidden_states = hidden_states + input_tensor
  341. return hidden_states
  342. class ViltLayer(GradientCheckpointingLayer):
  343. """This corresponds to the Block class in the timm implementation."""
  344. def __init__(self, config):
  345. super().__init__()
  346. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  347. self.seq_len_dim = 1
  348. self.attention = ViltAttention(config)
  349. self.intermediate = ViltIntermediate(config)
  350. self.output = ViltOutput(config)
  351. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  352. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  353. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  354. self_attention_outputs = self.attention(
  355. self.layernorm_before(hidden_states), # in ViLT, layernorm is applied before self-attention
  356. attention_mask,
  357. output_attentions=output_attentions,
  358. )
  359. attention_output = self_attention_outputs[0]
  360. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  361. # first residual connection
  362. hidden_states = attention_output + hidden_states.to(attention_output.device)
  363. # in ViLT, layernorm is also applied after self-attention
  364. layer_output = self.layernorm_after(hidden_states)
  365. layer_output = self.intermediate(layer_output)
  366. # second residual connection is done here
  367. layer_output = self.output(layer_output, hidden_states)
  368. outputs = (layer_output,) + outputs
  369. return outputs
  370. class ViltEncoder(nn.Module):
  371. def __init__(self, config):
  372. super().__init__()
  373. self.config = config
  374. self.layer = nn.ModuleList([ViltLayer(config) for _ in range(config.num_hidden_layers)])
  375. self.gradient_checkpointing = False
  376. def forward(
  377. self,
  378. hidden_states,
  379. attention_mask=None,
  380. output_attentions=False,
  381. output_hidden_states=False,
  382. return_dict=True,
  383. ):
  384. all_hidden_states = () if output_hidden_states else None
  385. all_self_attentions = () if output_attentions else None
  386. for i, layer_module in enumerate(self.layer):
  387. if output_hidden_states:
  388. all_hidden_states = all_hidden_states + (hidden_states,)
  389. layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
  390. hidden_states = layer_outputs[0]
  391. if output_attentions:
  392. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  393. if output_hidden_states:
  394. all_hidden_states = all_hidden_states + (hidden_states,)
  395. if not return_dict:
  396. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  397. return BaseModelOutput(
  398. last_hidden_state=hidden_states,
  399. hidden_states=all_hidden_states,
  400. attentions=all_self_attentions,
  401. )
  402. @auto_docstring
  403. class ViltPreTrainedModel(PreTrainedModel):
  404. config: ViltConfig
  405. base_model_prefix = "vilt"
  406. input_modalities = ("image", "text")
  407. supports_gradient_checkpointing = True
  408. _no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"]
  409. def _init_weights(self, module):
  410. super()._init_weights(module)
  411. if isinstance(module, TextEmbeddings):
  412. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  413. init.zeros_(module.token_type_ids)
  414. @auto_docstring
  415. class ViltModel(ViltPreTrainedModel):
  416. def __init__(self, config, add_pooling_layer=True):
  417. r"""
  418. add_pooling_layer (bool, *optional*, defaults to `True`):
  419. Whether to add a pooling layer
  420. """
  421. super().__init__(config)
  422. self.config = config
  423. self.embeddings = ViltEmbeddings(config)
  424. self.encoder = ViltEncoder(config)
  425. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  426. self.pooler = ViltPooler(config) if add_pooling_layer else None
  427. # Initialize weights and apply final processing
  428. self.post_init()
  429. def get_input_embeddings(self):
  430. return self.embeddings.text_embeddings.word_embeddings
  431. def set_input_embeddings(self, value):
  432. self.embeddings.text_embeddings.word_embeddings = value
  433. @auto_docstring
  434. def forward(
  435. self,
  436. input_ids: torch.LongTensor | None = None,
  437. attention_mask: torch.FloatTensor | None = None,
  438. token_type_ids: torch.LongTensor | None = None,
  439. pixel_values: torch.FloatTensor | None = None,
  440. pixel_mask: torch.LongTensor | None = None,
  441. inputs_embeds: torch.FloatTensor | None = None,
  442. image_embeds: torch.FloatTensor | None = None,
  443. image_token_type_idx: int | None = None,
  444. output_attentions: bool | None = None,
  445. output_hidden_states: bool | None = None,
  446. return_dict: bool | None = None,
  447. **kwargs,
  448. ) -> BaseModelOutputWithPooling | tuple[torch.FloatTensor]:
  449. r"""
  450. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  451. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  452. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  453. image_token_type_idx (`int`, *optional*):
  454. - The token type ids for images.
  455. Examples:
  456. ```python
  457. >>> from transformers import ViltProcessor, ViltModel
  458. >>> from PIL import Image
  459. >>> import httpx
  460. >>> from io import BytesIO
  461. >>> # prepare image and text
  462. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  463. >>> with httpx.stream("GET", url) as response:
  464. ... image = Image.open(BytesIO(response.read()))
  465. >>> text = "hello world"
  466. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
  467. >>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
  468. >>> inputs = processor(image, text, return_tensors="pt")
  469. >>> outputs = model(**inputs)
  470. >>> last_hidden_states = outputs.last_hidden_state
  471. ```"""
  472. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  473. output_hidden_states = (
  474. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  475. )
  476. return_dict = return_dict if return_dict is not None else self.config.return_dict
  477. if input_ids is not None and inputs_embeds is not None:
  478. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  479. elif input_ids is not None:
  480. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  481. input_shape = input_ids.size()
  482. elif inputs_embeds is not None:
  483. input_shape = inputs_embeds.size()[:-1]
  484. else:
  485. raise ValueError("You have to specify either input_ids or inputs_embeds")
  486. text_batch_size, seq_length = input_shape
  487. device = input_ids.device if input_ids is not None else inputs_embeds.device
  488. if attention_mask is None:
  489. attention_mask = torch.ones(((text_batch_size, seq_length)), device=device)
  490. if pixel_values is not None and image_embeds is not None:
  491. raise ValueError("You cannot specify both pixel_values and image_embeds at the same time")
  492. elif pixel_values is None and image_embeds is None:
  493. raise ValueError("You have to specify either pixel_values or image_embeds")
  494. image_batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeds.shape[0]
  495. if image_batch_size != text_batch_size:
  496. raise ValueError("The text inputs and image inputs need to have the same batch size")
  497. if pixel_mask is None:
  498. pixel_mask = torch.ones((image_batch_size, self.config.image_size, self.config.image_size), device=device)
  499. embedding_output, attention_mask = self.embeddings(
  500. input_ids,
  501. attention_mask,
  502. token_type_ids,
  503. pixel_values,
  504. pixel_mask,
  505. inputs_embeds,
  506. image_embeds,
  507. image_token_type_idx=image_token_type_idx,
  508. )
  509. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  510. # ourselves in which case we just need to make it broadcastable to all heads.
  511. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  512. encoder_outputs = self.encoder(
  513. embedding_output,
  514. attention_mask=extended_attention_mask,
  515. output_attentions=output_attentions,
  516. output_hidden_states=output_hidden_states,
  517. return_dict=return_dict,
  518. )
  519. sequence_output = encoder_outputs[0]
  520. sequence_output = self.layernorm(sequence_output)
  521. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  522. if not return_dict:
  523. return (sequence_output, pooled_output) + encoder_outputs[1:]
  524. return BaseModelOutputWithPooling(
  525. last_hidden_state=sequence_output,
  526. pooler_output=pooled_output,
  527. hidden_states=encoder_outputs.hidden_states,
  528. attentions=encoder_outputs.attentions,
  529. )
  530. class ViltPooler(nn.Module):
  531. def __init__(self, config):
  532. super().__init__()
  533. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  534. self.activation = nn.Tanh()
  535. def forward(self, hidden_states):
  536. # We "pool" the model by simply taking the hidden state corresponding
  537. # to the first token.
  538. first_token_tensor = hidden_states[:, 0]
  539. pooled_output = self.dense(first_token_tensor)
  540. pooled_output = self.activation(pooled_output)
  541. return pooled_output
  542. @auto_docstring(
  543. custom_intro="""
  544. ViLT Model with a language modeling head on top as done during pretraining.
  545. """
  546. )
  547. class ViltForMaskedLM(ViltPreTrainedModel):
  548. _tied_weights_keys = {
  549. "mlm_score.decoder.weight": "vilt.embeddings.text_embeddings.word_embeddings.weight",
  550. }
  551. def __init__(self, config):
  552. super().__init__(config)
  553. self.vilt = ViltModel(config)
  554. self.mlm_score = ViltMLMHead(config)
  555. # Initialize weights and apply final processing
  556. self.post_init()
  557. def get_output_embeddings(self):
  558. return self.mlm_score.decoder
  559. def set_output_embeddings(self, new_embeddings):
  560. self.mlm_score.decoder = new_embeddings
  561. self.mlm_score.bias = new_embeddings.bias
  562. @auto_docstring
  563. def forward(
  564. self,
  565. input_ids: torch.LongTensor | None = None,
  566. attention_mask: torch.FloatTensor | None = None,
  567. token_type_ids: torch.LongTensor | None = None,
  568. pixel_values: torch.FloatTensor | None = None,
  569. pixel_mask: torch.LongTensor | None = None,
  570. inputs_embeds: torch.FloatTensor | None = None,
  571. image_embeds: torch.FloatTensor | None = None,
  572. labels: torch.LongTensor | None = None,
  573. output_attentions: bool | None = None,
  574. output_hidden_states: bool | None = None,
  575. return_dict: bool | None = None,
  576. **kwargs,
  577. ) -> MaskedLMOutput | tuple[torch.FloatTensor]:
  578. r"""
  579. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  580. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  581. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  582. labels (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
  583. Labels for computing the masked language modeling loss. Indices should be in *[-100, 0, ...,
  584. config.vocab_size]* (see *input_ids* docstring) Tokens with indices set to *-100* are ignored (masked), the
  585. loss is only computed for the tokens with labels in *[0, ..., config.vocab_size]*
  586. Examples:
  587. ```python
  588. >>> from transformers import ViltProcessor, ViltForMaskedLM
  589. >>> import httpx
  590. >>> from io import BytesIO
  591. >>> from PIL import Image
  592. >>> import re
  593. >>> import torch
  594. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  595. >>> with httpx.stream("GET", url) as response:
  596. ... image = Image.open(BytesIO(response.read()))
  597. >>> text = "a bunch of [MASK] laying on a [MASK]."
  598. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
  599. >>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
  600. >>> # prepare inputs
  601. >>> encoding = processor(image, text, return_tensors="pt")
  602. >>> # forward pass
  603. >>> outputs = model(**encoding)
  604. >>> tl = len(re.findall("\[MASK\]", text))
  605. >>> inferred_token = [text]
  606. >>> # gradually fill in the MASK tokens, one by one
  607. >>> with torch.no_grad():
  608. ... for i in range(tl):
  609. ... encoded = processor.tokenizer(inferred_token)
  610. ... input_ids = torch.tensor(encoded.input_ids)
  611. ... encoded = encoded["input_ids"][0][1:-1]
  612. ... outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values)
  613. ... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
  614. ... # only take into account text features (minus CLS and SEP token)
  615. ... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
  616. ... mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
  617. ... # only take into account text
  618. ... mlm_values[torch.tensor(encoded) != 103] = 0
  619. ... select = mlm_values.argmax().item()
  620. ... encoded[select] = mlm_ids[select].item()
  621. ... inferred_token = [processor.decode(encoded)]
  622. >>> selected_token = ""
  623. >>> encoded = processor.tokenizer(inferred_token)
  624. >>> output = processor.decode(encoded.input_ids[0], skip_special_tokens=True)
  625. >>> print(output)
  626. a bunch of cats laying on a couch.
  627. ```"""
  628. return_dict = return_dict if return_dict is not None else self.config.return_dict
  629. outputs = self.vilt(
  630. input_ids,
  631. attention_mask=attention_mask,
  632. token_type_ids=token_type_ids,
  633. pixel_values=pixel_values,
  634. pixel_mask=pixel_mask,
  635. inputs_embeds=inputs_embeds,
  636. image_embeds=image_embeds,
  637. output_attentions=output_attentions,
  638. output_hidden_states=output_hidden_states,
  639. return_dict=return_dict,
  640. )
  641. sequence_output, pooled_output = outputs[:2]
  642. # split up final hidden states into text and image features
  643. text_seq_len = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  644. text_features, _ = (sequence_output[:, :text_seq_len], sequence_output[:, text_seq_len:])
  645. mlm_logits = self.mlm_score(text_features)
  646. masked_lm_loss = None
  647. if labels is not None:
  648. loss_fct = CrossEntropyLoss() # -100 index = padding token
  649. # move labels to correct device to enable PP
  650. labels = labels.to(mlm_logits.device)
  651. masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  652. if not return_dict:
  653. output = (mlm_logits,) + outputs[2:]
  654. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  655. return MaskedLMOutput(
  656. loss=masked_lm_loss,
  657. logits=mlm_logits,
  658. hidden_states=outputs.hidden_states,
  659. attentions=outputs.attentions,
  660. )
  661. class ViltPredictionHeadTransform(nn.Module):
  662. def __init__(self, config):
  663. super().__init__()
  664. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  665. if isinstance(config.hidden_act, str):
  666. self.transform_act_fn = ACT2FN[config.hidden_act]
  667. else:
  668. self.transform_act_fn = config.hidden_act
  669. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  670. def forward(self, hidden_states):
  671. hidden_states = self.dense(hidden_states)
  672. hidden_states = self.transform_act_fn(hidden_states)
  673. hidden_states = self.LayerNorm(hidden_states)
  674. return hidden_states
  675. class ViltMLMHead(nn.Module):
  676. def __init__(self, config):
  677. super().__init__()
  678. self.config = config
  679. self.transform = ViltPredictionHeadTransform(config)
  680. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  681. def forward(self, x):
  682. x = self.transform(x)
  683. x = self.decoder(x)
  684. return x
  685. @auto_docstring(
  686. custom_intro="""
  687. Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
  688. token) for visual question answering, e.g. for VQAv2.
  689. """
  690. )
  691. class ViltForQuestionAnswering(ViltPreTrainedModel):
  692. def __init__(self, config):
  693. super().__init__(config)
  694. self.num_labels = config.num_labels
  695. self.vilt = ViltModel(config)
  696. # Classifier head
  697. self.classifier = nn.Sequential(
  698. nn.Linear(config.hidden_size, config.hidden_size * 2),
  699. nn.LayerNorm(config.hidden_size * 2),
  700. nn.GELU(),
  701. nn.Linear(config.hidden_size * 2, config.num_labels),
  702. )
  703. # Initialize weights and apply final processing
  704. self.post_init()
  705. @auto_docstring
  706. def forward(
  707. self,
  708. input_ids: torch.LongTensor | None = None,
  709. attention_mask: torch.FloatTensor | None = None,
  710. token_type_ids: torch.LongTensor | None = None,
  711. pixel_values: torch.FloatTensor | None = None,
  712. pixel_mask: torch.LongTensor | None = None,
  713. inputs_embeds: torch.FloatTensor | None = None,
  714. image_embeds: torch.FloatTensor | None = None,
  715. labels: torch.LongTensor | None = None,
  716. output_attentions: bool | None = None,
  717. output_hidden_states: bool | None = None,
  718. return_dict: bool | None = None,
  719. **kwargs,
  720. ) -> SequenceClassifierOutput | tuple[torch.FloatTensor]:
  721. r"""
  722. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  723. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  724. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  725. labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
  726. Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
  727. all answers that are applicable for a given example in the batch, or a soft encoding indicating which
  728. answers are applicable, where 1.0 is the highest score.
  729. Examples:
  730. ```python
  731. >>> from transformers import ViltProcessor, ViltForQuestionAnswering
  732. >>> import httpx
  733. >>> from io import BytesIO
  734. >>> from PIL import Image
  735. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  736. >>> with httpx.stream("GET", url) as response:
  737. ... image = Image.open(BytesIO(response.read()))
  738. >>> text = "How many cats are there?"
  739. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
  740. >>> model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
  741. >>> # prepare inputs
  742. >>> encoding = processor(image, text, return_tensors="pt")
  743. >>> # forward pass
  744. >>> outputs = model(**encoding)
  745. >>> logits = outputs.logits
  746. >>> idx = logits.argmax(-1).item()
  747. >>> print("Predicted answer:", model.config.id2label[idx])
  748. Predicted answer: 2
  749. ```"""
  750. return_dict = return_dict if return_dict is not None else self.config.return_dict
  751. outputs = self.vilt(
  752. input_ids,
  753. attention_mask=attention_mask,
  754. token_type_ids=token_type_ids,
  755. pixel_values=pixel_values,
  756. pixel_mask=pixel_mask,
  757. inputs_embeds=inputs_embeds,
  758. image_embeds=image_embeds,
  759. output_attentions=output_attentions,
  760. output_hidden_states=output_hidden_states,
  761. return_dict=return_dict,
  762. )
  763. pooler_output = outputs.pooler_output if return_dict else outputs[1]
  764. logits = self.classifier(pooler_output)
  765. loss = None
  766. if labels is not None:
  767. # move labels to correct device to enable PP
  768. labels = labels.to(logits.device)
  769. loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) * labels.shape[1]
  770. # see https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19
  771. if not return_dict:
  772. output = (logits,) + outputs[2:]
  773. return ((loss,) + output) if loss is not None else output
  774. return SequenceClassifierOutput(
  775. loss=loss,
  776. logits=logits,
  777. hidden_states=outputs.hidden_states,
  778. attentions=outputs.attentions,
  779. )
  780. @auto_docstring(
  781. custom_intro="""
  782. Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
  783. token) for image-to-text or text-to-image retrieval, e.g. MSCOCO and F30K.
  784. """
  785. )
  786. class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
  787. def __init__(self, config):
  788. super().__init__(config)
  789. self.vilt = ViltModel(config)
  790. # Classifier head
  791. self.rank_output = nn.Linear(config.hidden_size, 1)
  792. # Initialize weights and apply final processing
  793. self.post_init()
  794. @auto_docstring
  795. def forward(
  796. self,
  797. input_ids: torch.LongTensor | None = None,
  798. attention_mask: torch.FloatTensor | None = None,
  799. token_type_ids: torch.LongTensor | None = None,
  800. pixel_values: torch.FloatTensor | None = None,
  801. pixel_mask: torch.LongTensor | None = None,
  802. inputs_embeds: torch.FloatTensor | None = None,
  803. image_embeds: torch.FloatTensor | None = None,
  804. labels: torch.LongTensor | None = None,
  805. output_attentions: bool | None = None,
  806. output_hidden_states: bool | None = None,
  807. return_dict: bool | None = None,
  808. **kwargs,
  809. ) -> SequenceClassifierOutput | tuple[torch.FloatTensor]:
  810. r"""
  811. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  812. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  813. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  814. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  815. Labels are currently not supported.
  816. Examples:
  817. ```python
  818. >>> from transformers import ViltProcessor, ViltForImageAndTextRetrieval
  819. >>> import httpx
  820. >>> from io import BytesIO
  821. >>> from PIL import Image
  822. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  823. >>> with httpx.stream("GET", url) as response:
  824. ... image = Image.open(BytesIO(response.read()))
  825. >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]
  826. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
  827. >>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")
  828. >>> # forward pass
  829. >>> scores = dict()
  830. >>> for text in texts:
  831. ... # prepare inputs
  832. ... encoding = processor(image, text, return_tensors="pt")
  833. ... outputs = model(**encoding)
  834. ... scores[text] = outputs.logits[0, :].item()
  835. ```"""
  836. return_dict = return_dict if return_dict is not None else self.config.return_dict
  837. loss = None
  838. if labels is not None:
  839. raise NotImplementedError("Training is not yet supported.")
  840. outputs = self.vilt(
  841. input_ids,
  842. attention_mask=attention_mask,
  843. token_type_ids=token_type_ids,
  844. pixel_values=pixel_values,
  845. pixel_mask=pixel_mask,
  846. inputs_embeds=inputs_embeds,
  847. image_embeds=image_embeds,
  848. output_attentions=output_attentions,
  849. output_hidden_states=output_hidden_states,
  850. return_dict=return_dict,
  851. )
  852. pooler_output = outputs.pooler_output if return_dict else outputs[1]
  853. logits = self.rank_output(pooler_output)
  854. if not return_dict:
  855. output = (logits,) + outputs[2:]
  856. return ((loss,) + output) if loss is not None else output
  857. return SequenceClassifierOutput(
  858. loss=loss,
  859. logits=logits,
  860. hidden_states=outputs.hidden_states,
  861. attentions=outputs.attentions,
  862. )
  863. @auto_docstring(
  864. custom_intro="""
  865. Vilt Model transformer with a classifier head on top for natural language visual reasoning, e.g. NLVR2.
  866. """
  867. )
  868. class ViltForImagesAndTextClassification(ViltPreTrainedModel):
  869. def __init__(self, config):
  870. super().__init__(config)
  871. self.num_labels = config.num_labels
  872. self.vilt = ViltModel(config)
  873. # Classifier head
  874. num_images = config.num_images
  875. self.classifier = nn.Sequential(
  876. nn.Linear(config.hidden_size * num_images, config.hidden_size * num_images),
  877. nn.LayerNorm(config.hidden_size * num_images),
  878. nn.GELU(),
  879. nn.Linear(config.hidden_size * num_images, config.num_labels),
  880. )
  881. # Initialize weights and apply final processing
  882. self.post_init()
  883. @auto_docstring
  884. def forward(
  885. self,
  886. input_ids: torch.LongTensor | None = None,
  887. attention_mask: torch.FloatTensor | None = None,
  888. token_type_ids: torch.LongTensor | None = None,
  889. pixel_values: torch.FloatTensor | None = None,
  890. pixel_mask: torch.LongTensor | None = None,
  891. inputs_embeds: torch.FloatTensor | None = None,
  892. image_embeds: torch.FloatTensor | None = None,
  893. labels: torch.LongTensor | None = None,
  894. output_attentions: bool | None = None,
  895. output_hidden_states: bool | None = None,
  896. return_dict: bool | None = None,
  897. **kwargs,
  898. ) -> ViltForImagesAndTextClassificationOutput | tuple[torch.FloatTensor]:
  899. r"""
  900. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  901. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  902. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  903. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  904. Binary classification labels.
  905. Examples:
  906. ```python
  907. >>> from transformers import ViltProcessor, ViltForImagesAndTextClassification
  908. >>> import httpx
  909. >>> from io import BytesIO
  910. >>> from PIL import Image
  911. >>> url_1 = "https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg"
  912. >>> with httpx.stream("GET", url_1) as response:
  913. ... image_1 = Image.open(BytesIO(response.read()))
  914. >>> url_2 = "https://lil.nlp.cornell.edu/nlvr/exs/ex0_1.jpg"
  915. >>> with httpx.stream("GET", url_2) as response:
  916. ... image_2 = Image.open(BytesIO(response.read()))
  917. >>> text = "The left image contains twice the number of dogs as the right image."
  918. >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
  919. >>> model = ViltForImagesAndTextClassification.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
  920. >>> # prepare inputs
  921. >>> encoding = processor([image_1, image_2], text, return_tensors="pt")
  922. >>> # forward pass
  923. >>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
  924. >>> logits = outputs.logits
  925. >>> idx = logits.argmax(-1).item()
  926. >>> print("Predicted answer:", model.config.id2label[idx])
  927. Predicted answer: True
  928. ```"""
  929. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  930. output_hidden_states = (
  931. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  932. )
  933. return_dict = return_dict if return_dict is not None else self.config.return_dict
  934. if pixel_values is not None and pixel_values.ndim == 4:
  935. # add dummy num_images dimension
  936. pixel_values = pixel_values.unsqueeze(1)
  937. if image_embeds is not None and image_embeds.ndim == 3:
  938. # add dummy num_images dimension
  939. image_embeds = image_embeds.unsqueeze(1)
  940. num_images = pixel_values.shape[1] if pixel_values is not None else None
  941. if num_images is None:
  942. num_images = image_embeds.shape[1] if image_embeds is not None else None
  943. if num_images != self.config.num_images:
  944. raise ValueError(
  945. "Make sure to match the number of images in the model with the number of images in the input."
  946. )
  947. pooler_outputs = []
  948. hidden_states = [] if output_hidden_states else None
  949. attentions = [] if output_attentions else None
  950. for i in range(num_images):
  951. # forward every image through the model
  952. outputs = self.vilt(
  953. input_ids,
  954. attention_mask=attention_mask,
  955. token_type_ids=token_type_ids,
  956. pixel_values=pixel_values[:, i, :, :, :] if pixel_values is not None else None,
  957. pixel_mask=pixel_mask[:, i, :, :] if pixel_mask is not None else None,
  958. inputs_embeds=inputs_embeds,
  959. image_embeds=image_embeds[:, i, :, :] if image_embeds is not None else None,
  960. image_token_type_idx=i + 1,
  961. output_attentions=output_attentions,
  962. output_hidden_states=output_hidden_states,
  963. return_dict=return_dict,
  964. )
  965. pooler_output = outputs.pooler_output if return_dict else outputs[1]
  966. pooler_outputs.append(pooler_output)
  967. if output_hidden_states:
  968. hidden_states.append(outputs.hidden_states)
  969. if output_attentions:
  970. attentions.append(outputs.attentions)
  971. pooled_output = torch.cat(pooler_outputs, dim=-1)
  972. logits = self.classifier(pooled_output)
  973. loss = None
  974. if labels is not None:
  975. loss_fct = CrossEntropyLoss()
  976. # move labels to correct device to enable PP
  977. labels = labels.to(logits.device)
  978. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  979. if not return_dict:
  980. output = (logits, hidden_states, attentions)
  981. return ((loss,) + output) if loss is not None else output
  982. return ViltForImagesAndTextClassificationOutput(
  983. loss=loss,
  984. logits=logits,
  985. hidden_states=hidden_states,
  986. attentions=attentions,
  987. )
  988. @auto_docstring
  989. class ViltForTokenClassification(ViltPreTrainedModel):
  990. def __init__(self, config):
  991. super().__init__(config)
  992. self.num_labels = config.num_labels
  993. self.vilt = ViltModel(config, add_pooling_layer=False)
  994. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  995. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  996. # Initialize weights and apply final processing
  997. self.post_init()
  998. @auto_docstring
  999. def forward(
  1000. self,
  1001. input_ids: torch.LongTensor | None = None,
  1002. attention_mask: torch.FloatTensor | None = None,
  1003. token_type_ids: torch.LongTensor | None = None,
  1004. pixel_values: torch.FloatTensor | None = None,
  1005. pixel_mask: torch.LongTensor | None = None,
  1006. inputs_embeds: torch.FloatTensor | None = None,
  1007. image_embeds: torch.FloatTensor | None = None,
  1008. labels: torch.LongTensor | None = None,
  1009. output_attentions: bool | None = None,
  1010. output_hidden_states: bool | None = None,
  1011. return_dict: bool | None = None,
  1012. **kwargs,
  1013. ) -> TokenClassifierOutput | tuple[torch.FloatTensor]:
  1014. r"""
  1015. image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
  1016. Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
  1017. This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
  1018. labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
  1019. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1020. """
  1021. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1022. outputs = self.vilt(
  1023. input_ids,
  1024. attention_mask=attention_mask,
  1025. token_type_ids=token_type_ids,
  1026. pixel_values=pixel_values,
  1027. pixel_mask=pixel_mask,
  1028. inputs_embeds=inputs_embeds,
  1029. image_embeds=image_embeds,
  1030. output_attentions=output_attentions,
  1031. output_hidden_states=output_hidden_states,
  1032. return_dict=return_dict,
  1033. )
  1034. sequence_output = outputs[0]
  1035. text_input_size = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1036. sequence_output = self.dropout(sequence_output)
  1037. logits = self.classifier(sequence_output[:, :text_input_size])
  1038. loss = None
  1039. if labels is not None:
  1040. loss_fct = CrossEntropyLoss()
  1041. # move labels to correct device to enable PP
  1042. labels = labels.to(logits.device)
  1043. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1044. if not return_dict:
  1045. output = (logits,) + outputs[2:]
  1046. return ((loss,) + output) if loss is not None else output
  1047. return TokenClassifierOutput(
  1048. loss=loss,
  1049. logits=logits,
  1050. hidden_states=outputs.hidden_states,
  1051. attentions=outputs.attentions,
  1052. )
  1053. __all__ = [
  1054. "ViltForImageAndTextRetrieval",
  1055. "ViltForImagesAndTextClassification",
  1056. "ViltForTokenClassification",
  1057. "ViltForMaskedLM",
  1058. "ViltForQuestionAnswering",
  1059. "ViltLayer",
  1060. "ViltModel",
  1061. "ViltPreTrainedModel",
  1062. ]