modeling_vjepa2.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. from dataclasses import dataclass
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...modeling_layers import GradientCheckpointingLayer
  21. from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
  25. from ...utils.generic import merge_with_config_defaults
  26. from ...utils.output_capturing import OutputRecorder, capture_outputs
  27. from .configuration_vjepa2 import VJEPA2Config
  28. logger = logging.get_logger(__name__)
  29. @dataclass
  30. @auto_docstring(
  31. custom_intro="""
  32. VJEPA Predictor outputs that also contains the masked encoder outputs
  33. """
  34. )
  35. class VJEPA2WithMaskedInputPredictorOutput(ModelOutput):
  36. r"""
  37. masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
  38. The masked hidden state of the model.
  39. target_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs):
  40. The target hidden state of the model.
  41. """
  42. last_hidden_state: torch.FloatTensor
  43. masked_hidden_state: torch.FloatTensor | None = None
  44. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  45. attentions: tuple[torch.FloatTensor, ...] | None = None
  46. target_hidden_state: torch.FloatTensor | None = None
  47. @dataclass
  48. @auto_docstring(
  49. custom_intro="""
  50. VJEPA outputs that also contains the masked encoder outputs
  51. Optionally contains the predictor outputs
  52. """
  53. )
  54. class VJEPA2WithMaskedInputModelOutput(ModelOutput):
  55. r"""
  56. masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
  57. The masked hidden state of the model.
  58. predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*):
  59. The output from the Predictor module.
  60. """
  61. last_hidden_state: torch.FloatTensor
  62. masked_hidden_state: torch.FloatTensor | None = None
  63. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  64. attentions: tuple[torch.FloatTensor, ...] | None = None
  65. predictor_output: VJEPA2WithMaskedInputPredictorOutput | None = None
  66. def to_tuple(self):
  67. output = list(super().to_tuple())
  68. if isinstance(output[-1], VJEPA2WithMaskedInputPredictorOutput):
  69. output[-1] = output[-1].to_tuple()
  70. return tuple(output)
  71. class VJEPA2PatchEmbeddings3D(nn.Module):
  72. """
  73. Image to Patch Embedding
  74. """
  75. def __init__(
  76. self,
  77. config: VJEPA2Config,
  78. hidden_size: int = 1024,
  79. ):
  80. super().__init__()
  81. self.patch_size = config.patch_size
  82. self.tubelet_size = config.tubelet_size
  83. self.hidden_size = hidden_size
  84. self.proj = nn.Conv3d(
  85. in_channels=config.in_chans,
  86. out_channels=hidden_size,
  87. kernel_size=(config.tubelet_size, config.patch_size, config.patch_size),
  88. stride=(config.tubelet_size, config.patch_size, config.patch_size),
  89. )
  90. @staticmethod
  91. def num_patches(config):
  92. return (
  93. (config.frames_per_clip // config.tubelet_size)
  94. * (config.crop_size // config.patch_size)
  95. * (config.crop_size // config.patch_size)
  96. )
  97. def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
  98. x = self.proj(pixel_values_videos).flatten(2).transpose(1, 2)
  99. return x
  100. class VJEPA2Embeddings(nn.Module):
  101. """
  102. Construct mask token, position and patch embeddings.
  103. """
  104. def __init__(self, config: VJEPA2Config, hidden_size: int = 1024):
  105. super().__init__()
  106. self.config = config
  107. self.hidden_size = hidden_size
  108. self.patch_embeddings = VJEPA2PatchEmbeddings3D(config, hidden_size=hidden_size)
  109. self.num_patches = self.patch_embeddings.num_patches
  110. self.patch_size = config.patch_size
  111. def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
  112. num_frames = pixel_values_videos.shape[1]
  113. # Swap `frames` and `channels` dims, the result is:
  114. # (batch_size, channels, num_frames, height, width)
  115. pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
  116. # For some cases, if the input vision (image/video) consists of num_frames < tubelet_size,
  117. # then embedding lookup fails. In these cases, we duplicate the frames.
  118. if num_frames < self.config.tubelet_size:
  119. pixel_values_videos = pixel_values_videos.repeat(1, 1, self.config.tubelet_size, 1, 1)
  120. target_dtype = self.patch_embeddings.proj.weight.dtype
  121. pixel_values_videos = pixel_values_videos.to(dtype=target_dtype)
  122. embeddings = self.patch_embeddings(pixel_values_videos)
  123. return embeddings
  124. # Adapted from transformers.models.vit.modeling_vit.eager_attention_forward
  125. def eager_attention_forward(
  126. module: nn.Module,
  127. query: torch.Tensor,
  128. key: torch.Tensor,
  129. value: torch.Tensor,
  130. attention_mask: torch.Tensor | None,
  131. scaling: float,
  132. dropout: float = 0.0,
  133. **kwargs,
  134. ):
  135. # Take the dot product between "query" and "key" to get the raw attention scores.
  136. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  137. # Normalize the attention scores to probabilities.
  138. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  139. # This is actually dropping out entire tokens to attend to, which might
  140. # seem a bit unusual, but is taken from the original Transformer paper.
  141. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  142. attn_output = torch.matmul(attn_weights, value)
  143. attn_output = attn_output.transpose(1, 2).contiguous()
  144. return attn_output, attn_weights
  145. def rotate_queries_or_keys(x, pos):
  146. B, num_heads, N, D = x.size()
  147. # similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
  148. # they are computing this every time. instead HF style is to compute the inv_freq once and store it
  149. # -- compute angle for each position
  150. omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
  151. omega /= D / 2.0
  152. omega = 1.0 / 10000**omega # (D/2,)
  153. freq = pos.unsqueeze(-1) * omega # (..., N, D/2), outer product
  154. # -- build rotation matrix and apply
  155. emb_sin = freq.sin() # (..., N, D/2)
  156. emb_cos = freq.cos() # (..., N, D/2)
  157. emb_sin = emb_sin.repeat(1, 1, 1, 2)
  158. emb_cos = emb_cos.repeat(1, 1, 1, 2)
  159. # --
  160. y = x.unflatten(-1, (-1, 2))
  161. y1, y2 = y.unbind(dim=-1)
  162. y = torch.stack((-y2, y1), dim=-1)
  163. y = y.flatten(-2)
  164. return (x * emb_cos) + (y * emb_sin)
  165. class VJEPA2RopeAttention(nn.Module):
  166. def __init__(
  167. self,
  168. config: VJEPA2Config,
  169. hidden_size: int = 1024,
  170. num_attention_heads: int = 16,
  171. ):
  172. super().__init__()
  173. self.config = config
  174. self.hidden_size = hidden_size
  175. self.num_attention_heads = num_attention_heads
  176. if hidden_size % num_attention_heads != 0:
  177. raise ValueError(
  178. f"The hidden size {(hidden_size,)} is not a multiple of the number of attention "
  179. f"heads {num_attention_heads}."
  180. )
  181. self.attention_head_size = int(hidden_size / num_attention_heads)
  182. self.all_head_size = self.num_attention_heads * self.attention_head_size
  183. self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  184. self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  185. self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  186. self.proj = nn.Linear(hidden_size, hidden_size)
  187. self.dropout_prob = config.attention_probs_dropout_prob
  188. self.dropout = nn.Dropout(self.dropout_prob)
  189. self.grid_size = self.config.crop_size // self.config.patch_size
  190. self.grid_depth = self.config.frames_per_clip // self.config.tubelet_size
  191. self.d_dim = int(2 * ((self.attention_head_size // 3) // 2))
  192. self.h_dim = int(2 * ((self.attention_head_size // 3) // 2))
  193. self.w_dim = int(2 * ((self.attention_head_size // 3) // 2))
  194. self.scaling = self.attention_head_size**-0.5
  195. self.is_causal = False
  196. def _get_frame_pos(self, ids):
  197. tokens_per_frame = int(self.grid_size * self.grid_size)
  198. return ids // tokens_per_frame
  199. def _get_height_pos(self, ids):
  200. # Remove frame component from ids
  201. tokens_per_frame = int(self.grid_size * self.grid_size)
  202. frame_ids = self._get_frame_pos(ids)
  203. ids = ids - tokens_per_frame * frame_ids
  204. # --
  205. tokens_per_row = self.grid_size
  206. return ids // tokens_per_row
  207. def get_position_ids(self, x, masks=None):
  208. device = x.device
  209. token_size = x.size(1)
  210. # Note: when masks is none, we use a 1d id instead of Bxnum_attention_heads mask,
  211. # as 1d vector is broadcasted to the correct shapes.
  212. if masks is not None:
  213. ids = masks.unsqueeze(1).repeat(1, self.num_attention_heads, 1)
  214. else:
  215. ids = torch.arange(token_size, device=device)
  216. # change to allow for extrapolation
  217. tokens_per_frame = int(self.grid_size * self.grid_size)
  218. frame_ids = self._get_frame_pos(ids)
  219. # --
  220. tokens_per_row = self.grid_size
  221. height_ids = self._get_height_pos(ids)
  222. # --
  223. # Remove frame component from ids (1st term) and height component (2nd term)
  224. width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
  225. return frame_ids, height_ids, width_ids
  226. def apply_rotary_embeddings(self, qk, pos_ids):
  227. d_mask, h_mask, w_mask = pos_ids
  228. s = 0
  229. qkd = rotate_queries_or_keys(qk[..., s : s + self.d_dim], pos=d_mask)
  230. s += self.d_dim
  231. qkh = rotate_queries_or_keys(qk[..., s : s + self.h_dim], pos=h_mask)
  232. s += self.h_dim
  233. qkw = rotate_queries_or_keys(qk[..., s : s + self.w_dim], pos=w_mask)
  234. s += self.w_dim
  235. # Combine rotated dimension
  236. if s < self.attention_head_size:
  237. qkr = qk[..., s:]
  238. qk = torch.cat([qkd, qkh, qkw, qkr], dim=-1)
  239. else:
  240. qk = torch.cat([qkd, qkh, qkw], dim=-1)
  241. return qk
  242. def forward(
  243. self,
  244. hidden_states,
  245. position_mask: torch.Tensor | None = None,
  246. ) -> tuple[torch.Tensor, torch.Tensor]:
  247. input_shape = hidden_states.shape[:-1]
  248. hidden_shape = (*input_shape, -1, self.attention_head_size)
  249. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  250. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  251. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  252. pos_ids = self.get_position_ids(hidden_states, masks=position_mask)
  253. key_layer = self.apply_rotary_embeddings(key_layer, pos_ids)
  254. query_layer = self.apply_rotary_embeddings(query_layer, pos_ids)
  255. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  256. self.config._attn_implementation, eager_attention_forward
  257. )
  258. context_layer, attention_probs = attention_interface(
  259. self,
  260. query_layer,
  261. key_layer,
  262. value_layer,
  263. None,
  264. is_causal=self.is_causal,
  265. scaling=self.scaling,
  266. dropout=0.0 if not self.training else self.dropout_prob,
  267. )
  268. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  269. context_layer = self.proj(context_layer.reshape(new_context_layer_shape))
  270. return context_layer, attention_probs
  271. # Adapted from transformers.models.beit.modeling_dinov2.drop_path
  272. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  273. """
  274. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  275. """
  276. if drop_prob == 0.0 or not training:
  277. return input
  278. keep_prob = 1 - drop_prob
  279. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  280. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  281. random_tensor.floor_() # binarize
  282. output = input.div(keep_prob) * random_tensor
  283. return output
  284. # Adapted from transformers.models.beit.modeling_beit.BeitDropPath
  285. class VJEPA2DropPath(nn.Module):
  286. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  287. def __init__(self, drop_prob: float | None = None):
  288. super().__init__()
  289. self.drop_prob = drop_prob
  290. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  291. return drop_path(hidden_states, self.drop_prob, self.training)
  292. def extra_repr(self) -> str:
  293. return f"p={self.drop_prob}"
  294. class VJEPA2MLP(nn.Module):
  295. def __init__(self, config: VJEPA2Config, hidden_size: int = 1024, mlp_ratio: float = 4.0):
  296. super().__init__()
  297. in_features = out_features = hidden_size
  298. hidden_features = int(hidden_size * mlp_ratio)
  299. self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
  300. self.activation = ACT2FN[config.hidden_act]
  301. self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
  302. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  303. hidden_state = self.fc1(hidden_state)
  304. hidden_state = self.activation(hidden_state)
  305. hidden_state = self.fc2(hidden_state)
  306. return hidden_state
  307. class VJEPA2Layer(GradientCheckpointingLayer):
  308. """This corresponds to the Block class in the original implementation."""
  309. def __init__(
  310. self,
  311. config: VJEPA2Config,
  312. drop_path_rate: float = 0.0,
  313. hidden_size: int = 1024,
  314. num_attention_heads: int = 16,
  315. mlp_ratio: float = 4.0,
  316. ):
  317. super().__init__()
  318. self.config = config
  319. self.hidden_size = hidden_size
  320. self.num_attention_heads = num_attention_heads
  321. self.mlp_ratio = mlp_ratio
  322. self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  323. self.attention = VJEPA2RopeAttention(config, hidden_size, num_attention_heads)
  324. self.drop_path = VJEPA2DropPath(drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  325. self.norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  326. self.mlp = VJEPA2MLP(config, hidden_size=hidden_size, mlp_ratio=mlp_ratio)
  327. def forward(
  328. self,
  329. hidden_states: torch.Tensor,
  330. position_mask: torch.Tensor | None = None,
  331. **kwargs: Unpack[TransformersKwargs],
  332. ) -> tuple[torch.Tensor, ...]:
  333. # Self-Attention
  334. residual = hidden_states
  335. hidden_states = self.norm1(hidden_states)
  336. attention_output, attn_weights = self.attention(
  337. hidden_states,
  338. position_mask=position_mask, # position mask for context/target selection
  339. )
  340. hidden_states = self.drop_path(attention_output) + residual
  341. # MLP
  342. residual = hidden_states
  343. hidden_states = self.norm2(hidden_states)
  344. hidden_states = self.mlp(hidden_states)
  345. hidden_states = self.drop_path(hidden_states) + residual
  346. # Add self attentions if we output attention weights
  347. return hidden_states, attn_weights
  348. class VJEPA2Encoder(nn.Module):
  349. def __init__(self, config: VJEPA2Config):
  350. super().__init__()
  351. self.config = config
  352. self.embeddings = VJEPA2Embeddings(config, hidden_size=config.hidden_size)
  353. drop_path_rates = [
  354. (config.drop_path_rate * i / (config.num_hidden_layers - 1) if config.num_hidden_layers > 1 else 0.0)
  355. for i in range(config.num_hidden_layers)
  356. ]
  357. self.layer = nn.ModuleList(
  358. [
  359. VJEPA2Layer(
  360. config,
  361. drop_path_rate=drop_path_rates[i],
  362. hidden_size=config.hidden_size,
  363. num_attention_heads=config.num_attention_heads,
  364. mlp_ratio=config.mlp_ratio,
  365. )
  366. for i in range(config.num_hidden_layers)
  367. ]
  368. )
  369. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  370. self.gradient_checkpointing = False
  371. def forward(
  372. self,
  373. pixel_values_videos: torch.Tensor | None = None,
  374. **kwargs: Unpack[TransformersKwargs],
  375. ) -> BaseModelOutput:
  376. hidden_states = self.embeddings(pixel_values_videos)
  377. for i, layer_module in enumerate(self.layer):
  378. layer_outputs = layer_module(hidden_states, None, **kwargs)
  379. hidden_states = layer_outputs[0]
  380. hidden_states = self.layernorm(hidden_states)
  381. return BaseModelOutput(
  382. last_hidden_state=hidden_states,
  383. )
  384. def apply_masks(tensor: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor:
  385. """
  386. Args:
  387. tensor (`torch.Tensor`):
  388. Tensor of shape [batch_size, num_patches, feature_dim]
  389. masks (`List[torch.Tensor]`):
  390. List of tensors of shape [batch_size, num_patches] containing indices of patches to keep
  391. """
  392. all_masked_tensors = []
  393. for mask in masks:
  394. mask = mask.to(tensor.device)
  395. mask_keep = mask.unsqueeze(-1).repeat(1, 1, tensor.size(-1))
  396. all_masked_tensors += [torch.gather(tensor, dim=1, index=mask_keep)]
  397. return torch.cat(all_masked_tensors, dim=0)
  398. class VJEPA2PredictorEmbeddings(nn.Module):
  399. """
  400. Construct mask token, position and patch embeddings.
  401. """
  402. def __init__(self, config: VJEPA2Config):
  403. super().__init__()
  404. self.config = config
  405. self.predictor_embeddings = nn.Linear(config.hidden_size, config.pred_hidden_size)
  406. self.num_mask_tokens = 0
  407. self.zero_init_mask_tokens = config.pred_zero_init_mask_tokens
  408. self.num_mask_tokens = config.pred_num_mask_tokens
  409. self.mask_tokens = nn.Parameter(torch.zeros(self.num_mask_tokens, 1, 1, config.pred_hidden_size))
  410. self.patch_size = config.patch_size
  411. self.config = config
  412. @staticmethod
  413. def num_patches(config):
  414. if config.frames_per_clip > 1:
  415. return (
  416. (config.frames_per_clip // config.tubelet_size)
  417. * (config.crop_size // config.patch_size)
  418. * (config.crop_size // config.patch_size)
  419. )
  420. else:
  421. return (config.crop_size // config.patch_size) * (config.crop_size // config.patch_size)
  422. def forward(
  423. self,
  424. hidden_states: torch.Tensor,
  425. context_mask: list[torch.Tensor],
  426. target_mask: list[torch.Tensor],
  427. mask_index: int = 1,
  428. ) -> tuple[torch.Tensor, torch.Tensor]:
  429. """
  430. hidden_states : encoder outputs (context)
  431. context_mask: tokens of the context (outputs from the encoder)
  432. target_mask: tokens to predict
  433. mask_index: index of the target mask to choose (useful for multiclip?)
  434. """
  435. B = hidden_states.size(0)
  436. context = self.predictor_embeddings(hidden_states)
  437. # Make target tokens
  438. mask_index = mask_index % self.num_mask_tokens
  439. target = self.mask_tokens[mask_index]
  440. # Note: this is problematic if the config isn't initialized with the right frames_per_clip value,
  441. # e.g. for scenarios if we want to run predictor for more tokens than in the config.
  442. # target = target.repeat(B, self.num_patches(self.config), 1)
  443. # Remedy: use the provided target mask to get the max patch num
  444. max_patch_num = target_mask[0].max() + 1 # one extra to include the last patch
  445. target = target.repeat(B, max_patch_num, 1)
  446. target = apply_masks(target, target_mask)
  447. # Concatenate context & target tokens
  448. context = context.repeat(len(context_mask), 1, 1)
  449. embeddings = torch.cat([context, target], dim=1)
  450. # Positions of context & target tokens
  451. cm = torch.cat(context_mask, dim=0)
  452. tm = torch.cat(target_mask, dim=0)
  453. masks = torch.cat([cm, tm], dim=1)
  454. return embeddings, masks
  455. class VJEPA2Predictor(nn.Module):
  456. def __init__(self, config: VJEPA2Config):
  457. super().__init__()
  458. self.config = config
  459. self.gradient_checkpointing = False
  460. self.embeddings = VJEPA2PredictorEmbeddings(config)
  461. drop_path_rates = [
  462. (
  463. config.drop_path_rate * i / (config.pred_num_hidden_layers - 1)
  464. if config.pred_num_hidden_layers > 1
  465. else 0.0
  466. )
  467. for i in range(config.pred_num_hidden_layers)
  468. ]
  469. self.layer = nn.ModuleList(
  470. [
  471. VJEPA2Layer(
  472. config,
  473. drop_path_rate=drop_path_rates[i],
  474. hidden_size=config.pred_hidden_size,
  475. num_attention_heads=config.pred_num_attention_heads,
  476. mlp_ratio=config.pred_mlp_ratio,
  477. )
  478. for i in range(config.pred_num_hidden_layers)
  479. ]
  480. )
  481. self.layernorm = nn.LayerNorm(config.pred_hidden_size, eps=config.layer_norm_eps)
  482. self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
  483. def sort_tokens(self, hidden_states, position_masks, argsort):
  484. # gather position masks
  485. argsort = argsort.to(position_masks.device)
  486. position_masks = torch.gather(position_masks, dim=1, index=argsort)
  487. # gather hidden states
  488. argsort = argsort.to(hidden_states.device)
  489. hidden_states_argsort = argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
  490. hidden_states = torch.gather(hidden_states, dim=1, index=hidden_states_argsort)
  491. return hidden_states, position_masks
  492. def unsort_tokens(self, hidden_states, argsort):
  493. argsort = argsort.to(hidden_states.device)
  494. reverse_argsort = torch.argsort(argsort, dim=1)
  495. reverse_argsort = reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
  496. hidden_states = torch.gather(hidden_states, dim=1, index=reverse_argsort)
  497. return hidden_states
  498. def forward(
  499. self,
  500. encoder_hidden_states: torch.Tensor,
  501. context_mask: list[torch.Tensor],
  502. target_mask: list[torch.Tensor],
  503. **kwargs: Unpack[TransformersKwargs],
  504. ) -> BaseModelOutput:
  505. # mask out the encoder hidden states
  506. # this is implemented here as in VJEPA training a separate encoder is used for target
  507. encoder_hidden_states = apply_masks(encoder_hidden_states, context_mask)
  508. _, N_ctxt, D = encoder_hidden_states.shape
  509. hidden_states, position_masks = self.embeddings(encoder_hidden_states, context_mask, target_mask)
  510. # Put tokens in sorted order
  511. argsort = torch.argsort(position_masks, dim=1) # [B, N]
  512. hidden_states, position_masks = self.sort_tokens(hidden_states, position_masks, argsort)
  513. for i, layer_module in enumerate(self.layer):
  514. layer_outputs = layer_module(hidden_states, position_masks, **kwargs)
  515. hidden_states = layer_outputs[0]
  516. hidden_states = self.layernorm(hidden_states)
  517. # unsort and extract the predicted tokens
  518. hidden_states = self.unsort_tokens(hidden_states, argsort)
  519. hidden_states = hidden_states[:, N_ctxt:]
  520. # projection
  521. hidden_states = self.proj(hidden_states)
  522. return BaseModelOutput(
  523. last_hidden_state=hidden_states,
  524. )
  525. class VJEPA2PoolerSelfAttention(nn.Module):
  526. """Multi-headed attention from 'Attention Is All You Need' paper"""
  527. def __init__(self, config: VJEPA2Config):
  528. super().__init__()
  529. self.config = config
  530. self.embed_dim = config.hidden_size
  531. self.num_heads = config.num_attention_heads
  532. self.head_dim = self.embed_dim // self.num_heads
  533. if self.head_dim * self.num_heads != self.embed_dim:
  534. raise ValueError(
  535. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  536. f" {self.num_heads})."
  537. )
  538. self.scale = self.head_dim**-0.5
  539. self.dropout = config.attention_dropout
  540. self.is_causal = False
  541. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  542. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  543. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  544. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  545. def forward(
  546. self,
  547. hidden_states: torch.Tensor,
  548. attention_mask: torch.Tensor | None = None,
  549. ) -> tuple[torch.Tensor, torch.Tensor]:
  550. """Input shape: Batch x Time x Channel"""
  551. input_shape = hidden_states.shape[:-1]
  552. hidden_shape = (*input_shape, -1, self.head_dim)
  553. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  554. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  555. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  556. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  557. self.config._attn_implementation, eager_attention_forward
  558. )
  559. attn_output, attn_weights = attention_interface(
  560. self,
  561. queries,
  562. keys,
  563. values,
  564. attention_mask,
  565. is_causal=self.is_causal,
  566. scaling=self.scale,
  567. dropout=0.0 if not self.training else self.dropout,
  568. )
  569. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  570. attn_output = self.out_proj(attn_output)
  571. return attn_output, attn_weights
  572. class VJEPA2PoolerCrossAttention(nn.Module):
  573. """It's different from other cross-attention layers, doesn't have output projection layer (o_proj)"""
  574. # in case of modular refactoring - o_proj can be replaces with nn.Identity()
  575. def __init__(self, config: VJEPA2Config):
  576. super().__init__()
  577. self.config = config
  578. self.embed_dim = config.hidden_size
  579. self.num_heads = config.num_attention_heads
  580. self.head_dim = self.embed_dim // self.num_heads
  581. if self.head_dim * self.num_heads != self.embed_dim:
  582. raise ValueError(
  583. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  584. f" {self.num_heads})."
  585. )
  586. self.scale = self.head_dim**-0.5
  587. self.dropout = config.attention_dropout
  588. self.is_causal = False
  589. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  590. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  591. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  592. def forward(
  593. self,
  594. queries: torch.Tensor,
  595. keys: torch.Tensor,
  596. values: torch.Tensor,
  597. attention_mask: torch.Tensor | None = None,
  598. ) -> tuple[torch.Tensor, torch.Tensor]:
  599. """Input shape: Batch x Time x Channel"""
  600. batch_size, q_seq_length, embed_dim = queries.shape
  601. kv_seq_length = keys.shape[1]
  602. queries = self.q_proj(queries)
  603. keys = self.k_proj(keys)
  604. values = self.v_proj(values)
  605. queries = queries.view(batch_size, q_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  606. keys = keys.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  607. values = values.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  608. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  609. self.config._attn_implementation, eager_attention_forward
  610. )
  611. attn_output, attn_weights = attention_interface(
  612. self,
  613. queries,
  614. keys,
  615. values,
  616. attention_mask,
  617. is_causal=self.is_causal,
  618. scaling=self.scale,
  619. dropout=0.0 if not self.training else self.dropout,
  620. )
  621. attn_output = attn_output.reshape(batch_size, q_seq_length, embed_dim).contiguous()
  622. return attn_output, attn_weights
  623. # Modified from SiglipEncoderLayer, but we have to propagate proper hidden_size to VJEPA2MLP
  624. class VJEPA2PoolerSelfAttentionLayer(GradientCheckpointingLayer):
  625. def __init__(self, config: VJEPA2Config):
  626. super().__init__()
  627. self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  628. self.self_attn = VJEPA2PoolerSelfAttention(config)
  629. self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  630. self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
  631. def forward(
  632. self,
  633. hidden_states: torch.Tensor,
  634. attention_mask: torch.Tensor,
  635. ) -> tuple[torch.Tensor, torch.Tensor]:
  636. """
  637. Args:
  638. hidden_states (`torch.FloatTensor`):
  639. Input to the layer of shape `(batch, seq_len, embed_dim)`.
  640. attention_mask (`torch.FloatTensor`):
  641. Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
  642. """
  643. residual = hidden_states
  644. hidden_states = self.layer_norm1(hidden_states)
  645. hidden_states, attn_weights = self.self_attn(
  646. hidden_states=hidden_states,
  647. attention_mask=attention_mask,
  648. )
  649. hidden_states = residual + hidden_states
  650. residual = hidden_states
  651. hidden_states = self.layer_norm2(hidden_states)
  652. hidden_states = self.mlp(hidden_states)
  653. hidden_states = residual + hidden_states
  654. return hidden_states, attn_weights
  655. class VJEPA2PoolerCrossAttentionLayer(GradientCheckpointingLayer):
  656. def __init__(self, config: VJEPA2Config):
  657. super().__init__()
  658. self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  659. self.cross_attn = VJEPA2PoolerCrossAttention(config)
  660. self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  661. self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
  662. def forward(
  663. self,
  664. queries: torch.Tensor,
  665. hidden_state: torch.Tensor,
  666. attention_mask: torch.Tensor | None = None,
  667. ) -> tuple[torch.Tensor, torch.Tensor]:
  668. # Apply cross-attention
  669. residual = queries
  670. hidden_state = self.layer_norm1(hidden_state)
  671. hidden_state, *attn_weights = self.cross_attn(
  672. queries,
  673. hidden_state,
  674. hidden_state,
  675. attention_mask=attention_mask,
  676. )
  677. hidden_state = residual + hidden_state
  678. # Apply MLP
  679. residual = hidden_state
  680. hidden_state = self.layer_norm2(hidden_state)
  681. hidden_state = self.mlp(hidden_state)
  682. hidden_state = residual + hidden_state
  683. return hidden_state, *attn_weights
  684. class VJEPA2AttentivePooler(nn.Module):
  685. """Attentive Pooler"""
  686. def __init__(self, config: VJEPA2Config):
  687. super().__init__()
  688. self.query_tokens = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  689. self.cross_attention_layer = VJEPA2PoolerCrossAttentionLayer(config)
  690. self.self_attention_layers = nn.ModuleList(
  691. [VJEPA2PoolerSelfAttentionLayer(config) for _ in range(config.num_pooler_layers)]
  692. )
  693. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  694. for layer in self.self_attention_layers:
  695. hidden_state = layer(hidden_state, attention_mask=None)[0]
  696. queries = self.query_tokens.repeat(hidden_state.shape[0], 1, 1)
  697. hidden_state = self.cross_attention_layer(queries, hidden_state)[0]
  698. return hidden_state.squeeze(1)
  699. @auto_docstring
  700. class VJEPA2PreTrainedModel(PreTrainedModel):
  701. config: VJEPA2Config
  702. base_model_prefix = "vjepa2"
  703. main_input_name = "pixel_values_videos"
  704. input_modalities = "video"
  705. supports_gradient_checkpointing = True
  706. _no_split_modules = [
  707. "VJEPA2Layer",
  708. "VJEPA2PoolerSelfAttentionLayer",
  709. "VJEPA2PoolerCrossAttentionLayer",
  710. "VJEPA2PredictorEmbeddings",
  711. ]
  712. _supports_sdpa = True
  713. _supports_flash_attn = True
  714. _can_record_outputs = {
  715. "hidden_states": OutputRecorder(VJEPA2Layer, layer_name="encoder.layer"),
  716. "attentions": OutputRecorder(VJEPA2RopeAttention, index=1, layer_name="encoder.layer"),
  717. }
  718. @torch.no_grad()
  719. def _init_weights(self, module):
  720. """Initialize the weights"""
  721. init_std = self.config.initializer_range
  722. if isinstance(module, VJEPA2AttentivePooler):
  723. init.trunc_normal_(module.query_tokens, std=init_std)
  724. for i, layer in enumerate(module.self_attention_layers, 1):
  725. std = init_std / (i**0.5)
  726. init.trunc_normal_(layer.self_attn.out_proj.weight, std=std)
  727. init.trunc_normal_(layer.mlp.fc2.weight, std=std)
  728. std = init_std / (len(module.self_attention_layers) + 1) ** 0.5
  729. init.trunc_normal_(module.cross_attention_layer.mlp.fc2.weight, std=std)
  730. elif isinstance(module, VJEPA2PredictorEmbeddings):
  731. if module.zero_init_mask_tokens:
  732. init.zeros_(module.mask_tokens)
  733. else:
  734. init.trunc_normal_(module.mask_tokens, std=init_std)
  735. elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
  736. init.trunc_normal_(module.weight, std=init_std)
  737. if module.bias is not None:
  738. init.zeros_(module.bias)
  739. elif isinstance(module, nn.LayerNorm):
  740. init.zeros_(module.bias)
  741. init.ones_(module.weight)
  742. @auto_docstring
  743. class VJEPA2Model(VJEPA2PreTrainedModel):
  744. def __init__(self, config: VJEPA2Config):
  745. super().__init__(config)
  746. self.config = config
  747. self.encoder = VJEPA2Encoder(config)
  748. self.predictor = VJEPA2Predictor(config)
  749. # Initialize weights and apply final processing
  750. self.post_init()
  751. def get_input_embeddings(self) -> VJEPA2PatchEmbeddings3D:
  752. return self.encoder.embeddings.patch_embeddings
  753. @merge_with_config_defaults
  754. @capture_outputs(tie_last_hidden_states=False)
  755. @auto_docstring
  756. def forward(
  757. self,
  758. pixel_values_videos: torch.Tensor,
  759. context_mask: list[torch.Tensor] | None = None,
  760. target_mask: list[torch.Tensor] | None = None,
  761. skip_predictor: bool = False,
  762. **kwargs: Unpack[TransformersKwargs],
  763. ) -> VJEPA2WithMaskedInputModelOutput:
  764. r"""
  765. context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
  766. The mask position ids indicating which encoder output patches are going to be exposed to the predictor.
  767. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context
  768. available to the predictor.
  769. target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
  770. The mask position ids indicating which encoder output patches are going to be used as a prediction target
  771. for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating
  772. that the predictor should predict all encoder patches.
  773. skip_predictor (bool):
  774. flag to skip the predictor forward, useful if you just need the encoder outputs
  775. """
  776. if pixel_values_videos is None:
  777. raise ValueError("You have to specify pixel_values_videos")
  778. encoder_outputs: BaseModelOutput = self.encoder(
  779. pixel_values_videos=pixel_values_videos,
  780. **kwargs,
  781. )
  782. sequence_output = encoder_outputs.last_hidden_state
  783. if context_mask is None and target_mask is None:
  784. B = pixel_values_videos.size(0)
  785. N = sequence_output.size(1) # ensure we are using dynamic patch size
  786. context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
  787. target_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
  788. if not skip_predictor:
  789. predictor_outputs: BaseModelOutput = self.predictor(
  790. encoder_hidden_states=sequence_output,
  791. context_mask=context_mask,
  792. target_mask=target_mask,
  793. **kwargs,
  794. )
  795. predictor_output = VJEPA2WithMaskedInputPredictorOutput(
  796. last_hidden_state=predictor_outputs.last_hidden_state,
  797. target_hidden_state=apply_masks(sequence_output, target_mask),
  798. hidden_states=predictor_outputs.hidden_states,
  799. attentions=predictor_outputs.attentions,
  800. )
  801. else:
  802. predictor_output = None
  803. encoder_output = VJEPA2WithMaskedInputModelOutput(
  804. last_hidden_state=sequence_output,
  805. masked_hidden_state=apply_masks(sequence_output, context_mask),
  806. hidden_states=encoder_outputs.hidden_states,
  807. attentions=encoder_outputs.attentions,
  808. predictor_output=predictor_output,
  809. )
  810. return encoder_output
  811. def get_vision_features(self, pixel_values_videos) -> torch.Tensor:
  812. encoder_output = self.forward(pixel_values_videos, skip_predictor=True)
  813. return encoder_output.last_hidden_state
  814. @auto_docstring(
  815. custom_intro="""
  816. V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler).
  817. """
  818. )
  819. class VJEPA2ForVideoClassification(VJEPA2PreTrainedModel):
  820. def __init__(self, config: VJEPA2Config):
  821. super().__init__(config)
  822. self.num_labels = config.num_labels
  823. self.vjepa2 = VJEPA2Model(config)
  824. # Classifier head
  825. self.pooler = VJEPA2AttentivePooler(config)
  826. self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=True)
  827. # Initialize weights and apply final processing
  828. self.post_init()
  829. @can_return_tuple
  830. @auto_docstring
  831. def forward(
  832. self,
  833. pixel_values_videos: torch.Tensor,
  834. labels: torch.Tensor | None = None,
  835. **kwargs: Unpack[TransformersKwargs],
  836. ) -> tuple | ImageClassifierOutput:
  837. r"""
  838. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  839. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  840. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  841. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  842. Examples:
  843. ```python
  844. >>> import torch
  845. >>> import numpy as np
  846. >>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification
  847. >>> device = "cuda"
  848. >>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
  849. >>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device)
  850. >>> video = np.ones((64, 256, 256, 3)) # 64 frames, 256x256 RGB
  851. >>> inputs = video_processor(video, return_tensors="pt").to(device)
  852. >>> # For inference
  853. >>> with torch.no_grad():
  854. ... outputs = model(**inputs)
  855. >>> logits = outputs.logits
  856. >>> predicted_label = logits.argmax(-1).item()
  857. >>> print(model.config.id2label[predicted_label])
  858. >>> # For training
  859. >>> labels = torch.ones(1, dtype=torch.long, device=device)
  860. >>> loss = model(**inputs, labels=labels).loss
  861. ```"""
  862. outputs = self.vjepa2(
  863. pixel_values_videos=pixel_values_videos,
  864. skip_predictor=True,
  865. **kwargs,
  866. )
  867. last_hidden_state = outputs.last_hidden_state
  868. pooler_output = self.pooler(last_hidden_state)
  869. logits = self.classifier(pooler_output)
  870. loss = None
  871. if labels is not None:
  872. loss = self.loss_function(pooled_logits=logits, labels=labels, config=self.config)
  873. return ImageClassifierOutput(
  874. loss=loss,
  875. logits=logits,
  876. hidden_states=outputs.hidden_states,
  877. attentions=outputs.attentions,
  878. )
  879. __all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel", "VJEPA2ForVideoClassification"]