modeling_videomae.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. # Copyright 2022 Multimedia Computing Group, Nanjing University and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch VideoMAE (masked autoencoder) model."""
  15. import collections.abc
  16. from collections.abc import Callable
  17. from copy import deepcopy
  18. from dataclasses import dataclass
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from torch.nn import MSELoss
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
  26. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
  29. from ...utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  30. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  31. from ...utils.output_capturing import capture_outputs
  32. from .configuration_videomae import VideoMAEConfig
  33. logger = logging.get_logger(__name__)
  34. @dataclass
  35. @auto_docstring(
  36. custom_intro="""
  37. Class for VideoMAEDecoder's outputs, with potential hidden states and attentions.
  38. """
  39. )
  40. class VideoMAEDecoderOutput(ModelOutput):
  41. r"""
  42. logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
  43. Pixel reconstruction logits.
  44. """
  45. logits: torch.FloatTensor | None = None
  46. hidden_states: tuple[torch.FloatTensor] | None = None
  47. attentions: tuple[torch.FloatTensor] | None = None
  48. @dataclass
  49. @auto_docstring(
  50. custom_intro="""
  51. Class for VideoMAEForPreTraining's outputs, with potential hidden states and attentions.
  52. """
  53. )
  54. class VideoMAEForPreTrainingOutput(ModelOutput):
  55. r"""
  56. loss (`torch.FloatTensor` of shape `(1,)`):
  57. Pixel reconstruction loss.
  58. logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
  59. Pixel reconstruction logits.
  60. """
  61. loss: torch.FloatTensor | None = None
  62. logits: torch.FloatTensor | None = None
  63. hidden_states: tuple[torch.FloatTensor] | None = None
  64. attentions: tuple[torch.FloatTensor] | None = None
  65. # sin-cos position encoding
  66. # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
  67. def get_sinusoid_encoding_table(n_position, d_hid):
  68. """Sinusoid position encoding table"""
  69. # TODO: make it with torch instead of numpy
  70. def get_position_angle_vec(position):
  71. return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
  72. sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  73. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  74. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  75. return torch.FloatTensor(sinusoid_table).unsqueeze(0)
  76. class VideoMAEEmbeddings(nn.Module):
  77. """
  78. Construct the patch and position embeddings.
  79. """
  80. def __init__(self, config):
  81. super().__init__()
  82. self.patch_embeddings = VideoMAEPatchEmbeddings(config)
  83. self.num_patches = self.patch_embeddings.num_patches
  84. # fixed sin-cos embedding
  85. self.position_embeddings = get_sinusoid_encoding_table(self.num_patches, config.hidden_size)
  86. self.config = config
  87. def forward(self, pixel_values, bool_masked_pos):
  88. # create patch embeddings
  89. embeddings = self.patch_embeddings(pixel_values)
  90. # add position embeddings
  91. embeddings = embeddings + self.position_embeddings.detach().type_as(embeddings).to(
  92. device=embeddings.device, copy=True
  93. )
  94. # only keep visible patches
  95. # ~bool_masked_pos means visible
  96. if bool_masked_pos is not None:
  97. batch_size, _, num_channels = embeddings.shape
  98. embeddings = embeddings[~bool_masked_pos]
  99. embeddings = embeddings.reshape(batch_size, -1, num_channels)
  100. return embeddings
  101. class VideoMAEPatchEmbeddings(nn.Module):
  102. """
  103. Video to Patch Embedding. This module turns a batch of videos of shape (batch_size, num_frames, num_channels,
  104. height, width) into a tensor of shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
  105. The seq_len (the number of patches) equals (number of frames // tubelet_size) * (height // patch_size) * (width //
  106. patch_size).
  107. """
  108. def __init__(self, config):
  109. super().__init__()
  110. image_size = config.image_size
  111. patch_size = config.patch_size
  112. num_channels = config.num_channels
  113. hidden_size = config.hidden_size
  114. num_frames = config.num_frames
  115. tubelet_size = config.tubelet_size
  116. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  117. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  118. self.image_size = image_size
  119. self.patch_size = patch_size
  120. self.tubelet_size = int(tubelet_size)
  121. num_patches = (
  122. (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
  123. )
  124. self.num_channels = num_channels
  125. self.num_patches = num_patches
  126. self.projection = nn.Conv3d(
  127. in_channels=num_channels,
  128. out_channels=hidden_size,
  129. kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
  130. stride=(self.tubelet_size, patch_size[0], patch_size[1]),
  131. )
  132. def forward(self, pixel_values):
  133. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  134. if num_channels != self.num_channels:
  135. raise ValueError(
  136. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  137. )
  138. if height != self.image_size[0] or width != self.image_size[1]:
  139. raise ValueError(
  140. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  141. )
  142. # permute to (batch_size, num_channels, num_frames, height, width)
  143. pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
  144. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  145. return embeddings
  146. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  147. def eager_attention_forward(
  148. module: nn.Module,
  149. query: torch.Tensor,
  150. key: torch.Tensor,
  151. value: torch.Tensor,
  152. attention_mask: torch.Tensor | None,
  153. scaling: float | None = None,
  154. dropout: float = 0.0,
  155. **kwargs: Unpack[TransformersKwargs],
  156. ):
  157. if scaling is None:
  158. scaling = query.size(-1) ** -0.5
  159. # Take the dot product between "query" and "key" to get the raw attention scores.
  160. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  161. if attention_mask is not None:
  162. attn_weights = attn_weights + attention_mask
  163. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  164. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  165. attn_output = torch.matmul(attn_weights, value)
  166. attn_output = attn_output.transpose(1, 2).contiguous()
  167. return attn_output, attn_weights
  168. class VideoMAESelfAttention(nn.Module):
  169. def __init__(self, config: VideoMAEConfig) -> None:
  170. super().__init__()
  171. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  172. raise ValueError(
  173. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  174. f"heads {config.num_attention_heads}."
  175. )
  176. self.config = config
  177. self.num_attention_heads = config.num_attention_heads
  178. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  179. self.all_head_size = self.num_attention_heads * self.attention_head_size
  180. self.dropout_prob = config.attention_probs_dropout_prob
  181. self.scaling = self.attention_head_size**-0.5
  182. self.is_causal = False
  183. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  184. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  185. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  186. def forward(
  187. self, hidden_states: torch.Tensor | None = None
  188. ) -> tuple[torch.Tensor, torch.Tensor]: # TODO: siglip attention 1-1
  189. input_shape = hidden_states.shape[:-1]
  190. hidden_shape = (*input_shape, -1, self.attention_head_size)
  191. keys = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  192. values = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  193. queries = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  194. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  195. self.config._attn_implementation, eager_attention_forward
  196. )
  197. context_layer, attention_probs = attention_interface(
  198. self,
  199. queries,
  200. keys,
  201. values,
  202. None,
  203. is_causal=self.is_causal,
  204. scaling=self.scaling,
  205. dropout=0.0 if not self.training else self.dropout_prob,
  206. )
  207. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  208. context_layer = context_layer.reshape(new_context_layer_shape)
  209. return context_layer, attention_probs
  210. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
  211. class VideoMAESelfOutput(nn.Module):
  212. """
  213. The residual connection is defined in VideoMAELayer instead of here (as is the case with other models), due to the
  214. layernorm applied before each block.
  215. """
  216. def __init__(self, config: VideoMAEConfig):
  217. super().__init__()
  218. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  219. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  220. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  221. hidden_states = self.dense(hidden_states)
  222. hidden_states = self.dropout(hidden_states)
  223. return hidden_states
  224. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VideoMAE
  225. class VideoMAEAttention(nn.Module):
  226. def __init__(self, config: VideoMAEConfig):
  227. super().__init__()
  228. self.attention = VideoMAESelfAttention(config)
  229. self.output = VideoMAESelfOutput(config)
  230. def forward(
  231. self,
  232. hidden_states: torch.Tensor,
  233. **kwargs: Unpack[TransformersKwargs],
  234. ) -> torch.Tensor:
  235. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  236. output = self.output(self_attn_output, hidden_states)
  237. return output
  238. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
  239. class VideoMAEIntermediate(nn.Module):
  240. def __init__(self, config: VideoMAEConfig):
  241. super().__init__()
  242. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  243. if isinstance(config.hidden_act, str):
  244. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  245. else:
  246. self.intermediate_act_fn = config.hidden_act
  247. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  248. hidden_states = self.dense(hidden_states)
  249. hidden_states = self.intermediate_act_fn(hidden_states)
  250. return hidden_states
  251. # Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->VideoMAE
  252. class VideoMAEOutput(nn.Module):
  253. def __init__(self, config: VideoMAEConfig):
  254. super().__init__()
  255. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  256. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  257. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  258. hidden_states = self.dense(hidden_states)
  259. hidden_states = self.dropout(hidden_states)
  260. hidden_states = hidden_states + input_tensor
  261. return hidden_states
  262. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
  263. class VideoMAELayer(GradientCheckpointingLayer):
  264. """This corresponds to the Block class in the timm implementation."""
  265. def __init__(self, config: VideoMAEConfig):
  266. super().__init__()
  267. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  268. self.seq_len_dim = 1
  269. self.attention = VideoMAEAttention(config)
  270. self.intermediate = VideoMAEIntermediate(config)
  271. self.output = VideoMAEOutput(config)
  272. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  273. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  274. def forward(
  275. self,
  276. hidden_states: torch.Tensor,
  277. **kwargs: Unpack[TransformersKwargs],
  278. ) -> torch.Tensor:
  279. hidden_states_norm = self.layernorm_before(hidden_states)
  280. attention_output = self.attention(hidden_states_norm, **kwargs)
  281. # first residual connection
  282. hidden_states = attention_output + hidden_states
  283. # in VideoMAE, layernorm is also applied after self-attention
  284. layer_output = self.layernorm_after(hidden_states)
  285. layer_output = self.intermediate(layer_output)
  286. # second residual connection is done here
  287. layer_output = self.output(layer_output, hidden_states)
  288. return layer_output
  289. # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VideoMAE
  290. class VideoMAEEncoder(nn.Module):
  291. def __init__(self, config: VideoMAEConfig):
  292. super().__init__()
  293. self.config = config
  294. self.layer = nn.ModuleList([VideoMAELayer(config) for _ in range(config.num_hidden_layers)])
  295. self.gradient_checkpointing = False
  296. def forward(
  297. self,
  298. hidden_states: torch.Tensor,
  299. **kwargs: Unpack[TransformersKwargs],
  300. ) -> BaseModelOutput:
  301. for layer_module in self.layer:
  302. hidden_states = layer_module(hidden_states, **kwargs)
  303. return BaseModelOutput(last_hidden_state=hidden_states)
  304. @auto_docstring
  305. class VideoMAEPreTrainedModel(PreTrainedModel):
  306. config: VideoMAEConfig
  307. base_model_prefix = "videomae"
  308. main_input_name = "pixel_values"
  309. input_modalities = "video"
  310. supports_gradient_checkpointing = True
  311. _no_split_modules = ["VideoMAEEmbeddings", "VideoMAELayer"]
  312. _supports_sdpa = True
  313. _supports_flash_attn = True
  314. _supports_flex_attn = True
  315. _supports_attention_backend = True
  316. _can_record_outputs = {
  317. "hidden_states": VideoMAELayer,
  318. "attentions": VideoMAESelfAttention,
  319. }
  320. @auto_docstring
  321. class VideoMAEModel(VideoMAEPreTrainedModel):
  322. def __init__(self, config):
  323. super().__init__(config)
  324. self.config = config
  325. self.embeddings = VideoMAEEmbeddings(config)
  326. self.encoder = VideoMAEEncoder(config)
  327. if config.use_mean_pooling:
  328. self.layernorm = None
  329. else:
  330. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  331. # Initialize weights and apply final processing
  332. self.post_init()
  333. def get_input_embeddings(self):
  334. return self.embeddings.patch_embeddings
  335. @merge_with_config_defaults
  336. @capture_outputs(tie_last_hidden_states=False)
  337. @auto_docstring
  338. def forward(
  339. self,
  340. pixel_values: torch.FloatTensor,
  341. bool_masked_pos: torch.BoolTensor | None = None,
  342. **kwargs: Unpack[TransformersKwargs],
  343. ) -> BaseModelOutput:
  344. r"""
  345. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  346. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
  347. batch must have the same number of masked patches. If `None`, then all patches are considered. Sequence
  348. length is `(num_frames // tubelet_size) * (image_size // patch_size) ** 2`.
  349. Examples:
  350. ```python
  351. >>> import torch
  352. >>> from transformers import VideoMAEVideoProcessor, VideoMAEModel
  353. >>> from huggingface_hub import hf_hub_download
  354. >>> # replace this with your own video file
  355. >>> video_path = hf_hub_download(
  356. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  357. ... )
  358. >>> video_processor = VideoMAEVideoProcessor.from_pretrained("MCG-NJU/videomae-base")
  359. >>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")
  360. >>> # prepare video for the model
  361. >>> inputs = video_processor(video_path, return_tensors="pt")
  362. >>> # forward pass
  363. >>> with torch.no_grad():
  364. ... outputs = model(**inputs)
  365. >>> last_hidden_states = outputs.last_hidden_state
  366. >>> list(last_hidden_states.shape)
  367. [1, 1568, 768]
  368. ```"""
  369. embedding_output = self.embeddings(pixel_values, bool_masked_pos)
  370. encoder_outputs: BaseModelOutput = self.encoder(embedding_output)
  371. sequence_output = encoder_outputs.last_hidden_state
  372. if self.layernorm is not None:
  373. sequence_output = self.layernorm(sequence_output)
  374. return BaseModelOutput(last_hidden_state=sequence_output)
  375. class VideoMAEDecoder(nn.Module):
  376. def __init__(self, config: VideoMAEConfig):
  377. super().__init__()
  378. decoder_num_labels = config.num_channels * config.tubelet_size * config.patch_size**2
  379. decoder_config = deepcopy(config)
  380. decoder_config.hidden_size = config.decoder_hidden_size
  381. decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
  382. decoder_config.num_attention_heads = config.decoder_num_attention_heads
  383. decoder_config.intermediate_size = config.decoder_intermediate_size
  384. self.decoder_layers = nn.ModuleList(
  385. [VideoMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
  386. )
  387. self.norm = nn.LayerNorm(config.decoder_hidden_size)
  388. self.head = (
  389. nn.Linear(config.decoder_hidden_size, decoder_num_labels) if decoder_num_labels > 0 else nn.Identity()
  390. )
  391. self.gradient_checkpointing = False
  392. self.config = decoder_config
  393. def forward(self, hidden_states: torch.Tensor, return_token_num: int):
  394. # Apply transformer layers
  395. for layer_module in self.decoder_layers:
  396. hidden_states = layer_module(hidden_states)
  397. hidden_states = hidden_states[:, -return_token_num:]
  398. # predictor projection
  399. hidden_states = self.norm(hidden_states)
  400. logits = self.head(hidden_states)
  401. return VideoMAEDecoderOutput(logits=logits)
  402. @auto_docstring(
  403. custom_intro="""
  404. The VideoMAE Model transformer with the decoder on top for self-supervised pre-training.
  405. """
  406. )
  407. class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
  408. def __init__(self, config):
  409. super().__init__(config)
  410. self.config = config
  411. self.videomae = VideoMAEModel(config)
  412. self.encoder_to_decoder = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=False)
  413. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
  414. self.position_embeddings = get_sinusoid_encoding_table(
  415. self.videomae.embeddings.num_patches, config.decoder_hidden_size
  416. )
  417. self.decoder = VideoMAEDecoder(config)
  418. # Initialize weights and apply final processing
  419. self.post_init()
  420. @can_return_tuple
  421. @auto_docstring
  422. def forward(
  423. self,
  424. pixel_values: torch.FloatTensor,
  425. bool_masked_pos: torch.BoolTensor,
  426. **kwargs: Unpack[TransformersKwargs],
  427. ) -> VideoMAEForPreTrainingOutput:
  428. r"""
  429. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
  430. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
  431. batch must have the same number of masked patches. Sequence length is `(num_frames // tubelet_size) *
  432. (image_size // patch_size) ** 2`.
  433. Examples:
  434. ```python
  435. >>> from transformers import AutoImageProcessor, VideoMAEForPreTraining
  436. >>> import numpy as np
  437. >>> import torch
  438. >>> num_frames = 16
  439. >>> video = list(np.random.randint(0, 256, (num_frames, 3, 224, 224)))
  440. >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
  441. >>> model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base")
  442. >>> pixel_values = image_processor(video, return_tensors="pt").pixel_values
  443. >>> num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
  444. >>> seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
  445. >>> bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
  446. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  447. >>> loss = outputs.loss
  448. ```"""
  449. outputs: BaseModelOutput = self.videomae(pixel_values, bool_masked_pos=bool_masked_pos, **kwargs)
  450. sequence_output = outputs.last_hidden_state
  451. sequence_output = self.encoder_to_decoder(sequence_output)
  452. # [batch_size, num_visible_patches, decoder_hidden_size]
  453. batch_size, _, num_channels = sequence_output.shape
  454. # we don't unshuffle the correct visible token order, but shuffle the position embeddings accordingly.
  455. if bool_masked_pos is None:
  456. raise ValueError("One must provided a boolean mask ")
  457. expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
  458. expanded_position_embeddings = expanded_position_embeddings.detach().to(device=pixel_values.device, copy=True)
  459. pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
  460. pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
  461. # [batch_size, num_patches, decoder_hidden_size]
  462. x_full = torch.cat([sequence_output + pos_emb_visible, self.mask_token + pos_emb_mask], dim=1)
  463. # [batch_size, num_masked_patches, num_channels * patch_size * patch_size]
  464. decoder_outputs: VideoMAEDecoderOutput = self.decoder(x_full, pos_emb_mask.shape[1])
  465. logits = decoder_outputs.logits
  466. loss = None
  467. with torch.no_grad():
  468. # calculate the labels to be predicted
  469. if self.config.num_channels != 3:
  470. # Can't unnormalize with default means/stds
  471. frames = pixel_values
  472. else:
  473. # first, unnormalize the frames
  474. device = pixel_values.device
  475. dtype = pixel_values.dtype
  476. mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
  477. std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
  478. frames = pixel_values * std + mean # in [0, 1]
  479. batch_size, time, num_channels, height, width = frames.shape
  480. tubelet_size, patch_size = self.config.tubelet_size, self.config.patch_size
  481. if self.config.norm_pix_loss:
  482. # step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
  483. frames = frames.view(
  484. batch_size,
  485. time // tubelet_size,
  486. tubelet_size,
  487. num_channels,
  488. height // patch_size,
  489. patch_size,
  490. width // patch_size,
  491. patch_size,
  492. )
  493. # step 2: move dimensions to concatenate:
  494. frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
  495. # step 3: concatenate:
  496. frames = frames.view(
  497. batch_size,
  498. time // tubelet_size * height // patch_size * width // patch_size,
  499. tubelet_size * patch_size * patch_size,
  500. num_channels,
  501. )
  502. # step 4: normalize. The authors find that the mean is about 0.48 and standard deviation is about 0.08.
  503. frames_norm = (frames - frames.mean(dim=-2, keepdim=True)) / (
  504. frames.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
  505. )
  506. # step 5: reshape to (batch_size, T//ts * H//ps * W//ps, ts * ps * ps * C)
  507. videos_patch = frames_norm.view(
  508. batch_size,
  509. time // tubelet_size * height // patch_size * width // patch_size,
  510. tubelet_size * patch_size * patch_size * num_channels,
  511. )
  512. else:
  513. if self.config.num_channels != 3:
  514. raise ValueError(
  515. "Can't unnormalize non-RGB images. Consider setting config.norm_pix_loss to False."
  516. )
  517. # step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
  518. frames = frames.view(
  519. batch_size,
  520. time // tubelet_size,
  521. tubelet_size,
  522. num_channels,
  523. height // patch_size,
  524. patch_size,
  525. width // patch_size,
  526. patch_size,
  527. )
  528. # step 2: move dimensions to concatenate: (batch_size, T//ts, H//ps, W//ps, ts, ps, ps, C)
  529. frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
  530. # step 3: concatenate
  531. videos_patch = frames.view(
  532. batch_size,
  533. time // tubelet_size * height // patch_size * width // patch_size,
  534. tubelet_size * patch_size * patch_size * num_channels,
  535. )
  536. batch_size, _, num_channels = videos_patch.shape
  537. labels = videos_patch[bool_masked_pos].reshape(batch_size, -1, num_channels)
  538. loss_fct = MSELoss()
  539. loss = loss_fct(logits, labels)
  540. return VideoMAEForPreTrainingOutput(
  541. loss=loss,
  542. logits=logits,
  543. hidden_states=outputs.hidden_states,
  544. attentions=outputs.attentions,
  545. )
  546. @auto_docstring(
  547. custom_intro="""
  548. VideoMAE Model transformer with a video classification head on top (a linear layer on top of the average pooled hidden
  549. states of all tokens) e.g. for ImageNet.
  550. """
  551. )
  552. class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
  553. def __init__(self, config):
  554. super().__init__(config)
  555. self.num_labels = config.num_labels
  556. self.videomae = VideoMAEModel(config)
  557. # Classifier head
  558. self.fc_norm = nn.LayerNorm(config.hidden_size) if config.use_mean_pooling else None
  559. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  560. # Initialize weights and apply final processing
  561. self.post_init()
  562. @can_return_tuple
  563. @auto_docstring
  564. def forward(
  565. self,
  566. pixel_values: torch.Tensor | None = None,
  567. labels: torch.Tensor | None = None,
  568. **kwargs: Unpack[TransformersKwargs],
  569. ) -> ImageClassifierOutput:
  570. r"""
  571. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  572. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  573. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  574. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  575. Examples:
  576. ```python
  577. >>> import torch
  578. >>> from transformers import VideoMAEVideoProcessor, VideoMAEForVideoClassification
  579. >>> from huggingface_hub import hf_hub_download
  580. >>> # replace this with your own video file
  581. >>> video_path = hf_hub_download(
  582. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  583. ... )
  584. >>> video_processor = VideoMAEVideoProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
  585. >>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
  586. >>> inputs = video_processor(video_path, return_tensors="pt")
  587. >>> with torch.no_grad():
  588. ... outputs = model(**inputs)
  589. ... logits = outputs.logits
  590. >>> # model predicts one of the 400 Kinetics-400 classes
  591. >>> predicted_label = logits.argmax(-1).item()
  592. >>> print(model.config.id2label[predicted_label])
  593. eating spaghetti
  594. ```"""
  595. outputs: BaseModelOutput = self.videomae(pixel_values, **kwargs)
  596. sequence_output = outputs.last_hidden_state
  597. if self.fc_norm is not None:
  598. output = sequence_output.mean(1)
  599. output = self.fc_norm(output)
  600. else:
  601. output = sequence_output[:, 0]
  602. logits = self.classifier(output)
  603. loss = None
  604. if labels is not None:
  605. loss = self.loss_function(labels, logits, self.config, **kwargs)
  606. return ImageClassifierOutput(
  607. loss=loss,
  608. logits=logits,
  609. hidden_states=outputs.hidden_states,
  610. attentions=outputs.attentions,
  611. )
  612. __all__ = ["VideoMAEForPreTraining", "VideoMAEModel", "VideoMAEPreTrainedModel", "VideoMAEForVideoClassification"]