modeling_vivit.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. # Copyright 2023 Google AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch ViViT model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...modeling_layers import GradientCheckpointingLayer
  21. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import TransformersKwargs, auto_docstring, logging, torch_int
  25. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  26. from ...utils.output_capturing import capture_outputs
  27. from .configuration_vivit import VivitConfig
  28. logger = logging.get_logger(__name__)
  29. class VivitTubeletEmbeddings(nn.Module):
  30. """
  31. Construct Vivit Tubelet embeddings.
  32. This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
  33. shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
  34. The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
  35. (width // tubelet_size[2]).
  36. """
  37. def __init__(self, config: VivitConfig):
  38. super().__init__()
  39. self.num_frames = config.num_frames
  40. self.image_size = config.image_size
  41. self.patch_size = config.tubelet_size
  42. self.num_patches = (
  43. (self.image_size // self.patch_size[2])
  44. * (self.image_size // self.patch_size[1])
  45. * (self.num_frames // self.patch_size[0])
  46. )
  47. self.embed_dim = config.hidden_size
  48. self.projection = nn.Conv3d(
  49. config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
  50. )
  51. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  52. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  53. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  54. raise ValueError(
  55. f"Image image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  56. )
  57. # permute to (batch_size, num_channels, num_frames, height, width)
  58. pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
  59. x = self.projection(pixel_values)
  60. # out_batch_size, out_num_channels, out_num_frames, out_height, out_width = x.shape
  61. # flattens time and space dimensions, transposes to (out_batch_size, flat_tokens, out_num_channels)
  62. x = x.flatten(2).transpose(1, 2)
  63. return x
  64. class VivitEmbeddings(nn.Module):
  65. """
  66. Vivit Embeddings.
  67. Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings.
  68. """
  69. def __init__(self, config: VivitConfig):
  70. super().__init__()
  71. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  72. self.patch_embeddings = VivitTubeletEmbeddings(config)
  73. self.position_embeddings = nn.Parameter(
  74. torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
  75. )
  76. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  77. self.patch_size = config.tubelet_size[1:]
  78. self.config = config
  79. # Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  80. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  81. """
  82. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  83. images. This method is also adapted to support torch.jit tracing.
  84. Adapted from:
  85. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  86. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  87. """
  88. num_patches = embeddings.shape[1] - 1
  89. num_positions = self.position_embeddings.shape[1] - 1
  90. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  91. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  92. return self.position_embeddings
  93. class_pos_embed = self.position_embeddings[:, :1]
  94. patch_pos_embed = self.position_embeddings[:, 1:]
  95. dim = embeddings.shape[-1]
  96. new_height = height // self.patch_size[0]
  97. new_width = width // self.patch_size[1]
  98. sqrt_num_positions = torch_int(num_positions**0.5)
  99. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  100. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  101. patch_pos_embed = nn.functional.interpolate(
  102. patch_pos_embed,
  103. size=(new_height, new_width),
  104. mode="bicubic",
  105. align_corners=False,
  106. )
  107. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  108. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  109. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  110. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  111. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  112. cls_tokens = self.cls_token.tile([batch_size, 1, 1])
  113. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  114. # add positional encoding to each token
  115. if interpolate_pos_encoding:
  116. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  117. else:
  118. embeddings = embeddings + self.position_embeddings
  119. embeddings = self.dropout(embeddings)
  120. return embeddings
  121. # Copied from transformers.models.bert.modeling_bert.eager_attention_forward
  122. def eager_attention_forward(
  123. module: nn.Module,
  124. query: torch.Tensor,
  125. key: torch.Tensor,
  126. value: torch.Tensor,
  127. attention_mask: torch.Tensor | None,
  128. scaling: float | None = None,
  129. dropout: float = 0.0,
  130. **kwargs: Unpack[TransformersKwargs],
  131. ):
  132. if scaling is None:
  133. scaling = query.size(-1) ** -0.5
  134. # Take the dot product between "query" and "key" to get the raw attention scores.
  135. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  136. if attention_mask is not None:
  137. attn_weights = attn_weights + attention_mask
  138. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  139. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  140. attn_output = torch.matmul(attn_weights, value)
  141. attn_output = attn_output.transpose(1, 2).contiguous()
  142. return attn_output, attn_weights
  143. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Vivit
  144. class VivitSelfAttention(nn.Module):
  145. def __init__(self, config: VivitConfig):
  146. super().__init__()
  147. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  148. raise ValueError(
  149. f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
  150. f"heads {config.num_attention_heads}."
  151. )
  152. self.config = config
  153. self.num_attention_heads = config.num_attention_heads
  154. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  155. self.all_head_size = self.num_attention_heads * self.attention_head_size
  156. self.dropout_prob = config.attention_probs_dropout_prob
  157. self.scaling = self.attention_head_size**-0.5
  158. self.is_causal = False
  159. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  160. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  161. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  162. def forward(
  163. self,
  164. hidden_states: torch.Tensor,
  165. **kwargs: Unpack[TransformersKwargs],
  166. ) -> tuple[torch.Tensor, torch.Tensor]:
  167. batch_size = hidden_states.shape[0]
  168. new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
  169. key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
  170. value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
  171. query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
  172. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  173. self.config._attn_implementation, eager_attention_forward
  174. )
  175. context_layer, attention_probs = attention_interface(
  176. self,
  177. query_layer,
  178. key_layer,
  179. value_layer,
  180. None,
  181. is_causal=self.is_causal,
  182. scaling=self.scaling,
  183. dropout=0.0 if not self.training else self.dropout_prob,
  184. **kwargs,
  185. )
  186. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  187. context_layer = context_layer.reshape(new_context_layer_shape)
  188. return context_layer, attention_probs
  189. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
  190. class VivitSelfOutput(nn.Module):
  191. """
  192. The residual connection is defined in VivitLayer instead of here (as is the case with other models), due to the
  193. layernorm applied before each block.
  194. """
  195. def __init__(self, config: VivitConfig):
  196. super().__init__()
  197. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  198. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  199. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  200. hidden_states = self.dense(hidden_states)
  201. hidden_states = self.dropout(hidden_states)
  202. return hidden_states
  203. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Vivit
  204. class VivitAttention(nn.Module):
  205. def __init__(self, config: VivitConfig):
  206. super().__init__()
  207. self.attention = VivitSelfAttention(config)
  208. self.output = VivitSelfOutput(config)
  209. def forward(
  210. self,
  211. hidden_states: torch.Tensor,
  212. **kwargs: Unpack[TransformersKwargs],
  213. ) -> torch.Tensor:
  214. self_attn_output, _ = self.attention(hidden_states, **kwargs)
  215. output = self.output(self_attn_output, hidden_states)
  216. return output
  217. class VivitIntermediate(nn.Module):
  218. def __init__(self, config: VivitConfig):
  219. super().__init__()
  220. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  221. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  222. if isinstance(config.hidden_act, str):
  223. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  224. else:
  225. self.intermediate_act_fn = config.hidden_act
  226. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  227. hidden_states = self.dense(hidden_states)
  228. hidden_states = self.intermediate_act_fn(hidden_states)
  229. hidden_states = self.dropout(hidden_states)
  230. return hidden_states
  231. class VivitOutput(nn.Module):
  232. def __init__(self, config: VivitConfig):
  233. super().__init__()
  234. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  235. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  236. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  237. hidden_states = self.dense(hidden_states)
  238. hidden_states = self.dropout(hidden_states)
  239. hidden_states = hidden_states + input_tensor
  240. return hidden_states
  241. class VivitLayer(GradientCheckpointingLayer):
  242. """This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
  243. def __init__(self, config: VivitConfig):
  244. super().__init__()
  245. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  246. self.seq_len_dim = 1
  247. self.attention = VivitAttention(config)
  248. self.intermediate = VivitIntermediate(config)
  249. self.output = VivitOutput(config)
  250. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  251. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  252. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  253. hidden_states_norm = self.layernorm_before(hidden_states)
  254. attention_output = self.attention(hidden_states_norm)
  255. # first residual connection
  256. hidden_states = attention_output + hidden_states
  257. # in Vivit, layernorm is also applied after self-attention
  258. layer_output = self.layernorm_after(hidden_states)
  259. layer_output = self.intermediate(layer_output)
  260. # second residual connection is done here
  261. layer_output = self.output(layer_output, hidden_states)
  262. return layer_output
  263. class VivitEncoder(nn.Module):
  264. def __init__(self, config: VivitConfig):
  265. super().__init__()
  266. self.config = config
  267. self.layer = nn.ModuleList([VivitLayer(config) for _ in range(config.num_hidden_layers)])
  268. self.gradient_checkpointing = False
  269. def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
  270. for i, layer_module in enumerate(self.layer):
  271. hidden_states = layer_module(hidden_states)
  272. return BaseModelOutput(last_hidden_state=hidden_states)
  273. class VivitPooler(nn.Module):
  274. def __init__(self, config: VivitConfig):
  275. super().__init__()
  276. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  277. self.activation = nn.Tanh()
  278. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  279. # We "pool" the model by simply taking the hidden state corresponding
  280. # to the first token.
  281. first_token_tensor = hidden_states[:, 0]
  282. pooled_output = self.dense(first_token_tensor)
  283. pooled_output = self.activation(pooled_output)
  284. return pooled_output
  285. @auto_docstring
  286. class VivitPreTrainedModel(PreTrainedModel):
  287. config: VivitConfig
  288. base_model_prefix = "vivit"
  289. main_input_name = "pixel_values"
  290. input_modalities = "video"
  291. supports_gradient_checkpointing = True
  292. _no_split_modules = ["VivitLayer"]
  293. _supports_sdpa = True
  294. _supports_flash_attn = True
  295. _supports_flex_attn = True
  296. _supports_attention_backend = True
  297. _can_record_outputs = {
  298. "hidden_states": VivitLayer,
  299. "attentions": VivitSelfAttention,
  300. }
  301. @torch.no_grad()
  302. def _init_weights(self, module):
  303. """Initialize the weights"""
  304. super()._init_weights(module)
  305. if isinstance(module, VivitEmbeddings):
  306. init.zeros_(module.cls_token)
  307. init.zeros_(module.position_embeddings)
  308. @auto_docstring
  309. class VivitModel(VivitPreTrainedModel):
  310. def __init__(self, config: VivitConfig, add_pooling_layer: bool = True):
  311. r"""
  312. add_pooling_layer (bool, *optional*, defaults to `True`):
  313. Whether to add a pooling layer
  314. """
  315. super().__init__(config)
  316. self.config = config
  317. self.embeddings = VivitEmbeddings(config)
  318. self.encoder = VivitEncoder(config)
  319. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  320. self.pooler = VivitPooler(config) if add_pooling_layer else None
  321. # Initialize weights and apply final processing
  322. self.post_init()
  323. def get_input_embeddings(self):
  324. return self.embeddings.patch_embeddings
  325. @merge_with_config_defaults
  326. @capture_outputs(tie_last_hidden_states=False)
  327. @auto_docstring
  328. def forward(
  329. self,
  330. pixel_values: torch.FloatTensor | None = None,
  331. interpolate_pos_encoding: bool = False,
  332. **kwargs: Unpack[TransformersKwargs],
  333. ) -> BaseModelOutputWithPooling:
  334. r"""
  335. Examples:
  336. ```python
  337. >>> import av
  338. >>> import numpy as np
  339. >>> from transformers import VivitImageProcessor, VivitModel
  340. >>> from huggingface_hub import hf_hub_download
  341. >>> np.random.seed(0)
  342. >>> def read_video_pyav(container, indices):
  343. ... '''
  344. ... Decode the video with PyAV decoder.
  345. ... Args:
  346. ... container (`av.container.input.InputContainer`): PyAV container.
  347. ... indices (`list[int]`): List of frame indices to decode.
  348. ... Returns:
  349. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  350. ... '''
  351. ... frames = []
  352. ... container.seek(0)
  353. ... start_index = indices[0]
  354. ... end_index = indices[-1]
  355. ... for i, frame in enumerate(container.decode(video=0)):
  356. ... if i > end_index:
  357. ... break
  358. ... if i >= start_index and i in indices:
  359. ... frames.append(frame)
  360. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  361. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  362. ... '''
  363. ... Sample a given number of frame indices from the video.
  364. ... Args:
  365. ... clip_len (`int`): Total number of frames to sample.
  366. ... frame_sample_rate (`int`): Sample every n-th frame.
  367. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  368. ... Returns:
  369. ... indices (`list[int]`): List of sampled frame indices
  370. ... '''
  371. ... converted_len = int(clip_len * frame_sample_rate)
  372. ... end_idx = np.random.randint(converted_len, seg_len)
  373. ... start_idx = end_idx - converted_len
  374. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  375. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  376. ... return indices
  377. >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
  378. >>> file_path = hf_hub_download(
  379. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  380. ... )
  381. >>> container = av.open(file_path)
  382. >>> # sample 32 frames
  383. >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
  384. >>> video = read_video_pyav(container=container, indices=indices)
  385. >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
  386. >>> model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400")
  387. >>> # prepare video for the model
  388. >>> inputs = image_processor(list(video), return_tensors="pt")
  389. >>> # forward pass
  390. >>> outputs = model(**inputs)
  391. >>> last_hidden_states = outputs.last_hidden_state
  392. >>> list(last_hidden_states.shape)
  393. [1, 3137, 768]
  394. ```"""
  395. if pixel_values is None:
  396. raise ValueError("You have to specify pixel_values")
  397. embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  398. encoder_outputs: BaseModelOutput = self.encoder(embedding_output)
  399. sequence_output = encoder_outputs.last_hidden_state
  400. sequence_output = self.layernorm(sequence_output)
  401. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  402. return BaseModelOutputWithPooling(last_hidden_state=sequence_output, pooler_output=pooled_output)
  403. @auto_docstring(
  404. custom_intro="""
  405. ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
  406. [CLS] token) e.g. for Kinetics-400.
  407. <Tip>
  408. Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
  409. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  410. position embeddings to the higher resolution.
  411. </Tip>
  412. """
  413. )
  414. class VivitForVideoClassification(VivitPreTrainedModel):
  415. def __init__(self, config: VivitConfig):
  416. super().__init__(config)
  417. self.num_labels = config.num_labels
  418. self.vivit = VivitModel(config, add_pooling_layer=False)
  419. # Classifier head
  420. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  421. # Initialize weights and apply final processing
  422. self.post_init()
  423. @can_return_tuple
  424. @auto_docstring
  425. def forward(
  426. self,
  427. pixel_values: torch.FloatTensor | None = None,
  428. labels: torch.LongTensor | None = None,
  429. interpolate_pos_encoding: bool = False,
  430. **kwargs: Unpack[TransformersKwargs],
  431. ) -> ImageClassifierOutput:
  432. r"""
  433. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  434. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  435. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  436. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  437. Examples:
  438. ```python
  439. >>> import av
  440. >>> import numpy as np
  441. >>> import torch
  442. >>> from transformers import VivitImageProcessor, VivitForVideoClassification
  443. >>> from huggingface_hub import hf_hub_download
  444. >>> np.random.seed(0)
  445. >>> def read_video_pyav(container, indices):
  446. ... '''
  447. ... Decode the video with PyAV decoder.
  448. ... Args:
  449. ... container (`av.container.input.InputContainer`): PyAV container.
  450. ... indices (`list[int]`): List of frame indices to decode.
  451. ... Returns:
  452. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  453. ... '''
  454. ... frames = []
  455. ... container.seek(0)
  456. ... start_index = indices[0]
  457. ... end_index = indices[-1]
  458. ... for i, frame in enumerate(container.decode(video=0)):
  459. ... if i > end_index:
  460. ... break
  461. ... if i >= start_index and i in indices:
  462. ... frames.append(frame)
  463. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  464. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  465. ... '''
  466. ... Sample a given number of frame indices from the video.
  467. ... Args:
  468. ... clip_len (`int`): Total number of frames to sample.
  469. ... frame_sample_rate (`int`): Sample every n-th frame.
  470. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  471. ... Returns:
  472. ... indices (`list[int]`): List of sampled frame indices
  473. ... '''
  474. ... converted_len = int(clip_len * frame_sample_rate)
  475. ... end_idx = np.random.randint(converted_len, seg_len)
  476. ... start_idx = end_idx - converted_len
  477. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  478. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  479. ... return indices
  480. >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
  481. >>> file_path = hf_hub_download(
  482. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  483. ... )
  484. >>> container = av.open(file_path)
  485. >>> # sample 32 frames
  486. >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
  487. >>> video = read_video_pyav(container=container, indices=indices)
  488. >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
  489. >>> model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400")
  490. >>> inputs = image_processor(list(video), return_tensors="pt")
  491. >>> with torch.no_grad():
  492. ... outputs = model(**inputs)
  493. ... logits = outputs.logits
  494. >>> # model predicts one of the 400 Kinetics-400 classes
  495. >>> predicted_label = logits.argmax(-1).item()
  496. >>> print(model.config.id2label[predicted_label])
  497. LABEL_116
  498. ```"""
  499. outputs: BaseModelOutput = self.vivit(
  500. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
  501. )
  502. sequence_output = outputs.last_hidden_state
  503. logits = self.classifier(sequence_output[:, 0, :])
  504. loss = None
  505. if labels is not None:
  506. loss = self.loss_function(labels, logits, self.config, **kwargs)
  507. return ImageClassifierOutput(
  508. loss=loss,
  509. logits=logits,
  510. hidden_states=outputs.hidden_states,
  511. attentions=outputs.attentions,
  512. )
  513. __all__ = ["VivitModel", "VivitPreTrainedModel", "VivitForVideoClassification"]