modeling_sam3.py 99 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429
  1. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from collections.abc import Callable, Iterable
  16. from dataclasses import dataclass
  17. import numpy as np
  18. import torch
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from torch import Tensor
  22. from ...utils import is_torchvision_available
  23. if is_torchvision_available():
  24. import torchvision
  25. from transformers import CLIPTextModelWithProjection
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...masking_utils import create_bidirectional_mask
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import (
  31. BaseModelOutput,
  32. BaseModelOutputWithPooling,
  33. ModelOutput,
  34. )
  35. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  36. from ...processing_utils import Unpack
  37. from ...pytorch_utils import compile_compatible_method_lru_cache
  38. from ...utils import auto_docstring, can_return_tuple, logging
  39. from ...utils.generic import (
  40. TransformersKwargs,
  41. is_flash_attention_requested,
  42. merge_with_config_defaults,
  43. )
  44. from ...utils.import_utils import requires
  45. from ...utils.output_capturing import capture_outputs
  46. from ..auto import AutoModel
  47. from .configuration_sam3 import (
  48. Sam3Config,
  49. Sam3DETRDecoderConfig,
  50. Sam3DETREncoderConfig,
  51. Sam3GeometryEncoderConfig,
  52. Sam3MaskDecoderConfig,
  53. Sam3VisionConfig,
  54. Sam3ViTConfig,
  55. )
  56. logger = logging.get_logger(__name__)
  57. @dataclass
  58. @auto_docstring
  59. class Sam3VisionEncoderOutput(BaseModelOutputWithPooling):
  60. r"""
  61. fpn_hidden_states (`tuple[torch.FloatTensor]`):
  62. Tuple of multi-level FPN feature maps.
  63. fpn_position_encoding (`tuple[torch.FloatTensor]`):
  64. Tuple of position encodings for each FPN level.
  65. """
  66. fpn_hidden_states: tuple[torch.FloatTensor, ...] = None
  67. fpn_position_encoding: tuple[torch.FloatTensor, ...] = None
  68. @dataclass
  69. @auto_docstring
  70. class Sam3GeometryEncoderOutput(ModelOutput):
  71. r"""
  72. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_prompts, hidden_size)`):
  73. Encoded geometry prompt features (boxes).
  74. attention_mask (`torch.BoolTensor` of shape `(batch_size, num_prompts)`, *optional*):
  75. Attention mask for geometry prompts where True indicates valid positions and False indicates padding.
  76. """
  77. last_hidden_state: torch.FloatTensor = None
  78. attention_mask: torch.BoolTensor | None = None
  79. @dataclass
  80. @auto_docstring
  81. class Sam3DETREncoderOutput(ModelOutput):
  82. r"""
  83. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  84. Encoded vision features (flattened from multi-level features).
  85. pos_embeds_flattened (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  86. Flattened position embeddings for the vision features.
  87. text_features (`torch.FloatTensor` of shape `(batch_size, text_seq_len, hidden_size)`, *optional*):
  88. Text features (may be pooled after encoder processing).
  89. spatial_shapes (`torch.LongTensor` of shape `(num_levels, 2)`, *optional*):
  90. Spatial shapes (height, width) for each feature pyramid level.
  91. hidden_states (`tuple[torch.FloatTensor]`, *optional*):
  92. Tuple of hidden states from all encoder layers.
  93. attentions (`tuple[torch.FloatTensor]`, *optional*):
  94. Tuple of attention weights from all encoder layers.
  95. """
  96. last_hidden_state: torch.FloatTensor = None
  97. pos_embeds_flattened: torch.FloatTensor | None = None
  98. text_features: torch.FloatTensor | None = None
  99. spatial_shapes: torch.LongTensor | None = None
  100. hidden_states: tuple[torch.FloatTensor] | None = None
  101. attentions: tuple[torch.FloatTensor] | None = None
  102. @dataclass
  103. @auto_docstring
  104. class Sam3DETRDecoderOutput(ModelOutput):
  105. r"""
  106. intermediate_hidden_states (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, hidden_size)`):
  107. Decoder hidden states from all layers.
  108. reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`):
  109. Predicted reference boxes from all decoder layers in (cx, cy, w, h) format.
  110. presence_logits (`torch.FloatTensor` of shape `(num_layers, batch_size, 1)`):
  111. Presence logits from all decoder layers indicating object presence confidence.
  112. hidden_states (`tuple[torch.FloatTensor]`, *optional*):
  113. Tuple of hidden states from all decoder layers.
  114. attentions (`tuple[torch.FloatTensor]`, *optional*):
  115. Tuple of attention weights from all decoder layers (self-attention and cross-attention).
  116. """
  117. intermediate_hidden_states: torch.FloatTensor = None
  118. reference_boxes: torch.FloatTensor = None
  119. presence_logits: torch.FloatTensor = None
  120. hidden_states: tuple[torch.FloatTensor] | None = None
  121. attentions: tuple[torch.FloatTensor] | None = None
  122. @dataclass
  123. @auto_docstring
  124. class Sam3MaskDecoderOutput(ModelOutput):
  125. r"""
  126. pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
  127. Predicted segmentation masks for each query.
  128. semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
  129. Semantic segmentation output.
  130. attentions (`tuple[torch.FloatTensor]`, *optional*):
  131. Tuple of attention weights from mask decoder cross-attention layers.
  132. """
  133. pred_masks: torch.FloatTensor = None
  134. semantic_seg: torch.FloatTensor | None = None
  135. attentions: tuple[torch.FloatTensor] | None = None
  136. @dataclass
  137. @auto_docstring
  138. class Sam3ImageSegmentationOutput(ModelOutput):
  139. r"""
  140. pred_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
  141. Predicted segmentation masks for each query.
  142. pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
  143. Predicted bounding boxes in (x1, y1, x2, y2) format.
  144. pred_logits (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
  145. Classification confidence scores for each query, computed via dot product between
  146. decoder query features and text features.
  147. presence_logits (`torch.FloatTensor` of shape `(batch_size, 1)`, *optional*):
  148. Presence logits from the DETR decoder presence token (last layer only). These indicate whether objects
  149. are present in the scene. Can be used to compute final scores by multiplying with pred_logits:
  150. `final_scores = pred_logits.sigmoid() * presence_logits.sigmoid()`.
  151. semantic_seg (`torch.FloatTensor` of shape `(batch_size, 1, height, width)`, *optional*):
  152. Semantic segmentation output.
  153. decoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
  154. Tuple of hidden states from all DETR decoder layers. Each tensor has shape `(batch_size, num_queries, hidden_size)`.
  155. decoder_reference_boxes (`torch.FloatTensor` of shape `(num_layers, batch_size, num_queries, 4)`, *optional*):
  156. Reference boxes from all DETR decoder layers.
  157. encoder_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
  158. Tuple of hidden states from all DETR encoder layers.
  159. vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*):
  160. Tuple of hidden states from all vision encoder (ViT) layers.
  161. vision_attentions (`tuple[torch.FloatTensor]`, *optional*):
  162. Attention weights from vision encoder (ViT) layers.
  163. detr_encoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
  164. Attention weights from DETR encoder layers.
  165. detr_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
  166. Attention weights from DETR decoder layers (self-attention and cross-attention).
  167. mask_decoder_attentions (`tuple[torch.FloatTensor]`, *optional*):
  168. Attention weights from mask decoder layers.
  169. """
  170. pred_masks: torch.FloatTensor = None
  171. pred_boxes: torch.FloatTensor = None
  172. pred_logits: torch.FloatTensor | None = None
  173. presence_logits: torch.FloatTensor | None = None
  174. semantic_seg: torch.FloatTensor | None = None
  175. decoder_hidden_states: tuple[torch.FloatTensor] | None = None
  176. decoder_reference_boxes: torch.FloatTensor | None = None
  177. encoder_hidden_states: tuple[torch.FloatTensor] | None = None
  178. vision_hidden_states: tuple[torch.FloatTensor] | None = None
  179. vision_attentions: tuple[torch.FloatTensor] | None = None
  180. detr_encoder_attentions: tuple[torch.FloatTensor] | None = None
  181. detr_decoder_attentions: tuple[torch.FloatTensor] | None = None
  182. mask_decoder_attentions: tuple[torch.FloatTensor] | None = None
  183. def inverse_sigmoid(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
  184. """The inverse function for sigmoid activation function."""
  185. x = x.clamp(min=0, max=1)
  186. x1 = x.clamp(min=eps)
  187. x2 = (1 - x).clamp(min=eps)
  188. return torch.log(x1 / x2)
  189. def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
  190. """
  191. Concatenates two right-padded sequences, such that the resulting sequence
  192. is contiguous and also right-padded.
  193. Tensors are batch-first, masks are batch-first with True=valid, False=padding.
  194. Args:
  195. seq1: A tensor of shape (batch_size, seq1_length, hidden_size).
  196. mask1: A tensor of shape (batch_size, seq1_length) with True=valid, False=padding.
  197. seq2: A tensor of shape (batch_size, seq2_length, hidden_size).
  198. mask2: A tensor of shape (batch_size, seq2_length) with True=valid, False=padding.
  199. return_index: If True, also returns the index of the ids of the element of seq2
  200. in the concatenated sequence. This can be used to retrieve the elements of seq2.
  201. Returns:
  202. A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
  203. otherwise (concatenated_sequence, concatenated_mask, index).
  204. The concatenated_mask uses True=valid, False=padding convention.
  205. """
  206. batch_size, seq1_length, hidden_size = seq1.shape
  207. batch_size2, seq2_length, hidden_size2 = seq2.shape
  208. assert batch_size == batch_size2 == mask1.size(0) == mask2.size(0)
  209. assert hidden_size == hidden_size2
  210. assert seq1_length == mask1.size(1)
  211. assert seq2_length == mask2.size(1)
  212. actual_seq1_lengths = mask1.sum(dim=-1)
  213. actual_seq2_lengths = mask2.sum(dim=-1)
  214. final_lengths = actual_seq1_lengths + actual_seq2_lengths
  215. max_length = seq1_length + seq2_length
  216. concatenated_mask = (
  217. torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1) < final_lengths[:, None]
  218. )
  219. concatenated_sequence = torch.zeros((batch_size, max_length, hidden_size), device=seq2.device, dtype=seq2.dtype)
  220. concatenated_sequence[:, :seq1_length, :] = seq1
  221. # Shift seq2 elements to start at the end of valid seq1
  222. index = torch.arange(seq2_length, device=seq2.device)[None].repeat(batch_size, 1)
  223. index = index + actual_seq1_lengths[:, None]
  224. # Scatter seq2 into the right positions
  225. concatenated_sequence = concatenated_sequence.scatter(1, index[:, :, None].expand(-1, -1, hidden_size), seq2)
  226. if return_index:
  227. return concatenated_sequence, concatenated_mask, index
  228. return concatenated_sequence, concatenated_mask
  229. def box_cxcywh_to_xyxy(x):
  230. """Convert boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format."""
  231. x_c, y_c, w, h = x.unbind(-1)
  232. b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
  233. return torch.stack(b, dim=-1)
  234. class Sam3MLP(nn.Module):
  235. def __init__(self, config: Sam3ViTConfig):
  236. super().__init__()
  237. self.config = config
  238. self.activation_fn = ACT2FN[config.hidden_act]
  239. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  240. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  241. self.dropout = nn.Dropout(config.hidden_dropout)
  242. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  243. hidden_states = self.fc1(hidden_states)
  244. hidden_states = self.dropout(hidden_states)
  245. hidden_states = self.activation_fn(hidden_states)
  246. hidden_states = self.fc2(hidden_states)
  247. return hidden_states
  248. def eager_attention_forward(
  249. module: nn.Module,
  250. query: torch.Tensor,
  251. key: torch.Tensor,
  252. value: torch.Tensor,
  253. attention_mask: torch.Tensor | None,
  254. scaling: float | None = None,
  255. dropout: float = 0.0,
  256. **kwargs: Unpack[TransformersKwargs],
  257. ):
  258. if scaling is None:
  259. scaling = query.size(-1) ** -0.5
  260. # Take the dot product between "query" and "key" to get the raw attention scores.
  261. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  262. if attention_mask is not None:
  263. attn_weights = attn_weights + attention_mask
  264. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  265. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  266. attn_output = torch.matmul(attn_weights, value)
  267. attn_output = attn_output.transpose(1, 2).contiguous()
  268. return attn_output, attn_weights
  269. class Sam3Attention(nn.Module):
  270. """
  271. Multi-head attention.
  272. Handles standard [batch_size, seq_len, hidden_size] tensors.
  273. """
  274. def __init__(self, config):
  275. super().__init__()
  276. self.config = config
  277. self.hidden_size = config.hidden_size
  278. self.num_attention_heads = config.num_attention_heads
  279. self.head_dim = self.hidden_size // config.num_attention_heads
  280. self.scaling = self.head_dim**-0.5
  281. self.is_causal = False
  282. self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
  283. self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
  284. self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
  285. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
  286. def forward(
  287. self,
  288. query: torch.Tensor,
  289. key: torch.Tensor,
  290. value: torch.Tensor,
  291. attention_mask: torch.Tensor | None = None,
  292. **kwargs: Unpack[TransformersKwargs],
  293. ) -> tuple[torch.Tensor, torch.Tensor]:
  294. """
  295. Args:
  296. query: [batch_size, query_len, hidden_size]
  297. key: [batch_size, key_len, hidden_size]
  298. value: [batch_size, value_len, hidden_size]
  299. attention_mask: [batch_size, num_heads, query_len, key_len] or broadcastable
  300. Returns:
  301. Tuple of (output, attention_weights)
  302. output: [batch_size, query_len, hidden_size]
  303. attention_weights: [batch_size, num_heads, query_len, key_len]
  304. """
  305. batch_size = query.shape[0]
  306. query_len = query.shape[1]
  307. key_len = key.shape[1]
  308. query = self.q_proj(query).view(batch_size, query_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
  309. key = self.k_proj(key).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
  310. value = self.v_proj(value).view(batch_size, key_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
  311. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  312. self.config._attn_implementation, eager_attention_forward
  313. )
  314. if (
  315. is_flash_attention_requested(self.config)
  316. and attention_mask is not None
  317. and attention_mask.dtype != torch.bool
  318. ):
  319. # Relative position bias tensors are represented as float masks and are incompatible with Flash Attention
  320. # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
  321. attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
  322. logger.warning_once(
  323. "Sam3Attention: falling back to SDPA for relative-position cross-attention because "
  324. "Flash Attention does not support additive bias masks."
  325. )
  326. attn_output, attn_weights = attention_interface(
  327. self,
  328. query,
  329. key,
  330. value,
  331. attention_mask=attention_mask,
  332. dropout=0.0,
  333. scaling=self.scaling,
  334. is_causal=self.is_causal,
  335. **kwargs,
  336. )
  337. attn_output = attn_output.reshape(batch_size, query_len, self.num_attention_heads * self.head_dim).contiguous()
  338. attn_output = self.o_proj(attn_output)
  339. return attn_output, attn_weights
  340. class Sam3ViTRotaryEmbedding(nn.Module):
  341. """
  342. Vision Rotary Position Embedding for SAM3, following transformers library standards.
  343. Supports 2D (axial) rotary embeddings for spatial dimensions.
  344. """
  345. def __init__(self, config: Sam3ViTConfig, end_x: int, end_y: int, scale: float = 1.0):
  346. super().__init__()
  347. dim = config.hidden_size // config.num_attention_heads
  348. # Ensure even dimension for proper axial splitting
  349. if dim % 4 != 0:
  350. raise ValueError("Dimension must be divisible by 4 for axial RoPE")
  351. self.end_x, self.end_y = end_x, end_y
  352. self.dim = dim
  353. self.rope_theta = config.rope_theta
  354. self.scale = scale
  355. freqs = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
  356. flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
  357. x_positions = (flattened_indices % end_x) * scale
  358. y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * scale
  359. freqs_x = torch.outer(x_positions, freqs).float()
  360. freqs_y = torch.outer(y_positions, freqs).float()
  361. inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
  362. inv_freq = inv_freq.repeat_interleave(2, dim=-1)
  363. # directly register the cos and sin embeddings as we have a fixed feature shape
  364. self.register_buffer("rope_embeddings_cos", inv_freq.cos(), persistent=False)
  365. self.register_buffer("rope_embeddings_sin", inv_freq.sin(), persistent=False)
  366. @torch.no_grad()
  367. def forward(self) -> tuple[torch.Tensor, torch.Tensor]:
  368. # As the feature map size is fixed for each stage, we can just return the pre-computed embeddings.
  369. return self.rope_embeddings_cos, self.rope_embeddings_sin
  370. def rotate_pairwise(x):
  371. """
  372. pairwise rotation of the hidden dims of the input. Differerent from Llama Half-Tensor Rotation.
  373. This is an optimized version of the following more explicit implementation:
  374. ```python
  375. x_rotated = torch.zeros_like(x, dtype=x.dtype, device=x.device)
  376. x_rotated[..., ::2] = -x[..., 1::2]
  377. x_rotated[..., 1::2] = x[..., ::2]
  378. return x_rotated
  379. ```
  380. """
  381. x = x.view(*x.shape[:-1], -1, 2)
  382. x1, x2 = x.unbind(dim=-1)
  383. x = torch.stack((-x2, x1), dim=-1)
  384. return x.flatten(start_dim=-2)
  385. def apply_rotary_pos_emb_2d(
  386. q: torch.Tensor,
  387. k: torch.Tensor,
  388. cos: torch.Tensor,
  389. sin: torch.Tensor,
  390. ) -> tuple[torch.Tensor, torch.Tensor]:
  391. """
  392. Apply rotary position embedding to query and key tensors for self-attention.
  393. Args:
  394. q: Query tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim)
  395. k: Key tensor of shape (batch_size, num_windows, seq_len, num_heads, head_dim)
  396. cos: Cosine position embedding of shape (seq_len, head_dim)
  397. sin: Sine position embedding of shape (seq_len, head_dim)
  398. Returns:
  399. Rotated (q, k) tensors
  400. """
  401. q_embed = q.float()
  402. q_embed = (q_embed * cos) + (rotate_pairwise(q_embed) * sin)
  403. k_embed = k.float()
  404. k_embed = (k_embed * cos) + (rotate_pairwise(k_embed) * sin)
  405. return q_embed.type_as(q), k_embed.type_as(k)
  406. class Sam3ViTRoPEAttention(nn.Module):
  407. """Self-attention with rotary position encoding."""
  408. def __init__(self, config: Sam3ViTConfig):
  409. super().__init__()
  410. self.config = config
  411. self.hidden_size = config.hidden_size
  412. self.num_attention_heads = config.num_attention_heads
  413. self.head_dim = self.hidden_size // config.num_attention_heads
  414. self.scaling = self.head_dim**-0.5
  415. self.attention_dropout = config.attention_dropout
  416. self.is_causal = False
  417. self.q_proj = nn.Linear(self.hidden_size, self.hidden_size)
  418. self.k_proj = nn.Linear(self.hidden_size, self.hidden_size)
  419. self.v_proj = nn.Linear(self.hidden_size, self.hidden_size)
  420. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size)
  421. def forward(
  422. self,
  423. hidden_states: torch.Tensor,
  424. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  425. **kwargs: Unpack[TransformersKwargs],
  426. ) -> Tensor:
  427. batch_size, height, width, _ = hidden_states.shape
  428. seq_len = height * width
  429. new_shape = (batch_size, seq_len, self.num_attention_heads, self.head_dim)
  430. query = self.q_proj(hidden_states).view(*new_shape).transpose(1, 2)
  431. key = self.k_proj(hidden_states).view(*new_shape).transpose(1, 2)
  432. value = self.v_proj(hidden_states).view(*new_shape).transpose(1, 2)
  433. cos, sin = position_embeddings
  434. query, key = apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin)
  435. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  436. self.config._attn_implementation, eager_attention_forward
  437. )
  438. attn_output, attn_weights = attention_interface(
  439. self,
  440. query,
  441. key,
  442. value,
  443. attention_mask=None,
  444. dropout=0.0 if not self.training else self.attention_dropout,
  445. scaling=self.scaling,
  446. is_causal=self.is_causal,
  447. **kwargs,
  448. )
  449. attn_output = attn_output.reshape(batch_size, height, width, -1).contiguous()
  450. attn_output = self.o_proj(attn_output)
  451. return attn_output, attn_weights
  452. class Sam3ViTPatchEmbeddings(nn.Module):
  453. """
  454. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  455. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  456. Transformer.
  457. """
  458. def __init__(self, config: Sam3ViTConfig):
  459. super().__init__()
  460. image_size, patch_size = config.pretrain_image_size, config.patch_size
  461. num_channels, hidden_size = config.num_channels, config.hidden_size
  462. image_size = image_size if isinstance(image_size, Iterable) else (image_size, image_size)
  463. patch_size = patch_size if isinstance(patch_size, Iterable) else (patch_size, patch_size)
  464. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  465. self.image_size = image_size
  466. self.patch_size = patch_size
  467. self.num_channels = num_channels
  468. self.num_patches = num_patches
  469. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=False)
  470. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  471. embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).flatten(2).transpose(1, 2)
  472. return embeddings
  473. class Sam3ViTEmbeddings(nn.Module):
  474. """
  475. Construct the patch embeddings and position embeddings for SAM3 ViT.
  476. Position embeddings are tiled (not interpolated) when resizing to match different input sizes.
  477. """
  478. def __init__(self, config: Sam3ViTConfig):
  479. super().__init__()
  480. self.patch_embeddings = Sam3ViTPatchEmbeddings(config)
  481. num_patches = self.patch_embeddings.num_patches
  482. self.position_embeddings = nn.Parameter(
  483. torch.randn(1, num_patches, config.hidden_size)
  484. ) # !Remove cls token in convert weights!
  485. self.dropout = nn.Dropout(config.hidden_dropout)
  486. self.patch_size = config.patch_size
  487. def _tile_position_embeddings(
  488. self,
  489. position_embeddings: torch.Tensor,
  490. height: int,
  491. width: int,
  492. ) -> torch.Tensor:
  493. """
  494. Tile position embeddings to match target spatial dimensions.
  495. Args:
  496. position_embeddings: Shape [1, num_pretrain_patches, hidden_size]
  497. height: Target height in patches
  498. width: Target width in patches
  499. Returns:
  500. Shape [1, height * width, hidden_size]
  501. """
  502. pretrain_size = int(position_embeddings.shape[1] ** 0.5)
  503. # Skip tiling if sizes match (but always tile during tracing for consistent graph)
  504. if not torch.jit.is_tracing() and pretrain_size == height and pretrain_size == width:
  505. return position_embeddings.reshape(1, height * width, -1)
  506. # Tile position embeddings to match target spatial dimensions
  507. hidden_size = position_embeddings.shape[-1]
  508. pos_embed = position_embeddings.reshape(1, pretrain_size, pretrain_size, hidden_size).permute(0, 3, 1, 2)
  509. repeat_h = height // pretrain_size + 1
  510. repeat_w = width // pretrain_size + 1
  511. pos_embed = pos_embed.tile([1, 1, repeat_h, repeat_w])[:, :, :height, :width]
  512. return pos_embed.permute(0, 2, 3, 1).reshape(1, height * width, hidden_size)
  513. def forward(
  514. self,
  515. pixel_values: torch.Tensor,
  516. interpolate_pos_encoding: bool = False,
  517. ) -> torch.Tensor:
  518. height, width = pixel_values.shape[-2:]
  519. embeddings = self.patch_embeddings(pixel_values)
  520. # Calculate spatial dimensions in patches
  521. height_patches = height // self.patch_size
  522. width_patches = width // self.patch_size
  523. position_embeddings = self._tile_position_embeddings(
  524. self.position_embeddings,
  525. height_patches,
  526. width_patches,
  527. )
  528. embeddings = embeddings + position_embeddings
  529. embeddings = self.dropout(embeddings)
  530. return embeddings
  531. def window_partition(hidden_state, window_size):
  532. """
  533. Partition into non-overlapping windows with padding if needed.
  534. Args:
  535. hidden_state (`torch.Tensor`):
  536. Input tokens with [batch_size, height, width, num_channels].
  537. window_size (`int`):
  538. Window size.
  539. Returns:
  540. `tuple(torch.FloatTensor)` comprising various elements:
  541. - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
  542. - (padded_height, padded_width): padded height and width before partition
  543. """
  544. batch_size, height, width, num_channels = hidden_state.shape
  545. pad_height = (window_size - height % window_size) % window_size
  546. pad_width = (window_size - width % window_size) % window_size
  547. # Noop in case pad_width == 0 and pad_height == 0.
  548. hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
  549. padded_height, padded_width = height + pad_height, width + pad_width
  550. hidden_state = hidden_state.view(
  551. batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
  552. )
  553. windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  554. return windows, (padded_height, padded_width)
  555. def window_unpartition(windows, window_size, pad_height_width, height_width):
  556. """
  557. Window unpartition into original sequences and removing padding.
  558. Args:
  559. windows (`torch.Tensor`):
  560. Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
  561. window_size (`int`):
  562. Window size.
  563. pad_height_width (`tuple[int]`):
  564. Padded height and width (padded_height, padded_width).
  565. height_width (`tuple[int]`):
  566. Original height and width before padding.
  567. Returns:
  568. hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
  569. """
  570. padded_height, padded_width = pad_height_width
  571. height, width = height_width
  572. batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
  573. hidden_state = windows.view(
  574. batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
  575. )
  576. hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
  577. hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
  578. # We always have height <= padded_height and width <= padded_width
  579. hidden_state = hidden_state[:, :height, :width, :].contiguous()
  580. return hidden_state
  581. class Sam3ViTLayerScale(nn.Module):
  582. def __init__(self, config) -> None:
  583. super().__init__()
  584. self.lambda1 = nn.Parameter(config.layer_scale_init_value * torch.ones(config.hidden_size))
  585. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  586. return hidden_state * self.lambda1
  587. class Sam3ViTLayer(GradientCheckpointingLayer):
  588. """Vision Transformer layer with rotary position embeddings and optional windowed attention."""
  589. def __init__(self, config: Sam3ViTConfig, window_size: int = 0) -> None:
  590. super().__init__()
  591. hidden_size = config.hidden_size
  592. image_size = config.image_size
  593. image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
  594. patch_size = config.patch_size
  595. patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
  596. input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  597. self.layer_norm1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  598. rotary_input_size = input_size if window_size == 0 else (window_size, window_size)
  599. rotary_scale = config.window_size / rotary_input_size[0]
  600. self.rotary_emb = Sam3ViTRotaryEmbedding(
  601. config, end_x=rotary_input_size[0], end_y=rotary_input_size[1], scale=rotary_scale
  602. )
  603. self.attention = Sam3ViTRoPEAttention(config)
  604. self.layer_norm2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  605. self.mlp = Sam3MLP(config)
  606. self.dropout = nn.Dropout(config.hidden_dropout)
  607. self.window_size = window_size
  608. def forward(
  609. self,
  610. hidden_states: torch.Tensor,
  611. **kwargs: Unpack[TransformersKwargs],
  612. ) -> torch.Tensor:
  613. residual = hidden_states
  614. hidden_states = self.layer_norm1(hidden_states)
  615. if self.window_size > 0:
  616. height, width = hidden_states.shape[1], hidden_states.shape[2]
  617. # Partition into non-overlapping windows for efficient attention
  618. hidden_states, pad_height_width = window_partition(hidden_states, self.window_size)
  619. position_embeddings = self.rotary_emb()
  620. hidden_states, _ = self.attention(hidden_states, position_embeddings, **kwargs)
  621. if self.window_size > 0:
  622. # Reverse window partition to restore original spatial layout
  623. hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width))
  624. hidden_states = residual + hidden_states
  625. residual = hidden_states
  626. hidden_states = self.layer_norm2(hidden_states)
  627. hidden_states = self.mlp(hidden_states)
  628. hidden_states = residual + self.dropout(hidden_states)
  629. return hidden_states
  630. @auto_docstring
  631. @requires(backends=("torch", "torchvision"))
  632. class Sam3PreTrainedModel(PreTrainedModel):
  633. config_class = Sam3Config
  634. base_model_prefix = "sam3"
  635. main_input_name = "pixel_values"
  636. input_modalities = ["image", "text"]
  637. _supports_sdpa = True
  638. _supports_flash_attn = True
  639. _supports_flex_attn = True
  640. _supports_attention_backend = True
  641. def _init_weights(self, module):
  642. super()._init_weights(module)
  643. if isinstance(module, Sam3ViTEmbeddings):
  644. init.normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
  645. elif isinstance(module, Sam3ViTRotaryEmbedding):
  646. end_x, end_y = module.end_x, module.end_y
  647. dim = module.dim
  648. freqs = 1.0 / (module.rope_theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
  649. flattened_indices = torch.arange(end_x * end_y, dtype=torch.long)
  650. x_positions = (flattened_indices % end_x) * module.scale
  651. y_positions = torch.div(flattened_indices, end_x, rounding_mode="floor") * module.scale
  652. freqs_x = torch.outer(x_positions, freqs).float()
  653. freqs_y = torch.outer(y_positions, freqs).float()
  654. inv_freq = torch.cat([freqs_x, freqs_y], dim=-1)
  655. inv_freq = inv_freq.repeat_interleave(2, dim=-1)
  656. init.copy_(module.rope_embeddings_cos, inv_freq.cos())
  657. init.copy_(module.rope_embeddings_sin, inv_freq.sin())
  658. @auto_docstring
  659. class Sam3ViTModel(Sam3PreTrainedModel):
  660. config: Sam3ViTConfig
  661. _can_record_outputs = {
  662. "hidden_states": Sam3ViTLayer,
  663. "attentions": Sam3ViTRoPEAttention,
  664. }
  665. def __init__(self, config: Sam3ViTConfig):
  666. super().__init__(config)
  667. self.config = config
  668. self.embeddings = Sam3ViTEmbeddings(config)
  669. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  670. self.layers = nn.ModuleList(
  671. [
  672. Sam3ViTLayer(config, window_size=config.window_size if i not in config.global_attn_indexes else 0)
  673. for i in range(config.num_hidden_layers)
  674. ]
  675. )
  676. self.post_init()
  677. def get_input_embeddings(self) -> Sam3ViTPatchEmbeddings:
  678. return self.embeddings.patch_embeddings
  679. @merge_with_config_defaults
  680. @capture_outputs(tie_last_hidden_states=False)
  681. @auto_docstring
  682. def forward(
  683. self,
  684. pixel_values: torch.Tensor,
  685. **kwargs: Unpack[TransformersKwargs],
  686. ) -> BaseModelOutput:
  687. hidden_states = self.embeddings(pixel_values) # [batch_size, seq_len, hidden_size]
  688. batch_size = hidden_states.shape[0]
  689. height = pixel_values.shape[-2] // self.config.patch_size
  690. width = pixel_values.shape[-1] // self.config.patch_size
  691. hidden_size = hidden_states.shape[-1]
  692. # Reshape to spatial format for windowed attention: [batch_size, height, width, hidden_size]
  693. hidden_states = hidden_states.view(batch_size, height, width, hidden_size)
  694. hidden_states = self.layer_norm(hidden_states)
  695. for layer in self.layers:
  696. hidden_states = layer(hidden_states, **kwargs)
  697. # Reshape back to sequence format: [batch_size, height*width, hidden_size]
  698. hidden_states = hidden_states.view(batch_size, height * width, hidden_size)
  699. return BaseModelOutput(last_hidden_state=hidden_states)
  700. class Sam3SinePositionEmbedding(nn.Module):
  701. """
  702. This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
  703. need paper, generalized to work on images.
  704. """
  705. def __init__(
  706. self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
  707. ):
  708. super().__init__()
  709. if scale is not None and normalize is False:
  710. raise ValueError("normalize should be True if scale is passed")
  711. self.num_pos_feats = num_pos_feats
  712. self.temperature = temperature
  713. self.normalize = normalize
  714. self.scale = 2 * math.pi if scale is None else scale
  715. def encode_1d_positions(self, x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  716. """
  717. Encode 1D coordinate pairs using sine/cosine positional embeddings.
  718. Args:
  719. x: 1D tensor of x coordinates (flattened)
  720. y: 1D tensor of y coordinates (flattened)
  721. Returns:
  722. Tuple of (pos_x, pos_y) positional embeddings
  723. """
  724. x_embed = x * self.scale
  725. y_embed = y * self.scale
  726. dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).to(x.dtype)
  727. dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
  728. pos_x = x_embed[:, None] / dim_t
  729. pos_y = y_embed[:, None] / dim_t
  730. pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
  731. pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
  732. return pos_x, pos_y
  733. def encode_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  734. """
  735. Encode 4D box coordinates (x, y, w, h) for decoder conditioning using sine/cosine embeddings.
  736. Args:
  737. boxes: Box coordinates [batch_size, num_queries, 4] in (x, y, w, h) format
  738. Returns:
  739. Position embeddings [batch_size, num_queries, num_pos_feats*4]
  740. """
  741. assert boxes.size(-1) == 4, f"Expected 4D box coordinates (x, y, w, h), got shape {boxes.shape}"
  742. dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=boxes.device).to(boxes.dtype)
  743. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
  744. x_embed = boxes[:, :, 0] * self.scale
  745. y_embed = boxes[:, :, 1] * self.scale
  746. w_embed = boxes[:, :, 2] * self.scale
  747. h_embed = boxes[:, :, 3] * self.scale
  748. pos_x = x_embed[:, :, None] / dim_t
  749. pos_y = y_embed[:, :, None] / dim_t
  750. pos_w = w_embed[:, :, None] / dim_t
  751. pos_h = h_embed[:, :, None] / dim_t
  752. pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
  753. pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
  754. pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
  755. pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
  756. pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
  757. return pos
  758. @compile_compatible_method_lru_cache(maxsize=4)
  759. def forward(
  760. self,
  761. shape: torch.Size,
  762. device: torch.device | str,
  763. dtype: torch.dtype,
  764. mask: Tensor | None = None,
  765. ) -> Tensor:
  766. if mask is None:
  767. mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
  768. not_mask = (~mask).to(dtype)
  769. y_embed = not_mask.cumsum(1)
  770. x_embed = not_mask.cumsum(2)
  771. if self.normalize:
  772. eps = 1e-6
  773. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  774. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  775. dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
  776. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
  777. pos_x = x_embed[:, :, :, None] / dim_t
  778. pos_y = y_embed[:, :, :, None] / dim_t
  779. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  780. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  781. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  782. return pos
  783. class Sam3FPNLayer(nn.Module):
  784. def __init__(self, in_channels: int, fpn_dim: int, scale_factor: float):
  785. super().__init__()
  786. self.scale_factor = scale_factor
  787. # Build the upsampling/downsampling layers based on scale factor
  788. self.scale_layers = nn.ModuleList()
  789. if scale_factor == 4.0:
  790. self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2))
  791. self.scale_layers.append(nn.GELU())
  792. self.scale_layers.append(nn.ConvTranspose2d(in_channels // 2, in_channels // 4, kernel_size=2, stride=2))
  793. intermediate_channels = in_channels // 4
  794. elif scale_factor == 2.0:
  795. self.scale_layers.append(nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2))
  796. intermediate_channels = in_channels // 2
  797. elif scale_factor == 1.0:
  798. intermediate_channels = in_channels
  799. elif scale_factor == 0.5:
  800. self.scale_layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
  801. intermediate_channels = in_channels
  802. else:
  803. raise NotImplementedError(f"scale_factor={scale_factor} is not supported yet.")
  804. self.proj1 = nn.Conv2d(in_channels=intermediate_channels, out_channels=fpn_dim, kernel_size=1)
  805. self.proj2 = nn.Conv2d(in_channels=fpn_dim, out_channels=fpn_dim, kernel_size=3, padding=1)
  806. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  807. hidden_states = hidden_states.to(self.proj1.weight.dtype)
  808. for layer in self.scale_layers:
  809. hidden_states = layer(hidden_states)
  810. hidden_states = self.proj1(hidden_states)
  811. hidden_states = self.proj2(hidden_states)
  812. return hidden_states
  813. class Sam3VisionNeck(nn.Module):
  814. def __init__(self, config: Sam3VisionConfig):
  815. super().__init__()
  816. self.config = config
  817. self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
  818. # Create one FPN layer per scale factor
  819. self.fpn_layers = nn.ModuleList(
  820. [
  821. Sam3FPNLayer(
  822. in_channels=config.backbone_config.hidden_size, fpn_dim=config.fpn_hidden_size, scale_factor=scale
  823. )
  824. for scale in config.scale_factors
  825. ]
  826. )
  827. def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
  828. fpn_hidden_states = ()
  829. fpn_position_encoding = ()
  830. for fpn_layer in self.fpn_layers:
  831. fpn_output = fpn_layer(hidden_states)
  832. fpn_hidden_states += (fpn_output,)
  833. # Generate position encoding for this FPN level
  834. pos_enc = self.position_encoding(fpn_output.shape, fpn_output.device, fpn_output.dtype)
  835. fpn_position_encoding += (pos_enc,)
  836. return fpn_hidden_states, fpn_position_encoding
  837. @auto_docstring(
  838. custom_intro="""
  839. The vision model from Sam without any head or projection on top.
  840. """
  841. )
  842. class Sam3VisionModel(Sam3PreTrainedModel):
  843. config_class = Sam3VisionConfig
  844. main_input_name = "pixel_values"
  845. def __init__(self, config: Sam3VisionConfig):
  846. super().__init__(config)
  847. self.config = config
  848. self.backbone = AutoModel.from_config(config.backbone_config)
  849. self.neck = Sam3VisionNeck(config)
  850. self.post_init()
  851. def get_input_embeddings(self):
  852. return self.backbone.get_input_embeddings()
  853. @can_return_tuple
  854. def forward(
  855. self,
  856. pixel_values: torch.FloatTensor | None = None,
  857. **kwargs: Unpack[TransformersKwargs],
  858. ) -> tuple | Sam3VisionEncoderOutput:
  859. if pixel_values is None:
  860. raise ValueError("You have to specify pixel_values")
  861. backbone_output = self.backbone(pixel_values, **kwargs)
  862. hidden_states = backbone_output.last_hidden_state # [batch_size, seq_len, hidden_size]
  863. # Reshape for FPN neck: [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size, height, width]
  864. batch_size = hidden_states.shape[0]
  865. height = pixel_values.shape[-2] // self.config.backbone_config.patch_size
  866. width = pixel_values.shape[-1] // self.config.backbone_config.patch_size
  867. hidden_states_spatial = hidden_states.view(batch_size, height, width, -1).permute(0, 3, 1, 2)
  868. fpn_hidden_states, fpn_position_encoding = self.neck(hidden_states_spatial)
  869. return Sam3VisionEncoderOutput(
  870. last_hidden_state=hidden_states,
  871. fpn_hidden_states=fpn_hidden_states,
  872. fpn_position_encoding=fpn_position_encoding,
  873. hidden_states=backbone_output.hidden_states,
  874. attentions=backbone_output.attentions,
  875. )
  876. class Sam3GeometryEncoderLayer(nn.Module):
  877. def __init__(self, config: Sam3GeometryEncoderConfig):
  878. super().__init__()
  879. self.layer_norm1 = nn.LayerNorm(config.hidden_size)
  880. self.self_attn = Sam3Attention(config)
  881. self.dropout = nn.Dropout(config.dropout)
  882. self.cross_attn = Sam3Attention(config)
  883. self.layer_norm2 = nn.LayerNorm(config.hidden_size)
  884. self.mlp = Sam3MLP(config)
  885. self.layer_norm3 = nn.LayerNorm(config.hidden_size)
  886. def forward(
  887. self,
  888. prompt_feats: Tensor,
  889. vision_feats: Tensor,
  890. vision_pos_encoding: Tensor,
  891. prompt_mask: Tensor,
  892. **kwargs: Unpack[TransformersKwargs],
  893. ):
  894. residual = prompt_feats
  895. hidden_states = self.layer_norm1(prompt_feats)
  896. hidden_states, _ = self.self_attn(
  897. query=hidden_states, key=hidden_states, value=hidden_states, attention_mask=prompt_mask, **kwargs
  898. )
  899. hidden_states = self.dropout(hidden_states) + residual
  900. residual = hidden_states
  901. hidden_states = self.layer_norm2(hidden_states)
  902. key = vision_feats + vision_pos_encoding
  903. hidden_states, _ = self.cross_attn(query=hidden_states, key=key, value=vision_feats, **kwargs)
  904. hidden_states = self.dropout(hidden_states) + residual
  905. residual = hidden_states
  906. hidden_states = self.layer_norm3(hidden_states)
  907. hidden_states = self.mlp(hidden_states)
  908. hidden_states = self.dropout(hidden_states) + residual
  909. return hidden_states
  910. class Sam3GeometryEncoder(nn.Module):
  911. """
  912. Encoder for geometric prompts (boxes).
  913. Boxes are encoded using three approaches:
  914. - Direct projection: linear projection from coordinate space to hidden_size
  915. - Pooling: pool features from the backbone at the specified location (ROI align for boxes)
  916. - Position encoding: use position encoding of the box center
  917. These encodings are combined additively and further processed with transformer layers.
  918. """
  919. def __init__(self, config: Sam3GeometryEncoderConfig):
  920. super().__init__()
  921. self.config = config
  922. self.hidden_size = config.hidden_size
  923. self.roi_size = config.roi_size
  924. self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=True)
  925. self.label_embed = nn.Embedding(2, self.hidden_size)
  926. self.cls_embed = nn.Embedding(1, self.hidden_size)
  927. # Box encoding layers
  928. self.boxes_direct_project = nn.Linear(4, self.hidden_size)
  929. self.boxes_pool_project = nn.Conv2d(self.hidden_size, self.hidden_size, self.roi_size)
  930. self.boxes_pos_enc_project = nn.Linear(self.hidden_size + 2, self.hidden_size)
  931. # Image feature normalization
  932. self.vision_layer_norm = nn.LayerNorm(self.hidden_size)
  933. # Prompt projection and normalization
  934. self.final_proj = nn.Linear(self.hidden_size, self.hidden_size)
  935. self.prompt_layer_norm = nn.LayerNorm(self.hidden_size)
  936. # Transformer layers
  937. self.layers = nn.ModuleList([Sam3GeometryEncoderLayer(config) for _ in range(config.num_layers)])
  938. self.output_layer_norm = nn.LayerNorm(self.hidden_size)
  939. def _encode_box_coordinates(
  940. self, center_x: torch.Tensor, center_y: torch.Tensor, width: torch.Tensor, height: torch.Tensor
  941. ) -> torch.Tensor:
  942. """
  943. Encode box coordinates by combining position-encoded centers with raw width/height.
  944. Args:
  945. center_x: 1D tensor of box center x coordinates
  946. center_y: 1D tensor of box center y coordinates
  947. width: 1D tensor of box widths
  948. height: 1D tensor of box heights
  949. Returns:
  950. Encoded box coordinates [N, embedding_dim]
  951. """
  952. pos_x, pos_y = self.position_encoding.encode_1d_positions(center_x, center_y)
  953. pos = torch.cat((pos_y, pos_x, height[:, None], width[:, None]), dim=1)
  954. return pos
  955. def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features):
  956. """Encode box prompts. Mask convention: True=valid, False=padding."""
  957. batch_size, num_boxes = boxes.shape[:2]
  958. height, width = vision_features.shape[-2:]
  959. boxes_embed = self.boxes_direct_project(boxes)
  960. # Pool features using ROI align
  961. # Convert boxes from CxCyWH to xyxy format and denormalize
  962. boxes_xyxy = box_cxcywh_to_xyxy(boxes)
  963. scale = torch.tensor([width, height, width, height], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
  964. scale = scale.view(1, 1, 4)
  965. boxes_xyxy = boxes_xyxy * scale
  966. # ROI align expects list of boxes per batch element,
  967. # convert from bfloat16 to float16 as roi_align only supports float16 and float32
  968. dtype = torch.float16 if vision_features.dtype == torch.bfloat16 else vision_features.dtype
  969. sampled_features = torchvision.ops.roi_align(
  970. vision_features.to(dtype), boxes_xyxy.to(dtype).unbind(0), self.roi_size
  971. ).to(vision_features.dtype)
  972. pooled_projection = self.boxes_pool_project(sampled_features)
  973. pooled_projection = pooled_projection.view(batch_size, num_boxes, self.hidden_size)
  974. boxes_embed = boxes_embed + pooled_projection
  975. # Add position encoding
  976. center_x, center_y, box_width, box_height = boxes.unbind(-1)
  977. pos_enc = self._encode_box_coordinates(
  978. center_x.flatten(), center_y.flatten(), box_width.flatten(), box_height.flatten()
  979. )
  980. pos_enc = pos_enc.view(batch_size, num_boxes, pos_enc.shape[-1])
  981. pos_projection = self.boxes_pos_enc_project(pos_enc)
  982. boxes_embed = boxes_embed + pos_projection
  983. # Add label embeddings (positive/negative)
  984. label_embed = self.label_embed(boxes_labels.long())
  985. return label_embed + boxes_embed, boxes_mask
  986. def forward(
  987. self,
  988. box_embeddings: torch.Tensor,
  989. box_mask: torch.Tensor,
  990. box_labels: torch.Tensor,
  991. img_feats: tuple[torch.Tensor, ...],
  992. img_pos_embeds: tuple[torch.Tensor, ...] | None = None,
  993. ):
  994. """
  995. Forward pass for encoding geometric prompts.
  996. Args:
  997. box_embeddings: Box coordinates in CxCyWH format [batch_size, num_boxes, 4]
  998. box_mask: Attention mask for boxes [batch_size, num_boxes]
  999. box_labels: Labels for boxes (positive/negative) [batch_size, num_boxes]
  1000. img_feats: Image features from vision encoder
  1001. img_pos_embeds: Optional position embeddings for image features
  1002. Returns:
  1003. Sam3GeometryEncoderOutput containing encoded geometry features and attention mask.
  1004. """
  1005. batch_size = box_embeddings.shape[0]
  1006. # Prepare vision features for cross-attention: flatten spatial dimensions
  1007. vision_feats = img_feats[-1] # [B, C, H, W]
  1008. vision_pos_embeds = img_pos_embeds[-1] if img_pos_embeds is not None else torch.zeros_like(vision_feats)
  1009. vision_feats_flat = vision_feats.flatten(2).transpose(1, 2) # [B, H*W, C]
  1010. vision_pos_embeds_flat = vision_pos_embeds.flatten(2).transpose(1, 2) # [B, H*W, C]
  1011. # Normalize image features for pooling operations
  1012. img_feats_last = img_feats[-1] # [B, C, H, W]
  1013. img_feats_last = img_feats_last.permute(0, 2, 3, 1) # [B, H, W, C]
  1014. normalized_img_feats = self.vision_layer_norm(img_feats_last)
  1015. normalized_img_feats = normalized_img_feats.permute(0, 3, 1, 2) # [B, C, H, W]
  1016. prompt_embeds, prompt_mask = self._encode_boxes(box_embeddings, box_mask, box_labels, normalized_img_feats)
  1017. # Add CLS token (always valid)
  1018. cls_embed = self.cls_embed.weight.view(1, self.hidden_size).unsqueeze(0).expand(batch_size, -1, -1)
  1019. cls_mask = torch.ones(batch_size, 1, dtype=prompt_mask.dtype, device=prompt_mask.device)
  1020. prompt_embeds, prompt_mask = concat_padded_sequences(prompt_embeds, prompt_mask, cls_embed, cls_mask)
  1021. prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds))
  1022. # Create bidirectional attention mask for transformer layers
  1023. prompt_attention_mask = None
  1024. if prompt_mask is not None:
  1025. prompt_attention_mask = create_bidirectional_mask(
  1026. config=self.config,
  1027. inputs_embeds=prompt_embeds,
  1028. attention_mask=prompt_mask,
  1029. )
  1030. # Apply transformer layers with cross-attention to vision features
  1031. for layer in self.layers:
  1032. prompt_embeds = layer(
  1033. prompt_feats=prompt_embeds,
  1034. vision_feats=vision_feats_flat,
  1035. vision_pos_encoding=vision_pos_embeds_flat,
  1036. prompt_mask=prompt_attention_mask,
  1037. )
  1038. # Final output normalization
  1039. prompt_embeds = self.output_layer_norm(prompt_embeds)
  1040. return Sam3GeometryEncoderOutput(
  1041. last_hidden_state=prompt_embeds,
  1042. attention_mask=prompt_mask,
  1043. )
  1044. class Sam3DetrEncoderLayer(nn.Module):
  1045. """DETR encoder layer with self-attention and cross-attention."""
  1046. def __init__(self, config: Sam3DETREncoderConfig):
  1047. super().__init__()
  1048. self.config = config
  1049. self.layer_norm1 = nn.LayerNorm(config.hidden_size)
  1050. self.self_attn = Sam3Attention(config)
  1051. self.dropout = nn.Dropout(config.dropout)
  1052. self.cross_attn = Sam3Attention(config)
  1053. self.layer_norm2 = nn.LayerNorm(config.hidden_size)
  1054. self.mlp = Sam3MLP(config)
  1055. self.layer_norm3 = nn.LayerNorm(config.hidden_size)
  1056. def forward(
  1057. self,
  1058. vision_feats: Tensor,
  1059. prompt_feats: Tensor,
  1060. vision_pos_encoding: Tensor,
  1061. prompt_cross_attn_mask: Tensor | None = None,
  1062. **kwargs: Unpack[TransformersKwargs],
  1063. ):
  1064. """
  1065. Forward pass for DETR encoder layer.
  1066. Args:
  1067. vision_feats: Vision features [batch_size, vision_len, hidden_size] (main hidden states)
  1068. prompt_feats: Text prompt features [batch_size, text_len, hidden_size]
  1069. vision_pos_encoding: Position encoding for vision [batch_size, vision_len, hidden_size]
  1070. prompt_cross_attn_mask: Cross-attention mask for prompt features
  1071. Returns:
  1072. Updated vision features [batch_size, vision_len, hidden_size]
  1073. """
  1074. # Self-attention on vision features with position encoding
  1075. residual = vision_feats
  1076. hidden_states = self.layer_norm1(vision_feats)
  1077. hidden_states_with_pos = hidden_states + vision_pos_encoding
  1078. hidden_states, _ = self.self_attn(
  1079. query=hidden_states_with_pos,
  1080. key=hidden_states_with_pos,
  1081. value=hidden_states,
  1082. **kwargs,
  1083. )
  1084. hidden_states = self.dropout(hidden_states) + residual
  1085. # Cross-attention: vision queries attend to text/prompt features
  1086. residual = hidden_states
  1087. hidden_states = self.layer_norm2(hidden_states)
  1088. hidden_states, _ = self.cross_attn(
  1089. query=hidden_states,
  1090. key=prompt_feats,
  1091. value=prompt_feats,
  1092. attention_mask=prompt_cross_attn_mask,
  1093. **kwargs,
  1094. )
  1095. hidden_states = self.dropout(hidden_states) + residual
  1096. # MLP
  1097. residual = hidden_states
  1098. hidden_states = self.layer_norm3(hidden_states)
  1099. hidden_states = self.mlp(hidden_states)
  1100. hidden_states = self.dropout(hidden_states) + residual
  1101. return hidden_states
  1102. class Sam3DetrEncoder(Sam3PreTrainedModel):
  1103. """
  1104. DETR-style encoder that processes multi-level vision features with text fusion.
  1105. This encoder processes vision features from multiple levels (e.g., FPN features at different
  1106. resolutions) and fuses them with text prompts through a stack of transformer encoder layers.
  1107. """
  1108. _can_record_outputs = {
  1109. "hidden_states": Sam3DetrEncoderLayer,
  1110. "attentions": Sam3Attention,
  1111. }
  1112. def __init__(self, config: Sam3DETREncoderConfig):
  1113. super().__init__(config)
  1114. self.config = config
  1115. self.hidden_size = config.hidden_size
  1116. self.layers = nn.ModuleList([Sam3DetrEncoderLayer(config) for _ in range(config.num_layers)])
  1117. self.post_init()
  1118. def _prepare_multilevel_features(
  1119. self,
  1120. vision_features: list[torch.Tensor],
  1121. vision_pos_embeds: list[torch.Tensor],
  1122. ):
  1123. """
  1124. Prepare multi-level vision features by flattening spatial dimensions and adding level embeddings.
  1125. Args:
  1126. vision_features: List of vision features at different levels [batch_size, channels, height, width]
  1127. vision_pos_embeds: List of position embeddings for each level [batch_size, channels, height, width]
  1128. Returns:
  1129. Tuple containing flattened features, position embeddings, and spatial metadata
  1130. """
  1131. features_flattened = []
  1132. pos_embeds_flattened = []
  1133. spatial_shapes = []
  1134. for features, pos_embed in zip(vision_features, vision_pos_embeds):
  1135. height, width = features.shape[-2:]
  1136. spatial_shapes.append((height, width))
  1137. # Flatten spatial dimensions: [batch_size, channels, height, width] -> [batch_size, height*width, channels]
  1138. features = features.flatten(2).transpose(1, 2)
  1139. pos_embed = pos_embed.flatten(2).transpose(1, 2)
  1140. features_flattened.append(features)
  1141. pos_embeds_flattened.append(pos_embed)
  1142. # Concatenate all levels into single sequence
  1143. features_flattened = torch.cat(features_flattened, dim=1)
  1144. pos_embeds_flattened = torch.cat(pos_embeds_flattened, dim=1)
  1145. spatial_shapes = torch.tensor(spatial_shapes, dtype=torch.long, device=features_flattened.device)
  1146. return (
  1147. features_flattened,
  1148. pos_embeds_flattened,
  1149. spatial_shapes,
  1150. )
  1151. @merge_with_config_defaults
  1152. @capture_outputs
  1153. def forward(
  1154. self,
  1155. vision_features: list[torch.Tensor],
  1156. text_features: torch.Tensor,
  1157. vision_pos_embeds: list[torch.Tensor] | None = None,
  1158. text_mask: torch.Tensor | None = None,
  1159. spatial_sizes: list[tuple[int, int]] | None = None,
  1160. **kwargs: Unpack[TransformersKwargs],
  1161. ) -> tuple | Sam3DETREncoderOutput:
  1162. """
  1163. Forward pass for the DETR encoder.
  1164. Args:
  1165. vision_features: List of vision features at different levels
  1166. text_features: Text prompt features [batch_size, seq_len, hidden_size]
  1167. vision_pos_embeds: Optional list of position embeddings for each level
  1168. text_mask: Optional text padding mask [batch_size, seq_len]
  1169. spatial_sizes: Optional list of (height, width) tuples for reshaping
  1170. Returns:
  1171. Sam3DETREncoderOutput containing encoded features and metadata.
  1172. """
  1173. batch_size = vision_features[0].shape[0] if vision_features[0].dim() == 4 else vision_features[0].shape[1]
  1174. # TODO: See if we can remove that reshaping and just use the features as is.
  1175. if spatial_sizes is not None:
  1176. for i, (height, width) in enumerate(spatial_sizes):
  1177. # Reshape from [height*width, batch_size, channels] to [batch_size, channels, height, width]
  1178. vision_features[i] = vision_features[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
  1179. vision_pos_embeds[i] = vision_pos_embeds[i].reshape(height, width, batch_size, -1).permute(2, 3, 0, 1)
  1180. # Flatten multi-level features for encoder processing
  1181. (
  1182. features_flattened,
  1183. pos_embeds_flattened,
  1184. spatial_shapes,
  1185. ) = self._prepare_multilevel_features(vision_features, vision_pos_embeds)
  1186. prompt_cross_attn_mask = None
  1187. if text_mask is not None:
  1188. prompt_cross_attn_mask = create_bidirectional_mask(
  1189. config=self.config,
  1190. inputs_embeds=features_flattened,
  1191. attention_mask=text_mask,
  1192. encoder_hidden_states=text_features,
  1193. )
  1194. hidden_states = features_flattened
  1195. for layer in self.layers:
  1196. hidden_states = layer(
  1197. hidden_states,
  1198. prompt_feats=text_features,
  1199. vision_pos_encoding=pos_embeds_flattened,
  1200. prompt_cross_attn_mask=prompt_cross_attn_mask,
  1201. **kwargs,
  1202. )
  1203. return Sam3DETREncoderOutput(
  1204. last_hidden_state=hidden_states,
  1205. pos_embeds_flattened=pos_embeds_flattened,
  1206. text_features=text_features,
  1207. spatial_shapes=spatial_shapes,
  1208. )
  1209. class Sam3DecoderMLP(nn.Module):
  1210. """Simple 2 or 3-layer MLP for decoder components."""
  1211. def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 2):
  1212. super().__init__()
  1213. if num_layers == 2:
  1214. self.layer1 = nn.Linear(input_dim, hidden_dim)
  1215. self.layer2 = nn.Linear(hidden_dim, output_dim)
  1216. self.layer3 = None
  1217. elif num_layers == 3:
  1218. self.layer1 = nn.Linear(input_dim, hidden_dim)
  1219. self.layer2 = nn.Linear(hidden_dim, hidden_dim)
  1220. self.layer3 = nn.Linear(hidden_dim, output_dim)
  1221. else:
  1222. raise ValueError(f"Only 2 or 3 layers supported, got {num_layers}")
  1223. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1224. x = F.relu(self.layer1(x))
  1225. if self.layer3 is not None:
  1226. x = F.relu(self.layer2(x))
  1227. x = self.layer3(x)
  1228. else:
  1229. x = self.layer2(x)
  1230. return x
  1231. class Sam3DetrDecoderLayer(nn.Module):
  1232. """DETR decoder layer with self-attention, text cross-attention, and vision cross-attention."""
  1233. def __init__(self, config: Sam3DETRDecoderConfig):
  1234. super().__init__()
  1235. self.config = config
  1236. self.self_attn = Sam3Attention(config)
  1237. self.self_attn_dropout = nn.Dropout(config.dropout)
  1238. self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size)
  1239. self.text_cross_attn = Sam3Attention(config)
  1240. self.text_cross_attn_dropout = nn.Dropout(config.dropout)
  1241. self.text_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
  1242. self.vision_cross_attn = Sam3Attention(config)
  1243. self.vision_cross_attn_dropout = nn.Dropout(config.dropout)
  1244. self.vision_cross_attn_layer_norm = nn.LayerNorm(config.hidden_size)
  1245. self.mlp = Sam3MLP(config)
  1246. self.mlp_layer_norm = nn.LayerNorm(config.hidden_size)
  1247. self.mlp_dropout = nn.Dropout(config.dropout)
  1248. def forward(
  1249. self,
  1250. hidden_states: torch.Tensor,
  1251. query_pos: torch.Tensor,
  1252. text_features: torch.Tensor,
  1253. vision_features: torch.Tensor,
  1254. vision_pos_encoding: torch.Tensor,
  1255. text_cross_attn_mask: torch.Tensor | None = None,
  1256. vision_cross_attn_mask: torch.Tensor | None = None,
  1257. **kwargs: Unpack[TransformersKwargs],
  1258. ) -> torch.Tensor:
  1259. """
  1260. Forward pass for decoder layer.
  1261. Args:
  1262. hidden_states: Query features [batch_size, num_queries + 1, hidden_size] (includes presence token at position 0)
  1263. query_pos: Query position embeddings [batch_size, num_queries, hidden_size]
  1264. text_features: Text features [batch_size, seq_len, hidden_size]
  1265. vision_features: Vision features [batch_size, height*width, hidden_size]
  1266. vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
  1267. text_cross_attn_mask: Text cross-attention mask
  1268. vision_cross_attn_mask: Vision cross-attention mask, already expanded for presence token
  1269. Returns:
  1270. Updated hidden states (including presence token at position 0)
  1271. """
  1272. # Prepend zeros to query_pos for presence token
  1273. query_pos = F.pad(query_pos, (0, 0, 1, 0), mode="constant", value=0)
  1274. # Self-attention with query position encoding
  1275. residual = hidden_states
  1276. query_with_pos = hidden_states + query_pos
  1277. attn_output, _ = self.self_attn(
  1278. query=query_with_pos,
  1279. key=query_with_pos,
  1280. value=hidden_states,
  1281. attention_mask=None,
  1282. **kwargs,
  1283. )
  1284. hidden_states = residual + self.self_attn_dropout(attn_output)
  1285. hidden_states = self.self_attn_layer_norm(hidden_states)
  1286. # Text cross-attention: queries attend to text features
  1287. residual = hidden_states
  1288. query_with_pos = hidden_states + query_pos
  1289. attn_output, _ = self.text_cross_attn(
  1290. query=query_with_pos,
  1291. key=text_features,
  1292. value=text_features,
  1293. attention_mask=text_cross_attn_mask,
  1294. **kwargs,
  1295. )
  1296. hidden_states = residual + self.text_cross_attn_dropout(attn_output)
  1297. hidden_states = self.text_cross_attn_layer_norm(hidden_states)
  1298. # Vision cross-attention: queries attend to vision features (with RPB)
  1299. residual = hidden_states
  1300. query_with_pos = hidden_states + query_pos
  1301. key_with_pos = vision_features + vision_pos_encoding
  1302. attn_output, _ = self.vision_cross_attn(
  1303. query=query_with_pos,
  1304. key=key_with_pos,
  1305. value=vision_features,
  1306. attention_mask=vision_cross_attn_mask,
  1307. **kwargs,
  1308. )
  1309. hidden_states = residual + self.vision_cross_attn_dropout(attn_output)
  1310. hidden_states = self.vision_cross_attn_layer_norm(hidden_states)
  1311. # MLP
  1312. residual = hidden_states
  1313. hidden_states = self.mlp(hidden_states)
  1314. hidden_states = residual + self.mlp_dropout(hidden_states)
  1315. hidden_states = self.mlp_layer_norm(hidden_states)
  1316. return hidden_states
  1317. class Sam3DetrDecoder(Sam3PreTrainedModel):
  1318. """
  1319. DETR-style decoder with box refinement and presence token.
  1320. Simplified version that assumes:
  1321. - Box refinement is always enabled
  1322. - Intermediate outputs are always returned
  1323. - BoxRPB (relative position bias) with log-scale encoding
  1324. - Presence token is used
  1325. """
  1326. _can_record_outputs = {
  1327. "hidden_states": Sam3DetrDecoderLayer,
  1328. "attentions": Sam3Attention,
  1329. }
  1330. def __init__(
  1331. self,
  1332. config: Sam3DETRDecoderConfig,
  1333. ):
  1334. super().__init__(config)
  1335. self.config = config
  1336. self.hidden_size = config.hidden_size
  1337. self.layers = nn.ModuleList([Sam3DetrDecoderLayer(config) for _ in range(config.num_layers)])
  1338. self.output_layer_norm = nn.LayerNorm(config.hidden_size)
  1339. self.box_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 4, 3)
  1340. self.query_embed = nn.Embedding(config.num_queries, config.hidden_size)
  1341. self.reference_points = nn.Embedding(config.num_queries, 4)
  1342. self.presence_token = nn.Embedding(1, config.hidden_size)
  1343. self.presence_head = Sam3DecoderMLP(config.hidden_size, config.hidden_size, 1, 3)
  1344. self.presence_layer_norm = nn.LayerNorm(config.hidden_size)
  1345. self.clamp_presence_logit_max_val = 10.0
  1346. self.ref_point_head = Sam3DecoderMLP(2 * config.hidden_size, config.hidden_size, config.hidden_size, 2)
  1347. self.box_rpb_embed_x = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
  1348. self.box_rpb_embed_y = Sam3DecoderMLP(2, config.hidden_size, config.num_attention_heads, 2)
  1349. self.position_encoding = Sam3SinePositionEmbedding(num_pos_feats=config.hidden_size // 2, normalize=False)
  1350. self.post_init()
  1351. @compile_compatible_method_lru_cache(maxsize=1)
  1352. def _get_coords(
  1353. self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype, device: torch.device
  1354. ) -> tuple[torch.Tensor, torch.Tensor]:
  1355. """Generate normalized coordinate grids."""
  1356. coords_h = torch.arange(0, height, device=device, dtype=dtype) / height
  1357. coords_w = torch.arange(0, width, device=device, dtype=dtype) / width
  1358. return coords_h, coords_w
  1359. def _get_rpb_matrix(
  1360. self, reference_boxes: torch.Tensor, spatial_shape: tuple[torch.Tensor, torch.Tensor]
  1361. ) -> torch.Tensor:
  1362. """
  1363. Compute box relative position bias (RPB) matrix using log-scale encoding.
  1364. RPB helps the decoder attend to relevant spatial locations based on predicted box positions.
  1365. Args:
  1366. reference_boxes: Reference boxes [batch_size, num_queries, 4] in sigmoid space
  1367. spatial_shape: (height, width) of the vision features as tensors
  1368. Returns:
  1369. RPB matrix [batch_size, num_heads, num_queries, height*width]
  1370. """
  1371. height, width = spatial_shape
  1372. boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
  1373. batch_size, num_queries, _ = boxes_xyxy.shape
  1374. # Generate coordinate grids
  1375. coords_h, coords_w = self._get_coords(
  1376. height, width, dtype=reference_boxes.dtype, device=reference_boxes.device
  1377. )
  1378. # Compute deltas between coordinates and box boundaries
  1379. deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
  1380. deltas_y = deltas_y.view(batch_size, num_queries, -1, 2)
  1381. deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
  1382. deltas_x = deltas_x.view(batch_size, num_queries, -1, 2)
  1383. # Apply log-scale encoding
  1384. deltas_x_log = deltas_x * 8
  1385. deltas_x_log = torch.sign(deltas_x_log) * torch.log2(torch.abs(deltas_x_log) + 1.0) / math.log2(8)
  1386. deltas_y_log = deltas_y * 8
  1387. deltas_y_log = torch.sign(deltas_y_log) * torch.log2(torch.abs(deltas_y_log) + 1.0) / math.log2(8)
  1388. # Embed deltas
  1389. deltas_x = self.box_rpb_embed_x(deltas_x_log) # [batch_size, num_queries, width, num_heads]
  1390. deltas_y = self.box_rpb_embed_y(deltas_y_log) # [batch_size, num_queries, height, num_heads]
  1391. # Combine into 2D bias matrix
  1392. rpb_matrix = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
  1393. 2
  1394. ) # [batch_size, num_queries, height, width, num_heads]
  1395. rpb_matrix = rpb_matrix.flatten(2, 3) # [batch_size, num_queries, height*width, num_heads]
  1396. rpb_matrix = rpb_matrix.permute(0, 3, 1, 2).contiguous() # [batch_size, num_heads, num_queries, height*width]
  1397. return rpb_matrix
  1398. @merge_with_config_defaults
  1399. @capture_outputs
  1400. def forward(
  1401. self,
  1402. vision_features: torch.Tensor,
  1403. text_features: torch.Tensor,
  1404. vision_pos_encoding: torch.Tensor,
  1405. text_mask: torch.Tensor | None = None,
  1406. spatial_shapes: torch.Tensor | None = None,
  1407. **kwargs: Unpack[TransformersKwargs],
  1408. ) -> tuple | Sam3DETRDecoderOutput:
  1409. """
  1410. Forward pass for the DETR decoder.
  1411. Args:
  1412. vision_features: Vision features [batch_size, height*width, hidden_size]
  1413. text_features: Text features [batch_size, seq_len, hidden_size]
  1414. vision_pos_encoding: Vision position encoding [batch_size, height*width, hidden_size]
  1415. text_mask: Text padding mask [batch_size, seq_len] where True=valid, False=padding
  1416. spatial_shapes: Spatial shapes [num_levels, 2]
  1417. Returns:
  1418. Sam3DETRDecoderOutput containing decoder outputs from all layers.
  1419. """
  1420. batch_size = vision_features.shape[0]
  1421. query_embeds = self.query_embed.weight.unsqueeze(0).expand(batch_size, -1, -1)
  1422. reference_boxes = self.reference_points.weight.unsqueeze(0).expand(batch_size, -1, -1)
  1423. reference_boxes = reference_boxes.sigmoid()
  1424. presence_token = self.presence_token.weight.unsqueeze(0).expand(batch_size, -1, -1)
  1425. # Concatenate presence token with query embeddings
  1426. hidden_states = torch.cat([presence_token, query_embeds], dim=1)
  1427. text_cross_attn_mask = None
  1428. if text_mask is not None:
  1429. text_cross_attn_mask = create_bidirectional_mask(
  1430. config=self.config,
  1431. inputs_embeds=hidden_states,
  1432. attention_mask=text_mask,
  1433. encoder_hidden_states=text_features,
  1434. )
  1435. intermediate_outputs = []
  1436. intermediate_boxes = [reference_boxes]
  1437. intermediate_presence_logits = []
  1438. for layer in self.layers:
  1439. # Generate sine embeddings for conditional queries
  1440. reference_points_input = reference_boxes.unsqueeze(2)
  1441. query_sine_embed = self.position_encoding.encode_boxes(reference_points_input[:, :, 0, :])
  1442. query_pos = self.ref_point_head(query_sine_embed)
  1443. # Compute box relative position bias (RPB) attention mask
  1444. vision_cross_attn_mask = None
  1445. if spatial_shapes is not None and spatial_shapes.shape[0] == 1:
  1446. spatial_shape = (spatial_shapes[0, 0], spatial_shapes[0, 1])
  1447. rpb_matrix = self._get_rpb_matrix(reference_boxes, spatial_shape)
  1448. # Prepend zeros row for presence token (it attends to all vision tokens equally)
  1449. vision_cross_attn_mask = F.pad(rpb_matrix, (0, 0, 1, 0), mode="constant", value=0)
  1450. hidden_states = layer(
  1451. hidden_states,
  1452. query_pos=query_pos,
  1453. text_features=text_features,
  1454. vision_features=vision_features,
  1455. vision_pos_encoding=vision_pos_encoding,
  1456. text_cross_attn_mask=text_cross_attn_mask,
  1457. vision_cross_attn_mask=vision_cross_attn_mask,
  1458. **kwargs,
  1459. )
  1460. # Extract query hidden states (without presence token) for box refinement
  1461. query_hidden_states = hidden_states[:, 1:]
  1462. # Box refinement: predict delta and update reference boxes
  1463. reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
  1464. delta_boxes = self.box_head(self.output_layer_norm(query_hidden_states))
  1465. new_reference_boxes = (delta_boxes + reference_boxes_before_sigmoid).sigmoid()
  1466. reference_boxes = new_reference_boxes.detach()
  1467. intermediate_outputs.append(self.output_layer_norm(query_hidden_states))
  1468. intermediate_boxes.append(new_reference_boxes)
  1469. # Process presence token
  1470. presence_hidden = hidden_states[:, :1]
  1471. presence_logits = self.presence_head(self.presence_layer_norm(presence_hidden)).squeeze(-1)
  1472. presence_logits = presence_logits.clamp(
  1473. min=-self.clamp_presence_logit_max_val, max=self.clamp_presence_logit_max_val
  1474. )
  1475. intermediate_presence_logits.append(presence_logits)
  1476. # Stack outputs from all layers
  1477. intermediate_outputs = torch.stack(intermediate_outputs)
  1478. intermediate_boxes = torch.stack(intermediate_boxes[:-1])
  1479. intermediate_presence_logits = torch.stack(intermediate_presence_logits)
  1480. return Sam3DETRDecoderOutput(
  1481. intermediate_hidden_states=intermediate_outputs,
  1482. reference_boxes=intermediate_boxes,
  1483. presence_logits=intermediate_presence_logits,
  1484. )
  1485. class Sam3DotProductScoring(nn.Module):
  1486. """
  1487. Computes classification scores by computing dot product between projected decoder queries and pooled text features.
  1488. This is used to determine confidence/presence scores for each query.
  1489. """
  1490. def __init__(self, config: Sam3Config):
  1491. super().__init__()
  1492. self.config = config
  1493. hidden_size = config.detr_decoder_config.hidden_size
  1494. projection_dim = config.detr_decoder_config.hidden_size
  1495. self.text_mlp = Sam3DecoderMLP(
  1496. input_dim=hidden_size,
  1497. hidden_dim=config.detr_decoder_config.intermediate_size,
  1498. output_dim=hidden_size,
  1499. num_layers=2,
  1500. )
  1501. self.text_mlp_dropout = nn.Dropout(config.detr_decoder_config.dropout)
  1502. self.text_mlp_out_norm = nn.LayerNorm(hidden_size)
  1503. # Projections for text and query features
  1504. self.text_proj = nn.Linear(hidden_size, projection_dim)
  1505. self.query_proj = nn.Linear(hidden_size, projection_dim)
  1506. # Scale factor for dot product
  1507. self.scale = float(1.0 / np.sqrt(projection_dim))
  1508. # Clamping to avoid numerical issues
  1509. self.clamp_logits = True
  1510. self.clamp_max_val = 12.0
  1511. def _pool_text_features(self, text_features: torch.Tensor, text_mask: torch.Tensor | None) -> torch.Tensor:
  1512. """
  1513. Mean pool text features, accounting for padding.
  1514. Args:
  1515. text_features: [batch_size, seq_len, hidden_size]
  1516. text_mask: [batch_size, seq_len] where True indicates valid tokens, False indicates padding
  1517. Returns:
  1518. pooled_text: [batch_size, hidden_size]
  1519. """
  1520. if text_mask is None:
  1521. # No padding, simple mean
  1522. return text_features.mean(dim=1)
  1523. is_valid = text_mask.to(text_features.dtype).unsqueeze(-1) # [batch_size, seq_len, 1]
  1524. # Count valid tokens per batch
  1525. num_valid = is_valid.sum(dim=1).clamp(min=1.0) # [batch_size, 1]
  1526. # Mean pool only over valid tokens
  1527. pooled_text = (text_features * is_valid).sum(dim=1) / num_valid # [batch_size, hidden_size]
  1528. return pooled_text
  1529. def forward(
  1530. self,
  1531. decoder_hidden_states: torch.Tensor,
  1532. text_features: torch.Tensor,
  1533. text_mask: torch.Tensor | None = None,
  1534. ) -> torch.Tensor:
  1535. """
  1536. Compute classification scores via dot product.
  1537. Args:
  1538. decoder_hidden_states: [num_layers, batch_size, num_queries, hidden_size]
  1539. text_features: [batch_size, seq_len, hidden_size]
  1540. text_mask: [batch_size, seq_len] where True=valid, False=padding
  1541. Returns:
  1542. scores: [num_layers, batch_size, num_queries, 1]
  1543. """
  1544. orig_text_features = text_features
  1545. text_features = self.text_mlp(text_features)
  1546. text_features = self.text_mlp_dropout(text_features)
  1547. text_features = text_features + orig_text_features
  1548. text_features = self.text_mlp_out_norm(text_features)
  1549. pooled_text = self._pool_text_features(text_features, text_mask)
  1550. proj_text = self.text_proj(pooled_text)
  1551. proj_queries = self.query_proj(decoder_hidden_states)
  1552. proj_text = proj_text.unsqueeze(-1)
  1553. scores = torch.matmul(proj_queries, proj_text.unsqueeze(0))
  1554. scores = scores * self.scale
  1555. if self.clamp_logits:
  1556. scores = scores.clamp(min=-self.clamp_max_val, max=self.clamp_max_val)
  1557. return scores
  1558. class Sam3MaskEmbedder(nn.Module):
  1559. """
  1560. MLP that embeds object queries for mask prediction.
  1561. Similar to MaskFormer's mask embedder.
  1562. """
  1563. def __init__(self, config: Sam3MaskDecoderConfig):
  1564. super().__init__()
  1565. self.config = config
  1566. hidden_size = config.hidden_size
  1567. self.layers = nn.ModuleList(
  1568. [
  1569. nn.Linear(hidden_size, hidden_size),
  1570. nn.Linear(hidden_size, hidden_size),
  1571. nn.Linear(hidden_size, hidden_size),
  1572. ]
  1573. )
  1574. self.activation = nn.ReLU()
  1575. def forward(self, queries: torch.Tensor) -> torch.Tensor:
  1576. """
  1577. Args:
  1578. queries: Query embeddings [batch_size, num_queries, hidden_size]
  1579. Returns:
  1580. Mask embeddings [batch_size, num_queries, hidden_size]
  1581. """
  1582. hidden_states = queries
  1583. for i, layer in enumerate(self.layers):
  1584. hidden_states = layer(hidden_states)
  1585. if i < len(self.layers) - 1:
  1586. hidden_states = self.activation(hidden_states)
  1587. return hidden_states
  1588. class Sam3PixelDecoder(nn.Module):
  1589. """
  1590. Feature Pyramid Network (FPN) decoder that generates pixel-level features.
  1591. Inspired by MaskFormer's pixel decoder.
  1592. """
  1593. def __init__(self, config: Sam3MaskDecoderConfig):
  1594. super().__init__()
  1595. self.config = config
  1596. hidden_size = config.hidden_size
  1597. num_upsampling_stages = config.num_upsampling_stages
  1598. # Create conv layers and norms for FPN
  1599. self.conv_layers = nn.ModuleList(
  1600. [
  1601. nn.Conv2d(hidden_size, hidden_size, kernel_size=3, stride=1, padding=1)
  1602. for _ in range(num_upsampling_stages)
  1603. ]
  1604. )
  1605. self.norms = nn.ModuleList([nn.GroupNorm(8, hidden_size) for _ in range(num_upsampling_stages)])
  1606. self.out_channels = hidden_size
  1607. def forward(self, backbone_features: list[torch.Tensor]) -> torch.Tensor:
  1608. """
  1609. Args:
  1610. backbone_features: List of backbone features [batch_size, hidden_size, H_i, W_i]
  1611. from low to high resolution (assumes already projected to hidden_size)
  1612. Returns:
  1613. Pixel embeddings [batch_size, hidden_size, H, W] at the finest resolution
  1614. """
  1615. # Start from the coarsest feature (last in list)
  1616. prev_fpn = backbone_features[-1]
  1617. # Iterate through features from coarse to fine (excluding the last which we started with)
  1618. for layer_idx, backbone_feat in enumerate(reversed(backbone_features[:-1])):
  1619. # Upsample previous FPN output to match current backbone feature size
  1620. prev_fpn = F.interpolate(prev_fpn, size=backbone_feat.shape[-2:], mode="nearest")
  1621. # Add skip connection
  1622. prev_fpn = prev_fpn + backbone_feat
  1623. # Apply conv and norm
  1624. prev_fpn = self.conv_layers[layer_idx](prev_fpn)
  1625. prev_fpn = self.norms[layer_idx](prev_fpn)
  1626. prev_fpn = F.relu(prev_fpn)
  1627. return prev_fpn
  1628. class Sam3MaskDecoder(Sam3PreTrainedModel):
  1629. """
  1630. Mask decoder that combines object queries with pixel-level features to predict instance masks.
  1631. Also produces a semantic segmentation output and supports cross-attention to prompts.
  1632. """
  1633. _can_record_outputs = {
  1634. "attentions": Sam3Attention,
  1635. }
  1636. def __init__(self, config: Sam3MaskDecoderConfig):
  1637. super().__init__(config)
  1638. self.config = config
  1639. hidden_size = config.hidden_size
  1640. # Pixel decoder (FPN)
  1641. self.pixel_decoder = Sam3PixelDecoder(config)
  1642. # Mask embedder (MLP to transform queries)
  1643. self.mask_embedder = Sam3MaskEmbedder(config)
  1644. # Projection from pixel decoder output to mask embedding space
  1645. self.instance_projection = nn.Conv2d(self.pixel_decoder.out_channels, hidden_size, kernel_size=1)
  1646. # Semantic segmentation head (always present in UniversalSegmentationHead)
  1647. self.semantic_projection = nn.Conv2d(self.pixel_decoder.out_channels, 1, kernel_size=1)
  1648. self.prompt_cross_attn = Sam3Attention(config)
  1649. self.prompt_cross_attn_norm = nn.LayerNorm(hidden_size)
  1650. self.prompt_cross_attn_dropout = nn.Dropout(config.dropout)
  1651. self.post_init()
  1652. @merge_with_config_defaults
  1653. @capture_outputs
  1654. def forward(
  1655. self,
  1656. decoder_queries: torch.Tensor,
  1657. backbone_features: list[torch.Tensor],
  1658. encoder_hidden_states: torch.Tensor,
  1659. prompt_features: torch.Tensor | None = None,
  1660. prompt_mask: torch.Tensor | None = None,
  1661. **kwargs: Unpack[TransformersKwargs],
  1662. ) -> tuple | Sam3MaskDecoderOutput:
  1663. """
  1664. Args:
  1665. decoder_queries: Decoder output queries [batch_size, num_queries, hidden_size]
  1666. backbone_features: List of backbone features to process through FPN
  1667. encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
  1668. prompt_features: Prompt features (text + geometry) for cross-attention [batch_size, prompt_len, hidden_size]
  1669. prompt_mask: Padding mask [batch_size, prompt_len] where True=valid, False=padding
  1670. Returns:
  1671. Sam3MaskDecoderOutput containing predicted masks and semantic segmentation.
  1672. """
  1673. if prompt_features is not None:
  1674. # Cross-attention: encoder features attend to prompt features
  1675. residual = encoder_hidden_states
  1676. normed_hidden_states = self.prompt_cross_attn_norm(encoder_hidden_states)
  1677. cross_attn_mask = None
  1678. if prompt_mask is not None:
  1679. cross_attn_mask = create_bidirectional_mask(
  1680. config=self.config,
  1681. inputs_embeds=normed_hidden_states,
  1682. encoder_hidden_states=prompt_features,
  1683. attention_mask=prompt_mask,
  1684. )
  1685. attn_output, _ = self.prompt_cross_attn(
  1686. query=normed_hidden_states,
  1687. key=prompt_features,
  1688. value=prompt_features,
  1689. attention_mask=cross_attn_mask,
  1690. **kwargs,
  1691. )
  1692. encoder_hidden_states = residual + self.prompt_cross_attn_dropout(attn_output)
  1693. # Process backbone features through FPN to get pixel embeddings
  1694. pixel_embed = self._embed_pixels(
  1695. backbone_features=backbone_features,
  1696. encoder_hidden_states=encoder_hidden_states,
  1697. )
  1698. # Predict instance masks via dot product between query embeddings and pixel embeddings
  1699. instance_embeds = self.instance_projection(pixel_embed)
  1700. mask_embeddings = self.mask_embedder(decoder_queries)
  1701. pred_masks = torch.einsum("bqc,bchw->bqhw", mask_embeddings, instance_embeds)
  1702. # Generate semantic segmentation
  1703. semantic_seg = self.semantic_projection(pixel_embed)
  1704. return Sam3MaskDecoderOutput(
  1705. pred_masks=pred_masks,
  1706. semantic_seg=semantic_seg,
  1707. )
  1708. def _embed_pixels(
  1709. self,
  1710. backbone_features: list[torch.Tensor],
  1711. encoder_hidden_states: torch.Tensor,
  1712. ) -> torch.Tensor:
  1713. """
  1714. Embed pixels by combining backbone FPN features with encoder vision features.
  1715. The encoder vision features replace the finest-resolution backbone feature.
  1716. Args:
  1717. backbone_features: List of backbone features [batch_size, C, H_i, W_i]
  1718. encoder_hidden_states: Encoder outputs [batch_size, seq_len, hidden_size]
  1719. Returns:
  1720. Pixel embeddings [batch_size, hidden_size, H, W]
  1721. """
  1722. backbone_visual_feats = [feat.clone() for feat in backbone_features]
  1723. # Extract vision features from encoder output and reshape to spatial format
  1724. spatial_dim = backbone_features[-1].shape[-2] * backbone_features[-1].shape[-1]
  1725. encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :]
  1726. batch_size, _, hidden_size = encoder_visual_embed.shape
  1727. height, width = backbone_features[-1].shape[-2:]
  1728. encoder_visual_embed = encoder_visual_embed.transpose(1, 2).reshape(batch_size, hidden_size, height, width)
  1729. # Replace finest backbone feature with encoder vision features
  1730. backbone_visual_feats[-1] = encoder_visual_embed
  1731. # Process through FPN decoder
  1732. pixel_embed = self.pixel_decoder(backbone_visual_feats)
  1733. return pixel_embed
  1734. class Sam3Model(Sam3PreTrainedModel):
  1735. input_modalities = ["image", "text"]
  1736. base_model_prefix = "detector_model"
  1737. _keys_to_ignore_on_load_unexpected = [
  1738. r"^tracker_model.",
  1739. r"^tracker_neck.",
  1740. ]
  1741. def __init__(self, config: Sam3Config):
  1742. # loading from a sam3_video config
  1743. if hasattr(config, "detector_config") and config.detector_config is not None:
  1744. detector_config = config.detector_config
  1745. if isinstance(detector_config, dict):
  1746. detector_config = Sam3Config(**detector_config)
  1747. config = detector_config
  1748. super().__init__(config)
  1749. self.vision_encoder = Sam3VisionModel(config.vision_config)
  1750. self.text_encoder = CLIPTextModelWithProjection(config.text_config)
  1751. self.vocab_size = config.text_config.vocab_size
  1752. # Project text features from text encoder hidden size to model hidden size
  1753. # CLIP text encoder outputs 1024-dim features, but we need 256-dim for DETR
  1754. self.text_projection = nn.Linear(config.text_config.hidden_size, config.detr_encoder_config.hidden_size)
  1755. # Pass _attn_implementation to subconfigs BEFORE creating modules
  1756. config.geometry_encoder_config._attn_implementation = config._attn_implementation
  1757. config.detr_encoder_config._attn_implementation = config._attn_implementation
  1758. config.detr_decoder_config._attn_implementation = config._attn_implementation
  1759. config.mask_decoder_config._attn_implementation = config._attn_implementation
  1760. self.geometry_encoder = Sam3GeometryEncoder(config.geometry_encoder_config)
  1761. self.detr_encoder = Sam3DetrEncoder(config.detr_encoder_config)
  1762. self.detr_decoder = Sam3DetrDecoder(config.detr_decoder_config)
  1763. self.mask_decoder = Sam3MaskDecoder(config.mask_decoder_config)
  1764. # Dot product scoring to compute classification scores
  1765. self.dot_product_scoring = Sam3DotProductScoring(config)
  1766. self.post_init()
  1767. @can_return_tuple
  1768. @auto_docstring
  1769. def get_text_features(
  1770. self,
  1771. input_ids: torch.LongTensor,
  1772. attention_mask: torch.Tensor | None = None,
  1773. **kwargs: Unpack[TransformersKwargs],
  1774. ) -> tuple | BaseModelOutputWithPooling:
  1775. r"""
  1776. Example:
  1777. ```python
  1778. >>> from transformers import Sam3Model, Sam3Processor
  1779. >>> from PIL import Image
  1780. >>> import httpx
  1781. >>> from io import BytesIO
  1782. >>> model = Sam3Model.from_pretrained("facebook/sam3")
  1783. >>> processor = Sam3Processor.from_pretrained("facebook/sam3")
  1784. >>> # Pre-compute text embeddings
  1785. >>> text_inputs = processor(text="cat", return_tensors="pt")
  1786. >>> text_embeds = model.get_text_features(**text_inputs).pooler_output
  1787. >>> # Reuse text embeddings for multiple images
  1788. >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
  1789. >>> with httpx.stream("GET", url) as response:
  1790. ... image = Image.open(BytesIO(response.read()))
  1791. >>> img_inputs = processor(images=image, return_tensors="pt")
  1792. >>> outputs = model(pixel_values=img_inputs.pixel_values, text_embeds=text_embeds)
  1793. ```
  1794. """
  1795. text_outputs = self.text_encoder(
  1796. input_ids=input_ids, attention_mask=attention_mask, return_dict=True, **kwargs
  1797. )
  1798. last_hidden_state = text_outputs.last_hidden_state
  1799. text_outputs.pooler_output = self.text_projection(last_hidden_state)
  1800. return text_outputs
  1801. @auto_docstring
  1802. def get_vision_features(
  1803. self,
  1804. pixel_values: torch.FloatTensor,
  1805. **kwargs: Unpack[TransformersKwargs],
  1806. ) -> Sam3VisionEncoderOutput:
  1807. r"""
  1808. Example:
  1809. ```python
  1810. >>> from transformers import Sam3Model, Sam3Processor
  1811. >>> from PIL import Image
  1812. >>> import httpx
  1813. >>> from io import BytesIO
  1814. >>> model = Sam3Model.from_pretrained("facebook/sam3")
  1815. >>> processor = Sam3Processor.from_pretrained("facebook/sam3")
  1816. >>> # Pre-compute vision embeddings
  1817. >>> url = "http://images.cocodataset.org/val2017/000000077595.jpg"
  1818. >>> with httpx.stream("GET", url) as response:
  1819. ... image = Image.open(BytesIO(response.read()))
  1820. >>> img_inputs = processor(images=image, return_tensors="pt")
  1821. >>> vision_embeds = model.get_vision_features(pixel_values=img_inputs.pixel_values)
  1822. >>> # Reuse vision embeddings for multiple text prompts
  1823. >>> text_inputs = processor(text="cat", return_tensors="pt")
  1824. >>> outputs = model(vision_embeds=vision_embeds, input_ids=text_inputs.input_ids)
  1825. ```
  1826. """
  1827. vision_outputs = self.vision_encoder(pixel_values, **kwargs)
  1828. return vision_outputs
  1829. @can_return_tuple
  1830. @auto_docstring
  1831. def forward(
  1832. self,
  1833. pixel_values: torch.FloatTensor | None = None,
  1834. vision_embeds: Sam3VisionEncoderOutput | None = None,
  1835. input_ids: torch.LongTensor | None = None,
  1836. attention_mask: torch.Tensor | None = None,
  1837. text_embeds: torch.FloatTensor | None = None,
  1838. input_boxes: torch.FloatTensor | None = None,
  1839. input_boxes_labels: torch.LongTensor | None = None,
  1840. **kwargs: Unpack[TransformersKwargs],
  1841. ) -> Sam3ImageSegmentationOutput:
  1842. r"""
  1843. vision_embeds (`Sam3VisionEncoderOutput`, *optional*):
  1844. Pre-computed vision embeddings. Can be used to easily reuse vision embeddings. If provided, `pixel_values`
  1845. should not be passed. Mutually exclusive with `pixel_values`.
  1846. text_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1847. Pre-computed text embeddings. Can be used to easily reuse text embeddings. If provided, `input_ids`
  1848. should not be passed. Mutually exclusive with `input_ids`.
  1849. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`, *optional*):
  1850. Normalized box coordinates in [0, 1] range, in (cx, cy, w, h) format.
  1851. input_boxes_labels (`torch.LongTensor` of shape `(batch_size, num_boxes)`, *optional*):
  1852. Labels for boxes: 1 (positive), 0 (negative).
  1853. Example:
  1854. ```python
  1855. >>> from PIL import Image
  1856. >>> import httpx
  1857. >>> from io import BytesIO
  1858. >>> from transformers import AutoModel, AutoProcessor
  1859. >>> model = AutoModel.from_pretrained("facebook/sam3")
  1860. >>> processor = AutoProcessor.from_pretrained("facebook/sam3")
  1861. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  1862. >>> with httpx.stream("GET", url) as response:
  1863. ... image = Image.open(BytesIO(response.read())).convert("RGB")
  1864. >>> text = "car"
  1865. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  1866. >>> # Get segmentation output
  1867. >>> outputs = model(**inputs)
  1868. >>> pred_masks = outputs.pred_masks
  1869. >>> pred_boxes = outputs.pred_boxes
  1870. ```
  1871. """
  1872. if (pixel_values is None) == (vision_embeds is None):
  1873. raise ValueError("You must specify exactly one of pixel_values or vision_embeds")
  1874. if (input_ids is None) == (text_embeds is None):
  1875. raise ValueError("You must specify exactly one of input_ids or text_embeds")
  1876. if pixel_values is not None:
  1877. batch_size = pixel_values.shape[0]
  1878. device = pixel_values.device
  1879. else:
  1880. batch_size = vision_embeds.fpn_hidden_states[0].shape[0]
  1881. device = vision_embeds.fpn_hidden_states[0].device
  1882. if vision_embeds is None:
  1883. vision_outputs = self.vision_encoder(pixel_values, **kwargs)
  1884. else:
  1885. vision_outputs = vision_embeds
  1886. fpn_hidden_states = vision_outputs.fpn_hidden_states[:-1]
  1887. fpn_position_encoding = vision_outputs.fpn_position_encoding[:-1]
  1888. if text_embeds is None:
  1889. text_features = self.get_text_features(
  1890. input_ids=input_ids, attention_mask=attention_mask, return_dict=True
  1891. ).pooler_output
  1892. else:
  1893. text_features = text_embeds
  1894. text_mask = attention_mask.bool() if attention_mask is not None else None
  1895. has_geometry_prompts = input_boxes is not None and input_boxes.numel() > 0
  1896. geometry_prompt_features = None
  1897. geometry_prompt_mask = None
  1898. if has_geometry_prompts:
  1899. if input_boxes is not None and input_boxes.numel() > 0:
  1900. box_embeddings = input_boxes # [batch_size, num_boxes, 4]
  1901. box_labels = (
  1902. input_boxes_labels
  1903. if input_boxes_labels is not None
  1904. else torch.ones_like(box_embeddings[..., 0], dtype=torch.long)
  1905. )
  1906. box_mask = (
  1907. (input_boxes_labels != -10)
  1908. if input_boxes_labels is not None
  1909. else torch.ones(batch_size, input_boxes.shape[1], dtype=torch.bool, device=device)
  1910. )
  1911. box_labels = torch.where(box_labels == -10, 0, box_labels)
  1912. else:
  1913. box_embeddings = torch.zeros(batch_size, 0, 4, dtype=text_features.dtype, device=device)
  1914. box_labels = torch.zeros(batch_size, 0, dtype=torch.long, device=device)
  1915. box_mask = torch.zeros(batch_size, 0, dtype=torch.bool, device=device)
  1916. geometry_outputs = self.geometry_encoder(
  1917. box_embeddings=box_embeddings,
  1918. box_mask=box_mask,
  1919. box_labels=box_labels,
  1920. img_feats=fpn_hidden_states,
  1921. img_pos_embeds=fpn_position_encoding,
  1922. )
  1923. geometry_prompt_features = geometry_outputs.last_hidden_state
  1924. geometry_prompt_mask = geometry_outputs.attention_mask
  1925. if geometry_prompt_features is not None:
  1926. # Repeat text_features for all geometry prompts
  1927. if text_features.shape[0] == 1 and geometry_prompt_features.shape[0] > 1:
  1928. text_features = text_features.repeat(geometry_prompt_features.shape[0], 1, 1)
  1929. combined_prompt_features = torch.cat([text_features, geometry_prompt_features], dim=1)
  1930. if text_mask is not None and text_mask.shape[0] == 1 and geometry_prompt_mask.shape[0] > 1:
  1931. text_mask = text_mask.repeat(geometry_prompt_mask.shape[0], 1)
  1932. if text_mask is not None and geometry_prompt_mask is not None:
  1933. combined_prompt_mask = torch.cat([text_mask, geometry_prompt_mask], dim=1)
  1934. elif text_mask is not None:
  1935. geo_valid_mask = torch.ones(
  1936. batch_size, geometry_prompt_features.shape[1], dtype=torch.bool, device=device
  1937. )
  1938. combined_prompt_mask = torch.cat([text_mask, geo_valid_mask], dim=1)
  1939. elif geometry_prompt_mask is not None:
  1940. text_valid_mask = torch.ones(batch_size, text_features.shape[1], dtype=torch.bool, device=device)
  1941. combined_prompt_mask = torch.cat([text_valid_mask, geometry_prompt_mask], dim=1)
  1942. else:
  1943. combined_prompt_mask = None
  1944. else:
  1945. combined_prompt_features = text_features
  1946. combined_prompt_mask = text_mask
  1947. encoder_outputs = self.detr_encoder(
  1948. vision_features=[fpn_hidden_states[-1]],
  1949. text_features=combined_prompt_features,
  1950. vision_pos_embeds=[fpn_position_encoding[-1]],
  1951. text_mask=combined_prompt_mask,
  1952. **kwargs,
  1953. )
  1954. decoder_outputs = self.detr_decoder(
  1955. vision_features=encoder_outputs.last_hidden_state,
  1956. text_features=encoder_outputs.text_features,
  1957. vision_pos_encoding=encoder_outputs.pos_embeds_flattened,
  1958. text_mask=combined_prompt_mask,
  1959. spatial_shapes=encoder_outputs.spatial_shapes,
  1960. **kwargs,
  1961. )
  1962. # Refine boxes from decoder
  1963. all_box_offsets = self.detr_decoder.box_head(decoder_outputs.intermediate_hidden_states)
  1964. reference_boxes_inv_sig = inverse_sigmoid(decoder_outputs.reference_boxes)
  1965. all_pred_boxes_cxcywh = (reference_boxes_inv_sig + all_box_offsets).sigmoid()
  1966. all_pred_boxes = box_cxcywh_to_xyxy(all_pred_boxes_cxcywh)
  1967. all_pred_logits = self.dot_product_scoring(
  1968. decoder_hidden_states=decoder_outputs.intermediate_hidden_states,
  1969. text_features=encoder_outputs.text_features,
  1970. text_mask=combined_prompt_mask,
  1971. ).squeeze(-1)
  1972. pred_logits = all_pred_logits[-1]
  1973. pred_boxes = all_pred_boxes[-1]
  1974. decoder_hidden_states = decoder_outputs.intermediate_hidden_states[-1]
  1975. presence_logits = decoder_outputs.presence_logits[-1]
  1976. mask_outputs = self.mask_decoder(
  1977. decoder_queries=decoder_hidden_states,
  1978. backbone_features=list(fpn_hidden_states),
  1979. encoder_hidden_states=encoder_outputs.last_hidden_state,
  1980. prompt_features=combined_prompt_features,
  1981. prompt_mask=combined_prompt_mask,
  1982. **kwargs,
  1983. )
  1984. return Sam3ImageSegmentationOutput(
  1985. pred_masks=mask_outputs.pred_masks,
  1986. pred_boxes=pred_boxes,
  1987. pred_logits=pred_logits,
  1988. presence_logits=presence_logits,
  1989. semantic_seg=mask_outputs.semantic_seg,
  1990. decoder_hidden_states=decoder_outputs.hidden_states,
  1991. decoder_reference_boxes=decoder_outputs.reference_boxes,
  1992. encoder_hidden_states=encoder_outputs.hidden_states,
  1993. vision_hidden_states=vision_outputs.hidden_states,
  1994. vision_attentions=vision_outputs.attentions,
  1995. detr_encoder_attentions=encoder_outputs.attentions,
  1996. detr_decoder_attentions=decoder_outputs.attentions,
  1997. mask_decoder_attentions=mask_outputs.attentions,
  1998. )
  1999. __all__ = ["Sam3Model", "Sam3VisionModel", "Sam3ViTModel", "Sam3PreTrainedModel"]