modeling_pix2struct.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350
  1. # Copyright 2023 The HuggingFace Inc. & Google team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Pix2Struct modeling file"""
  15. import math
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. BaseModelOutputWithPooling,
  27. CausalLMOutputWithCrossAttentions,
  28. Seq2SeqLMOutput,
  29. Seq2SeqModelOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...utils import (
  33. DUMMY_INPUTS,
  34. DUMMY_MASK,
  35. auto_docstring,
  36. is_torchdynamo_compiling,
  37. logging,
  38. )
  39. from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig
  40. logger = logging.get_logger(__name__)
  41. # General docstring
  42. # Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct
  43. class Pix2StructLayerNorm(nn.Module):
  44. def __init__(self, hidden_size, eps=1e-6):
  45. """
  46. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
  47. """
  48. super().__init__()
  49. self.weight = nn.Parameter(torch.ones(hidden_size))
  50. self.variance_epsilon = eps
  51. def forward(self, hidden_states):
  52. # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  53. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  54. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  55. # half-precision inputs is done in fp32
  56. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  57. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  58. # convert into half-precision if necessary
  59. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  60. hidden_states = hidden_states.to(self.weight.dtype)
  61. return self.weight * hidden_states
  62. try:
  63. from apex.normalization import FusedRMSNorm
  64. Pix2StructLayerNorm = FusedRMSNorm
  65. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm")
  66. except ImportError:
  67. # using the normal Pix2StructLayerNorm
  68. pass
  69. except Exception:
  70. logger.warning("Discovered apex but it failed to load, falling back to Pix2StructLayerNorm")
  71. class Pix2StructVisionEmbeddings(nn.Module):
  72. r"""
  73. Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models.
  74. Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch
  75. is represented by a vector of `hidden_size` values.
  76. """
  77. def __init__(self, config: Pix2StructConfig) -> None:
  78. super().__init__()
  79. self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size)
  80. self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size)
  81. self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size)
  82. self.dropout = nn.Dropout(config.dropout_rate)
  83. def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor:
  84. # the row and column indices are stored in the first and second position of the flattened_patches
  85. # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2
  86. row_indices = flattened_patches[:, :, 0].long()
  87. col_indices = flattened_patches[:, :, 1].long()
  88. flattened_patches = flattened_patches[:, :, 2:]
  89. embeddings = self.patch_projection(flattened_patches)
  90. row_embeddings = self.row_embedder(row_indices)
  91. col_embeddings = self.column_embedder(col_indices)
  92. # sum all embeddings together
  93. embeddings = embeddings + row_embeddings + col_embeddings
  94. embeddings = self.dropout(embeddings)
  95. return embeddings
  96. class Pix2StructVisionAttention(nn.Module):
  97. def __init__(self, config):
  98. super().__init__()
  99. self.hidden_size = config.hidden_size
  100. self.key_value_proj_dim = config.d_kv
  101. self.n_heads = config.num_attention_heads
  102. self.dropout = config.attention_dropout
  103. self.inner_dim = self.n_heads * self.key_value_proj_dim
  104. self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  105. self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  106. self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  107. self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
  108. self.gradient_checkpointing = False
  109. def forward(
  110. self,
  111. hidden_states,
  112. attention_mask=None,
  113. position_bias=None,
  114. output_attentions=False,
  115. ):
  116. """
  117. Self-attention block
  118. """
  119. # Input is (batch_size, seq_length, dim)
  120. # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
  121. batch_size, seq_length = hidden_states.shape[:2]
  122. def to_projection_shape(states):
  123. """projection"""
  124. return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  125. # get query states
  126. # (batch_size, n_heads, seq_length, dim_per_head)
  127. query_states = to_projection_shape(self.query(hidden_states))
  128. # get key/value states
  129. key_states = to_projection_shape(self.key(hidden_states))
  130. value_states = to_projection_shape(self.value(hidden_states))
  131. # compute scores
  132. # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  133. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  134. if position_bias is None:
  135. position_bias = torch.zeros(
  136. (1, self.n_heads, seq_length, seq_length), device=scores.device, dtype=scores.dtype
  137. )
  138. if self.gradient_checkpointing and self.training:
  139. position_bias.requires_grad = True
  140. if attention_mask.dim() == 2:
  141. position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
  142. elif attention_mask is not None:
  143. # (batch_size, n_heads, seq_length, key_length)
  144. position_bias = position_bias + attention_mask.to(position_bias.device)
  145. elif not is_torchdynamo_compiling():
  146. attention_mask = torch.ones(
  147. (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype
  148. )
  149. position_bias = position_bias + attention_mask.to(position_bias.device)
  150. position_bias = 1 - position_bias
  151. position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min)
  152. scores += position_bias_masked
  153. scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))
  154. # (batch_size, n_heads, seq_length, key_length)
  155. attn_weights = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores)
  156. # (batch_size, n_heads, seq_length, key_length)
  157. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  158. attn_output = torch.matmul(attn_weights, value_states)
  159. # (batch_size, seq_length, dim)
  160. attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
  161. attn_output = self.output(attn_output)
  162. outputs = (attn_output,) + (position_bias,)
  163. if output_attentions:
  164. outputs = outputs + (attn_weights,)
  165. return outputs
  166. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5DenseGatedActDense->Pix2StructVisionMlp,T5Config->Pix2StructVisionConfig,config.d_model->config.hidden_size,dropout_rate->dropout_rate
  167. class Pix2StructVisionMlp(nn.Module):
  168. def __init__(self, config: Pix2StructVisionConfig):
  169. super().__init__()
  170. self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  171. self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  172. self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
  173. self.dropout = nn.Dropout(config.dropout_rate)
  174. self.act = ACT2FN[config.dense_act_fn]
  175. def forward(self, hidden_states):
  176. hidden_gelu = self.act(self.wi_0(hidden_states))
  177. hidden_linear = self.wi_1(hidden_states)
  178. hidden_states = hidden_gelu * hidden_linear
  179. hidden_states = self.dropout(hidden_states)
  180. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  181. # See https://github.com/huggingface/transformers/issues/20287
  182. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  183. if (
  184. isinstance(self.wo.weight, torch.Tensor)
  185. and hidden_states.dtype != self.wo.weight.dtype
  186. and self.wo.weight.dtype != torch.int8
  187. ):
  188. hidden_states = hidden_states.to(self.wo.weight.dtype)
  189. hidden_states = self.wo(hidden_states)
  190. return hidden_states
  191. class Pix2StructVisionLayer(GradientCheckpointingLayer):
  192. def __init__(self, config: Pix2StructConfig) -> None:
  193. super().__init__()
  194. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  195. self.seq_len_dim = 1
  196. self.attention = Pix2StructVisionAttention(config)
  197. self.mlp = Pix2StructVisionMlp(config)
  198. self.pre_mlp_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  199. self.pre_attention_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  200. def forward(
  201. self,
  202. hidden_states: torch.Tensor,
  203. attention_mask: torch.Tensor | None = None,
  204. output_attentions: bool = False,
  205. ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor]:
  206. residual = hidden_states
  207. # in Pix2StructVision, layernorm is applied before self-attention
  208. hidden_states = self.pre_attention_layer_norm(hidden_states)
  209. self_attention_outputs = self.attention(
  210. hidden_states,
  211. attention_mask=attention_mask,
  212. output_attentions=output_attentions,
  213. )
  214. attention_output = self_attention_outputs[0]
  215. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  216. # first residual connection
  217. hidden_states = attention_output + residual
  218. # in Pix2StructVision, layernorm is also applied after self-attention
  219. layer_output = self.pre_mlp_layer_norm(hidden_states)
  220. layer_output = self.mlp(layer_output) + hidden_states # second residual connection
  221. outputs = (layer_output,) + outputs
  222. return outputs
  223. class Pix2StructVisionEncoder(nn.Module):
  224. def __init__(self, config: Pix2StructVisionConfig) -> None:
  225. super().__init__()
  226. self.config = config
  227. self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)])
  228. self.gradient_checkpointing = False
  229. def forward(
  230. self,
  231. hidden_states: torch.Tensor,
  232. attention_mask: torch.Tensor | None = None,
  233. output_attentions: bool = False,
  234. output_hidden_states: bool = False,
  235. return_dict: bool = True,
  236. ) -> tuple | BaseModelOutput:
  237. all_hidden_states = () if output_hidden_states else None
  238. all_self_attentions = () if output_attentions else None
  239. for i, layer_module in enumerate(self.layer):
  240. if output_hidden_states:
  241. all_hidden_states = all_hidden_states + (hidden_states,)
  242. layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
  243. hidden_states = layer_outputs[0]
  244. if output_attentions:
  245. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  246. if output_hidden_states:
  247. all_hidden_states = all_hidden_states + (hidden_states,)
  248. if not return_dict:
  249. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  250. return BaseModelOutput(
  251. last_hidden_state=hidden_states,
  252. hidden_states=all_hidden_states,
  253. attentions=all_self_attentions,
  254. )
  255. @auto_docstring
  256. class Pix2StructPreTrainedModel(PreTrainedModel):
  257. config: Pix2StructConfig
  258. input_modalities = ("image", "text")
  259. _can_compile_fullgraph = False
  260. @property
  261. def dummy_inputs(self):
  262. input_ids = torch.tensor(DUMMY_INPUTS)
  263. input_mask = torch.tensor(DUMMY_MASK)
  264. dummy_inputs = {
  265. "decoder_input_ids": input_ids,
  266. "input_ids": input_ids,
  267. "decoder_attention_mask": input_mask,
  268. }
  269. return dummy_inputs
  270. @torch.no_grad()
  271. def _init_weights(self, module):
  272. """Initialize the weights"""
  273. factor = self.config.initializer_factor # Used for testing weights initialization
  274. if isinstance(module, Pix2StructLayerNorm):
  275. init.constant_(module.weight, factor * 1.0)
  276. elif isinstance(module, Pix2StructTextDenseGatedActDense):
  277. hidden_size = (
  278. self.config.text_config.hidden_size
  279. if isinstance(self.config, Pix2StructConfig)
  280. else self.config.hidden_size
  281. )
  282. d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff
  283. init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5))
  284. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  285. init.zeros_(module.wi_0.bias)
  286. init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5))
  287. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  288. init.zeros_(module.wi_1.bias)
  289. init.normal_(module.wo.weight, mean=0.0, std=factor * ((d_ff) ** -0.5))
  290. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  291. init.zeros_(module.wo.bias)
  292. elif isinstance(module, Pix2StructTextAttention):
  293. hidden_size = (
  294. self.config.text_config.hidden_size
  295. if isinstance(self.config, Pix2StructConfig)
  296. else self.config.hidden_size
  297. )
  298. key_value_proj_dim = (
  299. self.config.text_config.d_kv if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size
  300. )
  301. n_heads = (
  302. self.config.text_config.num_heads
  303. if isinstance(self.config, Pix2StructConfig)
  304. else self.config.num_heads
  305. )
  306. init.normal_(module.query.weight, mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5))
  307. init.normal_(module.key.weight, mean=0.0, std=factor * (hidden_size**-0.5))
  308. init.normal_(module.value.weight, mean=0.0, std=factor * (hidden_size**-0.5))
  309. init.normal_(module.output.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  310. if module.has_relative_attention_bias:
  311. init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5))
  312. elif isinstance(module, nn.Embedding):
  313. hidden_size = (
  314. self.config.text_config.hidden_size
  315. if isinstance(self.config, Pix2StructConfig)
  316. else self.config.hidden_size
  317. )
  318. init.normal_(module.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5))
  319. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  320. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  321. init.zeros_(module.weight[module.padding_idx])
  322. elif isinstance(module, Pix2StructTextModel):
  323. hidden_size = (
  324. self.config.text_config.hidden_size
  325. if isinstance(self.config, Pix2StructConfig)
  326. else self.config.hidden_size
  327. )
  328. init.normal_(module.lm_head.weight, mean=0.0, std=factor * ((hidden_size) ** -0.5))
  329. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  330. init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  331. if module.bias is not None:
  332. init.zeros_(module.bias)
  333. elif isinstance(module, Pix2StructLayerNorm):
  334. if module.weight is not None:
  335. init.ones_(module.weight)
  336. elif isinstance(module, nn.Embedding):
  337. init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
  338. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  339. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  340. init.zeros_(module.weight[module.padding_idx])
  341. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct
  342. def _shift_right(self, input_ids):
  343. decoder_start_token_id = self.config.decoder_start_token_id
  344. pad_token_id = self.config.pad_token_id
  345. if decoder_start_token_id is None:
  346. raise ValueError(
  347. "self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. "
  348. "See Pix2Struct docs for more information."
  349. )
  350. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  351. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  352. shifted_input_ids[..., 0] = decoder_start_token_id
  353. if pad_token_id is None:
  354. raise ValueError("self.model.config.pad_token_id has to be defined.")
  355. # replace possible -100 values in labels by `pad_token_id`
  356. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  357. return shifted_input_ids
  358. @auto_docstring
  359. class Pix2StructVisionModel(Pix2StructPreTrainedModel):
  360. config: Pix2StructVisionConfig
  361. main_input_name = "flattened_patches"
  362. input_modalities = ("image",)
  363. supports_gradient_checkpointing = True
  364. _no_split_modules = ["Pix2StructVisionLayer"]
  365. def __init__(self, config: Pix2StructVisionConfig):
  366. super().__init__(config)
  367. self.config = config
  368. self.embeddings = Pix2StructVisionEmbeddings(config)
  369. self.encoder = Pix2StructVisionEncoder(config)
  370. self.layernorm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  371. # Initialize weights and apply final processing
  372. self.post_init()
  373. def get_input_embeddings(self):
  374. return self.embeddings.patch_projection
  375. @auto_docstring
  376. def forward(
  377. self,
  378. flattened_patches: torch.Tensor | None = None,
  379. attention_mask: torch.Tensor | None = None,
  380. output_attentions: bool | None = None,
  381. output_hidden_states: bool | None = None,
  382. return_dict: bool | None = None,
  383. **kwargs,
  384. ) -> tuple | BaseModelOutputWithPooling:
  385. r"""
  386. flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`):
  387. Flattened and padded pixel values. These values can be obtained using [`AutoImageProcessor`]. See
  388. [`Pix2StructVisionImageProcessor.__call__`] for details. Check the [original
  389. paper](https://huggingface.co/papers/2210.03347) (figure 5) for more details.
  390. Example:
  391. ```python
  392. >>> import httpx
  393. >>> from io import BytesIO
  394. >>> from PIL import Image
  395. >>> from transformers import AutoProcessor, Pix2StructVisionModel
  396. >>> image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  397. >>> model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
  398. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  399. >>> with httpx.stream("GET", url) as response:
  400. ... image = Image.open(BytesIO(response.read()))
  401. >>> inputs = image_processor(images=image, return_tensors="pt")
  402. >>> with torch.no_grad():
  403. ... outputs = model(**inputs)
  404. >>> last_hidden_states = outputs.last_hidden_state
  405. >>> list(last_hidden_states.shape)
  406. [1, 2048, 768]
  407. ```
  408. """
  409. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  410. output_hidden_states = (
  411. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  412. )
  413. return_dict = return_dict if return_dict is not None else self.config.return_dict
  414. if flattened_patches is None:
  415. raise ValueError("You have to specify flattened_patches")
  416. if attention_mask is None:
  417. # check where `flattened_patches` is not 0
  418. attention_mask = (flattened_patches.sum(dim=-1) != 0).float()
  419. embedding_output = self.embeddings(flattened_patches)
  420. encoder_outputs = self.encoder(
  421. embedding_output,
  422. attention_mask=attention_mask,
  423. output_attentions=output_attentions,
  424. output_hidden_states=output_hidden_states,
  425. return_dict=return_dict,
  426. )
  427. sequence_output = encoder_outputs[0]
  428. sequence_output = self.layernorm(sequence_output)
  429. if not return_dict:
  430. head_outputs = (sequence_output,)
  431. return head_outputs + encoder_outputs[1:]
  432. return BaseModelOutput(
  433. last_hidden_state=sequence_output,
  434. hidden_states=encoder_outputs.hidden_states,
  435. attentions=encoder_outputs.attentions,
  436. )
  437. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pix2StructText,d_model->hidden_size
  438. class Pix2StructTextDenseGatedActDense(nn.Module):
  439. def __init__(self, config: Pix2StructTextConfig):
  440. super().__init__()
  441. self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  442. self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  443. self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
  444. self.dropout = nn.Dropout(config.dropout_rate)
  445. self.act = ACT2FN[config.dense_act_fn]
  446. def forward(self, hidden_states):
  447. hidden_gelu = self.act(self.wi_0(hidden_states))
  448. hidden_linear = self.wi_1(hidden_states)
  449. hidden_states = hidden_gelu * hidden_linear
  450. hidden_states = self.dropout(hidden_states)
  451. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  452. # See https://github.com/huggingface/transformers/issues/20287
  453. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  454. if (
  455. isinstance(self.wo.weight, torch.Tensor)
  456. and hidden_states.dtype != self.wo.weight.dtype
  457. and self.wo.weight.dtype != torch.int8
  458. ):
  459. hidden_states = hidden_states.to(self.wo.weight.dtype)
  460. hidden_states = self.wo(hidden_states)
  461. return hidden_states
  462. class Pix2StructTextLayerFF(nn.Module):
  463. def __init__(self, config: Pix2StructTextConfig):
  464. super().__init__()
  465. self.DenseReluDense = Pix2StructTextDenseGatedActDense(config)
  466. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  467. self.dropout = nn.Dropout(config.dropout_rate)
  468. # Copied from transformers.models.t5.modeling_t5.T5LayerFF.forward
  469. def forward(self, hidden_states):
  470. forwarded_states = self.layer_norm(hidden_states)
  471. forwarded_states = self.DenseReluDense(forwarded_states)
  472. hidden_states = hidden_states + self.dropout(forwarded_states)
  473. return hidden_states
  474. class Pix2StructTextAttention(nn.Module):
  475. def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: int | None = None):
  476. super().__init__()
  477. self.has_relative_attention_bias = has_relative_attention_bias
  478. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  479. self.relative_attention_max_distance = config.relative_attention_max_distance
  480. self.hidden_size = config.hidden_size
  481. self.key_value_proj_dim = config.d_kv
  482. self.n_heads = config.num_heads
  483. self.dropout = config.dropout_rate
  484. self.inner_dim = self.n_heads * self.key_value_proj_dim
  485. self.layer_idx = layer_idx
  486. if layer_idx is None:
  487. logger.warning_once(
  488. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  489. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  490. "when creating this class."
  491. )
  492. self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  493. self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  494. self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  495. self.output = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  496. if self.has_relative_attention_bias:
  497. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  498. self.gradient_checkpointing = False
  499. @staticmethod
  500. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  501. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  502. """
  503. Adapted from Mesh Tensorflow:
  504. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  505. Translate relative position to a bucket number for relative attention. The relative position is defined as
  506. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  507. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  508. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  509. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  510. This should allow for more graceful generalization to longer sequences than the model has been trained on
  511. Args:
  512. relative_position: an int32 Tensor
  513. bidirectional: a boolean - whether the attention is bidirectional
  514. num_buckets: an integer
  515. max_distance: an integer
  516. Returns:
  517. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  518. """
  519. relative_buckets = 0
  520. if bidirectional:
  521. num_buckets //= 2
  522. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  523. relative_position = torch.abs(relative_position)
  524. else:
  525. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  526. # now relative_position is in the range [0, inf)
  527. # half of the buckets are for exact increments in positions
  528. max_exact = num_buckets // 2
  529. is_small = relative_position < max_exact
  530. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  531. relative_position_if_large = max_exact + (
  532. torch.log(relative_position.float() / max_exact)
  533. / math.log(max_distance / max_exact)
  534. * (num_buckets - max_exact)
  535. ).to(torch.long)
  536. relative_position_if_large = torch.min(
  537. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  538. )
  539. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  540. return relative_buckets
  541. # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
  542. def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
  543. """Compute binned relative position bias"""
  544. if device is None:
  545. device = self.relative_attention_bias.weight.device
  546. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
  547. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  548. relative_position = memory_position - context_position # shape (query_length, key_length)
  549. relative_position_bucket = self._relative_position_bucket(
  550. relative_position, # shape (query_length, key_length)
  551. bidirectional=False,
  552. num_buckets=self.relative_attention_num_buckets,
  553. max_distance=self.relative_attention_max_distance,
  554. )
  555. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  556. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  557. return values
  558. # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
  559. def forward(
  560. self,
  561. hidden_states,
  562. mask=None,
  563. key_value_states=None,
  564. position_bias=None,
  565. past_key_values=None,
  566. output_attentions=False,
  567. **kwargs,
  568. ):
  569. """
  570. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  571. """
  572. # Input is (batch_size, seq_length, dim)
  573. # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
  574. batch_size, seq_length = hidden_states.shape[:2]
  575. past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
  576. # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
  577. past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
  578. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  579. is_cross_attention = key_value_states is not None
  580. query_states = self.query(hidden_states)
  581. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  582. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  583. if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
  584. is_updated = past_key_values.is_updated.get(self.layer_idx)
  585. if is_cross_attention:
  586. # after the first generated id, we can subsequently re-use all key/value_states from cache
  587. curr_past_key_values = past_key_values.cross_attention_cache
  588. else:
  589. curr_past_key_values = past_key_values.self_attention_cache
  590. else:
  591. curr_past_key_values = past_key_values
  592. current_states = key_value_states if is_cross_attention else hidden_states
  593. if is_cross_attention and past_key_values and is_updated:
  594. # reuse k,v, cross_attentions
  595. key_states = curr_past_key_values.layers[self.layer_idx].keys
  596. value_states = curr_past_key_values.layers[self.layer_idx].values
  597. else:
  598. key_states = self.key(current_states)
  599. value_states = self.value(current_states)
  600. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  601. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  602. if past_key_values is not None:
  603. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  604. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  605. if is_cross_attention:
  606. past_key_values.is_updated[self.layer_idx] = True
  607. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  608. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  609. if position_bias is None:
  610. key_length = key_states.shape[-2]
  611. if not self.has_relative_attention_bias:
  612. position_bias = torch.zeros(
  613. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  614. )
  615. if self.gradient_checkpointing and self.training:
  616. position_bias.requires_grad = True
  617. else:
  618. position_bias = self.compute_bias(
  619. seq_length, key_length, device=scores.device, past_seen_tokens=past_seen_tokens
  620. )
  621. if mask is not None:
  622. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  623. position_bias = position_bias + causal_mask
  624. position_bias_masked = position_bias
  625. scores += position_bias_masked
  626. # (batch_size, n_heads, seq_length, key_length)
  627. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  628. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  629. attn_output = torch.matmul(attn_weights, value_states)
  630. attn_output = attn_output.transpose(1, 2).contiguous()
  631. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  632. attn_output = self.output(attn_output)
  633. outputs = (attn_output, position_bias)
  634. if output_attentions:
  635. outputs = outputs + (attn_weights,)
  636. return outputs
  637. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size
  638. class Pix2StructTextLayerSelfAttention(nn.Module):
  639. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  640. super().__init__()
  641. self.attention = Pix2StructTextAttention(
  642. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  643. )
  644. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  645. self.dropout = nn.Dropout(config.dropout_rate)
  646. def forward(
  647. self,
  648. hidden_states,
  649. attention_mask=None,
  650. position_bias=None,
  651. past_key_values=None,
  652. use_cache=False,
  653. output_attentions=False,
  654. **kwargs,
  655. ):
  656. normed_hidden_states = self.layer_norm(hidden_states)
  657. attention_output = self.attention(
  658. normed_hidden_states,
  659. mask=attention_mask,
  660. position_bias=position_bias,
  661. past_key_values=past_key_values,
  662. use_cache=use_cache,
  663. output_attentions=output_attentions,
  664. )
  665. hidden_states = hidden_states + self.dropout(attention_output[0])
  666. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  667. return outputs
  668. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size
  669. class Pix2StructTextLayerCrossAttention(nn.Module):
  670. def __init__(self, config, layer_idx: int | None = None):
  671. super().__init__()
  672. self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  673. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  674. self.dropout = nn.Dropout(config.dropout_rate)
  675. def forward(
  676. self,
  677. hidden_states,
  678. key_value_states,
  679. attention_mask=None,
  680. position_bias=None,
  681. past_key_values=None,
  682. output_attentions=False,
  683. **kwargs,
  684. ):
  685. normed_hidden_states = self.layer_norm(hidden_states)
  686. attention_output = self.attention(
  687. normed_hidden_states,
  688. mask=attention_mask,
  689. key_value_states=key_value_states,
  690. position_bias=position_bias,
  691. past_key_values=past_key_values,
  692. output_attentions=output_attentions,
  693. )
  694. layer_output = hidden_states + self.dropout(attention_output[0])
  695. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  696. return outputs
  697. class Pix2StructTextBlock(GradientCheckpointingLayer):
  698. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  699. super().__init__()
  700. self.self_attention = Pix2StructTextLayerSelfAttention(
  701. config,
  702. has_relative_attention_bias=has_relative_attention_bias,
  703. layer_idx=layer_idx,
  704. )
  705. self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
  706. config,
  707. layer_idx=layer_idx,
  708. )
  709. self.mlp = Pix2StructTextLayerFF(config)
  710. def forward(
  711. self,
  712. hidden_states,
  713. attention_mask=None,
  714. position_bias=None,
  715. encoder_hidden_states=None,
  716. encoder_attention_mask=None,
  717. encoder_decoder_position_bias=None,
  718. past_key_values=None,
  719. use_cache=False,
  720. output_attentions=False,
  721. return_dict=True,
  722. **kwargs,
  723. ):
  724. self_attention_outputs = self.self_attention(
  725. hidden_states,
  726. attention_mask=attention_mask,
  727. position_bias=position_bias,
  728. past_key_values=past_key_values,
  729. use_cache=use_cache,
  730. output_attentions=output_attentions,
  731. )
  732. hidden_states = self_attention_outputs[0]
  733. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  734. # clamp inf values to enable fp16 training
  735. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  736. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  737. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  738. do_cross_attention = encoder_hidden_states is not None
  739. if do_cross_attention:
  740. cross_attention_outputs = self.encoder_decoder_attention(
  741. hidden_states,
  742. key_value_states=encoder_hidden_states,
  743. attention_mask=encoder_attention_mask,
  744. position_bias=encoder_decoder_position_bias,
  745. past_key_values=past_key_values,
  746. output_attentions=output_attentions,
  747. )
  748. hidden_states = cross_attention_outputs[0]
  749. # clamp inf values to enable fp16 training
  750. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  751. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  752. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  753. # Keep cross-attention outputs and relative position weights
  754. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  755. # Apply Feed Forward layer
  756. hidden_states = self.mlp(hidden_states)
  757. # clamp inf values to enable fp16 training
  758. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  759. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  760. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  761. outputs = (hidden_states,)
  762. return outputs + attention_outputs
  763. @auto_docstring(
  764. custom_intro="""
  765. The standalone text decoder of Pix2Struct
  766. """
  767. )
  768. class Pix2StructTextModel(Pix2StructPreTrainedModel):
  769. config: Pix2StructTextConfig
  770. input_modalities = ("text",)
  771. _no_split_modules = ["Pix2StructTextBlock"]
  772. _tied_weights_keys = {"lm_head.weight": "embed_tokens.weight"}
  773. supports_gradient_checkpointing = True
  774. def __init__(self, config):
  775. super().__init__(config)
  776. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  777. self.layer = nn.ModuleList(
  778. [
  779. Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
  780. for i in range(config.num_layers)
  781. ]
  782. )
  783. self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  784. self.dropout = nn.Dropout(config.dropout_rate)
  785. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  786. # Initialize weights and apply final processing
  787. self.post_init()
  788. self.gradient_checkpointing = False
  789. def set_input_embeddings(self, new_embeddings):
  790. self.embed_tokens = new_embeddings
  791. @auto_docstring
  792. def forward(
  793. self,
  794. input_ids: torch.LongTensor | None = None,
  795. attention_mask: torch.FloatTensor | None = None,
  796. encoder_hidden_states: torch.FloatTensor | None = None,
  797. encoder_attention_mask: torch.FloatTensor | None = None,
  798. inputs_embeds: torch.LongTensor | None = None,
  799. past_key_values: Cache | None = None,
  800. use_cache: bool | None = None,
  801. output_attentions: bool | None = None,
  802. output_hidden_states: bool | None = None,
  803. labels: torch.LongTensor | None = None,
  804. return_dict: bool | None = None,
  805. **kwargs,
  806. ) -> tuple[torch.FloatTensor, ...] | CausalLMOutputWithCrossAttentions:
  807. r"""
  808. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  809. Indices of input sequence tokens in the vocabulary. Pix2StructText is a model with relative position
  810. embeddings so you should be able to pad the inputs on both the right and the left.
  811. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  812. [`PreTrainedTokenizer.__call__`] for detail.
  813. [What are input IDs?](../glossary#input-ids)
  814. To know more on how to prepare `input_ids` for pretraining take a look a [Pix2StructText
  815. Training](./t5#training).
  816. Example:
  817. ```python
  818. >>> from transformers import AutoProcessor, Pix2StructTextModel
  819. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  820. >>> model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base")
  821. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  822. >>> outputs = model(**inputs)
  823. >>> loss = outputs.loss
  824. ```
  825. """
  826. use_cache = use_cache if use_cache is not None else self.config.use_cache
  827. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  828. output_hidden_states = (
  829. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  830. )
  831. return_dict = return_dict if return_dict is not None else self.config.return_dict
  832. if self.gradient_checkpointing and self.training and use_cache:
  833. logger.warning(
  834. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  835. )
  836. use_cache = False
  837. if input_ids is not None and inputs_embeds is not None:
  838. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  839. elif input_ids is not None:
  840. input_shape = input_ids.size()
  841. input_ids = input_ids.view(-1, input_shape[-1])
  842. elif inputs_embeds is not None:
  843. input_shape = inputs_embeds.size()[:-1]
  844. else:
  845. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  846. if inputs_embeds is None:
  847. assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
  848. inputs_embeds = self.embed_tokens(input_ids)
  849. batch_size, seq_length = input_shape
  850. if use_cache and past_key_values is None:
  851. if self.config.is_encoder_decoder:
  852. past_key_values = EncoderDecoderCache(
  853. DynamicCache(config=self.config), DynamicCache(config=self.config)
  854. )
  855. else:
  856. past_key_values = DynamicCache(config=self.config)
  857. if attention_mask is None:
  858. # required mask seq length can be calculated via length of past
  859. mask_seq_length = (
  860. past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length
  861. )
  862. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  863. if self.config.is_decoder:
  864. causal_mask = create_causal_mask(
  865. config=self.config,
  866. inputs_embeds=inputs_embeds,
  867. attention_mask=attention_mask,
  868. past_key_values=past_key_values,
  869. )
  870. else:
  871. causal_mask = attention_mask[:, None, None, :]
  872. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  873. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  874. # If a 2D or 3D attention mask is provided for the cross-attention
  875. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  876. if encoder_hidden_states is not None:
  877. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  878. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  879. if encoder_attention_mask is None:
  880. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  881. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  882. else:
  883. encoder_extended_attention_mask = None
  884. all_hidden_states = () if output_hidden_states else None
  885. all_attentions = () if output_attentions else None
  886. all_cross_attentions = () if (output_attentions) else None
  887. position_bias = None
  888. encoder_decoder_position_bias = None
  889. hidden_states = self.dropout(inputs_embeds)
  890. for i, layer_module in enumerate(self.layer):
  891. if output_hidden_states:
  892. all_hidden_states = all_hidden_states + (hidden_states,)
  893. layer_outputs = layer_module(
  894. hidden_states,
  895. causal_mask,
  896. position_bias,
  897. encoder_hidden_states,
  898. encoder_extended_attention_mask,
  899. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  900. past_key_values=past_key_values,
  901. use_cache=use_cache,
  902. output_attentions=output_attentions,
  903. )
  904. hidden_states = layer_outputs[0]
  905. # We share the position biases between the layers - the first layer store them
  906. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  907. # (cross-attention position bias), (cross-attention weights)
  908. position_bias = layer_outputs[1]
  909. if encoder_hidden_states is not None:
  910. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  911. if output_attentions:
  912. all_attentions = all_attentions + (layer_outputs[2],)
  913. if encoder_hidden_states is not None:
  914. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  915. hidden_states = self.final_layer_norm(hidden_states)
  916. hidden_states = self.dropout(hidden_states)
  917. logits = self.lm_head(hidden_states)
  918. # Add last layer
  919. if output_hidden_states:
  920. all_hidden_states = all_hidden_states + (hidden_states,)
  921. loss = None
  922. if labels is not None:
  923. # move labels to correct device
  924. labels = labels.to(logits.device)
  925. loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
  926. loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1))
  927. if not return_dict:
  928. return tuple(
  929. v
  930. for v in [
  931. loss,
  932. logits,
  933. past_key_values,
  934. all_hidden_states,
  935. all_attentions,
  936. all_cross_attentions,
  937. ]
  938. if v is not None
  939. )
  940. return CausalLMOutputWithCrossAttentions(
  941. loss=loss,
  942. logits=logits,
  943. past_key_values=past_key_values,
  944. hidden_states=all_hidden_states,
  945. attentions=all_attentions,
  946. cross_attentions=all_cross_attentions,
  947. )
  948. @auto_docstring(
  949. custom_intro="""
  950. A conditional generation model with a language modeling head. Can be used for sequence generation tasks.
  951. """
  952. )
  953. class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin):
  954. config: Pix2StructConfig
  955. main_input_name = "flattened_patches"
  956. def __init__(self, config: Pix2StructConfig):
  957. super().__init__(config)
  958. self.encoder = Pix2StructVisionModel(config.vision_config)
  959. self.decoder = Pix2StructTextModel(config.text_config)
  960. self.is_vqa = config.is_vqa
  961. # Initialize weights and apply final processing
  962. self.post_init()
  963. def get_input_embeddings(self):
  964. return self.decoder.get_input_embeddings()
  965. def set_input_embeddings(self, new_embeddings):
  966. self.decoder.set_input_embeddings(new_embeddings)
  967. def get_output_embeddings(self) -> nn.Module:
  968. return self.decoder.get_output_embeddings()
  969. def set_output_embeddings(self, new_embeddings):
  970. self.decoder.set_output_embeddings(new_embeddings)
  971. @auto_docstring
  972. def forward(
  973. self,
  974. flattened_patches: torch.FloatTensor | None = None,
  975. attention_mask: torch.FloatTensor | None = None,
  976. decoder_input_ids: torch.LongTensor | None = None,
  977. decoder_attention_mask: torch.BoolTensor | None = None,
  978. encoder_outputs: tuple[tuple[torch.FloatTensor]] | None = None,
  979. past_key_values: Cache | None = None,
  980. labels: torch.LongTensor | None = None,
  981. decoder_inputs_embeds: torch.Tensor | None = None,
  982. use_cache: bool | None = None,
  983. output_attentions: bool | None = None,
  984. output_hidden_states: bool | None = None,
  985. return_dict: bool | None = None,
  986. **kwargs,
  987. ) -> tuple[torch.FloatTensor] | Seq2SeqModelOutput:
  988. r"""
  989. flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`):
  990. Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` =
  991. `num_channels` * `patch_size` * `patch_size`
  992. The process of flattening the pixel patches is done by `Pix2StructProcessor`.
  993. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  994. Indices of decoder input sequence tokens in the vocabulary.
  995. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  996. [`PreTrainedTokenizer.__call__`] for details.
  997. [What are decoder input IDs?](../glossary#decoder-input-ids)
  998. Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  999. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1000. `past_key_values`).
  1001. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText
  1002. Training](./t5#training).
  1003. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1004. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1005. be used by default.
  1006. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1007. Labels for computing the masked language modeling loss for the decoder.
  1008. Example:
  1009. Inference:
  1010. ```python
  1011. >>> from PIL import Image
  1012. >>> import httpx
  1013. >>> from io import BytesIO
  1014. >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
  1015. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  1016. >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
  1017. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  1018. >>> with httpx.stream("GET", url) as response:
  1019. ... image = Image.open(BytesIO(response.read()))
  1020. >>> inputs = processor(images=image, return_tensors="pt")
  1021. >>> # autoregressive generation
  1022. >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
  1023. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1024. >>> print(generated_text)
  1025. A stop sign is on a street corner.
  1026. >>> # conditional generation
  1027. >>> text = "A picture of"
  1028. >>> inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False)
  1029. >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
  1030. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1031. >>> print(generated_text)
  1032. A picture of a stop sign with a red stop sign
  1033. ```
  1034. Training:
  1035. ```python
  1036. >>> from PIL import Image
  1037. >>> import httpx
  1038. >>> from io import BytesIO
  1039. >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
  1040. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
  1041. >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
  1042. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  1043. >>> with httpx.stream("GET", url) as response:
  1044. ... image = Image.open(BytesIO(response.read()))
  1045. >>> text = "A stop sign is on the street corner."
  1046. >>> inputs = processor(images=image, return_tensors="pt")
  1047. >>> labels = processor(text=text, return_tensors="pt").input_ids
  1048. >>> # forward pass
  1049. >>> outputs = model(**inputs, labels=labels)
  1050. >>> loss = outputs.loss
  1051. >>> print(f"{loss.item():.5f}")
  1052. 5.94282
  1053. ```"""
  1054. use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
  1055. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1056. # Encode if needed (training, first prediction pass)
  1057. if encoder_outputs is None:
  1058. encoder_outputs = self.encoder(
  1059. flattened_patches=flattened_patches,
  1060. attention_mask=attention_mask,
  1061. output_attentions=output_attentions,
  1062. output_hidden_states=output_hidden_states,
  1063. return_dict=return_dict,
  1064. )
  1065. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1066. encoder_outputs = BaseModelOutput(
  1067. last_hidden_state=encoder_outputs[0],
  1068. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1069. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1070. )
  1071. hidden_states = encoder_outputs[0]
  1072. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1073. # get decoder inputs from shifting lm labels to the right
  1074. decoder_input_ids = self._shift_right(labels)
  1075. decoder_attention_mask = (
  1076. decoder_attention_mask
  1077. if decoder_attention_mask is not None
  1078. else decoder_input_ids.ne(self.config.pad_token_id).float()
  1079. )
  1080. # Always attend to the first token
  1081. decoder_attention_mask[:, 0] = 1
  1082. # Decode
  1083. decoder_outputs = self.decoder(
  1084. input_ids=decoder_input_ids,
  1085. attention_mask=decoder_attention_mask,
  1086. inputs_embeds=decoder_inputs_embeds,
  1087. past_key_values=past_key_values,
  1088. encoder_hidden_states=hidden_states,
  1089. encoder_attention_mask=attention_mask,
  1090. use_cache=use_cache,
  1091. output_attentions=output_attentions,
  1092. output_hidden_states=output_hidden_states,
  1093. labels=labels,
  1094. return_dict=return_dict,
  1095. )
  1096. if not return_dict:
  1097. return decoder_outputs + encoder_outputs
  1098. return Seq2SeqLMOutput(
  1099. loss=decoder_outputs.loss,
  1100. logits=decoder_outputs.logits,
  1101. past_key_values=decoder_outputs.past_key_values,
  1102. decoder_hidden_states=decoder_outputs.hidden_states,
  1103. decoder_attentions=decoder_outputs.attentions,
  1104. cross_attentions=decoder_outputs.cross_attentions,
  1105. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1106. encoder_hidden_states=encoder_outputs.hidden_states,
  1107. encoder_attentions=encoder_outputs.attentions,
  1108. )
  1109. __all__ = [
  1110. "Pix2StructPreTrainedModel",
  1111. "Pix2StructForConditionalGeneration",
  1112. "Pix2StructVisionModel",
  1113. "Pix2StructTextModel",
  1114. ]