modeling_videomt.py 53 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/videomt/modular_videomt.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_videomt.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2026 the HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import collections.abc
  21. import math
  22. from collections.abc import Callable
  23. from dataclasses import dataclass
  24. import numpy as np
  25. import torch
  26. import torch.nn.functional as F
  27. from torch import Tensor, nn
  28. from ... import initialization as init
  29. from ...activations import ACT2FN
  30. from ...file_utils import ModelOutput, is_scipy_available, requires_backends
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, is_accelerate_available
  35. from ...utils.generic import merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_videomt import VideomtConfig
  38. if is_scipy_available():
  39. from scipy.optimize import linear_sum_assignment
  40. if is_accelerate_available():
  41. from accelerate import PartialState
  42. from accelerate.utils import reduce
  43. class VideomtPatchEmbeddings(nn.Module):
  44. """
  45. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  46. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  47. Transformer.
  48. """
  49. def __init__(self, config):
  50. super().__init__()
  51. image_size, patch_size = config.image_size, config.patch_size
  52. num_channels, hidden_size = config.num_channels, config.hidden_size
  53. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  54. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  55. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  56. self.image_size = image_size
  57. self.patch_size = patch_size
  58. self.num_channels = num_channels
  59. self.num_patches = num_patches
  60. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  61. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  62. num_channels = pixel_values.shape[1]
  63. if num_channels != self.num_channels:
  64. raise ValueError(
  65. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  66. f" Expected {self.num_channels} but got {num_channels}."
  67. )
  68. pixel_values = pixel_values.to(dtype=self.projection.weight.dtype)
  69. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  70. return embeddings
  71. class VideomtEmbeddings(nn.Module):
  72. """
  73. Construct the CLS token, mask token, position and patch embeddings.
  74. """
  75. def __init__(self, config: VideomtConfig) -> None:
  76. super().__init__()
  77. self.config = config
  78. self.patch_size = config.patch_size
  79. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  80. self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
  81. self.patch_embeddings = VideomtPatchEmbeddings(config)
  82. num_patches = self.patch_embeddings.num_patches
  83. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  84. self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS]
  85. self.position_embeddings = nn.Embedding(num_patches, config.hidden_size)
  86. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  87. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  88. def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
  89. if pixel_values.ndim == 5:
  90. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  91. pixel_values = pixel_values.reshape(batch_size * num_frames, num_channels, height, width)
  92. if bool_masked_pos is not None:
  93. bool_masked_pos = bool_masked_pos.reshape(batch_size * num_frames, -1)
  94. elif bool_masked_pos is not None and bool_masked_pos.ndim > 2:
  95. bool_masked_pos = bool_masked_pos.reshape(bool_masked_pos.shape[0], -1)
  96. batch_size = pixel_values.shape[0]
  97. embeddings = self.patch_embeddings(pixel_values)
  98. if bool_masked_pos is not None:
  99. mask = bool_masked_pos.to(device=embeddings.device, dtype=torch.bool).unsqueeze(-1)
  100. embeddings = torch.where(mask, self.mask_token, embeddings)
  101. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  102. register_tokens = self.register_tokens.expand(batch_size, -1, -1)
  103. embeddings = embeddings + self.position_embeddings(self.position_ids)
  104. embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
  105. embeddings = self.dropout(embeddings)
  106. return embeddings
  107. class VideomtMLP(nn.Module):
  108. def __init__(self, config) -> None:
  109. super().__init__()
  110. in_features = out_features = config.hidden_size
  111. hidden_features = int(config.hidden_size * config.mlp_ratio)
  112. self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
  113. if isinstance(config.hidden_act, str):
  114. self.activation = ACT2FN[config.hidden_act]
  115. else:
  116. self.activation = config.hidden_act
  117. self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
  118. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  119. hidden_state = self.fc1(hidden_state)
  120. hidden_state = self.activation(hidden_state)
  121. hidden_state = self.fc2(hidden_state)
  122. return hidden_state
  123. class VideomtGatedMLP(nn.Module):
  124. def __init__(self, config) -> None:
  125. super().__init__()
  126. in_features = out_features = config.hidden_size
  127. hidden_features = int(config.hidden_size * config.mlp_ratio)
  128. hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  129. self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
  130. self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
  131. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  132. hidden_state = self.weights_in(hidden_state)
  133. x1, x2 = hidden_state.chunk(2, dim=-1)
  134. hidden = nn.functional.silu(x1) * x2
  135. return self.weights_out(hidden)
  136. def eager_attention_forward(
  137. module: nn.Module,
  138. query: torch.Tensor,
  139. key: torch.Tensor,
  140. value: torch.Tensor,
  141. attention_mask: torch.Tensor | None,
  142. scaling: float,
  143. dropout: float = 0.0,
  144. **kwargs,
  145. ):
  146. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  147. if attention_mask is not None:
  148. attn_weights = attn_weights + attention_mask
  149. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  150. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  151. attn_output = torch.matmul(attn_weights, value)
  152. attn_output = attn_output.transpose(1, 2).contiguous()
  153. return attn_output, attn_weights
  154. class VideomtAttention(nn.Module):
  155. """Multi-headed attention from 'Attention Is All You Need' paper"""
  156. def __init__(self, config):
  157. super().__init__()
  158. self.config = config
  159. self.embed_dim = config.hidden_size
  160. self.num_heads = config.num_attention_heads
  161. self.head_dim = self.embed_dim // self.num_heads
  162. if self.head_dim * self.num_heads != self.embed_dim:
  163. raise ValueError(
  164. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  165. f" {self.num_heads})."
  166. )
  167. self.scale = self.head_dim**-0.5
  168. self.dropout = config.attention_dropout
  169. self.is_causal = False
  170. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  171. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  172. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  173. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  174. def forward(
  175. self,
  176. hidden_states: torch.Tensor,
  177. attention_mask: torch.Tensor | None = None,
  178. **kwargs,
  179. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  180. """Input shape: Batch x Time x Channel"""
  181. input_shape = hidden_states.shape[:-1]
  182. hidden_shape = (*input_shape, -1, self.head_dim)
  183. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  184. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  185. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  186. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  187. self.config._attn_implementation, eager_attention_forward
  188. )
  189. attn_output, attn_weights = attention_interface(
  190. self,
  191. queries,
  192. keys,
  193. values,
  194. attention_mask,
  195. is_causal=self.is_causal,
  196. scaling=self.scale,
  197. dropout=0.0 if not self.training else self.dropout,
  198. )
  199. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  200. attn_output = self.out_proj(attn_output)
  201. return attn_output, attn_weights
  202. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  203. """
  204. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  205. """
  206. if drop_prob == 0.0 or not training:
  207. return input
  208. keep_prob = 1 - drop_prob
  209. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  210. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  211. random_tensor.floor_() # binarize
  212. output = input.div(keep_prob) * random_tensor
  213. return output
  214. class VideomtDropPath(nn.Module):
  215. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  216. def __init__(self, drop_prob: float | None = None) -> None:
  217. super().__init__()
  218. self.drop_prob = drop_prob
  219. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  220. return drop_path(hidden_states, self.drop_prob, self.training)
  221. def extra_repr(self) -> str:
  222. return f"p={self.drop_prob}"
  223. class VideomtSwiGLUFFN(nn.Module):
  224. def __init__(self, config) -> None:
  225. super().__init__()
  226. in_features = out_features = config.hidden_size
  227. hidden_features = int(config.hidden_size * config.mlp_ratio)
  228. hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  229. self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
  230. self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
  231. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  232. hidden_state = self.weights_in(hidden_state)
  233. x1, x2 = hidden_state.chunk(2, dim=-1)
  234. hidden = nn.functional.silu(x1) * x2
  235. return self.weights_out(hidden)
  236. class VideomtLayer(GradientCheckpointingLayer):
  237. """This corresponds to the Block class in the original implementation."""
  238. def __init__(self, config: VideomtConfig) -> None:
  239. super().__init__()
  240. self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  241. self.attention = VideomtAttention(config)
  242. self.layer_scale1 = VideomtLayerScale(config)
  243. self.drop_path = VideomtDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  244. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  245. if config.use_swiglu_ffn:
  246. self.mlp = VideomtSwiGLUFFN(config)
  247. else:
  248. self.mlp = VideomtMLP(config)
  249. self.layer_scale2 = VideomtLayerScale(config)
  250. def forward(
  251. self,
  252. hidden_states: torch.Tensor,
  253. attention_mask: torch.Tensor | None = None,
  254. ) -> torch.Tensor:
  255. hidden_states_norm = self.norm1(hidden_states)
  256. self_attention_output, _ = self.attention(hidden_states_norm, attention_mask)
  257. self_attention_output = self.layer_scale1(self_attention_output)
  258. # first residual connection
  259. hidden_states = self.drop_path(self_attention_output) + hidden_states
  260. # in Videomt, layernorm is also applied after self-attention
  261. layer_output = self.norm2(hidden_states)
  262. layer_output = self.mlp(layer_output)
  263. layer_output = self.layer_scale2(layer_output)
  264. # second residual connection
  265. layer_output = self.drop_path(layer_output) + hidden_states
  266. return layer_output
  267. class VideomtLayerScale(nn.Module):
  268. def __init__(self, config) -> None:
  269. super().__init__()
  270. self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
  271. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  272. return hidden_state * self.lambda1
  273. @dataclass
  274. @auto_docstring(
  275. custom_intro="""
  276. Class for outputs of [`VideomtForUniversalSegmentationOutput`].
  277. This output can be directly passed to [`~VideomtVideoProcessor.post_process_semantic_segmentation`] or
  278. [`~VideomtVideoProcessor.post_process_instance_segmentation`] or
  279. [`~VideomtVideoProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
  280. [`~VideomtVideoProcessor`] for details regarding usage.
  281. """
  282. )
  283. class VideomtForUniversalSegmentationOutput(ModelOutput):
  284. r"""
  285. loss (`torch.Tensor`, *optional*):
  286. The computed loss, returned when labels are present.
  287. class_queries_logits (`torch.FloatTensor`):
  288. A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
  289. query. Note the `+ 1` is needed because we incorporate the null class.
  290. masks_queries_logits (`torch.FloatTensor`):
  291. A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
  292. query.
  293. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  294. Last hidden states (final feature map) of the last layer.
  295. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  296. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  297. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
  298. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  299. Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  300. sequence_length)`. Self and Cross Attentions weights from transformer decoder.
  301. """
  302. loss: torch.FloatTensor | None = None
  303. class_queries_logits: torch.FloatTensor | None = None
  304. masks_queries_logits: torch.FloatTensor | None = None
  305. last_hidden_state: torch.FloatTensor | None = None
  306. hidden_states: tuple[torch.FloatTensor] | None = None
  307. attentions: tuple[torch.FloatTensor] | None = None
  308. # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
  309. def sample_point(
  310. input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
  311. ) -> torch.Tensor:
  312. """
  313. A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
  314. Args:
  315. input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
  316. A tensor that contains features map on a height * width grid
  317. point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
  318. 2)):
  319. A tensor that contains [0, 1] * [0, 1] normalized point coordinates
  320. add_dim (`bool`):
  321. boolean value to keep track of added dimension
  322. Returns:
  323. point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
  324. height_grid, width_grid):
  325. A tensor that contains features for points in `point_coordinates`.
  326. """
  327. if point_coordinates.dim() == 3:
  328. add_dim = True
  329. point_coordinates = point_coordinates.unsqueeze(2)
  330. # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation
  331. point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
  332. if add_dim:
  333. point_features = point_features.squeeze(3)
  334. return point_features
  335. def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
  336. """
  337. A pair wise version of the dice loss, see `dice_loss` for usage.
  338. Args:
  339. inputs (`torch.Tensor`):
  340. A tensor representing a mask
  341. labels (`torch.Tensor`):
  342. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  343. (0 for the negative class and 1 for the positive class).
  344. Returns:
  345. `torch.Tensor`: The computed loss between each pairs.
  346. """
  347. inputs = inputs.sigmoid().flatten(1)
  348. numerator = 2 * torch.matmul(inputs, labels.T)
  349. # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
  350. denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
  351. loss = 1 - (numerator + 1) / (denominator + 1)
  352. return loss
  353. def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
  354. r"""
  355. A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
  356. Args:
  357. inputs (`torch.Tensor`):
  358. A tensor representing a mask.
  359. labels (`torch.Tensor`):
  360. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  361. (0 for the negative class and 1 for the positive class).
  362. Returns:
  363. loss (`torch.Tensor`): The computed loss between each pairs.
  364. """
  365. height_and_width = inputs.shape[1]
  366. criterion = nn.BCEWithLogitsLoss(reduction="none")
  367. cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
  368. cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
  369. loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
  370. loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
  371. loss = loss_pos + loss_neg
  372. return loss
  373. # Adapted from https://github.com/facebookresearch/Videomt/blob/main/videomt/modeling/matcher.py
  374. class VideomtHungarianMatcher(nn.Module):
  375. """This class computes an assignment between the labels and the predictions of the network.
  376. For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
  377. predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
  378. un-matched (and thus treated as non-objects).
  379. """
  380. def __init__(
  381. self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
  382. ):
  383. """Creates the matcher
  384. Params:
  385. cost_class (`float`, *optional*, defaults to 1.0):
  386. Relative weight of the classification error in the matching cost.
  387. cost_mask (`float`, *optional*, defaults to 1.0):
  388. This is the relative weight of the focal loss of the binary mask in the matching cost.
  389. cost_dice (`float`, *optional*, defaults to 1.0):
  390. This is the relative weight of the dice loss of the binary mask in the matching cost.
  391. num_points (`int`, *optional*, defaults to 12544):
  392. No. of points to sample on which the mask loss will be calculated. The same set of K points are
  393. uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite
  394. matching.
  395. """
  396. super().__init__()
  397. if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
  398. raise ValueError("All costs can't be 0")
  399. self.num_points = num_points
  400. self.cost_class = cost_class
  401. self.cost_mask = cost_mask
  402. self.cost_dice = cost_dice
  403. @torch.no_grad()
  404. def forward(
  405. self,
  406. masks_queries_logits: torch.Tensor,
  407. class_queries_logits: torch.Tensor,
  408. mask_labels: torch.Tensor,
  409. class_labels: torch.Tensor,
  410. ) -> list[tuple[Tensor]]:
  411. """
  412. Params:
  413. masks_queries_logits (`torch.Tensor`):
  414. A tensor of dim `batch_size, num_queries, num_labels` with the classification logits.
  415. class_queries_logits (`torch.Tensor`):
  416. A tensor of dim `batch_size, num_queries, height, width` with the predicted masks.
  417. class_labels (`torch.Tensor`):
  418. A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the
  419. target) containing the class labels.
  420. mask_labels (`torch.Tensor`):
  421. A tensor of dim `num_target_boxes, height, width` containing the target masks.
  422. Returns:
  423. matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j)
  424. where:
  425. - index_i is the indices of the selected predictions (in order)
  426. - index_j is the indices of the corresponding selected labels (in order)
  427. For each batch element, it holds:
  428. len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
  429. """
  430. indices: list[tuple[np.array]] = []
  431. # iterate through batch size
  432. batch_size = masks_queries_logits.shape[0]
  433. for i in range(batch_size):
  434. pred_probs = class_queries_logits[i].softmax(-1)
  435. pred_mask = masks_queries_logits[i]
  436. # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted.
  437. cost_class = -pred_probs[:, class_labels[i]]
  438. target_mask = mask_labels[i].to(pred_mask)
  439. target_mask = target_mask[:, None]
  440. pred_mask = pred_mask[:, None]
  441. # Sample ground truth and predicted masks
  442. point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device)
  443. target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1)
  444. target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1)
  445. pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1)
  446. pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1)
  447. # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels)
  448. cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
  449. # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels)
  450. cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
  451. # final cost matrix
  452. cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
  453. # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible``
  454. cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
  455. cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
  456. cost_matrix = torch.nan_to_num(cost_matrix, 0)
  457. # do the assignment using the hungarian algorithm in scipy
  458. assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
  459. indices.append(assigned_indices)
  460. # It could be stacked in one tensor
  461. matched_indices = [
  462. (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
  463. ]
  464. return matched_indices
  465. def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
  466. r"""
  467. Compute the DICE loss, similar to generalized IOU for masks as follows:
  468. $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
  469. In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
  470. $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
  471. Args:
  472. inputs (`torch.Tensor`):
  473. A tensor representing a mask.
  474. labels (`torch.Tensor`):
  475. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  476. (0 for the negative class and 1 for the positive class).
  477. num_masks (`int`):
  478. The number of masks present in the current batch, used for normalization.
  479. Returns:
  480. `torch.Tensor`: The computed loss.
  481. """
  482. probs = inputs.sigmoid().flatten(1)
  483. numerator = 2 * (probs * labels).sum(-1)
  484. denominator = probs.sum(-1) + labels.sum(-1)
  485. loss = 1 - (numerator + 1) / (denominator + 1)
  486. loss = loss.sum() / num_masks
  487. return loss
  488. def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
  489. r"""
  490. Args:
  491. inputs (`torch.Tensor`):
  492. A float tensor of arbitrary shape.
  493. labels (`torch.Tensor`):
  494. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  495. (0 for the negative class and 1 for the positive class).
  496. Returns:
  497. loss (`torch.Tensor`): The computed loss.
  498. """
  499. criterion = nn.BCEWithLogitsLoss(reduction="none")
  500. cross_entropy_loss = criterion(inputs, labels)
  501. loss = cross_entropy_loss.mean(1).sum() / num_masks
  502. return loss
  503. # Adapted from https://github.com/facebookresearch/Videomt/blob/main/videomt/modeling/criterion.py
  504. class VideomtLoss(nn.Module):
  505. def __init__(self, config: VideomtConfig, weight_dict: dict[str, float]):
  506. """
  507. The Videomt Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we
  508. compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair
  509. of matched ground-truth / prediction (supervise class and mask)
  510. Args:
  511. config (`VideomtConfig`):
  512. The configuration for Videomt model also containing loss calculation specific parameters.
  513. weight_dict (`dict[str, float]`):
  514. A dictionary of weights to be applied to the different losses.
  515. """
  516. super().__init__()
  517. requires_backends(self, ["scipy"])
  518. self.num_labels = config.num_labels
  519. self.weight_dict = weight_dict
  520. # Weight to apply to the null class
  521. self.eos_coef = config.no_object_weight
  522. empty_weight = torch.ones(self.num_labels + 1)
  523. empty_weight[-1] = self.eos_coef
  524. self.register_buffer("empty_weight", empty_weight)
  525. # pointwise mask loss parameters
  526. self.num_points = config.train_num_points
  527. self.oversample_ratio = config.oversample_ratio
  528. self.importance_sample_ratio = config.importance_sample_ratio
  529. self.matcher = VideomtHungarianMatcher(
  530. cost_class=config.class_weight,
  531. cost_dice=config.dice_weight,
  532. cost_mask=config.mask_weight,
  533. num_points=self.num_points,
  534. )
  535. def _max_by_axis(self, sizes: list[list[int]]) -> list[int]:
  536. maxes = sizes[0]
  537. for sublist in sizes[1:]:
  538. for index, item in enumerate(sublist):
  539. maxes[index] = max(maxes[index], item)
  540. return maxes
  541. # Adapted from nested_tensor_from_tensor_list() in original implementation
  542. def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
  543. # get the maximum size in the batch
  544. max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
  545. # compute final size
  546. batch_shape = [len(tensors)] + max_size
  547. batch_size, _, height, width = batch_shape
  548. dtype = tensors[0].dtype
  549. device = tensors[0].device
  550. padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
  551. padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
  552. # pad the tensors to the size of the biggest one
  553. for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
  554. padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
  555. padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
  556. return padded_tensors, padding_masks
  557. def loss_labels(
  558. self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array]
  559. ) -> dict[str, Tensor]:
  560. """Compute the losses related to the labels using cross entropy.
  561. Args:
  562. class_queries_logits (`torch.Tensor`):
  563. A tensor of shape `batch_size, num_queries, num_labels`
  564. class_labels (`list[torch.Tensor]`):
  565. List of class labels of shape `(labels)`.
  566. indices (`tuple[np.array])`:
  567. The indices computed by the Hungarian matcher.
  568. Returns:
  569. `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
  570. - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
  571. """
  572. pred_logits = class_queries_logits
  573. batch_size, num_queries, _ = pred_logits.shape
  574. criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
  575. idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries)
  576. target_classes_o = torch.cat(
  577. [target[j] for target, (_, j) in zip(class_labels, indices)]
  578. ) # shape of (batch_size, num_queries)
  579. target_classes = torch.full(
  580. (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
  581. )
  582. target_classes[idx] = target_classes_o
  583. # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries)
  584. pred_logits_transposed = pred_logits.transpose(1, 2)
  585. loss_ce = criterion(pred_logits_transposed, target_classes)
  586. losses = {"loss_cross_entropy": loss_ce}
  587. return losses
  588. def loss_masks(
  589. self,
  590. masks_queries_logits: torch.Tensor,
  591. mask_labels: list[torch.Tensor],
  592. indices: tuple[np.array],
  593. num_masks: int,
  594. ) -> dict[str, torch.Tensor]:
  595. """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss.
  596. Args:
  597. masks_queries_logits (`torch.Tensor`):
  598. A tensor of shape `(batch_size, num_queries, height, width)`.
  599. mask_labels (`torch.Tensor`):
  600. List of mask labels of shape `(labels, height, width)`.
  601. indices (`tuple[np.array])`:
  602. The indices computed by the Hungarian matcher.
  603. num_masks (`int)`:
  604. The number of masks, used for normalization.
  605. Returns:
  606. losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys:
  607. - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth.
  608. masks.
  609. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth,
  610. masks.
  611. """
  612. src_idx = self._get_predictions_permutation_indices(indices)
  613. tgt_idx = self._get_targets_permutation_indices(indices)
  614. # shape (batch_size * num_queries, height, width)
  615. pred_masks = masks_queries_logits[src_idx]
  616. # shape (batch_size, num_queries, height, width)
  617. # pad all and stack the targets to the num_labels dimension
  618. target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
  619. target_masks = target_masks[tgt_idx]
  620. # No need to upsample predictions as we are using normalized coordinates
  621. pred_masks = pred_masks[:, None]
  622. target_masks = target_masks[:, None]
  623. # Sample point coordinates
  624. with torch.no_grad():
  625. point_coordinates = self.sample_points_using_uncertainty(
  626. pred_masks,
  627. lambda logits: self.calculate_uncertainty(logits),
  628. self.num_points,
  629. self.oversample_ratio,
  630. self.importance_sample_ratio,
  631. )
  632. point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1)
  633. point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1)
  634. losses = {
  635. "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
  636. "loss_dice": dice_loss(point_logits, point_labels, num_masks),
  637. }
  638. del pred_masks
  639. del target_masks
  640. return losses
  641. def _get_predictions_permutation_indices(self, indices):
  642. # Permute predictions following indices
  643. batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  644. predictions_indices = torch.cat([src for (src, _) in indices])
  645. return batch_indices, predictions_indices
  646. def _get_targets_permutation_indices(self, indices):
  647. # Permute labels following indices
  648. batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  649. target_indices = torch.cat([tgt for (_, tgt) in indices])
  650. return batch_indices, target_indices
  651. def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
  652. """
  653. In Videomt paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits'
  654. for the foreground class in `classes`.
  655. Args:
  656. logits (`torch.Tensor`):
  657. A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is:
  658. the number of foreground classes. The values are logits.
  659. Returns:
  660. scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most
  661. uncertain locations having the highest uncertainty score.
  662. """
  663. uncertainty_scores = -(torch.abs(logits))
  664. return uncertainty_scores
  665. def sample_points_using_uncertainty(
  666. self,
  667. logits: torch.Tensor,
  668. uncertainty_function,
  669. num_points: int,
  670. oversample_ratio: int,
  671. importance_sample_ratio: float,
  672. ) -> torch.Tensor:
  673. """
  674. This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
  675. uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
  676. prediction as input.
  677. Args:
  678. logits (`float`):
  679. Logit predictions for P points.
  680. uncertainty_function:
  681. A function that takes logit predictions for P points and returns their uncertainties.
  682. num_points (`int`):
  683. The number of points P to sample.
  684. oversample_ratio (`int`):
  685. Oversampling parameter.
  686. importance_sample_ratio (`float`):
  687. Ratio of points that are sampled via importance sampling.
  688. Returns:
  689. point_coordinates (`torch.Tensor`):
  690. Coordinates for P sampled points.
  691. """
  692. num_boxes = logits.shape[0]
  693. num_points_sampled = int(num_points * oversample_ratio)
  694. # Get random point coordinates
  695. point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
  696. # Get sampled prediction value for the point coordinates
  697. point_logits = sample_point(logits, point_coordinates, align_corners=False)
  698. # Calculate the uncertainties based on the sampled prediction values of the points
  699. point_uncertainties = uncertainty_function(point_logits)
  700. num_uncertain_points = int(importance_sample_ratio * num_points)
  701. num_random_points = num_points - num_uncertain_points
  702. idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
  703. shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
  704. idx += shift[:, None]
  705. point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
  706. if num_random_points > 0:
  707. point_coordinates = torch.cat(
  708. [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
  709. dim=1,
  710. )
  711. return point_coordinates
  712. def forward(
  713. self,
  714. masks_queries_logits: torch.Tensor,
  715. class_queries_logits: torch.Tensor,
  716. mask_labels: list[torch.Tensor],
  717. class_labels: list[torch.Tensor],
  718. auxiliary_predictions: dict[str, torch.Tensor] | None = None,
  719. ) -> dict[str, torch.Tensor]:
  720. """
  721. This performs the loss computation.
  722. Args:
  723. masks_queries_logits (`torch.Tensor`):
  724. A tensor of shape `(batch_size, num_queries, height, width)`.
  725. class_queries_logits (`torch.Tensor`):
  726. A tensor of shape `(batch_size, num_queries, num_labels)`.
  727. mask_labels (`torch.Tensor`):
  728. List of mask labels of shape `(labels, height, width)`.
  729. class_labels (`list[torch.Tensor]`):
  730. List of class labels of shape `(labels)`.
  731. auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*):
  732. if `use_auxiliary_loss` was set to `true` in [`VideomtConfig`], then it contains the logits from
  733. the inner layers of the VideomtMaskedAttentionDecoder.
  734. Returns:
  735. losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys:
  736. - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
  737. - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth
  738. masks.
  739. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
  740. masks.
  741. if `use_auxiliary_loss` was set to `true` in [`VideomtConfig`], the dictionary contains additional
  742. losses for each auxiliary predictions.
  743. """
  744. # retrieve the matching between the outputs of the last layer and the labels
  745. indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
  746. # compute the average number of target masks for normalization purposes
  747. num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
  748. # get all the losses
  749. losses: dict[str, Tensor] = {
  750. **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
  751. **self.loss_labels(class_queries_logits, class_labels, indices),
  752. }
  753. # in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  754. if auxiliary_predictions is not None:
  755. for idx, aux_outputs in enumerate(auxiliary_predictions):
  756. masks_queries_logits = aux_outputs["masks_queries_logits"]
  757. class_queries_logits = aux_outputs["class_queries_logits"]
  758. loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
  759. loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
  760. losses.update(loss_dict)
  761. return losses
  762. def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
  763. """
  764. Computes the average number of target masks across the batch, for normalization purposes.
  765. """
  766. num_masks = sum(len(classes) for classes in class_labels)
  767. num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
  768. world_size = 1
  769. if is_accelerate_available():
  770. if PartialState._shared_state != {}:
  771. num_masks = reduce(num_masks)
  772. world_size = PartialState().num_processes
  773. num_masks = torch.clamp(num_masks / world_size, min=1)
  774. return num_masks
  775. @auto_docstring
  776. class VideomtPreTrainedModel(PreTrainedModel):
  777. """
  778. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  779. models.
  780. """
  781. config: VideomtConfig
  782. base_model_prefix = "videomt"
  783. main_input_name = "pixel_values_videos"
  784. input_modalities = ("video",)
  785. supports_gradient_checkpointing = False
  786. _no_split_modules = ["VideomtLayer"]
  787. _supports_sdpa = True
  788. _can_record_outputs = {
  789. "hidden_states": VideomtLayer,
  790. "attentions": VideomtAttention,
  791. }
  792. @torch.no_grad()
  793. def _init_weights(self, module: nn.Module) -> None:
  794. std = self.config.initializer_range
  795. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  796. init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  797. if module.bias is not None:
  798. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
  799. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  800. init.uniform_(module.bias, -bound, bound)
  801. elif isinstance(module, nn.LayerNorm):
  802. init.ones_(module.weight)
  803. init.zeros_(module.bias)
  804. elif isinstance(module, nn.Embedding):
  805. init.normal_(module.weight, mean=0.0, std=1)
  806. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  807. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  808. init.zeros_(module.weight[module.padding_idx])
  809. elif isinstance(module, VideomtLayerScale):
  810. if hasattr(module, "lambda1"):
  811. init.constant_(module.lambda1, self.config.layerscale_value)
  812. elif isinstance(module, VideomtEmbeddings):
  813. init.trunc_normal_(module.cls_token, mean=0.0, std=std)
  814. init.zeros_(module.register_tokens)
  815. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  816. elif isinstance(module, VideomtLoss):
  817. empty_weight = torch.ones(module.num_labels + 1)
  818. empty_weight[-1] = module.eos_coef
  819. init.copy_(module.empty_weight, empty_weight)
  820. elif isinstance(module, VideomtForUniversalSegmentation):
  821. init.ones_(module.attn_mask_probs)
  822. if isinstance(module, VideomtEmbeddings):
  823. nn.init.zeros_(module.mask_token)
  824. class VideomtLayerNorm2d(nn.LayerNorm):
  825. def __init__(self, num_channels, eps=1e-6, affine=True):
  826. super().__init__(num_channels, eps=eps, elementwise_affine=affine)
  827. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  828. hidden_state = hidden_state.permute(0, 2, 3, 1)
  829. hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
  830. hidden_state = hidden_state.permute(0, 3, 1, 2)
  831. return hidden_state
  832. class VideomtScaleLayer(nn.Module):
  833. def __init__(self, config: VideomtConfig):
  834. super().__init__()
  835. hidden_size = config.hidden_size
  836. self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
  837. self.activation = ACT2FN[config.hidden_act]
  838. self.conv2 = nn.Conv2d(
  839. hidden_size,
  840. hidden_size,
  841. kernel_size=3,
  842. padding=1,
  843. groups=hidden_size,
  844. bias=False,
  845. )
  846. self.layernorm2d = VideomtLayerNorm2d(hidden_size)
  847. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  848. hidden_states = self.conv1(hidden_states)
  849. hidden_states = self.activation(hidden_states)
  850. hidden_states = self.conv2(hidden_states)
  851. hidden_states = self.layernorm2d(hidden_states)
  852. return hidden_states
  853. class VideomtScaleBlock(nn.Module):
  854. def __init__(self, config: VideomtConfig):
  855. super().__init__()
  856. self.num_blocks = config.num_upscale_blocks
  857. self.block = nn.ModuleList([VideomtScaleLayer(config) for _ in range(self.num_blocks)])
  858. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  859. for block in self.block:
  860. hidden_states = block(hidden_states)
  861. return hidden_states
  862. class VideomtMaskHead(nn.Module):
  863. def __init__(self, config: VideomtConfig):
  864. super().__init__()
  865. hidden_size = config.hidden_size
  866. self.fc1 = nn.Linear(hidden_size, hidden_size)
  867. self.fc2 = nn.Linear(hidden_size, hidden_size)
  868. self.fc3 = nn.Linear(hidden_size, hidden_size)
  869. self.activation = ACT2FN[config.hidden_act]
  870. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  871. hidden_states = self.activation(self.fc1(hidden_states))
  872. hidden_states = self.activation(self.fc2(hidden_states))
  873. hidden_states = self.fc3(hidden_states)
  874. return hidden_states
  875. @auto_docstring(
  876. custom_intro="""
  877. The Videomt Model with head on top for instance/semantic/panoptic segmentation.
  878. """
  879. )
  880. class VideomtForUniversalSegmentation(VideomtPreTrainedModel):
  881. main_input_name = "pixel_values_videos"
  882. def __init__(self, config: VideomtConfig):
  883. super().__init__(config)
  884. self.config = config
  885. self.num_hidden_layers = config.num_hidden_layers
  886. self.embeddings = VideomtEmbeddings(config)
  887. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  888. self.query = nn.Embedding(config.num_queries, config.hidden_size)
  889. self.layers = nn.ModuleList([VideomtLayer(config) for _ in range(config.num_hidden_layers)])
  890. self.upscale_block = VideomtScaleBlock(config)
  891. self.mask_head = VideomtMaskHead(config)
  892. self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
  893. self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  894. self.weight_dict: dict[str, float] = {
  895. "loss_cross_entropy": config.class_weight,
  896. "loss_mask": config.mask_weight,
  897. "loss_dice": config.dice_weight,
  898. }
  899. self.criterion = VideomtLoss(config=config, weight_dict=self.weight_dict)
  900. self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
  901. self.query_updater = nn.Linear(config.hidden_size, config.hidden_size)
  902. self.post_init()
  903. def get_loss_dict(
  904. self,
  905. masks_queries_logits: Tensor,
  906. class_queries_logits: Tensor,
  907. mask_labels: Tensor,
  908. class_labels: Tensor,
  909. auxiliary_predictions: dict[str, Tensor],
  910. ) -> dict[str, Tensor]:
  911. loss_dict: dict[str, Tensor] = self.criterion(
  912. masks_queries_logits=masks_queries_logits,
  913. class_queries_logits=class_queries_logits,
  914. mask_labels=mask_labels,
  915. class_labels=class_labels,
  916. auxiliary_predictions=auxiliary_predictions,
  917. )
  918. # weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
  919. for key, weight in self.weight_dict.items():
  920. for loss_key, loss in loss_dict.items():
  921. if key in loss_key:
  922. loss *= weight
  923. return loss_dict
  924. def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor:
  925. return sum(loss_dict.values())
  926. @merge_with_config_defaults
  927. @capture_outputs
  928. @auto_docstring
  929. def forward(
  930. self,
  931. pixel_values_videos: torch.Tensor | None = None,
  932. mask_labels: list[torch.Tensor] | None = None,
  933. class_labels: list[torch.Tensor] | None = None,
  934. patch_offsets: list[torch.Tensor] | None = None, # Unused, kept for modular compatibility.
  935. **kwargs: Unpack[TransformersKwargs],
  936. ) -> VideomtForUniversalSegmentationOutput:
  937. r"""
  938. pixel_values_videos (`torch.Tensor`, *optional*):
  939. Video inputs of shape `(batch_size, num_frames, num_channels, height, width)`.
  940. mask_labels (`list[torch.Tensor]`, *optional*):
  941. Not supported for 5D video inputs.
  942. class_labels (`list[torch.LongTensor]`, *optional*):
  943. Not supported for 5D video inputs.
  944. patch_offsets (`list[torch.Tensor]`, *optional*):
  945. Unused for video inputs and only kept for modular compatibility.
  946. """
  947. if "pixel_values" in kwargs:
  948. raise ValueError("Use `pixel_values_videos` with `VideomtForUniversalSegmentation`.")
  949. if pixel_values_videos is None:
  950. raise ValueError("You have to specify pixel_values_videos")
  951. if pixel_values_videos.ndim != 5:
  952. raise ValueError(
  953. "VideomtForUniversalSegmentation only supports 5D video inputs of shape "
  954. "(batch_size, num_frames, channels, height, width)."
  955. )
  956. if mask_labels is not None or class_labels is not None:
  957. raise ValueError(
  958. "Training with 5D video inputs is not supported in `VideomtForUniversalSegmentation`. "
  959. "Flatten frames and use `EomtForUniversalSegmentation` instead."
  960. )
  961. batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
  962. flat_pixel_values = pixel_values_videos.reshape(batch_size * num_frames, num_channels, height, width)
  963. hidden_states = self.embeddings(flat_pixel_values)
  964. query_start_idx = self.num_hidden_layers - self.config.num_blocks
  965. for layer_module in self.layers[:query_start_idx]:
  966. hidden_states = layer_module(hidden_states)
  967. hidden_states = hidden_states.view(batch_size, num_frames, hidden_states.shape[1], hidden_states.shape[2])
  968. all_masks_queries_logits = []
  969. all_class_queries_logits = []
  970. all_last_hidden_states = []
  971. propagated_query = None
  972. for frame_idx in range(num_frames):
  973. frame_hidden_states = hidden_states[:, frame_idx]
  974. if propagated_query is None:
  975. query_tokens = self.query.weight[None, :, :].expand(batch_size, -1, -1)
  976. else:
  977. query_tokens = self.query_updater(propagated_query) + self.query.weight[None, :, :].to(
  978. frame_hidden_states.device
  979. )
  980. frame_hidden_states = torch.cat((query_tokens.to(frame_hidden_states.device), frame_hidden_states), dim=1)
  981. for layer_module in self.layers[query_start_idx:]:
  982. frame_hidden_states = layer_module(frame_hidden_states)
  983. sequence_output = self.layernorm(frame_hidden_states)
  984. masks_queries_logits, class_queries_logits = self.predict(sequence_output)
  985. all_masks_queries_logits.append(masks_queries_logits)
  986. all_class_queries_logits.append(class_queries_logits)
  987. all_last_hidden_states.append(sequence_output)
  988. propagated_query = frame_hidden_states[:, : self.config.num_queries, :]
  989. return VideomtForUniversalSegmentationOutput(
  990. loss=None, # Training not supported yet
  991. masks_queries_logits=torch.cat(all_masks_queries_logits, dim=0),
  992. class_queries_logits=torch.cat(all_class_queries_logits, dim=0),
  993. last_hidden_state=torch.cat(all_last_hidden_states, dim=0),
  994. )
  995. def get_input_embeddings(self):
  996. return self.embeddings.patch_embeddings
  997. def predict(self, logits: torch.Tensor):
  998. query_tokens = logits[:, : self.config.num_queries, :]
  999. class_logits = self.class_predictor(query_tokens)
  1000. prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
  1001. prefix_tokens = prefix_tokens.transpose(1, 2)
  1002. prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
  1003. query_tokens = self.mask_head(query_tokens)
  1004. prefix_tokens = self.upscale_block(prefix_tokens)
  1005. mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
  1006. return mask_logits, class_logits
  1007. __all__ = ["VideomtPreTrainedModel", "VideomtForUniversalSegmentation"]