modeling_seggpt.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SegGpt model."""
  15. import collections.abc
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import functional as F
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_utils import PreTrainedModel
  24. from ...utils import ModelOutput, auto_docstring, logging, torch_int
  25. from .configuration_seggpt import SegGptConfig
  26. logger = logging.get_logger(__name__)
  27. @dataclass
  28. @auto_docstring(
  29. custom_intro="""
  30. Output type of [`SegGptEncoderOutput`].
  31. """
  32. )
  33. class SegGptEncoderOutput(ModelOutput):
  34. r"""
  35. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`):
  36. Sequence of hidden-states at the output of the last layer of the model.
  37. hidden_states (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
  38. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
  39. of shape `(batch_size, patch_height, patch_width, hidden_size)`.
  40. attentions (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
  41. Tuple of *torch.FloatTensor* (one for each layer) of shape
  42. `(batch_size, num_heads, seq_len, seq_len)`.
  43. intermediate_hidden_states (`tuple[torch.FloatTensor]`, *optional*, returned when `config.intermediate_hidden_state_indices` is set):
  44. Tuple of `torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`.
  45. Each element in the Tuple corresponds to the output of the layer specified in `config.intermediate_hidden_state_indices`.
  46. Additionally, each feature passes through a LayerNorm.
  47. """
  48. last_hidden_state: torch.FloatTensor
  49. hidden_states: tuple[torch.FloatTensor] | None = None
  50. attentions: tuple[torch.FloatTensor] | None = None
  51. intermediate_hidden_states: tuple[torch.FloatTensor] | None = None
  52. @dataclass
  53. @auto_docstring(
  54. custom_intro="""
  55. Output type of [`SegGptImageSegmentationOutput`].
  56. """
  57. )
  58. class SegGptImageSegmentationOutput(ModelOutput):
  59. r"""
  60. loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
  61. The loss value.
  62. pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  63. The predicted masks.
  64. hidden_states (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
  65. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
  66. of shape `(batch_size, patch_height, patch_width, hidden_size)`.
  67. attentions (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
  68. Tuple of `torch.FloatTensor` (one for each layer) of shape
  69. `(batch_size, num_heads, seq_len, seq_len)`.
  70. """
  71. loss: torch.FloatTensor | None = None
  72. pred_masks: torch.FloatTensor | None = None
  73. hidden_states: tuple[torch.FloatTensor] | None = None
  74. attentions: tuple[torch.FloatTensor] | None = None
  75. # Copied from transformers.models.sam.modeling_sam.SamPatchEmbeddings with Sam->SegGpt
  76. class SegGptPatchEmbeddings(nn.Module):
  77. """
  78. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  79. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  80. Transformer.
  81. """
  82. def __init__(self, config):
  83. super().__init__()
  84. image_size, patch_size = config.image_size, config.patch_size
  85. num_channels, hidden_size = config.num_channels, config.hidden_size
  86. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  87. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  88. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  89. self.image_size = image_size
  90. self.patch_size = patch_size
  91. self.num_channels = num_channels
  92. self.num_patches = num_patches
  93. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  94. def forward(self, pixel_values):
  95. batch_size, num_channels, height, width = pixel_values.shape
  96. if num_channels != self.num_channels:
  97. raise ValueError(
  98. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  99. )
  100. if height != self.image_size[0] or width != self.image_size[1]:
  101. raise ValueError(
  102. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  103. )
  104. embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
  105. return embeddings
  106. class SegGptEmbeddings(nn.Module):
  107. """
  108. Construct the embeddings from patch, position embeddings for input and prompt.
  109. """
  110. def __init__(self, config: SegGptConfig) -> None:
  111. super().__init__()
  112. self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
  113. self.segment_token_input = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
  114. self.segment_token_prompt = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
  115. # token for seg types
  116. self.type_token_semantic = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
  117. self.type_token_instance = nn.Parameter(torch.zeros(1, 1, 1, config.hidden_size))
  118. self.patch_embeddings = SegGptPatchEmbeddings(config)
  119. num_positions = (config.pretrain_image_size // config.patch_size) ** 2 + 1
  120. self.position_embeddings = nn.Parameter(torch.randn(1, num_positions, config.hidden_size))
  121. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  122. def interpolate_pos_encoding(self, height: int, width: int) -> torch.Tensor:
  123. patch_pos_embed = self.position_embeddings[:, 1:]
  124. num_patches = patch_pos_embed.shape[1]
  125. pretrain_patch_size = torch_int(num_patches**0.5)
  126. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  127. if torch.jit.is_tracing() or pretrain_patch_size != height or pretrain_patch_size != width:
  128. patch_pos_embed = F.interpolate(
  129. patch_pos_embed.reshape(1, pretrain_patch_size, pretrain_patch_size, -1).permute(0, 3, 1, 2),
  130. size=(height, width),
  131. mode="bicubic",
  132. align_corners=False,
  133. )
  134. return patch_pos_embed.permute(0, 2, 3, 1)
  135. else:
  136. return patch_pos_embed.reshape(1, height, width, -1)
  137. def forward(
  138. self,
  139. pixel_values: torch.Tensor,
  140. prompt_pixel_values: torch.Tensor,
  141. bool_masked_pos: torch.BoolTensor | None = None,
  142. embedding_type: str | None = None,
  143. ) -> torch.Tensor:
  144. input_embeddings = self.patch_embeddings(pixel_values)
  145. prompt_embeddings = self.patch_embeddings(prompt_pixel_values)
  146. batch_size, patch_height, patch_width, _ = input_embeddings.shape
  147. mask_token = self.mask_token.expand(batch_size, patch_height, patch_width, -1)
  148. # replace the masked visual tokens by mask_token
  149. w = bool_masked_pos.unsqueeze(-1).type_as(mask_token).reshape(-1, patch_height, patch_width, 1)
  150. prompt_embeddings = prompt_embeddings * (1 - w) + mask_token * w
  151. embedding_type = embedding_type if embedding_type is not None else "instance"
  152. # add positional encoding to each token
  153. pos_embed = self.interpolate_pos_encoding(patch_height, patch_width)
  154. # add segment token
  155. input_embeddings = input_embeddings + self.segment_token_input
  156. prompt_embeddings = prompt_embeddings + self.segment_token_prompt
  157. # add position embedding skipping CLS
  158. input_embeddings = input_embeddings + pos_embed
  159. prompt_embeddings = prompt_embeddings + pos_embed
  160. # add type embedding to each token
  161. if embedding_type == "semantic":
  162. type_embedding = self.type_token_semantic
  163. elif embedding_type == "instance":
  164. type_embedding = self.type_token_instance
  165. else:
  166. raise ValueError(f"Embedding type should be either 'semantic' or 'instance', but got {embedding_type}")
  167. input_embeddings = input_embeddings + type_embedding
  168. prompt_embeddings = prompt_embeddings + type_embedding
  169. embeddings = torch.cat((input_embeddings, prompt_embeddings), dim=0)
  170. return embeddings
  171. class SegGptAttention(nn.Module):
  172. """Multi-head Attention block with relative position embeddings."""
  173. def __init__(self, config):
  174. super().__init__()
  175. image_size, patch_size = config.image_size, config.patch_size
  176. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  177. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  178. input_size = (image_size[0] // config.patch_size, image_size[1] // config.patch_size)
  179. head_dim = config.hidden_size // config.num_attention_heads
  180. self.num_attention_heads = config.num_attention_heads
  181. self.scale = head_dim**-0.5
  182. self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
  183. self.proj = nn.Linear(config.hidden_size, config.hidden_size)
  184. self.use_relative_position_embeddings = config.use_relative_position_embeddings
  185. if self.use_relative_position_embeddings:
  186. if input_size is None:
  187. raise ValueError("Input size must be provided if using relative positional encoding.")
  188. # initialize relative positional embeddings
  189. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  190. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  191. def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
  192. """
  193. Get relative positional embeddings according to the relative positions of
  194. query and key sizes.
  195. Args:
  196. q_size (int):
  197. size of the query.
  198. k_size (int):
  199. size of key k.
  200. rel_pos (`torch.Tensor`):
  201. relative position embeddings (L, channel).
  202. Returns:
  203. Extracted positional embeddings according to relative positions.
  204. """
  205. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  206. # Interpolate rel pos.
  207. rel_pos_resized = F.interpolate(
  208. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  209. size=max_rel_dist,
  210. mode="linear",
  211. )
  212. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  213. # Scale the coords with short length if shapes for q and k are different.
  214. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  215. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  216. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  217. return rel_pos_resized[relative_coords.long()]
  218. def add_decomposed_rel_pos(
  219. self,
  220. attn: torch.Tensor,
  221. query: torch.Tensor,
  222. rel_pos_h: torch.Tensor,
  223. rel_pos_w: torch.Tensor,
  224. q_size: tuple[int, int],
  225. k_size: tuple[int, int],
  226. ) -> torch.Tensor:
  227. """
  228. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  229. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
  230. Args:
  231. attn (`torch.Tensor`):
  232. attention map.
  233. query (`torch.Tensor`):
  234. query q in the attention layer with shape (batch_size, query_height * query_width, channel).
  235. rel_pos_h (`torch.Tensor`):
  236. relative position embeddings (Lh, channel) for height axis.
  237. rel_pos_w (`torch.Tensor`):
  238. relative position embeddings (Lw, channel) for width axis.
  239. q_size (tuple):
  240. spatial sequence size of query q with (query_height, query_width).
  241. k_size (tuple):
  242. spatial sequence size of key k with (key_height, key_width).
  243. Returns:
  244. attn (`torch.Tensor`):
  245. attention map with added relative positional embeddings.
  246. """
  247. query_height, query_width = q_size
  248. key_height, key_width = k_size
  249. relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
  250. relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
  251. batch_size, _, dim = query.shape
  252. reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
  253. rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
  254. rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
  255. attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
  256. attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
  257. attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
  258. return attn
  259. def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
  260. batch_size, height, width, _ = hidden_states.shape
  261. # qkv with shape (3, batch_size, nHead, height * width, channel)
  262. qkv = (
  263. self.qkv(hidden_states)
  264. .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  265. .permute(2, 0, 3, 1, 4)
  266. )
  267. # q, k, v with shape (batch_size * nHead, height * width, channel)
  268. query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
  269. attn_weights = (query * self.scale) @ key.transpose(-2, -1)
  270. if self.use_relative_position_embeddings:
  271. attn_weights = self.add_decomposed_rel_pos(
  272. attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
  273. )
  274. attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
  275. if output_attentions:
  276. # this operation is a bit awkward, but it's required to
  277. # make sure that attn_weights keeps its gradient.
  278. # In order to do so, attn_weights have to reshaped
  279. # twice and have to be reused in the following
  280. attn_weights_reshaped = attn_weights.view(batch_size, self.num_attention_heads, height * width, -1)
  281. attn_weights = attn_weights_reshaped.view(batch_size * self.num_attention_heads, height * width, -1)
  282. else:
  283. attn_weights_reshaped = None
  284. attn_output = (attn_weights @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
  285. attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
  286. attn_output = self.proj(attn_output)
  287. return (attn_output, attn_weights_reshaped)
  288. # Copied from transformers.models.sam.modeling_sam.SamMLPBlock with SamMLPBlock->SegGptMlp
  289. class SegGptMlp(nn.Module):
  290. def __init__(self, config):
  291. super().__init__()
  292. self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
  293. self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
  294. self.act = ACT2FN[config.hidden_act]
  295. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  296. hidden_states = self.lin1(hidden_states)
  297. hidden_states = self.act(hidden_states)
  298. hidden_states = self.lin2(hidden_states)
  299. return hidden_states
  300. # Copied from transformers.models.beit.modeling_beit.drop_path
  301. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  302. """
  303. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  304. """
  305. if drop_prob == 0.0 or not training:
  306. return input
  307. keep_prob = 1 - drop_prob
  308. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  309. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  310. random_tensor.floor_() # binarize
  311. output = input.div(keep_prob) * random_tensor
  312. return output
  313. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->SegGpt
  314. class SegGptDropPath(nn.Module):
  315. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  316. def __init__(self, drop_prob: float | None = None) -> None:
  317. super().__init__()
  318. self.drop_prob = drop_prob
  319. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  320. return drop_path(hidden_states, self.drop_prob, self.training)
  321. def extra_repr(self) -> str:
  322. return f"p={self.drop_prob}"
  323. class SegGptLayer(GradientCheckpointingLayer):
  324. def __init__(self, config: SegGptConfig, drop_path_rate: float) -> None:
  325. super().__init__()
  326. self.attention = SegGptAttention(config)
  327. self.mlp = SegGptMlp(config)
  328. self.drop_path = SegGptDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  329. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  330. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  331. def forward(
  332. self,
  333. hidden_states: torch.Tensor,
  334. ensemble_cond: int,
  335. feature_ensemble: bool = False,
  336. output_attentions: bool = False,
  337. ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
  338. self_attention_outputs = self.attention(
  339. self.layernorm_before(hidden_states), # in SegGpt, layernorm is applied before self-attention
  340. output_attentions=output_attentions,
  341. )
  342. attention_output = self_attention_outputs[0]
  343. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  344. if feature_ensemble and attention_output.shape[0] // 2 >= ensemble_cond:
  345. prompt, inputs = attention_output.split(attention_output.shape[1] // 2, dim=1)
  346. if ensemble_cond == 2:
  347. num_prompts = attention_output.shape[0] // 2
  348. inputs = inputs.reshape(2, num_prompts, -1)
  349. inputs = inputs.mean(dim=1, keepdim=True).expand_as(inputs)
  350. inputs = inputs.reshape(*prompt.shape)
  351. else:
  352. inputs = inputs.mean(dim=0, keepdim=True).expand_as(inputs)
  353. attention_output = torch.cat([prompt, inputs], dim=1)
  354. # first residual connection
  355. hidden_states = self.drop_path(attention_output) + hidden_states
  356. residual = hidden_states
  357. hidden_states = self.layernorm_after(hidden_states)
  358. hidden_states = self.mlp(hidden_states)
  359. hidden_states = residual + self.drop_path(hidden_states)
  360. outputs = (hidden_states,) + outputs
  361. return outputs
  362. class SegGptEncoder(nn.Module):
  363. def __init__(self, config: SegGptConfig) -> None:
  364. super().__init__()
  365. self.config = config
  366. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
  367. self.layers = nn.ModuleList([SegGptLayer(config, dpr[i]) for i in range(config.num_hidden_layers)])
  368. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  369. self.gradient_checkpointing = False
  370. def forward(
  371. self,
  372. hidden_states: torch.Tensor,
  373. feature_ensemble: bool = False,
  374. output_attentions: bool = False,
  375. output_hidden_states: bool = False,
  376. return_dict: bool = True,
  377. ) -> tuple | SegGptEncoderOutput:
  378. all_hidden_states = () if output_hidden_states else None
  379. all_self_attentions = () if output_attentions else None
  380. intermediate_hidden_states = []
  381. for i, layer_module in enumerate(self.layers):
  382. if output_hidden_states:
  383. all_hidden_states = all_hidden_states + (hidden_states,)
  384. # Condition to check if we have the appropriate number of prompts to ensemble
  385. ensemble_cond = 2 if self.config.merge_index > i else 1
  386. layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions)
  387. hidden_states = layer_outputs[0]
  388. if i == self.config.merge_index:
  389. hidden_states = (
  390. hidden_states[: hidden_states.shape[0] // 2] + hidden_states[hidden_states.shape[0] // 2 :]
  391. ) * 0.5
  392. if i in self.config.intermediate_hidden_state_indices:
  393. intermediate_hidden_states.append(self.layernorm(hidden_states))
  394. if output_attentions:
  395. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  396. if output_hidden_states:
  397. all_hidden_states = all_hidden_states + (hidden_states,)
  398. if not return_dict:
  399. return tuple(
  400. v
  401. for v in [hidden_states, all_hidden_states, all_self_attentions, intermediate_hidden_states]
  402. if v is not None
  403. )
  404. return SegGptEncoderOutput(
  405. last_hidden_state=hidden_states,
  406. hidden_states=all_hidden_states,
  407. attentions=all_self_attentions,
  408. intermediate_hidden_states=intermediate_hidden_states,
  409. )
  410. # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->SegGpt
  411. class SegGptLayerNorm(nn.LayerNorm):
  412. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  413. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  414. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  415. """
  416. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  417. super().__init__(normalized_shape, eps=eps, **kwargs)
  418. if data_format not in ["channels_last", "channels_first"]:
  419. raise NotImplementedError(f"Unsupported data format: {data_format}")
  420. self.data_format = data_format
  421. def forward(self, features: torch.Tensor) -> torch.Tensor:
  422. """
  423. Args:
  424. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  425. """
  426. if self.data_format == "channels_first":
  427. features = features.permute(0, 2, 3, 1)
  428. features = super().forward(features)
  429. features = features.permute(0, 3, 1, 2)
  430. else:
  431. features = super().forward(features)
  432. return features
  433. class SegGptDecoderHead(nn.Module):
  434. def __init__(self, config):
  435. super().__init__()
  436. self.conv = nn.Conv2d(
  437. config.decoder_hidden_size,
  438. config.decoder_hidden_size,
  439. kernel_size=3,
  440. padding=1,
  441. )
  442. self.layernorm = SegGptLayerNorm(
  443. normalized_shape=config.decoder_hidden_size, eps=config.layer_norm_eps, data_format="channels_first"
  444. )
  445. self.act_fct = ACT2FN[config.hidden_act]
  446. self.head = nn.Conv2d(config.decoder_hidden_size, 3, kernel_size=1, bias=True) # decoder to patch
  447. def forward(self, hidden_states: torch.FloatTensor):
  448. hidden_states = self.conv(hidden_states)
  449. hidden_states = self.layernorm(hidden_states)
  450. hidden_states = self.act_fct(hidden_states)
  451. hidden_states = self.head(hidden_states)
  452. return hidden_states
  453. class SegGptDecoder(nn.Module):
  454. def __init__(self, config):
  455. super().__init__()
  456. self.decoder_embed = nn.Linear(
  457. config.hidden_size * len(config.intermediate_hidden_state_indices),
  458. config.patch_size**2 * config.decoder_hidden_size,
  459. bias=True,
  460. )
  461. self.decoder_pred = SegGptDecoderHead(config)
  462. self.patch_size = config.patch_size
  463. self.decoder_hidden_size = config.decoder_hidden_size
  464. self.config = config
  465. def _reshape_hidden_states(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  466. batch_size, patch_height, patch_width, _ = hidden_states.shape
  467. hidden_states = hidden_states.reshape(
  468. batch_size, patch_height, patch_width, self.patch_size, self.patch_size, self.decoder_hidden_size
  469. )
  470. hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
  471. hidden_states = hidden_states.reshape(
  472. shape=(batch_size, -1, patch_height * self.patch_size, patch_width * self.patch_size)
  473. )
  474. return hidden_states
  475. def forward(self, hidden_states: torch.FloatTensor):
  476. hidden_states = self.decoder_embed(hidden_states)
  477. hidden_states = self._reshape_hidden_states(hidden_states)
  478. hidden_states = self.decoder_pred(hidden_states)
  479. return hidden_states
  480. @auto_docstring
  481. class SegGptPreTrainedModel(PreTrainedModel):
  482. config: SegGptConfig
  483. base_model_prefix = "model"
  484. main_input_name = "pixel_values"
  485. input_modalities = ("image",)
  486. supports_gradient_checkpointing = True
  487. _no_split_modules = ["SegGptEmbeddings", "SegGptLayer"]
  488. @torch.no_grad()
  489. def _init_weights(self, module: nn.Module) -> None:
  490. """Initialize the weights"""
  491. std = self.config.initializer_range
  492. if isinstance(module, (nn.Linear, nn.Conv2d)):
  493. init.trunc_normal_(module.weight, mean=0.0, std=std)
  494. if module.bias is not None:
  495. init.zeros_(module.bias)
  496. elif isinstance(module, (nn.LayerNorm, SegGptLayerNorm)):
  497. init.zeros_(module.bias)
  498. init.ones_(module.weight)
  499. elif isinstance(module, SegGptAttention):
  500. init.trunc_normal_(module.rel_pos_h, mean=0.0, std=std)
  501. init.trunc_normal_(module.rel_pos_w, mean=0.0, std=std)
  502. elif isinstance(module, SegGptEmbeddings):
  503. init.trunc_normal_(module.position_embeddings, mean=0.0, std=std)
  504. init.normal_(module.mask_token, std=std)
  505. init.normal_(module.segment_token_input, std=std)
  506. init.normal_(module.segment_token_prompt, std=std)
  507. init.normal_(module.type_token_semantic, std=std)
  508. init.normal_(module.type_token_instance, std=std)
  509. @auto_docstring
  510. class SegGptModel(SegGptPreTrainedModel):
  511. def __init__(self, config: SegGptConfig):
  512. super().__init__(config)
  513. self.config = config
  514. self.embeddings = SegGptEmbeddings(config)
  515. self.encoder = SegGptEncoder(config)
  516. # Initialize weights and apply final processing
  517. self.post_init()
  518. def get_input_embeddings(self) -> SegGptPatchEmbeddings:
  519. return self.embeddings.patch_embeddings
  520. @auto_docstring
  521. def forward(
  522. self,
  523. pixel_values: torch.Tensor,
  524. prompt_pixel_values: torch.Tensor,
  525. prompt_masks: torch.Tensor,
  526. bool_masked_pos: torch.BoolTensor | None = None,
  527. feature_ensemble: bool | None = None,
  528. embedding_type: str | None = None,
  529. labels: torch.FloatTensor | None = None,
  530. output_attentions: bool | None = None,
  531. output_hidden_states: bool | None = None,
  532. return_dict: bool | None = None,
  533. **kwargs,
  534. ) -> tuple | SegGptEncoderOutput:
  535. r"""
  536. prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  537. Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See
  538. [`SegGptImageProcessor.__call__`] for details.
  539. prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  540. Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for
  541. details.
  542. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  543. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  544. feature_ensemble (`bool`, *optional*):
  545. Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble
  546. if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should
  547. be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image.
  548. embedding_type (`str`, *optional*):
  549. Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either
  550. instance or semantic.
  551. labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
  552. Ground truth mask for input images.
  553. Examples:
  554. ```python
  555. >>> from transformers import SegGptImageProcessor, SegGptModel
  556. >>> from PIL import Image
  557. >>> import httpx
  558. >>> from io import BytesIO
  559. >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
  560. >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
  561. >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
  562. >>> with httpx.stream("GET", image_input_url) as response:
  563. ... image_input = Image.open(BytesIO(response.read()))
  564. >>> with httpx.stream("GET", image_prompt_url) as response:
  565. ... image_prompt = Image.open(BytesIO(response.read()))
  566. >>> with httpx.stream("GET", mask_prompt_url) as response:
  567. ... mask_prompt = Image.open(BytesIO(response.read())).convert("L")
  568. >>> checkpoint = "BAAI/seggpt-vit-large"
  569. >>> model = SegGptModel.from_pretrained(checkpoint)
  570. >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
  571. >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
  572. >>> outputs = model(**inputs)
  573. >>> list(outputs.last_hidden_state.shape)
  574. [1, 56, 28, 1024]
  575. ```
  576. """
  577. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  578. output_hidden_states = (
  579. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  580. )
  581. return_dict = return_dict if return_dict is not None else self.config.return_dict
  582. feature_ensemble = feature_ensemble if feature_ensemble is not None else False
  583. expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
  584. pixel_values = pixel_values.to(expected_dtype)
  585. prompt_pixel_values = prompt_pixel_values.to(expected_dtype)
  586. # Prepare inputs
  587. pixel_values = torch.cat((prompt_pixel_values, pixel_values), dim=2)
  588. prompt_pixel_values = (
  589. torch.cat((prompt_masks, prompt_masks), dim=2)
  590. if labels is None
  591. else torch.cat((prompt_masks, labels), dim=2)
  592. )
  593. if bool_masked_pos is None and labels is not None:
  594. logger.warning_once(
  595. "Labels were provided, but bool_masked_pos were not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos."
  596. )
  597. # We concat on height axis so SegGPT can handle as a single image, hence we need to mask the portion
  598. # of the mask prompt pixels that will be destinated to the prediction as they don't add any information.
  599. # This is only the case for inference. In training, the model concat of prompt mask and label is masked
  600. # and reconstructed together (In-Context Painting).
  601. if bool_masked_pos is None:
  602. num_patches = self.embeddings.patch_embeddings.num_patches
  603. bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
  604. bool_masked_pos_ones = torch.ones(
  605. num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
  606. )
  607. bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
  608. bool_masked_pos = bool_masked_pos.unsqueeze(0)
  609. embedding_output = self.embeddings(
  610. pixel_values, prompt_pixel_values, embedding_type=embedding_type, bool_masked_pos=bool_masked_pos
  611. )
  612. encoder_outputs = self.encoder(
  613. embedding_output,
  614. feature_ensemble=feature_ensemble,
  615. output_attentions=output_attentions,
  616. output_hidden_states=output_hidden_states,
  617. return_dict=return_dict,
  618. )
  619. return encoder_outputs
  620. def patchify(tensor: torch.Tensor, patch_size: int) -> torch.Tensor:
  621. batch_size, num_channels, height, width = tensor.shape
  622. patch_height = height // patch_size
  623. patch_width = width // patch_size
  624. tensor = tensor.reshape(shape=(batch_size, num_channels, patch_height, patch_size, patch_width, patch_size))
  625. tensor = tensor.permute(0, 2, 4, 3, 5, 1)
  626. tensor = tensor.reshape(shape=(batch_size, patch_height * patch_width, patch_size**2 * 3))
  627. return tensor
  628. def unpatchify(tensor: torch.Tensor, patch_height: int, patch_width: int) -> torch.Tensor:
  629. batch_size = tensor.shape[0]
  630. patch_size = int((tensor.shape[-1] / 3) ** 0.5)
  631. if patch_height * patch_width != tensor.shape[1]:
  632. raise ValueError(
  633. f"Number of patches {tensor.shape[1]} does not match patch height ({patch_height}) and width ({patch_width})."
  634. )
  635. tensor = tensor.reshape(shape=(batch_size, patch_height, patch_width, patch_size, patch_size, 3))
  636. tensor = tensor.permute(0, 5, 1, 3, 2, 4)
  637. tensor = tensor.reshape(shape=(batch_size, 3, patch_height * patch_size, patch_width * patch_size))
  638. return tensor
  639. class SegGptLoss(nn.Module):
  640. def __init__(self, config):
  641. super().__init__()
  642. self.beta = config.beta
  643. self.patch_size = config.patch_size
  644. def forward(
  645. self,
  646. prompt_masks: torch.FloatTensor,
  647. pred_masks: torch.FloatTensor,
  648. labels: torch.FloatTensor,
  649. bool_masked_pos: torch.BoolTensor,
  650. ):
  651. """Computes the L1 loss between the predicted masks and the ground truth masks.
  652. Args:
  653. prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  654. Pixel values from mask prompt.
  655. pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
  656. Predicted masks.
  657. labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  658. Ground truth mask for input images.
  659. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  660. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  661. Returns:
  662. `torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks.
  663. """
  664. ground_truth = torch.cat((prompt_masks, labels), dim=2)
  665. mask = bool_masked_pos[:, :, None].repeat(1, 1, self.patch_size**2 * 3)
  666. mask = unpatchify(mask, ground_truth.shape[2] // self.patch_size, ground_truth.shape[3] // self.patch_size)
  667. loss = F.smooth_l1_loss(pred_masks, ground_truth, reduction="none", beta=self.beta)
  668. loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
  669. return loss
  670. @auto_docstring(
  671. custom_intro="""
  672. SegGpt model with a decoder on top for one-shot image segmentation.
  673. """
  674. )
  675. class SegGptForImageSegmentation(SegGptPreTrainedModel):
  676. def __init__(self, config: SegGptConfig):
  677. super().__init__(config)
  678. self.config = config
  679. self.model = SegGptModel(config)
  680. self.decoder = SegGptDecoder(config)
  681. # Initialize weights and apply final processing
  682. self.post_init()
  683. @auto_docstring
  684. def forward(
  685. self,
  686. pixel_values: torch.Tensor,
  687. prompt_pixel_values: torch.Tensor,
  688. prompt_masks: torch.Tensor,
  689. bool_masked_pos: torch.BoolTensor | None = None,
  690. feature_ensemble: bool | None = None,
  691. embedding_type: str | None = None,
  692. labels: torch.FloatTensor | None = None,
  693. output_attentions: bool | None = None,
  694. output_hidden_states: bool | None = None,
  695. return_dict: bool | None = None,
  696. **kwargs,
  697. ) -> tuple | SegGptImageSegmentationOutput:
  698. r"""
  699. prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  700. Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See
  701. [`SegGptImageProcessor.__call__`] for details.
  702. prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  703. Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for
  704. details.
  705. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  706. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  707. feature_ensemble (`bool`, *optional*):
  708. Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble
  709. if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should
  710. be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image.
  711. embedding_type (`str`, *optional*):
  712. Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either
  713. instance or semantic.
  714. labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
  715. Ground truth mask for input images.
  716. Examples:
  717. ```python
  718. >>> from transformers import SegGptImageProcessor, SegGptForImageSegmentation
  719. >>> from PIL import Image
  720. >>> import httpx
  721. >>> from io import BytesIO
  722. >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
  723. >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
  724. >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
  725. >>> with httpx.stream("GET", image_input_url) as response:
  726. ... image_input = Image.open(BytesIO(response.read()))
  727. >>> with httpx.stream("GET", image_prompt_url) as response:
  728. ... image_prompt = Image.open(BytesIO(response.read()))
  729. >>> with httpx.stream("GET", mask_prompt_url) as response:
  730. ... mask_prompt = Image.open(BytesIO(response.read())).convert("L")
  731. >>> checkpoint = "BAAI/seggpt-vit-large"
  732. >>> model = SegGptForImageSegmentation.from_pretrained(checkpoint)
  733. >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
  734. >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
  735. >>> outputs = model(**inputs)
  736. >>> result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[(image_input.height, image_input.width)])[0]
  737. >>> print(list(result.shape))
  738. [170, 297]
  739. ```
  740. """
  741. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  742. output_hidden_states = (
  743. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  744. )
  745. return_dict = return_dict if return_dict is not None else self.config.return_dict
  746. if bool_masked_pos is None:
  747. num_patches = self.model.embeddings.patch_embeddings.num_patches
  748. bool_masked_pos_zeros = torch.zeros(num_patches // 2, dtype=torch.bool, device=pixel_values.device)
  749. bool_masked_pos_ones = torch.ones(
  750. num_patches - num_patches // 2, dtype=torch.bool, device=pixel_values.device
  751. )
  752. bool_masked_pos = torch.cat([bool_masked_pos_zeros, bool_masked_pos_ones])
  753. bool_masked_pos = bool_masked_pos.unsqueeze(0)
  754. outputs = self.model(
  755. pixel_values=pixel_values,
  756. prompt_pixel_values=prompt_pixel_values,
  757. prompt_masks=prompt_masks,
  758. bool_masked_pos=bool_masked_pos,
  759. feature_ensemble=feature_ensemble,
  760. embedding_type=embedding_type,
  761. labels=labels,
  762. output_attentions=output_attentions,
  763. output_hidden_states=output_hidden_states,
  764. return_dict=return_dict,
  765. )
  766. intermediate_hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[-1]
  767. intermediate_hidden_states = torch.cat(intermediate_hidden_states, dim=-1)
  768. pred_masks = self.decoder(intermediate_hidden_states)
  769. loss = None
  770. if labels is not None:
  771. loss_fn = SegGptLoss(self.config)
  772. loss = loss_fn(prompt_masks, pred_masks, labels, bool_masked_pos)
  773. if not return_dict:
  774. output = (pred_masks,)
  775. if output_hidden_states:
  776. output = output + (outputs[1],)
  777. if output_attentions:
  778. idx = 2 if output_hidden_states else 1
  779. output = output + (outputs[idx],)
  780. if loss is not None:
  781. output = (loss,) + output
  782. return output
  783. return SegGptImageSegmentationOutput(
  784. loss=loss,
  785. pred_masks=pred_masks,
  786. hidden_states=outputs.hidden_states,
  787. attentions=outputs.attentions,
  788. )
  789. __all__ = ["SegGptModel", "SegGptPreTrainedModel", "SegGptForImageSegmentation"]