modeling_flava.py 87 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939
  1. # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch FLAVA model."""
  15. import collections
  16. import math
  17. from collections import OrderedDict
  18. from dataclasses import dataclass
  19. from typing import Any
  20. import torch
  21. from torch import nn
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  26. from ...modeling_utils import PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  29. from .configuration_flava import (
  30. FlavaConfig,
  31. FlavaImageCodebookConfig,
  32. FlavaImageConfig,
  33. FlavaMultimodalConfig,
  34. FlavaTextConfig,
  35. )
  36. logger = logging.get_logger(__name__)
  37. _CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook"
  38. LOGIT_SCALE_CLAMP_MIN = 0
  39. LOGIT_SCALE_CLAMP_MAX = 4.6052
  40. FlavaPossibleConfigs = FlavaTextConfig | FlavaImageConfig | FlavaMultimodalConfig
  41. @dataclass
  42. @auto_docstring(
  43. custom_intro="""
  44. Output from FlavaModel containing embeddings and outputs from individual encoders.
  45. Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
  46. transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
  47. `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
  48. """
  49. )
  50. class FlavaModelOutput(ModelOutput):
  51. r"""
  52. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
  53. The image embeddings which are basically the pooled output of [`FlavaImageModel`].
  54. image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
  55. The output of the [`FlavaImageModel`].
  56. text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
  57. The text embeddings which are basically the pooled output of [`FlavaTextModel`].
  58. text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
  59. The output of the [`FlavaTextModel`].
  60. multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
  61. The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
  62. multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
  63. The output of the [`FlavaMultimodalModel`].
  64. """
  65. image_embeddings: torch.FloatTensor | None = None
  66. image_output: BaseModelOutputWithPooling | None = None
  67. text_embeddings: torch.FloatTensor | None = None
  68. text_output: BaseModelOutputWithPooling | None = None
  69. multimodal_embeddings: torch.FloatTensor | None = None
  70. multimodal_output: BaseModelOutputWithPooling | None = None
  71. def to_tuple(self) -> tuple[Any]:
  72. return tuple(
  73. self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple()
  74. for k in self.keys()
  75. )
  76. @dataclass
  77. @auto_docstring(
  78. custom_intro="""
  79. Class representing pretraining losses from FLAVA model
  80. """
  81. )
  82. class FlavaLosses(ModelOutput):
  83. r"""
  84. mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.):
  85. Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
  86. mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.):
  87. Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
  88. itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.):
  89. Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
  90. masked pairs in FLAVA.
  91. global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.):
  92. Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
  93. data. This is calculated on unmasked images and texts.
  94. mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.):
  95. Masked Multimodal Modeling loss's image component calculated on paired image-text data.
  96. mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.):
  97. Masked Multimodal Modeling loss's text component calculated on paired image-text data.
  98. """
  99. mim: torch.FloatTensor | None = None
  100. mlm: torch.FloatTensor | None = None
  101. itm: torch.FloatTensor | None = None
  102. global_contrastive: torch.FloatTensor | None = None
  103. mmm_image: torch.FloatTensor | None = None
  104. mmm_text: torch.FloatTensor | None = None
  105. def all_none(self) -> bool:
  106. all_none = True
  107. for v in self.values():
  108. if v is not None:
  109. all_none = False
  110. break
  111. return all_none
  112. @dataclass
  113. @auto_docstring(
  114. custom_intro="""
  115. Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.
  116. Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
  117. transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
  118. `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
  119. """
  120. )
  121. class FlavaForPreTrainingOutput(ModelOutput):
  122. r"""
  123. loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True):
  124. Total loss calculated for this model.
  125. loss_info (`FlavaLosses`):
  126. Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on
  127. the keys.
  128. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
  129. The image embeddings which are basically the pooled output of [`FlavaImageModel`].
  130. image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
  131. The output of the [`FlavaImageModel`].
  132. text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
  133. The text embeddings which are basically the pooled output of [`FlavaTextModel`].
  134. text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
  135. The output of the [`FlavaTextModel`].
  136. multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
  137. The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
  138. multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
  139. The output of the [`FlavaMultimodalModel`].
  140. image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
  141. The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos`
  142. to create masked images.
  143. image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
  144. The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images.
  145. text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present):
  146. The text embeddings which are basically the pooled output of [`FlavaTextModel`].
  147. text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present):
  148. The output of the [`FlavaTextModel`].
  149. multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present):
  150. The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
  151. multimodal_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
  152. The output of the [`FlavaMultimodalModel`].
  153. mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not):
  154. The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is
  155. returned when `bool_masked_pos` has some of the patches masked.
  156. mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not):
  157. The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of
  158. the tokens masked.
  159. itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
  160. The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA.
  161. contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  162. The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's
  163. `image_projection` and `text_projection` layers respectively. This represents the image-text similarity
  164. scores. This is calculated on unmasked images and texts.
  165. contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  166. The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's
  167. `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and
  168. texts.
  169. mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present):
  170. The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened
  171. output is returned when `bool_masked_pos` has some of the patches masked.
  172. mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present):
  173. The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has
  174. some of the tokens masked.
  175. """
  176. loss: torch.FloatTensor | None = None
  177. loss_info: FlavaLosses = None
  178. image_embeddings: torch.FloatTensor | None = None
  179. image_output: BaseModelOutputWithPooling | None = None
  180. text_embeddings: torch.FloatTensor | None = None
  181. text_output: BaseModelOutputWithPooling | None = None
  182. multimodal_embeddings: torch.FloatTensor | None = None
  183. multimodal_output: BaseModelOutputWithPooling | None = None
  184. image_masked_embeddings: torch.FloatTensor | None = None
  185. image_masked_output: BaseModelOutputWithPooling | None = None
  186. text_masked_embeddings: torch.FloatTensor | None = None
  187. text_masked_output: BaseModelOutputWithPooling | None = None
  188. multimodal_masked_embeddings: torch.FloatTensor | None = None
  189. multimodal_masked_output: BaseModelOutputWithPooling | None = None
  190. mim_logits: torch.FloatTensor | None = None
  191. mlm_logits: torch.FloatTensor | None = None
  192. itm_logits: torch.FloatTensor | None = None
  193. contrastive_logits_per_image: torch.FloatTensor | None = None
  194. contrastive_logits_per_text: torch.FloatTensor | None = None
  195. mmm_image_logits: torch.FloatTensor | None = None
  196. mmm_text_logits: torch.FloatTensor | None = None
  197. def to_tuple(self) -> tuple[Any]:
  198. transformer_outputs = [
  199. "text_output",
  200. "image_output",
  201. "multimodal_output",
  202. "text_masked_output",
  203. "image_masked_output",
  204. "multimodal_masked_output",
  205. ]
  206. return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys())
  207. # Based on timm implementation, which can be found here:
  208. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
  209. class FlavaImageEmbeddings(nn.Module):
  210. """
  211. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  212. """
  213. def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None:
  214. super().__init__()
  215. use_mask_token = use_mask_token or config.mask_token
  216. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  217. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
  218. self.patch_embeddings = PatchEmbeddings(
  219. image_size=config.image_size,
  220. patch_size=config.patch_size,
  221. num_channels=config.num_channels,
  222. embed_dim=config.hidden_size,
  223. )
  224. num_patches = self.patch_embeddings.num_patches
  225. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  226. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  227. self.patch_size = config.patch_size
  228. self.config = config
  229. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  230. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  231. """
  232. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  233. images. This method is also adapted to support torch.jit tracing.
  234. Adapted from:
  235. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  236. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  237. """
  238. num_patches = embeddings.shape[1] - 1
  239. num_positions = self.position_embeddings.shape[1] - 1
  240. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  241. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  242. return self.position_embeddings
  243. class_pos_embed = self.position_embeddings[:, :1]
  244. patch_pos_embed = self.position_embeddings[:, 1:]
  245. dim = embeddings.shape[-1]
  246. new_height = height // self.patch_size
  247. new_width = width // self.patch_size
  248. sqrt_num_positions = torch_int(num_positions**0.5)
  249. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  250. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  251. patch_pos_embed = nn.functional.interpolate(
  252. patch_pos_embed,
  253. size=(new_height, new_width),
  254. mode="bicubic",
  255. align_corners=False,
  256. )
  257. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  258. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  259. def forward(
  260. self,
  261. pixel_values: torch.Tensor,
  262. bool_masked_pos: torch.BoolTensor | None = None,
  263. interpolate_pos_encoding: bool = False,
  264. ) -> torch.Tensor:
  265. batch_size, num_channels, height, width = pixel_values.shape
  266. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  267. batch_size, seq_len, _ = embeddings.size()
  268. if bool_masked_pos is not None:
  269. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  270. # B X H X W = B X HW
  271. if bool_masked_pos.dim() == 3:
  272. bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1)
  273. # replace the masked visual tokens by mask_tokens
  274. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  275. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  276. # add the [CLS] token to the embedded patch tokens
  277. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  278. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  279. # add positional encoding to each token
  280. if interpolate_pos_encoding:
  281. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  282. else:
  283. embeddings = embeddings + self.position_embeddings
  284. embeddings = self.dropout(embeddings)
  285. return embeddings
  286. # Based on timm implementation, which can be found here:
  287. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
  288. class PatchEmbeddings(nn.Module):
  289. """
  290. Image to Patch Embedding.
  291. """
  292. def __init__(
  293. self,
  294. image_size: int | list[int] | tuple[int, int] = 224,
  295. patch_size: int | tuple[int, int] = 16,
  296. num_channels: int = 3,
  297. embed_dim: int = 768,
  298. ):
  299. super().__init__()
  300. if not isinstance(image_size, collections.abc.Iterable):
  301. image_size = (image_size, image_size)
  302. if not isinstance(patch_size, collections.abc.Iterable):
  303. patch_size = (patch_size, patch_size)
  304. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  305. self.image_size = image_size
  306. self.patch_size = patch_size
  307. self.num_patches = num_patches
  308. self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
  309. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  310. batch_size, num_channels, height, width = pixel_values.shape
  311. if not interpolate_pos_encoding:
  312. if height != self.image_size[0] or width != self.image_size[1]:
  313. raise ValueError(
  314. f"Input image size ({height}*{width}) doesn't match model"
  315. f" ({self.image_size[0]}*{self.image_size[1]})."
  316. )
  317. x = self.projection(pixel_values).flatten(2).transpose(1, 2)
  318. return x
  319. class FlavaTextEmbeddings(nn.Module):
  320. """Construct the embeddings from word, position and token_type embeddings."""
  321. def __init__(self, config):
  322. super().__init__()
  323. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  324. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  325. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  326. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  327. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  328. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  329. self.register_buffer(
  330. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  331. )
  332. self.register_buffer(
  333. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  334. )
  335. def forward(
  336. self,
  337. input_ids: torch.Tensor | None = None,
  338. token_type_ids: torch.Tensor | None = None,
  339. position_ids: torch.Tensor | None = None,
  340. ):
  341. input_shape = input_ids.size()
  342. seq_length = input_shape[1]
  343. if position_ids is None:
  344. position_ids = self.position_ids[:, :seq_length]
  345. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  346. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  347. # issue #5664
  348. if token_type_ids is None:
  349. if hasattr(self, "token_type_ids"):
  350. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  351. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  352. token_type_ids = buffered_token_type_ids_expanded
  353. else:
  354. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  355. inputs_embeds = self.word_embeddings(input_ids)
  356. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  357. embeddings = inputs_embeds + token_type_embeddings
  358. position_embeddings = self.position_embeddings(position_ids)
  359. embeddings += position_embeddings
  360. embeddings = self.LayerNorm(embeddings)
  361. embeddings = self.dropout(embeddings)
  362. return embeddings
  363. class FlavaSelfAttention(nn.Module):
  364. def __init__(self, config: FlavaPossibleConfigs) -> None:
  365. super().__init__()
  366. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  367. raise ValueError(
  368. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  369. f"heads {config.num_attention_heads}."
  370. )
  371. self.num_attention_heads = config.num_attention_heads
  372. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  373. self.all_head_size = self.num_attention_heads * self.attention_head_size
  374. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  375. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  376. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  377. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  378. def forward(
  379. self,
  380. hidden_states: torch.Tensor,
  381. attention_mask: torch.Tensor | None = None,
  382. output_attentions: bool = False,
  383. ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
  384. input_shape = hidden_states.shape[:-1]
  385. hidden_shape = (*input_shape, -1, self.attention_head_size)
  386. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  387. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  388. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  389. # Take the dot product between "query" and "key" to get the raw attention scores.
  390. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  391. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  392. if attention_mask is not None:
  393. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  394. attention_scores = attention_scores + attention_mask
  395. # Normalize the attention scores to probabilities.
  396. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  397. # This is actually dropping out entire tokens to attend to, which might
  398. # seem a bit unusual, but is taken from the original Transformer paper.
  399. attention_probs = self.dropout(attention_probs)
  400. context_layer = torch.matmul(attention_probs, value_layer)
  401. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  402. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  403. context_layer = context_layer.view(*new_context_layer_shape)
  404. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  405. return outputs
  406. class FlavaSelfOutput(nn.Module):
  407. """
  408. The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
  409. models), due to the layernorm applied before each block.
  410. """
  411. def __init__(self, config: FlavaPossibleConfigs) -> None:
  412. super().__init__()
  413. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  414. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  415. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  416. hidden_states = self.dense(hidden_states)
  417. hidden_states = self.dropout(hidden_states)
  418. return hidden_states
  419. class FlavaAttention(nn.Module):
  420. def __init__(self, config: FlavaPossibleConfigs) -> None:
  421. super().__init__()
  422. self.attention = FlavaSelfAttention(config)
  423. self.output = FlavaSelfOutput(config)
  424. def forward(
  425. self,
  426. hidden_states: torch.Tensor,
  427. attention_mask: torch.Tensor | None = None,
  428. output_attentions: bool = False,
  429. ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
  430. self_outputs = self.attention(
  431. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  432. )
  433. attention_output = self.output(self_outputs[0], hidden_states)
  434. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  435. return outputs
  436. class FlavaIntermediate(nn.Module):
  437. def __init__(self, config: FlavaPossibleConfigs) -> None:
  438. super().__init__()
  439. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  440. if isinstance(config.hidden_act, str):
  441. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  442. else:
  443. self.intermediate_act_fn = config.hidden_act
  444. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate.forward
  445. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  446. hidden_states = self.dense(hidden_states)
  447. hidden_states = self.intermediate_act_fn(hidden_states)
  448. return hidden_states
  449. class FlavaOutput(nn.Module):
  450. def __init__(self, config: FlavaPossibleConfigs) -> None:
  451. super().__init__()
  452. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  453. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  454. # Copied from transformers.models.vit.modeling_vit.ViTOutput.forward
  455. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  456. hidden_states = self.dense(hidden_states)
  457. hidden_states = self.dropout(hidden_states)
  458. hidden_states = hidden_states + input_tensor
  459. return hidden_states
  460. class FlavaLayer(GradientCheckpointingLayer):
  461. """This corresponds to the Block class in the timm implementation."""
  462. def __init__(self, config: FlavaPossibleConfigs) -> None:
  463. super().__init__()
  464. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  465. self.seq_len_dim = 1
  466. self.attention = FlavaAttention(config)
  467. self.intermediate = FlavaIntermediate(config)
  468. self.output = FlavaOutput(config)
  469. # TODO: Check fp32 layer norm possibility
  470. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  471. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  472. def forward(
  473. self,
  474. hidden_states: torch.Tensor,
  475. attention_mask: torch.Tensor | None = None,
  476. output_attentions: bool = False,
  477. ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
  478. self_attention_outputs = self.attention(
  479. self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
  480. attention_mask=attention_mask,
  481. output_attentions=output_attentions,
  482. )
  483. attention_output = self_attention_outputs[0]
  484. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  485. # first residual connection
  486. hidden_states = attention_output + hidden_states
  487. # in ViT, layernorm is also applied after self-attention
  488. layer_output = self.layernorm_after(hidden_states)
  489. layer_output = self.intermediate(layer_output)
  490. # second residual connection is done here
  491. layer_output = self.output(layer_output, hidden_states)
  492. outputs = (layer_output,) + outputs
  493. return outputs
  494. class FlavaEncoder(nn.Module):
  495. def __init__(self, config: FlavaConfig) -> None:
  496. super().__init__()
  497. self.config = config
  498. self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)])
  499. self.gradient_checkpointing = False
  500. def forward(
  501. self,
  502. hidden_states: torch.Tensor,
  503. attention_mask: torch.Tensor | None = None,
  504. output_attentions: bool = False,
  505. output_hidden_states: bool = False,
  506. return_dict: bool = True,
  507. ) -> tuple | BaseModelOutput:
  508. all_hidden_states = () if output_hidden_states else None
  509. all_self_attentions = () if output_attentions else None
  510. for i, layer_module in enumerate(self.layer):
  511. if output_hidden_states:
  512. all_hidden_states = all_hidden_states + (hidden_states,)
  513. layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
  514. hidden_states = layer_outputs[0]
  515. if output_attentions:
  516. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  517. if output_hidden_states:
  518. all_hidden_states = all_hidden_states + (hidden_states,)
  519. if not return_dict:
  520. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  521. return BaseModelOutput(
  522. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
  523. )
  524. class FlavaPooler(nn.Module):
  525. def __init__(self, config: FlavaPossibleConfigs):
  526. super().__init__()
  527. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  528. self.activation = nn.Tanh()
  529. def forward(self, hidden_states: torch.Tensor):
  530. # We "pool" the model by simply taking the hidden state corresponding
  531. # to the first token.
  532. first_token_tensor = hidden_states[:, 0]
  533. pooled_output = self.dense(first_token_tensor)
  534. pooled_output = self.activation(pooled_output)
  535. return pooled_output
  536. @auto_docstring
  537. class FlavaPreTrainedModel(PreTrainedModel):
  538. config: FlavaConfig
  539. base_model_prefix = "flava"
  540. input_modalities = ("image", "text")
  541. supports_gradient_checkpointing = True
  542. @torch.no_grad()
  543. def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
  544. """Initialize the weights"""
  545. super()._init_weights(module)
  546. if isinstance(module, FlavaMaskedPredictionHead):
  547. init.zeros_(module.bias)
  548. elif isinstance(module, FlavaImageEmbeddings):
  549. init.zeros_(module.cls_token)
  550. init.zeros_(module.position_embeddings)
  551. if module.mask_token is not None:
  552. init.zeros_(module.mask_token)
  553. elif isinstance(module, FlavaTextEmbeddings):
  554. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  555. init.zeros_(module.token_type_ids)
  556. elif isinstance(module, FlavaMultimodalModel):
  557. if module.use_cls_token:
  558. init.zeros_(module.cls_token)
  559. elif isinstance(module, FlavaModel):
  560. init.constant_(module.logit_scale, self.config.logit_scale_init_value)
  561. @auto_docstring
  562. class FlavaImageModel(FlavaPreTrainedModel):
  563. config: FlavaImageConfig
  564. # This override allows us to load FlavaImageModel from FlavaModel/FlavaForPreTraining checkpoints.
  565. base_model_prefix = "flava.image_model"
  566. main_input_name = "pixel_values"
  567. input_modalities = ("image",)
  568. def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True):
  569. r"""
  570. add_pooling_layer (bool, *optional*, defaults to `True`):
  571. Whether to add a pooling layer
  572. """
  573. super().__init__(config)
  574. self.config = config
  575. self.embeddings = FlavaImageEmbeddings(config)
  576. self.encoder = FlavaEncoder(config)
  577. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  578. self.pooler = FlavaPooler(config) if add_pooling_layer else None
  579. self.post_init()
  580. def get_input_embeddings(self) -> nn.Module:
  581. return self.embeddings.patch_embeddings
  582. def set_input_embeddings(self, value: nn.Module):
  583. self.embeddings.patch_embeddings = value
  584. @auto_docstring
  585. def forward(
  586. self,
  587. pixel_values: torch.Tensor | None = None,
  588. bool_masked_pos: torch.BoolTensor | None = None,
  589. interpolate_pos_encoding: bool | None = None,
  590. attention_mask: torch.Tensor | None = None,
  591. output_attentions: bool | None = None,
  592. output_hidden_states: bool | None = None,
  593. return_dict: bool | None = None,
  594. **kwargs,
  595. ) -> tuple | BaseModelOutputWithPooling:
  596. r"""
  597. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
  598. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  599. """
  600. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  601. output_hidden_states = (
  602. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  603. )
  604. return_dict = return_dict if return_dict is not None else self.config.return_dict
  605. if pixel_values is None:
  606. raise ValueError("You have to specify pixel_values")
  607. embedding_output = self.embeddings(
  608. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  609. )
  610. encoder_outputs = self.encoder(
  611. embedding_output,
  612. attention_mask=attention_mask,
  613. output_attentions=output_attentions,
  614. output_hidden_states=output_hidden_states,
  615. return_dict=return_dict,
  616. )
  617. sequence_output = encoder_outputs[0]
  618. sequence_output = self.layernorm(sequence_output)
  619. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  620. if not return_dict:
  621. return (sequence_output, pooled_output) + encoder_outputs[1:]
  622. return BaseModelOutputWithPooling(
  623. last_hidden_state=sequence_output,
  624. pooler_output=pooled_output,
  625. hidden_states=encoder_outputs.hidden_states,
  626. attentions=encoder_outputs.attentions,
  627. )
  628. @auto_docstring
  629. class FlavaTextModel(FlavaPreTrainedModel):
  630. config: FlavaTextConfig
  631. # This override allows us to load FlavaTextModel from FlavaModel/FlavaForPreTraining checkpoints.
  632. base_model_prefix = "flava.text_model"
  633. input_modalities = ("text",)
  634. def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True):
  635. r"""
  636. add_pooling_layer (bool, *optional*, defaults to `True`):
  637. Whether to add a pooling layer
  638. """
  639. super().__init__(config)
  640. self.config = config
  641. self.embeddings = FlavaTextEmbeddings(config)
  642. self.encoder = FlavaEncoder(config)
  643. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  644. self.pooler = FlavaPooler(config) if add_pooling_layer else None
  645. self.post_init()
  646. def get_input_embeddings(self) -> PatchEmbeddings:
  647. return self.embeddings.word_embeddings
  648. def set_input_embeddings(self, value: nn.Module):
  649. self.embeddings.word_embeddings = value
  650. @auto_docstring
  651. def forward(
  652. self,
  653. input_ids: torch.Tensor | None = None,
  654. attention_mask: torch.Tensor | None = None,
  655. token_type_ids: torch.Tensor | None = None,
  656. position_ids: torch.Tensor | None = None,
  657. output_attentions: bool | None = None,
  658. output_hidden_states: bool | None = None,
  659. return_dict: bool | None = None,
  660. **kwargs,
  661. ) -> tuple | BaseModelOutputWithPooling:
  662. r"""
  663. input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`):
  664. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  665. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  666. IDs?](../glossary#input-ids)
  667. token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`, *optional*):
  668. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  669. 1]`:
  670. - 0 corresponds to a *sentence A* token,
  671. - 1 corresponds to a *sentence B* token.
  672. [What are token type IDs?](../glossary#token-type-ids)
  673. """
  674. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  675. output_hidden_states = (
  676. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  677. )
  678. return_dict = return_dict if return_dict is not None else self.config.return_dict
  679. if input_ids is None:
  680. raise ValueError("You have to specify input_ids")
  681. input_shape = input_ids.size()
  682. if attention_mask is None:
  683. attention_mask = torch.ones(input_shape, device=input_ids.device)
  684. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  685. attention_mask,
  686. input_shape,
  687. )
  688. embedding_output = self.embeddings(
  689. input_ids=input_ids,
  690. token_type_ids=token_type_ids,
  691. position_ids=position_ids,
  692. )
  693. encoder_outputs = self.encoder(
  694. embedding_output,
  695. attention_mask=extended_attention_mask,
  696. output_attentions=output_attentions,
  697. output_hidden_states=output_hidden_states,
  698. return_dict=return_dict,
  699. )
  700. sequence_output = encoder_outputs[0]
  701. sequence_output = self.layernorm(sequence_output)
  702. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  703. if not return_dict:
  704. return (sequence_output, pooled_output) + encoder_outputs[1:]
  705. return BaseModelOutputWithPooling(
  706. last_hidden_state=sequence_output,
  707. pooler_output=pooled_output,
  708. hidden_states=encoder_outputs.hidden_states,
  709. attentions=encoder_outputs.attentions,
  710. )
  711. @auto_docstring
  712. class FlavaMultimodalModel(FlavaPreTrainedModel):
  713. config: FlavaMultimodalConfig
  714. # This override allows us to load FlavaMultimodalModel from FlavaModel/FlavaForPreTraining checkpoints.
  715. base_model_prefix = "flava.multimodal_model"
  716. main_input_name = "hidden_states"
  717. def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True):
  718. r"""
  719. add_pooling_layer (bool, *optional*, defaults to `True`):
  720. Whether to add a pooling layer
  721. """
  722. super().__init__(config)
  723. self.config = config
  724. self.use_cls_token = self.config.use_cls_token
  725. if self.use_cls_token:
  726. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  727. self.encoder = FlavaEncoder(config)
  728. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  729. self.pooler = FlavaPooler(config) if add_pooling_layer else None
  730. self.post_init()
  731. @auto_docstring
  732. def forward(
  733. self,
  734. hidden_states: torch.Tensor,
  735. attention_mask: torch.Tensor | None = None,
  736. output_attentions: bool | None = None,
  737. output_hidden_states: bool | None = None,
  738. return_dict: bool | None = None,
  739. **kwargs,
  740. ) -> tuple | BaseModelOutputWithPooling:
  741. r"""
  742. hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
  743. The concatenated hidden states of unimodal encoders.
  744. """
  745. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  746. output_hidden_states = (
  747. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  748. )
  749. return_dict = return_dict if return_dict is not None else self.config.return_dict
  750. batch_size, seq_length, _ = hidden_states.size()
  751. if self.use_cls_token:
  752. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  753. hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)
  754. seq_length += 1
  755. if attention_mask is None:
  756. attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)
  757. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  758. attention_mask,
  759. (batch_size, seq_length),
  760. )
  761. encoder_outputs = self.encoder(
  762. hidden_states,
  763. attention_mask=extended_attention_mask,
  764. output_attentions=output_attentions,
  765. output_hidden_states=output_hidden_states,
  766. return_dict=return_dict,
  767. )
  768. sequence_output = encoder_outputs[0]
  769. sequence_output = self.layernorm(sequence_output)
  770. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  771. if not return_dict:
  772. return (sequence_output, pooled_output) + encoder_outputs[1:]
  773. return BaseModelOutputWithPooling(
  774. last_hidden_state=sequence_output,
  775. pooler_output=pooled_output,
  776. hidden_states=encoder_outputs.hidden_states,
  777. attentions=encoder_outputs.attentions,
  778. )
  779. @auto_docstring
  780. class FlavaModel(FlavaPreTrainedModel):
  781. config: FlavaConfig
  782. def __init__(self, config: FlavaConfig):
  783. super().__init__(config)
  784. if not isinstance(config.text_config, FlavaTextConfig):
  785. raise TypeError(
  786. "config.text_config is expected to be of type FlavaTextConfig but is of type"
  787. f" {type(config.text_config)}."
  788. )
  789. if not isinstance(config.image_config, FlavaImageConfig):
  790. raise TypeError(
  791. "config.image_config is expected to be of type FlavaImageConfig but is of type"
  792. f" {type(config.image_config)}."
  793. )
  794. if not isinstance(config.multimodal_config, FlavaMultimodalConfig):
  795. raise TypeError(
  796. "config.multimodal_config is expected to be of type FlavaMultimodalConfig but "
  797. + f"is of type {type(config.multimodal_config)}."
  798. )
  799. text_config = config.text_config
  800. image_config = config.image_config
  801. multimodal_config = config.multimodal_config
  802. self.projection_dim = config.projection_dim
  803. self.text_hidden_size = text_config.hidden_size
  804. self.image_hidden_size = image_config.hidden_size
  805. self.mm_hidden_size = multimodal_config.hidden_size
  806. self.text_model = FlavaTextModel(text_config)
  807. self.image_model = FlavaImageModel(image_config)
  808. self.multimodal_model = FlavaMultimodalModel(multimodal_config)
  809. self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim)
  810. self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim)
  811. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  812. self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size)
  813. self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size)
  814. # Initialize weights and apply final processing
  815. self.post_init()
  816. @can_return_tuple
  817. @auto_docstring
  818. def get_text_features(
  819. self,
  820. input_ids: torch.Tensor,
  821. attention_mask: torch.Tensor | None = None,
  822. token_type_ids: torch.Tensor | None = None,
  823. position_ids: torch.Tensor | None = None,
  824. **kwargs: Unpack[TransformersKwargs],
  825. ) -> tuple | BaseModelOutputWithPooling:
  826. r"""
  827. input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`):
  828. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  829. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  830. IDs?](../glossary#input-ids)
  831. token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`, *optional*):
  832. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  833. 1]`:
  834. - 0 corresponds to a *sentence A* token,
  835. - 1 corresponds to a *sentence B* token.
  836. [What are token type IDs?](../glossary#token-type-ids)
  837. Examples:
  838. ```python
  839. >>> import torch
  840. >>> from transformers import AutoProcessor, FlavaModel
  841. >>> model = FlavaModel.from_pretrained("{0}")
  842. >>> processor = AutoProcessor.from_pretrained("{0}")
  843. >>> inputs = processor(
  844. ... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
  845. ... )
  846. >>> with torch.inference_mode():
  847. ... text_features = model.get_text_features(**inputs)
  848. ```
  849. """
  850. text_outputs: BaseModelOutputWithPooling = self.text_model(
  851. input_ids=input_ids,
  852. attention_mask=attention_mask,
  853. token_type_ids=token_type_ids,
  854. position_ids=position_ids,
  855. return_dict=True,
  856. **kwargs,
  857. )
  858. last_hidden_state = text_outputs.last_hidden_state
  859. text_outputs.pooler_output = self.text_projection(last_hidden_state)
  860. return text_outputs
  861. @can_return_tuple
  862. @auto_docstring
  863. def get_image_features(
  864. self,
  865. pixel_values: torch.Tensor,
  866. bool_masked_pos: torch.BoolTensor | None = None,
  867. interpolate_pos_encoding: bool | None = None,
  868. attention_mask: torch.Tensor | None = None,
  869. **kwargs: Unpack[TransformersKwargs],
  870. ) -> tuple | BaseModelOutputWithPooling:
  871. r"""
  872. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
  873. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  874. Examples:
  875. ```python
  876. >>> import torch
  877. >>> from transformers import AutoProcessor, FlavaModel
  878. >>> from transformers.image_utils import load_image
  879. >>> model = FlavaModel.from_pretrained("{0}")
  880. >>> processor = AutoProcessor.from_pretrained("{0}")
  881. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  882. >>> image = load_image(url)
  883. >>> inputs = processor(images=image, return_tensors="pt")
  884. >>> with torch.inference_mode():
  885. ... image_features = model.get_image_features(**inputs)
  886. ```
  887. """
  888. image_outputs: BaseModelOutputWithPooling = self.image_model(
  889. pixel_values=pixel_values,
  890. bool_masked_pos=bool_masked_pos,
  891. attention_mask=attention_mask,
  892. interpolate_pos_encoding=interpolate_pos_encoding,
  893. return_dict=True,
  894. **kwargs,
  895. )
  896. last_hidden_state = image_outputs.last_hidden_state
  897. image_outputs.pooler_output = self.image_projection(last_hidden_state)
  898. return image_outputs
  899. @auto_docstring
  900. def forward(
  901. self,
  902. input_ids: torch.LongTensor | None = None,
  903. pixel_values: torch.FloatTensor | None = None,
  904. attention_mask: torch.Tensor | None = None,
  905. token_type_ids: torch.Tensor | None = None,
  906. bool_masked_pos: torch.Tensor | None = None,
  907. position_ids: torch.LongTensor | None = None,
  908. image_attention_mask: torch.Tensor | None = None,
  909. skip_multimodal_encoder: bool | None = None,
  910. output_attentions: bool | None = None,
  911. output_hidden_states: bool = True,
  912. return_dict: bool | None = None,
  913. **kwargs,
  914. ) -> tuple | FlavaModelOutput:
  915. r"""
  916. input_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`):
  917. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  918. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  919. IDs?](../glossary#input-ids)
  920. token_type_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`, *optional*):
  921. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  922. 1]`:
  923. - 0 corresponds to a *sentence A* token,
  924. - 1 corresponds to a *sentence B* token.
  925. [What are token type IDs?](../glossary#token-type-ids)
  926. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
  927. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  928. image_attention_mask (`torch.Tensor` of shape `(batch_size, image_num_patches)`, *optional*):
  929. Mask to avoid performing attention on padding pixel values for image inputs. Mask values selected in `[0, 1]`:
  930. - 1 for pixel values that are real (i.e., **not masked**),
  931. - 0 for pixel values that are padding (i.e., **masked**).
  932. skip_multimodal_encoder (*bool*, *optional*):
  933. Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
  934. Examples:
  935. ```python
  936. >>> from PIL import Image
  937. >>> import httpx
  938. >>> from io import BytesIO
  939. >>> from transformers import AutoProcessor, FlavaModel
  940. >>> model = FlavaModel.from_pretrained("facebook/flava-full")
  941. >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")
  942. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  943. >>> with httpx.stream("GET", url) as response:
  944. ... image = Image.open(BytesIO(response.read()))
  945. >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True)
  946. >>> outputs = model(**inputs)
  947. >>> image_embeddings = outputs.image_embeddings
  948. >>> text_embeddings = outputs.text_embeddings
  949. >>> multimodal_embeddings = outputs.multimodal_embeddings
  950. >>> outputs.image_embeddings.shape
  951. torch.Size([1, 197, 768])
  952. >>> text_embeddings.shape
  953. torch.Size([1, 7, 768])
  954. >>> multimodal_embeddings.shape
  955. torch.Size([1, 205, 768])
  956. ```
  957. """
  958. return_dict = return_dict if return_dict is not None else self.config.return_dict
  959. if not output_hidden_states:
  960. raise ValueError("FLAVA model requires hidden states to work. Please set `output_hidden_states=True`")
  961. image_embeddings = None
  962. image_states = None
  963. image_mm_projection = None
  964. image_output = None
  965. if pixel_values is not None:
  966. image_output = self.image_model(
  967. pixel_values=pixel_values,
  968. bool_masked_pos=bool_masked_pos,
  969. attention_mask=image_attention_mask,
  970. output_attentions=output_attentions,
  971. output_hidden_states=output_hidden_states,
  972. return_dict=return_dict,
  973. )
  974. image_embeddings, image_states = image_output[0], image_output[2]
  975. # Note that these states don't use final layernorm in the transformer model
  976. image_mm_projection = self.image_to_mm_projection(image_states[-1])
  977. text_embeddings = None
  978. text_states = None
  979. text_mm_projection = None
  980. text_output = None
  981. if input_ids is not None:
  982. text_output = self.text_model(
  983. input_ids=input_ids,
  984. attention_mask=attention_mask,
  985. position_ids=position_ids,
  986. token_type_ids=token_type_ids,
  987. output_attentions=output_attentions,
  988. output_hidden_states=output_hidden_states,
  989. return_dict=return_dict,
  990. )
  991. text_embeddings, text_states = text_output[0], text_output[2]
  992. # Note that these states don't use final layernorm in the transformer model
  993. text_mm_projection = self.text_to_mm_projection(text_states[-1])
  994. multimodal_embeddings = None
  995. multimodal_output = None
  996. if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder:
  997. if attention_mask is not None:
  998. batch_size, seq_len, _ = image_mm_projection.shape
  999. if self.multimodal_model.use_cls_token:
  1000. seq_len += 1
  1001. attention_mask_image = torch.ones(batch_size, seq_len, device=image_mm_projection.device)
  1002. attention_multimodal = torch.cat([attention_mask_image, attention_mask], dim=1)
  1003. else:
  1004. attention_multimodal = None
  1005. multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1)
  1006. multimodal_output = self.multimodal_model(
  1007. multimodal_input, attention_mask=attention_multimodal, return_dict=return_dict
  1008. )
  1009. multimodal_embeddings = multimodal_output[0]
  1010. if not return_dict:
  1011. return (
  1012. image_embeddings,
  1013. image_output,
  1014. text_embeddings,
  1015. text_output,
  1016. multimodal_embeddings,
  1017. multimodal_output,
  1018. )
  1019. return FlavaModelOutput(
  1020. image_embeddings=image_embeddings,
  1021. image_output=image_output,
  1022. text_embeddings=text_embeddings,
  1023. text_output=text_output,
  1024. multimodal_embeddings=multimodal_embeddings,
  1025. multimodal_output=multimodal_output,
  1026. )
  1027. class FlavaImageCodebookResPath(nn.Module):
  1028. def __init__(self, in_size: int, out_size: int, **kwargs):
  1029. super().__init__()
  1030. hid_size = out_size // 4
  1031. path = OrderedDict()
  1032. path["relu_1"] = nn.ReLU()
  1033. path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1)
  1034. path["relu_2"] = nn.ReLU()
  1035. path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
  1036. path["relu_3"] = nn.ReLU()
  1037. path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
  1038. path["relu_4"] = nn.ReLU()
  1039. path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0)
  1040. self.path = nn.Sequential(path)
  1041. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1042. return self.path(x)
  1043. class FlavaImageCodebookBlock(nn.Module):
  1044. def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs):
  1045. super().__init__()
  1046. self.post_gain = 1 / (num_layers**2)
  1047. if in_size != out_size:
  1048. self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0)
  1049. else:
  1050. self.id_path = nn.Identity()
  1051. self.res_path = FlavaImageCodebookResPath(in_size, out_size)
  1052. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1053. return self.id_path(x) + self.post_gain * self.res_path(x)
  1054. class FlavaImageCodebookLayerGroup(nn.Module):
  1055. def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True):
  1056. super().__init__()
  1057. blocks = OrderedDict()
  1058. for i in range(num_blocks):
  1059. if i == 0:
  1060. blocks[f"block_{i + 1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers)
  1061. else:
  1062. blocks[f"block_{i + 1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers)
  1063. if use_pool:
  1064. blocks["pool"] = nn.MaxPool2d(kernel_size=2)
  1065. self.group = nn.Sequential(blocks)
  1066. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1067. return self.group(x)
  1068. # Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42
  1069. @auto_docstring(
  1070. custom_intro="""
  1071. The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used
  1072. to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use
  1073. `get_codebook_indices` to get image tokens for an image.
  1074. """
  1075. )
  1076. class FlavaImageCodebook(FlavaPreTrainedModel):
  1077. base_model_prefix = "model"
  1078. config: FlavaImageCodebookConfig
  1079. main_input_name = "pixel_values"
  1080. input_modalities = ("image",)
  1081. supports_gradient_checkpointing = False
  1082. def __init__(
  1083. self,
  1084. config: FlavaImageCodebookConfig,
  1085. **kwargs: Any,
  1086. ):
  1087. super().__init__(config)
  1088. self.config = config
  1089. self.num_groups = config.num_groups
  1090. self.input_channels = config.input_channels
  1091. self.num_blocks_per_group = config.num_blocks_per_group
  1092. self.hidden_size = config.hidden_size
  1093. self.vocab_size = config.vocab_size
  1094. num_layers = self.num_groups * self.num_blocks_per_group
  1095. output_blocks = OrderedDict()
  1096. output_blocks["relu"] = nn.ReLU()
  1097. output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0)
  1098. blocks = OrderedDict()
  1099. blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3)
  1100. blocks["group_1"] = FlavaImageCodebookLayerGroup(
  1101. self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size
  1102. )
  1103. blocks["group_2"] = FlavaImageCodebookLayerGroup(
  1104. self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size
  1105. )
  1106. blocks["group_3"] = FlavaImageCodebookLayerGroup(
  1107. self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size
  1108. )
  1109. blocks["group_4"] = FlavaImageCodebookLayerGroup(
  1110. self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False
  1111. )
  1112. blocks["output"] = nn.Sequential(output_blocks)
  1113. self.blocks = nn.Sequential(blocks)
  1114. self.post_init()
  1115. if self.config.freeze:
  1116. for param in self.parameters():
  1117. param.requires_grad = False
  1118. def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor:
  1119. f"""
  1120. Args:
  1121. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1122. Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
  1123. `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
  1124. Examples:
  1125. ```python
  1126. >>> from PIL import Image
  1127. >>> import httpx
  1128. >>> from io import BytesIO
  1129. >>> from transformers import AutoImageProcessor, FlavaImageCodebook
  1130. >>> model = FlavaImageCodebook.from_pretrained("{_CHECKPOINT_FOR_CODEBOOK_DOC}")
  1131. >>> image_processor = AutoImageProcessor.from_pretrained("{_CHECKPOINT_FOR_CODEBOOK_DOC}")
  1132. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1133. >>> with httpx.stream("GET", url) as response:
  1134. ... image = Image.open(BytesIO(response.read()))
  1135. >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
  1136. >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
  1137. >>> outputs = model.get_codebook_indices(**inputs)
  1138. ```
  1139. """
  1140. z_logits = self.blocks(pixel_values)
  1141. return torch.argmax(z_logits, axis=1)
  1142. def get_codebook_probs(self, pixel_values: torch.Tensor) -> torch.Tensor:
  1143. z_logits = self.blocks(pixel_values)
  1144. return nn.Softmax(dim=1)(z_logits)
  1145. def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> torch.Tensor:
  1146. f"""
  1147. Args:
  1148. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1149. Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
  1150. `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
  1151. Examples:
  1152. ```python
  1153. >>> from PIL import Image
  1154. >>> import httpx
  1155. >>> from io import BytesIO
  1156. >>> from transformers import AutoImageProcessor, FlavaImageCodebook
  1157. >>> model = FlavaImageCodebook.from_pretrained("{_CHECKPOINT_FOR_CODEBOOK_DOC}")
  1158. >>> image_processor = AutoImageProcessor.from_pretrained("{_CHECKPOINT_FOR_CODEBOOK_DOC}")
  1159. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1160. >>> with httpx.stream("GET", url) as response:
  1161. ... image = Image.open(BytesIO(response.read()))
  1162. >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
  1163. >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
  1164. >>> outputs = model(**inputs)
  1165. >>> print(outputs.shape)
  1166. (1, 196)
  1167. ```
  1168. """
  1169. if len(pixel_values.shape) != 4:
  1170. raise ValueError(f"input shape {pixel_values.shape} is not 4d")
  1171. if pixel_values.shape[1] != self.input_channels:
  1172. raise ValueError(f"input has {pixel_values.shape[1]} channels but model built for {self.input_channels}")
  1173. return self.blocks(pixel_values)
  1174. class FlavaPredictionHeadTransform(nn.Module):
  1175. def __init__(self, config):
  1176. super().__init__()
  1177. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1178. if isinstance(config.hidden_act, str):
  1179. self.transform_act_fn = ACT2FN[config.hidden_act]
  1180. else:
  1181. self.transform_act_fn = config.hidden_act
  1182. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1183. def forward(self, hidden_states):
  1184. hidden_states = self.dense(hidden_states)
  1185. hidden_states = self.transform_act_fn(hidden_states)
  1186. hidden_states = self.LayerNorm(hidden_states)
  1187. return hidden_states
  1188. class FlavaMaskedPredictionHead(nn.Module):
  1189. def __init__(self, config, weight=None):
  1190. super().__init__()
  1191. self.config = config
  1192. self.transform = FlavaPredictionHeadTransform(config)
  1193. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  1194. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  1195. if weight is not None:
  1196. self.decoder.weight = weight
  1197. def forward(self, x):
  1198. x = self.transform(x)
  1199. x = self.decoder(x)
  1200. return x
  1201. class FlavaITMHead(nn.Module):
  1202. def __init__(self, config):
  1203. super().__init__()
  1204. self.config = config
  1205. self.pooler = FlavaPooler(config)
  1206. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  1207. def forward(self, x):
  1208. x = self.pooler(x)
  1209. x = self.seq_relationship(x)
  1210. return x
  1211. class FlavaGlobalContrastiveHead(nn.Module):
  1212. def __init__(self, config):
  1213. super().__init__()
  1214. self.config = config
  1215. self.global_backprop_contrastive = config.global_backprop_contrastive
  1216. def forward(self, image_embeddings, text_embeddings, logit_scale):
  1217. temperature = torch.exp(logit_scale)
  1218. if not torch.distributed.is_available() or not torch.distributed.is_initialized():
  1219. labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device)
  1220. image_embeddings_all = [image_embeddings]
  1221. text_embeddings_all = [text_embeddings]
  1222. else:
  1223. local_batch_size = image_embeddings.size(0)
  1224. world_size = torch.distributed.get_world_size()
  1225. if self.global_backprop_contrastive:
  1226. # `torch.distributed.nn.functional.all_gather` does backprop on all active workers
  1227. # whereas `torch.distributed.all_gather` does only backpropagates on the current worker.
  1228. image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
  1229. text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
  1230. else:
  1231. image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
  1232. text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
  1233. torch.distributed.all_gather(image_embeddings_all, image_embeddings)
  1234. torch.distributed.all_gather(text_embeddings_all, text_embeddings)
  1235. labels = local_batch_size * torch.distributed.get_rank() + torch.arange(
  1236. local_batch_size, device=image_embeddings.device
  1237. )
  1238. image_embeddings_all = torch.cat(image_embeddings_all)
  1239. text_embeddings_all = torch.cat(text_embeddings_all)
  1240. logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature
  1241. logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature
  1242. return logits_per_image, logits_per_text, labels
  1243. @auto_docstring(
  1244. custom_intro="""
  1245. The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
  1246. """
  1247. )
  1248. class FlavaForPreTraining(FlavaPreTrainedModel):
  1249. # Those are linked to xxx.bias
  1250. _tied_weights_keys = {
  1251. "mmm_text_head.bias": "mmm_text_head.decoder.bias",
  1252. "mim_head.bias": "mim_head.decoder.bias",
  1253. "mlm_head.bias": "mlm_head.decoder.bias",
  1254. "mmm_image_head.bias": "mmm_image_head.decoder.bias",
  1255. }
  1256. def __init__(self, config: FlavaConfig, image_codebook: nn.Module | None = None):
  1257. r"""
  1258. image_codebook ([`nn.Module`]):
  1259. If passed, the image codebook will be set to this. Otherwise, it will be initialized using the
  1260. image_codebook_config defined in the config first as the first parameter.
  1261. """
  1262. super().__init__(config)
  1263. self.flava = FlavaModel(config)
  1264. self.image_codebook = image_codebook
  1265. if self.image_codebook is None and config.init_codebook:
  1266. self.image_codebook = FlavaImageCodebook(config.image_codebook_config)
  1267. # Levarage text and image encoder configs to create the masked
  1268. # head since it has the right vocab
  1269. self.mim_head = FlavaMaskedPredictionHead(config.image_config)
  1270. self.mlm_head = FlavaMaskedPredictionHead(config.text_config)
  1271. self.itm_head = FlavaITMHead(config)
  1272. self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config)
  1273. self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config)
  1274. self.global_contrastive_head = FlavaGlobalContrastiveHead(config)
  1275. self.image_vocab_size = config.image_config.vocab_size
  1276. self.text_vocab_size = config.text_config.vocab_size
  1277. self.mlm_weight = config.mlm_weight
  1278. self.mim_weight = config.mim_weight
  1279. self.global_contrastive_weight = config.global_contrastive_weight
  1280. self.ce_ignore_index = config.ce_ignore_index
  1281. self.itm_weight = config.itm_weight
  1282. self.mmm_image_weight = config.mmm_image_weight
  1283. self.mmm_text_weight = config.mmm_text_weight
  1284. self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder
  1285. self.post_init()
  1286. def _resize_to_2d(self, x: torch.Tensor):
  1287. if x.dim() > 2:
  1288. x = x.view(x.size(0), -1)
  1289. return x
  1290. @auto_docstring
  1291. def forward(
  1292. self,
  1293. input_ids: torch.LongTensor | None = None,
  1294. input_ids_masked: torch.LongTensor | None = None,
  1295. pixel_values: torch.FloatTensor | None = None,
  1296. codebook_pixel_values: torch.FloatTensor | None = None,
  1297. attention_mask: torch.Tensor | None = None,
  1298. token_type_ids: torch.Tensor | None = None,
  1299. bool_masked_pos: torch.Tensor | None = None,
  1300. position_ids: torch.LongTensor | None = None,
  1301. image_attention_mask: torch.Tensor | None = None,
  1302. skip_unmasked_multimodal_encoder: bool | None = None,
  1303. mlm_labels: torch.Tensor | None = None,
  1304. mim_labels: torch.Tensor | None = None,
  1305. itm_labels: torch.Tensor | None = None,
  1306. output_attentions: bool | None = None,
  1307. output_hidden_states: bool = True,
  1308. return_dict: bool | None = None,
  1309. return_loss: bool | None = None,
  1310. **kwargs,
  1311. ) -> tuple[torch.Tensor] | FlavaForPreTrainingOutput:
  1312. r"""
  1313. input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`):
  1314. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  1315. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  1316. IDs?](../glossary#input-ids)
  1317. input_ids_masked (`torch.LongTensor` of shape `(batch_size, text_seq_len)`):
  1318. Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
  1319. to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with
  1320. [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
  1321. [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
  1322. codebook_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_image_patches, patch_size, patch_size, 3)`, *optional*):
  1323. Pixel values for image patches that are used to compute the image codebook labels for masked image modeling.
  1324. token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
  1325. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1326. 1]`:
  1327. - 0 corresponds to a *sentence A* token,
  1328. - 1 corresponds to a *sentence B* token.
  1329. [What are token type IDs?](../glossary#token-type-ids)
  1330. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
  1331. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  1332. image_attention_mask (`torch.FloatTensor` of shape `(batch_size, image_num_patches)`, *optional*):
  1333. Mask to avoid performing attention on padding token indices specifically for images. Mask values selected
  1334. in `[0, 1]`:
  1335. - 1 for tokens that are **not masked**,
  1336. - 0 for tokens that are **masked**.
  1337. [What are attention masks?](../glossary#attention-mask)
  1338. skip_unmasked_multimodal_encoder (*bool*, *optional*):
  1339. Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked
  1340. multimodal embeddings or outputs as of now.
  1341. mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
  1342. Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction).
  1343. Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with
  1344. indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0,
  1345. ..., text_config.vocab_size - 1]`.
  1346. mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*):
  1347. Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ...,
  1348. image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
  1349. computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are
  1350. generated automatically using the image codebook assigned to the model. By default, it uses
  1351. [`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels.
  1352. itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
  1353. Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
  1354. The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well.
  1355. return_loss (`bool`, *optional*, default to None):
  1356. Whether to return calculated loss or not.
  1357. Examples:
  1358. ```python
  1359. >>> from PIL import Image
  1360. >>> import httpx
  1361. >>> from io import BytesIO
  1362. >>> from transformers import FlavaForPreTraining, AutoProcessor
  1363. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1364. >>> with httpx.stream("GET", url) as response:
  1365. ... image = Image.open(BytesIO(response.read()))
  1366. >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
  1367. >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")
  1368. >>> text = ["a photo of a cat"]
  1369. >>> inputs = processor(
  1370. ... images=[image],
  1371. ... text=text,
  1372. ... return_masks=True,
  1373. ... return_codebook_pixels=True,
  1374. ... padding=True,
  1375. ... max_length=77,
  1376. ... return_tensors="pt",
  1377. ... )
  1378. >>> output = model(**inputs)
  1379. ```
  1380. """
  1381. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1382. return_loss = return_loss if return_loss is not None else self.config.return_loss
  1383. skip_unmasked_multimodal_encoder = (
  1384. skip_unmasked_multimodal_encoder
  1385. if skip_unmasked_multimodal_encoder is not None
  1386. else self.skip_unmasked_multimodal_encoder
  1387. )
  1388. if input_ids_masked is None and input_ids is not None:
  1389. logger.warning(
  1390. "`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to"
  1391. " `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if"
  1392. " you are doing inference on unmasked text..."
  1393. )
  1394. input_ids_masked = input_ids
  1395. flava_output = self.flava(
  1396. input_ids=input_ids,
  1397. pixel_values=pixel_values,
  1398. attention_mask=attention_mask,
  1399. token_type_ids=token_type_ids,
  1400. position_ids=position_ids,
  1401. image_attention_mask=image_attention_mask,
  1402. # Don't need unmasked multimodal embedding for anything so skip it
  1403. # NOTE: ITM uses masked version
  1404. skip_multimodal_encoder=skip_unmasked_multimodal_encoder,
  1405. output_attentions=output_attentions,
  1406. output_hidden_states=output_hidden_states,
  1407. # Pass true to have deterministic outputs
  1408. return_dict=True,
  1409. )
  1410. flava_masked_output = self.flava(
  1411. input_ids=input_ids_masked,
  1412. pixel_values=pixel_values,
  1413. attention_mask=attention_mask,
  1414. token_type_ids=token_type_ids,
  1415. image_attention_mask=image_attention_mask,
  1416. bool_masked_pos=bool_masked_pos,
  1417. output_attentions=output_attentions,
  1418. output_hidden_states=output_hidden_states,
  1419. return_dict=True,
  1420. )
  1421. pos_mask = None
  1422. image_embeddings = flava_output.image_embeddings
  1423. text_embeddings = flava_output.text_embeddings
  1424. image_masked_embeddings = flava_masked_output.image_embeddings
  1425. text_masked_embeddings = flava_masked_output.text_embeddings
  1426. multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings
  1427. total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None
  1428. mim_logits = mlm_logits = mmm_text_logits = mmm_image_logits = None
  1429. itm_logits = logits_per_image = logits_per_text = None
  1430. # Calculate mim_labels if necessary from the image_codebook
  1431. if image_masked_embeddings is not None or multimodal_masked_embeddings is not None:
  1432. if mim_labels is None and return_loss:
  1433. if self.image_codebook is None:
  1434. raise RuntimeError(
  1435. "`return_loss` is set to True but the image codebook is not initialized and no `mim_labels` "
  1436. " have been passed. Reinstantiate the model with `init_codebook` set to True or "
  1437. "pass in your custom `mim_labels`"
  1438. )
  1439. if codebook_pixel_values is None:
  1440. raise ValueError(
  1441. "`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. "
  1442. "Call `AutoProcessor` with `return_codebook_pixels` set to True"
  1443. )
  1444. mim_labels = self.image_codebook.get_codebook_indices(codebook_pixel_values)
  1445. # Unimodal MIM Loss
  1446. # If multimodal embeddings are present, we will calculate MMM loss
  1447. if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None:
  1448. sequence_for_image = image_masked_embeddings
  1449. if mim_labels is not None:
  1450. mim_labels = self._resize_to_2d(mim_labels)
  1451. bool_masked_pos = self._resize_to_2d(bool_masked_pos)
  1452. mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
  1453. sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :]
  1454. masked_tokens = mim_labels.ne(self.ce_ignore_index)
  1455. mim_labels_filtered = mim_labels[masked_tokens]
  1456. sequence_for_image = sequence_for_image[masked_tokens, :]
  1457. mim_logits = self.mim_head(sequence_for_image)
  1458. if return_loss:
  1459. mim_loss = nn.functional.cross_entropy(
  1460. mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
  1461. )
  1462. mim_loss *= self.mim_weight
  1463. else:
  1464. mim_logits = self.mim_head(sequence_for_image)
  1465. # Unimodal MLM Loss
  1466. if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None:
  1467. sequence_for_text = text_masked_embeddings
  1468. if mlm_labels is not None:
  1469. mlm_labels = self._resize_to_2d(mlm_labels)
  1470. sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :]
  1471. masked_tokens = mlm_labels.ne(self.ce_ignore_index)
  1472. mlm_labels_filtered = mlm_labels[masked_tokens]
  1473. sequence_for_text = sequence_for_text[masked_tokens, :]
  1474. mlm_logits = self.mlm_head(sequence_for_text)
  1475. if return_loss:
  1476. mlm_loss = nn.functional.cross_entropy(
  1477. mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
  1478. )
  1479. mlm_loss *= self.mlm_weight
  1480. else:
  1481. mlm_logits = self.mlm_head(sequence_for_text)
  1482. # ITM Loss
  1483. if self.itm_weight > 0 and multimodal_masked_embeddings is not None:
  1484. itm_logits = self.itm_head(multimodal_masked_embeddings)
  1485. if itm_labels is not None:
  1486. pos_pairs = itm_labels.ne(0)
  1487. pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True]))
  1488. if return_loss:
  1489. itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels)
  1490. itm_loss *= self.itm_weight
  1491. if multimodal_masked_embeddings is not None:
  1492. multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask]
  1493. if mlm_labels is not None:
  1494. mlm_labels = mlm_labels[pos_mask]
  1495. if mim_labels is not None:
  1496. mim_labels = mim_labels[pos_mask]
  1497. bool_masked_pos = bool_masked_pos[pos_mask]
  1498. # MMM Image Loss
  1499. if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
  1500. sequence_for_image = multimodal_masked_embeddings
  1501. end_index = image_masked_embeddings.size(1) - 1
  1502. sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
  1503. if mim_labels is not None:
  1504. mim_labels = self._resize_to_2d(mim_labels)
  1505. bool_masked_pos = self._resize_to_2d(bool_masked_pos)
  1506. mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
  1507. masked_tokens = mim_labels.ne(self.ce_ignore_index)
  1508. mim_labels_filtered = mim_labels[masked_tokens]
  1509. sequence_for_image = sequence_for_image[masked_tokens, :]
  1510. mmm_image_logits = self.mmm_image_head(sequence_for_image)
  1511. if return_loss:
  1512. mmm_image_loss = nn.functional.cross_entropy(
  1513. mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
  1514. )
  1515. mmm_image_loss *= self.mmm_image_weight
  1516. else:
  1517. mmm_image_logits = self.mmm_image_head(sequence_for_image)
  1518. # MMM Text Loss
  1519. if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
  1520. sequence_for_text = multimodal_masked_embeddings
  1521. sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
  1522. if mlm_labels is not None:
  1523. mlm_labels = self._resize_to_2d(mlm_labels)
  1524. masked_tokens = mlm_labels.ne(self.ce_ignore_index)
  1525. mlm_labels_filtered = mlm_labels[masked_tokens]
  1526. sequence_for_text = sequence_for_text[masked_tokens, :]
  1527. mmm_text_logits = self.mmm_text_head(sequence_for_text)
  1528. if return_loss:
  1529. mmm_text_loss = nn.functional.cross_entropy(
  1530. mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
  1531. )
  1532. mmm_text_loss *= self.mmm_text_weight
  1533. else:
  1534. mmm_text_logits = self.mmm_text_head(sequence_for_text)
  1535. # Global Contrastive Loss
  1536. if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0:
  1537. text_embedding = self.flava.text_projection(text_embeddings[:, 0, :])
  1538. text_embedding = nn.functional.normalize(text_embedding, dim=-1)
  1539. image_embedding = self.flava.image_projection(image_embeddings[:, 0, :])
  1540. image_embedding = nn.functional.normalize(image_embedding, dim=-1)
  1541. if self.training:
  1542. self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX)
  1543. logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head(
  1544. image_embedding, text_embedding, self.flava.logit_scale
  1545. )
  1546. # Apply ITM negative mask if any
  1547. if pos_mask is not None:
  1548. logits_per_image = logits_per_image[pos_mask]
  1549. logits_per_text = logits_per_text[pos_mask]
  1550. gc_labels = gc_labels[pos_mask]
  1551. if return_loss:
  1552. gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels)
  1553. gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels)
  1554. gc_loss = (gc_loss_image + gc_loss_text) / 2
  1555. gc_loss *= self.global_contrastive_weight
  1556. flava_losses = FlavaLosses(
  1557. mim=mim_loss,
  1558. mlm=mlm_loss,
  1559. itm=itm_loss,
  1560. global_contrastive=gc_loss,
  1561. mmm_image=mmm_image_loss,
  1562. mmm_text=mmm_text_loss,
  1563. )
  1564. if return_loss and not flava_losses.all_none():
  1565. total_loss = sum(loss if loss is not None else 0 for loss in flava_losses.values())
  1566. if not return_dict:
  1567. output = (
  1568. image_embeddings,
  1569. flava_output.image_output.to_tuple() if flava_output.image_output is not None else None,
  1570. text_embeddings,
  1571. flava_output.text_output.to_tuple() if flava_output.text_output is not None else None,
  1572. flava_output.multimodal_embeddings,
  1573. flava_output.multimodal_output.to_tuple() if flava_output.multimodal_output is not None else None,
  1574. image_masked_embeddings,
  1575. flava_masked_output.image_output.to_tuple() if flava_masked_output.image_output is not None else None,
  1576. text_masked_embeddings,
  1577. flava_masked_output.text_output.to_tuple() if flava_masked_output.text_output is not None else None,
  1578. multimodal_masked_embeddings,
  1579. flava_masked_output.multimodal_output.to_tuple()
  1580. if flava_masked_output.multimodal_output is not None
  1581. else None,
  1582. mim_logits,
  1583. mlm_logits,
  1584. itm_logits,
  1585. logits_per_image,
  1586. logits_per_image,
  1587. mmm_image_logits,
  1588. mmm_text_logits,
  1589. )
  1590. if return_loss and not flava_losses.all_none():
  1591. output = (
  1592. total_loss,
  1593. flava_losses,
  1594. ) + output
  1595. # Filter None as transformer by default won't handle it
  1596. return tuple(x for x in output if x is None)
  1597. return FlavaForPreTrainingOutput(
  1598. loss=total_loss,
  1599. loss_info=flava_losses,
  1600. image_embeddings=image_embeddings,
  1601. image_output=flava_output.image_output,
  1602. text_embeddings=text_embeddings,
  1603. text_output=flava_output.text_output,
  1604. multimodal_embeddings=flava_output.multimodal_embeddings,
  1605. multimodal_output=flava_output.multimodal_output,
  1606. image_masked_embeddings=image_masked_embeddings,
  1607. image_masked_output=flava_masked_output.image_output,
  1608. text_masked_embeddings=text_masked_embeddings,
  1609. text_masked_output=flava_masked_output.text_output,
  1610. multimodal_masked_embeddings=multimodal_masked_embeddings,
  1611. multimodal_masked_output=flava_masked_output.multimodal_output,
  1612. mim_logits=mim_logits,
  1613. mlm_logits=mlm_logits,
  1614. itm_logits=itm_logits,
  1615. contrastive_logits_per_image=logits_per_image,
  1616. contrastive_logits_per_text=logits_per_text,
  1617. mmm_image_logits=mmm_image_logits,
  1618. mmm_text_logits=mmm_text_logits,
  1619. )
  1620. __all__ = [
  1621. "FlavaForPreTraining",
  1622. "FlavaImageCodebook",
  1623. "FlavaImageModel",
  1624. "FlavaModel",
  1625. "FlavaMultimodalModel",
  1626. "FlavaPreTrainedModel",
  1627. "FlavaTextModel",
  1628. ]