modeling_canine.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358
  1. # Copyright 2021 Google AI The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch CANINE model."""
  15. import copy
  16. import math
  17. from dataclasses import dataclass
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. ModelOutput,
  27. MultipleChoiceModelOutput,
  28. QuestionAnsweringModelOutput,
  29. SequenceClassifierOutput,
  30. TokenClassifierOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...pytorch_utils import apply_chunking_to_forward
  34. from ...utils import auto_docstring, logging
  35. from .configuration_canine import CanineConfig
  36. logger = logging.get_logger(__name__)
  37. # Support up to 16 hash functions.
  38. _PRIMES = [31, 43, 59, 61, 73, 97, 103, 113, 137, 149, 157, 173, 181, 193, 211, 223]
  39. @dataclass
  40. @auto_docstring(
  41. custom_intro="""
  42. Output type of [`CanineModel`]. Based on [`~modeling_outputs.BaseModelOutputWithPooling`], but with slightly
  43. different `hidden_states` and `attentions`, as these also include the hidden states and attentions of the shallow
  44. Transformer encoders.
  45. """
  46. )
  47. class CanineModelOutputWithPooling(ModelOutput):
  48. r"""
  49. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  50. Sequence of hidden-states at the output of the last layer of the model (i.e. the output of the final
  51. shallow Transformer encoder).
  52. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  53. Hidden-state of the first token of the sequence (classification token) at the last layer of the deep
  54. Transformer encoder, further processed by a Linear layer and a Tanh activation function. The Linear layer
  55. weights are trained from the next sentence prediction (classification) objective during pretraining.
  56. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  57. Tuple of `torch.FloatTensor` (one for the input to each encoder + one for the output of each layer of each
  58. encoder) of shape `(batch_size, sequence_length, hidden_size)` and `(batch_size, sequence_length //
  59. config.downsampling_rate, hidden_size)`. Hidden-states of the model at the output of each layer plus the
  60. initial input to each Transformer encoder. The hidden states of the shallow encoders have length
  61. `sequence_length`, but the hidden states of the deep encoder have length `sequence_length` //
  62. `config.downsampling_rate`.
  63. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  64. Tuple of `torch.FloatTensor` (one for each layer) of the 3 Transformer encoders of shape `(batch_size,
  65. num_heads, sequence_length, sequence_length)` and `(batch_size, num_heads, sequence_length //
  66. config.downsampling_rate, sequence_length // config.downsampling_rate)`. Attentions weights after the
  67. attention softmax, used to compute the weighted average in the self-attention heads.
  68. """
  69. last_hidden_state: torch.FloatTensor | None = None
  70. pooler_output: torch.FloatTensor | None = None
  71. hidden_states: tuple[torch.FloatTensor] | None = None
  72. attentions: tuple[torch.FloatTensor] | None = None
  73. class CanineEmbeddings(nn.Module):
  74. """Construct the character, position and token_type embeddings."""
  75. def __init__(self, config):
  76. super().__init__()
  77. self.config = config
  78. # character embeddings
  79. shard_embedding_size = config.hidden_size // config.num_hash_functions
  80. for i in range(config.num_hash_functions):
  81. name = f"HashBucketCodepointEmbedder_{i}"
  82. setattr(self, name, nn.Embedding(config.num_hash_buckets, shard_embedding_size))
  83. self.char_position_embeddings = nn.Embedding(config.num_hash_buckets, config.hidden_size)
  84. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  85. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  86. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  87. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  88. self.register_buffer(
  89. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  90. )
  91. def _hash_bucket_tensors(self, input_ids, num_hashes: int, num_buckets: int):
  92. """
  93. Converts ids to hash bucket ids via multiple hashing.
  94. Args:
  95. input_ids: The codepoints or other IDs to be hashed.
  96. num_hashes: The number of hash functions to use.
  97. num_buckets: The number of hash buckets (i.e. embeddings in each table).
  98. Returns:
  99. A list of tensors, each of which is the hash bucket IDs from one hash function.
  100. """
  101. if num_hashes > len(_PRIMES):
  102. raise ValueError(f"`num_hashes` must be <= {len(_PRIMES)}")
  103. primes = _PRIMES[:num_hashes]
  104. result_tensors = []
  105. for prime in primes:
  106. hashed = ((input_ids + 1) * prime) % num_buckets
  107. result_tensors.append(hashed)
  108. return result_tensors
  109. def _embed_hash_buckets(self, input_ids, embedding_size: int, num_hashes: int, num_buckets: int):
  110. """Converts IDs (e.g. codepoints) into embeddings via multiple hashing."""
  111. if embedding_size % num_hashes != 0:
  112. raise ValueError(f"Expected `embedding_size` ({embedding_size}) % `num_hashes` ({num_hashes}) == 0")
  113. hash_bucket_tensors = self._hash_bucket_tensors(input_ids, num_hashes=num_hashes, num_buckets=num_buckets)
  114. embedding_shards = []
  115. for i, hash_bucket_ids in enumerate(hash_bucket_tensors):
  116. name = f"HashBucketCodepointEmbedder_{i}"
  117. shard_embeddings = getattr(self, name)(hash_bucket_ids)
  118. embedding_shards.append(shard_embeddings)
  119. return torch.cat(embedding_shards, dim=-1)
  120. def forward(
  121. self,
  122. input_ids: torch.LongTensor | None = None,
  123. token_type_ids: torch.LongTensor | None = None,
  124. position_ids: torch.LongTensor | None = None,
  125. inputs_embeds: torch.FloatTensor | None = None,
  126. ) -> torch.FloatTensor:
  127. if input_ids is not None:
  128. input_shape = input_ids.size()
  129. else:
  130. input_shape = inputs_embeds.size()[:-1]
  131. seq_length = input_shape[1]
  132. if position_ids is None:
  133. position_ids = self.position_ids[:, :seq_length]
  134. if token_type_ids is None:
  135. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  136. if inputs_embeds is None:
  137. inputs_embeds = self._embed_hash_buckets(
  138. input_ids, self.config.hidden_size, self.config.num_hash_functions, self.config.num_hash_buckets
  139. )
  140. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  141. embeddings = inputs_embeds + token_type_embeddings
  142. position_embeddings = self.char_position_embeddings(position_ids)
  143. embeddings += position_embeddings
  144. embeddings = self.LayerNorm(embeddings)
  145. embeddings = self.dropout(embeddings)
  146. return embeddings
  147. class CharactersToMolecules(nn.Module):
  148. """Convert character sequence to initial molecule sequence (i.e. downsample) using strided convolutions."""
  149. def __init__(self, config):
  150. super().__init__()
  151. self.conv = nn.Conv1d(
  152. in_channels=config.hidden_size,
  153. out_channels=config.hidden_size,
  154. kernel_size=config.downsampling_rate,
  155. stride=config.downsampling_rate,
  156. )
  157. self.activation = ACT2FN[config.hidden_act]
  158. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  159. def forward(self, char_encoding: torch.Tensor) -> torch.Tensor:
  160. # `cls_encoding`: [batch, 1, hidden_size]
  161. cls_encoding = char_encoding[:, 0:1, :]
  162. # char_encoding has shape [batch, char_seq, hidden_size]
  163. # We transpose it to be [batch, hidden_size, char_seq]
  164. char_encoding = torch.transpose(char_encoding, 1, 2)
  165. downsampled = self.conv(char_encoding)
  166. downsampled = torch.transpose(downsampled, 1, 2)
  167. downsampled = self.activation(downsampled)
  168. # Truncate the last molecule in order to reserve a position for [CLS].
  169. # Often, the last position is never used (unless we completely fill the
  170. # text buffer). This is important in order to maintain alignment on TPUs
  171. # (i.e. a multiple of 128).
  172. downsampled_truncated = downsampled[:, 0:-1, :]
  173. # We also keep [CLS] as a separate sequence position since we always
  174. # want to reserve a position (and the model capacity that goes along
  175. # with that) in the deep BERT stack.
  176. # `result`: [batch, molecule_seq, molecule_dim]
  177. result = torch.cat([cls_encoding, downsampled_truncated], dim=1)
  178. result = self.LayerNorm(result)
  179. return result
  180. class ConvProjection(nn.Module):
  181. """
  182. Project representations from hidden_size*2 back to hidden_size across a window of w = config.upsampling_kernel_size
  183. characters.
  184. """
  185. def __init__(self, config):
  186. super().__init__()
  187. self.config = config
  188. self.conv = nn.Conv1d(
  189. in_channels=config.hidden_size * 2,
  190. out_channels=config.hidden_size,
  191. kernel_size=config.upsampling_kernel_size,
  192. stride=1,
  193. )
  194. self.activation = ACT2FN[config.hidden_act]
  195. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  196. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  197. def forward(
  198. self,
  199. inputs: torch.Tensor,
  200. final_seq_char_positions: torch.Tensor | None = None,
  201. ) -> torch.Tensor:
  202. # inputs has shape [batch, mol_seq, molecule_hidden_size+char_hidden_final]
  203. # we transpose it to be [batch, molecule_hidden_size+char_hidden_final, mol_seq]
  204. inputs = torch.transpose(inputs, 1, 2)
  205. # PyTorch < 1.9 does not support padding="same" (which is used in the original implementation),
  206. # so we pad the tensor manually before passing it to the conv layer
  207. # based on https://github.com/google-research/big_transfer/blob/49afe42338b62af9fbe18f0258197a33ee578a6b/bit_tf2/models.py#L36-L38
  208. pad_total = self.config.upsampling_kernel_size - 1
  209. pad_beg = pad_total // 2
  210. pad_end = pad_total - pad_beg
  211. pad = nn.ConstantPad1d((pad_beg, pad_end), 0)
  212. # `result`: shape (batch_size, char_seq_len, hidden_size)
  213. result = self.conv(pad(inputs))
  214. result = torch.transpose(result, 1, 2)
  215. result = self.activation(result)
  216. result = self.LayerNorm(result)
  217. result = self.dropout(result)
  218. final_char_seq = result
  219. if final_seq_char_positions is not None:
  220. # Limit transformer query seq and attention mask to these character
  221. # positions to greatly reduce the compute cost. Typically, this is just
  222. # done for the MLM training task.
  223. # TODO add support for MLM
  224. raise NotImplementedError("CanineForMaskedLM is currently not supported")
  225. else:
  226. query_seq = final_char_seq
  227. return query_seq
  228. class CanineSelfAttention(nn.Module):
  229. def __init__(self, config):
  230. super().__init__()
  231. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  232. raise ValueError(
  233. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  234. f"heads ({config.num_attention_heads})"
  235. )
  236. self.num_attention_heads = config.num_attention_heads
  237. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  238. self.all_head_size = self.num_attention_heads * self.attention_head_size
  239. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  240. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  241. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  242. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  243. def forward(
  244. self,
  245. from_tensor: torch.Tensor,
  246. to_tensor: torch.Tensor,
  247. attention_mask: torch.FloatTensor | None = None,
  248. output_attentions: bool | None = False,
  249. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  250. batch_size, seq_length, _ = from_tensor.shape
  251. # If this is instantiated as a cross-attention module, the keys
  252. # and values come from an encoder; the attention mask needs to be
  253. # such that the encoder's padding tokens are not attended to.
  254. key_layer = (
  255. self.key(to_tensor)
  256. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  257. .transpose(1, 2)
  258. )
  259. value_layer = (
  260. self.value(to_tensor)
  261. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  262. .transpose(1, 2)
  263. )
  264. query_layer = (
  265. self.query(from_tensor)
  266. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  267. .transpose(1, 2)
  268. )
  269. # Take the dot product between "query" and "key" to get the raw attention scores.
  270. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  271. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  272. if attention_mask is not None:
  273. if attention_mask.ndim == 3:
  274. # if attention_mask is 3D, do the following:
  275. attention_mask = torch.unsqueeze(attention_mask, dim=1)
  276. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  277. # masked positions, this operation will create a tensor which is 0.0 for
  278. # positions we want to attend and the dtype's smallest value for masked positions.
  279. attention_mask = (1.0 - attention_mask.float()) * torch.finfo(attention_scores.dtype).min
  280. # Apply the attention mask (precomputed for all layers in CanineModel forward() function)
  281. attention_scores = attention_scores + attention_mask
  282. # Normalize the attention scores to probabilities.
  283. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  284. # This is actually dropping out entire tokens to attend to, which might
  285. # seem a bit unusual, but is taken from the original Transformer paper.
  286. attention_probs = self.dropout(attention_probs)
  287. context_layer = torch.matmul(attention_probs, value_layer)
  288. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  289. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  290. context_layer = context_layer.view(*new_context_layer_shape)
  291. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  292. return outputs
  293. class CanineSelfOutput(nn.Module):
  294. def __init__(self, config):
  295. super().__init__()
  296. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  297. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  298. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  299. def forward(
  300. self, hidden_states: tuple[torch.FloatTensor], input_tensor: torch.FloatTensor
  301. ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
  302. hidden_states = self.dense(hidden_states)
  303. hidden_states = self.dropout(hidden_states)
  304. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  305. return hidden_states
  306. class CanineAttention(nn.Module):
  307. """
  308. Additional arguments related to local attention:
  309. - **local** (`bool`, *optional*, defaults to `False`) -- Whether to apply local attention.
  310. - **always_attend_to_first_position** (`bool`, *optional*, defaults to `False`) -- Should all blocks be able to
  311. attend
  312. to the `to_tensor`'s first position (e.g. a [CLS] position)? - **first_position_attends_to_all** (`bool`,
  313. *optional*, defaults to `False`) -- Should the *from_tensor*'s first position be able to attend to all
  314. positions within the *from_tensor*? - **attend_from_chunk_width** (`int`, *optional*, defaults to 128) -- The
  315. width of each block-wise chunk in `from_tensor`. - **attend_from_chunk_stride** (`int`, *optional*, defaults to
  316. 128) -- The number of elements to skip when moving to the next block in `from_tensor`. -
  317. **attend_to_chunk_width** (`int`, *optional*, defaults to 128) -- The width of each block-wise chunk in
  318. *to_tensor*. - **attend_to_chunk_stride** (`int`, *optional*, defaults to 128) -- The number of elements to
  319. skip when moving to the next block in `to_tensor`.
  320. """
  321. def __init__(
  322. self,
  323. config,
  324. local=False,
  325. always_attend_to_first_position: bool = False,
  326. first_position_attends_to_all: bool = False,
  327. attend_from_chunk_width: int = 128,
  328. attend_from_chunk_stride: int = 128,
  329. attend_to_chunk_width: int = 128,
  330. attend_to_chunk_stride: int = 128,
  331. ):
  332. super().__init__()
  333. self.self = CanineSelfAttention(config)
  334. self.output = CanineSelfOutput(config)
  335. # additional arguments related to local attention
  336. self.local = local
  337. if attend_from_chunk_width < attend_from_chunk_stride:
  338. raise ValueError(
  339. "`attend_from_chunk_width` < `attend_from_chunk_stride` would cause sequence positions to get skipped."
  340. )
  341. if attend_to_chunk_width < attend_to_chunk_stride:
  342. raise ValueError(
  343. "`attend_to_chunk_width` < `attend_to_chunk_stride`would cause sequence positions to get skipped."
  344. )
  345. self.always_attend_to_first_position = always_attend_to_first_position
  346. self.first_position_attends_to_all = first_position_attends_to_all
  347. self.attend_from_chunk_width = attend_from_chunk_width
  348. self.attend_from_chunk_stride = attend_from_chunk_stride
  349. self.attend_to_chunk_width = attend_to_chunk_width
  350. self.attend_to_chunk_stride = attend_to_chunk_stride
  351. def forward(
  352. self,
  353. hidden_states: tuple[torch.FloatTensor],
  354. attention_mask: torch.FloatTensor | None = None,
  355. output_attentions: bool | None = False,
  356. ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]:
  357. if not self.local:
  358. self_outputs = self.self(hidden_states, hidden_states, attention_mask, output_attentions)
  359. attention_output = self_outputs[0]
  360. else:
  361. from_seq_length = to_seq_length = hidden_states.shape[1]
  362. from_tensor = to_tensor = hidden_states
  363. # Create chunks (windows) that we will attend *from* and then concatenate them.
  364. from_chunks = []
  365. if self.first_position_attends_to_all:
  366. from_chunks.append((0, 1))
  367. # We must skip this first position so that our output sequence is the
  368. # correct length (this matters in the *from* sequence only).
  369. from_start = 1
  370. else:
  371. from_start = 0
  372. for chunk_start in range(from_start, from_seq_length, self.attend_from_chunk_stride):
  373. chunk_end = min(from_seq_length, chunk_start + self.attend_from_chunk_width)
  374. from_chunks.append((chunk_start, chunk_end))
  375. # Determine the chunks (windows) that will attend *to*.
  376. to_chunks = []
  377. if self.first_position_attends_to_all:
  378. to_chunks.append((0, to_seq_length))
  379. for chunk_start in range(0, to_seq_length, self.attend_to_chunk_stride):
  380. chunk_end = min(to_seq_length, chunk_start + self.attend_to_chunk_width)
  381. to_chunks.append((chunk_start, chunk_end))
  382. if len(from_chunks) != len(to_chunks):
  383. raise ValueError(
  384. f"Expected to have same number of `from_chunks` ({from_chunks}) and "
  385. f"`to_chunks` ({from_chunks}). Check strides."
  386. )
  387. # next, compute attention scores for each pair of windows and concatenate
  388. attention_output_chunks = []
  389. attention_probs_chunks = []
  390. for (from_start, from_end), (to_start, to_end) in zip(from_chunks, to_chunks):
  391. from_tensor_chunk = from_tensor[:, from_start:from_end, :]
  392. to_tensor_chunk = to_tensor[:, to_start:to_end, :]
  393. # `attention_mask`: <float>[batch_size, from_seq, to_seq]
  394. # `attention_mask_chunk`: <float>[batch_size, from_seq_chunk, to_seq_chunk]
  395. attention_mask_chunk = attention_mask[:, from_start:from_end, to_start:to_end]
  396. if self.always_attend_to_first_position:
  397. cls_attention_mask = attention_mask[:, from_start:from_end, 0:1]
  398. attention_mask_chunk = torch.cat([cls_attention_mask, attention_mask_chunk], dim=2)
  399. cls_position = to_tensor[:, 0:1, :]
  400. to_tensor_chunk = torch.cat([cls_position, to_tensor_chunk], dim=1)
  401. attention_outputs_chunk = self.self(
  402. from_tensor_chunk, to_tensor_chunk, attention_mask_chunk, output_attentions
  403. )
  404. attention_output_chunks.append(attention_outputs_chunk[0])
  405. if output_attentions:
  406. attention_probs_chunks.append(attention_outputs_chunk[1])
  407. attention_output = torch.cat(attention_output_chunks, dim=1)
  408. attention_output = self.output(attention_output, hidden_states)
  409. outputs = (attention_output,)
  410. if not self.local:
  411. outputs = outputs + self_outputs[1:] # add attentions if we output them
  412. else:
  413. outputs = outputs + tuple(attention_probs_chunks) # add attentions if we output them
  414. return outputs
  415. class CanineIntermediate(nn.Module):
  416. def __init__(self, config):
  417. super().__init__()
  418. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  419. if isinstance(config.hidden_act, str):
  420. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  421. else:
  422. self.intermediate_act_fn = config.hidden_act
  423. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  424. hidden_states = self.dense(hidden_states)
  425. hidden_states = self.intermediate_act_fn(hidden_states)
  426. return hidden_states
  427. class CanineOutput(nn.Module):
  428. def __init__(self, config):
  429. super().__init__()
  430. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  431. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  432. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  433. def forward(self, hidden_states: tuple[torch.FloatTensor], input_tensor: torch.FloatTensor) -> torch.FloatTensor:
  434. hidden_states = self.dense(hidden_states)
  435. hidden_states = self.dropout(hidden_states)
  436. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  437. return hidden_states
  438. class CanineLayer(GradientCheckpointingLayer):
  439. def __init__(
  440. self,
  441. config,
  442. local,
  443. always_attend_to_first_position,
  444. first_position_attends_to_all,
  445. attend_from_chunk_width,
  446. attend_from_chunk_stride,
  447. attend_to_chunk_width,
  448. attend_to_chunk_stride,
  449. ):
  450. super().__init__()
  451. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  452. self.seq_len_dim = 1
  453. self.attention = CanineAttention(
  454. config,
  455. local,
  456. always_attend_to_first_position,
  457. first_position_attends_to_all,
  458. attend_from_chunk_width,
  459. attend_from_chunk_stride,
  460. attend_to_chunk_width,
  461. attend_to_chunk_stride,
  462. )
  463. self.intermediate = CanineIntermediate(config)
  464. self.output = CanineOutput(config)
  465. def forward(
  466. self,
  467. hidden_states: tuple[torch.FloatTensor],
  468. attention_mask: torch.FloatTensor | None = None,
  469. output_attentions: bool | None = False,
  470. ) -> tuple[torch.FloatTensor, torch.FloatTensor | None]:
  471. self_attention_outputs = self.attention(
  472. hidden_states,
  473. attention_mask,
  474. output_attentions=output_attentions,
  475. )
  476. attention_output = self_attention_outputs[0]
  477. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  478. layer_output = apply_chunking_to_forward(
  479. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  480. )
  481. outputs = (layer_output,) + outputs
  482. return outputs
  483. def feed_forward_chunk(self, attention_output):
  484. intermediate_output = self.intermediate(attention_output)
  485. layer_output = self.output(intermediate_output, attention_output)
  486. return layer_output
  487. class CanineEncoder(nn.Module):
  488. def __init__(
  489. self,
  490. config,
  491. local=False,
  492. always_attend_to_first_position=False,
  493. first_position_attends_to_all=False,
  494. attend_from_chunk_width=128,
  495. attend_from_chunk_stride=128,
  496. attend_to_chunk_width=128,
  497. attend_to_chunk_stride=128,
  498. ):
  499. super().__init__()
  500. self.config = config
  501. self.layer = nn.ModuleList(
  502. [
  503. CanineLayer(
  504. config,
  505. local,
  506. always_attend_to_first_position,
  507. first_position_attends_to_all,
  508. attend_from_chunk_width,
  509. attend_from_chunk_stride,
  510. attend_to_chunk_width,
  511. attend_to_chunk_stride,
  512. )
  513. for _ in range(config.num_hidden_layers)
  514. ]
  515. )
  516. self.gradient_checkpointing = False
  517. def forward(
  518. self,
  519. hidden_states: tuple[torch.FloatTensor],
  520. attention_mask: torch.FloatTensor | None = None,
  521. output_attentions: bool | None = False,
  522. output_hidden_states: bool | None = False,
  523. return_dict: bool | None = True,
  524. ) -> tuple | BaseModelOutput:
  525. all_hidden_states = () if output_hidden_states else None
  526. all_self_attentions = () if output_attentions else None
  527. for i, layer_module in enumerate(self.layer):
  528. if output_hidden_states:
  529. all_hidden_states = all_hidden_states + (hidden_states,)
  530. layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
  531. hidden_states = layer_outputs[0]
  532. if output_attentions:
  533. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  534. if output_hidden_states:
  535. all_hidden_states = all_hidden_states + (hidden_states,)
  536. if not return_dict:
  537. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  538. return BaseModelOutput(
  539. last_hidden_state=hidden_states,
  540. hidden_states=all_hidden_states,
  541. attentions=all_self_attentions,
  542. )
  543. class CaninePooler(nn.Module):
  544. def __init__(self, config):
  545. super().__init__()
  546. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  547. self.activation = nn.Tanh()
  548. def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor:
  549. # We "pool" the model by simply taking the hidden state corresponding
  550. # to the first token.
  551. first_token_tensor = hidden_states[:, 0]
  552. pooled_output = self.dense(first_token_tensor)
  553. pooled_output = self.activation(pooled_output)
  554. return pooled_output
  555. class CaninePredictionHeadTransform(nn.Module):
  556. def __init__(self, config):
  557. super().__init__()
  558. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  559. if isinstance(config.hidden_act, str):
  560. self.transform_act_fn = ACT2FN[config.hidden_act]
  561. else:
  562. self.transform_act_fn = config.hidden_act
  563. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  564. def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor:
  565. hidden_states = self.dense(hidden_states)
  566. hidden_states = self.transform_act_fn(hidden_states)
  567. hidden_states = self.LayerNorm(hidden_states)
  568. return hidden_states
  569. class CanineLMPredictionHead(nn.Module):
  570. def __init__(self, config):
  571. super().__init__()
  572. self.transform = CaninePredictionHeadTransform(config)
  573. # The output weights are the same as the input embeddings, but there is
  574. # an output-only bias for each token.
  575. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  576. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  577. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  578. def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor:
  579. hidden_states = self.transform(hidden_states)
  580. hidden_states = self.decoder(hidden_states)
  581. return hidden_states
  582. class CanineOnlyMLMHead(nn.Module):
  583. def __init__(self, config):
  584. super().__init__()
  585. self.predictions = CanineLMPredictionHead(config)
  586. def forward(
  587. self,
  588. sequence_output: tuple[torch.Tensor],
  589. ) -> tuple[torch.Tensor]:
  590. prediction_scores = self.predictions(sequence_output)
  591. return prediction_scores
  592. @auto_docstring
  593. class CaninePreTrainedModel(PreTrainedModel):
  594. config: CanineConfig
  595. base_model_prefix = "canine"
  596. supports_gradient_checkpointing = True
  597. def _init_weights(self, module):
  598. super()._init_weights(module)
  599. if isinstance(module, CanineEmbeddings):
  600. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  601. @auto_docstring
  602. class CanineModel(CaninePreTrainedModel):
  603. def __init__(self, config, add_pooling_layer=True):
  604. r"""
  605. add_pooling_layer (bool, *optional*, defaults to `True`):
  606. Whether to add a pooling layer
  607. """
  608. super().__init__(config)
  609. self.config = config
  610. shallow_config = copy.deepcopy(config)
  611. shallow_config.num_hidden_layers = 1
  612. self.char_embeddings = CanineEmbeddings(config)
  613. # shallow/low-dim transformer encoder to get a initial character encoding
  614. self.initial_char_encoder = CanineEncoder(
  615. shallow_config,
  616. local=True,
  617. always_attend_to_first_position=False,
  618. first_position_attends_to_all=False,
  619. attend_from_chunk_width=config.local_transformer_stride,
  620. attend_from_chunk_stride=config.local_transformer_stride,
  621. attend_to_chunk_width=config.local_transformer_stride,
  622. attend_to_chunk_stride=config.local_transformer_stride,
  623. )
  624. self.chars_to_molecules = CharactersToMolecules(config)
  625. # deep transformer encoder
  626. self.encoder = CanineEncoder(config)
  627. self.projection = ConvProjection(config)
  628. # shallow/low-dim transformer encoder to get a final character encoding
  629. self.final_char_encoder = CanineEncoder(shallow_config)
  630. self.pooler = CaninePooler(config) if add_pooling_layer else None
  631. # Initialize weights and apply final processing
  632. self.post_init()
  633. def _create_3d_attention_mask_from_input_mask(self, from_tensor, to_mask):
  634. """
  635. Create 3D attention mask from a 2D tensor mask.
  636. Args:
  637. from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
  638. to_mask: int32 Tensor of shape [batch_size, to_seq_length].
  639. Returns:
  640. float Tensor of shape [batch_size, from_seq_length, to_seq_length].
  641. """
  642. batch_size, from_seq_length = from_tensor.shape[0], from_tensor.shape[1]
  643. to_seq_length = to_mask.shape[1]
  644. to_mask = torch.reshape(to_mask, (batch_size, 1, to_seq_length)).float()
  645. # We don't assume that `from_tensor` is a mask (although it could be). We
  646. # don't actually care if we attend *from* padding tokens (only *to* padding)
  647. # tokens so we create a tensor of all ones.
  648. broadcast_ones = torch.ones(size=(batch_size, from_seq_length, 1), dtype=torch.float32, device=to_mask.device)
  649. # Here we broadcast along two dimensions to create the mask.
  650. mask = broadcast_ones * to_mask
  651. return mask
  652. def _downsample_attention_mask(self, char_attention_mask: torch.Tensor, downsampling_rate: int):
  653. """Downsample 2D character attention mask to 2D molecule attention mask using MaxPool1d layer."""
  654. # first, make char_attention_mask 3D by adding a channel dim
  655. batch_size, char_seq_len = char_attention_mask.shape
  656. poolable_char_mask = torch.reshape(char_attention_mask, (batch_size, 1, char_seq_len))
  657. # next, apply MaxPool1d to get pooled_molecule_mask of shape (batch_size, 1, mol_seq_len)
  658. pooled_molecule_mask = torch.nn.MaxPool1d(kernel_size=downsampling_rate, stride=downsampling_rate)(
  659. poolable_char_mask.float()
  660. )
  661. # finally, squeeze to get tensor of shape (batch_size, mol_seq_len)
  662. molecule_attention_mask = torch.squeeze(pooled_molecule_mask, dim=-1)
  663. return molecule_attention_mask
  664. def _repeat_molecules(self, molecules: torch.Tensor, char_seq_length: int) -> torch.Tensor:
  665. """Repeats molecules to make them the same length as the char sequence."""
  666. rate = self.config.downsampling_rate
  667. molecules_without_extra_cls = molecules[:, 1:, :]
  668. # `repeated`: [batch_size, almost_char_seq_len, molecule_hidden_size]
  669. repeated = torch.repeat_interleave(molecules_without_extra_cls, repeats=rate, dim=-2)
  670. # So far, we've repeated the elements sufficient for any `char_seq_length`
  671. # that's a multiple of `downsampling_rate`. Now we account for the last
  672. # n elements (n < `downsampling_rate`), i.e. the remainder of floor
  673. # division. We do this by repeating the last molecule a few extra times.
  674. last_molecule = molecules[:, -1:, :]
  675. remainder_length = char_seq_length % rate
  676. remainder_repeated = torch.repeat_interleave(
  677. last_molecule,
  678. # +1 molecule to compensate for truncation.
  679. repeats=remainder_length + rate,
  680. dim=-2,
  681. )
  682. # `repeated`: [batch_size, char_seq_len, molecule_hidden_size]
  683. return torch.cat([repeated, remainder_repeated], dim=-2)
  684. @auto_docstring
  685. def forward(
  686. self,
  687. input_ids: torch.LongTensor | None = None,
  688. attention_mask: torch.FloatTensor | None = None,
  689. token_type_ids: torch.LongTensor | None = None,
  690. position_ids: torch.LongTensor | None = None,
  691. inputs_embeds: torch.FloatTensor | None = None,
  692. output_attentions: bool | None = None,
  693. output_hidden_states: bool | None = None,
  694. return_dict: bool | None = None,
  695. **kwargs,
  696. ) -> tuple | CanineModelOutputWithPooling:
  697. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  698. output_hidden_states = (
  699. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  700. )
  701. all_hidden_states = () if output_hidden_states else None
  702. all_self_attentions = () if output_attentions else None
  703. return_dict = return_dict if return_dict is not None else self.config.return_dict
  704. if input_ids is not None and inputs_embeds is not None:
  705. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  706. elif input_ids is not None:
  707. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  708. input_shape = input_ids.size()
  709. elif inputs_embeds is not None:
  710. input_shape = inputs_embeds.size()[:-1]
  711. else:
  712. raise ValueError("You have to specify either input_ids or inputs_embeds")
  713. batch_size, seq_length = input_shape
  714. device = input_ids.device if input_ids is not None else inputs_embeds.device
  715. if attention_mask is None:
  716. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  717. if token_type_ids is None:
  718. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  719. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  720. # ourselves in which case we just need to make it broadcastable to all heads.
  721. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  722. molecule_attention_mask = self._downsample_attention_mask(
  723. attention_mask, downsampling_rate=self.config.downsampling_rate
  724. )
  725. extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  726. molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
  727. )
  728. # `input_char_embeddings`: shape (batch_size, char_seq, char_dim)
  729. input_char_embeddings = self.char_embeddings(
  730. input_ids=input_ids,
  731. position_ids=position_ids,
  732. token_type_ids=token_type_ids,
  733. inputs_embeds=inputs_embeds,
  734. )
  735. # Contextualize character embeddings using shallow Transformer.
  736. # We use a 3D attention mask for the local attention.
  737. # `input_char_encoding`: shape (batch_size, char_seq_len, char_dim)
  738. char_attention_mask = self._create_3d_attention_mask_from_input_mask(
  739. input_ids if input_ids is not None else inputs_embeds, attention_mask
  740. )
  741. init_chars_encoder_outputs = self.initial_char_encoder(
  742. input_char_embeddings,
  743. attention_mask=char_attention_mask,
  744. output_attentions=output_attentions,
  745. output_hidden_states=output_hidden_states,
  746. )
  747. input_char_encoding = init_chars_encoder_outputs.last_hidden_state
  748. # Downsample chars to molecules.
  749. # The following lines have dimensions: [batch, molecule_seq, molecule_dim].
  750. # In this transformation, we change the dimensionality from `char_dim` to
  751. # `molecule_dim`, but do *NOT* add a resnet connection. Instead, we rely on
  752. # the resnet connections (a) from the final char transformer stack back into
  753. # the original char transformer stack and (b) the resnet connections from
  754. # the final char transformer stack back into the deep BERT stack of
  755. # molecules.
  756. #
  757. # Empirically, it is critical to use a powerful enough transformation here:
  758. # mean pooling causes training to diverge with huge gradient norms in this
  759. # region of the model; using a convolution here resolves this issue. From
  760. # this, it seems that molecules and characters require a very different
  761. # feature space; intuitively, this makes sense.
  762. init_molecule_encoding = self.chars_to_molecules(input_char_encoding)
  763. # Deep BERT encoder
  764. # `molecule_sequence_output`: shape (batch_size, mol_seq_len, mol_dim)
  765. encoder_outputs = self.encoder(
  766. init_molecule_encoding,
  767. attention_mask=extended_molecule_attention_mask,
  768. output_attentions=output_attentions,
  769. output_hidden_states=output_hidden_states,
  770. return_dict=return_dict,
  771. )
  772. molecule_sequence_output = encoder_outputs[0]
  773. pooled_output = self.pooler(molecule_sequence_output) if self.pooler is not None else None
  774. # Upsample molecules back to characters.
  775. # `repeated_molecules`: shape (batch_size, char_seq_len, mol_hidden_size)
  776. repeated_molecules = self._repeat_molecules(molecule_sequence_output, char_seq_length=input_shape[-1])
  777. # Concatenate representations (contextualized char embeddings and repeated molecules):
  778. # `concat`: shape [batch_size, char_seq_len, molecule_hidden_size+char_hidden_final]
  779. concat = torch.cat([input_char_encoding, repeated_molecules], dim=-1)
  780. # Project representation dimension back to hidden_size
  781. # `sequence_output`: shape (batch_size, char_seq_len, hidden_size])
  782. sequence_output = self.projection(concat)
  783. # Apply final shallow Transformer
  784. # `sequence_output`: shape (batch_size, char_seq_len, hidden_size])
  785. final_chars_encoder_outputs = self.final_char_encoder(
  786. sequence_output,
  787. attention_mask=extended_attention_mask,
  788. output_attentions=output_attentions,
  789. output_hidden_states=output_hidden_states,
  790. )
  791. sequence_output = final_chars_encoder_outputs.last_hidden_state
  792. if output_hidden_states:
  793. deep_encoder_hidden_states = encoder_outputs.hidden_states if return_dict else encoder_outputs[1]
  794. all_hidden_states = (
  795. all_hidden_states
  796. + init_chars_encoder_outputs.hidden_states
  797. + deep_encoder_hidden_states
  798. + final_chars_encoder_outputs.hidden_states
  799. )
  800. if output_attentions:
  801. deep_encoder_self_attentions = encoder_outputs.attentions if return_dict else encoder_outputs[-1]
  802. all_self_attentions = (
  803. all_self_attentions
  804. + init_chars_encoder_outputs.attentions
  805. + deep_encoder_self_attentions
  806. + final_chars_encoder_outputs.attentions
  807. )
  808. if not return_dict:
  809. output = (sequence_output, pooled_output)
  810. output += tuple(v for v in [all_hidden_states, all_self_attentions] if v is not None)
  811. return output
  812. return CanineModelOutputWithPooling(
  813. last_hidden_state=sequence_output,
  814. pooler_output=pooled_output,
  815. hidden_states=all_hidden_states,
  816. attentions=all_self_attentions,
  817. )
  818. @auto_docstring(
  819. custom_intro="""
  820. CANINE Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  821. output) e.g. for GLUE tasks.
  822. """
  823. )
  824. class CanineForSequenceClassification(CaninePreTrainedModel):
  825. def __init__(self, config):
  826. super().__init__(config)
  827. self.num_labels = config.num_labels
  828. self.canine = CanineModel(config)
  829. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  830. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  831. # Initialize weights and apply final processing
  832. self.post_init()
  833. @auto_docstring
  834. def forward(
  835. self,
  836. input_ids: torch.LongTensor | None = None,
  837. attention_mask: torch.FloatTensor | None = None,
  838. token_type_ids: torch.LongTensor | None = None,
  839. position_ids: torch.LongTensor | None = None,
  840. inputs_embeds: torch.FloatTensor | None = None,
  841. labels: torch.LongTensor | None = None,
  842. output_attentions: bool | None = None,
  843. output_hidden_states: bool | None = None,
  844. return_dict: bool | None = None,
  845. **kwargs,
  846. ) -> tuple | SequenceClassifierOutput:
  847. r"""
  848. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  849. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  850. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  851. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  852. """
  853. return_dict = return_dict if return_dict is not None else self.config.return_dict
  854. outputs = self.canine(
  855. input_ids,
  856. attention_mask=attention_mask,
  857. token_type_ids=token_type_ids,
  858. position_ids=position_ids,
  859. inputs_embeds=inputs_embeds,
  860. output_attentions=output_attentions,
  861. output_hidden_states=output_hidden_states,
  862. return_dict=return_dict,
  863. )
  864. pooled_output = outputs[1]
  865. pooled_output = self.dropout(pooled_output)
  866. logits = self.classifier(pooled_output)
  867. loss = None
  868. if labels is not None:
  869. if self.config.problem_type is None:
  870. if self.num_labels == 1:
  871. self.config.problem_type = "regression"
  872. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  873. self.config.problem_type = "single_label_classification"
  874. else:
  875. self.config.problem_type = "multi_label_classification"
  876. if self.config.problem_type == "regression":
  877. loss_fct = MSELoss()
  878. if self.num_labels == 1:
  879. loss = loss_fct(logits.squeeze(), labels.squeeze())
  880. else:
  881. loss = loss_fct(logits, labels)
  882. elif self.config.problem_type == "single_label_classification":
  883. loss_fct = CrossEntropyLoss()
  884. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  885. elif self.config.problem_type == "multi_label_classification":
  886. loss_fct = BCEWithLogitsLoss()
  887. loss = loss_fct(logits, labels)
  888. if not return_dict:
  889. output = (logits,) + outputs[2:]
  890. return ((loss,) + output) if loss is not None else output
  891. return SequenceClassifierOutput(
  892. loss=loss,
  893. logits=logits,
  894. hidden_states=outputs.hidden_states,
  895. attentions=outputs.attentions,
  896. )
  897. @auto_docstring
  898. class CanineForMultipleChoice(CaninePreTrainedModel):
  899. def __init__(self, config):
  900. super().__init__(config)
  901. self.canine = CanineModel(config)
  902. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  903. self.classifier = nn.Linear(config.hidden_size, 1)
  904. # Initialize weights and apply final processing
  905. self.post_init()
  906. @auto_docstring
  907. def forward(
  908. self,
  909. input_ids: torch.LongTensor | None = None,
  910. attention_mask: torch.FloatTensor | None = None,
  911. token_type_ids: torch.LongTensor | None = None,
  912. position_ids: torch.LongTensor | None = None,
  913. inputs_embeds: torch.FloatTensor | None = None,
  914. labels: torch.LongTensor | None = None,
  915. output_attentions: bool | None = None,
  916. output_hidden_states: bool | None = None,
  917. return_dict: bool | None = None,
  918. **kwargs,
  919. ) -> tuple | MultipleChoiceModelOutput:
  920. r"""
  921. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  922. Indices of input sequence tokens in the vocabulary.
  923. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  924. [`PreTrainedTokenizer.__call__`] for details.
  925. [What are input IDs?](../glossary#input-ids)
  926. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  927. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  928. 1]`:
  929. - 0 corresponds to a *sentence A* token,
  930. - 1 corresponds to a *sentence B* token.
  931. [What are token type IDs?](../glossary#token-type-ids)
  932. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  933. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  934. config.max_position_embeddings - 1]`.
  935. [What are position IDs?](../glossary#position-ids)
  936. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  937. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  938. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  939. model's internal embedding lookup matrix.
  940. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  941. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  942. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  943. `input_ids` above)
  944. """
  945. return_dict = return_dict if return_dict is not None else self.config.return_dict
  946. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  947. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  948. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  949. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  950. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  951. inputs_embeds = (
  952. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  953. if inputs_embeds is not None
  954. else None
  955. )
  956. outputs = self.canine(
  957. input_ids,
  958. attention_mask=attention_mask,
  959. token_type_ids=token_type_ids,
  960. position_ids=position_ids,
  961. inputs_embeds=inputs_embeds,
  962. output_attentions=output_attentions,
  963. output_hidden_states=output_hidden_states,
  964. return_dict=return_dict,
  965. )
  966. pooled_output = outputs[1]
  967. pooled_output = self.dropout(pooled_output)
  968. logits = self.classifier(pooled_output)
  969. reshaped_logits = logits.view(-1, num_choices)
  970. loss = None
  971. if labels is not None:
  972. loss_fct = CrossEntropyLoss()
  973. loss = loss_fct(reshaped_logits, labels)
  974. if not return_dict:
  975. output = (reshaped_logits,) + outputs[2:]
  976. return ((loss,) + output) if loss is not None else output
  977. return MultipleChoiceModelOutput(
  978. loss=loss,
  979. logits=reshaped_logits,
  980. hidden_states=outputs.hidden_states,
  981. attentions=outputs.attentions,
  982. )
  983. @auto_docstring
  984. class CanineForTokenClassification(CaninePreTrainedModel):
  985. def __init__(self, config):
  986. super().__init__(config)
  987. self.num_labels = config.num_labels
  988. self.canine = CanineModel(config)
  989. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  990. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  991. # Initialize weights and apply final processing
  992. self.post_init()
  993. @auto_docstring
  994. def forward(
  995. self,
  996. input_ids: torch.LongTensor | None = None,
  997. attention_mask: torch.FloatTensor | None = None,
  998. token_type_ids: torch.LongTensor | None = None,
  999. position_ids: torch.LongTensor | None = None,
  1000. inputs_embeds: torch.FloatTensor | None = None,
  1001. labels: torch.LongTensor | None = None,
  1002. output_attentions: bool | None = None,
  1003. output_hidden_states: bool | None = None,
  1004. return_dict: bool | None = None,
  1005. **kwargs,
  1006. ) -> tuple | TokenClassifierOutput:
  1007. r"""
  1008. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1009. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1010. Example:
  1011. ```python
  1012. >>> from transformers import AutoTokenizer, CanineForTokenClassification
  1013. >>> import torch
  1014. >>> tokenizer = AutoTokenizer.from_pretrained("google/canine-s")
  1015. >>> model = CanineForTokenClassification.from_pretrained("google/canine-s")
  1016. >>> inputs = tokenizer(
  1017. ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
  1018. ... )
  1019. >>> with torch.no_grad():
  1020. ... logits = model(**inputs).logits
  1021. >>> predicted_token_class_ids = logits.argmax(-1)
  1022. >>> # Note that tokens are classified rather then input words which means that
  1023. >>> # there might be more predicted token classes than words.
  1024. >>> # Multiple token classes might account for the same word
  1025. >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
  1026. >>> predicted_tokens_classes # doctest: +SKIP
  1027. ```
  1028. ```python
  1029. >>> labels = predicted_token_class_ids
  1030. >>> loss = model(**inputs, labels=labels).loss
  1031. >>> round(loss.item(), 2) # doctest: +SKIP
  1032. ```"""
  1033. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1034. outputs = self.canine(
  1035. input_ids,
  1036. attention_mask=attention_mask,
  1037. token_type_ids=token_type_ids,
  1038. position_ids=position_ids,
  1039. inputs_embeds=inputs_embeds,
  1040. output_attentions=output_attentions,
  1041. output_hidden_states=output_hidden_states,
  1042. return_dict=return_dict,
  1043. )
  1044. sequence_output = outputs[0]
  1045. sequence_output = self.dropout(sequence_output)
  1046. logits = self.classifier(sequence_output)
  1047. loss = None
  1048. if labels is not None:
  1049. loss_fct = CrossEntropyLoss()
  1050. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1051. if not return_dict:
  1052. output = (logits,) + outputs[2:]
  1053. return ((loss,) + output) if loss is not None else output
  1054. return TokenClassifierOutput(
  1055. loss=loss,
  1056. logits=logits,
  1057. hidden_states=outputs.hidden_states,
  1058. attentions=outputs.attentions,
  1059. )
  1060. @auto_docstring
  1061. class CanineForQuestionAnswering(CaninePreTrainedModel):
  1062. def __init__(self, config):
  1063. super().__init__(config)
  1064. self.num_labels = config.num_labels
  1065. self.canine = CanineModel(config)
  1066. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1067. # Initialize weights and apply final processing
  1068. self.post_init()
  1069. @auto_docstring
  1070. def forward(
  1071. self,
  1072. input_ids: torch.LongTensor | None = None,
  1073. attention_mask: torch.FloatTensor | None = None,
  1074. token_type_ids: torch.LongTensor | None = None,
  1075. position_ids: torch.LongTensor | None = None,
  1076. inputs_embeds: torch.FloatTensor | None = None,
  1077. start_positions: torch.LongTensor | None = None,
  1078. end_positions: torch.LongTensor | None = None,
  1079. output_attentions: bool | None = None,
  1080. output_hidden_states: bool | None = None,
  1081. return_dict: bool | None = None,
  1082. **kwargs,
  1083. ) -> tuple | QuestionAnsweringModelOutput:
  1084. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1085. outputs = self.canine(
  1086. input_ids,
  1087. attention_mask=attention_mask,
  1088. token_type_ids=token_type_ids,
  1089. position_ids=position_ids,
  1090. inputs_embeds=inputs_embeds,
  1091. output_attentions=output_attentions,
  1092. output_hidden_states=output_hidden_states,
  1093. return_dict=return_dict,
  1094. )
  1095. sequence_output = outputs[0]
  1096. logits = self.qa_outputs(sequence_output)
  1097. start_logits, end_logits = logits.split(1, dim=-1)
  1098. start_logits = start_logits.squeeze(-1)
  1099. end_logits = end_logits.squeeze(-1)
  1100. total_loss = None
  1101. if start_positions is not None and end_positions is not None:
  1102. # If we are on multi-GPU, split add a dimension
  1103. if len(start_positions.size()) > 1:
  1104. start_positions = start_positions.squeeze(-1)
  1105. if len(end_positions.size()) > 1:
  1106. end_positions = end_positions.squeeze(-1)
  1107. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1108. ignored_index = start_logits.size(1)
  1109. start_positions.clamp_(0, ignored_index)
  1110. end_positions.clamp_(0, ignored_index)
  1111. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1112. start_loss = loss_fct(start_logits, start_positions)
  1113. end_loss = loss_fct(end_logits, end_positions)
  1114. total_loss = (start_loss + end_loss) / 2
  1115. if not return_dict:
  1116. output = (start_logits, end_logits) + outputs[2:]
  1117. return ((total_loss,) + output) if total_loss is not None else output
  1118. return QuestionAnsweringModelOutput(
  1119. loss=total_loss,
  1120. start_logits=start_logits,
  1121. end_logits=end_logits,
  1122. hidden_states=outputs.hidden_states,
  1123. attentions=outputs.attentions,
  1124. )
  1125. __all__ = [
  1126. "CanineForMultipleChoice",
  1127. "CanineForQuestionAnswering",
  1128. "CanineForSequenceClassification",
  1129. "CanineForTokenClassification",
  1130. "CanineLayer",
  1131. "CanineModel",
  1132. "CaninePreTrainedModel",
  1133. ]