| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726 |
- # Copyright 2024 Microsoft Research and HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch UDOP model."""
- import collections
- import logging
- import math
- import random
- from abc import ABC, abstractmethod
- from collections.abc import Sequence
- from copy import deepcopy
- from dataclasses import dataclass
- from typing import Any
- import torch
- from torch import Tensor, nn
- from torch.nn import CrossEntropyLoss
- from transformers import UdopConfig
- from transformers.modeling_outputs import (
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- )
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
- from ...generation import GenerationMixin
- from ...masking_utils import create_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_utils import PreTrainedModel
- from ...utils import (
- ModelOutput,
- auto_docstring,
- is_torchdynamo_compiling,
- )
- logger = logging.getLogger(__name__)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Class for the model's outputs that may also contain a past key/values (to speed up sequential decoding). Includes
- an additional attention mask.
- """
- )
- class BaseModelOutputWithAttentionMask(ModelOutput):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only
- the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output.
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Attention mask used in the model's forward pass to avoid performing attention on padding token indices.
- Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
- Contains pre-computed hidden-states (key and values in the
- self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks)
- that can be used (see `past_key_values` input) to speed up sequential decoding.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
- the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
- the self-attention heads.
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
- used to compute the weighted average in the cross-attention heads.
- """
- last_hidden_state: torch.FloatTensor | None = None
- attention_mask: torch.FloatTensor | None = None
- past_key_values: Cache | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- cross_attentions: tuple[torch.FloatTensor] | None = None
- def get_visual_bbox(image_size=224, patch_size=16):
- image_feature_pool_shape = [image_size // patch_size, image_size // patch_size]
- visual_bbox_x = torch.arange(0, 1.0 * (image_feature_pool_shape[1] + 1), 1.0)
- visual_bbox_x /= image_feature_pool_shape[1]
- visual_bbox_y = torch.arange(0, 1.0 * (image_feature_pool_shape[0] + 1), 1.0)
- visual_bbox_y /= image_feature_pool_shape[0]
- visual_bbox_input = torch.stack(
- [
- visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1),
- visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
- visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1),
- visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
- ],
- dim=-1,
- )
- visual_bbox_input = visual_bbox_input.view(-1, 4)
- return visual_bbox_input
- def pad_sequence(seq, target_len, pad_value=0):
- if isinstance(seq, torch.Tensor):
- n = seq.shape[0]
- else:
- n = len(seq)
- seq = torch.tensor(seq)
- m = target_len - n
- if m > 0:
- ret = torch.stack([pad_value] * m).to(seq)
- seq = torch.cat([seq, ret], dim=0)
- return seq[:target_len]
- def combine_image_text_embeddings(
- image_embeddings,
- inputs_embeds,
- bbox,
- visual_bbox,
- attention_mask=None,
- num_patches=14,
- max_len=0,
- image_size=224,
- patch_size=16,
- ):
- """
- Combine the image and text embeddings for the input to the encoder/decoder of UDOP.
- First, the image embeddings are created by checking for each visual patch if it is inside the bounding box of a
- token. If it is, the visual patch is combined with the token embedding. Then, the visual bounding boxes are combined
- with the text bounding boxes. Finally, the visual bounding boxes are combined with the text attention mask.
- """
- sequence_length = num_patches
- ocr_points_x = torch.clip(
- torch.floor((bbox[:, :, 0] + bbox[:, :, 2]) / 2.0 * sequence_length).long(), 0, sequence_length - 1
- )
- ocr_points_y = (
- torch.clip(torch.floor((bbox[:, :, 1] + bbox[:, :, 3]) / 2.0 * sequence_length).long(), 0, sequence_length - 1)
- * sequence_length
- )
- ocr_points = ocr_points_x + ocr_points_y
- # make sure bounding boxes are of type float to calculate means
- bbox = bbox.to(torch.float64)
- target_seg = (bbox.mean(-1) == 0.0) | (bbox.mean(-1) == 1.0)
- repeated_vision_embeds = torch.gather(
- image_embeddings, 1, ocr_points.unsqueeze(-1).repeat(1, 1, image_embeddings.size(-1))
- )
- repeated_vision_embeds[target_seg] = 0.0
- inputs_embeds += repeated_vision_embeds
- patch_inds = torch.full_like(image_embeddings[:, :, 0], True).bool()
- ind = torch.cat(
- [
- torch.arange(len(ocr_points))[:, None].repeat(1, ocr_points.size(-1))[:, :, None].to(ocr_points),
- ocr_points[:, :, None],
- ],
- dim=-1,
- )
- ind = ind.flatten(0, 1)
- rows, cols = zip(*ind)
- patch_inds[rows, cols] = False
- input_vision_patches = [image_embeddings[i][patch_inds[i]] for i in range(len(patch_inds))]
- if visual_bbox is None:
- visual_bbox = get_visual_bbox(image_size=image_size, patch_size=patch_size)
- visual_bbox = visual_bbox.unsqueeze(0).repeat(image_embeddings.size(0), 1, 1)
- visual_bbox = visual_bbox.to(image_embeddings.device)
- visual_bbox = [visual_bbox[i][patch_inds[i]] for i in range(len(patch_inds))]
- if attention_mask is not None:
- visual_attention_mask = [
- torch.ones(item.size(0), dtype=attention_mask.dtype, device=attention_mask.device) for item in visual_bbox
- ]
- if max_len == 0:
- max_len = image_embeddings.size(1)
- else:
- max_len = max_len - inputs_embeds.size(1)
- inputs_vision_patches = torch.stack(
- [pad_sequence(item, max_len, torch.zeros_like(image_embeddings[0, 0])) for item in input_vision_patches]
- )
- visual_bbox = torch.stack([pad_sequence(item, max_len, torch.zeros_like(bbox[0, 0])) for item in visual_bbox])
- if attention_mask is not None:
- visual_attention_mask = torch.stack(
- [pad_sequence(item, max_len, torch.zeros_like(attention_mask[0, 0])) for item in visual_attention_mask]
- )
- inputs_embeds = torch.cat([inputs_embeds, inputs_vision_patches], 1)
- bbox = torch.cat([bbox, visual_bbox], 1)
- if attention_mask is not None:
- attention_mask = torch.cat([attention_mask, visual_attention_mask], 1)
- return inputs_embeds, bbox, attention_mask
- class UdopPatchEmbeddings(nn.Module):
- """2D Image to Patch Embeddings"""
- def __init__(self, config):
- super().__init__()
- image_size, patch_size = config.image_size, config.patch_size
- num_channels, hidden_size = config.num_channels, config.hidden_size
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- self.image_size = image_size
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.num_patches = num_patches
- self.proj = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
- def forward(self, pixel_values):
- batch_size, num_channels, height, width = pixel_values.shape
- if height != self.image_size[0] or width != self.image_size[1]:
- raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
- )
- embeddings = self.proj(pixel_values)
- embeddings = embeddings.flatten(2).transpose(1, 2)
- return embeddings
- @auto_docstring
- class UdopPreTrainedModel(PreTrainedModel):
- config: UdopConfig
- base_model_prefix = "transformer"
- input_modalities = ("image", "text")
- supports_gradient_checkpointing = True
- _can_compile_fullgraph = False
- _keep_in_fp32_modules = ["wo"]
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- factor = self.config.initializer_factor # Used for testing weights initialization
- if isinstance(module, UdopLayerNorm):
- init.constant_(module.weight, factor * 1.0)
- elif isinstance(module, nn.Embedding):
- init.normal_(module.weight, mean=0.0, std=factor)
- # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
- if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
- init.zeros_(module.weight[module.padding_idx])
- elif isinstance(module, nn.Conv2d):
- init.trunc_normal_(module.weight, mean=0.0, std=factor)
- if module.bias is not None:
- init.zeros_(module.bias)
- elif isinstance(module, RelativePositionBiasBase):
- factor = self.config.initializer_factor
- d_model = self.config.d_model
- init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
- elif isinstance(module, UdopModel):
- init.normal_(module.shared.weight, mean=0.0, std=factor * 1.0)
- elif isinstance(module, UdopForConditionalGeneration):
- if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
- init.normal_(module.lm_head.weight, mean=0.0, std=factor * 1.0)
- elif isinstance(module, UdopDenseActDense):
- init.normal_(module.wi.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi, "bias") and module.wi.bias is not None:
- init.zeros_(module.wi.bias)
- init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
- init.zeros_(module.wo.bias)
- elif isinstance(module, UdopDenseGatedActDense):
- init.normal_(module.wi_0.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
- init.zeros_(module.wi_0.bias)
- init.normal_(module.wi_1.weight, mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
- if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
- init.zeros_(module.wi_1.bias)
- init.normal_(module.wo.weight, mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
- init.zeros_(module.wo.bias)
- elif isinstance(module, UdopAttention):
- d_model = self.config.d_model
- key_value_proj_dim = self.config.d_kv
- n_heads = self.config.num_heads
- init.normal_(module.q.weight, mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
- init.normal_(module.k.weight, mean=0.0, std=factor * (d_model**-0.5))
- init.normal_(module.v.weight, mean=0.0, std=factor * (d_model**-0.5))
- init.normal_(module.o.weight, mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
- if module.has_relative_attention_bias:
- init.normal_(module.relative_attention_bias.weight, mean=0.0, std=factor * ((d_model) ** -0.5))
- # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetPreTrainedModel._shift_right with ProphetNet->Udop
- def _shift_right(self, input_ids):
- decoder_start_token_id = self.config.decoder_start_token_id
- pad_token_id = self.config.pad_token_id
- assert decoder_start_token_id is not None, (
- "self.model.config.decoder_start_token_id has to be defined. In Udop it is usually set to the"
- " pad_token_id. See Udop docs for more information"
- )
- # shift inputs to the right
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
- shifted_input_ids[..., 0] = decoder_start_token_id
- assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
- return shifted_input_ids
- # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Udop
- class UdopLayerNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Construct a layernorm module in the Udop style. No bias and no subtraction of mean.
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states):
- # Udop uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
- # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
- # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
- # half-precision inputs is done in fp32
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- # convert into half-precision if necessary
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
- return self.weight * hidden_states
- # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Udop
- class UdopDenseActDense(nn.Module):
- def __init__(self, config: UdopConfig):
- super().__init__()
- self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.act = ACT2FN[config.dense_act_fn]
- def forward(self, hidden_states):
- hidden_states = self.wi(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.dropout(hidden_states)
- if (
- isinstance(self.wo.weight, torch.Tensor)
- and hidden_states.dtype != self.wo.weight.dtype
- and self.wo.weight.dtype != torch.int8
- ):
- hidden_states = hidden_states.to(self.wo.weight.dtype)
- hidden_states = self.wo(hidden_states)
- return hidden_states
- # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Udop
- class UdopDenseGatedActDense(nn.Module):
- def __init__(self, config: UdopConfig):
- super().__init__()
- self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
- self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.act = ACT2FN[config.dense_act_fn]
- def forward(self, hidden_states):
- hidden_gelu = self.act(self.wi_0(hidden_states))
- hidden_linear = self.wi_1(hidden_states)
- hidden_states = hidden_gelu * hidden_linear
- hidden_states = self.dropout(hidden_states)
- # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
- # See https://github.com/huggingface/transformers/issues/20287
- # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
- if (
- isinstance(self.wo.weight, torch.Tensor)
- and hidden_states.dtype != self.wo.weight.dtype
- and self.wo.weight.dtype != torch.int8
- ):
- hidden_states = hidden_states.to(self.wo.weight.dtype)
- hidden_states = self.wo(hidden_states)
- return hidden_states
- # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Udop
- class UdopLayerFF(nn.Module):
- def __init__(self, config: UdopConfig):
- super().__init__()
- if config.is_gated_act:
- self.DenseReluDense = UdopDenseGatedActDense(config)
- else:
- self.DenseReluDense = UdopDenseActDense(config)
- self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(self, hidden_states):
- forwarded_states = self.layer_norm(hidden_states)
- forwarded_states = self.DenseReluDense(forwarded_states)
- hidden_states = hidden_states + self.dropout(forwarded_states)
- return hidden_states
- # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop
- class UdopAttention(nn.Module):
- def __init__(
- self,
- config: UdopConfig,
- has_relative_attention_bias=False,
- layer_idx: int | None = None,
- ):
- super().__init__()
- self.is_decoder = config.is_decoder
- self.has_relative_attention_bias = has_relative_attention_bias
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
- self.relative_attention_max_distance = config.relative_attention_max_distance
- self.d_model = config.d_model
- self.key_value_proj_dim = config.d_kv
- self.n_heads = config.num_heads
- self.dropout = config.dropout_rate
- self.inner_dim = self.n_heads * self.key_value_proj_dim
- self.layer_idx = layer_idx
- if layer_idx is None and self.is_decoder:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
- self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
- if self.has_relative_attention_bias:
- self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
- self.gradient_checkpointing = False
- @staticmethod
- def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
- """
- Adapted from Mesh Tensorflow:
- https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
- Translate relative position to a bucket number for relative attention. The relative position is defined as
- memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
- position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
- small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
- positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
- This should allow for more graceful generalization to longer sequences than the model has been trained on
- Args:
- relative_position: an int32 Tensor
- bidirectional: a boolean - whether the attention is bidirectional
- num_buckets: an integer
- max_distance: an integer
- Returns:
- a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
- """
- relative_buckets = 0
- if bidirectional:
- num_buckets //= 2
- relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
- relative_position = torch.abs(relative_position)
- else:
- relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
- # now relative_position is in the range [0, inf)
- # half of the buckets are for exact increments in positions
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
- relative_position_if_large = torch.min(
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
- )
- relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
- return relative_buckets
- def compute_bias(self, query_length, key_length, device=None, past_seen_tokens=0):
- """Compute binned relative position bias"""
- if device is None:
- device = self.relative_attention_bias.weight.device
- context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + past_seen_tokens
- memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
- relative_position = memory_position - context_position # shape (query_length, key_length)
- relative_position_bucket = self._relative_position_bucket(
- relative_position, # shape (query_length, key_length)
- bidirectional=(not self.is_decoder),
- num_buckets=self.relative_attention_num_buckets,
- max_distance=self.relative_attention_max_distance,
- )
- values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
- values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
- return values
- def forward(
- self,
- hidden_states,
- mask=None,
- key_value_states=None,
- position_bias=None,
- past_key_values=None,
- output_attentions=False,
- **kwargs,
- ):
- """
- Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
- """
- # Input is (batch_size, seq_length, dim)
- # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.key_value_proj_dim)
- past_seen_tokens = past_key_values.get_seq_length(self.layer_idx) if past_key_values is not None else 0
- # We clone here for StaticCache, as we get the value before updating it, but use it after and it's the same ref
- past_seen_tokens = past_seen_tokens.clone() if isinstance(past_seen_tokens, torch.Tensor) else past_seen_tokens
- # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
- is_cross_attention = key_value_states is not None
- query_states = self.q(hidden_states).view(hidden_shape).transpose(1, 2)
- # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
- is_updated = False
- if isinstance(past_key_values, EncoderDecoderCache):
- is_updated = past_key_values.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_values = past_key_values.cross_attention_cache
- else:
- curr_past_key_values = past_key_values.self_attention_cache
- else:
- curr_past_key_values = past_key_values
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_values is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_values.layers[self.layer_idx].keys
- value_states = curr_past_key_values.layers[self.layer_idx].values
- else:
- kv_shape = (*current_states.shape[:-1], -1, self.key_value_proj_dim)
- key_states = self.k(current_states).view(kv_shape).transpose(1, 2)
- value_states = self.v(current_states).view(kv_shape).transpose(1, 2)
- if past_key_values is not None:
- key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
- past_key_values.is_updated[self.layer_idx] = True
- # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
- scores = torch.matmul(query_states, key_states.transpose(3, 2))
- if position_bias is None:
- key_length = key_states.shape[-2]
- if not self.has_relative_attention_bias:
- position_bias = torch.zeros(
- (1, query_states.shape[1], input_shape[1], key_length), device=scores.device, dtype=scores.dtype
- )
- if self.gradient_checkpointing and self.training:
- position_bias.requires_grad = True
- else:
- position_bias = self.compute_bias(
- input_shape[1], key_length, device=scores.device, past_seen_tokens=past_seen_tokens
- )
- if mask is not None:
- causal_mask = mask[:, :, :, : key_states.shape[-2]]
- position_bias = position_bias + causal_mask
- position_bias_masked = position_bias
- scores += position_bias_masked
- # (batch_size, n_heads, seq_length, key_length)
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(*input_shape, -1)
- attn_output = self.o(attn_output)
- outputs = (attn_output, position_bias)
- if output_attentions:
- outputs = outputs + (attn_weights,)
- return outputs
- # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop
- class UdopLayerSelfAttention(nn.Module):
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
- super().__init__()
- self.SelfAttention = UdopAttention(
- config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
- )
- self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- past_key_values=None,
- use_cache=False,
- output_attentions=False,
- **kwargs,
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.SelfAttention(
- normed_hidden_states,
- mask=attention_mask,
- position_bias=position_bias,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
- hidden_states = hidden_states + self.dropout(attention_output[0])
- outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
- return outputs
- # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop
- class UdopLayerCrossAttention(nn.Module):
- def __init__(self, config, layer_idx: int | None = None):
- super().__init__()
- self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
- self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- key_value_states,
- attention_mask=None,
- position_bias=None,
- past_key_values=None,
- output_attentions=False,
- **kwargs,
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.EncDecAttention(
- normed_hidden_states,
- mask=attention_mask,
- key_value_states=key_value_states,
- position_bias=position_bias,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- layer_output = hidden_states + self.dropout(attention_output[0])
- outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
- return outputs
- # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop
- class UdopBlock(GradientCheckpointingLayer):
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: int | None = None):
- super().__init__()
- self.is_decoder = config.is_decoder
- self.layer = nn.ModuleList()
- self.layer.append(
- UdopLayerSelfAttention(
- config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
- )
- )
- if self.is_decoder:
- self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx))
- self.layer.append(UdopLayerFF(config))
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- encoder_decoder_position_bias=None,
- past_key_values=None,
- use_cache=False,
- output_attentions=False,
- return_dict=True,
- **kwargs,
- ):
- self_attention_outputs = self.layer[0](
- hidden_states,
- attention_mask=attention_mask,
- position_bias=position_bias,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
- hidden_states = self_attention_outputs[0]
- attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16:
- clamp_value = torch.where(
- torch.isinf(hidden_states).any(),
- torch.finfo(hidden_states.dtype).max - 1000,
- torch.finfo(hidden_states.dtype).max,
- )
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- do_cross_attention = self.is_decoder and encoder_hidden_states is not None
- if do_cross_attention:
- cross_attention_outputs = self.layer[1](
- hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- position_bias=encoder_decoder_position_bias,
- past_key_values=past_key_values,
- output_attentions=output_attentions,
- )
- hidden_states = cross_attention_outputs[0]
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16:
- clamp_value = torch.where(
- torch.isinf(hidden_states).any(),
- torch.finfo(hidden_states.dtype).max - 1000,
- torch.finfo(hidden_states.dtype).max,
- )
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- # Keep cross-attention outputs and relative position weights
- attention_outputs = attention_outputs + cross_attention_outputs[1:]
- # Apply Feed Forward layer
- hidden_states = self.layer[-1](hidden_states)
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16:
- clamp_value = torch.where(
- torch.isinf(hidden_states).any(),
- torch.finfo(hidden_states.dtype).max - 1000,
- torch.finfo(hidden_states.dtype).max,
- )
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- outputs = (hidden_states,)
- return (
- outputs + attention_outputs
- ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
- class UdopCellEmbeddings(nn.Module):
- def __init__(self, max_2d_position_embeddings=501, hidden_size=1024):
- super().__init__()
- self.max_2d_position_embeddings = max_2d_position_embeddings
- self.x_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size)
- self.y_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size)
- def forward(self, bbox):
- bbox = torch.clip(bbox, 0.0, 1.0)
- bbox = (bbox * (self.max_2d_position_embeddings - 1)).long()
- left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
- upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
- right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
- lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
- embeddings = (
- left_position_embeddings
- + upper_position_embeddings
- + right_position_embeddings
- + lower_position_embeddings
- )
- return embeddings
- # get function for bucket computation
- # protected member access seems to be lesser evil than copy paste whole function
- get_relative_position_bucket = UdopAttention._relative_position_bucket
- AUGMENTATION_RANGE = (0.80, 1.25)
- class RelativePositionBiasBase(nn.Module, ABC):
- """
- Base class of relative biases.
- Args:
- num_heads (`int`):
- Number of attention heads in the model, it will create embeddings of size `num_heads`, which will be added to the scores of each token pair.
- relative_attention_num_buckets (`int`, *optional*, defaults to 32):
- Pair token metric (distance in the sequence, distance in pixels etc.) will be bucketed, parameter is defining number of such
- buckets.
- bidirectional (`bool`, *optional*, defaults to `True`):
- Whether the distance should be bidirectional for a pair of tokens. If `False`, then distance(tok1, tok2) == distance(tok2, tok1).
- scaling_factor (`int`, *optional*, defaults to 1):
- Defining factor which will be used to scale relative distance.
- max_distance (`int`, *optional*, defaults to 128):
- All distances above this value will end up in the one/same bucket.
- augmentation (`bool`, *optional*, defaults to `False`):
- Whether to multiply relative distances by a random scalar.
- expand (`bool`, *optional*, defaults to `False`):
- Whether to expand an existing pretrained model with subsequent additions of prefix_bucket.
- """
- def __init__(
- self,
- num_heads=None,
- relative_attention_num_buckets=32,
- bidirectional=True,
- scaling_factor=1,
- max_distance=128,
- level="tokens",
- augmentation=False,
- prefix_bucket=False,
- expand=False,
- ):
- super().__init__()
- self.prefix_bucket = prefix_bucket
- self.augmentation = augmentation
- self.level = level
- self.max_distance = max_distance
- self.scaling_factor = scaling_factor
- self.bidirectional = bidirectional
- self.num_heads = num_heads
- self.expand = expand
- self.relative_attention_num_buckets = relative_attention_num_buckets
- extra_head = 2 if prefix_bucket and not self.expand else 0
- self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets + extra_head, self.num_heads)
- @abstractmethod
- def prepare_input(
- self,
- attention_mask: Tensor | None = None,
- bbox: dict[str, Any] | None = None,
- ) -> Tensor:
- pass
- def get_bucket(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
- relative_position = self.prepare_input(attention_mask, bbox)
- rp_bucket: Tensor = get_relative_position_bucket(
- relative_position,
- bidirectional=self.bidirectional,
- num_buckets=self.relative_attention_num_buckets,
- max_distance=self.max_distance,
- )
- return rp_bucket
- def get_relative_position(self, positions):
- context_position = positions[:, :, None]
- memory_position = positions[:, None, :]
- relative_position = memory_position - context_position
- if self.augmentation and self.training:
- relative_position *= random.uniform(*AUGMENTATION_RANGE)
- relative_position *= self.scaling_factor
- return relative_position.to(torch.long)
- def forward(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
- # re-using pretrained model with subsequent addition of prefix_bucket
- if self.expand and self.prefix_bucket:
- new_bias = nn.Embedding(self.relative_attention_num_buckets + 2, self.num_heads)
- new_bias.weight.data[: self.relative_attention_num_buckets] = self.relative_attention_bias.weight.data
- new_bias.weight.data[self.relative_attention_num_buckets :] = 0.1
- self.relative_attention_bias = new_bias
- self.expand = False
- rp_bucket = self.get_bucket(attention_mask, bbox)
- if self.prefix_bucket:
- if rp_bucket.size(0) == 1 and attention_mask.size(0) > 1:
- rp_bucket = rp_bucket.repeat(attention_mask.size(0), 1, 1)
- # based on assumption that prefix bboxes are negative
- is_prefix = bbox[:, :, 1] < 0
- num_prefix = is_prefix.sum(-1)
- for idx, num_prefix_row in enumerate(num_prefix.cpu().numpy()):
- rp_bucket[idx, :num_prefix_row, num_prefix_row:] = self.relative_attention_num_buckets
- rp_bucket[idx, num_prefix_row:, :num_prefix_row] = self.relative_attention_num_buckets + 1
- values: Tensor = self.relative_attention_bias(rp_bucket)
- if values.dim() != 4:
- raise ValueError("Wrong dimension of values tensor")
- values = values.permute([0, 3, 1, 2])
- return values
- class RelativePositionBias1D(RelativePositionBiasBase):
- def __init__(self, scaling_factor=1, max_distance=128, **kwargs):
- """
- Reimplementation of T5 relative position bias. Distance between given tokens is their distance in the sequence.
- Parameters are the same as in base class
- """
- super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
- def prepare_input(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
- if self.scaling_factor != 1:
- raise ValueError("No need to scale 1d features")
- relative_position = self.get_relative_position(
- torch.arange(attention_mask.size(1), dtype=torch.long, device=attention_mask.device)[None, :]
- )
- return relative_position
- class RelativePositionBiasHorizontal(RelativePositionBiasBase):
- def __init__(self, scaling_factor=100, max_distance=100, **kwargs):
- """
- Represents in the bucket embeddings horizontal distance between two tokens. Parameters are the same as in base
- class
- """
- super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
- def prepare_input(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
- if not self.scaling_factor > 1.0:
- raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range")
- if bbox is None:
- raise ValueError("Bbox is required for horizontal relative position bias")
- # get x positions of left point of bbox
- horizontal_position: Tensor = bbox[:, :, [0, 2]].mean(dim=-1)
- return self.get_relative_position(horizontal_position)
- class RelativePositionBiasVertical(RelativePositionBiasBase):
- def __init__(self, scaling_factor=100, max_distance=100, **kwargs):
- """
- Represents in the bucket embeddings vertical distance between two tokens. Parameters are the same as in base
- class
- """
- super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
- def prepare_input(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> Tensor:
- if not self.scaling_factor > 1.0:
- raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range")
- if bbox is None:
- raise ValueError("Bbox is required for vertical relative position bias")
- # get y positions of middle of bbox
- vertical_position: Tensor = bbox[:, :, [1, 3]].mean(dim=-1)
- return self.get_relative_position(vertical_position)
- class RelativePositionBiasAggregated(nn.Module):
- def __init__(self, modules: Sequence[RelativePositionBiasBase]):
- """
- Class which sums up various computed biases.
- Args:
- modules (Sequence[RelativePositionBiasBase]):
- List of relative bias modules.
- """
- super().__init__()
- self.biases = nn.ModuleList(modules)
- def forward(self, attention_mask: Tensor | None = None, bbox: dict[str, Any] | None = None) -> float | Tensor:
- output = 0.0
- for bias in self.biases: # type: ignore
- output = bias(attention_mask, bbox) + output
- return output
- BIAS_CLASSES = {
- "1d": RelativePositionBias1D,
- "horizontal": RelativePositionBiasHorizontal,
- "vertical": RelativePositionBiasVertical,
- }
- def create_relative_bias(config: UdopConfig) -> Sequence[RelativePositionBiasBase]:
- """
- Creates empty list or one/multiple relative biases.
- :param config: Model's configuration :return: Sequence with created bias modules.
- """
- bias_list = []
- if hasattr(config, "relative_bias_args"):
- for bias_kwargs_org in config.relative_bias_args:
- bias_kwargs = deepcopy(bias_kwargs_org)
- bias_type = bias_kwargs.pop("type")
- model_num_heads = config.num_heads if hasattr(config, "num_heads") else config.num_attention_heads
- if "num_heads" in bias_kwargs:
- if bias_kwargs["num_heads"] != model_num_heads:
- raise ValueError("Number of heads must match num of heads in the model")
- else:
- bias_kwargs["num_heads"] = model_num_heads
- bias_list.append(BIAS_CLASSES[bias_type](**bias_kwargs)) # type: ignore
- return bias_list
- class UdopStack(UdopPreTrainedModel):
- """
- This class is based on `T5Stack`, but modified to take into account the image modality as well as 2D position
- embeddings.
- """
- def __init__(self, config):
- super().__init__(config)
- # text and image embeddings
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
- self.embed_patches = UdopPatchEmbeddings(config)
- self.is_decoder = config.is_decoder
- self.num_layers = config.num_layers
- self.block = nn.ModuleList(
- [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)]
- )
- self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- if not self.is_decoder:
- self.cell_2d_embedding = UdopCellEmbeddings(config.max_2d_position_embeddings, config.hidden_size)
- # get weights from encoder position bias
- self.relative_bias = self._get_relative_bias(config)
- self.post_init()
- @staticmethod
- def _get_relative_bias(config: UdopConfig) -> RelativePositionBiasAggregated:
- relative_bias_list = create_relative_bias(config)
- return RelativePositionBiasAggregated(relative_bias_list)
- def get_output_embeddings(self):
- return self.embed_tokens
- def set_input_embeddings(self, new_embeddings):
- self.embed_tokens = new_embeddings
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- bbox=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- inputs_embeds=None,
- pixel_values=None,
- visual_bbox=None,
- image_embeddings=None,
- position_bias=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- **kwargs,
- ) -> tuple | BaseModelOutputWithAttentionMask:
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # input embeddings processing
- if input_ids is not None and inputs_embeds is not None:
- err_msg_prefix = "decoder_" if self.is_decoder else ""
- raise ValueError(
- f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
- )
- elif input_ids is not None and torch.numel(input_ids) > 0:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is None and input_ids is not None and torch.numel(input_ids) == 0:
- input_ids = torch.full((4, 1024), self.config.pad_token_id, device=input_ids.device, dtype=input_ids.dtype)
- attention_mask = torch.zeros((4, 1024), device=input_ids.device, dtype=input_ids.dtype)
- bbox = torch.zeros((4, 1024, 4), device=input_ids.device, dtype=input_ids.dtype)
- input_shape = input_ids.size()
- position_bias = torch.zeros_like(self.get_extended_attention_mask(attention_mask, input_shape))
- # encoder_attention_mask = attention_mask
- logger.warning("Empty batch")
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- err_msg_prefix = "decoder_" if self.is_decoder else ""
- raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
- if inputs_embeds is None:
- if self.embed_tokens is None:
- raise ValueError("You have to initialize the model with valid token embeddings")
- inputs_embeds = self.embed_tokens(input_ids)
- if pixel_values is not None:
- image_embeddings = self.embed_patches(pixel_values)
- if image_embeddings is not None:
- # combine visual and OCR text embeddings
- num_patches = self.config.image_size // self.config.patch_size
- inputs_embeds, bbox, attention_mask = combine_image_text_embeddings(
- image_embeddings,
- inputs_embeds,
- bbox,
- visual_bbox,
- attention_mask,
- num_patches,
- 0,
- self.config.image_size,
- self.config.patch_size,
- )
- input_shape = inputs_embeds.size()[:-1]
- if not self.is_decoder and bbox is not None:
- inputs_embeds += self.cell_2d_embedding(bbox)
- batch_size, seq_length = input_shape
- if use_cache is True:
- assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
- if self.is_decoder:
- if use_cache and past_key_values is None:
- if self.config.is_encoder_decoder:
- past_key_values = EncoderDecoderCache(
- DynamicCache(config=self.config), DynamicCache(config=self.config)
- )
- else:
- past_key_values = DynamicCache(config=self.config)
- elif not self.is_decoder:
- # do not pass cache object down the line for encoder stack
- # it messes indexing later in decoder-stack because cache object is modified in-place
- past_key_values = None
- past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
- if attention_mask is None and not is_torchdynamo_compiling():
- # required mask seq length can be calculated via length of past cache
- mask_seq_length = past_key_values_length + seq_length
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
- if self.config.is_decoder:
- causal_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- )
- else:
- causal_mask = attention_mask[:, None, None, :]
- causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
- causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
- if self.is_decoder and encoder_attention_mask is not None:
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = None
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- all_cross_attentions = () if (output_attentions and self.is_decoder) else None
- if self.is_decoder: # modified lines
- position_bias = None
- else:
- position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox)
- position_bias = position_bias + causal_mask
- encoder_decoder_position_bias = None
- hidden_states = inputs_embeds
- hidden_states = self.dropout(hidden_states)
- for i, layer_module in enumerate(self.block):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module(
- hidden_states,
- causal_mask,
- position_bias,
- encoder_hidden_states,
- encoder_extended_attention_mask,
- encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- # We share the position biases between the layers - the first layer store them
- # layer_outputs = hidden-states, key-value-states (self-attention weights),
- # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
- position_bias = layer_outputs[1]
- if self.is_decoder and encoder_hidden_states is not None:
- encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
- if self.is_decoder:
- all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- # Add last layer
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(
- v
- for v in [
- hidden_states,
- attention_mask,
- past_key_values,
- all_hidden_states,
- all_attentions,
- all_cross_attentions,
- ]
- if v is not None
- )
- return BaseModelOutputWithAttentionMask(
- last_hidden_state=hidden_states,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_attentions,
- cross_attentions=all_cross_attentions,
- )
- @auto_docstring
- class UdopModel(UdopPreTrainedModel):
- _tied_weights_keys = {
- "encoder.embed_tokens.weight": "shared.weight",
- "decoder.embed_tokens.weight": "shared.weight",
- "encoder.embed_patches.proj.weight": "patch_embed.proj.weight",
- "encoder.embed_patches.proj.bias": "patch_embed.proj.bias",
- }
- def __init__(self, config):
- super().__init__(config)
- # text and image embeddings
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- self.patch_embed = UdopPatchEmbeddings(config)
- encoder_config = deepcopy(config)
- encoder_config.is_decoder = False
- encoder_config.use_cache = False
- self.encoder = UdopStack(encoder_config)
- decoder_config = deepcopy(config)
- decoder_config.is_decoder = True
- decoder_config.num_layers = config.num_decoder_layers
- self.decoder = UdopStack(decoder_config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- self.decoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: Tensor | None = None,
- attention_mask: Tensor | None = None,
- bbox: dict[str, Any] | None = None,
- pixel_values: Tensor | None = None,
- visual_bbox: dict[str, Any] | None = None,
- decoder_input_ids: Tensor | None = None,
- decoder_attention_mask: Tensor | None = None,
- inputs_embeds: Tensor | None = None,
- encoder_outputs: Tensor | None = None,
- past_key_values: Cache | None = None,
- decoder_inputs_embeds: Tensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | Seq2SeqModelOutput:
- r"""
- bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
- Bounding boxes of each input sequence tokens. Selected in the range `[0,
- config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
- format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
- y1) represents the position of the lower right corner.
- Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
- token. See `pixel_values` for `patch_sequence_length`.
- visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
- Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
- [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting
- token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
- `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
- `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- Example:
- ```python
- >>> from transformers import AutoProcessor, AutoModel
- >>> from datasets import load_dataset
- >>> import torch
- >>> # load model and processor
- >>> # in this case, we already have performed OCR ourselves
- >>> # so we initialize the processor with `apply_ocr=False`
- >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
- >>> model = AutoModel.from_pretrained("microsoft/udop-large")
- >>> # load an example image, along with the words and coordinates
- >>> # which were extracted using an OCR engine
- >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
- >>> example = dataset[0]
- >>> image = example["image"]
- >>> words = example["tokens"]
- >>> boxes = example["bboxes"]
- >>> inputs = processor(image, words, boxes=boxes, return_tensors="pt")
- >>> decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]])
- >>> # forward pass
- >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
- >>> last_hidden_states = outputs.last_hidden_state
- >>> list(last_hidden_states.shape)
- [1, 1, 1024]
- ```"""
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # Encode if needed (training, first prediction pass)
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- bbox=bbox,
- pixel_values=pixel_values,
- visual_bbox=visual_bbox,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = encoder_outputs[0]
- encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1]
- # Decode
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if not return_dict:
- # we filter out the attention mask
- decoder_outputs = tuple(value for idx, value in enumerate(decoder_outputs) if idx != 1)
- encoder_outputs = tuple(value for idx, value in enumerate(encoder_outputs) if idx != 1)
- return decoder_outputs + encoder_outputs
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- The UDOP encoder-decoder Transformer with a language modeling head on top, enabling to generate text given document
- images and an optional prompt.
- This class is based on [`T5ForConditionalGeneration`], extended to deal with images and layout (2D) data.
- """
- )
- class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
- _tied_weights_keys = {
- "encoder.embed_tokens.weight": "shared.weight",
- "decoder.embed_tokens.weight": "shared.weight",
- "encoder.embed_patches.proj.weight": "patch_embed.proj.weight",
- "encoder.embed_patches.proj.bias": "patch_embed.proj.bias",
- "encoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
- "decoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
- "lm_head.weight": "shared.weight",
- }
- def __init__(self, config):
- super().__init__(config)
- # text and image embeddings
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- self.patch_embed = UdopPatchEmbeddings(config)
- encoder_config = deepcopy(config)
- encoder_config.is_decoder = False
- encoder_config.use_cache = False
- self.encoder = UdopStack(encoder_config)
- decoder_config = deepcopy(config)
- decoder_config.is_decoder = True
- decoder_config.num_layers = config.num_decoder_layers
- self.decoder = UdopStack(decoder_config)
- # The weights of the language modeling head are shared with those of the encoder and decoder
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- self.decoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: Tensor | None = None,
- attention_mask: Tensor | None = None,
- bbox: dict[str, Any] | None = None,
- pixel_values: Tensor | None = None,
- visual_bbox: dict[str, Any] | None = None,
- decoder_input_ids: Tensor | None = None,
- decoder_attention_mask: Tensor | None = None,
- inputs_embeds: Tensor | None = None,
- encoder_outputs: Tensor | None = None,
- past_key_values: Cache | None = None,
- decoder_inputs_embeds: Tensor | None = None,
- use_cache: bool | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- labels: Tensor | None = None,
- **kwargs,
- ) -> tuple | Seq2SeqLMOutput:
- r"""
- bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
- Bounding boxes of each input sequence tokens. Selected in the range `[0,
- config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
- format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
- y1) represents the position of the lower right corner.
- Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
- token. See `pixel_values` for `patch_sequence_length`.
- visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
- Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
- [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting
- token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
- `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
- `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
- 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
- config.vocab_size]`.
- Examples:
- ```python
- >>> from transformers import AutoProcessor, UdopForConditionalGeneration
- >>> from datasets import load_dataset
- >>> # load model and processor
- >>> # in this case, we already have performed OCR ourselves
- >>> # so we initialize the processor with `apply_ocr=False`
- >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
- >>> model = UdopForConditionalGeneration.from_pretrained("microsoft/udop-large")
- >>> # load an example image, along with the words and coordinates
- >>> # which were extracted using an OCR engine
- >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
- >>> example = dataset[0]
- >>> image = example["image"]
- >>> words = example["tokens"]
- >>> boxes = example["bboxes"]
- >>> # one can use the various task prefixes (prompts) used during pre-training
- >>> # e.g. the task prefix for DocVQA is "Question answering. "
- >>> question = "Question answering. What is the date on the form?"
- >>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt")
- >>> # autoregressive generation
- >>> predicted_ids = model.generate(**encoding)
- >>> print(processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
- 9/30/92
- ```"""
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if decoder_input_ids is None and labels is not None:
- decoder_input_ids = self._shift_right(labels)
- # Encode if needed (training, first prediction pass)
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- bbox=bbox,
- visual_bbox=visual_bbox,
- pixel_values=pixel_values,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = encoder_outputs[0]
- encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1]
- # Decode
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = decoder_outputs[0]
- if self.config.tie_word_embeddings:
- sequence_output = sequence_output * (self.config.d_model**-0.5)
- lm_logits = self.lm_head(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss(ignore_index=-100)
- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
- if not return_dict:
- output = (lm_logits,) + decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return Seq2SeqLMOutput(
- loss=loss,
- logits=lm_logits,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
- @auto_docstring
- class UdopEncoderModel(UdopPreTrainedModel):
- _tied_weights_keys = {
- "encoder.embed_tokens.weight": "shared.weight",
- "encoder.embed_patches.proj.weight": "patch_embed.proj.weight",
- "encoder.embed_patches.proj.bias": "patch_embed.proj.bias",
- "encoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight",
- }
- def __init__(self, config: UdopConfig):
- super().__init__(config)
- # text and image embeddings
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
- self.patch_embed = UdopPatchEmbeddings(config)
- encoder_config = deepcopy(config)
- encoder_config.is_decoder = False
- encoder_config.use_cache = False
- encoder_config.is_encoder_decoder = False
- self.encoder = UdopStack(encoder_config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.shared
- def set_input_embeddings(self, new_embeddings):
- self.shared = new_embeddings
- self.encoder.set_input_embeddings(new_embeddings)
- @auto_docstring
- def forward(
- self,
- input_ids: Tensor | None = None,
- bbox: dict[str, Any] | None = None,
- attention_mask: Tensor | None = None,
- pixel_values: Tensor | None = None,
- visual_bbox: dict[str, Any] | None = None,
- inputs_embeds: Tensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple[torch.FloatTensor] | BaseModelOutputWithAttentionMask:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
- should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
- bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
- Bounding boxes of each input sequence tokens. Selected in the range `[0,
- config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
- format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
- y1) represents the position of the lower right corner.
- Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
- token. See `pixel_values` for `patch_sequence_length`.
- visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
- Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
- Example:
- ```python
- >>> from transformers import AutoProcessor, UdopEncoderModel
- >>> from huggingface_hub import hf_hub_download
- >>> from datasets import load_dataset
- >>> # load model and processor
- >>> # in this case, we already have performed OCR ourselves
- >>> # so we initialize the processor with `apply_ocr=False`
- >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
- >>> model = UdopEncoderModel.from_pretrained("microsoft/udop-large")
- >>> # load an example image, along with the words and coordinates
- >>> # which were extracted using an OCR engine
- >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
- >>> example = dataset[0]
- >>> image = example["image"]
- >>> words = example["tokens"]
- >>> boxes = example["bboxes"]
- >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
- >>> outputs = model(**encoding)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- bbox=bbox,
- visual_bbox=visual_bbox,
- pixel_values=pixel_values,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- return encoder_outputs
- __all__ = ["UdopForConditionalGeneration", "UdopPreTrainedModel", "UdopModel", "UdopEncoderModel"]
|