modular_sam2.py 62 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438
  1. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SAM 2 model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. from typing import Union
  18. import numpy as np
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...image_processing_backends import TorchvisionBackend
  25. from ...image_processing_utils import BatchFeature, get_size_dict
  26. from ...image_utils import (
  27. IMAGENET_DEFAULT_MEAN,
  28. IMAGENET_DEFAULT_STD,
  29. ChannelDimension,
  30. ImageInput,
  31. PILImageResampling,
  32. SizeDict,
  33. )
  34. from ...modeling_layers import GradientCheckpointingLayer
  35. from ...modeling_outputs import BaseModelOutputWithPooling
  36. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  37. from ...processing_utils import ImagesKwargs, Unpack
  38. from ...utils import ModelOutput, TensorType, auto_docstring, can_return_tuple, logging
  39. from ...utils.generic import (
  40. TransformersKwargs,
  41. is_flash_attention_requested,
  42. merge_with_config_defaults,
  43. )
  44. from ...utils.output_capturing import capture_outputs
  45. from ..auto import AutoModel
  46. from ..maskformer.modeling_maskformer import MaskFormerSinePositionEmbedding
  47. from ..sam.image_processing_sam import SamImageProcessor
  48. from ..sam.modeling_sam import (
  49. SamLayerNorm,
  50. SamMaskDecoder,
  51. SamMaskEmbedding,
  52. SamModel,
  53. SamPromptEncoder,
  54. SamTwoWayAttentionBlock,
  55. SamTwoWayTransformer,
  56. eager_attention_forward,
  57. )
  58. from ..vitdet.modeling_vitdet import window_partition, window_unpartition
  59. from .configuration_sam2 import (
  60. Sam2Config,
  61. Sam2HieraDetConfig,
  62. Sam2MaskDecoderConfig,
  63. Sam2PromptEncoderConfig,
  64. Sam2VisionConfig,
  65. )
  66. logger = logging.get_logger(__name__)
  67. class Sam2ImageProcessorKwargs(ImagesKwargs, total=False):
  68. r"""
  69. mask_size (`dict[str, int]`, *optional*):
  70. The size `{"height": int, "width": int}` to resize the segmentation maps to.
  71. """
  72. mask_size: dict[str, int]
  73. @auto_docstring
  74. class Sam2ImageProcessor(SamImageProcessor):
  75. resample = PILImageResampling.BILINEAR
  76. image_mean = IMAGENET_DEFAULT_MEAN
  77. image_std = IMAGENET_DEFAULT_STD
  78. size = {"height": 1024, "width": 1024}
  79. mask_size = {"height": 256, "width": 256}
  80. do_resize = True
  81. do_rescale = True
  82. do_normalize = True
  83. do_convert_rgb = True
  84. valid_kwargs = Sam2ImageProcessorKwargs
  85. # disable SAM padding logic
  86. do_pad = None
  87. pad_size = None
  88. mask_pad_size = None
  89. def __init__(self, **kwargs: Unpack[Sam2ImageProcessorKwargs]):
  90. TorchvisionBackend.__init__(self, **kwargs)
  91. def _preprocess(
  92. self,
  93. images: list["torch.Tensor"],
  94. return_tensors: str | TensorType | None,
  95. **kwargs,
  96. ) -> "torch.Tensor":
  97. return TorchvisionBackend._preprocess(self, images, return_tensors=return_tensors, **kwargs).pixel_values
  98. @auto_docstring
  99. def preprocess(
  100. self,
  101. images: ImageInput,
  102. segmentation_maps: ImageInput | None = None,
  103. **kwargs: Unpack[Sam2ImageProcessorKwargs],
  104. ) -> BatchFeature:
  105. r"""
  106. segmentation_maps (`ImageInput`, *optional*):
  107. The segmentation maps to preprocess.
  108. """
  109. return super().preprocess(images, segmentation_maps, **kwargs)
  110. def _preprocess_image_like_inputs(
  111. self,
  112. images: ImageInput,
  113. segmentation_maps: ImageInput | None,
  114. do_convert_rgb: bool,
  115. input_data_format: ChannelDimension,
  116. device: Union[str, "torch.device"] | None = None,
  117. **kwargs: Unpack[Sam2ImageProcessorKwargs],
  118. ) -> BatchFeature:
  119. """
  120. Preprocess image-like inputs.
  121. """
  122. images = self._prepare_image_like_inputs(
  123. images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
  124. )
  125. original_sizes = [image.shape[-2:] for image in images]
  126. images_kwargs = kwargs.copy()
  127. pixel_values = self._preprocess(images, **images_kwargs)
  128. data = {
  129. "pixel_values": pixel_values,
  130. "original_sizes": original_sizes,
  131. }
  132. if segmentation_maps is not None:
  133. processed_segmentation_maps = self._prepare_image_like_inputs(
  134. images=segmentation_maps,
  135. expected_ndims=2,
  136. do_convert_rgb=False,
  137. input_data_format=ChannelDimension.FIRST,
  138. )
  139. segmentation_maps_kwargs = kwargs.copy()
  140. segmentation_maps_kwargs.update(
  141. {
  142. "do_normalize": False,
  143. "do_rescale": False,
  144. "resample": PILImageResampling.NEAREST,
  145. "size": segmentation_maps_kwargs.pop("mask_size"),
  146. }
  147. )
  148. processed_segmentation_maps = self._preprocess(
  149. images=processed_segmentation_maps, **segmentation_maps_kwargs
  150. )
  151. data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
  152. return BatchFeature(data=data, tensor_type=kwargs["return_tensors"])
  153. def _standardize_kwargs(
  154. self,
  155. mask_size: SizeDict | None = None,
  156. **kwargs,
  157. ) -> dict:
  158. """
  159. Update kwargs that need further processing before being validated.
  160. """
  161. if mask_size is not None and not isinstance(mask_size, SizeDict):
  162. mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
  163. kwargs["mask_size"] = mask_size
  164. return TorchvisionBackend._standardize_kwargs(self, **kwargs)
  165. def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor:
  166. """
  167. Apply non-overlapping constraints to the object scores in pred_masks. Here we
  168. keep only the highest scoring object at each spatial location in pred_masks.
  169. """
  170. batch_size = pred_masks.size(0)
  171. if batch_size == 1:
  172. return pred_masks
  173. device = pred_masks.device
  174. # "max_obj_inds": object index of the object with the highest score at each location
  175. max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
  176. # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
  177. batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
  178. keep = max_obj_inds == batch_obj_inds
  179. # suppress overlapping regions' scores below -10.0 so that the foreground regions
  180. # don't overlap (here sigmoid(-10.0)=4.5398e-05)
  181. pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
  182. return pred_masks
  183. def post_process_masks(
  184. self,
  185. masks,
  186. original_sizes,
  187. mask_threshold=0.0,
  188. binarize=True,
  189. max_hole_area=0.0,
  190. max_sprinkle_area=0.0,
  191. apply_non_overlapping_constraints=False,
  192. **kwargs,
  193. ):
  194. """
  195. Remove padding and upscale masks to the original image size.
  196. Args:
  197. masks (`Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]`):
  198. Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
  199. original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
  200. The original sizes of each image before it was resized to the model's expected input shape, in (height,
  201. width) format.
  202. mask_threshold (`float`, *optional*, defaults to 0.0):
  203. Threshold for binarization and post-processing operations.
  204. binarize (`bool`, *optional*, defaults to `True`):
  205. Whether to binarize the masks.
  206. max_hole_area (`float`, *optional*, defaults to 0.0):
  207. The maximum area of a hole to fill.
  208. max_sprinkle_area (`float`, *optional*, defaults to 0.0):
  209. The maximum area of a sprinkle to fill.
  210. apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
  211. Whether to apply non-overlapping constraints to the masks.
  212. Returns:
  213. (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
  214. is given by original_size.
  215. """
  216. if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
  217. original_sizes = original_sizes.tolist()
  218. # TODO: add connected components kernel for postprocessing
  219. output_masks = []
  220. for i, original_size in enumerate(original_sizes):
  221. if isinstance(masks[i], np.ndarray):
  222. masks[i] = torch.from_numpy(masks[i])
  223. elif not isinstance(masks[i], torch.Tensor):
  224. raise TypeError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
  225. interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False)
  226. if apply_non_overlapping_constraints:
  227. interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask)
  228. if binarize:
  229. interpolated_mask = interpolated_mask > mask_threshold
  230. output_masks.append(interpolated_mask)
  231. return output_masks
  232. def _get_preprocess_shape(self):
  233. raise NotImplementedError("No _get_preprocess_shape for SAM 2.")
  234. def resize(self):
  235. raise NotImplementedError("No need to override resize for SAM 2.")
  236. @dataclass
  237. @auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
  238. class Sam2VisionEncoderOutput(BaseModelOutputWithPooling):
  239. r"""
  240. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  241. Sequence of hidden-states at the output of the last layer of the model.
  242. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  243. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  244. one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
  245. model at the output of each stage.
  246. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  247. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  248. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  249. the self-attention heads.
  250. fpn_hidden_states (`tuple(torch.FloatTensor)`):
  251. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  252. `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
  253. fpn_position_encoding (`tuple(torch.FloatTensor)`):
  254. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  255. `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
  256. """
  257. fpn_hidden_states: torch.FloatTensor | None = None
  258. fpn_position_encoding: torch.FloatTensor | None = None
  259. @dataclass
  260. @auto_docstring(custom_intro="Base class for the Sam2 model's output.")
  261. class Sam2ImageSegmentationOutput(ModelOutput):
  262. r"""
  263. iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
  264. The Intersection over Union (IoU) scores of the predicted masks.
  265. pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
  266. The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
  267. by the processor to be brought to the original image size.
  268. object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
  269. Logits for the object score, indicating if an object is present.
  270. image_embeddings (`tuple(torch.FloatTensor)`):
  271. The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
  272. tensor has shape `(batch_size, channels, height, width)`.
  273. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  274. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
  275. Hidden-states of the vision model at the output of each stage.
  276. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  277. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  278. Attentions weights of the vision model.
  279. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  280. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  281. Attentions weights of the mask decoder.
  282. """
  283. iou_scores: torch.FloatTensor | None = None
  284. pred_masks: torch.FloatTensor | None = None
  285. object_score_logits: torch.FloatTensor | None = None
  286. image_embeddings: tuple[torch.FloatTensor, ...] = None
  287. vision_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  288. vision_attentions: tuple[torch.FloatTensor, ...] | None = None
  289. mask_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  290. class Sam2PatchEmbeddings(nn.Module):
  291. r"""
  292. Turns pixel values into patch embeddings for transformer consumption.
  293. Args:
  294. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  295. Pixel values. Pixel values can be obtained using
  296. [`AutoImageProcessor`]. See [`Sam2ImageProcessor.__call__`] for details.
  297. Returns:
  298. embeddings (`torch.FloatTensor`):
  299. Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
  300. """
  301. def __init__(self, config: Sam2HieraDetConfig):
  302. super().__init__()
  303. num_channels = config.num_channels
  304. hidden_size = config.hidden_size
  305. self.projection = nn.Conv2d(
  306. num_channels,
  307. hidden_size,
  308. kernel_size=config.patch_kernel_size,
  309. stride=config.patch_stride,
  310. padding=config.patch_padding,
  311. )
  312. def forward(self, pixel_values):
  313. _, num_channels, height, width = pixel_values.shape
  314. embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)).permute(0, 2, 3, 1)
  315. return embeddings
  316. class Sam2SinePositionEmbedding(MaskFormerSinePositionEmbedding):
  317. pass
  318. class Sam2VisionNeck(nn.Module):
  319. def __init__(self, config: Sam2VisionConfig):
  320. super().__init__()
  321. self.config = config
  322. self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
  323. self.convs = nn.ModuleList()
  324. for in_channels in config.backbone_channel_list:
  325. self.convs.append(
  326. nn.Conv2d(
  327. in_channels=in_channels,
  328. out_channels=config.fpn_hidden_size,
  329. kernel_size=config.fpn_kernel_size,
  330. stride=config.fpn_stride,
  331. padding=config.fpn_padding,
  332. ),
  333. )
  334. self.fpn_top_down_levels = config.fpn_top_down_levels
  335. def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
  336. fpn_hidden_states = ()
  337. fpn_position_encoding = ()
  338. # forward in top-down order (from low to high resolution)
  339. n = len(self.convs) - 1
  340. for i in range(n, -1, -1):
  341. lateral_features = hidden_states[i].permute(0, 3, 1, 2)
  342. lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
  343. if i not in self.fpn_top_down_levels or i == n:
  344. prev_features = lateral_features
  345. else:
  346. top_down_features = F.interpolate(
  347. prev_features.to(dtype=torch.float32),
  348. scale_factor=2.0,
  349. mode="nearest",
  350. align_corners=None,
  351. antialias=False,
  352. ).to(lateral_features.dtype)
  353. prev_features = lateral_features + top_down_features
  354. prev_position_encoding = self.position_encoding(
  355. prev_features.shape, prev_features.device, prev_features.dtype
  356. ).to(prev_features.dtype)
  357. fpn_hidden_states += (prev_features,)
  358. fpn_position_encoding += (prev_position_encoding,)
  359. return fpn_hidden_states, fpn_position_encoding
  360. def do_pool(x: torch.Tensor, query_stride: int | None = None) -> torch.Tensor:
  361. if query_stride is None:
  362. return x
  363. # (B, H, W, C) -> (B, C, H, W)
  364. x = x.permute(0, 3, 1, 2)
  365. x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
  366. # (B, C, H', W') -> (B, H', W', C)
  367. x = x.permute(0, 2, 3, 1)
  368. return x
  369. class Sam2MultiScaleAttention(nn.Module):
  370. def __init__(
  371. self,
  372. config: Sam2HieraDetConfig,
  373. dim: int,
  374. dim_out: int,
  375. num_attention_heads: int,
  376. query_stride: tuple[int, int] | None = None,
  377. ):
  378. super().__init__()
  379. self.config = config
  380. self.dim = dim
  381. self.dim_out = dim_out
  382. self.query_stride = query_stride
  383. self.num_attention_heads = num_attention_heads
  384. head_dim = dim_out // num_attention_heads
  385. self.scale = head_dim**-0.5
  386. self.qkv = nn.Linear(dim, dim_out * 3)
  387. self.proj = nn.Linear(dim_out, dim_out)
  388. self.is_causal = False
  389. def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
  390. batch_size, height, width, _ = hidden_states.shape
  391. # qkv with shape (B, H * W, 3, nHead, C)
  392. qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  393. # q, k, v with shape (B, H * W, nheads, C)
  394. query, key, value = torch.unbind(qkv, 2)
  395. attn_weights = (query * self.scale) @ key.transpose(-2, -1)
  396. attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
  397. # Q pooling (for downsample at stage changes)
  398. if self.query_stride:
  399. query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
  400. height, width = query.shape[1:3] # downsampled shape
  401. query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
  402. # transpose query, key, value to (B, nHead, H * W, C)
  403. query = query.transpose(1, 2)
  404. key = key.transpose(1, 2)
  405. value = value.transpose(1, 2)
  406. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  407. self.config._attn_implementation, eager_attention_forward
  408. )
  409. attn_output, _ = attention_interface(
  410. self,
  411. query,
  412. key,
  413. value,
  414. attention_mask=None,
  415. is_causal=self.is_causal,
  416. scaling=self.scale,
  417. **kwargs,
  418. )
  419. attn_output = attn_output.reshape(batch_size, height, width, -1)
  420. attn_output = self.proj(attn_output)
  421. return attn_output
  422. class Sam2FeedForward(nn.Module):
  423. def __init__(
  424. self,
  425. input_dim: int,
  426. hidden_dim: int,
  427. output_dim: int,
  428. num_layers: int,
  429. activation: str = "relu",
  430. sigmoid_output: bool = False,
  431. ):
  432. super().__init__()
  433. self.num_layers = num_layers
  434. self.activation = ACT2FN[activation]
  435. self.proj_in = nn.Linear(input_dim, hidden_dim)
  436. self.proj_out = nn.Linear(hidden_dim, output_dim)
  437. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  438. self.sigmoid_output = sigmoid_output
  439. def forward(self, hidden_states):
  440. hidden_states = self.proj_in(hidden_states)
  441. hidden_states = self.activation(hidden_states)
  442. for layer in self.layers:
  443. hidden_states = self.activation(layer(hidden_states))
  444. hidden_states = self.proj_out(hidden_states)
  445. if self.sigmoid_output:
  446. hidden_states = F.sigmoid(hidden_states)
  447. return hidden_states
  448. class Sam2MultiScaleBlock(GradientCheckpointingLayer):
  449. def __init__(
  450. self,
  451. config: Sam2HieraDetConfig,
  452. stage_idx: int,
  453. block_idx: int,
  454. total_block_idx: int,
  455. ):
  456. super().__init__()
  457. # take embed dim from previous stage if first block of stage
  458. self.dim = (
  459. config.embed_dim_per_stage[stage_idx - 1]
  460. if stage_idx > 0 and block_idx == 0
  461. else config.embed_dim_per_stage[stage_idx]
  462. )
  463. self.dim_out = config.embed_dim_per_stage[stage_idx]
  464. self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
  465. # take window size from previous stage if first block of stage
  466. self.window_size = (
  467. config.window_size_per_stage[stage_idx - 1]
  468. if stage_idx > 0 and block_idx == 0
  469. else config.window_size_per_stage[stage_idx]
  470. )
  471. self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
  472. # use query stride for first block of stage if stage is a query pool stage
  473. self.query_stride = (
  474. config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
  475. )
  476. self.attn = Sam2MultiScaleAttention(
  477. config,
  478. self.dim,
  479. self.dim_out,
  480. num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
  481. query_stride=self.query_stride,
  482. )
  483. self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
  484. self.mlp = Sam2FeedForward(
  485. self.dim_out,
  486. int(self.dim_out * config.mlp_ratio),
  487. self.dim_out,
  488. num_layers=2,
  489. activation=config.hidden_act,
  490. )
  491. if self.dim != self.dim_out:
  492. self.proj = nn.Linear(self.dim, self.dim_out)
  493. def forward(
  494. self,
  495. hidden_states: torch.Tensor,
  496. **kwargs: Unpack[TransformersKwargs],
  497. ) -> torch.FloatTensor:
  498. residual = hidden_states # batch_size, height, width, channel
  499. hidden_states = self.layer_norm1(hidden_states)
  500. # Skip connection
  501. if self.dim != self.dim_out:
  502. residual = do_pool(self.proj(hidden_states), self.query_stride)
  503. # Window partition
  504. window_size = self.window_size
  505. if self.window_size > 0:
  506. H, W = hidden_states.shape[1], hidden_states.shape[2]
  507. hidden_states, pad_hw = window_partition(hidden_states, window_size)
  508. # Window Attention + Q Pooling (if stage change)
  509. attn_output = self.attn(
  510. hidden_states=hidden_states,
  511. **kwargs,
  512. )
  513. hidden_states = attn_output
  514. if self.query_stride:
  515. # Shapes have changed due to Q pooling
  516. window_size = self.window_size // self.query_stride[0]
  517. H, W = residual.shape[1:3]
  518. pad_h = (window_size - H % window_size) % window_size
  519. pad_w = (window_size - W % window_size) % window_size
  520. pad_hw = (H + pad_h, W + pad_w)
  521. # Reverse window partition
  522. if self.window_size > 0:
  523. hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
  524. hidden_states = residual + hidden_states
  525. layernorm_output = self.layer_norm2(hidden_states)
  526. hidden_states = hidden_states + self.mlp(layernorm_output)
  527. return hidden_states
  528. @dataclass
  529. @auto_docstring(
  530. custom_intro="""
  531. Hiera model's outputs that also contains a pooling of the last hidden states.
  532. """
  533. )
  534. class Sam2HieraDetModelOutput(ModelOutput):
  535. r"""
  536. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  537. hidden-states at the output of the last layer of the model.
  538. intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
  539. Sequence of hidden-states at the output of the intermediate layers of the model.
  540. """
  541. last_hidden_state: torch.FloatTensor | None = None
  542. intermediate_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  543. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  544. attentions: tuple[torch.FloatTensor, ...] | None = None
  545. @auto_docstring
  546. class Sam2PreTrainedModel(PreTrainedModel):
  547. config_class = Sam2Config
  548. base_model_prefix = "sam2"
  549. main_input_name = "pixel_values"
  550. input_modalities = ("image",)
  551. _supports_sdpa = True
  552. _supports_flash_attn = True
  553. _supports_attention_backend = True
  554. _keys_to_ignore_on_load_unexpected = [
  555. r"^memory_.*",
  556. r"^mask_downsample.*",
  557. r"^object_pointer_proj.*",
  558. r"^temporal_positional_encoding_projection_layer.*",
  559. "no_memory_positional_encoding",
  560. "no_object_pointer",
  561. "occlusion_spatial_embedding_parameter",
  562. ]
  563. @torch.no_grad()
  564. def _init_weights(self, module):
  565. super()._init_weights(module)
  566. if isinstance(module, Sam2HieraDetModel):
  567. if module.pos_embed is not None:
  568. init.zeros_(module.pos_embed)
  569. if module.pos_embed_window is not None:
  570. init.zeros_(module.pos_embed_window)
  571. elif isinstance(module, Sam2PositionalEmbedding):
  572. init.normal_(module.positional_embedding, std=module.scale)
  573. elif isinstance(module, Sam2Model):
  574. if module.no_memory_embedding is not None:
  575. init.zeros_(module.no_memory_embedding)
  576. class Sam2HieraDetModel(Sam2PreTrainedModel):
  577. config_class = Sam2HieraDetConfig
  578. main_input_name = "pixel_values"
  579. _can_record_outputs = {
  580. "hidden_states": Sam2MultiScaleBlock,
  581. "attentions": Sam2MultiScaleAttention,
  582. }
  583. def __init__(self, config: Sam2HieraDetConfig):
  584. super().__init__(config)
  585. self.patch_embed = Sam2PatchEmbeddings(config)
  586. # Windowed positional embedding (https://huggingface.co/papers/2311.05613)
  587. self.pos_embed = nn.Parameter(
  588. torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
  589. )
  590. self.pos_embed_window = nn.Parameter(
  591. torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
  592. )
  593. self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
  594. self.blocks = nn.ModuleList()
  595. total_block_idx = 0
  596. for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
  597. for block_idx in range(blocks_per_stage):
  598. block = Sam2MultiScaleBlock(
  599. config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
  600. )
  601. self.blocks.append(block)
  602. total_block_idx += 1
  603. self.post_init()
  604. def get_input_embeddings(self):
  605. return self.patch_embed
  606. def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
  607. h, w = hw
  608. window_embed = self.pos_embed_window
  609. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  610. pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
  611. pos_embed = pos_embed.permute(0, 2, 3, 1)
  612. return pos_embed
  613. @merge_with_config_defaults
  614. @capture_outputs
  615. def forward(
  616. self,
  617. pixel_values: torch.FloatTensor | None = None,
  618. **kwargs: Unpack[TransformersKwargs],
  619. ) -> tuple | Sam2HieraDetModelOutput:
  620. if pixel_values is None:
  621. raise ValueError("You have to specify pixel_values")
  622. hidden_states = self.patch_embed(pixel_values)
  623. hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
  624. intermediate_hidden_states = ()
  625. for i, block_module in enumerate(self.blocks):
  626. hidden_states = block_module(hidden_states, **kwargs)
  627. if i in self.stage_ends:
  628. intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
  629. return Sam2HieraDetModelOutput(
  630. last_hidden_state=hidden_states,
  631. intermediate_hidden_states=intermediate_hidden_states,
  632. )
  633. @auto_docstring(
  634. custom_intro="""
  635. The vision model from Sam without any head or projection on top.
  636. """
  637. )
  638. class Sam2VisionModel(Sam2PreTrainedModel):
  639. config_class = Sam2VisionConfig
  640. main_input_name = "pixel_values"
  641. _can_record_outputs = {
  642. "hidden_states": Sam2MultiScaleBlock,
  643. "attentions": Sam2MultiScaleAttention,
  644. }
  645. def __init__(self, config: Sam2VisionConfig):
  646. super().__init__(config)
  647. self.config = config
  648. self.backbone = AutoModel.from_config(config.backbone_config)
  649. self.neck = Sam2VisionNeck(config)
  650. self.num_feature_levels = config.num_feature_levels
  651. self.post_init()
  652. def get_input_embeddings(self):
  653. return self.backbone.get_input_embeddings()
  654. @can_return_tuple
  655. def forward(
  656. self,
  657. pixel_values: torch.FloatTensor | None = None,
  658. **kwargs: Unpack[TransformersKwargs],
  659. ) -> tuple | Sam2VisionEncoderOutput:
  660. if pixel_values is None:
  661. raise ValueError("You have to specify pixel_values")
  662. # Forward through backbone
  663. backbone_output = self.backbone(pixel_values, **kwargs)
  664. hidden_states = backbone_output.last_hidden_state
  665. intermediate_hidden_states = backbone_output.intermediate_hidden_states
  666. fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
  667. # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
  668. fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
  669. fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
  670. return Sam2VisionEncoderOutput(
  671. last_hidden_state=hidden_states,
  672. fpn_hidden_states=fpn_hidden_states,
  673. fpn_position_encoding=fpn_position_encoding,
  674. hidden_states=backbone_output.hidden_states,
  675. attentions=backbone_output.attentions,
  676. )
  677. class Sam2PositionalEmbedding(nn.Module):
  678. def __init__(self, config: Sam2PromptEncoderConfig):
  679. super().__init__()
  680. self.scale = config.scale
  681. positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
  682. self.register_buffer("positional_embedding", positional_embedding)
  683. def forward(self, input_coords, input_shape=None):
  684. """Positionally encode points that are normalized to [0,1]."""
  685. coordinates = input_coords.clone()
  686. if input_shape is not None:
  687. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  688. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  689. coordinates.to(torch.float32)
  690. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  691. coordinates = 2 * coordinates - 1
  692. coordinates = coordinates.to(self.positional_embedding.dtype)
  693. coordinates = coordinates @ self.positional_embedding
  694. coordinates = 2 * np.pi * coordinates
  695. # outputs d_1 x ... x d_n x channel shape
  696. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  697. class Sam2MaskEmbedding(SamMaskEmbedding):
  698. pass
  699. class Sam2PromptEncoder(SamPromptEncoder):
  700. def __init__(self, config: Sam2PromptEncoderConfig):
  701. nn.Module.__init__(self)
  702. self.shared_embedding = Sam2PositionalEmbedding(config)
  703. self.mask_embed = Sam2MaskEmbedding(config)
  704. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  705. self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  706. self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
  707. self.input_image_size = config.image_size
  708. self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
  709. self.hidden_size = config.hidden_size
  710. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  711. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  712. """Embeds point prompts."""
  713. points = points + 0.5 # Shift to center of pixel
  714. if pad:
  715. points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
  716. labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
  717. input_shape = (self.input_image_size, self.input_image_size)
  718. point_embedding = self.shared_embedding(points, input_shape)
  719. # torch.where and expanding the labels tensor is required by the ONNX export
  720. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  721. # This is required for the ONNX export. The dtype, device need to be explicitly
  722. # specified as otherwise torch.onnx.export interprets as double
  723. point_embedding = torch.where(
  724. labels[..., None] != -10,
  725. point_embedding,
  726. torch.zeros_like(point_embedding),
  727. )
  728. # Add point embeddings for labels >= 0
  729. point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
  730. return point_embedding
  731. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  732. """Embeds box prompts."""
  733. boxes = boxes + 0.5 # Shift to center of pixel
  734. coords = boxes.view(*boxes.shape[:2], 2, 2)
  735. # add padding point for consistency with the original implementation
  736. coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
  737. corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
  738. corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
  739. corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
  740. corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
  741. return corner_embedding
  742. class Sam2Attention(nn.Module):
  743. """
  744. SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  745. values.
  746. """
  747. def __init__(self, config, downsample_rate=None):
  748. super().__init__()
  749. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  750. self.config = config
  751. self.hidden_size = config.hidden_size
  752. self.internal_dim = config.hidden_size // downsample_rate
  753. self.num_attention_heads = config.num_attention_heads
  754. self.head_dim = self.internal_dim // config.num_attention_heads
  755. self.scaling = self.head_dim**-0.5
  756. self.is_causal = False
  757. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  758. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  759. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  760. self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
  761. def forward(
  762. self,
  763. query: torch.Tensor,
  764. key: torch.Tensor,
  765. value: torch.Tensor,
  766. attention_similarity: torch.Tensor | None = None,
  767. **kwargs: Unpack[TransformersKwargs],
  768. ) -> tuple[torch.Tensor, torch.Tensor]:
  769. # Input projections
  770. batch_size, point_batch_size = query.shape[:2]
  771. new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
  772. query = self.q_proj(query).view(*new_shape).transpose(1, 2)
  773. key = self.k_proj(key).view(*new_shape).transpose(1, 2)
  774. value = self.v_proj(value).view(*new_shape).transpose(1, 2)
  775. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  776. self.config._attn_implementation, eager_attention_forward
  777. )
  778. if is_flash_attention_requested(self.config) and attention_similarity is not None:
  779. # Target guided masks are represented as float masks and are incompatible with Flash Attention
  780. # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
  781. attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
  782. logger.warning_once(
  783. "Falling back to SDPA for target-guided attention because "
  784. "Flash Attention does not support additive bias masks."
  785. )
  786. attn_output, attn_weights = attention_interface(
  787. self,
  788. query,
  789. key,
  790. value,
  791. attention_mask=attention_similarity,
  792. dropout=0.0,
  793. scaling=self.scaling,
  794. is_causal=self.is_causal,
  795. **kwargs,
  796. )
  797. attn_output = attn_output.reshape(
  798. batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
  799. ).contiguous()
  800. attn_output = self.o_proj(attn_output)
  801. return attn_output, attn_weights
  802. class Sam2TwoWayAttentionBlock(SamTwoWayAttentionBlock, GradientCheckpointingLayer):
  803. def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
  804. nn.Module.__init__(self)
  805. self.self_attn = Sam2Attention(config, downsample_rate=1)
  806. self.layer_norm1 = nn.LayerNorm(config.hidden_size)
  807. self.cross_attn_token_to_image = Sam2Attention(config)
  808. self.layer_norm2 = nn.LayerNorm(config.hidden_size)
  809. self.mlp = Sam2FeedForward(
  810. config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
  811. )
  812. self.layer_norm3 = nn.LayerNorm(config.hidden_size)
  813. self.layer_norm4 = nn.LayerNorm(config.hidden_size)
  814. self.cross_attn_image_to_token = Sam2Attention(config)
  815. self.skip_first_layer_pe = skip_first_layer_pe
  816. class Sam2TwoWayTransformer(SamTwoWayTransformer):
  817. pass
  818. class Sam2LayerNorm(SamLayerNorm):
  819. pass
  820. class Sam2MaskDecoder(SamMaskDecoder):
  821. def __init__(self, config: Sam2MaskDecoderConfig):
  822. super().__init__(config)
  823. del self.iou_prediction_head
  824. self.iou_prediction_head = Sam2FeedForward(
  825. self.hidden_size,
  826. config.iou_head_hidden_dim,
  827. self.num_mask_tokens,
  828. config.iou_head_depth,
  829. sigmoid_output=True,
  830. )
  831. self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
  832. self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
  833. self.obj_score_token = nn.Embedding(1, self.hidden_size)
  834. self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
  835. self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
  836. self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
  837. self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
  838. def _get_stability_scores(self, mask_logits):
  839. """
  840. Compute stability scores of the mask logits based on the IoU between upper and
  841. lower thresholds.
  842. """
  843. mask_logits = mask_logits.flatten(-2)
  844. stability_delta = self.dynamic_multimask_stability_delta
  845. area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
  846. area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
  847. stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
  848. return stability_scores
  849. def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
  850. """
  851. When outputting a single mask, if the stability score from the current single-mask
  852. output (based on output token 0) falls below a threshold, we instead select from
  853. multi-mask outputs (based on output token 1~3) the mask with the highest predicted
  854. IoU score. This is intended to ensure a valid mask for both clicking and tracking.
  855. """
  856. # The best mask from multimask output tokens (1~3)
  857. multimask_logits = all_mask_logits[:, :, 1:, :, :]
  858. multimask_iou_scores = all_iou_scores[:, :, 1:]
  859. best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
  860. best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
  861. best_scores_inds_expanded = best_scores_inds_expanded.expand(
  862. -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
  863. )
  864. best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
  865. best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
  866. # The mask from singlemask output token 0 and its stability score
  867. singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
  868. singlemask_iou_scores = all_iou_scores[:, :, 0:1]
  869. stability_scores = self._get_stability_scores(singlemask_logits)
  870. is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
  871. # Dynamically fall back to best multimask output upon low stability scores.
  872. mask_logits_out = torch.where(
  873. is_stable[..., None, None].expand_as(singlemask_logits),
  874. singlemask_logits,
  875. best_multimask_logits,
  876. )
  877. iou_scores_out = torch.where(
  878. is_stable.expand_as(singlemask_iou_scores),
  879. singlemask_iou_scores,
  880. best_multimask_iou_scores,
  881. )
  882. return mask_logits_out, iou_scores_out
  883. def forward(
  884. self,
  885. image_embeddings: torch.Tensor,
  886. image_positional_embeddings: torch.Tensor,
  887. sparse_prompt_embeddings: torch.Tensor,
  888. dense_prompt_embeddings: torch.Tensor,
  889. multimask_output: bool,
  890. high_resolution_features: list[torch.Tensor],
  891. attention_similarity: torch.Tensor | None = None,
  892. target_embedding: torch.Tensor | None = None,
  893. **kwargs: Unpack[TransformersKwargs],
  894. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  895. """
  896. Predict masks given image and prompt embeddings.
  897. Args:
  898. image_embeddings (`torch.Tensor`):
  899. The embeddings from the image encoder.
  900. image_positional_embeddings (`torch.Tensor`):
  901. Positional encoding with the shape of image_embeddings.
  902. sparse_prompt_embeddings (`torch.Tensor`):
  903. The embeddings of the points and boxes.
  904. dense_prompt_embeddings (`torch.Tensor`):
  905. The embeddings of the mask inputs.
  906. multimask_output (`bool`):
  907. Whether to return multiple masks or a single mask.
  908. high_resolution_features (`list[torch.Tensor]`, *optional*):
  909. The high-resolution features from the vision encoder.
  910. attention_similarity (`torch.Tensor`, *optional*):
  911. The attention similarity tensor.
  912. target_embedding (`torch.Tensor`, *optional*):
  913. The target embedding.
  914. """
  915. batch_size, num_channels, height, width = image_embeddings.shape
  916. point_batch_size = sparse_prompt_embeddings.shape[1]
  917. # Concatenate output tokens
  918. output_tokens = torch.cat(
  919. [
  920. self.obj_score_token.weight,
  921. self.iou_token.weight,
  922. self.mask_tokens.weight,
  923. ],
  924. dim=0,
  925. )
  926. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  927. if sparse_prompt_embeddings.shape[0] != 0:
  928. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  929. else:
  930. tokens = output_tokens
  931. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  932. # Expand per-image data in batch direction to be per-mask
  933. image_embeddings = image_embeddings + dense_prompt_embeddings
  934. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
  935. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  936. # Run the transformer
  937. point_embeddings, image_embeddings = self.transformer(
  938. point_embeddings=point_embeddings,
  939. image_embeddings=image_embeddings,
  940. image_positional_embeddings=image_positional_embeddings,
  941. attention_similarity=attention_similarity,
  942. target_embedding=target_embedding,
  943. **kwargs,
  944. )
  945. iou_token_out = point_embeddings[:, :, 1, :]
  946. mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
  947. # Upscale mask embeddings and predict masks using the mask tokens
  948. image_embeddings = image_embeddings.transpose(2, 3).view(
  949. batch_size * point_batch_size, num_channels, height, width
  950. )
  951. feat_s0, feat_s1 = high_resolution_features
  952. feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
  953. feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
  954. upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
  955. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  956. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
  957. hyper_in_list: list[torch.Tensor] = []
  958. for i in range(self.num_mask_tokens):
  959. current_mlp = self.output_hypernetworks_mlps[i]
  960. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  961. hyper_in = torch.stack(hyper_in_list, dim=2)
  962. _, num_channels, height, width = upscaled_embedding.shape
  963. upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
  964. masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
  965. # Generate mask quality predictions
  966. iou_pred = self.iou_prediction_head(iou_token_out)
  967. object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
  968. # Select the correct mask or masks for output
  969. if multimask_output:
  970. mask_slice = slice(1, None)
  971. masks = masks[:, :, mask_slice, :, :]
  972. iou_pred = iou_pred[:, :, mask_slice]
  973. elif self.dynamic_multimask_via_stability and not self.training:
  974. mask_slice = slice(0, 1)
  975. masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
  976. else:
  977. mask_slice = slice(0, 1)
  978. masks = masks[:, :, mask_slice, :, :]
  979. iou_pred = iou_pred[:, :, mask_slice]
  980. sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
  981. return masks, iou_pred, sam_tokens_out, object_score_logits
  982. @auto_docstring(
  983. custom_intro="""
  984. Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
  985. input points and labels, boxes, or masks.
  986. """
  987. )
  988. class Sam2Model(SamModel):
  989. _tied_weights_keys = {}
  990. def __init__(self, config: Sam2Config):
  991. PreTrainedModel.__init__(self, config)
  992. self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
  993. self.vision_encoder = AutoModel.from_config(config.vision_config)
  994. self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
  995. # The module using it is not a PreTrainedModel subclass so we need this
  996. config.mask_decoder_config._attn_implementation = config._attn_implementation
  997. self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
  998. self.num_feature_levels = config.vision_config.num_feature_levels
  999. self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
  1000. # a single token to indicate no memory embedding from previous frames
  1001. self.hidden_dim = config.vision_config.fpn_hidden_size
  1002. self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  1003. self.post_init()
  1004. def get_image_wide_positional_embeddings(self) -> torch.Tensor:
  1005. size = self.prompt_encoder.image_embedding_size
  1006. target_device = self.shared_image_embedding.positional_embedding.device
  1007. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  1008. grid = torch.ones(size, device=target_device, dtype=target_dtype)
  1009. y_embed = grid.cumsum(dim=0) - 0.5
  1010. x_embed = grid.cumsum(dim=1) - 0.5
  1011. y_embed = y_embed / size[0]
  1012. x_embed = x_embed / size[1]
  1013. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  1014. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  1015. @torch.no_grad()
  1016. def get_image_embeddings(
  1017. self,
  1018. pixel_values: torch.FloatTensor,
  1019. **kwargs: Unpack[TransformersKwargs],
  1020. ) -> list[torch.Tensor]:
  1021. r"""
  1022. Returns the image embeddings by passing the pixel values through the vision encoder.
  1023. Args:
  1024. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1025. Input pixel values
  1026. """
  1027. batch_size = pixel_values.shape[0]
  1028. image_outputs = self.get_image_features(pixel_values, return_dict=True, **kwargs)
  1029. feature_maps = image_outputs.fpn_hidden_states
  1030. # add no memory embedding to the last feature map
  1031. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  1032. # reshape feature maps to the same shape as the backbone feature sizes
  1033. image_embeddings = [
  1034. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  1035. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  1036. ]
  1037. return image_embeddings
  1038. @can_return_tuple
  1039. @auto_docstring
  1040. def get_image_features(
  1041. self,
  1042. pixel_values: torch.FloatTensor,
  1043. **kwargs: Unpack[TransformersKwargs],
  1044. ) -> tuple | Sam2VisionEncoderOutput:
  1045. r"""
  1046. pixel_values (`torch.FloatTensor`):
  1047. Input pixel values of shape `(batch_size, num_channels, height, width)`.
  1048. """
  1049. vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)
  1050. feature_maps = vision_outputs.fpn_hidden_states
  1051. feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
  1052. # precompute projected level 0 and level 1 features in SAM decoder
  1053. # to avoid running it again on every SAM click
  1054. feature_maps = list(feature_maps)
  1055. feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
  1056. feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
  1057. # flatten NxCxHxW to HWxNxC
  1058. feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
  1059. feature_maps_position_embeddings = [
  1060. feature_map_position_embedding.flatten(2).permute(2, 0, 1)
  1061. for feature_map_position_embedding in feature_maps_position_embeddings
  1062. ]
  1063. vision_outputs.fpn_hidden_states = feature_maps
  1064. vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
  1065. return vision_outputs
  1066. @merge_with_config_defaults
  1067. @capture_outputs
  1068. @auto_docstring
  1069. def forward(
  1070. self,
  1071. pixel_values: torch.FloatTensor | None = None,
  1072. input_points: torch.FloatTensor | None = None,
  1073. input_labels: torch.LongTensor | None = None,
  1074. input_boxes: torch.FloatTensor | None = None,
  1075. input_masks: torch.LongTensor | None = None,
  1076. image_embeddings: torch.FloatTensor | None = None,
  1077. multimask_output: bool = True,
  1078. attention_similarity: torch.FloatTensor | None = None,
  1079. target_embedding: torch.FloatTensor | None = None,
  1080. **kwargs: Unpack[TransformersKwargs],
  1081. ) -> Sam2ImageSegmentationOutput:
  1082. r"""
  1083. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  1084. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  1085. better results. The points can be obtained by passing a list of list of list to the processor that will
  1086. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  1087. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  1088. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  1089. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  1090. coordinates of the point. If a different number of points is passed either for each image, or for each
  1091. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  1092. computation of the embedding will be skipped for these points using the labels.
  1093. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  1094. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  1095. official implementation, there are 3 types of labels
  1096. - `1`: the point is a point that contains the object of interest
  1097. - `0`: the point is a point that does not contain the object of interest
  1098. - `-1`: the point corresponds to the background
  1099. We added the label:
  1100. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  1101. The padding labels should be automatically done by the processor.
  1102. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  1103. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  1104. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  1105. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  1106. size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
  1107. In the order (`x1`, `y1`, `x2`, `y2`):
  1108. - `x1`: the x coordinate of the top left point of the input box
  1109. - `y1`: the y coordinate of the top left point of the input box
  1110. - `x2`: the x coordinate of the bottom right point of the input box
  1111. - `y2`: the y coordinate of the bottom right point of the input box
  1112. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  1113. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  1114. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  1115. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  1116. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  1117. Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
  1118. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  1119. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  1120. multimask_output (`bool`, *optional*):
  1121. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  1122. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  1123. "best" mask, by specifying `multimask_output=False`.
  1124. attention_similarity (`torch.FloatTensor`, *optional*):
  1125. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  1126. model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1127. target_embedding (`torch.FloatTensor`, *optional*):
  1128. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  1129. the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1130. Example:
  1131. ```python
  1132. >>> from PIL import Image
  1133. >>> import httpx
  1134. >>> from io import BytesIO
  1135. >>> from transformers import AutoModel, AutoProcessor
  1136. >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
  1137. >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
  1138. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  1139. >>> with httpx.stream("GET", url) as response:
  1140. ... raw_image = Image.open(BytesIO(response.read())).convert("RGB")
  1141. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  1142. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  1143. >>> # Get segmentation mask
  1144. >>> outputs = model(**inputs)
  1145. >>> # Postprocess masks
  1146. >>> masks = processor.post_process_masks(
  1147. ... outputs.pred_masks, inputs["original_sizes"]
  1148. ... )
  1149. ```
  1150. """
  1151. if not ((pixel_values is None) ^ (image_embeddings is None)):
  1152. raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
  1153. if input_points is not None and input_boxes is not None:
  1154. if input_points.shape[1] != input_boxes.shape[1]:
  1155. raise ValueError(
  1156. f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
  1157. )
  1158. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  1159. # repeat with batch size
  1160. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
  1161. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  1162. vision_attentions = None
  1163. vision_hidden_states = None
  1164. if pixel_values is not None:
  1165. image_outputs: Sam2VisionEncoderOutput = self.get_image_features(pixel_values, return_dict=True, **kwargs)
  1166. feature_maps = image_outputs.fpn_hidden_states
  1167. vision_hidden_states = image_outputs.hidden_states
  1168. vision_attentions = image_outputs.attentions
  1169. # add no memory embedding to the last feature map
  1170. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  1171. # reshape feature maps to the same shape as the backbone feature sizes
  1172. image_embeddings = [
  1173. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  1174. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  1175. ]
  1176. if input_points is not None and input_labels is None:
  1177. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  1178. if input_points is None and input_boxes is None:
  1179. # If no points are provide, pad with an empty point (with label -1)
  1180. input_points = torch.zeros(
  1181. batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
  1182. )
  1183. input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
  1184. if input_masks is not None:
  1185. # If mask_inputs is provided, downsize it into low-res mask input if needed
  1186. # and feed it as a dense mask prompt into the SAM mask encoder
  1187. if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
  1188. input_masks = F.interpolate(
  1189. input_masks.float(),
  1190. size=self.prompt_encoder.mask_input_size,
  1191. align_corners=False,
  1192. mode="bilinear",
  1193. antialias=True, # use antialias for downsampling
  1194. ).to(input_masks.dtype)
  1195. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1196. input_points=input_points,
  1197. input_labels=input_labels,
  1198. input_boxes=input_boxes,
  1199. input_masks=input_masks,
  1200. )
  1201. low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
  1202. image_embeddings=image_embeddings[-1],
  1203. image_positional_embeddings=image_positional_embeddings,
  1204. sparse_prompt_embeddings=sparse_embeddings,
  1205. dense_prompt_embeddings=dense_embeddings,
  1206. multimask_output=multimask_output,
  1207. high_resolution_features=image_embeddings[:-1],
  1208. attention_similarity=attention_similarity,
  1209. target_embedding=target_embedding,
  1210. **kwargs,
  1211. )
  1212. return Sam2ImageSegmentationOutput(
  1213. iou_scores=iou_scores,
  1214. pred_masks=low_res_multimasks,
  1215. object_score_logits=object_score_logits,
  1216. image_embeddings=image_embeddings,
  1217. vision_hidden_states=vision_hidden_states,
  1218. vision_attentions=vision_attentions,
  1219. )
  1220. __all__ = [
  1221. "Sam2Model",
  1222. "Sam2VisionModel",
  1223. "Sam2PreTrainedModel",
  1224. "Sam2ImageProcessor",
  1225. "Sam2HieraDetModel",
  1226. ]