modeling_ijepa.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/ijepa/modular_ijepa.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_ijepa.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. import collections.abc
  8. from collections.abc import Callable
  9. import torch
  10. import torch.nn as nn
  11. from ... import initialization as init
  12. from ...activations import ACT2FN
  13. from ...modeling_layers import GradientCheckpointingLayer
  14. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  15. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  16. from ...processing_utils import Unpack
  17. from ...utils import TransformersKwargs, auto_docstring, torch_int
  18. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  19. from ...utils.output_capturing import capture_outputs
  20. from .configuration_ijepa import IJepaConfig
  21. class IJepaPatchEmbeddings(nn.Module):
  22. """
  23. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  24. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  25. Transformer.
  26. """
  27. def __init__(self, config: IJepaConfig):
  28. super().__init__()
  29. image_size, patch_size = config.image_size, config.patch_size
  30. num_channels, hidden_size = config.num_channels, config.hidden_size
  31. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  32. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  33. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  34. self.image_size = image_size
  35. self.patch_size = patch_size
  36. self.num_channels = num_channels
  37. self.num_patches = num_patches
  38. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  39. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  40. batch_size, num_channels, height, width = pixel_values.shape
  41. if num_channels != self.num_channels:
  42. raise ValueError(
  43. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  44. f" Expected {self.num_channels} but got {num_channels}."
  45. )
  46. if not interpolate_pos_encoding:
  47. if height != self.image_size[0] or width != self.image_size[1]:
  48. raise ValueError(
  49. f"Input image size ({height}*{width}) doesn't match model"
  50. f" ({self.image_size[0]}*{self.image_size[1]})."
  51. )
  52. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  53. return embeddings
  54. class IJepaEmbeddings(nn.Module):
  55. """
  56. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  57. """
  58. def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
  59. super().__init__()
  60. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
  61. self.patch_embeddings = IJepaPatchEmbeddings(config)
  62. num_patches = self.patch_embeddings.num_patches
  63. self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
  64. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  65. self.patch_size = config.patch_size
  66. self.config = config
  67. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  68. """
  69. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  70. images. This method is also adapted to support torch.jit tracing.
  71. Adapted from:
  72. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  73. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  74. """
  75. num_patches = embeddings.shape[1]
  76. num_positions = self.position_embeddings.shape[1]
  77. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  78. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  79. return self.position_embeddings
  80. patch_pos_embed = self.position_embeddings
  81. dim = embeddings.shape[-1]
  82. new_height = height // self.patch_size
  83. new_width = width // self.patch_size
  84. sqrt_num_positions = torch_int(num_positions**0.5)
  85. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  86. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  87. patch_pos_embed = nn.functional.interpolate(
  88. patch_pos_embed,
  89. size=(new_height, new_width),
  90. mode="bicubic",
  91. align_corners=False,
  92. )
  93. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  94. return patch_pos_embed
  95. def forward(
  96. self,
  97. pixel_values: torch.Tensor,
  98. bool_masked_pos: torch.BoolTensor | None = None,
  99. interpolate_pos_encoding: bool = False,
  100. ) -> torch.Tensor:
  101. batch_size, _, height, width = pixel_values.shape
  102. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  103. if bool_masked_pos is not None:
  104. seq_length = embeddings.shape[1]
  105. mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
  106. # replace the masked visual tokens by mask_tokens
  107. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  108. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  109. # add positional encoding to each token
  110. if interpolate_pos_encoding:
  111. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  112. else:
  113. embeddings = embeddings + self.position_embeddings
  114. embeddings = self.dropout(embeddings)
  115. return embeddings
  116. def eager_attention_forward(
  117. module: nn.Module,
  118. query: torch.Tensor,
  119. key: torch.Tensor,
  120. value: torch.Tensor,
  121. attention_mask: torch.Tensor | None,
  122. scaling: float | None = None,
  123. dropout: float = 0.0,
  124. **kwargs: Unpack[TransformersKwargs],
  125. ):
  126. if scaling is None:
  127. scaling = query.size(-1) ** -0.5
  128. # Take the dot product between "query" and "key" to get the raw attention scores.
  129. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  130. if attention_mask is not None:
  131. attn_weights = attn_weights + attention_mask
  132. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  133. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  134. attn_output = torch.matmul(attn_weights, value)
  135. attn_output = attn_output.transpose(1, 2).contiguous()
  136. return attn_output, attn_weights
  137. class IJepaSelfAttention(nn.Module):
  138. def __init__(self, config: IJepaConfig):
  139. super().__init__()
  140. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  141. raise ValueError(
  142. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  143. f"heads {config.num_attention_heads}."
  144. )
  145. self.config = config
  146. self.num_attention_heads = config.num_attention_heads
  147. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  148. self.all_head_size = self.num_attention_heads * self.attention_head_size
  149. self.dropout_prob = config.attention_probs_dropout_prob
  150. self.scaling = self.attention_head_size**-0.5
  151. self.is_causal = False
  152. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  153. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  154. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  155. def forward(
  156. self,
  157. hidden_states: torch.Tensor,
  158. **kwargs: Unpack[TransformersKwargs],
  159. ) -> tuple[torch.Tensor, torch.Tensor]:
  160. batch_size = hidden_states.shape[0]
  161. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  162. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  163. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  164. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  165. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  166. self.config._attn_implementation, eager_attention_forward
  167. )
  168. context_layer, attention_probs = attention_interface(
  169. self,
  170. query_layer,
  171. key_layer,
  172. value_layer,
  173. None,
  174. is_causal=self.is_causal,
  175. scaling=self.scaling,
  176. dropout=0.0 if not self.training else self.dropout_prob,
  177. **kwargs,
  178. )
  179. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  180. context_layer = context_layer.reshape(new_context_layer_shape)
  181. return context_layer, attention_probs
  182. class IJepaSelfOutput(nn.Module):
  183. """
  184. The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the
  185. layernorm applied before each block.
  186. """
  187. def __init__(self, config: IJepaConfig):
  188. super().__init__()
  189. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  190. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  191. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  192. hidden_states = self.dense(hidden_states)
  193. hidden_states = self.dropout(hidden_states)
  194. return hidden_states
  195. class IJepaAttention(nn.Module):
  196. def __init__(self, config: IJepaConfig):
  197. super().__init__()
  198. self.attention = IJepaSelfAttention(config)
  199. self.output = IJepaSelfOutput(config)
  200. def forward(
  201. self,
  202. hidden_states: torch.Tensor,
  203. **kwargs: Unpack[TransformersKwargs],
  204. ) -> torch.Tensor:
  205. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  206. output = self.output(self_attn_output, hidden_states)
  207. return output
  208. class IJepaIntermediate(nn.Module):
  209. def __init__(self, config: IJepaConfig):
  210. super().__init__()
  211. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  212. if isinstance(config.hidden_act, str):
  213. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  214. else:
  215. self.intermediate_act_fn = config.hidden_act
  216. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  217. hidden_states = self.dense(hidden_states)
  218. hidden_states = self.intermediate_act_fn(hidden_states)
  219. return hidden_states
  220. class IJepaOutput(nn.Module):
  221. def __init__(self, config: IJepaConfig):
  222. super().__init__()
  223. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  224. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  225. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  226. hidden_states = self.dense(hidden_states)
  227. hidden_states = self.dropout(hidden_states)
  228. hidden_states = hidden_states + input_tensor
  229. return hidden_states
  230. class IJepaLayer(GradientCheckpointingLayer):
  231. """This corresponds to the Block class in the timm implementation."""
  232. def __init__(self, config: IJepaConfig):
  233. super().__init__()
  234. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  235. self.seq_len_dim = 1
  236. self.attention = IJepaAttention(config)
  237. self.intermediate = IJepaIntermediate(config)
  238. self.output = IJepaOutput(config)
  239. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  240. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  241. def forward(
  242. self,
  243. hidden_states: torch.Tensor,
  244. **kwargs: Unpack[TransformersKwargs],
  245. ) -> torch.Tensor:
  246. hidden_states_norm = self.layernorm_before(hidden_states)
  247. attention_output = self.attention(hidden_states_norm, **kwargs)
  248. # first residual connection
  249. hidden_states = attention_output + hidden_states
  250. # in IJepa, layernorm is also applied after self-attention
  251. layer_output = self.layernorm_after(hidden_states)
  252. layer_output = self.intermediate(layer_output)
  253. # second residual connection is done here
  254. layer_output = self.output(layer_output, hidden_states)
  255. return layer_output
  256. @auto_docstring
  257. class IJepaPreTrainedModel(PreTrainedModel):
  258. config: IJepaConfig
  259. base_model_prefix = "ijepa"
  260. main_input_name = "pixel_values"
  261. input_modalities = ("image",)
  262. supports_gradient_checkpointing = True
  263. _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
  264. _supports_sdpa = True
  265. _supports_flash_attn = True
  266. _supports_flex_attn = True
  267. _supports_attention_backend = True
  268. _can_record_outputs = {
  269. "hidden_states": IJepaLayer,
  270. "attentions": IJepaSelfAttention,
  271. }
  272. @torch.no_grad()
  273. def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
  274. """Initialize the weights"""
  275. if isinstance(module, (nn.Linear, nn.Conv2d)):
  276. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  277. if module.bias is not None:
  278. init.zeros_(module.bias)
  279. elif isinstance(module, nn.LayerNorm):
  280. init.zeros_(module.bias)
  281. init.ones_(module.weight)
  282. elif isinstance(module, IJepaEmbeddings):
  283. init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
  284. if module.mask_token is not None:
  285. init.zeros_(module.mask_token)
  286. class IJepaEncoder(nn.Module):
  287. def __init__(self, config: IJepaConfig):
  288. super().__init__()
  289. self.config = config
  290. self.layer = nn.ModuleList([IJepaLayer(config) for _ in range(config.num_hidden_layers)])
  291. self.gradient_checkpointing = False
  292. def forward(
  293. self,
  294. hidden_states: torch.Tensor,
  295. **kwargs: Unpack[TransformersKwargs],
  296. ) -> BaseModelOutput:
  297. for layer_module in self.layer:
  298. hidden_states = layer_module(hidden_states, **kwargs)
  299. return BaseModelOutput(last_hidden_state=hidden_states)
  300. class IJepaPooler(nn.Module):
  301. def __init__(self, config: IJepaConfig):
  302. super().__init__()
  303. self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
  304. self.activation = ACT2FN[config.pooler_act]
  305. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  306. # We "pool" the model by simply taking the hidden state corresponding
  307. # to the first token.
  308. first_token_tensor = hidden_states[:, 0]
  309. pooled_output = self.dense(first_token_tensor)
  310. pooled_output = self.activation(pooled_output)
  311. return pooled_output
  312. @auto_docstring
  313. class IJepaModel(IJepaPreTrainedModel):
  314. def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
  315. r"""
  316. add_pooling_layer (bool, *optional*, defaults to `True`):
  317. Whether to add a pooling layer
  318. use_mask_token (`bool`, *optional*, defaults to `False`):
  319. Whether to use a mask token for masked image modeling.
  320. """
  321. super().__init__(config)
  322. self.config = config
  323. self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
  324. self.encoder = IJepaEncoder(config)
  325. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  326. self.pooler = IJepaPooler(config) if add_pooling_layer else None
  327. # Initialize weights and apply final processing
  328. self.post_init()
  329. def get_input_embeddings(self) -> IJepaPatchEmbeddings:
  330. return self.embeddings.patch_embeddings
  331. @merge_with_config_defaults
  332. @capture_outputs(tie_last_hidden_states=False)
  333. @auto_docstring
  334. def forward(
  335. self,
  336. pixel_values: torch.Tensor | None = None,
  337. bool_masked_pos: torch.BoolTensor | None = None,
  338. interpolate_pos_encoding: bool | None = None,
  339. **kwargs: Unpack[TransformersKwargs],
  340. ) -> BaseModelOutputWithPooling:
  341. r"""
  342. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  343. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  344. """
  345. if pixel_values is None:
  346. raise ValueError("You have to specify pixel_values")
  347. # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
  348. expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
  349. if pixel_values.dtype != expected_dtype:
  350. pixel_values = pixel_values.to(expected_dtype)
  351. embedding_output = self.embeddings(
  352. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  353. )
  354. encoder_outputs: BaseModelOutput = self.encoder(embedding_output)
  355. sequence_output = encoder_outputs.last_hidden_state
  356. sequence_output = self.layernorm(sequence_output)
  357. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  358. return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output)
  359. @auto_docstring(
  360. custom_intro="""
  361. IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
  362. e.g. for ImageNet.
  363. <Tip>
  364. Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
  365. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  366. position embeddings to the higher resolution.
  367. </Tip>
  368. """
  369. )
  370. class IJepaForImageClassification(IJepaPreTrainedModel):
  371. def __init__(self, config: IJepaConfig):
  372. super().__init__(config)
  373. self.num_labels = config.num_labels
  374. self.ijepa = IJepaModel(config, add_pooling_layer=False)
  375. # Classifier head
  376. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  377. # Initialize weights and apply final processing
  378. self.post_init()
  379. @can_return_tuple
  380. @auto_docstring
  381. def forward(
  382. self,
  383. pixel_values: torch.Tensor | None = None,
  384. labels: torch.Tensor | None = None,
  385. interpolate_pos_encoding: bool | None = None,
  386. **kwargs: Unpack[TransformersKwargs],
  387. ) -> ImageClassifierOutput:
  388. r"""
  389. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  390. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  391. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  392. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  393. """
  394. outputs: BaseModelOutputWithPooling = self.ijepa(
  395. pixel_values,
  396. interpolate_pos_encoding=interpolate_pos_encoding,
  397. **kwargs,
  398. )
  399. sequence_output = outputs.last_hidden_state
  400. logits = self.classifier(sequence_output.mean(dim=1))
  401. loss = None
  402. if labels is not None:
  403. loss = self.loss_function(labels, logits, self.config, **kwargs)
  404. return ImageClassifierOutput(
  405. loss=loss,
  406. logits=logits,
  407. hidden_states=outputs.hidden_states,
  408. attentions=outputs.attentions,
  409. )
  410. __all__ = ["IJepaPreTrainedModel", "IJepaModel", "IJepaForImageClassification"]