| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358 |
- # Copyright 2021 Google AI The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch CANINE model."""
- import copy
- import math
- from dataclasses import dataclass
- import torch
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- ModelOutput,
- MultipleChoiceModelOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import apply_chunking_to_forward
- from ...utils import auto_docstring, logging
- from .configuration_canine import CanineConfig
- logger = logging.get_logger(__name__)
- # Support up to 16 hash functions.
- _PRIMES = [31, 43, 59, 61, 73, 97, 103, 113, 137, 149, 157, 173, 181, 193, 211, 223]
- @dataclass
- @auto_docstring(
- custom_intro="""
- Output type of [`CanineModel`]. Based on [`~modeling_outputs.BaseModelOutputWithPooling`], but with slightly
- different `hidden_states` and `attentions`, as these also include the hidden states and attentions of the shallow
- Transformer encoders.
- """
- )
- class CanineModelOutputWithPooling(ModelOutput):
- r"""
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model (i.e. the output of the final
- shallow Transformer encoder).
- pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
- Hidden-state of the first token of the sequence (classification token) at the last layer of the deep
- Transformer encoder, further processed by a Linear layer and a Tanh activation function. The Linear layer
- weights are trained from the next sentence prediction (classification) objective during pretraining.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the input to each encoder + one for the output of each layer of each
- encoder) of shape `(batch_size, sequence_length, hidden_size)` and `(batch_size, sequence_length //
- config.downsampling_rate, hidden_size)`. Hidden-states of the model at the output of each layer plus the
- initial input to each Transformer encoder. The hidden states of the shallow encoders have length
- `sequence_length`, but the hidden states of the deep encoder have length `sequence_length` //
- `config.downsampling_rate`.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of the 3 Transformer encoders of shape `(batch_size,
- num_heads, sequence_length, sequence_length)` and `(batch_size, num_heads, sequence_length //
- config.downsampling_rate, sequence_length // config.downsampling_rate)`. Attentions weights after the
- attention softmax, used to compute the weighted average in the self-attention heads.
- """
- last_hidden_state: torch.FloatTensor | None = None
- pooler_output: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor] | None = None
- attentions: tuple[torch.FloatTensor] | None = None
- class CanineEmbeddings(nn.Module):
- """Construct the character, position and token_type embeddings."""
- def __init__(self, config):
- super().__init__()
- self.config = config
- # character embeddings
- shard_embedding_size = config.hidden_size // config.num_hash_functions
- for i in range(config.num_hash_functions):
- name = f"HashBucketCodepointEmbedder_{i}"
- setattr(self, name, nn.Embedding(config.num_hash_buckets, shard_embedding_size))
- self.char_position_embeddings = nn.Embedding(config.num_hash_buckets, config.hidden_size)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- def _hash_bucket_tensors(self, input_ids, num_hashes: int, num_buckets: int):
- """
- Converts ids to hash bucket ids via multiple hashing.
- Args:
- input_ids: The codepoints or other IDs to be hashed.
- num_hashes: The number of hash functions to use.
- num_buckets: The number of hash buckets (i.e. embeddings in each table).
- Returns:
- A list of tensors, each of which is the hash bucket IDs from one hash function.
- """
- if num_hashes > len(_PRIMES):
- raise ValueError(f"`num_hashes` must be <= {len(_PRIMES)}")
- primes = _PRIMES[:num_hashes]
- result_tensors = []
- for prime in primes:
- hashed = ((input_ids + 1) * prime) % num_buckets
- result_tensors.append(hashed)
- return result_tensors
- def _embed_hash_buckets(self, input_ids, embedding_size: int, num_hashes: int, num_buckets: int):
- """Converts IDs (e.g. codepoints) into embeddings via multiple hashing."""
- if embedding_size % num_hashes != 0:
- raise ValueError(f"Expected `embedding_size` ({embedding_size}) % `num_hashes` ({num_hashes}) == 0")
- hash_bucket_tensors = self._hash_bucket_tensors(input_ids, num_hashes=num_hashes, num_buckets=num_buckets)
- embedding_shards = []
- for i, hash_bucket_ids in enumerate(hash_bucket_tensors):
- name = f"HashBucketCodepointEmbedder_{i}"
- shard_embeddings = getattr(self, name)(hash_bucket_ids)
- embedding_shards.append(shard_embeddings)
- return torch.cat(embedding_shards, dim=-1)
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- ) -> torch.FloatTensor:
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
- if inputs_embeds is None:
- inputs_embeds = self._embed_hash_buckets(
- input_ids, self.config.hidden_size, self.config.num_hash_functions, self.config.num_hash_buckets
- )
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + token_type_embeddings
- position_embeddings = self.char_position_embeddings(position_ids)
- embeddings += position_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class CharactersToMolecules(nn.Module):
- """Convert character sequence to initial molecule sequence (i.e. downsample) using strided convolutions."""
- def __init__(self, config):
- super().__init__()
- self.conv = nn.Conv1d(
- in_channels=config.hidden_size,
- out_channels=config.hidden_size,
- kernel_size=config.downsampling_rate,
- stride=config.downsampling_rate,
- )
- self.activation = ACT2FN[config.hidden_act]
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, char_encoding: torch.Tensor) -> torch.Tensor:
- # `cls_encoding`: [batch, 1, hidden_size]
- cls_encoding = char_encoding[:, 0:1, :]
- # char_encoding has shape [batch, char_seq, hidden_size]
- # We transpose it to be [batch, hidden_size, char_seq]
- char_encoding = torch.transpose(char_encoding, 1, 2)
- downsampled = self.conv(char_encoding)
- downsampled = torch.transpose(downsampled, 1, 2)
- downsampled = self.activation(downsampled)
- # Truncate the last molecule in order to reserve a position for [CLS].
- # Often, the last position is never used (unless we completely fill the
- # text buffer). This is important in order to maintain alignment on TPUs
- # (i.e. a multiple of 128).
- downsampled_truncated = downsampled[:, 0:-1, :]
- # We also keep [CLS] as a separate sequence position since we always
- # want to reserve a position (and the model capacity that goes along
- # with that) in the deep BERT stack.
- # `result`: [batch, molecule_seq, molecule_dim]
- result = torch.cat([cls_encoding, downsampled_truncated], dim=1)
- result = self.LayerNorm(result)
- return result
- class ConvProjection(nn.Module):
- """
- Project representations from hidden_size*2 back to hidden_size across a window of w = config.upsampling_kernel_size
- characters.
- """
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.conv = nn.Conv1d(
- in_channels=config.hidden_size * 2,
- out_channels=config.hidden_size,
- kernel_size=config.upsampling_kernel_size,
- stride=1,
- )
- self.activation = ACT2FN[config.hidden_act]
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(
- self,
- inputs: torch.Tensor,
- final_seq_char_positions: torch.Tensor | None = None,
- ) -> torch.Tensor:
- # inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final]
- # we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq]
- inputs = torch.transpose(inputs, 1, 2)
- # PyTorch < 1.9 does not support padding="same" (which is used in the original implementation),
- # so we pad the tensor manually before passing it to the conv layer
- # based on https://github.com/google-research/big_transfer/blob/49afe42338b62af9fbe18f0258197a33ee578a6b/bit_tf2/models.py#L36-L38
- pad_total = self.config.upsampling_kernel_size - 1
- pad_beg = pad_total // 2
- pad_end = pad_total - pad_beg
- pad = nn.ConstantPad1d((pad_beg, pad_end), 0)
- # `result`: shape (batch_size, char_seq_len, hidden_size)
- result = self.conv(pad(inputs))
- result = torch.transpose(result, 1, 2)
- result = self.activation(result)
- result = self.LayerNorm(result)
- result = self.dropout(result)
- final_char_seq = result
- if final_seq_char_positions is not None:
- # Limit transformer query seq and attention mask to these character
- # positions to greatly reduce the compute cost. Typically, this is just
- # done for the MLM training task.
- # TODO add support for MLM
- raise NotImplementedError("CanineForMaskedLM is currently not supported")
- else:
- query_seq = final_char_seq
- return query_seq
- class CanineSelfAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- def forward(
- self,
- from_tensor: torch.Tensor,
- to_tensor: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- batch_size, seq_length, _ = from_tensor.shape
- # If this is instantiated as a cross-attention module, the keys
- # and values come from an encoder; the attention mask needs to be
- # such that the encoder's padding tokens are not attended to.
- key_layer = (
- self.key(to_tensor)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- value_layer = (
- self.value(to_tensor)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- query_layer = (
- self.query(from_tensor)
- .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
- .transpose(1, 2)
- )
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- if attention_mask.ndim == 3:
- # if attention_mask is 3D, do the following:
- attention_mask = torch.unsqueeze(attention_mask, dim=1)
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and the dtype's smallest value for masked positions.
- attention_mask = (1.0 - attention_mask.float()) * torch.finfo(attention_scores.dtype).min
- # Apply the attention mask (precomputed for all layers in CanineModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
- return outputs
- class CanineSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(
- self, hidden_states: tuple[torch.FloatTensor], input_tensor: torch.FloatTensor
- ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class CanineAttention(nn.Module):
- """
- Additional arguments related to local attention:
- - **local** (`bool`, *optional*, defaults to `False`) -- Whether to apply local attention.
- - **always_attend_to_first_position** (`bool`, *optional*, defaults to `False`) -- Should all blocks be able to
- attend
- to the `to_tensor`'s first position (e.g. a [CLS] position)? - **first_position_attends_to_all** (`bool`,
- *optional*, defaults to `False`) -- Should the *from_tensor*'s first position be able to attend to all
- positions within the *from_tensor*? - **attend_from_chunk_width** (`int`, *optional*, defaults to 128) -- The
- width of each block-wise chunk in `from_tensor`. - **attend_from_chunk_stride** (`int`, *optional*, defaults to
- 128) -- The number of elements to skip when moving to the next block in `from_tensor`. -
- **attend_to_chunk_width** (`int`, *optional*, defaults to 128) -- The width of each block-wise chunk in
- *to_tensor*. - **attend_to_chunk_stride** (`int`, *optional*, defaults to 128) -- The number of elements to
- skip when moving to the next block in `to_tensor`.
- """
- def __init__(
- self,
- config,
- local=False,
- always_attend_to_first_position: bool = False,
- first_position_attends_to_all: bool = False,
- attend_from_chunk_width: int = 128,
- attend_from_chunk_stride: int = 128,
- attend_to_chunk_width: int = 128,
- attend_to_chunk_stride: int = 128,
- ):
- super().__init__()
- self.self = CanineSelfAttention(config)
- self.output = CanineSelfOutput(config)
- # additional arguments related to local attention
- self.local = local
- if attend_from_chunk_width < attend_from_chunk_stride:
- raise ValueError(
- "`attend_from_chunk_width` < `attend_from_chunk_stride` would cause sequence positions to get skipped."
- )
- if attend_to_chunk_width < attend_to_chunk_stride:
- raise ValueError(
- "`attend_to_chunk_width` < `attend_to_chunk_stride`would cause sequence positions to get skipped."
- )
- self.always_attend_to_first_position = always_attend_to_first_position
- self.first_position_attends_to_all = first_position_attends_to_all
- self.attend_from_chunk_width = attend_from_chunk_width
- self.attend_from_chunk_stride = attend_from_chunk_stride
- self.attend_to_chunk_width = attend_to_chunk_width
- self.attend_to_chunk_stride = attend_to_chunk_stride
- def forward(
- self,
- hidden_states: tuple[torch.FloatTensor],
- attention_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]:
- if not self.local:
- self_outputs = self.self(hidden_states, hidden_states, attention_mask, output_attentions)
- attention_output = self_outputs[0]
- else:
- from_seq_length = to_seq_length = hidden_states.shape[1]
- from_tensor = to_tensor = hidden_states
- # Create chunks (windows) that we will attend *from* and then concatenate them.
- from_chunks = []
- if self.first_position_attends_to_all:
- from_chunks.append((0, 1))
- # We must skip this first position so that our output sequence is the
- # correct length (this matters in the *from* sequence only).
- from_start = 1
- else:
- from_start = 0
- for chunk_start in range(from_start, from_seq_length, self.attend_from_chunk_stride):
- chunk_end = min(from_seq_length, chunk_start + self.attend_from_chunk_width)
- from_chunks.append((chunk_start, chunk_end))
- # Determine the chunks (windows) that will attend *to*.
- to_chunks = []
- if self.first_position_attends_to_all:
- to_chunks.append((0, to_seq_length))
- for chunk_start in range(0, to_seq_length, self.attend_to_chunk_stride):
- chunk_end = min(to_seq_length, chunk_start + self.attend_to_chunk_width)
- to_chunks.append((chunk_start, chunk_end))
- if len(from_chunks) != len(to_chunks):
- raise ValueError(
- f"Expected to have same number of `from_chunks` ({from_chunks}) and "
- f"`to_chunks` ({from_chunks}). Check strides."
- )
- # next, compute attention scores for each pair of windows and concatenate
- attention_output_chunks = []
- attention_probs_chunks = []
- for (from_start, from_end), (to_start, to_end) in zip(from_chunks, to_chunks):
- from_tensor_chunk = from_tensor[:, from_start:from_end, :]
- to_tensor_chunk = to_tensor[:, to_start:to_end, :]
- # `attention_mask`: <float>[batch_size, from_seq, to_seq]
- # `attention_mask_chunk`: <float>[batch_size, from_seq_chunk, to_seq_chunk]
- attention_mask_chunk = attention_mask[:, from_start:from_end, to_start:to_end]
- if self.always_attend_to_first_position:
- cls_attention_mask = attention_mask[:, from_start:from_end, 0:1]
- attention_mask_chunk = torch.cat([cls_attention_mask, attention_mask_chunk], dim=2)
- cls_position = to_tensor[:, 0:1, :]
- to_tensor_chunk = torch.cat([cls_position, to_tensor_chunk], dim=1)
- attention_outputs_chunk = self.self(
- from_tensor_chunk, to_tensor_chunk, attention_mask_chunk, output_attentions
- )
- attention_output_chunks.append(attention_outputs_chunk[0])
- if output_attentions:
- attention_probs_chunks.append(attention_outputs_chunk[1])
- attention_output = torch.cat(attention_output_chunks, dim=1)
- attention_output = self.output(attention_output, hidden_states)
- outputs = (attention_output,)
- if not self.local:
- outputs = outputs + self_outputs[1:] # add attentions if we output them
- else:
- outputs = outputs + tuple(attention_probs_chunks) # add attentions if we output them
- return outputs
- class CanineIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class CanineOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: tuple[torch.FloatTensor], input_tensor: torch.FloatTensor) -> torch.FloatTensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class CanineLayer(GradientCheckpointingLayer):
- def __init__(
- self,
- config,
- local,
- always_attend_to_first_position,
- first_position_attends_to_all,
- attend_from_chunk_width,
- attend_from_chunk_stride,
- attend_to_chunk_width,
- attend_to_chunk_stride,
- ):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = CanineAttention(
- config,
- local,
- always_attend_to_first_position,
- first_position_attends_to_all,
- attend_from_chunk_width,
- attend_from_chunk_stride,
- attend_to_chunk_width,
- attend_to_chunk_stride,
- )
- self.intermediate = CanineIntermediate(config)
- self.output = CanineOutput(config)
- def forward(
- self,
- hidden_states: tuple[torch.FloatTensor],
- attention_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]:
- self_attention_outputs = self.attention(
- hidden_states,
- attention_mask,
- output_attentions=output_attentions,
- )
- attention_output = self_attention_outputs[0]
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
- )
- outputs = (layer_output,) + outputs
- return outputs
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- class CanineEncoder(nn.Module):
- def __init__(
- self,
- config,
- local=False,
- always_attend_to_first_position=False,
- first_position_attends_to_all=False,
- attend_from_chunk_width=128,
- attend_from_chunk_stride=128,
- attend_to_chunk_width=128,
- attend_to_chunk_stride=128,
- ):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList(
- [
- CanineLayer(
- config,
- local,
- always_attend_to_first_position,
- first_position_attends_to_all,
- attend_from_chunk_width,
- attend_from_chunk_stride,
- attend_to_chunk_width,
- attend_to_chunk_stride,
- )
- for _ in range(config.num_hidden_layers)
- ]
- )
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: tuple[torch.FloatTensor],
- attention_mask: torch.FloatTensor | None = None,
- output_attentions: bool | None = False,
- output_hidden_states: bool | None = False,
- return_dict: bool | None = True,
- ) -> tuple | BaseModelOutput:
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- class CaninePooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- class CaninePredictionHeadTransform(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- if isinstance(config.hidden_act, str):
- self.transform_act_fn = ACT2FN[config.hidden_act]
- else:
- self.transform_act_fn = config.hidden_act
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- return hidden_states
- class CanineLMPredictionHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.transform = CaninePredictionHeadTransform(config)
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
- # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
- def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor:
- hidden_states = self.transform(hidden_states)
- hidden_states = self.decoder(hidden_states)
- return hidden_states
- class CanineOnlyMLMHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.predictions = CanineLMPredictionHead(config)
- def forward(
- self,
- sequence_output: tuple[torch.Tensor],
- ) -> tuple[torch.Tensor]:
- prediction_scores = self.predictions(sequence_output)
- return prediction_scores
- @auto_docstring
- class CaninePreTrainedModel(PreTrainedModel):
- config: CanineConfig
- base_model_prefix = "canine"
- supports_gradient_checkpointing = True
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, CanineEmbeddings):
- init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
- @auto_docstring
- class CanineModel(CaninePreTrainedModel):
- def __init__(self, config, add_pooling_layer=True):
- r"""
- add_pooling_layer (bool, *optional*, defaults to `True`):
- Whether to add a pooling layer
- """
- super().__init__(config)
- self.config = config
- shallow_config = copy.deepcopy(config)
- shallow_config.num_hidden_layers = 1
- self.char_embeddings = CanineEmbeddings(config)
- # shallow/low-dim transformer encoder to get a initial character encoding
- self.initial_char_encoder = CanineEncoder(
- shallow_config,
- local=True,
- always_attend_to_first_position=False,
- first_position_attends_to_all=False,
- attend_from_chunk_width=config.local_transformer_stride,
- attend_from_chunk_stride=config.local_transformer_stride,
- attend_to_chunk_width=config.local_transformer_stride,
- attend_to_chunk_stride=config.local_transformer_stride,
- )
- self.chars_to_molecules = CharactersToMolecules(config)
- # deep transformer encoder
- self.encoder = CanineEncoder(config)
- self.projection = ConvProjection(config)
- # shallow/low-dim transformer encoder to get a final character encoding
- self.final_char_encoder = CanineEncoder(shallow_config)
- self.pooler = CaninePooler(config) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def _create_3d_attention_mask_from_input_mask(self, from_tensor, to_mask):
- """
- Create 3D attention mask from a 2D tensor mask.
- Args:
- from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
- to_mask: int32 Tensor of shape [batch_size, to_seq_length].
- Returns:
- float Tensor of shape [batch_size, from_seq_length, to_seq_length].
- """
- batch_size, from_seq_length = from_tensor.shape[0], from_tensor.shape[1]
- to_seq_length = to_mask.shape[1]
- to_mask = torch.reshape(to_mask, (batch_size, 1, to_seq_length)).float()
- # We don't assume that `from_tensor` is a mask (although it could be). We
- # don't actually care if we attend *from* padding tokens (only *to* padding)
- # tokens so we create a tensor of all ones.
- broadcast_ones = torch.ones(size=(batch_size, from_seq_length, 1), dtype=torch.float32, device=to_mask.device)
- # Here we broadcast along two dimensions to create the mask.
- mask = broadcast_ones * to_mask
- return mask
- def _downsample_attention_mask(self, char_attention_mask: torch.Tensor, downsampling_rate: int):
- """Downsample 2D character attention mask to 2D molecule attention mask using MaxPool1d layer."""
- # first, make char_attention_mask 3D by adding a channel dim
- batch_size, char_seq_len = char_attention_mask.shape
- poolable_char_mask = torch.reshape(char_attention_mask, (batch_size, 1, char_seq_len))
- # next, apply MaxPool1d to get pooled_molecule_mask of shape (batch_size, 1, mol_seq_len)
- pooled_molecule_mask = torch.nn.MaxPool1d(kernel_size=downsampling_rate, stride=downsampling_rate)(
- poolable_char_mask.float()
- )
- # finally, squeeze to get tensor of shape (batch_size, mol_seq_len)
- molecule_attention_mask = torch.squeeze(pooled_molecule_mask, dim=-1)
- return molecule_attention_mask
- def _repeat_molecules(self, molecules: torch.Tensor, char_seq_length: int) -> torch.Tensor:
- """Repeats molecules to make them the same length as the char sequence."""
- rate = self.config.downsampling_rate
- molecules_without_extra_cls = molecules[:, 1:, :]
- # `repeated`: [batch_size, almost_char_seq_len, molecule_hidden_size]
- repeated = torch.repeat_interleave(molecules_without_extra_cls, repeats=rate, dim=-2)
- # So far, we've repeated the elements sufficient for any `char_seq_length`
- # that's a multiple of `downsampling_rate`. Now we account for the last
- # n elements (n < `downsampling_rate`), i.e. the remainder of floor
- # division. We do this by repeating the last molecule a few extra times.
- last_molecule = molecules[:, -1:, :]
- remainder_length = char_seq_length % rate
- remainder_repeated = torch.repeat_interleave(
- last_molecule,
- # +1 molecule to compensate for truncation.
- repeats=remainder_length + rate,
- dim=-2,
- )
- # `repeated`: [batch_size, char_seq_len, molecule_hidden_size]
- return torch.cat([repeated, remainder_repeated], dim=-2)
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | CanineModelOutputWithPooling:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- input_shape = input_ids.size()
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- batch_size, seq_length = input_shape
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones(((batch_size, seq_length)), device=device)
- if token_type_ids is None:
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
- # ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
- molecule_attention_mask = self._downsample_attention_mask(
- attention_mask, downsampling_rate=self.config.downsampling_rate
- )
- extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
- molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
- )
- # `input_char_embeddings`: shape (batch_size, char_seq, char_dim)
- input_char_embeddings = self.char_embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- token_type_ids=token_type_ids,
- inputs_embeds=inputs_embeds,
- )
- # Contextualize character embeddings using shallow Transformer.
- # We use a 3D attention mask for the local attention.
- # `input_char_encoding`: shape (batch_size, char_seq_len, char_dim)
- char_attention_mask = self._create_3d_attention_mask_from_input_mask(
- input_ids if input_ids is not None else inputs_embeds, attention_mask
- )
- init_chars_encoder_outputs = self.initial_char_encoder(
- input_char_embeddings,
- attention_mask=char_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
- input_char_encoding = init_chars_encoder_outputs.last_hidden_state
- # Downsample chars to molecules.
- # The following lines have dimensions: [batch, molecule_seq, molecule_dim].
- # In this transformation, we change the dimensionality from `char_dim` to
- # `molecule_dim`, but do *NOT* add a resnet connection. Instead, we rely on
- # the resnet connections (a) from the final char transformer stack back into
- # the original char transformer stack and (b) the resnet connections from
- # the final char transformer stack back into the deep BERT stack of
- # molecules.
- #
- # Empirically, it is critical to use a powerful enough transformation here:
- # mean pooling causes training to diverge with huge gradient norms in this
- # region of the model; using a convolution here resolves this issue. From
- # this, it seems that molecules and characters require a very different
- # feature space; intuitively, this makes sense.
- init_molecule_encoding = self.chars_to_molecules(input_char_encoding)
- # Deep BERT encoder
- # `molecule_sequence_output`: shape (batch_size, mol_seq_len, mol_dim)
- encoder_outputs = self.encoder(
- init_molecule_encoding,
- attention_mask=extended_molecule_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- molecule_sequence_output = encoder_outputs[0]
- pooled_output = self.pooler(molecule_sequence_output) if self.pooler is not None else None
- # Upsample molecules back to characters.
- # `repeated_molecules`: shape (batch_size, char_seq_len, mol_hidden_size)
- repeated_molecules = self._repeat_molecules(molecule_sequence_output, char_seq_length=input_shape[-1])
- # Concatenate representations (contextualized char embeddings and repeated molecules):
- # `concat`: shape [batch_size, char_seq_len, molecule_hidden_size+char_hidden_final]
- concat = torch.cat([input_char_encoding, repeated_molecules], dim=-1)
- # Project representation dimension back to hidden_size
- # `sequence_output`: shape (batch_size, char_seq_len, hidden_size])
- sequence_output = self.projection(concat)
- # Apply final shallow Transformer
- # `sequence_output`: shape (batch_size, char_seq_len, hidden_size])
- final_chars_encoder_outputs = self.final_char_encoder(
- sequence_output,
- attention_mask=extended_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
- sequence_output = final_chars_encoder_outputs.last_hidden_state
- if output_hidden_states:
- deep_encoder_hidden_states = encoder_outputs.hidden_states if return_dict else encoder_outputs[1]
- all_hidden_states = (
- all_hidden_states
- + init_chars_encoder_outputs.hidden_states
- + deep_encoder_hidden_states
- + final_chars_encoder_outputs.hidden_states
- )
- if output_attentions:
- deep_encoder_self_attentions = encoder_outputs.attentions if return_dict else encoder_outputs[-1]
- all_self_attentions = (
- all_self_attentions
- + init_chars_encoder_outputs.attentions
- + deep_encoder_self_attentions
- + final_chars_encoder_outputs.attentions
- )
- if not return_dict:
- output = (sequence_output, pooled_output)
- output += tuple(v for v in [all_hidden_states, all_self_attentions] if v is not None)
- return output
- return CanineModelOutputWithPooling(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- @auto_docstring(
- custom_intro="""
- CANINE Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
- output) e.g. for GLUE tasks.
- """
- )
- class CanineForSequenceClassification(CaninePreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.canine = CanineModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | SequenceClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.canine(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class CanineForMultipleChoice(CaninePreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.canine = CanineModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | MultipleChoiceModelOutput:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
- model's internal embedding lookup matrix.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
- `input_ids` above)
- """
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- outputs = self.canine(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- if not return_dict:
- output = (reshaped_logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return MultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class CanineForTokenClassification(CaninePreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.canine = CanineModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- labels: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | TokenClassifierOutput:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, CanineForTokenClassification
- >>> import torch
- >>> tokenizer = AutoTokenizer.from_pretrained("google/canine-s")
- >>> model = CanineForTokenClassification.from_pretrained("google/canine-s")
- >>> inputs = tokenizer(
- ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
- ... )
- >>> with torch.no_grad():
- ... logits = model(**inputs).logits
- >>> predicted_token_class_ids = logits.argmax(-1)
- >>> # Note that tokens are classified rather then input words which means that
- >>> # there might be more predicted token classes than words.
- >>> # Multiple token classes might account for the same word
- >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
- >>> predicted_tokens_classes # doctest: +SKIP
- ```
- ```python
- >>> labels = predicted_token_class_ids
- >>> loss = model(**inputs, labels=labels).loss
- >>> round(loss.item(), 2) # doctest: +SKIP
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.canine(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class CanineForQuestionAnswering(CaninePreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.canine = CanineModel(config)
- self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.FloatTensor | None = None,
- token_type_ids: torch.LongTensor | None = None,
- position_ids: torch.LongTensor | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- start_positions: torch.LongTensor | None = None,
- end_positions: torch.LongTensor | None = None,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- **kwargs,
- ) -> tuple | QuestionAnsweringModelOutput:
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- outputs = self.canine(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1)
- end_logits = end_logits.squeeze(-1)
- total_loss = None
- if start_positions is not None and end_positions is not None:
- # If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
- ignored_index = start_logits.size(1)
- start_positions.clamp_(0, ignored_index)
- end_positions.clamp_(0, ignored_index)
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
- total_loss = (start_loss + end_loss) / 2
- if not return_dict:
- output = (start_logits, end_logits) + outputs[2:]
- return ((total_loss,) + output) if total_loss is not None else output
- return QuestionAnsweringModelOutput(
- loss=total_loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "CanineForMultipleChoice",
- "CanineForQuestionAnswering",
- "CanineForSequenceClassification",
- "CanineForTokenClassification",
- "CanineLayer",
- "CanineModel",
- "CaninePreTrainedModel",
- ]
|