| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/sam2/modular_sam2.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_sam2.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import compile_compatible_method_lru_cache
- from ...utils import (
- ModelOutput,
- auto_docstring,
- can_return_tuple,
- logging,
- )
- from ...utils.generic import TransformersKwargs, is_flash_attention_requested, merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from ..auto import AutoModel
- from .configuration_sam2 import (
- Sam2Config,
- Sam2HieraDetConfig,
- Sam2MaskDecoderConfig,
- Sam2PromptEncoderConfig,
- Sam2VisionConfig,
- )
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
- class Sam2VisionEncoderOutput(BaseModelOutputWithPooling):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
- model at the output of each stage.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
- the self-attention heads.
- fpn_hidden_states (`tuple(torch.FloatTensor)`):
- Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
- `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
- fpn_position_encoding (`tuple(torch.FloatTensor)`):
- Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
- `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
- """
- fpn_hidden_states: torch.FloatTensor | None = None
- fpn_position_encoding: torch.FloatTensor | None = None
- @dataclass
- @auto_docstring(custom_intro="Base class for the Sam2 model's output.")
- class Sam2ImageSegmentationOutput(ModelOutput):
- r"""
- iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
- The Intersection over Union (IoU) scores of the predicted masks.
- pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
- The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
- by the processor to be brought to the original image size.
- object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
- Logits for the object score, indicating if an object is present.
- image_embeddings (`tuple(torch.FloatTensor)`):
- The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
- tensor has shape `(batch_size, channels, height, width)`.
- vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
- Hidden-states of the vision model at the output of each stage.
- vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
- Attentions weights of the vision model.
- mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
- Attentions weights of the mask decoder.
- """
- iou_scores: torch.FloatTensor | None = None
- pred_masks: torch.FloatTensor | None = None
- object_score_logits: torch.FloatTensor | None = None
- image_embeddings: tuple[torch.FloatTensor, ...] = None
- vision_hidden_states: tuple[torch.FloatTensor, ...] | None = None
- vision_attentions: tuple[torch.FloatTensor, ...] | None = None
- mask_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
- class Sam2PatchEmbeddings(nn.Module):
- r"""
- Turns pixel values into patch embeddings for transformer consumption.
- Args:
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Pixel values. Pixel values can be obtained using
- [`AutoImageProcessor`]. See [`Sam2ImageProcessor.__call__`] for details.
- Returns:
- embeddings (`torch.FloatTensor`):
- Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
- """
- def __init__(self, config: Sam2HieraDetConfig):
- super().__init__()
- num_channels = config.num_channels
- hidden_size = config.hidden_size
- self.projection = nn.Conv2d(
- num_channels,
- hidden_size,
- kernel_size=config.patch_kernel_size,
- stride=config.patch_stride,
- padding=config.patch_padding,
- )
- def forward(self, pixel_values):
- _, num_channels, height, width = pixel_values.shape
- embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).permute(0, 2, 3, 1)
- return embeddings
- # copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
- class Sam2SinePositionEmbedding(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
- need paper, generalized to work on images.
- """
- def __init__(
- self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
- ):
- super().__init__()
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- self.scale = 2 * math.pi if scale is None else scale
- @compile_compatible_method_lru_cache(maxsize=1)
- def forward(
- self,
- shape: torch.Size,
- device: torch.device | str,
- dtype: torch.dtype,
- mask: Tensor | None = None,
- ) -> Tensor:
- if mask is None:
- mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
- not_mask = (~mask).to(dtype)
- y_embed = not_mask.cumsum(1)
- x_embed = not_mask.cumsum(2)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
- class Sam2VisionNeck(nn.Module):
- def __init__(self, config: Sam2VisionConfig):
- super().__init__()
- self.config = config
- self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
- self.convs = nn.ModuleList()
- for in_channels in config.backbone_channel_list:
- self.convs.append(
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=config.fpn_hidden_size,
- kernel_size=config.fpn_kernel_size,
- stride=config.fpn_stride,
- padding=config.fpn_padding,
- ),
- )
- self.fpn_top_down_levels = config.fpn_top_down_levels
- def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
- fpn_hidden_states = ()
- fpn_position_encoding = ()
- # forward in top-down order (from low to high resolution)
- n = len(self.convs) - 1
- for i in range(n, -1, -1):
- lateral_features = hidden_states[i].permute(0, 3, 1, 2)
- lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
- if i not in self.fpn_top_down_levels or i == n:
- prev_features = lateral_features
- else:
- top_down_features = F.interpolate(
- prev_features.to(dtype=torch.float32),
- scale_factor=2.0,
- mode="nearest",
- align_corners=None,
- antialias=False,
- ).to(lateral_features.dtype)
- prev_features = lateral_features + top_down_features
- prev_position_encoding = self.position_encoding(
- prev_features.shape, prev_features.device, prev_features.dtype
- ).to(prev_features.dtype)
- fpn_hidden_states += (prev_features,)
- fpn_position_encoding += (prev_position_encoding,)
- return fpn_hidden_states, fpn_position_encoding
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
- ):
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- def do_pool(x: torch.Tensor, query_stride: int | None = None) -> torch.Tensor:
- if query_stride is None:
- return x
- # (B, H, W, C) -> (B, C, H, W)
- x = x.permute(0, 3, 1, 2)
- x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
- # (B, C, H', W') -> (B, H', W', C)
- x = x.permute(0, 2, 3, 1)
- return x
- class Sam2MultiScaleAttention(nn.Module):
- def __init__(
- self,
- config: Sam2HieraDetConfig,
- dim: int,
- dim_out: int,
- num_attention_heads: int,
- query_stride: tuple[int, int] | None = None,
- ):
- super().__init__()
- self.config = config
- self.dim = dim
- self.dim_out = dim_out
- self.query_stride = query_stride
- self.num_attention_heads = num_attention_heads
- head_dim = dim_out // num_attention_heads
- self.scale = head_dim**-0.5
- self.qkv = nn.Linear(dim, dim_out * 3)
- self.proj = nn.Linear(dim_out, dim_out)
- self.is_causal = False
- def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
- batch_size, height, width, _ = hidden_states.shape
- # qkv with shape (B, H * W, 3, nHead, C)
- qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
- # q, k, v with shape (B, H * W, nheads, C)
- query, key, value = torch.unbind(qkv, 2)
- attn_weights = (query * self.scale) @ key.transpose(-2, -1)
- attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
- # Q pooling (for downsample at stage changes)
- if self.query_stride:
- query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
- height, width = query.shape[1:3] # downsampled shape
- query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
- # transpose query, key, value to (B, nHead, H * W, C)
- query = query.transpose(1, 2)
- key = key.transpose(1, 2)
- value = value.transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, _ = attention_interface(
- self,
- query,
- key,
- value,
- attention_mask=None,
- is_causal=self.is_causal,
- scaling=self.scale,
- **kwargs,
- )
- attn_output = attn_output.reshape(batch_size, height, width, -1)
- attn_output = self.proj(attn_output)
- return attn_output
- class Sam2FeedForward(nn.Module):
- def __init__(
- self,
- input_dim: int,
- hidden_dim: int,
- output_dim: int,
- num_layers: int,
- activation: str = "relu",
- sigmoid_output: bool = False,
- ):
- super().__init__()
- self.num_layers = num_layers
- self.activation = ACT2FN[activation]
- self.proj_in = nn.Linear(input_dim, hidden_dim)
- self.proj_out = nn.Linear(hidden_dim, output_dim)
- self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
- self.sigmoid_output = sigmoid_output
- def forward(self, hidden_states):
- hidden_states = self.proj_in(hidden_states)
- hidden_states = self.activation(hidden_states)
- for layer in self.layers:
- hidden_states = self.activation(layer(hidden_states))
- hidden_states = self.proj_out(hidden_states)
- if self.sigmoid_output:
- hidden_states = F.sigmoid(hidden_states)
- return hidden_states
- def window_partition(hidden_state, window_size):
- """
- Partition into non-overlapping windows with padding if needed.
- Args:
- hidden_state (`torch.Tensor`):
- Input tokens with [batch_size, height, width, num_channels].
- window_size (`int`):
- Window size.
- Returns:
- `tuple(torch.FloatTensor)` comprising various elements:
- - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
- - (padded_height, padded_width): padded height and width before partition
- """
- batch_size, height, width, num_channels = hidden_state.shape
- pad_height = (window_size - height % window_size) % window_size
- pad_width = (window_size - width % window_size) % window_size
- # Noop in case pad_width == 0 and pad_height == 0.
- hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
- padded_height, padded_width = height + pad_height, width + pad_width
- hidden_state = hidden_state.view(
- batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
- )
- windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
- return windows, (padded_height, padded_width)
- def window_unpartition(windows, window_size, pad_height_width, height_width):
- """
- Window unpartition into original sequences and removing padding.
- Args:
- windows (`torch.Tensor`):
- Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
- window_size (`int`):
- Window size.
- pad_height_width (`tuple[int]`):
- Padded height and width (padded_height, padded_width).
- height_width (`tuple[int]`):
- Original height and width before padding.
- Returns:
- hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
- """
- padded_height, padded_width = pad_height_width
- height, width = height_width
- batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
- hidden_state = windows.view(
- batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
- )
- hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
- hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
- # We always have height <= padded_height and width <= padded_width
- hidden_state = hidden_state[:, :height, :width, :].contiguous()
- return hidden_state
- class Sam2MultiScaleBlock(GradientCheckpointingLayer):
- def __init__(
- self,
- config: Sam2HieraDetConfig,
- stage_idx: int,
- block_idx: int,
- total_block_idx: int,
- ):
- super().__init__()
- # take embed dim from previous stage if first block of stage
- self.dim = (
- config.embed_dim_per_stage[stage_idx - 1]
- if stage_idx > 0 and block_idx == 0
- else config.embed_dim_per_stage[stage_idx]
- )
- self.dim_out = config.embed_dim_per_stage[stage_idx]
- self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
- # take window size from previous stage if first block of stage
- self.window_size = (
- config.window_size_per_stage[stage_idx - 1]
- if stage_idx > 0 and block_idx == 0
- else config.window_size_per_stage[stage_idx]
- )
- self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
- # use query stride for first block of stage if stage is a query pool stage
- self.query_stride = (
- config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
- )
- self.attn = Sam2MultiScaleAttention(
- config,
- self.dim,
- self.dim_out,
- num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
- query_stride=self.query_stride,
- )
- self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
- self.mlp = Sam2FeedForward(
- self.dim_out,
- int(self.dim_out * config.mlp_ratio),
- self.dim_out,
- num_layers=2,
- activation=config.hidden_act,
- )
- if self.dim != self.dim_out:
- self.proj = nn.Linear(self.dim, self.dim_out)
- def forward(
- self,
- hidden_states: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.FloatTensor:
- residual = hidden_states # batch_size, height, width, channel
- hidden_states = self.layer_norm1(hidden_states)
- # Skip connection
- if self.dim != self.dim_out:
- residual = do_pool(self.proj(hidden_states), self.query_stride)
- # Window partition
- window_size = self.window_size
- if self.window_size > 0:
- H, W = hidden_states.shape[1], hidden_states.shape[2]
- hidden_states, pad_hw = window_partition(hidden_states, window_size)
- # Window Attention + Q Pooling (if stage change)
- attn_output = self.attn(
- hidden_states=hidden_states,
- **kwargs,
- )
- hidden_states = attn_output
- if self.query_stride:
- # Shapes have changed due to Q pooling
- window_size = self.window_size // self.query_stride[0]
- H, W = residual.shape[1:3]
- pad_h = (window_size - H % window_size) % window_size
- pad_w = (window_size - W % window_size) % window_size
- pad_hw = (H + pad_h, W + pad_w)
- # Reverse window partition
- if self.window_size > 0:
- hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
- hidden_states = residual + hidden_states
- layernorm_output = self.layer_norm2(hidden_states)
- hidden_states = hidden_states + self.mlp(layernorm_output)
- return hidden_states
- @dataclass
- @auto_docstring(
- custom_intro="""
- Hiera model's outputs that also contains a pooling of the last hidden states.
- """
- )
- class Sam2HieraDetModelOutput(ModelOutput):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
- hidden-states at the output of the last layer of the model.
- intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
- Sequence of hidden-states at the output of the intermediate layers of the model.
- """
- last_hidden_state: torch.FloatTensor | None = None
- intermediate_hidden_states: tuple[torch.FloatTensor, ...] | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- @auto_docstring
- class Sam2PreTrainedModel(PreTrainedModel):
- config_class = Sam2Config
- base_model_prefix = "sam2"
- main_input_name = "pixel_values"
- input_modalities = ("image",)
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_attention_backend = True
- _keys_to_ignore_on_load_unexpected = [
- r"^memory_.*",
- r"^mask_downsample.*",
- r"^object_pointer_proj.*",
- r"^temporal_positional_encoding_projection_layer.*",
- "no_memory_positional_encoding",
- "no_object_pointer",
- "occlusion_spatial_embedding_parameter",
- ]
- @torch.no_grad()
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, Sam2HieraDetModel):
- if module.pos_embed is not None:
- init.zeros_(module.pos_embed)
- if module.pos_embed_window is not None:
- init.zeros_(module.pos_embed_window)
- elif isinstance(module, Sam2PositionalEmbedding):
- init.normal_(module.positional_embedding, std=module.scale)
- elif isinstance(module, Sam2Model):
- if module.no_memory_embedding is not None:
- init.zeros_(module.no_memory_embedding)
- class Sam2HieraDetModel(Sam2PreTrainedModel):
- config_class = Sam2HieraDetConfig
- main_input_name = "pixel_values"
- _can_record_outputs = {
- "hidden_states": Sam2MultiScaleBlock,
- "attentions": Sam2MultiScaleAttention,
- }
- def __init__(self, config: Sam2HieraDetConfig):
- super().__init__(config)
- self.patch_embed = Sam2PatchEmbeddings(config)
- # Windowed positional embedding (https://huggingface.co/papers/2311.05613)
- self.pos_embed = nn.Parameter(
- torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
- )
- self.pos_embed_window = nn.Parameter(
- torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
- )
- self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
- self.blocks = nn.ModuleList()
- total_block_idx = 0
- for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
- for block_idx in range(blocks_per_stage):
- block = Sam2MultiScaleBlock(
- config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
- )
- self.blocks.append(block)
- total_block_idx += 1
- self.post_init()
- def get_input_embeddings(self):
- return self.patch_embed
- def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
- h, w = hw
- window_embed = self.pos_embed_window
- pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
- pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
- pos_embed = pos_embed.permute(0, 2, 3, 1)
- return pos_embed
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Sam2HieraDetModelOutput:
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- hidden_states = self.patch_embed(pixel_values)
- hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
- intermediate_hidden_states = ()
- for i, block_module in enumerate(self.blocks):
- hidden_states = block_module(hidden_states, **kwargs)
- if i in self.stage_ends:
- intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
- return Sam2HieraDetModelOutput(
- last_hidden_state=hidden_states,
- intermediate_hidden_states=intermediate_hidden_states,
- )
- @auto_docstring(
- custom_intro="""
- The vision model from Sam without any head or projection on top.
- """
- )
- class Sam2VisionModel(Sam2PreTrainedModel):
- config_class = Sam2VisionConfig
- main_input_name = "pixel_values"
- _can_record_outputs = {
- "hidden_states": Sam2MultiScaleBlock,
- "attentions": Sam2MultiScaleAttention,
- }
- def __init__(self, config: Sam2VisionConfig):
- super().__init__(config)
- self.config = config
- self.backbone = AutoModel.from_config(config.backbone_config)
- self.neck = Sam2VisionNeck(config)
- self.num_feature_levels = config.num_feature_levels
- self.post_init()
- def get_input_embeddings(self):
- return self.backbone.get_input_embeddings()
- @can_return_tuple
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Sam2VisionEncoderOutput:
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- # Forward through backbone
- backbone_output = self.backbone(pixel_values, **kwargs)
- hidden_states = backbone_output.last_hidden_state
- intermediate_hidden_states = backbone_output.intermediate_hidden_states
- fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
- # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
- fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
- fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
- return Sam2VisionEncoderOutput(
- last_hidden_state=hidden_states,
- fpn_hidden_states=fpn_hidden_states,
- fpn_position_encoding=fpn_position_encoding,
- hidden_states=backbone_output.hidden_states,
- attentions=backbone_output.attentions,
- )
- class Sam2PositionalEmbedding(nn.Module):
- def __init__(self, config: Sam2PromptEncoderConfig):
- super().__init__()
- self.scale = config.scale
- positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
- self.register_buffer("positional_embedding", positional_embedding)
- def forward(self, input_coords, input_shape=None):
- """Positionally encode points that are normalized to [0,1]."""
- coordinates = input_coords.clone()
- if input_shape is not None:
- coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
- coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
- coordinates.to(torch.float32)
- # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
- coordinates = 2 * coordinates - 1
- coordinates = coordinates.to(self.positional_embedding.dtype)
- coordinates = coordinates @ self.positional_embedding
- coordinates = 2 * np.pi * coordinates
- # outputs d_1 x ... x d_n x channel shape
- return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
- class Sam2MaskEmbedding(nn.Module):
- def __init__(self, config: Sam2PromptEncoderConfig):
- super().__init__()
- self.mask_input_channels = config.mask_input_channels // 4
- self.activation = ACT2FN[config.hidden_act]
- self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
- self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
- self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
- self.layer_norm1 = Sam2LayerNorm(
- self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
- )
- self.layer_norm2 = Sam2LayerNorm(
- self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
- )
- def forward(self, masks):
- hidden_states = self.conv1(masks)
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states = self.activation(hidden_states)
- hidden_states = self.conv2(hidden_states)
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.activation(hidden_states)
- dense_embeddings = self.conv3(hidden_states)
- return dense_embeddings
- class Sam2PromptEncoder(nn.Module):
- def __init__(self, config: Sam2PromptEncoderConfig):
- super().__init__()
- self.shared_embedding = Sam2PositionalEmbedding(config)
- self.mask_embed = Sam2MaskEmbedding(config)
- self.no_mask_embed = nn.Embedding(1, config.hidden_size)
- self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
- self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
- self.input_image_size = config.image_size
- self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
- self.hidden_size = config.hidden_size
- self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
- def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
- """Embeds point prompts."""
- points = points + 0.5 # Shift to center of pixel
- if pad:
- points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
- labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
- input_shape = (self.input_image_size, self.input_image_size)
- point_embedding = self.shared_embedding(points, input_shape)
- # torch.where and expanding the labels tensor is required by the ONNX export
- point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
- # This is required for the ONNX export. The dtype, device need to be explicitly
- # specified as otherwise torch.onnx.export interprets as double
- point_embedding = torch.where(
- labels[..., None] != -10,
- point_embedding,
- torch.zeros_like(point_embedding),
- )
- # Add point embeddings for labels >= 0
- point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
- return point_embedding
- def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
- """Embeds box prompts."""
- boxes = boxes + 0.5 # Shift to center of pixel
- coords = boxes.view(*boxes.shape[:2], 2, 2)
- # add padding point for consistency with the original implementation
- coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
- corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
- corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
- corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
- corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
- return corner_embedding
- def forward(
- self,
- input_points: tuple[torch.Tensor, torch.Tensor] | None,
- input_labels: torch.Tensor | None,
- input_boxes: torch.Tensor | None,
- input_masks: torch.Tensor | None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Embeds different types of prompts, returning both sparse and dense embeddings.
- Args:
- points (`torch.Tensor`, *optional*):
- point coordinates and labels to embed.
- boxes (`torch.Tensor`, *optional*):
- boxes to embed
- masks (`torch.Tensor`, *optional*):
- masks to embed
- """
- sparse_embeddings = None
- batch_size = 1
- if input_points is not None:
- batch_size = input_points.shape[0]
- if input_labels is None:
- raise ValueError("If points are provided, labels must also be provided.")
- point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
- sparse_embeddings = point_embeddings
- if input_boxes is not None:
- batch_size = input_boxes.shape[0]
- box_embeddings = self._embed_boxes(input_boxes)
- if sparse_embeddings is None:
- sparse_embeddings = box_embeddings
- else:
- sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
- if input_masks is not None:
- dense_embeddings = self.mask_embed(input_masks)
- else:
- dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
- batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
- )
- return sparse_embeddings, dense_embeddings
- class Sam2Attention(nn.Module):
- """
- SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
- values.
- """
- def __init__(self, config, downsample_rate=None):
- super().__init__()
- downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
- self.config = config
- self.hidden_size = config.hidden_size
- self.internal_dim = config.hidden_size // downsample_rate
- self.num_attention_heads = config.num_attention_heads
- self.head_dim = self.internal_dim // config.num_attention_heads
- self.scaling = self.head_dim**-0.5
- self.is_causal = False
- self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
- self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
- self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
- self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_similarity: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- # Input projections
- batch_size, point_batch_size = query.shape[:2]
- new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
- query = self.q_proj(query).view(*new_shape).transpose(1, 2)
- key = self.k_proj(key).view(*new_shape).transpose(1, 2)
- value = self.v_proj(value).view(*new_shape).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- if is_flash_attention_requested(self.config) and attention_similarity is not None:
- # Target guided masks are represented as float masks and are incompatible with Flash Attention
- # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
- attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
- logger.warning_once(
- "Falling back to SDPA for target-guided attention because "
- "Flash Attention does not support additive bias masks."
- )
- attn_output, attn_weights = attention_interface(
- self,
- query,
- key,
- value,
- attention_mask=attention_similarity,
- dropout=0.0,
- scaling=self.scaling,
- is_causal=self.is_causal,
- **kwargs,
- )
- attn_output = attn_output.reshape(
- batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
- ).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class Sam2TwoWayAttentionBlock(GradientCheckpointingLayer):
- def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
- """
- A transformer block with four layers:
- (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
- sparse inputs (4) cross attention of dense inputs -> sparse inputs
- Arguments:
- config (`Sam2MaskDecoderConfig`):
- The configuration file used to instantiate the block
- attention_downsample_rate (*optionalk*, int, defaults to 2):
- The downsample ratio of the block used to reduce the inner dim of the attention.
- skip_first_layer_pe (*optional*, bool, defaults to `False`):
- Whether or not to skip the addition of the query_point_embedding on the first layer.
- """
- super().__init__()
- self.self_attn = Sam2Attention(config, downsample_rate=1)
- self.layer_norm1 = nn.LayerNorm(config.hidden_size)
- self.cross_attn_token_to_image = Sam2Attention(config)
- self.layer_norm2 = nn.LayerNorm(config.hidden_size)
- self.mlp = Sam2FeedForward(
- config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
- )
- self.layer_norm3 = nn.LayerNorm(config.hidden_size)
- self.layer_norm4 = nn.LayerNorm(config.hidden_size)
- self.cross_attn_image_to_token = Sam2Attention(config)
- self.skip_first_layer_pe = skip_first_layer_pe
- def forward(
- self,
- queries: Tensor,
- keys: Tensor,
- query_point_embedding: Tensor,
- key_point_embedding: Tensor,
- attention_similarity: Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ):
- # Self attention block
- if self.skip_first_layer_pe:
- queries, _ = self.self_attn(query=queries, key=queries, value=queries)
- else:
- query = queries + query_point_embedding
- attn_out, _ = self.self_attn(query=query, key=query, value=queries)
- queries = queries + attn_out
- queries = self.layer_norm1(queries)
- # Cross attention block, tokens attending to image embedding
- query = queries + query_point_embedding
- key = keys + key_point_embedding
- attn_out, _ = self.cross_attn_token_to_image(
- query=query, key=key, value=keys, attention_similarity=attention_similarity
- )
- queries = queries + attn_out
- queries = self.layer_norm2(queries)
- # MLP block
- mlp_out = self.mlp(queries)
- queries = queries + mlp_out
- queries = self.layer_norm3(queries)
- # Cross attention block, image embedding attending to tokens
- query = queries + query_point_embedding
- key = keys + key_point_embedding
- attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
- keys = keys + attn_out
- keys = self.layer_norm4(keys)
- return queries, keys, attn_out
- class Sam2TwoWayTransformer(nn.Module):
- def __init__(self, config: Sam2MaskDecoderConfig):
- super().__init__()
- self.config = config
- self.num_hidden_layers = config.num_hidden_layers
- self.layers = nn.ModuleList()
- for i in range(self.num_hidden_layers):
- self.layers.append(Sam2TwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
- self.final_attn_token_to_image = Sam2Attention(config)
- self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
- def forward(
- self,
- point_embeddings: Tensor,
- image_embeddings: Tensor,
- image_positional_embeddings: Tensor,
- attention_similarity: Tensor,
- target_embedding=None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutput:
- if image_embeddings is None:
- raise ValueError("You have to specify an image_embedding")
- image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
- image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
- # Prepare queries
- queries = point_embeddings
- keys = image_embeddings
- # Apply transformer blocks and final layernorm
- for layer in self.layers:
- if target_embedding is not None:
- queries += target_embedding
- queries, keys, _ = layer(
- queries=queries,
- keys=keys,
- query_point_embedding=point_embeddings,
- key_point_embedding=image_positional_embeddings,
- attention_similarity=attention_similarity,
- **kwargs,
- )
- # Apply the final attention layer from the points to the image
- query = queries + point_embeddings
- key = keys + image_positional_embeddings
- attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
- queries = queries + attn_out
- queries = self.layer_norm_final_attn(queries)
- return queries, keys
- class Sam2LayerNorm(nn.LayerNorm):
- r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
- width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
- """
- def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
- super().__init__(normalized_shape, eps=eps, **kwargs)
- if data_format not in ["channels_last", "channels_first"]:
- raise NotImplementedError(f"Unsupported data format: {data_format}")
- self.data_format = data_format
- def forward(self, features: torch.Tensor) -> torch.Tensor:
- """
- Args:
- features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
- """
- if self.data_format == "channels_first":
- features = features.permute(0, 2, 3, 1)
- features = super().forward(features)
- features = features.permute(0, 3, 1, 2)
- else:
- features = super().forward(features)
- return features
- class Sam2MaskDecoder(nn.Module):
- def __init__(self, config: Sam2MaskDecoderConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.num_multimask_outputs = config.num_multimask_outputs
- self.num_mask_tokens = config.num_multimask_outputs + 1
- self.iou_token = nn.Embedding(1, self.hidden_size)
- self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
- self.transformer = Sam2TwoWayTransformer(config)
- # should we create a new class for this?
- self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
- self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
- self.upscale_layer_norm = Sam2LayerNorm(self.hidden_size // 4, data_format="channels_first")
- self.activation = nn.GELU()
- mlps_list = []
- for _ in range(self.num_mask_tokens):
- mlps_list += [Sam2FeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
- self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
- self.iou_prediction_head = Sam2FeedForward(
- self.hidden_size,
- config.iou_head_hidden_dim,
- self.num_mask_tokens,
- config.iou_head_depth,
- sigmoid_output=True,
- )
- self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
- self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
- self.obj_score_token = nn.Embedding(1, self.hidden_size)
- self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
- self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
- self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
- self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
- def forward(
- self,
- image_embeddings: torch.Tensor,
- image_positional_embeddings: torch.Tensor,
- sparse_prompt_embeddings: torch.Tensor,
- dense_prompt_embeddings: torch.Tensor,
- multimask_output: bool,
- high_resolution_features: list[torch.Tensor],
- attention_similarity: torch.Tensor | None = None,
- target_embedding: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Predict masks given image and prompt embeddings.
- Args:
- image_embeddings (`torch.Tensor`):
- The embeddings from the image encoder.
- image_positional_embeddings (`torch.Tensor`):
- Positional encoding with the shape of image_embeddings.
- sparse_prompt_embeddings (`torch.Tensor`):
- The embeddings of the points and boxes.
- dense_prompt_embeddings (`torch.Tensor`):
- The embeddings of the mask inputs.
- multimask_output (`bool`):
- Whether to return multiple masks or a single mask.
- high_resolution_features (`list[torch.Tensor]`, *optional*):
- The high-resolution features from the vision encoder.
- attention_similarity (`torch.Tensor`, *optional*):
- The attention similarity tensor.
- target_embedding (`torch.Tensor`, *optional*):
- The target embedding.
- """
- batch_size, num_channels, height, width = image_embeddings.shape
- point_batch_size = sparse_prompt_embeddings.shape[1]
- # Concatenate output tokens
- output_tokens = torch.cat(
- [
- self.obj_score_token.weight,
- self.iou_token.weight,
- self.mask_tokens.weight,
- ],
- dim=0,
- )
- output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
- if sparse_prompt_embeddings.shape[0] != 0:
- tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
- else:
- tokens = output_tokens
- point_embeddings = tokens.to(self.iou_token.weight.dtype)
- # Expand per-image data in batch direction to be per-mask
- image_embeddings = image_embeddings + dense_prompt_embeddings
- image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
- image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
- # Run the transformer
- point_embeddings, image_embeddings = self.transformer(
- point_embeddings=point_embeddings,
- image_embeddings=image_embeddings,
- image_positional_embeddings=image_positional_embeddings,
- attention_similarity=attention_similarity,
- target_embedding=target_embedding,
- **kwargs,
- )
- iou_token_out = point_embeddings[:, :, 1, :]
- mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
- # Upscale mask embeddings and predict masks using the mask tokens
- image_embeddings = image_embeddings.transpose(2, 3).view(
- batch_size * point_batch_size, num_channels, height, width
- )
- feat_s0, feat_s1 = high_resolution_features
- feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
- feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
- upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
- upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
- upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
- hyper_in_list: list[torch.Tensor] = []
- for i in range(self.num_mask_tokens):
- current_mlp = self.output_hypernetworks_mlps[i]
- hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
- hyper_in = torch.stack(hyper_in_list, dim=2)
- _, num_channels, height, width = upscaled_embedding.shape
- upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
- masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
- # Generate mask quality predictions
- iou_pred = self.iou_prediction_head(iou_token_out)
- object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
- # Select the correct mask or masks for output
- if multimask_output:
- mask_slice = slice(1, None)
- masks = masks[:, :, mask_slice, :, :]
- iou_pred = iou_pred[:, :, mask_slice]
- elif self.dynamic_multimask_via_stability and not self.training:
- mask_slice = slice(0, 1)
- masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
- else:
- mask_slice = slice(0, 1)
- masks = masks[:, :, mask_slice, :, :]
- iou_pred = iou_pred[:, :, mask_slice]
- sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
- return masks, iou_pred, sam_tokens_out, object_score_logits
- def _get_stability_scores(self, mask_logits):
- """
- Compute stability scores of the mask logits based on the IoU between upper and
- lower thresholds.
- """
- mask_logits = mask_logits.flatten(-2)
- stability_delta = self.dynamic_multimask_stability_delta
- area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
- area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
- stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
- return stability_scores
- def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
- """
- When outputting a single mask, if the stability score from the current single-mask
- output (based on output token 0) falls below a threshold, we instead select from
- multi-mask outputs (based on output token 1~3) the mask with the highest predicted
- IoU score. This is intended to ensure a valid mask for both clicking and tracking.
- """
- # The best mask from multimask output tokens (1~3)
- multimask_logits = all_mask_logits[:, :, 1:, :, :]
- multimask_iou_scores = all_iou_scores[:, :, 1:]
- best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
- best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
- best_scores_inds_expanded = best_scores_inds_expanded.expand(
- -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
- )
- best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
- best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
- # The mask from singlemask output token 0 and its stability score
- singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
- singlemask_iou_scores = all_iou_scores[:, :, 0:1]
- stability_scores = self._get_stability_scores(singlemask_logits)
- is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
- # Dynamically fall back to best multimask output upon low stability scores.
- mask_logits_out = torch.where(
- is_stable[..., None, None].expand_as(singlemask_logits),
- singlemask_logits,
- best_multimask_logits,
- )
- iou_scores_out = torch.where(
- is_stable.expand_as(singlemask_iou_scores),
- singlemask_iou_scores,
- best_multimask_iou_scores,
- )
- return mask_logits_out, iou_scores_out
- @auto_docstring(
- custom_intro="""
- Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
- input points and labels, boxes, or masks.
- """
- )
- class Sam2Model(Sam2PreTrainedModel):
- input_modalities = ("image", "text")
- _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2TwoWayAttentionBlock, index=2)}
- _tied_weights_keys = {}
- def __init__(self, config: Sam2Config):
- super().__init__(config)
- self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
- self.vision_encoder = AutoModel.from_config(config.vision_config)
- self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
- # The module using it is not a PreTrainedModel subclass so we need this
- config.mask_decoder_config._attn_implementation = config._attn_implementation
- self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
- self.num_feature_levels = config.vision_config.num_feature_levels
- self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
- # a single token to indicate no memory embedding from previous frames
- self.hidden_dim = config.vision_config.fpn_hidden_size
- self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
- self.post_init()
- def get_input_embeddings(self):
- return self.vision_encoder.get_input_embeddings()
- def get_image_wide_positional_embeddings(self) -> torch.Tensor:
- size = self.prompt_encoder.image_embedding_size
- target_device = self.shared_image_embedding.positional_embedding.device
- target_dtype = self.shared_image_embedding.positional_embedding.dtype
- grid = torch.ones(size, device=target_device, dtype=target_dtype)
- y_embed = grid.cumsum(dim=0) - 0.5
- x_embed = grid.cumsum(dim=1) - 0.5
- y_embed = y_embed / size[0]
- x_embed = x_embed / size[1]
- positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
- return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
- @torch.no_grad()
- def get_image_embeddings(
- self,
- pixel_values: torch.FloatTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> list[torch.Tensor]:
- r"""
- Returns the image embeddings by passing the pixel values through the vision encoder.
- Args:
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Input pixel values
- """
- batch_size = pixel_values.shape[0]
- image_outputs = self.get_image_features(pixel_values, return_dict=True, **kwargs)
- feature_maps = image_outputs.fpn_hidden_states
- # add no memory embedding to the last feature map
- feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
- # reshape feature maps to the same shape as the backbone feature sizes
- image_embeddings = [
- feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
- for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
- ]
- return image_embeddings
- @torch.no_grad()
- def get_prompt_embeddings(
- self,
- input_points: torch.FloatTensor | None = None,
- input_labels: torch.LongTensor | None = None,
- input_boxes: torch.FloatTensor | None = None,
- input_masks: torch.LongTensor | None = None,
- ):
- r"""
- Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
- Args:
- input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
- Optional input points for the prompt encoder. The padding of the point is automatically done by the
- processor. `point_batch_size` refers to the number of masks that we want the model to predict per
- point. The model will output `point_batch_size` times 3 masks in total.
- input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
- Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
- processor, or can be fed by the user.
- input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
- Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
- processor. users can also pass manually the input boxes.
- input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
- Optional input masks for the prompt encoder.
- """
- prompt_output = self.prompt_encoder(
- input_points=input_points,
- input_labels=input_labels,
- input_boxes=input_boxes,
- input_masks=input_masks,
- )
- return prompt_output
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- input_points: torch.FloatTensor | None = None,
- input_labels: torch.LongTensor | None = None,
- input_boxes: torch.FloatTensor | None = None,
- input_masks: torch.LongTensor | None = None,
- image_embeddings: torch.FloatTensor | None = None,
- multimask_output: bool = True,
- attention_similarity: torch.FloatTensor | None = None,
- target_embedding: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Sam2ImageSegmentationOutput:
- r"""
- input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
- Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
- better results. The points can be obtained by passing a list of list of list to the processor that will
- create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
- second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
- per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
- multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
- coordinates of the point. If a different number of points is passed either for each image, or for each
- mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
- computation of the embedding will be skipped for these points using the labels.
- input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
- Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
- official implementation, there are 3 types of labels
- - `1`: the point is a point that contains the object of interest
- - `0`: the point is a point that does not contain the object of interest
- - `-1`: the point corresponds to the background
- We added the label:
- - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
- The padding labels should be automatically done by the processor.
- input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
- Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
- much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
- that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
- size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
- In the order (`x1`, `y1`, `x2`, `y2`):
- - `x1`: the x coordinate of the top left point of the input box
- - `y1`: the y coordinate of the top left point of the input box
- - `x2`: the x coordinate of the bottom right point of the input box
- - `y2`: the y coordinate of the bottom right point of the input box
- input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
- SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
- generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
- manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
- image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
- Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
- efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
- method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
- multimask_output (`bool`, *optional*):
- In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
- bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
- "best" mask, by specifying `multimask_output=False`.
- attention_similarity (`torch.FloatTensor`, *optional*):
- Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
- model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
- target_embedding (`torch.FloatTensor`, *optional*):
- Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
- the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
- Example:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoModel, AutoProcessor
- >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
- >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
- >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
- >>> with httpx.stream("GET", url) as response:
- ... raw_image = Image.open(BytesIO(response.read())).convert("RGB")
- >>> input_points = [[[400, 650]]] # 2D location of a window on the car
- >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
- >>> # Get segmentation mask
- >>> outputs = model(**inputs)
- >>> # Postprocess masks
- >>> masks = processor.post_process_masks(
- ... outputs.pred_masks, inputs["original_sizes"]
- ... )
- ```
- """
- if not ((pixel_values is None) ^ (image_embeddings is None)):
- raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
- if input_points is not None and input_boxes is not None:
- if input_points.shape[1] != input_boxes.shape[1]:
- raise ValueError(
- f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
- )
- image_positional_embeddings = self.get_image_wide_positional_embeddings()
- # repeat with batch size
- batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
- image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
- vision_attentions = None
- vision_hidden_states = None
- if pixel_values is not None:
- image_outputs: Sam2VisionEncoderOutput = self.get_image_features(pixel_values, return_dict=True, **kwargs)
- feature_maps = image_outputs.fpn_hidden_states
- vision_hidden_states = image_outputs.hidden_states
- vision_attentions = image_outputs.attentions
- # add no memory embedding to the last feature map
- feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
- # reshape feature maps to the same shape as the backbone feature sizes
- image_embeddings = [
- feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
- for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
- ]
- if input_points is not None and input_labels is None:
- input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
- if input_points is None and input_boxes is None:
- # If no points are provide, pad with an empty point (with label -1)
- input_points = torch.zeros(
- batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
- )
- input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
- if input_masks is not None:
- # If mask_inputs is provided, downsize it into low-res mask input if needed
- # and feed it as a dense mask prompt into the SAM mask encoder
- if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
- input_masks = F.interpolate(
- input_masks.float(),
- size=self.prompt_encoder.mask_input_size,
- align_corners=False,
- mode="bilinear",
- antialias=True, # use antialias for downsampling
- ).to(input_masks.dtype)
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
- input_points=input_points,
- input_labels=input_labels,
- input_boxes=input_boxes,
- input_masks=input_masks,
- )
- low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
- image_embeddings=image_embeddings[-1],
- image_positional_embeddings=image_positional_embeddings,
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- high_resolution_features=image_embeddings[:-1],
- attention_similarity=attention_similarity,
- target_embedding=target_embedding,
- **kwargs,
- )
- return Sam2ImageSegmentationOutput(
- iou_scores=iou_scores,
- pred_masks=low_res_multimasks,
- object_score_logits=object_score_logits,
- image_embeddings=image_embeddings,
- vision_hidden_states=vision_hidden_states,
- vision_attentions=vision_attentions,
- )
- @can_return_tuple
- @auto_docstring
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Sam2VisionEncoderOutput:
- r"""
- pixel_values (`torch.FloatTensor`):
- Input pixel values of shape `(batch_size, num_channels, height, width)`.
- """
- vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)
- feature_maps = vision_outputs.fpn_hidden_states
- feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
- # precompute projected level 0 and level 1 features in SAM decoder
- # to avoid running it again on every SAM click
- feature_maps = list(feature_maps)
- feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
- feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
- # flatten NxCxHxW to HWxNxC
- feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
- feature_maps_position_embeddings = [
- feature_map_position_embedding.flatten(2).permute(2, 0, 1)
- for feature_map_position_embedding in feature_maps_position_embeddings
- ]
- vision_outputs.fpn_hidden_states = feature_maps
- vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
- return vision_outputs
- __all__ = ["Sam2Model", "Sam2VisionModel", "Sam2PreTrainedModel", "Sam2HieraDetModel"]
|