modeling_tvp.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. # Copyright 2023 The Intel AIA Team Authors, and HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License=, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing=, software
  10. # distributed under the License is distributed on an "AS IS" BASIS=,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND=, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch TVP Model"""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...backbone_utils import load_backbone
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
  24. from ...modeling_utils import PreTrainedModel
  25. from ...utils import auto_docstring, logging
  26. from .configuration_tvp import TvpConfig
  27. logger = logging.get_logger(__name__)
  28. @dataclass
  29. @auto_docstring
  30. class TvpVideoGroundingOutput(ModelOutput):
  31. r"""
  32. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  33. Temporal-Distance IoU loss for video grounding.
  34. logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  35. Contains start_time/duration and end_time/duration. It is the time slot of the videos corresponding to the
  36. input texts.
  37. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  38. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  39. sequence_length)`.
  40. """
  41. loss: torch.FloatTensor | None = None
  42. logits: torch.FloatTensor | None = None
  43. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  44. attentions: tuple[torch.FloatTensor, ...] | None = None
  45. class TvpLoss(nn.Module):
  46. """
  47. This class computes the losses for `TvpForVideoGrounding`. The process happens in two steps: 1) we compute
  48. hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched
  49. ground-truth / prediction (supervise class and box).
  50. Args:
  51. losses (`list[str]`):
  52. List of all the losses to be applied.
  53. """
  54. def __init__(self, losses):
  55. super().__init__()
  56. self.loss_map = {
  57. "iou": self.loss_iou,
  58. "distance": self.loss_distance,
  59. "duration": self.loss_duration,
  60. }
  61. for loss in losses:
  62. if loss not in self.loss_map:
  63. raise ValueError(f"Loss {loss} not supported")
  64. self.losses = losses
  65. def loss_iou(self, start_time, end_time, candidates_start_time, candidates_end_time, duration):
  66. """
  67. Measure the intersection over union.
  68. """
  69. inter = torch.min(candidates_end_time, end_time) - torch.max(candidates_start_time, start_time)
  70. union = torch.max(candidates_end_time, end_time) - torch.min(candidates_start_time, start_time)
  71. iou = 1 - inter.clamp(min=0) / union
  72. return iou
  73. def loss_distance(self, start_time, end_time, candidates_start_time, candidates_end_time, duration):
  74. """
  75. Measure the distance of mid points.
  76. """
  77. mid_candidates = torch.div(torch.add(candidates_start_time, candidates_end_time), 2.0)
  78. mid_groundtruth = torch.div(torch.add(start_time, end_time), 2.0)
  79. distance_diff = torch.div(
  80. torch.max(mid_candidates, mid_groundtruth) - torch.min(mid_candidates, mid_groundtruth), duration
  81. ).clamp(min=0.2)
  82. return distance_diff
  83. def loss_duration(self, start_time, end_time, candidates_start_time, candidates_end_time, duration):
  84. """
  85. Measure the difference of duration.
  86. """
  87. duration_candidates = torch.sub(candidates_end_time, candidates_start_time)
  88. duration_groundtruth = torch.sub(end_time, start_time)
  89. duration_diff = torch.square(torch.div(torch.sub(duration_candidates, duration_groundtruth), duration))
  90. duration_diff = duration_diff.clamp(min=0.4)
  91. return duration_diff
  92. def forward(self, logits, labels):
  93. """
  94. This performs the loss computation.
  95. Args:
  96. logits (`torch.FloatTensor`):
  97. The output logits of head module.
  98. labels (`list[torch.FloatTensor]`):
  99. List of tensors ([start, end, duration]), which contains start time, end time of the video corresponding to the text, and also the duration.
  100. """
  101. duration, start_time, end_time = labels
  102. candidates = torch.mul(logits, duration)
  103. candidates_start_time, candidates_end_time = candidates[:, 0].float(), candidates[:, 1].float()
  104. losses_dict = {}
  105. for loss in self.losses:
  106. losses_dict.update(
  107. {loss: self.loss_map[loss](start_time, end_time, candidates_start_time, candidates_end_time, duration)}
  108. )
  109. return losses_dict
  110. class TvpVisionModel(nn.Module):
  111. def __init__(self, config):
  112. super().__init__()
  113. self.backbone = load_backbone(config)
  114. if config.backbone_config is not None:
  115. in_channels = config.backbone_config.hidden_sizes[-1]
  116. elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_sizes"):
  117. in_channels = self.backbone.config.hidden_sizes[-1]
  118. elif hasattr(self.backbone, "config") and hasattr(self.backbone.config, "hidden_size"):
  119. in_channels = self.backbone.config.hidden_size
  120. else:
  121. raise ValueError("Backbone config not found")
  122. self.grid_encoder_conv = nn.Conv2d(
  123. in_channels,
  124. config.hidden_size,
  125. kernel_size=3,
  126. stride=1,
  127. padding=1,
  128. groups=1,
  129. bias=False,
  130. )
  131. def forward(self, pixel_values):
  132. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  133. # (batch_size * num_frames, num_channels, height, width)
  134. pixel_values = pixel_values.view(batch_size * num_frames, num_channels, height, width)
  135. grid_feat_outputs = self.backbone(pixel_values)["feature_maps"][0]
  136. grid = self.grid_encoder_conv(grid_feat_outputs)
  137. grid = nn.functional.max_pool2d(grid, kernel_size=2, stride=2)
  138. grid = nn.functional.relu(grid, inplace=True)
  139. new_channel, new_height, new_width = grid.shape[-3:]
  140. # (batch_size, num_frames, num_channels, height, width)
  141. grid = grid.view(batch_size, num_frames, new_channel, new_height, new_width)
  142. # (batch_size, num_frames, height, width, num_channels)
  143. grid = grid.permute(0, 1, 3, 4, 2)
  144. return grid
  145. class TvpVisualInputEmbedding(nn.Module):
  146. """
  147. Takes input of both image and video (multi-frame)
  148. """
  149. def __init__(self, config):
  150. super().__init__()
  151. # sequence embedding
  152. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  153. self.row_position_embeddings = nn.Embedding(config.max_grid_row_position_embeddings, config.hidden_size)
  154. self.col_position_embeddings = nn.Embedding(config.max_grid_col_position_embeddings, config.hidden_size)
  155. self.token_type_embeddings = nn.Embedding(1, config.hidden_size)
  156. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  157. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  158. self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings
  159. self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings
  160. def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor:
  161. """
  162. This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
  163. resolution images (high resolution videos).
  164. """
  165. h0 = w0 = 1
  166. # if height dimension is to be interpolated
  167. if height > self.max_grid_row_position_embeddings:
  168. h0 = height / self.max_grid_row_position_embeddings
  169. # if width dimension is to be interpolated
  170. if width > self.max_grid_col_position_embeddings:
  171. w0 = width / self.max_grid_col_position_embeddings
  172. embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width)
  173. embedding = nn.functional.interpolate(
  174. embedding,
  175. scale_factor=(h0, w0),
  176. mode="bicubic",
  177. align_corners=False,
  178. )
  179. embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim)
  180. return embedding
  181. def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False):
  182. """
  183. Args:
  184. grid: (batch_size, height, width, hidden_dim)
  185. interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
  186. Whether to interpolate the pre-trained position encodings.
  187. Returns:
  188. grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
  189. """
  190. batch_size, height, width, hidden_dim = grid.shape
  191. # add row-wise position embeddings
  192. # (height, )
  193. row_height = min(self.max_grid_row_position_embeddings, height)
  194. row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device)
  195. # (height, hidden_dim)
  196. row_position_embeddings = self.row_position_embeddings(row_position_ids)
  197. row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim)
  198. # (batch_size, height, 1, hidden_dim)
  199. row_position_embeddings = row_position_embeddings.view(*row_shape)
  200. # add column-wise position embeddings
  201. row_width = min(self.max_grid_col_position_embeddings, width)
  202. col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device)
  203. # (width, hidden_dim)
  204. col_position_embeddings = self.col_position_embeddings(col_position_ids)
  205. col_shape = (batch_size, 1, row_width, hidden_dim)
  206. # (batch_size, 1, width, hidden_dim)
  207. col_position_embeddings = col_position_embeddings.view(*col_shape)
  208. # (batch_size, height, width, hidden_dim)
  209. positional_embeddings = row_position_embeddings + col_position_embeddings
  210. # This interpolation gets triggered ONLY when the input image dim is larger in any dimension than the original position embeddings
  211. if interpolate_pos_encoding and (
  212. height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings
  213. ):
  214. grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width)
  215. else:
  216. grid = grid + positional_embeddings
  217. return grid
  218. def forward(self, grid, interpolate_pos_encoding: bool = False):
  219. """
  220. Args:
  221. grid: Array of shape (batch_size, num_frames, height, width, num_channels).
  222. It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
  223. num_frames can be 1
  224. interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
  225. Whether to interpolate the pre-trained position encodings.
  226. Returns:
  227. embeddings: The embedding of grid with size (batch_size, height*width, num_channels)
  228. """
  229. batch_size, num_frames, height, width, num_channels = grid.shape
  230. # temporal mean pooling, (batch_size, height, width, hidden_size)
  231. grid = grid.mean(1)
  232. grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding)
  233. # image token sequence, (batch_size, height*width, num_channels)
  234. visual_tokens = grid.view(batch_size, -1, num_channels)
  235. visual_tokens_shape = visual_tokens.shape[:-1]
  236. device = visual_tokens.device
  237. # image token type embeddings.
  238. token_type_ids = torch.zeros(visual_tokens_shape, dtype=torch.long, device=device)
  239. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  240. embeddings = visual_tokens + token_type_embeddings
  241. embeddings = self.layer_norm(embeddings)
  242. embeddings = self.dropout(embeddings)
  243. return embeddings
  244. class TvpTextInputEmbeddings(nn.Module):
  245. """Construct the embeddings from word, position and token_type embeddings."""
  246. def __init__(self, config):
  247. super().__init__()
  248. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  249. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  250. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  251. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  252. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  253. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  254. if input_ids is not None:
  255. input_shape = input_ids.size()
  256. else:
  257. input_shape = inputs_embeds.size()[:-1]
  258. seq_length = input_shape[1]
  259. device = input_ids.device if input_ids is not None else inputs_embeds.device
  260. if position_ids is None:
  261. position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
  262. position_ids = position_ids.unsqueeze(0).expand(input_shape)
  263. if token_type_ids is None:
  264. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  265. if inputs_embeds is None:
  266. inputs_embeds = self.word_embeddings(input_ids)
  267. position_embeddings = self.position_embeddings(position_ids)
  268. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  269. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  270. embeddings = self.layer_norm(embeddings)
  271. embeddings = self.dropout(embeddings)
  272. return embeddings
  273. class TvpAttention(nn.Module):
  274. def __init__(self, config):
  275. super().__init__()
  276. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  277. raise ValueError(
  278. f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}"
  279. )
  280. self.num_attention_heads = config.num_attention_heads
  281. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  282. self.all_head_size = self.num_attention_heads * self.attention_head_size
  283. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  284. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  285. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  286. self.attn_dropout = nn.Dropout(config.attention_probs_dropout_prob)
  287. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  288. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  289. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  290. def _reshape(self, tensor: torch.Tensor, sequence_length: int, batch_size: int):
  291. return (
  292. tensor.view(batch_size, sequence_length, self.num_attention_heads, self.attention_head_size)
  293. .transpose(1, 2)
  294. .contiguous()
  295. )
  296. def forward(
  297. self,
  298. hidden_states,
  299. attention_mask=None,
  300. output_attentions: bool | None = None,
  301. ):
  302. batch_size, sequence_length = hidden_states.shape[:2]
  303. mixed_query_layer = self.query(hidden_states)
  304. mixed_key_layer = self.key(hidden_states)
  305. mixed_value_layer = self.value(hidden_states)
  306. query_layer = self._reshape(mixed_query_layer, sequence_length, batch_size)
  307. key_layer = self._reshape(mixed_key_layer, sequence_length, batch_size)
  308. value_layer = self._reshape(mixed_value_layer, sequence_length, batch_size)
  309. # Take the dot product between "query" and "key" to get the raw attention scores.
  310. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  311. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  312. if attention_mask is not None:
  313. attention_scores = attention_scores + attention_mask
  314. # Normalize the attention scores to probabilities.
  315. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  316. # This is actually dropping out entire tokens to attend to, which might
  317. # seem a bit unusual, but is taken from the original Transformer paper.
  318. attention_probs = self.attn_dropout(attention_probs)
  319. attn_output = torch.matmul(attention_probs, value_layer)
  320. attn_output = attn_output.transpose(1, 2).contiguous()
  321. attn_output = attn_output.reshape(batch_size, sequence_length, self.all_head_size)
  322. attn_output = self.dense(attn_output)
  323. attn_output = self.dropout(attn_output)
  324. attn_output = self.layer_norm(attn_output + hidden_states)
  325. # add attentions if we output them
  326. outputs = (attn_output, attention_probs) if output_attentions else (attn_output,)
  327. return outputs
  328. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Tvp
  329. class TvpIntermediate(nn.Module):
  330. def __init__(self, config):
  331. super().__init__()
  332. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  333. if isinstance(config.hidden_act, str):
  334. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  335. else:
  336. self.intermediate_act_fn = config.hidden_act
  337. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  338. hidden_states = self.dense(hidden_states)
  339. hidden_states = self.intermediate_act_fn(hidden_states)
  340. return hidden_states
  341. class TvpOutputLayer(nn.Module):
  342. def __init__(self, config):
  343. super().__init__()
  344. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  345. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  346. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  347. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  348. hidden_states = self.dense(hidden_states)
  349. hidden_states = self.dropout(hidden_states)
  350. hidden_states = self.layer_norm(hidden_states + input_tensor)
  351. return hidden_states
  352. class TvpEncodeLayer(GradientCheckpointingLayer):
  353. def __init__(self, config):
  354. super().__init__()
  355. self.attention = TvpAttention(config)
  356. self.intermediate = TvpIntermediate(config)
  357. self.output = TvpOutputLayer(config)
  358. def forward(
  359. self,
  360. hidden_states,
  361. attention_mask=None,
  362. output_attentions: bool | None = None,
  363. ):
  364. self_attention_outputs = self.attention(
  365. hidden_states,
  366. attention_mask,
  367. output_attentions=output_attentions,
  368. )
  369. attention_output = self_attention_outputs[0]
  370. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  371. intermediate_output = self.intermediate(attention_output)
  372. layer_output = self.output(intermediate_output, attention_output)
  373. outputs = (layer_output,) + outputs
  374. return outputs
  375. class TvpEncoder(nn.Module):
  376. def __init__(self, config):
  377. super().__init__()
  378. self.config = config
  379. self.layer = nn.ModuleList([TvpEncodeLayer(config) for _ in range(config.num_hidden_layers)])
  380. self.gradient_checkpointing = False
  381. def forward(
  382. self,
  383. hidden_states,
  384. attention_mask=None,
  385. output_attentions: bool | None = None,
  386. output_hidden_states: bool | None = None,
  387. return_dict: bool | None = None,
  388. ) -> tuple | BaseModelOutput:
  389. return_dict = return_dict if return_dict is not None else self.config.return_dict
  390. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  391. output_hidden_states = (
  392. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  393. )
  394. all_hidden_states = ()
  395. all_attentions = ()
  396. for i, layer_module in enumerate(self.layer):
  397. if output_hidden_states:
  398. all_hidden_states = all_hidden_states + (hidden_states,)
  399. layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
  400. hidden_states = layer_outputs[0]
  401. if output_attentions:
  402. all_attentions = all_attentions + (layer_outputs[1],)
  403. # Add last layer
  404. if output_hidden_states:
  405. all_hidden_states = all_hidden_states + (hidden_states,)
  406. if not return_dict:
  407. outputs = (hidden_states,)
  408. if output_hidden_states:
  409. outputs = outputs + (all_hidden_states,)
  410. if output_attentions:
  411. outputs = outputs + (all_attentions,)
  412. return outputs # last-layer hidden state, (all hidden states), (all attentions)
  413. return BaseModelOutput(
  414. last_hidden_state=hidden_states,
  415. hidden_states=all_hidden_states if output_hidden_states else None,
  416. attentions=all_attentions if output_attentions else None,
  417. )
  418. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Tvp
  419. class TvpPooler(nn.Module):
  420. def __init__(self, config):
  421. super().__init__()
  422. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  423. self.activation = nn.Tanh()
  424. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  425. # We "pool" the model by simply taking the hidden state corresponding
  426. # to the first token.
  427. first_token_tensor = hidden_states[:, 0]
  428. pooled_output = self.dense(first_token_tensor)
  429. pooled_output = self.activation(pooled_output)
  430. return pooled_output
  431. @auto_docstring
  432. class TvpPreTrainedModel(PreTrainedModel):
  433. config: TvpConfig
  434. base_model_prefix = "model"
  435. input_modalities = ("video", "text")
  436. supports_gradient_checkpointing = True
  437. @torch.no_grad()
  438. def _init_weights(self, module: nn.Module):
  439. """Initialize the weights"""
  440. if isinstance(module, (nn.Linear, nn.Embedding)):
  441. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  442. elif isinstance(module, nn.LayerNorm):
  443. init.zeros_(module.bias)
  444. init.ones_(module.weight)
  445. elif isinstance(module, nn.Conv2d):
  446. init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  447. if module.bias is not None:
  448. init.constant_(module.bias, 0)
  449. elif isinstance(module, TvpModel):
  450. init.normal_(module.text_prompt)
  451. if isinstance(module, nn.Linear) and module.bias is not None:
  452. init.zeros_(module.bias)
  453. if hasattr(module, "pad_up"):
  454. init.normal_(module.pad_up)
  455. if hasattr(module, "pad_down"):
  456. init.normal_(module.pad_down)
  457. if hasattr(module, "pad_left"):
  458. init.normal_(module.pad_left)
  459. if hasattr(module, "pad_right"):
  460. init.normal_(module.pad_right)
  461. class TvpFrameDownPadPrompter(nn.Module):
  462. """
  463. Pad frames extracted from videos only at the bottom.
  464. """
  465. def __init__(self, config):
  466. if config.visual_prompter_apply not in ("add", "replace", "remove"):
  467. raise ValueError("`visual_prompter_apply` must be in (add, replace, remove)")
  468. super().__init__()
  469. self.visual_prompt_size = config.visual_prompt_size
  470. self.frame_num = config.frame_num
  471. self.max_img_size = config.max_img_size
  472. self.visual_prompter_apply = config.visual_prompter_apply
  473. self.pad_down = nn.Parameter(
  474. torch.randn([1, config.frame_num, 3, config.visual_prompt_size, config.max_img_size])
  475. )
  476. def forward(self, pixel_values):
  477. if self.visual_prompter_apply != "add":
  478. visual_prompt_mask = torch.ones(
  479. [self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
  480. )
  481. visual_prompt_mask[self.max_img_size - self.visual_prompt_size : self.max_img_size, :] = 0.0
  482. pixel_values *= visual_prompt_mask
  483. if self.visual_prompter_apply != "remove":
  484. prompt = torch.zeros(
  485. [pixel_values.shape[0], pixel_values.shape[1], 3, self.max_img_size, self.max_img_size],
  486. device=pixel_values.device,
  487. )
  488. start_point = self.max_img_size - self.visual_prompt_size
  489. prompt[:, :, :, start_point : self.max_img_size, :] = self.pad_down
  490. pixel_values += prompt.to(pixel_values.dtype)
  491. return pixel_values
  492. class TvpFramePadPrompter(nn.Module):
  493. """
  494. Pad frames extracted from videos in the surroundings.
  495. """
  496. def __init__(self, config):
  497. if config.visual_prompter_apply not in ("add", "replace", "remove"):
  498. raise ValueError("`visual_prompter_apply` must be in (add, replace, remove)")
  499. super().__init__()
  500. self.num_frames = config.num_frames
  501. self.max_img_size = config.max_img_size
  502. self.visual_prompter_apply = config.visual_prompter_apply
  503. self.base_size = config.max_img_size - config.visual_prompt_size * 2
  504. self.pad_up = nn.Parameter(
  505. torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
  506. )
  507. self.pad_down = nn.Parameter(
  508. torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
  509. )
  510. self.pad_left = nn.Parameter(
  511. torch.randn(
  512. [
  513. 1,
  514. config.num_frames,
  515. 3,
  516. config.max_img_size - config.visual_prompt_size * 2,
  517. config.visual_prompt_size,
  518. ]
  519. )
  520. )
  521. self.pad_right = nn.Parameter(
  522. torch.randn(
  523. [
  524. 1,
  525. config.num_frames,
  526. 3,
  527. config.max_img_size - config.visual_prompt_size * 2,
  528. config.visual_prompt_size,
  529. ]
  530. )
  531. )
  532. def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor:
  533. """
  534. This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
  535. resolution images (high resolution videos).
  536. """
  537. # creates scale factor from height and width of original image wrt to the config.max_img_size
  538. h0, w0 = height / self.max_img_size, width / self.max_img_size
  539. batch, num_frames, channels, prompt_height, prompt_width = prompt.shape
  540. # reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation
  541. prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width)
  542. prompt = nn.functional.interpolate(
  543. prompt,
  544. scale_factor=(h0, w0),
  545. mode="bicubic",
  546. align_corners=False,
  547. )
  548. # reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width
  549. prompt = prompt.reshape(batch, num_frames, channels, height, width)
  550. return prompt
  551. def forward(self, pixel_values, interpolate_pad_encoding: bool = False):
  552. height, width = (
  553. (pixel_values.shape[-2], pixel_values.shape[-1])
  554. if interpolate_pad_encoding
  555. else (self.max_img_size, self.max_img_size)
  556. )
  557. if self.visual_prompter_apply not in ("add", "remove", "replace"):
  558. raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
  559. if self.visual_prompter_apply in ("replace", "remove"):
  560. visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device)
  561. pixel_values *= visual_prompt_mask
  562. if self.visual_prompter_apply in ("replace", "add"):
  563. base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
  564. prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
  565. prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
  566. prompt = torch.cat(pixel_values.size(0) * [prompt])
  567. if interpolate_pad_encoding:
  568. prompt = self.interpolate_pad_encoding(prompt, height, width)
  569. pixel_values = pixel_values + prompt.to(pixel_values.dtype)
  570. return pixel_values
  571. TVP_PROMPTER_CLASSES_MAPPING = {
  572. "framedownpad": TvpFrameDownPadPrompter,
  573. "framepad": TvpFramePadPrompter,
  574. }
  575. @auto_docstring(
  576. custom_intro="""
  577. The bare Tvp Model transformer outputting BaseModelOutputWithPooling object without any specific head on top.
  578. """
  579. )
  580. class TvpModel(TvpPreTrainedModel):
  581. def __init__(self, config):
  582. super().__init__(config)
  583. self.config = config
  584. self.vision_model = TvpVisionModel(config)
  585. self.embeddings = TvpTextInputEmbeddings(config)
  586. self.visual_embeddings = TvpVisualInputEmbedding(config)
  587. self.encoder = TvpEncoder(config)
  588. self.pooler = TvpPooler(config)
  589. self.text_prompt = nn.Parameter(torch.randn([1, 10, config.hidden_size]))
  590. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  591. if config.visual_prompter_type not in TVP_PROMPTER_CLASSES_MAPPING:
  592. raise ValueError("`visual_prompter_type` must be in (framedownpad, framepad)")
  593. self.visual_prompter = TVP_PROMPTER_CLASSES_MAPPING[config.visual_prompter_type](config)
  594. self.post_init()
  595. def get_input_embeddings(self):
  596. return self.embeddings.word_embeddings
  597. def set_input_embeddings(self, value):
  598. self.embeddings.word_embeddings = value
  599. @auto_docstring
  600. def forward(
  601. self,
  602. input_ids: torch.LongTensor | None = None,
  603. pixel_values: torch.FloatTensor | None = None,
  604. attention_mask: torch.LongTensor | None = None,
  605. output_attentions: bool | None = None,
  606. output_hidden_states: bool | None = None,
  607. return_dict: bool | None = None,
  608. interpolate_pos_encoding: bool = False,
  609. **kwargs,
  610. ) -> tuple | BaseModelOutputWithPooling:
  611. r"""
  612. Examples:
  613. ```python
  614. >>> import torch
  615. >>> from transformers import AutoConfig, AutoTokenizer, TvpModel
  616. >>> model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp")
  617. >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")
  618. >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
  619. >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
  620. >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
  621. ```"""
  622. return_dict = return_dict if return_dict is not None else self.config.return_dict
  623. # Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features.
  624. pixel_values = self.vision_model(
  625. self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding)
  626. )
  627. # (batch_size, sequence_length, hidden_size)
  628. text_embedding_output = self.embeddings(input_ids=input_ids)
  629. # (batch_size, visual_sequence_length, hidden_size)
  630. visual_embedding_output = self.visual_embeddings(
  631. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
  632. )
  633. if attention_mask is not None:
  634. # (batch_size, visual_sequence_length)
  635. visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2])
  636. pt_mask = torch.ones(attention_mask.shape[0], 10).to(
  637. device=attention_mask.device, dtype=attention_mask.dtype
  638. )
  639. attention_mask = torch.cat([pt_mask, attention_mask, visual_attention_mask], dim=-1)
  640. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  641. # ourselves in which case we just need to make it broadcastable to all heads.
  642. attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size()).to(input_ids.device)
  643. text_prompt = self.text_prompt.expand(text_embedding_output.shape[0], -1, -1)
  644. # (batch_size, sequence_length + visual_sequence_length, hidden_size)
  645. embedding_output = torch.cat([text_prompt, text_embedding_output, visual_embedding_output], dim=1)
  646. encoder_outputs = self.encoder(
  647. embedding_output,
  648. attention_mask=attention_mask,
  649. output_attentions=output_attentions,
  650. output_hidden_states=output_hidden_states,
  651. return_dict=return_dict,
  652. )
  653. last_hidden_state = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
  654. pooled_output = self.pooler(last_hidden_state)
  655. last_hidden_state = self.dropout(last_hidden_state)
  656. pooled_output = self.dropout(pooled_output)
  657. if not return_dict:
  658. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  659. return BaseModelOutputWithPooling(
  660. last_hidden_state=last_hidden_state,
  661. pooler_output=pooled_output,
  662. hidden_states=encoder_outputs.hidden_states,
  663. attentions=encoder_outputs.attentions,
  664. )
  665. class TvpVideoGroundingHead(nn.Module):
  666. def __init__(self, config):
  667. super().__init__()
  668. self.layer_0 = nn.Linear(config.hidden_size, config.hidden_size * 2)
  669. self.layer_1 = nn.Linear(config.hidden_size * 2, 2)
  670. self.activation_0 = nn.ReLU()
  671. self.activation_1 = nn.Sigmoid()
  672. def forward(self, pooler_output):
  673. logits = self.activation_0(self.layer_0(pooler_output))
  674. logits = self.activation_1(self.layer_1(logits))
  675. return logits
  676. @auto_docstring(
  677. custom_intro="""
  678. Tvp Model with a video grounding head on top computing IoU, distance, and duration loss.
  679. """
  680. )
  681. class TvpForVideoGrounding(TvpPreTrainedModel):
  682. def __init__(self, config):
  683. super().__init__(config)
  684. self.config = config
  685. self.model = TvpModel(config)
  686. self.video_grounding_head = TvpVideoGroundingHead(config)
  687. self.post_init()
  688. @auto_docstring
  689. def forward(
  690. self,
  691. input_ids: torch.LongTensor | None = None,
  692. pixel_values: torch.FloatTensor | None = None,
  693. attention_mask: torch.LongTensor | None = None,
  694. labels: tuple[torch.Tensor] | None = None,
  695. output_attentions: bool | None = None,
  696. output_hidden_states: bool | None = None,
  697. return_dict: bool | None = None,
  698. interpolate_pos_encoding: bool = False,
  699. **kwargs,
  700. ) -> tuple | TvpVideoGroundingOutput:
  701. r"""
  702. labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
  703. The labels contains duration, start time, and end time of the video corresponding to the text.
  704. Examples:
  705. ```python
  706. >>> import torch
  707. >>> from transformers import AutoConfig, AutoTokenizer, TvpForVideoGrounding
  708. >>> model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp")
  709. >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")
  710. >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
  711. >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
  712. >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
  713. ```"""
  714. return_dict = return_dict if return_dict is not None else self.config.return_dict
  715. outputs = self.model(
  716. input_ids,
  717. pixel_values,
  718. attention_mask,
  719. output_attentions=output_attentions,
  720. output_hidden_states=output_hidden_states,
  721. return_dict=return_dict,
  722. interpolate_pos_encoding=interpolate_pos_encoding,
  723. )
  724. pooler_output = outputs[1]
  725. logits = self.video_grounding_head(pooler_output)
  726. loss = None
  727. if labels is not None:
  728. criterion = TvpLoss(["iou", "distance", "duration"])
  729. criterion.to(self.device)
  730. loss_dict = criterion(logits, labels)
  731. loss = (
  732. loss_dict["iou"]
  733. + self.config.distance_loss_weight * loss_dict["distance"]
  734. + self.config.duration_loss_weight * loss_dict["duration"]
  735. )
  736. if not return_dict:
  737. outputs = (logits,) + outputs[2:]
  738. if loss is not None:
  739. outputs = (loss,) + outputs
  740. return outputs
  741. return TvpVideoGroundingOutput(
  742. loss=loss,
  743. logits=logits,
  744. hidden_states=outputs.hidden_states,
  745. attentions=outputs.attentions,
  746. )
  747. __all__ = ["TvpModel", "TvpPreTrainedModel", "TvpForVideoGrounding"]