modeling_imagegpt.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. # Copyright 2021 The OpenAI Team Authors and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch OpenAI ImageGPT model."""
  15. import math
  16. from typing import Any
  17. import torch
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutputWithPastAndCrossAttentions,
  27. CausalLMOutputWithCrossAttentions,
  28. SequenceClassifierOutputWithPast,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...pytorch_utils import Conv1D
  32. from ...utils import (
  33. auto_docstring,
  34. logging,
  35. torch_float,
  36. )
  37. from ...utils.generic import maybe_autocast
  38. from .configuration_imagegpt import ImageGPTConfig
  39. logger = logging.get_logger(__name__)
  40. class ImageGPTLayerNorm(nn.Module):
  41. def __init__(self, hidden_size: tuple[int], eps: float = 1e-5):
  42. super().__init__()
  43. self.eps = eps
  44. self.weight = nn.Parameter(torch.Tensor(hidden_size))
  45. def forward(self, tensor: torch.Tensor) -> torch.Tensor:
  46. # input is not mean centered
  47. tensor = tensor / torch.sqrt(torch.mean(torch.square(tensor), axis=-1, keepdim=True) + self.eps)
  48. tensor = tensor * self.weight
  49. return tensor
  50. class ImageGPTAttention(nn.Module):
  51. def __init__(self, config, is_cross_attention: bool | None = False, layer_idx: int | None = None):
  52. super().__init__()
  53. self.config = config
  54. max_positions = config.max_position_embeddings
  55. self.register_buffer(
  56. "bias",
  57. torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
  58. 1, 1, max_positions, max_positions
  59. ),
  60. persistent=False,
  61. )
  62. self.embed_dim = config.hidden_size
  63. self.num_heads = config.num_attention_heads
  64. self.head_dim = self.embed_dim // self.num_heads
  65. self.split_size = self.embed_dim
  66. if self.head_dim * self.num_heads != self.embed_dim:
  67. raise ValueError(
  68. f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  69. f" {self.num_heads})."
  70. )
  71. self.scale_attn_weights = config.scale_attn_weights
  72. self.is_cross_attention = is_cross_attention
  73. # Layer-wise attention scaling, reordering, and upcasting
  74. self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
  75. self.layer_idx = layer_idx
  76. self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
  77. if self.is_cross_attention:
  78. self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
  79. self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
  80. else:
  81. self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
  82. self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
  83. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  84. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  85. def _attn(self, query, key, value, attention_mask=None):
  86. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  87. if self.scale_attn_weights:
  88. attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
  89. # Layer-wise attention scaling
  90. if self.scale_attn_by_inverse_layer_idx:
  91. attn_weights = attn_weights / float(self.layer_idx + 1)
  92. if not self.is_cross_attention:
  93. # if only "normal" attention layer implements causal mask
  94. query_length, key_length = query.size(-2), key.size(-2)
  95. causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
  96. mask_value = torch.finfo(attn_weights.dtype).min
  97. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  98. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  99. mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
  100. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  101. if attention_mask is not None:
  102. # Apply the attention mask
  103. attn_weights = attn_weights + attention_mask
  104. attn_weights = nn.Softmax(dim=-1)(attn_weights)
  105. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
  106. attn_weights = attn_weights.type(value.dtype)
  107. attn_weights = self.attn_dropout(attn_weights)
  108. attn_output = torch.matmul(attn_weights, value)
  109. return attn_output, attn_weights
  110. def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None):
  111. # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
  112. bsz, num_heads, q_seq_len, dk = query.size()
  113. _, _, k_seq_len, _ = key.size()
  114. # Preallocate attn_weights for `baddbmm`
  115. attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
  116. # Compute Scale Factor
  117. scale_factor = 1.0
  118. if self.scale_attn_weights:
  119. scale_factor /= float(value.size(-1)) ** 0.5
  120. if self.scale_attn_by_inverse_layer_idx:
  121. scale_factor /= float(self.layer_idx + 1)
  122. # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
  123. with maybe_autocast(query.device.type, enabled=False):
  124. q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
  125. attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
  126. attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
  127. if not self.is_cross_attention:
  128. # if only "normal" attention layer implements causal mask
  129. query_length, key_length = query.size(-2), key.size(-2)
  130. causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
  131. mask_value = torch.finfo(attn_weights.dtype).min
  132. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  133. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  134. mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
  135. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  136. if attention_mask is not None:
  137. # Apply the attention mask
  138. attn_weights = attn_weights + attention_mask
  139. attn_weights = nn.Softmax(dim=-1)(attn_weights)
  140. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
  141. if attn_weights.dtype != torch.float32:
  142. raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
  143. attn_weights = attn_weights.type(value.dtype)
  144. attn_weights = self.attn_dropout(attn_weights)
  145. attn_output = torch.matmul(attn_weights, value)
  146. return attn_output, attn_weights
  147. def _split_heads(self, tensor, num_heads, attn_head_size):
  148. """
  149. Splits hidden_size dim into attn_head_size and num_heads
  150. """
  151. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  152. tensor = tensor.view(*new_shape)
  153. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  154. def _merge_heads(self, tensor, num_heads, attn_head_size):
  155. """
  156. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  157. """
  158. tensor = tensor.permute(0, 2, 1, 3).contiguous()
  159. new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
  160. return tensor.view(new_shape)
  161. def forward(
  162. self,
  163. hidden_states: torch.Tensor,
  164. layer_past: Cache | None = None,
  165. attention_mask: torch.Tensor | None = None,
  166. encoder_hidden_states: torch.Tensor | None = None,
  167. encoder_attention_mask: torch.Tensor | None = None,
  168. use_cache: bool | None = False,
  169. output_attentions: bool | None = False,
  170. **kwargs,
  171. ) -> tuple:
  172. is_cross_attention = encoder_hidden_states is not None
  173. bsz, seq_len, _ = hidden_states.shape
  174. if layer_past is not None:
  175. if isinstance(layer_past, EncoderDecoderCache):
  176. is_updated = layer_past.is_updated.get(self.layer_idx)
  177. if is_cross_attention:
  178. # after the first generated id, we can subsequently re-use all key/value_states from cache
  179. curr_past_key_values = layer_past.cross_attention_cache
  180. else:
  181. curr_past_key_values = layer_past.self_attention_cache
  182. else:
  183. curr_past_key_values = layer_past
  184. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  185. if is_cross_attention:
  186. if not hasattr(self, "q_attn"):
  187. raise ValueError(
  188. "If class is used as cross attention, the weights `q_attn` have to be defined. "
  189. "Please make sure to instantiate class with `ImageGPTAttention(..., is_cross_attention=True)`."
  190. )
  191. if layer_past is not None and is_updated:
  192. # reuse k,v, cross_attentions, and compute only q
  193. query = self.q_attn(hidden_states)
  194. key = curr_past_key_values.layers[self.layer_idx].keys
  195. value = curr_past_key_values.layers[self.layer_idx].values
  196. else:
  197. query = self.q_attn(hidden_states)
  198. key, value = self.c_attn(current_states).split(self.split_size, dim=2)
  199. key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  200. value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  201. else:
  202. query, key, value = self.c_attn(current_states).split(self.split_size, dim=2)
  203. key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  204. value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  205. if layer_past is not None:
  206. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  207. key, value = curr_past_key_values.update(key, value, self.layer_idx)
  208. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  209. if is_cross_attention:
  210. layer_past.is_updated[self.layer_idx] = True
  211. query = query.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  212. if self.reorder_and_upcast_attn:
  213. attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask)
  214. else:
  215. attn_output, attn_weights = self._attn(query, key, value, attention_mask)
  216. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  217. attn_output = self.c_proj(attn_output)
  218. attn_output = self.resid_dropout(attn_output)
  219. return attn_output, attn_weights
  220. class ImageGPTMLP(nn.Module):
  221. def __init__(self, intermediate_size, config):
  222. super().__init__()
  223. embed_dim = config.hidden_size
  224. self.c_fc = Conv1D(intermediate_size, embed_dim)
  225. self.c_proj = Conv1D(embed_dim, intermediate_size)
  226. self.act = ACT2FN[config.activation_function]
  227. self.dropout = nn.Dropout(config.resid_pdrop)
  228. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  229. hidden_states = self.c_fc(hidden_states)
  230. hidden_states = self.act(hidden_states)
  231. hidden_states = self.c_proj(hidden_states)
  232. hidden_states = self.dropout(hidden_states)
  233. return hidden_states
  234. class ImageGPTBlock(GradientCheckpointingLayer):
  235. def __init__(self, config, layer_idx=None):
  236. super().__init__()
  237. hidden_size = config.hidden_size
  238. inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
  239. self.ln_1 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  240. self.attn = ImageGPTAttention(config, layer_idx=layer_idx)
  241. self.ln_2 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  242. if config.add_cross_attention:
  243. self.crossattention = ImageGPTAttention(config, is_cross_attention=True, layer_idx=layer_idx)
  244. self.ln_cross_attn = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  245. self.mlp = ImageGPTMLP(inner_dim, config)
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. layer_past: Cache | None = None,
  250. attention_mask: torch.Tensor | None = None,
  251. encoder_hidden_states: torch.Tensor | None = None,
  252. encoder_attention_mask: torch.Tensor | None = None,
  253. use_cache: bool | None = False,
  254. output_attentions: bool | None = False,
  255. **kwargs,
  256. ) -> tuple:
  257. residual = hidden_states
  258. hidden_states = self.ln_1(hidden_states)
  259. attn_outputs = self.attn(
  260. hidden_states,
  261. layer_past=layer_past,
  262. attention_mask=attention_mask,
  263. use_cache=use_cache,
  264. output_attentions=output_attentions,
  265. )
  266. attn_output = attn_outputs[0]
  267. outputs = attn_outputs[1:]
  268. # residual connection
  269. hidden_states = attn_output + residual
  270. if encoder_hidden_states is not None:
  271. # add one self-attention block for cross-attention
  272. if not hasattr(self, "crossattention"):
  273. raise ValueError(
  274. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
  275. "cross-attention layers by setting `config.add_cross_attention=True`"
  276. )
  277. residual = hidden_states
  278. hidden_states = self.ln_cross_attn(hidden_states)
  279. cross_attn_outputs = self.crossattention(
  280. hidden_states,
  281. layer_past=layer_past,
  282. attention_mask=attention_mask,
  283. encoder_hidden_states=encoder_hidden_states,
  284. encoder_attention_mask=encoder_attention_mask,
  285. output_attentions=output_attentions,
  286. )
  287. attn_output = cross_attn_outputs[0]
  288. # residual connection
  289. hidden_states = residual + attn_output
  290. outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
  291. residual = hidden_states
  292. hidden_states = self.ln_2(hidden_states)
  293. feed_forward_hidden_states = self.mlp(hidden_states)
  294. # residual connection
  295. hidden_states = residual + feed_forward_hidden_states
  296. return (hidden_states,) + outputs
  297. @auto_docstring
  298. class ImageGPTPreTrainedModel(PreTrainedModel):
  299. config: ImageGPTConfig
  300. base_model_prefix = "transformer"
  301. main_input_name = "input_ids"
  302. input_modalities = ("image",)
  303. supports_gradient_checkpointing = True
  304. _no_split_modules = ["ImageGPTBlock"]
  305. @torch.no_grad()
  306. def _init_weights(self, module):
  307. """Initialize the weights."""
  308. super()._init_weights(module)
  309. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  310. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  311. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  312. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  313. #
  314. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  315. if isinstance(module, PreTrainedModel):
  316. for name, p in module.named_parameters():
  317. if "c_proj" in name and "weight" in name:
  318. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  319. init.normal_(p, mean=0.0, std=self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
  320. elif isinstance(module, ImageGPTAttention):
  321. max_positions = module.config.max_position_embeddings
  322. init.copy_(
  323. module.bias,
  324. torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
  325. 1, 1, max_positions, max_positions
  326. ),
  327. )
  328. @auto_docstring
  329. class ImageGPTModel(ImageGPTPreTrainedModel):
  330. def __init__(self, config: ImageGPTConfig):
  331. super().__init__(config)
  332. self.embed_dim = config.hidden_size
  333. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  334. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  335. self.drop = nn.Dropout(config.embd_pdrop)
  336. self.h = nn.ModuleList([ImageGPTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  337. self.ln_f = ImageGPTLayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  338. self.gradient_checkpointing = False
  339. # Initialize weights and apply final processing
  340. self.post_init()
  341. def get_input_embeddings(self):
  342. return self.wte
  343. def set_input_embeddings(self, new_embeddings):
  344. self.wte = new_embeddings
  345. @auto_docstring
  346. def forward(
  347. self,
  348. input_ids: torch.Tensor | None = None,
  349. past_key_values: Cache | None = None,
  350. attention_mask: torch.Tensor | None = None,
  351. token_type_ids: torch.Tensor | None = None,
  352. position_ids: torch.Tensor | None = None,
  353. inputs_embeds: torch.Tensor | None = None,
  354. encoder_hidden_states: torch.Tensor | None = None,
  355. encoder_attention_mask: torch.Tensor | None = None,
  356. use_cache: bool | None = None,
  357. output_attentions: bool | None = None,
  358. output_hidden_states: bool | None = None,
  359. return_dict: bool | None = None,
  360. **kwargs: Any,
  361. ) -> tuple | BaseModelOutputWithPastAndCrossAttentions:
  362. r"""
  363. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  364. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  365. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  366. sequence tokens in the vocabulary.
  367. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  368. `input_ids`.
  369. Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
  370. Examples:
  371. ```python
  372. >>> from transformers import AutoImageProcessor, ImageGPTModel
  373. >>> from PIL import Image
  374. >>> import httpx
  375. >>> from io import BytesIO
  376. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  377. >>> with httpx.stream("GET", url) as response:
  378. ... image = Image.open(BytesIO(response.read()))
  379. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  380. >>> model = ImageGPTModel.from_pretrained("openai/imagegpt-small")
  381. >>> inputs = image_processor(images=image, return_tensors="pt")
  382. >>> outputs = model(**inputs)
  383. >>> last_hidden_states = outputs.last_hidden_state
  384. ```"""
  385. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  386. output_hidden_states = (
  387. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  388. )
  389. use_cache = use_cache if use_cache is not None else self.config.use_cache
  390. return_dict = return_dict if return_dict is not None else self.config.return_dict
  391. if input_ids is not None and inputs_embeds is not None:
  392. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  393. elif input_ids is not None:
  394. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  395. input_shape = input_ids.size()
  396. input_ids = input_ids.view(-1, input_shape[-1])
  397. batch_size = input_ids.shape[0]
  398. elif inputs_embeds is not None:
  399. input_shape = inputs_embeds.size()[:-1]
  400. batch_size = inputs_embeds.shape[0]
  401. else:
  402. raise ValueError("You have to specify either input_ids or inputs_embeds")
  403. device = input_ids.device if input_ids is not None else inputs_embeds.device
  404. if self.gradient_checkpointing and self.training:
  405. if use_cache:
  406. logger.warning_once(
  407. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  408. )
  409. use_cache = False
  410. if token_type_ids is not None:
  411. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  412. if use_cache and past_key_values is None:
  413. past_key_values = DynamicCache(config=self.config)
  414. if position_ids is None:
  415. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  416. position_ids = torch.arange(input_shape[-1], device=device) + past_seen_tokens
  417. position_ids = position_ids.unsqueeze(0)
  418. # ImageGPTAttention mask.
  419. if attention_mask is not None:
  420. if batch_size <= 0:
  421. raise ValueError("batch_size has to be defined and > 0")
  422. attention_mask = attention_mask.view(batch_size, -1)
  423. # We create a 3D attention mask from a 2D tensor mask.
  424. # Sizes are [batch_size, 1, 1, to_seq_length]
  425. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  426. # this attention mask is more simple than the triangular masking of causal attention
  427. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  428. attention_mask = attention_mask[:, None, None, :]
  429. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  430. # masked positions, this operation will create a tensor which is 0.0 for
  431. # positions we want to attend and the dtype's smallest value for masked positions.
  432. # Since we are adding it to the raw scores before the softmax, this is
  433. # effectively the same as removing these entirely.
  434. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
  435. attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
  436. # If a 2D or 3D attention mask is provided for the cross-attention
  437. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  438. if self.config.add_cross_attention and encoder_hidden_states is not None:
  439. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  440. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  441. if encoder_attention_mask is None:
  442. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  443. encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  444. else:
  445. encoder_attention_mask = None
  446. if inputs_embeds is None:
  447. inputs_embeds = self.wte(input_ids)
  448. position_embeds = self.wpe(position_ids)
  449. hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
  450. if token_type_ids is not None:
  451. token_type_embeds = self.wte(token_type_ids)
  452. hidden_states = hidden_states + token_type_embeds
  453. hidden_states = self.drop(hidden_states)
  454. output_shape = input_shape + (hidden_states.size(-1),)
  455. all_self_attentions = () if output_attentions else None
  456. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  457. all_hidden_states = () if output_hidden_states else None
  458. for i, block in enumerate(self.h):
  459. if output_hidden_states:
  460. all_hidden_states = all_hidden_states + (hidden_states,)
  461. outputs = block(
  462. hidden_states,
  463. past_key_values,
  464. attention_mask,
  465. encoder_hidden_states, # as a positional argument for gradient checkpointing
  466. encoder_attention_mask=encoder_attention_mask,
  467. use_cache=use_cache,
  468. output_attentions=output_attentions,
  469. )
  470. hidden_states = outputs[0]
  471. if output_attentions:
  472. all_self_attentions = all_self_attentions + (outputs[1],)
  473. if self.config.add_cross_attention:
  474. all_cross_attentions = all_cross_attentions + (outputs[2],)
  475. hidden_states = self.ln_f(hidden_states)
  476. hidden_states = hidden_states.view(*output_shape)
  477. # Add last hidden state
  478. if output_hidden_states:
  479. all_hidden_states = all_hidden_states + (hidden_states,)
  480. if not return_dict:
  481. return tuple(
  482. v
  483. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
  484. if v is not None
  485. )
  486. return BaseModelOutputWithPastAndCrossAttentions(
  487. last_hidden_state=hidden_states,
  488. past_key_values=past_key_values,
  489. hidden_states=all_hidden_states,
  490. attentions=all_self_attentions,
  491. cross_attentions=all_cross_attentions,
  492. )
  493. @auto_docstring(
  494. custom_intro="""
  495. The ImageGPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
  496. embeddings).
  497. """
  498. )
  499. class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin):
  500. _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"}
  501. def __init__(self, config: ImageGPTConfig):
  502. super().__init__(config)
  503. self.transformer = ImageGPTModel(config)
  504. self.lm_head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False)
  505. # Initialize weights and apply final processing
  506. self.post_init()
  507. @auto_docstring
  508. def forward(
  509. self,
  510. input_ids: torch.Tensor | None = None,
  511. past_key_values: Cache | None = None,
  512. attention_mask: torch.Tensor | None = None,
  513. token_type_ids: torch.Tensor | None = None,
  514. position_ids: torch.Tensor | None = None,
  515. inputs_embeds: torch.Tensor | None = None,
  516. encoder_hidden_states: torch.Tensor | None = None,
  517. encoder_attention_mask: torch.Tensor | None = None,
  518. labels: torch.Tensor | None = None,
  519. use_cache: bool | None = None,
  520. output_attentions: bool | None = None,
  521. output_hidden_states: bool | None = None,
  522. return_dict: bool | None = None,
  523. **kwargs: Any,
  524. ) -> tuple | CausalLMOutputWithCrossAttentions:
  525. r"""
  526. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  527. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  528. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  529. sequence tokens in the vocabulary.
  530. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  531. `input_ids`.
  532. Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
  533. labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
  534. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  535. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  536. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  537. Examples:
  538. ```python
  539. >>> from transformers import AutoImageProcessor, ImageGPTForCausalImageModeling
  540. >>> import torch
  541. >>> import matplotlib.pyplot as plt
  542. >>> import numpy as np
  543. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  544. >>> model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small")
  545. >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  546. >>> model.to(device) # doctest: +IGNORE_RESULT
  547. >>> # unconditional generation of 8 images
  548. >>> batch_size = 4
  549. >>> context = torch.full((batch_size, 1), model.config.vocab_size - 1) # initialize with SOS token
  550. >>> context = context.to(device)
  551. >>> output = model.generate(
  552. ... input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40
  553. ... )
  554. >>> clusters = image_processor.clusters
  555. >>> height = image_processor.size["height"]
  556. >>> width = image_processor.size["width"]
  557. >>> samples = output[:, 1:].detach().cpu().numpy()
  558. >>> samples_img = [
  559. ... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
  560. ... ] # convert color cluster tokens back to pixels
  561. >>> f, axes = plt.subplots(1, batch_size, dpi=300)
  562. >>> for img, ax in zip(samples_img, axes): # doctest: +IGNORE_RESULT
  563. ... ax.axis("off")
  564. ... ax.imshow(img)
  565. ```"""
  566. return_dict = return_dict if return_dict is not None else self.config.return_dict
  567. transformer_outputs = self.transformer(
  568. input_ids,
  569. past_key_values=past_key_values,
  570. attention_mask=attention_mask,
  571. token_type_ids=token_type_ids,
  572. position_ids=position_ids,
  573. inputs_embeds=inputs_embeds,
  574. encoder_hidden_states=encoder_hidden_states,
  575. encoder_attention_mask=encoder_attention_mask,
  576. use_cache=use_cache,
  577. output_attentions=output_attentions,
  578. output_hidden_states=output_hidden_states,
  579. return_dict=return_dict,
  580. )
  581. hidden_states = transformer_outputs[0]
  582. lm_logits = self.lm_head(hidden_states)
  583. loss = None
  584. if labels is not None:
  585. # Shift so that tokens < n predict n
  586. shift_logits = lm_logits[..., :-1, :].contiguous()
  587. shift_labels = labels[..., 1:].contiguous()
  588. # Flatten the tokens
  589. loss_fct = CrossEntropyLoss()
  590. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  591. if not return_dict:
  592. output = (lm_logits,) + transformer_outputs[1:]
  593. return ((loss,) + output) if loss is not None else output
  594. return CausalLMOutputWithCrossAttentions(
  595. loss=loss,
  596. logits=lm_logits,
  597. past_key_values=transformer_outputs.past_key_values,
  598. hidden_states=transformer_outputs.hidden_states,
  599. attentions=transformer_outputs.attentions,
  600. cross_attentions=transformer_outputs.cross_attentions,
  601. )
  602. @auto_docstring(
  603. custom_intro="""
  604. The ImageGPT Model transformer with an image classification head on top (linear layer).
  605. [`ImageGPTForImageClassification`] average-pools the hidden states in order to do the classification.
  606. """
  607. )
  608. class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
  609. def __init__(self, config: ImageGPTConfig):
  610. super().__init__(config)
  611. self.num_labels = config.num_labels
  612. self.transformer = ImageGPTModel(config)
  613. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  614. # Initialize weights and apply final processing
  615. self.post_init()
  616. @auto_docstring
  617. def forward(
  618. self,
  619. input_ids: torch.Tensor | None = None,
  620. past_key_values: Cache | None = None,
  621. attention_mask: torch.Tensor | None = None,
  622. token_type_ids: torch.Tensor | None = None,
  623. position_ids: torch.Tensor | None = None,
  624. inputs_embeds: torch.Tensor | None = None,
  625. labels: torch.Tensor | None = None,
  626. use_cache: bool | None = None,
  627. output_attentions: bool | None = None,
  628. output_hidden_states: bool | None = None,
  629. return_dict: bool | None = None,
  630. **kwargs: Any,
  631. ) -> tuple | SequenceClassifierOutputWithPast:
  632. r"""
  633. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  634. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  635. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  636. sequence tokens in the vocabulary.
  637. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  638. `input_ids`.
  639. Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
  640. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  641. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  642. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  643. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  644. Examples:
  645. ```python
  646. >>> from transformers import AutoImageProcessor, ImageGPTForImageClassification
  647. >>> from PIL import Image
  648. >>> import httpx
  649. >>> from io import BytesIO
  650. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  651. >>> with httpx.stream("GET", url) as response:
  652. ... image = Image.open(BytesIO(response.read()))
  653. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  654. >>> model = ImageGPTForImageClassification.from_pretrained("openai/imagegpt-small")
  655. >>> inputs = image_processor(images=image, return_tensors="pt")
  656. >>> outputs = model(**inputs)
  657. >>> logits = outputs.logits
  658. ```"""
  659. return_dict = return_dict if return_dict is not None else self.config.return_dict
  660. transformer_outputs = self.transformer(
  661. input_ids,
  662. past_key_values=past_key_values,
  663. attention_mask=attention_mask,
  664. token_type_ids=token_type_ids,
  665. position_ids=position_ids,
  666. inputs_embeds=inputs_embeds,
  667. use_cache=use_cache,
  668. output_attentions=output_attentions,
  669. output_hidden_states=output_hidden_states,
  670. return_dict=return_dict,
  671. )
  672. hidden_states = transformer_outputs[0]
  673. # average-pool the hidden states along the sequence dimension
  674. pooled_hidden_states = hidden_states.mean(dim=1)
  675. # project from (batch_size, hidden_size) to (batch_size, num_labels)
  676. logits = self.score(pooled_hidden_states)
  677. loss = None
  678. if labels is not None:
  679. loss = self.loss_function(labels, logits, self.config)
  680. if not return_dict:
  681. output = (logits,) + transformer_outputs[1:]
  682. return ((loss,) + output) if loss is not None else output
  683. return SequenceClassifierOutputWithPast(
  684. loss=loss,
  685. logits=logits,
  686. past_key_values=transformer_outputs.past_key_values,
  687. hidden_states=transformer_outputs.hidden_states,
  688. attentions=transformer_outputs.attentions,
  689. )
  690. __all__ = [
  691. "ImageGPTForCausalImageModeling",
  692. "ImageGPTForImageClassification",
  693. "ImageGPTModel",
  694. "ImageGPTPreTrainedModel",
  695. ]