modeling_dinov2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch DINOv2 model."""
  15. import collections.abc
  16. from collections.abc import Callable
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...backbone_utils import BackboneMixin, filter_output_hidden_states
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
  27. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  28. from ...utils.output_capturing import capture_outputs
  29. from .configuration_dinov2 import Dinov2Config
  30. logger = logging.get_logger(__name__)
  31. class Dinov2Embeddings(nn.Module):
  32. """
  33. Construct the CLS token, mask token, position and patch embeddings.
  34. """
  35. def __init__(self, config: Dinov2Config) -> None:
  36. super().__init__()
  37. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  38. if config.use_mask_token:
  39. self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
  40. self.patch_embeddings = Dinov2PatchEmbeddings(config)
  41. num_patches = self.patch_embeddings.num_patches
  42. self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
  43. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  44. self.patch_size = config.patch_size
  45. self.use_mask_token = config.use_mask_token
  46. self.config = config
  47. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  48. """
  49. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  50. images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
  51. Adapted from:
  52. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  53. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  54. """
  55. num_patches = embeddings.shape[1] - 1
  56. num_positions = self.position_embeddings.shape[1] - 1
  57. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  58. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  59. return self.position_embeddings
  60. class_pos_embed = self.position_embeddings[:, :1]
  61. patch_pos_embed = self.position_embeddings[:, 1:]
  62. dim = embeddings.shape[-1]
  63. new_height = height // self.patch_size
  64. new_width = width // self.patch_size
  65. sqrt_num_positions = torch_int(num_positions**0.5)
  66. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  67. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  68. target_dtype = patch_pos_embed.dtype
  69. patch_pos_embed = nn.functional.interpolate(
  70. patch_pos_embed.to(torch.float32),
  71. size=(new_height, new_width),
  72. mode="bicubic",
  73. align_corners=False,
  74. ).to(dtype=target_dtype)
  75. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  76. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  77. def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
  78. batch_size, _, height, width = pixel_values.shape
  79. target_dtype = self.patch_embeddings.projection.weight.dtype
  80. embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
  81. if bool_masked_pos is not None and self.use_mask_token:
  82. embeddings = torch.where(
  83. bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
  84. )
  85. # add the [CLS] token to the embedded patch tokens
  86. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  87. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  88. # add positional encoding to each token
  89. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  90. embeddings = self.dropout(embeddings)
  91. return embeddings
  92. class Dinov2PatchEmbeddings(nn.Module):
  93. """
  94. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  95. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  96. Transformer.
  97. """
  98. def __init__(self, config):
  99. super().__init__()
  100. image_size, patch_size = config.image_size, config.patch_size
  101. num_channels, hidden_size = config.num_channels, config.hidden_size
  102. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  103. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  104. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  105. self.image_size = image_size
  106. self.patch_size = patch_size
  107. self.num_channels = num_channels
  108. self.num_patches = num_patches
  109. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  110. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  111. num_channels = pixel_values.shape[1]
  112. if num_channels != self.num_channels:
  113. raise ValueError(
  114. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  115. f" Expected {self.num_channels} but got {num_channels}."
  116. )
  117. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  118. return embeddings
  119. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  120. def eager_attention_forward(
  121. module: nn.Module,
  122. query: torch.Tensor,
  123. key: torch.Tensor,
  124. value: torch.Tensor,
  125. attention_mask: torch.Tensor | None,
  126. scaling: float | None = None,
  127. dropout: float = 0.0,
  128. **kwargs: Unpack[TransformersKwargs],
  129. ):
  130. if scaling is None:
  131. scaling = query.size(-1) ** -0.5
  132. # Take the dot product between "query" and "key" to get the raw attention scores.
  133. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  134. if attention_mask is not None:
  135. attn_weights = attn_weights + attention_mask
  136. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  137. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  138. attn_output = torch.matmul(attn_weights, value)
  139. attn_output = attn_output.transpose(1, 2).contiguous()
  140. return attn_output, attn_weights
  141. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
  142. class Dinov2SelfAttention(nn.Module):
  143. def __init__(self, config: Dinov2Config):
  144. super().__init__()
  145. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  146. raise ValueError(
  147. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  148. f"heads {config.num_attention_heads}."
  149. )
  150. self.config = config
  151. self.num_attention_heads = config.num_attention_heads
  152. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  153. self.all_head_size = self.num_attention_heads * self.attention_head_size
  154. self.dropout_prob = config.attention_probs_dropout_prob
  155. self.scaling = self.attention_head_size**-0.5
  156. self.is_causal = False
  157. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  158. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  159. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  160. def forward(
  161. self,
  162. hidden_states: torch.Tensor,
  163. **kwargs: Unpack[TransformersKwargs],
  164. ) -> tuple[torch.Tensor, torch.Tensor]:
  165. batch_size = hidden_states.shape[0]
  166. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  167. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  168. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  169. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  170. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  171. self.config._attn_implementation, eager_attention_forward
  172. )
  173. context_layer, attention_probs = attention_interface(
  174. self,
  175. query_layer,
  176. key_layer,
  177. value_layer,
  178. None,
  179. is_causal=self.is_causal,
  180. scaling=self.scaling,
  181. dropout=0.0 if not self.training else self.dropout_prob,
  182. **kwargs,
  183. )
  184. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  185. context_layer = context_layer.reshape(new_context_layer_shape)
  186. return context_layer, attention_probs
  187. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
  188. class Dinov2SelfOutput(nn.Module):
  189. """
  190. The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
  191. layernorm applied before each block.
  192. """
  193. def __init__(self, config: Dinov2Config):
  194. super().__init__()
  195. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  196. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  197. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  198. hidden_states = self.dense(hidden_states)
  199. hidden_states = self.dropout(hidden_states)
  200. return hidden_states
  201. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
  202. class Dinov2Attention(nn.Module):
  203. def __init__(self, config: Dinov2Config):
  204. super().__init__()
  205. self.attention = Dinov2SelfAttention(config)
  206. self.output = Dinov2SelfOutput(config)
  207. def forward(
  208. self,
  209. hidden_states: torch.Tensor,
  210. **kwargs: Unpack[TransformersKwargs],
  211. ) -> torch.Tensor:
  212. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  213. output = self.output(self_attn_output, hidden_states)
  214. return output
  215. class Dinov2LayerScale(nn.Module):
  216. def __init__(self, config) -> None:
  217. super().__init__()
  218. self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
  219. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  220. return hidden_state * self.lambda1
  221. # Copied from transformers.models.beit.modeling_beit.drop_path
  222. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  223. """
  224. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  225. """
  226. if drop_prob == 0.0 or not training:
  227. return input
  228. keep_prob = 1 - drop_prob
  229. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  230. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  231. random_tensor.floor_() # binarize
  232. output = input.div(keep_prob) * random_tensor
  233. return output
  234. # Copied from transformers.models.beit.modeling_beit.BeitDropPath
  235. class Dinov2DropPath(nn.Module):
  236. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  237. def __init__(self, drop_prob: float | None = None) -> None:
  238. super().__init__()
  239. self.drop_prob = drop_prob
  240. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  241. return drop_path(hidden_states, self.drop_prob, self.training)
  242. def extra_repr(self) -> str:
  243. return f"p={self.drop_prob}"
  244. class Dinov2MLP(nn.Module):
  245. def __init__(self, config) -> None:
  246. super().__init__()
  247. in_features = out_features = config.hidden_size
  248. hidden_features = int(config.hidden_size * config.mlp_ratio)
  249. self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
  250. if isinstance(config.hidden_act, str):
  251. self.activation = ACT2FN[config.hidden_act]
  252. else:
  253. self.activation = config.hidden_act
  254. self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
  255. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  256. hidden_state = self.fc1(hidden_state)
  257. hidden_state = self.activation(hidden_state)
  258. hidden_state = self.fc2(hidden_state)
  259. return hidden_state
  260. class Dinov2SwiGLUFFN(nn.Module):
  261. def __init__(self, config) -> None:
  262. super().__init__()
  263. in_features = out_features = config.hidden_size
  264. hidden_features = int(config.hidden_size * config.mlp_ratio)
  265. hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  266. self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
  267. self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
  268. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  269. hidden_state = self.weights_in(hidden_state)
  270. x1, x2 = hidden_state.chunk(2, dim=-1)
  271. hidden = nn.functional.silu(x1) * x2
  272. return self.weights_out(hidden)
  273. class Dinov2Layer(GradientCheckpointingLayer):
  274. """This corresponds to the Block class in the original implementation."""
  275. def __init__(self, config: Dinov2Config) -> None:
  276. super().__init__()
  277. self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  278. self.attention = Dinov2Attention(config)
  279. self.layer_scale1 = Dinov2LayerScale(config)
  280. self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  281. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  282. if config.use_swiglu_ffn:
  283. self.mlp = Dinov2SwiGLUFFN(config)
  284. else:
  285. self.mlp = Dinov2MLP(config)
  286. self.layer_scale2 = Dinov2LayerScale(config)
  287. def forward(
  288. self,
  289. hidden_states: torch.Tensor,
  290. ) -> torch.Tensor:
  291. hidden_states_norm = self.norm1(hidden_states)
  292. self_attention_output = self.attention(hidden_states_norm)
  293. self_attention_output = self.layer_scale1(self_attention_output)
  294. # first residual connection
  295. hidden_states = self.drop_path(self_attention_output) + hidden_states
  296. # in Dinov2, layernorm is also applied after self-attention
  297. layer_output = self.norm2(hidden_states)
  298. layer_output = self.mlp(layer_output)
  299. layer_output = self.layer_scale2(layer_output)
  300. # second residual connection
  301. layer_output = self.drop_path(layer_output) + hidden_states
  302. return layer_output
  303. @auto_docstring
  304. class Dinov2PreTrainedModel(PreTrainedModel):
  305. config: Dinov2Config
  306. base_model_prefix = "dinov2"
  307. main_input_name = "pixel_values"
  308. input_modalities = ("image",)
  309. supports_gradient_checkpointing = True
  310. _no_split_modules = ["Dinov2Layer"]
  311. _supports_sdpa = True
  312. _supports_flash_attn = True
  313. _supports_flex_attn = True
  314. _supports_attention_backend = True
  315. _can_record_outputs = {
  316. "hidden_states": Dinov2Layer,
  317. "attentions": Dinov2SelfAttention,
  318. }
  319. @torch.no_grad()
  320. def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
  321. """Initialize the weights"""
  322. if isinstance(module, (nn.Linear, nn.Conv2d)):
  323. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  324. if module.bias is not None:
  325. init.zeros_(module.bias)
  326. elif isinstance(module, nn.LayerNorm):
  327. init.zeros_(module.bias)
  328. init.ones_(module.weight)
  329. elif isinstance(module, Dinov2Embeddings):
  330. init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
  331. init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range)
  332. if self.config.use_mask_token:
  333. init.zeros_(module.mask_token)
  334. elif isinstance(module, Dinov2LayerScale):
  335. init.constant_(module.lambda1, self.config.layerscale_value)
  336. class Dinov2Encoder(Dinov2PreTrainedModel):
  337. def __init__(self, config: Dinov2Config):
  338. super().__init__(config)
  339. self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)])
  340. self.post_init()
  341. @merge_with_config_defaults
  342. @capture_outputs(tie_last_hidden_states=False)
  343. def forward(self, hidden_states: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BaseModelOutput:
  344. for layer_module in self.layer:
  345. hidden_states = layer_module(hidden_states)
  346. return BaseModelOutput(last_hidden_state=hidden_states)
  347. @auto_docstring
  348. class Dinov2Model(Dinov2PreTrainedModel):
  349. def __init__(self, config: Dinov2Config):
  350. super().__init__(config)
  351. self.config = config
  352. self.embeddings = Dinov2Embeddings(config)
  353. self.encoder = Dinov2Encoder(config)
  354. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  355. # Initialize weights and apply final processing
  356. self.post_init()
  357. def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
  358. return self.embeddings.patch_embeddings
  359. @can_return_tuple
  360. @auto_docstring
  361. def forward(
  362. self,
  363. pixel_values: torch.Tensor | None = None,
  364. bool_masked_pos: torch.Tensor | None = None,
  365. **kwargs: Unpack[TransformersKwargs],
  366. ) -> BaseModelOutputWithPooling:
  367. r"""
  368. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
  369. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
  370. pre-training.
  371. """
  372. if pixel_values is None:
  373. raise ValueError("You have to specify pixel_values")
  374. embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  375. encoder_outputs: BaseModelOutput = self.encoder(embedding_output, **kwargs)
  376. sequence_output = encoder_outputs.last_hidden_state
  377. sequence_output = self.layernorm(sequence_output)
  378. pooled_output = sequence_output[:, 0, :]
  379. return BaseModelOutputWithPooling(
  380. last_hidden_state=sequence_output,
  381. pooler_output=pooled_output,
  382. hidden_states=encoder_outputs.hidden_states,
  383. attentions=encoder_outputs.attentions,
  384. )
  385. @auto_docstring(
  386. custom_intro="""
  387. Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
  388. of the [CLS] token) e.g. for ImageNet.
  389. """
  390. )
  391. class Dinov2ForImageClassification(Dinov2PreTrainedModel):
  392. def __init__(self, config: Dinov2Config) -> None:
  393. super().__init__(config)
  394. self.num_labels = config.num_labels
  395. self.dinov2 = Dinov2Model(config)
  396. # Classifier head
  397. self.classifier = (
  398. nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
  399. )
  400. # Initialize weights and apply final processing
  401. self.post_init()
  402. @can_return_tuple
  403. @auto_docstring
  404. def forward(
  405. self,
  406. pixel_values: torch.Tensor | None = None,
  407. labels: torch.Tensor | None = None,
  408. **kwargs: Unpack[TransformersKwargs],
  409. ) -> ImageClassifierOutput:
  410. r"""
  411. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  412. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  413. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  414. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  415. """
  416. outputs: BaseModelOutputWithPooling = self.dinov2(pixel_values, **kwargs)
  417. sequence_output = outputs.last_hidden_state # batch_size, sequence_length, hidden_size
  418. cls_token = sequence_output[:, 0]
  419. patch_tokens = sequence_output[:, 1:]
  420. linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
  421. logits = self.classifier(linear_input)
  422. loss = None
  423. if labels is not None:
  424. loss = self.loss_function(labels, logits, self.config, **kwargs)
  425. return ImageClassifierOutput(
  426. loss=loss,
  427. logits=logits,
  428. hidden_states=outputs.hidden_states,
  429. attentions=outputs.attentions,
  430. )
  431. @auto_docstring(
  432. custom_intro="""
  433. Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
  434. """
  435. )
  436. class Dinov2Backbone(BackboneMixin, Dinov2PreTrainedModel):
  437. def __init__(self, config):
  438. super().__init__(config)
  439. self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
  440. self.embeddings = Dinov2Embeddings(config)
  441. self.encoder = Dinov2Encoder(config)
  442. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  443. # Initialize weights and apply final processing
  444. self.post_init()
  445. def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
  446. return self.embeddings.patch_embeddings
  447. @can_return_tuple
  448. @filter_output_hidden_states
  449. @auto_docstring
  450. def forward(
  451. self,
  452. pixel_values: torch.Tensor,
  453. **kwargs: Unpack[TransformersKwargs],
  454. ) -> BackboneOutput:
  455. r"""
  456. Examples:
  457. ```python
  458. >>> from transformers import AutoImageProcessor, AutoBackbone
  459. >>> import torch
  460. >>> from PIL import Image
  461. >>> import httpx
  462. >>> from io import BytesIO
  463. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  464. >>> with httpx.stream("GET", url) as response:
  465. ... image = Image.open(BytesIO(response.read()))
  466. >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
  467. >>> model = AutoBackbone.from_pretrained(
  468. ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
  469. ... )
  470. >>> inputs = processor(image, return_tensors="pt")
  471. >>> outputs = model(**inputs)
  472. >>> feature_maps = outputs.feature_maps
  473. >>> list(feature_maps[-1].shape)
  474. [1, 768, 16, 16]
  475. ```"""
  476. kwargs["output_hidden_states"] = True # required to extract layers for the stages
  477. embedding_output = self.embeddings(pixel_values)
  478. output: BaseModelOutput = self.encoder(embedding_output, **kwargs)
  479. hidden_states = output.hidden_states
  480. feature_maps = []
  481. for stage, hidden_state in zip(self.stage_names, hidden_states):
  482. if stage in self.out_features:
  483. if self.config.apply_layernorm:
  484. hidden_state = self.layernorm(hidden_state)
  485. if self.config.reshape_hidden_states:
  486. hidden_state = hidden_state[:, 1:]
  487. # this was actually a bug in the original implementation that we copied here,
  488. # cause normally the order is height, width
  489. batch_size, _, height, width = pixel_values.shape
  490. patch_size = self.config.patch_size
  491. hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
  492. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  493. feature_maps.append(hidden_state)
  494. return BackboneOutput(
  495. feature_maps=tuple(feature_maps),
  496. hidden_states=hidden_states,
  497. attentions=output.attentions,
  498. )
  499. __all__ = ["Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", "Dinov2Backbone"]