| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429 |
- # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable, Iterable
- from dataclasses import dataclass
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from ...utils import is_torchvision_available
- if is_torchvision_available():
- import torchvision
- from transformers import CLIPTextModelWithProjection
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...masking_utils import create_bidirectional_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPooling,
- ModelOutput,
- )
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import compile_compatible_method_lru_cache
- from ...utils import auto_docstring, can_return_tuple, logging
- from ...utils.generic import (
- TransformersKwargs,
- is_flash_attention_requested,
- merge_with_config_defaults,
- )
- from ...utils.import_utils import requires
- from ...utils.output_capturing import capture_outputs
- from ..auto import AutoModel
- from .configuration_sam3 import (
- Sam3Config,
- Sam3DETRDecoderConfig,
- Sam3DETREncoderConfig,
- Sam3GeometryEncoderConfig,
- Sam3MaskDecoderConfig,
- Sam3VisionConfig,
- Sam3ViTConfig,
- )
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring
- class Sam3VisionEncoderOutput(BaseModelOutputWithPooling):
- r"""
- fpn_hidden_states (`tuple[torch.FloatTensor]`):
- Tuple of multi-level FPN feature maps.
- fpn_position_encoding (`tuple[torch.FloatTensor]`):
- Tuple of position encodings for each FPN level.
- """
- fpn_hidden_states: tuple[torch.FloatTensor, ...] = None
- fpn_position_encoding: tuple[torch.FloatTensor, ...] = None
- @dataclass
- @auto_docstring
- class Sam3GeometryEncoderOutput(ModelOutput):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_prompts, hidden_size)`):
- Encoded geometry prompt features (boxes).
- attention_mask (`torch.BoolTensor` of shape `(batch_size, num_prompts)`, *optional*):
- Attention mask for geometry prompts where True indicates valid positions and False indicates padding.
- """
- last_hidden_state: torch.FloatTensor = None
- attention_mask: torch.BoolTensor | None = None
- @dataclass
- @auto_docstring
- class Sam3DETREncoderOutput(ModelOutput):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Encoded vision features (flattened from multi-level features).
- pos_embeds_flattened (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Flattened position embeddings for the vision features.
- text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`, *optional*):
- Text features (may be pooled after encoder processing).
- spatial_shapes (`torch.LongTensor` of shape `(num_levels, 2)`, *optional*):
- Spatial shapes (height, width) for each feature pyramid level.
- hidden_states (`tuple[torch.FloatTensor]`, *optional*):
- Tuple of hidden states from all encoder layers.
- attentions (`tuple[torch.FloatTensor]`, *optional*):
- Tuple of attention weights from all encoder layers.
- """
- last_hidden_state: torch.FloatTensor = None
- pos_embeds_flattened: torch.FloatTensor | None = None
- text_features: torch.FloatTensor | None = None
- spatial_shapes: torch.LongTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring
- class Sam3DETRDecoderOutput(ModelOutput):
- r"""
- intermediate_hidden_states (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, hidden_size)`):
- Decoder hidden states from all layers.
- reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
- Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
- presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
- Presence logits from all decoder layers indicating object presence confidence.
- hidden_states (`tuple[torch.FloatTensor]`, *optional*):
- Tuple of hidden states from all decoder layers.
- attentions (`tuple[torch.FloatTensor]`, *optional*):
- Tuple of attention weights from all decoder layers (self-attention and cross-attention).
- """
- intermediate_hidden_states: torch.FloatTensor = None
- reference_boxes: torch.FloatTensor = None
- presence_logits: torch.FloatTensor = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring
- class Sam3MaskDecoderOutput(ModelOutput):
- r"""
- pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
- Predicted segmentation masks for each query.
- semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
- Semantic segmentation output.
- attentions (`tuple[torch.FloatTensor]`, *optional*):
- Tuple of attention weights from mask decoder cross-attention layers.
- """
- pred_masks: torch.FloatTensor = None
- semantic_seg: torch.FloatTensor | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- @dataclass
- @auto_docstring
- class Sam3ImageSegmentationOutput(ModelOutput):
- r"""
- pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
- Predicted segmentation masks for each query.
- pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
- Predicted bounding boxes in (x1, y1, x2, y2) format.
- pred_logits (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
- Classification confidence scores for each query, computed via dot product between
- decoder query features and text features.
- presence_logits (`torch.FloatTensor` of shape `(batch_size, 1)`, *optional*):
- Presence logits from the DETR decoder presence token (last layer only). These indicate whether objects
- are present in the scene. Can be used to compute final scores by multiplying with pred_logits:
- `final_scores = pred_logits.sigmoid() * presence_logits.sigmoid()`.
- semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
- Semantic segmentation output.
- decoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
- Tuple of hidden states from all DETR decoder layers. Each tensor has shape `(batch_size, num_queries, hidden_size)`.
- decoder_reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`, *optional*):
- Reference boxes from all DETR decoder layers.
- encoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
- Tuple of hidden states from all DETR encoder layers.
- vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
- Tuple of hidden states from all vision encoder (ViT) layers.
- vision_attentions (`tuple[torch.FloatTensor]`, *optional*):
- Attention weights from vision encoder (ViT) layers.
- detr_encoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
- Attention weights from DETR encoder layers.
- detr_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
- Attention weights from DETR decoder layers (self-attention and cross-attention).
- mask_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
- Attention weights from mask decoder layers.
- """
- pred_masks: torch.FloatTensor = None
- pred_boxes: torch.FloatTensor = None
- pred_logits: torch.FloatTensor | None = None
- presence_logits: torch.FloatTensor | None = None
- semantic_seg: torch.FloatTensor | None = None
- decoder_hidden_states: tuple[torch.FloatTensor] | None = None
- decoder_reference_boxes: torch.FloatTensor | None = None
- encoder_hidden_states: tuple[torch.FloatTensor] | None = None
- vision_hidden_states: tuple[torch.FloatTensor] | None = None
- vision_attentions: tuple[torch.FloatTensor] | None = None
- detr_encoder_attentions: tuple[torch.FloatTensor] | None = None
- detr_decoder_attentions: tuple[torch.FloatTensor] | None = None
- mask_decoder_attentions: tuple[torch.FloatTensor] | None = None
- def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
- """The inverse function for sigmoid activation function."""
- x = x.clamp(min=0, max=1)
- x1 = x.clamp(min=eps)
- x2 = (1 - x).clamp(min=eps)
- return torch.log(x1 / x2)
- def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
- """
- Concatenates two right-padded sequences, such that the resulting sequence
- is contiguous and also right-padded.
- Tensors are batch-first, masks are batch-first with True=valid, False=padding.
- Args:
- seq1: A tensor of shape (batch_size, seq1_length, hidden_size).
- mask1: A tensor of shape (batch_size, seq1_length) with True=valid, False=padding.
- seq2: A tensor of shape (batch_size, seq2_length, hidden_size).
- mask2: A tensor of shape (batch_size, seq2_length) with True=valid, False=padding.
- return_index: If True, also returns the index of the ids of the element of seq2
- in the concatenated sequence. This can be used to retrieve the elements of seq2.
- Returns:
- A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
- otherwise (concatenated_sequence, concatenated_mask, index).
- The concatenated_mask uses True=valid, False=padding convention.
- """
- batch_size, seq1_length, hidden_size = seq1.shape
- batch_size2, seq2_length, hidden_size2 = seq2.shape
- assert batch_size == batch_size2 == mask1.size(0) == mask2.size(0)
- assert hidden_size == hidden_size2
- assert seq1_length == mask1.size(1)
- assert seq2_length == mask2.size(1)
- actual_seq1_lengths = mask1.sum(dim=-1)
- actual_seq2_lengths = mask2.sum(dim=-1)
- final_lengths = actual_seq1_lengths + actual_seq2_lengths
- max_length = seq1_length + seq2_length
- concatenated_mask = (
- torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) < final_lengths[:, None]
- )
- concatenated_sequence = torch.zeros((batch_size, max_length, hidden_size), device=seq2.device, dtype=seq2.dtype)
- concatenated_sequence[:, :seq1_length, :] = seq1
- # Shift seq2 elements to start at the end of valid seq1
- index = torch.arange(seq2_length, device=seq2.device)[None].repeat(batch_size, 1)
- index = index + actual_seq1_lengths[:, None]
- # Scatter seq2 into the right positions
- concatenated_sequence = concatenated_sequence.scatter(1, index[:, :, None].expand(-1, -1, hidden_size), seq2)
- if return_index:
- return concatenated_sequence, concatenated_mask, index
- return concatenated_sequence, concatenated_mask
- def box_cxcywh_to_xyxy(x):
- """Convert boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format."""
- x_c, y_c, w, h = x.unbind(-1)
- b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
- return torch.stack(b, dim=-1)
- class Sam3MLP(nn.Module):
- def __init__(self, config: Sam3ViTConfig):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class Sam3Attention(nn.Module):
- """
- Multi-head attention.
- Handles standard [batch_size, seq_len, hidden_size] tensors.
- """
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.num_attention_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // config.num_attention_heads
- self.scaling = self.head_dim**-0.5
- self.is_causal = False
- self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
- self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
- self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Args:
- query: [batch_size, query_len, hidden_size]
- key: [batch_size, key_len, hidden_size]
- value: [batch_size, value_len, hidden_size]
- attention_mask: [batch_size, num_heads, query_len, key_len] or broadcastable
- Returns:
- Tuple of (output, attention_weights)
- output: [batch_size, query_len, hidden_size]
- attention_weights: [batch_size, num_heads, query_len, key_len]
- """
- batch_size = query.shape[0]
- query_len = query.shape[1]
- key_len = key.shape[1]
- query = self.q_proj(query).view(batch_size, query_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
- key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
- value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- if (
- is_flash_attention_requested(self.config)
- and attention_mask is not None
- and attention_mask.dtype != torch.bool
- ):
- # Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
- # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
- attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
- logger.warning_once(
- "Sam3Attention: falling back to SDPA for relative-position cross-attention because "
- "Flash Attention does not support additive bias masks."
- )
- attn_output, attn_weights = attention_interface(
- self,
- query,
- key,
- value,
- attention_mask=attention_mask,
- dropout=0.0,
- scaling=self.scaling,
- is_causal=self.is_causal,
- **kwargs,
- )
- attn_output = attn_output.reshape(batch_size, query_len, self.num_attention_heads * self.head_dim).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class Sam3ViTRotaryEmbedding(nn.Module):
- """
- Vision Rotary Position Embedding for SAM3, following transformers library standards.
- Supports 2D (axial) rotary embeddings for spatial dimensions.
- """
- def __init__(self, config: Sam3ViTConfig, end_x: int, end_y: int, scale: float = 1.0):
- super().__init__()
- dim = config.hidden_size // config.num_attention_heads
- # Ensure even dimension for proper axial splitting
- if dim % 4 != 0:
- raise ValueError("Dimension must be divisible by 4 for axial RoPE")
- self.end_x, self.end_y = end_x, end_y
- self.dim = dim
- self.rope_theta = config.rope_theta
- self.scale = scale
- freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
- flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
- x_positions = (flattened_indices % end_x) * scale
- y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale
- freqs_x = torch.outer(x_positions, freqs).float()
- freqs_y = torch.outer(y_positions, freqs).float()
- inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
- inv_freq = inv_freq.repeat_interleave(2, dim=-1)
- # directly register the cos and sin embeddings as we have a fixed feature shape
- self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
- self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
- @torch.no_grad()
- def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
- # As the feature map size is fixed for each stage, we can just return the pre-computed embeddings.
- return self.rope_embeddings_cos, self.rope_embeddings_sin
- def rotate_pairwise(x):
- """
- pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation.
- This is an optimized version of the following more explicit implementation:
- ```python
- x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device)
- x_rotated[..., ::2] = -x[..., 1::2]
- x_rotated[..., 1::2] = x[..., ::2]
- return x_rotated
- ```
- """
- x = x.view(*x.shape[:-1], -1, 2)
- x1, x2 = x.unbind(dim=-1)
- x = torch.stack((-x2, x1), dim=-1)
- return x.flatten(start_dim=-2)
- def apply_rotary_pos_emb_2d(
- q: torch.Tensor,
- k: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Apply rotary position embedding to query and key tensors for self-attention.
- Args:
- q: Query tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim)
- k: Key tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim)
- cos: Cosine position embedding of shape (seq_len, head_dim)
- sin: Sine position embedding of shape (seq_len, head_dim)
- Returns:
- Rotated (q, k) tensors
- """
- q_embed = q.float()
- q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
- k_embed = k.float()
- k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin)
- return q_embed.type_as(q), k_embed.type_as(k)
- class Sam3ViTRoPEAttention(nn.Module):
- """Self-attention with rotary position encoding."""
- def __init__(self, config: Sam3ViTConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.num_attention_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // config.num_attention_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = False
- self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
- self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
- self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- **kwargs: Unpack[TransformersKwargs],
- ) -> Tensor:
- batch_size, height, width, _ = hidden_states.shape
- seq_len = height * width
- new_shape = (batch_size, seq_len, self.num_attention_heads, self.head_dim)
- query = self.q_proj(hidden_states).view(*new_shape).transpose(1, 2)
- key = self.k_proj(hidden_states).view(*new_shape).transpose(1, 2)
- value = self.v_proj(hidden_states).view(*new_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query, key = apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query,
- key,
- value,
- attention_mask=None,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- is_causal=self.is_causal,
- **kwargs,
- )
- attn_output = attn_output.reshape(batch_size, height, width, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class Sam3ViTPatchEmbeddings(nn.Module):
- """
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
- Transformer.
- """
- def __init__(self, config: Sam3ViTConfig):
- super().__init__()
- image_size, patch_size = config.pretrain_image_size, config.patch_size
- num_channels, hidden_size = config.num_channels, config.hidden_size
- image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
- patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- self.image_size = image_size
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.num_patches = num_patches
- self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
- embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
- return embeddings
- class Sam3ViTEmbeddings(nn.Module):
- """
- Construct the patch embeddings and position embeddings for SAM3 ViT.
- Position embeddings are tiled (not interpolated) when resizing to match different input sizes.
- """
- def __init__(self, config: Sam3ViTConfig):
- super().__init__()
- self.patch_embeddings = Sam3ViTPatchEmbeddings(config)
- num_patches = self.patch_embeddings.num_patches
- self.position_embeddings = nn.Parameter(
- torch.randn(1, num_patches, config.hidden_size)
- ) # !Remove cls token in convert weights!
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.patch_size = config.patch_size
- def _tile_position_embeddings(
- self,
- position_embeddings: torch.Tensor,
- height: int,
- width: int,
- ) -> torch.Tensor:
- """
- Tile position embeddings to match target spatial dimensions.
- Args:
- position_embeddings: Shape [1, num_pretrain_patches, hidden_size]
- height: Target height in patches
- width: Target width in patches
- Returns:
- Shape [1, height * width, hidden_size]
- """
- pretrain_size = int(position_embeddings.shape[1] ** 0.5)
- # Skip tiling if sizes match (but always tile during tracing for consistent graph)
- if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width:
- return position_embeddings.reshape(1, height * width, -1)
- # Tile position embeddings to match target spatial dimensions
- hidden_size = position_embeddings.shape[-1]
- pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2)
- repeat_h = height // pretrain_size + 1
- repeat_w = width // pretrain_size + 1
- pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width]
- return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size)
- def forward(
- self,
- pixel_values: torch.Tensor,
- interpolate_pos_encoding: bool = False,
- ) -> torch.Tensor:
- height, width = pixel_values.shape[-2:]
- embeddings = self.patch_embeddings(pixel_values)
- # Calculate spatial dimensions in patches
- height_patches = height // self.patch_size
- width_patches = width // self.patch_size
- position_embeddings = self._tile_position_embeddings(
- self.position_embeddings,
- height_patches,
- width_patches,
- )
- embeddings = embeddings + position_embeddings
- embeddings = self.dropout(embeddings)
- return embeddings
- def window_partition(hidden_state, window_size):
- """
- Partition into non-overlapping windows with padding if needed.
- Args:
- hidden_state (`torch.Tensor`):
- Input tokens with [batch_size, height, width, num_channels].
- window_size (`int`):
- Window size.
- Returns:
- `tuple(torch.FloatTensor)` comprising various elements:
- - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
- - (padded_height, padded_width): padded height and width before partition
- """
- batch_size, height, width, num_channels = hidden_state.shape
- pad_height = (window_size - height % window_size) % window_size
- pad_width = (window_size - width % window_size) % window_size
- # Noop in case pad_width == 0 and pad_height == 0.
- hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
- padded_height, padded_width = height + pad_height, width + pad_width
- hidden_state = hidden_state.view(
- batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
- )
- windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
- return windows, (padded_height, padded_width)
- def window_unpartition(windows, window_size, pad_height_width, height_width):
- """
- Window unpartition into original sequences and removing padding.
- Args:
- windows (`torch.Tensor`):
- Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
- window_size (`int`):
- Window size.
- pad_height_width (`tuple[int]`):
- Padded height and width (padded_height, padded_width).
- height_width (`tuple[int]`):
- Original height and width before padding.
- Returns:
- hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
- """
- padded_height, padded_width = pad_height_width
- height, width = height_width
- batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
- hidden_state = windows.view(
- batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
- )
- hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
- hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
- # We always have height <= padded_height and width <= padded_width
- hidden_state = hidden_state[:, :height, :width, :].contiguous()
- return hidden_state
- class Sam3ViTLayerScale(nn.Module):
- def __init__(self, config) -> None:
- super().__init__()
- self.lambda1 = nn.Parameter(config.layer_scale_init_value * torch.ones(config.hidden_size))
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
- return hidden_state * self.lambda1
- class Sam3ViTLayer(GradientCheckpointingLayer):
- """Vision Transformer layer with rotary position embeddings and optional windowed attention."""
- def __init__(self, config: Sam3ViTConfig, window_size: int = 0) -> None:
- super().__init__()
- hidden_size = config.hidden_size
- image_size = config.image_size
- image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
- patch_size = config.patch_size
- patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
- input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
- self.layer_norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
- rotary_input_size = input_size if window_size == 0 else (window_size, window_size)
- rotary_scale = config.window_size / rotary_input_size[0]
- self.rotary_emb = Sam3ViTRotaryEmbedding(
- config, end_x=rotary_input_size[0], end_y=rotary_input_size[1], scale=rotary_scale
- )
- self.attention = Sam3ViTRoPEAttention(config)
- self.layer_norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
- self.mlp = Sam3MLP(config)
- self.dropout = nn.Dropout(config.hidden_dropout)
- self.window_size = window_size
- def forward(
- self,
- hidden_states: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- if self.window_size > 0:
- height, width = hidden_states.shape[1], hidden_states.shape[2]
- # Partition into non-overlapping windows for efficient attention
- hidden_states, pad_height_width = window_partition(hidden_states, self.window_size)
- position_embeddings = self.rotary_emb()
- hidden_states, _ = self.attention(hidden_states, position_embeddings, **kwargs)
- if self.window_size > 0:
- # Reverse window partition to restore original spatial layout
- hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width))
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + self.dropout(hidden_states)
- return hidden_states
- @auto_docstring
- @requires(backends=("torch", "torchvision"))
- class Sam3PreTrainedModel(PreTrainedModel):
- config_class = Sam3Config
- base_model_prefix = "sam3"
- main_input_name = "pixel_values"
- input_modalities = ["image", "text"]
- _supports_sdpa = True
- _supports_flash_attn = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, Sam3ViTEmbeddings):
- init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
- elif isinstance(module, Sam3ViTRotaryEmbedding):
- end_x, end_y = module.end_x, module.end_y
- dim = module.dim
- freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
- flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
- x_positions = (flattened_indices % end_x) * module.scale
- y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
- freqs_x = torch.outer(x_positions, freqs).float()
- freqs_y = torch.outer(y_positions, freqs).float()
- inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
- inv_freq = inv_freq.repeat_interleave(2, dim=-1)
- init.copy_(module.rope_embeddings_cos, inv_freq.cos())
- init.copy_(module.rope_embeddings_sin, inv_freq.sin())
- @auto_docstring
- class Sam3ViTModel(Sam3PreTrainedModel):
- config: Sam3ViTConfig
- _can_record_outputs = {
- "hidden_states": Sam3ViTLayer,
- "attentions": Sam3ViTRoPEAttention,
- }
- def __init__(self, config: Sam3ViTConfig):
- super().__init__(config)
- self.config = config
- self.embeddings = Sam3ViTEmbeddings(config)
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.layers = nn.ModuleList(
- [
- Sam3ViTLayer(config, window_size=config.window_size if i not in config.global_attn_indexes else 0)
- for i in range(config.num_hidden_layers)
- ]
- )
- self.post_init()
- def get_input_embeddings(self) -> Sam3ViTPatchEmbeddings:
- return self.embeddings.patch_embeddings
- @merge_with_config_defaults
- @capture_outputs(tie_last_hidden_states=False)
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> BaseModelOutput:
- hidden_states = self.embeddings(pixel_values) # [batch_size, seq_len, hidden_size]
- batch_size = hidden_states.shape[0]
- height = pixel_values.shape[-2] // self.config.patch_size
- width = pixel_values.shape[-1] // self.config.patch_size
- hidden_size = hidden_states.shape[-1]
- # Reshape to spatial format for windowed attention: [batch_size, height, width, hidden_size]
- hidden_states = hidden_states.view(batch_size, height, width, hidden_size)
- hidden_states = self.layer_norm(hidden_states)
- for layer in self.layers:
- hidden_states = layer(hidden_states, **kwargs)
- # Reshape back to sequence format: [batch_size, height*width, hidden_size]
- hidden_states = hidden_states.view(batch_size, height * width, hidden_size)
- return BaseModelOutput(last_hidden_state=hidden_states)
- class Sam3SinePositionEmbedding(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
- need paper, generalized to work on images.
- """
- def __init__(
- self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
- ):
- super().__init__()
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- self.scale = 2 * math.pi if scale is None else scale
- def encode_1d_positions(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Encode 1D coordinate pairs using sine/cosine positional embeddings.
- Args:
- x: 1D tensor of x coordinates (flattened)
- y: 1D tensor of y coordinates (flattened)
- Returns:
- Tuple of (pos_x, pos_y) positional embeddings
- """
- x_embed = x * self.scale
- y_embed = y * self.scale
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).to(x.dtype)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
- pos_x = x_embed[:, None] / dim_t
- pos_y = y_embed[:, None] / dim_t
- pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
- pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
- return pos_x, pos_y
- def encode_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
- """
- Encode 4D box coordinates (x, y, w, h) for decoder conditioning using sine/cosine embeddings.
- Args:
- boxes: Box coordinates [batch_size, num_queries, 4] in (x, y, w, h) format
- Returns:
- Position embeddings [batch_size, num_queries, num_pos_feats*4]
- """
- assert boxes.size(-1) == 4, f"Expected 4D box coordinates (x, y, w, h), got shape {boxes.shape}"
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=boxes.device).to(boxes.dtype)
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
- x_embed = boxes[:, :, 0] * self.scale
- y_embed = boxes[:, :, 1] * self.scale
- w_embed = boxes[:, :, 2] * self.scale
- h_embed = boxes[:, :, 3] * self.scale
- pos_x = x_embed[:, :, None] / dim_t
- pos_y = y_embed[:, :, None] / dim_t
- pos_w = w_embed[:, :, None] / dim_t
- pos_h = h_embed[:, :, None] / dim_t
- pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
- pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
- pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
- pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
- pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
- return pos
- @compile_compatible_method_lru_cache(maxsize=4)
- def forward(
- self,
- shape: torch.Size,
- device: torch.device | str,
- dtype: torch.dtype,
- mask: Tensor | None = None,
- ) -> Tensor:
- if mask is None:
- mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
- not_mask = (~mask).to(dtype)
- y_embed = not_mask.cumsum(1)
- x_embed = not_mask.cumsum(2)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
- dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
- class Sam3FPNLayer(nn.Module):
- def __init__(self, in_channels: int, fpn_dim: int, scale_factor: float):
- super().__init__()
- self.scale_factor = scale_factor
- # Build the upsampling/downsampling layers based on scale factor
- self.scale_layers = nn.ModuleList()
- if scale_factor == 4.0:
- self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2))
- self.scale_layers.append(nn.GELU())
- self.scale_layers.append(nn.ConvTranspose2d(in_channels // 2, in_channels // 4, kernel_size=2, stride=2))
- intermediate_channels = in_channels // 4
- elif scale_factor == 2.0:
- self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2))
- intermediate_channels = in_channels // 2
- elif scale_factor == 1.0:
- intermediate_channels = in_channels
- elif scale_factor == 0.5:
- self.scale_layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
- intermediate_channels = in_channels
- else:
- raise NotImplementedError(f"scale_factor={scale_factor} is not supported yet.")
- self.proj1 = nn.Conv2d(in_channels=intermediate_channels, out_channels=fpn_dim, kernel_size=1)
- self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = hidden_states.to(self.proj1.weight.dtype)
- for layer in self.scale_layers:
- hidden_states = layer(hidden_states)
- hidden_states = self.proj1(hidden_states)
- hidden_states = self.proj2(hidden_states)
- return hidden_states
- class Sam3VisionNeck(nn.Module):
- def __init__(self, config: Sam3VisionConfig):
- super().__init__()
- self.config = config
- self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
- # Create one FPN layer per scale factor
- self.fpn_layers = nn.ModuleList(
- [
- Sam3FPNLayer(
- in_channels=config.backbone_config.hidden_size, fpn_dim=config.fpn_hidden_size, scale_factor=scale
- )
- for scale in config.scale_factors
- ]
- )
- def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
- fpn_hidden_states = ()
- fpn_position_encoding = ()
- for fpn_layer in self.fpn_layers:
- fpn_output = fpn_layer(hidden_states)
- fpn_hidden_states += (fpn_output,)
- # Generate position encoding for this FPN level
- pos_enc = self.position_encoding(fpn_output.shape, fpn_output.device, fpn_output.dtype)
- fpn_position_encoding += (pos_enc,)
- return fpn_hidden_states, fpn_position_encoding
- @auto_docstring(
- custom_intro="""
- The vision model from Sam without any head or projection on top.
- """
- )
- class Sam3VisionModel(Sam3PreTrainedModel):
- config_class = Sam3VisionConfig
- main_input_name = "pixel_values"
- def __init__(self, config: Sam3VisionConfig):
- super().__init__(config)
- self.config = config
- self.backbone = AutoModel.from_config(config.backbone_config)
- self.neck = Sam3VisionNeck(config)
- self.post_init()
- def get_input_embeddings(self):
- return self.backbone.get_input_embeddings()
- @can_return_tuple
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Sam3VisionEncoderOutput:
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- backbone_output = self.backbone(pixel_values, **kwargs)
- hidden_states = backbone_output.last_hidden_state # [batch_size, seq_len, hidden_size]
- # Reshape for FPN neck: [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size, height, width]
- batch_size = hidden_states.shape[0]
- height = pixel_values.shape[-2] // self.config.backbone_config.patch_size
- width = pixel_values.shape[-1] // self.config.backbone_config.patch_size
- hidden_states_spatial = hidden_states.view(batch_size, height, width, -1).permute(0, 3, 1, 2)
- fpn_hidden_states, fpn_position_encoding = self.neck(hidden_states_spatial)
- return Sam3VisionEncoderOutput(
- last_hidden_state=hidden_states,
- fpn_hidden_states=fpn_hidden_states,
- fpn_position_encoding=fpn_position_encoding,
- hidden_states=backbone_output.hidden_states,
- attentions=backbone_output.attentions,
- )
- class Sam3GeometryEncoderLayer(nn.Module):
- def __init__(self, config: Sam3GeometryEncoderConfig):
- super().__init__()
- self.layer_norm1 = nn.LayerNorm(config.hidden_size)
- self.self_attn = Sam3Attention(config)
- self.dropout = nn.Dropout(config.dropout)
- self.cross_attn = Sam3Attention(config)
- self.layer_norm2 = nn.LayerNorm(config.hidden_size)
- self.mlp = Sam3MLP(config)
- self.layer_norm3 = nn.LayerNorm(config.hidden_size)
- def forward(
- self,
- prompt_feats: Tensor,
- vision_feats: Tensor,
- vision_pos_encoding: Tensor,
- prompt_mask: Tensor,
- **kwargs: Unpack[TransformersKwargs],
- ):
- residual = prompt_feats
- hidden_states = self.layer_norm1(prompt_feats)
- hidden_states, _ = self.self_attn(
- query=hidden_states, key=hidden_states, value=hidden_states, attention_mask=prompt_mask, **kwargs
- )
- hidden_states = self.dropout(hidden_states) + residual
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- key = vision_feats + vision_pos_encoding
- hidden_states, _ = self.cross_attn(query=hidden_states, key=key, value=vision_feats, **kwargs)
- hidden_states = self.dropout(hidden_states) + residual
- residual = hidden_states
- hidden_states = self.layer_norm3(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.dropout(hidden_states) + residual
- return hidden_states
- class Sam3GeometryEncoder(nn.Module):
- """
- Encoder for geometric prompts (boxes).
- Boxes are encoded using three approaches:
- - Direct projection: linear projection from coordinate space to hidden_size
- - Pooling: pool features from the backbone at the specified location (ROI align for boxes)
- - Position encoding: use position encoding of the box center
- These encodings are combined additively and further processed with transformer layers.
- """
- def __init__(self, config: Sam3GeometryEncoderConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.roi_size = config.roi_size
- self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=True)
- self.label_embed = nn.Embedding(2, self.hidden_size)
- self.cls_embed = nn.Embedding(1, self.hidden_size)
- # Box encoding layers
- self.boxes_direct_project = nn.Linear(4, self.hidden_size)
- self.boxes_pool_project = nn.Conv2d(self.hidden_size, self.hidden_size, self.roi_size)
- self.boxes_pos_enc_project = nn.Linear(self.hidden_size + 2, self.hidden_size)
- # Image feature normalization
- self.vision_layer_norm = nn.LayerNorm(self.hidden_size)
- # Prompt projection and normalization
- self.final_proj = nn.Linear(self.hidden_size, self.hidden_size)
- self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)
- # Transformer layers
- self.layers = nn.ModuleList([Sam3GeometryEncoderLayer(config) for _ in range(config.num_layers)])
- self.output_layer_norm = nn.LayerNorm(self.hidden_size)
- def _encode_box_coordinates(
- self, center_x: torch.Tensor, center_y: torch.Tensor, width: torch.Tensor, height: torch.Tensor
- ) -> torch.Tensor:
- """
- Encode box coordinates by combining position-encoded centers with raw width/height.
- Args:
- center_x: 1D tensor of box center x coordinates
- center_y: 1D tensor of box center y coordinates
- width: 1D tensor of box widths
- height: 1D tensor of box heights
- Returns:
- Encoded box coordinates [N, embedding_dim]
- """
- pos_x, pos_y = self.position_encoding.encode_1d_positions(center_x, center_y)
- pos = torch.cat((pos_y, pos_x, height[:, None], width[:, None]), dim=1)
- return pos
- def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features):
- """Encode box prompts. Mask convention: True=valid, False=padding."""
- batch_size, num_boxes = boxes.shape[:2]
- height, width = vision_features.shape[-2:]
- boxes_embed = self.boxes_direct_project(boxes)
- # Pool features using ROI align
- # Convert boxes from CxCyWH to xyxy format and denormalize
- boxes_xyxy = box_cxcywh_to_xyxy(boxes)
- scale = torch.tensor([width, height, width, height], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
- scale = scale.view(1, 1, 4)
- boxes_xyxy = boxes_xyxy * scale
- # ROI align expects list of boxes per batch element,
- # convert from bfloat16 to float16 as roi_align only supports float16 and float32
- dtype = torch.float16 if vision_features.dtype == torch.bfloat16 else vision_features.dtype
- sampled_features = torchvision.ops.roi_align(
- vision_features.to(dtype), boxes_xyxy.to(dtype).unbind(0), self.roi_size
- ).to(vision_features.dtype)
- pooled_projection = self.boxes_pool_project(sampled_features)
- pooled_projection = pooled_projection.view(batch_size, num_boxes, self.hidden_size)
- boxes_embed = boxes_embed + pooled_projection
- # Add position encoding
- center_x, center_y, box_width, box_height = boxes.unbind(-1)
- pos_enc = self._encode_box_coordinates(
- center_x.flatten(), center_y.flatten(), box_width.flatten(), box_height.flatten()
- )
- pos_enc = pos_enc.view(batch_size, num_boxes, pos_enc.shape[-1])
- pos_projection = self.boxes_pos_enc_project(pos_enc)
- boxes_embed = boxes_embed + pos_projection
- # Add label embeddings (positive/negative)
- label_embed = self.label_embed(boxes_labels.long())
- return label_embed + boxes_embed, boxes_mask
- def forward(
- self,
- box_embeddings: torch.Tensor,
- box_mask: torch.Tensor,
- box_labels: torch.Tensor,
- img_feats: tuple[torch.Tensor, ...],
- img_pos_embeds: tuple[torch.Tensor, ...] | None = None,
- ):
- """
- Forward pass for encoding geometric prompts.
- Args:
- box_embeddings: Box coordinates in CxCyWH format [batch_size, num_boxes, 4]
- box_mask: Attention mask for boxes [batch_size, num_boxes]
- box_labels: Labels for boxes (positive/negative) [batch_size, num_boxes]
- img_feats: Image features from vision encoder
- img_pos_embeds: Optional position embeddings for image features
- Returns:
- Sam3GeometryEncoderOutput containing encoded geometry features and attention mask.
- """
- batch_size = box_embeddings.shape[0]
- # Prepare vision features for cross-attention: flatten spatial dimensions
- vision_feats = img_feats[-1] # [B, C, H, W]
- vision_pos_embeds = img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(vision_feats)
- vision_feats_flat = vision_feats.flatten(2).transpose(1, 2) # [B, H*W, C]
- vision_pos_embeds_flat = vision_pos_embeds.flatten(2).transpose(1, 2) # [B, H*W, C]
- # Normalize image features for pooling operations
- img_feats_last = img_feats[-1] # [B, C, H, W]
- img_feats_last = img_feats_last.permute(0, 2, 3, 1) # [B, H, W, C]
- normalized_img_feats = self.vision_layer_norm(img_feats_last)
- normalized_img_feats = normalized_img_feats.permute(0, 3, 1, 2) # [B, C, H, W]
- prompt_embeds, prompt_mask = self._encode_boxes(box_embeddings, box_mask, box_labels, normalized_img_feats)
- # Add CLS token (always valid)
- cls_embed = self.cls_embed.weight.view(1, self.hidden_size).unsqueeze(0).expand(batch_size, -1, -1)
- cls_mask = torch.ones(batch_size, 1, dtype=prompt_mask.dtype, device=prompt_mask.device)
- prompt_embeds, prompt_mask = concat_padded_sequences(prompt_embeds, prompt_mask, cls_embed, cls_mask)
- prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds))
- # Create bidirectional attention mask for transformer layers
- prompt_attention_mask = None
- if prompt_mask is not None:
- prompt_attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=prompt_embeds,
- attention_mask=prompt_mask,
- )
- # Apply transformer layers with cross-attention to vision features
- for layer in self.layers:
- prompt_embeds = layer(
- prompt_feats=prompt_embeds,
- vision_feats=vision_feats_flat,
- vision_pos_encoding=vision_pos_embeds_flat,
- prompt_mask=prompt_attention_mask,
- )
- # Final output normalization
- prompt_embeds = self.output_layer_norm(prompt_embeds)
- return Sam3GeometryEncoderOutput(
- last_hidden_state=prompt_embeds,
- attention_mask=prompt_mask,
- )
- class Sam3DetrEncoderLayer(nn.Module):
- """DETR encoder layer with self-attention and cross-attention."""
- def __init__(self, config: Sam3DETREncoderConfig):
- super().__init__()
- self.config = config
- self.layer_norm1 = nn.LayerNorm(config.hidden_size)
- self.self_attn = Sam3Attention(config)
- self.dropout = nn.Dropout(config.dropout)
- self.cross_attn = Sam3Attention(config)
- self.layer_norm2 = nn.LayerNorm(config.hidden_size)
- self.mlp = Sam3MLP(config)
- self.layer_norm3 = nn.LayerNorm(config.hidden_size)
- def forward(
- self,
- vision_feats: Tensor,
- prompt_feats: Tensor,
- vision_pos_encoding: Tensor,
- prompt_cross_attn_mask: Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- """
- Forward pass for DETR encoder layer.
- Args:
- vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
- prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
- vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
- prompt_cross_attn_mask: Cross-attention mask for prompt features
- Returns:
- Updated vision features [batch_size, vision_len, hidden_size]
- """
- # Self-attention on vision features with position encoding
- residual = vision_feats
- hidden_states = self.layer_norm1(vision_feats)
- hidden_states_with_pos = hidden_states + vision_pos_encoding
- hidden_states, _ = self.self_attn(
- query=hidden_states_with_pos,
- key=hidden_states_with_pos,
- value=hidden_states,
- **kwargs,
- )
- hidden_states = self.dropout(hidden_states) + residual
- # Cross-attention: vision queries attend to text/prompt features
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states, _ = self.cross_attn(
- query=hidden_states,
- key=prompt_feats,
- value=prompt_feats,
- attention_mask=prompt_cross_attn_mask,
- **kwargs,
- )
- hidden_states = self.dropout(hidden_states) + residual
- # MLP
- residual = hidden_states
- hidden_states = self.layer_norm3(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.dropout(hidden_states) + residual
- return hidden_states
- class Sam3DetrEncoder(Sam3PreTrainedModel):
- """
- DETR-style encoder that processes multi-level vision features with text fusion.
- This encoder processes vision features from multiple levels (e.g., FPN features at different
- resolutions) and fuses them with text prompts through a stack of transformer encoder layers.
- """
- _can_record_outputs = {
- "hidden_states": Sam3DetrEncoderLayer,
- "attentions": Sam3Attention,
- }
- def __init__(self, config: Sam3DETREncoderConfig):
- super().__init__(config)
- self.config = config
- self.hidden_size = config.hidden_size
- self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])
- self.post_init()
- def _prepare_multilevel_features(
- self,
- vision_features: list[torch.Tensor],
- vision_pos_embeds: list[torch.Tensor],
- ):
- """
- Prepare multi-level vision features by flattening spatial dimensions and adding level embeddings.
- Args:
- vision_features: List of vision features at different levels [batch_size, channels, height, width]
- vision_pos_embeds: List of position embeddings for each level [batch_size, channels, height, width]
- Returns:
- Tuple containing flattened features, position embeddings, and spatial metadata
- """
- features_flattened = []
- pos_embeds_flattened = []
- spatial_shapes = []
- for features, pos_embed in zip(vision_features, vision_pos_embeds):
- height, width = features.shape[-2:]
- spatial_shapes.append((height, width))
- # Flatten spatial dimensions: [batch_size, channels, height, width] -> [batch_size, height*width, channels]
- features = features.flatten(2).transpose(1, 2)
- pos_embed = pos_embed.flatten(2).transpose(1, 2)
- features_flattened.append(features)
- pos_embeds_flattened.append(pos_embed)
- # Concatenate all levels into single sequence
- features_flattened = torch.cat(features_flattened, dim=1)
- pos_embeds_flattened = torch.cat(pos_embeds_flattened, dim=1)
- spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=features_flattened.device)
- return (
- features_flattened,
- pos_embeds_flattened,
- spatial_shapes,
- )
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- vision_features: list[torch.Tensor],
- text_features: torch.Tensor,
- vision_pos_embeds: list[torch.Tensor] | None = None,
- text_mask: torch.Tensor | None = None,
- spatial_sizes: list[tuple[int, int]] | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Sam3DETREncoderOutput:
- """
- Forward pass for the DETR encoder.
- Args:
- vision_features: List of vision features at different levels
- text_features: Text prompt features [batch_size, seq_len, hidden_size]
- vision_pos_embeds: Optional list of position embeddings for each level
- text_mask: Optional text padding mask [batch_size, seq_len]
- spatial_sizes: Optional list of (height, width) tuples for reshaping
- Returns:
- Sam3DETREncoderOutput containing encoded features and metadata.
- """
- batch_size = vision_features[0].shape[0] if vision_features[0].dim() == 4 else vision_features[0].shape[1]
- # TODO: See if we can remove that reshaping and just use the features as is.
- if spatial_sizes is not None:
- for i, (height, width) in enumerate(spatial_sizes):
- # Reshape from [height*width, batch_size, channels] to [batch_size, channels, height, width]
- vision_features[i] = vision_features[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
- vision_pos_embeds[i] = vision_pos_embeds[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
- # Flatten multi-level features for encoder processing
- (
- features_flattened,
- pos_embeds_flattened,
- spatial_shapes,
- ) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
- prompt_cross_attn_mask = None
- if text_mask is not None:
- prompt_cross_attn_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=features_flattened,
- attention_mask=text_mask,
- encoder_hidden_states=text_features,
- )
- hidden_states = features_flattened
- for layer in self.layers:
- hidden_states = layer(
- hidden_states,
- prompt_feats=text_features,
- vision_pos_encoding=pos_embeds_flattened,
- prompt_cross_attn_mask=prompt_cross_attn_mask,
- **kwargs,
- )
- return Sam3DETREncoderOutput(
- last_hidden_state=hidden_states,
- pos_embeds_flattened=pos_embeds_flattened,
- text_features=text_features,
- spatial_shapes=spatial_shapes,
- )
- class Sam3DecoderMLP(nn.Module):
- """Simple 2 or 3-layer MLP for decoder components."""
- def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2):
- super().__init__()
- if num_layers == 2:
- self.layer1 = nn.Linear(input_dim, hidden_dim)
- self.layer2 = nn.Linear(hidden_dim, output_dim)
- self.layer3 = None
- elif num_layers == 3:
- self.layer1 = nn.Linear(input_dim, hidden_dim)
- self.layer2 = nn.Linear(hidden_dim, hidden_dim)
- self.layer3 = nn.Linear(hidden_dim, output_dim)
- else:
- raise ValueError(f"Only 2 or 3 layers supported, got {num_layers}")
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = F.relu(self.layer1(x))
- if self.layer3 is not None:
- x = F.relu(self.layer2(x))
- x = self.layer3(x)
- else:
- x = self.layer2(x)
- return x
- class Sam3DetrDecoderLayer(nn.Module):
- """DETR decoder layer with self-attention, text cross-attention, and vision cross-attention."""
- def __init__(self, config: Sam3DETRDecoderConfig):
- super().__init__()
- self.config = config
- self.self_attn = Sam3Attention(config)
- self.self_attn_dropout = nn.Dropout(config.dropout)
- self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
- self.text_cross_attn = Sam3Attention(config)
- self.text_cross_attn_dropout = nn.Dropout(config.dropout)
- self.text_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
- self.vision_cross_attn = Sam3Attention(config)
- self.vision_cross_attn_dropout = nn.Dropout(config.dropout)
- self.vision_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
- self.mlp = Sam3MLP(config)
- self.mlp_layer_norm = nn.LayerNorm(config.hidden_size)
- self.mlp_dropout = nn.Dropout(config.dropout)
- def forward(
- self,
- hidden_states: torch.Tensor,
- query_pos: torch.Tensor,
- text_features: torch.Tensor,
- vision_features: torch.Tensor,
- vision_pos_encoding: torch.Tensor,
- text_cross_attn_mask: torch.Tensor | None = None,
- vision_cross_attn_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- """
- Forward pass for decoder layer.
- Args:
- hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
- query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
- text_features: Text features [batch_size, seq_len, hidden_size]
- vision_features: Vision features [batch_size, height*width, hidden_size]
- vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
- text_cross_attn_mask: Text cross-attention mask
- vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
- Returns:
- Updated hidden states (including presence token at position 0)
- """
- # Prepend zeros to query_pos for presence token
- query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
- # Self-attention with query position encoding
- residual = hidden_states
- query_with_pos = hidden_states + query_pos
- attn_output, _ = self.self_attn(
- query=query_with_pos,
- key=query_with_pos,
- value=hidden_states,
- attention_mask=None,
- **kwargs,
- )
- hidden_states = residual + self.self_attn_dropout(attn_output)
- hidden_states = self.self_attn_layer_norm(hidden_states)
- # Text cross-attention: queries attend to text features
- residual = hidden_states
- query_with_pos = hidden_states + query_pos
- attn_output, _ = self.text_cross_attn(
- query=query_with_pos,
- key=text_features,
- value=text_features,
- attention_mask=text_cross_attn_mask,
- **kwargs,
- )
- hidden_states = residual + self.text_cross_attn_dropout(attn_output)
- hidden_states = self.text_cross_attn_layer_norm(hidden_states)
- # Vision cross-attention: queries attend to vision features (with RPB)
- residual = hidden_states
- query_with_pos = hidden_states + query_pos
- key_with_pos = vision_features + vision_pos_encoding
- attn_output, _ = self.vision_cross_attn(
- query=query_with_pos,
- key=key_with_pos,
- value=vision_features,
- attention_mask=vision_cross_attn_mask,
- **kwargs,
- )
- hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
- hidden_states = self.vision_cross_attn_layer_norm(hidden_states)
- # MLP
- residual = hidden_states
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + self.mlp_dropout(hidden_states)
- hidden_states = self.mlp_layer_norm(hidden_states)
- return hidden_states
- class Sam3DetrDecoder(Sam3PreTrainedModel):
- """
- DETR-style decoder with box refinement and presence token.
- Simplified version that assumes:
- - Box refinement is always enabled
- - Intermediate outputs are always returned
- - BoxRPB (relative position bias) with log-scale encoding
- - Presence token is used
- """
- _can_record_outputs = {
- "hidden_states": Sam3DetrDecoderLayer,
- "attentions": Sam3Attention,
- }
- def __init__(
- self,
- config: Sam3DETRDecoderConfig,
- ):
- super().__init__(config)
- self.config = config
- self.hidden_size = config.hidden_size
- self.layers = nn.ModuleList([Sam3DetrDecoderLayer(config) for _ in range(config.num_layers)])
- self.output_layer_norm = nn.LayerNorm(config.hidden_size)
- self.box_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 4, 3)
- self.query_embed = nn.Embedding(config.num_queries, config.hidden_size)
- self.reference_points = nn.Embedding(config.num_queries, 4)
- self.presence_token = nn.Embedding(1, config.hidden_size)
- self.presence_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 1, 3)
- self.presence_layer_norm = nn.LayerNorm(config.hidden_size)
- self.clamp_presence_logit_max_val = 10.0
- self.ref_point_head = Sam3DecoderMLP(2 * config.hidden_size, config.hidden_size, config.hidden_size, 2)
- self.box_rpb_embed_x = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
- self.box_rpb_embed_y = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
- self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)
- self.post_init()
- @compile_compatible_method_lru_cache(maxsize=1)
- def _get_coords(
- self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Generate normalized coordinate grids."""
- coords_h = torch.arange(0, height, device=device, dtype=dtype) / height
- coords_w = torch.arange(0, width, device=device, dtype=dtype) / width
- return coords_h, coords_w
- def _get_rpb_matrix(
- self, reference_boxes: torch.Tensor, spatial_shape: tuple[torch.Tensor, torch.Tensor]
- ) -> torch.Tensor:
- """
- Compute box relative position bias (RPB) matrix using log-scale encoding.
- RPB helps the decoder attend to relevant spatial locations based on predicted box positions.
- Args:
- reference_boxes: Reference boxes [batch_size, num_queries, 4] in sigmoid space
- spatial_shape: (height, width) of the vision features as tensors
- Returns:
- RPB matrix [batch_size, num_heads, num_queries, height*width]
- """
- height, width = spatial_shape
- boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
- batch_size, num_queries, _ = boxes_xyxy.shape
- # Generate coordinate grids
- coords_h, coords_w = self._get_coords(
- height, width, dtype=reference_boxes.dtype, device=reference_boxes.device
- )
- # Compute deltas between coordinates and box boundaries
- deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
- deltas_y = deltas_y.view(batch_size, num_queries, -1, 2)
- deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
- deltas_x = deltas_x.view(batch_size, num_queries, -1, 2)
- # Apply log-scale encoding
- deltas_x_log = deltas_x * 8
- deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / math.log2(8)
- deltas_y_log = deltas_y * 8
- deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / math.log2(8)
- # Embed deltas
- deltas_x = self.box_rpb_embed_x(deltas_x_log) # [batch_size, num_queries, width, num_heads]
- deltas_y = self.box_rpb_embed_y(deltas_y_log) # [batch_size, num_queries, height, num_heads]
- # Combine into 2D bias matrix
- rpb_matrix = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
- 2
- ) # [batch_size, num_queries, height, width, num_heads]
- rpb_matrix = rpb_matrix.flatten(2, 3) # [batch_size, num_queries, height*width, num_heads]
- rpb_matrix = rpb_matrix.permute(0, 3, 1, 2).contiguous() # [batch_size, num_heads, num_queries, height*width]
- return rpb_matrix
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- vision_features: torch.Tensor,
- text_features: torch.Tensor,
- vision_pos_encoding: torch.Tensor,
- text_mask: torch.Tensor | None = None,
- spatial_shapes: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Sam3DETRDecoderOutput:
- """
- Forward pass for the DETR decoder.
- Args:
- vision_features: Vision features [batch_size, height*width, hidden_size]
- text_features: Text features [batch_size, seq_len, hidden_size]
- vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
- text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding
- spatial_shapes: Spatial shapes [num_levels, 2]
- Returns:
- Sam3DETRDecoderOutput containing decoder outputs from all layers.
- """
- batch_size = vision_features.shape[0]
- query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
- reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
- reference_boxes = reference_boxes.sigmoid()
- presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
- # Concatenate presence token with query embeddings
- hidden_states = torch.cat([presence_token, query_embeds], dim=1)
- text_cross_attn_mask = None
- if text_mask is not None:
- text_cross_attn_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=hidden_states,
- attention_mask=text_mask,
- encoder_hidden_states=text_features,
- )
- intermediate_outputs = []
- intermediate_boxes = [reference_boxes]
- intermediate_presence_logits = []
- for layer in self.layers:
- # Generate sine embeddings for conditional queries
- reference_points_input = reference_boxes.unsqueeze(2)
- query_sine_embed = self.position_encoding.encode_boxes(reference_points_input[:, :, 0, :])
- query_pos = self.ref_point_head(query_sine_embed)
- # Compute box relative position bias (RPB) attention mask
- vision_cross_attn_mask = None
- if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
- spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
- rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
- # Prepend zeros row for presence token (it attends to all vision tokens equally)
- vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
- hidden_states = layer(
- hidden_states,
- query_pos=query_pos,
- text_features=text_features,
- vision_features=vision_features,
- vision_pos_encoding=vision_pos_encoding,
- text_cross_attn_mask=text_cross_attn_mask,
- vision_cross_attn_mask=vision_cross_attn_mask,
- **kwargs,
- )
- # Extract query hidden states (without presence token) for box refinement
- query_hidden_states = hidden_states[:, 1:]
- # Box refinement: predict delta and update reference boxes
- reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
- delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
- new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
- reference_boxes = new_reference_boxes.detach()
- intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
- intermediate_boxes.append(new_reference_boxes)
- # Process presence token
- presence_hidden = hidden_states[:, :1]
- presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
- presence_logits = presence_logits.clamp(
- min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
- )
- intermediate_presence_logits.append(presence_logits)
- # Stack outputs from all layers
- intermediate_outputs = torch.stack(intermediate_outputs)
- intermediate_boxes = torch.stack(intermediate_boxes[:-1])
- intermediate_presence_logits = torch.stack(intermediate_presence_logits)
- return Sam3DETRDecoderOutput(
- intermediate_hidden_states=intermediate_outputs,
- reference_boxes=intermediate_boxes,
- presence_logits=intermediate_presence_logits,
- )
- class Sam3DotProductScoring(nn.Module):
- """
- Computes classification scores by computing dot product between projected decoder queries and pooled text features.
- This is used to determine confidence/presence scores for each query.
- """
- def __init__(self, config: Sam3Config):
- super().__init__()
- self.config = config
- hidden_size = config.detr_decoder_config.hidden_size
- projection_dim = config.detr_decoder_config.hidden_size
- self.text_mlp = Sam3DecoderMLP(
- input_dim=hidden_size,
- hidden_dim=config.detr_decoder_config.intermediate_size,
- output_dim=hidden_size,
- num_layers=2,
- )
- self.text_mlp_dropout = nn.Dropout(config.detr_decoder_config.dropout)
- self.text_mlp_out_norm = nn.LayerNorm(hidden_size)
- # Projections for text and query features
- self.text_proj = nn.Linear(hidden_size, projection_dim)
- self.query_proj = nn.Linear(hidden_size, projection_dim)
- # Scale factor for dot product
- self.scale = float(1.0 / np.sqrt(projection_dim))
- # Clamping to avoid numerical issues
- self.clamp_logits = True
- self.clamp_max_val = 12.0
- def _pool_text_features(self, text_features: torch.Tensor, text_mask: torch.Tensor | None) -> torch.Tensor:
- """
- Mean pool text features, accounting for padding.
- Args:
- text_features: [batch_size, seq_len, hidden_size]
- text_mask: [batch_size, seq_len] where True indicates valid tokens, False indicates padding
- Returns:
- pooled_text: [batch_size, hidden_size]
- """
- if text_mask is None:
- # No padding, simple mean
- return text_features.mean(dim=1)
- is_valid = text_mask.to(text_features.dtype).unsqueeze(-1) # [batch_size, seq_len, 1]
- # Count valid tokens per batch
- num_valid = is_valid.sum(dim=1).clamp(min=1.0) # [batch_size, 1]
- # Mean pool only over valid tokens
- pooled_text = (text_features * is_valid).sum(dim=1) / num_valid # [batch_size, hidden_size]
- return pooled_text
- def forward(
- self,
- decoder_hidden_states: torch.Tensor,
- text_features: torch.Tensor,
- text_mask: torch.Tensor | None = None,
- ) -> torch.Tensor:
- """
- Compute classification scores via dot product.
- Args:
- decoder_hidden_states: [num_layers, batch_size, num_queries, hidden_size]
- text_features: [batch_size, seq_len, hidden_size]
- text_mask: [batch_size, seq_len] where True=valid, False=padding
- Returns:
- scores: [num_layers, batch_size, num_queries, 1]
- """
- orig_text_features = text_features
- text_features = self.text_mlp(text_features)
- text_features = self.text_mlp_dropout(text_features)
- text_features = text_features + orig_text_features
- text_features = self.text_mlp_out_norm(text_features)
- pooled_text = self._pool_text_features(text_features, text_mask)
- proj_text = self.text_proj(pooled_text)
- proj_queries = self.query_proj(decoder_hidden_states)
- proj_text = proj_text.unsqueeze(-1)
- scores = torch.matmul(proj_queries, proj_text.unsqueeze(0))
- scores = scores * self.scale
- if self.clamp_logits:
- scores = scores.clamp(min=-self.clamp_max_val, max=self.clamp_max_val)
- return scores
- class Sam3MaskEmbedder(nn.Module):
- """
- MLP that embeds object queries for mask prediction.
- Similar to MaskFormer's mask embedder.
- """
- def __init__(self, config: Sam3MaskDecoderConfig):
- super().__init__()
- self.config = config
- hidden_size = config.hidden_size
- self.layers = nn.ModuleList(
- [
- nn.Linear(hidden_size, hidden_size),
- nn.Linear(hidden_size, hidden_size),
- nn.Linear(hidden_size, hidden_size),
- ]
- )
- self.activation = nn.ReLU()
- def forward(self, queries: torch.Tensor) -> torch.Tensor:
- """
- Args:
- queries: Query embeddings [batch_size, num_queries, hidden_size]
- Returns:
- Mask embeddings [batch_size, num_queries, hidden_size]
- """
- hidden_states = queries
- for i, layer in enumerate(self.layers):
- hidden_states = layer(hidden_states)
- if i < len(self.layers) - 1:
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class Sam3PixelDecoder(nn.Module):
- """
- Feature Pyramid Network (FPN) decoder that generates pixel-level features.
- Inspired by MaskFormer's pixel decoder.
- """
- def __init__(self, config: Sam3MaskDecoderConfig):
- super().__init__()
- self.config = config
- hidden_size = config.hidden_size
- num_upsampling_stages = config.num_upsampling_stages
- # Create conv layers and norms for FPN
- self.conv_layers = nn.ModuleList(
- [
- nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1)
- for _ in range(num_upsampling_stages)
- ]
- )
- self.norms = nn.ModuleList([nn.GroupNorm(8, hidden_size) for _ in range(num_upsampling_stages)])
- self.out_channels = hidden_size
- def forward(self, backbone_features: list[torch.Tensor]) -> torch.Tensor:
- """
- Args:
- backbone_features: List of backbone features [batch_size, hidden_size, H_i, W_i]
- from low to high resolution (assumes already projected to hidden_size)
- Returns:
- Pixel embeddings [batch_size, hidden_size, H, W] at the finest resolution
- """
- # Start from the coarsest feature (last in list)
- prev_fpn = backbone_features[-1]
- # Iterate through features from coarse to fine (excluding the last which we started with)
- for layer_idx, backbone_feat in enumerate(reversed(backbone_features[:-1])):
- # Upsample previous FPN output to match current backbone feature size
- prev_fpn = F.interpolate(prev_fpn, size=backbone_feat.shape[-2:], mode="nearest")
- # Add skip connection
- prev_fpn = prev_fpn + backbone_feat
- # Apply conv and norm
- prev_fpn = self.conv_layers[layer_idx](prev_fpn)
- prev_fpn = self.norms[layer_idx](prev_fpn)
- prev_fpn = F.relu(prev_fpn)
- return prev_fpn
- class Sam3MaskDecoder(Sam3PreTrainedModel):
- """
- Mask decoder that combines object queries with pixel-level features to predict instance masks.
- Also produces a semantic segmentation output and supports cross-attention to prompts.
- """
- _can_record_outputs = {
- "attentions": Sam3Attention,
- }
- def __init__(self, config: Sam3MaskDecoderConfig):
- super().__init__(config)
- self.config = config
- hidden_size = config.hidden_size
- # Pixel decoder (FPN)
- self.pixel_decoder = Sam3PixelDecoder(config)
- # Mask embedder (MLP to transform queries)
- self.mask_embedder = Sam3MaskEmbedder(config)
- # Projection from pixel decoder output to mask embedding space
- self.instance_projection = nn.Conv2d(self.pixel_decoder.out_channels, hidden_size, kernel_size=1)
- # Semantic segmentation head (always present in UniversalSegmentationHead)
- self.semantic_projection = nn.Conv2d(self.pixel_decoder.out_channels, 1, kernel_size=1)
- self.prompt_cross_attn = Sam3Attention(config)
- self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
- self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- decoder_queries: torch.Tensor,
- backbone_features: list[torch.Tensor],
- encoder_hidden_states: torch.Tensor,
- prompt_features: torch.Tensor | None = None,
- prompt_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | Sam3MaskDecoderOutput:
- """
- Args:
- decoder_queries: Decoder output queries [batch_size, num_queries, hidden_size]
- backbone_features: List of backbone features to process through FPN
- encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
- prompt_features: Prompt features (text + geometry) for cross-attention [batch_size, prompt_len, hidden_size]
- prompt_mask: Padding mask [batch_size, prompt_len] where True=valid, False=padding
- Returns:
- Sam3MaskDecoderOutput containing predicted masks and semantic segmentation.
- """
- if prompt_features is not None:
- # Cross-attention: encoder features attend to prompt features
- residual = encoder_hidden_states
- normed_hidden_states = self.prompt_cross_attn_norm(encoder_hidden_states)
- cross_attn_mask = None
- if prompt_mask is not None:
- cross_attn_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=normed_hidden_states,
- encoder_hidden_states=prompt_features,
- attention_mask=prompt_mask,
- )
- attn_output, _ = self.prompt_cross_attn(
- query=normed_hidden_states,
- key=prompt_features,
- value=prompt_features,
- attention_mask=cross_attn_mask,
- **kwargs,
- )
- encoder_hidden_states = residual + self.prompt_cross_attn_dropout(attn_output)
- # Process backbone features through FPN to get pixel embeddings
- pixel_embed = self._embed_pixels(
- backbone_features=backbone_features,
- encoder_hidden_states=encoder_hidden_states,
- )
- # Predict instance masks via dot product between query embeddings and pixel embeddings
- instance_embeds = self.instance_projection(pixel_embed)
- mask_embeddings = self.mask_embedder(decoder_queries)
- pred_masks = torch.einsum("bqc,bchw->bqhw", mask_embeddings, instance_embeds)
- # Generate semantic segmentation
- semantic_seg = self.semantic_projection(pixel_embed)
- return Sam3MaskDecoderOutput(
- pred_masks=pred_masks,
- semantic_seg=semantic_seg,
- )
- def _embed_pixels(
- self,
- backbone_features: list[torch.Tensor],
- encoder_hidden_states: torch.Tensor,
- ) -> torch.Tensor:
- """
- Embed pixels by combining backbone FPN features with encoder vision features.
- The encoder vision features replace the finest-resolution backbone feature.
- Args:
- backbone_features: List of backbone features [batch_size, C, H_i, W_i]
- encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
- Returns:
- Pixel embeddings [batch_size, hidden_size, H, W]
- """
- backbone_visual_feats = [feat.clone() for feat in backbone_features]
- # Extract vision features from encoder output and reshape to spatial format
- spatial_dim = backbone_features[-1].shape[-2] * backbone_features[-1].shape[-1]
- encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :]
- batch_size, _, hidden_size = encoder_visual_embed.shape
- height, width = backbone_features[-1].shape[-2:]
- encoder_visual_embed = encoder_visual_embed.transpose(1, 2).reshape(batch_size, hidden_size, height, width)
- # Replace finest backbone feature with encoder vision features
- backbone_visual_feats[-1] = encoder_visual_embed
- # Process through FPN decoder
- pixel_embed = self.pixel_decoder(backbone_visual_feats)
- return pixel_embed
- class Sam3Model(Sam3PreTrainedModel):
- input_modalities = ["image", "text"]
- base_model_prefix = "detector_model"
- _keys_to_ignore_on_load_unexpected = [
- r"^tracker_model.",
- r"^tracker_neck.",
- ]
- def __init__(self, config: Sam3Config):
- # loading from a sam3_video config
- if hasattr(config, "detector_config") and config.detector_config is not None:
- detector_config = config.detector_config
- if isinstance(detector_config, dict):
- detector_config = Sam3Config(**detector_config)
- config = detector_config
- super().__init__(config)
- self.vision_encoder = Sam3VisionModel(config.vision_config)
- self.text_encoder = CLIPTextModelWithProjection(config.text_config)
- self.vocab_size = config.text_config.vocab_size
- # Project text features from text encoder hidden size to model hidden size
- # CLIP text encoder outputs 1024-dim features, but we need 256-dim for DETR
- self.text_projection = nn.Linear(config.text_config.hidden_size, config.detr_encoder_config.hidden_size)
- # Pass _attn_implementation to subconfigs BEFORE creating modules
- config.geometry_encoder_config._attn_implementation = config._attn_implementation
- config.detr_encoder_config._attn_implementation = config._attn_implementation
- config.detr_decoder_config._attn_implementation = config._attn_implementation
- config.mask_decoder_config._attn_implementation = config._attn_implementation
- self.geometry_encoder = Sam3GeometryEncoder(config.geometry_encoder_config)
- self.detr_encoder = Sam3DetrEncoder(config.detr_encoder_config)
- self.detr_decoder = Sam3DetrDecoder(config.detr_decoder_config)
- self.mask_decoder = Sam3MaskDecoder(config.mask_decoder_config)
- # Dot product scoring to compute classification scores
- self.dot_product_scoring = Sam3DotProductScoring(config)
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def get_text_features(
- self,
- input_ids: torch.LongTensor,
- attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutputWithPooling:
- r"""
- Example:
- ```python
- >>> from transformers import Sam3Model, Sam3Processor
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> model = Sam3Model.from_pretrained("facebook/sam3")
- >>> processor = Sam3Processor.from_pretrained("facebook/sam3")
- >>> # Pre-compute text embeddings
- >>> text_inputs = processor(text="cat", return_tensors="pt")
- >>> text_embeds = model.get_text_features(**text_inputs).pooler_output
- >>> # Reuse text embeddings for multiple images
- >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> img_inputs = processor(images=image, return_tensors="pt")
- >>> outputs = model(pixel_values=img_inputs.pixel_values, text_embeds=text_embeds)
- ```
- """
- text_outputs = self.text_encoder(
- input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs
- )
- last_hidden_state = text_outputs.last_hidden_state
- text_outputs.pooler_output = self.text_projection(last_hidden_state)
- return text_outputs
- @auto_docstring
- def get_vision_features(
- self,
- pixel_values: torch.FloatTensor,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Sam3VisionEncoderOutput:
- r"""
- Example:
- ```python
- >>> from transformers import Sam3Model, Sam3Processor
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> model = Sam3Model.from_pretrained("facebook/sam3")
- >>> processor = Sam3Processor.from_pretrained("facebook/sam3")
- >>> # Pre-compute vision embeddings
- >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read()))
- >>> img_inputs = processor(images=image, return_tensors="pt")
- >>> vision_embeds = model.get_vision_features(pixel_values=img_inputs.pixel_values)
- >>> # Reuse vision embeddings for multiple text prompts
- >>> text_inputs = processor(text="cat", return_tensors="pt")
- >>> outputs = model(vision_embeds=vision_embeds, input_ids=text_inputs.input_ids)
- ```
- """
- vision_outputs = self.vision_encoder(pixel_values, **kwargs)
- return vision_outputs
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- pixel_values: torch.FloatTensor | None = None,
- vision_embeds: Sam3VisionEncoderOutput | None = None,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- text_embeds: torch.FloatTensor | None = None,
- input_boxes: torch.FloatTensor | None = None,
- input_boxes_labels: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> Sam3ImageSegmentationOutput:
- r"""
- vision_embeds (`Sam3VisionEncoderOutput`, *optional*):
- Pre-computed vision embeddings. Can be used to easily reuse vision embeddings. If provided, `pixel_values`
- should not be passed. Mutually exclusive with `pixel_values`.
- text_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Pre-computed text embeddings. Can be used to easily reuse text embeddings. If provided, `input_ids`
- should not be passed. Mutually exclusive with `input_ids`.
- input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`, *optional*):
- Normalized box coordinates in [0, 1] range, in (cx, cy, w, h) format.
- input_boxes_labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`, *optional*):
- Labels for boxes: 1 (positive), 0 (negative).
- Example:
- ```python
- >>> from PIL import Image
- >>> import httpx
- >>> from io import BytesIO
- >>> from transformers import AutoModel, AutoProcessor
- >>> model = AutoModel.from_pretrained("facebook/sam3")
- >>> processor = AutoProcessor.from_pretrained("facebook/sam3")
- >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
- >>> with httpx.stream("GET", url) as response:
- ... image = Image.open(BytesIO(response.read())).convert("RGB")
- >>> text = "car"
- >>> inputs = processor(images=image, text=text, return_tensors="pt")
- >>> # Get segmentation output
- >>> outputs = model(**inputs)
- >>> pred_masks = outputs.pred_masks
- >>> pred_boxes = outputs.pred_boxes
- ```
- """
- if (pixel_values is None) == (vision_embeds is None):
- raise ValueError("You must specify exactly one of pixel_values or vision_embeds")
- if (input_ids is None) == (text_embeds is None):
- raise ValueError("You must specify exactly one of input_ids or text_embeds")
- if pixel_values is not None:
- batch_size = pixel_values.shape[0]
- device = pixel_values.device
- else:
- batch_size = vision_embeds.fpn_hidden_states[0].shape[0]
- device = vision_embeds.fpn_hidden_states[0].device
- if vision_embeds is None:
- vision_outputs = self.vision_encoder(pixel_values, **kwargs)
- else:
- vision_outputs = vision_embeds
- fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1]
- fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1]
- if text_embeds is None:
- text_features = self.get_text_features(
- input_ids=input_ids, attention_mask=attention_mask, return_dict=True
- ).pooler_output
- else:
- text_features = text_embeds
- text_mask = attention_mask.bool() if attention_mask is not None else None
- has_geometry_prompts = input_boxes is not None and input_boxes.numel() > 0
- geometry_prompt_features = None
- geometry_prompt_mask = None
- if has_geometry_prompts:
- if input_boxes is not None and input_boxes.numel() > 0:
- box_embeddings = input_boxes # [batch_size, num_boxes, 4]
- box_labels = (
- input_boxes_labels
- if input_boxes_labels is not None
- else torch.ones_like(box_embeddings[..., 0], dtype=torch.long)
- )
- box_mask = (
- (input_boxes_labels != -10)
- if input_boxes_labels is not None
- else torch.ones(batch_size, input_boxes.shape[1], dtype=torch.bool, device=device)
- )
- box_labels = torch.where(box_labels == -10, 0, box_labels)
- else:
- box_embeddings = torch.zeros(batch_size, 0, 4, dtype=text_features.dtype, device=device)
- box_labels = torch.zeros(batch_size, 0, dtype=torch.long, device=device)
- box_mask = torch.zeros(batch_size, 0, dtype=torch.bool, device=device)
- geometry_outputs = self.geometry_encoder(
- box_embeddings=box_embeddings,
- box_mask=box_mask,
- box_labels=box_labels,
- img_feats=fpn_hidden_states,
- img_pos_embeds=fpn_position_encoding,
- )
- geometry_prompt_features = geometry_outputs.last_hidden_state
- geometry_prompt_mask = geometry_outputs.attention_mask
- if geometry_prompt_features is not None:
- # Repeat text_features for all geometry prompts
- if text_features.shape[0] == 1 and geometry_prompt_features.shape[0] > 1:
- text_features = text_features.repeat(geometry_prompt_features.shape[0], 1, 1)
- combined_prompt_features = torch.cat([text_features, geometry_prompt_features], dim=1)
- if text_mask is not None and text_mask.shape[0] == 1 and geometry_prompt_mask.shape[0] > 1:
- text_mask = text_mask.repeat(geometry_prompt_mask.shape[0], 1)
- if text_mask is not None and geometry_prompt_mask is not None:
- combined_prompt_mask = torch.cat([text_mask, geometry_prompt_mask], dim=1)
- elif text_mask is not None:
- geo_valid_mask = torch.ones(
- batch_size, geometry_prompt_features.shape[1], dtype=torch.bool, device=device
- )
- combined_prompt_mask = torch.cat([text_mask, geo_valid_mask], dim=1)
- elif geometry_prompt_mask is not None:
- text_valid_mask = torch.ones(batch_size, text_features.shape[1], dtype=torch.bool, device=device)
- combined_prompt_mask = torch.cat([text_valid_mask, geometry_prompt_mask], dim=1)
- else:
- combined_prompt_mask = None
- else:
- combined_prompt_features = text_features
- combined_prompt_mask = text_mask
- encoder_outputs = self.detr_encoder(
- vision_features=[fpn_hidden_states[-1]],
- text_features=combined_prompt_features,
- vision_pos_embeds=[fpn_position_encoding[-1]],
- text_mask=combined_prompt_mask,
- **kwargs,
- )
- decoder_outputs = self.detr_decoder(
- vision_features=encoder_outputs.last_hidden_state,
- text_features=encoder_outputs.text_features,
- vision_pos_encoding=encoder_outputs.pos_embeds_flattened,
- text_mask=combined_prompt_mask,
- spatial_shapes=encoder_outputs.spatial_shapes,
- **kwargs,
- )
- # Refine boxes from decoder
- all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states)
- reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes)
- all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid()
- all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh)
- all_pred_logits = self.dot_product_scoring(
- decoder_hidden_states=decoder_outputs.intermediate_hidden_states,
- text_features=encoder_outputs.text_features,
- text_mask=combined_prompt_mask,
- ).squeeze(-1)
- pred_logits = all_pred_logits[-1]
- pred_boxes = all_pred_boxes[-1]
- decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1]
- presence_logits = decoder_outputs.presence_logits[-1]
- mask_outputs = self.mask_decoder(
- decoder_queries=decoder_hidden_states,
- backbone_features=list(fpn_hidden_states),
- encoder_hidden_states=encoder_outputs.last_hidden_state,
- prompt_features=combined_prompt_features,
- prompt_mask=combined_prompt_mask,
- **kwargs,
- )
- return Sam3ImageSegmentationOutput(
- pred_masks=mask_outputs.pred_masks,
- pred_boxes=pred_boxes,
- pred_logits=pred_logits,
- presence_logits=presence_logits,
- semantic_seg=mask_outputs.semantic_seg,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_reference_boxes=decoder_outputs.reference_boxes,
- encoder_hidden_states=encoder_outputs.hidden_states,
- vision_hidden_states=vision_outputs.hidden_states,
- vision_attentions=vision_outputs.attentions,
- detr_encoder_attentions=encoder_outputs.attentions,
- detr_decoder_attentions=decoder_outputs.attentions,
- mask_decoder_attentions=mask_outputs.attentions,
- )
- __all__ = ["Sam3Model", "Sam3VisionModel", "Sam3ViTModel", "Sam3PreTrainedModel"]
|