modular_videomt.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright 2026 the HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from dataclasses import dataclass
  15. import torch
  16. from torch import nn
  17. from ...file_utils import ModelOutput
  18. from ...processing_utils import Unpack
  19. from ...utils import TransformersKwargs, auto_docstring
  20. from ..eomt.configuration_eomt import EomtConfig
  21. from ..eomt.modeling_eomt import (
  22. EomtEmbeddings,
  23. EomtForUniversalSegmentation,
  24. EomtLayer,
  25. EomtLayerNorm2d,
  26. EomtLayerScale,
  27. EomtMLP,
  28. EomtPatchEmbeddings,
  29. EomtPreTrainedModel,
  30. EomtScaleBlock,
  31. EomtScaleLayer,
  32. EomtSwiGLUFFN,
  33. )
  34. class VideomtConfig(EomtConfig):
  35. model_type = "videomt"
  36. class VideomtPatchEmbeddings(EomtPatchEmbeddings):
  37. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  38. num_channels = pixel_values.shape[1]
  39. if num_channels != self.num_channels:
  40. raise ValueError(
  41. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  42. f" Expected {self.num_channels} but got {num_channels}."
  43. )
  44. pixel_values = pixel_values.to(dtype=self.projection.weight.dtype)
  45. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  46. return embeddings
  47. class VideomtEmbeddings(EomtEmbeddings):
  48. def __init__(self, config: VideomtConfig):
  49. super().__init__(config)
  50. self.patch_embeddings = VideomtPatchEmbeddings(config)
  51. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  52. def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
  53. if pixel_values.ndim == 5:
  54. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  55. pixel_values = pixel_values.reshape(batch_size * num_frames, num_channels, height, width)
  56. if bool_masked_pos is not None:
  57. bool_masked_pos = bool_masked_pos.reshape(batch_size * num_frames, -1)
  58. elif bool_masked_pos is not None and bool_masked_pos.ndim > 2:
  59. bool_masked_pos = bool_masked_pos.reshape(bool_masked_pos.shape[0], -1)
  60. batch_size = pixel_values.shape[0]
  61. embeddings = self.patch_embeddings(pixel_values)
  62. if bool_masked_pos is not None:
  63. mask = bool_masked_pos.to(device=embeddings.device, dtype=torch.bool).unsqueeze(-1)
  64. embeddings = torch.where(mask, self.mask_token, embeddings)
  65. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  66. register_tokens = self.register_tokens.expand(batch_size, -1, -1)
  67. embeddings = embeddings + self.position_embeddings(self.position_ids)
  68. embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
  69. embeddings = self.dropout(embeddings)
  70. return embeddings
  71. class VideomtMLP(EomtMLP):
  72. pass
  73. class VideomtGatedMLP(EomtSwiGLUFFN):
  74. pass
  75. class VideomtLayer(EomtLayer):
  76. pass
  77. class VideomtLayerScale(EomtLayerScale):
  78. pass
  79. @dataclass
  80. @auto_docstring(
  81. custom_intro="""
  82. Class for outputs of [`VideomtForUniversalSegmentationOutput`].
  83. This output can be directly passed to [`~VideomtVideoProcessor.post_process_semantic_segmentation`] or
  84. [`~VideomtVideoProcessor.post_process_instance_segmentation`] or
  85. [`~VideomtVideoProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
  86. [`~VideomtVideoProcessor`] for details regarding usage.
  87. """
  88. )
  89. class VideomtForUniversalSegmentationOutput(ModelOutput):
  90. r"""
  91. loss (`torch.Tensor`, *optional*):
  92. The computed loss, returned when labels are present.
  93. class_queries_logits (`torch.FloatTensor`):
  94. A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
  95. query. Note the `+ 1` is needed because we incorporate the null class.
  96. masks_queries_logits (`torch.FloatTensor`):
  97. A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
  98. query.
  99. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  100. Last hidden states (final feature map) of the last layer.
  101. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  102. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  103. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
  104. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  105. Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  106. sequence_length)`. Self and Cross Attentions weights from transformer decoder.
  107. """
  108. loss: torch.FloatTensor | None = None
  109. class_queries_logits: torch.FloatTensor | None = None
  110. masks_queries_logits: torch.FloatTensor | None = None
  111. last_hidden_state: torch.FloatTensor | None = None
  112. hidden_states: tuple[torch.FloatTensor] | None = None
  113. attentions: tuple[torch.FloatTensor] | None = None
  114. class VideomtPreTrainedModel(EomtPreTrainedModel):
  115. main_input_name = "pixel_values_videos"
  116. input_modalities = ("video",)
  117. @torch.no_grad()
  118. def _init_weights(self, module: nn.Module) -> None:
  119. super()._init_weights(module)
  120. if isinstance(module, VideomtEmbeddings):
  121. nn.init.zeros_(module.mask_token)
  122. class VideomtLayerNorm2d(EomtLayerNorm2d):
  123. pass
  124. class VideomtScaleLayer(EomtScaleLayer):
  125. pass
  126. class VideomtScaleBlock(EomtScaleBlock):
  127. pass
  128. class VideomtForUniversalSegmentation(EomtForUniversalSegmentation):
  129. main_input_name = "pixel_values_videos"
  130. def __init__(self, config: VideomtConfig):
  131. super().__init__(config)
  132. self.query_updater = nn.Linear(config.hidden_size, config.hidden_size)
  133. def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
  134. raise AttributeError("Not needed for Videomt")
  135. def forward(
  136. self,
  137. pixel_values_videos: torch.Tensor | None = None,
  138. mask_labels: list[torch.Tensor] | None = None,
  139. class_labels: list[torch.Tensor] | None = None,
  140. patch_offsets: list[torch.Tensor] | None = None, # Unused, kept for modular compatibility.
  141. **kwargs: Unpack[TransformersKwargs],
  142. ) -> VideomtForUniversalSegmentationOutput:
  143. r"""
  144. pixel_values_videos (`torch.Tensor`, *optional*):
  145. Video inputs of shape `(batch_size, num_frames, num_channels, height, width)`.
  146. mask_labels (`list[torch.Tensor]`, *optional*):
  147. Not supported for 5D video inputs.
  148. class_labels (`list[torch.LongTensor]`, *optional*):
  149. Not supported for 5D video inputs.
  150. patch_offsets (`list[torch.Tensor]`, *optional*):
  151. Unused for video inputs and only kept for modular compatibility.
  152. """
  153. if "pixel_values" in kwargs:
  154. raise ValueError("Use `pixel_values_videos` with `VideomtForUniversalSegmentation`.")
  155. if pixel_values_videos is None:
  156. raise ValueError("You have to specify pixel_values_videos")
  157. if pixel_values_videos.ndim != 5:
  158. raise ValueError(
  159. "VideomtForUniversalSegmentation only supports 5D video inputs of shape "
  160. "(batch_size, num_frames, channels, height, width)."
  161. )
  162. if mask_labels is not None or class_labels is not None:
  163. raise ValueError(
  164. "Training with 5D video inputs is not supported in `VideomtForUniversalSegmentation`. "
  165. "Flatten frames and use `EomtForUniversalSegmentation` instead."
  166. )
  167. batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
  168. flat_pixel_values = pixel_values_videos.reshape(batch_size * num_frames, num_channels, height, width)
  169. hidden_states = self.embeddings(flat_pixel_values)
  170. query_start_idx = self.num_hidden_layers - self.config.num_blocks
  171. for layer_module in self.layers[:query_start_idx]:
  172. hidden_states = layer_module(hidden_states)
  173. hidden_states = hidden_states.view(batch_size, num_frames, hidden_states.shape[1], hidden_states.shape[2])
  174. all_masks_queries_logits = []
  175. all_class_queries_logits = []
  176. all_last_hidden_states = []
  177. propagated_query = None
  178. for frame_idx in range(num_frames):
  179. frame_hidden_states = hidden_states[:, frame_idx]
  180. if propagated_query is None:
  181. query_tokens = self.query.weight[None, :, :].expand(batch_size, -1, -1)
  182. else:
  183. query_tokens = self.query_updater(propagated_query) + self.query.weight[None, :, :].to(
  184. frame_hidden_states.device
  185. )
  186. frame_hidden_states = torch.cat((query_tokens.to(frame_hidden_states.device), frame_hidden_states), dim=1)
  187. for layer_module in self.layers[query_start_idx:]:
  188. frame_hidden_states = layer_module(frame_hidden_states)
  189. sequence_output = self.layernorm(frame_hidden_states)
  190. masks_queries_logits, class_queries_logits = self.predict(sequence_output)
  191. all_masks_queries_logits.append(masks_queries_logits)
  192. all_class_queries_logits.append(class_queries_logits)
  193. all_last_hidden_states.append(sequence_output)
  194. propagated_query = frame_hidden_states[:, : self.config.num_queries, :]
  195. return VideomtForUniversalSegmentationOutput(
  196. loss=None, # Training not supported yet
  197. masks_queries_logits=torch.cat(all_masks_queries_logits, dim=0),
  198. class_queries_logits=torch.cat(all_class_queries_logits, dim=0),
  199. last_hidden_state=torch.cat(all_last_hidden_states, dim=0),
  200. )
  201. __all__ = [
  202. "VideomtConfig",
  203. "VideomtPreTrainedModel",
  204. "VideomtForUniversalSegmentation",
  205. ]