| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- # Copyright 2026 the HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from dataclasses import dataclass
- import torch
- from torch import nn
- from ...file_utils import ModelOutput
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring
- from ..eomt.configuration_eomt import EomtConfig
- from ..eomt.modeling_eomt import (
- EomtEmbeddings,
- EomtForUniversalSegmentation,
- EomtLayer,
- EomtLayerNorm2d,
- EomtLayerScale,
- EomtMLP,
- EomtPatchEmbeddings,
- EomtPreTrainedModel,
- EomtScaleBlock,
- EomtScaleLayer,
- EomtSwiGLUFFN,
- )
- class VideomtConfig(EomtConfig):
- model_type = "videomt"
- class VideomtPatchEmbeddings(EomtPatchEmbeddings):
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
- num_channels = pixel_values.shape[1]
- if num_channels != self.num_channels:
- raise ValueError(
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
- f" Expected {self.num_channels} but got {num_channels}."
- )
- pixel_values = pixel_values.to(dtype=self.projection.weight.dtype)
- embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
- return embeddings
- class VideomtEmbeddings(EomtEmbeddings):
- def __init__(self, config: VideomtConfig):
- super().__init__(config)
- self.patch_embeddings = VideomtPatchEmbeddings(config)
- self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
- def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
- if pixel_values.ndim == 5:
- batch_size, num_frames, num_channels, height, width = pixel_values.shape
- pixel_values = pixel_values.reshape(batch_size * num_frames, num_channels, height, width)
- if bool_masked_pos is not None:
- bool_masked_pos = bool_masked_pos.reshape(batch_size * num_frames, -1)
- elif bool_masked_pos is not None and bool_masked_pos.ndim > 2:
- bool_masked_pos = bool_masked_pos.reshape(bool_masked_pos.shape[0], -1)
- batch_size = pixel_values.shape[0]
- embeddings = self.patch_embeddings(pixel_values)
- if bool_masked_pos is not None:
- mask = bool_masked_pos.to(device=embeddings.device, dtype=torch.bool).unsqueeze(-1)
- embeddings = torch.where(mask, self.mask_token, embeddings)
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
- register_tokens = self.register_tokens.expand(batch_size, -1, -1)
- embeddings = embeddings + self.position_embeddings(self.position_ids)
- embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
- embeddings = self.dropout(embeddings)
- return embeddings
- class VideomtMLP(EomtMLP):
- pass
- class VideomtGatedMLP(EomtSwiGLUFFN):
- pass
- class VideomtLayer(EomtLayer):
- pass
- class VideomtLayerScale(EomtLayerScale):
- pass
- @dataclass
- @auto_docstring(
- custom_intro="""
- Class for outputs of [`VideomtForUniversalSegmentationOutput`].
- This output can be directly passed to [`~VideomtVideoProcessor.post_process_semantic_segmentation`] or
- [`~VideomtVideoProcessor.post_process_instance_segmentation`] or
- [`~VideomtVideoProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
- [`~VideomtVideoProcessor`] for details regarding usage.
- """
- )
- class VideomtForUniversalSegmentationOutput(ModelOutput):
- r"""
- loss (`torch.Tensor`, *optional*):
- The computed loss, returned when labels are present.
- class_queries_logits (`torch.FloatTensor`):
- A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
- query. Note the `+ 1` is needed because we incorporate the null class.
- masks_queries_logits (`torch.FloatTensor`):
- A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
- query.
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Last hidden states (final feature map) of the last layer.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
- attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`. Self and Cross Attentions weights from transformer decoder.
- """
- loss: torch.FloatTensor | None = None
- class_queries_logits: torch.FloatTensor | None = None
- masks_queries_logits: torch.FloatTensor | None = None
- last_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- class VideomtPreTrainedModel(EomtPreTrainedModel):
- main_input_name = "pixel_values_videos"
- input_modalities = ("video",)
- @torch.no_grad()
- def _init_weights(self, module: nn.Module) -> None:
- super()._init_weights(module)
- if isinstance(module, VideomtEmbeddings):
- nn.init.zeros_(module.mask_token)
- class VideomtLayerNorm2d(EomtLayerNorm2d):
- pass
- class VideomtScaleLayer(EomtScaleLayer):
- pass
- class VideomtScaleBlock(EomtScaleBlock):
- pass
- class VideomtForUniversalSegmentation(EomtForUniversalSegmentation):
- main_input_name = "pixel_values_videos"
- def __init__(self, config: VideomtConfig):
- super().__init__(config)
- self.query_updater = nn.Linear(config.hidden_size, config.hidden_size)
- def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
- raise AttributeError("Not needed for Videomt")
- def forward(
- self,
- pixel_values_videos: torch.Tensor | None = None,
- mask_labels: list[torch.Tensor] | None = None,
- class_labels: list[torch.Tensor] | None = None,
- patch_offsets: list[torch.Tensor] | None = None, # Unused, kept for modular compatibility.
- **kwargs: Unpack[TransformersKwargs],
- ) -> VideomtForUniversalSegmentationOutput:
- r"""
- pixel_values_videos (`torch.Tensor`, *optional*):
- Video inputs of shape `(batch_size, num_frames, num_channels, height, width)`.
- mask_labels (`list[torch.Tensor]`, *optional*):
- Not supported for 5D video inputs.
- class_labels (`list[torch.LongTensor]`, *optional*):
- Not supported for 5D video inputs.
- patch_offsets (`list[torch.Tensor]`, *optional*):
- Unused for video inputs and only kept for modular compatibility.
- """
- if "pixel_values" in kwargs:
- raise ValueError("Use `pixel_values_videos` with `VideomtForUniversalSegmentation`.")
- if pixel_values_videos is None:
- raise ValueError("You have to specify pixel_values_videos")
- if pixel_values_videos.ndim != 5:
- raise ValueError(
- "VideomtForUniversalSegmentation only supports 5D video inputs of shape "
- "(batch_size, num_frames, channels, height, width)."
- )
- if mask_labels is not None or class_labels is not None:
- raise ValueError(
- "Training with 5D video inputs is not supported in `VideomtForUniversalSegmentation`. "
- "Flatten frames and use `EomtForUniversalSegmentation` instead."
- )
- batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
- flat_pixel_values = pixel_values_videos.reshape(batch_size * num_frames, num_channels, height, width)
- hidden_states = self.embeddings(flat_pixel_values)
- query_start_idx = self.num_hidden_layers - self.config.num_blocks
- for layer_module in self.layers[:query_start_idx]:
- hidden_states = layer_module(hidden_states)
- hidden_states = hidden_states.view(batch_size, num_frames, hidden_states.shape[1], hidden_states.shape[2])
- all_masks_queries_logits = []
- all_class_queries_logits = []
- all_last_hidden_states = []
- propagated_query = None
- for frame_idx in range(num_frames):
- frame_hidden_states = hidden_states[:, frame_idx]
- if propagated_query is None:
- query_tokens = self.query.weight[None, :, :].expand(batch_size, -1, -1)
- else:
- query_tokens = self.query_updater(propagated_query) + self.query.weight[None, :, :].to(
- frame_hidden_states.device
- )
- frame_hidden_states = torch.cat((query_tokens.to(frame_hidden_states.device), frame_hidden_states), dim=1)
- for layer_module in self.layers[query_start_idx:]:
- frame_hidden_states = layer_module(frame_hidden_states)
- sequence_output = self.layernorm(frame_hidden_states)
- masks_queries_logits, class_queries_logits = self.predict(sequence_output)
- all_masks_queries_logits.append(masks_queries_logits)
- all_class_queries_logits.append(class_queries_logits)
- all_last_hidden_states.append(sequence_output)
- propagated_query = frame_hidden_states[:, : self.config.num_queries, :]
- return VideomtForUniversalSegmentationOutput(
- loss=None, # Training not supported yet
- masks_queries_logits=torch.cat(all_masks_queries_logits, dim=0),
- class_queries_logits=torch.cat(all_class_queries_logits, dim=0),
- last_hidden_state=torch.cat(all_last_hidden_states, dim=0),
- )
- __all__ = [
- "VideomtConfig",
- "VideomtPreTrainedModel",
- "VideomtForUniversalSegmentation",
- ]
|