modeling_sam.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361
  1. # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SAM model."""
  15. import collections
  16. from collections.abc import Callable
  17. from dataclasses import dataclass
  18. import numpy as np
  19. import torch
  20. import torch.nn.functional as F
  21. from torch import Tensor, nn
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput
  26. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import ModelOutput, auto_docstring, logging
  29. from ...utils.generic import TransformersKwargs, merge_with_config_defaults
  30. from ...utils.output_capturing import OutputRecorder, capture_outputs
  31. from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
  32. logger = logging.get_logger(__name__)
  33. @dataclass
  34. @auto_docstring(
  35. custom_intro="""
  36. Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
  37. layer to the pooler_output.
  38. """
  39. )
  40. class SamVisionEncoderOutput(ModelOutput):
  41. r"""
  42. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  43. The image embeddings obtained by applying the projection layer to the pooler_output.
  44. """
  45. image_embeds: torch.FloatTensor | None = None
  46. last_hidden_state: torch.FloatTensor | None = None
  47. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  48. attentions: tuple[torch.FloatTensor, ...] | None = None
  49. @dataclass
  50. @auto_docstring(
  51. custom_intro="""
  52. Base class for Segment-Anything model's output
  53. """
  54. )
  55. class SamImageSegmentationOutput(ModelOutput):
  56. r"""
  57. iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
  58. The iou scores of the predicted masks.
  59. pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
  60. The predicted low resolutions masks. Needs to be post-processed by the processor
  61. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  62. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  63. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  64. Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
  65. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  66. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  67. sequence_length)`.
  68. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  69. heads.
  70. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  71. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  72. sequence_length)`.
  73. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  74. heads.
  75. """
  76. iou_scores: torch.FloatTensor | None = None
  77. pred_masks: torch.FloatTensor | None = None
  78. vision_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  79. vision_attentions: tuple[torch.FloatTensor, ...] | None = None
  80. mask_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  81. class SamPatchEmbeddings(nn.Module):
  82. """
  83. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  84. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  85. Transformer.
  86. """
  87. def __init__(self, config):
  88. super().__init__()
  89. image_size, patch_size = config.image_size, config.patch_size
  90. num_channels, hidden_size = config.num_channels, config.hidden_size
  91. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  92. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  93. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  94. self.image_size = image_size
  95. self.patch_size = patch_size
  96. self.num_channels = num_channels
  97. self.num_patches = num_patches
  98. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  99. def forward(self, pixel_values):
  100. batch_size, num_channels, height, width = pixel_values.shape
  101. if num_channels != self.num_channels:
  102. raise ValueError(
  103. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  104. )
  105. if height != self.image_size[0] or width != self.image_size[1]:
  106. raise ValueError(
  107. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  108. )
  109. embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
  110. return embeddings
  111. class SamMLPBlock(nn.Module):
  112. def __init__(self, config):
  113. super().__init__()
  114. self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
  115. self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
  116. self.act = ACT2FN[config.hidden_act]
  117. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  118. hidden_states = self.lin1(hidden_states)
  119. hidden_states = self.act(hidden_states)
  120. hidden_states = self.lin2(hidden_states)
  121. return hidden_states
  122. # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
  123. class SamLayerNorm(nn.LayerNorm):
  124. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  125. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  126. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  127. """
  128. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  129. super().__init__(normalized_shape, eps=eps, **kwargs)
  130. if data_format not in ["channels_last", "channels_first"]:
  131. raise NotImplementedError(f"Unsupported data format: {data_format}")
  132. self.data_format = data_format
  133. def forward(self, features: torch.Tensor) -> torch.Tensor:
  134. """
  135. Args:
  136. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  137. """
  138. if self.data_format == "channels_first":
  139. features = features.permute(0, 2, 3, 1)
  140. features = super().forward(features)
  141. features = features.permute(0, 3, 1, 2)
  142. else:
  143. features = super().forward(features)
  144. return features
  145. def eager_attention_forward(
  146. module: nn.Module,
  147. query: torch.Tensor,
  148. key: torch.Tensor,
  149. value: torch.Tensor,
  150. attention_mask: torch.Tensor | None,
  151. scaling: float,
  152. dropout: float = 0.0,
  153. **kwargs,
  154. ):
  155. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  156. if attention_mask is not None:
  157. attn_weights = attn_weights + attention_mask
  158. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  159. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  160. attn_output = torch.matmul(attn_weights, value)
  161. attn_output = attn_output.transpose(1, 2).contiguous()
  162. return attn_output, attn_weights
  163. class SamAttention(nn.Module):
  164. """
  165. SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  166. values.
  167. """
  168. def __init__(self, config, downsample_rate=None):
  169. super().__init__()
  170. self.config = config
  171. self.hidden_size = config.hidden_size
  172. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  173. self.internal_dim = config.hidden_size // downsample_rate
  174. self.num_attention_heads = config.num_attention_heads
  175. if self.internal_dim % config.num_attention_heads != 0:
  176. raise ValueError("num_attention_heads must divide hidden_size.")
  177. self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
  178. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  179. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  180. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  181. self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
  182. self.is_causal = False
  183. def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
  184. batch, point_batch_size, n_tokens, channel = hidden_states.shape
  185. c_per_head = channel // num_attention_heads
  186. hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
  187. return hidden_states.transpose(1, 2)
  188. def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
  189. batch, n_tokens, n_heads, c_per_head = hidden_states.shape
  190. return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
  191. def forward(
  192. self,
  193. query: Tensor,
  194. key: Tensor,
  195. value: Tensor,
  196. attention_similarity: Tensor | None = None,
  197. **kwargs: Unpack[TransformersKwargs],
  198. ) -> Tensor:
  199. # Input projections
  200. query = self.q_proj(query)
  201. key = self.k_proj(key)
  202. value = self.v_proj(value)
  203. point_batch_size = query.shape[1]
  204. # Separate into heads
  205. query = self._separate_heads(query, self.num_attention_heads)
  206. key = self._separate_heads(key, self.num_attention_heads)
  207. value = self._separate_heads(value, self.num_attention_heads)
  208. # SamAttention
  209. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  210. self.config._attn_implementation, eager_attention_forward
  211. )
  212. attn_output, attn_weights = attention_interface(
  213. self,
  214. query,
  215. key,
  216. value,
  217. attention_mask=attention_similarity,
  218. dropout=0.0,
  219. scaling=self.scaling,
  220. is_causal=self.is_causal,
  221. **kwargs,
  222. )
  223. attn_output = self._recombine_heads(attn_output, point_batch_size)
  224. attn_output = self.out_proj(attn_output)
  225. return attn_output, attn_weights
  226. class SamTwoWayAttentionBlock(nn.Module):
  227. def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
  228. """
  229. A transformer block with four layers:
  230. (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
  231. sparse inputs (4) cross attention of dense inputs -> sparse inputs
  232. Arguments:
  233. config (`SamMaskDecoderConfig`):
  234. The configuration file used to instantiate the block
  235. attention_downsample_rate (*optionalk*, int, defaults to 2):
  236. The downsample ratio of the block used to reduce the inner dim of the attention.
  237. skip_first_layer_pe (*optional*, bool, defaults to `False`):
  238. Whether or not to skip the addition of the query_point_embedding on the first layer.
  239. """
  240. super().__init__()
  241. self.hidden_size = config.hidden_size
  242. self.layer_norm_eps = config.layer_norm_eps
  243. self.self_attn = SamAttention(config, downsample_rate=1)
  244. self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  245. self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
  246. self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  247. self.mlp = SamMLPBlock(config)
  248. self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  249. self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  250. self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
  251. self.skip_first_layer_pe = skip_first_layer_pe
  252. def forward(
  253. self,
  254. queries: Tensor,
  255. keys: Tensor,
  256. query_point_embedding: Tensor,
  257. key_point_embedding: Tensor,
  258. attention_similarity: Tensor,
  259. **kwargs: Unpack[TransformersKwargs],
  260. ):
  261. # Self attention block
  262. if self.skip_first_layer_pe:
  263. queries, _ = self.self_attn(query=queries, key=queries, value=queries)
  264. else:
  265. query = queries + query_point_embedding
  266. attn_out, _ = self.self_attn(query=query, key=query, value=queries)
  267. queries = queries + attn_out
  268. queries = self.layer_norm1(queries)
  269. # Cross attention block, tokens attending to image embedding
  270. query = queries + query_point_embedding
  271. key = keys + key_point_embedding
  272. attn_out, _ = self.cross_attn_token_to_image(
  273. query=query, key=key, value=keys, attention_similarity=attention_similarity
  274. )
  275. queries = queries + attn_out
  276. queries = self.layer_norm2(queries)
  277. # MLP block
  278. mlp_out = self.mlp(queries)
  279. queries = queries + mlp_out
  280. queries = self.layer_norm3(queries)
  281. # Cross attention block, image embedding attending to tokens
  282. query = queries + query_point_embedding
  283. key = keys + key_point_embedding
  284. attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
  285. keys = keys + attn_out
  286. keys = self.layer_norm4(keys)
  287. return queries, keys, attn_out
  288. class SamTwoWayTransformer(nn.Module):
  289. def __init__(self, config: SamMaskDecoderConfig):
  290. super().__init__()
  291. self.config = config
  292. self.num_hidden_layers = config.num_hidden_layers
  293. self.layers = nn.ModuleList()
  294. for i in range(self.num_hidden_layers):
  295. self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
  296. self.final_attn_token_to_image = SamAttention(config)
  297. self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
  298. def forward(
  299. self,
  300. point_embeddings: Tensor,
  301. image_embeddings: Tensor,
  302. image_positional_embeddings: Tensor,
  303. attention_similarity: Tensor,
  304. target_embedding=None,
  305. **kwargs: Unpack[TransformersKwargs],
  306. ) -> tuple | BaseModelOutput:
  307. if image_embeddings is None:
  308. raise ValueError("You have to specify an image_embedding")
  309. image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  310. image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  311. # Prepare queries
  312. queries = point_embeddings
  313. keys = image_embeddings
  314. # Apply transformer blocks and final layernorm
  315. for layer in self.layers:
  316. if target_embedding is not None:
  317. queries += target_embedding
  318. queries, keys, _ = layer(
  319. queries=queries,
  320. keys=keys,
  321. query_point_embedding=point_embeddings,
  322. key_point_embedding=image_positional_embeddings,
  323. attention_similarity=attention_similarity,
  324. **kwargs,
  325. )
  326. # Apply the final attention layer from the points to the image
  327. query = queries + point_embeddings
  328. key = keys + image_positional_embeddings
  329. attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
  330. queries = queries + attn_out
  331. queries = self.layer_norm_final_attn(queries)
  332. return queries, keys
  333. class SamFeedForward(nn.Module):
  334. def __init__(
  335. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
  336. ):
  337. super().__init__()
  338. self.num_layers = num_layers
  339. self.activation = nn.ReLU()
  340. self.proj_in = nn.Linear(input_dim, hidden_dim)
  341. self.proj_out = nn.Linear(hidden_dim, output_dim)
  342. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  343. self.sigmoid_output = sigmoid_output
  344. def forward(self, hidden_states):
  345. hidden_states = self.proj_in(hidden_states)
  346. hidden_states = self.activation(hidden_states)
  347. for layer in self.layers:
  348. hidden_states = self.activation(layer(hidden_states))
  349. hidden_states = self.proj_out(hidden_states)
  350. if self.sigmoid_output:
  351. hidden_states = F.sigmoid(hidden_states)
  352. return hidden_states
  353. class SamMaskDecoder(nn.Module):
  354. def __init__(self, config: SamMaskDecoderConfig):
  355. super().__init__()
  356. self.config = config
  357. self.hidden_size = config.hidden_size
  358. self.num_multimask_outputs = config.num_multimask_outputs
  359. self.num_mask_tokens = config.num_multimask_outputs + 1
  360. self.iou_token = nn.Embedding(1, self.hidden_size)
  361. self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
  362. self.transformer = SamTwoWayTransformer(config)
  363. # should we create a new class for this?
  364. self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
  365. self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
  366. self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
  367. self.activation = nn.GELU()
  368. mlps_list = []
  369. for _ in range(self.num_mask_tokens):
  370. mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
  371. self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
  372. self.iou_prediction_head = SamFeedForward(
  373. self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
  374. )
  375. def forward(
  376. self,
  377. image_embeddings: torch.Tensor,
  378. image_positional_embeddings: torch.Tensor,
  379. sparse_prompt_embeddings: torch.Tensor,
  380. dense_prompt_embeddings: torch.Tensor,
  381. multimask_output: bool,
  382. attention_similarity: torch.Tensor | None = None,
  383. target_embedding: torch.Tensor | None = None,
  384. ) -> tuple[torch.Tensor, torch.Tensor]:
  385. """
  386. Predict masks given image and prompt embeddings.
  387. Args:
  388. image_embeddings (`torch.Tensor`):
  389. the embeddings from the image encoder
  390. image_positional_embedding (`torch.Tensor`):
  391. positional encoding with the shape of image_embeddings
  392. sparse_prompt_embeddings (`torch.Tensor`):
  393. The embeddings of the points and boxes
  394. dense_prompt_embeddings (`torch.Tensor`):
  395. the embeddings of the mask inputs
  396. multimask_output (bool):
  397. Whether to return multiple masks or a single mask.
  398. """
  399. batch_size, num_channels, height, width = image_embeddings.shape
  400. point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
  401. # Concatenate output tokens
  402. output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
  403. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  404. if sparse_prompt_embeddings is not None:
  405. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  406. else:
  407. tokens = output_tokens
  408. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  409. # Expand per-image data in batch direction to be per-point
  410. image_embeddings = image_embeddings + dense_prompt_embeddings
  411. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
  412. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  413. # Run the transformer, image_positional_embedding are consumed
  414. point_embedding, image_embeddings = self.transformer(
  415. point_embeddings=point_embeddings,
  416. image_embeddings=image_embeddings,
  417. image_positional_embeddings=image_positional_embeddings,
  418. attention_similarity=attention_similarity,
  419. target_embedding=target_embedding,
  420. )
  421. iou_token_out = point_embedding[:, :, 0, :]
  422. mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
  423. # Upscale mask embeddings and predict masks using the mask tokens
  424. image_embeddings = image_embeddings.transpose(2, 3).reshape(
  425. batch_size * point_batch_size, num_channels, height, width
  426. )
  427. upscaled_embedding = self.upscale_conv1(image_embeddings)
  428. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  429. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
  430. hyper_in_list = []
  431. for i in range(self.num_mask_tokens):
  432. current_mlp = self.output_hypernetworks_mlps[i]
  433. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  434. hyper_in = torch.stack(hyper_in_list, dim=2)
  435. _, num_channels, height, width = upscaled_embedding.shape
  436. upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
  437. masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
  438. # Generate mask quality predictions
  439. iou_pred = self.iou_prediction_head(iou_token_out)
  440. # Select the correct mask or masks for output
  441. if multimask_output:
  442. mask_slice = slice(1, None)
  443. else:
  444. mask_slice = slice(0, 1)
  445. masks = masks[:, :, mask_slice, :, :]
  446. iou_pred = iou_pred[:, :, mask_slice]
  447. return masks, iou_pred
  448. class SamPositionalEmbedding(nn.Module):
  449. def __init__(self, config):
  450. super().__init__()
  451. self.scale = config.scale
  452. self.positional_embedding = nn.Parameter(self.scale * torch.randn((2, config.num_pos_feats)))
  453. def forward(self, input_coords, input_shape=None):
  454. """Positionally encode points that are normalized to [0,1]."""
  455. coordinates = input_coords.clone()
  456. if input_shape is not None:
  457. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  458. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  459. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  460. coordinates = 2 * coordinates - 1
  461. coordinates = coordinates.to(self.positional_embedding.dtype)
  462. coordinates = coordinates @ self.positional_embedding
  463. coordinates = 2 * np.pi * coordinates
  464. # outputs d_1 x ... x d_n x channel shape
  465. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  466. class SamMaskEmbedding(nn.Module):
  467. def __init__(self, config: SamPromptEncoderConfig):
  468. super().__init__()
  469. self.mask_input_channels = config.mask_input_channels // 4
  470. self.activation = ACT2FN[config.hidden_act]
  471. self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
  472. self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
  473. self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
  474. self.layer_norm1 = SamLayerNorm(
  475. self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
  476. )
  477. self.layer_norm2 = SamLayerNorm(
  478. self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
  479. )
  480. def forward(self, masks):
  481. hidden_states = self.conv1(masks)
  482. hidden_states = self.layer_norm1(hidden_states)
  483. hidden_states = self.activation(hidden_states)
  484. hidden_states = self.conv2(hidden_states)
  485. hidden_states = self.layer_norm2(hidden_states)
  486. hidden_states = self.activation(hidden_states)
  487. dense_embeddings = self.conv3(hidden_states)
  488. return dense_embeddings
  489. class SamPromptEncoder(nn.Module):
  490. def __init__(self, config: SamConfig):
  491. super().__init__()
  492. self.shared_embedding = SamPositionalEmbedding(config.vision_config)
  493. config = config.prompt_encoder_config
  494. self.mask_embed = SamMaskEmbedding(config)
  495. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  496. self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
  497. self.input_image_size = config.image_size
  498. self.point_embed = nn.ModuleList(
  499. [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
  500. )
  501. self.hidden_size = config.hidden_size
  502. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  503. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  504. """Embeds point prompts."""
  505. points = points + 0.5 # Shift to center of pixel
  506. if pad:
  507. target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
  508. target_labels_shape = (points.shape[0], points.shape[1], 1)
  509. padding_point = torch.zeros(target_point_shape, device=points.device)
  510. padding_label = -torch.ones(target_labels_shape, device=labels.device)
  511. points = torch.cat([points, padding_point], dim=2)
  512. labels = torch.cat([labels, padding_label], dim=2)
  513. input_shape = (self.input_image_size, self.input_image_size)
  514. point_embedding = self.shared_embedding(points, input_shape)
  515. # torch.where and expanding the labels tensor is required by the ONNX export
  516. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  517. # This is required for the ONNX export. The dtype, device need to be explicitly
  518. # specified as otherwise torch.onnx.export interprets as double
  519. point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
  520. point_embedding = torch.where(
  521. (labels == 0)[:, :, :, None],
  522. point_embedding + self.point_embed[0].weight[None, None, :, :],
  523. point_embedding,
  524. )
  525. point_embedding = torch.where(
  526. (labels == 1)[:, :, :, None],
  527. point_embedding + self.point_embed[1].weight[None, None, :, :],
  528. point_embedding,
  529. )
  530. return point_embedding
  531. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  532. """Embeds box prompts."""
  533. boxes = boxes + 0.5 # Shift to center of pixel
  534. batch_size, nb_boxes = boxes.shape[:2]
  535. coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
  536. input_shape = (self.input_image_size, self.input_image_size)
  537. corner_embedding = self.shared_embedding(coords, input_shape)
  538. corner_embedding[:, :, 0, :] += self.point_embed[2].weight
  539. corner_embedding[:, :, 1, :] += self.point_embed[3].weight
  540. return corner_embedding
  541. def forward(
  542. self,
  543. input_points: tuple[torch.Tensor, torch.Tensor] | None,
  544. input_labels: torch.Tensor | None,
  545. input_boxes: torch.Tensor | None,
  546. input_masks: torch.Tensor | None,
  547. ) -> tuple[torch.Tensor, torch.Tensor]:
  548. """
  549. Embeds different types of prompts, returning both sparse and dense embeddings.
  550. Args:
  551. points (`torch.Tensor`, *optional*):
  552. point coordinates and labels to embed.
  553. boxes (`torch.Tensor`, *optional*):
  554. boxes to embed
  555. masks (`torch.Tensor`, *optional*):
  556. masks to embed
  557. """
  558. sparse_embeddings = None
  559. batch_size = 1
  560. if input_points is not None:
  561. batch_size = input_points.shape[0]
  562. if input_labels is None:
  563. raise ValueError("If points are provided, labels must also be provided.")
  564. point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
  565. sparse_embeddings = point_embeddings
  566. if input_boxes is not None:
  567. batch_size = input_boxes.shape[0]
  568. box_embeddings = self._embed_boxes(input_boxes)
  569. if sparse_embeddings is None:
  570. sparse_embeddings = box_embeddings
  571. else:
  572. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
  573. if input_masks is not None:
  574. dense_embeddings = self.mask_embed(input_masks)
  575. else:
  576. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  577. batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  578. )
  579. return sparse_embeddings, dense_embeddings
  580. class SamVisionAttention(nn.Module):
  581. """Multi-head Attention block with relative position embeddings."""
  582. def __init__(self, config, window_size):
  583. super().__init__()
  584. input_size = (
  585. (config.image_size // config.patch_size, config.image_size // config.patch_size)
  586. if window_size == 0
  587. else (window_size, window_size)
  588. )
  589. self.num_attention_heads = config.num_attention_heads
  590. head_dim = config.hidden_size // config.num_attention_heads
  591. self.scale = head_dim**-0.5
  592. self.dropout = config.attention_dropout
  593. self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
  594. self.proj = nn.Linear(config.hidden_size, config.hidden_size)
  595. self.use_rel_pos = config.use_rel_pos
  596. if self.use_rel_pos:
  597. if input_size is None:
  598. raise ValueError("Input size must be provided if using relative positional encoding.")
  599. # initialize relative positional embeddings
  600. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  601. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  602. def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
  603. """
  604. Get relative positional embeddings according to the relative positions of
  605. query and key sizes.
  606. Args:
  607. q_size (int):
  608. size of the query.
  609. k_size (int):
  610. size of key k.
  611. rel_pos (`torch.Tensor`):
  612. relative position embeddings (L, channel).
  613. Returns:
  614. Extracted positional embeddings according to relative positions.
  615. """
  616. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  617. # Interpolate rel pos.
  618. rel_pos_resized = F.interpolate(
  619. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  620. size=max_rel_dist,
  621. mode="linear",
  622. )
  623. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  624. # Scale the coords with short length if shapes for q and k are different.
  625. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  626. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  627. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  628. return rel_pos_resized[relative_coords.long()]
  629. def get_decomposed_rel_pos(
  630. self,
  631. query: torch.Tensor,
  632. rel_pos_h: torch.Tensor,
  633. rel_pos_w: torch.Tensor,
  634. q_size: tuple[int, int],
  635. k_size: tuple[int, int],
  636. ) -> torch.Tensor:
  637. """
  638. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  639. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
  640. Args:
  641. query (`torch.Tensor`):
  642. query q in the attention layer with shape (batch_size, query_height * query_width, channel).
  643. rel_pos_h (`torch.Tensor`):
  644. relative position embeddings (Lh, channel) for height axis.
  645. rel_pos_w (`torch.Tensor`):
  646. relative position embeddings (Lw, channel) for width axis.
  647. q_size (tuple):
  648. spatial sequence size of query q with (query_height, query_width).
  649. k_size (tuple):
  650. spatial sequence size of key k with (key_height, key_width).
  651. Returns:
  652. decomposed_rel_pos (`torch.Tensor`):
  653. decomposed relative position embeddings.
  654. """
  655. query_height, query_width = q_size
  656. key_height, key_width = k_size
  657. relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
  658. relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
  659. batch_size, _, dim = query.shape
  660. reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
  661. rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
  662. rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
  663. decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
  664. return decomposed_rel_pos
  665. def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]:
  666. batch_size, height, width, _ = hidden_states.shape
  667. # qkv with shape (3, batch_size, nHead, height * width, channel)
  668. qkv = (
  669. self.qkv(hidden_states)
  670. .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  671. .permute(2, 0, 3, 1, 4)
  672. )
  673. # q, k, v with shape (batch_size * nHead, height * width, channel)
  674. query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
  675. attn_weights = (query * self.scale) @ key.transpose(-2, -1)
  676. if self.use_rel_pos:
  677. decomposed_rel_pos = self.get_decomposed_rel_pos(
  678. query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
  679. )
  680. decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
  681. attn_weights = attn_weights + decomposed_rel_pos
  682. attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
  683. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  684. attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
  685. attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
  686. attn_output = self.proj(attn_output)
  687. return attn_output, attn_weights
  688. class SamVisionSdpaAttention(SamVisionAttention):
  689. """
  690. Multi-head Attention block with relative position embeddings.
  691. Using SDPA instead of the default attention.
  692. """
  693. def __init__(self, config, window_size):
  694. super().__init__(config, window_size)
  695. def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
  696. if output_attentions:
  697. logger.warning_once(
  698. f"{self.__class__.__name__} does not support `output_attentions=True`. The returned attention weights will "
  699. "be `None`. If you want to get attention weights, please set `attn_implementation='eager'` when loading the model."
  700. )
  701. batch_size, height, width, _ = hidden_states.shape
  702. # qkv with shape (3, B, nHead, H * W, C)
  703. qkv = (
  704. self.qkv(hidden_states)
  705. .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  706. .permute(2, 0, 3, 1, 4)
  707. )
  708. # q, k, v with shape (B * nHead, H * W, C)
  709. query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
  710. attn_bias = None
  711. if self.use_rel_pos:
  712. decomposed_rel_pos = self.get_decomposed_rel_pos(
  713. query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
  714. )
  715. decomposed_rel_pos = decomposed_rel_pos.reshape(
  716. batch_size, self.num_attention_heads, height * width, height * width
  717. )
  718. attn_bias = decomposed_rel_pos
  719. query = query.view(batch_size, self.num_attention_heads, height * width, -1)
  720. key = key.view(batch_size, self.num_attention_heads, height * width, -1)
  721. value = value.view(batch_size, self.num_attention_heads, height * width, -1)
  722. attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
  723. attn_output = (
  724. attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
  725. .permute(0, 2, 3, 1, 4)
  726. .reshape(batch_size, height, width, -1)
  727. )
  728. attn_output = self.proj(attn_output)
  729. return attn_output, None
  730. SAM_VISION_ATTENTION_CLASSES = {
  731. "eager": SamVisionAttention,
  732. "sdpa": SamVisionSdpaAttention,
  733. }
  734. class SamVisionLayer(GradientCheckpointingLayer):
  735. def __init__(self, config, window_size):
  736. super().__init__()
  737. self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  738. self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
  739. self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  740. self.mlp = SamMLPBlock(config)
  741. self.window_size = window_size
  742. def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
  743. """
  744. Args:
  745. Partition into non-overlapping windows with padding if needed.
  746. hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
  747. size.
  748. Returns:
  749. windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
  750. (pad_height, pad_width): padded height and width before partition
  751. """
  752. batch_size, height, width, channel = hidden_states.shape
  753. pad_h = (window_size - height % window_size) % window_size
  754. pad_w = (window_size - width % window_size) % window_size
  755. hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
  756. pad_height, pad_width = height + pad_h, width + pad_w
  757. hidden_states = hidden_states.reshape(
  758. batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
  759. )
  760. windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
  761. return windows, (pad_height, pad_width)
  762. def window_unpartition(
  763. self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int]
  764. ) -> torch.Tensor:
  765. """
  766. Args:
  767. Window unpartition into original sequences and removing padding.
  768. hidden_states (tensor):
  769. input tokens with [batch_size * num_windows, window_size, window_size, channel].
  770. window_size (int):
  771. window size.
  772. padding_shape (Tuple):
  773. padded height and width (pad_height, pad_width).
  774. original_shape (Tuple): original height and width (height, width) before padding.
  775. Returns:
  776. hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
  777. """
  778. pad_height, pad_width = padding_shape
  779. height, width = original_shape
  780. batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
  781. hidden_states = windows.reshape(
  782. batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
  783. )
  784. hidden_states = (
  785. hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
  786. )
  787. hidden_states = hidden_states[:, :height, :width, :].contiguous()
  788. return hidden_states
  789. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
  790. residual = hidden_states
  791. hidden_states = self.layer_norm1(hidden_states)
  792. # Window partition
  793. if self.window_size > 0:
  794. height, width = hidden_states.shape[1], hidden_states.shape[2]
  795. hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
  796. hidden_states, attn_weights = self.attn(
  797. hidden_states=hidden_states,
  798. )
  799. # Reverse window partition
  800. if self.window_size > 0:
  801. hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
  802. hidden_states = residual + hidden_states
  803. layernorm_output = self.layer_norm2(hidden_states)
  804. hidden_states = hidden_states + self.mlp(layernorm_output)
  805. return hidden_states
  806. class SamVisionNeck(nn.Module):
  807. def __init__(self, config: SamVisionConfig):
  808. super().__init__()
  809. self.config = config
  810. self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
  811. self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
  812. self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
  813. self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
  814. def forward(self, hidden_states):
  815. hidden_states = hidden_states.permute(0, 3, 1, 2)
  816. hidden_states = self.conv1(hidden_states)
  817. hidden_states = self.layer_norm1(hidden_states)
  818. hidden_states = self.conv2(hidden_states)
  819. hidden_states = self.layer_norm2(hidden_states)
  820. return hidden_states
  821. @auto_docstring
  822. class SamPreTrainedModel(PreTrainedModel):
  823. config: SamConfig
  824. base_model_prefix = "sam"
  825. main_input_name = "pixel_values"
  826. input_modalities = ("image",)
  827. _no_split_modules = ["SamVisionAttention"]
  828. supports_gradient_checkpointing = True
  829. _supports_sdpa = True
  830. @torch.no_grad()
  831. def _init_weights(self, module: nn.Module):
  832. super()._init_weights(module)
  833. if isinstance(module, SamVisionAttention):
  834. if module.use_rel_pos:
  835. init.zeros_(module.rel_pos_h)
  836. init.zeros_(module.rel_pos_w)
  837. elif isinstance(module, SamVisionEncoder):
  838. if self.config.use_abs_pos:
  839. init.zeros_(module.pos_embed)
  840. elif isinstance(module, SamPositionalEmbedding):
  841. init.normal_(module.positional_embedding, std=module.scale)
  842. class SamVisionEncoder(SamPreTrainedModel):
  843. _can_record_outputs = {"hidden_states": SamVisionLayer, "attentions": SamVisionAttention}
  844. def __init__(self, config: SamVisionConfig):
  845. super().__init__(config)
  846. self.config = config
  847. self.image_size = config.image_size
  848. self.patch_embed = SamPatchEmbeddings(config)
  849. self.pos_embed = None
  850. if config.use_abs_pos:
  851. # Initialize absolute positional embedding with pretrain image size.
  852. self.pos_embed = nn.Parameter(
  853. torch.zeros(
  854. 1,
  855. config.image_size // config.patch_size,
  856. config.image_size // config.patch_size,
  857. config.hidden_size,
  858. )
  859. )
  860. self.layers = nn.ModuleList()
  861. for i in range(config.num_hidden_layers):
  862. layer = SamVisionLayer(
  863. config,
  864. window_size=config.window_size if i not in config.global_attn_indexes else 0,
  865. )
  866. self.layers.append(layer)
  867. self.neck = SamVisionNeck(config)
  868. self.gradient_checkpointing = False
  869. self.post_init()
  870. def get_input_embeddings(self):
  871. return self.patch_embed
  872. @merge_with_config_defaults
  873. @capture_outputs(tie_last_hidden_states=False)
  874. def forward(
  875. self, pixel_values: torch.FloatTensor | None = None, **kwargs: Unpack[TransformersKwargs]
  876. ) -> tuple | SamVisionEncoderOutput:
  877. if pixel_values is None:
  878. raise ValueError("You have to specify pixel_values")
  879. hidden_states = self.patch_embed(pixel_values)
  880. if self.pos_embed is not None:
  881. hidden_states = hidden_states + self.pos_embed
  882. for layer_module in self.layers:
  883. hidden_states = layer_module(hidden_states)
  884. hidden_states = self.neck(hidden_states)
  885. return SamVisionEncoderOutput(
  886. last_hidden_state=hidden_states,
  887. )
  888. @auto_docstring(
  889. custom_intro="""
  890. The vision model from Sam without any head or projection on top.
  891. """
  892. )
  893. class SamVisionModel(SamPreTrainedModel):
  894. config: SamVisionConfig
  895. main_input_name = "pixel_values"
  896. def __init__(self, config: SamVisionConfig):
  897. super().__init__(config)
  898. self.vision_encoder = SamVisionEncoder(config)
  899. self.post_init()
  900. def get_input_embeddings(self) -> nn.Module:
  901. return self.vision_encoder.patch_embed
  902. @auto_docstring
  903. def forward(
  904. self,
  905. pixel_values: torch.FloatTensor | None = None,
  906. **kwargs: Unpack[TransformersKwargs],
  907. ) -> tuple | SamVisionEncoderOutput:
  908. return self.vision_encoder(pixel_values, **kwargs)
  909. @auto_docstring(
  910. custom_intro="""
  911. Segment Anything Model (SAM) for generating segmentation masks, given an input image and
  912. input points and labels, boxes, or masks.
  913. """
  914. )
  915. class SamModel(SamPreTrainedModel):
  916. input_modalities = ("image", "text")
  917. _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)}
  918. _tied_weights_keys = {
  919. "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding"
  920. }
  921. def __init__(self, config: SamConfig):
  922. super().__init__(config)
  923. self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
  924. self.vision_encoder = SamVisionEncoder(config.vision_config)
  925. self.prompt_encoder = SamPromptEncoder(config)
  926. # The module using it is not a PreTrainedModel subclass so we need this
  927. config.mask_decoder_config._attn_implementation = config._attn_implementation
  928. self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
  929. self.post_init()
  930. def get_input_embeddings(self):
  931. return self.vision_encoder.get_input_embeddings()
  932. def get_image_wide_positional_embeddings(self):
  933. size = self.config.prompt_encoder_config.image_embedding_size
  934. target_device = self.shared_image_embedding.positional_embedding.device
  935. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  936. grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
  937. y_embed = grid.cumsum(dim=0) - 0.5
  938. x_embed = grid.cumsum(dim=1) - 0.5
  939. y_embed = y_embed / size
  940. x_embed = x_embed / size
  941. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  942. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  943. @torch.no_grad()
  944. def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
  945. r"""
  946. Returns the image embeddings by passing the pixel values through the vision encoder.
  947. Args:
  948. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  949. Input pixel values
  950. """
  951. vision_output = self.vision_encoder(
  952. pixel_values,
  953. **kwargs,
  954. )
  955. image_embeddings = vision_output[0]
  956. return image_embeddings
  957. @torch.no_grad()
  958. def get_prompt_embeddings(
  959. self,
  960. input_points: torch.FloatTensor | None = None,
  961. input_labels: torch.LongTensor | None = None,
  962. input_boxes: torch.FloatTensor | None = None,
  963. input_masks: torch.LongTensor | None = None,
  964. ):
  965. r"""
  966. Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
  967. Args:
  968. input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
  969. Optional input points for the prompt encoder. The padding of the point is automatically done by the
  970. processor. `point_batch_size` refers to the number of masks that we want the model to predict per
  971. point. The model will output `point_batch_size` times 3 masks in total.
  972. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
  973. Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
  974. processor, or can be fed by the user.
  975. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
  976. Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
  977. processor. users can also pass manually the input boxes.
  978. input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
  979. Optional input masks for the prompt encoder.
  980. """
  981. prompt_output = self.prompt_encoder(
  982. input_points=input_points,
  983. input_labels=input_labels,
  984. input_boxes=input_boxes,
  985. input_masks=input_masks,
  986. )
  987. return prompt_output
  988. @merge_with_config_defaults
  989. @capture_outputs
  990. @auto_docstring
  991. def forward(
  992. self,
  993. pixel_values: torch.FloatTensor | None = None,
  994. input_points: torch.FloatTensor | None = None,
  995. input_labels: torch.LongTensor | None = None,
  996. input_boxes: torch.FloatTensor | None = None,
  997. input_masks: torch.LongTensor | None = None,
  998. image_embeddings: torch.FloatTensor | None = None,
  999. multimask_output: bool = True,
  1000. attention_similarity: torch.FloatTensor | None = None,
  1001. target_embedding: torch.FloatTensor | None = None,
  1002. **kwargs: Unpack[TransformersKwargs],
  1003. ) -> SamImageSegmentationOutput:
  1004. r"""
  1005. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  1006. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  1007. better results. The points can be obtained by passing a list of list of list to the processor that will
  1008. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  1009. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  1010. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  1011. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  1012. coordinates of the point. If a different number of points is passed either for each image, or for each
  1013. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  1014. computation of the embedding will be skipped for these points using the labels.
  1015. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  1016. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  1017. official implementation, there are 3 types of labels
  1018. - `1`: the point is a point that contains the object of interest
  1019. - `0`: the point is a point that does not contain the object of interest
  1020. - `-1`: the point corresponds to the background
  1021. We added the label:
  1022. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  1023. The padding labels should be automatically done by the processor.
  1024. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  1025. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  1026. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  1027. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  1028. size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
  1029. In the order (`x1`, `y1`, `x2`, `y2`):
  1030. - `x1`: the x coordinate of the top left point of the input box
  1031. - `y1`: the y coordinate of the top left point of the input box
  1032. - `x2`: the x coordinate of the bottom right point of the input box
  1033. - `y2`: the y coordinate of the bottom right point of the input box
  1034. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  1035. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  1036. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  1037. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  1038. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  1039. Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
  1040. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  1041. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  1042. multimask_output (`bool`, *optional*):
  1043. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  1044. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  1045. "best" mask, by specifying `multimask_output=False`.
  1046. attention_similarity (`torch.FloatTensor`, *optional*):
  1047. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  1048. model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1049. target_embedding (`torch.FloatTensor`, *optional*):
  1050. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  1051. the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1052. Example:
  1053. ```python
  1054. >>> from PIL import Image
  1055. >>> import httpx
  1056. >>> from io import BytesIO
  1057. >>> from transformers import AutoModel, AutoProcessor
  1058. >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
  1059. >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
  1060. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  1061. >>> with httpx.stream("GET", url) as response:
  1062. ... raw_image = Image.open(BytesIO(response.read())).convert("RGB")
  1063. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  1064. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  1065. >>> # Get segmentation mask
  1066. >>> outputs = model(**inputs)
  1067. >>> # Postprocess masks
  1068. >>> masks = processor.post_process_masks(
  1069. ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
  1070. ... )
  1071. ```
  1072. """
  1073. if pixel_values is None and image_embeddings is None:
  1074. raise ValueError("Either pixel_values or image_embeddings must be provided.")
  1075. if pixel_values is not None and image_embeddings is not None:
  1076. raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
  1077. if input_points is not None and len(input_points.shape) != 4:
  1078. raise ValueError(
  1079. "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
  1080. f" got {input_points.shape}.",
  1081. )
  1082. if input_boxes is not None and len(input_boxes.shape) != 3:
  1083. raise ValueError(
  1084. "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
  1085. f" got {input_boxes.shape}.",
  1086. )
  1087. if input_points is not None and input_boxes is not None:
  1088. point_batch_size = input_points.shape[1]
  1089. box_batch_size = input_boxes.shape[1]
  1090. if point_batch_size != box_batch_size:
  1091. raise ValueError(
  1092. f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}."
  1093. )
  1094. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  1095. # repeat with batch size
  1096. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
  1097. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  1098. vision_attentions = None
  1099. vision_hidden_states = None
  1100. if pixel_values is not None:
  1101. vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
  1102. image_embeddings = vision_outputs.last_hidden_state
  1103. vision_hidden_states = vision_outputs.hidden_states
  1104. vision_attentions = vision_outputs.attentions
  1105. if input_points is not None and input_labels is None:
  1106. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  1107. if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
  1108. raise ValueError(
  1109. "The batch size of the image embeddings and the input points must be the same. ",
  1110. f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.",
  1111. " if you want to pass multiple points for the same image, make sure that you passed ",
  1112. " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
  1113. " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
  1114. )
  1115. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1116. input_points=input_points,
  1117. input_labels=input_labels,
  1118. input_boxes=input_boxes,
  1119. input_masks=input_masks,
  1120. )
  1121. low_res_masks, iou_predictions = self.mask_decoder(
  1122. image_embeddings=image_embeddings,
  1123. image_positional_embeddings=image_positional_embeddings,
  1124. sparse_prompt_embeddings=sparse_embeddings,
  1125. dense_prompt_embeddings=dense_embeddings,
  1126. multimask_output=multimask_output,
  1127. attention_similarity=attention_similarity,
  1128. target_embedding=target_embedding,
  1129. )
  1130. return SamImageSegmentationOutput(
  1131. iou_scores=iou_predictions,
  1132. pred_masks=low_res_masks,
  1133. vision_hidden_states=vision_hidden_states,
  1134. vision_attentions=vision_attentions,
  1135. )
  1136. __all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]