modeling_vit.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ViT model."""
  15. import collections.abc
  16. import math
  17. from collections.abc import Callable
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. ImageClassifierOutput,
  27. MaskedImageModelingOutput,
  28. )
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
  32. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  33. from ...utils.output_capturing import capture_outputs
  34. from .configuration_vit import ViTConfig
  35. logger = logging.get_logger(__name__)
  36. class ViTEmbeddings(nn.Module):
  37. """
  38. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  39. """
  40. def __init__(self, config: ViTConfig, use_mask_token: bool = False):
  41. super().__init__()
  42. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  43. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
  44. self.patch_embeddings = ViTPatchEmbeddings(config)
  45. num_patches = self.patch_embeddings.num_patches
  46. self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
  47. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  48. self.patch_size = config.patch_size
  49. self.config = config
  50. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  51. """
  52. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  53. images. This method is also adapted to support torch.jit tracing.
  54. Adapted from:
  55. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  56. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  57. """
  58. num_patches = embeddings.shape[1] - 1
  59. num_positions = self.position_embeddings.shape[1] - 1
  60. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  61. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  62. return self.position_embeddings
  63. class_pos_embed = self.position_embeddings[:, :1]
  64. patch_pos_embed = self.position_embeddings[:, 1:]
  65. dim = embeddings.shape[-1]
  66. new_height = height // self.patch_size
  67. new_width = width // self.patch_size
  68. sqrt_num_positions = torch_int(num_positions**0.5)
  69. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  70. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  71. patch_pos_embed = nn.functional.interpolate(
  72. patch_pos_embed,
  73. size=(new_height, new_width),
  74. mode="bicubic",
  75. align_corners=False,
  76. )
  77. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  78. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  79. def forward(
  80. self,
  81. pixel_values: torch.Tensor,
  82. bool_masked_pos: torch.BoolTensor | None = None,
  83. interpolate_pos_encoding: bool = False,
  84. ) -> torch.Tensor:
  85. batch_size, num_channels, height, width = pixel_values.shape
  86. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  87. if bool_masked_pos is not None:
  88. seq_length = embeddings.shape[1]
  89. mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
  90. # replace the masked visual tokens by mask_tokens
  91. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  92. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  93. # add the [CLS] token to the embedded patch tokens
  94. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  95. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  96. # add positional encoding to each token
  97. if interpolate_pos_encoding:
  98. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  99. else:
  100. embeddings = embeddings + self.position_embeddings
  101. embeddings = self.dropout(embeddings)
  102. return embeddings
  103. class ViTPatchEmbeddings(nn.Module):
  104. """
  105. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  106. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  107. Transformer.
  108. """
  109. def __init__(self, config: ViTConfig):
  110. super().__init__()
  111. image_size, patch_size = config.image_size, config.patch_size
  112. num_channels, hidden_size = config.num_channels, config.hidden_size
  113. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  114. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  115. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  116. self.image_size = image_size
  117. self.patch_size = patch_size
  118. self.num_channels = num_channels
  119. self.num_patches = num_patches
  120. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  121. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  122. batch_size, num_channels, height, width = pixel_values.shape
  123. if num_channels != self.num_channels:
  124. raise ValueError(
  125. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  126. f" Expected {self.num_channels} but got {num_channels}."
  127. )
  128. if not interpolate_pos_encoding:
  129. if height != self.image_size[0] or width != self.image_size[1]:
  130. raise ValueError(
  131. f"Input image size ({height}*{width}) doesn't match model"
  132. f" ({self.image_size[0]}*{self.image_size[1]})."
  133. )
  134. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  135. return embeddings
  136. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  137. def eager_attention_forward(
  138. module: nn.Module,
  139. query: torch.Tensor,
  140. key: torch.Tensor,
  141. value: torch.Tensor,
  142. attention_mask: torch.Tensor | None,
  143. scaling: float | None = None,
  144. dropout: float = 0.0,
  145. **kwargs: Unpack[TransformersKwargs],
  146. ):
  147. if scaling is None:
  148. scaling = query.size(-1) ** -0.5
  149. # Take the dot product between "query" and "key" to get the raw attention scores.
  150. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  151. if attention_mask is not None:
  152. attn_weights = attn_weights + attention_mask
  153. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  154. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  155. attn_output = torch.matmul(attn_weights, value)
  156. attn_output = attn_output.transpose(1, 2).contiguous()
  157. return attn_output, attn_weights
  158. class ViTSelfAttention(nn.Module):
  159. def __init__(self, config: ViTConfig):
  160. super().__init__()
  161. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  162. raise ValueError(
  163. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  164. f"heads {config.num_attention_heads}."
  165. )
  166. self.config = config
  167. self.num_attention_heads = config.num_attention_heads
  168. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  169. self.all_head_size = self.num_attention_heads * self.attention_head_size
  170. self.dropout_prob = config.attention_probs_dropout_prob
  171. self.scaling = self.attention_head_size**-0.5
  172. self.is_causal = False
  173. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  174. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  175. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  176. def forward(
  177. self,
  178. hidden_states: torch.Tensor,
  179. **kwargs: Unpack[TransformersKwargs],
  180. ) -> tuple[torch.Tensor, torch.Tensor]:
  181. batch_size = hidden_states.shape[0]
  182. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  183. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  184. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  185. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  186. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  187. self.config._attn_implementation, eager_attention_forward
  188. )
  189. context_layer, attention_probs = attention_interface(
  190. self,
  191. query_layer,
  192. key_layer,
  193. value_layer,
  194. None,
  195. is_causal=self.is_causal,
  196. scaling=self.scaling,
  197. dropout=0.0 if not self.training else self.dropout_prob,
  198. **kwargs,
  199. )
  200. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  201. context_layer = context_layer.reshape(new_context_layer_shape)
  202. return context_layer, attention_probs
  203. class ViTSelfOutput(nn.Module):
  204. """
  205. The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
  206. layernorm applied before each block.
  207. """
  208. def __init__(self, config: ViTConfig):
  209. super().__init__()
  210. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  211. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  212. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  213. hidden_states = self.dense(hidden_states)
  214. hidden_states = self.dropout(hidden_states)
  215. return hidden_states
  216. class ViTAttention(nn.Module):
  217. def __init__(self, config: ViTConfig):
  218. super().__init__()
  219. self.attention = ViTSelfAttention(config)
  220. self.output = ViTSelfOutput(config)
  221. def forward(
  222. self,
  223. hidden_states: torch.Tensor,
  224. **kwargs: Unpack[TransformersKwargs],
  225. ) -> torch.Tensor:
  226. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  227. output = self.output(self_attn_output, hidden_states)
  228. return output
  229. class ViTIntermediate(nn.Module):
  230. def __init__(self, config: ViTConfig):
  231. super().__init__()
  232. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  233. if isinstance(config.hidden_act, str):
  234. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  235. else:
  236. self.intermediate_act_fn = config.hidden_act
  237. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  238. hidden_states = self.dense(hidden_states)
  239. hidden_states = self.intermediate_act_fn(hidden_states)
  240. return hidden_states
  241. class ViTOutput(nn.Module):
  242. def __init__(self, config: ViTConfig):
  243. super().__init__()
  244. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  245. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  246. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  247. hidden_states = self.dense(hidden_states)
  248. hidden_states = self.dropout(hidden_states)
  249. hidden_states = hidden_states + input_tensor
  250. return hidden_states
  251. class ViTLayer(GradientCheckpointingLayer):
  252. """This corresponds to the Block class in the timm implementation."""
  253. def __init__(self, config: ViTConfig):
  254. super().__init__()
  255. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  256. self.seq_len_dim = 1
  257. self.attention = ViTAttention(config)
  258. self.intermediate = ViTIntermediate(config)
  259. self.output = ViTOutput(config)
  260. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  261. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  262. def forward(
  263. self,
  264. hidden_states: torch.Tensor,
  265. **kwargs: Unpack[TransformersKwargs],
  266. ) -> torch.Tensor:
  267. hidden_states_norm = self.layernorm_before(hidden_states)
  268. attention_output = self.attention(hidden_states_norm, **kwargs)
  269. # first residual connection
  270. hidden_states = attention_output + hidden_states
  271. # in ViT, layernorm is also applied after self-attention
  272. layer_output = self.layernorm_after(hidden_states)
  273. layer_output = self.intermediate(layer_output)
  274. # second residual connection is done here
  275. layer_output = self.output(layer_output, hidden_states)
  276. return layer_output
  277. class ViTEncoder(nn.Module):
  278. def __init__(self, config: ViTConfig):
  279. super().__init__()
  280. self.config = config
  281. self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
  282. self.gradient_checkpointing = False
  283. def forward(
  284. self,
  285. hidden_states: torch.Tensor,
  286. **kwargs: Unpack[TransformersKwargs],
  287. ) -> BaseModelOutput:
  288. for layer_module in self.layer:
  289. hidden_states = layer_module(hidden_states, **kwargs)
  290. return BaseModelOutput(last_hidden_state=hidden_states)
  291. @auto_docstring
  292. class ViTPreTrainedModel(PreTrainedModel):
  293. config: ViTConfig
  294. base_model_prefix = "vit"
  295. main_input_name = "pixel_values"
  296. input_modalities = ("image",)
  297. supports_gradient_checkpointing = True
  298. _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
  299. _supports_sdpa = True
  300. _supports_flash_attn = True
  301. _supports_flex_attn = True
  302. _supports_attention_backend = True
  303. _can_record_outputs = {
  304. "hidden_states": ViTLayer,
  305. "attentions": ViTSelfAttention,
  306. }
  307. @torch.no_grad()
  308. def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm):
  309. """Initialize the weights"""
  310. if isinstance(module, nn.Linear | nn.Conv2d):
  311. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  312. if module.bias is not None:
  313. init.zeros_(module.bias)
  314. elif isinstance(module, nn.LayerNorm):
  315. init.zeros_(module.bias)
  316. init.ones_(module.weight)
  317. elif isinstance(module, ViTEmbeddings):
  318. init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
  319. init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
  320. if module.mask_token is not None:
  321. init.zeros_(module.mask_token)
  322. @auto_docstring
  323. class ViTModel(ViTPreTrainedModel):
  324. def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
  325. r"""
  326. add_pooling_layer (bool, *optional*, defaults to `True`):
  327. Whether to add a pooling layer
  328. use_mask_token (`bool`, *optional*, defaults to `False`):
  329. Whether to use a mask token for masked image modeling.
  330. """
  331. super().__init__(config)
  332. self.config = config
  333. self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
  334. self.encoder = ViTEncoder(config)
  335. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  336. self.pooler = ViTPooler(config) if add_pooling_layer else None
  337. # Initialize weights and apply final processing
  338. self.post_init()
  339. def get_input_embeddings(self) -> ViTPatchEmbeddings:
  340. return self.embeddings.patch_embeddings
  341. @merge_with_config_defaults
  342. @capture_outputs(tie_last_hidden_states=False)
  343. @auto_docstring
  344. def forward(
  345. self,
  346. pixel_values: torch.Tensor | None = None,
  347. bool_masked_pos: torch.BoolTensor | None = None,
  348. interpolate_pos_encoding: bool | None = None,
  349. **kwargs: Unpack[TransformersKwargs],
  350. ) -> BaseModelOutputWithPooling:
  351. r"""
  352. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  353. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  354. """
  355. if pixel_values is None:
  356. raise ValueError("You have to specify pixel_values")
  357. # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
  358. expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
  359. if pixel_values.dtype != expected_dtype:
  360. pixel_values = pixel_values.to(expected_dtype)
  361. embedding_output = self.embeddings(
  362. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  363. )
  364. encoder_outputs: BaseModelOutput = self.encoder(embedding_output)
  365. sequence_output = encoder_outputs.last_hidden_state
  366. sequence_output = self.layernorm(sequence_output)
  367. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  368. return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output)
  369. class ViTPooler(nn.Module):
  370. def __init__(self, config: ViTConfig):
  371. super().__init__()
  372. self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
  373. self.activation = ACT2FN[config.pooler_act]
  374. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  375. # We "pool" the model by simply taking the hidden state corresponding
  376. # to the first token.
  377. first_token_tensor = hidden_states[:, 0]
  378. pooled_output = self.dense(first_token_tensor)
  379. pooled_output = self.activation(pooled_output)
  380. return pooled_output
  381. @auto_docstring(
  382. custom_intro="""
  383. ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
  384. <Tip>
  385. Note that we provide a script to pre-train this model on custom data in our [examples
  386. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  387. </Tip>
  388. """
  389. )
  390. class ViTForMaskedImageModeling(ViTPreTrainedModel):
  391. def __init__(self, config: ViTConfig):
  392. super().__init__(config)
  393. self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
  394. self.decoder = nn.Sequential(
  395. nn.Conv2d(
  396. in_channels=config.hidden_size,
  397. out_channels=config.encoder_stride**2 * config.num_channels,
  398. kernel_size=1,
  399. ),
  400. nn.PixelShuffle(config.encoder_stride),
  401. )
  402. # Initialize weights and apply final processing
  403. self.post_init()
  404. @can_return_tuple
  405. @auto_docstring
  406. def forward(
  407. self,
  408. pixel_values: torch.Tensor | None = None,
  409. bool_masked_pos: torch.BoolTensor | None = None,
  410. interpolate_pos_encoding: bool | None = None,
  411. **kwargs: Unpack[TransformersKwargs],
  412. ) -> MaskedImageModelingOutput:
  413. r"""
  414. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  415. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  416. Examples:
  417. ```python
  418. >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
  419. >>> import torch
  420. >>> from PIL import Image
  421. >>> import httpx
  422. >>> from io import BytesIO
  423. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  424. >>> with httpx.stream("GET", url) as response:
  425. ... image = Image.open(BytesIO(response.read()))
  426. >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
  427. >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
  428. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  429. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  430. >>> # create random boolean mask of shape (batch_size, num_patches)
  431. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  432. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  433. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  434. >>> list(reconstructed_pixel_values.shape)
  435. [1, 3, 224, 224]
  436. ```"""
  437. if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
  438. raise ValueError(
  439. "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
  440. "the reconstructed image has the same dimensions as the input. "
  441. f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}."
  442. )
  443. outputs: BaseModelOutputWithPooling = self.vit(
  444. pixel_values,
  445. bool_masked_pos=bool_masked_pos,
  446. interpolate_pos_encoding=interpolate_pos_encoding,
  447. **kwargs,
  448. )
  449. sequence_output = outputs.last_hidden_state
  450. # Reshape to (batch_size, num_channels, height, width)
  451. sequence_output = sequence_output[:, 1:]
  452. batch_size, sequence_length, num_channels = sequence_output.shape
  453. height = width = math.floor(sequence_length**0.5)
  454. sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  455. # Reconstruct pixel values
  456. reconstructed_pixel_values = self.decoder(sequence_output)
  457. masked_im_loss = None
  458. if bool_masked_pos is not None:
  459. size = self.config.image_size // self.config.patch_size
  460. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  461. mask = (
  462. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  463. .repeat_interleave(self.config.patch_size, 2)
  464. .unsqueeze(1)
  465. .contiguous()
  466. )
  467. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  468. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  469. return MaskedImageModelingOutput(
  470. loss=masked_im_loss,
  471. reconstruction=reconstructed_pixel_values,
  472. hidden_states=outputs.hidden_states,
  473. attentions=outputs.attentions,
  474. )
  475. @auto_docstring(
  476. custom_intro="""
  477. ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  478. the [CLS] token) e.g. for ImageNet.
  479. <Tip>
  480. Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
  481. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  482. position embeddings to the higher resolution.
  483. </Tip>
  484. """
  485. )
  486. class ViTForImageClassification(ViTPreTrainedModel):
  487. def __init__(self, config: ViTConfig):
  488. super().__init__(config)
  489. self.num_labels = config.num_labels
  490. self.vit = ViTModel(config, add_pooling_layer=False)
  491. # Classifier head
  492. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  493. # Initialize weights and apply final processing
  494. self.post_init()
  495. @can_return_tuple
  496. @auto_docstring
  497. def forward(
  498. self,
  499. pixel_values: torch.Tensor | None = None,
  500. labels: torch.Tensor | None = None,
  501. interpolate_pos_encoding: bool | None = None,
  502. **kwargs: Unpack[TransformersKwargs],
  503. ) -> ImageClassifierOutput:
  504. r"""
  505. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  506. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  507. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  508. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  509. """
  510. outputs: BaseModelOutputWithPooling = self.vit(
  511. pixel_values,
  512. interpolate_pos_encoding=interpolate_pos_encoding,
  513. **kwargs,
  514. )
  515. sequence_output = outputs.last_hidden_state
  516. pooled_output = sequence_output[:, 0, :]
  517. logits = self.classifier(pooled_output)
  518. loss = None
  519. if labels is not None:
  520. loss = self.loss_function(labels, logits, self.config, **kwargs)
  521. return ImageClassifierOutput(
  522. loss=loss,
  523. logits=logits,
  524. hidden_states=outputs.hidden_states,
  525. attentions=outputs.attentions,
  526. )
  527. __all__ = ["ViTForImageClassification", "ViTForMaskedImageModeling", "ViTModel", "ViTPreTrainedModel"]