modeling_clip.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149
  1. # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch CLIP model."""
  15. from collections.abc import Callable
  16. from dataclasses import dataclass
  17. from typing import Any
  18. import torch
  19. from torch import nn
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  25. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import (
  28. ModelOutput,
  29. TransformersKwargs,
  30. auto_docstring,
  31. logging,
  32. torch_int,
  33. )
  34. from ...utils.generic import can_return_tuple, merge_with_config_defaults
  35. from ...utils.output_capturing import capture_outputs
  36. from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
  37. logger = logging.get_logger(__name__)
  38. # contrastive loss function, adapted from
  39. # https://sachinruk.github.io/blog/2021-03-07-clip.html
  40. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  41. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  42. def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
  43. caption_loss = contrastive_loss(similarity)
  44. image_loss = contrastive_loss(similarity.t())
  45. return (caption_loss + image_loss) / 2.0
  46. def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
  47. """
  48. This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
  49. model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
  50. """
  51. square_tensor = torch.pow(tensor, 2)
  52. sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
  53. normed_tensor = torch.pow(sum_tensor, 0.5)
  54. return normed_tensor
  55. @dataclass
  56. @auto_docstring(
  57. custom_intro="""
  58. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  59. """
  60. )
  61. class CLIPVisionModelOutput(ModelOutput):
  62. r"""
  63. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  64. The image embeddings obtained by applying the projection layer to the pooler_output.
  65. """
  66. image_embeds: torch.FloatTensor | None = None
  67. last_hidden_state: torch.FloatTensor | None = None
  68. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  69. attentions: tuple[torch.FloatTensor, ...] | None = None
  70. @dataclass
  71. @auto_docstring(
  72. custom_intro="""
  73. Base class for text model's outputs that also contains a pooling of the last hidden states.
  74. """
  75. )
  76. class CLIPTextModelOutput(ModelOutput):
  77. r"""
  78. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  79. The text embeddings obtained by applying the projection layer to the pooler_output.
  80. """
  81. text_embeds: torch.FloatTensor | None = None
  82. last_hidden_state: torch.FloatTensor | None = None
  83. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  84. attentions: tuple[torch.FloatTensor, ...] | None = None
  85. @dataclass
  86. @auto_docstring
  87. class CLIPOutput(ModelOutput):
  88. r"""
  89. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  90. Contrastive loss for image-text similarity.
  91. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  92. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  93. similarity scores.
  94. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  95. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  96. similarity scores.
  97. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  98. The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
  99. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  100. The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
  101. text_model_output (`BaseModelOutputWithPooling`):
  102. The output of the [`CLIPTextModel`].
  103. vision_model_output (`BaseModelOutputWithPooling`):
  104. The output of the [`CLIPVisionModel`].
  105. """
  106. loss: torch.FloatTensor | None = None
  107. logits_per_image: torch.FloatTensor | None = None
  108. logits_per_text: torch.FloatTensor | None = None
  109. text_embeds: torch.FloatTensor | None = None
  110. image_embeds: torch.FloatTensor | None = None
  111. text_model_output: BaseModelOutputWithPooling = None
  112. vision_model_output: BaseModelOutputWithPooling = None
  113. def to_tuple(self) -> tuple[Any]:
  114. return tuple(v.to_tuple() if isinstance(v, ModelOutput) else v for v in self.values())
  115. class CLIPVisionEmbeddings(nn.Module):
  116. def __init__(self, config: CLIPVisionConfig):
  117. super().__init__()
  118. self.config = config
  119. self.embed_dim = config.hidden_size
  120. self.image_size = config.image_size
  121. self.patch_size = config.patch_size
  122. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  123. self.patch_embedding = nn.Conv2d(
  124. in_channels=config.num_channels,
  125. out_channels=self.embed_dim,
  126. kernel_size=self.patch_size,
  127. stride=self.patch_size,
  128. bias=False,
  129. )
  130. self.num_patches = (self.image_size // self.patch_size) ** 2
  131. self.num_positions = self.num_patches + 1
  132. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  133. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  134. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  135. """
  136. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  137. images. This method is also adapted to support torch.jit tracing.
  138. Adapted from:
  139. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  140. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  141. """
  142. num_patches = embeddings.shape[1] - 1
  143. position_embedding = self.position_embedding.weight.unsqueeze(0)
  144. num_positions = position_embedding.shape[1] - 1
  145. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  146. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  147. return self.position_embedding(self.position_ids)
  148. class_pos_embed = position_embedding[:, :1]
  149. patch_pos_embed = position_embedding[:, 1:]
  150. dim = embeddings.shape[-1]
  151. new_height = height // self.patch_size
  152. new_width = width // self.patch_size
  153. sqrt_num_positions = torch_int(num_positions**0.5)
  154. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  155. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  156. patch_pos_embed = nn.functional.interpolate(
  157. patch_pos_embed,
  158. size=(new_height, new_width),
  159. mode="bicubic",
  160. align_corners=False,
  161. )
  162. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  163. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  164. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  165. batch_size, _, height, width = pixel_values.shape
  166. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  167. raise ValueError(
  168. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  169. )
  170. target_dtype = self.patch_embedding.weight.dtype
  171. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  172. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  173. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  174. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  175. if interpolate_pos_encoding:
  176. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  177. else:
  178. embeddings = embeddings + self.position_embedding(self.position_ids)
  179. return embeddings
  180. class CLIPTextEmbeddings(nn.Module):
  181. def __init__(self, config: CLIPTextConfig):
  182. super().__init__()
  183. embed_dim = config.hidden_size
  184. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  185. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  186. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  187. self.register_buffer(
  188. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  189. )
  190. def forward(
  191. self,
  192. input_ids: torch.LongTensor | None = None,
  193. position_ids: torch.LongTensor | None = None,
  194. inputs_embeds: torch.FloatTensor | None = None,
  195. ) -> torch.Tensor:
  196. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  197. max_position_embedding = self.position_embedding.weight.shape[0]
  198. if seq_length > max_position_embedding:
  199. raise ValueError(
  200. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  201. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  202. )
  203. if position_ids is None:
  204. position_ids = self.position_ids[:, :seq_length]
  205. if inputs_embeds is None:
  206. inputs_embeds = self.token_embedding(input_ids)
  207. position_embeddings = self.position_embedding(position_ids)
  208. embeddings = inputs_embeds + position_embeddings
  209. return embeddings
  210. def eager_attention_forward(
  211. module: nn.Module,
  212. query: torch.Tensor,
  213. key: torch.Tensor,
  214. value: torch.Tensor,
  215. attention_mask: torch.Tensor | None,
  216. scaling: float,
  217. dropout: float = 0.0,
  218. **kwargs: Unpack[TransformersKwargs],
  219. ):
  220. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  221. if attention_mask is not None:
  222. attn_weights = attn_weights + attention_mask
  223. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  224. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  225. attn_output = torch.matmul(attn_weights, value)
  226. attn_output = attn_output.transpose(1, 2).contiguous()
  227. return attn_output, attn_weights
  228. class CLIPAttention(nn.Module):
  229. """Multi-headed attention from 'Attention Is All You Need' paper"""
  230. def __init__(self, config: CLIPVisionConfig | CLIPTextConfig):
  231. super().__init__()
  232. self.config = config
  233. self.embed_dim = config.hidden_size
  234. self.num_heads = config.num_attention_heads
  235. self.head_dim = self.embed_dim // self.num_heads
  236. self.scale = self.head_dim**-0.5
  237. self.dropout = config.attention_dropout
  238. self.is_causal = False
  239. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  240. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  241. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  242. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  243. def forward(
  244. self,
  245. hidden_states: torch.Tensor,
  246. attention_mask: torch.Tensor | None = None,
  247. **kwargs: Unpack[TransformersKwargs],
  248. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  249. """Input shape: Batch x Time x Channel"""
  250. input_shape = hidden_states.shape[:-1]
  251. hidden_shape = (*input_shape, -1, self.head_dim)
  252. queries = self.q_proj(hidden_states)
  253. keys = self.k_proj(hidden_states)
  254. values = self.v_proj(hidden_states)
  255. queries = queries.view(hidden_shape).transpose(1, 2)
  256. keys = keys.view(hidden_shape).transpose(1, 2)
  257. values = values.view(hidden_shape).transpose(1, 2)
  258. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  259. self.config._attn_implementation, eager_attention_forward
  260. )
  261. attn_output, attn_weights = attention_interface(
  262. self,
  263. queries,
  264. keys,
  265. values,
  266. attention_mask,
  267. scaling=self.scale,
  268. dropout=0.0 if not self.training else self.dropout,
  269. **kwargs,
  270. )
  271. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  272. attn_output = self.out_proj(attn_output)
  273. return attn_output, attn_weights
  274. class CLIPMLP(nn.Module):
  275. def __init__(self, config):
  276. super().__init__()
  277. self.config = config
  278. self.activation_fn = ACT2FN[config.hidden_act]
  279. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  280. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  281. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  282. hidden_states = self.fc1(hidden_states)
  283. hidden_states = self.activation_fn(hidden_states)
  284. hidden_states = self.fc2(hidden_states)
  285. return hidden_states
  286. class CLIPEncoderLayer(GradientCheckpointingLayer):
  287. def __init__(self, config: CLIPVisionConfig | CLIPTextConfig):
  288. super().__init__()
  289. self.embed_dim = config.hidden_size
  290. self.self_attn = CLIPAttention(config)
  291. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  292. self.mlp = CLIPMLP(config)
  293. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  294. def forward(
  295. self,
  296. hidden_states: torch.Tensor,
  297. attention_mask: torch.Tensor,
  298. **kwargs: Unpack[TransformersKwargs],
  299. ) -> torch.FloatTensor:
  300. residual = hidden_states
  301. hidden_states = self.layer_norm1(hidden_states)
  302. hidden_states, _ = self.self_attn(
  303. hidden_states=hidden_states,
  304. attention_mask=attention_mask,
  305. **kwargs,
  306. )
  307. hidden_states = residual + hidden_states
  308. residual = hidden_states
  309. hidden_states = self.layer_norm2(hidden_states)
  310. hidden_states = self.mlp(hidden_states)
  311. hidden_states = residual + hidden_states
  312. return hidden_states
  313. @auto_docstring
  314. class CLIPPreTrainedModel(PreTrainedModel):
  315. config: CLIPConfig
  316. base_model_prefix = "clip"
  317. input_modalities = ("image", "text")
  318. supports_gradient_checkpointing = True
  319. _supports_sdpa = True
  320. _supports_flash_attn = True
  321. _supports_flex_attn = True
  322. _supports_attention_backend = True
  323. _can_record_outputs = {
  324. "hidden_states": CLIPEncoderLayer,
  325. "attentions": CLIPAttention,
  326. }
  327. @torch.no_grad()
  328. def _init_weights(self, module):
  329. """Initialize the weights"""
  330. factor = self.config.initializer_factor
  331. if isinstance(module, CLIPTextEmbeddings):
  332. init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
  333. init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
  334. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  335. elif isinstance(module, CLIPVisionEmbeddings):
  336. factor = self.config.initializer_factor
  337. init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  338. init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  339. init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  340. init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
  341. elif isinstance(module, CLIPAttention):
  342. factor = self.config.initializer_factor
  343. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  344. out_proj_std = (module.embed_dim**-0.5) * factor
  345. init.normal_(module.q_proj.weight, std=in_proj_std)
  346. init.normal_(module.k_proj.weight, std=in_proj_std)
  347. init.normal_(module.v_proj.weight, std=in_proj_std)
  348. init.normal_(module.out_proj.weight, std=out_proj_std)
  349. elif isinstance(module, CLIPMLP):
  350. factor = self.config.initializer_factor
  351. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  352. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  353. init.normal_(module.fc1.weight, std=fc_std)
  354. init.normal_(module.fc2.weight, std=in_proj_std)
  355. elif isinstance(module, CLIPModel):
  356. init.normal_(
  357. module.text_projection.weight,
  358. std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
  359. )
  360. init.normal_(
  361. module.visual_projection.weight,
  362. std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
  363. )
  364. elif isinstance(module, CLIPVisionModelWithProjection):
  365. init.normal_(
  366. module.visual_projection.weight,
  367. std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
  368. )
  369. elif isinstance(module, CLIPTextModelWithProjection):
  370. init.normal_(
  371. module.text_projection.weight,
  372. std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
  373. )
  374. elif isinstance(module, CLIPForImageClassification):
  375. init.normal_(
  376. module.classifier.weight,
  377. std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
  378. )
  379. if isinstance(module, nn.LayerNorm):
  380. init.zeros_(module.bias)
  381. init.ones_(module.weight)
  382. if isinstance(module, nn.Linear) and module.bias is not None:
  383. init.zeros_(module.bias)
  384. class CLIPEncoder(nn.Module):
  385. """
  386. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  387. [`CLIPEncoderLayer`].
  388. Args:
  389. config: CLIPConfig
  390. """
  391. def __init__(self, config: CLIPConfig):
  392. super().__init__()
  393. self.config = config
  394. self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  395. self.gradient_checkpointing = False
  396. def forward(
  397. self,
  398. inputs_embeds,
  399. attention_mask: torch.Tensor | None = None,
  400. **kwargs: Unpack[TransformersKwargs],
  401. ) -> BaseModelOutput:
  402. r"""
  403. Args:
  404. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  405. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  406. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  407. than the model's internal embedding lookup matrix.
  408. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  409. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  410. - 1 for tokens that are **not masked**,
  411. - 0 for tokens that are **masked**.
  412. [What are attention masks?](../glossary#attention-mask)
  413. """
  414. hidden_states = inputs_embeds
  415. for encoder_layer in self.layers:
  416. hidden_states = encoder_layer(
  417. hidden_states,
  418. attention_mask,
  419. **kwargs,
  420. )
  421. return BaseModelOutput(
  422. last_hidden_state=hidden_states,
  423. )
  424. class CLIPTextTransformer(CLIPPreTrainedModel):
  425. config: CLIPTextConfig
  426. input_modalities = ("text",)
  427. _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
  428. def __init__(self, config: CLIPTextConfig):
  429. super().__init__(config)
  430. self.config = config
  431. embed_dim = config.hidden_size
  432. self.embeddings = CLIPTextEmbeddings(config)
  433. self.encoder = CLIPEncoder(config)
  434. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  435. # For `pooled_output` computation
  436. self.eos_token_id = config.eos_token_id
  437. self.post_init()
  438. @merge_with_config_defaults
  439. @capture_outputs(tie_last_hidden_states=False)
  440. @auto_docstring
  441. def forward(
  442. self,
  443. input_ids: torch.Tensor | None = None,
  444. attention_mask: torch.Tensor | None = None,
  445. position_ids: torch.Tensor | None = None,
  446. **kwargs: Unpack[TransformersKwargs],
  447. ) -> BaseModelOutputWithPooling:
  448. if input_ids is None:
  449. raise ValueError("You have to specify input_ids")
  450. input_shape = input_ids.size()
  451. input_ids = input_ids.view(-1, input_shape[-1])
  452. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  453. attention_mask = create_causal_mask(
  454. config=self.config,
  455. inputs_embeds=hidden_states,
  456. attention_mask=attention_mask,
  457. past_key_values=None,
  458. )
  459. kwargs.pop("is_causal", None)
  460. encoder_outputs: BaseModelOutput = self.encoder(
  461. inputs_embeds=hidden_states,
  462. attention_mask=attention_mask,
  463. is_causal=True,
  464. **kwargs,
  465. )
  466. last_hidden_state = encoder_outputs.last_hidden_state
  467. last_hidden_state = self.final_layer_norm(last_hidden_state)
  468. if self.eos_token_id == 2:
  469. # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
  470. # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
  471. # ------------------------------------------------------------
  472. # text_embeds.shape = [batch_size, sequence_length, transformer.width]
  473. # take features from the eot embedding (eot_token is the highest number in each sequence)
  474. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  475. pooled_output = last_hidden_state[
  476. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  477. input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
  478. ]
  479. else:
  480. # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
  481. pooled_output = last_hidden_state[
  482. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  483. # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
  484. # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
  485. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
  486. .int()
  487. .argmax(dim=-1),
  488. ]
  489. return BaseModelOutputWithPooling(
  490. last_hidden_state=last_hidden_state,
  491. pooler_output=pooled_output,
  492. )
  493. @auto_docstring(
  494. custom_intro="""
  495. The text model from CLIP without any head or projection on top.
  496. """
  497. )
  498. class CLIPTextModel(CLIPPreTrainedModel):
  499. config: CLIPTextConfig
  500. input_modalities = ("text",)
  501. _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
  502. def __init__(self, config: CLIPTextConfig):
  503. super().__init__(config)
  504. self.text_model = CLIPTextTransformer(config)
  505. # Initialize weights and apply final processing
  506. self.post_init()
  507. def get_input_embeddings(self) -> nn.Module:
  508. return self.text_model.embeddings.token_embedding
  509. def set_input_embeddings(self, value):
  510. self.text_model.embeddings.token_embedding = value
  511. @auto_docstring
  512. def forward(
  513. self,
  514. input_ids: torch.Tensor | None = None,
  515. attention_mask: torch.Tensor | None = None,
  516. position_ids: torch.Tensor | None = None,
  517. **kwargs: Unpack[TransformersKwargs],
  518. ) -> BaseModelOutputWithPooling:
  519. r"""
  520. Examples:
  521. ```python
  522. >>> from transformers import AutoTokenizer, CLIPTextModel
  523. >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
  524. >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
  525. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  526. >>> outputs = model(**inputs)
  527. >>> last_hidden_state = outputs.last_hidden_state
  528. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  529. ```"""
  530. return self.text_model(
  531. input_ids=input_ids,
  532. attention_mask=attention_mask,
  533. position_ids=position_ids,
  534. **kwargs,
  535. )
  536. class CLIPVisionTransformer(CLIPPreTrainedModel):
  537. config: CLIPVisionConfig
  538. main_input_name = "pixel_values"
  539. input_modalities = ("image",)
  540. _no_split_modules = ["CLIPEncoderLayer"]
  541. def __init__(self, config: CLIPVisionConfig):
  542. super().__init__(config)
  543. self.config = config
  544. embed_dim = config.hidden_size
  545. self.embeddings = CLIPVisionEmbeddings(config)
  546. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  547. self.encoder = CLIPEncoder(config)
  548. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  549. self.post_init()
  550. @merge_with_config_defaults
  551. @capture_outputs(tie_last_hidden_states=False)
  552. @auto_docstring
  553. def forward(
  554. self,
  555. pixel_values: torch.FloatTensor | None = None,
  556. interpolate_pos_encoding: bool | None = False,
  557. **kwargs: Unpack[TransformersKwargs],
  558. ) -> BaseModelOutputWithPooling:
  559. if pixel_values is None:
  560. raise ValueError("You have to specify pixel_values")
  561. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  562. hidden_states = self.pre_layrnorm(hidden_states)
  563. encoder_outputs: BaseModelOutput = self.encoder(
  564. inputs_embeds=hidden_states,
  565. **kwargs,
  566. )
  567. last_hidden_state = encoder_outputs.last_hidden_state
  568. pooled_output = last_hidden_state[:, 0, :]
  569. pooled_output = self.post_layernorm(pooled_output)
  570. return BaseModelOutputWithPooling(
  571. last_hidden_state=last_hidden_state,
  572. pooler_output=pooled_output,
  573. )
  574. @auto_docstring(
  575. custom_intro="""
  576. The vision model from CLIP without any head or projection on top.
  577. """
  578. )
  579. class CLIPVisionModel(CLIPPreTrainedModel):
  580. config: CLIPVisionConfig
  581. main_input_name = "pixel_values"
  582. input_modalities = ("image",)
  583. _no_split_modules = ["CLIPEncoderLayer"]
  584. def __init__(self, config: CLIPVisionConfig):
  585. super().__init__(config)
  586. self.vision_model = CLIPVisionTransformer(config)
  587. # Initialize weights and apply final processing
  588. self.post_init()
  589. def get_input_embeddings(self) -> nn.Module:
  590. return self.vision_model.embeddings.patch_embedding
  591. @auto_docstring
  592. def forward(
  593. self,
  594. pixel_values: torch.FloatTensor | None = None,
  595. interpolate_pos_encoding: bool = False,
  596. **kwargs: Unpack[TransformersKwargs],
  597. ) -> BaseModelOutputWithPooling:
  598. r"""
  599. Example:
  600. ```python
  601. >>> from PIL import Image
  602. >>> import httpx
  603. >>> from io import BytesIO
  604. >>> from transformers import AutoProcessor, CLIPVisionModel
  605. >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
  606. >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
  607. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  608. >>> with httpx.stream("GET", url) as response:
  609. ... image = Image.open(BytesIO(response.read()))
  610. >>> inputs = processor(images=image, return_tensors="pt")
  611. >>> outputs = model(**inputs)
  612. >>> last_hidden_state = outputs.last_hidden_state
  613. >>> pooled_output = outputs.pooler_output # pooled CLS states
  614. ```"""
  615. return self.vision_model(
  616. pixel_values=pixel_values,
  617. interpolate_pos_encoding=interpolate_pos_encoding,
  618. **kwargs,
  619. )
  620. @auto_docstring
  621. class CLIPModel(CLIPPreTrainedModel):
  622. config: CLIPConfig
  623. _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
  624. def __init__(self, config: CLIPConfig):
  625. super().__init__(config)
  626. if not isinstance(config.text_config, CLIPTextConfig):
  627. raise TypeError(
  628. "config.text_config is expected to be of type CLIPTextConfig but is of type"
  629. f" {type(config.text_config)}."
  630. )
  631. if not isinstance(config.vision_config, CLIPVisionConfig):
  632. raise TypeError(
  633. "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
  634. f" {type(config.vision_config)}."
  635. )
  636. text_config = config.text_config
  637. vision_config = config.vision_config
  638. self.projection_dim = config.projection_dim
  639. self.text_embed_dim = text_config.hidden_size
  640. self.vision_embed_dim = vision_config.hidden_size
  641. text_model = CLIPTextModel._from_config(text_config)
  642. self.text_model = text_model.text_model
  643. vision_model = CLIPVisionModel._from_config(vision_config)
  644. self.vision_model = vision_model.vision_model
  645. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  646. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  647. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  648. # Initialize weights and apply final processing
  649. self.post_init()
  650. @can_return_tuple
  651. @auto_docstring
  652. def get_text_features(
  653. self,
  654. input_ids: torch.Tensor,
  655. attention_mask: torch.Tensor | None = None,
  656. position_ids: torch.Tensor | None = None,
  657. **kwargs: Unpack[TransformersKwargs],
  658. ) -> tuple | BaseModelOutputWithPooling:
  659. r"""
  660. Examples:
  661. ```python
  662. >>> import torch
  663. >>> from transformers import AutoTokenizer, CLIPModel
  664. >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  665. >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
  666. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  667. >>> with torch.inference_mode():
  668. ... text_features = model.get_text_features(**inputs)
  669. ```"""
  670. text_outputs: BaseModelOutputWithPooling = self.text_model(
  671. input_ids=input_ids,
  672. attention_mask=attention_mask,
  673. position_ids=position_ids,
  674. return_dict=True,
  675. **kwargs,
  676. )
  677. pooled_output = text_outputs.pooler_output
  678. text_outputs.pooler_output = self.text_projection(pooled_output)
  679. return text_outputs
  680. @can_return_tuple
  681. @auto_docstring
  682. def get_image_features(
  683. self,
  684. pixel_values: torch.FloatTensor,
  685. interpolate_pos_encoding: bool = False,
  686. **kwargs: Unpack[TransformersKwargs],
  687. ) -> tuple | BaseModelOutputWithPooling:
  688. r"""
  689. Examples:
  690. ```python
  691. >>> import torch
  692. >>> from transformers import AutoProcessor, CLIPModel
  693. >>> from transformers.image_utils import load_image
  694. >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  695. >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
  696. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  697. >>> image = load_image(url)
  698. >>> inputs = processor(images=image, return_tensors="pt")
  699. >>> with torch.inference_mode():
  700. ... image_features = model.get_image_features(**inputs)
  701. ```"""
  702. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  703. pixel_values=pixel_values,
  704. interpolate_pos_encoding=interpolate_pos_encoding,
  705. return_dict=True,
  706. **kwargs,
  707. )
  708. pooled_output = vision_outputs.pooler_output
  709. vision_outputs.pooler_output = self.visual_projection(pooled_output)
  710. return vision_outputs
  711. @can_return_tuple
  712. @auto_docstring
  713. def forward(
  714. self,
  715. input_ids: torch.LongTensor | None = None,
  716. pixel_values: torch.FloatTensor | None = None,
  717. attention_mask: torch.Tensor | None = None,
  718. position_ids: torch.LongTensor | None = None,
  719. return_loss: bool | None = None,
  720. interpolate_pos_encoding: bool = False,
  721. **kwargs: Unpack[TransformersKwargs],
  722. ) -> CLIPOutput:
  723. r"""
  724. return_loss (`bool`, *optional*):
  725. Whether or not to return the contrastive loss.
  726. Examples:
  727. ```python
  728. >>> import torch
  729. >>> from transformers import AutoProcessor, CLIPModel
  730. >>> from transformers.image_utils import load_image
  731. >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  732. >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
  733. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  734. >>> image = load_image(url)
  735. >>> inputs = processor(
  736. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  737. ... )
  738. >>> with torch.inference_mode():
  739. ... outputs = model(**inputs)
  740. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  741. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  742. ```"""
  743. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  744. pixel_values=pixel_values,
  745. interpolate_pos_encoding=interpolate_pos_encoding,
  746. **kwargs,
  747. )
  748. text_outputs: BaseModelOutputWithPooling = self.text_model(
  749. input_ids=input_ids,
  750. attention_mask=attention_mask,
  751. position_ids=position_ids,
  752. **kwargs,
  753. )
  754. image_embeds = vision_outputs.pooler_output
  755. image_embeds = self.visual_projection(image_embeds)
  756. text_embeds = text_outputs.pooler_output
  757. text_embeds = self.text_projection(text_embeds)
  758. # normalized features
  759. image_embeds = image_embeds / _get_vector_norm(image_embeds)
  760. text_embeds = text_embeds / _get_vector_norm(text_embeds)
  761. # cosine similarity as logits
  762. logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
  763. logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device)
  764. logits_per_image = logits_per_text.t()
  765. loss = None
  766. if return_loss:
  767. loss = clip_loss(logits_per_text)
  768. return CLIPOutput(
  769. loss=loss,
  770. logits_per_image=logits_per_image,
  771. logits_per_text=logits_per_text,
  772. text_embeds=text_embeds,
  773. image_embeds=image_embeds,
  774. text_model_output=text_outputs,
  775. vision_model_output=vision_outputs,
  776. )
  777. @auto_docstring
  778. class CLIPTextModelWithProjection(CLIPPreTrainedModel):
  779. config: CLIPTextConfig
  780. input_modalities = ("text",)
  781. _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
  782. def __init__(self, config: CLIPTextConfig):
  783. super().__init__(config)
  784. text_model = CLIPTextModel._from_config(config)
  785. self.text_model = text_model.text_model
  786. self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
  787. # Initialize weights and apply final processing
  788. self.post_init()
  789. def get_input_embeddings(self) -> nn.Module:
  790. return self.text_model.embeddings.token_embedding
  791. def set_input_embeddings(self, value):
  792. self.text_model.embeddings.token_embedding = value
  793. @can_return_tuple
  794. @auto_docstring
  795. def forward(
  796. self,
  797. input_ids: torch.Tensor | None = None,
  798. attention_mask: torch.Tensor | None = None,
  799. position_ids: torch.Tensor | None = None,
  800. **kwargs: Unpack[TransformersKwargs],
  801. ) -> CLIPTextModelOutput:
  802. r"""
  803. Examples:
  804. ```python
  805. >>> import torch
  806. >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
  807. >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
  808. >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
  809. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  810. >>> with torch.inference_mode():
  811. ... outputs = model(**inputs)
  812. >>> text_embeds = outputs.text_embeds
  813. ```"""
  814. text_outputs: BaseModelOutputWithPooling = self.text_model(
  815. input_ids=input_ids,
  816. attention_mask=attention_mask,
  817. position_ids=position_ids,
  818. **kwargs,
  819. )
  820. pooled_output = text_outputs.pooler_output
  821. text_embeds = self.text_projection(pooled_output)
  822. return CLIPTextModelOutput(
  823. text_embeds=text_embeds,
  824. last_hidden_state=text_outputs.last_hidden_state,
  825. hidden_states=text_outputs.hidden_states,
  826. attentions=text_outputs.attentions,
  827. )
  828. @auto_docstring
  829. class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
  830. config: CLIPVisionConfig
  831. main_input_name = "pixel_values"
  832. input_modalities = ("image",)
  833. def __init__(self, config: CLIPVisionConfig):
  834. super().__init__(config)
  835. vision_model = CLIPVisionModel._from_config(config)
  836. self.vision_model = vision_model.vision_model
  837. self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
  838. # Initialize weights and apply final processing
  839. self.post_init()
  840. def get_input_embeddings(self) -> nn.Module:
  841. return self.vision_model.embeddings.patch_embedding
  842. @can_return_tuple
  843. @auto_docstring
  844. def forward(
  845. self,
  846. pixel_values: torch.FloatTensor | None = None,
  847. interpolate_pos_encoding: bool = False,
  848. **kwargs: Unpack[TransformersKwargs],
  849. ) -> CLIPVisionModelOutput:
  850. r"""
  851. Examples:
  852. ```python
  853. >>> import torch
  854. >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
  855. >>> from transformers.image_utils import load_image
  856. >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
  857. >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
  858. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  859. >>> image = load_image(url)
  860. >>> inputs = processor(images=image, return_tensors="pt")
  861. >>> with torch.inference_mode():
  862. ... outputs = model(**inputs)
  863. >>> image_embeds = outputs.image_embeds
  864. ```"""
  865. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  866. pixel_values=pixel_values,
  867. interpolate_pos_encoding=interpolate_pos_encoding,
  868. **kwargs,
  869. )
  870. pooled_output = vision_outputs.pooler_output
  871. image_embeds = self.visual_projection(pooled_output)
  872. return CLIPVisionModelOutput(
  873. image_embeds=image_embeds,
  874. last_hidden_state=vision_outputs.last_hidden_state,
  875. hidden_states=vision_outputs.hidden_states,
  876. attentions=vision_outputs.attentions,
  877. )
  878. @auto_docstring(
  879. custom_intro="""
  880. CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
  881. the patch tokens) e.g. for ImageNet.
  882. """
  883. )
  884. class CLIPForImageClassification(CLIPPreTrainedModel):
  885. main_input_name = "pixel_values"
  886. input_modalities = ("image",)
  887. def __init__(self, config: CLIPConfig) -> None:
  888. super().__init__(config)
  889. self.num_labels = config.num_labels
  890. vision_model = CLIPVisionModel._from_config(config.vision_config)
  891. self.vision_model = vision_model.vision_model
  892. # Classifier head
  893. self.classifier = (
  894. nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  895. )
  896. # Initialize weights and apply final processing
  897. self.post_init()
  898. @can_return_tuple
  899. @auto_docstring
  900. def forward(
  901. self,
  902. pixel_values: torch.Tensor | None = None,
  903. labels: torch.Tensor | None = None,
  904. **kwargs: Unpack[TransformersKwargs],
  905. ) -> ImageClassifierOutput:
  906. r"""
  907. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  908. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  909. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  910. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  911. """
  912. outputs: BaseModelOutputWithPooling = self.vision_model(
  913. pixel_values,
  914. **kwargs,
  915. )
  916. sequence_output = outputs.last_hidden_state
  917. sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
  918. logits = self.classifier(sequence_output)
  919. loss = None
  920. if labels is not None:
  921. loss = self.loss_function(labels, logits, self.config)
  922. return ImageClassifierOutput(
  923. loss=loss,
  924. logits=logits,
  925. hidden_states=outputs.hidden_states,
  926. attentions=outputs.attentions,
  927. )
  928. __all__ = [
  929. "CLIPModel",
  930. "CLIPPreTrainedModel",
  931. "CLIPTextModel",
  932. "CLIPTextModelWithProjection",
  933. "CLIPVisionModel",
  934. "CLIPVisionModelWithProjection",
  935. "CLIPForImageClassification",
  936. ]