modeling_edgetam.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/edgetam/modular_edgetam.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_edgetam.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. import numpy as np
  24. import torch
  25. import torch.nn as nn
  26. import torch.nn.functional as F
  27. from torch import Tensor
  28. from ... import initialization as init
  29. from ...activations import ACT2FN
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...pytorch_utils import compile_compatible_method_lru_cache
  35. from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
  36. from ...utils.generic import TransformersKwargs, is_flash_attention_requested, merge_with_config_defaults
  37. from ...utils.output_capturing import OutputRecorder, capture_outputs
  38. from ..auto import AutoModel
  39. from .configuration_edgetam import (
  40. EdgeTamConfig,
  41. EdgeTamMaskDecoderConfig,
  42. EdgeTamPromptEncoderConfig,
  43. EdgeTamVisionConfig,
  44. )
  45. logger = logging.get_logger(__name__)
  46. class EdgeTamLayerNorm(nn.LayerNorm):
  47. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  48. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  49. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  50. """
  51. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  52. super().__init__(normalized_shape, eps=eps, **kwargs)
  53. if data_format not in ["channels_last", "channels_first"]:
  54. raise NotImplementedError(f"Unsupported data format: {data_format}")
  55. self.data_format = data_format
  56. def forward(self, features: torch.Tensor) -> torch.Tensor:
  57. """
  58. Args:
  59. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  60. """
  61. if self.data_format == "channels_first":
  62. features = features.permute(0, 2, 3, 1)
  63. features = super().forward(features)
  64. features = features.permute(0, 3, 1, 2)
  65. else:
  66. features = super().forward(features)
  67. return features
  68. @dataclass
  69. @auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
  70. class EdgeTamVisionEncoderOutput(BaseModelOutputWithPooling):
  71. r"""
  72. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  73. Sequence of hidden-states at the output of the last layer of the model.
  74. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  75. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  76. one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
  77. model at the output of each stage.
  78. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  79. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  80. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  81. the self-attention heads.
  82. fpn_hidden_states (`tuple(torch.FloatTensor)`):
  83. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  84. `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
  85. fpn_position_encoding (`tuple(torch.FloatTensor)`):
  86. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  87. `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
  88. """
  89. fpn_hidden_states: torch.FloatTensor | None = None
  90. fpn_position_encoding: torch.FloatTensor | None = None
  91. def eager_attention_forward(
  92. module: nn.Module,
  93. query: torch.Tensor,
  94. key: torch.Tensor,
  95. value: torch.Tensor,
  96. attention_mask: torch.Tensor | None,
  97. scaling: float,
  98. dropout: float = 0.0,
  99. **kwargs,
  100. ):
  101. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  102. if attention_mask is not None:
  103. attn_weights = attn_weights + attention_mask
  104. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  105. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  106. attn_output = torch.matmul(attn_weights, value)
  107. attn_output = attn_output.transpose(1, 2).contiguous()
  108. return attn_output, attn_weights
  109. class EdgeTamAttention(nn.Module):
  110. """
  111. EDGETAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  112. values.
  113. """
  114. def __init__(self, config, downsample_rate=None):
  115. super().__init__()
  116. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  117. self.config = config
  118. self.hidden_size = config.hidden_size
  119. self.internal_dim = config.hidden_size // downsample_rate
  120. self.num_attention_heads = config.num_attention_heads
  121. self.head_dim = self.internal_dim // config.num_attention_heads
  122. self.scaling = self.head_dim**-0.5
  123. self.is_causal = False
  124. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  125. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  126. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  127. self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
  128. def forward(
  129. self,
  130. query: torch.Tensor,
  131. key: torch.Tensor,
  132. value: torch.Tensor,
  133. attention_similarity: torch.Tensor | None = None,
  134. **kwargs: Unpack[TransformersKwargs],
  135. ) -> tuple[torch.Tensor, torch.Tensor]:
  136. # Input projections
  137. batch_size, point_batch_size = query.shape[:2]
  138. new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
  139. query = self.q_proj(query).view(*new_shape).transpose(1, 2)
  140. key = self.k_proj(key).view(*new_shape).transpose(1, 2)
  141. value = self.v_proj(value).view(*new_shape).transpose(1, 2)
  142. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  143. self.config._attn_implementation, eager_attention_forward
  144. )
  145. if is_flash_attention_requested(self.config) and attention_similarity is not None:
  146. # Target guided masks are represented as float masks and are incompatible with Flash Attention
  147. # Fallback to SDPA for this call only so the rest of the model can still benefit from FA
  148. attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"]
  149. logger.warning_once(
  150. "Falling back to SDPA for target-guided attention because "
  151. "Flash Attention does not support additive bias masks."
  152. )
  153. attn_output, attn_weights = attention_interface(
  154. self,
  155. query,
  156. key,
  157. value,
  158. attention_mask=attention_similarity,
  159. dropout=0.0,
  160. scaling=self.scaling,
  161. is_causal=self.is_causal,
  162. **kwargs,
  163. )
  164. attn_output = attn_output.reshape(
  165. batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
  166. ).contiguous()
  167. attn_output = self.o_proj(attn_output)
  168. return attn_output, attn_weights
  169. class EdgeTamTwoWayAttentionBlock(GradientCheckpointingLayer):
  170. def __init__(self, config: EdgeTamMaskDecoderConfig, skip_first_layer_pe: bool = False):
  171. """
  172. A transformer block with four layers:
  173. (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
  174. sparse inputs (4) cross attention of dense inputs -> sparse inputs
  175. Arguments:
  176. config (`EdgeTamMaskDecoderConfig`):
  177. The configuration file used to instantiate the block
  178. attention_downsample_rate (*optionalk*, int, defaults to 2):
  179. The downsample ratio of the block used to reduce the inner dim of the attention.
  180. skip_first_layer_pe (*optional*, bool, defaults to `False`):
  181. Whether or not to skip the addition of the query_point_embedding on the first layer.
  182. """
  183. super().__init__()
  184. self.self_attn = EdgeTamAttention(config, downsample_rate=1)
  185. self.layer_norm1 = nn.LayerNorm(config.hidden_size)
  186. self.cross_attn_token_to_image = EdgeTamAttention(config)
  187. self.layer_norm2 = nn.LayerNorm(config.hidden_size)
  188. self.mlp = EdgeTamFeedForward(
  189. config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
  190. )
  191. self.layer_norm3 = nn.LayerNorm(config.hidden_size)
  192. self.layer_norm4 = nn.LayerNorm(config.hidden_size)
  193. self.cross_attn_image_to_token = EdgeTamAttention(config)
  194. self.skip_first_layer_pe = skip_first_layer_pe
  195. def forward(
  196. self,
  197. queries: Tensor,
  198. keys: Tensor,
  199. query_point_embedding: Tensor,
  200. key_point_embedding: Tensor,
  201. attention_similarity: Tensor,
  202. **kwargs: Unpack[TransformersKwargs],
  203. ):
  204. # Self attention block
  205. if self.skip_first_layer_pe:
  206. queries, _ = self.self_attn(query=queries, key=queries, value=queries)
  207. else:
  208. query = queries + query_point_embedding
  209. attn_out, _ = self.self_attn(query=query, key=query, value=queries)
  210. queries = queries + attn_out
  211. queries = self.layer_norm1(queries)
  212. # Cross attention block, tokens attending to image embedding
  213. query = queries + query_point_embedding
  214. key = keys + key_point_embedding
  215. attn_out, _ = self.cross_attn_token_to_image(
  216. query=query, key=key, value=keys, attention_similarity=attention_similarity
  217. )
  218. queries = queries + attn_out
  219. queries = self.layer_norm2(queries)
  220. # MLP block
  221. mlp_out = self.mlp(queries)
  222. queries = queries + mlp_out
  223. queries = self.layer_norm3(queries)
  224. # Cross attention block, image embedding attending to tokens
  225. query = queries + query_point_embedding
  226. key = keys + key_point_embedding
  227. attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
  228. keys = keys + attn_out
  229. keys = self.layer_norm4(keys)
  230. return queries, keys, attn_out
  231. class EdgeTamFeedForward(nn.Module):
  232. def __init__(
  233. self,
  234. input_dim: int,
  235. hidden_dim: int,
  236. output_dim: int,
  237. num_layers: int,
  238. activation: str = "relu",
  239. sigmoid_output: bool = False,
  240. ):
  241. super().__init__()
  242. self.num_layers = num_layers
  243. self.activation = ACT2FN[activation]
  244. self.proj_in = nn.Linear(input_dim, hidden_dim)
  245. self.proj_out = nn.Linear(hidden_dim, output_dim)
  246. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  247. self.sigmoid_output = sigmoid_output
  248. def forward(self, hidden_states):
  249. hidden_states = self.proj_in(hidden_states)
  250. hidden_states = self.activation(hidden_states)
  251. for layer in self.layers:
  252. hidden_states = self.activation(layer(hidden_states))
  253. hidden_states = self.proj_out(hidden_states)
  254. if self.sigmoid_output:
  255. hidden_states = F.sigmoid(hidden_states)
  256. return hidden_states
  257. @auto_docstring
  258. class EdgeTamPreTrainedModel(PreTrainedModel):
  259. config_class = EdgeTamConfig
  260. base_model_prefix = "edgetam"
  261. main_input_name = "pixel_values"
  262. input_modalities = ("image",)
  263. _supports_sdpa = True
  264. _supports_flash_attn = True
  265. _supports_attention_backend = True
  266. _keys_to_ignore_on_load_unexpected = None
  267. @torch.no_grad()
  268. def _init_weights(self, module):
  269. super()._init_weights(module)
  270. if isinstance(module, EdgeTamModel):
  271. if module.no_memory_embedding is not None:
  272. init.zeros_(module.no_memory_embedding)
  273. elif hasattr(module, "positional_embedding"):
  274. init.normal_(module.positional_embedding, std=module.scale)
  275. # copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
  276. class EdgeTamSinePositionEmbedding(nn.Module):
  277. """
  278. This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
  279. need paper, generalized to work on images.
  280. """
  281. def __init__(
  282. self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: float | None = None
  283. ):
  284. super().__init__()
  285. if scale is not None and normalize is False:
  286. raise ValueError("normalize should be True if scale is passed")
  287. self.num_pos_feats = num_pos_feats
  288. self.temperature = temperature
  289. self.normalize = normalize
  290. self.scale = 2 * math.pi if scale is None else scale
  291. @compile_compatible_method_lru_cache(maxsize=1)
  292. def forward(
  293. self,
  294. shape: torch.Size,
  295. device: torch.device | str,
  296. dtype: torch.dtype,
  297. mask: Tensor | None = None,
  298. ) -> Tensor:
  299. if mask is None:
  300. mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
  301. not_mask = (~mask).to(dtype)
  302. y_embed = not_mask.cumsum(1)
  303. x_embed = not_mask.cumsum(2)
  304. if self.normalize:
  305. eps = 1e-6
  306. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  307. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  308. dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
  309. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
  310. pos_x = x_embed[:, :, :, None] / dim_t
  311. pos_y = y_embed[:, :, :, None] / dim_t
  312. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  313. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  314. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  315. return pos
  316. class EdgeTamVisionNeck(nn.Module):
  317. def __init__(self, config: EdgeTamVisionConfig):
  318. super().__init__()
  319. self.config = config
  320. self.position_encoding = EdgeTamSinePositionEmbedding(
  321. num_pos_feats=config.fpn_hidden_size // 2, normalize=True
  322. )
  323. self.convs = nn.ModuleList()
  324. for in_channels in config.backbone_channel_list:
  325. self.convs.append(
  326. nn.Conv2d(
  327. in_channels=in_channels,
  328. out_channels=config.fpn_hidden_size,
  329. kernel_size=config.fpn_kernel_size,
  330. stride=config.fpn_stride,
  331. padding=config.fpn_padding,
  332. ),
  333. )
  334. self.fpn_top_down_levels = config.fpn_top_down_levels
  335. def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
  336. fpn_hidden_states = ()
  337. fpn_position_encoding = ()
  338. # forward in top-down order (from low to high resolution)
  339. n = len(self.convs) - 1
  340. for i in range(n, -1, -1):
  341. lateral_features = hidden_states[i].permute(0, 3, 1, 2)
  342. lateral_features = self.convs[n - i](lateral_features.to(self.convs[i].weight.dtype))
  343. if i not in self.fpn_top_down_levels or i == n:
  344. prev_features = lateral_features
  345. else:
  346. top_down_features = F.interpolate(
  347. prev_features.to(dtype=torch.float32),
  348. scale_factor=2.0,
  349. mode="nearest",
  350. align_corners=None,
  351. antialias=False,
  352. ).to(lateral_features.dtype)
  353. prev_features = lateral_features + top_down_features
  354. prev_position_encoding = self.position_encoding(
  355. prev_features.shape, prev_features.device, prev_features.dtype
  356. ).to(prev_features.dtype)
  357. fpn_hidden_states += (prev_features,)
  358. fpn_position_encoding += (prev_position_encoding,)
  359. return fpn_hidden_states, fpn_position_encoding
  360. @auto_docstring(
  361. custom_intro="""
  362. The vision model from EdgeTAM without any head or projection on top.
  363. """
  364. )
  365. class EdgeTamVisionModel(EdgeTamPreTrainedModel):
  366. config_class = EdgeTamVisionConfig
  367. main_input_name = "pixel_values"
  368. # TODO: TimmWrapper models aren't compatible with _can_record_outputs yet. We specifically set this to
  369. # an empty dict to avoid the _can_record_outputs from Sam2VisionModel being inherited here.
  370. _can_record_outputs = {}
  371. def __init__(self, config: EdgeTamVisionConfig):
  372. super().__init__(config)
  373. self.config = config
  374. self.backbone = AutoModel.from_config(config.backbone_config)
  375. self.neck = EdgeTamVisionNeck(config)
  376. self.num_feature_levels = config.num_feature_levels
  377. self.post_init()
  378. @merge_with_config_defaults
  379. @capture_outputs
  380. def forward(
  381. self,
  382. pixel_values: torch.FloatTensor | None = None,
  383. **kwargs: Unpack[TransformersKwargs],
  384. ) -> tuple | EdgeTamVisionEncoderOutput:
  385. if pixel_values is None:
  386. raise ValueError("You have to specify pixel_values")
  387. # Forward through backbone
  388. backbone_output = self.backbone(pixel_values, **kwargs)
  389. intermediate_hidden_states = backbone_output.last_hidden_state
  390. intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states]
  391. fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
  392. # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
  393. fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
  394. fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
  395. return EdgeTamVisionEncoderOutput(
  396. last_hidden_state=intermediate_hidden_states[-1],
  397. fpn_hidden_states=fpn_hidden_states,
  398. fpn_position_encoding=fpn_position_encoding,
  399. hidden_states=backbone_output.hidden_states,
  400. )
  401. @dataclass
  402. @auto_docstring(custom_intro="Base class for the EdgeTam model's output.")
  403. class EdgeTamImageSegmentationOutput(ModelOutput):
  404. r"""
  405. iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
  406. The Intersection over Union (IoU) scores of the predicted masks.
  407. pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
  408. The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
  409. by the processor to be brought to the original image size.
  410. object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
  411. Logits for the object score, indicating if an object is present.
  412. image_embeddings (`tuple(torch.FloatTensor)`):
  413. The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
  414. tensor has shape `(batch_size, channels, height, width)`.
  415. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  416. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
  417. Hidden-states of the vision model at the output of each stage.
  418. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  419. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  420. Attentions weights of the vision model.
  421. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  422. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  423. Attentions weights of the mask decoder.
  424. """
  425. iou_scores: torch.FloatTensor | None = None
  426. pred_masks: torch.FloatTensor | None = None
  427. object_score_logits: torch.FloatTensor | None = None
  428. image_embeddings: tuple[torch.FloatTensor, ...] = None
  429. vision_hidden_states: tuple[torch.FloatTensor, ...] | None = None
  430. vision_attentions: tuple[torch.FloatTensor, ...] | None = None
  431. mask_decoder_attentions: tuple[torch.FloatTensor, ...] | None = None
  432. class EdgeTamPositionalEmbedding(nn.Module):
  433. def __init__(self, config: EdgeTamPromptEncoderConfig):
  434. super().__init__()
  435. self.scale = config.scale
  436. positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
  437. self.register_buffer("positional_embedding", positional_embedding)
  438. def forward(self, input_coords, input_shape=None):
  439. """Positionally encode points that are normalized to [0,1]."""
  440. coordinates = input_coords.clone()
  441. if input_shape is not None:
  442. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  443. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  444. coordinates.to(torch.float32)
  445. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  446. coordinates = 2 * coordinates - 1
  447. coordinates = coordinates.to(self.positional_embedding.dtype)
  448. coordinates = coordinates @ self.positional_embedding
  449. coordinates = 2 * np.pi * coordinates
  450. # outputs d_1 x ... x d_n x channel shape
  451. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  452. class EdgeTamMaskEmbedding(nn.Module):
  453. def __init__(self, config: EdgeTamPromptEncoderConfig):
  454. super().__init__()
  455. self.mask_input_channels = config.mask_input_channels // 4
  456. self.activation = ACT2FN[config.hidden_act]
  457. self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
  458. self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
  459. self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
  460. self.layer_norm1 = EdgeTamLayerNorm(
  461. self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
  462. )
  463. self.layer_norm2 = EdgeTamLayerNorm(
  464. self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
  465. )
  466. def forward(self, masks):
  467. hidden_states = self.conv1(masks)
  468. hidden_states = self.layer_norm1(hidden_states)
  469. hidden_states = self.activation(hidden_states)
  470. hidden_states = self.conv2(hidden_states)
  471. hidden_states = self.layer_norm2(hidden_states)
  472. hidden_states = self.activation(hidden_states)
  473. dense_embeddings = self.conv3(hidden_states)
  474. return dense_embeddings
  475. class EdgeTamPromptEncoder(nn.Module):
  476. def __init__(self, config: EdgeTamPromptEncoderConfig):
  477. super().__init__()
  478. self.shared_embedding = EdgeTamPositionalEmbedding(config)
  479. self.mask_embed = EdgeTamMaskEmbedding(config)
  480. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  481. self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  482. self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
  483. self.input_image_size = config.image_size
  484. self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
  485. self.hidden_size = config.hidden_size
  486. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  487. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  488. """Embeds point prompts."""
  489. points = points + 0.5 # Shift to center of pixel
  490. if pad:
  491. points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
  492. labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
  493. input_shape = (self.input_image_size, self.input_image_size)
  494. point_embedding = self.shared_embedding(points, input_shape)
  495. # torch.where and expanding the labels tensor is required by the ONNX export
  496. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  497. # This is required for the ONNX export. The dtype, device need to be explicitly
  498. # specified as otherwise torch.onnx.export interprets as double
  499. point_embedding = torch.where(
  500. labels[..., None] != -10,
  501. point_embedding,
  502. torch.zeros_like(point_embedding),
  503. )
  504. # Add point embeddings for labels >= 0
  505. point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
  506. return point_embedding
  507. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  508. """Embeds box prompts."""
  509. boxes = boxes + 0.5 # Shift to center of pixel
  510. coords = boxes.view(*boxes.shape[:2], 2, 2)
  511. # add padding point for consistency with the original implementation
  512. coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
  513. corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
  514. corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
  515. corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
  516. corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
  517. return corner_embedding
  518. def forward(
  519. self,
  520. input_points: tuple[torch.Tensor, torch.Tensor] | None,
  521. input_labels: torch.Tensor | None,
  522. input_boxes: torch.Tensor | None,
  523. input_masks: torch.Tensor | None,
  524. ) -> tuple[torch.Tensor, torch.Tensor]:
  525. """
  526. Embeds different types of prompts, returning both sparse and dense embeddings.
  527. Args:
  528. points (`torch.Tensor`, *optional*):
  529. point coordinates and labels to embed.
  530. boxes (`torch.Tensor`, *optional*):
  531. boxes to embed
  532. masks (`torch.Tensor`, *optional*):
  533. masks to embed
  534. """
  535. sparse_embeddings = None
  536. batch_size = 1
  537. if input_points is not None:
  538. batch_size = input_points.shape[0]
  539. if input_labels is None:
  540. raise ValueError("If points are provided, labels must also be provided.")
  541. point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
  542. sparse_embeddings = point_embeddings
  543. if input_boxes is not None:
  544. batch_size = input_boxes.shape[0]
  545. box_embeddings = self._embed_boxes(input_boxes)
  546. if sparse_embeddings is None:
  547. sparse_embeddings = box_embeddings
  548. else:
  549. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
  550. if input_masks is not None:
  551. dense_embeddings = self.mask_embed(input_masks)
  552. else:
  553. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  554. batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  555. )
  556. return sparse_embeddings, dense_embeddings
  557. class EdgeTamTwoWayTransformer(nn.Module):
  558. def __init__(self, config: EdgeTamMaskDecoderConfig):
  559. super().__init__()
  560. self.config = config
  561. self.num_hidden_layers = config.num_hidden_layers
  562. self.layers = nn.ModuleList()
  563. for i in range(self.num_hidden_layers):
  564. self.layers.append(EdgeTamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
  565. self.final_attn_token_to_image = EdgeTamAttention(config)
  566. self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
  567. def forward(
  568. self,
  569. point_embeddings: Tensor,
  570. image_embeddings: Tensor,
  571. image_positional_embeddings: Tensor,
  572. attention_similarity: Tensor,
  573. target_embedding=None,
  574. **kwargs: Unpack[TransformersKwargs],
  575. ) -> tuple | BaseModelOutput:
  576. if image_embeddings is None:
  577. raise ValueError("You have to specify an image_embedding")
  578. image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  579. image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  580. # Prepare queries
  581. queries = point_embeddings
  582. keys = image_embeddings
  583. # Apply transformer blocks and final layernorm
  584. for layer in self.layers:
  585. if target_embedding is not None:
  586. queries += target_embedding
  587. queries, keys, _ = layer(
  588. queries=queries,
  589. keys=keys,
  590. query_point_embedding=point_embeddings,
  591. key_point_embedding=image_positional_embeddings,
  592. attention_similarity=attention_similarity,
  593. **kwargs,
  594. )
  595. # Apply the final attention layer from the points to the image
  596. query = queries + point_embeddings
  597. key = keys + image_positional_embeddings
  598. attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
  599. queries = queries + attn_out
  600. queries = self.layer_norm_final_attn(queries)
  601. return queries, keys
  602. class EdgeTamMaskDecoder(nn.Module):
  603. def __init__(self, config: EdgeTamMaskDecoderConfig):
  604. super().__init__()
  605. self.config = config
  606. self.hidden_size = config.hidden_size
  607. self.num_multimask_outputs = config.num_multimask_outputs
  608. self.num_mask_tokens = config.num_multimask_outputs + 1
  609. self.iou_token = nn.Embedding(1, self.hidden_size)
  610. self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
  611. self.transformer = EdgeTamTwoWayTransformer(config)
  612. # should we create a new class for this?
  613. self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
  614. self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
  615. self.upscale_layer_norm = EdgeTamLayerNorm(self.hidden_size // 4, data_format="channels_first")
  616. self.activation = nn.GELU()
  617. mlps_list = []
  618. for _ in range(self.num_mask_tokens):
  619. mlps_list += [EdgeTamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
  620. self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
  621. self.iou_prediction_head = EdgeTamFeedForward(
  622. self.hidden_size,
  623. config.iou_head_hidden_dim,
  624. self.num_mask_tokens,
  625. config.iou_head_depth,
  626. sigmoid_output=True,
  627. )
  628. self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
  629. self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
  630. self.obj_score_token = nn.Embedding(1, self.hidden_size)
  631. self.pred_obj_score_head = EdgeTamFeedForward(self.hidden_size, self.hidden_size, 1, 3)
  632. self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
  633. self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
  634. self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
  635. def forward(
  636. self,
  637. image_embeddings: torch.Tensor,
  638. image_positional_embeddings: torch.Tensor,
  639. sparse_prompt_embeddings: torch.Tensor,
  640. dense_prompt_embeddings: torch.Tensor,
  641. multimask_output: bool,
  642. high_resolution_features: list[torch.Tensor],
  643. attention_similarity: torch.Tensor | None = None,
  644. target_embedding: torch.Tensor | None = None,
  645. **kwargs: Unpack[TransformersKwargs],
  646. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  647. """
  648. Predict masks given image and prompt embeddings.
  649. Args:
  650. image_embeddings (`torch.Tensor`):
  651. The embeddings from the image encoder.
  652. image_positional_embeddings (`torch.Tensor`):
  653. Positional encoding with the shape of image_embeddings.
  654. sparse_prompt_embeddings (`torch.Tensor`):
  655. The embeddings of the points and boxes.
  656. dense_prompt_embeddings (`torch.Tensor`):
  657. The embeddings of the mask inputs.
  658. multimask_output (`bool`):
  659. Whether to return multiple masks or a single mask.
  660. high_resolution_features (`list[torch.Tensor]`, *optional*):
  661. The high-resolution features from the vision encoder.
  662. attention_similarity (`torch.Tensor`, *optional*):
  663. The attention similarity tensor.
  664. target_embedding (`torch.Tensor`, *optional*):
  665. The target embedding.
  666. """
  667. batch_size, num_channels, height, width = image_embeddings.shape
  668. point_batch_size = sparse_prompt_embeddings.shape[1]
  669. # Concatenate output tokens
  670. output_tokens = torch.cat(
  671. [
  672. self.obj_score_token.weight,
  673. self.iou_token.weight,
  674. self.mask_tokens.weight,
  675. ],
  676. dim=0,
  677. )
  678. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  679. if sparse_prompt_embeddings.shape[0] != 0:
  680. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  681. else:
  682. tokens = output_tokens
  683. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  684. # Expand per-image data in batch direction to be per-mask
  685. image_embeddings = image_embeddings + dense_prompt_embeddings
  686. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
  687. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  688. # Run the transformer
  689. point_embeddings, image_embeddings = self.transformer(
  690. point_embeddings=point_embeddings,
  691. image_embeddings=image_embeddings,
  692. image_positional_embeddings=image_positional_embeddings,
  693. attention_similarity=attention_similarity,
  694. target_embedding=target_embedding,
  695. **kwargs,
  696. )
  697. iou_token_out = point_embeddings[:, :, 1, :]
  698. mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
  699. # Upscale mask embeddings and predict masks using the mask tokens
  700. image_embeddings = image_embeddings.transpose(2, 3).view(
  701. batch_size * point_batch_size, num_channels, height, width
  702. )
  703. feat_s0, feat_s1 = high_resolution_features
  704. feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
  705. feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
  706. upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
  707. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  708. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
  709. hyper_in_list: list[torch.Tensor] = []
  710. for i in range(self.num_mask_tokens):
  711. current_mlp = self.output_hypernetworks_mlps[i]
  712. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  713. hyper_in = torch.stack(hyper_in_list, dim=2)
  714. _, num_channels, height, width = upscaled_embedding.shape
  715. upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
  716. masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
  717. # Generate mask quality predictions
  718. iou_pred = self.iou_prediction_head(iou_token_out)
  719. object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
  720. # Select the correct mask or masks for output
  721. if multimask_output:
  722. mask_slice = slice(1, None)
  723. masks = masks[:, :, mask_slice, :, :]
  724. iou_pred = iou_pred[:, :, mask_slice]
  725. elif self.dynamic_multimask_via_stability and not self.training:
  726. mask_slice = slice(0, 1)
  727. masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
  728. else:
  729. mask_slice = slice(0, 1)
  730. masks = masks[:, :, mask_slice, :, :]
  731. iou_pred = iou_pred[:, :, mask_slice]
  732. sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
  733. return masks, iou_pred, sam_tokens_out, object_score_logits
  734. def _get_stability_scores(self, mask_logits):
  735. """
  736. Compute stability scores of the mask logits based on the IoU between upper and
  737. lower thresholds.
  738. """
  739. mask_logits = mask_logits.flatten(-2)
  740. stability_delta = self.dynamic_multimask_stability_delta
  741. area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
  742. area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
  743. stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
  744. return stability_scores
  745. def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
  746. """
  747. When outputting a single mask, if the stability score from the current single-mask
  748. output (based on output token 0) falls below a threshold, we instead select from
  749. multi-mask outputs (based on output token 1~3) the mask with the highest predicted
  750. IoU score. This is intended to ensure a valid mask for both clicking and tracking.
  751. """
  752. # The best mask from multimask output tokens (1~3)
  753. multimask_logits = all_mask_logits[:, :, 1:, :, :]
  754. multimask_iou_scores = all_iou_scores[:, :, 1:]
  755. best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
  756. best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
  757. best_scores_inds_expanded = best_scores_inds_expanded.expand(
  758. -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
  759. )
  760. best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
  761. best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
  762. # The mask from singlemask output token 0 and its stability score
  763. singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
  764. singlemask_iou_scores = all_iou_scores[:, :, 0:1]
  765. stability_scores = self._get_stability_scores(singlemask_logits)
  766. is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
  767. # Dynamically fall back to best multimask output upon low stability scores.
  768. mask_logits_out = torch.where(
  769. is_stable[..., None, None].expand_as(singlemask_logits),
  770. singlemask_logits,
  771. best_multimask_logits,
  772. )
  773. iou_scores_out = torch.where(
  774. is_stable.expand_as(singlemask_iou_scores),
  775. singlemask_iou_scores,
  776. best_multimask_iou_scores,
  777. )
  778. return mask_logits_out, iou_scores_out
  779. @auto_docstring(
  780. custom_intro="""
  781. Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
  782. input points and labels, boxes, or masks.
  783. """
  784. )
  785. class EdgeTamModel(EdgeTamPreTrainedModel):
  786. input_modalities = ("image", "text")
  787. _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamTwoWayAttentionBlock, index=2)}
  788. _tied_weights_keys = {}
  789. _keys_to_ignore_on_load_unexpected = [
  790. r"^memory_.*",
  791. r"^mask_downsample.*",
  792. r"spatial_perceiver.*",
  793. r"^object_pointer_proj.*",
  794. r"^temporal_positional_encoding_projection_layer.*",
  795. "no_memory_positional_encoding",
  796. "no_object_pointer",
  797. "occlusion_spatial_embedding_parameter",
  798. ]
  799. def __init__(self, config: EdgeTamConfig):
  800. super().__init__(config)
  801. self.shared_image_embedding = EdgeTamPositionalEmbedding(config.prompt_encoder_config)
  802. self.vision_encoder = AutoModel.from_config(config.vision_config)
  803. self.prompt_encoder = EdgeTamPromptEncoder(config.prompt_encoder_config)
  804. # The module using it is not a PreTrainedModel subclass so we need this
  805. config.mask_decoder_config._attn_implementation = config._attn_implementation
  806. self.mask_decoder = EdgeTamMaskDecoder(config.mask_decoder_config)
  807. self.num_feature_levels = config.vision_config.num_feature_levels
  808. self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
  809. # a single token to indicate no memory embedding from previous frames
  810. self.hidden_dim = config.vision_config.fpn_hidden_size
  811. self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  812. self.post_init()
  813. def get_image_wide_positional_embeddings(self) -> torch.Tensor:
  814. size = self.prompt_encoder.image_embedding_size
  815. target_device = self.shared_image_embedding.positional_embedding.device
  816. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  817. grid = torch.ones(size, device=target_device, dtype=target_dtype)
  818. y_embed = grid.cumsum(dim=0) - 0.5
  819. x_embed = grid.cumsum(dim=1) - 0.5
  820. y_embed = y_embed / size[0]
  821. x_embed = x_embed / size[1]
  822. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  823. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  824. @torch.no_grad()
  825. def get_image_embeddings(
  826. self,
  827. pixel_values: torch.FloatTensor,
  828. **kwargs: Unpack[TransformersKwargs],
  829. ) -> list[torch.Tensor]:
  830. r"""
  831. Returns the image embeddings by passing the pixel values through the vision encoder.
  832. Args:
  833. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  834. Input pixel values
  835. """
  836. batch_size = pixel_values.shape[0]
  837. image_outputs = self.get_image_features(pixel_values, return_dict=True, **kwargs)
  838. feature_maps = image_outputs.fpn_hidden_states
  839. # add no memory embedding to the last feature map
  840. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  841. # reshape feature maps to the same shape as the backbone feature sizes
  842. image_embeddings = [
  843. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  844. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  845. ]
  846. return image_embeddings
  847. @torch.no_grad()
  848. def get_prompt_embeddings(
  849. self,
  850. input_points: torch.FloatTensor | None = None,
  851. input_labels: torch.LongTensor | None = None,
  852. input_boxes: torch.FloatTensor | None = None,
  853. input_masks: torch.LongTensor | None = None,
  854. ):
  855. r"""
  856. Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
  857. Args:
  858. input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
  859. Optional input points for the prompt encoder. The padding of the point is automatically done by the
  860. processor. `point_batch_size` refers to the number of masks that we want the model to predict per
  861. point. The model will output `point_batch_size` times 3 masks in total.
  862. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
  863. Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
  864. processor, or can be fed by the user.
  865. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
  866. Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
  867. processor. users can also pass manually the input boxes.
  868. input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
  869. Optional input masks for the prompt encoder.
  870. """
  871. prompt_output = self.prompt_encoder(
  872. input_points=input_points,
  873. input_labels=input_labels,
  874. input_boxes=input_boxes,
  875. input_masks=input_masks,
  876. )
  877. return prompt_output
  878. @merge_with_config_defaults
  879. @capture_outputs
  880. @auto_docstring
  881. def forward(
  882. self,
  883. pixel_values: torch.FloatTensor | None = None,
  884. input_points: torch.FloatTensor | None = None,
  885. input_labels: torch.LongTensor | None = None,
  886. input_boxes: torch.FloatTensor | None = None,
  887. input_masks: torch.LongTensor | None = None,
  888. image_embeddings: torch.FloatTensor | None = None,
  889. multimask_output: bool = True,
  890. attention_similarity: torch.FloatTensor | None = None,
  891. target_embedding: torch.FloatTensor | None = None,
  892. **kwargs: Unpack[TransformersKwargs],
  893. ) -> EdgeTamImageSegmentationOutput:
  894. r"""
  895. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  896. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  897. better results. The points can be obtained by passing a list of list of list to the processor that will
  898. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  899. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  900. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  901. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  902. coordinates of the point. If a different number of points is passed either for each image, or for each
  903. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  904. computation of the embedding will be skipped for these points using the labels.
  905. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  906. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  907. official implementation, there are 3 types of labels
  908. - `1`: the point is a point that contains the object of interest
  909. - `0`: the point is a point that does not contain the object of interest
  910. - `-1`: the point corresponds to the background
  911. We added the label:
  912. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  913. The padding labels should be automatically done by the processor.
  914. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  915. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  916. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  917. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  918. size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
  919. In the order (`x1`, `y1`, `x2`, `y2`):
  920. - `x1`: the x coordinate of the top left point of the input box
  921. - `y1`: the y coordinate of the top left point of the input box
  922. - `x2`: the x coordinate of the bottom right point of the input box
  923. - `y2`: the y coordinate of the bottom right point of the input box
  924. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  925. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  926. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  927. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  928. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  929. Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
  930. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  931. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  932. multimask_output (`bool`, *optional*):
  933. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  934. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  935. "best" mask, by specifying `multimask_output=False`.
  936. attention_similarity (`torch.FloatTensor`, *optional*):
  937. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  938. model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  939. target_embedding (`torch.FloatTensor`, *optional*):
  940. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  941. the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  942. Example:
  943. ```python
  944. >>> from PIL import Image
  945. >>> import httpx
  946. >>> from io import BytesIO
  947. >>> from transformers import AutoModel, AutoProcessor
  948. >>> model = AutoModel.from_pretrained("danelcsb/edgetam.1_hiera_tiny")
  949. >>> processor = AutoProcessor.from_pretrained("danelcsb/edgetam.1_hiera_tiny")
  950. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  951. >>> with httpx.stream("GET", url) as response:
  952. ... raw_image = Image.open(BytesIO(response.read())).convert("RGB")
  953. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  954. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  955. >>> # Get segmentation mask
  956. >>> outputs = model(**inputs)
  957. >>> # Postprocess masks
  958. >>> masks = processor.post_process_masks(
  959. ... outputs.pred_masks, inputs["original_sizes"]
  960. ... )
  961. ```
  962. """
  963. if not ((pixel_values is None) ^ (image_embeddings is None)):
  964. raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
  965. if input_points is not None and input_boxes is not None:
  966. if input_points.shape[1] != input_boxes.shape[1]:
  967. raise ValueError(
  968. f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
  969. )
  970. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  971. # repeat with batch size
  972. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
  973. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  974. vision_attentions = None
  975. vision_hidden_states = None
  976. if pixel_values is not None:
  977. image_outputs: EdgeTamVisionEncoderOutput = self.get_image_features(
  978. pixel_values, return_dict=True, **kwargs
  979. )
  980. feature_maps = image_outputs.fpn_hidden_states
  981. vision_hidden_states = image_outputs.hidden_states
  982. vision_attentions = image_outputs.attentions
  983. # add no memory embedding to the last feature map
  984. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  985. # reshape feature maps to the same shape as the backbone feature sizes
  986. image_embeddings = [
  987. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  988. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  989. ]
  990. if input_points is not None and input_labels is None:
  991. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  992. if input_points is None and input_boxes is None:
  993. # If no points are provide, pad with an empty point (with label -1)
  994. input_points = torch.zeros(
  995. batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
  996. )
  997. input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
  998. if input_masks is not None:
  999. # If mask_inputs is provided, downsize it into low-res mask input if needed
  1000. # and feed it as a dense mask prompt into the SAM mask encoder
  1001. if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
  1002. input_masks = F.interpolate(
  1003. input_masks.float(),
  1004. size=self.prompt_encoder.mask_input_size,
  1005. align_corners=False,
  1006. mode="bilinear",
  1007. antialias=True, # use antialias for downsampling
  1008. ).to(input_masks.dtype)
  1009. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1010. input_points=input_points,
  1011. input_labels=input_labels,
  1012. input_boxes=input_boxes,
  1013. input_masks=input_masks,
  1014. )
  1015. low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
  1016. image_embeddings=image_embeddings[-1],
  1017. image_positional_embeddings=image_positional_embeddings,
  1018. sparse_prompt_embeddings=sparse_embeddings,
  1019. dense_prompt_embeddings=dense_embeddings,
  1020. multimask_output=multimask_output,
  1021. high_resolution_features=image_embeddings[:-1],
  1022. attention_similarity=attention_similarity,
  1023. target_embedding=target_embedding,
  1024. **kwargs,
  1025. )
  1026. return EdgeTamImageSegmentationOutput(
  1027. iou_scores=iou_scores,
  1028. pred_masks=low_res_multimasks,
  1029. object_score_logits=object_score_logits,
  1030. image_embeddings=image_embeddings,
  1031. vision_hidden_states=vision_hidden_states,
  1032. vision_attentions=vision_attentions,
  1033. )
  1034. @can_return_tuple
  1035. @auto_docstring
  1036. def get_image_features(
  1037. self,
  1038. pixel_values: torch.FloatTensor,
  1039. **kwargs: Unpack[TransformersKwargs],
  1040. ) -> tuple | EdgeTamVisionEncoderOutput:
  1041. r"""
  1042. pixel_values (`torch.FloatTensor`):
  1043. Input pixel values of shape `(batch_size, num_channels, height, width)`.
  1044. """
  1045. vision_outputs: EdgeTamVisionEncoderOutput = self.vision_encoder(pixel_values, return_dict=True, **kwargs)
  1046. feature_maps = vision_outputs.fpn_hidden_states
  1047. feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
  1048. # precompute projected level 0 and level 1 features in SAM decoder
  1049. # to avoid running it again on every SAM click
  1050. feature_maps = list(feature_maps)
  1051. feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
  1052. feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
  1053. # flatten NxCxHxW to HWxNxC
  1054. feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
  1055. feature_maps_position_embeddings = [
  1056. feature_map_position_embedding.flatten(2).permute(2, 0, 1)
  1057. for feature_map_position_embedding in feature_maps_position_embeddings
  1058. ]
  1059. vision_outputs.fpn_hidden_states = feature_maps
  1060. vision_outputs.fpn_position_encoding = feature_maps_position_embeddings
  1061. return vision_outputs
  1062. __all__ = ["EdgeTamModel", "EdgeTamVisionModel", "EdgeTamPreTrainedModel"]