modeling_deit.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765
  1. # Copyright 2021 Facebook AI Research (FAIR), Ross Wightman, The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch DeiT model."""
  15. import collections.abc
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. ImageClassifierOutput,
  27. MaskedImageModelingOutput,
  28. )
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int
  32. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  33. from ...utils.output_capturing import capture_outputs
  34. from .configuration_deit import DeiTConfig
  35. logger = logging.get_logger(__name__)
  36. class DeiTEmbeddings(nn.Module):
  37. """
  38. Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
  39. """
  40. def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
  41. super().__init__()
  42. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  43. self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  44. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
  45. self.patch_embeddings = DeiTPatchEmbeddings(config)
  46. num_patches = self.patch_embeddings.num_patches
  47. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size))
  48. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  49. self.patch_size = config.patch_size
  50. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  51. """
  52. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  53. images. This method is also adapted to support torch.jit tracing and 2 class embeddings.
  54. Adapted from:
  55. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  56. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  57. """
  58. num_patches = embeddings.shape[1] - 2
  59. num_positions = self.position_embeddings.shape[1] - 2
  60. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  61. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  62. return self.position_embeddings
  63. class_and_dist_pos_embed = self.position_embeddings[:, :2]
  64. patch_pos_embed = self.position_embeddings[:, 2:]
  65. dim = embeddings.shape[-1]
  66. new_height = height // self.patch_size
  67. new_width = width // self.patch_size
  68. sqrt_num_positions = torch_int(num_positions**0.5)
  69. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  70. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  71. patch_pos_embed = nn.functional.interpolate(
  72. patch_pos_embed,
  73. size=(new_height, new_width),
  74. mode="bicubic",
  75. align_corners=False,
  76. )
  77. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  78. return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1)
  79. def forward(
  80. self,
  81. pixel_values: torch.Tensor,
  82. bool_masked_pos: torch.BoolTensor | None = None,
  83. interpolate_pos_encoding: bool = False,
  84. ) -> torch.Tensor:
  85. _, _, height, width = pixel_values.shape
  86. embeddings = self.patch_embeddings(pixel_values)
  87. batch_size, seq_length, _ = embeddings.size()
  88. if bool_masked_pos is not None:
  89. mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
  90. # replace the masked visual tokens by mask_tokens
  91. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  92. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  93. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  94. distillation_tokens = self.distillation_token.expand(batch_size, -1, -1)
  95. embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1)
  96. position_embedding = self.position_embeddings
  97. if interpolate_pos_encoding:
  98. position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
  99. embeddings = embeddings + position_embedding
  100. embeddings = self.dropout(embeddings)
  101. return embeddings
  102. class DeiTPatchEmbeddings(nn.Module):
  103. """
  104. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  105. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  106. Transformer.
  107. """
  108. def __init__(self, config):
  109. super().__init__()
  110. image_size, patch_size = config.image_size, config.patch_size
  111. num_channels, hidden_size = config.num_channels, config.hidden_size
  112. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  113. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  114. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  115. self.image_size = image_size
  116. self.patch_size = patch_size
  117. self.num_channels = num_channels
  118. self.num_patches = num_patches
  119. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  120. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  121. batch_size, num_channels, height, width = pixel_values.shape
  122. if num_channels != self.num_channels:
  123. raise ValueError(
  124. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  125. )
  126. x = self.projection(pixel_values).flatten(2).transpose(1, 2)
  127. return x
  128. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  129. def eager_attention_forward(
  130. module: nn.Module,
  131. query: torch.Tensor,
  132. key: torch.Tensor,
  133. value: torch.Tensor,
  134. attention_mask: torch.Tensor | None,
  135. scaling: float | None = None,
  136. dropout: float = 0.0,
  137. **kwargs: Unpack[TransformersKwargs],
  138. ):
  139. if scaling is None:
  140. scaling = query.size(-1) ** -0.5
  141. # Take the dot product between "query" and "key" to get the raw attention scores.
  142. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  143. if attention_mask is not None:
  144. attn_weights = attn_weights + attention_mask
  145. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  146. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  147. attn_output = torch.matmul(attn_weights, value)
  148. attn_output = attn_output.transpose(1, 2).contiguous()
  149. return attn_output, attn_weights
  150. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT
  151. class DeiTSelfAttention(nn.Module):
  152. def __init__(self, config: DeiTConfig):
  153. super().__init__()
  154. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  155. raise ValueError(
  156. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  157. f"heads {config.num_attention_heads}."
  158. )
  159. self.config = config
  160. self.num_attention_heads = config.num_attention_heads
  161. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  162. self.all_head_size = self.num_attention_heads * self.attention_head_size
  163. self.dropout_prob = config.attention_probs_dropout_prob
  164. self.scaling = self.attention_head_size**-0.5
  165. self.is_causal = False
  166. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  167. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  168. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  169. def forward(
  170. self,
  171. hidden_states: torch.Tensor,
  172. **kwargs: Unpack[TransformersKwargs],
  173. ) -> tuple[torch.Tensor, torch.Tensor]:
  174. batch_size = hidden_states.shape[0]
  175. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  176. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  177. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  178. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  179. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  180. self.config._attn_implementation, eager_attention_forward
  181. )
  182. context_layer, attention_probs = attention_interface(
  183. self,
  184. query_layer,
  185. key_layer,
  186. value_layer,
  187. None,
  188. is_causal=self.is_causal,
  189. scaling=self.scaling,
  190. dropout=0.0 if not self.training else self.dropout_prob,
  191. **kwargs,
  192. )
  193. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  194. context_layer = context_layer.reshape(new_context_layer_shape)
  195. return context_layer, attention_probs
  196. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT
  197. class DeiTSelfOutput(nn.Module):
  198. """
  199. The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
  200. layernorm applied before each block.
  201. """
  202. def __init__(self, config: DeiTConfig):
  203. super().__init__()
  204. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  205. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  206. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  207. hidden_states = self.dense(hidden_states)
  208. hidden_states = self.dropout(hidden_states)
  209. return hidden_states
  210. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT
  211. class DeiTAttention(nn.Module):
  212. def __init__(self, config: DeiTConfig):
  213. super().__init__()
  214. self.attention = DeiTSelfAttention(config)
  215. self.output = DeiTSelfOutput(config)
  216. def forward(
  217. self,
  218. hidden_states: torch.Tensor,
  219. **kwargs: Unpack[TransformersKwargs],
  220. ) -> torch.Tensor:
  221. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  222. output = self.output(self_attn_output, hidden_states)
  223. return output
  224. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT
  225. class DeiTIntermediate(nn.Module):
  226. def __init__(self, config: DeiTConfig):
  227. super().__init__()
  228. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  229. if isinstance(config.hidden_act, str):
  230. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  231. else:
  232. self.intermediate_act_fn = config.hidden_act
  233. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  234. hidden_states = self.dense(hidden_states)
  235. hidden_states = self.intermediate_act_fn(hidden_states)
  236. return hidden_states
  237. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT
  238. class DeiTOutput(nn.Module):
  239. def __init__(self, config: DeiTConfig):
  240. super().__init__()
  241. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  242. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  243. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  244. hidden_states = self.dense(hidden_states)
  245. hidden_states = self.dropout(hidden_states)
  246. hidden_states = hidden_states + input_tensor
  247. return hidden_states
  248. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT
  249. class DeiTLayer(GradientCheckpointingLayer):
  250. """This corresponds to the Block class in the timm implementation."""
  251. def __init__(self, config: DeiTConfig):
  252. super().__init__()
  253. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  254. self.seq_len_dim = 1
  255. self.attention = DeiTAttention(config)
  256. self.intermediate = DeiTIntermediate(config)
  257. self.output = DeiTOutput(config)
  258. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  259. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  260. def forward(
  261. self,
  262. hidden_states: torch.Tensor,
  263. **kwargs: Unpack[TransformersKwargs],
  264. ) -> torch.Tensor:
  265. hidden_states_norm = self.layernorm_before(hidden_states)
  266. attention_output = self.attention(hidden_states_norm, **kwargs)
  267. # first residual connection
  268. hidden_states = attention_output + hidden_states
  269. # in DeiT, layernorm is also applied after self-attention
  270. layer_output = self.layernorm_after(hidden_states)
  271. layer_output = self.intermediate(layer_output)
  272. # second residual connection is done here
  273. layer_output = self.output(layer_output, hidden_states)
  274. return layer_output
  275. # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT
  276. class DeiTEncoder(nn.Module):
  277. def __init__(self, config: DeiTConfig):
  278. super().__init__()
  279. self.config = config
  280. self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
  281. self.gradient_checkpointing = False
  282. def forward(
  283. self,
  284. hidden_states: torch.Tensor,
  285. **kwargs: Unpack[TransformersKwargs],
  286. ) -> BaseModelOutput:
  287. for layer_module in self.layer:
  288. hidden_states = layer_module(hidden_states, **kwargs)
  289. return BaseModelOutput(last_hidden_state=hidden_states)
  290. @auto_docstring
  291. class DeiTPreTrainedModel(PreTrainedModel):
  292. config: DeiTConfig
  293. base_model_prefix = "deit"
  294. main_input_name = "pixel_values"
  295. input_modalities = ("image",)
  296. supports_gradient_checkpointing = True
  297. _no_split_modules = ["DeiTLayer"]
  298. _supports_sdpa = True
  299. _supports_flash_attn = True
  300. _supports_flex_attn = True
  301. _supports_attention_backend = True
  302. _can_record_outputs = {
  303. "hidden_states": DeiTLayer,
  304. "attentions": DeiTSelfAttention,
  305. }
  306. @torch.no_grad()
  307. def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
  308. """Initialize the weights"""
  309. if isinstance(module, nn.Linear | nn.Conv2d):
  310. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  311. if module.bias is not None:
  312. init.zeros_(module.bias)
  313. elif isinstance(module, nn.LayerNorm):
  314. init.zeros_(module.bias)
  315. init.ones_(module.weight)
  316. elif isinstance(module, DeiTEmbeddings):
  317. init.zeros_(module.cls_token)
  318. init.zeros_(module.position_embeddings)
  319. init.zeros_(module.distillation_token)
  320. if module.mask_token is not None:
  321. init.zeros_(module.mask_token)
  322. @auto_docstring
  323. class DeiTModel(DeiTPreTrainedModel):
  324. def __init__(self, config: DeiTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
  325. r"""
  326. add_pooling_layer (bool, *optional*, defaults to `True`):
  327. Whether to add a pooling layer
  328. use_mask_token (`bool`, *optional*, defaults to `False`):
  329. Whether to use a mask token for masked image modeling.
  330. """
  331. super().__init__(config)
  332. self.config = config
  333. self.embeddings = DeiTEmbeddings(config, use_mask_token=use_mask_token)
  334. self.encoder = DeiTEncoder(config)
  335. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  336. self.pooler = DeiTPooler(config) if add_pooling_layer else None
  337. # Initialize weights and apply final processing
  338. self.post_init()
  339. def get_input_embeddings(self) -> DeiTPatchEmbeddings:
  340. return self.embeddings.patch_embeddings
  341. @merge_with_config_defaults
  342. @capture_outputs(tie_last_hidden_states=False)
  343. @auto_docstring
  344. def forward(
  345. self,
  346. pixel_values: torch.Tensor | None = None,
  347. bool_masked_pos: torch.BoolTensor | None = None,
  348. interpolate_pos_encoding: bool = False,
  349. **kwargs: Unpack[TransformersKwargs],
  350. ) -> BaseModelOutputWithPooling:
  351. r"""
  352. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  353. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  354. """
  355. if pixel_values is None:
  356. raise ValueError("You have to specify pixel_values")
  357. # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
  358. expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
  359. if pixel_values.dtype != expected_dtype:
  360. pixel_values = pixel_values.to(expected_dtype)
  361. embedding_output = self.embeddings(
  362. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  363. )
  364. encoder_outputs: BaseModelOutput = self.encoder(embedding_output)
  365. sequence_output = encoder_outputs.last_hidden_state
  366. sequence_output = self.layernorm(sequence_output)
  367. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  368. return BaseModelOutputWithPooling(
  369. last_hidden_state=sequence_output,
  370. pooler_output=pooled_output,
  371. )
  372. # Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT
  373. class DeiTPooler(nn.Module):
  374. def __init__(self, config: DeiTConfig):
  375. super().__init__()
  376. self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
  377. self.activation = ACT2FN[config.pooler_act]
  378. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  379. # We "pool" the model by simply taking the hidden state corresponding
  380. # to the first token.
  381. first_token_tensor = hidden_states[:, 0]
  382. pooled_output = self.dense(first_token_tensor)
  383. pooled_output = self.activation(pooled_output)
  384. return pooled_output
  385. @auto_docstring(
  386. custom_intro="""
  387. DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
  388. <Tip>
  389. Note that we provide a script to pre-train this model on custom data in our [examples
  390. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  391. </Tip>
  392. """
  393. )
  394. class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
  395. def __init__(self, config: DeiTConfig) -> None:
  396. super().__init__(config)
  397. self.deit = DeiTModel(config, add_pooling_layer=False, use_mask_token=True)
  398. self.decoder = nn.Sequential(
  399. nn.Conv2d(
  400. in_channels=config.hidden_size,
  401. out_channels=config.encoder_stride**2 * config.num_channels,
  402. kernel_size=1,
  403. ),
  404. nn.PixelShuffle(config.encoder_stride),
  405. )
  406. # Initialize weights and apply final processing
  407. self.post_init()
  408. @can_return_tuple
  409. @auto_docstring
  410. def forward(
  411. self,
  412. pixel_values: torch.Tensor | None = None,
  413. bool_masked_pos: torch.BoolTensor | None = None,
  414. interpolate_pos_encoding: bool = False,
  415. **kwargs: Unpack[TransformersKwargs],
  416. ) -> MaskedImageModelingOutput:
  417. r"""
  418. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  419. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  420. Examples:
  421. ```python
  422. >>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
  423. >>> import torch
  424. >>> from PIL import Image
  425. >>> import httpx
  426. >>> from io import BytesIO
  427. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  428. >>> with httpx.stream("GET", url) as response:
  429. ... image = Image.open(BytesIO(response.read()))
  430. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
  431. >>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")
  432. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  433. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  434. >>> # create random boolean mask of shape (batch_size, num_patches)
  435. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  436. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  437. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  438. >>> list(reconstructed_pixel_values.shape)
  439. [1, 3, 224, 224]
  440. ```"""
  441. outputs: BaseModelOutputWithPooling = self.deit(
  442. pixel_values,
  443. bool_masked_pos=bool_masked_pos,
  444. interpolate_pos_encoding=interpolate_pos_encoding,
  445. **kwargs,
  446. )
  447. sequence_output = outputs.last_hidden_state
  448. # Reshape to (batch_size, num_channels, height, width)
  449. sequence_output = sequence_output[:, 1:-1]
  450. batch_size, sequence_length, num_channels = sequence_output.shape
  451. height = width = int(sequence_length**0.5)
  452. sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  453. # Reconstruct pixel values
  454. reconstructed_pixel_values = self.decoder(sequence_output)
  455. masked_im_loss = None
  456. if bool_masked_pos is not None:
  457. size = self.config.image_size // self.config.patch_size
  458. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  459. mask = (
  460. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  461. .repeat_interleave(self.config.patch_size, 2)
  462. .unsqueeze(1)
  463. .contiguous()
  464. )
  465. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  466. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  467. return MaskedImageModelingOutput(
  468. loss=masked_im_loss,
  469. reconstruction=reconstructed_pixel_values,
  470. hidden_states=outputs.hidden_states,
  471. attentions=outputs.attentions,
  472. )
  473. @auto_docstring(
  474. custom_intro="""
  475. DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  476. the [CLS] token) e.g. for ImageNet.
  477. """
  478. )
  479. class DeiTForImageClassification(DeiTPreTrainedModel):
  480. def __init__(self, config: DeiTConfig) -> None:
  481. super().__init__(config)
  482. self.num_labels = config.num_labels
  483. self.deit = DeiTModel(config, add_pooling_layer=False)
  484. # Classifier head
  485. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  486. # Initialize weights and apply final processing
  487. self.post_init()
  488. @can_return_tuple
  489. @auto_docstring
  490. def forward(
  491. self,
  492. pixel_values: torch.Tensor | None = None,
  493. labels: torch.Tensor | None = None,
  494. interpolate_pos_encoding: bool = False,
  495. **kwargs: Unpack[TransformersKwargs],
  496. ) -> ImageClassifierOutput:
  497. r"""
  498. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  499. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  500. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  501. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  502. Examples:
  503. ```python
  504. >>> from transformers import AutoImageProcessor, DeiTForImageClassification
  505. >>> import torch
  506. >>> from PIL import Image
  507. >>> import httpx
  508. >>> from io import BytesIO
  509. >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
  510. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  511. >>> with httpx.stream("GET", url) as response:
  512. ... image = Image.open(BytesIO(response.read()))
  513. >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
  514. >>> # so the head will be randomly initialized, hence the predictions will be random
  515. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
  516. >>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")
  517. >>> inputs = image_processor(images=image, return_tensors="pt")
  518. >>> outputs = model(**inputs)
  519. >>> logits = outputs.logits
  520. >>> # model predicts one of the 1000 ImageNet classes
  521. >>> predicted_class_idx = logits.argmax(-1).item()
  522. >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
  523. Predicted class: Polaroid camera, Polaroid Land camera
  524. ```"""
  525. outputs: BaseModelOutputWithPooling = self.deit(
  526. pixel_values,
  527. interpolate_pos_encoding=interpolate_pos_encoding,
  528. **kwargs,
  529. )
  530. sequence_output = outputs.last_hidden_state
  531. logits = self.classifier(sequence_output[:, 0, :])
  532. # we don't use the distillation token
  533. loss = None
  534. if labels is not None:
  535. loss = self.loss_function(labels, logits, self.config, **kwargs)
  536. return ImageClassifierOutput(
  537. loss=loss,
  538. logits=logits,
  539. hidden_states=outputs.hidden_states,
  540. attentions=outputs.attentions,
  541. )
  542. @dataclass
  543. @auto_docstring(
  544. custom_intro="""
  545. Output type of [`DeiTForImageClassificationWithTeacher`].
  546. """
  547. )
  548. class DeiTForImageClassificationWithTeacherOutput(ModelOutput):
  549. r"""
  550. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  551. Prediction scores as the average of the cls_logits and distillation logits.
  552. cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  553. Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
  554. class token).
  555. distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  556. Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
  557. distillation token).
  558. """
  559. logits: torch.FloatTensor | None = None
  560. cls_logits: torch.FloatTensor | None = None
  561. distillation_logits: torch.FloatTensor | None = None
  562. hidden_states: tuple[torch.FloatTensor] | None = None
  563. attentions: tuple[torch.FloatTensor] | None = None
  564. @auto_docstring(
  565. custom_intro="""
  566. DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
  567. the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
  568. .. warning::
  569. This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
  570. supported.
  571. """
  572. )
  573. class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel):
  574. def __init__(self, config: DeiTConfig) -> None:
  575. super().__init__(config)
  576. self.num_labels = config.num_labels
  577. self.deit = DeiTModel(config, add_pooling_layer=False)
  578. # Classifier heads
  579. self.cls_classifier = (
  580. nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  581. )
  582. self.distillation_classifier = (
  583. nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  584. )
  585. # Initialize weights and apply final processing
  586. self.post_init()
  587. @can_return_tuple
  588. @auto_docstring
  589. def forward(
  590. self,
  591. pixel_values: torch.Tensor | None = None,
  592. interpolate_pos_encoding: bool = False,
  593. **kwargs: Unpack[TransformersKwargs],
  594. ) -> DeiTForImageClassificationWithTeacherOutput:
  595. outputs: BaseModelOutputWithPooling = self.deit(
  596. pixel_values,
  597. interpolate_pos_encoding=interpolate_pos_encoding,
  598. **kwargs,
  599. )
  600. sequence_output = outputs.last_hidden_state
  601. cls_logits = self.cls_classifier(sequence_output[:, 0, :])
  602. distillation_logits = self.distillation_classifier(sequence_output[:, 1, :])
  603. # during inference, return the average of both classifier predictions
  604. logits = (cls_logits + distillation_logits) / 2
  605. return DeiTForImageClassificationWithTeacherOutput(
  606. logits=logits,
  607. cls_logits=cls_logits,
  608. distillation_logits=distillation_logits,
  609. hidden_states=outputs.hidden_states,
  610. attentions=outputs.attentions,
  611. )
  612. __all__ = [
  613. "DeiTForImageClassification",
  614. "DeiTForImageClassificationWithTeacher",
  615. "DeiTForMaskedImageModeling",
  616. "DeiTModel",
  617. "DeiTPreTrainedModel",
  618. ]