modeling_udop.py 77 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726
  1. # Copyright 2024 Microsoft Research and HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch UDOP model."""
  15. import collections
  16. import logging
  17. import math
  18. import random
  19. from abc import ABC, abstractmethod
  20. from collections.abc import Sequence
  21. from copy import deepcopy
  22. from dataclasses import dataclass
  23. from typing import Any
  24. import torch
  25. from torch import Tensor, nn
  26. from torch.nn import CrossEntropyLoss
  27. from transformers import UdopConfig
  28. from transformers.modeling_outputs import (
  29. Seq2SeqLMOutput,
  30. Seq2SeqModelOutput,
  31. )
  32. from ... import initialization as init
  33. from ...activations import ACT2FN
  34. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  35. from ...generation import GenerationMixin
  36. from ...masking_utils import create_causal_mask
  37. from ...modeling_layers import GradientCheckpointingLayer
  38. from ...modeling_utils import PreTrainedModel
  39. from ...utils import (
  40. ModelOutput,
  41. auto_docstring,
  42. is_torchdynamo_compiling,
  43. )
  44. logger = logging.getLogger(__name__)
  45. @dataclass
  46. @auto_docstring(
  47. custom_intro="""
  48. Class for the model's outputs that may also contain a past key/values (to speed up sequential decoding). Includes
  49. an additional attention mask.
  50. """
  51. )
  52. class BaseModelOutputWithAttentionMask(ModelOutput):
  53. r"""
  54. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  55. Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only
  56. the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output.
  57. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  58. Attention mask used in the model's forward pass to avoid performing attention on padding token indices.
  59. Mask values selected in `[0, 1]`:
  60. - 1 for tokens that are **not masked**,
  61. - 0 for tokens that are **masked**.
  62. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  63. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  64. Contains pre-computed hidden-states (key and values in the
  65. self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks)
  66. that can be used (see `past_key_values` input) to speed up sequential decoding.
  67. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  68. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  69. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
  70. the model at the output of each layer plus the optional initial embedding outputs.
  71. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  72. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  73. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  74. the self-attention heads.
  75. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
  76. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  77. sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
  78. used to compute the weighted average in the cross-attention heads.
  79. """
  80. last_hidden_state: torch.FloatTensor | None = None
  81. attention_mask: torch.FloatTensor | None = None
  82. past_key_values: Cache | None = None
  83. hidden_states: tuple[torch.FloatTensor] | None = None
  84. attentions: tuple[torch.FloatTensor] | None = None
  85. cross_attentions: tuple[torch.FloatTensor] | None = None
  86. def get_visual_bbox(image_size=224, patch_size=16):
  87. image_feature_pool_shape = [image_size // patch_size, image_size // patch_size]
  88. visual_bbox_x = torch.arange(0, 1.0 * (image_feature_pool_shape[1] + 1), 1.0)
  89. visual_bbox_x /= image_feature_pool_shape[1]
  90. visual_bbox_y = torch.arange(0, 1.0 * (image_feature_pool_shape[0] + 1), 1.0)
  91. visual_bbox_y /= image_feature_pool_shape[0]
  92. visual_bbox_input = torch.stack(
  93. [
  94. visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1),
  95. visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  96. visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1),
  97. visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  98. ],
  99. dim=-1,
  100. )
  101. visual_bbox_input = visual_bbox_input.view(-1, 4)
  102. return visual_bbox_input
  103. def pad_sequence(seq, target_len, pad_value=0):
  104. if isinstance(seq, torch.Tensor):
  105. n = seq.shape[0]
  106. else:
  107. n = len(seq)
  108. seq = torch.tensor(seq)
  109. m = target_len - n
  110. if m > 0:
  111. ret = torch.stack([pad_value] * m).to(seq)
  112. seq = torch.cat([seq, ret], dim=0)
  113. return seq[:target_len]
  114. def combine_image_text_embeddings(
  115. image_embeddings,
  116. inputs_embeds,
  117. bbox,
  118. visual_bbox,
  119. attention_mask=None,
  120. num_patches=14,
  121. max_len=0,
  122. image_size=224,
  123. patch_size=16,
  124. ):
  125. """
  126. Combine the image and text embeddings for the input to the encoder/decoder of UDOP.
  127. First, the image embeddings are created by checking for each visual patch if it is inside the bounding box of a
  128. token. If it is, the visual patch is combined with the token embedding. Then, the visual bounding boxes are combined
  129. with the text bounding boxes. Finally, the visual bounding boxes are combined with the text attention mask.
  130. """
  131. sequence_length = num_patches
  132. ocr_points_x = torch.clip(
  133. torch.floor((bbox[:, :, 0] + bbox[:, :, 2]) / 2.0 * sequence_length).long(), 0, sequence_length - 1
  134. )
  135. ocr_points_y = (
  136. torch.clip(torch.floor((bbox[:, :, 1] + bbox[:, :, 3]) / 2.0 * sequence_length).long(), 0, sequence_length - 1)
  137. * sequence_length
  138. )
  139. ocr_points = ocr_points_x + ocr_points_y
  140. # make sure bounding boxes are of type float to calculate means
  141. bbox = bbox.to(torch.float64)
  142. target_seg = (bbox.mean(-1) == 0.0) | (bbox.mean(-1) == 1.0)
  143. repeated_vision_embeds = torch.gather(
  144. image_embeddings, 1, ocr_points.unsqueeze(-1).repeat(1, 1, image_embeddings.size(-1))
  145. )
  146. repeated_vision_embeds[target_seg] = 0.0
  147. inputs_embeds += repeated_vision_embeds
  148. patch_inds = torch.full_like(image_embeddings[:, :, 0], True).bool()
  149. ind = torch.cat(
  150. [
  151. torch.arange(len(ocr_points))[:, None].repeat(1, ocr_points.size(-1))[:, :, None].to(ocr_points),
  152. ocr_points[:, :, None],
  153. ],
  154. dim=-1,
  155. )
  156. ind = ind.flatten(0, 1)
  157. rows, cols = zip(*ind)
  158. patch_inds[rows, cols] = False
  159. input_vision_patches = [image_embeddings[i][patch_inds[i]] for i in range(len(patch_inds))]
  160. if visual_bbox is None:
  161. visual_bbox = get_visual_bbox(image_size=image_size, patch_size=patch_size)
  162. visual_bbox = visual_bbox.unsqueeze(0).repeat(image_embeddings.size(0), 1, 1)
  163. visual_bbox = visual_bbox.to(image_embeddings.device)
  164. visual_bbox = [visual_bbox[i][patch_inds[i]] for i in range(len(patch_inds))]
  165. if attention_mask is not None:
  166. visual_attention_mask = [
  167. torch.ones(item.size(0), dtype=attention_mask.dtype, device=attention_mask.device) for item in visual_bbox
  168. ]
  169. if max_len == 0:
  170. max_len = image_embeddings.size(1)
  171. else:
  172. max_len = max_len - inputs_embeds.size(1)
  173. inputs_vision_patches = torch.stack(
  174. [pad_sequence(item, max_len, torch.zeros_like(image_embeddings[0, 0])) for item in input_vision_patches]
  175. )
  176. visual_bbox = torch.stack([pad_sequence(item, max_len, torch.zeros_like(bbox[0, 0])) for item in visual_bbox])
  177. if attention_mask is not None:
  178. visual_attention_mask = torch.stack(
  179. [pad_sequence(item, max_len, torch.zeros_like(attention_mask[0, 0])) for item in visual_attention_mask]
  180. )
  181. inputs_embeds = torch.cat([inputs_embeds, inputs_vision_patches], 1)
  182. bbox = torch.cat([bbox, visual_bbox], 1)
  183. if attention_mask is not None:
  184. attention_mask = torch.cat([attention_mask, visual_attention_mask], 1)
  185. return inputs_embeds, bbox, attention_mask
  186. class UdopPatchEmbeddings(nn.Module):
  187. """2D Image to Patch Embeddings"""
  188. def __init__(self, config):
  189. super().__init__()
  190. image_size, patch_size = config.image_size, config.patch_size
  191. num_channels, hidden_size = config.num_channels, config.hidden_size
  192. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  193. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  194. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  195. self.image_size = image_size
  196. self.patch_size = patch_size
  197. self.num_channels = num_channels
  198. self.num_patches = num_patches
  199. self.proj = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  200. def forward(self, pixel_values):
  201. batch_size, num_channels, height, width = pixel_values.shape
  202. if height != self.image_size[0] or width != self.image_size[1]:
  203. raise ValueError(
  204. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  205. )
  206. embeddings = self.proj(pixel_values)
  207. embeddings = embeddings.flatten(2).transpose(1, 2)
  208. return embeddings
  209. @auto_docstring
  210. class UdopPreTrainedModel(PreTrainedModel):
  211. config: UdopConfig
  212. base_model_prefix = "transformer"
  213. input_modalities = ("image", "text")
  214. supports_gradient_checkpointing = True
  215. _can_compile_fullgraph = False
  216. _keep_in_fp32_modules = ["wo"]
  217. @torch.no_grad()
  218. def _init_weights(self, module):
  219. """Initialize the weights"""
  220. factor = self.config.initializer_factor # Used for testing weights initialization
  221. if isinstance(module, UdopLayerNorm):
  222. init.constant_(module.weight, factor * 1.0)
  223. elif isinstance(module, nn.Embedding):
  224. init.normal_(module.weight, mean=0.0, std=factor)
  225. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  226. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  227. init.zeros_(module.weight[module.padding_idx])
  228. elif isinstance(module, nn.Conv2d):
  229. init.trunc_normal_(module.weight, mean=0.0, std=factor)
  230. if module.bias is not None:
  231. init.zeros_(module.bias)
  232. elif isinstance(module, RelativePositionBiasBase):
  233. factor = self.config.initializer_factor
  234. d_model = self.config.d_model
  235. init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
  236. elif isinstance(module, UdopModel):
  237. init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
  238. elif isinstance(module, UdopForConditionalGeneration):
  239. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  240. init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0)
  241. elif isinstance(module, UdopDenseActDense):
  242. init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  243. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  244. init.zeros_(module.wi.bias)
  245. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  246. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  247. init.zeros_(module.wo.bias)
  248. elif isinstance(module, UdopDenseGatedActDense):
  249. init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  250. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  251. init.zeros_(module.wi_0.bias)
  252. init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  253. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  254. init.zeros_(module.wi_1.bias)
  255. init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  256. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  257. init.zeros_(module.wo.bias)
  258. elif isinstance(module, UdopAttention):
  259. d_model = self.config.d_model
  260. key_value_proj_dim = self.config.d_kv
  261. n_heads = self.config.num_heads
  262. init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  263. init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
  264. init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
  265. init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  266. if module.has_relative_attention_bias:
  267. init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
  268. # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetPreTrainedModel._shift_right with ProphetNet->Udop
  269. def _shift_right(self, input_ids):
  270. decoder_start_token_id = self.config.decoder_start_token_id
  271. pad_token_id = self.config.pad_token_id
  272. assert decoder_start_token_id is not None, (
  273. "self.model.config.decoder_start_token_id has to be defined. In Udop it is usually set to the"
  274. " pad_token_id. See Udop docs for more information"
  275. )
  276. # shift inputs to the right
  277. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  278. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  279. shifted_input_ids[..., 0] = decoder_start_token_id
  280. assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
  281. # replace possible -100 values in labels by `pad_token_id`
  282. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  283. assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
  284. return shifted_input_ids
  285. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Udop
  286. class UdopLayerNorm(nn.Module):
  287. def __init__(self, hidden_size, eps=1e-6):
  288. """
  289. Construct a layernorm module in the Udop style. No bias and no subtraction of mean.
  290. """
  291. super().__init__()
  292. self.weight = nn.Parameter(torch.ones(hidden_size))
  293. self.variance_epsilon = eps
  294. def forward(self, hidden_states):
  295. # Udop uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  296. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  297. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  298. # half-precision inputs is done in fp32
  299. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  300. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  301. # convert into half-precision if necessary
  302. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  303. hidden_states = hidden_states.to(self.weight.dtype)
  304. return self.weight * hidden_states
  305. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Udop
  306. class UdopDenseActDense(nn.Module):
  307. def __init__(self, config: UdopConfig):
  308. super().__init__()
  309. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  310. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  311. self.dropout = nn.Dropout(config.dropout_rate)
  312. self.act = ACT2FN[config.dense_act_fn]
  313. def forward(self, hidden_states):
  314. hidden_states = self.wi(hidden_states)
  315. hidden_states = self.act(hidden_states)
  316. hidden_states = self.dropout(hidden_states)
  317. if (
  318. isinstance(self.wo.weight, torch.Tensor)
  319. and hidden_states.dtype != self.wo.weight.dtype
  320. and self.wo.weight.dtype != torch.int8
  321. ):
  322. hidden_states = hidden_states.to(self.wo.weight.dtype)
  323. hidden_states = self.wo(hidden_states)
  324. return hidden_states
  325. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Udop
  326. class UdopDenseGatedActDense(nn.Module):
  327. def __init__(self, config: UdopConfig):
  328. super().__init__()
  329. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  330. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  331. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  332. self.dropout = nn.Dropout(config.dropout_rate)
  333. self.act = ACT2FN[config.dense_act_fn]
  334. def forward(self, hidden_states):
  335. hidden_gelu = self.act(self.wi_0(hidden_states))
  336. hidden_linear = self.wi_1(hidden_states)
  337. hidden_states = hidden_gelu * hidden_linear
  338. hidden_states = self.dropout(hidden_states)
  339. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  340. # See https://github.com/huggingface/transformers/issues/20287
  341. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  342. if (
  343. isinstance(self.wo.weight, torch.Tensor)
  344. and hidden_states.dtype != self.wo.weight.dtype
  345. and self.wo.weight.dtype != torch.int8
  346. ):
  347. hidden_states = hidden_states.to(self.wo.weight.dtype)
  348. hidden_states = self.wo(hidden_states)
  349. return hidden_states
  350. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Udop
  351. class UdopLayerFF(nn.Module):
  352. def __init__(self, config: UdopConfig):
  353. super().__init__()
  354. if config.is_gated_act:
  355. self.DenseReluDense = UdopDenseGatedActDense(config)
  356. else:
  357. self.DenseReluDense = UdopDenseActDense(config)
  358. self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  359. self.dropout = nn.Dropout(config.dropout_rate)
  360. def forward(self, hidden_states):
  361. forwarded_states = self.layer_norm(hidden_states)
  362. forwarded_states = self.DenseReluDense(forwarded_states)
  363. hidden_states = hidden_states + self.dropout(forwarded_states)
  364. return hidden_states
  365. # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop
  366. class UdopAttention(nn.Module):
  367. def __init__(
  368. self,
  369. config: UdopConfig,
  370. has_relative_attention_bias=False,
  371. layer_idx: int | None = None,
  372. ):
  373. super().__init__()
  374. self.is_decoder = config.is_decoder
  375. self.has_relative_attention_bias = has_relative_attention_bias
  376. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  377. self.relative_attention_max_distance = config.relative_attention_max_distance
  378. self.d_model = config.d_model
  379. self.key_value_proj_dim = config.d_kv
  380. self.n_heads = config.num_heads
  381. self.dropout = config.dropout_rate
  382. self.inner_dim = self.n_heads * self.key_value_proj_dim
  383. self.layer_idx = layer_idx
  384. if layer_idx is None and self.is_decoder:
  385. logger.warning_once(
  386. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  387. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  388. "when creating this class."
  389. )
  390. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  391. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  392. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  393. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  394. if self.has_relative_attention_bias:
  395. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  396. self.gradient_checkpointing = False
  397. @staticmethod
  398. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  399. """
  400. Adapted from Mesh Tensorflow:
  401. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  402. Translate relative position to a bucket number for relative attention. The relative position is defined as
  403. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  404. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  405. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  406. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  407. This should allow for more graceful generalization to longer sequences than the model has been trained on
  408. Args:
  409. relative_position: an int32 Tensor
  410. bidirectional: a boolean - whether the attention is bidirectional
  411. num_buckets: an integer
  412. max_distance: an integer
  413. Returns:
  414. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  415. """
  416. relative_buckets = 0
  417. if bidirectional:
  418. num_buckets //= 2
  419. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  420. relative_position = torch.abs(relative_position)
  421. else:
  422. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  423. # now relative_position is in the range [0, inf)
  424. # half of the buckets are for exact increments in positions
  425. max_exact = num_buckets // 2
  426. is_small = relative_position < max_exact
  427. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  428. relative_position_if_large = max_exact + (
  429. torch.log(relative_position.float() / max_exact)
  430. / math.log(max_distance / max_exact)
  431. * (num_buckets - max_exact)
  432. ).to(torch.long)
  433. relative_position_if_large = torch.min(
  434. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  435. )
  436. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  437. return relative_buckets
  438. def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
  439. """Compute binned relative position bias"""
  440. if device is None:
  441. device = self.relative_attention_bias.weight.device
  442. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
  443. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  444. relative_position = memory_position - context_position # shape (query_length, key_length)
  445. relative_position_bucket = self._relative_position_bucket(
  446. relative_position, # shape (query_length, key_length)
  447. bidirectional=(not self.is_decoder),
  448. num_buckets=self.relative_attention_num_buckets,
  449. max_distance=self.relative_attention_max_distance,
  450. )
  451. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  452. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  453. return values
  454. def forward(
  455. self,
  456. hidden_states,
  457. mask=None,
  458. key_value_states=None,
  459. position_bias=None,
  460. past_key_values=None,
  461. output_attentions=False,
  462. **kwargs,
  463. ):
  464. """
  465. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  466. """
  467. # Input is (batch_size, seq_length, dim)
  468. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  469. input_shape = hidden_states.shape[:-1]
  470. hidden_shape = (*input_shape, -1, self.key_value_proj_dim)
  471. past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
  472. # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
  473. past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
  474. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  475. is_cross_attention = key_value_states is not None
  476. query_states = self.q(hidden_states).view(hidden_shape).transpose(1, 2)
  477. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  478. is_updated = False
  479. if isinstance(past_key_values, EncoderDecoderCache):
  480. is_updated = past_key_values.is_updated.get(self.layer_idx)
  481. if is_cross_attention:
  482. # after the first generated id, we can subsequently re-use all key/value_states from cache
  483. curr_past_key_values = past_key_values.cross_attention_cache
  484. else:
  485. curr_past_key_values = past_key_values.self_attention_cache
  486. else:
  487. curr_past_key_values = past_key_values
  488. current_states = key_value_states if is_cross_attention else hidden_states
  489. if is_cross_attention and past_key_values is not None and is_updated:
  490. # reuse k,v, cross_attentions
  491. key_states = curr_past_key_values.layers[self.layer_idx].keys
  492. value_states = curr_past_key_values.layers[self.layer_idx].values
  493. else:
  494. kv_shape = (*current_states.shape[:-1], -1, self.key_value_proj_dim)
  495. key_states = self.k(current_states).view(kv_shape).transpose(1, 2)
  496. value_states = self.v(current_states).view(kv_shape).transpose(1, 2)
  497. if past_key_values is not None:
  498. key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
  499. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  500. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  501. past_key_values.is_updated[self.layer_idx] = True
  502. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  503. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  504. if position_bias is None:
  505. key_length = key_states.shape[-2]
  506. if not self.has_relative_attention_bias:
  507. position_bias = torch.zeros(
  508. (1, query_states.shape[1], input_shape[1], key_length), device=scores.device, dtype=scores.dtype
  509. )
  510. if self.gradient_checkpointing and self.training:
  511. position_bias.requires_grad = True
  512. else:
  513. position_bias = self.compute_bias(
  514. input_shape[1], key_length, device=scores.device, past_seen_tokens=past_seen_tokens
  515. )
  516. if mask is not None:
  517. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  518. position_bias = position_bias + causal_mask
  519. position_bias_masked = position_bias
  520. scores += position_bias_masked
  521. # (batch_size, n_heads, seq_length, key_length)
  522. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  523. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  524. attn_output = torch.matmul(attn_weights, value_states)
  525. attn_output = attn_output.transpose(1, 2).contiguous()
  526. attn_output = attn_output.reshape(*input_shape, -1)
  527. attn_output = self.o(attn_output)
  528. outputs = (attn_output, position_bias)
  529. if output_attentions:
  530. outputs = outputs + (attn_weights,)
  531. return outputs
  532. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop
  533. class UdopLayerSelfAttention(nn.Module):
  534. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  535. super().__init__()
  536. self.SelfAttention = UdopAttention(
  537. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  538. )
  539. self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  540. self.dropout = nn.Dropout(config.dropout_rate)
  541. def forward(
  542. self,
  543. hidden_states,
  544. attention_mask=None,
  545. position_bias=None,
  546. past_key_values=None,
  547. use_cache=False,
  548. output_attentions=False,
  549. **kwargs,
  550. ):
  551. normed_hidden_states = self.layer_norm(hidden_states)
  552. attention_output = self.SelfAttention(
  553. normed_hidden_states,
  554. mask=attention_mask,
  555. position_bias=position_bias,
  556. past_key_values=past_key_values,
  557. use_cache=use_cache,
  558. output_attentions=output_attentions,
  559. )
  560. hidden_states = hidden_states + self.dropout(attention_output[0])
  561. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  562. return outputs
  563. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop
  564. class UdopLayerCrossAttention(nn.Module):
  565. def __init__(self, config, layer_idx: int | None = None):
  566. super().__init__()
  567. self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  568. self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  569. self.dropout = nn.Dropout(config.dropout_rate)
  570. def forward(
  571. self,
  572. hidden_states,
  573. key_value_states,
  574. attention_mask=None,
  575. position_bias=None,
  576. past_key_values=None,
  577. output_attentions=False,
  578. **kwargs,
  579. ):
  580. normed_hidden_states = self.layer_norm(hidden_states)
  581. attention_output = self.EncDecAttention(
  582. normed_hidden_states,
  583. mask=attention_mask,
  584. key_value_states=key_value_states,
  585. position_bias=position_bias,
  586. past_key_values=past_key_values,
  587. output_attentions=output_attentions,
  588. )
  589. layer_output = hidden_states + self.dropout(attention_output[0])
  590. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  591. return outputs
  592. # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop
  593. class UdopBlock(GradientCheckpointingLayer):
  594. def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
  595. super().__init__()
  596. self.is_decoder = config.is_decoder
  597. self.layer = nn.ModuleList()
  598. self.layer.append(
  599. UdopLayerSelfAttention(
  600. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  601. )
  602. )
  603. if self.is_decoder:
  604. self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx))
  605. self.layer.append(UdopLayerFF(config))
  606. def forward(
  607. self,
  608. hidden_states,
  609. attention_mask=None,
  610. position_bias=None,
  611. encoder_hidden_states=None,
  612. encoder_attention_mask=None,
  613. encoder_decoder_position_bias=None,
  614. past_key_values=None,
  615. use_cache=False,
  616. output_attentions=False,
  617. return_dict=True,
  618. **kwargs,
  619. ):
  620. self_attention_outputs = self.layer[0](
  621. hidden_states,
  622. attention_mask=attention_mask,
  623. position_bias=position_bias,
  624. past_key_values=past_key_values,
  625. use_cache=use_cache,
  626. output_attentions=output_attentions,
  627. )
  628. hidden_states = self_attention_outputs[0]
  629. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  630. # clamp inf values to enable fp16 training
  631. if hidden_states.dtype == torch.float16:
  632. clamp_value = torch.where(
  633. torch.isinf(hidden_states).any(),
  634. torch.finfo(hidden_states.dtype).max - 1000,
  635. torch.finfo(hidden_states.dtype).max,
  636. )
  637. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  638. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  639. if do_cross_attention:
  640. cross_attention_outputs = self.layer[1](
  641. hidden_states,
  642. key_value_states=encoder_hidden_states,
  643. attention_mask=encoder_attention_mask,
  644. position_bias=encoder_decoder_position_bias,
  645. past_key_values=past_key_values,
  646. output_attentions=output_attentions,
  647. )
  648. hidden_states = cross_attention_outputs[0]
  649. # clamp inf values to enable fp16 training
  650. if hidden_states.dtype == torch.float16:
  651. clamp_value = torch.where(
  652. torch.isinf(hidden_states).any(),
  653. torch.finfo(hidden_states.dtype).max - 1000,
  654. torch.finfo(hidden_states.dtype).max,
  655. )
  656. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  657. # Keep cross-attention outputs and relative position weights
  658. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  659. # Apply Feed Forward layer
  660. hidden_states = self.layer[-1](hidden_states)
  661. # clamp inf values to enable fp16 training
  662. if hidden_states.dtype == torch.float16:
  663. clamp_value = torch.where(
  664. torch.isinf(hidden_states).any(),
  665. torch.finfo(hidden_states.dtype).max - 1000,
  666. torch.finfo(hidden_states.dtype).max,
  667. )
  668. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  669. outputs = (hidden_states,)
  670. return (
  671. outputs + attention_outputs
  672. ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  673. class UdopCellEmbeddings(nn.Module):
  674. def __init__(self, max_2d_position_embeddings=501, hidden_size=1024):
  675. super().__init__()
  676. self.max_2d_position_embeddings = max_2d_position_embeddings
  677. self.x_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size)
  678. self.y_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size)
  679. def forward(self, bbox):
  680. bbox = torch.clip(bbox, 0.0, 1.0)
  681. bbox = (bbox * (self.max_2d_position_embeddings - 1)).long()
  682. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  683. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  684. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  685. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  686. embeddings = (
  687. left_position_embeddings
  688. + upper_position_embeddings
  689. + right_position_embeddings
  690. + lower_position_embeddings
  691. )
  692. return embeddings
  693. # get function for bucket computation
  694. # protected member access seems to be lesser evil than copy paste whole function
  695. get_relative_position_bucket = UdopAttention._relative_position_bucket
  696. AUGMENTATION_RANGE = (0.80, 1.25)
  697. class RelativePositionBiasBase(nn.Module, ABC):
  698. """
  699. Base class of relative biases.
  700. Args:
  701. num_heads (`int`):
  702. Number of attention heads in the model, it will create embeddings of size `num_heads`, which will be added to the scores of each token pair.
  703. relative_attention_num_buckets (`int`, *optional*, defaults to 32):
  704. Pair token metric (distance in the sequence, distance in pixels etc.) will be bucketed, parameter is defining number of such
  705. buckets.
  706. bidirectional (`bool`, *optional*, defaults to `True`):
  707. Whether the distance should be bidirectional for a pair of tokens. If `False`, then distance(tok1, tok2) == distance(tok2, tok1).
  708. scaling_factor (`int`, *optional*, defaults to 1):
  709. Defining factor which will be used to scale relative distance.
  710. max_distance (`int`, *optional*, defaults to 128):
  711. All distances above this value will end up in the one/same bucket.
  712. augmentation (`bool`, *optional*, defaults to `False`):
  713. Whether to multiply relative distances by a random scalar.
  714. expand (`bool`, *optional*, defaults to `False`):
  715. Whether to expand an existing pretrained model with subsequent additions of prefix_bucket.
  716. """
  717. def __init__(
  718. self,
  719. num_heads=None,
  720. relative_attention_num_buckets=32,
  721. bidirectional=True,
  722. scaling_factor=1,
  723. max_distance=128,
  724. level="tokens",
  725. augmentation=False,
  726. prefix_bucket=False,
  727. expand=False,
  728. ):
  729. super().__init__()
  730. self.prefix_bucket = prefix_bucket
  731. self.augmentation = augmentation
  732. self.level = level
  733. self.max_distance = max_distance
  734. self.scaling_factor = scaling_factor
  735. self.bidirectional = bidirectional
  736. self.num_heads = num_heads
  737. self.expand = expand
  738. self.relative_attention_num_buckets = relative_attention_num_buckets
  739. extra_head = 2 if prefix_bucket and not self.expand else 0
  740. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets + extra_head, self.num_heads)
  741. @abstractmethod
  742. def prepare_input(
  743. self,
  744. attention_mask: Tensor | None = None,
  745. bbox: dict[str, Any] | None = None,
  746. ) -> Tensor:
  747. pass
  748. def get_bucket(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
  749. relative_position = self.prepare_input(attention_mask, bbox)
  750. rp_bucket: Tensor = get_relative_position_bucket(
  751. relative_position,
  752. bidirectional=self.bidirectional,
  753. num_buckets=self.relative_attention_num_buckets,
  754. max_distance=self.max_distance,
  755. )
  756. return rp_bucket
  757. def get_relative_position(self, positions):
  758. context_position = positions[:, :, None]
  759. memory_position = positions[:, None, :]
  760. relative_position = memory_position - context_position
  761. if self.augmentation and self.training:
  762. relative_position *= random.uniform(*AUGMENTATION_RANGE)
  763. relative_position *= self.scaling_factor
  764. return relative_position.to(torch.long)
  765. def forward(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
  766. # re-using pretrained model with subsequent addition of prefix_bucket
  767. if self.expand and self.prefix_bucket:
  768. new_bias = nn.Embedding(self.relative_attention_num_buckets + 2, self.num_heads)
  769. new_bias.weight.data[: self.relative_attention_num_buckets] = self.relative_attention_bias.weight.data
  770. new_bias.weight.data[self.relative_attention_num_buckets :] = 0.1
  771. self.relative_attention_bias = new_bias
  772. self.expand = False
  773. rp_bucket = self.get_bucket(attention_mask, bbox)
  774. if self.prefix_bucket:
  775. if rp_bucket.size(0) == 1 and attention_mask.size(0) > 1:
  776. rp_bucket = rp_bucket.repeat(attention_mask.size(0), 1, 1)
  777. # based on assumption that prefix bboxes are negative
  778. is_prefix = bbox[:, :, 1] < 0
  779. num_prefix = is_prefix.sum(-1)
  780. for idx, num_prefix_row in enumerate(num_prefix.cpu().numpy()):
  781. rp_bucket[idx, :num_prefix_row, num_prefix_row:] = self.relative_attention_num_buckets
  782. rp_bucket[idx, num_prefix_row:, :num_prefix_row] = self.relative_attention_num_buckets + 1
  783. values: Tensor = self.relative_attention_bias(rp_bucket)
  784. if values.dim() != 4:
  785. raise ValueError("Wrong dimension of values tensor")
  786. values = values.permute([0, 3, 1, 2])
  787. return values
  788. class RelativePositionBias1D(RelativePositionBiasBase):
  789. def __init__(self, scaling_factor=1, max_distance=128, **kwargs):
  790. """
  791. Reimplementation of T5 relative position bias. Distance between given tokens is their distance in the sequence.
  792. Parameters are the same as in base class
  793. """
  794. super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
  795. def prepare_input(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
  796. if self.scaling_factor != 1:
  797. raise ValueError("No need to scale 1d features")
  798. relative_position = self.get_relative_position(
  799. torch.arange(attention_mask.size(1), dtype=torch.long, device=attention_mask.device)[None, :]
  800. )
  801. return relative_position
  802. class RelativePositionBiasHorizontal(RelativePositionBiasBase):
  803. def __init__(self, scaling_factor=100, max_distance=100, **kwargs):
  804. """
  805. Represents in the bucket embeddings horizontal distance between two tokens. Parameters are the same as in base
  806. class
  807. """
  808. super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
  809. def prepare_input(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
  810. if not self.scaling_factor > 1.0:
  811. raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range")
  812. if bbox is None:
  813. raise ValueError("Bbox is required for horizontal relative position bias")
  814. # get x positions of left point of bbox
  815. horizontal_position: Tensor = bbox[:, :, [0, 2]].mean(dim=-1)
  816. return self.get_relative_position(horizontal_position)
  817. class RelativePositionBiasVertical(RelativePositionBiasBase):
  818. def __init__(self, scaling_factor=100, max_distance=100, **kwargs):
  819. """
  820. Represents in the bucket embeddings vertical distance between two tokens. Parameters are the same as in base
  821. class
  822. """
  823. super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
  824. def prepare_input(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
  825. if not self.scaling_factor > 1.0:
  826. raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range")
  827. if bbox is None:
  828. raise ValueError("Bbox is required for vertical relative position bias")
  829. # get y positions of middle of bbox
  830. vertical_position: Tensor = bbox[:, :, [1, 3]].mean(dim=-1)
  831. return self.get_relative_position(vertical_position)
  832. class RelativePositionBiasAggregated(nn.Module):
  833. def __init__(self, modules: Sequence[RelativePositionBiasBase]):
  834. """
  835. Class which sums up various computed biases.
  836. Args:
  837. modules (Sequence[RelativePositionBiasBase]):
  838. List of relative bias modules.
  839. """
  840. super().__init__()
  841. self.biases = nn.ModuleList(modules)
  842. def forward(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> float | Tensor:
  843. output = 0.0
  844. for bias in self.biases: # type: ignore
  845. output = bias(attention_mask, bbox) + output
  846. return output
  847. BIAS_CLASSES = {
  848. "1d": RelativePositionBias1D,
  849. "horizontal": RelativePositionBiasHorizontal,
  850. "vertical": RelativePositionBiasVertical,
  851. }
  852. def create_relative_bias(config: UdopConfig) -> Sequence[RelativePositionBiasBase]:
  853. """
  854. Creates empty list or one/multiple relative biases.
  855. :param config: Model's configuration :return: Sequence with created bias modules.
  856. """
  857. bias_list = []
  858. if hasattr(config, "relative_bias_args"):
  859. for bias_kwargs_org in config.relative_bias_args:
  860. bias_kwargs = deepcopy(bias_kwargs_org)
  861. bias_type = bias_kwargs.pop("type")
  862. model_num_heads = config.num_heads if hasattr(config, "num_heads") else config.num_attention_heads
  863. if "num_heads" in bias_kwargs:
  864. if bias_kwargs["num_heads"] != model_num_heads:
  865. raise ValueError("Number of heads must match num of heads in the model")
  866. else:
  867. bias_kwargs["num_heads"] = model_num_heads
  868. bias_list.append(BIAS_CLASSES[bias_type](**bias_kwargs)) # type: ignore
  869. return bias_list
  870. class UdopStack(UdopPreTrainedModel):
  871. """
  872. This class is based on `T5Stack`, but modified to take into account the image modality as well as 2D position
  873. embeddings.
  874. """
  875. def __init__(self, config):
  876. super().__init__(config)
  877. # text and image embeddings
  878. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
  879. self.embed_patches = UdopPatchEmbeddings(config)
  880. self.is_decoder = config.is_decoder
  881. self.num_layers = config.num_layers
  882. self.block = nn.ModuleList(
  883. [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)]
  884. )
  885. self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  886. self.dropout = nn.Dropout(config.dropout_rate)
  887. if not self.is_decoder:
  888. self.cell_2d_embedding = UdopCellEmbeddings(config.max_2d_position_embeddings, config.hidden_size)
  889. # get weights from encoder position bias
  890. self.relative_bias = self._get_relative_bias(config)
  891. self.post_init()
  892. @staticmethod
  893. def _get_relative_bias(config: UdopConfig) -> RelativePositionBiasAggregated:
  894. relative_bias_list = create_relative_bias(config)
  895. return RelativePositionBiasAggregated(relative_bias_list)
  896. def get_output_embeddings(self):
  897. return self.embed_tokens
  898. def set_input_embeddings(self, new_embeddings):
  899. self.embed_tokens = new_embeddings
  900. def forward(
  901. self,
  902. input_ids=None,
  903. attention_mask=None,
  904. bbox=None,
  905. encoder_hidden_states=None,
  906. encoder_attention_mask=None,
  907. inputs_embeds=None,
  908. pixel_values=None,
  909. visual_bbox=None,
  910. image_embeddings=None,
  911. position_bias=None,
  912. past_key_values=None,
  913. use_cache=None,
  914. output_attentions=None,
  915. output_hidden_states=None,
  916. return_dict=None,
  917. **kwargs,
  918. ) -> tuple | BaseModelOutputWithAttentionMask:
  919. use_cache = use_cache if use_cache is not None else self.config.use_cache
  920. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  921. output_hidden_states = (
  922. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  923. )
  924. return_dict = return_dict if return_dict is not None else self.config.return_dict
  925. # input embeddings processing
  926. if input_ids is not None and inputs_embeds is not None:
  927. err_msg_prefix = "decoder_" if self.is_decoder else ""
  928. raise ValueError(
  929. f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
  930. )
  931. elif input_ids is not None and torch.numel(input_ids) > 0:
  932. input_shape = input_ids.size()
  933. input_ids = input_ids.view(-1, input_shape[-1])
  934. elif inputs_embeds is None and input_ids is not None and torch.numel(input_ids) == 0:
  935. input_ids = torch.full((4, 1024), self.config.pad_token_id, device=input_ids.device, dtype=input_ids.dtype)
  936. attention_mask = torch.zeros((4, 1024), device=input_ids.device, dtype=input_ids.dtype)
  937. bbox = torch.zeros((4, 1024, 4), device=input_ids.device, dtype=input_ids.dtype)
  938. input_shape = input_ids.size()
  939. position_bias = torch.zeros_like(self.get_extended_attention_mask(attention_mask, input_shape))
  940. # encoder_attention_mask = attention_mask
  941. logger.warning("Empty batch")
  942. elif inputs_embeds is not None:
  943. input_shape = inputs_embeds.size()[:-1]
  944. else:
  945. err_msg_prefix = "decoder_" if self.is_decoder else ""
  946. raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
  947. if inputs_embeds is None:
  948. if self.embed_tokens is None:
  949. raise ValueError("You have to initialize the model with valid token embeddings")
  950. inputs_embeds = self.embed_tokens(input_ids)
  951. if pixel_values is not None:
  952. image_embeddings = self.embed_patches(pixel_values)
  953. if image_embeddings is not None:
  954. # combine visual and OCR text embeddings
  955. num_patches = self.config.image_size // self.config.patch_size
  956. inputs_embeds, bbox, attention_mask = combine_image_text_embeddings(
  957. image_embeddings,
  958. inputs_embeds,
  959. bbox,
  960. visual_bbox,
  961. attention_mask,
  962. num_patches,
  963. 0,
  964. self.config.image_size,
  965. self.config.patch_size,
  966. )
  967. input_shape = inputs_embeds.size()[:-1]
  968. if not self.is_decoder and bbox is not None:
  969. inputs_embeds += self.cell_2d_embedding(bbox)
  970. batch_size, seq_length = input_shape
  971. if use_cache is True:
  972. assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
  973. if self.is_decoder:
  974. if use_cache and past_key_values is None:
  975. if self.config.is_encoder_decoder:
  976. past_key_values = EncoderDecoderCache(
  977. DynamicCache(config=self.config), DynamicCache(config=self.config)
  978. )
  979. else:
  980. past_key_values = DynamicCache(config=self.config)
  981. elif not self.is_decoder:
  982. # do not pass cache object down the line for encoder stack
  983. # it messes indexing later in decoder-stack because cache object is modified in-place
  984. past_key_values = None
  985. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  986. if attention_mask is None and not is_torchdynamo_compiling():
  987. # required mask seq length can be calculated via length of past cache
  988. mask_seq_length = past_key_values_length + seq_length
  989. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  990. if self.config.is_decoder:
  991. causal_mask = create_causal_mask(
  992. config=self.config,
  993. inputs_embeds=inputs_embeds,
  994. attention_mask=attention_mask,
  995. past_key_values=past_key_values,
  996. )
  997. else:
  998. causal_mask = attention_mask[:, None, None, :]
  999. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  1000. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  1001. if self.is_decoder and encoder_attention_mask is not None:
  1002. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1003. else:
  1004. encoder_extended_attention_mask = None
  1005. all_hidden_states = () if output_hidden_states else None
  1006. all_attentions = () if output_attentions else None
  1007. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  1008. if self.is_decoder: # modified lines
  1009. position_bias = None
  1010. else:
  1011. position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox)
  1012. position_bias = position_bias + causal_mask
  1013. encoder_decoder_position_bias = None
  1014. hidden_states = inputs_embeds
  1015. hidden_states = self.dropout(hidden_states)
  1016. for i, layer_module in enumerate(self.block):
  1017. if output_hidden_states:
  1018. all_hidden_states = all_hidden_states + (hidden_states,)
  1019. layer_outputs = layer_module(
  1020. hidden_states,
  1021. causal_mask,
  1022. position_bias,
  1023. encoder_hidden_states,
  1024. encoder_extended_attention_mask,
  1025. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  1026. past_key_values=past_key_values,
  1027. use_cache=use_cache,
  1028. output_attentions=output_attentions,
  1029. )
  1030. hidden_states = layer_outputs[0]
  1031. # We share the position biases between the layers - the first layer store them
  1032. # layer_outputs = hidden-states, key-value-states (self-attention weights),
  1033. # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
  1034. position_bias = layer_outputs[1]
  1035. if self.is_decoder and encoder_hidden_states is not None:
  1036. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  1037. if output_attentions:
  1038. all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
  1039. if self.is_decoder:
  1040. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  1041. hidden_states = self.final_layer_norm(hidden_states)
  1042. hidden_states = self.dropout(hidden_states)
  1043. # Add last layer
  1044. if output_hidden_states:
  1045. all_hidden_states = all_hidden_states + (hidden_states,)
  1046. if not return_dict:
  1047. return tuple(
  1048. v
  1049. for v in [
  1050. hidden_states,
  1051. attention_mask,
  1052. past_key_values,
  1053. all_hidden_states,
  1054. all_attentions,
  1055. all_cross_attentions,
  1056. ]
  1057. if v is not None
  1058. )
  1059. return BaseModelOutputWithAttentionMask(
  1060. last_hidden_state=hidden_states,
  1061. attention_mask=attention_mask,
  1062. past_key_values=past_key_values,
  1063. hidden_states=all_hidden_states,
  1064. attentions=all_attentions,
  1065. cross_attentions=all_cross_attentions,
  1066. )
  1067. @auto_docstring
  1068. class UdopModel(UdopPreTrainedModel):
  1069. _tied_weights_keys = {
  1070. "encoder.embed_tokens.weight": "shared.weight",
  1071. "decoder.embed_tokens.weight": "shared.weight",
  1072. "encoder.embed_patches.proj.weight": "patch_embed.proj.weight",
  1073. "encoder.embed_patches.proj.bias": "patch_embed.proj.bias",
  1074. }
  1075. def __init__(self, config):
  1076. super().__init__(config)
  1077. # text and image embeddings
  1078. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1079. self.patch_embed = UdopPatchEmbeddings(config)
  1080. encoder_config = deepcopy(config)
  1081. encoder_config.is_decoder = False
  1082. encoder_config.use_cache = False
  1083. self.encoder = UdopStack(encoder_config)
  1084. decoder_config = deepcopy(config)
  1085. decoder_config.is_decoder = True
  1086. decoder_config.num_layers = config.num_decoder_layers
  1087. self.decoder = UdopStack(decoder_config)
  1088. # Initialize weights and apply final processing
  1089. self.post_init()
  1090. def get_input_embeddings(self):
  1091. return self.shared
  1092. def set_input_embeddings(self, new_embeddings):
  1093. self.shared = new_embeddings
  1094. self.encoder.set_input_embeddings(new_embeddings)
  1095. self.decoder.set_input_embeddings(new_embeddings)
  1096. @auto_docstring
  1097. def forward(
  1098. self,
  1099. input_ids: Tensor | None = None,
  1100. attention_mask: Tensor | None = None,
  1101. bbox: dict[str, Any] | None = None,
  1102. pixel_values: Tensor | None = None,
  1103. visual_bbox: dict[str, Any] | None = None,
  1104. decoder_input_ids: Tensor | None = None,
  1105. decoder_attention_mask: Tensor | None = None,
  1106. inputs_embeds: Tensor | None = None,
  1107. encoder_outputs: Tensor | None = None,
  1108. past_key_values: Cache | None = None,
  1109. decoder_inputs_embeds: Tensor | None = None,
  1110. use_cache: bool | None = None,
  1111. output_attentions: bool | None = None,
  1112. output_hidden_states: bool | None = None,
  1113. return_dict: bool | None = None,
  1114. **kwargs,
  1115. ) -> tuple | Seq2SeqModelOutput:
  1116. r"""
  1117. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  1118. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  1119. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  1120. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  1121. y1) represents the position of the lower right corner.
  1122. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  1123. token. See `pixel_values` for `patch_sequence_length`.
  1124. visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
  1125. Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
  1126. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1127. Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
  1128. [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  1129. [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting
  1130. token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
  1131. `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
  1132. `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training).
  1133. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1134. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1135. be used by default.
  1136. Example:
  1137. ```python
  1138. >>> from transformers import AutoProcessor, AutoModel
  1139. >>> from datasets import load_dataset
  1140. >>> import torch
  1141. >>> # load model and processor
  1142. >>> # in this case, we already have performed OCR ourselves
  1143. >>> # so we initialize the processor with `apply_ocr=False`
  1144. >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
  1145. >>> model = AutoModel.from_pretrained("microsoft/udop-large")
  1146. >>> # load an example image, along with the words and coordinates
  1147. >>> # which were extracted using an OCR engine
  1148. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  1149. >>> example = dataset[0]
  1150. >>> image = example["image"]
  1151. >>> words = example["tokens"]
  1152. >>> boxes = example["bboxes"]
  1153. >>> inputs = processor(image, words, boxes=boxes, return_tensors="pt")
  1154. >>> decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]])
  1155. >>> # forward pass
  1156. >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
  1157. >>> last_hidden_states = outputs.last_hidden_state
  1158. >>> list(last_hidden_states.shape)
  1159. [1, 1, 1024]
  1160. ```"""
  1161. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1162. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1163. # Encode if needed (training, first prediction pass)
  1164. if encoder_outputs is None:
  1165. encoder_outputs = self.encoder(
  1166. input_ids=input_ids,
  1167. attention_mask=attention_mask,
  1168. bbox=bbox,
  1169. pixel_values=pixel_values,
  1170. visual_bbox=visual_bbox,
  1171. inputs_embeds=inputs_embeds,
  1172. output_attentions=output_attentions,
  1173. output_hidden_states=output_hidden_states,
  1174. return_dict=return_dict,
  1175. )
  1176. hidden_states = encoder_outputs[0]
  1177. encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1]
  1178. # Decode
  1179. decoder_outputs = self.decoder(
  1180. input_ids=decoder_input_ids,
  1181. attention_mask=decoder_attention_mask,
  1182. inputs_embeds=decoder_inputs_embeds,
  1183. past_key_values=past_key_values,
  1184. encoder_hidden_states=hidden_states,
  1185. encoder_attention_mask=encoder_attention_mask,
  1186. use_cache=use_cache,
  1187. output_attentions=output_attentions,
  1188. output_hidden_states=output_hidden_states,
  1189. return_dict=return_dict,
  1190. )
  1191. if not return_dict:
  1192. # we filter out the attention mask
  1193. decoder_outputs = tuple(value for idx, value in enumerate(decoder_outputs) if idx != 1)
  1194. encoder_outputs = tuple(value for idx, value in enumerate(encoder_outputs) if idx != 1)
  1195. return decoder_outputs + encoder_outputs
  1196. return Seq2SeqModelOutput(
  1197. last_hidden_state=decoder_outputs.last_hidden_state,
  1198. past_key_values=decoder_outputs.past_key_values,
  1199. decoder_hidden_states=decoder_outputs.hidden_states,
  1200. decoder_attentions=decoder_outputs.attentions,
  1201. cross_attentions=decoder_outputs.cross_attentions,
  1202. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1203. encoder_hidden_states=encoder_outputs.hidden_states,
  1204. encoder_attentions=encoder_outputs.attentions,
  1205. )
  1206. @auto_docstring(
  1207. custom_intro="""
  1208. The UDOP encoder-decoder Transformer with a language modeling head on top, enabling to generate text given document
  1209. images and an optional prompt.
  1210. This class is based on [`T5ForConditionalGeneration`], extended to deal with images and layout (2D) data.
  1211. """
  1212. )
  1213. class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
  1214. _tied_weights_keys = {
  1215. "encoder.embed_tokens.weight": "shared.weight",
  1216. "decoder.embed_tokens.weight": "shared.weight",
  1217. "encoder.embed_patches.proj.weight": "patch_embed.proj.weight",
  1218. "encoder.embed_patches.proj.bias": "patch_embed.proj.bias",
  1219. "encoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
  1220. "decoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
  1221. "lm_head.weight": "shared.weight",
  1222. }
  1223. def __init__(self, config):
  1224. super().__init__(config)
  1225. # text and image embeddings
  1226. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1227. self.patch_embed = UdopPatchEmbeddings(config)
  1228. encoder_config = deepcopy(config)
  1229. encoder_config.is_decoder = False
  1230. encoder_config.use_cache = False
  1231. self.encoder = UdopStack(encoder_config)
  1232. decoder_config = deepcopy(config)
  1233. decoder_config.is_decoder = True
  1234. decoder_config.num_layers = config.num_decoder_layers
  1235. self.decoder = UdopStack(decoder_config)
  1236. # The weights of the language modeling head are shared with those of the encoder and decoder
  1237. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1238. # Initialize weights and apply final processing
  1239. self.post_init()
  1240. def get_input_embeddings(self):
  1241. return self.shared
  1242. def set_input_embeddings(self, new_embeddings):
  1243. self.shared = new_embeddings
  1244. self.encoder.set_input_embeddings(new_embeddings)
  1245. self.decoder.set_input_embeddings(new_embeddings)
  1246. @auto_docstring
  1247. def forward(
  1248. self,
  1249. input_ids: Tensor | None = None,
  1250. attention_mask: Tensor | None = None,
  1251. bbox: dict[str, Any] | None = None,
  1252. pixel_values: Tensor | None = None,
  1253. visual_bbox: dict[str, Any] | None = None,
  1254. decoder_input_ids: Tensor | None = None,
  1255. decoder_attention_mask: Tensor | None = None,
  1256. inputs_embeds: Tensor | None = None,
  1257. encoder_outputs: Tensor | None = None,
  1258. past_key_values: Cache | None = None,
  1259. decoder_inputs_embeds: Tensor | None = None,
  1260. use_cache: bool | None = None,
  1261. output_attentions: bool | None = None,
  1262. output_hidden_states: bool | None = None,
  1263. return_dict: bool | None = None,
  1264. labels: Tensor | None = None,
  1265. **kwargs,
  1266. ) -> tuple | Seq2SeqLMOutput:
  1267. r"""
  1268. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  1269. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  1270. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  1271. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  1272. y1) represents the position of the lower right corner.
  1273. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  1274. token. See `pixel_values` for `patch_sequence_length`.
  1275. visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
  1276. Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
  1277. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1278. Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
  1279. [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  1280. [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting
  1281. token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
  1282. `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
  1283. `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training).
  1284. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1285. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1286. be used by default.
  1287. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1288. Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
  1289. 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1290. config.vocab_size]`.
  1291. Examples:
  1292. ```python
  1293. >>> from transformers import AutoProcessor, UdopForConditionalGeneration
  1294. >>> from datasets import load_dataset
  1295. >>> # load model and processor
  1296. >>> # in this case, we already have performed OCR ourselves
  1297. >>> # so we initialize the processor with `apply_ocr=False`
  1298. >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
  1299. >>> model = UdopForConditionalGeneration.from_pretrained("microsoft/udop-large")
  1300. >>> # load an example image, along with the words and coordinates
  1301. >>> # which were extracted using an OCR engine
  1302. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  1303. >>> example = dataset[0]
  1304. >>> image = example["image"]
  1305. >>> words = example["tokens"]
  1306. >>> boxes = example["bboxes"]
  1307. >>> # one can use the various task prefixes (prompts) used during pre-training
  1308. >>> # e.g. the task prefix for DocVQA is "Question answering. "
  1309. >>> question = "Question answering. What is the date on the form?"
  1310. >>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt")
  1311. >>> # autoregressive generation
  1312. >>> predicted_ids = model.generate(**encoding)
  1313. >>> print(processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
  1314. 9/30/92
  1315. ```"""
  1316. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1317. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1318. if decoder_input_ids is None and labels is not None:
  1319. decoder_input_ids = self._shift_right(labels)
  1320. # Encode if needed (training, first prediction pass)
  1321. if encoder_outputs is None:
  1322. encoder_outputs = self.encoder(
  1323. input_ids=input_ids,
  1324. bbox=bbox,
  1325. visual_bbox=visual_bbox,
  1326. pixel_values=pixel_values,
  1327. attention_mask=attention_mask,
  1328. inputs_embeds=inputs_embeds,
  1329. output_attentions=output_attentions,
  1330. output_hidden_states=output_hidden_states,
  1331. return_dict=return_dict,
  1332. )
  1333. hidden_states = encoder_outputs[0]
  1334. encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1]
  1335. # Decode
  1336. decoder_outputs = self.decoder(
  1337. input_ids=decoder_input_ids,
  1338. attention_mask=decoder_attention_mask,
  1339. inputs_embeds=decoder_inputs_embeds,
  1340. past_key_values=past_key_values,
  1341. encoder_hidden_states=hidden_states,
  1342. encoder_attention_mask=encoder_attention_mask,
  1343. use_cache=use_cache,
  1344. output_attentions=output_attentions,
  1345. output_hidden_states=output_hidden_states,
  1346. return_dict=return_dict,
  1347. )
  1348. sequence_output = decoder_outputs[0]
  1349. if self.config.tie_word_embeddings:
  1350. sequence_output = sequence_output * (self.config.d_model**-0.5)
  1351. lm_logits = self.lm_head(sequence_output)
  1352. loss = None
  1353. if labels is not None:
  1354. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1355. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1356. if not return_dict:
  1357. output = (lm_logits,) + decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:]
  1358. return ((loss,) + output) if loss is not None else output
  1359. return Seq2SeqLMOutput(
  1360. loss=loss,
  1361. logits=lm_logits,
  1362. past_key_values=decoder_outputs.past_key_values,
  1363. decoder_hidden_states=decoder_outputs.hidden_states,
  1364. decoder_attentions=decoder_outputs.attentions,
  1365. cross_attentions=decoder_outputs.cross_attentions,
  1366. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1367. encoder_hidden_states=encoder_outputs.hidden_states,
  1368. encoder_attentions=encoder_outputs.attentions,
  1369. )
  1370. @auto_docstring
  1371. class UdopEncoderModel(UdopPreTrainedModel):
  1372. _tied_weights_keys = {
  1373. "encoder.embed_tokens.weight": "shared.weight",
  1374. "encoder.embed_patches.proj.weight": "patch_embed.proj.weight",
  1375. "encoder.embed_patches.proj.bias": "patch_embed.proj.bias",
  1376. "encoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
  1377. }
  1378. def __init__(self, config: UdopConfig):
  1379. super().__init__(config)
  1380. # text and image embeddings
  1381. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1382. self.patch_embed = UdopPatchEmbeddings(config)
  1383. encoder_config = deepcopy(config)
  1384. encoder_config.is_decoder = False
  1385. encoder_config.use_cache = False
  1386. encoder_config.is_encoder_decoder = False
  1387. self.encoder = UdopStack(encoder_config)
  1388. # Initialize weights and apply final processing
  1389. self.post_init()
  1390. def get_input_embeddings(self):
  1391. return self.shared
  1392. def set_input_embeddings(self, new_embeddings):
  1393. self.shared = new_embeddings
  1394. self.encoder.set_input_embeddings(new_embeddings)
  1395. @auto_docstring
  1396. def forward(
  1397. self,
  1398. input_ids: Tensor | None = None,
  1399. bbox: dict[str, Any] | None = None,
  1400. attention_mask: Tensor | None = None,
  1401. pixel_values: Tensor | None = None,
  1402. visual_bbox: dict[str, Any] | None = None,
  1403. inputs_embeds: Tensor | None = None,
  1404. output_attentions: bool | None = None,
  1405. output_hidden_states: bool | None = None,
  1406. return_dict: bool | None = None,
  1407. **kwargs,
  1408. ) -> tuple[torch.FloatTensor] | BaseModelOutputWithAttentionMask:
  1409. r"""
  1410. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1411. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1412. should be able to pad the inputs on both the right and the left.
  1413. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1414. [`PreTrainedTokenizer.__call__`] for detail.
  1415. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1416. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  1417. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  1418. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  1419. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  1420. y1) represents the position of the lower right corner.
  1421. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  1422. token. See `pixel_values` for `patch_sequence_length`.
  1423. visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
  1424. Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
  1425. Example:
  1426. ```python
  1427. >>> from transformers import AutoProcessor, UdopEncoderModel
  1428. >>> from huggingface_hub import hf_hub_download
  1429. >>> from datasets import load_dataset
  1430. >>> # load model and processor
  1431. >>> # in this case, we already have performed OCR ourselves
  1432. >>> # so we initialize the processor with `apply_ocr=False`
  1433. >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
  1434. >>> model = UdopEncoderModel.from_pretrained("microsoft/udop-large")
  1435. >>> # load an example image, along with the words and coordinates
  1436. >>> # which were extracted using an OCR engine
  1437. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  1438. >>> example = dataset[0]
  1439. >>> image = example["image"]
  1440. >>> words = example["tokens"]
  1441. >>> boxes = example["bboxes"]
  1442. >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
  1443. >>> outputs = model(**encoding)
  1444. >>> last_hidden_states = outputs.last_hidden_state
  1445. ```"""
  1446. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1447. output_hidden_states = (
  1448. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1449. )
  1450. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1451. encoder_outputs = self.encoder(
  1452. input_ids=input_ids,
  1453. bbox=bbox,
  1454. visual_bbox=visual_bbox,
  1455. pixel_values=pixel_values,
  1456. attention_mask=attention_mask,
  1457. inputs_embeds=inputs_embeds,
  1458. output_attentions=output_attentions,
  1459. output_hidden_states=output_hidden_states,
  1460. return_dict=return_dict,
  1461. )
  1462. return encoder_outputs
  1463. __all__ = ["UdopForConditionalGeneration", "UdopPreTrainedModel", "UdopModel", "UdopEncoderModel"]