modeling_mlcd.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_mlcd.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. import torch
  22. import torch.nn as nn
  23. from ... import initialization as init
  24. from ...activations import ACT2FN
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, auto_docstring, torch_int
  30. from ...utils.generic import merge_with_config_defaults
  31. from ...utils.output_capturing import capture_outputs
  32. from .configuration_mlcd import MLCDVisionConfig
  33. class MLCDMLP(nn.Module):
  34. def __init__(self, config):
  35. super().__init__()
  36. self.config = config
  37. self.activation_fn = ACT2FN[config.hidden_act]
  38. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  39. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  40. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  41. hidden_states = self.fc1(hidden_states)
  42. hidden_states = self.activation_fn(hidden_states)
  43. hidden_states = self.fc2(hidden_states)
  44. return hidden_states
  45. class MLCDRotaryEmbedding(nn.Module):
  46. inv_freq: torch.Tensor # fix linting for `register_buffer`
  47. def __init__(self, dim: int, theta: float = 10000.0) -> None:
  48. super().__init__()
  49. self.dim = dim
  50. self.theta = theta
  51. inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
  52. self.register_buffer("inv_freq", inv_freq, persistent=False)
  53. def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
  54. """
  55. Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
  56. Args:
  57. num_patches_height (int): Number of patches in the height dimension.
  58. num_patches_width (int): Number of patches in the width dimension.
  59. Returns:
  60. torch.Tensor: Rotary positional embeddings for the given grid size.
  61. """
  62. # Generate position IDs for height and width dimensions
  63. hpos_ids = (
  64. torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
  65. )
  66. wpos_ids = (
  67. torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
  68. )
  69. # Flatten and stack the position IDs
  70. pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
  71. # Generate the full rotary positional embeddings for the maximum grid size
  72. max_grid_size = max(num_patches_height, num_patches_width)
  73. seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
  74. rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
  75. # Select and flatten the embeddings based on the position IDs
  76. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
  77. return rotary_pos_emb
  78. class MLCDVisionEmbeddings(nn.Module):
  79. def __init__(self, config: MLCDVisionConfig):
  80. super().__init__()
  81. self.config = config
  82. self.embed_dim = config.hidden_size
  83. self.image_size = config.image_size
  84. self.patch_size = config.patch_size
  85. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  86. self.patch_embedding = nn.Conv2d(
  87. in_channels=config.num_channels,
  88. out_channels=self.embed_dim,
  89. kernel_size=self.patch_size,
  90. stride=self.patch_size,
  91. bias=False,
  92. )
  93. self.num_patches = (self.image_size // self.patch_size) ** 2
  94. self.num_positions = self.num_patches + 1
  95. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  96. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  97. """
  98. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  99. images. This method is also adapted to support torch.jit tracing.
  100. Adapted from:
  101. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  102. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  103. """
  104. num_patches = embeddings.shape[1] - 1
  105. position_embedding = self.position_embedding.weight.unsqueeze(0)
  106. num_positions = position_embedding.shape[1] - 1
  107. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  108. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  109. return self.position_embedding(self.position_ids)
  110. class_pos_embed = position_embedding[:, :1]
  111. patch_pos_embed = position_embedding[:, 1:]
  112. dim = embeddings.shape[-1]
  113. new_height = height // self.patch_size
  114. new_width = width // self.patch_size
  115. sqrt_num_positions = torch_int(num_positions**0.5)
  116. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  117. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  118. patch_pos_embed = nn.functional.interpolate(
  119. patch_pos_embed,
  120. size=(new_height, new_width),
  121. mode="bicubic",
  122. align_corners=False,
  123. )
  124. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  125. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  126. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  127. batch_size = pixel_values.shape[0]
  128. target_dtype = self.patch_embedding.weight.dtype
  129. # patch_embeds -> shape = [batch, width, grid, grid]
  130. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  131. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  132. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  133. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  134. return embeddings
  135. def eager_attention_forward(
  136. module: nn.Module,
  137. query: torch.Tensor,
  138. key: torch.Tensor,
  139. value: torch.Tensor,
  140. attention_mask: torch.Tensor | None,
  141. scaling: float,
  142. dropout: float = 0.0,
  143. **kwargs: Unpack[TransformersKwargs],
  144. ):
  145. key_states = repeat_kv(key, module.num_key_value_groups)
  146. value_states = repeat_kv(value, module.num_key_value_groups)
  147. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  148. if attention_mask is not None:
  149. attn_weights = attn_weights + attention_mask
  150. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  151. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  152. attn_output = torch.matmul(attn_weights, value_states)
  153. attn_output = attn_output.transpose(1, 2).contiguous()
  154. return attn_output, attn_weights
  155. def rotate_half(x):
  156. """Rotates half the hidden dims of the input."""
  157. x1 = x[..., : x.shape[-1] // 2]
  158. x2 = x[..., x.shape[-1] // 2 :]
  159. return torch.cat((-x2, x1), dim=-1)
  160. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  161. """
  162. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  163. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  164. """
  165. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  166. if n_rep == 1:
  167. return hidden_states
  168. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  169. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  170. def apply_rotary_pos_emb_vision(
  171. q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
  172. ) -> tuple[torch.Tensor, torch.Tensor]:
  173. orig_q_dtype = q.dtype
  174. orig_k_dtype = k.dtype
  175. q, k = q.float(), k.float()
  176. cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
  177. q_embed = (q * cos) + (rotate_half(q) * sin)
  178. k_embed = (k * cos) + (rotate_half(k) * sin)
  179. q_embed = q_embed.to(orig_q_dtype)
  180. k_embed = k_embed.to(orig_k_dtype)
  181. return q_embed, k_embed
  182. class MLCDAttention(nn.Module):
  183. """Multi-headed attention with RoPE. Refer to papers:
  184. - Attention is all you need:
  185. https://huggingface.co/papers/1706.03762
  186. - RoFormer: Enhanced Transformer with Rotary Position Embedding:
  187. https://huggingface.co/papers/2104.09864
  188. """
  189. def __init__(self, config: MLCDVisionConfig):
  190. super().__init__()
  191. self.config = config
  192. self.embed_dim = config.hidden_size
  193. self.num_heads = config.num_attention_heads
  194. self.head_dim = self.embed_dim // self.num_heads
  195. self.scale = self.head_dim**-0.5
  196. self.dropout = config.attention_dropout
  197. self.is_causal = False
  198. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  199. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  200. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  201. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  202. self.num_key_value_groups = config.num_key_value_groups
  203. def forward(
  204. self,
  205. hidden_states: torch.Tensor,
  206. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  207. attention_mask: torch.Tensor | None = None,
  208. **kwargs: Unpack[TransformersKwargs],
  209. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  210. """Input shape: Batch x Time x Channel"""
  211. batch_size, seq_length = hidden_states.shape[:-1]
  212. # Each of shape: [batch_size, seq_length, num_heads, head_dim]
  213. query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  214. key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  215. value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  216. # Apply positional embeddings
  217. cos = position_embeddings[0].unsqueeze(0).float()
  218. sin = position_embeddings[1].unsqueeze(0).float()
  219. query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
  220. # Each of shape: [batch_size, num_heads, seq_length, head_dim]
  221. query_states = query_states.permute(0, 2, 1, 3).contiguous()
  222. key_states = key_states.permute(0, 2, 1, 3).contiguous()
  223. value_states = value_states.permute(0, 2, 1, 3).contiguous()
  224. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  225. self.config._attn_implementation, eager_attention_forward
  226. )
  227. attn_output, attn_weights = attention_interface(
  228. self,
  229. query_states,
  230. key_states,
  231. value_states,
  232. attention_mask,
  233. dropout=0.0 if not self.training else self.dropout,
  234. scaling=self.scale,
  235. is_causal=self.is_causal,
  236. **kwargs,
  237. )
  238. attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
  239. attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
  240. attn_output = self.out_proj(attn_output)
  241. attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
  242. return attn_output, attn_weights
  243. class MLCDEncoderLayer(GradientCheckpointingLayer):
  244. def __init__(self, config: MLCDVisionConfig):
  245. super().__init__()
  246. self.embed_dim = config.hidden_size
  247. self.self_attn = MLCDAttention(config)
  248. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  249. self.mlp = MLCDMLP(config)
  250. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  251. def forward(
  252. self,
  253. hidden_states: torch.Tensor,
  254. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  255. attention_mask: torch.Tensor | None = None,
  256. **kwargs: Unpack[TransformersKwargs],
  257. ) -> tuple[torch.FloatTensor]:
  258. """
  259. Args:
  260. hidden_states (`torch.FloatTensor`):
  261. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  262. Represents the hidden states from the previous layer or the input embeddings.
  263. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
  264. A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
  265. Represents absolute positional embeddings for the query and key in the attention mechanism.
  266. attention_mask (`torch.FloatTensor`):
  267. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  268. """
  269. residual = hidden_states
  270. hidden_states = self.layer_norm1(hidden_states)
  271. hidden_states, _ = self.self_attn(
  272. hidden_states=hidden_states,
  273. position_embeddings=position_embeddings,
  274. attention_mask=attention_mask,
  275. **kwargs,
  276. )
  277. hidden_states = residual + hidden_states
  278. residual = hidden_states
  279. hidden_states = self.layer_norm2(hidden_states)
  280. hidden_states = self.mlp(hidden_states)
  281. hidden_states = residual + hidden_states
  282. return hidden_states
  283. class MLCDEncoder(nn.Module):
  284. """
  285. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  286. [`MLCDEncoderLayer`].
  287. Args:
  288. config: MLCDVisionConfig
  289. """
  290. def __init__(self, config: MLCDVisionConfig):
  291. """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
  292. super().__init__()
  293. self.config = config
  294. self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  295. self.gradient_checkpointing = False
  296. def forward(
  297. self,
  298. inputs_embeds: torch.FloatTensor,
  299. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  300. attention_mask: torch.Tensor | None = None,
  301. **kwargs: Unpack[TransformersKwargs],
  302. ) -> tuple | BaseModelOutput:
  303. r"""
  304. Args:
  305. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  306. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  307. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  308. than the model's internal embedding lookup matrix.
  309. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
  310. A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
  311. Represents absolute positional embeddings for the query and key in the attention mechanism.
  312. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  313. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  314. - 1 for tokens that are **not masked**,
  315. - 0 for tokens that are **masked**.
  316. [What are attention masks?](../glossary#attention-mask)
  317. """
  318. hidden_states = inputs_embeds
  319. for encoder_layer in self.layers:
  320. hidden_states = encoder_layer(
  321. hidden_states,
  322. position_embeddings,
  323. attention_mask,
  324. **kwargs,
  325. )
  326. return BaseModelOutput(
  327. last_hidden_state=hidden_states,
  328. )
  329. @auto_docstring
  330. class MLCDPreTrainedModel(PreTrainedModel):
  331. config: MLCDVisionConfig
  332. base_model_prefix = "mlcd"
  333. supports_gradient_checkpointing = True
  334. accepts_loss_kwargs = False
  335. _supports_flash_attn = True
  336. _supports_sdpa = True
  337. _supports_flex_attn = True
  338. _supports_attention_backend = True
  339. _can_record_outputs = {
  340. "hidden_states": MLCDEncoderLayer,
  341. "attentions": MLCDAttention,
  342. }
  343. @torch.no_grad()
  344. def _init_weights(self, module):
  345. """Initialize the weights"""
  346. factor = self.config.initializer_factor
  347. if isinstance(module, MLCDVisionEmbeddings):
  348. factor = self.config.initializer_factor
  349. init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  350. init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  351. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  352. elif isinstance(module, MLCDAttention):
  353. factor = self.config.initializer_factor
  354. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  355. out_proj_std = (module.embed_dim**-0.5) * factor
  356. init.normal_(module.q_proj.weight, std=in_proj_std)
  357. init.normal_(module.k_proj.weight, std=in_proj_std)
  358. init.normal_(module.v_proj.weight, std=in_proj_std)
  359. init.normal_(module.out_proj.weight, std=out_proj_std)
  360. elif isinstance(module, MLCDMLP):
  361. factor = self.config.initializer_factor
  362. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  363. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  364. init.normal_(module.fc1.weight, std=fc_std)
  365. init.normal_(module.fc2.weight, std=in_proj_std)
  366. elif isinstance(module, MLCDVisionTransformer):
  367. factor = self.config.initializer_factor
  368. pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
  369. init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
  370. elif isinstance(module, nn.LayerNorm):
  371. init.zeros_(module.bias)
  372. init.ones_(module.weight)
  373. elif isinstance(module, nn.Linear) and module.bias is not None:
  374. init.zeros_(module.bias)
  375. elif isinstance(module, MLCDRotaryEmbedding):
  376. inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
  377. init.copy_(module.inv_freq, inv_freq)
  378. class MLCDVisionTransformer(MLCDPreTrainedModel):
  379. config: MLCDVisionConfig
  380. main_input_name = "pixel_values"
  381. input_modalities = ("image",)
  382. _no_split_modules = ["MLCDEncoderLayer"]
  383. def __init__(self, config: MLCDVisionConfig):
  384. super().__init__(config)
  385. self.config = config
  386. embed_dim = config.hidden_size
  387. self.embeddings = MLCDVisionEmbeddings(config)
  388. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  389. self.encoder = MLCDEncoder(config)
  390. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  391. self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
  392. self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
  393. self.post_init()
  394. @merge_with_config_defaults
  395. @capture_outputs(tie_last_hidden_states=False)
  396. @auto_docstring
  397. def forward(
  398. self,
  399. pixel_values: torch.FloatTensor | None = None,
  400. **kwargs: Unpack[TransformersKwargs],
  401. ) -> tuple | BaseModelOutputWithPooling:
  402. if pixel_values is None:
  403. raise ValueError("You have to specify pixel_values")
  404. num_patches_height = pixel_values.shape[-2] // self.config.patch_size
  405. num_patches_width = pixel_values.shape[-1] // self.config.patch_size
  406. rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
  407. rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
  408. rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
  409. emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
  410. position_embeddings = (emb.cos(), emb.sin())
  411. hidden_states = self.embeddings(pixel_values)
  412. hidden_states = self.pre_layrnorm(hidden_states)
  413. encoder_outputs = self.encoder(
  414. inputs_embeds=hidden_states,
  415. position_embeddings=position_embeddings,
  416. **kwargs,
  417. )
  418. last_hidden_state = encoder_outputs[0]
  419. pooled_output = last_hidden_state[:, 0, :]
  420. pooled_output = self.post_layernorm(pooled_output)
  421. return BaseModelOutputWithPooling(
  422. last_hidden_state=last_hidden_state,
  423. pooler_output=pooled_output,
  424. )
  425. @auto_docstring(
  426. custom_intro="""
  427. The vision model from M_L_C_D without any head or projection on top.
  428. """
  429. )
  430. class MLCDVisionModel(MLCDPreTrainedModel):
  431. config: MLCDVisionConfig
  432. main_input_name = "pixel_values"
  433. input_modalities = ("image",)
  434. _no_split_modules = ["MLCDEncoderLayer"]
  435. def __init__(self, config: MLCDVisionConfig):
  436. super().__init__(config)
  437. self.vision_model = MLCDVisionTransformer(config)
  438. # Initialize weights and apply final processing
  439. self.post_init()
  440. def get_input_embeddings(self) -> nn.Module:
  441. return self.vision_model.embeddings.patch_embedding
  442. @auto_docstring
  443. def forward(
  444. self,
  445. pixel_values: torch.FloatTensor | None = None,
  446. **kwargs: Unpack[TransformersKwargs],
  447. ) -> tuple | BaseModelOutputWithPooling:
  448. r"""
  449. Example:
  450. ```python
  451. >>> import httpx
  452. >>> from io import BytesIO
  453. >>> from PIL import Image
  454. >>> from transformers import AutoProcessor, MLCDVisionModel
  455. >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
  456. >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
  457. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  458. >>> with httpx.stream("GET", url) as response:
  459. ... image = Image.open(BytesIO(response.read()))
  460. >>> inputs = processor(images=image, return_tensors="pt")
  461. >>> with torch.no_grad():
  462. ... outputs = model(**inputs, output_attentions=True)
  463. >>> features = outputs.last_hidden_state
  464. >>> print(f"Extracted features shape: {features.shape}")
  465. >>> print(f"Number of attention layers: {len(outputs.attentions)}")
  466. >>> print(f"Attention shape: {outputs.attentions[0].shape}")
  467. ```"""
  468. return self.vision_model(
  469. pixel_values=pixel_values,
  470. **kwargs,
  471. )
  472. __all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"]