modeling_clipseg.py 49 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242
  1. # Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch CLIPSeg model."""
  15. import copy
  16. import math
  17. from collections.abc import Callable
  18. from dataclasses import dataclass
  19. from typing import Any
  20. import torch
  21. from torch import nn
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...masking_utils import create_causal_mask
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int
  30. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  31. from ...utils.output_capturing import capture_outputs
  32. from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
  33. logger = logging.get_logger(__name__)
  34. # contrastive loss function, adapted from
  35. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  36. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  37. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  38. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clipseg
  39. def clipseg_loss(similarity: torch.Tensor) -> torch.Tensor:
  40. caption_loss = contrastive_loss(similarity)
  41. image_loss = contrastive_loss(similarity.t())
  42. return (caption_loss + image_loss) / 2.0
  43. @dataclass
  44. @auto_docstring
  45. # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->CLIPSeg
  46. class CLIPSegOutput(ModelOutput):
  47. r"""
  48. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  49. Contrastive loss for image-text similarity.
  50. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  51. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  52. similarity scores.
  53. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  54. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  55. similarity scores.
  56. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  57. The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegTextModel`].
  58. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  59. The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
  60. text_model_output (`BaseModelOutputWithPooling`):
  61. The output of the [`CLIPSegTextModel`].
  62. vision_model_output (`BaseModelOutputWithPooling`):
  63. The output of the [`CLIPSegVisionModel`].
  64. """
  65. loss: torch.FloatTensor | None = None
  66. logits_per_image: torch.FloatTensor | None = None
  67. logits_per_text: torch.FloatTensor | None = None
  68. text_embeds: torch.FloatTensor | None = None
  69. image_embeds: torch.FloatTensor | None = None
  70. text_model_output: BaseModelOutputWithPooling = None
  71. vision_model_output: BaseModelOutputWithPooling = None
  72. def to_tuple(self) -> tuple[Any]:
  73. return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())
  74. @dataclass
  75. @auto_docstring
  76. class CLIPSegDecoderOutput(ModelOutput):
  77. r"""
  78. logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
  79. Classification scores for each pixel.
  80. """
  81. logits: torch.FloatTensor | None = None
  82. hidden_states: tuple[torch.FloatTensor] | None = None
  83. attentions: tuple[torch.FloatTensor] | None = None
  84. @dataclass
  85. @auto_docstring
  86. class CLIPSegImageSegmentationOutput(ModelOutput):
  87. r"""
  88. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  89. Binary cross entropy loss for segmentation.
  90. logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
  91. Classification scores for each pixel.
  92. conditional_embeddings (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
  93. Conditional embeddings used for segmentation.
  94. pooled_output (`torch.FloatTensor` of shape `(batch_size, embed_dim)`):
  95. Pooled output of the [`CLIPSegVisionModel`].
  96. vision_model_output (`BaseModelOutputWithPooling`):
  97. The output of the [`CLIPSegVisionModel`].
  98. decoder_output (`CLIPSegDecoderOutput`):
  99. The output of the [`CLIPSegDecoder`].
  100. """
  101. loss: torch.FloatTensor | None = None
  102. logits: torch.FloatTensor | None = None
  103. conditional_embeddings: torch.FloatTensor | None = None
  104. pooled_output: torch.FloatTensor | None = None
  105. vision_model_output: BaseModelOutputWithPooling = None
  106. decoder_output: CLIPSegDecoderOutput = None
  107. def to_tuple(self) -> tuple[Any]:
  108. return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())
  109. class CLIPSegVisionEmbeddings(nn.Module):
  110. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__ with CLIP->CLIPSeg
  111. def __init__(self, config: CLIPSegVisionConfig):
  112. super().__init__()
  113. self.config = config
  114. self.embed_dim = config.hidden_size
  115. self.image_size = config.image_size
  116. self.patch_size = config.patch_size
  117. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  118. self.patch_embedding = nn.Conv2d(
  119. in_channels=config.num_channels,
  120. out_channels=self.embed_dim,
  121. kernel_size=self.patch_size,
  122. stride=self.patch_size,
  123. bias=False,
  124. )
  125. self.num_patches = (self.image_size // self.patch_size) ** 2
  126. self.num_positions = self.num_patches + 1
  127. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  128. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  129. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  130. """
  131. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  132. images. This method is also adapted to support torch.jit tracing.
  133. Adapted from:
  134. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  135. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  136. """
  137. num_patches = embeddings.shape[1] - 1
  138. position_embedding = self.position_embedding.weight.unsqueeze(0)
  139. num_positions = position_embedding.shape[1] - 1
  140. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  141. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  142. return self.position_embedding(self.position_ids)
  143. class_pos_embed = position_embedding[:, :1]
  144. patch_pos_embed = position_embedding[:, 1:]
  145. dim = embeddings.shape[-1]
  146. new_height = height // self.patch_size
  147. new_width = width // self.patch_size
  148. sqrt_num_positions = torch_int(num_positions**0.5)
  149. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  150. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  151. patch_pos_embed = nn.functional.interpolate(
  152. patch_pos_embed,
  153. size=(new_height, new_width),
  154. mode="bicubic",
  155. align_corners=False,
  156. )
  157. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  158. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  159. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=True) -> torch.Tensor:
  160. batch_size, _, height, width = pixel_values.shape
  161. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  162. raise ValueError(
  163. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  164. )
  165. patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
  166. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  167. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  168. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  169. if interpolate_pos_encoding:
  170. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  171. else:
  172. embeddings = embeddings + self.position_embedding(self.position_ids)
  173. return embeddings
  174. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->CLIPSeg
  175. class CLIPSegTextEmbeddings(nn.Module):
  176. def __init__(self, config: CLIPSegTextConfig):
  177. super().__init__()
  178. embed_dim = config.hidden_size
  179. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  180. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  181. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  182. self.register_buffer(
  183. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  184. )
  185. def forward(
  186. self,
  187. input_ids: torch.LongTensor | None = None,
  188. position_ids: torch.LongTensor | None = None,
  189. inputs_embeds: torch.FloatTensor | None = None,
  190. ) -> torch.Tensor:
  191. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  192. max_position_embedding = self.position_embedding.weight.shape[0]
  193. if seq_length > max_position_embedding:
  194. raise ValueError(
  195. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  196. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  197. )
  198. if position_ids is None:
  199. position_ids = self.position_ids[:, :seq_length]
  200. if inputs_embeds is None:
  201. inputs_embeds = self.token_embedding(input_ids)
  202. position_embeddings = self.position_embedding(position_ids)
  203. embeddings = inputs_embeds + position_embeddings
  204. return embeddings
  205. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  206. def eager_attention_forward(
  207. module: nn.Module,
  208. query: torch.Tensor,
  209. key: torch.Tensor,
  210. value: torch.Tensor,
  211. attention_mask: torch.Tensor | None,
  212. scaling: float,
  213. dropout: float = 0.0,
  214. **kwargs,
  215. ):
  216. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  217. if attention_mask is not None:
  218. attn_weights = attn_weights + attention_mask
  219. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  220. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  221. attn_output = torch.matmul(attn_weights, value)
  222. attn_output = attn_output.transpose(1, 2).contiguous()
  223. return attn_output, attn_weights
  224. class CLIPSegAttention(nn.Module):
  225. """Multi-headed attention from 'Attention Is All You Need' paper"""
  226. def __init__(self, config: CLIPSegVisionConfig | CLIPSegTextConfig):
  227. super().__init__()
  228. self.config = config
  229. self.embed_dim = config.hidden_size
  230. self.num_heads = config.num_attention_heads
  231. self.head_dim = self.embed_dim // self.num_heads
  232. if self.head_dim * self.num_heads != self.embed_dim:
  233. raise ValueError(
  234. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  235. f" {self.num_heads})."
  236. )
  237. self.scale = self.head_dim**-0.5
  238. self.dropout = config.attention_dropout
  239. self.is_causal = False
  240. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  241. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  242. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  243. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  244. def forward(
  245. self,
  246. hidden_states: torch.Tensor,
  247. attention_mask: torch.Tensor | None = None,
  248. **kwargs: Unpack[TransformersKwargs],
  249. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  250. """Input shape: Batch x Time x Channel"""
  251. input_shape = hidden_states.shape[:-1]
  252. hidden_shape = (*input_shape, -1, self.head_dim)
  253. queries = self.q_proj(hidden_states)
  254. keys = self.k_proj(hidden_states)
  255. values = self.v_proj(hidden_states)
  256. queries = queries.view(hidden_shape).transpose(1, 2)
  257. keys = keys.view(hidden_shape).transpose(1, 2)
  258. values = values.view(hidden_shape).transpose(1, 2)
  259. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  260. self.config._attn_implementation, eager_attention_forward
  261. )
  262. attn_output, attn_weights = attention_interface(
  263. self,
  264. queries,
  265. keys,
  266. values,
  267. attention_mask,
  268. scaling=self.scale,
  269. dropout=0.0 if not self.training else self.dropout,
  270. **kwargs,
  271. )
  272. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  273. attn_output = self.out_proj(attn_output)
  274. return attn_output, attn_weights
  275. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->CLIPSeg
  276. class CLIPSegMLP(nn.Module):
  277. def __init__(self, config):
  278. super().__init__()
  279. self.config = config
  280. self.activation_fn = ACT2FN[config.hidden_act]
  281. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  282. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  283. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  284. hidden_states = self.fc1(hidden_states)
  285. hidden_states = self.activation_fn(hidden_states)
  286. hidden_states = self.fc2(hidden_states)
  287. return hidden_states
  288. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
  289. class CLIPSegEncoderLayer(GradientCheckpointingLayer):
  290. def __init__(self, config: CLIPSegConfig):
  291. super().__init__()
  292. self.embed_dim = config.hidden_size
  293. self.self_attn = CLIPSegAttention(config)
  294. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  295. self.mlp = CLIPSegMLP(config)
  296. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  297. def forward(
  298. self,
  299. hidden_states: torch.Tensor,
  300. attention_mask: torch.Tensor,
  301. **kwargs: Unpack[TransformersKwargs],
  302. ) -> tuple[torch.FloatTensor, torch.Tensor | None]:
  303. residual = hidden_states
  304. hidden_states = self.layer_norm1(hidden_states)
  305. hidden_states, _ = self.self_attn(
  306. hidden_states=hidden_states,
  307. attention_mask=attention_mask,
  308. **kwargs,
  309. )
  310. hidden_states = residual + hidden_states
  311. residual = hidden_states
  312. hidden_states = self.layer_norm2(hidden_states)
  313. hidden_states = self.mlp(hidden_states)
  314. hidden_states = residual + hidden_states
  315. return hidden_states
  316. class CLIPSegDecoderLayer(nn.Module):
  317. """
  318. CLIPSeg decoder layer, which is identical to `CLIPSegEncoderLayer`, except that normalization is applied after
  319. self-attention/MLP, rather than before.
  320. """
  321. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer.__init__ with AltCLIP->CLIPSeg
  322. def __init__(self, config: CLIPSegConfig):
  323. super().__init__()
  324. self.embed_dim = config.hidden_size
  325. self.self_attn = CLIPSegAttention(config)
  326. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  327. self.mlp = CLIPSegMLP(config)
  328. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  329. def forward(
  330. self,
  331. hidden_states: torch.Tensor,
  332. attention_mask: torch.Tensor,
  333. **kwargs,
  334. ) -> tuple[torch.FloatTensor]:
  335. """
  336. Args:
  337. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  338. attention_mask (`torch.FloatTensor`): attention mask of size
  339. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  340. `(config.encoder_attention_heads,)`.
  341. output_attentions (`bool`, *optional*):
  342. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  343. returned tensors for more detail.
  344. """
  345. residual = hidden_states
  346. hidden_states, _ = self.self_attn(
  347. hidden_states=hidden_states,
  348. attention_mask=attention_mask,
  349. **kwargs,
  350. )
  351. hidden_states = residual + hidden_states
  352. hidden_states = self.layer_norm1(hidden_states)
  353. residual = hidden_states
  354. hidden_states = self.mlp(hidden_states)
  355. hidden_states = residual + hidden_states
  356. hidden_states = self.layer_norm2(hidden_states)
  357. return hidden_states
  358. @auto_docstring
  359. class CLIPSegPreTrainedModel(PreTrainedModel):
  360. config: CLIPSegConfig
  361. base_model_prefix = "clip"
  362. input_modalities = ("image", "text")
  363. supports_gradient_checkpointing = True
  364. _can_record_outputs = {
  365. "hidden_states": [CLIPSegEncoderLayer, CLIPSegDecoderLayer],
  366. "attentions": CLIPSegAttention,
  367. }
  368. @torch.no_grad()
  369. def _init_weights(self, module):
  370. """Initialize the weights"""
  371. factor = self.config.initializer_factor
  372. if isinstance(module, CLIPSegTextEmbeddings):
  373. init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
  374. init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
  375. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  376. elif isinstance(module, CLIPSegVisionEmbeddings):
  377. factor = self.config.initializer_factor
  378. init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  379. init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  380. init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  381. init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
  382. elif isinstance(module, CLIPSegAttention):
  383. factor = self.config.initializer_factor
  384. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  385. out_proj_std = (module.embed_dim**-0.5) * factor
  386. init.normal_(module.q_proj.weight, std=in_proj_std)
  387. init.normal_(module.k_proj.weight, std=in_proj_std)
  388. init.normal_(module.v_proj.weight, std=in_proj_std)
  389. init.normal_(module.out_proj.weight, std=out_proj_std)
  390. elif isinstance(module, CLIPSegMLP):
  391. factor = self.config.initializer_factor
  392. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  393. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  394. init.normal_(module.fc1.weight, std=fc_std)
  395. init.normal_(module.fc2.weight, std=in_proj_std)
  396. elif isinstance(module, CLIPSegModel):
  397. init.normal_(
  398. module.text_projection.weight,
  399. std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
  400. )
  401. init.normal_(
  402. module.visual_projection.weight,
  403. std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
  404. )
  405. if isinstance(module, nn.LayerNorm):
  406. init.zeros_(module.bias)
  407. init.ones_(module.weight)
  408. if isinstance(module, nn.Linear) and module.bias is not None:
  409. init.zeros_(module.bias)
  410. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg
  411. class CLIPSegEncoder(nn.Module):
  412. """
  413. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  414. [`CLIPSegEncoderLayer`].
  415. Args:
  416. config: CLIPSegConfig
  417. """
  418. def __init__(self, config: CLIPSegConfig):
  419. super().__init__()
  420. self.config = config
  421. self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  422. self.gradient_checkpointing = False
  423. def forward(
  424. self,
  425. inputs_embeds,
  426. attention_mask: torch.Tensor | None = None,
  427. **kwargs: Unpack[TransformersKwargs],
  428. ) -> tuple | BaseModelOutput:
  429. r"""
  430. Args:
  431. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  432. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  433. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  434. than the model's internal embedding lookup matrix.
  435. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  436. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  437. - 1 for tokens that are **not masked**,
  438. - 0 for tokens that are **masked**.
  439. [What are attention masks?](../glossary#attention-mask)
  440. """
  441. hidden_states = inputs_embeds
  442. for encoder_layer in self.layers:
  443. hidden_states = encoder_layer(
  444. hidden_states,
  445. attention_mask,
  446. **kwargs,
  447. )
  448. return BaseModelOutput(
  449. last_hidden_state=hidden_states,
  450. )
  451. class CLIPSegDecoder(CLIPSegPreTrainedModel):
  452. def __init__(self, config: CLIPSegConfig):
  453. super().__init__(config)
  454. self.conditional_layer = config.conditional_layer
  455. self.film_mul = nn.Linear(config.projection_dim, config.reduce_dim)
  456. self.film_add = nn.Linear(config.projection_dim, config.reduce_dim)
  457. if config.use_complex_transposed_convolution:
  458. transposed_kernels = (config.vision_config.patch_size // 4, config.vision_config.patch_size // 4)
  459. self.transposed_convolution = nn.Sequential(
  460. nn.Conv2d(config.reduce_dim, config.reduce_dim, kernel_size=3, padding=1),
  461. nn.ReLU(),
  462. nn.ConvTranspose2d(
  463. config.reduce_dim,
  464. config.reduce_dim // 2,
  465. kernel_size=transposed_kernels[0],
  466. stride=transposed_kernels[0],
  467. ),
  468. nn.ReLU(),
  469. nn.ConvTranspose2d(
  470. config.reduce_dim // 2, 1, kernel_size=transposed_kernels[1], stride=transposed_kernels[1]
  471. ),
  472. )
  473. else:
  474. self.transposed_convolution = nn.ConvTranspose2d(
  475. config.reduce_dim, 1, config.vision_config.patch_size, stride=config.vision_config.patch_size
  476. )
  477. depth = len(config.extract_layers)
  478. self.reduces = nn.ModuleList(
  479. [nn.Linear(config.vision_config.hidden_size, config.reduce_dim) for _ in range(depth)]
  480. )
  481. decoder_config = copy.deepcopy(config.vision_config)
  482. decoder_config.hidden_size = config.reduce_dim
  483. decoder_config.num_attention_heads = config.decoder_num_attention_heads
  484. decoder_config.intermediate_size = config.decoder_intermediate_size
  485. decoder_config.hidden_act = "relu"
  486. self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])
  487. self.post_init()
  488. @merge_with_config_defaults
  489. @capture_outputs
  490. def forward(
  491. self,
  492. hidden_states: tuple[torch.Tensor],
  493. conditional_embeddings: torch.Tensor,
  494. output_attentions: bool | None = None,
  495. **kwargs: Unpack[TransformersKwargs],
  496. ):
  497. activations = hidden_states[::-1]
  498. output = None
  499. for i, (activation, layer, reduce) in enumerate(zip(activations, self.layers, self.reduces)):
  500. if output is not None:
  501. output = reduce(activation) + output
  502. else:
  503. output = reduce(activation)
  504. if i == self.conditional_layer:
  505. output = self.film_mul(conditional_embeddings) * output.permute(1, 0, 2) + self.film_add(
  506. conditional_embeddings
  507. )
  508. output = output.permute(1, 0, 2)
  509. output = layer(output, attention_mask=None, causal_attention_mask=None, **kwargs)
  510. output = output[:, 1:, :].permute(0, 2, 1) # remove cls token and reshape to [batch_size, reduce_dim, seq_len]
  511. size = int(math.sqrt(output.shape[2]))
  512. batch_size = conditional_embeddings.shape[0]
  513. output = output.view(batch_size, output.shape[1], size, size)
  514. logits = self.transposed_convolution(output).squeeze(1)
  515. return CLIPSegDecoderOutput(logits=logits)
  516. class CLIPSegTextTransformer(CLIPSegPreTrainedModel):
  517. def __init__(self, config: CLIPSegTextConfig):
  518. super().__init__(config)
  519. embed_dim = config.hidden_size
  520. self.embeddings = CLIPSegTextEmbeddings(config)
  521. self.encoder = CLIPSegEncoder(config)
  522. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  523. # For `pooled_output` computation
  524. self.eos_token_id = config.eos_token_id
  525. self.post_init()
  526. @auto_docstring
  527. def forward(
  528. self,
  529. input_ids: torch.Tensor | None = None,
  530. attention_mask: torch.Tensor | None = None,
  531. position_ids: torch.Tensor | None = None,
  532. output_attentions: bool | None = None,
  533. output_hidden_states: bool | None = None,
  534. return_dict: bool | None = None,
  535. **kwargs,
  536. ) -> tuple | BaseModelOutputWithPooling:
  537. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  538. output_hidden_states = (
  539. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  540. )
  541. return_dict = return_dict if return_dict is not None else self.config.return_dict
  542. if input_ids is None:
  543. raise ValueError("You have to specify input_ids")
  544. input_shape = input_ids.size()
  545. input_ids = input_ids.view(-1, input_shape[-1])
  546. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  547. attention_mask = create_causal_mask(
  548. config=self.config,
  549. inputs_embeds=hidden_states,
  550. attention_mask=attention_mask,
  551. past_key_values=None,
  552. )
  553. kwargs.pop("is_causal", None)
  554. encoder_outputs = self.encoder(
  555. inputs_embeds=hidden_states,
  556. attention_mask=attention_mask,
  557. output_attentions=output_attentions,
  558. output_hidden_states=output_hidden_states,
  559. return_dict=return_dict,
  560. is_causal=True,
  561. **kwargs,
  562. )
  563. last_hidden_state = encoder_outputs[0]
  564. last_hidden_state = self.final_layer_norm(last_hidden_state)
  565. if self.eos_token_id == 2:
  566. # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
  567. # A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added
  568. # ------------------------------------------------------------
  569. # text_embeds.shape = [batch_size, sequence_length, transformer.width]
  570. # take features from the eot embedding (eot_token is the highest number in each sequence)
  571. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  572. pooled_output = last_hidden_state[
  573. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  574. input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
  575. ]
  576. else:
  577. # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
  578. pooled_output = last_hidden_state[
  579. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  580. # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
  581. # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
  582. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
  583. .int()
  584. .argmax(dim=-1),
  585. ]
  586. if not return_dict:
  587. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  588. return BaseModelOutputWithPooling(
  589. last_hidden_state=last_hidden_state,
  590. pooler_output=pooled_output,
  591. hidden_states=encoder_outputs.hidden_states,
  592. attentions=encoder_outputs.attentions,
  593. )
  594. class CLIPSegTextModel(CLIPSegPreTrainedModel):
  595. config: CLIPSegTextConfig
  596. input_modalities = ("text",)
  597. _no_split_modules = ["CLIPSegTextEmbeddings", "CLIPSegEncoderLayer"]
  598. def __init__(self, config: CLIPSegTextConfig):
  599. super().__init__(config)
  600. self.text_model = CLIPSegTextTransformer(config)
  601. # Initialize weights and apply final processing
  602. self.post_init()
  603. def get_input_embeddings(self) -> nn.Module:
  604. return self.text_model.embeddings.token_embedding
  605. def set_input_embeddings(self, value):
  606. self.text_model.embeddings.token_embedding = value
  607. @merge_with_config_defaults
  608. @capture_outputs
  609. @auto_docstring
  610. def forward(
  611. self,
  612. input_ids: torch.Tensor | None = None,
  613. attention_mask: torch.Tensor | None = None,
  614. position_ids: torch.Tensor | None = None,
  615. **kwargs: Unpack[TransformersKwargs],
  616. ) -> tuple | BaseModelOutputWithPooling:
  617. r"""
  618. Examples:
  619. ```python
  620. >>> from transformers import AutoTokenizer, CLIPSegTextModel
  621. >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
  622. >>> model = CLIPSegTextModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  623. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  624. >>> outputs = model(**inputs)
  625. >>> last_hidden_state = outputs.last_hidden_state
  626. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  627. ```"""
  628. return self.text_model(
  629. input_ids=input_ids,
  630. attention_mask=attention_mask,
  631. position_ids=position_ids,
  632. return_dict=True,
  633. **kwargs,
  634. )
  635. class CLIPSegVisionTransformer(nn.Module):
  636. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIP->CLIPSeg
  637. def __init__(self, config: CLIPSegVisionConfig):
  638. super().__init__()
  639. self.config = config
  640. embed_dim = config.hidden_size
  641. self.embeddings = CLIPSegVisionEmbeddings(config)
  642. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  643. self.encoder = CLIPSegEncoder(config)
  644. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  645. @auto_docstring
  646. def forward(
  647. self,
  648. pixel_values: torch.FloatTensor | None,
  649. output_attentions: bool | None = None,
  650. output_hidden_states: bool | None = None,
  651. return_dict: bool | None = None,
  652. interpolate_pos_encoding: bool | None = True,
  653. ) -> tuple | BaseModelOutputWithPooling:
  654. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  655. output_hidden_states = (
  656. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  657. )
  658. return_dict = return_dict if return_dict is not None else self.config.return_dict
  659. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  660. hidden_states = self.pre_layrnorm(hidden_states)
  661. encoder_outputs = self.encoder(
  662. inputs_embeds=hidden_states,
  663. output_attentions=output_attentions,
  664. output_hidden_states=output_hidden_states,
  665. return_dict=return_dict,
  666. )
  667. last_hidden_state = encoder_outputs[0]
  668. pooled_output = last_hidden_state[:, 0, :]
  669. pooled_output = self.post_layernorm(pooled_output)
  670. if not return_dict:
  671. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  672. return BaseModelOutputWithPooling(
  673. last_hidden_state=last_hidden_state,
  674. pooler_output=pooled_output,
  675. hidden_states=encoder_outputs.hidden_states,
  676. attentions=encoder_outputs.attentions,
  677. )
  678. class CLIPSegVisionModel(CLIPSegPreTrainedModel):
  679. config: CLIPSegVisionConfig
  680. main_input_name = "pixel_values"
  681. input_modalities = ("image",)
  682. def __init__(self, config: CLIPSegVisionConfig):
  683. super().__init__(config)
  684. self.vision_model = CLIPSegVisionTransformer(config)
  685. # Initialize weights and apply final processing
  686. self.post_init()
  687. def get_input_embeddings(self) -> nn.Module:
  688. return self.vision_model.embeddings.patch_embedding
  689. @merge_with_config_defaults
  690. @capture_outputs(tie_last_hidden_states=False)
  691. @auto_docstring
  692. def forward(
  693. self,
  694. pixel_values: torch.FloatTensor | None = None,
  695. interpolate_pos_encoding: bool | None = True,
  696. **kwargs: Unpack[TransformersKwargs],
  697. ) -> tuple | BaseModelOutputWithPooling:
  698. r"""
  699. Examples:
  700. ```python
  701. >>> from PIL import Image
  702. >>> import httpx
  703. >>> from io import BytesIO
  704. >>> from transformers import AutoProcessor, CLIPSegVisionModel
  705. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  706. >>> model = CLIPSegVisionModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  707. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  708. >>> with httpx.stream("GET", url) as response:
  709. ... image = Image.open(BytesIO(response.read()))
  710. >>> inputs = processor(images=image, return_tensors="pt")
  711. >>> outputs = model(**inputs)
  712. >>> last_hidden_state = outputs.last_hidden_state
  713. >>> pooled_output = outputs.pooler_output # pooled CLS states
  714. ```"""
  715. return self.vision_model(
  716. pixel_values=pixel_values,
  717. interpolate_pos_encoding=interpolate_pos_encoding,
  718. **kwargs,
  719. )
  720. @auto_docstring
  721. class CLIPSegModel(CLIPSegPreTrainedModel):
  722. config: CLIPSegConfig
  723. def __init__(self, config: CLIPSegConfig):
  724. super().__init__(config)
  725. if not isinstance(config.text_config, CLIPSegTextConfig):
  726. raise TypeError(
  727. "config.text_config is expected to be of type CLIPSegTextConfig but is of type"
  728. f" {type(config.text_config)}."
  729. )
  730. if not isinstance(config.vision_config, CLIPSegVisionConfig):
  731. raise TypeError(
  732. "config.vision_config is expected to be of type CLIPSegVisionConfig but is of type"
  733. f" {type(config.vision_config)}."
  734. )
  735. text_config = config.text_config
  736. vision_config = config.vision_config
  737. # The module using it is not a PreTrainedModel subclass so we need this
  738. text_config._attn_implementation = config._attn_implementation
  739. # The module using it is not a PreTrainedModel subclass so we need this
  740. vision_config._attn_implementation = config._attn_implementation
  741. self.projection_dim = config.projection_dim
  742. self.text_embed_dim = text_config.hidden_size
  743. self.vision_embed_dim = vision_config.hidden_size
  744. self.text_model = CLIPSegTextTransformer(text_config)
  745. self.vision_model = CLIPSegVisionTransformer(vision_config)
  746. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  747. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  748. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  749. # Initialize weights and apply final processing
  750. self.post_init()
  751. @merge_with_config_defaults
  752. @capture_outputs
  753. @auto_docstring
  754. def get_text_features(
  755. self,
  756. input_ids: torch.Tensor,
  757. attention_mask: torch.Tensor | None = None,
  758. position_ids: torch.Tensor | None = None,
  759. **kwargs: Unpack[TransformersKwargs],
  760. ) -> tuple | BaseModelOutputWithPooling:
  761. r"""
  762. Examples:
  763. ```python
  764. >>> import torch
  765. >>> from transformers import AutoTokenizer, CLIPSegModel
  766. >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
  767. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  768. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  769. >>> with torch.inference_mode():
  770. ... text_features = model.get_text_features(**inputs)
  771. ```"""
  772. text_outputs: BaseModelOutputWithPooling = self.text_model(
  773. input_ids=input_ids,
  774. attention_mask=attention_mask,
  775. position_ids=position_ids,
  776. **kwargs,
  777. )
  778. pooled_output = text_outputs.pooler_output
  779. text_outputs.pooler_output = self.text_projection(pooled_output)
  780. return text_outputs
  781. @merge_with_config_defaults
  782. @capture_outputs(tie_last_hidden_states=False)
  783. @auto_docstring
  784. def get_image_features(
  785. self,
  786. pixel_values: torch.FloatTensor,
  787. interpolate_pos_encoding: bool = True,
  788. **kwargs: Unpack[TransformersKwargs],
  789. ) -> tuple | BaseModelOutputWithPooling:
  790. r"""
  791. Examples:
  792. ```python
  793. >>> import torch
  794. >>> from transformers import AutoProcessor, CLIPSegModel
  795. >>> from transformers.image_utils import load_image
  796. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  797. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  798. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  799. >>> image = load_image(url)
  800. >>> inputs = processor(images=image, return_tensors="pt")
  801. >>> with torch.inference_mode():
  802. ... image_features = model.get_image_features(**inputs)
  803. ```"""
  804. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  805. pixel_values=pixel_values,
  806. interpolate_pos_encoding=interpolate_pos_encoding,
  807. **kwargs,
  808. )
  809. pooled_output = vision_outputs.pooler_output
  810. vision_outputs.pooler_output = self.visual_projection(pooled_output)
  811. return vision_outputs
  812. @can_return_tuple
  813. @auto_docstring
  814. def forward(
  815. self,
  816. input_ids: torch.LongTensor | None = None,
  817. pixel_values: torch.FloatTensor | None = None,
  818. attention_mask: torch.Tensor | None = None,
  819. position_ids: torch.LongTensor | None = None,
  820. return_loss: bool | None = None,
  821. interpolate_pos_encoding: bool = True,
  822. **kwargs: Unpack[TransformersKwargs],
  823. ) -> tuple | CLIPSegOutput:
  824. r"""
  825. return_loss (`bool`, *optional*):
  826. Whether or not to return the contrastive loss.
  827. Examples:
  828. ```python
  829. >>> import torch
  830. >>> from transformers import AutoProcessor, CLIPSegModel
  831. >>> from transformers.image_utils import load_image
  832. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  833. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  834. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  835. >>> image = load_image(url)
  836. >>> inputs = processor(
  837. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  838. ... )
  839. >>> with torch.inference_mode():
  840. ... outputs = model(**inputs)
  841. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  842. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  843. ```"""
  844. vision_outputs = self.get_image_features(
  845. pixel_values=pixel_values,
  846. interpolate_pos_encoding=interpolate_pos_encoding,
  847. **kwargs,
  848. )
  849. text_outputs = self.get_text_features(
  850. input_ids=input_ids,
  851. attention_mask=attention_mask,
  852. position_ids=position_ids,
  853. **kwargs,
  854. )
  855. image_embeds = vision_outputs.pooler_output
  856. text_embeds = text_outputs.pooler_output
  857. # normalized features
  858. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  859. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  860. # cosine similarity as logits
  861. logit_scale = self.logit_scale.exp()
  862. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  863. logits_per_image = logits_per_text.t()
  864. loss = None
  865. if return_loss:
  866. loss = clipseg_loss(logits_per_text)
  867. return CLIPSegOutput(
  868. loss=loss,
  869. logits_per_image=logits_per_image,
  870. logits_per_text=logits_per_text,
  871. text_embeds=text_embeds,
  872. image_embeds=image_embeds,
  873. text_model_output=text_outputs,
  874. vision_model_output=vision_outputs,
  875. )
  876. @auto_docstring(
  877. custom_intro="""
  878. CLIPSeg model with a Transformer-based decoder on top for zero-shot and one-shot image segmentation.
  879. """
  880. )
  881. class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
  882. config: CLIPSegConfig
  883. def __init__(self, config: CLIPSegConfig):
  884. super().__init__(config)
  885. self.config = config
  886. self.clip = CLIPSegModel(config)
  887. self.extract_layers = config.extract_layers
  888. self.decoder = CLIPSegDecoder(config)
  889. # Initialize weights and apply final processing
  890. self.post_init()
  891. def get_conditional_embeddings(
  892. self,
  893. batch_size: int | None = None,
  894. input_ids: torch.Tensor | None = None,
  895. attention_mask: torch.Tensor | None = None,
  896. position_ids: torch.Tensor | None = None,
  897. conditional_pixel_values: torch.Tensor | None = None,
  898. ):
  899. if input_ids is not None:
  900. # compute conditional embeddings from texts
  901. if len(input_ids) != batch_size:
  902. raise ValueError("Make sure to pass as many prompt texts as there are query images")
  903. with torch.no_grad():
  904. conditional_embeddings = self.clip.get_text_features(
  905. input_ids, attention_mask=attention_mask, position_ids=position_ids
  906. ).pooler_output
  907. elif conditional_pixel_values is not None:
  908. # compute conditional embeddings from images
  909. if len(conditional_pixel_values) != batch_size:
  910. raise ValueError("Make sure to pass as many prompt images as there are query images")
  911. with torch.no_grad():
  912. conditional_embeddings = self.clip.get_image_features(conditional_pixel_values).pooler_output
  913. else:
  914. raise ValueError(
  915. "Invalid conditional, should be either provided as `input_ids` or `conditional_pixel_values`"
  916. )
  917. return conditional_embeddings
  918. @can_return_tuple
  919. @auto_docstring
  920. def forward(
  921. self,
  922. input_ids: torch.FloatTensor | None = None,
  923. pixel_values: torch.FloatTensor | None = None,
  924. conditional_pixel_values: torch.FloatTensor | None = None,
  925. conditional_embeddings: torch.FloatTensor | None = None,
  926. attention_mask: torch.Tensor | None = None,
  927. position_ids: torch.LongTensor | None = None,
  928. labels: torch.LongTensor | None = None,
  929. interpolate_pos_encoding: bool = True,
  930. **kwargs: Unpack[TransformersKwargs],
  931. ) -> tuple | CLIPSegOutput:
  932. r"""
  933. conditional_pixel_values (`torch.FloatTensor`, *optional*):
  934. The pixel values of the conditional images.
  935. conditional_embeddings (`torch.FloatTensor` of shape `(batch_size, config.projection_dim)`, *optional*):
  936. The conditional embeddings for the query images. If provided, the model will use this instead of computing
  937. the embeddings from the conditional_pixel_values.
  938. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  939. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  940. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  941. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  942. Examples:
  943. ```python
  944. >>> import torch
  945. >>> from transformers import AutoProcessor, CLIPSegForImageSegmentation
  946. >>> from transformers.image_utils import load_image
  947. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  948. >>> model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
  949. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  950. >>> image = load_image(url)
  951. >>> texts = ["a cat", "a remote", "a blanket"]
  952. >>> inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt")
  953. >>> with torch.inference_mode():
  954. ... outputs = model(**inputs)
  955. >>> logits = outputs.logits
  956. >>> print(logits.shape)
  957. torch.Size([3, 352, 352])
  958. ```"""
  959. # step 1: forward the query images through the frozen CLIP vision encoder
  960. with torch.no_grad():
  961. kwargs["output_hidden_states"] = True # required to extract layers for the stages
  962. vision_outputs = self.clip.get_image_features(
  963. pixel_values=pixel_values,
  964. interpolate_pos_encoding=interpolate_pos_encoding,
  965. **kwargs,
  966. )
  967. pooled_output = vision_outputs.pooler_output
  968. hidden_states = vision_outputs.hidden_states
  969. # we add +1 here as the hidden states also include the initial embeddings
  970. activations = [hidden_states[i + 1] for i in self.extract_layers]
  971. # update vision_outputs
  972. vision_outputs = BaseModelOutputWithPooling(
  973. last_hidden_state=vision_outputs.last_hidden_state,
  974. pooler_output=vision_outputs.pooler_output,
  975. hidden_states=vision_outputs.hidden_states,
  976. attentions=vision_outputs.attentions,
  977. )
  978. # step 2: compute conditional embeddings, either from text, images or an own provided embedding
  979. if conditional_embeddings is None:
  980. conditional_embeddings = self.get_conditional_embeddings(
  981. batch_size=pixel_values.shape[0],
  982. input_ids=input_ids,
  983. attention_mask=attention_mask,
  984. position_ids=position_ids,
  985. conditional_pixel_values=conditional_pixel_values,
  986. )
  987. else:
  988. if conditional_embeddings.shape[0] != pixel_values.shape[0]:
  989. raise ValueError(
  990. "Make sure to pass as many conditional embeddings as there are query images in the batch"
  991. )
  992. if conditional_embeddings.shape[1] != self.config.projection_dim:
  993. raise ValueError(
  994. "Make sure that the feature dimension of the conditional embeddings matches"
  995. " `config.projection_dim`."
  996. )
  997. # step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks
  998. decoder_outputs = self.decoder(
  999. activations,
  1000. conditional_embeddings,
  1001. **kwargs,
  1002. )
  1003. logits = decoder_outputs.logits
  1004. loss = None
  1005. if labels is not None:
  1006. # move labels to the correct device to enable PP
  1007. labels = labels.to(logits.device)
  1008. loss_fn = nn.BCEWithLogitsLoss()
  1009. loss = loss_fn(logits, labels)
  1010. return CLIPSegImageSegmentationOutput(
  1011. loss=loss,
  1012. logits=logits,
  1013. conditional_embeddings=conditional_embeddings,
  1014. pooled_output=pooled_output,
  1015. vision_model_output=vision_outputs,
  1016. decoder_output=decoder_outputs,
  1017. )
  1018. __all__ = [
  1019. "CLIPSegModel",
  1020. "CLIPSegPreTrainedModel",
  1021. "CLIPSegTextModel",
  1022. "CLIPSegVisionModel",
  1023. "CLIPSegForImageSegmentation",
  1024. ]