modeling_trocr.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch TrOCR decoder model (based on RoBERTa)."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import CrossEntropyLoss
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  21. from ...generation import GenerationMixin
  22. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import auto_docstring, logging
  27. from .configuration_trocr import TrOCRConfig
  28. logger = logging.get_logger(__name__)
  29. # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
  30. class TrOCRLearnedPositionalEmbedding(nn.Embedding):
  31. """
  32. This module learns positional embeddings up to a fixed maximum size.
  33. """
  34. def __init__(self, num_embeddings: int, embedding_dim: int):
  35. # TrOCR is set up so that if padding_idx is specified then offset the embedding ids by 2
  36. # and adjust num_embeddings appropriately. Other models don't have this hack
  37. self.offset = 2
  38. super().__init__(num_embeddings + self.offset, embedding_dim)
  39. def forward(
  40. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor | None = None
  41. ):
  42. """`input_ids' shape is expected to be [bsz x seqlen]."""
  43. if position_ids is None:
  44. bsz, seq_len = input_ids.shape[:2]
  45. position_ids = torch.arange(
  46. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  47. ).expand(bsz, -1)
  48. else:
  49. position_ids = position_ids.unsqueeze(0)
  50. return super().forward(position_ids + self.offset)
  51. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR
  52. class TrOCRScaledWordEmbedding(nn.Embedding):
  53. """
  54. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  55. """
  56. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float | None = 1.0):
  57. super().__init__(num_embeddings, embedding_dim, padding_idx)
  58. self.embed_scale = embed_scale
  59. def forward(self, input_ids: torch.Tensor):
  60. return super().forward(input_ids) * self.embed_scale
  61. class TrOCRSinusoidalPositionalEmbedding(nn.Module):
  62. """This module produces sinusoidal positional embeddings of any length."""
  63. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: int | None = None):
  64. super().__init__()
  65. self.offset = 2
  66. self.embedding_dim = embedding_dim
  67. self.padding_idx = padding_idx
  68. self.weights = self.get_embedding(num_positions, embedding_dim, padding_idx)
  69. @staticmethod
  70. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: int | None = None):
  71. """
  72. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
  73. description in Section 3.5 of "Attention Is All You Need".
  74. """
  75. half_dim = embedding_dim // 2
  76. emb = math.log(10000) / (half_dim - 1)
  77. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  78. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  79. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  80. if embedding_dim % 2 == 1:
  81. # zero pad
  82. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  83. if padding_idx is not None:
  84. emb[padding_idx, :] = 0
  85. return emb.to(torch.get_default_dtype())
  86. @torch.no_grad()
  87. def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
  88. bsz, seq_len = input_ids.size()
  89. # Create the position ids from the input token ids. Any padded tokens remain padded.
  90. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
  91. input_ids.device
  92. )
  93. # expand embeddings if needed
  94. max_pos = self.padding_idx + 1 + seq_len
  95. if self.weights is None or max_pos > self.weights.size(0):
  96. # recompute/expand embeddings if needed
  97. self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)
  98. x = self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
  99. return x
  100. def create_position_ids_from_input_ids(
  101. self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: int | None = 0
  102. ):
  103. """
  104. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  105. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  106. """
  107. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  108. mask = input_ids.ne(padding_idx).int()
  109. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  110. return incremental_indices.long() + padding_idx
  111. class TrOCRAttention(nn.Module):
  112. """Multi-headed attention from 'Attention Is All You Need' paper."""
  113. def __init__(
  114. self,
  115. config,
  116. embed_dim: int,
  117. num_heads: int,
  118. kdim: int | None = None,
  119. vdim: int | None = None,
  120. dropout: float | None = 0.0,
  121. is_decoder: bool | None = False,
  122. bias: bool | None = True,
  123. is_cross_attention: bool | None = False,
  124. layer_idx: bool | None = None,
  125. ):
  126. super().__init__()
  127. self.embed_dim = embed_dim
  128. self.kdim = kdim if kdim is not None else embed_dim
  129. self.vdim = vdim if vdim is not None else embed_dim
  130. self.num_heads = num_heads
  131. self.dropout = dropout
  132. self.head_dim = embed_dim // num_heads
  133. if not (self.head_dim * num_heads == self.embed_dim):
  134. raise ValueError(
  135. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  136. f" {num_heads})."
  137. )
  138. self.scaling = self.head_dim**-0.5
  139. self.is_decoder = is_decoder
  140. self.layer_idx = layer_idx
  141. self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
  142. self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
  143. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  144. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  145. def forward(
  146. self,
  147. hidden_states: torch.Tensor,
  148. key_value_states: torch.Tensor | None = None,
  149. past_key_values: Cache | None = None,
  150. attention_mask: torch.Tensor | None = None,
  151. output_attentions: bool | None = False,
  152. **kwargs,
  153. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  154. """Input shape: Batch x Time x Channel"""
  155. # if key_value_states are provided this layer is used as a cross-attention layer
  156. # for the decoder
  157. is_cross_attention = key_value_states is not None
  158. bsz, tgt_len, embed_dim = hidden_states.size()
  159. # get query proj
  160. query_states = self.q_proj(hidden_states) * self.scaling
  161. is_updated = False
  162. if past_key_values is not None:
  163. if isinstance(past_key_values, EncoderDecoderCache):
  164. is_updated = past_key_values.is_updated.get(self.layer_idx)
  165. if is_cross_attention:
  166. # after the first generated id, we can subsequently re-use all key/value_states from cache
  167. curr_past_key_values = past_key_values.cross_attention_cache
  168. else:
  169. curr_past_key_values = past_key_values.self_attention_cache
  170. else:
  171. curr_past_key_values = past_key_values
  172. current_states = key_value_states if is_cross_attention else hidden_states
  173. if is_cross_attention and past_key_values is not None and is_updated:
  174. # reuse k,v, cross_attentions
  175. key_states = curr_past_key_values.layers[self.layer_idx].keys
  176. value_states = curr_past_key_values.layers[self.layer_idx].values
  177. else:
  178. key_states = self.k_proj(current_states)
  179. value_states = self.v_proj(current_states)
  180. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  181. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  182. if past_key_values is not None:
  183. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  184. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  185. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  186. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  187. past_key_values.is_updated[self.layer_idx] = True
  188. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  189. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  190. query_states = query_states.reshape(*proj_shape)
  191. key_states = key_states.reshape(*proj_shape)
  192. value_states = value_states.reshape(*proj_shape)
  193. src_len = key_states.size(1)
  194. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  195. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  196. raise ValueError(
  197. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  198. f" {attn_weights.size()}"
  199. )
  200. if attention_mask is not None:
  201. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  202. raise ValueError(
  203. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  204. )
  205. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  206. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  207. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  208. if output_attentions:
  209. # this operation is a bit awkward, but it's required to
  210. # make sure that attn_weights keeps its gradient.
  211. # In order to do so, attn_weights have to be reshaped
  212. # twice and have to be reused in the following
  213. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  214. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  215. else:
  216. attn_weights_reshaped = None
  217. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  218. attn_output = torch.bmm(attn_probs, value_states)
  219. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  220. raise ValueError(
  221. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  222. f" {attn_output.size()}"
  223. )
  224. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  225. attn_output = attn_output.transpose(1, 2)
  226. attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
  227. attn_output = self.out_proj(attn_output)
  228. return attn_output, attn_weights_reshaped
  229. class TrOCRDecoderLayer(GradientCheckpointingLayer):
  230. def __init__(self, config: TrOCRConfig, layer_idx=None):
  231. super().__init__()
  232. self.embed_dim = config.hidden_size
  233. self.self_attn = TrOCRAttention(
  234. config,
  235. embed_dim=self.embed_dim,
  236. num_heads=config.decoder_attention_heads,
  237. dropout=config.attention_dropout,
  238. is_decoder=True,
  239. layer_idx=layer_idx,
  240. )
  241. self.dropout = config.dropout
  242. self.activation_fn = ACT2FN[config.activation_function]
  243. self.activation_dropout = config.activation_dropout
  244. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  245. if config.is_decoder:
  246. self.encoder_attn = TrOCRAttention(
  247. config,
  248. embed_dim=self.embed_dim,
  249. num_heads=config.decoder_attention_heads,
  250. kdim=config.cross_attention_hidden_size,
  251. vdim=config.cross_attention_hidden_size,
  252. dropout=config.attention_dropout,
  253. is_decoder=True,
  254. is_cross_attention=True,
  255. layer_idx=layer_idx,
  256. )
  257. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  258. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  259. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  260. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  261. def forward(
  262. self,
  263. hidden_states: torch.Tensor,
  264. attention_mask: torch.Tensor | None = None,
  265. encoder_hidden_states: torch.Tensor | None = None,
  266. encoder_attention_mask: torch.Tensor | None = None,
  267. past_key_values: Cache | None = None,
  268. output_attentions: bool | None = False,
  269. use_cache: bool | None = True,
  270. **kwargs,
  271. ):
  272. """
  273. Args:
  274. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  275. attention_mask (`torch.FloatTensor`): attention mask of size
  276. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  277. encoder_hidden_states (`torch.FloatTensor`):
  278. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  279. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  280. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  281. past_key_values (`Cache`): cached past key and value projection states
  282. output_attentions (`bool`, *optional*):
  283. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  284. returned tensors for more detail.
  285. """
  286. residual = hidden_states
  287. # Self Attention
  288. hidden_states, self_attn_weights = self.self_attn(
  289. hidden_states=hidden_states,
  290. past_key_values=past_key_values,
  291. attention_mask=attention_mask,
  292. output_attentions=output_attentions,
  293. )
  294. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  295. hidden_states = residual + hidden_states
  296. hidden_states = self.self_attn_layer_norm(hidden_states)
  297. # Cross-Attention Block
  298. cross_attn_weights = None
  299. if encoder_hidden_states is not None:
  300. residual = hidden_states
  301. hidden_states, cross_attn_weights = self.encoder_attn(
  302. hidden_states=hidden_states,
  303. key_value_states=encoder_hidden_states,
  304. attention_mask=encoder_attention_mask,
  305. past_key_values=past_key_values,
  306. output_attentions=output_attentions,
  307. )
  308. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  309. hidden_states = residual + hidden_states
  310. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  311. # Fully Connected
  312. residual = hidden_states
  313. hidden_states = self.activation_fn(self.fc1(hidden_states))
  314. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  315. hidden_states = self.fc2(hidden_states)
  316. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  317. hidden_states = residual + hidden_states
  318. hidden_states = self.final_layer_norm(hidden_states)
  319. outputs = (hidden_states,)
  320. if output_attentions:
  321. outputs += (self_attn_weights, cross_attn_weights)
  322. return outputs
  323. @auto_docstring
  324. class TrOCRPreTrainedModel(PreTrainedModel):
  325. config: TrOCRConfig
  326. base_model_prefix = "model"
  327. supports_gradient_checkpointing = True
  328. _no_split_modules = ["TrOCRDecoderLayer"]
  329. class TrOCRDecoder(TrOCRPreTrainedModel):
  330. """
  331. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TrOCRDecoderLayer`]
  332. Args:
  333. config: TrOCRConfig
  334. """
  335. def __init__(self, config: TrOCRConfig):
  336. super().__init__(config)
  337. self.dropout = config.dropout
  338. self.layerdrop = config.decoder_layerdrop
  339. self.padding_idx = config.pad_token_id
  340. embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  341. self.embed_tokens = TrOCRScaledWordEmbedding(
  342. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
  343. )
  344. if config.use_learned_position_embeddings:
  345. self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
  346. else:
  347. self.embed_positions = TrOCRSinusoidalPositionalEmbedding(
  348. config.max_position_embeddings + self.padding_idx + 1,
  349. config.hidden_size,
  350. self.padding_idx,
  351. )
  352. if config.layernorm_embedding:
  353. self.layernorm_embedding = nn.LayerNorm(config.hidden_size)
  354. else:
  355. self.layernorm_embedding = None
  356. self.layers = nn.ModuleList([TrOCRDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  357. self.gradient_checkpointing = False
  358. # Initialize weights and apply final processing
  359. self.post_init()
  360. def forward(
  361. self,
  362. input_ids=None,
  363. attention_mask=None,
  364. encoder_hidden_states=None,
  365. encoder_attention_mask=None,
  366. past_key_values=None,
  367. inputs_embeds=None,
  368. use_cache=None,
  369. output_attentions=None,
  370. output_hidden_states=None,
  371. return_dict=None,
  372. **kwargs,
  373. ):
  374. r"""
  375. Args:
  376. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  377. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  378. provide it.
  379. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  380. [`PreTrainedTokenizer.__call__`] for details.
  381. [What are input IDs?](../glossary#input-ids)
  382. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  383. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  384. - 1 for tokens that are **not masked**,
  385. - 0 for tokens that are **masked**.
  386. [What are attention masks?](../glossary#attention-mask)
  387. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  388. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  389. of the decoder.
  390. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  391. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  392. selected in `[0, 1]`:
  393. - 1 for tokens that are **not masked**,
  394. - 0 for tokens that are **masked**.
  395. [What are attention masks?](../glossary#attention-mask)
  396. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  397. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  398. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  399. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  400. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  401. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  402. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  403. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  404. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  405. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  406. than the model's internal embedding lookup matrix.
  407. output_attentions (`bool`, *optional*):
  408. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  409. returned tensors for more detail.
  410. output_hidden_states (`bool`, *optional*):
  411. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  412. for more detail.
  413. return_dict (`bool`, *optional*):
  414. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  415. """
  416. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  417. output_hidden_states = (
  418. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  419. )
  420. use_cache = use_cache if use_cache is not None else self.config.use_cache
  421. return_dict = return_dict if return_dict is not None else self.config.return_dict
  422. # retrieve input_ids and inputs_embeds
  423. if input_ids is not None and inputs_embeds is not None:
  424. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  425. elif input_ids is not None:
  426. input = input_ids
  427. input_ids = input_ids.view(-1, input.shape[-1])
  428. elif inputs_embeds is not None:
  429. input = inputs_embeds[:, :, -1]
  430. else:
  431. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  432. if self.gradient_checkpointing and self.training:
  433. if use_cache:
  434. logger.warning_once(
  435. "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
  436. )
  437. use_cache = False
  438. if use_cache and past_key_values is None:
  439. past_key_values = (
  440. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  441. if encoder_hidden_states is not None or self.config.is_encoder_decoder
  442. else DynamicCache(config=self.config)
  443. )
  444. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  445. if inputs_embeds is None:
  446. inputs_embeds = self.embed_tokens(input_ids)
  447. if self.config.use_learned_position_embeddings:
  448. embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
  449. else:
  450. embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
  451. hidden_states = inputs_embeds + embed_pos
  452. if self.layernorm_embedding is not None:
  453. hidden_states = self.layernorm_embedding(hidden_states)
  454. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  455. attention_mask = create_causal_mask(
  456. config=self.config,
  457. inputs_embeds=inputs_embeds,
  458. attention_mask=attention_mask,
  459. past_key_values=past_key_values,
  460. )
  461. # expand encoder attention mask
  462. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  463. encoder_attention_mask = create_bidirectional_mask(
  464. config=self.config,
  465. inputs_embeds=inputs_embeds,
  466. attention_mask=encoder_attention_mask,
  467. encoder_hidden_states=encoder_hidden_states,
  468. )
  469. # decoder layers
  470. all_hidden_states = () if output_hidden_states else None
  471. all_self_attns = () if output_attentions else None
  472. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  473. for idx, decoder_layer in enumerate(self.layers):
  474. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  475. if output_hidden_states:
  476. all_hidden_states += (hidden_states,)
  477. if self.training:
  478. dropout_probability = torch.rand([])
  479. if dropout_probability < self.layerdrop:
  480. continue
  481. layer_outputs = decoder_layer(
  482. hidden_states,
  483. attention_mask,
  484. encoder_hidden_states, # as a positional argument for gradient checkpointing
  485. encoder_attention_mask=encoder_attention_mask,
  486. past_key_values=past_key_values,
  487. output_attentions=output_attentions,
  488. use_cache=use_cache,
  489. )
  490. hidden_states = layer_outputs[0]
  491. if output_attentions:
  492. all_self_attns += (layer_outputs[1],)
  493. if encoder_hidden_states is not None:
  494. all_cross_attentions += (layer_outputs[2],)
  495. # add hidden states from the last decoder layer
  496. if output_hidden_states:
  497. all_hidden_states += (hidden_states,)
  498. if not return_dict:
  499. return tuple(
  500. v
  501. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  502. if v is not None
  503. )
  504. return BaseModelOutputWithPastAndCrossAttentions(
  505. last_hidden_state=hidden_states,
  506. past_key_values=past_key_values,
  507. hidden_states=all_hidden_states,
  508. attentions=all_self_attns,
  509. cross_attentions=all_cross_attentions,
  510. )
  511. @auto_docstring(
  512. custom_intro="""
  513. The TrOCR Model with a language modeling head. Can be used for summarization.
  514. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  515. used in combination with the [`EncoderDecoderModel`] framework.
  516. """
  517. )
  518. class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
  519. def __init__(self, config):
  520. super().__init__(config)
  521. self.decoder = TrOCRDecoder(config)
  522. self.post_init()
  523. def forward(self, *args, **kwargs):
  524. return self.decoder(*args, **kwargs)
  525. @auto_docstring(
  526. custom_intro="""
  527. The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and
  528. """
  529. )
  530. class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin):
  531. _tied_weights_keys = {"output_projection.weight": "model.decoder.embed_tokens.weight"}
  532. def __init__(self, config):
  533. config.is_decoder = True
  534. config.is_encoder_decoder = False
  535. super().__init__(config)
  536. self.model = TrOCRDecoderWrapper(config)
  537. self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  538. # Initialize weights and apply final processing
  539. self.post_init()
  540. def get_input_embeddings(self):
  541. return self.model.decoder.embed_tokens
  542. def set_input_embeddings(self, value):
  543. self.model.decoder.embed_tokens = value
  544. def get_output_embeddings(self):
  545. return self.output_projection
  546. def set_output_embeddings(self, new_embeddings):
  547. self.output_projection = new_embeddings
  548. @auto_docstring
  549. def forward(
  550. self,
  551. input_ids: torch.LongTensor | None = None,
  552. attention_mask: torch.Tensor | None = None,
  553. encoder_hidden_states: torch.FloatTensor | None = None,
  554. encoder_attention_mask: torch.LongTensor | None = None,
  555. past_key_values: Cache | None = None,
  556. inputs_embeds: torch.FloatTensor | None = None,
  557. labels: torch.LongTensor | None = None,
  558. use_cache: bool | None = None,
  559. output_attentions: bool | None = None,
  560. output_hidden_states: bool | None = None,
  561. return_dict: bool | None = None,
  562. **kwargs,
  563. ) -> tuple | CausalLMOutputWithCrossAttentions:
  564. r"""
  565. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  566. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  567. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  568. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  569. Example:
  570. ```python
  571. >>> from transformers import (
  572. ... TrOCRConfig,
  573. ... TrOCRProcessor,
  574. ... TrOCRForCausalLM,
  575. ... ViTConfig,
  576. ... ViTModel,
  577. ... VisionEncoderDecoderModel,
  578. ... )
  579. >>> import httpx
  580. >>> from io import BytesIO
  581. >>> from PIL import Image
  582. >>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
  583. >>> # init vision2text model with random weights
  584. >>> encoder = ViTModel(ViTConfig())
  585. >>> decoder = TrOCRForCausalLM(TrOCRConfig())
  586. >>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
  587. >>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel`
  588. >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
  589. >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
  590. >>> # load image from the IAM dataset
  591. >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
  592. >>> with httpx.stream("GET", url) as response:
  593. ... image = Image.open(BytesIO(response.read())).convert("RGB")
  594. >>> pixel_values = processor(image, return_tensors="pt").pixel_values
  595. >>> text = "industry, ' Mr. Brown commented icily. ' Let us have a"
  596. >>> # training
  597. >>> model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
  598. >>> model.config.pad_token_id = processor.tokenizer.pad_token_id
  599. >>> model.config.vocab_size = model.config.decoder.vocab_size
  600. >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
  601. >>> outputs = model(pixel_values, labels=labels)
  602. >>> loss = outputs.loss
  603. >>> round(loss.item(), 2)
  604. 5.30
  605. >>> # inference
  606. >>> generated_ids = model.generate(pixel_values)
  607. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  608. >>> generated_text
  609. 'industry, " Mr. Brown commented icily. " Let us have a'
  610. ```"""
  611. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  612. output_hidden_states = (
  613. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  614. )
  615. return_dict = return_dict if return_dict is not None else self.config.return_dict
  616. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  617. outputs = self.model.decoder(
  618. input_ids=input_ids,
  619. attention_mask=attention_mask,
  620. encoder_hidden_states=encoder_hidden_states,
  621. encoder_attention_mask=encoder_attention_mask,
  622. past_key_values=past_key_values,
  623. inputs_embeds=inputs_embeds,
  624. use_cache=use_cache,
  625. output_attentions=output_attentions,
  626. output_hidden_states=output_hidden_states,
  627. return_dict=return_dict,
  628. )
  629. logits = self.output_projection(outputs[0])
  630. loss = None
  631. if labels is not None:
  632. loss_fct = CrossEntropyLoss()
  633. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  634. if not return_dict:
  635. output = (logits,) + outputs[1:]
  636. return (loss,) + output if loss is not None else output
  637. return CausalLMOutputWithCrossAttentions(
  638. loss=loss,
  639. logits=logits,
  640. past_key_values=outputs.past_key_values,
  641. hidden_states=outputs.hidden_states,
  642. attentions=outputs.attentions,
  643. cross_attentions=outputs.cross_attentions,
  644. )
  645. __all__ = ["TrOCRForCausalLM", "TrOCRPreTrainedModel"]