| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063 |
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Callable
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from .configuration_vjepa2 import VJEPA2Config
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring(
- custom_intro="""
- VJEPA Predictor outputs that also contains the masked encoder outputs
- """
- )
- class VJEPA2WithMaskedInputPredictorOutput(ModelOutput):
- r"""
- masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
- The masked hidden state of the model.
- target_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `target_mask` is provided which is applied on VJEPA2Encoder outputs):
- The target hidden state of the model.
- """
- last_hidden_state: torch.FloatTensor
- masked_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- target_hidden_state: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring(
- custom_intro="""
- VJEPA outputs that also contains the masked encoder outputs
- Optionally contains the predictor outputs
- """
- )
- class VJEPA2WithMaskedInputModelOutput(ModelOutput):
- r"""
- masked_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*, returned when `context_mask` is provided which is applied on VJEPA2Encoder outputs):
- The masked hidden state of the model.
- predictor_output (`VJEPA2WithMaskedInputPredictorOutput`, *optional*):
- The output from the Predictor module.
- """
- last_hidden_state: torch.FloatTensor
- masked_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- predictor_output: VJEPA2WithMaskedInputPredictorOutput | None = None
- def to_tuple(self):
- output = list(super().to_tuple())
- if isinstance(output[-1], VJEPA2WithMaskedInputPredictorOutput):
- output[-1] = output[-1].to_tuple()
- return tuple(output)
- class VJEPA2PatchEmbeddings3D(nn.Module):
- """
- Image to Patch Embedding
- """
- def __init__(
- self,
- config: VJEPA2Config,
- hidden_size: int = 1024,
- ):
- super().__init__()
- self.patch_size = config.patch_size
- self.tubelet_size = config.tubelet_size
- self.hidden_size = hidden_size
- self.proj = nn.Conv3d(
- in_channels=config.in_chans,
- out_channels=hidden_size,
- kernel_size=(config.tubelet_size, config.patch_size, config.patch_size),
- stride=(config.tubelet_size, config.patch_size, config.patch_size),
- )
- @staticmethod
- def num_patches(config):
- return (
- (config.frames_per_clip // config.tubelet_size)
- * (config.crop_size // config.patch_size)
- * (config.crop_size // config.patch_size)
- )
- def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
- x = self.proj(pixel_values_videos).flatten(2).transpose(1, 2)
- return x
- class VJEPA2Embeddings(nn.Module):
- """
- Construct mask token, position and patch embeddings.
- """
- def __init__(self, config: VJEPA2Config, hidden_size: int = 1024):
- super().__init__()
- self.config = config
- self.hidden_size = hidden_size
- self.patch_embeddings = VJEPA2PatchEmbeddings3D(config, hidden_size=hidden_size)
- self.num_patches = self.patch_embeddings.num_patches
- self.patch_size = config.patch_size
- def forward(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
- num_frames = pixel_values_videos.shape[1]
- # Swap `frames` and `channels` dims, the result is:
- # (batch_size, channels, num_frames, height, width)
- pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
- # For some cases, if the input vision (image/video) consists of num_frames < tubelet_size,
- # then embedding lookup fails. In these cases, we duplicate the frames.
- if num_frames < self.config.tubelet_size:
- pixel_values_videos = pixel_values_videos.repeat(1, 1, self.config.tubelet_size, 1, 1)
- target_dtype = self.patch_embeddings.proj.weight.dtype
- pixel_values_videos = pixel_values_videos.to(dtype=target_dtype)
- embeddings = self.patch_embeddings(pixel_values_videos)
- return embeddings
- # Adapted from transformers.models.vit.modeling_vit.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
- ):
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
- # Normalize the attention scores to probabilities.
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- def rotate_queries_or_keys(x, pos):
- B, num_heads, N, D = x.size()
- # similar to inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
- # they are computing this every time. instead HF style is to compute the inv_freq once and store it
- # -- compute angle for each position
- omega = torch.arange(D // 2, dtype=x.dtype, device=x.device)
- omega /= D / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
- freq = pos.unsqueeze(-1) * omega # (..., N, D/2), outer product
- # -- build rotation matrix and apply
- emb_sin = freq.sin() # (..., N, D/2)
- emb_cos = freq.cos() # (..., N, D/2)
- emb_sin = emb_sin.repeat(1, 1, 1, 2)
- emb_cos = emb_cos.repeat(1, 1, 1, 2)
- # --
- y = x.unflatten(-1, (-1, 2))
- y1, y2 = y.unbind(dim=-1)
- y = torch.stack((-y2, y1), dim=-1)
- y = y.flatten(-2)
- return (x * emb_cos) + (y * emb_sin)
- class VJEPA2RopeAttention(nn.Module):
- def __init__(
- self,
- config: VJEPA2Config,
- hidden_size: int = 1024,
- num_attention_heads: int = 16,
- ):
- super().__init__()
- self.config = config
- self.hidden_size = hidden_size
- self.num_attention_heads = num_attention_heads
- if hidden_size % num_attention_heads != 0:
- raise ValueError(
- f"The hidden size {(hidden_size,)} is not a multiple of the number of attention "
- f"heads {num_attention_heads}."
- )
- self.attention_head_size = int(hidden_size / num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
- self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
- self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
- self.proj = nn.Linear(hidden_size, hidden_size)
- self.dropout_prob = config.attention_probs_dropout_prob
- self.dropout = nn.Dropout(self.dropout_prob)
- self.grid_size = self.config.crop_size // self.config.patch_size
- self.grid_depth = self.config.frames_per_clip // self.config.tubelet_size
- self.d_dim = int(2 * ((self.attention_head_size // 3) // 2))
- self.h_dim = int(2 * ((self.attention_head_size // 3) // 2))
- self.w_dim = int(2 * ((self.attention_head_size // 3) // 2))
- self.scaling = self.attention_head_size**-0.5
- self.is_causal = False
- def _get_frame_pos(self, ids):
- tokens_per_frame = int(self.grid_size * self.grid_size)
- return ids // tokens_per_frame
- def _get_height_pos(self, ids):
- # Remove frame component from ids
- tokens_per_frame = int(self.grid_size * self.grid_size)
- frame_ids = self._get_frame_pos(ids)
- ids = ids - tokens_per_frame * frame_ids
- # --
- tokens_per_row = self.grid_size
- return ids // tokens_per_row
- def get_position_ids(self, x, masks=None):
- device = x.device
- token_size = x.size(1)
- # Note: when masks is none, we use a 1d id instead of Bxnum_attention_heads mask,
- # as 1d vector is broadcasted to the correct shapes.
- if masks is not None:
- ids = masks.unsqueeze(1).repeat(1, self.num_attention_heads, 1)
- else:
- ids = torch.arange(token_size, device=device)
- # change to allow for extrapolation
- tokens_per_frame = int(self.grid_size * self.grid_size)
- frame_ids = self._get_frame_pos(ids)
- # --
- tokens_per_row = self.grid_size
- height_ids = self._get_height_pos(ids)
- # --
- # Remove frame component from ids (1st term) and height component (2nd term)
- width_ids = (ids - tokens_per_frame * frame_ids) - tokens_per_row * height_ids
- return frame_ids, height_ids, width_ids
- def apply_rotary_embeddings(self, qk, pos_ids):
- d_mask, h_mask, w_mask = pos_ids
- s = 0
- qkd = rotate_queries_or_keys(qk[..., s : s + self.d_dim], pos=d_mask)
- s += self.d_dim
- qkh = rotate_queries_or_keys(qk[..., s : s + self.h_dim], pos=h_mask)
- s += self.h_dim
- qkw = rotate_queries_or_keys(qk[..., s : s + self.w_dim], pos=w_mask)
- s += self.w_dim
- # Combine rotated dimension
- if s < self.attention_head_size:
- qkr = qk[..., s:]
- qk = torch.cat([qkd, qkh, qkw, qkr], dim=-1)
- else:
- qk = torch.cat([qkd, qkh, qkw], dim=-1)
- return qk
- def forward(
- self,
- hidden_states,
- position_mask: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.attention_head_size)
- query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
- key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
- value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
- pos_ids = self.get_position_ids(hidden_states, masks=position_mask)
- key_layer = self.apply_rotary_embeddings(key_layer, pos_ids)
- query_layer = self.apply_rotary_embeddings(query_layer, pos_ids)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- context_layer, attention_probs = attention_interface(
- self,
- query_layer,
- key_layer,
- value_layer,
- None,
- is_causal=self.is_causal,
- scaling=self.scaling,
- dropout=0.0 if not self.training else self.dropout_prob,
- )
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = self.proj(context_layer.reshape(new_context_layer_shape))
- return context_layer, attention_probs
- # Adapted from transformers.models.beit.modeling_dinov2.drop_path
- def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
- if drop_prob == 0.0 or not training:
- return input
- keep_prob = 1 - drop_prob
- shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
- random_tensor.floor_() # binarize
- output = input.div(keep_prob) * random_tensor
- return output
- # Adapted from transformers.models.beit.modeling_beit.BeitDropPath
- class VJEPA2DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob: float | None = None):
- super().__init__()
- self.drop_prob = drop_prob
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return drop_path(hidden_states, self.drop_prob, self.training)
- def extra_repr(self) -> str:
- return f"p={self.drop_prob}"
- class VJEPA2MLP(nn.Module):
- def __init__(self, config: VJEPA2Config, hidden_size: int = 1024, mlp_ratio: float = 4.0):
- super().__init__()
- in_features = out_features = hidden_size
- hidden_features = int(hidden_size * mlp_ratio)
- self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
- self.activation = ACT2FN[config.hidden_act]
- self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- hidden_state = self.fc1(hidden_state)
- hidden_state = self.activation(hidden_state)
- hidden_state = self.fc2(hidden_state)
- return hidden_state
- class VJEPA2Layer(GradientCheckpointingLayer):
- """This corresponds to the Block class in the original implementation."""
- def __init__(
- self,
- config: VJEPA2Config,
- drop_path_rate: float = 0.0,
- hidden_size: int = 1024,
- num_attention_heads: int = 16,
- mlp_ratio: float = 4.0,
- ):
- super().__init__()
- self.config = config
- self.hidden_size = hidden_size
- self.num_attention_heads = num_attention_heads
- self.mlp_ratio = mlp_ratio
- self.norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
- self.attention = VJEPA2RopeAttention(config, hidden_size, num_attention_heads)
- self.drop_path = VJEPA2DropPath(drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
- self.norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
- self.mlp = VJEPA2MLP(config, hidden_size=hidden_size, mlp_ratio=mlp_ratio)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, ...]:
- # Self-Attention
- residual = hidden_states
- hidden_states = self.norm1(hidden_states)
- attention_output, attn_weights = self.attention(
- hidden_states,
- position_mask=position_mask, # position mask for context/target selection
- )
- hidden_states = self.drop_path(attention_output) + residual
- # MLP
- residual = hidden_states
- hidden_states = self.norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.drop_path(hidden_states) + residual
- # Add self attentions if we output attention weights
- return hidden_states, attn_weights
- class VJEPA2Encoder(nn.Module):
- def __init__(self, config: VJEPA2Config):
- super().__init__()
- self.config = config
- self.embeddings = VJEPA2Embeddings(config, hidden_size=config.hidden_size)
- drop_path_rates = [
- (config.drop_path_rate * i / (config.num_hidden_layers - 1) if config.num_hidden_layers > 1 else 0.0)
- for i in range(config.num_hidden_layers)
- ]
- self.layer = nn.ModuleList(
- [
- VJEPA2Layer(
- config,
- drop_path_rate=drop_path_rates[i],
- hidden_size=config.hidden_size,
- num_attention_heads=config.num_attention_heads,
- mlp_ratio=config.mlp_ratio,
- )
- for i in range(config.num_hidden_layers)
- ]
- )
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.gradient_checkpointing = False
- def forward(
- self,
- pixel_values_videos: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- hidden_states = self.embeddings(pixel_values_videos)
- for i, layer_module in enumerate(self.layer):
- layer_outputs = layer_module(hidden_states, None, **kwargs)
- hidden_states = layer_outputs[0]
- hidden_states = self.layernorm(hidden_states)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- )
- def apply_masks(tensor: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor:
- """
- Args:
- tensor (`torch.Tensor`):
- Tensor of shape [batch_size, num_patches, feature_dim]
- masks (`List[torch.Tensor]`):
- List of tensors of shape [batch_size, num_patches] containing indices of patches to keep
- """
- all_masked_tensors = []
- for mask in masks:
- mask = mask.to(tensor.device)
- mask_keep = mask.unsqueeze(-1).repeat(1, 1, tensor.size(-1))
- all_masked_tensors += [torch.gather(tensor, dim=1, index=mask_keep)]
- return torch.cat(all_masked_tensors, dim=0)
- class VJEPA2PredictorEmbeddings(nn.Module):
- """
- Construct mask token, position and patch embeddings.
- """
- def __init__(self, config: VJEPA2Config):
- super().__init__()
- self.config = config
- self.predictor_embeddings = nn.Linear(config.hidden_size, config.pred_hidden_size)
- self.num_mask_tokens = 0
- self.zero_init_mask_tokens = config.pred_zero_init_mask_tokens
- self.num_mask_tokens = config.pred_num_mask_tokens
- self.mask_tokens = nn.Parameter(torch.zeros(self.num_mask_tokens, 1, 1, config.pred_hidden_size))
- self.patch_size = config.patch_size
- self.config = config
- @staticmethod
- def num_patches(config):
- if config.frames_per_clip > 1:
- return (
- (config.frames_per_clip // config.tubelet_size)
- * (config.crop_size // config.patch_size)
- * (config.crop_size // config.patch_size)
- )
- else:
- return (config.crop_size // config.patch_size) * (config.crop_size // config.patch_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- context_mask: list[torch.Tensor],
- target_mask: list[torch.Tensor],
- mask_index: int = 1,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- hidden_states : encoder outputs (context)
- context_mask: tokens of the context (outputs from the encoder)
- target_mask: tokens to predict
- mask_index: index of the target mask to choose (useful for multiclip?)
- """
- B = hidden_states.size(0)
- context = self.predictor_embeddings(hidden_states)
- # Make target tokens
- mask_index = mask_index % self.num_mask_tokens
- target = self.mask_tokens[mask_index]
- # Note: this is problematic if the config isn't initialized with the right frames_per_clip value,
- # e.g. for scenarios if we want to run predictor for more tokens than in the config.
- # target = target.repeat(B, self.num_patches(self.config), 1)
- # Remedy: use the provided target mask to get the max patch num
- max_patch_num = target_mask[0].max() + 1 # one extra to include the last patch
- target = target.repeat(B, max_patch_num, 1)
- target = apply_masks(target, target_mask)
- # Concatenate context & target tokens
- context = context.repeat(len(context_mask), 1, 1)
- embeddings = torch.cat([context, target], dim=1)
- # Positions of context & target tokens
- cm = torch.cat(context_mask, dim=0)
- tm = torch.cat(target_mask, dim=0)
- masks = torch.cat([cm, tm], dim=1)
- return embeddings, masks
- class VJEPA2Predictor(nn.Module):
- def __init__(self, config: VJEPA2Config):
- super().__init__()
- self.config = config
- self.gradient_checkpointing = False
- self.embeddings = VJEPA2PredictorEmbeddings(config)
- drop_path_rates = [
- (
- config.drop_path_rate * i / (config.pred_num_hidden_layers - 1)
- if config.pred_num_hidden_layers > 1
- else 0.0
- )
- for i in range(config.pred_num_hidden_layers)
- ]
- self.layer = nn.ModuleList(
- [
- VJEPA2Layer(
- config,
- drop_path_rate=drop_path_rates[i],
- hidden_size=config.pred_hidden_size,
- num_attention_heads=config.pred_num_attention_heads,
- mlp_ratio=config.pred_mlp_ratio,
- )
- for i in range(config.pred_num_hidden_layers)
- ]
- )
- self.layernorm = nn.LayerNorm(config.pred_hidden_size, eps=config.layer_norm_eps)
- self.proj = nn.Linear(config.pred_hidden_size, config.hidden_size, bias=True)
- def sort_tokens(self, hidden_states, position_masks, argsort):
- # gather position masks
- argsort = argsort.to(position_masks.device)
- position_masks = torch.gather(position_masks, dim=1, index=argsort)
- # gather hidden states
- argsort = argsort.to(hidden_states.device)
- hidden_states_argsort = argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
- hidden_states = torch.gather(hidden_states, dim=1, index=hidden_states_argsort)
- return hidden_states, position_masks
- def unsort_tokens(self, hidden_states, argsort):
- argsort = argsort.to(hidden_states.device)
- reverse_argsort = torch.argsort(argsort, dim=1)
- reverse_argsort = reverse_argsort.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1))
- hidden_states = torch.gather(hidden_states, dim=1, index=reverse_argsort)
- return hidden_states
- def forward(
- self,
- encoder_hidden_states: torch.Tensor,
- context_mask: list[torch.Tensor],
- target_mask: list[torch.Tensor],
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- # mask out the encoder hidden states
- # this is implemented here as in VJEPA training a separate encoder is used for target
- encoder_hidden_states = apply_masks(encoder_hidden_states, context_mask)
- _, N_ctxt, D = encoder_hidden_states.shape
- hidden_states, position_masks = self.embeddings(encoder_hidden_states, context_mask, target_mask)
- # Put tokens in sorted order
- argsort = torch.argsort(position_masks, dim=1) # [B, N]
- hidden_states, position_masks = self.sort_tokens(hidden_states, position_masks, argsort)
- for i, layer_module in enumerate(self.layer):
- layer_outputs = layer_module(hidden_states, position_masks, **kwargs)
- hidden_states = layer_outputs[0]
- hidden_states = self.layernorm(hidden_states)
- # unsort and extract the predicted tokens
- hidden_states = self.unsort_tokens(hidden_states, argsort)
- hidden_states = hidden_states[:, N_ctxt:]
- # projection
- hidden_states = self.proj(hidden_states)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- )
- class VJEPA2PoolerSelfAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: VJEPA2Config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.is_causal = False
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Input shape: Batch x Time x Channel"""
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- queries,
- keys,
- values,
- attention_mask,
- is_causal=self.is_causal,
- scaling=self.scale,
- dropout=0.0 if not self.training else self.dropout,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- class VJEPA2PoolerCrossAttention(nn.Module):
- """It's different from other cross-attention layers, doesn't have output projection layer (o_proj)"""
- # in case of modular refactoring - o_proj can be replaces with nn.Identity()
- def __init__(self, config: VJEPA2Config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.is_causal = False
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
- def forward(
- self,
- queries: torch.Tensor,
- keys: torch.Tensor,
- values: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Input shape: Batch x Time x Channel"""
- batch_size, q_seq_length, embed_dim = queries.shape
- kv_seq_length = keys.shape[1]
- queries = self.q_proj(queries)
- keys = self.k_proj(keys)
- values = self.v_proj(values)
- queries = queries.view(batch_size, q_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
- keys = keys.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
- values = values.view(batch_size, kv_seq_length, self.num_heads, self.head_dim).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- queries,
- keys,
- values,
- attention_mask,
- is_causal=self.is_causal,
- scaling=self.scale,
- dropout=0.0 if not self.training else self.dropout,
- )
- attn_output = attn_output.reshape(batch_size, q_seq_length, embed_dim).contiguous()
- return attn_output, attn_weights
- # Modified from SiglipEncoderLayer, but we have to propagate proper hidden_size to VJEPA2MLP
- class VJEPA2PoolerSelfAttentionLayer(GradientCheckpointingLayer):
- def __init__(self, config: VJEPA2Config):
- super().__init__()
- self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.self_attn = VJEPA2PoolerSelfAttention(config)
- self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Args:
- hidden_states (`torch.FloatTensor`):
- Input to the layer of shape `(batch, seq_len, embed_dim)`.
- attention_mask (`torch.FloatTensor`):
- Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
- """
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states, attn_weights
- class VJEPA2PoolerCrossAttentionLayer(GradientCheckpointingLayer):
- def __init__(self, config: VJEPA2Config):
- super().__init__()
- self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.cross_attn = VJEPA2PoolerCrossAttention(config)
- self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.mlp = VJEPA2MLP(config, hidden_size=config.hidden_size)
- def forward(
- self,
- queries: torch.Tensor,
- hidden_state: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- # Apply cross-attention
- residual = queries
- hidden_state = self.layer_norm1(hidden_state)
- hidden_state, *attn_weights = self.cross_attn(
- queries,
- hidden_state,
- hidden_state,
- attention_mask=attention_mask,
- )
- hidden_state = residual + hidden_state
- # Apply MLP
- residual = hidden_state
- hidden_state = self.layer_norm2(hidden_state)
- hidden_state = self.mlp(hidden_state)
- hidden_state = residual + hidden_state
- return hidden_state, *attn_weights
- class VJEPA2AttentivePooler(nn.Module):
- """Attentive Pooler"""
- def __init__(self, config: VJEPA2Config):
- super().__init__()
- self.query_tokens = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
- self.cross_attention_layer = VJEPA2PoolerCrossAttentionLayer(config)
- self.self_attention_layers = nn.ModuleList(
- [VJEPA2PoolerSelfAttentionLayer(config) for _ in range(config.num_pooler_layers)]
- )
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- for layer in self.self_attention_layers:
- hidden_state = layer(hidden_state, attention_mask=None)[0]
- queries = self.query_tokens.repeat(hidden_state.shape[0], 1, 1)
- hidden_state = self.cross_attention_layer(queries, hidden_state)[0]
- return hidden_state.squeeze(1)
- @auto_docstring
- class VJEPA2PreTrainedModel(PreTrainedModel):
- config: VJEPA2Config
- base_model_prefix = "vjepa2"
- main_input_name = "pixel_values_videos"
- input_modalities = "video"
- supports_gradient_checkpointing = True
- _no_split_modules = [
- "VJEPA2Layer",
- "VJEPA2PoolerSelfAttentionLayer",
- "VJEPA2PoolerCrossAttentionLayer",
- "VJEPA2PredictorEmbeddings",
- ]
- _supports_sdpa = True
- _supports_flash_attn = True
- _can_record_outputs = {
- "hidden_states": OutputRecorder(VJEPA2Layer, layer_name="encoder.layer"),
- "attentions": OutputRecorder(VJEPA2RopeAttention, index=1, layer_name="encoder.layer"),
- }
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- init_std = self.config.initializer_range
- if isinstance(module, VJEPA2AttentivePooler):
- init.trunc_normal_(module.query_tokens, std=init_std)
- for i, layer in enumerate(module.self_attention_layers, 1):
- std = init_std / (i**0.5)
- init.trunc_normal_(layer.self_attn.out_proj.weight, std=std)
- init.trunc_normal_(layer.mlp.fc2.weight, std=std)
- std = init_std / (len(module.self_attention_layers) + 1) ** 0.5
- init.trunc_normal_(module.cross_attention_layer.mlp.fc2.weight, std=std)
- elif isinstance(module, VJEPA2PredictorEmbeddings):
- if module.zero_init_mask_tokens:
- init.zeros_(module.mask_tokens)
- else:
- init.trunc_normal_(module.mask_tokens, std=init_std)
- elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
- init.trunc_normal_(module.weight, std=init_std)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, nn.LayerNorm):
- init.zeros_(module.bias)
- init.ones_(module.weight)
- @auto_docstring
- class VJEPA2Model(VJEPA2PreTrainedModel):
- def __init__(self, config: VJEPA2Config):
- super().__init__(config)
- self.config = config
- self.encoder = VJEPA2Encoder(config)
- self.predictor = VJEPA2Predictor(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self) -> VJEPA2PatchEmbeddings3D:
- return self.encoder.embeddings.patch_embeddings
- @merge_with_config_defaults
- @capture_outputs(tie_last_hidden_states=False)
- @auto_docstring
- def forward(
- self,
- pixel_values_videos: torch.Tensor,
- context_mask: list[torch.Tensor] | None = None,
- target_mask: list[torch.Tensor] | None = None,
- skip_predictor: bool = False,
- **kwargs: Unpack[TransformersKwargs],
- ) -> VJEPA2WithMaskedInputModelOutput:
- r"""
- context_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
- The mask position ids indicating which encoder output patches are going to be exposed to the predictor.
- By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating full context
- available to the predictor.
- target_mask (`torch.Tensor` with shape `[batch_size, patch_size, 1]`, *optional*):
- The mask position ids indicating which encoder output patches are going to be used as a prediction target
- for the predictor. By default, this mask is created as torch.arange(N).unsqueeze(0).repeat(B,1), indicating
- that the predictor should predict all encoder patches.
- skip_predictor (bool):
- flag to skip the predictor forward, useful if you just need the encoder outputs
- """
- if pixel_values_videos is None:
- raise ValueError("You have to specify pixel_values_videos")
- encoder_outputs: BaseModelOutput = self.encoder(
- pixel_values_videos=pixel_values_videos,
- **kwargs,
- )
- sequence_output = encoder_outputs.last_hidden_state
- if context_mask is None and target_mask is None:
- B = pixel_values_videos.size(0)
- N = sequence_output.size(1) # ensure we are using dynamic patch size
- context_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
- target_mask = [torch.arange(N, device=pixel_values_videos.device).unsqueeze(0).repeat((B, 1))]
- if not skip_predictor:
- predictor_outputs: BaseModelOutput = self.predictor(
- encoder_hidden_states=sequence_output,
- context_mask=context_mask,
- target_mask=target_mask,
- **kwargs,
- )
- predictor_output = VJEPA2WithMaskedInputPredictorOutput(
- last_hidden_state=predictor_outputs.last_hidden_state,
- target_hidden_state=apply_masks(sequence_output, target_mask),
- hidden_states=predictor_outputs.hidden_states,
- attentions=predictor_outputs.attentions,
- )
- else:
- predictor_output = None
- encoder_output = VJEPA2WithMaskedInputModelOutput(
- last_hidden_state=sequence_output,
- masked_hidden_state=apply_masks(sequence_output, context_mask),
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- predictor_output=predictor_output,
- )
- return encoder_output
- def get_vision_features(self, pixel_values_videos) -> torch.Tensor:
- encoder_output = self.forward(pixel_values_videos, skip_predictor=True)
- return encoder_output.last_hidden_state
- @auto_docstring(
- custom_intro="""
- V-JEPA 2 Model transformer with a video classification head on top (a linear layer on top of the attentive pooler).
- """
- )
- class VJEPA2ForVideoClassification(VJEPA2PreTrainedModel):
- def __init__(self, config: VJEPA2Config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.vjepa2 = VJEPA2Model(config)
- # Classifier head
- self.pooler = VJEPA2AttentivePooler(config)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels, bias=True)
- # Initialize weights and apply final processing
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values_videos: torch.Tensor,
- labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | ImageClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> import torch
- >>> import numpy as np
- >>> from transformers import AutoVideoProcessor, VJEPA2ForVideoClassification
- >>> device = "cuda"
- >>> video_processor = AutoVideoProcessor.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2")
- >>> model = VJEPA2ForVideoClassification.from_pretrained("facebook/vjepa2-vitl-fpc16-256-ssv2").to(device)
- >>> video = np.ones((64, 256, 256, 3)) # 64 frames, 256x256 RGB
- >>> inputs = video_processor(video, return_tensors="pt").to(device)
- >>> # For inference
- >>> with torch.no_grad():
- ... outputs = model(**inputs)
- >>> logits = outputs.logits
- >>> predicted_label = logits.argmax(-1).item()
- >>> print(model.config.id2label[predicted_label])
- >>> # For training
- >>> labels = torch.ones(1, dtype=torch.long, device=device)
- >>> loss = model(**inputs, labels=labels).loss
- ```"""
- outputs = self.vjepa2(
- pixel_values_videos=pixel_values_videos,
- skip_predictor=True,
- **kwargs,
- )
- last_hidden_state = outputs.last_hidden_state
- pooler_output = self.pooler(last_hidden_state)
- logits = self.classifier(pooler_output)
- loss = None
- if labels is not None:
- loss = self.loss_function(pooled_logits=logits, labels=labels, config=self.config)
- return ImageClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = ["VJEPA2Model", "VJEPA2PreTrainedModel", "VJEPA2ForVideoClassification"]
|