| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971 |
- # Copyright 2023-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Bros model."""
- import math
- from dataclasses import dataclass
- import torch
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutputWithCrossAttentions,
- BaseModelOutputWithPoolingAndCrossAttentions,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...processing_utils import Unpack
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
- from ...utils.generic import merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from .configuration_bros import BrosConfig
- logger = logging.get_logger(__name__)
- @dataclass
- @auto_docstring(
- custom_intro="""
- Base class for outputs of token classification models.
- """
- )
- class BrosSpadeOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification loss.
- initial_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
- Classification scores for entity initial tokens (before SoftMax).
- subsequent_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length+1)`):
- Classification scores for entity sequence tokens (before SoftMax).
- """
- loss: torch.FloatTensor | None = None
- initial_token_logits: torch.FloatTensor | None = None
- subsequent_token_logits: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- class BrosPositionalEmbedding1D(nn.Module):
- # Reference: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15
- def __init__(self, config):
- super().__init__()
- self.dim_bbox_sinusoid_emb_1d = config.dim_bbox_sinusoid_emb_1d
- inv_freq = 1 / (
- 10000 ** (torch.arange(0.0, self.dim_bbox_sinusoid_emb_1d, 2.0) / self.dim_bbox_sinusoid_emb_1d)
- )
- self.register_buffer("inv_freq", inv_freq)
- def forward(self, pos_seq: torch.Tensor) -> torch.Tensor:
- seq_size = pos_seq.size()
- b1, b2, b3 = seq_size
- sinusoid_inp = pos_seq.view(b1, b2, b3, 1) * self.inv_freq.view(1, 1, 1, self.dim_bbox_sinusoid_emb_1d // 2)
- pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
- return pos_emb
- class BrosPositionalEmbedding2D(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dim_bbox = config.dim_bbox
- self.x_pos_emb = BrosPositionalEmbedding1D(config)
- self.y_pos_emb = BrosPositionalEmbedding1D(config)
- def forward(self, bbox: torch.Tensor) -> torch.Tensor:
- stack = []
- for i in range(self.dim_bbox):
- if i % 2 == 0:
- stack.append(self.x_pos_emb(bbox[..., i]))
- else:
- stack.append(self.y_pos_emb(bbox[..., i]))
- bbox_pos_emb = torch.cat(stack, dim=-1)
- return bbox_pos_emb
- class BrosBboxEmbeddings(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.bbox_sinusoid_emb = BrosPositionalEmbedding2D(config)
- self.bbox_projection = nn.Linear(config.dim_bbox_sinusoid_emb_2d, config.dim_bbox_projection, bias=False)
- def forward(self, bbox: torch.Tensor):
- bbox_t = bbox.transpose(0, 1)
- bbox_pos = bbox_t[None, :, :, :] - bbox_t[:, None, :, :]
- bbox_pos_emb = self.bbox_sinusoid_emb(bbox_pos)
- bbox_pos_emb = self.bbox_projection(bbox_pos_emb)
- return bbox_pos_emb
- class BrosTextEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings."""
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
- self.register_buffer(
- "token_type_ids",
- torch.zeros(
- self.position_ids.size(),
- dtype=torch.long,
- device=self.position_ids.device,
- ),
- persistent=False,
- )
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- ) -> torch.Tensor:
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
- if token_type_ids is None:
- if hasattr(self, "token_type_ids"):
- buffered_token_type_ids = self.token_type_ids[:, :seq_length]
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
- token_type_ids = buffered_token_type_ids_expanded
- else:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + token_type_embeddings
- position_embeddings = self.position_embeddings(position_ids)
- embeddings += position_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class BrosSelfAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.is_decoder = config.is_decoder
- def forward(
- self,
- hidden_states: torch.Tensor,
- bbox_pos_emb: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor]:
- hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
- query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
- # If this is instantiated as a cross-attention module, the keys
- # and values come from an encoder; the attention mask needs to be
- # such that the encoder's padding tokens are not attended to.
- is_cross_attention = encoder_hidden_states is not None
- if is_cross_attention:
- key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
- value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
- attention_mask = encoder_attention_mask
- else:
- key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
- value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- # bbox positional encoding
- batch_size, n_head, seq_length, d_head = query_layer.shape
- bbox_pos_emb = bbox_pos_emb.view(seq_length, seq_length, batch_size, d_head)
- bbox_pos_emb = bbox_pos_emb.permute([2, 0, 1, 3])
- bbox_pos_scores = torch.einsum("bnid,bijd->bnij", (query_layer, bbox_pos_emb))
- attention_scores = attention_scores + bbox_pos_scores
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in BrosModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- return context_layer, attention_probs
- # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Bros
- class BrosSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class BrosAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.self = BrosSelfAttention(config)
- self.output = BrosSelfOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- bbox_pos_emb: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states, _ = self.self(
- hidden_states,
- bbox_pos_emb=bbox_pos_emb,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- )
- hidden_states = self.output(hidden_states, residual)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Bros
- class BrosIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class BrosOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class BrosLayer(GradientCheckpointingLayer):
- def __init__(self, config):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = BrosAttention(config)
- self.is_decoder = config.is_decoder
- self.add_cross_attention = config.add_cross_attention
- if self.add_cross_attention:
- if not self.is_decoder:
- raise Exception(f"{self} should be used as a decoder model if cross attention is added")
- self.crossattention = BrosAttention(config)
- self.intermediate = BrosIntermediate(config)
- self.output = BrosOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- bbox_pos_emb: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> torch.Tensor:
- hidden_states = self.attention(
- hidden_states,
- bbox_pos_emb=bbox_pos_emb,
- attention_mask=attention_mask,
- )
- if self.is_decoder and encoder_hidden_states is not None:
- if hasattr(self, "crossattention"):
- raise Exception(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
- )
- hidden_states, _ = self.crossattention(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- **kwargs,
- )
- hidden_states = apply_chunking_to_forward(
- self.feed_forward_chunk,
- self.chunk_size_feed_forward,
- self.seq_len_dim,
- hidden_states,
- )
- return hidden_states
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Bros
- class BrosPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- class BrosRelationExtractor(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.n_relations = config.n_relations
- self.backbone_hidden_size = config.hidden_size
- self.head_hidden_size = config.hidden_size
- self.classifier_dropout_prob = config.classifier_dropout_prob
- self.drop = nn.Dropout(self.classifier_dropout_prob)
- self.query = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
- self.key = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
- self.dummy_node = nn.Parameter(torch.zeros(1, self.backbone_hidden_size))
- def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
- query_layer = self.query(self.drop(query_layer))
- dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, key_layer.size(1), 1)
- key_layer = torch.cat([key_layer, dummy_vec], axis=0)
- key_layer = self.key(self.drop(key_layer))
- query_layer = query_layer.view(
- query_layer.size(0), query_layer.size(1), self.n_relations, self.head_hidden_size
- )
- key_layer = key_layer.view(key_layer.size(0), key_layer.size(1), self.n_relations, self.head_hidden_size)
- relation_score = torch.matmul(
- query_layer.permute(2, 1, 0, 3), key_layer.permute(2, 1, 3, 0)
- ) # equivalent to torch.einsum("ibnd,jbnd->nbij", (query_layer, key_layer))
- return relation_score
- @auto_docstring
- class BrosPreTrainedModel(PreTrainedModel):
- config: BrosConfig
- base_model_prefix = "bros"
- _can_record_outputs = {
- "hidden_states": BrosLayer,
- "attentions": OutputRecorder(BrosSelfAttention, index=1, layer_name="attention"),
- "cross_attentions": OutputRecorder(BrosSelfAttention, index=1, layer_name="crossattention"),
- }
- @torch.no_grad()
- def _init_weights(self, module: nn.Module):
- """Initialize the weights"""
- super()._init_weights(module)
- std = self.config.initializer_range
- if isinstance(module, BrosRelationExtractor):
- init.normal_(module.dummy_node, std=std)
- elif isinstance(module, BrosTextEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- init.zeros_(module.token_type_ids)
- elif isinstance(module, BrosPositionalEmbedding1D):
- inv_freq = 1 / (
- 10000 ** (torch.arange(0.0, module.dim_bbox_sinusoid_emb_1d, 2.0) / module.dim_bbox_sinusoid_emb_1d)
- )
- init.copy_(module.inv_freq, inv_freq)
- class BrosEncoder(BrosPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
- self.post_init()
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- hidden_states: torch.Tensor,
- bbox_pos_emb: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | BaseModelOutputWithCrossAttentions:
- for layer_module in self.layer:
- hidden_states = layer_module(
- hidden_states,
- bbox_pos_emb=bbox_pos_emb,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- **kwargs,
- )
- return BaseModelOutputWithCrossAttentions(
- last_hidden_state=hidden_states,
- )
- @auto_docstring
- class BrosModel(BrosPreTrainedModel):
- def __init__(self, config, add_pooling_layer=True):
- r"""
- add_pooling_layer (bool, *optional*, defaults to `True`):
- Whether to add a pooling layer
- """
- super().__init__(config)
- self.config = config
- self.embeddings = BrosTextEmbeddings(config)
- self.bbox_embeddings = BrosBboxEmbeddings(config)
- self.encoder = BrosEncoder(config)
- self.pooler = BrosPooler(config) if add_pooling_layer else None
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- bbox: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- encoder_hidden_states: torch.Tensor | None = None,
- encoder_attention_mask: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
- r"""
- bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
- Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
- (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
- bounding box.
- Examples:
- ```python
- >>> import torch
- >>> from transformers import BrosProcessor, BrosModel
- >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
- >>> model = BrosModel.from_pretrained("jinho8345/bros-base-uncased")
- >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
- >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
- >>> encoding["bbox"] = bbox
- >>> outputs = model(**encoding)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if bbox is None:
- raise ValueError("You have to specify bbox")
- embedding_output = self.embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- )
- input_shape = embedding_output.shape[:-1]
- device = embedding_output.device
- if attention_mask is None:
- attention_mask = torch.ones(input_shape, device=device)
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
- # ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if self.config.is_decoder and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = None
- # if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
- if bbox.shape[-1] == 4:
- bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
- scaled_bbox = bbox * self.config.bbox_scale
- bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
- encoder_outputs: BaseModelOutputWithCrossAttentions = self.encoder(
- embedding_output,
- bbox_pos_emb=bbox_position_embeddings,
- attention_mask=extended_attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_extended_attention_mask,
- **kwargs,
- )
- sequence_output = encoder_outputs[0]
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
- return BaseModelOutputWithPoolingAndCrossAttentions(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- cross_attentions=encoder_outputs.cross_attentions,
- )
- @auto_docstring
- class BrosForTokenClassification(BrosPreTrainedModel):
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.bros = BrosModel(config)
- classifier_dropout = (
- config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- bbox: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- bbox_first_token_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | TokenClassifierOutput:
- r"""
- bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
- Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
- (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
- bounding box.
- bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- Examples:
- ```python
- >>> import torch
- >>> from transformers import BrosProcessor, BrosForTokenClassification
- >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
- >>> model = BrosForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
- >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
- >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
- >>> encoding["bbox"] = bbox
- >>> outputs = model(**encoding)
- ```"""
- outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.bros(
- input_ids,
- bbox=bbox,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- if bbox_first_token_mask is not None:
- bbox_first_token_mask = bbox_first_token_mask.view(-1)
- loss = loss_fct(
- logits.view(-1, self.num_labels)[bbox_first_token_mask], labels.view(-1)[bbox_first_token_mask]
- )
- else:
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Bros Model with a token classification head on top (initial_token_layers and subsequent_token_layer on top of the
- hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. The initial_token_classifier is used to
- predict the first token of each entity, and the subsequent_token_classifier is used to predict the subsequent
- tokens within an entity. Compared to BrosForTokenClassification, this model is more robust to serialization errors
- since it predicts next token from one token.
- """
- )
- class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- self.num_labels = config.num_labels
- self.n_relations = config.n_relations
- self.backbone_hidden_size = config.hidden_size
- self.bros = BrosModel(config)
- classifier_dropout = (
- config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
- )
- # Initial token classification for Entity Extraction (NER)
- self.initial_token_classifier = nn.Sequential(
- nn.Dropout(classifier_dropout),
- nn.Linear(config.hidden_size, config.hidden_size),
- nn.Dropout(classifier_dropout),
- nn.Linear(config.hidden_size, config.num_labels),
- )
- # Subsequent token classification for Entity Extraction (NER)
- self.subsequent_token_classifier = BrosRelationExtractor(config)
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- bbox: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- bbox_first_token_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- initial_token_labels: torch.Tensor | None = None,
- subsequent_token_labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | BrosSpadeOutput:
- r"""
- bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
- Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
- (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
- bounding box.
- bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- initial_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for the initial token classification.
- subsequent_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for the subsequent token classification.
- Examples:
- ```python
- >>> import torch
- >>> from transformers import BrosProcessor, BrosSpadeEEForTokenClassification
- >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
- >>> model = BrosSpadeEEForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
- >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
- >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
- >>> encoding["bbox"] = bbox
- >>> outputs = model(**encoding)
- ```"""
- outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.bros(
- input_ids=input_ids,
- bbox=bbox,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
- last_hidden_states = outputs[0]
- last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
- initial_token_logits = self.initial_token_classifier(last_hidden_states).transpose(0, 1).contiguous()
- subsequent_token_logits = self.subsequent_token_classifier(last_hidden_states, last_hidden_states).squeeze(0)
- # make subsequent token (sequence token classification) mask
- inv_attention_mask = 1 - attention_mask
- batch_size, max_seq_length = inv_attention_mask.shape
- device = inv_attention_mask.device
- invalid_token_mask = torch.cat([inv_attention_mask, torch.zeros([batch_size, 1]).to(device)], axis=1).bool()
- subsequent_token_logits = subsequent_token_logits.masked_fill(
- invalid_token_mask[:, None, :], torch.finfo(subsequent_token_logits.dtype).min
- )
- self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
- subsequent_token_logits = subsequent_token_logits.masked_fill(
- self_token_mask[None, :, :], torch.finfo(subsequent_token_logits.dtype).min
- )
- subsequent_token_mask = attention_mask.view(-1).bool()
- loss = None
- if initial_token_labels is not None and subsequent_token_labels is not None:
- loss_fct = CrossEntropyLoss()
- # get initial token loss
- initial_token_labels = initial_token_labels.view(-1)
- if bbox_first_token_mask is not None:
- bbox_first_token_mask = bbox_first_token_mask.view(-1)
- initial_token_loss = loss_fct(
- initial_token_logits.view(-1, self.num_labels)[bbox_first_token_mask],
- initial_token_labels[bbox_first_token_mask],
- )
- else:
- initial_token_loss = loss_fct(initial_token_logits.view(-1, self.num_labels), initial_token_labels)
- subsequent_token_labels = subsequent_token_labels.view(-1)
- subsequent_token_loss = loss_fct(
- subsequent_token_logits.view(-1, max_seq_length + 1)[subsequent_token_mask],
- subsequent_token_labels[subsequent_token_mask],
- )
- loss = initial_token_loss + subsequent_token_loss
- return BrosSpadeOutput(
- loss=loss,
- initial_token_logits=initial_token_logits,
- subsequent_token_logits=subsequent_token_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- Bros Model with a token classification head on top (a entity_linker layer on top of the hidden-states output) e.g.
- for Entity-Linking. The entity_linker is used to predict intra-entity links (one entity to another entity).
- """
- )
- class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
- _keys_to_ignore_on_load_unexpected = [r"pooler"]
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- self.num_labels = config.num_labels
- self.n_relations = config.n_relations
- self.backbone_hidden_size = config.hidden_size
- self.bros = BrosModel(config)
- (config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob)
- self.entity_linker = BrosRelationExtractor(config)
- self.post_init()
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- bbox: torch.Tensor | None = None,
- attention_mask: torch.Tensor | None = None,
- bbox_first_token_mask: torch.Tensor | None = None,
- token_type_ids: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- inputs_embeds: torch.Tensor | None = None,
- labels: torch.Tensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor] | TokenClassifierOutput:
- r"""
- bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
- Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
- (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
- bounding box.
- bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- Examples:
- ```python
- >>> import torch
- >>> from transformers import BrosProcessor, BrosSpadeELForTokenClassification
- >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
- >>> model = BrosSpadeELForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
- >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
- >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
- >>> encoding["bbox"] = bbox
- >>> outputs = model(**encoding)
- ```"""
- outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.bros(
- input_ids=input_ids,
- bbox=bbox,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
- last_hidden_states = outputs[0]
- last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
- logits = self.entity_linker(last_hidden_states, last_hidden_states).squeeze(0)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- batch_size, max_seq_length = attention_mask.shape
- device = attention_mask.device
- self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
- mask = bbox_first_token_mask.view(-1)
- bbox_first_token_mask = torch.cat(
- [
- ~bbox_first_token_mask,
- torch.zeros([batch_size, 1], dtype=torch.bool, device=device),
- ],
- axis=1,
- )
- logits = logits.masked_fill(bbox_first_token_mask[:, None, :], torch.finfo(logits.dtype).min)
- logits = logits.masked_fill(self_token_mask[None, :, :], torch.finfo(logits.dtype).min)
- loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "BrosPreTrainedModel",
- "BrosModel",
- "BrosForTokenClassification",
- "BrosSpadeEEForTokenClassification",
- "BrosSpadeELForTokenClassification",
- ]
|