| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159 |
- # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch LayoutLMv3 model."""
- import collections
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import (
- auto_docstring,
- can_return_tuple,
- logging,
- torch_int,
- )
- from ...utils.generic import TransformersKwargs, merge_with_config_defaults
- from ...utils.output_capturing import capture_outputs
- from .configuration_layoutlmv3 import LayoutLMv3Config
- logger = logging.get_logger(__name__)
- class LayoutLMv3PatchEmbeddings(nn.Module):
- """LayoutLMv3 image (patch) embeddings. This class also automatically interpolates the position embeddings for varying
- image sizes."""
- def __init__(self, config):
- super().__init__()
- image_size = (
- config.input_size
- if isinstance(config.input_size, collections.abc.Iterable)
- else (config.input_size, config.input_size)
- )
- patch_size = (
- config.patch_size
- if isinstance(config.patch_size, collections.abc.Iterable)
- else (config.patch_size, config.patch_size)
- )
- self.patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
- self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)
- def forward(self, pixel_values, position_embedding=None):
- embeddings = self.proj(pixel_values)
- if position_embedding is not None:
- # interpolate the position embedding to the corresponding size
- position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1)
- position_embedding = position_embedding.permute(0, 3, 1, 2)
- patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
- position_embedding = F.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic")
- embeddings = embeddings + position_embedding
- embeddings = embeddings.flatten(2).transpose(1, 2)
- return embeddings
- class LayoutLMv3TextEmbeddings(nn.Module):
- """
- LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings.
- """
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- self.padding_idx = config.pad_token_id
- self.position_embeddings = nn.Embedding(
- config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
- )
- self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
- self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
- self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
- self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
- def calculate_spatial_position_embeddings(self, bbox):
- try:
- left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
- upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
- right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
- lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
- except IndexError as e:
- raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
- h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))
- w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))
- # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)
- spatial_position_embeddings = torch.cat(
- [
- left_position_embeddings,
- upper_position_embeddings,
- right_position_embeddings,
- lower_position_embeddings,
- h_position_embeddings,
- w_position_embeddings,
- ],
- dim=-1,
- )
- return spatial_position_embeddings
- def create_position_ids_from_input_ids(self, input_ids, padding_idx):
- """
- Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
- symbols are ignored. This is modified from fairseq's `utils.make_positions`.
- """
- # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
- mask = input_ids.ne(padding_idx).int()
- incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
- return incremental_indices.long() + padding_idx
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
- """
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
- """
- input_shape = inputs_embeds.size()[:-1]
- sequence_length = input_shape[1]
- position_ids = torch.arange(
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
- )
- return position_ids.unsqueeze(0).expand(input_shape)
- def forward(
- self,
- input_ids=None,
- bbox=None,
- token_type_ids=None,
- position_ids=None,
- inputs_embeds=None,
- ):
- if position_ids is None:
- if input_ids is not None:
- # Create the position ids from the input token ids. Any padded tokens remain padded.
- position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
- input_ids.device
- )
- else:
- position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + token_type_embeddings
- position_embeddings = self.position_embeddings(position_ids)
- embeddings += position_embeddings
- spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox)
- embeddings = embeddings + spatial_position_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class LayoutLMv3SelfAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.has_relative_attention_bias = config.has_relative_attention_bias
- self.has_spatial_attention_bias = config.has_spatial_attention_bias
- def cogview_attention(self, attention_scores, alpha=32):
- """
- https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
- (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
- will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
- cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
- """
- scaled_attention_scores = attention_scores / alpha
- max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
- new_attention_scores = (scaled_attention_scores - max_value) * alpha
- return nn.Softmax(dim=-1)(new_attention_scores)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- rel_pos=None,
- rel_2d_pos=None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- batch_size = hidden_states.shape[0]
- query_layer = (
- self.query(hidden_states)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- key_layer = (
- self.key(hidden_states)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- value_layer = (
- self.value(hidden_states)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- # Take the dot product between "query" and "key" to get the raw attention scores.
- # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
- # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
- attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
- if self.has_relative_attention_bias and self.has_spatial_attention_bias:
- attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
- elif self.has_relative_attention_bias:
- attention_scores += rel_pos / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- # Use the trick of the CogView paper to stabilize training
- attention_probs = self.cogview_attention(attention_scores)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- return context_layer, attention_probs
- # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput
- class LayoutLMv3SelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
- class LayoutLMv3Attention(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.self = LayoutLMv3SelfAttention(config)
- self.output = LayoutLMv3SelfOutput(config)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- rel_pos=None,
- rel_2d_pos=None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- residual = hidden_states
- attention_output, _ = self.self(
- hidden_states,
- attention_mask,
- rel_pos=rel_pos,
- rel_2d_pos=rel_2d_pos,
- **kwargs,
- )
- attention_output = self.output(attention_output, residual)
- return attention_output
- # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
- class LayoutLMv3Layer(GradientCheckpointingLayer):
- def __init__(self, config):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = LayoutLMv3Attention(config)
- self.intermediate = LayoutLMv3Intermediate(config)
- self.output = LayoutLMv3Output(config)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- output_attentions=False,
- rel_pos=None,
- rel_2d_pos=None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- attention_output = self.attention(
- hidden_states,
- attention_mask,
- rel_pos=rel_pos,
- rel_2d_pos=rel_2d_pos,
- )
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
- )
- return layer_output
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- class LayoutLMv3Encoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- self.has_relative_attention_bias = config.has_relative_attention_bias
- self.has_spatial_attention_bias = config.has_spatial_attention_bias
- if self.has_relative_attention_bias:
- self.rel_pos_bins = config.rel_pos_bins
- self.max_rel_pos = config.max_rel_pos
- self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
- if self.has_spatial_attention_bias:
- self.max_rel_2d_pos = config.max_rel_2d_pos
- self.rel_2d_pos_bins = config.rel_2d_pos_bins
- self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
- self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
- def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
- ret = 0
- if bidirectional:
- num_buckets //= 2
- ret += (relative_position > 0).long() * num_buckets
- n = torch.abs(relative_position)
- else:
- n = torch.max(-relative_position, torch.zeros_like(relative_position))
- # now n is in the range [0, inf)
- # half of the buckets are for exact increments in positions
- max_exact = num_buckets // 2
- is_small = n < max_exact
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- val_if_large = max_exact + (
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
- ).to(torch.long)
- val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
- ret += torch.where(is_small, n, val_if_large)
- return ret
- def _cal_1d_pos_emb(self, position_ids):
- rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
- rel_pos = self.relative_position_bucket(
- rel_pos_mat,
- num_buckets=self.rel_pos_bins,
- max_distance=self.max_rel_pos,
- )
- # Since this is a simple indexing operation that is independent of the input,
- # no need to track gradients for this operation
- #
- # Without this no_grad context, training speed slows down significantly
- with torch.no_grad():
- rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
- rel_pos = rel_pos.contiguous()
- return rel_pos
- def _cal_2d_pos_emb(self, bbox):
- position_coord_x = bbox[:, :, 0]
- position_coord_y = bbox[:, :, 3]
- rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
- rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
- rel_pos_x = self.relative_position_bucket(
- rel_pos_x_2d_mat,
- num_buckets=self.rel_2d_pos_bins,
- max_distance=self.max_rel_2d_pos,
- )
- rel_pos_y = self.relative_position_bucket(
- rel_pos_y_2d_mat,
- num_buckets=self.rel_2d_pos_bins,
- max_distance=self.max_rel_2d_pos,
- )
- # Since this is a simple indexing operation that is independent of the input,
- # no need to track gradients for this operation
- #
- # Without this no_grad context, training speed slows down significantly
- with torch.no_grad():
- rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
- rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
- rel_pos_x = rel_pos_x.contiguous()
- rel_pos_y = rel_pos_y.contiguous()
- rel_2d_pos = rel_pos_x + rel_pos_y
- return rel_2d_pos
- def forward(
- self,
- hidden_states,
- bbox=None,
- attention_mask=None,
- position_ids=None,
- patch_height=None,
- patch_width=None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- rel_pos = self._cal_1d_pos_emb(position_ids) if self.has_relative_attention_bias else None
- rel_2d_pos = self._cal_2d_pos_emb(bbox) if self.has_spatial_attention_bias else None
- for layer_module in self.layer:
- hidden_states = layer_module(
- hidden_states,
- attention_mask,
- rel_pos=rel_pos,
- rel_2d_pos=rel_2d_pos,
- **kwargs,
- )
- return BaseModelOutput(last_hidden_state=hidden_states)
- # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate
- class LayoutLMv3Intermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- # Copied from transformers.models.roberta.modeling_roberta.RobertaOutput
- class LayoutLMv3Output(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- @auto_docstring
- class LayoutLMv3PreTrainedModel(PreTrainedModel):
- config: LayoutLMv3Config
- base_model_prefix = "layoutlmv3"
- input_modalities = ("image", "text")
- _can_record_outputs = {"hidden_states": LayoutLMv3Layer, "attentions": LayoutLMv3SelfAttention}
- @torch.no_grad()
- def _init_weights(self, module):
- """Initialize the weights"""
- super()._init_weights(module)
- if isinstance(module, LayoutLMv3Model):
- if self.config.visual_embed:
- init.zeros_(module.cls_token)
- init.zeros_(module.pos_embed)
- if hasattr(module, "visual_bbox"):
- init.copy_(module.visual_bbox, module.create_visual_bbox(image_size=(module.size, module.size)))
- elif isinstance(module, LayoutLMv3TextEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- @auto_docstring
- class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- if config.text_embed:
- self.embeddings = LayoutLMv3TextEmbeddings(config)
- if config.visual_embed:
- # use the default pre-training parameters for fine-tuning (e.g., input_size)
- # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward
- self.patch_embed = LayoutLMv3PatchEmbeddings(config)
- self.size = int(config.input_size / config.patch_size)
- self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
- self.pos_embed = nn.Parameter(torch.zeros(1, self.size * self.size + 1, config.hidden_size))
- self.pos_drop = nn.Dropout(p=0.0)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
- self.register_buffer(
- "visual_bbox", self.create_visual_bbox(image_size=(self.size, self.size)), persistent=False
- )
- self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
- self.encoder = LayoutLMv3Encoder(config)
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- def create_visual_bbox(self, image_size=(14, 14), max_len=1000):
- """
- Create the bounding boxes for the visual (patch) tokens.
- """
- visual_bbox_x = torch.div(
- torch.arange(0, max_len * (image_size[1] + 1), max_len), image_size[1], rounding_mode="trunc"
- )
- visual_bbox_y = torch.div(
- torch.arange(0, max_len * (image_size[0] + 1), max_len), image_size[0], rounding_mode="trunc"
- )
- visual_bbox = torch.stack(
- [
- visual_bbox_x[:-1].repeat(image_size[0], 1),
- visual_bbox_y[:-1].repeat(image_size[1], 1).transpose(0, 1),
- visual_bbox_x[1:].repeat(image_size[0], 1),
- visual_bbox_y[1:].repeat(image_size[1], 1).transpose(0, 1),
- ],
- dim=-1,
- ).view(-1, 4)
- cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
- return torch.cat([cls_token_box, visual_bbox], dim=0)
- def calculate_visual_bbox(self, device, dtype, batch_size):
- visual_bbox = self.visual_bbox.repeat(batch_size, 1, 1)
- visual_bbox = visual_bbox.to(device).type(dtype)
- return visual_bbox
- def forward_image(self, pixel_values):
- embeddings = self.patch_embed(pixel_values)
- # add [CLS] token
- batch_size, seq_len, _ = embeddings.size()
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
- embeddings = torch.cat((cls_tokens, embeddings), dim=1)
- # add position embeddings
- if self.pos_embed is not None:
- embeddings = embeddings + self.pos_embed
- embeddings = self.pos_drop(embeddings)
- embeddings = self.norm(embeddings)
- return embeddings
- @merge_with_config_defaults
- @capture_outputs
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- bbox: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- pixel_values: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | BaseModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
- token. See `pixel_values` for `patch_sequence_length`.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- bbox (`torch.LongTensor` of shape `(batch_size, token_sequence_length, 4)`, *optional*):
- Bounding boxes of each input sequence tokens. Selected in the range `[0,
- config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
- format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
- y1) represents the position of the lower right corner.
- Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
- token. See `pixel_values` for `patch_sequence_length`.
- token_type_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
- token. See `pixel_values` for `patch_sequence_length`.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
- token. See `pixel_values` for `patch_sequence_length`.
- [What are position IDs?](../glossary#position-ids)
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, token_sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
- model's internal embedding lookup matrix.
- Examples:
- ```python
- >>> from transformers import AutoProcessor, AutoModel
- >>> from datasets import load_dataset
- >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
- >>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base")
- >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
- >>> example = dataset[0]
- >>> image = example["image"]
- >>> words = example["tokens"]
- >>> boxes = example["bboxes"]
- >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
- >>> outputs = model(**encoding)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- if input_ids is not None:
- input_shape = input_ids.size()
- batch_size, seq_length = input_shape
- device = input_ids.device
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- batch_size, seq_length = input_shape
- device = inputs_embeds.device
- elif pixel_values is not None:
- batch_size = len(pixel_values)
- device = pixel_values.device
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values")
- if input_ids is not None or inputs_embeds is not None:
- if attention_mask is None:
- attention_mask = torch.ones(((batch_size, seq_length)), device=device)
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
- if bbox is None:
- bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
- embedding_output = self.embeddings(
- input_ids=input_ids,
- bbox=bbox,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- )
- final_bbox = final_position_ids = None
- patch_height = patch_width = None
- if pixel_values is not None:
- patch_height, patch_width = (
- torch_int(pixel_values.shape[2] / self.config.patch_size),
- torch_int(pixel_values.shape[3] / self.config.patch_size),
- )
- visual_embeddings = self.forward_image(pixel_values)
- visual_attention_mask = torch.ones(
- (batch_size, visual_embeddings.shape[1]), dtype=torch.long, device=device
- )
- if attention_mask is not None:
- attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
- else:
- attention_mask = visual_attention_mask
- if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
- if self.config.has_spatial_attention_bias:
- visual_bbox = self.calculate_visual_bbox(device, dtype=torch.long, batch_size=batch_size)
- if bbox is not None:
- final_bbox = torch.cat([bbox, visual_bbox], dim=1)
- else:
- final_bbox = visual_bbox
- visual_position_ids = torch.arange(
- 0, visual_embeddings.shape[1], dtype=torch.long, device=device
- ).repeat(batch_size, 1)
- if input_ids is not None or inputs_embeds is not None:
- position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
- position_ids = position_ids.expand(input_shape)
- final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
- else:
- final_position_ids = visual_position_ids
- if input_ids is not None or inputs_embeds is not None:
- embedding_output = torch.cat([embedding_output, visual_embeddings], dim=1)
- else:
- embedding_output = visual_embeddings
- embedding_output = self.LayerNorm(embedding_output)
- embedding_output = self.dropout(embedding_output)
- elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
- if self.config.has_spatial_attention_bias:
- final_bbox = bbox
- if self.config.has_relative_attention_bias:
- position_ids = self.embeddings.position_ids[:, : input_shape[1]]
- position_ids = position_ids.expand_as(input_ids)
- final_position_ids = position_ids
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
- attention_mask, None, dtype=embedding_output.dtype
- )
- encoder_outputs = self.encoder(
- embedding_output,
- bbox=final_bbox,
- position_ids=final_position_ids,
- attention_mask=extended_attention_mask,
- patch_height=patch_height,
- patch_width=patch_width,
- **kwargs,
- )
- sequence_output = encoder_outputs.last_hidden_state
- return BaseModelOutput(
- last_hidden_state=sequence_output,
- )
- class LayoutLMv3ClassificationHead(nn.Module):
- """
- Head for sentence-level classification tasks. Reference: RobertaClassificationHead
- """
- def __init__(self, config, pool_feature=False):
- super().__init__()
- self.pool_feature = pool_feature
- if pool_feature:
- self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size)
- else:
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- classifier_dropout = (
- config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(classifier_dropout)
- self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
- def forward(self, x):
- x = self.dropout(x)
- x = self.dense(x)
- x = torch.tanh(x)
- x = self.dropout(x)
- x = self.out_proj(x)
- return x
- @auto_docstring(
- custom_intro="""
- LayoutLMv3 Model with a token classification head on top (a linear layer on top of the final hidden states) e.g.
- for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/),
- [SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and
- [Kleister-NDA](https://github.com/applicaai/kleister-nda).
- """
- )
- class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.layoutlmv3 = LayoutLMv3Model(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- if config.num_labels < 10:
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- else:
- self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
- self.post_init()
- def get_input_embeddings(self):
- return self.layoutlmv3.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.layoutlmv3.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- bbox: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- pixel_values: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | TokenClassifierOutput:
- r"""
- bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
- Bounding boxes of each input sequence tokens. Selected in the range `[0,
- config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
- format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
- y1) represents the position of the lower right corner.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- Examples:
- ```python
- >>> from transformers import AutoProcessor, AutoModelForTokenClassification
- >>> from datasets import load_dataset
- >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
- >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7)
- >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
- >>> example = dataset[0]
- >>> image = example["image"]
- >>> words = example["tokens"]
- >>> boxes = example["bboxes"]
- >>> word_labels = example["ner_tags"]
- >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
- >>> outputs = model(**encoding)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
- outputs = self.layoutlmv3(
- input_ids,
- bbox=bbox,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- **kwargs,
- )
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- # only take the text part of the output representations
- sequence_output = outputs[0][:, :seq_length]
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.layoutlmv3 = LayoutLMv3Model(config)
- self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
- self.post_init()
- def get_input_embeddings(self):
- return self.layoutlmv3.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.layoutlmv3.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- start_positions: torch.LongTensor | None = None,
- end_positions: torch.LongTensor | None = None,
- bbox: torch.LongTensor | None = None,
- pixel_values: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | QuestionAnsweringModelOutput:
- r"""
- bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
- Bounding boxes of each input sequence tokens. Selected in the range `[0,
- config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
- format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
- y1) represents the position of the lower right corner.
- Examples:
- ```python
- >>> from transformers import AutoProcessor, AutoModelForQuestionAnswering
- >>> from datasets import load_dataset
- >>> import torch
- >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
- >>> model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
- >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
- >>> example = dataset[0]
- >>> image = example["image"]
- >>> question = "what's his name?"
- >>> words = example["tokens"]
- >>> boxes = example["bboxes"]
- >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
- >>> start_positions = torch.tensor([1])
- >>> end_positions = torch.tensor([3])
- >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)
- >>> loss = outputs.loss
- >>> start_scores = outputs.start_logits
- >>> end_scores = outputs.end_logits
- ```"""
- outputs: BaseModelOutput = self.layoutlmv3(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- bbox=bbox,
- pixel_values=pixel_values,
- **kwargs,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions = start_positions.clamp(0, ignored_index)
- end_positions = end_positions.clamp(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- LayoutLMv3 Model with a sequence classification head on top (a linear layer on top of the final hidden state of the
- [CLS] token) e.g. for document image classification tasks such as the
- [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
- """
- )
- class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.layoutlmv3 = LayoutLMv3Model(config)
- self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
- self.post_init()
- def get_input_embeddings(self):
- return self.layoutlmv3.get_input_embeddings()
- def set_input_embeddings(self, value):
- self.layoutlmv3.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- bbox: torch.LongTensor | None = None,
- pixel_values: torch.LongTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple | SequenceClassifierOutput:
- r"""
- bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
- Bounding boxes of each input sequence tokens. Selected in the range `[0,
- config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
- format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
- y1) represents the position of the lower right corner.
- Examples:
- ```python
- >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
- >>> from datasets import load_dataset
- >>> import torch
- >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
- >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
- >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
- >>> example = dataset[0]
- >>> image = example["image"]
- >>> words = example["tokens"]
- >>> boxes = example["bboxes"]
- >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
- >>> sequence_label = torch.tensor([1])
- >>> outputs = model(**encoding, labels=sequence_label)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
- outputs: BaseModelOutput = self.layoutlmv3(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- bbox=bbox,
- pixel_values=pixel_values,
- **kwargs,
- )
- sequence_output = outputs[0][:, 0, :]
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "LayoutLMv3ForQuestionAnswering",
- "LayoutLMv3ForSequenceClassification",
- "LayoutLMv3ForTokenClassification",
- "LayoutLMv3Model",
- "LayoutLMv3PreTrainedModel",
- ]
|