modeling_sam2.py 71 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/sam2/modular_sam2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_sam2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. import numpy as np
  24. import torch
  25. import torch.nn as nn
  26. import torch.nn.functional as F
  27. from torch import Tensor
  28. from ... import initialization as init
  29. from ...activations import ACT2FN
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...pytorch_utils import compile_compatible_method_lru_cache
  35. from ...utils import (
  36. ModelOutput,
  37. auto_docstring,
  38. can_return_tuple,
  39. logging,
  40. )
  41. from ...utils.generic import TransformersKwargs, is_flash_attention_requested, merge_with_config_defaults
  42. from ...utils.output_capturing import OutputRecorder, capture_outputs
  43. from ..auto import AutoModel
  44. from .configuration_sam2 import (
  45. Sam2Config,
  46. Sam2HieraDetConfig,
  47. Sam2MaskDecoderConfig,
  48. Sam2PromptEncoderConfig,
  49. Sam2VisionConfig,
  50. )
  51. logger = logging.get_logger(__name__)
  52. @dataclass
  53. @auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
  54. class Sam2VisionEncoderOutput(BaseModelOutputWithPooling):
  55. r"""
  56. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  57. Sequence of hidden-states at the output of the last layer of the model.
  58. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  59. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  60. one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
  61. model at the output of each stage.
  62. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  63. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  64. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  65. the self-attention heads.
  66. fpn_hidden_states (`tuple(torch.FloatTensor)`):
  67. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  68. `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
  69. fpn_position_encoding (`tuple(torch.FloatTensor)`):
  70. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  71. `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
  72. """
  73. fpn_hidden_states: torch.FloatTensor | None = None
  74. fpn_position_encoding: torch.FloatTensor | None = None
  75. @dataclass
  76. @auto_docstring(custom_intro="Base class for the Sam2 model's output.")
  77. class Sam2ImageSegmentationOutput(ModelOutput):
  78. r"""
  79. iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
  80. The Intersection over Union (IoU) scores of the predicted masks.
  81. pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
  82. The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
  83. by the processor to be brought to the original image size.
  84. object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
  85. Logits for the object score, indicating if an object is present.
  86. image_embeddings (`tuple(torch.FloatTensor)`):
  87. The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
  88. tensor has shape `(batch_size, channels, height, width)`.
  89. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  90. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
  91. Hidden-states of the vision model at the output of each stage.
  92. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  93. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  94. Attentions weights of the vision model.
  95. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  96. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  97. Attentions weights of the mask decoder.
  98. """
  99. iou_scores: torch.FloatTensor | None = None
  100. pred_masks: torch.FloatTensor | None = None
  101. object_score_logits: torch.FloatTensor | None = None
  102. image_embeddings: tuple[torch.FloatTensor, ...] = None
  103. vision_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  104. vision_attentions: tuple[torch.FloatTensor, ...] | None = None
  105. mask_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  106. class Sam2PatchEmbeddings(nn.Module):
  107. r"""
  108. Turns pixel values into patch embeddings for transformer consumption.
  109. Args:
  110. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  111. Pixel values. Pixel values can be obtained using
  112. [`AutoImageProcessor`]. See [`Sam2ImageProcessor.__call__`] for details.
  113. Returns:
  114. embeddings (`torch.FloatTensor`):
  115. Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
  116. """
  117. def __init__(self, config: Sam2HieraDetConfig):
  118. super().__init__()
  119. num_channels = config.num_channels
  120. hidden_size = config.hidden_size
  121. self.projection = nn.Conv2d(
  122. num_channels,
  123. hidden_size,
  124. kernel_size=config.patch_kernel_size,
  125. stride=config.patch_stride,
  126. padding=config.patch_padding,
  127. )
  128. def forward(self, pixel_values):
  129. _, num_channels, height, width = pixel_values.shape
  130. embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).permute(0, 2, 3, 1)
  131. return embeddings
  132. # copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
  133. class Sam2SinePositionEmbedding(nn.Module):
  134. """
  135. This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
  136. need paper, generalized to work on images.
  137. """
  138. def __init__(
  139. self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
  140. ):
  141. super().__init__()
  142. if scale is not None and normalize is False:
  143. raise ValueError("normalize should be True if scale is passed")
  144. self.num_pos_feats = num_pos_feats
  145. self.temperature = temperature
  146. self.normalize = normalize
  147. self.scale = 2 * math.pi if scale is None else scale
  148. @compile_compatible_method_lru_cache(maxsize=1)
  149. def forward(
  150. self,
  151. shape: torch.Size,
  152. device: torch.device | str,
  153. dtype: torch.dtype,
  154. mask: Tensor | None = None,
  155. ) -> Tensor:
  156. if mask is None:
  157. mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
  158. not_mask = (~mask).to(dtype)
  159. y_embed = not_mask.cumsum(1)
  160. x_embed = not_mask.cumsum(2)
  161. if self.normalize:
  162. eps = 1e-6
  163. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  164. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  165. dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
  166. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
  167. pos_x = x_embed[:, :, :, None] / dim_t
  168. pos_y = y_embed[:, :, :, None] / dim_t
  169. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  170. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  171. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  172. return pos
  173. class Sam2VisionNeck(nn.Module):
  174. def __init__(self, config: Sam2VisionConfig):
  175. super().__init__()
  176. self.config = config
  177. self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
  178. self.convs = nn.ModuleList()
  179. for in_channels in config.backbone_channel_list:
  180. self.convs.append(
  181. nn.Conv2d(
  182. in_channels=in_channels,
  183. out_channels=config.fpn_hidden_size,
  184. kernel_size=config.fpn_kernel_size,
  185. stride=config.fpn_stride,
  186. padding=config.fpn_padding,
  187. ),
  188. )
  189. self.fpn_top_down_levels = config.fpn_top_down_levels
  190. def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
  191. fpn_hidden_states = ()
  192. fpn_position_encoding = ()
  193. # forward in top-down order (from low to high resolution)
  194. n = len(self.convs) - 1
  195. for i in range(n, -1, -1):
  196. lateral_features = hidden_states[i].permute(0, 3, 1, 2)
  197. lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
  198. if i not in self.fpn_top_down_levels or i == n:
  199. prev_features = lateral_features
  200. else:
  201. top_down_features = F.interpolate(
  202. prev_features.to(dtype=torch.float32),
  203. scale_factor=2.0,
  204. mode="nearest",
  205. align_corners=None,
  206. antialias=False,
  207. ).to(lateral_features.dtype)
  208. prev_features = lateral_features + top_down_features
  209. prev_position_encoding = self.position_encoding(
  210. prev_features.shape, prev_features.device, prev_features.dtype
  211. ).to(prev_features.dtype)
  212. fpn_hidden_states += (prev_features,)
  213. fpn_position_encoding += (prev_position_encoding,)
  214. return fpn_hidden_states, fpn_position_encoding
  215. def eager_attention_forward(
  216. module: nn.Module,
  217. query: torch.Tensor,
  218. key: torch.Tensor,
  219. value: torch.Tensor,
  220. attention_mask: torch.Tensor | None,
  221. scaling: float,
  222. dropout: float = 0.0,
  223. **kwargs,
  224. ):
  225. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  226. if attention_mask is not None:
  227. attn_weights = attn_weights + attention_mask
  228. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  229. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  230. attn_output = torch.matmul(attn_weights, value)
  231. attn_output = attn_output.transpose(1, 2).contiguous()
  232. return attn_output, attn_weights
  233. def do_pool(x: torch.Tensor, query_stride: int | None = None) -> torch.Tensor:
  234. if query_stride is None:
  235. return x
  236. # (B, H, W, C) -> (B, C, H, W)
  237. x = x.permute(0, 3, 1, 2)
  238. x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
  239. # (B, C, H', W') -> (B, H', W', C)
  240. x = x.permute(0, 2, 3, 1)
  241. return x
  242. class Sam2MultiScaleAttention(nn.Module):
  243. def __init__(
  244. self,
  245. config: Sam2HieraDetConfig,
  246. dim: int,
  247. dim_out: int,
  248. num_attention_heads: int,
  249. query_stride: tuple[int, int] | None = None,
  250. ):
  251. super().__init__()
  252. self.config = config
  253. self.dim = dim
  254. self.dim_out = dim_out
  255. self.query_stride = query_stride
  256. self.num_attention_heads = num_attention_heads
  257. head_dim = dim_out // num_attention_heads
  258. self.scale = head_dim**-0.5
  259. self.qkv = nn.Linear(dim, dim_out * 3)
  260. self.proj = nn.Linear(dim_out, dim_out)
  261. self.is_causal = False
  262. def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
  263. batch_size, height, width, _ = hidden_states.shape
  264. # qkv with shape (B, H * W, 3, nHead, C)
  265. qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  266. # q, k, v with shape (B, H * W, nheads, C)
  267. query, key, value = torch.unbind(qkv, 2)
  268. attn_weights = (query * self.scale) @ key.transpose(-2, -1)
  269. attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
  270. # Q pooling (for downsample at stage changes)
  271. if self.query_stride:
  272. query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
  273. height, width = query.shape[1:3] # downsampled shape
  274. query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
  275. # transpose query, key, value to (B, nHead, H * W, C)
  276. query = query.transpose(1, 2)
  277. key = key.transpose(1, 2)
  278. value = value.transpose(1, 2)
  279. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  280. self.config._attn_implementation, eager_attention_forward
  281. )
  282. attn_output, _ = attention_interface(
  283. self,
  284. query,
  285. key,
  286. value,
  287. attention_mask=None,
  288. is_causal=self.is_causal,
  289. scaling=self.scale,
  290. **kwargs,
  291. )
  292. attn_output = attn_output.reshape(batch_size, height, width, -1)
  293. attn_output = self.proj(attn_output)
  294. return attn_output
  295. class Sam2FeedForward(nn.Module):
  296. def __init__(
  297. self,
  298. input_dim: int,
  299. hidden_dim: int,
  300. output_dim: int,
  301. num_layers: int,
  302. activation: str = "relu",
  303. sigmoid_output: bool = False,
  304. ):
  305. super().__init__()
  306. self.num_layers = num_layers
  307. self.activation = ACT2FN[activation]
  308. self.proj_in = nn.Linear(input_dim, hidden_dim)
  309. self.proj_out = nn.Linear(hidden_dim, output_dim)
  310. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  311. self.sigmoid_output = sigmoid_output
  312. def forward(self, hidden_states):
  313. hidden_states = self.proj_in(hidden_states)
  314. hidden_states = self.activation(hidden_states)
  315. for layer in self.layers:
  316. hidden_states = self.activation(layer(hidden_states))
  317. hidden_states = self.proj_out(hidden_states)
  318. if self.sigmoid_output:
  319. hidden_states = F.sigmoid(hidden_states)
  320. return hidden_states
  321. def window_partition(hidden_state, window_size):
  322. """
  323. Partition into non-overlapping windows with padding if needed.
  324. Args:
  325. hidden_state (`torch.Tensor`):
  326. Input tokens with [batch_size, height, width, num_channels].
  327. window_size (`int`):
  328. Window size.
  329. Returns:
  330. `tuple(torch.FloatTensor)` comprising various elements:
  331. - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
  332. - (padded_height, padded_width): padded height and width before partition
  333. """
  334. batch_size, height, width, num_channels = hidden_state.shape
  335. pad_height = (window_size - height % window_size) % window_size
  336. pad_width = (window_size - width % window_size) % window_size
  337. # Noop in case pad_width == 0 and pad_height == 0.
  338. hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
  339. padded_height, padded_width = height + pad_height, width + pad_width
  340. hidden_state = hidden_state.view(
  341. batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
  342. )
  343. windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  344. return windows, (padded_height, padded_width)
  345. def window_unpartition(windows, window_size, pad_height_width, height_width):
  346. """
  347. Window unpartition into original sequences and removing padding.
  348. Args:
  349. windows (`torch.Tensor`):
  350. Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
  351. window_size (`int`):
  352. Window size.
  353. pad_height_width (`tuple[int]`):
  354. Padded height and width (padded_height, padded_width).
  355. height_width (`tuple[int]`):
  356. Original height and width before padding.
  357. Returns:
  358. hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
  359. """
  360. padded_height, padded_width = pad_height_width
  361. height, width = height_width
  362. batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
  363. hidden_state = windows.view(
  364. batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
  365. )
  366. hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
  367. hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
  368. # We always have height <= padded_height and width <= padded_width
  369. hidden_state = hidden_state[:, :height, :width, :].contiguous()
  370. return hidden_state
  371. class Sam2MultiScaleBlock(GradientCheckpointingLayer):
  372. def __init__(
  373. self,
  374. config: Sam2HieraDetConfig,
  375. stage_idx: int,
  376. block_idx: int,
  377. total_block_idx: int,
  378. ):
  379. super().__init__()
  380. # take embed dim from previous stage if first block of stage
  381. self.dim = (
  382. config.embed_dim_per_stage[stage_idx - 1]
  383. if stage_idx > 0 and block_idx == 0
  384. else config.embed_dim_per_stage[stage_idx]
  385. )
  386. self.dim_out = config.embed_dim_per_stage[stage_idx]
  387. self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
  388. # take window size from previous stage if first block of stage
  389. self.window_size = (
  390. config.window_size_per_stage[stage_idx - 1]
  391. if stage_idx > 0 and block_idx == 0
  392. else config.window_size_per_stage[stage_idx]
  393. )
  394. self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
  395. # use query stride for first block of stage if stage is a query pool stage
  396. self.query_stride = (
  397. config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
  398. )
  399. self.attn = Sam2MultiScaleAttention(
  400. config,
  401. self.dim,
  402. self.dim_out,
  403. num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
  404. query_stride=self.query_stride,
  405. )
  406. self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
  407. self.mlp = Sam2FeedForward(
  408. self.dim_out,
  409. int(self.dim_out * config.mlp_ratio),
  410. self.dim_out,
  411. num_layers=2,
  412. activation=config.hidden_act,
  413. )
  414. if self.dim != self.dim_out:
  415. self.proj = nn.Linear(self.dim, self.dim_out)
  416. def forward(
  417. self,
  418. hidden_states: torch.Tensor,
  419. **kwargs: Unpack[TransformersKwargs],
  420. ) -> torch.FloatTensor:
  421. residual = hidden_states # batch_size, height, width, channel
  422. hidden_states = self.layer_norm1(hidden_states)
  423. # Skip connection
  424. if self.dim != self.dim_out:
  425. residual = do_pool(self.proj(hidden_states), self.query_stride)
  426. # Window partition
  427. window_size = self.window_size
  428. if self.window_size > 0:
  429. H, W = hidden_states.shape[1], hidden_states.shape[2]
  430. hidden_states, pad_hw = window_partition(hidden_states, window_size)
  431. # Window Attention + Q Pooling (if stage change)
  432. attn_output = self.attn(
  433. hidden_states=hidden_states,
  434. **kwargs,
  435. )
  436. hidden_states = attn_output
  437. if self.query_stride:
  438. # Shapes have changed due to Q pooling
  439. window_size = self.window_size // self.query_stride[0]
  440. H, W = residual.shape[1:3]
  441. pad_h = (window_size - H % window_size) % window_size
  442. pad_w = (window_size - W % window_size) % window_size
  443. pad_hw = (H + pad_h, W + pad_w)
  444. # Reverse window partition
  445. if self.window_size > 0:
  446. hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
  447. hidden_states = residual + hidden_states
  448. layernorm_output = self.layer_norm2(hidden_states)
  449. hidden_states = hidden_states + self.mlp(layernorm_output)
  450. return hidden_states
  451. @dataclass
  452. @auto_docstring(
  453. custom_intro="""
  454. Hiera model's outputs that also contains a pooling of the last hidden states.
  455. """
  456. )
  457. class Sam2HieraDetModelOutput(ModelOutput):
  458. r"""
  459. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  460. hidden-states at the output of the last layer of the model.
  461. intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
  462. Sequence of hidden-states at the output of the intermediate layers of the model.
  463. """
  464. last_hidden_state: torch.FloatTensor | None = None
  465. intermediate_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  466. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  467. attentions: tuple[torch.FloatTensor, ...] | None = None
  468. @auto_docstring
  469. class Sam2PreTrainedModel(PreTrainedModel):
  470. config_class = Sam2Config
  471. base_model_prefix = "sam2"
  472. main_input_name = "pixel_values"
  473. input_modalities = ("image",)
  474. _supports_sdpa = True
  475. _supports_flash_attn = True
  476. _supports_attention_backend = True
  477. _keys_to_ignore_on_load_unexpected = [
  478. r"^memory_.*",
  479. r"^mask_downsample.*",
  480. r"^object_pointer_proj.*",
  481. r"^temporal_positional_encoding_projection_layer.*",
  482. "no_memory_positional_encoding",
  483. "no_object_pointer",
  484. "occlusion_spatial_embedding_parameter",
  485. ]
  486. @torch.no_grad()
  487. def _init_weights(self, module):
  488. super()._init_weights(module)
  489. if isinstance(module, Sam2HieraDetModel):
  490. if module.pos_embed is not None:
  491. init.zeros_(module.pos_embed)
  492. if module.pos_embed_window is not None:
  493. init.zeros_(module.pos_embed_window)
  494. elif isinstance(module, Sam2PositionalEmbedding):
  495. init.normal_(module.positional_embedding, std=module.scale)
  496. elif isinstance(module, Sam2Model):
  497. if module.no_memory_embedding is not None:
  498. init.zeros_(module.no_memory_embedding)
  499. class Sam2HieraDetModel(Sam2PreTrainedModel):
  500. config_class = Sam2HieraDetConfig
  501. main_input_name = "pixel_values"
  502. _can_record_outputs = {
  503. "hidden_states": Sam2MultiScaleBlock,
  504. "attentions": Sam2MultiScaleAttention,
  505. }
  506. def __init__(self, config: Sam2HieraDetConfig):
  507. super().__init__(config)
  508. self.patch_embed = Sam2PatchEmbeddings(config)
  509. # Windowed positional embedding (https://huggingface.co/papers/2311.05613)
  510. self.pos_embed = nn.Parameter(
  511. torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
  512. )
  513. self.pos_embed_window = nn.Parameter(
  514. torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
  515. )
  516. self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
  517. self.blocks = nn.ModuleList()
  518. total_block_idx = 0
  519. for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
  520. for block_idx in range(blocks_per_stage):
  521. block = Sam2MultiScaleBlock(
  522. config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
  523. )
  524. self.blocks.append(block)
  525. total_block_idx += 1
  526. self.post_init()
  527. def get_input_embeddings(self):
  528. return self.patch_embed
  529. def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
  530. h, w = hw
  531. window_embed = self.pos_embed_window
  532. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  533. pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
  534. pos_embed = pos_embed.permute(0, 2, 3, 1)
  535. return pos_embed
  536. @merge_with_config_defaults
  537. @capture_outputs
  538. def forward(
  539. self,
  540. pixel_values: torch.FloatTensor | None = None,
  541. **kwargs: Unpack[TransformersKwargs],
  542. ) -> tuple | Sam2HieraDetModelOutput:
  543. if pixel_values is None:
  544. raise ValueError("You have to specify pixel_values")
  545. hidden_states = self.patch_embed(pixel_values)
  546. hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
  547. intermediate_hidden_states = ()
  548. for i, block_module in enumerate(self.blocks):
  549. hidden_states = block_module(hidden_states, **kwargs)
  550. if i in self.stage_ends:
  551. intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
  552. return Sam2HieraDetModelOutput(
  553. last_hidden_state=hidden_states,
  554. intermediate_hidden_states=intermediate_hidden_states,
  555. )
  556. @auto_docstring(
  557. custom_intro="""
  558. The vision model from Sam without any head or projection on top.
  559. """
  560. )
  561. class Sam2VisionModel(Sam2PreTrainedModel):
  562. config_class = Sam2VisionConfig
  563. main_input_name = "pixel_values"
  564. _can_record_outputs = {
  565. "hidden_states": Sam2MultiScaleBlock,
  566. "attentions": Sam2MultiScaleAttention,
  567. }
  568. def __init__(self, config: Sam2VisionConfig):
  569. super().__init__(config)
  570. self.config = config
  571. self.backbone = AutoModel.from_config(config.backbone_config)
  572. self.neck = Sam2VisionNeck(config)
  573. self.num_feature_levels = config.num_feature_levels
  574. self.post_init()
  575. def get_input_embeddings(self):
  576. return self.backbone.get_input_embeddings()
  577. @can_return_tuple
  578. def forward(
  579. self,
  580. pixel_values: torch.FloatTensor | None = None,
  581. **kwargs: Unpack[TransformersKwargs],
  582. ) -> tuple | Sam2VisionEncoderOutput:
  583. if pixel_values is None:
  584. raise ValueError("You have to specify pixel_values")
  585. # Forward through backbone
  586. backbone_output = self.backbone(pixel_values, **kwargs)
  587. hidden_states = backbone_output.last_hidden_state
  588. intermediate_hidden_states = backbone_output.intermediate_hidden_states
  589. fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
  590. # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
  591. fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
  592. fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
  593. return Sam2VisionEncoderOutput(
  594. last_hidden_state=hidden_states,
  595. fpn_hidden_states=fpn_hidden_states,
  596. fpn_position_encoding=fpn_position_encoding,
  597. hidden_states=backbone_output.hidden_states,
  598. attentions=backbone_output.attentions,
  599. )
  600. class Sam2PositionalEmbedding(nn.Module):
  601. def __init__(self, config: Sam2PromptEncoderConfig):
  602. super().__init__()
  603. self.scale = config.scale
  604. positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
  605. self.register_buffer("positional_embedding", positional_embedding)
  606. def forward(self, input_coords, input_shape=None):
  607. """Positionally encode points that are normalized to [0,1]."""
  608. coordinates = input_coords.clone()
  609. if input_shape is not None:
  610. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  611. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  612. coordinates.to(torch.float32)
  613. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  614. coordinates = 2 * coordinates - 1
  615. coordinates = coordinates.to(self.positional_embedding.dtype)
  616. coordinates = coordinates @ self.positional_embedding
  617. coordinates = 2 * np.pi * coordinates
  618. # outputs d_1 x ... x d_n x channel shape
  619. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  620. class Sam2MaskEmbedding(nn.Module):
  621. def __init__(self, config: Sam2PromptEncoderConfig):
  622. super().__init__()
  623. self.mask_input_channels = config.mask_input_channels // 4
  624. self.activation = ACT2FN[config.hidden_act]
  625. self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
  626. self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
  627. self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
  628. self.layer_norm1 = Sam2LayerNorm(
  629. self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
  630. )
  631. self.layer_norm2 = Sam2LayerNorm(
  632. self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
  633. )
  634. def forward(self, masks):
  635. hidden_states = self.conv1(masks)
  636. hidden_states = self.layer_norm1(hidden_states)
  637. hidden_states = self.activation(hidden_states)
  638. hidden_states = self.conv2(hidden_states)
  639. hidden_states = self.layer_norm2(hidden_states)
  640. hidden_states = self.activation(hidden_states)
  641. dense_embeddings = self.conv3(hidden_states)
  642. return dense_embeddings
  643. class Sam2PromptEncoder(nn.Module):
  644. def __init__(self, config: Sam2PromptEncoderConfig):
  645. super().__init__()
  646. self.shared_embedding = Sam2PositionalEmbedding(config)
  647. self.mask_embed = Sam2MaskEmbedding(config)
  648. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  649. self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  650. self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
  651. self.input_image_size = config.image_size
  652. self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
  653. self.hidden_size = config.hidden_size
  654. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  655. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  656. """Embeds point prompts."""
  657. points = points + 0.5 # Shift to center of pixel
  658. if pad:
  659. points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
  660. labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
  661. input_shape = (self.input_image_size, self.input_image_size)
  662. point_embedding = self.shared_embedding(points, input_shape)
  663. # torch.where and expanding the labels tensor is required by the ONNX export
  664. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  665. # This is required for the ONNX export. The dtype, device need to be explicitly
  666. # specified as otherwise torch.onnx.export interprets as double
  667. point_embedding = torch.where(
  668. labels[..., None] != -10,
  669. point_embedding,
  670. torch.zeros_like(point_embedding),
  671. )
  672. # Add point embeddings for labels >= 0
  673. point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
  674. return point_embedding
  675. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  676. """Embeds box prompts."""
  677. boxes = boxes + 0.5 # Shift to center of pixel
  678. coords = boxes.view(*boxes.shape[:2], 2, 2)
  679. # add padding point for consistency with the original implementation
  680. coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
  681. corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
  682. corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
  683. corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
  684. corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
  685. return corner_embedding
  686. def forward(
  687. self,
  688. input_points: tuple[torch.Tensor, torch.Tensor] | None,
  689. input_labels: torch.Tensor | None,
  690. input_boxes: torch.Tensor | None,
  691. input_masks: torch.Tensor | None,
  692. ) -> tuple[torch.Tensor, torch.Tensor]:
  693. """
  694. Embeds different types of prompts, returning both sparse and dense embeddings.
  695. Args:
  696. points (`torch.Tensor`, *optional*):
  697. point coordinates and labels to embed.
  698. boxes (`torch.Tensor`, *optional*):
  699. boxes to embed
  700. masks (`torch.Tensor`, *optional*):
  701. masks to embed
  702. """
  703. sparse_embeddings = None
  704. batch_size = 1
  705. if input_points is not None:
  706. batch_size = input_points.shape[0]
  707. if input_labels is None:
  708. raise ValueError("If points are provided, labels must also be provided.")
  709. point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
  710. sparse_embeddings = point_embeddings
  711. if input_boxes is not None:
  712. batch_size = input_boxes.shape[0]
  713. box_embeddings = self._embed_boxes(input_boxes)
  714. if sparse_embeddings is None:
  715. sparse_embeddings = box_embeddings
  716. else:
  717. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
  718. if input_masks is not None:
  719. dense_embeddings = self.mask_embed(input_masks)
  720. else:
  721. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  722. batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  723. )
  724. return sparse_embeddings, dense_embeddings
  725. class Sam2Attention(nn.Module):
  726. """
  727. SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  728. values.
  729. """
  730. def __init__(self, config, downsample_rate=None):
  731. super().__init__()
  732. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  733. self.config = config
  734. self.hidden_size = config.hidden_size
  735. self.internal_dim = config.hidden_size // downsample_rate
  736. self.num_attention_heads = config.num_attention_heads
  737. self.head_dim = self.internal_dim // config.num_attention_heads
  738. self.scaling = self.head_dim**-0.5
  739. self.is_causal = False
  740. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  741. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  742. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  743. self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
  744. def forward(
  745. self,
  746. query: torch.Tensor,
  747. key: torch.Tensor,
  748. value: torch.Tensor,
  749. attention_similarity: torch.Tensor | None = None,
  750. **kwargs: Unpack[TransformersKwargs],
  751. ) -> tuple[torch.Tensor, torch.Tensor]:
  752. # Input projections
  753. batch_size, point_batch_size = query.shape[:2]
  754. new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
  755. query = self.q_proj(query).view(*new_shape).transpose(1, 2)
  756. key = self.k_proj(key).view(*new_shape).transpose(1, 2)
  757. value = self.v_proj(value).view(*new_shape).transpose(1, 2)
  758. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  759. self.config._attn_implementation, eager_attention_forward
  760. )
  761. if is_flash_attention_requested(self.config) and attention_similarity is not None:
  762. # Target guided masks are represented as float masks and are incompatible with Flash Attention
  763. # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
  764. attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
  765. logger.warning_once(
  766. "Falling back to SDPA for target-guided attention because "
  767. "Flash Attention does not support additive bias masks."
  768. )
  769. attn_output, attn_weights = attention_interface(
  770. self,
  771. query,
  772. key,
  773. value,
  774. attention_mask=attention_similarity,
  775. dropout=0.0,
  776. scaling=self.scaling,
  777. is_causal=self.is_causal,
  778. **kwargs,
  779. )
  780. attn_output = attn_output.reshape(
  781. batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
  782. ).contiguous()
  783. attn_output = self.o_proj(attn_output)
  784. return attn_output, attn_weights
  785. class Sam2TwoWayAttentionBlock(GradientCheckpointingLayer):
  786. def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
  787. """
  788. A transformer block with four layers:
  789. (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
  790. sparse inputs (4) cross attention of dense inputs -> sparse inputs
  791. Arguments:
  792. config (`Sam2MaskDecoderConfig`):
  793. The configuration file used to instantiate the block
  794. attention_downsample_rate (*optionalk*, int, defaults to 2):
  795. The downsample ratio of the block used to reduce the inner dim of the attention.
  796. skip_first_layer_pe (*optional*, bool, defaults to `False`):
  797. Whether or not to skip the addition of the query_point_embedding on the first layer.
  798. """
  799. super().__init__()
  800. self.self_attn = Sam2Attention(config, downsample_rate=1)
  801. self.layer_norm1 = nn.LayerNorm(config.hidden_size)
  802. self.cross_attn_token_to_image = Sam2Attention(config)
  803. self.layer_norm2 = nn.LayerNorm(config.hidden_size)
  804. self.mlp = Sam2FeedForward(
  805. config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
  806. )
  807. self.layer_norm3 = nn.LayerNorm(config.hidden_size)
  808. self.layer_norm4 = nn.LayerNorm(config.hidden_size)
  809. self.cross_attn_image_to_token = Sam2Attention(config)
  810. self.skip_first_layer_pe = skip_first_layer_pe
  811. def forward(
  812. self,
  813. queries: Tensor,
  814. keys: Tensor,
  815. query_point_embedding: Tensor,
  816. key_point_embedding: Tensor,
  817. attention_similarity: Tensor,
  818. **kwargs: Unpack[TransformersKwargs],
  819. ):
  820. # Self attention block
  821. if self.skip_first_layer_pe:
  822. queries, _ = self.self_attn(query=queries, key=queries, value=queries)
  823. else:
  824. query = queries + query_point_embedding
  825. attn_out, _ = self.self_attn(query=query, key=query, value=queries)
  826. queries = queries + attn_out
  827. queries = self.layer_norm1(queries)
  828. # Cross attention block, tokens attending to image embedding
  829. query = queries + query_point_embedding
  830. key = keys + key_point_embedding
  831. attn_out, _ = self.cross_attn_token_to_image(
  832. query=query, key=key, value=keys, attention_similarity=attention_similarity
  833. )
  834. queries = queries + attn_out
  835. queries = self.layer_norm2(queries)
  836. # MLP block
  837. mlp_out = self.mlp(queries)
  838. queries = queries + mlp_out
  839. queries = self.layer_norm3(queries)
  840. # Cross attention block, image embedding attending to tokens
  841. query = queries + query_point_embedding
  842. key = keys + key_point_embedding
  843. attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
  844. keys = keys + attn_out
  845. keys = self.layer_norm4(keys)
  846. return queries, keys, attn_out
  847. class Sam2TwoWayTransformer(nn.Module):
  848. def __init__(self, config: Sam2MaskDecoderConfig):
  849. super().__init__()
  850. self.config = config
  851. self.num_hidden_layers = config.num_hidden_layers
  852. self.layers = nn.ModuleList()
  853. for i in range(self.num_hidden_layers):
  854. self.layers.append(Sam2TwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
  855. self.final_attn_token_to_image = Sam2Attention(config)
  856. self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
  857. def forward(
  858. self,
  859. point_embeddings: Tensor,
  860. image_embeddings: Tensor,
  861. image_positional_embeddings: Tensor,
  862. attention_similarity: Tensor,
  863. target_embedding=None,
  864. **kwargs: Unpack[TransformersKwargs],
  865. ) -> tuple | BaseModelOutput:
  866. if image_embeddings is None:
  867. raise ValueError("You have to specify an image_embedding")
  868. image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  869. image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  870. # Prepare queries
  871. queries = point_embeddings
  872. keys = image_embeddings
  873. # Apply transformer blocks and final layernorm
  874. for layer in self.layers:
  875. if target_embedding is not None:
  876. queries += target_embedding
  877. queries, keys, _ = layer(
  878. queries=queries,
  879. keys=keys,
  880. query_point_embedding=point_embeddings,
  881. key_point_embedding=image_positional_embeddings,
  882. attention_similarity=attention_similarity,
  883. **kwargs,
  884. )
  885. # Apply the final attention layer from the points to the image
  886. query = queries + point_embeddings
  887. key = keys + image_positional_embeddings
  888. attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
  889. queries = queries + attn_out
  890. queries = self.layer_norm_final_attn(queries)
  891. return queries, keys
  892. class Sam2LayerNorm(nn.LayerNorm):
  893. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  894. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  895. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  896. """
  897. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  898. super().__init__(normalized_shape, eps=eps, **kwargs)
  899. if data_format not in ["channels_last", "channels_first"]:
  900. raise NotImplementedError(f"Unsupported data format: {data_format}")
  901. self.data_format = data_format
  902. def forward(self, features: torch.Tensor) -> torch.Tensor:
  903. """
  904. Args:
  905. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  906. """
  907. if self.data_format == "channels_first":
  908. features = features.permute(0, 2, 3, 1)
  909. features = super().forward(features)
  910. features = features.permute(0, 3, 1, 2)
  911. else:
  912. features = super().forward(features)
  913. return features
  914. class Sam2MaskDecoder(nn.Module):
  915. def __init__(self, config: Sam2MaskDecoderConfig):
  916. super().__init__()
  917. self.config = config
  918. self.hidden_size = config.hidden_size
  919. self.num_multimask_outputs = config.num_multimask_outputs
  920. self.num_mask_tokens = config.num_multimask_outputs + 1
  921. self.iou_token = nn.Embedding(1, self.hidden_size)
  922. self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
  923. self.transformer = Sam2TwoWayTransformer(config)
  924. # should we create a new class for this?
  925. self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
  926. self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
  927. self.upscale_layer_norm = Sam2LayerNorm(self.hidden_size // 4, data_format="channels_first")
  928. self.activation = nn.GELU()
  929. mlps_list = []
  930. for _ in range(self.num_mask_tokens):
  931. mlps_list += [Sam2FeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
  932. self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
  933. self.iou_prediction_head = Sam2FeedForward(
  934. self.hidden_size,
  935. config.iou_head_hidden_dim,
  936. self.num_mask_tokens,
  937. config.iou_head_depth,
  938. sigmoid_output=True,
  939. )
  940. self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
  941. self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
  942. self.obj_score_token = nn.Embedding(1, self.hidden_size)
  943. self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
  944. self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
  945. self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
  946. self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
  947. def forward(
  948. self,
  949. image_embeddings: torch.Tensor,
  950. image_positional_embeddings: torch.Tensor,
  951. sparse_prompt_embeddings: torch.Tensor,
  952. dense_prompt_embeddings: torch.Tensor,
  953. multimask_output: bool,
  954. high_resolution_features: list[torch.Tensor],
  955. attention_similarity: torch.Tensor | None = None,
  956. target_embedding: torch.Tensor | None = None,
  957. **kwargs: Unpack[TransformersKwargs],
  958. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  959. """
  960. Predict masks given image and prompt embeddings.
  961. Args:
  962. image_embeddings (`torch.Tensor`):
  963. The embeddings from the image encoder.
  964. image_positional_embeddings (`torch.Tensor`):
  965. Positional encoding with the shape of image_embeddings.
  966. sparse_prompt_embeddings (`torch.Tensor`):
  967. The embeddings of the points and boxes.
  968. dense_prompt_embeddings (`torch.Tensor`):
  969. The embeddings of the mask inputs.
  970. multimask_output (`bool`):
  971. Whether to return multiple masks or a single mask.
  972. high_resolution_features (`list[torch.Tensor]`, *optional*):
  973. The high-resolution features from the vision encoder.
  974. attention_similarity (`torch.Tensor`, *optional*):
  975. The attention similarity tensor.
  976. target_embedding (`torch.Tensor`, *optional*):
  977. The target embedding.
  978. """
  979. batch_size, num_channels, height, width = image_embeddings.shape
  980. point_batch_size = sparse_prompt_embeddings.shape[1]
  981. # Concatenate output tokens
  982. output_tokens = torch.cat(
  983. [
  984. self.obj_score_token.weight,
  985. self.iou_token.weight,
  986. self.mask_tokens.weight,
  987. ],
  988. dim=0,
  989. )
  990. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  991. if sparse_prompt_embeddings.shape[0] != 0:
  992. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  993. else:
  994. tokens = output_tokens
  995. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  996. # Expand per-image data in batch direction to be per-mask
  997. image_embeddings = image_embeddings + dense_prompt_embeddings
  998. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
  999. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  1000. # Run the transformer
  1001. point_embeddings, image_embeddings = self.transformer(
  1002. point_embeddings=point_embeddings,
  1003. image_embeddings=image_embeddings,
  1004. image_positional_embeddings=image_positional_embeddings,
  1005. attention_similarity=attention_similarity,
  1006. target_embedding=target_embedding,
  1007. **kwargs,
  1008. )
  1009. iou_token_out = point_embeddings[:, :, 1, :]
  1010. mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
  1011. # Upscale mask embeddings and predict masks using the mask tokens
  1012. image_embeddings = image_embeddings.transpose(2, 3).view(
  1013. batch_size * point_batch_size, num_channels, height, width
  1014. )
  1015. feat_s0, feat_s1 = high_resolution_features
  1016. feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
  1017. feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
  1018. upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
  1019. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  1020. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
  1021. hyper_in_list: list[torch.Tensor] = []
  1022. for i in range(self.num_mask_tokens):
  1023. current_mlp = self.output_hypernetworks_mlps[i]
  1024. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  1025. hyper_in = torch.stack(hyper_in_list, dim=2)
  1026. _, num_channels, height, width = upscaled_embedding.shape
  1027. upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
  1028. masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
  1029. # Generate mask quality predictions
  1030. iou_pred = self.iou_prediction_head(iou_token_out)
  1031. object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
  1032. # Select the correct mask or masks for output
  1033. if multimask_output:
  1034. mask_slice = slice(1, None)
  1035. masks = masks[:, :, mask_slice, :, :]
  1036. iou_pred = iou_pred[:, :, mask_slice]
  1037. elif self.dynamic_multimask_via_stability and not self.training:
  1038. mask_slice = slice(0, 1)
  1039. masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
  1040. else:
  1041. mask_slice = slice(0, 1)
  1042. masks = masks[:, :, mask_slice, :, :]
  1043. iou_pred = iou_pred[:, :, mask_slice]
  1044. sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
  1045. return masks, iou_pred, sam_tokens_out, object_score_logits
  1046. def _get_stability_scores(self, mask_logits):
  1047. """
  1048. Compute stability scores of the mask logits based on the IoU between upper and
  1049. lower thresholds.
  1050. """
  1051. mask_logits = mask_logits.flatten(-2)
  1052. stability_delta = self.dynamic_multimask_stability_delta
  1053. area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
  1054. area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
  1055. stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
  1056. return stability_scores
  1057. def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
  1058. """
  1059. When outputting a single mask, if the stability score from the current single-mask
  1060. output (based on output token 0) falls below a threshold, we instead select from
  1061. multi-mask outputs (based on output token 1~3) the mask with the highest predicted
  1062. IoU score. This is intended to ensure a valid mask for both clicking and tracking.
  1063. """
  1064. # The best mask from multimask output tokens (1~3)
  1065. multimask_logits = all_mask_logits[:, :, 1:, :, :]
  1066. multimask_iou_scores = all_iou_scores[:, :, 1:]
  1067. best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
  1068. best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
  1069. best_scores_inds_expanded = best_scores_inds_expanded.expand(
  1070. -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
  1071. )
  1072. best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
  1073. best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
  1074. # The mask from singlemask output token 0 and its stability score
  1075. singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
  1076. singlemask_iou_scores = all_iou_scores[:, :, 0:1]
  1077. stability_scores = self._get_stability_scores(singlemask_logits)
  1078. is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
  1079. # Dynamically fall back to best multimask output upon low stability scores.
  1080. mask_logits_out = torch.where(
  1081. is_stable[..., None, None].expand_as(singlemask_logits),
  1082. singlemask_logits,
  1083. best_multimask_logits,
  1084. )
  1085. iou_scores_out = torch.where(
  1086. is_stable.expand_as(singlemask_iou_scores),
  1087. singlemask_iou_scores,
  1088. best_multimask_iou_scores,
  1089. )
  1090. return mask_logits_out, iou_scores_out
  1091. @auto_docstring(
  1092. custom_intro="""
  1093. Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
  1094. input points and labels, boxes, or masks.
  1095. """
  1096. )
  1097. class Sam2Model(Sam2PreTrainedModel):
  1098. input_modalities = ("image", "text")
  1099. _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2TwoWayAttentionBlock, index=2)}
  1100. _tied_weights_keys = {}
  1101. def __init__(self, config: Sam2Config):
  1102. super().__init__(config)
  1103. self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
  1104. self.vision_encoder = AutoModel.from_config(config.vision_config)
  1105. self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
  1106. # The module using it is not a PreTrainedModel subclass so we need this
  1107. config.mask_decoder_config._attn_implementation = config._attn_implementation
  1108. self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
  1109. self.num_feature_levels = config.vision_config.num_feature_levels
  1110. self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
  1111. # a single token to indicate no memory embedding from previous frames
  1112. self.hidden_dim = config.vision_config.fpn_hidden_size
  1113. self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  1114. self.post_init()
  1115. def get_input_embeddings(self):
  1116. return self.vision_encoder.get_input_embeddings()
  1117. def get_image_wide_positional_embeddings(self) -> torch.Tensor:
  1118. size = self.prompt_encoder.image_embedding_size
  1119. target_device = self.shared_image_embedding.positional_embedding.device
  1120. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  1121. grid = torch.ones(size, device=target_device, dtype=target_dtype)
  1122. y_embed = grid.cumsum(dim=0) - 0.5
  1123. x_embed = grid.cumsum(dim=1) - 0.5
  1124. y_embed = y_embed / size[0]
  1125. x_embed = x_embed / size[1]
  1126. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  1127. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  1128. @torch.no_grad()
  1129. def get_image_embeddings(
  1130. self,
  1131. pixel_values: torch.FloatTensor,
  1132. **kwargs: Unpack[TransformersKwargs],
  1133. ) -> list[torch.Tensor]:
  1134. r"""
  1135. Returns the image embeddings by passing the pixel values through the vision encoder.
  1136. Args:
  1137. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1138. Input pixel values
  1139. """
  1140. batch_size = pixel_values.shape[0]
  1141. image_outputs = self.get_image_features(pixel_values, return_dict=True, **kwargs)
  1142. feature_maps = image_outputs.fpn_hidden_states
  1143. # add no memory embedding to the last feature map
  1144. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  1145. # reshape feature maps to the same shape as the backbone feature sizes
  1146. image_embeddings = [
  1147. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  1148. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  1149. ]
  1150. return image_embeddings
  1151. @torch.no_grad()
  1152. def get_prompt_embeddings(
  1153. self,
  1154. input_points: torch.FloatTensor | None = None,
  1155. input_labels: torch.LongTensor | None = None,
  1156. input_boxes: torch.FloatTensor | None = None,
  1157. input_masks: torch.LongTensor | None = None,
  1158. ):
  1159. r"""
  1160. Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
  1161. Args:
  1162. input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
  1163. Optional input points for the prompt encoder. The padding of the point is automatically done by the
  1164. processor. `point_batch_size` refers to the number of masks that we want the model to predict per
  1165. point. The model will output `point_batch_size` times 3 masks in total.
  1166. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
  1167. Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
  1168. processor, or can be fed by the user.
  1169. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
  1170. Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
  1171. processor. users can also pass manually the input boxes.
  1172. input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
  1173. Optional input masks for the prompt encoder.
  1174. """
  1175. prompt_output = self.prompt_encoder(
  1176. input_points=input_points,
  1177. input_labels=input_labels,
  1178. input_boxes=input_boxes,
  1179. input_masks=input_masks,
  1180. )
  1181. return prompt_output
  1182. @merge_with_config_defaults
  1183. @capture_outputs
  1184. @auto_docstring
  1185. def forward(
  1186. self,
  1187. pixel_values: torch.FloatTensor | None = None,
  1188. input_points: torch.FloatTensor | None = None,
  1189. input_labels: torch.LongTensor | None = None,
  1190. input_boxes: torch.FloatTensor | None = None,
  1191. input_masks: torch.LongTensor | None = None,
  1192. image_embeddings: torch.FloatTensor | None = None,
  1193. multimask_output: bool = True,
  1194. attention_similarity: torch.FloatTensor | None = None,
  1195. target_embedding: torch.FloatTensor | None = None,
  1196. **kwargs: Unpack[TransformersKwargs],
  1197. ) -> Sam2ImageSegmentationOutput:
  1198. r"""
  1199. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  1200. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  1201. better results. The points can be obtained by passing a list of list of list to the processor that will
  1202. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  1203. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  1204. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  1205. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  1206. coordinates of the point. If a different number of points is passed either for each image, or for each
  1207. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  1208. computation of the embedding will be skipped for these points using the labels.
  1209. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  1210. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  1211. official implementation, there are 3 types of labels
  1212. - `1`: the point is a point that contains the object of interest
  1213. - `0`: the point is a point that does not contain the object of interest
  1214. - `-1`: the point corresponds to the background
  1215. We added the label:
  1216. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  1217. The padding labels should be automatically done by the processor.
  1218. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  1219. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  1220. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  1221. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  1222. size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
  1223. In the order (`x1`, `y1`, `x2`, `y2`):
  1224. - `x1`: the x coordinate of the top left point of the input box
  1225. - `y1`: the y coordinate of the top left point of the input box
  1226. - `x2`: the x coordinate of the bottom right point of the input box
  1227. - `y2`: the y coordinate of the bottom right point of the input box
  1228. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  1229. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  1230. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  1231. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  1232. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  1233. Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
  1234. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  1235. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  1236. multimask_output (`bool`, *optional*):
  1237. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  1238. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  1239. "best" mask, by specifying `multimask_output=False`.
  1240. attention_similarity (`torch.FloatTensor`, *optional*):
  1241. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  1242. model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1243. target_embedding (`torch.FloatTensor`, *optional*):
  1244. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  1245. the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1246. Example:
  1247. ```python
  1248. >>> from PIL import Image
  1249. >>> import httpx
  1250. >>> from io import BytesIO
  1251. >>> from transformers import AutoModel, AutoProcessor
  1252. >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
  1253. >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
  1254. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  1255. >>> with httpx.stream("GET", url) as response:
  1256. ... raw_image = Image.open(BytesIO(response.read())).convert("RGB")
  1257. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  1258. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  1259. >>> # Get segmentation mask
  1260. >>> outputs = model(**inputs)
  1261. >>> # Postprocess masks
  1262. >>> masks = processor.post_process_masks(
  1263. ... outputs.pred_masks, inputs["original_sizes"]
  1264. ... )
  1265. ```
  1266. """
  1267. if not ((pixel_values is None) ^ (image_embeddings is None)):
  1268. raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
  1269. if input_points is not None and input_boxes is not None:
  1270. if input_points.shape[1] != input_boxes.shape[1]:
  1271. raise ValueError(
  1272. f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
  1273. )
  1274. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  1275. # repeat with batch size
  1276. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
  1277. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  1278. vision_attentions = None
  1279. vision_hidden_states = None
  1280. if pixel_values is not None:
  1281. image_outputs: Sam2VisionEncoderOutput = self.get_image_features(pixel_values, return_dict=True, **kwargs)
  1282. feature_maps = image_outputs.fpn_hidden_states
  1283. vision_hidden_states = image_outputs.hidden_states
  1284. vision_attentions = image_outputs.attentions
  1285. # add no memory embedding to the last feature map
  1286. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  1287. # reshape feature maps to the same shape as the backbone feature sizes
  1288. image_embeddings = [
  1289. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  1290. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  1291. ]
  1292. if input_points is not None and input_labels is None:
  1293. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  1294. if input_points is None and input_boxes is None:
  1295. # If no points are provide, pad with an empty point (with label -1)
  1296. input_points = torch.zeros(
  1297. batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
  1298. )
  1299. input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
  1300. if input_masks is not None:
  1301. # If mask_inputs is provided, downsize it into low-res mask input if needed
  1302. # and feed it as a dense mask prompt into the SAM mask encoder
  1303. if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
  1304. input_masks = F.interpolate(
  1305. input_masks.float(),
  1306. size=self.prompt_encoder.mask_input_size,
  1307. align_corners=False,
  1308. mode="bilinear",
  1309. antialias=True, # use antialias for downsampling
  1310. ).to(input_masks.dtype)
  1311. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1312. input_points=input_points,
  1313. input_labels=input_labels,
  1314. input_boxes=input_boxes,
  1315. input_masks=input_masks,
  1316. )
  1317. low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
  1318. image_embeddings=image_embeddings[-1],
  1319. image_positional_embeddings=image_positional_embeddings,
  1320. sparse_prompt_embeddings=sparse_embeddings,
  1321. dense_prompt_embeddings=dense_embeddings,
  1322. multimask_output=multimask_output,
  1323. high_resolution_features=image_embeddings[:-1],
  1324. attention_similarity=attention_similarity,
  1325. target_embedding=target_embedding,
  1326. **kwargs,
  1327. )
  1328. return Sam2ImageSegmentationOutput(
  1329. iou_scores=iou_scores,
  1330. pred_masks=low_res_multimasks,
  1331. object_score_logits=object_score_logits,
  1332. image_embeddings=image_embeddings,
  1333. vision_hidden_states=vision_hidden_states,
  1334. vision_attentions=vision_attentions,
  1335. )
  1336. @can_return_tuple
  1337. @auto_docstring
  1338. def get_image_features(
  1339. self,
  1340. pixel_values: torch.FloatTensor,
  1341. **kwargs: Unpack[TransformersKwargs],
  1342. ) -> tuple | Sam2VisionEncoderOutput:
  1343. r"""
  1344. pixel_values (`torch.FloatTensor`):
  1345. Input pixel values of shape `(batch_size, num_channels, height, width)`.
  1346. """
  1347. vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)
  1348. feature_maps = vision_outputs.fpn_hidden_states
  1349. feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
  1350. # precompute projected level 0 and level 1 features in SAM decoder
  1351. # to avoid running it again on every SAM click
  1352. feature_maps = list(feature_maps)
  1353. feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
  1354. feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
  1355. # flatten NxCxHxW to HWxNxC
  1356. feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
  1357. feature_maps_position_embeddings = [
  1358. feature_map_position_embedding.flatten(2).permute(2, 0, 1)
  1359. for feature_map_position_embedding in feature_maps_position_embeddings
  1360. ]
  1361. vision_outputs.fpn_hidden_states = feature_maps
  1362. vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
  1363. return vision_outputs
  1364. __all__ = ["Sam2Model", "Sam2VisionModel", "Sam2PreTrainedModel", "Sam2HieraDetModel"]