modular_mlcd.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. import torch
  16. import torch.nn as nn
  17. from huggingface_hub.dataclasses import strict
  18. from ... import initialization as init
  19. from ...configuration_utils import PreTrainedConfig
  20. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  21. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  22. from ...processing_utils import Unpack
  23. from ...utils import TransformersKwargs, auto_docstring, logging
  24. from ..clip.modeling_clip import (
  25. CLIPMLP,
  26. CLIPAttention,
  27. CLIPEncoder,
  28. CLIPEncoderLayer,
  29. CLIPVisionEmbeddings,
  30. CLIPVisionModel,
  31. CLIPVisionTransformer,
  32. )
  33. from ..llama.modeling_llama import eager_attention_forward
  34. from ..qwen2_vl.modeling_qwen2_vl import VisionRotaryEmbedding, apply_rotary_pos_emb_vision
  35. logger = logging.get_logger(__name__)
  36. @auto_docstring(checkpoint="DeepGlint-AI/mlcd-vit-bigG-patch14-336")
  37. @strict
  38. class MLCDVisionConfig(PreTrainedConfig):
  39. r"""
  40. num_key_value_groups (`int`, *optional*, defaults to 1):
  41. Number of key-value groups used in Attention.
  42. Example:
  43. ```python
  44. >>> from transformers import MLCDVisionConfig, MLCDVisionModel
  45. >>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
  46. >>> configuration = MLCDVisionConfig()
  47. >>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
  48. >>> model = MLCDVisionModel(configuration)
  49. >>> # Accessing the model configuration
  50. >>> configuration = model.config
  51. ```"""
  52. model_type = "mlcd_vision_model"
  53. base_config_key = "vision_config"
  54. hidden_size: int = 1664
  55. intermediate_size: int = 8192
  56. num_hidden_layers: int = 48
  57. num_attention_heads: int = 16
  58. num_key_value_groups: int = 1
  59. num_channels: int = 3
  60. image_size: int | list[int] | tuple[int, int] = 336
  61. patch_size: int | list[int] | tuple[int, int] = 14
  62. hidden_act: str = "gelu"
  63. layer_norm_eps: float = 1e-5
  64. attention_dropout: float | int = 0.0
  65. initializer_range: float = 0.02
  66. initializer_factor: float = 1.0
  67. class MLCDMLP(CLIPMLP):
  68. pass
  69. class MLCDRotaryEmbedding(VisionRotaryEmbedding):
  70. def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
  71. """
  72. Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
  73. Args:
  74. num_patches_height (int): Number of patches in the height dimension.
  75. num_patches_width (int): Number of patches in the width dimension.
  76. Returns:
  77. torch.Tensor: Rotary positional embeddings for the given grid size.
  78. """
  79. # Generate position IDs for height and width dimensions
  80. hpos_ids = (
  81. torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
  82. )
  83. wpos_ids = (
  84. torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
  85. )
  86. # Flatten and stack the position IDs
  87. pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
  88. # Generate the full rotary positional embeddings for the maximum grid size
  89. max_grid_size = max(num_patches_height, num_patches_width)
  90. seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
  91. rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
  92. # Select and flatten the embeddings based on the position IDs
  93. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
  94. return rotary_pos_emb
  95. class MLCDVisionEmbeddings(CLIPVisionEmbeddings):
  96. def __init__(self, config: MLCDVisionConfig):
  97. super().__init__(config)
  98. del self.position_embedding
  99. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  100. batch_size = pixel_values.shape[0]
  101. target_dtype = self.patch_embedding.weight.dtype
  102. # patch_embeds -> shape = [batch, width, grid, grid]
  103. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
  104. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  105. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  106. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  107. return embeddings
  108. class MLCDAttention(CLIPAttention):
  109. """Multi-headed attention with RoPE. Refer to papers:
  110. - Attention is all you need:
  111. https://huggingface.co/papers/1706.03762
  112. - RoFormer: Enhanced Transformer with Rotary Position Embedding:
  113. https://huggingface.co/papers/2104.09864
  114. """
  115. def __init__(self, config: MLCDVisionConfig):
  116. super().__init__(config)
  117. self.num_key_value_groups = config.num_key_value_groups
  118. self.is_causal = False
  119. def forward(
  120. self,
  121. hidden_states: torch.Tensor,
  122. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  123. attention_mask: torch.Tensor | None = None,
  124. **kwargs: Unpack[TransformersKwargs],
  125. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  126. batch_size, seq_length = hidden_states.shape[:-1]
  127. # Each of shape: [batch_size, seq_length, num_heads, head_dim]
  128. query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  129. key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  130. value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
  131. # Apply positional embeddings
  132. cos = position_embeddings[0].unsqueeze(0).float()
  133. sin = position_embeddings[1].unsqueeze(0).float()
  134. query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
  135. # Each of shape: [batch_size, num_heads, seq_length, head_dim]
  136. query_states = query_states.permute(0, 2, 1, 3).contiguous()
  137. key_states = key_states.permute(0, 2, 1, 3).contiguous()
  138. value_states = value_states.permute(0, 2, 1, 3).contiguous()
  139. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  140. self.config._attn_implementation, eager_attention_forward
  141. )
  142. attn_output, attn_weights = attention_interface(
  143. self,
  144. query_states,
  145. key_states,
  146. value_states,
  147. attention_mask,
  148. dropout=0.0 if not self.training else self.dropout,
  149. scaling=self.scale,
  150. is_causal=self.is_causal,
  151. **kwargs,
  152. )
  153. attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
  154. attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
  155. attn_output = self.out_proj(attn_output)
  156. attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
  157. return attn_output, attn_weights
  158. class MLCDEncoderLayer(CLIPEncoderLayer):
  159. def __init__(self, config: MLCDVisionConfig):
  160. super().__init__(config)
  161. self.self_attn = MLCDAttention(config)
  162. def forward(
  163. self,
  164. hidden_states: torch.Tensor,
  165. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  166. attention_mask: torch.Tensor | None = None,
  167. **kwargs: Unpack[TransformersKwargs],
  168. ) -> tuple[torch.FloatTensor]:
  169. """
  170. Args:
  171. hidden_states (`torch.FloatTensor`):
  172. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  173. Represents the hidden states from the previous layer or the input embeddings.
  174. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
  175. A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
  176. Represents absolute positional embeddings for the query and key in the attention mechanism.
  177. attention_mask (`torch.FloatTensor`):
  178. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  179. """
  180. residual = hidden_states
  181. hidden_states = self.layer_norm1(hidden_states)
  182. hidden_states, _ = self.self_attn(
  183. hidden_states=hidden_states,
  184. position_embeddings=position_embeddings,
  185. attention_mask=attention_mask,
  186. **kwargs,
  187. )
  188. hidden_states = residual + hidden_states
  189. residual = hidden_states
  190. hidden_states = self.layer_norm2(hidden_states)
  191. hidden_states = self.mlp(hidden_states)
  192. hidden_states = residual + hidden_states
  193. return hidden_states
  194. class MLCDEncoder(CLIPEncoder):
  195. """
  196. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  197. [`MLCDEncoderLayer`].
  198. Args:
  199. config: MLCDVisionConfig
  200. """
  201. def __init__(self, config: MLCDVisionConfig):
  202. """Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
  203. super().__init__(config)
  204. def forward(
  205. self,
  206. inputs_embeds: torch.FloatTensor,
  207. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  208. attention_mask: torch.Tensor | None = None,
  209. **kwargs: Unpack[TransformersKwargs],
  210. ) -> tuple | BaseModelOutput:
  211. r"""
  212. Args:
  213. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  214. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  215. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  216. than the model's internal embedding lookup matrix.
  217. position_embeddings (`tuple[torch.Tensor, torch.Tensor]`):
  218. A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
  219. Represents absolute positional embeddings for the query and key in the attention mechanism.
  220. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  221. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  222. - 1 for tokens that are **not masked**,
  223. - 0 for tokens that are **masked**.
  224. [What are attention masks?](../glossary#attention-mask)
  225. """
  226. hidden_states = inputs_embeds
  227. for encoder_layer in self.layers:
  228. hidden_states = encoder_layer(
  229. hidden_states,
  230. position_embeddings,
  231. attention_mask,
  232. **kwargs,
  233. )
  234. return BaseModelOutput(
  235. last_hidden_state=hidden_states,
  236. )
  237. @auto_docstring
  238. class MLCDPreTrainedModel(PreTrainedModel):
  239. config: MLCDVisionConfig
  240. base_model_prefix = "mlcd"
  241. supports_gradient_checkpointing = True
  242. accepts_loss_kwargs = False
  243. _supports_flash_attn = True
  244. _supports_sdpa = True
  245. _supports_flex_attn = True
  246. _supports_attention_backend = True
  247. _can_record_outputs = {
  248. "hidden_states": MLCDEncoderLayer,
  249. "attentions": MLCDAttention,
  250. }
  251. @torch.no_grad()
  252. def _init_weights(self, module):
  253. """Initialize the weights"""
  254. factor = self.config.initializer_factor
  255. if isinstance(module, MLCDVisionEmbeddings):
  256. factor = self.config.initializer_factor
  257. init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  258. init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  259. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  260. elif isinstance(module, MLCDAttention):
  261. factor = self.config.initializer_factor
  262. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  263. out_proj_std = (module.embed_dim**-0.5) * factor
  264. init.normal_(module.q_proj.weight, std=in_proj_std)
  265. init.normal_(module.k_proj.weight, std=in_proj_std)
  266. init.normal_(module.v_proj.weight, std=in_proj_std)
  267. init.normal_(module.out_proj.weight, std=out_proj_std)
  268. elif isinstance(module, MLCDMLP):
  269. factor = self.config.initializer_factor
  270. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  271. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  272. init.normal_(module.fc1.weight, std=fc_std)
  273. init.normal_(module.fc2.weight, std=in_proj_std)
  274. elif isinstance(module, MLCDVisionTransformer):
  275. factor = self.config.initializer_factor
  276. pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
  277. init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
  278. elif isinstance(module, nn.LayerNorm):
  279. init.zeros_(module.bias)
  280. init.ones_(module.weight)
  281. elif isinstance(module, nn.Linear) and module.bias is not None:
  282. init.zeros_(module.bias)
  283. elif isinstance(module, MLCDRotaryEmbedding):
  284. inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
  285. init.copy_(module.inv_freq, inv_freq)
  286. class MLCDVisionTransformer(CLIPVisionTransformer):
  287. def __init__(self, config: MLCDVisionConfig):
  288. super().__init__(config)
  289. self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
  290. self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
  291. def forward(
  292. self,
  293. pixel_values: torch.FloatTensor | None = None,
  294. **kwargs: Unpack[TransformersKwargs],
  295. ) -> tuple | BaseModelOutputWithPooling:
  296. if pixel_values is None:
  297. raise ValueError("You have to specify pixel_values")
  298. num_patches_height = pixel_values.shape[-2] // self.config.patch_size
  299. num_patches_width = pixel_values.shape[-1] // self.config.patch_size
  300. rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
  301. rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
  302. rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
  303. emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
  304. position_embeddings = (emb.cos(), emb.sin())
  305. hidden_states = self.embeddings(pixel_values)
  306. hidden_states = self.pre_layrnorm(hidden_states)
  307. encoder_outputs = self.encoder(
  308. inputs_embeds=hidden_states,
  309. position_embeddings=position_embeddings,
  310. **kwargs,
  311. )
  312. last_hidden_state = encoder_outputs[0]
  313. pooled_output = last_hidden_state[:, 0, :]
  314. pooled_output = self.post_layernorm(pooled_output)
  315. return BaseModelOutputWithPooling(
  316. last_hidden_state=last_hidden_state,
  317. pooler_output=pooled_output,
  318. )
  319. class MLCDVisionModel(CLIPVisionModel):
  320. def forward(
  321. self,
  322. pixel_values: torch.FloatTensor | None = None,
  323. **kwargs: Unpack[TransformersKwargs],
  324. ) -> tuple | BaseModelOutputWithPooling:
  325. r"""
  326. Example:
  327. ```python
  328. >>> import httpx
  329. >>> from io import BytesIO
  330. >>> from PIL import Image
  331. >>> from transformers import AutoProcessor, MLCDVisionModel
  332. >>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
  333. >>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
  334. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  335. >>> with httpx.stream("GET", url) as response:
  336. ... image = Image.open(BytesIO(response.read()))
  337. >>> inputs = processor(images=image, return_tensors="pt")
  338. >>> with torch.no_grad():
  339. ... outputs = model(**inputs, output_attentions=True)
  340. >>> features = outputs.last_hidden_state
  341. >>> print(f"Extracted features shape: {features.shape}")
  342. >>> print(f"Number of attention layers: {len(outputs.attentions)}")
  343. >>> print(f"Attention shape: {outputs.attentions[0].shape}")
  344. ```"""
  345. return self.vision_model(
  346. pixel_values=pixel_values,
  347. **kwargs,
  348. )
  349. __all__ = [
  350. "MLCDVisionConfig",
  351. "MLCDPreTrainedModel",
  352. "MLCDVisionModel",
  353. ]